pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

bash 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

class QuantAwareNet(nn.Module):
    def __init__(self):
        super(QuantAwareNet, self).__init__()
        self.quant = QuantStub() # 新插入内容
        self.dequant = DeQuantStub() # 新插入内容
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = self.quant(x) # 新插入内容
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x) # 新插入内容
        return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 'fbgemm' 或 'qnnpack':

  • 'fbgemm': 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • 'qnnpack': 适用于移动设备,也支持INT8量化,优化了ARM架构。
bash 复制代码
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

bash 复制代码
from torch.quantization import QConfig, default_observer, default_per_channel_weight_observer

custom_qconfig = QConfig(
    activation=default_observer.with_args(dtype=torch.qint8),
    weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

bash 复制代码
model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

bash 复制代码
# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

bash 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert

# 实例化模型
model = MyQuantizedModel() 

# 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# 准备量化感知训练,
model = prepare_qat(model)

# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练循环
for epoch in range(num_epochs):
    for data, target in dataloader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

    # 转换模型为完全量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)

4. 量化推理测试

bash 复制代码
import torch
from torch.quantization import convert

def test_quantized_model(model, dataloader, device='cpu'):
    model = convert(model.eval(), inplace=True)
    model.to(device)  # 确保模型在正确的设备上

    correct = 0
    total = 0

    with torch.no_grad():  # 关闭梯度计算,因为我们只做推理
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备
            outputs = model(data)  # 前向推理
            _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')

# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

bash 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
    
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert


# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class QuantizedCNN(nn.Module):
    def __init__(self):
        super(QuantizedCNN, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        # x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x
    
    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')

# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    # 切换到评估模式进行测试
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

    # 在最后一个epoch后完成量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)
        print("Model quantization completed.")
相关推荐
码上掘金1 分钟前
基于深度学习的行人计数与人群密度分析系统设计与实现
人工智能·深度学习
北京软秦科技有限公司6 分钟前
灌封胶耐候测试报告为何更依赖“AI报告审核”?IACheck如何提升长期环境可靠性判断精度
人工智能
程序员果子9 分钟前
Agent设计手册:四层架构、工程约束、框架选型
人工智能·agent·智能体·agent框架
2401_8322981013 分钟前
SaaS 到 Agent-as-a-Service——OpenClaw 生态爆发,开启企业数字化新时代
人工智能
AI产品测评官20 分钟前
2026年AI招聘架构深潜:多Agent协同如何打造主动出击智能体代表?
人工智能·架构
captain_AIouo25 分钟前
Captain AI:全阶段适配不同规模OZON商家
大数据·人工智能·经验分享·aigc
HyperAI超神经35 分钟前
在线教程丨支持600+语言,小米开源OmniVoice:仅需3-10秒参考音频实现语音克隆
人工智能·音频识别·语音生成
段一凡-华北理工大学37 分钟前
【高炉炼铁领域炉温监测、预警、调控智能体设计与应用】~系列文章14:时序数据处理:捕捉温度的脉搏
人工智能·高炉炼铁·工业智能体·炉温监测·炉温预警
情绪总是阴雨天~41 分钟前
提示词工程实战:金融行业 Prompt 设计与大模型应用
人工智能·金融·prompt
汽车仪器仪表相关领域42 分钟前
Kvaser Air Bridge Light HS:免配置工业级无线 CAN 桥接器,70 米稳定传输,移动设备与动态场景的 CAN 互联理想之选
人工智能·功能测试·安全·单元测试·汽车·可用性测试