训练时量化(Quantization-aware Training, QAT)是一种在模型训练过程中,通过模拟低精度量化效应来增强模型对量化操作的鲁棒性的技术。与后训练量化不同,QAT 允许模型在训练过程中考虑到量化引入的误差,从而在实际部署时使用低精度进行推理时能够维持更高的性能。
1. 假量化节点插入(Fake Quantization Nodes)
在训练过程中,通过在网络中插入假量化节点来模拟量化和反量化的过程。这些节点在前向传播过程中将权重和激活值量化到指定的数值范围和精度(如INT8),然后再反量化回浮点数,以进行后续的计算。通过这种方式,模型可以适应量化带来的信息损失。
bash
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
class QuantAwareNet(nn.Module):
def __init__(self):
super(QuantAwareNet, self).__init__()
self.quant = QuantStub() # 新插入内容
self.dequant = DeQuantStub() # 新插入内容
self.fc1 = nn.Linear(784, 256)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(256, 10)
def forward(self, x):
x = self.quant(x) # 新插入内容
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.dequant(x) # 新插入内容
return x
2. 量化配置
在PyTorch中,量化配置(QConfig)指定了模型量化过程中将使用的量化方案和算子。通过使用不同的QConfig,你可以控制如何量化模型中的权重和激活,这对于模型的性能和精度具有重要影响。
2.1 量化配置函数 get_default_qat_qconfig
get_default_qat_qconfig 是PyTorch提供的一个函数,用于获取用于量化感知训练(QAT)的默认量化配置。这个函数的一个重要参数是后端,通常是 'fbgemm' 或 'qnnpack':
- 'fbgemm': 主要用于服务器和桌面平台上的x86架构,支持INT8量化。
- 'qnnpack': 适用于移动设备,也支持INT8量化,优化了ARM架构。
bash
from torch.quantization import get_default_qconfig
qconfig = get_default_qconfig('fbgemm')
这个函数会设置一个QConfig,其中包括针对权重和激活的量化方案。对于QAT,权重通常在前向过程中进行伪量化,而激活则在训练时进行动态量化。
2.2 可以设置的其他配置选项
PyTorch允许用户自定义QConfig,以适应特定的需求或实验不同的量化方案。自定义QConfig通常涉及以下部分:
2.2.1 量化方案:
- torch.quantization.default_observer:
默认的观察者,用于激活,基于移动平均和最小最大值自动调整量化参数。 - torch.quantization.default_per_channel_weight_observer:
用于权重的通道级观察者,每个输出通道有独立的量化参数。
2.2.2 量化和反量化函数:
- torch.quantization.FakeQuantize: 实现伪量化和反量化,模拟量化的效果而不改变底层数据类型。
创建自定义的QConfig:
bash
from torch.quantization import QConfig, default_observer, default_per_channel_weight_observer
custom_qconfig = QConfig(
activation=default_observer.with_args(dtype=torch.qint8),
weight=default_per_channel_weight_observer.with_args(dtype=torch.qint8)
)
2.3 使用自定义QConfig
可以应用到模型的特定部分或整个模型上
bash
model.fc1.qconfig = custom_qconfig # 应用到模型的一个特定层
和
bash
# 应用到整个模型
from torch.quantization import prepare_qat
model.qconfig = custom_qconfig
model = prepare_qat(model, inplace=True)
3. 量化感知训练
bash
import torch
import torch.nn as nn
import torch.optim as optim
from torch.quantization import QuantStub, DeQuantStub, prepare_qat, convert
# 实例化模型
model = MyQuantizedModel()
# 指定量化配置
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
# 准备量化感知训练,
model = prepare_qat(model)
# 训练配置
optimizer = optim.SGD(model.parameters(), lr=0.01)
# 训练循环
for epoch in range(num_epochs):
for data, target in dataloader:
optimizer.zero_grad()
output = model(data)
loss = loss_fn(output, target)
loss.backward()
optimizer.step()
# 转换模型为完全量化
if epoch == num_epochs - 1:
model = convert(model.eval(), inplace=True)
4. 量化推理测试
bash
import torch
from torch.quantization import convert
def test_quantized_model(model, dataloader, device='cpu'):
model = convert(model.eval(), inplace=True)
model.to(device) # 确保模型在正确的设备上
correct = 0
total = 0
with torch.no_grad(): # 关闭梯度计算,因为我们只做推理
for data, targets in dataloader:
data, targets = data.to(device), targets.to(device) # 移动数据到相应设备
outputs = model(data) # 前向推理
_, predicted = torch.max(outputs.data, 1) # 获取预测结果
total += targets.size(0)
correct += (predicted == targets).sum().item()
accuracy = 100 * correct / total
print(f'Accuracy of the quantized model on the test data: {accuracy:.2f}%')
# 'test_loader' 是用于测试的 DataLoader
# 测试模型
# test_quantized_model(quantized_model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'
5.完整参考代码
bash
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
from torch.quantization import QuantStub, DeQuantStub
import torch.optim as optim
from torch.quantization import get_default_qconfig, prepare_qat, convert
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 标准化
])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
class QuantizedCNN(nn.Module):
def __init__(self):
super(QuantizedCNN, self).__init__()
self.quant = QuantStub()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
self.dequant = DeQuantStub()
def forward(self, x):
# x = self.quant(x)
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
x = self.dequant(x)
return x
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = QuantizedCNN().to(device)
model.qconfig = get_default_qconfig('qnnpack')
# # 准备模型进行量化感知训练
model = prepare_qat(model, inplace=True)
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
criterion = nn.CrossEntropyLoss()
# 训练循环
num_epochs = 10
for epoch in range(num_epochs):
model.train()
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data[0].to(device), data[1].to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 2000 == 1999:
print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 2000:.3f}')
running_loss = 0.0
# 切换到评估模式进行测试
model.eval()
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
# 在最后一个epoch后完成量化
if epoch == num_epochs - 1:
model = convert(model.eval(), inplace=True)
print("Model quantization completed.")