TPU JAX Flax stable-diffusion text-to-image

Maybe this is the first ever model trained with TPUs and converted to ๐Ÿงจ PyTorch ๐ŸŽŠ๐ŸŽ‰ <br/> Trained with google cloud TPUs.

Runtime: 3h 26m 44s
Steps: 18000
Precision: bf16
Learning Rate: 1e-6

plushies