pytorch量化训练

训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。

1. 假量化节点插入(Fake Quantization Nodes)

在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。

bash 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub

class QuantAwareNet(nn.Module):
    def __init__(self):
        super(QuantAwareNet, self).__init__()
        self.quant = QuantStub() # 新插入内容
        self.dequant = DeQuantStub() # 新插入内容
        self.fc1 = nn.Linear(784, 256)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x):
        x = self.quant(x) # 新插入内容
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.dequant(x) # 新插入内容
        return x

2. 量化配置

在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。

2.1 量化配置函数 get_default_qat_qconfig

get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 'fbgemm' 或 'qnnpack':

  • 'fbgemm': 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
  • 'qnnpack': 适用于移动设备,也支持INT8量化,优化了ARM架构。
bash 复制代码
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')

这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。

2.2 可以设置的其他配置选项

PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:

2.2.1 量化方案:

  • torch.quantization.default_observer:
    默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。
  • torch.quantization.default_per_channel_weight_observer:
    用于权重的通道级观察者,每个输出通道有独立的量化参数。

2.2.2 量化和反量化函数:

  • torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。

创建自定义的QConfig:

bash 复制代码
from torch.quantization import QConfig, default_observer, default_per_channel_weight_observer

custom_qconfig = QConfig(
    activation=default_observer.with_args(dtype=torch.qint8),
    weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)

2.3 使用自定义QConfig

可以应用到模型的特定部分或整个模型上

bash 复制代码
model.fc1.qconfig = custom_qconfig  # 应用到模型的一个特定层 

bash 复制代码
# 应用到整个模型
from torch.quantization import prepare_qat  
model.qconfig = custom_qconfig 
model = prepare_qat(model, inplace=True) 

3. 量化感知训练

bash 复制代码
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert

# 实例化模型
model = MyQuantizedModel() 

# 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# 准备量化感知训练,
model = prepare_qat(model)

# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)

# 训练循环
for epoch in range(num_epochs):
    for data, target in dataloader:
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

    # 转换模型为完全量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)

4. 量化推理测试

bash 复制代码
import torch
from torch.quantization import convert

def test_quantized_model(model, dataloader, device='cpu'):
    model = convert(model.eval(), inplace=True)
    model.to(device)  # 确保模型在正确的设备上

    correct = 0
    total = 0

    with torch.no_grad():  # 关闭梯度计算,因为我们只做推理
        for data, targets in dataloader:
            data, targets = data.to(device), targets.to(device)  # 移动数据到相应设备
            outputs = model(data)  # 前向推理
            _, predicted = torch.max(outputs.data, 1)  # 获取预测结果
            total += targets.size(0)
            correct += (predicted == targets).sum().item()

    accuracy = 100 * correct / total
    print(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')

# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'

5.完整参考代码

bash 复制代码
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
    
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert


# 定义数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  # 标准化
])

# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


class QuantizedCNN(nn.Module):
    def __init__(self):
        super(QuantizedCNN, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(3, 16, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5)
        self.fc1 = nn.Linear(32 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        self.dequant = DeQuantStub()

    def forward(self, x):
        # x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        x = self.dequant(x)
        return x
    
    


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')

# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)

optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()

# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data[0].to(device), data[1].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        if i % 2000 == 1999:
            print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
            running_loss = 0.0

    # 切换到评估模式进行测试
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data[0].to(device), data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print('Accuracy of the network on the 10000 test images: %d %%' % (
        100 * correct / total))

    # 在最后一个epoch后完成量化
    if epoch == num_epochs - 1:
        model = convert(model.eval(), inplace=True)
        print("Model quantization completed.")
相关推荐
orion-orion8 分钟前
概率论沉思录:初等假设检验
人工智能·概率论·科学哲学
FL16238631299 分钟前
医学数据集肺肿瘤分割数据集labelme格式687张1类别
深度学习
人工智能研究所37 分钟前
MaskGCT——开源文本转语音模型,可模仿任何人说话声音
人工智能·文本转语音·文本转音频
宸码44 分钟前
【项目实战】ISIC 数据集上的实验揭秘:UNet + SENet、Spatial Attention 和 CBAM 的最终表现
人工智能·python·深度学习·神经网络·机器学习·计算机视觉
老板多放点香菜1 小时前
AI、大数据、机器学习、深度学习、神经网络之间的关系
大数据·人工智能·深度学习·神经网络·机器学习
volcanical2 小时前
MoCo 对比自监督学习
人工智能·学习·机器学习
视觉语言导航2 小时前
ACL-2024 | MapGPT:基于地图引导提示和自适应路径规划机制的视觉语言导航
人工智能·具身智能
四口鲸鱼爱吃盐2 小时前
Pytorch | 从零构建Vgg对CIFAR10进行分类
人工智能·pytorch·分类
bielaile_leisigoule2 小时前
智能与人工智能控制:机器学习、深度学习、模糊逻辑、强化学习
人工智能
aiblog2 小时前
AIGC:图像风格迁移技术实现猜想
人工智能·aigc