pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

bash 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

class QuantAwareNet(nn.Module):
    def __init__(self):
        super(QuantAwareNet, self).__init__()
        self.quant = QuantStub() # 新插入内容
        self.dequant = DeQuantStub() # 新插入内容
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = self.quant(x) # 新插入内容
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x) # 新插入内容
        return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 'fbgemm' 或 'qnnpack':

  • 'fbgemm': 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • 'qnnpack': 适用于移动设备,也支持INT8量化,优化了ARM架构。
bash 复制代码
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

bash 复制代码
from torch.quantization import QConfig, default_observer, default_per_channel_weight_observer

custom_qconfig = QConfig(
    activation=default_observer.with_args(dtype=torch.qint8),
    weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

bash 复制代码
model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

bash 复制代码
# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

bash 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert

# 实例化模型
model = MyQuantizedModel() 

# 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# 准备量化感知训练,
model = prepare_qat(model)

# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练循环
for epoch in range(num_epochs):
    for data, target in dataloader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

    # 转换模型为完全量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)

4. 量化推理测试

bash 复制代码
import torch
from torch.quantization import convert

def test_quantized_model(model, dataloader, device='cpu'):
    model = convert(model.eval(), inplace=True)
    model.to(device)  # 确保模型在正确的设备上

    correct = 0
    total = 0

    with torch.no_grad():  # 关闭梯度计算,因为我们只做推理
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备
            outputs = model(data)  # 前向推理
            _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')

# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

bash 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
    
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert


# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class QuantizedCNN(nn.Module):
    def __init__(self):
        super(QuantizedCNN, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        # x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x
    
    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')

# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    # 切换到评估模式进行测试
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

    # 在最后一个epoch后完成量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)
        print("Model quantization completed.")
相关推荐
ai产品老杨3 分钟前
部署神经网络时计算图的优化方法
人工智能·深度学习·神经网络·安全·机器学习·开源
fanxbl9576 分钟前
深入探索离散 Hopfield 神经网络
人工智能·神经网络
TaoYuan__19 分钟前
深度学习概览
人工智能·深度学习
云起无垠24 分钟前
第74期 | GPTSecurity周报
人工智能·安全·网络安全
workflower33 分钟前
AI+自动驾驶
人工智能·机器学习·自动驾驶
爱技术的小伙子43 分钟前
【ChatGPT】 让ChatGPT模拟客户服务对话与应答策略
人工智能·chatgpt
OptimaAI1 小时前
【 LLM论文日更|检索增强:大型语言模型是强大的零样本检索器 】
人工智能·深度学习·语言模型·自然语言处理·nlp
谢眠1 小时前
机器学习day4-朴素贝叶斯分类和决策树
人工智能·机器学习
HelpHelp同学1 小时前
教育机构内部知识库:教学资源的集中管理与优化
人工智能·知识库软件·搭建知识库·知识管理工具
深度学习lover1 小时前
<项目代码>YOLOv8 番茄识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·番茄识别