使用您自己的图像微调 FLUX.1 LORA 并使用 Azure 机器学习进行部署

目录

介绍

[了解 Flux.1 模型系列](#了解 Flux.1 模型系列)

[什么是 Dreambooth?](#什么是 Dreambooth?)

先决条件

[使用 Dreambooth 微调 Flux 的步骤](#使用 Dreambooth 微调 Flux 的步骤)

[步骤 1:设置环境](#步骤 1:设置环境)

[第 2 步:加载库](#第 2 步:加载库)

[步骤 3:准备数据集](#步骤 3:准备数据集)

[3.1 通过 AML 数据资产(URI 文件夹)将图像上传到数据存储区](#3.1 通过 AML 数据资产(URI 文件夹)将图像上传到数据存储区)

[步骤 4:创建训练环境](#步骤 4:创建训练环境)

[步骤 5:创建计算](#步骤 5:创建计算)

[步骤 6:创建计算](#步骤 6:创建计算)

[步骤 7:下载微调模型并注册模型](#步骤 7:下载微调模型并注册模型)

[步骤 8:在线管理端点部署](#步骤 8:在线管理端点部署)

[步骤 9:为在线端点创建推理环境](#步骤 9:为在线端点创建推理环境)

[步骤 10:为托管在线端点创建部署](#步骤 10:为托管在线端点创建部署)

[步骤 11:测试部署](#步骤 11:测试部署)

结论


介绍

人工智能和机器学习领域继续快速发展,生成式 AI 模型取得了重大进展。Black Forest Labs 的 FLUX.1 模型套件就是其中一项显著进展。这些模型突破了文本到图像合成的界限,提供了无与伦比的图像细节、快速一致性和风格多样性。在本博客中,我们将深入研究使用 Dreambooth 对 FLUX 模型进行微调的过程,这种方法因其在生成高质量、定制的 AI 生成内容方面的有效性而广受欢迎。

了解 Flux.1 模型系列

Black Forest Labs 推出了 FLUX.1 模型的三种变体:

  1. FLUX.1 [pro]:具有一流图像生成功能的高级产品,可供非商业用途使用。
  2. FLUX.1 [dev]:一种开放重量、指导提炼的模型,用于非商业用途,提供高效的性能。
  3. FLUX.1 [schnell]:专为本地开发和个人使用而设计,根据 Apache 2.0 许可提供。

欲了解更多信息,请点击此处的官方公告

这些模型基于多模态和并行扩散变压器块的混合架构,可扩展至 120 亿个参数。它们提供最先进的性能,超越其他领先模型。

什么是 Dreambooth?

Dreambooth 是一种使用小型数据集对生成模型进行微调以生成高度定制化输出的技术。它利用预先训练的模型的现有功能,并通过微调数据集中提供的特定细节、风格或主题对其进行增强。此方法对于需要个性化内容生成的应用程序特别有用。

先决条件

**在我们继续使用 Dreambooth 对LUX.1 [schnell]**模型进行微调之前,请确保您具有以下条件:

  • 访问**FLUX.1 [schnell]**模型,可以在 HuggingFace 上找到。
  • 包含用于微调的图像和相应文本描述的数据集。
  • 具有足够资源(例如 GPU)的计算环境来处理训练过程。

使用 Dreambooth 微调 Flux 的步骤

在这篇博客中,我们将利用 Azure 机器学习来微调文本到图像模型,根据文本输入生成狗的图片。

在开始之前,请确保您已准备好以下物品:

  1. 有权访问 Azure 机器学习的 Azure 帐户。
  2. 对 Python 和 Jupyter 笔记本有基本的了解。
  3. 熟悉 Hugging Face 的 Diffusers 库。
步骤 1:设置环境

首先,通过安装必要的库来设置您的环境。您可以使用以下命令:

<span style="color:#3e3e3e"><span style="background-color:#f5f5f5"><code class="language-bash">pip install transformers diffusers accelerate
pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu113
</code></span></span>
第 2 步:加载库

加载库

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> sys
sys<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>insert<span style="color:#fefefe">(</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">,</span> <span style="color:#abe338">'..'</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">import</span> os
<span style="color:#00e0e0">import</span> shutil
<span style="color:#00e0e0">import</span> random
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml <span style="color:#00e0e0">import</span> automl<span style="color:#fefefe">,</span> Input<span style="color:#fefefe">,</span> Output<span style="color:#fefefe">,</span> MLClient<span style="color:#fefefe">,</span> command<span style="color:#fefefe">,</span> load_job
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>identity <span style="color:#00e0e0">import</span> DefaultAzureCredential<span style="color:#fefefe">,</span> InteractiveBrowserCredential

<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>entities <span style="color:#00e0e0">import</span> Data<span style="color:#fefefe">,</span> Environment<span style="color:#fefefe">,</span> AmlCompute
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>constants <span style="color:#00e0e0">import</span> AssetTypes
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>core<span style="color:#fefefe">.</span>exceptions <span style="color:#00e0e0">import</span> ResourceNotFoundError

<span style="color:#00e0e0">import</span> matplotlib<span style="color:#fefefe">.</span>pyplot <span style="color:#00e0e0">as</span> plt

<span style="color:#00e0e0">import</span> mlflow
<span style="color:#00e0e0">from</span> mlflow<span style="color:#fefefe">.</span>tracking<span style="color:#fefefe">.</span>client <span style="color:#00e0e0">import</span> MlflowClient</code></span></span>

在深入研究代码之前,您需要连接到您的工作区。工作区是 Azure 机器学习的顶级资源,提供了一个集中的位置来处理使用 Azure 机器学习时创建的所有工件。

我们用它来访问工作区。应该能够处理大多数场景。如果你想了解更多其他可用凭据,请转到设置身份验证文档azure-identity 参考文档DefaultAzureCredential DefaultAzureCredential

将下面单元格中的、和替换为其各自的值。 AML_WORKSPACE_NAME RESOURCE_GROUP SUBSCRIPTION_ID

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml <span style="color:#00e0e0">import</span> MLClient
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>identity <span style="color:#00e0e0">import</span> DefaultAzureCredential



credential <span style="color:#00e0e0">=</span> DefaultAzureCredential<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
workspace_ml_client <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
<span style="color:#00e0e0">try</span><span style="color:#fefefe">:</span>
    workspace_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">.</span>from_config<span style="color:#fefefe">(</span>credential<span style="color:#fefefe">)</span>
    subscription_id <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>subscription_id
    resource_group <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>resource_group_name
    workspace_name <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspace_name
<span style="color:#00e0e0">except</span> Exception <span style="color:#00e0e0">as</span> ex<span style="color:#fefefe">:</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span>ex<span style="color:#fefefe">)</span>
    <span style="color:#d4d0ab"># Enter details of your AML workspace</span>
    subscription_id <span style="color:#00e0e0">=</span> <span style="color:#abe338">"SUBSCRIPTION_ID"</span>
    resource_group <span style="color:#00e0e0">=</span> <span style="color:#abe338">"RESOURCE_GROUP"</span>
    workspace_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">"AML_WORKSPACE_NAME"</span>

workspace_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">(</span>
    credential<span style="color:#fefefe">,</span> subscription_id<span style="color:#fefefe">,</span> resource_group<span style="color:#fefefe">,</span> workspace_name
<span style="color:#fefefe">)</span>
registry_ml_client <span style="color:#00e0e0">=</span> MLClient<span style="color:#fefefe">(</span>
    credential<span style="color:#fefefe">,</span>
    subscription_id<span style="color:#fefefe">,</span>
    resource_group<span style="color:#fefefe">,</span>
    registry_name<span style="color:#00e0e0">=</span><span style="color:#abe338">"azureml"</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span></code></span></span>
language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">workspace <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspace_name
subscription_id <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>workspace<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span><span style="color:#abe338">id</span><span style="color:#fefefe">.</span>split<span style="color:#fefefe">(</span><span style="color:#abe338">"/"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">[</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">]</span>
resource_group <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>workspace<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>resource_group

local_train_data <span style="color:#00e0e0">=</span> <span style="color:#abe338">'./train-data/monu/'</span> <span style="color:#d4d0ab"># Azure ML dataset will be created for training on this content</span>
generated_images <span style="color:#00e0e0">=</span> <span style="color:#abe338">'./results/monu'</span>

azureml_dataset_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'monu'</span> <span style="color:#d4d0ab"># Name of the dataset </span>
train_target <span style="color:#00e0e0">=</span> <span style="color:#abe338">'gpu-cluster-big'</span>
experiment_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'dreambooth-finetuning'</span>
training_env_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'dreambooth-flux-train-envn'</span>
inference_env_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux-inference-envn'</span></code></span></span>
步骤 3:准备数据集

通过整理图像及其描述来准备数据集。确保数据的格式与 Dreambooth 兼容。以下是示例结构:

复制代码
`train-data/monu/
    image_1.jpg
    image_2.jpg
    ...
`

3.1 通过 AML 数据资产(URI 文件夹)将图像上传到数据存储区

为了使用数据在 Azure ML 中进行训练,我们将其上传到 Azure ML 工作区的默认 Azure Blob 存储。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Register dataset</span>
my_data <span style="color:#00e0e0">=</span> Data<span style="color:#fefefe">(</span>
    path<span style="color:#00e0e0">=</span> local_train_data<span style="color:#fefefe">,</span>
    <span style="color:#abe338">type</span><span style="color:#00e0e0">=</span> AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">,</span>
    description<span style="color:#00e0e0">=</span> <span style="color:#abe338">"Training images for Dreambooth finetuning"</span><span style="color:#fefefe">,</span>
    name<span style="color:#00e0e0">=</span> azureml_dataset_name
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>my_data<span style="color:#fefefe">)</span></code></span></span>
步骤 4:创建训练环境

我们需要一个 dreambooth-conda.yaml 文件来创建我们的客户环境。

language-yaml 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-yaml"><span style="color:#ffd700">name</span><span style="color:#fefefe">:</span> dreambooth<span style="color:#fefefe">-</span>flux<span style="color:#fefefe">-</span>env
<span style="color:#ffd700">channels</span><span style="color:#fefefe">:</span>
  <span style="color:#fefefe">-</span> conda<span style="color:#fefefe">-</span>forge
<span style="color:#ffd700">dependencies</span><span style="color:#fefefe">:</span>
  <span style="color:#fefefe">-</span> python=3.10
  <span style="color:#fefefe">-</span> <span style="color:#ffd700">pip</span><span style="color:#fefefe">:</span>
      <span style="color:#fefefe">-</span> <span style="color:#abe338">'git+https://github.com/huggingface/diffusers.git'</span>
      <span style="color:#fefefe">-</span> transformers<span style="color:#fefefe">></span>=4.41.2
      <span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>acft<span style="color:#fefefe">-</span>accelerator==0.0.59
      <span style="color:#fefefe">-</span> azureml_acft_common_components==0.0.59
      <span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>acft<span style="color:#fefefe">-</span>contrib<span style="color:#fefefe">-</span>hf<span style="color:#fefefe">-</span>nlp==0.0.59
      <span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>evaluate<span style="color:#fefefe">-</span>mlflow==0.0.59
      <span style="color:#fefefe">-</span> azureml<span style="color:#fefefe">-</span>metrics<span style="color:#fefefe">[</span>text<span style="color:#fefefe">]</span>==0.0.59
      <span style="color:#fefefe">-</span> mltable==1.6.1
      <span style="color:#fefefe">-</span> mpi4py==3.1.5
      <span style="color:#fefefe">-</span> sentencepiece==0.1.99
      <span style="color:#fefefe">-</span> transformers==4.44.0
      <span style="color:#fefefe">-</span> datasets==2.17.1
      <span style="color:#fefefe">-</span> optimum==1.17.1
      <span style="color:#fefefe">-</span> accelerate<span style="color:#fefefe">></span>=0.31.0
      <span style="color:#fefefe">-</span> onnxruntime==1.17.3
      <span style="color:#fefefe">-</span> rouge<span style="color:#fefefe">-</span>score==0.1.2
      <span style="color:#fefefe">-</span> sacrebleu==2.4.0
      <span style="color:#fefefe">-</span> bitsandbytes==0.43.3
      <span style="color:#fefefe">-</span> einops==0.7.0
      <span style="color:#fefefe">-</span> aiohttp==3.10.5
      <span style="color:#fefefe">-</span> peft==0.8.2
      <span style="color:#fefefe">-</span> deepspeed==0.15.0
      <span style="color:#fefefe">-</span> trl==0.8.1
      <span style="color:#fefefe">-</span> tiktoken==0.6.0
      <span style="color:#fefefe">-</span> scipy==1.14.0
</code></span></span>
language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">environment <span style="color:#00e0e0">=</span> Environment<span style="color:#fefefe">(</span>
    image<span style="color:#00e0e0">=</span><span style="color:#abe338">"mcr.microsoft.com/azureml/curated/acft-hf-nlp-gpu:67"</span><span style="color:#fefefe">,</span>
    conda_file<span style="color:#00e0e0">=</span><span style="color:#abe338">"environment/dreambooth-conda.yaml"</span><span style="color:#fefefe">,</span>
    name<span style="color:#00e0e0">=</span>training_env_name<span style="color:#fefefe">,</span>
    description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Dreambooth training environment"</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>

workspace_ml_client<span style="color:#fefefe">.</span>environments<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>environment<span style="color:#fefefe">)</span></code></span></span>
步骤 5:创建计算

为了在 Azure 机器学习工作室上微调模型,您需要先创建计算资源。创建计算需要 3-4 分钟。

有关更多参考,请参阅一日 Azure 机器学习

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">try</span><span style="color:#fefefe">:</span>
    _ <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>compute<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>train_target<span style="color:#fefefe">)</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"Found existing compute target."</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">except</span> ResourceNotFoundError<span style="color:#fefefe">:</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"Creating a new compute target..."</span><span style="color:#fefefe">)</span>
    compute_config <span style="color:#00e0e0">=</span> AmlCompute<span style="color:#fefefe">(</span>
        name<span style="color:#00e0e0">=</span>train_target<span style="color:#fefefe">,</span>
        <span style="color:#abe338">type</span><span style="color:#00e0e0">=</span><span style="color:#abe338">"amlcompute"</span><span style="color:#fefefe">,</span>
        size<span style="color:#00e0e0">=</span><span style="color:#abe338">"Standard_NC24ads_A100_v4"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># 1 x A100, 80 GB GPU memory each</span>
        tier<span style="color:#00e0e0">=</span><span style="color:#abe338">"low_priority"</span><span style="color:#fefefe">,</span>
        idle_time_before_scale_down<span style="color:#00e0e0">=</span><span style="color:#00e0e0">600</span><span style="color:#fefefe">,</span>
        min_instances<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">,</span>
        max_instances<span style="color:#00e0e0">=</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
    <span style="color:#fefefe">)</span>
    workspace_ml_client<span style="color:#fefefe">.</span>begin_create_or_update<span style="color:#fefefe">(</span>compute_config<span style="color:#fefefe">)</span></code></span></span>
步骤 6:创建计算

我们将使用此笔记本中的模型。按照本指南,您已成功使用 Azure 上的 Diffusers 和 Dreambooth 对文本到图像模型进行了微调。此模型可以根据文本描述生成高质量的狗图像,展示了结合这些高级技术的强大功能和灵活性。请随意尝试不同的提示和微调参数,以进一步探索模型的功能。 black-forest-labs/FLUX.1-schnell

首先让我们创建一个命令行指令

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">command_str <span style="color:#00e0e0">=</span> <span style="color:#abe338">'''python prepare.py && accelerate launch train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path="black-forest-labs/FLUX.1-schnell" \
  --instance_data_dir=${{inputs.input_data}} \
  --output_dir="outputs/models" \
  --mixed_precision="bf16" \
  --instance_prompt="photo of sks dog" \
  --class_prompt="photo of a dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --learning_rate=1e-5 \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=2500 \
  --seed="0"'''</span></code></span></span>

如您所见,您需要 2 个文件来运行上述命令行 train_dreambooth_lora_flux.py 和 prepare.py。您可以从此处的官方扩散器存储库下载 train_dreambooth_lora_flux.py 。

以下是 prepare.py 的代码

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> os

os<span style="color:#fefefe">.</span>environ<span style="color:#fefefe">[</span><span style="color:#abe338">"PYTORCH_CUDA_ALLOC_CONF"</span><span style="color:#fefefe">]</span> <span style="color:#00e0e0">=</span> <span style="color:#abe338">"max_split_size_mb:100"</span>

<span style="color:#00e0e0">from</span> accelerate<span style="color:#fefefe">.</span>utils <span style="color:#00e0e0">import</span> write_basic_config
write_basic_config<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span></code></span></span>

你的文件夹结构应该是这样的

复制代码
`src/
    prepare.py
    train_dreambooth_lora_flux.py
`

现在让我们初始化一些变量。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Retrieve latest version of dataset</span>
latest_version <span style="color:#00e0e0">=</span> <span style="color:#fefefe">[</span>dataset<span style="color:#fefefe">.</span>latest_version <span style="color:#00e0e0">for</span> dataset <span style="color:#00e0e0">in</span> workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span><span style="color:#abe338">list</span><span style="color:#fefefe">(</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">if</span> dataset<span style="color:#fefefe">.</span>name <span style="color:#00e0e0">==</span> azureml_dataset_name<span style="color:#fefefe">]</span><span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span>
dataset_asset <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>data<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>name<span style="color:#00e0e0">=</span> azureml_dataset_name<span style="color:#fefefe">,</span> version<span style="color:#00e0e0">=</span> latest_version<span style="color:#fefefe">)</span>
<span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'Latest version of </span><span style="color:#fefefe">{</span>azureml_dataset_name<span style="color:#fefefe">}</span><span style="color:#abe338">: </span><span style="color:#fefefe">{</span>latest_version<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>
inputs <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">"input_data"</span><span style="color:#fefefe">:</span> Input<span style="color:#fefefe">(</span><span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">,</span> path<span style="color:#00e0e0">=</span><span style="color:#abe338">f'azureml:</span><span style="color:#fefefe">{</span>azureml_dataset_name<span style="color:#fefefe">}</span><span style="color:#abe338">:</span><span style="color:#fefefe">{</span>latest_version<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span><span style="color:#fefefe">}</span>
outputs <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">"output_dir"</span><span style="color:#fefefe">:</span> Output<span style="color:#fefefe">(</span><span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>URI_FOLDER<span style="color:#fefefe">)</span><span style="color:#fefefe">}</span></code></span></span>

在这种情况下,我们提交一份包含上述代码、计算和创建的环境的作业。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">job <span style="color:#00e0e0">=</span> command<span style="color:#fefefe">(</span>
    inputs <span style="color:#00e0e0">=</span> inputs<span style="color:#fefefe">,</span>
    outputs <span style="color:#00e0e0">=</span> outputs<span style="color:#fefefe">,</span>
    code <span style="color:#00e0e0">=</span> <span style="color:#abe338">"./src"</span><span style="color:#fefefe">,</span>
    command <span style="color:#00e0e0">=</span> command_str<span style="color:#fefefe">,</span>
    environment <span style="color:#00e0e0">=</span> <span style="color:#abe338">f"</span><span style="color:#fefefe">{</span>training_env_name<span style="color:#fefefe">}</span><span style="color:#abe338">:latest"</span><span style="color:#fefefe">,</span>
    compute <span style="color:#00e0e0">=</span>  train_target<span style="color:#fefefe">,</span> 
    experiment_name <span style="color:#00e0e0">=</span> experiment_name<span style="color:#fefefe">,</span>
    display_name<span style="color:#00e0e0">=</span> <span style="color:#abe338">"flux-finetune-batchsize-1"</span><span style="color:#fefefe">,</span>
    environment_variables <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#abe338">'HF_TOKEN'</span><span style="color:#fefefe">:</span> <span style="color:#abe338">'Place Your HF Token Here'</span><span style="color:#fefefe">}</span>
    <span style="color:#fefefe">)</span>

returned_job <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>jobs<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>job<span style="color:#fefefe">)</span>
returned_job</code></span></span>
步骤 7:下载微调模型并注册模型

微调后,评估模型以确保其满足您的要求。

我们将从微调作业的输出中注册模型。这将跟踪微调模型和微调作业之间的血统。此外,微调作业还跟踪基础模型、数据和训练代码的血统。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Obtain the tracking URL from MLClient</span>
MLFLOW_TRACKING_URI <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>workspaces<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>name<span style="color:#00e0e0">=</span>workspace_ml_client<span style="color:#fefefe">.</span>workspace_name<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>mlflow_tracking_uri

<span style="color:#d4d0ab"># Set the MLFLOW TRACKING URI</span>
mlflow<span style="color:#fefefe">.</span>set_tracking_uri<span style="color:#fefefe">(</span>MLFLOW_TRACKING_URI<span style="color:#fefefe">)</span>

<span style="color:#d4d0ab"># Initialize MLFlow client</span>
mlflow_client <span style="color:#00e0e0">=</span> MlflowClient<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
mlflow_run <span style="color:#00e0e0">=</span> mlflow_client<span style="color:#fefefe">.</span>get_run<span style="color:#fefefe">(</span>returned_job<span style="color:#fefefe">.</span>name<span style="color:#fefefe">)</span>
mlflow<span style="color:#fefefe">.</span>artifacts<span style="color:#fefefe">.</span>download_artifacts<span style="color:#fefefe">(</span>run_id<span style="color:#00e0e0">=</span> mlflow_run<span style="color:#fefefe">.</span>info<span style="color:#fefefe">.</span>run_id<span style="color:#fefefe">,</span> 
                                    artifact_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"outputs/models/"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># Azure ML job output</span>
                                    dst_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"./train-artifacts"</span><span style="color:#fefefe">)</span> <span style="color:#d4d0ab"># local folder</span></code></span></span>

现在让我们下载模型

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">json_path <span style="color:#00e0e0">=</span> <span style="color:#abe338">"./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"</span>

<span style="color:#00e0e0">if</span> os<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>isdir<span style="color:#fefefe">(</span><span style="color:#abe338">"./train-artifacts/outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    shutil<span style="color:#fefefe">.</span>rmtree<span style="color:#fefefe">(</span>json_path<span style="color:#fefefe">)</span>
    mlflow<span style="color:#fefefe">.</span>artifacts<span style="color:#fefefe">.</span>download_artifacts<span style="color:#fefefe">(</span>run_id<span style="color:#00e0e0">=</span> mlflow_run<span style="color:#fefefe">.</span>info<span style="color:#fefefe">.</span>run_id<span style="color:#fefefe">,</span> 
                                    artifact_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">,</span> <span style="color:#d4d0ab"># Azure ML job output</span>
                                    dst_path<span style="color:#00e0e0">=</span><span style="color:#abe338">"./train-artifacts"</span><span style="color:#fefefe">)</span> <span style="color:#d4d0ab"># local folder</span></code></span></span>

最后让我们注册模型。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>entities <span style="color:#00e0e0">import</span> Model
<span style="color:#00e0e0">from</span> azure<span style="color:#fefefe">.</span>ai<span style="color:#fefefe">.</span>ml<span style="color:#fefefe">.</span>constants <span style="color:#00e0e0">import</span> AssetTypes

run_model <span style="color:#00e0e0">=</span> Model<span style="color:#fefefe">(</span>
    path<span style="color:#00e0e0">=</span><span style="color:#abe338">f"azureml://jobs/</span><span style="color:#fefefe">{</span>returned_job<span style="color:#fefefe">.</span>name<span style="color:#fefefe">}</span><span style="color:#abe338">/outputs/artifacts/paths/outputs/models/pytorch_lora_weights.safetensors"</span><span style="color:#fefefe">,</span>
    name<span style="color:#00e0e0">=</span><span style="color:#abe338">"mano-dreambooth-flux-finetuned"</span><span style="color:#fefefe">,</span>
    description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Model created from run."</span><span style="color:#fefefe">,</span>
    <span style="color:#abe338">type</span><span style="color:#00e0e0">=</span>AssetTypes<span style="color:#fefefe">.</span>CUSTOM_MODEL<span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
model <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>models<span style="color:#fefefe">.</span>create_or_update<span style="color:#fefefe">(</span>run_model<span style="color:#fefefe">)</span></code></span></span>
步骤 8:在线管理端点部署

现在让我们将这个经过微调的模型部署为 AML 上的在线托管端点。首先让我们定义一些常量变量,以便稍后部署时使用。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python">endpoint_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux-endpoint-finetuned-a100'</span>
deployment_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">'flux'</span>
instance_type <span style="color:#00e0e0">=</span> 'Standard_NC24ads_A100_v4
score_file <span style="color:#00e0e0">=</span> <span style="color:#abe338">'score.py'</span></code></span></span>

让我们创建一个托管的在线端点。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># create an online endpoint</span>
endpoint <span style="color:#00e0e0">=</span> ManagedOnlineEndpoint<span style="color:#fefefe">(</span>
    name<span style="color:#00e0e0">=</span>endpoint_name<span style="color:#fefefe">,</span>
    description<span style="color:#00e0e0">=</span><span style="color:#abe338">"this is the flux inference online endpoint"</span><span style="color:#fefefe">,</span>
    auth_mode<span style="color:#00e0e0">=</span><span style="color:#abe338">"key"</span>
<span style="color:#fefefe">)</span>
workspace_ml_client<span style="color:#fefefe">.</span>online_endpoints<span style="color:#fefefe">.</span>begin_create_or_update<span style="color:#fefefe">(</span>endpoint<span style="color:#fefefe">)</span></code></span></span>
步骤 9:为在线端点创建推理环境

首先,我们创建一个 Dockerfile,它将在创建环境时使用。

language-bash 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-bash">FROM mcr.microsoft.com/aifx/acpt/stable-ubuntu2004-cu121-py310-torch22x:biweekly.202408.3

<span style="color:#d4d0ab"># Install pip dependencies</span>
COPY requirements.txt <span style="color:#abe338">.</span>
RUN pip <span style="color:#ffd700">install</span> -r requirements.txt --no-cache-dir

<span style="color:#d4d0ab"># Inference requirements</span>
COPY --from<span style="color:#00e0e0">=</span>mcr.microsoft.com/azureml/o16n-base/python-assets:20230419.v1 /artifacts /var/
RUN /var/requirements/install_system_requirements.sh <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
    <span style="color:#ffd700">cp</span> /var/configuration/rsyslog.conf /etc/rsyslog.conf <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
    <span style="color:#ffd700">cp</span> /var/configuration/nginx.conf /etc/nginx/sites-available/app <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
    <span style="color:#ffd700">ln</span> -sf /etc/nginx/sites-available/app /etc/nginx/sites-enabled/app <span style="color:#00e0e0">&&</span> <span style="color:#fefefe">\</span>
    <span style="color:#ffd700">rm</span> -f /etc/nginx/sites-enabled/default
ENV <span style="color:#00e0e0">SVDIR</span><span style="color:#00e0e0">=</span>/var/runit
ENV <span style="color:#00e0e0">WORKER_TIMEOUT</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">400</span>
EXPOSE <span style="color:#00e0e0">5001</span> <span style="color:#00e0e0">8883</span> <span style="color:#00e0e0">8888</span>   

<span style="color:#d4d0ab"># support Deepspeed launcher requirement of passwordless ssh login</span>
RUN <span style="color:#ffd700">apt-get</span> update
RUN <span style="color:#ffd700">apt-get</span> <span style="color:#ffd700">install</span> -y openssh-server openssh-client
</code></span></span>

此Dockefile的requirements.txt部分如下。

language-applescript 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">azureml<span style="color:#00e0e0">-</span>core<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>dataset<span style="color:#00e0e0">-</span>runtime<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>defaults<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azure<span style="color:#00e0e0">-</span>ml<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.0</span><span style="color:#00e0e0">.1</span>
azure<span style="color:#00e0e0">-</span>ml<span style="color:#00e0e0">-</span>component<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.9</span><span style="color:#00e0e0">.18</span>.post2
azureml<span style="color:#00e0e0">-</span>mlflow<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>contrib<span style="color:#00e0e0">-</span>services<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>contrib<span style="color:#00e0e0">-</span>services<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.57</span><span style="color:#00e0e0">.0</span>
torch<span style="color:#00e0e0">-</span>tb<span style="color:#00e0e0">-</span>profiler~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.4</span><span style="color:#00e0e0">.0</span>
azureml<span style="color:#00e0e0">-</span>inference<span style="color:#00e0e0">-</span>server<span style="color:#00e0e0">-</span>http
inference<span style="color:#00e0e0">-</span>schema
MarkupSafe<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">2.1</span><span style="color:#00e0e0">.2</span>
regex
pybind11
urllib3<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">1.26</span><span style="color:#00e0e0">.18</span>
cryptography<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">42.0</span><span style="color:#00e0e0">.4</span>
aiohttp<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">3.8</span><span style="color:#00e0e0">.5</span>
py<span style="color:#00e0e0">-</span>spy<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.3</span><span style="color:#00e0e0">.12</span>
debugpy~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1.6</span><span style="color:#00e0e0">.3</span>
ipykernel~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">6.0</span>
tensorboard
psutil~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">5.8</span><span style="color:#00e0e0">.0</span>
matplotlib~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">3.5</span><span style="color:#00e0e0">.0</span>
tqdm~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">4.66</span><span style="color:#00e0e0">.3</span>
py<span style="color:#00e0e0">-</span>cpuinfo<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">5.0</span><span style="color:#00e0e0">.0</span>
torch<span style="color:#00e0e0">-</span>tb<span style="color:#00e0e0">-</span>profiler~<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.4</span><span style="color:#00e0e0">.0</span>
transformers<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">4.44</span><span style="color:#00e0e0">.2</span>
diffusers<span style="color:#00e0e0">=</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.30</span><span style="color:#00e0e0">.1</span>
accelerate<span style="color:#00e0e0">>=</span><span style="color:#00e0e0">0.31</span><span style="color:#00e0e0">.0</span>
sentencepiece
peft
bitsandbytes</code></span></span>

确保文件夹结构符合以下格式

复制代码
`inference-env/python-and-pip
    Dockerfile
    requirements.txt`

最后,让我们运行下面的代码来为 Flux LORA 模型创建推理环境

language-applescript 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">env_docker_context <span style="color:#00e0e0">=</span> Environment<span style="color:#fefefe">(</span>
    build<span style="color:#00e0e0">=</span>BuildContext<span style="color:#fefefe">(</span>path<span style="color:#00e0e0">=</span><span style="color:#abe338">"docker-contexts/python-and-pip"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
    name<span style="color:#00e0e0">=</span>inference_env_name<span style="color:#fefefe">,</span>
    description<span style="color:#00e0e0">=</span><span style="color:#abe338">"Environment created from a Docker context."</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
ml_client.environments.create_or_update<span style="color:#fefefe">(</span>env_docker_context<span style="color:#fefefe">)</span></code></span></span>
步骤 10:为托管在线端点创建部署

最后,让我们将模型部署到我们创建的端点。让我们创建一个名为 score.py 的文件并将其放在名为 assets 的文件夹下。

复制代码
`assets/
    score.py
`
language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#00e0e0">import</span> torch
<span style="color:#00e0e0">import</span> io
<span style="color:#00e0e0">import</span> os
<span style="color:#00e0e0">import</span> logging
<span style="color:#00e0e0">import</span> json
<span style="color:#00e0e0">import</span> math
<span style="color:#00e0e0">import</span> numpy <span style="color:#00e0e0">as</span> np
<span style="color:#00e0e0">from</span> base64 <span style="color:#00e0e0">import</span> b64encode
<span style="color:#00e0e0">import</span> requests
<span style="color:#00e0e0">from</span> PIL <span style="color:#00e0e0">import</span> Image<span style="color:#fefefe">,</span> ImageDraw
<span style="color:#00e0e0">from</span> safetensors<span style="color:#fefefe">.</span>torch <span style="color:#00e0e0">import</span> load_file
<span style="color:#00e0e0">from</span> azureml<span style="color:#fefefe">.</span>contrib<span style="color:#fefefe">.</span>services<span style="color:#fefefe">.</span>aml_response <span style="color:#00e0e0">import</span> AMLResponse

<span style="color:#00e0e0">from</span> transformers <span style="color:#00e0e0">import</span> pipeline
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> DiffusionPipeline<span style="color:#fefefe">,</span> StableDiffusionXLImg2ImgPipeline
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> AutoPipelineForText2Image<span style="color:#fefefe">,</span> FluxPipeline
<span style="color:#00e0e0">from</span> diffusers<span style="color:#fefefe">.</span>schedulers <span style="color:#00e0e0">import</span> EulerAncestralDiscreteScheduler
<span style="color:#00e0e0">from</span> diffusers <span style="color:#00e0e0">import</span> DPMSolverMultistepScheduler

device <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>device<span style="color:#fefefe">(</span><span style="color:#abe338">"cuda"</span> <span style="color:#00e0e0">if</span> torch<span style="color:#fefefe">.</span>cuda<span style="color:#fefefe">.</span>is_available<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">else</span> <span style="color:#abe338">"cpu"</span><span style="color:#fefefe">)</span>
<span style="color:#00e0e0">def</span> <span style="color:#ffd700">init</span><span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    <span style="color:#abe338">"""
    This function is called when the container is initialized/started, typically after create/update of the deployment.
    You can write the logic here to perform init operations like caching the model in memory
    """</span>
    <span style="color:#00e0e0">global</span> pipe<span style="color:#fefefe">,</span> refiner
    weights_path <span style="color:#00e0e0">=</span> os<span style="color:#fefefe">.</span>path<span style="color:#fefefe">.</span>join<span style="color:#fefefe">(</span>
        os<span style="color:#fefefe">.</span>getenv<span style="color:#fefefe">(</span><span style="color:#abe338">"AZUREML_MODEL_DIR"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span> <span style="color:#abe338">"pytorch_lora_weights.safetensors"</span>
    <span style="color:#fefefe">)</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">"weights_path:"</span><span style="color:#fefefe">,</span> weights_path<span style="color:#fefefe">)</span>
    pipe <span style="color:#00e0e0">=</span> FluxPipeline<span style="color:#fefefe">.</span>from_pretrained<span style="color:#fefefe">(</span><span style="color:#abe338">"black-forest-labs/FLUX.1-dev"</span><span style="color:#fefefe">,</span> torch_dtype<span style="color:#00e0e0">=</span>torch<span style="color:#fefefe">.</span>bfloat16<span style="color:#fefefe">)</span>
    pipe<span style="color:#fefefe">.</span>enable_model_cpu_offload<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
    pipe<span style="color:#fefefe">.</span>load_lora_weights<span style="color:#fefefe">(</span>weights_path<span style="color:#fefefe">,</span> use_safetensors<span style="color:#00e0e0">=</span><span style="color:#00e0e0">True</span><span style="color:#fefefe">)</span>
    pipe<span style="color:#fefefe">.</span>to<span style="color:#fefefe">(</span>device<span style="color:#fefefe">)</span>
    <span style="color:#d4d0ab"># refiner = StableDiffusionXLImg2ImgPipeline.from_pretrained(</span>
    <span style="color:#d4d0ab">#                         "stabilityai/stable-diffusion-xl-refiner-1.0", </span>
    <span style="color:#d4d0ab">#                         torch_dtype=torch.float16, </span>
    <span style="color:#d4d0ab">#                         use_safetensors=True, </span>
    <span style="color:#d4d0ab">#                         variant="fp16"</span>
    <span style="color:#d4d0ab">#                     )</span>
    <span style="color:#d4d0ab"># refiner.to(device)</span>
    logging<span style="color:#fefefe">.</span>info<span style="color:#fefefe">(</span><span style="color:#abe338">"Init complete"</span><span style="color:#fefefe">)</span>


<span style="color:#00e0e0">def</span> <span style="color:#ffd700">get_image_object</span><span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    <span style="color:#abe338">"""
    This function takes an image URL and returns an Image object.
    """</span>
    response <span style="color:#00e0e0">=</span> requests<span style="color:#fefefe">.</span>get<span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span>
    init_image <span style="color:#00e0e0">=</span> Image<span style="color:#fefefe">.</span><span style="color:#abe338">open</span><span style="color:#fefefe">(</span>io<span style="color:#fefefe">.</span>BytesIO<span style="color:#fefefe">(</span>response<span style="color:#fefefe">.</span>content<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>convert<span style="color:#fefefe">(</span><span style="color:#abe338">"RGB"</span><span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
    <span style="color:#00e0e0">return</span> init_image

<span style="color:#00e0e0">def</span> <span style="color:#ffd700">prepare_response</span><span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    <span style="color:#abe338">"""
    This function takes a list of images and converts them to a dictionary of base64 encoded strings.
    """</span>
    ENCODING <span style="color:#00e0e0">=</span> <span style="color:#abe338">'utf-8'</span>
    dic_response <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span><span style="color:#fefefe">}</span>
    <span style="color:#00e0e0">for</span> i<span style="color:#fefefe">,</span> image <span style="color:#00e0e0">in</span> <span style="color:#abe338">enumerate</span><span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
        output <span style="color:#00e0e0">=</span> io<span style="color:#fefefe">.</span>BytesIO<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span>
        image<span style="color:#fefefe">.</span>save<span style="color:#fefefe">(</span>output<span style="color:#fefefe">,</span> <span style="color:#abe338">format</span><span style="color:#00e0e0">=</span><span style="color:#abe338">"JPEG"</span><span style="color:#fefefe">)</span>
        base64_bytes <span style="color:#00e0e0">=</span> b64encode<span style="color:#fefefe">(</span>output<span style="color:#fefefe">.</span>getvalue<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
        base64_string <span style="color:#00e0e0">=</span> base64_bytes<span style="color:#fefefe">.</span>decode<span style="color:#fefefe">(</span>ENCODING<span style="color:#fefefe">)</span>
        dic_response<span style="color:#fefefe">[</span><span style="color:#abe338">f'image_</span><span style="color:#fefefe">{</span>i<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">]</span> <span style="color:#00e0e0">=</span> base64_string
    <span style="color:#00e0e0">return</span> dic_response

<span style="color:#00e0e0">def</span> <span style="color:#ffd700">design</span><span style="color:#fefefe">(</span>prompt<span style="color:#fefefe">,</span> image<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> num_images_per_prompt<span style="color:#00e0e0">=</span><span style="color:#00e0e0">4</span><span style="color:#fefefe">,</span> negative_prompt<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> strength<span style="color:#00e0e0">=</span><span style="color:#00e0e0">0.65</span><span style="color:#fefefe">,</span> guidance_scale<span style="color:#00e0e0">=</span><span style="color:#00e0e0">7.5</span><span style="color:#fefefe">,</span> num_inference_steps<span style="color:#00e0e0">=</span><span style="color:#00e0e0">50</span><span style="color:#fefefe">,</span> seed<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> design_type<span style="color:#00e0e0">=</span><span style="color:#abe338">'TXT_TO_IMG'</span><span style="color:#fefefe">,</span> mask<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">,</span> other_args<span style="color:#00e0e0">=</span><span style="color:#00e0e0">None</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    <span style="color:#abe338">"""
    This function takes various parameters like prompt, image, seed, design_type, etc., and generates images based on the specified design type. It returns a list of generated images.
    """</span>
    generator <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    <span style="color:#00e0e0">if</span> seed<span style="color:#fefefe">:</span>
        generator <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>manual_seed<span style="color:#fefefe">(</span>seed<span style="color:#fefefe">)</span>
    <span style="color:#00e0e0">else</span><span style="color:#fefefe">:</span>
        generator <span style="color:#00e0e0">=</span> torch<span style="color:#fefefe">.</span>manual_seed<span style="color:#fefefe">(</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">)</span>

    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">'other_args'</span><span style="color:#fefefe">,</span> other_args<span style="color:#fefefe">)</span>
    image <span style="color:#00e0e0">=</span> pipe<span style="color:#fefefe">(</span>prompt<span style="color:#00e0e0">=</span>prompt<span style="color:#fefefe">,</span> 
                 height<span style="color:#00e0e0">=</span><span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
                 width<span style="color:#00e0e0">=</span><span style="color:#00e0e0">768</span><span style="color:#fefefe">,</span>
                 guidance_scale<span style="color:#00e0e0">=</span>guidance_scale<span style="color:#fefefe">,</span>
                 output_type<span style="color:#00e0e0">=</span><span style="color:#abe338">"latent"</span><span style="color:#fefefe">,</span> 
                 generator<span style="color:#00e0e0">=</span>generator<span style="color:#fefefe">)</span><span style="color:#fefefe">.</span>images<span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span>
    <span style="color:#d4d0ab">#image = refiner(prompt=prompt, image=image[None, :], generator=generator).images[0]    </span>
    <span style="color:#00e0e0">return</span> <span style="color:#fefefe">[</span>image<span style="color:#fefefe">]</span>


<span style="color:#00e0e0">def</span> <span style="color:#ffd700">run</span><span style="color:#fefefe">(</span>raw_data<span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
    <span style="color:#abe338">"""
     This function takes raw data as input, processes it, and calls the design function to generate images.
     It then prepares the response and returns it.
    """</span>
    logging<span style="color:#fefefe">.</span>info<span style="color:#fefefe">(</span><span style="color:#abe338">"Request received"</span><span style="color:#fefefe">)</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'raw data: </span><span style="color:#fefefe">{</span>raw_data<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>
    data <span style="color:#00e0e0">=</span> json<span style="color:#fefefe">.</span>loads<span style="color:#fefefe">(</span>raw_data<span style="color:#fefefe">)</span><span style="color:#fefefe">[</span><span style="color:#abe338">"data"</span><span style="color:#fefefe">]</span>
    <span style="color:#00e0e0">print</span><span style="color:#fefefe">(</span><span style="color:#abe338">f'data: </span><span style="color:#fefefe">{</span>data<span style="color:#fefefe">}</span><span style="color:#abe338">'</span><span style="color:#fefefe">)</span>

    prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'prompt'</span><span style="color:#fefefe">]</span>
    negative_prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'negative_prompt'</span><span style="color:#fefefe">]</span>
    seed <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'seed'</span><span style="color:#fefefe">]</span>
    num_images_per_prompt <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'num_images_per_prompt'</span><span style="color:#fefefe">]</span>
    guidance_scale <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'guidance_scale'</span><span style="color:#fefefe">]</span>
    num_inference_steps <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'num_inference_steps'</span><span style="color:#fefefe">]</span>
    design_type <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'design_type'</span><span style="color:#fefefe">]</span>

    image_url <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    mask_url <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    mask <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    other_args <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    image <span style="color:#00e0e0">=</span> <span style="color:#00e0e0">None</span>
    strength <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'strength'</span><span style="color:#fefefe">]</span>

    <span style="color:#00e0e0">if</span> <span style="color:#abe338">'mask_image'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
        mask_url <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'mask_image'</span><span style="color:#fefefe">]</span>
        mask <span style="color:#00e0e0">=</span> get_image_object<span style="color:#fefefe">(</span>mask_url<span style="color:#fefefe">)</span>

    <span style="color:#00e0e0">if</span> <span style="color:#abe338">'other_args'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
        other_args <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'other_args'</span><span style="color:#fefefe">]</span>


    <span style="color:#00e0e0">if</span> <span style="color:#abe338">'image_url'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
        image_url <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'image_url'</span><span style="color:#fefefe">]</span>
        image <span style="color:#00e0e0">=</span> get_image_object<span style="color:#fefefe">(</span>image_url<span style="color:#fefefe">)</span>

    <span style="color:#00e0e0">if</span> <span style="color:#abe338">'strength'</span> <span style="color:#00e0e0">in</span> data<span style="color:#fefefe">:</span>
        strength <span style="color:#00e0e0">=</span> data<span style="color:#fefefe">[</span><span style="color:#abe338">'strength'</span><span style="color:#fefefe">]</span>

    <span style="color:#00e0e0">with</span> torch<span style="color:#fefefe">.</span>inference_mode<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span><span style="color:#fefefe">:</span>
        images <span style="color:#00e0e0">=</span> design<span style="color:#fefefe">(</span>prompt<span style="color:#00e0e0">=</span>prompt<span style="color:#fefefe">,</span> image<span style="color:#00e0e0">=</span>image<span style="color:#fefefe">,</span> 
                        num_images_per_prompt<span style="color:#00e0e0">=</span>num_images_per_prompt<span style="color:#fefefe">,</span> 
                        negative_prompt<span style="color:#00e0e0">=</span>negative_prompt<span style="color:#fefefe">,</span> strength<span style="color:#00e0e0">=</span>strength<span style="color:#fefefe">,</span> 
                        guidance_scale<span style="color:#00e0e0">=</span>guidance_scale<span style="color:#fefefe">,</span> num_inference_steps<span style="color:#00e0e0">=</span>num_inference_steps<span style="color:#fefefe">,</span>
                        seed<span style="color:#00e0e0">=</span>seed<span style="color:#fefefe">,</span> design_type<span style="color:#00e0e0">=</span>design_type<span style="color:#fefefe">,</span> mask<span style="color:#00e0e0">=</span>mask<span style="color:#fefefe">,</span> other_args<span style="color:#00e0e0">=</span>other_args<span style="color:#fefefe">)</span>
    
    preped_response <span style="color:#00e0e0">=</span> prepare_response<span style="color:#fefefe">(</span>images<span style="color:#fefefe">)</span>
    resp <span style="color:#00e0e0">=</span> AMLResponse<span style="color:#fefefe">(</span>message<span style="color:#00e0e0">=</span>preped_response<span style="color:#fefefe">,</span> status_code<span style="color:#00e0e0">=</span><span style="color:#00e0e0">200</span><span style="color:#fefefe">,</span> json_str<span style="color:#00e0e0">=</span><span style="color:#00e0e0">True</span><span style="color:#fefefe">)</span>

    <span style="color:#00e0e0">return</span> resp

</code></span></span>

最后我们可以继续部署它。

language-applescript 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-applescript">deployment <span style="color:#00e0e0">=</span> ManagedOnlineDeployment<span style="color:#fefefe">(</span>
    name<span style="color:#00e0e0">=</span>deployment_name<span style="color:#fefefe">,</span>
    endpoint_name<span style="color:#00e0e0">=</span>endpoint_name<span style="color:#fefefe">,</span>
    model<span style="color:#00e0e0">=</span>model<span style="color:#fefefe">,</span>
    environment<span style="color:#00e0e0">=</span>env_docker_context<span style="color:#fefefe">,</span>
    code_configuration<span style="color:#00e0e0">=</span>CodeConfiguration<span style="color:#fefefe">(</span>
        code<span style="color:#00e0e0">=</span><span style="color:#abe338">"assets"</span><span style="color:#fefefe">,</span> scoring_script<span style="color:#00e0e0">=</span>score_file
    <span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
    instance_type<span style="color:#00e0e0">=</span>instance_type<span style="color:#fefefe">,</span>
    instance_count<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
    request_settings<span style="color:#00e0e0">=</span>OnlineRequestSettings<span style="color:#fefefe">(</span>request_timeout_ms<span style="color:#00e0e0">=</span><span style="color:#00e0e0">90000</span><span style="color:#fefefe">,</span> max_queue_wait_ms<span style="color:#00e0e0">=</span><span style="color:#00e0e0">900000</span><span style="color:#fefefe">,</span> max_concurrent_requests_per_instance<span style="color:#00e0e0">=</span><span style="color:#00e0e0">5</span><span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
    liveness_probe<span style="color:#00e0e0">=</span>ProbeSettings<span style="color:#fefefe">(</span>
        failure_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">30</span><span style="color:#fefefe">,</span>
        success_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
        <span style="color:#00e0e0">timeout</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
        period<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
        initial_delay<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1000</span><span style="color:#fefefe">,</span>
    <span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
    readiness_probe<span style="color:#00e0e0">=</span>ProbeSettings<span style="color:#fefefe">(</span>
        failure_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
        success_threshold<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1</span><span style="color:#fefefe">,</span>
        <span style="color:#00e0e0">timeout</span><span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
        period<span style="color:#00e0e0">=</span><span style="color:#00e0e0">10</span><span style="color:#fefefe">,</span>
        initial_delay<span style="color:#00e0e0">=</span><span style="color:#00e0e0">1000</span><span style="color:#fefefe">,</span>
    <span style="color:#fefefe">)</span><span style="color:#fefefe">,</span>
    environment_variables <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span>'HF_TOKEN'<span style="color:#fefefe">:</span> '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'<span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
workspace_ml_client.online_deployments.begin_create_or_update<span style="color:#fefefe">(</span>deployment<span style="color:#fefefe">)</span>.result<span style="color:#fefefe">(</span><span style="color:#fefefe">)</span></code></span></span>
步骤 11:测试部署

最后,我们可以测试这个端点了。

language-python 复制代码
<span style="background-color:#2b2b2b"><span style="color:#f8f8f2"><code class="language-python"><span style="color:#d4d0ab"># Create request json</span>
<span style="color:#00e0e0">import</span> json

request_json <span style="color:#00e0e0">=</span> <span style="color:#fefefe">{</span>
    <span style="color:#abe338">"input_data"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">{</span>
        <span style="color:#abe338">"columns"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"prompt"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"index"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#00e0e0">0</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"data"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"a photo of sks dog in a bucket"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
    <span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
    <span style="color:#abe338">"params"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">{</span>
        <span style="color:#abe338">"height"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"width"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">512</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"num_inference_steps"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">50</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"guidance_scale"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">7.5</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"negative_prompt"</span><span style="color:#fefefe">:</span> <span style="color:#fefefe">[</span><span style="color:#abe338">"blurry; three legs"</span><span style="color:#fefefe">]</span><span style="color:#fefefe">,</span>
        <span style="color:#abe338">"num_images_per_prompt"</span><span style="color:#fefefe">:</span> <span style="color:#00e0e0">2</span><span style="color:#fefefe">,</span>
    <span style="color:#fefefe">}</span><span style="color:#fefefe">,</span>
<span style="color:#fefefe">}</span>

request_file_name <span style="color:#00e0e0">=</span> <span style="color:#abe338">"sample_request_data.json"</span>
<span style="color:#00e0e0">with</span> <span style="color:#abe338">open</span><span style="color:#fefefe">(</span>request_file_name<span style="color:#fefefe">,</span> <span style="color:#abe338">"w"</span><span style="color:#fefefe">)</span> <span style="color:#00e0e0">as</span> request_file<span style="color:#fefefe">:</span>
    json<span style="color:#fefefe">.</span>dump<span style="color:#fefefe">(</span>request_json<span style="color:#fefefe">,</span> request_file<span style="color:#fefefe">)</span>
responses <span style="color:#00e0e0">=</span> workspace_ml_client<span style="color:#fefefe">.</span>online_endpoints<span style="color:#fefefe">.</span>invoke<span style="color:#fefefe">(</span>
    endpoint_name<span style="color:#00e0e0">=</span>online_endpoint_name<span style="color:#fefefe">,</span>
    deployment_name<span style="color:#00e0e0">=</span>deployment_name<span style="color:#fefefe">,</span>
    request_file<span style="color:#00e0e0">=</span>request_file_name<span style="color:#fefefe">,</span>
<span style="color:#fefefe">)</span>
responses <span style="color:#00e0e0">=</span> json<span style="color:#fefefe">.</span>loads<span style="color:#fefefe">(</span>responses<span style="color:#fefefe">)</span>


<span style="color:#00e0e0">import</span> base64
<span style="color:#00e0e0">from</span> io <span style="color:#00e0e0">import</span> BytesIO
<span style="color:#00e0e0">from</span> PIL <span style="color:#00e0e0">import</span> Image

<span style="color:#00e0e0">for</span> response <span style="color:#00e0e0">in</span> responses<span style="color:#fefefe">:</span>
    base64_string <span style="color:#00e0e0">=</span> response<span style="color:#fefefe">[</span><span style="color:#abe338">"generated_image"</span><span style="color:#fefefe">]</span>
    image_stream <span style="color:#00e0e0">=</span> BytesIO<span style="color:#fefefe">(</span>base64<span style="color:#fefefe">.</span>b64decode<span style="color:#fefefe">(</span>base64_string<span style="color:#fefefe">)</span><span style="color:#fefefe">)</span>
    image <span style="color:#00e0e0">=</span> Image<span style="color:#fefefe">.</span><span style="color:#abe338">open</span><span style="color:#fefefe">(</span>image_stream<span style="color:#fefefe">)</span>
    display<span style="color:#fefefe">(</span>image<span style="color:#fefefe">)</span></code></span></span>

结论

使用 Dreambooth 对 FLUX 模型进行微调是针对特定应用定制生成式 AI 模型的有效方法。按照本博客中概述的步骤,您可以利用 FLUX.1 [dev] 模型的优势,并使用您独特的数据集对其进行增强,从而实现高质量、个性化的输出。无论您是在从事创意项目、研究还是商业应用,这种方法都可以为提升您的 AI 能力提供强大的解决方案

相关推荐
paixiaoxin19 分钟前
CV-OCR经典论文解读|An Empirical Study of Scaling Law for OCR/OCR 缩放定律的实证研究
人工智能·深度学习·机器学习·生成对抗网络·计算机视觉·ocr·.net
Altair澳汰尔44 分钟前
数据分析和AI丨知识图谱,AI革命中数据集成和模型构建的关键推动者
人工智能·算法·机器学习·数据分析·知识图谱
call me by ur name1 小时前
VLM--CLIP作分类任务的损失函数
人工智能·机器学习·分类
Python机器学习AI1 小时前
分类模型的预测概率解读:3D概率分布可视化的直观呈现
算法·机器学习·分类
dwjf3212 小时前
机器学习(四)-回归模型评估指标
人工智能·机器学习·线性回归
电子海鸥3 小时前
迁移学习--fasttext概述
人工智能·机器学习·迁移学习
dwjf3214 小时前
机器学习(三)-多项式线性回归
人工智能·机器学习·线性回归
愚者大大7 小时前
优化算法(SGD,RMSProp,Ada)
人工智能·算法·机器学习
dundunmm8 小时前
数据挖掘之认识数据
人工智能·机器学习·信息可视化·数据挖掘