模型蒸馏:使用bert-base-uncased模型蒸馏出distilbert-base-uncased

1、模型蒸馏:

模型蒸馏(Model Distillation)是一种将复杂模型(教师模型)的知识迁移到更小、更高效的模型(学生模型) 的技术。其核心目的是在保持模型性能的同时,显著减少计算资源占用和推理时间,便于在边缘设备(如手机、IoT设备)上部署。 本文的实例是使用bert-base-uncased模型蒸馏出distilbert-base-uncased,模型蒸馏的核心步骤包括:

  1. 训练教师模型:在大规模数据上训练一个高性能但复杂的模型(如BERT、ResNet)。

  2. 生成软标签:用教师模型对训练数据预测,得到概率分布(软标签)。

  3. 训练学生模型:学生模型同时学习:

    • 软标签(通过KL散度损失函数)。
    • 真实标签(通过交叉熵损失)。
  4. 调整温度:高温训练,低温推理。 温度参数 T>1时:概率分布更平滑,凸显次要类别信息。 T=1时:标准softmax。 训练时使用较高的T,推理时恢复为T=1。

2、代码实例

首先定义一个dataset数据类:

python 复制代码
class TextClassificationDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts #文本内容
        self.labels = labels #文本对应的标签
        self.tokenizer = tokenizer #token解析器
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )

        item = {
            'input_ids': encoding['input_ids'].flatten(),  
            'attention_mask': encoding['attention_mask'].flatten(), 
            'labels': torch.tensor(label, dtype=torch.long)  #确保label是一个张量
        }

        return item

3、准备训练数据和测试数据:

ini 复制代码
# 示例数据 - 情感分析 (0: 负面, 1: 正面)
texts = [
        "这部电影太棒了,演员表演出色!",
        "完全浪费时间和金钱。",
        "剧情一般,但特效还不错。",
        "强烈推荐,今年最好的电影之一!",
        "糟糕的导演和剧本,令人失望。",
        "演员阵容强大,但故事缺乏深度。",
        "从头到尾都吸引人,毫无冷场。",
        "摄影很美,但情节太 predictable。"
]

labels = [1, 0, 1, 1, 0, 0, 1, 0]
    # 测试数据
test_texts = [
        "不算太好,但也不差",
        "绝对 masterpiece,完美无缺"
    ]
test_labels = [0, 1]

4、定义模型蒸馏时需要的参数配置类,定义教师模型和学生模型:

ini 复制代码
class Config:
    #这两个模型可以自行下载,下载地址为git clone https://hf-mirror.com/google-bert/bert-base-uncased 和git clone https://hf-mirror.com/distilbert/distilbert-base-uncased,确保电脑上安装了lfs
    teacher_model_name = "bert-base-uncased的本地路径"
    student_model_name = "distilbert-base-uncased的本地路径"
    number_labels = 2

    batch_size = 2
    learning_rate = 5e-5#学习率
    num_epochs = 10
    max_length = 64

    temperature = 2 #温度参数,控制软标签的平滑程度
    alpha = 0.5 # 知识蒸馏的权重系数

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
config = Config()
tokenizer = BertTokenizer.from_pretrained(config.teacher_model_name)

train_dataset = TextClassificationDataset(texts, labels, tokenizer, config.max_length)
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)

test_dataset = TextClassificationDataset(test_texts, test_labels, tokenizer, config.max_length)
test_loader = DataLoader(test_dataset, batch_size=1)
teacher_model = BertForSequenceClassification.from_pretrained(config.teacher_model_name,
                                num_labels=config.number_labels).to(config.device)
student_model = BertForSequenceClassification.from_pretrained(config.student_model_name,
                                num_labels=config.number_labels).to(config.device)

for param in teacher_model.parameters():
    param.requires_grad = False  # 冻结教师模型参数
optimizer = torch.optim.AdamW(student_model.parameters(), lr=config.learning_rate)

先加载预训练的模型,然后冻结教师模型的各项参数,定义优化器。

5、定义损失函数

python 复制代码
def distill_loss(student_logits, teacher_logits, labels, temperature, alpha):
    """
    计算知识蒸馏损失
    :param student_logits: 学生模型的输出
    :param teacher_logits: 教师模型的输出
    :param labels: 真实标签
    :param temperature: 温度参数
    :param alpha: 知识蒸馏的权重系数
    :return: 损失值
    """
    soft_loss = torch.nn.KLDivLoss(reduction='batchmean')(torch.log_softmax(student_logits/temperature, dim=1),
     torch.softmax(teacher_logits/temperature, dim=1))*(temperature**2)

    hard_loss = torch.nn.CrossEntropyLoss()(student_logits, labels)
    return alpha * soft_loss + (1 - alpha) * hard_loss

6、训练、蒸馏模型,并进行评估:

python 复制代码
def train(model,data_loader, optimizer):
    model.train()
    total_loss = 0
    for batch in tqdm(data_loader, desc="Training"):
        input_ids = batch['input_ids'].to(config.device)
        attention_mask = batch['attention_mask'].to(config.device)
        labels = batch['labels'].to(config.device)

        optimizer.zero_grad()

        with torch.no_grad():
            teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
            teacher_logits = teacher_outputs.logits

        student_outputs = model(input_ids, attention_mask=attention_mask)
        student_logits = student_outputs.logits

        loss = distill_loss(student_logits, teacher_logits, labels, config.temperature, config.alpha)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss/len(data_loader)

def evaluate(model, data_loader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            labels = batch['labels'].to(config.device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits

            _, predicted = torch.max(logits, dim=1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return correct/total

7、调用、训练并开始评估:

python 复制代码
for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.num_epochs}")

    # 训练
    train_loss = train(student_model, train_loader, optimizer)
    print(f"Train Loss: {train_loss:.4f}")

    # 评估
    accuracy = evaluate(student_model, test_loader)
    print(f"Test Accuracy: {accuracy:.2f}")

8、使用optuna框架寻找最优的蒸馏参数:

python 复制代码
ef objective(trial):
    params = {
        'temperature': trial.suggest_float('temperature', 1.0, 15.0),
        'alpha': trial.suggest_float('alpha', 0.1, 0.9),
        'learning_rate': trial.suggest_float('learning_rate', 1e-6, 5e-5, log=True),
        'num_epochs': 5,
    }
    student_model = DistilBertForSequenceClassification.from_pretrained(
        config.student_model_name,
        num_labels=config.number_labels)
    student_model.to(config.device)
    optimizer = torch.optim.AdamW(student_model.parameters(), lr=params['learning_rate'])
    best_accuracy = 0.0
    for epoch in range(params['num_epochs']):
        student_model.train()
        for batch in train_loader:
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            labels = batch['labels'].to(config.device)

            optimizer.zero_grad()

            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            student_outputs = student_model(input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            loss = distill_loss(student_logits, teacher_logits, labels, params['temperature'], params['alpha'])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        accuracys = evaluate(student_model, test_loader)
        trial.report(accuracys, epoch)
        if trial.should_prune():
            raise optuna.TrialPruned()
        if accuracys > best_accuracy:
            best_accuracy = accuracys
    return best_accuracy
    
 
# 创建Optuna研究
study = optuna.create_study(
    direction='maximize',  # 我们要最大化准确率
    sampler=optuna.samplers.TPESampler(),  # 使用TPE采样器
    pruner=optuna.pruners.MedianPruner()  # 中值剪枝器,用于提前停止不理想的试验
)

# 运行优化
study.optimize(objective, n_trials=20, timeout=600)  # 最多20次试验或10分钟

# 输出最佳结果
print("Number of finished trials: ", len(study.trials))
print("Best trial:")
trial = study.best_trial

print(f"  Value (Accuracy): {trial.value}")
print("  Params: ")
for key, value in trial.params.items():
    print(f"    {key}: {value}")

def best_params_train():
    best_params = study.best_params
    final_model = DistilBertForSequenceClassification.from_pretrained(
        config.student_model_name,
        num_labels=config.number_labels
    ).to(config.device)
    optimizer = torch.optim.AdamW(final_model.parameters(), lr=best_params['learning_rate'])
    for epoch in range(5):
        final_model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"Final Training:{epoch + 1}"):
            input_ids = batch['input_ids'].to(config.device)
            attention_mask = batch['attention_mask'].to(config.device)
            labels = batch['labels'].to(config.device)

            # optimizer.zero_grad()

            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits

            student_outputs = final_model(input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits

            loss = distill_loss(student_logits, teacher_logits, labels,
                                best_params['temperature'], best_params['alpha'])
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        # 每个epoch后评估
        accuracy = evaluate(final_model, test_loader)
        print(f"Epoch {epoch + 1} - Loss: {total_loss / len(train_loader):.4f}, Accuracy: {accuracy:.4f}")

        # 保存最终模型
        final_model.save_pretrained('optimized_distilled_distilbert')
相关推荐
Sunhen_Qiletian21 分钟前
NumPy库学习(三):numpy在人工智能数据处理的具体应用及方法
人工智能·深度学习·神经网络·机器学习·计算机视觉·numpy
吕永强44 分钟前
人工智能与家庭:智能家居的便捷与隐患
人工智能·科普
kv18301 小时前
opencv解迷宫
人工智能·opencv·计算机视觉·广度优先搜索·图算法
Phoenixtree_DongZhao1 小时前
迈向透明人工智能: 可解释性大语言模型研究综述
人工智能·语言模型·自然语言处理
亅-丿-丶丿丶一l一丶-/^n1 小时前
deep research|从搜索引擎到搜索助手的实践(一)
人工智能·搜索引擎·deep research
说私域1 小时前
新零售“实—虚—合”逻辑下的技术赋能与模式革新:基于开源AI大模型、AI智能名片与S2B2C商城小程序源码的研究
人工智能·开源·零售
bright_colo2 小时前
Python-初学openCV——图像预处理(六)
人工智能·opencv·计算机视觉
图灵的白猫2 小时前
基于BiLSTM+CRF实现NER
人工智能
xiaobaibai1532 小时前
智慧交通中目标检测 mAP↑28%:陌讯多模态融合算法实战解析
人工智能·算法·目标检测·计算机视觉·目标跟踪·视觉检测
终将超越过去2 小时前
分类-鸢尾花分类
人工智能·分类·数据挖掘