Generative AI 新世界 | 文生图领域动手实践:预训练模型的微调

在上期文章,我们探讨了预训练模型的部署和推理,包括运行环境准备、角色权限配置、支持的主要推理参数、图像的压缩输出、提示工程 (Prompt Engineering)、反向提示 (Negative Prompting) 等内容。

亚马逊云科技开发者社区为开发者们提供全球的开发技术资源。这里有技术文档、开发案例、技术专栏、培训视频、活动与竞赛等。帮助中国开发者对接世界最前沿技术,观点,和项目,并将中国优秀开发者或技术推荐给全球云社区。如果你还没有关注/收藏,看到这里请一定不要匆匆划过,点这里让它成为你的技术宝库!

本期文章,我们将探讨如何在自定义数据集上来微调(fine-tuned)模型 ,该模型可以针对任何图像数据集进行微调。即使你手上只有几张自定义的图像提供做训练,模型也能输出比较理想的结果。

首先,让我们通过一篇论文的概括解读,来了解这种文生图模型的微调 (fine-tuned),背后的工作原理和理论基础知识。

DreamBooth 论文概述

这种文生图模型的微调(fine-tuned)理论基础来自于 DreamBooth 论文,如下图所示:

DreamBooth: Fine Tuning Text-to-Image Diffusion Models for Subject-DrivenGeneration

arxiv.org/pdf/2208.12...

在论文的开头,作者提出一个挑战性的问题:

虽然当时的文生图模型已经可以根据给定的 prompt 生成高质量的图片,但是这些模型并不能模仿给定参考图片中的物体要素,在不同情景中来生成新的图片。

举个例子。

我家里有一只叫做"小花"的可爱加菲猫,如下图:

我想让加菲猫"小花"带上一顶礼帽,如下图:

或者带上一副很酷炫的墨镜,如下图:

甚至想象下她刷牙的魔幻景象,如下图:

事实上,上面的这些加菲猫"小花"的照片(戴礼帽、戴墨镜、刷牙),都是大模型使用 DreamBooth 做微调后生成的。很有趣吧?在文末会提供生成这些魔幻照片的全部代码。

我们先看下 DreamBooth 论文阐述的背后原理。

DreamBooth 论文提出一个新颖的方法:将输入图片中的物体与一个特殊标识符绑定在一起,即用这个特殊标记符来表示输入图片中的物体。因此论文为微调模型设计了一种 prompt 格式:a [identifier] [class noun],即将所有输入图片的 prompt 都设置成这种形式,其中 identifier 是一个与输入图片中物体相关联的特殊标记符,class noun 是对物体的类别描述。

这里之所以在 prompt 中加入类别,是因为想利用预训练模型中关于该类别物品的先验知识,并将先验知识与特殊标记符相关信息进行融合,这样就可以在不同场景下生成不同姿势的目标物体。

简单来说就是:不要学了新的知识,就忘了旧的知识

论文提出的方法,大致如下图所示,即仅仅通过 3 到 5 张图片去微调文生图模型,使得模型能将输入图片中特定的物品和 prompt 中的特殊标记符关联起来了。

Source: dreambooth.github.io?trk=cndc-detail

关于特殊标记符的选择,论文提出通过在词表中选择罕见词来作为特殊标记符,这样避免了预训练模型对特殊标记符有很强烈的先验知识。

DreamBooth 论文提出一个新的微调方法:通过预先生成的一些图像,来保留先验损失权重;以此来解决过拟合与语言漂移问题。用模型自己生成的样本来监督模型,以便在 few-shot(小样本)微调开始后保留先验知识,如以下论文中提供的解释图所示:

Source: dreambooth.github.io/?trk=cndc-d...

给定大约 3-5 张拍摄对象的图像,我们分两步微调文本到图像的扩散:

  1. 使用输入图像与包含唯一标识符和主题所属类名称(例如:"A photo of a [T] dog")的文本提示配对;同时,我们应用特定于类的预先保存损失,它利用了模型之前的语义通过在文本提示中注入类名,来鼓励它生成属于受试者类的各种实例提示(例如:"A photo of a dog")。
  2. 使用从我们的输入图像集中拍摄的低分辨率和高分辨率图像,对超分辨率组件进行微调,这使我们能够保持对拍摄对象小细节的高保真度。

引入了先验损失的 loss 公式,如下所示:

通过这种 DreamBooth 方法,使得:输入训练集 + 提示词 [v] dog,然后还有用模型本身自己生成的 dog 图像,训练完成后得到了一个特殊标记符:[v]。通过这个特殊标记符 [v],就把这次训练的 dog 和其他本身学过的 dog 分开了。

最后得到惊艳的结果,比如给一只小熊带上太阳镜,如下图所示:

Source: dreambooth.github.io/?trk=cndc-d...

接下来,我们将完整用代码演示,如何给我家的加菲猫"小花"带上眼镜和礼帽。

Fine-tune 预训练模型在自有数据集上的微调

我们使用 Amazon SageMaker Studio 来实现在自有数据上的模型微调。

我首先将为我家的加菲猫"小花"拍摄几张照片,然后用这几张照片来微调模型;完成模型微调后,我们将使用 "a picture of Garfield cat with glasses" 这样的提示词,来直接为我家的加菲猫"小花"带上眼镜。

1 实例和环境准备

这个 Notebook 在带有 Python 3(Data Science)内核的 SageMaker Studio 中,使用 ml.t3.medium 实例上进行了测试。要对数据集的模型进行微调,您需要在账户中提供 ml.g4dn.2xlarge 实例类型。

完整的示例代码,可参考以下 GitHub 文档链接,从 "Fine-tune the pre-trained model on a custom dataset" 这一部分开始阅读代码:

github.com/aws/studio-...

你存放自定义照片的 s3 路径,应该看起来像这样:s3://bucket_name/input_directory/

请注意,后面的"/"为必填项。

以下是训练数据的示例格式:

css 复制代码
input_directory
    |---instance_image_1.png
    |---instance_image_2.png
    |---instance_image_3.png
    |---instance_image_4.png
    |---instance_image_5.png
    |---dataset_info.json
    |---class_data_dir
        |---class_image_1.png
        |---class_image_2.png
        |---class_image_3.png
        |---class_image_4.png

预先保存、实例提示和类提示(Prior preservation, instance prompt and class prompt) :预先保存是一种使用我们正在尝试训练的同一个类的其他图像的技术。例如,如果训练数据由特定狗的图像组成,并事先保存,则我们会合并普通犬的类别图像。它试图通过在为特定狗训练时显示不同狗的图像来避免过度拟合。类提示中缺少表示实例提示中存在的特定狗的标签。

例如,实例提示可能是 "A photo of a Garfield cat",类提示可能是 "A photo of a cat"。

您可以通过将超参数设置为 _prior_preservation = True 来启用预先保存。

以下为使用我家加菲猫"小花"的照片的 dataset_info.json 的文件示例:

bash 复制代码
$ cat dataset_info.json
{
  "instance_prompt": "A photo of a Garfield cat",
  "class_prompt": "A photo of a cat"
}

以下是我为了微调模型,而拍摄的我家加菲猫"小花"的照片。我只用了下面这六张照片,就实现了模型的微调。

我存放照片(即为微调模型提供的自定义训练图片)的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/

其中 "sagemaker-us-east-1-xxxxxxxxxxxx" 需要更新为你自己定义的桶名。

最终完成微调后,模型存放的 S3 桶参考路径如下:s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output

其中 "sagemaker-us-east-1-xxxxxxxxxxxx" 需要更新为你自己定义的桶名。

2 检索训练数据的 Artifacts

在这里,我们检索训练 docker 容器、训练算法源和预先训练的基础模型。

请注意,model_version= "*" 获取的是最新的模型版本号。以下代码选择了 Stable Diffusion V2.1 Base 的文生图大模型。

bash 复制代码
# Select a model 
train_model_id, train_model_version, train_scope = (
    "model-txt2img-stabilityai-stable-diffusion-v2-1-base",
    "*",
    "training",
)

以下代码选择了微调模型的实例是 ml.g4dn.2xlarge:

ini 复制代码
training_instance_type = "ml.g4dn.2xlarge"

以下代码获取 Docker Image:

ini 复制代码
# Retrieve the docker image
train_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    model_id=train_model_id,
    model_version=train_model_version,
    image_scope=train_scope,
    instance_type=training_instance_type,
)

以下代码获取训练脚本:

ini 复制代码
# Retrieve the training script. This contains all the necessary files including data processing, model training etc.
train_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope=train_scope
)

以下代码获取预训练模型的 tarball 包,用于之后的微调工作:

ini 复制代码
# Retrieve the pre-trained model tarball to further fine-tune
train_model_uri = model_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, model_scope=train_scope
)

3 设置训练参数

现在我们已经完成了所有需要的设置,我们已经准备好微调 Stable Diffusion 模型了。首先,让我们创建一个 sageMaker.estimator.Estimator 对象。该 Estimator 将启动训练作业。

模型的微调训练需要设置两种参数。

第一组参数是训练作业的参数。其中包括:

  1. 训练数据路径,这是存储输入数据的 S3 路径。即之前我们准备的 "s3://sagemaker-us-east-1-xxxxxxxxxxxx/haowen-datasets/cat_finetuning/" 这个路径;
  2. 输出路径,这是存储微调模型训练的输出 s3 路径。即之前我们准备的"s3://sagemaker-us-east-1-xxxxxxxxxxxx/jumpstart-example-sd-training/output" 这个路径;
  3. 训练实例类型,这表示运行模型微调训练的机器类型。我们在上面定义了训练实例类型,以获取正确的 train_image_uri。

第二组参数是特定于算法的训练超参数。对于算法特定的超参数,我们首先获取算法接受的训练超参数的 python 字典及其默认值,然后可以将其改写为自定义值。示例代码如下所示:

ini 复制代码
from sagemaker import hyperparameters

# Retrieve the default hyper-parameters for fine-tuning the model
hyperparameters = hyperparameters.retrieve_default(
    model_id=train_model_id, model_version=train_model_version
)

# [Optional] Override default hyperparameters with custom values
hyperparameters["max_steps"] = "400"
print(hyperparameters)

4 启动模型微调训练

我们首先使用所有必需的 assets 创建 estimator 对象,然后启动训练作业。

ini 复制代码
from sagemaker.estimator import Estimator
from sagemaker.utils import name_from_base
from sagemaker.tuner import HyperparameterTuner

training_job_name = name_from_base(f"jumpstart-example-{train_model_id}-transfer-learning")

# Create SageMaker Estimator instance
sd_estimator = Estimator(
    role=aws_role,
    image_uri=train_image_uri,
    source_dir=train_source_uri,
    model_uri=train_model_uri,
    entry_point="transfer_learning.py",  # Entry-point file in source_dir and present in train_source_uri.
    instance_count=1,
    instance_type=training_instance_type,
    max_run=360000,
    hyperparameters=hyperparameters,
    output_path=s3_output_location,
    base_job_name=training_job_name,
)

# Launch a SageMaker Training job by passing s3 path of the training data
sd_estimator.fit({"training": training_dataset_s3_path}, logs=True)

模型训练开始后,如果观察 SageMaker 的控制台,会发现:

  1. 训练任务的状态,从 "InProgress" 逐渐变成 "Completed";
  2. 超参调优的状态,从 "InProgress" 逐渐变成 "Completed"。

如下图所示:

经过那六张照片作为新的输入数据,微调后的模型重新训练完成后,就可以进入以下的模型部署阶段了。

5 微调后模型的部署

我们将遵循上一篇中介绍的模型部署的相同步骤,在训练好的模型上运行推理。我们首先检索用于部署端点的 jumpstart 工件。但是,我们部署的是经过微调的 sd_estimator 估算器,而不是上一篇中使用的 base_predictor 估算器。

ini 复制代码
inference_instance_type = "ml.g4dn.2xlarge"

# Retrieve the inference docker container uri
deploy_image_uri = image_uris.retrieve(
    region=None,
    framework=None,  # automatically inferred from model_id
    image_scope="inference",
    model_id=train_model_id,
    model_version=train_model_version,
    instance_type=inference_instance_type,
)
# Retrieve the inference script uri. This includes scripts for model loading, inference handling etc.
deploy_source_uri = script_uris.retrieve(
    model_id=train_model_id, model_version=train_model_version, script_scope="inference"
)

endpoint_name = name_from_base(f"jumpstart-example-FT-{train_model_id}-")

# Use the estimator from the previous step to deploy to a SageMaker endpoint
finetuned_predictor = sd_estimator.deploy(
    initial_instance_count=1,
    instance_type=inference_instance_type,
    entry_point="inference.py",  # entry point file in source_dir and present in deploy_source_uri
    image_uri=deploy_image_uri,
    source_dir=deploy_source_uri,
    endpoint_name=endpoint_name,
)

在等待新模型部署的过程中,可以回到 SageMaker 的控制台,在 Endpoints 项中刷新检查模型部署的情况。当 Status 从 "Creating" 变成 "Completed",就表示微调后的新模型已经部署完成可以开始进行推理了。如下图所示:

6 微调后模型的推理

下面进入激动人心的时刻,我们在微调后的模型上进行推理。

我输入的提示词是:"a photo of a Garfield cat with a hat"(一只带帽子的加菲猫)。

ini 复制代码
text = " a photo of a Garfield cat with a hat"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的魔幻输出如下图所示。我们成功地给加菲猫"小花"带上礼帽了!

接着我们给加菲猫"小花"带上眼镜,我输入的提示词是:"a picture of Garfield cat with glasses":

ini 复制代码
text = " a picture of Garfield cat with glasses"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

模型的输出如下:

最后让加菲猫"小花"像人类一样去刷牙,我输入的提示词是:"a picture of Garfield cat brushing her teeth":

ini 复制代码
text = " a picture of Garfield cat brushing her teeth"
query_response = query(finetuned_predictor, text)
img, prmpt = parse_response(query_response)
display_img_and_prompt(img, prmpt)

神奇吧?加菲猫"小花"会自己刷牙了!

7 计算资源删除和清理

和以前一样,实验完成后别忘记清除相关的 endpoint 资源,以避免产生不必要的费用:

ini 复制代码
# Delete the SageMaker endpoint
finetuned_predictor.delete_model()
finetuned_predictor.delete_endpoint()

总结

本文我们学习了如何使用 Amazon SageMaker JumpStart 方便地微调文生图的 Stable Diffusion 模型。

Amazon SageMaker JumpStart 为预训练的模型提供了微调功能,本文的例子中,你只需使用六张训练图像即可根据自己的用例调整模型。这在创建个性化艺术品、独特的徽标、企业的 LOGO、或者其他需要自定义设计的场景时非常有用。

下一期的文章,我们将重新回到文本生成的大模型场景,探讨如何在 Amazon SageMaker JumpStart 上部署当今炙手可热的开源大语言模型。我们将以 Falcon 40B 开源大模型为例,逐行代码轻松部署高达 400 亿参数的这个大型语言模型。敬请期待。

请持续关注 Build On Cloud 专栏,了解更多面向开发者的技术分享和云开发动态!

作者 黄浩文

亚马逊云科技资深开发者布道师,专注于 AI/ML、Data Science 等。拥有 20 多年电信、移动互联网以及云计算等行业架构设计、技术及创业管理等丰富经验,曾就职于 Microsoft、Sun Microsystems、中国电信等企业,专注为游戏、电商、媒体和广告等企业客户提供 AI/ML、数据分析和企业数字化转型等解决方案咨询服务。

文章来源:dev.amazoncloud.cn/column/arti...

相关推荐
不去幼儿园1 小时前
【MARL】深入理解多智能体近端策略优化(MAPPO)算法与调参
人工智能·python·算法·机器学习·强化学习
想成为高手4991 小时前
生成式AI在教育技术中的应用:变革与创新
人工智能·aigc
YSGZJJ2 小时前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞2 小时前
COR 损失函数
人工智能·机器学习
HPC_fac130520678163 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd6 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao7 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI11 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若12311 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界11 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲