从 Transformer 理论到文本分类:BERT 微调实战总结

1. 前言

Transformer 模型自《Attention Is All You Need》提出以来,已成为 NLP 的基石架构。

BERT 作为双向 Transformer 编码器的代表,通过双向编码Masked LM + NSP 预训练目标,更成为行业应用任务的标准起点:

  • 金融文本风控(如欺诈检测、信用违约预测)
  • 保险核保与理赔审核
  • 医疗文档的自动摘要与诊断建议生成

本文系统性梳理了 BERT 模型的结构与预训练过程,从理论到实战,完整复现了 BERT 模型的微调流程,并结合优化实验,探索了模型在小样本任务中的表现与调参策略。

2. 模型与任务背景

2.1 任务说明:

判断一句话是否为 "值得核查的声明"(Claim vs Non-Claim)。

例如:

  • "The global temperature has increased by 1°C in the past century." → ✅ Claim
  • "I like sunny days." → ❌ Non-Claim

2.2 模型选择:

本实验选用 bert-base-uncased 作为基座模型。其主要结构如下:

  • Encoder 层数:12
  • Attention Heads:12
  • Hidden Size:768
  • 参数规模:1.1 亿

在分类任务中,仅需在 [CLS] token 向量后接入一个全连接层进行二分类训练,使用 Hugging Face Transformers 框架进行微调。

2.3 实验设计

实验环境
  • 框架:PyTorch + Transformers
  • 显卡:NVIDIA T4/A100
  • Batch size:16
  • Epochs:2
  • 学习率:2e-5
  • 数据集:20,000+

3. 训练优化与参数分析

在微调过程中,对模型收敛行为进行了多组实验与参数调优。

通过实践发现,不同超参数对训练稳定性与最终性能影响显著,主要结论如下:

3.1 学习率(Learning Rate)

  • 过高(例如 3e-5)会导致训练损失曲线剧烈抖动,模型难以稳定收敛。
  • 过低则可能陷入局部最优,模型参数无法充分更新。
  • 综合考虑收敛速度与稳定性,2e-5 在当前数据集上表现最佳。

3.2 优化器(Optimizer)

采用 AdamW + Cosine Scheduler 能有效平衡训练前后期的学习率动态。

  • 训练初期:Cosine 策略可加快学习并促使模型快速逼近低损失区域。
  • 后期:学习率自然下降,帮助模型平稳收敛。

3.3 Warmup 策略

引入 线性 warmup 机制,使学习率在初始阶段从 0 平滑提升至设定值。

这一过程能显著降低早期梯度震荡,使曲线更加平稳。

3.4 梯度裁剪(Max Grad Norm)

为防止权重更新过大导致模型不稳定,引入梯度裁剪机制。设置的梯度的最大范数(L2 norm ∣ ∣ x ∣ ∣ 2 = s q r t ( s u m ( ∣ x i ∣ 2 ) ) ||x||₂ = sqrt(sum(|x_i|²)) ∣∣x∣∣2=sqrt(sum(∣xi∣2))),当梯度超过此阈值时,会按以下公式缩放:

梯度更新规则如下:
g r a d = g r a d × min ⁡ ( m a x _ n o r m t o t a l _ n o r m + 1 e − 6 , 1 ) grad = grad \times \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) grad=grad×min(total_norm+1e−6max_norm,1)

在本实验中,将 max_grad_norm 设置为 1.0 与 10 均未观察到明显过拟合,推测在参数量较大或数据复杂度更高的情况下,梯度爆炸风险才更显著。

3.5 权重衰减(Weight Decay)

用于防止过拟合。

AdamW 优化器在更新参数时,对非偏置项与 LayerNorm 权重施加衰减:
p a r a m = p a r a m × ( 1 − l r × w e i g h t d e c a y ) param = param \times (1 - lr \times weight decay) param=param×(1−lr×weightdecay)

在本实验中,weight_decay=1e-2 效果良好,若不设置或设置过高,均会导致模型在验证集上出现过拟合现象。

附:参数配置

python 复制代码
training_args = TrainingArguments(
        output_dir=model_args.output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=50,
        logging_steps=50,
        save_strategy="steps",
        save_steps=100,
        max_grad_norm=5.0,
        warmup_steps=400,
        learning_rate=2e-5,
        weight_decay=1e-2,
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=True,
        hub_model_id=model_args.hub_model_id,
        fp16=True,
        report_to="wandb",
    )

4. 实验结果与分析

Metric Value
Accuracy 0.91
F1-score 0.91
Eval Loss 0.22

使用 Weights & Biases (wandb) 记录了训练与验证曲线:

  • Loss 曲线:收敛平稳,2 epoch 内达到最优。
  • Accuracy / F1 曲线:F1 与 Accuracy 同步上升,模型无明显过拟合。
  • confusion matrix:对「claim」类判定更准确,少量误分类为非声明。

模型学习到的特征主要包括:

  • 事实陈述的语义模式(如"X increased by Y");
  • 动词与数值信息;
  • 主体+谓语+量化描述结构。

5. 对比实验

除使用 Trainer API 的标准微调外,还自实现了完整训练循环(full_training_loop.ipynb),以便深入理解优化器行为、梯度裁剪与调度策略。

  • 自定义 DataLoader
  • 梯度累积与优化步骤
  • 混合使用 adamW 与 SGD 优化器

结果对比:

模式 F1-score 评价
Hugging Face Trainer 0.91 收敛快,稳定性好
手动训练循环 0.91 收敛块,稳定性好

附:部分代码实现

python 复制代码
for epoch in range(num_epochs):
    model.train()

    train_metric = Accumulator(2)
    for i, train_batch in enumerate(train_dataloader):
        step_count += 1
        outputs = model(**train_batch)
        train_loss = outputs.loss
        train_metric.add(train_loss.item(), 1)

        accelerator.backward(train_loss)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        if step_count > num_adam_steps:
            sgd.step()
            sgd_scheduler.step()
            sgd.zero_grad()
        else:
            adamW.step()
            adamW_scheduler.step()
            adamW.zero_grad()

        progress_bar.update(1)

        # evaluation
        if (i + 1) % 50 == 0 or i == len(train_dataloader) - 1:
            train_loss_avg = train_metric[0] / train_metric[1]

            accuracy_metric = evaluate.load("accuracy")
            f1_metric = evaluate.load("f1")

            model.eval()
            eval_metric = Accumulator(2)
            for j, eval_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**eval_batch)

                eval_loss = outputs.loss
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)

                eval_metric.add(eval_loss.item(), 1)
                accuracy_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())
                f1_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())

            eval_loss_avg = eval_metric[0] / eval_metric[1]
            accuracy = accuracy_metric.compute()
            f1 = f1_metric.compute()

            print(f"\nsteps: {step_count}, eval_loss_avg:{eval_loss_avg}, train loss:{train_loss_avg}, acc: {accuracy['accuracy']}, f1: {f1['f1']}")

            table_data.append([
                step_count,
                f"{train_loss_avg}",
                f"{eval_loss_avg}",
                f"{accuracy['accuracy']}",
                f"{f1['f1']}"
            ])

            wandb.log({"train/loss": train_loss_avg,
                       "eval/loss": eval_loss_avg,
                       "eval/accuracy": accuracy,
                       "eval/f1": f1,
                       })
            train_metric.reset()

print(tabulate(table_data, headers="firstrow", tablefmt="grid"))
run.finish()

6. 推理服务部署

部署方式:

  • 封装成 FastAPI 微服务;
  • 提供 /predict 接口用于文本分类;
  • 通过 Docker 容器化部署。
bash 复制代码
docker build -t claim-detection-service:latest .
docker run -p 8000:8000 claim-detection-service:latest

可通过 http://localhost:8000/docs 访问 Swagger UI 进行预测。

7. 从 BERT 到 RAG:下一步探索

本次微调实验主要聚焦在语义分类任务

下一阶段计划探索如何将该模型融入 RAG(Retrieval-Augmented Generation) 流程中。
后续方向:

  1. 领域微调(Domain Adaptation);
  2. 使用 LangChain 构建 RAG 管线(文档检索 + claim 识别 + LLM 回答);
  3. 探索 LangGraph 实现可解释的 claim 追溯链(Explainable AI)。

附:资源与参考

GitHub 源码
Hugging Face Model Card
Attention Is All You Need (Vaswani et al., 2017)
Hugging Face Doc
Dive Into Deep Learning

相关推荐
小徐xxx4 小时前
Softmax回归(分类问题)学习记录
深度学习·分类·回归·softmax·学习记录
AAD555888994 小时前
YOLOv8-MAN-Faster电容器缺陷检测:七类组件识别与分类系统
yolo·分类·数据挖掘
JicasdC123asd4 小时前
【工业检测】基于YOLO13-C3k2-EIEM的铸造缺陷检测与分类系统_1
人工智能·算法·分类
kebijuelun5 小时前
ERNIE 5.0:统一自回归多模态与弹性训练
人工智能·算法·语言模型·transformer
子夜江寒5 小时前
基于 LSTM 的中文情感分类项目解析
人工智能·分类·lstm
Niuguangshuo7 小时前
DALL-E 2:从CLIP潜变量到高质量图像生成的突破
人工智能·深度学习·transformer
是小蟹呀^8 小时前
Focal Loss:解决长尾图像分类中“多数类太强势”的损失函数
人工智能·机器学习·分类
@鱼香肉丝没有鱼8 小时前
Transformer底层原理—Encoder结构
人工智能·深度学习·transformer
2501_941329729 小时前
基于Centernet的甜菜幼苗生长状态识别与分类系统
人工智能·分类·数据挖掘