用Hugging Face Transformers,高效部署:多显卡量化感知训练并转换为ONNX格式的多标签分类模型

文章目录

要在多显卡上进行量化感知训练(QAT),然后将量化后的模型转换为 ONNX 格式并部署到移动设备,可以按照以下步骤进行:

环境准备

确保你已经安装了必要的库:

bash 复制代码
pip install torch torchvision transformers onnx

数据准备

假设你有一个数据集,每张图片对应多个标签,数据格式类似于:

  • images/
    • image1.jpg
    • image2.jpg
    • ...
  • labels.csv

labels.csv 文件内容示例如下:

filename,label1,label2,...
image1.jpg,1,0,...
image2.jpg,0,1,...
...

数据集定义

定义一个自定义的数据集类来加载图像和对应的标签:

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

class MultiLabelDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        labels = torch.tensor(self.labels.iloc[idx, 1:].astype('float32'))

        if self.transform:
            image = self.transform(image)

        return image, labels

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = MultiLabelDataset(csv_file='labels.csv', root_dir='images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

模型定义

加载预训练的 EfficientNet-B0 模型并修改其用于多标签分类:

python 复制代码
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch.nn as nn
import torch

feature_extractor = AutoFeatureExtractor.from_pretrained("google/efficientnet-b0")
model = AutoModelForImageClassification.from_pretrained("google/efficientnet-b0")

num_labels = 5  # 根据你的标签数量修改
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_labels),
    nn.Sigmoid()
)

# 设置量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
model_prepared = torch.quantization.prepare_qat(model_fused)

多显卡训练

使用 DataParallel 进行多显卡训练:

python 复制代码
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)

num_epochs = 10
model_prepared = torch.nn.DataParallel(model_prepared)  # 包装模型

model_prepared.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU 上
        optimizer.zero_grad()
        
        outputs = model_prepared(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")

print("Finished Training")

# 转换为量化模型
model_quantized = torch.quantization.convert(model_prepared.module)  # 取消 DataParallel 包装

模型保存与 ONNX 转换

将量化后的模型转换为 ONNX 格式:

python 复制代码
import torch.onnx

# 创建一个示例输入
example_input = torch.randn(1, 3, 224, 224).cuda()

# 导出为 ONNX 模型
torch.onnx.export(
    model_quantized,
    example_input,
    "efficientnet_b0_quantized.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=12,
    do_constant_folding=True
)

验证 ONNX 模型

使用 ONNX Runtime 验证转换后的模型:

bash 复制代码
pip install onnxruntime
python 复制代码
import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model = onnx.load("efficientnet_b0_quantized.onnx")
onnx.checker.check_model(onnx_model)

# 使用 ONNX Runtime 运行模型
ort_session = ort.InferenceSession("efficientnet_b0_quantized.onnx")

# 准备输入数据
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

inputs = {ort_session.get_inputs()[0].name: to_numpy(example_input)}
outputs = ort_session.run(None, inputs)

print(outputs)

部署到移动设备

将 ONNX 模型文件 (efficientnet_b0_quantized.onnx) 部署到移动设备上,并使用适合的 ONNX 推理库(如 ONNX Runtime for Mobile)进行推理。在移动设备上,可以使用 ONNX Runtime for Mobile 或其他支持 ONNX 的库来加载和运行模型。

通过这些步骤,你可以在多显卡上进行量化感知训练,并将量化后的模型转换为 ONNX 格式,以便在移动设备上进行高效的推理。

相关推荐
YSGZJJ20 分钟前
股指期货的套保策略如何精准选择和规避风险?
人工智能·区块链
无脑敲代码,bug漫天飞22 分钟前
COR 损失函数
人工智能·机器学习
HPC_fac130520678161 小时前
以科学计算为切入点:剖析英伟达服务器过热难题
服务器·人工智能·深度学习·机器学习·计算机视觉·数据挖掘·gpu算力
小陈phd4 小时前
OpenCV从入门到精通实战(九)——基于dlib的疲劳监测 ear计算
人工智能·opencv·计算机视觉
Guofu_Liao5 小时前
大语言模型---LoRA简介;LoRA的优势;LoRA训练步骤;总结
人工智能·语言模型·自然语言处理·矩阵·llama
ZHOU_WUYI9 小时前
3.langchain中的prompt模板 (few shot examples in chat models)
人工智能·langchain·prompt
如若1239 小时前
主要用于图像的颜色提取、替换以及区域修改
人工智能·opencv·计算机视觉
老艾的AI世界9 小时前
AI翻唱神器,一键用你喜欢的歌手翻唱他人的曲目(附下载链接)
人工智能·深度学习·神经网络·机器学习·ai·ai翻唱·ai唱歌·ai歌曲
DK221519 小时前
机器学习系列----关联分析
人工智能·机器学习
Robot2519 小时前
Figure 02迎重大升级!!人形机器人独角兽[Figure AI]商业化加速
人工智能·机器人·微信公众平台