Unofficial Pytorch Implementation of WaveGrad2
WaveGrad 2: Iterative Refinement for Text-to-Speech Synthesis
Unofficial PyTorch+Lightning Implementation of Chen et al.(JHU, Google Brain), WaveGrad2.
Update: Enjoy our pre-trained model with Google Colab notebook!
The supported datasets are
We take LJSpeech as an example hereafter.
preprocess.yaml
, especially path
section.path:
corpus_path: '/DATA1/LJSpeech-1.1' # LJSpeech corpus path
lexicon_path: 'lexicon/librispeech-lexicon.txt'
raw_path: './raw_data/LJSpeech'
preprocessed_path: './preprocessed_data/LJSpeech'
prepare_align.py
for some preparations.python prepare_align.py -c preprocess.yaml
Montreal Forced Aligner (MFA) is used to obtain the alignments between the utterances and the phoneme sequences.
Alignments for the LJSpeech and AISHELL-3 datasets are provided here.
You have to unzip the files in preprocessed_data/LJSpeech/TextGrid/
.
After that, run preprocess.py
.
python preprocess.py -c preprocess.yaml
./montreal-forced-aligner/bin/mfa_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt english preprocessed_data/LJSpeech
or
./montreal-forced-aligner/bin/mfa_train_and_align raw_data/LJSpeech/ lexicon/librispeech-lexicon.txt preprocessed_data/LJSpeech
preprocess.py
.python preprocess.py -c preprocess.yaml
hparameter.yaml
, especially train
section.train:
batch_size: 12 # Dependent on GPU memory size
adam:
lr: 3e-4
weight_decay: 1e-6
decay:
rate: 0.05
start: 25000
end: 100000
num_workers: 16 # Dependent on CPU cores
gpus: 2 # number of GPUs
loss_rate:
dur: 1.0
data
section in hparameter.yaml
data:
lang: 'eng'
text_cleaners: ['english_cleaners'] # korean_cleaners, english_cleaners, chinese_cleaners
speakers: ['LJSpeech']
train_dir: 'preprocessed_data/LJSpeech'
train_meta: 'train.txt' # relative path of metadata file from train_dir
val_dir: 'preprocessed_data/LJSpeech'
val_meta: 'val.txt' # relative path of metadata file from val_dir'
lexicon_path: 'lexicon/librispeech-lexicon.txt'
trainer.py
python trainer.py
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--resume_from', type =int,\
required = False, help = "Resume Checkpoint epoch number")
parser.add_argument('-s', '--restart', action = "store_true",\
required = False, help = "Significant change occured, use this")
parser.add_argument('-e', '--ema', action = "store_true",
required = False, help = "Start from ema checkpoint")
args = parser.parse_args()
tensorboard --logdir=./tensorboard --bind_all
inference.py
python inference.py -c <checkpoint_path> --text <'text'>
We provide a Jupyter Notebook script to provide the code for inference and show some visualizations with resulting audio.
WaveGrad-Base
and WaveGrad-Large
decoder).We implemented WaveGrad-Large
decoder for high MOS output.
Note: it could be different with google's implementation since number of parameters are different with paper's value.
hparameter.yaml
.wavegrad:
is_large: True #if False, Base
...
dilations: [[1,2,4,8],[1,2,4,8],[1,2,4,8],[1,2,4,8],[1,2,4,8]] #dilations for Large
#dilations: [[1,2,4,8],[1,2,4,8],[1,2,4,8],[1,2,1,2],[1,2,1,2]] dilations for Base
Since this repo is unofficial implementation and WaveGrad2 paper do not provide several details, a slight differences between paper could exist.
We listed modifications or arbitrary setups
WaveGrad-Large
decoder's architecture could be different with Google's implementation.train.batch_size: 12
for Base and train.batch_size: 6
for Large, Trained with 2 V100 (32GB) GPUstrain.adam.lr: 3e-4
and train.adam.weight_decay: 1e-6
train.decay
learning rate decay is applied during trainingtrain.loss_rate: 1
as total_loss = 1 * L1_loss + 1 * duration_loss
ddpm.ddpm_noise_schedule: torch.linspace(1e-6, 0.01, hparams.ddpm.max_step)
encoder.channel
is reduced to 512 from 1024 or 2048.
├── Dockerfile
├── README.md
├── dataloader.py
├── docs
│ ├── spec.png
│ ├── tb.png
│ └── tblogger.png
├── hparameter.yaml
├── inference.py
├── lexicon
│ ├── librispeech-lexicon.txt
│ └── pinyin-lexicon-r.txt
├── lightning_model.py
├── model
│ ├── base.py
│ ├── downsampling.py
│ ├── encoder.py
│ ├── gaussian_upsampling.py
│ ├── interpolation.py
│ ├── layers.py
│ ├── linear_modulation.py
│ ├── nn.py
│ ├── resampling.py
│ ├── upsampling.py
│ └── window.py
├── prepare_align.py
├── preprocess.py
├── preprocess.yaml
├── preprocessor
│ ├── ljspeech.py
│ └── preprocessor.py
├── text
│ ├── __init__.py
│ ├── cleaners.py
│ ├── cmudict.py
│ ├── numbers.py
│ └── symbols.py
├── trainer.py
├── utils
│ ├── mel.py
│ ├── stft.py
│ ├── tblogger.py
│ └── utils.py
└── wavegrad2_tester.ipynb
This code is implemented by
Special thanks to
This implementation uses code from following repositories:
The webpage for the audio samples uses a template from:
The audio samples on our webpage are partially derived from: