从 Transformer 理论到文本分类:BERT 微调实战总结

1. 前言

Transformer 模型自《Attention Is All You Need》提出以来,已成为 NLP 的基石架构。

BERT 作为双向 Transformer 编码器的代表,通过双向编码Masked LM + NSP 预训练目标,更成为行业应用任务的标准起点:

  • 金融文本风控(如欺诈检测、信用违约预测)
  • 保险核保与理赔审核
  • 医疗文档的自动摘要与诊断建议生成

本文系统性梳理了 BERT 模型的结构与预训练过程,从理论到实战,完整复现了 BERT 模型的微调流程,并结合优化实验,探索了模型在小样本任务中的表现与调参策略。

2. 模型与任务背景

2.1 任务说明:

判断一句话是否为 "值得核查的声明"(Claim vs Non-Claim)。

例如:

  • "The global temperature has increased by 1°C in the past century." → ✅ Claim
  • "I like sunny days." → ❌ Non-Claim

2.2 模型选择:

本实验选用 bert-base-uncased 作为基座模型。其主要结构如下:

  • Encoder 层数:12
  • Attention Heads:12
  • Hidden Size:768
  • 参数规模:1.1 亿

在分类任务中,仅需在 [CLS] token 向量后接入一个全连接层进行二分类训练,使用 Hugging Face Transformers 框架进行微调。

2.3 实验设计

实验环境
  • 框架:PyTorch + Transformers
  • 显卡:NVIDIA T4/A100
  • Batch size:16
  • Epochs:2
  • 学习率:2e-5
  • 数据集:20,000+

3. 训练优化与参数分析

在微调过程中,对模型收敛行为进行了多组实验与参数调优。

通过实践发现,不同超参数对训练稳定性与最终性能影响显著,主要结论如下:

3.1 学习率(Learning Rate)

  • 过高(例如 3e-5)会导致训练损失曲线剧烈抖动,模型难以稳定收敛。
  • 过低则可能陷入局部最优,模型参数无法充分更新。
  • 综合考虑收敛速度与稳定性,2e-5 在当前数据集上表现最佳。

3.2 优化器(Optimizer)

采用 AdamW + Cosine Scheduler 能有效平衡训练前后期的学习率动态。

  • 训练初期:Cosine 策略可加快学习并促使模型快速逼近低损失区域。
  • 后期:学习率自然下降,帮助模型平稳收敛。

3.3 Warmup 策略

引入 线性 warmup 机制,使学习率在初始阶段从 0 平滑提升至设定值。

这一过程能显著降低早期梯度震荡,使曲线更加平稳。

3.4 梯度裁剪(Max Grad Norm)

为防止权重更新过大导致模型不稳定,引入梯度裁剪机制。设置的梯度的最大范数(L2 norm ∣ ∣ x ∣ ∣ 2 = s q r t ( s u m ( ∣ x i ∣ 2 ) ) ||x||₂ = sqrt(sum(|x_i|²)) ∣∣x∣∣2=sqrt(sum(∣xi∣2))),当梯度超过此阈值时,会按以下公式缩放:

梯度更新规则如下:
g r a d = g r a d × min ⁡ ( m a x _ n o r m t o t a l _ n o r m + 1 e − 6 , 1 ) grad = grad \times \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) grad=grad×min(total_norm+1e−6max_norm,1)

在本实验中,将 max_grad_norm 设置为 1.0 与 10 均未观察到明显过拟合,推测在参数量较大或数据复杂度更高的情况下,梯度爆炸风险才更显著。

3.5 权重衰减(Weight Decay)

用于防止过拟合。

AdamW 优化器在更新参数时,对非偏置项与 LayerNorm 权重施加衰减:
p a r a m = p a r a m × ( 1 − l r × w e i g h t d e c a y ) param = param \times (1 - lr \times weight decay) param=param×(1−lr×weightdecay)

在本实验中,weight_decay=1e-2 效果良好,若不设置或设置过高,均会导致模型在验证集上出现过拟合现象。

附:参数配置

python 复制代码
training_args = TrainingArguments(
        output_dir=model_args.output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=50,
        logging_steps=50,
        save_strategy="steps",
        save_steps=100,
        max_grad_norm=5.0,
        warmup_steps=400,
        learning_rate=2e-5,
        weight_decay=1e-2,
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=True,
        hub_model_id=model_args.hub_model_id,
        fp16=True,
        report_to="wandb",
    )

4. 实验结果与分析

Metric Value
Accuracy 0.91
F1-score 0.91
Eval Loss 0.22

使用 Weights & Biases (wandb) 记录了训练与验证曲线:

  • Loss 曲线:收敛平稳,2 epoch 内达到最优。
  • Accuracy / F1 曲线:F1 与 Accuracy 同步上升,模型无明显过拟合。
  • confusion matrix:对「claim」类判定更准确,少量误分类为非声明。

模型学习到的特征主要包括:

  • 事实陈述的语义模式(如"X increased by Y");
  • 动词与数值信息;
  • 主体+谓语+量化描述结构。

5. 对比实验

除使用 Trainer API 的标准微调外,还自实现了完整训练循环(full_training_loop.ipynb),以便深入理解优化器行为、梯度裁剪与调度策略。

  • 自定义 DataLoader
  • 梯度累积与优化步骤
  • 混合使用 adamW 与 SGD 优化器

结果对比:

模式 F1-score 评价
Hugging Face Trainer 0.91 收敛快,稳定性好
手动训练循环 0.91 收敛块,稳定性好

附:部分代码实现

python 复制代码
for epoch in range(num_epochs):
    model.train()

    train_metric = Accumulator(2)
    for i, train_batch in enumerate(train_dataloader):
        step_count += 1
        outputs = model(**train_batch)
        train_loss = outputs.loss
        train_metric.add(train_loss.item(), 1)

        accelerator.backward(train_loss)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        if step_count > num_adam_steps:
            sgd.step()
            sgd_scheduler.step()
            sgd.zero_grad()
        else:
            adamW.step()
            adamW_scheduler.step()
            adamW.zero_grad()

        progress_bar.update(1)

        # evaluation
        if (i + 1) % 50 == 0 or i == len(train_dataloader) - 1:
            train_loss_avg = train_metric[0] / train_metric[1]

            accuracy_metric = evaluate.load("accuracy")
            f1_metric = evaluate.load("f1")

            model.eval()
            eval_metric = Accumulator(2)
            for j, eval_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**eval_batch)

                eval_loss = outputs.loss
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)

                eval_metric.add(eval_loss.item(), 1)
                accuracy_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())
                f1_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())

            eval_loss_avg = eval_metric[0] / eval_metric[1]
            accuracy = accuracy_metric.compute()
            f1 = f1_metric.compute()

            print(f"\nsteps: {step_count}, eval_loss_avg:{eval_loss_avg}, train loss:{train_loss_avg}, acc: {accuracy['accuracy']}, f1: {f1['f1']}")

            table_data.append([
                step_count,
                f"{train_loss_avg}",
                f"{eval_loss_avg}",
                f"{accuracy['accuracy']}",
                f"{f1['f1']}"
            ])

            wandb.log({"train/loss": train_loss_avg,
                       "eval/loss": eval_loss_avg,
                       "eval/accuracy": accuracy,
                       "eval/f1": f1,
                       })
            train_metric.reset()

print(tabulate(table_data, headers="firstrow", tablefmt="grid"))
run.finish()

6. 推理服务部署

部署方式:

  • 封装成 FastAPI 微服务;
  • 提供 /predict 接口用于文本分类;
  • 通过 Docker 容器化部署。
bash 复制代码
docker build -t claim-detection-service:latest .
docker run -p 8000:8000 claim-detection-service:latest

可通过 http://localhost:8000/docs 访问 Swagger UI 进行预测。

7. 从 BERT 到 RAG:下一步探索

本次微调实验主要聚焦在语义分类任务

下一阶段计划探索如何将该模型融入 RAG(Retrieval-Augmented Generation) 流程中。
后续方向:

  1. 领域微调(Domain Adaptation);
  2. 使用 LangChain 构建 RAG 管线(文档检索 + claim 识别 + LLM 回答);
  3. 探索 LangGraph 实现可解释的 claim 追溯链(Explainable AI)。

附:资源与参考

GitHub 源码
Hugging Face Model Card
Attention Is All You Need (Vaswani et al., 2017)
Hugging Face Doc
Dive Into Deep Learning

相关推荐
DeepModel21 小时前
【分类算法】逻辑回归超详细讲解
分类·数据挖掘·逻辑回归
剑穗挂着新流苏3121 天前
Pytorch加载数据
python·深度学习·transformer
欧阳小猜1 天前
Transformer革命:从序列建模到通用人工智能的架构突破
人工智能·架构·transformer
小陈phd1 天前
多模态大模型学习笔记(二十一)—— 基于 Scaling Law方法 的大模型训练算力估算与 GPU 资源配置
笔记·深度学习·学习·自然语言处理·transformer
张张123y1 天前
#Transformer架构与微调技术深度解析
深度学习·架构·transformer
輕華1 天前
矿物成分数据智能分类实战(三):以平均值填充数据集的pytorch框架和MLP算法实现与性能分析
pytorch·分类·数据挖掘
前端摸鱼匠1 天前
面试题2:Transformer的Encoder、Decoder结构分别包含哪些核心组件?
人工智能·深度学习·ai·面试·职场和发展·transformer
油泼辣子多加1 天前
【DL】Transformer算法应用
人工智能·深度学习·算法·机器学习·transformer
小超同学你好2 天前
LangGraph 14. MCP:把“外部能力”标准化接入 LLM
人工智能·语言模型·transformer
_张一凡2 天前
【多模态模型学习】从零手撕一个Vision Transformer(ViT)模型实战篇
人工智能·深度学习·transformer