1、模型蒸馏:
模型蒸馏(Model Distillation)是一种将复杂模型(教师模型)的知识迁移到更小、更高效的模型(学生模型) 的技术。其核心目的是在保持模型性能的同时,显著减少计算资源占用和推理时间,便于在边缘设备(如手机、IoT设备)上部署。 本文的实例是使用bert-base-uncased模型蒸馏出distilbert-base-uncased,模型蒸馏的核心步骤包括:
-
训练教师模型:在大规模数据上训练一个高性能但复杂的模型(如BERT、ResNet)。
-
生成软标签:用教师模型对训练数据预测,得到概率分布(软标签)。
-
训练学生模型:学生模型同时学习:
- 软标签(通过KL散度损失函数)。
- 真实标签(通过交叉熵损失)。
-
调整温度:高温训练,低温推理。 温度参数 T>1时:概率分布更平滑,凸显次要类别信息。 T=1时:标准softmax。 训练时使用较高的T,推理时恢复为T=1。
2、代码实例
首先定义一个dataset数据类:
python
class TextClassificationDataset(Dataset):
def __init__(self, texts, labels, tokenizer, max_length):
self.texts = texts #文本内容
self.labels = labels #文本对应的标签
self.tokenizer = tokenizer #token解析器
self.max_length = max_length
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
label = self.labels[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding='max_length',
max_length=self.max_length,
return_tensors='pt'
)
item = {
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'labels': torch.tensor(label, dtype=torch.long) #确保label是一个张量
}
return item
3、准备训练数据和测试数据:
ini
# 示例数据 - 情感分析 (0: 负面, 1: 正面)
texts = [
"这部电影太棒了,演员表演出色!",
"完全浪费时间和金钱。",
"剧情一般,但特效还不错。",
"强烈推荐,今年最好的电影之一!",
"糟糕的导演和剧本,令人失望。",
"演员阵容强大,但故事缺乏深度。",
"从头到尾都吸引人,毫无冷场。",
"摄影很美,但情节太 predictable。"
]
labels = [1, 0, 1, 1, 0, 0, 1, 0]
# 测试数据
test_texts = [
"不算太好,但也不差",
"绝对 masterpiece,完美无缺"
]
test_labels = [0, 1]
4、定义模型蒸馏时需要的参数配置类,定义教师模型和学生模型:
ini
class Config:
#这两个模型可以自行下载,下载地址为git clone https://hf-mirror.com/google-bert/bert-base-uncased 和git clone https://hf-mirror.com/distilbert/distilbert-base-uncased,确保电脑上安装了lfs
teacher_model_name = "bert-base-uncased的本地路径"
student_model_name = "distilbert-base-uncased的本地路径"
number_labels = 2
batch_size = 2
learning_rate = 5e-5#学习率
num_epochs = 10
max_length = 64
temperature = 2 #温度参数,控制软标签的平滑程度
alpha = 0.5 # 知识蒸馏的权重系数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = Config()
tokenizer = BertTokenizer.from_pretrained(config.teacher_model_name)
train_dataset = TextClassificationDataset(texts, labels, tokenizer, config.max_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, config.max_length)
test_loader = DataLoader(test_dataset, batch_size=1)
teacher_model = BertForSequenceClassification.from_pretrained(config.teacher_model_name,
num_labels=config.number_labels).to(config.device)
student_model = BertForSequenceClassification.from_pretrained(config.student_model_name,
num_labels=config.number_labels).to(config.device)
for param in teacher_model.parameters():
param.requires_grad = False # 冻结教师模型参数
optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.learning_rate)
先加载预训练的模型,然后冻结教师模型的各项参数,定义优化器。
5、定义损失函数
python
def distill_loss(student_logits, teacher_logits, labels, temperature, alpha):
"""
计算知识蒸馏损失
:param student_logits: 学生模型的输出
:param teacher_logits: 教师模型的输出
:param labels: 真实标签
:param temperature: 温度参数
:param alpha: 知识蒸馏的权重系数
:return: 损失值
"""
soft_loss = torch.nn.KLDivLoss(reduction='batchmean')(torch.log_softmax(student_logits/temperature, dim=1),
torch.softmax(teacher_logits/temperature, dim=1))*(temperature**2)
hard_loss = torch.nn.CrossEntropyLoss()(student_logits, labels)
return alpha * soft_loss + (1 - alpha) * hard_loss
6、训练、蒸馏模型,并进行评估:
python
def train(model,data_loader, optimizer):
model.train()
total_loss = 0
for batch in tqdm(data_loader, desc="Training"):
input_ids = batch['input_ids'].to(config.device)
attention_mask = batch['attention_mask'].to(config.device)
labels = batch['labels'].to(config.device)
optimizer.zero_grad()
with torch.no_grad():
teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_outputs.logits
student_outputs = model(input_ids, attention_mask=attention_mask)
student_logits = student_outputs.logits
loss = distill_loss(student_logits, teacher_logits, labels, config.temperature, config.alpha)
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
return total_loss/len(data_loader)
def evaluate(model, data_loader):
model.eval()
correct = 0
total = 0
with torch.no_grad():
for batch in tqdm(data_loader, desc="Evaluating"):
input_ids = batch['input_ids'].to(config.device)
attention_mask = batch['attention_mask'].to(config.device)
labels = batch['labels'].to(config.device)
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
_, predicted = torch.max(logits, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
return correct/total
7、调用、训练并开始评估:
python
for epoch in range(config.num_epochs):
print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
# 训练
train_loss = train(student_model, train_loader, optimizer)
print(f"Train Loss: {train_loss:.4f}")
# 评估
accuracy = evaluate(student_model, test_loader)
print(f"Test Accuracy: {accuracy:.2f}")
8、使用optuna框架寻找最优的蒸馏参数:
python
ef objective(trial):
params = {
'temperature': trial.suggest_float('temperature', 1.0, 15.0),
'alpha': trial.suggest_float('alpha', 0.1, 0.9),
'learning_rate': trial.suggest_float('learning_rate', 1e-6, 5e-5, log=True),
'num_epochs': 5,
}
student_model = DistilBertForSequenceClassification.from_pretrained(
config.student_model_name,
num_labels=config.number_labels)
student_model.to(config.device)
optimizer = torch.optim.AdamW(student_model.parameters(), lr=params['learning_rate'])
best_accuracy = 0.0
for epoch in range(params['num_epochs']):
student_model.train()
for batch in train_loader:
input_ids = batch['input_ids'].to(config.device)
attention_mask = batch['attention_mask'].to(config.device)
labels = batch['labels'].to(config.device)
optimizer.zero_grad()
with torch.no_grad():
teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_outputs.logits
student_outputs = student_model(input_ids, attention_mask=attention_mask)
student_logits = student_outputs.logits
loss = distill_loss(student_logits, teacher_logits, labels, params['temperature'], params['alpha'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
accuracys = evaluate(student_model, test_loader)
trial.report(accuracys, epoch)
if trial.should_prune():
raise optuna.TrialPruned()
if accuracys > best_accuracy:
best_accuracy = accuracys
return best_accuracy
# 创建Optuna研究
study = optuna.create_study(
direction='maximize', # 我们要最大化准确率
sampler=optuna.samplers.TPESampler(), # 使用TPE采样器
pruner=optuna.pruners.MedianPruner() # 中值剪枝器,用于提前停止不理想的试验
)
# 运行优化
study.optimize(objective, n_trials=20, timeout=600) # 最多20次试验或10分钟
# 输出最佳结果
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial
print(f" Value (Accuracy): {trial.value}")
print(" Params: ")
for key, value in trial.params.items():
print(f" {key}: {value}")
def best_params_train():
best_params = study.best_params
final_model = DistilBertForSequenceClassification.from_pretrained(
config.student_model_name,
num_labels=config.number_labels
).to(config.device)
optimizer = torch.optim.AdamW(final_model.parameters(), lr=best_params['learning_rate'])
for epoch in range(5):
final_model.train()
total_loss = 0
for batch in tqdm(train_loader, desc=f"Final Training:{epoch + 1}"):
input_ids = batch['input_ids'].to(config.device)
attention_mask = batch['attention_mask'].to(config.device)
labels = batch['labels'].to(config.device)
# optimizer.zero_grad()
with torch.no_grad():
teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
teacher_logits = teacher_outputs.logits
student_outputs = final_model(input_ids, attention_mask=attention_mask)
student_logits = student_outputs.logits
loss = distill_loss(student_logits, teacher_logits, labels,
best_params['temperature'], best_params['alpha'])
optimizer.zero_grad()
loss.backward()
optimizer.step()
total_loss += loss.item()
# 每个epoch后评估
accuracy = evaluate(final_model, test_loader)
print(f"Epoch {epoch + 1} - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}")
# 保存最终模型
final_model.save_pretrained('optimized_distilled_distilbert')