A PyTorch implementation of the hierarchical encoder-decoder architecture (HRED) introduced in Sordoni et al (2015). It is a hierarchical encoder-decoder architecture for modeling conversation triples in the MovieTriples dataset. This version of the model is built for the MovieTriples dataset.
A PyTorch implementation of the hierarchical encoder-decoder architecture (HRED) introduced in Sordoni et al (2015). It is a hierarchical encoder-decoder architecture for modeling conversation triples. This version of the model is built for the MovieTriples dataset.
Link to related papers - HRED, Modified HRED
Training commands
For baseline seq2seq model - python3 train.py -n seq2seq -tf -bms 20 -bs 100 -e 80 -seshid 300 -uthid 300 -drp 0.4 -lr 0.0005 -s2s -pt 3
For HRED model - python3 train.py -n HRED -tf -bms 20 -bs 100 -e 80 -seshid 300 -uthid 300 -drp 0.4 -lr 0.0005 -pt 3
For Bi-HRED + Language model objective (inverse sigmoid teacher forcing rate decay) - python3 train.py -n BiHRED+LM -bi -lm -nl 2 -lr 0.0003 -e 80 -seshid 300 -uthid 300 -bs 10 -pt 3
For Bi-HRED + Language model objective (Ful teacher forcing) - python3 train.py -n model3 -nl 2 -bi -lm -drp 0.4 -e 25 -seshid 300 -uthid 300 -lr 0.0001 -bs 100 -tf
At test time, we use beam search decoding with beam size set at 20. For reranking the candidates during beam search, we use the MMI Anit-LM following the method in paper
Test command - Just add the following flags for testing the model. -test -mmi -bms 50
To perform a sanity check on the model, train the model on a small subset of the dataset with the flag -toy
. It should overfit with a training error of 0.5.
To bootstrap the encoder with pretrained word embeddings use -embed
flag. You should have a embeddings text file in /data/embeddings folder.
The default evaluation metric used is word perplexity. To evaluate the model on word embedding metrics, run the command python metrics.py <ground_truth txt file> <model_output txt file> <pretrained word2vec embeddings binary>
. The metrics script was taken from here . The word embedding based metric is done using the pretrained word2vec trained on Google news corpus. Download it from here and place it in /data/embeddings folder.