Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
Implementation of Imagen, Google's Text-to-Image Neural Network that beats DALL-E2, in Pytorch. It is the new SOTA for text-to-image synthesis.
Architecturally, it is actually much simpler than DALL-E2. It consists of a cascading DDPM conditioned on text embeddings from a large pretrained T5 model (attention network). It also contains dynamic clipping for improved classifier free guidance, noise level conditioning, and a memory efficient unet design.
It appears neither CLIP nor prior network is needed after all. And so research continues.
Please join if you are interested in helping out with the replication with the LAION community
$ pip install imagen-pytorch
import torch
from imagen_pytorch import Unet, Imagen
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
layer_cross_attns = (False, True, True, True)
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(4, 256, 768).cuda()
text_masks = torch.ones(4, 256).bool().cuda()
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, text_embeds = text_embeds, text_masks = text_masks, unet_number = i)
loss.backward()
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = imagen.sample(texts = [
'a whale breaching from afar',
'young girl blowing out candles on her birthday cake',
'fireworks with blue and green sparkles'
], cond_scale = 3.)
images.shape # (3, 3, 256, 256)
For simpler training, you can directly supply text strings instead of precomputing text encodings. (Although for scaling purposes, you will definitely want to precompute the textual embeddings + mask)
The number of textual captions must match the batch size of the images if you go this route.
# mock images and text (get a lot of this)
texts = [
'a child screaming at finding a worm within a half-eaten apple',
'lizard running across the desert on two feet',
'waking up to a psychedelic landscape',
'seashells sparkling in the shallow waters'
]
images = torch.randn(4, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = imagen(images, texts = texts, unet_number = i)
loss.backward()
With the ImagenTrainer
wrapper class, the exponential moving averages for all of the U-nets in the cascading DDPM will be automatically taken care of when calling update
import torch
from imagen_pytorch import Unet, Imagen, ImagenTrainer
# unet for imagen
unet1 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = 3,
layer_attns = (False, True, True, True),
)
unet2 = Unet(
dim = 32,
cond_dim = 512,
dim_mults = (1, 2, 4, 8),
num_resnet_blocks = (2, 4, 8, 8),
layer_attns = (False, False, False, True),
layer_cross_attns = (False, False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
unets = (unet1, unet2),
text_encoder_name = 't5-large',
image_sizes = (64, 256),
timesteps = 1000,
cond_drop_prob = 0.1
).cuda()
# wrap imagen with the trainer class
trainer = ImagenTrainer(imagen)
# mock images (get a lot of this) and text encodings from large T5
text_embeds = torch.randn(64, 256, 1024).cuda()
text_masks = torch.ones(64, 256).bool().cuda()
images = torch.randn(64, 3, 256, 256).cuda()
# feed images into imagen, training each unet in the cascade
for i in (1, 2):
loss = trainer(
images,
text_embeds = text_embeds,
text_masks = text_masks,
unet_number = i,
max_batch_size = 4 # auto divide the batch of 64 up into batch size of 4 and accumulate gradients, so it all fits in memory
)
trainer.update(unet_number = i)
# do the above for many many many many steps
# now you can sample an image based on the text embeddings from the cascading ddpm
images = trainer.sample(texts = [
'a puppy looking anxiously at a giant donut on the table',
'the milky way galaxy in the style of monet'
], cond_scale = 3.)
images.shape # (2, 3, 256, 256)
You can also train Imagen without text (unconditional image generation) as follows
import torch
from imagen_pytorch import Unet, Imagen, SRUnet256, ImagenTrainer
# unets for unconditional imagen
unet1 = Unet(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = 3,
layer_attns = (False, True, True),
layer_cross_attns = (False, True, True),
use_linear_attn = True
)
unet2 = SRUnet256(
dim = 32,
dim_mults = (1, 2, 4),
num_resnet_blocks = (2, 4, 8),
layer_attns = (False, False, True),
layer_cross_attns = (False, False, True)
)
# imagen, which contains the unets above (base unet and super resoluting ones)
imagen = Imagen(
condition_on_text = False, # this must be set to False for unconditional Imagen
unets = (unet1, unet2),
image_sizes = (64, 128),
timesteps = 1000
)
trainer = ImagenTrainer(imagen).cuda()
# now get a ton of images and feed it through the Imagen trainer
training_images = torch.randn(4, 3, 256, 256).cuda()
# train each unet in concert, or separately (recommended) to completion
for u in (1, 2):
loss = trainer(training_images, unet_number = u)
trainer.update(unet_number = u)
# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)
images = trainer.sample(batch_size = 16) # (16, 3, 128, 128)
Imagen uses an algorithm called Classifier Free Guidance. When sampling, you apply a scale to the conditioning (text in this case) of greater than 1.0
.
Researcher Netruk44 have reported 5-10
to be optimal, but anything greater than 10
to break.
trainer.sample(texts = [
'a cloud in the shape of a roman gladiator'
], cond_scale = 5.) # <-- cond_scale is the conditioning scale, needs to be greater than 1.0 to be better than average
Not at the moment but one will likely be trained and open sourced within the year, if not sooner. If you would like to participate, you can join the community of artificial neural network trainers at Laion (discord link is in the Readme above) and start collaborating.
StabilityAI for the generous sponsorship, as well as my other sponsors out there
🤗 Huggingface for their amazing transformers library. The text encoder portion is pretty much taken care of because of them
Jorge Gomes for helping out with the T5 loading code and advice on the correct T5 version
Katherine Crowson, for her beautiful code, which helped me understand the continuous time version of gaussian diffusion
Marunine and Netruk44, for reviewing code, sharing experimental results, and help with debugging
Marunine for providing a potential solution for a color shifting issue in the memory efficient u-nets. Thanks to Jacob for sharing experimental comparisons between the base and memory-efficient unets
You? It isn't done yet, chip in if you are a researcher or skilled ML engineer
@inproceedings{Saharia2022PhotorealisticTD,
title = {Photorealistic Text-to-Image Diffusion Models with Deep Language Understanding},
author = {Chitwan Saharia and William Chan and Saurabh Saxena and Lala Li and Jay Whang and Emily L. Denton and Seyed Kamyar Seyed Ghasemipour and Burcu Karagol Ayan and Seyedeh Sara Mahdavi and Raphael Gontijo Lopes and Tim Salimans and Jonathan Ho and David Fleet and Mohammad Norouzi},
year = {2022}
}
@article{Alayrac2022Flamingo,
title = {Flamingo: a Visual Language Model for Few-Shot Learning},
author = {Jean-Baptiste Alayrac et al},
year = {2022}
}
@article{Choi2022PerceptionPT,
title = {Perception Prioritized Training of Diffusion Models},
author = {Jooyoung Choi and Jungbeom Lee and Chaehun Shin and Sungwon Kim and Hyunwoo J. Kim and Sung-Hoon Yoon},
journal = {ArXiv},
year = {2022},
volume = {abs/2204.00227}
}
@inproceedings{Sankararaman2022BayesFormerTW,
title = {BayesFormer: Transformer with Uncertainty Estimation},
author = {Karthik Abinav Sankararaman and Sinong Wang and Han Fang},
year = {2022}
}
@article{So2021PrimerSF,
title = {Primer: Searching for Efficient Transformers for Language Modeling},
author = {David R. So and Wojciech Ma'nke and Hanxiao Liu and Zihang Dai and Noam M. Shazeer and Quoc V. Le},
journal = {ArXiv},
year = {2021},
volume = {abs/2109.08668}
}
@misc{cao2020global,
title = {Global Context Networks},
author = {Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
year = {2020},
eprint = {2012.13375},
archivePrefix = {arXiv},
primaryClass = {cs.CV}
}