基于Stable Diffusion XL模型进行文本生成图像的训练
flyfish
环境变量部分
bash
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
export DATASET_NAME="lambdalabs/naruto-blip-captions"
MODEL_NAME
:指定预训练模型的名称或路径。这里使用的是stabilityai/stable-diffusion-xl-base-1.0
,也就是Stable Diffusion XL的基础版本1.0。VAE_NAME
:指定变分自编码器(VAE)的名称或路径。madebyollin/sdxl-vae-fp16-fix
是针对Stable Diffusion XL的一个经过修复的VAE模型,适用于半精度(FP16)计算。DATASET_NAME
:指定训练所使用的数据集名称或路径。这里使用的是lambdalabs/naruto-blip-captions
,是一个包含火影忍者相关图像及其描述的数据集。
accelerate launch
命令参数部分
bash
accelerate launch train_text_to_image_sdxl.py \
这行代码使用 accelerate
工具来启动 train_text_to_image_sdxl.py
脚本,accelerate
可以帮助我们在多GPU、TPU等环境下进行分布式训练。
脚本参数部分
--pretrained_model_name_or_path=$MODEL_NAME
:指定预训练模型的名称或路径,这里使用前面定义的MODEL_NAME
环境变量。--pretrained_vae_model_name_or_path=$VAE_NAME
:指定预训练VAE模型的名称或路径,使用前面定义的VAE_NAME
环境变量。--dataset_name=$DATASET_NAME
:指定训练数据集的名称或路径,使用前面定义的DATASET_NAME
环境变量。--enable_xformers_memory_efficient_attention
:启用xformers
库的内存高效注意力机制,能减少训练过程中的内存占用。--resolution=512 --center_crop --random_flip
:--resolution=512
:将输入图像的分辨率统一调整为512x512像素。--center_crop
:对图像进行中心裁剪,使其达到指定的分辨率。--random_flip
:在训练过程中随机对图像进行水平翻转,以增加数据的多样性。
--proportion_empty_prompts=0.2
:设置空提示(没有文本描述)的样本在训练数据中的比例为20%。--train_batch_size=1
:每个训练批次包含的样本数量为1。--gradient_accumulation_steps=4 --gradient_checkpointing
:--gradient_accumulation_steps=4
:梯度累积步数为4,即每4个批次的梯度进行一次更新,这样可以在有限的内存下模拟更大的批次大小。--gradient_checkpointing
:启用梯度检查点机制,通过减少内存使用来支持更大的模型和批次大小。
--max_train_steps=10000
:最大训练步数为10000步。--use_8bit_adam
:使用8位Adam优化器,能减少内存占用。--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0
:--learning_rate=1e-06
:学习率设置为1e-6。--lr_scheduler="constant"
:学习率调度器设置为常数,即训练过程中学习率保持不变。--lr_warmup_steps=0
:学习率预热步数为0,即不进行学习率预热。
--mixed_precision="fp16"
:使用半精度(FP16)混合精度训练,能减少内存使用并加快训练速度。--report_to="wandb"
:将训练过程中的指标报告到Weights & Biases(WandB)平台,方便进行可视化和监控。--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5
:--validation_prompt="a cute Sundar Pichai creature"
:指定验证时使用的文本提示,这里是"一个可爱的桑达尔·皮查伊形象"。--validation_epochs 5
:每5个训练轮次进行一次验证。
--checkpointing_steps=5000
:每5000步保存一次模型的检查点。--output_dir="sdxl-naruto-model"
:指定训练好的模型的输出目录为sdxl-naruto-model
。--push_to_hub
:将训练好的模型推送到Hugging Face模型库。
离线环境运行
bash
# 假设已经把模型、VAE和数据集下载到本地了
# 这里假设模型在当前目录下的 sdxl-base-1.0 文件夹
# VAE 在 sdxl-vae-fp16-fix 文件夹
# 数据集在 naruto-blip-captions 文件夹
# 定义本地路径
MODEL_NAME="./sdxl-base-1.0"
VAE_NAME="./sdxl-vae-fp16-fix"
DATASET_NAME="./naruto-blip-captions"
# 移除需要外网连接的参数
accelerate launch train_text_to_image_sdxl.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--pretrained_vae_model_name_or_path=$VAE_NAME \
--dataset_name=$DATASET_NAME \
--enable_xformers_memory_efficient_attention \
--resolution=512 --center_crop --random_flip \
--proportion_empty_prompts=0.2 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=10000 \
--use_8bit_adam \
--learning_rate=1e-06 --lr_scheduler="constant" --lr_warmup_steps=0 \
--mixed_precision="fp16" \
--validation_prompt="a cute Sundar Pichai creature" --validation_epochs 5 \
--checkpointing_steps=5000 \
--output_dir="sdxl-naruto-model"
移除需要外网连接的参数 :去掉 --report_to="wandb"
和 --push_to_hub
参数,因为 wandb
需要外网连接来上传训练指标,--push_to_hub
则需要外网连接把模型推送到Hugging Face模型库。