用Hugging Face Transformers,高效部署:多显卡量化感知训练并转换为ONNX格式的多标签分类模型

文章目录

要在多显卡上进行量化感知训练(QAT),然后将量化后的模型转换为 ONNX 格式并部署到移动设备,可以按照以下步骤进行:

环境准备

确保你已经安装了必要的库:

bash 复制代码
pip install torch torchvision transformers onnx

数据准备

假设你有一个数据集,每张图片对应多个标签,数据格式类似于:

  • images/
    • image1.jpg
    • image2.jpg
    • ...
  • labels.csv

labels.csv 文件内容示例如下:

复制代码
filename,label1,label2,...
image1.jpg,1,0,...
image2.jpg,0,1,...
...

数据集定义

定义一个自定义的数据集类来加载图像和对应的标签:

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

class MultiLabelDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        labels = torch.tensor(self.labels.iloc[idx, 1:].astype('float32'))

        if self.transform:
            image = self.transform(image)

        return image, labels

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = MultiLabelDataset(csv_file='labels.csv', root_dir='images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

模型定义

加载预训练的 EfficientNet-B0 模型并修改其用于多标签分类:

python 复制代码
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch.nn as nn
import torch

feature_extractor = AutoFeatureExtractor.from_pretrained("google/efficientnet-b0")
model = AutoModelForImageClassification.from_pretrained("google/efficientnet-b0")

num_labels = 5  # 根据你的标签数量修改
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_labels),
    nn.Sigmoid()
)

# 设置量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
model_prepared = torch.quantization.prepare_qat(model_fused)

多显卡训练

使用 DataParallel 进行多显卡训练:

python 复制代码
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)

num_epochs = 10
model_prepared = torch.nn.DataParallel(model_prepared)  # 包装模型

model_prepared.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU 上
        optimizer.zero_grad()
        
        outputs = model_prepared(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")

print("Finished Training")

# 转换为量化模型
model_quantized = torch.quantization.convert(model_prepared.module)  # 取消 DataParallel 包装

模型保存与 ONNX 转换

将量化后的模型转换为 ONNX 格式:

python 复制代码
import torch.onnx

# 创建一个示例输入
example_input = torch.randn(1, 3, 224, 224).cuda()

# 导出为 ONNX 模型
torch.onnx.export(
    model_quantized,
    example_input,
    "efficientnet_b0_quantized.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=12,
    do_constant_folding=True
)

验证 ONNX 模型

使用 ONNX Runtime 验证转换后的模型:

bash 复制代码
pip install onnxruntime
python 复制代码
import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model = onnx.load("efficientnet_b0_quantized.onnx")
onnx.checker.check_model(onnx_model)

# 使用 ONNX Runtime 运行模型
ort_session = ort.InferenceSession("efficientnet_b0_quantized.onnx")

# 准备输入数据
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

inputs = {ort_session.get_inputs()[0].name: to_numpy(example_input)}
outputs = ort_session.run(None, inputs)

print(outputs)

部署到移动设备

将 ONNX 模型文件 (efficientnet_b0_quantized.onnx) 部署到移动设备上,并使用适合的 ONNX 推理库(如 ONNX Runtime for Mobile)进行推理。在移动设备上,可以使用 ONNX Runtime for Mobile 或其他支持 ONNX 的库来加载和运行模型。

通过这些步骤,你可以在多显卡上进行量化感知训练,并将量化后的模型转换为 ONNX 格式,以便在移动设备上进行高效的推理。

相关推荐
常丛丛10 小时前
5.6 LangGraph-Edges理解-Agent图的道路系统
人工智能
雪隐10 小时前
个人电脑玩AI-08让5060 Ti给你打工——我拿 Unlimited-OCR扫了 600 页书,然后悟了
人工智能·后端
Coffeeee10 小时前
Prompt要花心思写,与 AI 对话的七个技巧
人工智能·aigc·ai编程
蝎子莱莱爱打怪11 小时前
Claude Code 官宣新升级:子智能体默认后台跑,你边聊它边干活
人工智能
武子康11 小时前
调查研究-206 DeepSeek DSpark 深度解析:大模型推理加速,正在从“模型能力”转向“系统工程”
人工智能·agent·deepseek
甲维斯11 小时前
最佳work模型sonnet5来了,直接就能用!
人工智能
IT_陈寒12 小时前
React hooks 闭包陷阱把我的状态吃掉了,原来问题出在这里
前端·人工智能·后端
冬奇Lab1 天前
Workflow 系列(03):状态管理——持久化、幂等性与版本绑定
人工智能·工作流引擎
冬奇Lab1 天前
每日一个开源项目(第146篇):openpilot - 开源自动驾驶辅助系统,曾在 Consumer Reports 评测中超过特斯拉 Autopilot
人工智能·开源·自动驾驶