Deep Markov Models
This repository contains theano code for implementing Deep Markov Models. The code is documented and should be easy to modify for your own applications.
The code uses variational inference during learning to maximize the likelihood of the observed data:
z1...zT
and the observations x1...xT
together describe the generative process for the data.p(x_t|z_t)
and transition functions p(z_t|z_{t-1})
are parameterized by deep neural networksq(z1..zT | x1...xT)
represents the inference networkThis package has the following requirements:
python2.7
I used the following ~/.theanorc
configuration file:
[global]
floatX=float32
[mode]=FAST_RUN
[nvcc]
fastmath=True
[cuda]
root=/usr/local/cuda
You can change whether the model is run on the GPU or CPU by modifying the THEANO_FLAGS
. See here for documentation.
model_th
: This folder contains raw theano code implementing the model. See the folder for details on how the DMM was implementation
and pointers to portions of the code.dmm_data
: This folder contains code to load the polyphonic music data and a synthetic dataset. Add or change code in load.py
(dmm_data/load.py) to run the model on your own data.ipynb
: This folder contains some IPython notebooks with examples on loading and running the model on your own data.parse_args.py
: This file contains hyperparameters used by the model. Run python parse_args.py -h
for an explanation of what the various choices of parameters change in the generative model and inference network.expt
: Experimental setup for running the DMM on the polyphonic music dataset
expt_template
: Experimental setup for running the DMM on synthetic real-valued observations.expt/
.Please cite the following paper if you find the code useful in your research:
@inproceedings{krishnan2016structured,
title={Structured Inference Networks for Nonlinear State Space Models},
author={Krishnan, Rahul G and Shalit, Uri and Sontag, David},
booktitle={AAAI},
year={2017}
}
This paper subsumes the work in : [Deep Kalman Filters] (https://arxiv.org/abs/1511.05121)