X Czh SeqGAN PyTorch Save

Implementation of Sequence Generative Adversarial Nets with Policy Gradient in PyTorch

Project README

SeqGAN-PyTorch

An implementation of SeqGAN (Paper: SeqGAN: Sequence Generative Adversarial Nets with Policy Gradient) in PyTorch. The code performs the experiment on synthetic data as described in the paper.

Usage

$ python main.py

Please refer to main.py for supported arguments. You can also change model parameters there.

Dependency

  • PyTorch 0.4.0+ (1.0 ready)
  • Python 3.5+
  • CUDA 8.0+ & cuDNN (For GPU)
  • numpy

Hacks and Observations

  • Using Adam for Generator and SGD for Discriminator
  • Discriminator should neither be trained too powerful (fail to provide useful feedback) nor too ill-performed (randomly guessing, unable to guide generation)
  • The GAN phase may not always lead to massive drops in NLL (sometimes very minimal or even increases NLL)

Sample Learning Curve

Learning curve of generator obtained after MLE training for 120 steps (1 epoch per round) followed by adversarial training for 150 rounds (1 epoch per round):

alt tag

Learning curve of discriminator obtained after MLE training for 50 steps (3 epochs per step) followed by adversarial training for 150 rounds (9 epoch per round):

alt tag

Acknowledgement

This code is based on Zhao Zijian's SeqGAN-PyTorch, Surag Nair's SeqGAN and Lantao Yu's original implementation in Tensorflow. Many thanks to Zhao Zijian, Surag Nair and Lantao Yu!

Open Source Agenda is not affiliated with "X Czh SeqGAN PyTorch" Project. README Source: X-czh/SeqGAN-PyTorch
Stars
51
Open Issues
2
Last Commit
3 years ago

Open Source Agenda Badge

Open Source Agenda Rating