模型量化与知识蒸馏是深度学习模型轻量化的两大核心技术,广泛应用于移动端、嵌入式等低资源部署场景。二者核心逻辑完全不同,常搭配使用实现"高精度、低体积、高速度"的落地效果。本文融合理论与实战,精简冗余内容,搭配可直接运行的PyTorch极简代码,快速吃透两项技术。
前置环境:
bash
pip install torch torchvision
一、核心基础原理与通俗区别
1. 模型量化(Quantization)
核心定义 :不改变神经网络结构,仅压缩参数数值精度,将模型默认的FP32(32位浮点)参数转为INT8(8位整型)等低精度格式,属于数值压缩、无训练轻量化技术。
通俗理解 :原本用小数精准记录模型参数,量化后用整数近似记录,大幅降低显存占用、缩减模型体积、提升推理速度,仅存在极小的可控精度损失。工业主流为后训练量化(PTQ),无需重新训练,落地成本极低。
2. 知识蒸馏(Distillation)
核心定义 :依托"大模型教小模型"的逻辑,用精度高、参数量大的教师模型 ,训练结构简单、体量更小的学生模型 ,属于结构级、有训练精度迁移技术。
通俗理解:大模型不仅输出最终分类结果(硬标签),还输出类别概率分布(软标签),承载模型学习到的"暗知识"。学生模型同时学习真实标签和教师模型的推理逻辑,突破小模型的精度上限,实现小模型媲美大模型的效果。
3. 核心区别与组合逻辑
-
量化 :提速压缩、无需训练、轻微掉精度,优化推理速度与体积
-
蒸馏 :提升小模型精度、需要训练、无体积压缩,优化模型泛化能力
-
工业最优组合:先蒸馏提升小模型精度,再量化压缩提速,用蒸馏补偿量化的精度损失,实现1+1>2的轻量化效果
二、模型量化 极简代码实战(PTQ后训练量化)
1. 实战思路
搭建简易全连接模型,对比FP32原始模型与INT8量化模型的推理速度、精度差异,全程无需训练,仅通过数据校准完成量化。
2. 可运行代码
python
import torch
import torch.nn as nn
import time
# 搭建简易FP32原始模型
class SimpleNet(nn.Module):
def __init__(self):
super().__init__()
self.fc = nn.Linear(128, 10)
def forward(self, x):
return self.fc(x)
# 初始化模型与测试数据
model = SimpleNet().eval()
x = torch.randn(32, 128)
# 测试原始FP32模型推理速度
start = time.time()
for _ in range(1000):
out = model(x)
fp32_time = time.time() - start
print(f"FP32原始模型耗时: {fp32_time:.4f}s")
# 核心INT8量化流程
model.qconfig = torch.ao.quantization.get_default_qconfig("x86")
torch.ao.quantization.prepare(model, inplace=True)
with torch.no_grad():
model(x) # 数据校准
quant_model = torch.ao.quantization.convert(model, inplace=True)
# 测试量化后模型性能
start = time.time()
for _ in range(1000):
out_quant = quant_model(x)
int8_time = time.time() - start
print(f"INT8量化模型耗时: {int8_time:.4f}s")
print(f"推理加速比: {fp32_time/int8_time:.2f}x")
print(f"量化平均精度误差: {torch.abs(out - out_quant).mean():.6f}")
3. 实战结果总结
量化后模型推理速度提升1.5~3倍,精度误差微乎其微,无需训练、操作极简,是快速落地轻量化的首选方案。
三、知识蒸馏 极简代码实战(软标签蒸馏)
1. 实战思路
搭建复杂教师模型、轻量化学生模型,通过**硬标签(真实数据)+软标签(教师输出)**双损失训练,让小模型学习大模型的暗知识,提升泛化精度。
2. 可运行代码
python
import torch
import torch.nn as nn
import torch.nn.functional as F
# 蒸馏超参数
TEMPERATURE = 2.0 # 软化概率分布
ALPHA = 0.7 # 软标签损失权重
# 教师模型(大模型、高精度)
class TeacherNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 10)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
# 学生模型(小模型、轻量化)
class StudentNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(128, 64)
self.fc2 = nn.Linear(64, 10)
def forward(self, x):
return F.relu(self.fc2(F.relu(self.fc1(x))))
# 初始化组件
teacher = TeacherNet().eval()
student = StudentNet()
optimizer = torch.optim.Adam(student.parameters(), lr=1e-3)
x = torch.randn(256, 128)
y_true = torch.randint(0, 10, (256,))
# 蒸馏训练流程
for epoch in range(20):
optimizer.zero_grad()
stu_logits = student(x)
# 硬标签损失(贴合真实结果)
loss_hard = F.cross_entropy(stu_logits, y_true)
# 软标签损失(贴合教师推理逻辑)
with torch.no_grad():
tea_logits = teacher(x)
tea_soft = F.softmax(tea_logits / TEMPERATURE, dim=1)
stu_soft = F.log_softmax(stu_logits / TEMPERATURE, dim=1)
loss_soft = F.kl_div(stu_soft, tea_soft, reduction="batchmean") * (TEMPERATURE ** 2)
# 融合损失更新模型
loss_total = ALPHA * loss_soft + (1 - ALPHA) * loss_hard
loss_total.backward()
optimizer.step()
if (epoch + 1) % 5 == 0:
print(f"Epoch{epoch+1} | 总损失:{loss_total.item():.4f}")
3. 核心要点
温度系数软化概率分布,挖掘类别隐性关联;双损失融合兼顾基础精度与泛化能力,让参数量仅为教师1/8的学生模型,精度远超原生训练的小模型。
四、蒸馏+量化 工业组合实战
1. 实战思路
先通过蒸馏得到高精度学生模型,再对学生模型做INT8量化,兼顾高精度、小体积、快推理,是工业部署标准方案。
2. 组合实战代码
python
import time
# 蒸馏后的学生模型量化
student.eval()
student.qconfig = torch.ao.quantization.get_default_qconfig("x86")
torch.ao.quantization.prepare(student, inplace=True)
with torch.no_grad():
student(x)
final_model = torch.ao.quantization.convert(student, inplace=True)
# 对比原生大模型与轻量化组合模型性能
test_x = torch.randn(1000, 128)
# 教师大模型推理
start = time.time()
with torch.no_grad():
teacher(test_x)
teacher_time = time.time() - start
# 蒸馏+量化模型推理
start = time.time()
with torch.no_grad():
final_model(test_x)
light_time = time.time() - start
print(f"教师大模型耗时: {teacher_time:.4f}s")
print(f"轻量化组合模型耗时: {light_time:.4f}s")
print(f"整体加速比: {teacher_time/light_time:.2f}x")
五、全文核心总结
-
量化:无损结构、无需训练,压缩数值精度实现提速瘦身,轻微精度损耗可忽略;
-
蒸馏:无损精度、需要训练,通过师生学习迁移知识,提升小模型泛化能力;
-
组合方案:先蒸馏保精度,后量化提速度,完美适配终端、嵌入式等低资源部署场景。