Probabilistic Nlg Save

Tensorflow Implementation of Stochastic Wasserstein Autoencoder for Probabilistic Sentence Generation (NAACL 2019).

Project README

Stochastic Wasserstein Autoencoder for Probabilistic Sentence Generation

This is the official codebase for the following paper, implemented in tensorflow:

Hareesh Bahuleyan, Lili Mou, Hao Zhou, Olga Vechtomova. Stochastic Wasserstein Autoencoder for Probabilistic Sentence Generation. NAACL 2019. https://arxiv.org/pdf/1806.08462.pdf

Overview

This package contains the code for two tasks

  • SNLI Generation (snli : autoencoder models)
  • Dialog Generation (dialog : encoder-decoder models)

For the above tasks, the code for the following models have been made available:

  1. Variational autoencoder (vae) / Variational encoder-decoder (ved)
  2. Deterministic Wasserstein autoencoder (wae-det) / Deterministic Wasserstein encoder-decoder (wed-det)
  3. Stochastic Wasserstein autoencoder (wae-stochastic) / Stochastic Wasserstein encoder-decoder (wed-stochastic)

Datasets

The models mentioned in the paper have been evaluated on two datasets:

Additionally, the following dataset is also available to run dialog generation experiments:

The data has been preprocessed and the train-val-test split is provided in the data/ directory of the respective task.

Requirements

  • numpy==1.16.0
  • pandas==0.22.0
  • gensim==3.7.0
  • nltk==3.2.3
  • Keras==2.0.8
  • tqdm==4.19.1
  • tensorflow-gpu==1.3.0
  • sklearn
  • matplotlib

Instructions

  1. Create a virtual environment using conda
conda create -n nlg python=3.6.1
  1. Activate virtual environment and install the required packages.
source activate nlg
cd probabilistic_nlg/
pip install -r requirements.txt
  1. Generate word2vec, required for initializing word embeddings (you would need to specify the dataset as argument for dialog generation task) :
cd snli/
python w2v_generator.py
  1. Train the desired model, set configurations in the model_config.py file. For example,
cd wae-det
vim model_config.py # Make necessary edits or specify the hyperparams as command line arguments as below
python train.py --lstm_hidden_units=100 --vocab_size=30000 --latent_dim=100 --batch_size=128 --n_epochs=20 --kernel=IMQ --lambda_val=3.0
  • The model checkpoints are stored in models/ directory, the summaries for Tensorboard are stored in summary_logs/ directory. As training progresses, the metrics on the validation set are dumped intobleu_log.txt and bleu/ directory. The model configuration and outputs generated during training are written to a text file within runs/
  1. Runpredict.py specifying the desired checkpoint (--ckpt) to (1) generate sentences given test set inputs; (2) generate sentences by randomly sampling from the latent space; (3) linear interpolation between sentence in the latent space. By default for vae and wae-stochastic, sampling from latent space is carried out within one standard deviation from the mean . Note that predict.py also outputs the BLEU scores. Hence, when computing BLEU scores, it is ideal to simply use the mean (i.e., no sampling) - for this, set the argument --z_temp=0.0. The random_sample_save(checkpoint, num_batches=3) function call within predict.py automatically saves sentences generated by latent space sampling into samples/sample.txt

  2. To compute the metrics for evaluating the latent space (AvgLen, UnigramKL, Entropy) as proposed in the paper, run evaluate_latent_space.py specifying reference sentence set path (i.e., training corpus) and generated sentence samples path (~100k samples is recommended). For example:

python evaluate_latent_space.py -ref='snli/data/snli_sentences_all.txt' -gen='snli/wae-det/samples/sample.txt'

Citation

If you found this code useful in your research, please cite:

@inproceedings{probabilisticNLG2019,
  title={Stochastic Wasserstein Autoencoder for Probabilistic Sentence Generation},
  author={Bahuleyan, Hareesh and Mou, Lili and Zhou, Hao and Vechtomova, Olga},
  booktitle={Proceedings of the Annual Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT)},
  year={2019}
}
Open Source Agenda is not affiliated with "Probabilistic Nlg" Project. README Source: HareeshBahuleyan/probabilistic_nlg

Open Source Agenda Badge

Open Source Agenda Rating