Scalable Diffusion Models with Transformers (DiT)

Official PyTorch Implementation

Paper | Project Page | Run DiT-XL/2

This repo contains PyTorch model definitions, pre-trained weights and training/sampling code for our paper exploring diffusion models with transformers (DiTs). You can find more visualizations on our project page.

Scalable Diffusion Models with Transformers
William Peebles, Saining Xie

UC Berkeley, New York University

We train latent diffusion models, replacing the commonly-used U-Net backbone with a transformer that operates on latent patches. We analyze the scalability of our Diffusion Transformers (DiTs) through the lens of forward pass complexity as measured by Gflops. We find that DiTs with higher Gflops---through increased transformer depth/width or increased number of input tokens---consistently have lower FID. In addition to good scalability properties, our DiT-XL/2 models outperform all prior diffusion models on the class-conditional ImageNet 512×512 and 256×256 benchmarks, achieving a state-of-the-art FID of 2.27 on the latter.

This repository contains:

An implementation of DiT directly in Hugging Face diffusers can also be found here.

Setup

First, download and set up the repo:

复制代码
git clone https://github.com/facebookresearch/DiT.git
cd DiT

We provide an environment.yml file that can be used to create a Conda environment. If you only want to run pre-trained models locally on CPU, you can remove the cudatoolkit and pytorch-cuda requirements from the file.

复制代码
conda env create -f environment.yml
conda activate DiT

Sampling

Pre-trained DiT checkpoints. You can sample from our pre-trained DiT models with sample.py. Weights for our pre-trained DiT model will be automatically downloaded depending on the model you use. The script has various arguments to switch between the 256x256 and 512x512 models, adjust sampling steps, change the classifier-free guidance scale, etc. For example, to sample from our 512x512 DiT-XL/2 model, you can use:

复制代码
python sample.py --image-size 512 --seed 1

For convenience, our pre-trained DiT models can be downloaded directly here as well:

DiT Model Image Resolution FID-50K Inception Score Gflops
XL/2 256x256 2.27 278.24 119
XL/2 512x512 3.04 240.82 525

Custom DiT checkpoints. If you've trained a new DiT model with train.py (see below), you can add the --ckpt argument to use your own checkpoint instead. For example, to sample from the EMA weights of a custom 256x256 DiT-L/4 model, run:

复制代码
python sample.py --model DiT-L/4 --image-size 256 --ckpt /path/to/model.pt

Training DiT

We provide a training script for DiT in train.py. This script can be used to train class-conditional DiT models, but it can be easily modified to support other types of conditioning. To launch DiT-XL/2 (256x256) training with N GPUs on one node:

复制代码
torchrun --nnodes=1 --nproc_per_node=N train.py --model DiT-XL/2 --data-path /path/to/imagenet/train

PyTorch Training Results

We've trained DiT-XL/2 and DiT-B/4 models from scratch with the PyTorch training script to verify that it reproduces the original JAX results up to several hundred thousand training iterations. Across our experiments, the PyTorch-trained models give similar (and sometimes slightly better) results compared to the JAX-trained models up to reasonable random variation. Some data points:

DiT Model Train Steps FID-50K (JAX Training) FID-50K (PyTorch Training) PyTorch Global Training Seed
XL/2 400K 19.5 18.1 42
B/4 400K 68.4 68.9 42
B/4 400K 68.4 68.3 100

These models were trained at 256x256 resolution; we used 8x A100s to train XL/2 and 4x A100s to train B/4. Note that FID here is computed with 250 DDPM sampling steps, with the mse VAE decoder and without guidance (cfg-scale=1).

TF32 Note (important for A100 users). When we ran the above tests, TF32 matmuls were disabled per PyTorch's defaults. We've enabled them at the top of train.py and sample.py because it makes training and sampling way way way faster on A100s (and should for other Ampere GPUs too), but note that the use of TF32 may lead to some differences compared to the above results.

Enhancements

Training (and sampling) could likely be sped-up significantly by:

  • using Flash Attention in the DiT model
  • using torch.compile in PyTorch 2.0

Basic features that would be nice to add:

  • Monitor FID and other metrics
  • Generate and save samples from the EMA model periodically
  • Resume training from a checkpoint
  • AMP/bfloat16 support

🔥 Feature Update Check out this repository at GitHub - chuanyangjin/fast-DiT: Fast Diffusion Models with Transformers to preview a selection of training speed acceleration and memory saving features including gradient checkpointing, mixed precision training and pre-extrated VAE features. With these advancements, we have achieved a training speed of 0.84 steps/sec for DiT-XL/2 using just a single A100 GPU.

相关推荐
云边有个稻草人1 小时前
AIGC与娱乐产业:颠覆创意与生产的新力量
aigc·娱乐
猫头虎1 小时前
新纪天工 开物焕彩:重大科技成就发布会参会感
人工智能·开源·aigc·开放原子·开源软件·gpu算力·agi
云起无垠10 小时前
第79期 | GPTSecurity周报
gpt·aigc
Jeremy_lf11 小时前
【生成模型之三】ControlNet & Latent Diffusion Models论文详解
人工智能·深度学习·stable diffusion·aigc·扩散模型
程序员X小鹿13 小时前
羡慕了!小红书上3w+点赞的治愈系插图,用这个免费的AI工具,1分钟搞定!(附详细教程)
aigc
AIGC大时代1 天前
如何使用ChatGPT辅助文献综述,以及如何进行优化?一篇说清楚
人工智能·深度学习·chatgpt·prompt·aigc
吕小明么2 天前
OpenAI o3 “震撼” 发布后回归技术本身的审视与进一步思考
人工智能·深度学习·算法·aigc·agi
聆思科技AI芯片2 天前
实操给桌面机器人加上超拟人音色
人工智能·机器人·大模型·aigc·多模态·智能音箱·语音交互
minos.cpp2 天前
Mac上Stable Diffusion的环境搭建(还算比较简单)
macos·ai作画·stable diffusion·aigc
AI小欧同学2 天前
【AIGC-ChatGPT进阶副业提示词】育儿锦囊:化解日常育儿难题的实用指南
chatgpt·aigc