用Hugging Face Transformers,高效部署:多显卡量化感知训练并转换为ONNX格式的多标签分类模型

文章目录

要在多显卡上进行量化感知训练(QAT),然后将量化后的模型转换为 ONNX 格式并部署到移动设备,可以按照以下步骤进行:

环境准备

确保你已经安装了必要的库:

bash 复制代码
pip install torch torchvision transformers onnx

数据准备

假设你有一个数据集,每张图片对应多个标签,数据格式类似于:

  • images/
    • image1.jpg
    • image2.jpg
    • ...
  • labels.csv

labels.csv 文件内容示例如下:

复制代码
filename,label1,label2,...
image1.jpg,1,0,...
image2.jpg,0,1,...
...

数据集定义

定义一个自定义的数据集类来加载图像和对应的标签:

python 复制代码
import torch
from torch.utils.data import Dataset
from torchvision import transforms
from PIL import Image
import pandas as pd
import os

class MultiLabelDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels = pd.read_csv(csv_file)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        img_name = os.path.join(self.root_dir, self.labels.iloc[idx, 0])
        image = Image.open(img_name).convert('RGB')
        labels = torch.tensor(self.labels.iloc[idx, 1:].astype('float32'))

        if self.transform:
            image = self.transform(image)

        return image, labels

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])
dataset = MultiLabelDataset(csv_file='labels.csv', root_dir='images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

模型定义

加载预训练的 EfficientNet-B0 模型并修改其用于多标签分类:

python 复制代码
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
import torch.nn as nn
import torch

feature_extractor = AutoFeatureExtractor.from_pretrained("google/efficientnet-b0")
model = AutoModelForImageClassification.from_pretrained("google/efficientnet-b0")

num_labels = 5  # 根据你的标签数量修改
model.classifier = nn.Sequential(
    nn.Linear(model.classifier.in_features, num_labels),
    nn.Sigmoid()
)

# 设置量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fused = torch.quantization.fuse_modules(model, [['conv', 'bn', 'relu']])
model_prepared = torch.quantization.prepare_qat(model_fused)

多显卡训练

使用 DataParallel 进行多显卡训练:

python 复制代码
import torch.optim as optim

criterion = nn.BCELoss()
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)

num_epochs = 10
model_prepared = torch.nn.DataParallel(model_prepared)  # 包装模型

model_prepared.train()

for epoch in range(num_epochs):
    running_loss = 0.0
    for images, labels in dataloader:
        images, labels = images.cuda(), labels.cuda()  # 将数据移动到 GPU 上
        optimizer.zero_grad()
        
        outputs = model_prepared(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()

    print(f"Epoch {epoch+1}, Loss: {running_loss/len(dataloader)}")

print("Finished Training")

# 转换为量化模型
model_quantized = torch.quantization.convert(model_prepared.module)  # 取消 DataParallel 包装

模型保存与 ONNX 转换

将量化后的模型转换为 ONNX 格式:

python 复制代码
import torch.onnx

# 创建一个示例输入
example_input = torch.randn(1, 3, 224, 224).cuda()

# 导出为 ONNX 模型
torch.onnx.export(
    model_quantized,
    example_input,
    "efficientnet_b0_quantized.onnx",
    input_names=["input"],
    output_names=["output"],
    dynamic_axes={"input": {0: "batch_size"}, "output": {0: "batch_size"}},
    opset_version=12,
    do_constant_folding=True
)

验证 ONNX 模型

使用 ONNX Runtime 验证转换后的模型:

bash 复制代码
pip install onnxruntime
python 复制代码
import onnx
import onnxruntime as ort

# 加载 ONNX 模型
onnx_model = onnx.load("efficientnet_b0_quantized.onnx")
onnx.checker.check_model(onnx_model)

# 使用 ONNX Runtime 运行模型
ort_session = ort.InferenceSession("efficientnet_b0_quantized.onnx")

# 准备输入数据
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

inputs = {ort_session.get_inputs()[0].name: to_numpy(example_input)}
outputs = ort_session.run(None, inputs)

print(outputs)

部署到移动设备

将 ONNX 模型文件 (efficientnet_b0_quantized.onnx) 部署到移动设备上,并使用适合的 ONNX 推理库(如 ONNX Runtime for Mobile)进行推理。在移动设备上,可以使用 ONNX Runtime for Mobile 或其他支持 ONNX 的库来加载和运行模型。

通过这些步骤,你可以在多显卡上进行量化感知训练,并将量化后的模型转换为 ONNX 格式,以便在移动设备上进行高效的推理。

相关推荐
时序大模型5 分钟前
KDD2025 |DUET:时间 - 通道双聚类框架,多变量时序预测的 “全能选手”出现!
人工智能·机器学习·时间序列预测·时间序列·kdd2025
共绩算力29 分钟前
Ming Lite 万能模型对标 GPT-4o 的多模态能力
人工智能·共绩算力
猫先生Mr.Mao35 分钟前
2025年8月AGI月评|AI开源项目全解析:从智能体到3D世界,技术边界再突破
人工智能·开源·aigc·agi·ai资讯·分布式推理框架
深入理解GEE云计算1 小时前
遥感生态指数(RSEI):理论发展、方法论争与实践进展
javascript·人工智能·算法·机器学习
IT_陈寒1 小时前
从2秒到200ms:我是如何用JavaScript优化页面加载速度的🚀
前端·人工智能·后端
深度学习lover1 小时前
<项目代码>yolo织物缺陷识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·织物缺陷识别·项目代码
StarPrayers.1 小时前
Binary Classification& sigmoid 函数的逻辑回归&Decision Boundary
人工智能·分类·数据挖掘
渡我白衣1 小时前
C++:链接的两难 —— ODR中的强与弱符号机制
开发语言·c++·人工智能·深度学习·网络协议·算法·机器学习
大模型真好玩1 小时前
LangChain1.0速通指南(一)——LangChain1.0核心升级
人工智能·agent·mcp
私人珍藏库1 小时前
Parallels Desktop 26.1.1 for Mac 秋叶QiuChenly中文解锁直装版,最好用的macOS虚拟机
人工智能