从 Transformer 理论到文本分类:BERT 微调实战总结

1. 前言

Transformer 模型自《Attention Is All You Need》提出以来,已成为 NLP 的基石架构。

BERT 作为双向 Transformer 编码器的代表,通过双向编码Masked LM + NSP 预训练目标,更成为行业应用任务的标准起点:

  • 金融文本风控(如欺诈检测、信用违约预测)
  • 保险核保与理赔审核
  • 医疗文档的自动摘要与诊断建议生成

本文系统性梳理了 BERT 模型的结构与预训练过程,从理论到实战,完整复现了 BERT 模型的微调流程,并结合优化实验,探索了模型在小样本任务中的表现与调参策略。

2. 模型与任务背景

2.1 任务说明:

判断一句话是否为 "值得核查的声明"(Claim vs Non-Claim)。

例如:

  • "The global temperature has increased by 1°C in the past century." → ✅ Claim
  • "I like sunny days." → ❌ Non-Claim

2.2 模型选择:

本实验选用 bert-base-uncased 作为基座模型。其主要结构如下:

  • Encoder 层数:12
  • Attention Heads:12
  • Hidden Size:768
  • 参数规模:1.1 亿

在分类任务中,仅需在 [CLS] token 向量后接入一个全连接层进行二分类训练,使用 Hugging Face Transformers 框架进行微调。

2.3 实验设计

实验环境
  • 框架:PyTorch + Transformers
  • 显卡:NVIDIA T4/A100
  • Batch size:16
  • Epochs:2
  • 学习率:2e-5
  • 数据集:20,000+

3. 训练优化与参数分析

在微调过程中,对模型收敛行为进行了多组实验与参数调优。

通过实践发现,不同超参数对训练稳定性与最终性能影响显著,主要结论如下:

3.1 学习率(Learning Rate)

  • 过高(例如 3e-5)会导致训练损失曲线剧烈抖动,模型难以稳定收敛。
  • 过低则可能陷入局部最优,模型参数无法充分更新。
  • 综合考虑收敛速度与稳定性,2e-5 在当前数据集上表现最佳。

3.2 优化器(Optimizer)

采用 AdamW + Cosine Scheduler 能有效平衡训练前后期的学习率动态。

  • 训练初期:Cosine 策略可加快学习并促使模型快速逼近低损失区域。
  • 后期:学习率自然下降,帮助模型平稳收敛。

3.3 Warmup 策略

引入 线性 warmup 机制,使学习率在初始阶段从 0 平滑提升至设定值。

这一过程能显著降低早期梯度震荡,使曲线更加平稳。

3.4 梯度裁剪(Max Grad Norm)

为防止权重更新过大导致模型不稳定,引入梯度裁剪机制。设置的梯度的最大范数(L2 norm ∣ ∣ x ∣ ∣ 2 = s q r t ( s u m ( ∣ x i ∣ 2 ) ) ||x||₂ = sqrt(sum(|x_i|²)) ∣∣x∣∣2=sqrt(sum(∣xi∣2))),当梯度超过此阈值时,会按以下公式缩放:

梯度更新规则如下:
g r a d = g r a d × min ⁡ ( m a x _ n o r m t o t a l _ n o r m + 1 e − 6 , 1 ) grad = grad \times \min(\frac{max\_norm}{total\_norm + 1e-6}, 1) grad=grad×min(total_norm+1e−6max_norm,1)

在本实验中,将 max_grad_norm 设置为 1.0 与 10 均未观察到明显过拟合,推测在参数量较大或数据复杂度更高的情况下,梯度爆炸风险才更显著。

3.5 权重衰减(Weight Decay)

用于防止过拟合。

AdamW 优化器在更新参数时,对非偏置项与 LayerNorm 权重施加衰减:
p a r a m = p a r a m × ( 1 − l r × w e i g h t d e c a y ) param = param \times (1 - lr \times weight decay) param=param×(1−lr×weightdecay)

在本实验中,weight_decay=1e-2 效果良好,若不设置或设置过高,均会导致模型在验证集上出现过拟合现象。

附:参数配置

python 复制代码
training_args = TrainingArguments(
        output_dir=model_args.output_dir,
        num_train_epochs=2,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        gradient_accumulation_steps=1,
        eval_strategy="steps",
        eval_steps=50,
        logging_steps=50,
        save_strategy="steps",
        save_steps=100,
        max_grad_norm=5.0,
        warmup_steps=400,
        learning_rate=2e-5,
        weight_decay=1e-2,
        lr_scheduler_type="cosine",
        load_best_model_at_end=True,
        metric_for_best_model="f1",
        greater_is_better=True,
        push_to_hub=True,
        hub_model_id=model_args.hub_model_id,
        fp16=True,
        report_to="wandb",
    )

4. 实验结果与分析

Metric Value
Accuracy 0.91
F1-score 0.91
Eval Loss 0.22

使用 Weights & Biases (wandb) 记录了训练与验证曲线:

  • Loss 曲线:收敛平稳,2 epoch 内达到最优。
  • Accuracy / F1 曲线:F1 与 Accuracy 同步上升,模型无明显过拟合。
  • confusion matrix:对「claim」类判定更准确,少量误分类为非声明。

模型学习到的特征主要包括:

  • 事实陈述的语义模式(如"X increased by Y");
  • 动词与数值信息;
  • 主体+谓语+量化描述结构。

5. 对比实验

除使用 Trainer API 的标准微调外,还自实现了完整训练循环(full_training_loop.ipynb),以便深入理解优化器行为、梯度裁剪与调度策略。

  • 自定义 DataLoader
  • 梯度累积与优化步骤
  • 混合使用 adamW 与 SGD 优化器

结果对比:

模式 F1-score 评价
Hugging Face Trainer 0.91 收敛快,稳定性好
手动训练循环 0.91 收敛块,稳定性好

附:部分代码实现

python 复制代码
for epoch in range(num_epochs):
    model.train()

    train_metric = Accumulator(2)
    for i, train_batch in enumerate(train_dataloader):
        step_count += 1
        outputs = model(**train_batch)
        train_loss = outputs.loss
        train_metric.add(train_loss.item(), 1)

        accelerator.backward(train_loss)
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=5.0)
        if step_count > num_adam_steps:
            sgd.step()
            sgd_scheduler.step()
            sgd.zero_grad()
        else:
            adamW.step()
            adamW_scheduler.step()
            adamW.zero_grad()

        progress_bar.update(1)

        # evaluation
        if (i + 1) % 50 == 0 or i == len(train_dataloader) - 1:
            train_loss_avg = train_metric[0] / train_metric[1]

            accuracy_metric = evaluate.load("accuracy")
            f1_metric = evaluate.load("f1")

            model.eval()
            eval_metric = Accumulator(2)
            for j, eval_batch in enumerate(eval_dataloader):
                with torch.no_grad():
                    outputs = model(**eval_batch)

                eval_loss = outputs.loss
                logits = outputs.logits
                predictions = torch.argmax(logits, dim=-1)

                eval_metric.add(eval_loss.item(), 1)
                accuracy_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())
                f1_metric.add_batch(predictions=predictions.cpu().numpy(), references=eval_batch["labels"].cpu().numpy())

            eval_loss_avg = eval_metric[0] / eval_metric[1]
            accuracy = accuracy_metric.compute()
            f1 = f1_metric.compute()

            print(f"\nsteps: {step_count}, eval_loss_avg:{eval_loss_avg}, train loss:{train_loss_avg}, acc: {accuracy['accuracy']}, f1: {f1['f1']}")

            table_data.append([
                step_count,
                f"{train_loss_avg}",
                f"{eval_loss_avg}",
                f"{accuracy['accuracy']}",
                f"{f1['f1']}"
            ])

            wandb.log({"train/loss": train_loss_avg,
                       "eval/loss": eval_loss_avg,
                       "eval/accuracy": accuracy,
                       "eval/f1": f1,
                       })
            train_metric.reset()

print(tabulate(table_data, headers="firstrow", tablefmt="grid"))
run.finish()

6. 推理服务部署

部署方式:

  • 封装成 FastAPI 微服务;
  • 提供 /predict 接口用于文本分类;
  • 通过 Docker 容器化部署。
bash 复制代码
docker build -t claim-detection-service:latest .
docker run -p 8000:8000 claim-detection-service:latest

可通过 http://localhost:8000/docs 访问 Swagger UI 进行预测。

7. 从 BERT 到 RAG:下一步探索

本次微调实验主要聚焦在语义分类任务

下一阶段计划探索如何将该模型融入 RAG(Retrieval-Augmented Generation) 流程中。
后续方向:

  1. 领域微调(Domain Adaptation);
  2. 使用 LangChain 构建 RAG 管线(文档检索 + claim 识别 + LLM 回答);
  3. 探索 LangGraph 实现可解释的 claim 追溯链(Explainable AI)。

附:资源与参考

GitHub 源码
Hugging Face Model Card
Attention Is All You Need (Vaswani et al., 2017)
Hugging Face Doc
Dive Into Deep Learning

相关推荐
weixin_456904275 小时前
Transformer架构发展历史
深度学习·架构·transformer
X.AI66612 小时前
YouTube评论情感分析项目84%正确率:基于BERT的实战复现与原理解析
人工智能·深度学习·bert
njsgcs14 小时前
读取文件夹内的pdf装换成npg给vlm分类人工确认然后填入excel vlmapi速度挺快 qwen3-vl-plus webbrowser.open
分类·pdf·excel
油泼辣子多加1 天前
【实战】自然语言处理--长文本分类(1)DPCNN算法
算法·自然语言处理·分类
盼小辉丶1 天前
视觉Transformer实战 | Transformer详解与实现
pytorch·深度学习·transformer·1024程序员节
jerryinwuhan2 天前
TableTime:将时序分类重构为表格理解任务,更有效对齐LLM语义空间
重构·分类·数据挖掘
2401_841495642 天前
【机器学习】k近邻法
人工智能·python·机器学习·分类··knn·k近邻算法
Light602 天前
深度学习 × 计算机视觉 × Kaggle(上):从理论殿堂起步 ——像素、特征与模型的进化之路
人工智能·深度学习·计算机视觉·卷积神经网络·transformer·特征学习
机器学习之心2 天前
未发表,三大创新!OCSSA-VMD-Transformer-Adaboost特征提取+编码器+集成学习轴承故障诊断
深度学习·transformer·集成学习·ocssa-vmd