从 Transformer 理论到文本分类:BERT 微调实战总结

1. 前言

Transformer 模型自《Attention Is All You Need》提出以来,已成为 NLP 的基石架构。

BERT 作为双向 Transformer 编码器的代表,通过双向编码Masked LM + NSP 预训练目标,更成为行业应用任务的标准起点:

  • 金融文本风控(如欺诈检测、信用违约预测)
  • 保险核保与理赔审核
  • 医疗文档的自动摘要与诊断建议生成

本文系统性梳理了 BERT 模型的结构与预训练过程,从理论到实战,完整复现了 BERT 模型的微调流程,并结合优化实验,探索了模型在小样本任务中的表现与调参策略。

2. 模型与任务背景

2.1 任务说明:

判断一句话是否为 "值得核查的声明"(Claim vs Non-Claim)。

例如:

  • "The global temperature has increased by 1°C in the past century." → ✅ Claim
  • "I like sunny days." → ❌ Non-Claim

2.2 模型选择:

本实验选用 bert-base-uncased 作为基座模型。其主要结构如下:

  • Encoder 层数:12
  • Attention Heads:12
  • Hidden Size:768
  • 参数规模:1.1 亿

在分类任务中,仅需在 [CLS] token 向量后接入一个全连接层进行二分类训练,使用 Hugging Face Transformers 框架进行微调。

2.3 实验设计

实验环境
  • 框架:PyTorch + Transformers
  • 显卡:NVIDIA T4/A100
  • Batch size:16
  • Epochs:2
  • 学习率:2e-5
  • 数据集:20,000+

3. 训练优化与参数分析

在微调过程中,对模型收敛行为进行了多组实验与参数调优。

通过实践发现,不同超参数对训练稳定性与最终性能影响显著,主要结论如下:

3.1 学习率(Learning Rate)

  • 过高(例如 3e-5)会导致训练损失曲线剧烈抖动,模型难以稳定收敛。
  • 过低则可能陷入局部最优,模型参数无法充分更新。
  • 综合考虑收敛速度与稳定性,2e-5 在当前数据集上表现最佳。

3.2 优化器(Optimizer)

采用 AdamW + Cosine Scheduler 能有效平衡训练前后期的学习率动态。

  • 训练初期:Cosine 策略可加快学习并促使模型快速逼近低损失区域。
  • 后期:学习率自然下降,帮助模型平稳收敛。

3.3 Warmup 策略

引入 线性 warmup 机制,使学习率在初始阶段从 0 平滑提升至设定值。

这一过程能显著降低早期梯度震荡,使曲线更加平稳。

3.4 梯度裁剪(Max Grad Norm)

为防止权重更新过大导致模型不稳定,引入梯度裁剪机制。设置的梯度的最大范数(L2 norm ∣ ∣ x ∣ ∣ 2 = s q r t ( s u m ( ∣ x i ∣ 2 ) ) ||x||₂ = sqrt(sum(|x_i|²)) ∣∣x∣∣2=sqrt(sum(∣xi∣2))),当梯度超过此阈值时,会按以下公式缩放:

梯度更新规则如下:
g r a d = g r a d × min ⁡ ( m a x _ n o r m t o t a l _ n o r m + 1 e − 6 , 1 ) grad = grad \times \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) grad=grad×min(total_norm+1e−6max_norm,1)

在本实验中,将 max_grad_norm 设置为 1.0 与 10 均未观察到明显过拟合,推测在参数量较大或数据复杂度更高的情况下,梯度爆炸风险才更显著。

3.5 权重衰减(Weight Decay)

用于防止过拟合。

AdamW 优化器在更新参数时,对非偏置项与 LayerNorm 权重施加衰减:
p a r a m = p a r a m × ( 1 − l r × w e i g h t d e c a y ) param = param \times (1 - lr \times weight decay) param=param×(1−lr×weightdecay)

在本实验中,weight_decay=1e-2 效果良好,若不设置或设置过高,均会导致模型在验证集上出现过拟合现象。

附:参数配置

python 复制代码
training_args = TrainingArguments(
        output_dir=model_args.output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=50,
        logging_steps=50,
        save_strategy="steps",
        save_steps=100,
        max_grad_norm=5.0,
        warmup_steps=400,
        learning_rate=2e-5,
        weight_decay=1e-2,
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=True,
        hub_model_id=model_args.hub_model_id,
        fp16=True,
        report_to="wandb",
    )

4. 实验结果与分析

Metric Value
Accuracy 0.91
F1-score 0.91
Eval Loss 0.22

使用 Weights & Biases (wandb) 记录了训练与验证曲线:

  • Loss 曲线:收敛平稳,2 epoch 内达到最优。
  • Accuracy / F1 曲线:F1 与 Accuracy 同步上升,模型无明显过拟合。
  • confusion matrix:对「claim」类判定更准确,少量误分类为非声明。

模型学习到的特征主要包括:

  • 事实陈述的语义模式(如"X increased by Y");
  • 动词与数值信息;
  • 主体+谓语+量化描述结构。

5. 对比实验

除使用 Trainer API 的标准微调外,还自实现了完整训练循环(full_training_loop.ipynb),以便深入理解优化器行为、梯度裁剪与调度策略。

  • 自定义 DataLoader
  • 梯度累积与优化步骤
  • 混合使用 adamW 与 SGD 优化器

结果对比:

模式 F1-score 评价
Hugging Face Trainer 0.91 收敛快,稳定性好
手动训练循环 0.91 收敛块,稳定性好

附:部分代码实现

python 复制代码
for epoch in range(num_epochs):
    model.train()

    train_metric = Accumulator(2)
    for i, train_batch in enumerate(train_dataloader):
        step_count += 1
        outputs = model(**train_batch)
        train_loss = outputs.loss
        train_metric.add(train_loss.item(), 1)

        accelerator.backward(train_loss)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        if step_count > num_adam_steps:
            sgd.step()
            sgd_scheduler.step()
            sgd.zero_grad()
        else:
            adamW.step()
            adamW_scheduler.step()
            adamW.zero_grad()

        progress_bar.update(1)

        # evaluation
        if (i + 1) % 50 == 0 or i == len(train_dataloader) - 1:
            train_loss_avg = train_metric[0] / train_metric[1]

            accuracy_metric = evaluate.load("accuracy")
            f1_metric = evaluate.load("f1")

            model.eval()
            eval_metric = Accumulator(2)
            for j, eval_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**eval_batch)

                eval_loss = outputs.loss
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)

                eval_metric.add(eval_loss.item(), 1)
                accuracy_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())
                f1_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())

            eval_loss_avg = eval_metric[0] / eval_metric[1]
            accuracy = accuracy_metric.compute()
            f1 = f1_metric.compute()

            print(f"\nsteps: {step_count}, eval_loss_avg:{eval_loss_avg}, train loss:{train_loss_avg}, acc: {accuracy['accuracy']}, f1: {f1['f1']}")

            table_data.append([
                step_count,
                f"{train_loss_avg}",
                f"{eval_loss_avg}",
                f"{accuracy['accuracy']}",
                f"{f1['f1']}"
            ])

            wandb.log({"train/loss": train_loss_avg,
                       "eval/loss": eval_loss_avg,
                       "eval/accuracy": accuracy,
                       "eval/f1": f1,
                       })
            train_metric.reset()

print(tabulate(table_data, headers="firstrow", tablefmt="grid"))
run.finish()

6. 推理服务部署

部署方式:

  • 封装成 FastAPI 微服务;
  • 提供 /predict 接口用于文本分类;
  • 通过 Docker 容器化部署。
bash 复制代码
docker build -t claim-detection-service:latest .
docker run -p 8000:8000 claim-detection-service:latest

可通过 http://localhost:8000/docs 访问 Swagger UI 进行预测。

7. 从 BERT 到 RAG:下一步探索

本次微调实验主要聚焦在语义分类任务

下一阶段计划探索如何将该模型融入 RAG(Retrieval-Augmented Generation) 流程中。
后续方向:

  1. 领域微调(Domain Adaptation);
  2. 使用 LangChain 构建 RAG 管线(文档检索 + claim 识别 + LLM 回答);
  3. 探索 LangGraph 实现可解释的 claim 追溯链(Explainable AI)。

附:资源与参考

GitHub 源码
Hugging Face Model Card
Attention Is All You Need (Vaswani et al., 2017)
Hugging Face Doc
Dive Into Deep Learning

相关推荐
zhangfeng113314 小时前
spss 性别类似的二分类变量 多分类变量 做线性回归分析
分类·数据挖掘·线性回归
bst@微胖子17 小时前
HuggingFace项目实战之分类任务实战
pytorch·深度学习·分类
机器学习之心17 小时前
MATLAB基于GWO优化Transformer多输入多输出回归预测与改进NSGA III的多目标优化
transformer·gwo-transformer·多输入多输出回归预测·改进nsgaiii的多目标优化
Francek Chen18 小时前
【自然语言处理】应用06:针对序列级和词元级应用微调BERT
人工智能·pytorch·深度学习·自然语言处理·bert
ekkoalex18 小时前
强化学习中参数的设置
人工智能·深度学习·transformer
摸鱼仙人~19 小时前
BERT分类的上下文限制及解决方案
人工智能·分类·bert
悟道心19 小时前
5. 自然语言处理NLP - Transformer
人工智能·自然语言处理·transformer
摸鱼仙人~19 小时前
使用 BERT 系列模型实现 RAG Chunk 分类打标
人工智能·分类·bert
楚来客1 天前
AI基础概念之八:Transformer算法通俗解析
人工智能·算法·transformer
雍凉明月夜1 天前
深度学习网络笔记Ⅳ(Transformer + VIT)
笔记·深度学习·transformer