【实战】自然语言处理--长文本分类(2)BERTSplitLSTM算法

BERTSplitLSTM算法

1. 定义

bert_overlap_split_bilstm 是一种面向长文档级文本分类的混合模型,其核心思想是:

  1. BERT(Transformer)对文本段落进行上下文编码;
  2. 对长文本进行 Overlap Split(重叠切分),保证跨段上下文连续;
  3. 将各段 BERT 输出的向量序列输入 BiLSTM,进一步建模段间时序依赖;
  4. 最后通过 Attention + 全连接层 汇总文档表示并分类。

这种结构兼具 Transformer 的全局语义理解能力与 LSTM 的序列依赖建模优势,适合长度超出 BERT 单次最大输入的长文本。

2. 原理

  1. BERT 编码

    • 对每个文本片段(segment)分别调用预训练 BERT,输出其 pooler_output(即 [CLS] 向量),获得固定维度的段级表示。
  2. 重叠切分(Overlap Split)

    • 将原始文档的 token 序列以 segment_len 为段长、overlap 为重叠步长滑窗切分。
    • 保证前后相邻两个段有一定 token 重合,缓解截断带来的信息丢失。
  3. BiLSTM 建模

    • 将所有段的 BERT 表示按文档顺序拼接成形状 [batch, num_seg, hid_dim] 的张量。
    • 输入双向 LSTM,捕捉段与段之间的依赖关系。
  4. 段级 Attention

    • 对 LSTM 输出序列计算注意力权重,动态聚焦对分类贡献最大的段。
    • 权重加权后汇总得到文档级向量 doc_repr
  5. 全连接分类

    • doc_repr 经过 Dropout 和 Linear 层,输出各类别的 logits。

3. 模型结构图示

复制代码
原始文档 (tokens)  
      │  
  Overlap Split  
      ↓  
 ┌───────────────────────┐   ...   ┌───────────────────────┐
 │ segment 1 (≤ seg_len) │       │ segment k (≤ seg_len) │
 └───────────────────────┘       └───────────────────────┘
      │                               │
  [BERT 编码]                       [BERT 编码]
      ↓                               ↓
  pooled_output  ←--------------- ... ---------------→  pooled_output  
      │                               │
      └──────────→ 拼接为 [batch, num_seg, hid] ←──────────┘
                             │
                         BiLSTM  
                             ↓
                       LSTM 输出序列  
                             │
                          Attention  
                             ↓
                   文档级向量 doc_repr  
                             │
                          FC 分类  
                             ↓
                     输出 logits → Softmax  

4. 关键代码解读

4.1 Overlap Split

python 复制代码
def split_segments(self, input_ids, type_ids, attn_mask):
    doc_len = input_ids.size(1)
    step = self.seg_len - self.overlap
    starts = list(range(0, doc_len, step))
    segments = []
    for s in starts:
        end = min(s + self.seg_len, doc_len)
        segments.append({
            'input_ids':   input_ids[:, s:end],
            'token_type_ids': type_ids[:, s:end],
            'attention_mask': attn_mask[:, s:end]
        })
    return segments
  • seg_len:每段最大长度(如 200)
  • overlap:相邻段重叠长度(如 50)
  • step = seg_len - overlap 作为滑窗步长

4.2 前向计算

python 复制代码
# 拆分后对每段调用 BERT
reprs = []
for seg in segs:
    outputs = self.bert(**seg, return_dict=True)
    pooled = outputs.pooler_output        # [batch, hid]
    reprs.append(pooled.unsqueeze(1))     # 增加段维度

reprs = torch.cat(reprs, dim=1)           # [batch, num_seg, hid]

# BiLSTM 编码
lstm_out, _ = self.lstm(reprs)            # [batch, num_seg, hid]

# 段级 Attention
weights = self.attn(lstm_out).transpose(1, 2)   # [batch, 1, num_seg]
doc_repr = torch.bmm(weights, lstm_out).squeeze(1)  # [batch, hid]

# 分类
logits = self.fc(doc_repr)

5. 训练过程

  1. 数据准备

    • TSV 格式:label \t content
    • 调用 prepare_data.read_data 读取,BERT Tokenizer 分批编码(最大长度 doc_maxlen,此处可设置大于单段长度以保证切分后不丢)
    • 构建 TensorDatasetDataLoader
  2. 模型初始化

    python 复制代码
    model = BertLSTMWithOverlap(
        pretrained_dir=bert_model_dir,
        num_classes=num_classes,
        segment_len=segment_len,
        overlap=overlap
    )
    model.to(device)
  3. 冻结 BERT(可选)

    python 复制代码
    if feature_extract:
        for p in model.bert.parameters(): 
            p.requires_grad = False
  4. Optimizer & Scheduler

    python 复制代码
    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR, weight_decay=1e-4
    )
    scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda=lambda ep: 0.1 if ep > EPOCHS*0.8 else 1
    )
  5. 训练循环

    • 每个 epoch
      • 遍历训练集 batch,调用 train_one_step 计算 loss、反向传播、更新参数
      • 遍历验证集 batch,调用 eval_one_step 评估
      • 记录并保存最佳模型
  6. 结果可视化

    • 使用 Matplotlib 绘制 lossacc 曲线,保存到指定目录。

6. 模型作用与应用场景

  • 长文本分类:如新闻分级、法律文档分类、学术论文主题判别
  • 文档检索候选排序:将长文档编码为紧凑表示后用于相似度计算
  • 多标签分类:调整输出层即能扩展到多标签场景
  • 信息抽取:可替换最后的分类头为序列标注模块,用于实体识别等

7. 后续优化方向

  1. 动态切分:根据句边界或段落边界智能切分,减少人为超参数依赖。
  2. Transformer 级联:用轻量级 Transformer 层替代 BiLSTM,提高并行效率。
  3. 多尺度注意力:在段内和段间分别建模注意力,捕捉细粒度与全局信息。
  4. 蒸馏与剪枝:在推理阶段用小型 BERT 或量化模型,降低部署成本。

代码分析

配置代码

1. 路径相关

  • project_dir

    • 值:'/root/autodl-tmp/NLP/text_class/'
    • 作用:项目根目录,后续所有相对路径都基于这个基础目录拼接而成。
  • data_base_dir

    • 值:project_dir + 'data/thucnews/'
    • 作用:数据存放目录,cnews.train.txtcnews.val.txtcnews.test.txt 等文件都在这里。
  • bert_model_dir

    • 值:'/root/autodl-tmp/NLP/text_class/pretrained_models/dienstag/chinese-roberta-wwm-ext-large/'
    • 作用:预训练 BERT/RoBERTa 模型所在目录,用于 BertTokenizerFast.from_pretrainedBertModel.from_pretrained
  • save_dir

    • 值:'./save/20250718/'
    • 作用:训练过程中保存 checkpoint(.pt 文件)的目录,并以日期命名,便于版本管理。
  • imgs_dir

    • 值:'./imgs/20250718/'
    • 作用:绘制的训练曲线(loss、acc)等图像保存目录,同样以日期区分。

2. 训练策略

  • feature_extract

    • 类型:bool
    • 值:True
    • 含义:
      • True 时,冻结 BERT(不更新其参数),只用作"特征提取器";
      • False 时,对 BERT 进行微调,BERT 的权重也参与梯度更新。
    • 何时切换:
      • 数据量少或想加快训练时,可先 True,后续微调再设为 False
  • train_from_scrach

    • 类型:bool
    • 值:True
    • 含义:
      • True 时,从头开始训练,不加载已有 checkpoint;
      • False 时,会检查 save_dir 是否有 last_new_checkpoint 文件,并尝试加载继续训练。
  • last_new_checkpoint

    • 类型:str
    • 值:'best_epoch004_acc0.9327.pt'
    • 含义:
      • train_from_scrach=Falsesave_dir 下存在此文件,就会自动加载这份权重继续训练。

3. 标签映射

  • labels

    • 值:['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经']
    • 含义:本次分类任务的 10 个类别名称列表,对应 THUCNews 中选择的 10 类。
  • label2id

    • 构造方式:{l:i for i,l in enumerate(labels)}
    • 作用:将中文标签映射为整数 ID,例如 { '体育':0, '娱乐':1, ... },方便模型输出与损失计算。
  • id2label

    • 构造方式:{i:l for i,l in enumerate(labels)}
    • 作用:将模型预测的整数 ID 转回中文标签,便于结果展示与分析。
  • num_classes

    • 值:len(labels)
    • 作用:类别数,作为最后一层全连接层的输出维度。

4. 优化器与训练超参

  • LR

    • 值:5e-4
    • 含义:Adam 优化器的初始学习率。
  • EPOCHS

    • 值:4
    • 含义:训练轮数,总共遍历训练集 4 次。
  • batch_size

    • 值:1500
    • 含义:每个 DataLoader 中的批量大小。
    • 注意:如果显存不足,可适当调小。

5. 文档切分参数

  • doc_maxlen

    • 值:1000
    • 含义:整体文档在调用 tokenizer 时的最大输入长度
    • 建议:该值应≥segment_len,以保证切分后不丢失尾部。
  • segment_len

    • 值:150
    • 含义:每个切分段的最大 token 长度(即 BERT 单次能处理的长度)。
  • overlap

    • 值:50
    • 含义:相邻两段之间的重叠 token 数量,用于保留上下文连续性
    • 计算:每段起始位置步长为 segment_len -- overlap = 100

预处理代码

python 复制代码
import time
import pandas as pd
import torch
from transformers import BertTokenizerFast

# 配置文件导入
from config import bert_model_dir, doc_maxlen, batch_size, data_base_dir, label2id

# 初始化 tokenizer
tokenizer = BertTokenizerFast.from_pretrained(bert_model_dir)
  1. 依赖导入

    • time:用于测量编码耗时。
    • pandas as pd:读取 TSV 文本数据并做基本清洗。
    • torch:张量操作、构造 TensorDataset、创建 DataLoader
    • BertTokenizerFast:HuggingFace 的快速 tokenizer 实现,支持批量编码。
    • config.py 导入五个关键配置:
      • bert_model_dir:预训练模型路径;
      • doc_maxlen:编码时最大 token 长度;
      • batch_size:DataLoader 的 batch 大小;
      • data_base_dir:数据目录(这里暂未使用,但通常和 filepath 配合);
      • label2id:标签到整数 ID 的映射字典。
  2. Tokenizer 实例化

    python 复制代码
    tokenizer = BertTokenizerFast.from_pretrained(bert_model_dir)
    • 加载指定目录下的预训练 BERT 分词器(包含词表、特殊 token 定义等)。

read_data(filepath: str)

python 复制代码
def read_data(filepath: str):
    """
    读取 tsv 数据,返回文本列表和标签张量
    filepath: 文件路径,格式为 label \t content
    """
    # 使用 pandas 读取 tsv 文件
    df = pd.read_csv(filepath, sep='\t', names=['label', 'content'], encoding='utf-8')
    df = df.dropna()

    texts = df['content'].tolist()
    labels = [label2id[label] for label in df['label']]
    labels = torch.tensor(labels, dtype=torch.long)

    print(f"加载数据: 样本数={len(texts)}, 样本示例长度={len(texts[0])} 字符, 标签样本={labels.shape}")
    return texts, labels
  • 输入

    • filepath:TSV 文件路径,文件每行形如 标签\t文本内容
  • 步骤解析

    1. pd.read_csv(..., sep='\t', names=['label','content'])

      • 用 pandas 读取带分隔符的文本文件,并手动指定两列名称。
    2. df.dropna()

      • 丢弃任何含有空值的行,避免后续分词或映射报错。
    3. texts = df['content'].tolist()

      • 提取文本列为 Python 列表,每个元素都是一个字符串。
    4. labels = [label2id[label] for label in df['label']]

      • 通过配置中的 label2id 将中文标签转换成整数 ID。
    5. labels = torch.tensor(..., dtype=torch.long)

      • 转为 PyTorch 的长整型张量,便于放入模型计算损失。
    6. 打印日志:样本总数、第一条文本长度、标签张量形状,帮助排查问题。

  • 输出

    • textsList[str],所有文本;
    • labelstorch.LongTensor,形状 (样本数,)

bert_encode(texts: list[str])

python 复制代码
def bert_encode(texts: list[str]):
    """
    使用 tokenizer 对文本进行分批编码,返回字典形式的 tensor 对象
    """
    start = time.time()
    print("开始编码...")
    # 调用 tokenizer,支持批量
    inputs = tokenizer(
        texts,
        add_special_tokens=True,
        max_length=doc_maxlen,
        padding='longest',
        truncation=True,
        return_tensors='pt'
    )
    end = time.time()
    elapsed = end - start
    print(f"编码完成,耗时 {elapsed//60:.0f} 分 {elapsed%60:.2f} 秒")
    return inputs
  • 功能

    批量将纯文本列表转成 BERT 可接受的张量格式,包含:

    • input_ids:token 索引;
    • token_type_ids:句子 A/B 区分,单文档通常全 0;
    • attention_mask:标记有效 token(1)与填充(0)。
  • 关键参数

    • add_special_tokens=True:添加 [CLS][SEP] 等;
    • max_length=doc_maxlen:截断或填充到固定长度;
    • padding='longest':批内最长序列长度填充,而非固定最大,以节省计算;
    • truncation=True:超长截断;
    • return_tensors='pt':直接返回 PyTorch 张量。
  • 性能监控

    • time.time() 记录整个批量编码耗时,并打印分钟+秒级别耗时。
  • 返回

    • inputs:字典,包含上述三种 tensor,形状均为 (batch, seq_len)

load_data(filepath: str, shuffle: bool=False)

python 复制代码
def load_data(filepath: str, shuffle: bool = False) -> torch.utils.data.DataLoader:
    """
    读取文件并封装为 DataLoader
    filepath: 数据文件路径
    shuffle: 是否打乱顺序
    返回: DataLoader({input_ids, token_type_ids, attention_mask}, labels)
    """
    texts, labels = read_data(filepath)
    inputs = bert_encode(texts)

    # 构造 TensorDataset,需要所有字段为 tensor
    dataset = torch.utils.data.TensorDataset(
        inputs['input_ids'],
        inputs['token_type_ids'],
        inputs['attention_mask'],
        labels
    )
    # 创建 DataLoader
    loader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=2,
        pin_memory=True
    )
    return loader
  • 流程

    1. 调用 read_data 拿到 textslabels
    2. 调用 bert_encode 拿到三大输入张量;
    3. TensorDataset 将所有 tensor 合并,一个样本就是 (input_ids, token_type_ids, attention_mask, label) 四元组;
    4. DataLoader 封装并行加载:
      • batch_size:每批样本数;
      • shuffle:训练集一般为 True,验证/测试集设 False
      • num_workers=2:两个子进程并行读取,加快 IO;
      • pin_memory=True:提高 GPU 端数据传输效率。
  • 返回

    • 一个 DataLoader 实例,每次迭代返回一个 batch,形式为 (input_ids, token_type_ids, attention_mask, labels),即可直接送入模型训练或评估函数。

模型结构

1.整体结构

python 复制代码
class BertLSTMWithOverlap(torch.nn.Module):
    """BERT Overlap Split + 双向 LSTM + Attention + 全连接分类"""
    ...
  • 继承自 torch.nn.Module,是一整个可训练/推理的模型。
  • 文档字符串简要说明了它集成了四部分:
    1. BERT(Transformer 编码器)
    2. Overlap Split(重叠切分长文档)
    3. 双向 LSTM(段间序列建模)
    4. Attention + 全连接分类

2.初始化方法 __init__

python 复制代码
def __init__(
    self, pretrained_dir: str, num_classes: int,
    segment_len: int = 200, overlap: int = 50, dropout_p: float = 0.5
):
    super().__init__()
    self.seg_len = segment_len
    self.overlap = overlap
  1. 参数说明

    • pretrained_dir:BERT 预训练模型所在目录;
    • num_classes:要分类的类别总数;
    • segment_len:每段的最大 token 数;
    • overlap:段与段之间的重叠 token 数;
    • dropout_p:分类层 Dropout 比例。
  2. 成员变量

    • self.seg_lenself.overlap:后面切分时会用到。
2.1 加载 BERT
python 复制代码
# 加载 BERT 配置与模型
config = BertConfig.from_json_file(os.path.join(pretrained_dir, 'config.json'))
self.bert = BertModel.from_pretrained(pretrained_dir, config=config)
  • 先从目录里的 config.json 读取模型配置,确保与本地文件一一对应;
  • 再用 from_pretrained 加载模型权重。
2.2 冻结 BERT 参数(可选)
python 复制代码
# 特征提取模式:冻结 BERT 参数
if feature_extract:
    for p in self.bert.parameters():
        p.requires_grad = False
  • feature_extract 来自全局配置(config.py),若为 True,则冻结所有 BERT 参数,BERT 只用来提取特征,不参与梯度更新。
  • 冻结可以加快训练、减少显存占用,尤其在数据量有限时常用。
2.3 构建上层网络
python 复制代码
d_model = config.hidden_size

# 双向 LSTM
self.lstm = torch.nn.LSTM(
    input_size=d_model,
    hidden_size=d_model // 2,
    bidirectional=True,
    batch_first=True
)
  • d_model:BERT 的隐藏层维度(例如 1024);
  • 双向 LSTM
    • 输入维度 = d_model
    • 隐藏维度 = d_model // 2,因为双向拼接后又恢复到 d_model
    • batch_first=True,输入输出的 shape 都是 (batch, seq, hid)
python 复制代码
# 段级注意力
self.attn = torch.nn.Sequential(
    torch.nn.Linear(d_model, d_model),
    torch.nn.Tanh(),
    torch.nn.Linear(d_model, 1, bias=False),
    torch.nn.Softmax(dim=1)
)
  • 段级 Attention
    1. 先把每个 LSTM 输出向量映射到同维度,再做 Tanh
    2. 用一维线性层算出"得分",不加 bias;
    3. Softmax(dim=1) 把所有段的得分归一化,得到权重矩阵。
python 复制代码
# 分类层
self.fc = torch.nn.Sequential(
    torch.nn.Dropout(p=dropout_p),
    torch.nn.Linear(d_model, num_classes)
)
  • 分类头 :先 Dropout,再把 d_model 维度映射到类别数上,得到最终 logits。

3.文档切分方法 split_segments

python 复制代码
def split_segments(self, input_ids, type_ids, attn_mask):
    """根据 segment_len 与 overlap 拆分文档为多个片段"""
    doc_len = input_ids.size(1)
    step = self.seg_len - self.overlap
    starts = list(range(0, doc_len, step))
    segments = []
    for s in starts:
        end = min(s + self.seg_len, doc_len)
        segments.append({
            'input_ids': input_ids[:, s:end],
            'token_type_ids': type_ids[:, s:end],
            'attention_mask': attn_mask[:, s:end]
        })
    return segments
  1. doc_len:输入序列的总长度(包括所有 padding)。
  2. step:滑窗步长 = 段长 − 重叠
  3. starts:所有窗口起始位置列表;
  4. 遍历切分:每个窗口切出一段字典,包含三种 BERT 输入张量。
  5. 返回 :一个段列表,后续 forward 会遍历每段分别送入 BERT。

4.前向计算 forward

python 复制代码
def forward(self, input_ids, type_ids, attn_mask):
    # 1. 切分
    segs = self.split_segments(input_ids, type_ids, attn_mask)

    # 2. 每段 BERT 编码并收集 pooled_output
    reprs = []
    for seg in segs:
        outputs = self.bert(**seg, return_dict=True)
        pooled  = outputs.pooler_output        # [batch, d_model]
        reprs.append(pooled.unsqueeze(1))     # 扩增段维度

    reprs = torch.cat(reprs, dim=1)           # [batch, num_seg, d_model]

    # 3. BiLSTM 建模段序列
    lstm_out, _ = self.lstm(reprs)            # [batch, num_seg, d_model]

    # 4. Attention 汇聚
    weights = self.attn(lstm_out).transpose(1, 2)   # [batch, 1, num_seg]
    doc_repr = torch.bmm(weights, lstm_out).squeeze(1)  # [batch, d_model]

    # 5. 分类
    logits = self.fc(doc_repr)  
    return logits
  1. 切分 :得到 segs,列表长度 = 窗口数 (num_seg);
  2. BERT 编码 :对每个段独立调用 self.bert(**seg),取出 pooler_output (即 [CLS] 向量),并在第二维增加一维,以便后面拼接;
  3. 拼接reprs 形状变为 (batch_size, num_seg, d_model)
  4. LSTM :建模段与段之间的时序关系,输出同样形状的 lstm_out
  5. Attention :对每个段计算权重,weights 形状 (batch,1,num_seg),然后用 bmm(批量矩阵乘)把权重和 lstm_out 相乘,得到文档级别表示 doc_repr
  6. 分类 :把 doc_repr 送入 self.fc,输出 (batch_size, num_classes) 的 logits。

训练过程

1. 单步训练:train_one_step

python 复制代码
def train_one_step(model, batch, optimizer):
    model.train()                                   # ① 切换模型到训练模式(启用 Dropout、BatchNorm 等)
    optimizer.zero_grad()                           # ② 清空上一轮累积的梯度
    inputs = [t.to(device) for t in batch[:3]]      # ③ 将 input_ids、token_type_ids、attention_mask 移到 GPU/CPU
    labels = batch[3].to(device)                    # ④ 将标签也移到同一设备
    logits = model(*inputs)                         # ⑤ 前向计算:得到本 batch 的 logits
    loss = loss_fn(logits, labels)                  # ⑥ 计算交叉熵损失
    preds = torch.argmax(logits, dim=1).cpu().numpy()      # ⑦ 取最大概率下标作为预测,送回 CPU 并转 numpy
    acc = metric_fn(preds, labels.cpu().numpy())            # ⑧ 计算准确率指标
    loss.backward()                                 # ⑨ 反向传播,计算梯度
    optimizer.step()                                # ⑩ 优化器更新参数
    return loss.item(), acc                         # ⑪ 返回该 batch 的损失值和准确率
  • model.train()

    告诉 PyTorch 这是训练阶段,启用 Dropout 和 BatchNorm 的"训练模式"。

  • optimizer.zero_grad()

    PyTorch 默认累积梯度,需要在每次更新前先清零。

  • ③--④ 数据搬运

    把输入和标签搬到同一个设备(GPU 或 CPU),确保后续计算不报错。

  • ⑤ 前向计算
    model(*inputs) 等同于 model(input_ids, token_type_ids, attention_mask),执行你在 forward 里定义的重叠切分+BERT+LSTM+Attention+FC 整个流程。

  • ⑥ 损失函数

    交叉熵(CrossEntropyLoss)常用于多分类。

  • ⑦--⑧ 预测与指标

    • torch.argmax 得到每条样本最可能的类别;
    • .cpu().numpy() 把结果搬回 CPU,并转为 NumPy;
    • metric_fn(这里是 accuracy_score)计算准确率。
  • ⑨--⑩ 反向传播 & 更新

    • loss.backward() 计算所有可学习参数的梯度;
    • optimizer.step() 根据梯度更新参数。
  • ⑪ 返回

    用于在主循环里汇总整个训练集的平均 loss 和 acc。

2. 单步验证:eval_one_step

python 复制代码
def eval_one_step(model, batch):
    model.eval()                                    # ① 切换到评估模式(禁用 Dropout、BatchNorm 固定)
    with torch.no_grad():                           # ② 关闭梯度计算,节省显存和计算
        inputs = [t.to(device) for t in batch[:3]]  # ③ 搬运输入
        labels = batch[3].to(device)                # ④ 搬运标签
        logits = model(*inputs)                     # ⑤ 前向计算
        loss = loss_fn(logits, labels)              # ⑥ 计算损失
        preds = torch.argmax(logits, dim=1).cpu().numpy()  
        acc = metric_fn(preds, labels.cpu().numpy())
    return loss.item(), acc                        # ⑦ 返回损失与准确率
  • 与训练不同的点:

    • model.eval():关闭训练专用行为(如 Dropout);
    • torch.no_grad():不计算梯度,仅做前向,减小显存占用并加速。
  • 同样返回损失和准确率,用于验证集统计。

3. 整体训练流程:train_model

python 复制代码
def train_model(model, train_loader, val_loader, optimizer, scheduler=None):
    print(f"训练集 loader 总 batch 数: {len(train_loader)}")
    print(f"验证集 loader 总 batch 数: {len(val_loader)}")

    start = time.time()            # ① 记录训练开始时间
    best_acc = 0.0                 # ② 用于追踪验证集上最优准确率

    for epoch in range(1, EPOCHS+1):
        print_time_bar()           # ③ 打印时间分隔线

        # ------ 训练阶段 ------  
        train_losses, train_accs = [], []
        for batch in train_loader:
            l, a = train_one_step(model, batch, optimizer)
            train_losses.append(l)
            train_accs.append(a)

        # ------ 验证阶段 ------  
        val_losses, val_accs = [], []
        for batch in val_loader:
            l, a = eval_one_step(model, batch)
            val_losses.append(l)
            val_accs.append(a)

        # 学习率调度  
        if scheduler:
            scheduler.step()       # 根据策略(如后期降 LR)更新学习率

        # ------ 指标汇总与日志 ------  
        mean_train_loss = sum(train_losses) / len(train_losses)
        mean_train_acc  = sum(train_accs) / len(train_accs)
        mean_val_loss   = sum(val_losses)   / len(val_losses)
        mean_val_acc    = sum(val_accs)     / len(val_accs)
        history.loc[epoch] = [  # 记录到 DataFrame
            epoch,
            mean_train_loss, mean_train_acc,
            mean_val_loss,   mean_val_acc
        ]
        print(f"Epoch {epoch}: loss={mean_train_loss:.4f}, acc={mean_train_acc:.4f}, \
val_loss={mean_val_loss:.4f}, val_acc={mean_val_acc:.4f}")

        # ------ 保存最佳模型 ------  
        if mean_val_acc > best_acc:
            best_acc = mean_val_acc
            path = os.path.join(save_dir, f"best_epoch{epoch:03d}_acc{best_acc:.4f}.pt")
            state = {
                'epoch': epoch,
                'model': copy.deepcopy(model.state_dict()),
                'optimizer': optimizer.state_dict()
            }
            torch.save(state, path)  # 只保留在验证集上最优的 checkpoint

    elapsed = time.time() - start   # ④ 计算总耗时
    print_time_bar()
    print(f"训练完成,共耗时 {elapsed//3600:.0f}h {(elapsed%3600)//60:.0f}m {elapsed%60:.2f}s,最佳 val_acc={best_acc:.4f}")
    return history                 # 返回包含整个训练历史的 DataFrame
  1. 批次数打印

    直观展示训练/验证用的数据大小。

  2. best_acc

    用于条件式保存最优模型。

  3. Epoch 循环

    • 训练 :遍历所有 train_loader,累积各 batch 的 loss/acc;
    • 验证 :遍历 val_loader,同样累积;
    • 调度器:在每个 epoch 末尾更新学习率(如线性衰减、分段调度等);
    • 汇总与记录 :算出平均指标,打印并存入 history
    • Checkpoint:当验证集准确率更好时保存模型权重和优化器状态。
  4. 总耗时统计

    训练结束后报告总用时,帮助评估资源消耗。

补充问题

bert_encode输出结果是什么?

当你把一批原始文本传给 bert_encode(texts),它会帮你把这些"人能看懂的句子"转成"机器能处理的数字张量",并打包成一个字典,主要包含三个部分:

  1. input_ids(输入 ID)

    • 通俗说,就是把每个词(或子词)都换成一个数字编号。
    • 比如句子"我爱北京"会被分成"[CLS] 我 爱 北 京 [SEP]",然后对应编号 [101, 2769, 4263, 1266, 1872, 102]
    • 最终形状是 (batch_size, seq_len),每行就是一句话的词 ID 列表。
  2. token_type_ids(句子类型 ID)

    • 如果你一次只传一句话,所有值都填 0;
    • 在问答或句子对任务里,第一句标 0,第二句标 1,方便模型区分"第一句话"和"第二句话"。
    • 形状也是 (batch_size, seq_len),对应每个 token 的类型。
  3. attention_mask(注意力掩码)

    • 通俗点就是告诉模型"哪些位置是真正的词,哪些是补齐的空白"。
    • 真正的词位置标 1,padding(为了对齐批次长度塞进去的空位置)标 0。
    • 这样 BERT 在计算时就只关注"真实内容"部分,不去浪费计算在空白上。
    • 形状同样是 (batch_size, seq_len)

pooler_output作用?

pooler_output 是 HuggingFace 实现的 BertModel 在前向传播中给出的"句子级"向量,具体作用和特性如下:

  1. 什么是 pooler_output

    • 它对应于 Transformer 最后一层的第一个 token(即 [CLS])在经过一个额外的全连接层(dense)+ tanh 激活后的输出。
    • 设计初衷是把 [CLS] 位置的上下文信息"浓缩"成一个固定维度的向量,常用于分类、回归或下游任务的输入。
  2. 有没有降维?

    • 默认情况下,不降维
      • dense 层的输入维度和输出维度都等于 hidden_size,也就是 BERT 模型本身的隐藏层维度(如 768、1024 等)。
      • 因此 pooler_output 的 shape 总是 [batch_size, hidden_size],与 last_hidden_state[:,0](未经过额外变换的 [CLS] 向量)在维度上保持一致。
  3. 有没有做表示转换?

    • 是的,会做一次仿射 + 激活的转换:
      1. 仿射变换output = inputs @ W + b,其中 WbBertPooler 模块里的可学习参数,W 形状为 (hidden_size, hidden_size)
      2. tanh 激活 :给这个仿射结果加上非线性,丰富表达能力,并且将输出压缩到 (-1,1) 区间。
  4. 参数如何定义维度?

    • 全都跟着 BertConfig.hidden_size

      • 在初始化 BertModel 时,会读取 config.hidden_size,并用它来构造:
        • Transformer 每层的维度
        • BertPooler.dense 层的权重矩阵大小。
    • 如果你想改维度(比如降至 512 或升至 2048),需要手动在上游 重写这一层,或者在 forward 后接一个额外的降维/升维层。

self.attn为什么选择使用这种结构?

这里的 self.attn 并不是在做 Transformer 那种"多头自注意力"(Multi‑Head Self‑Attention),而是一个段级的"加性注意力"(又称 Bahdanau Attention 或者 Attention Pooling),它的作用只是给每个段一个可学习的权重,然后在段维度上做加权平均,从而得到整个文档的表示。

python 复制代码
self.attn = torch.nn.Sequential(
    torch.nn.Linear(d_model, d_model),
    torch.nn.Tanh(),
    torch.nn.Linear(d_model, 1, bias=False),
    torch.nn.Softmax(dim=1)
)

逐步来看它为什么这么写、有什么好处:

  1. 简单高效

    • 我们只需要对 BiLSTM 输出的每个段向量 h_i ∈ ℝ^{d_model} 计算一个标量分数,然后做 softmax 得到权重。
    • 这种"先映射 + Tanh,再线性投一维、最后 softmax"正是经典的加性注意力(additive attention)。
    • 它参数少、计算快,完全能满足"选出对分类最重要的那几段"这个需求。
  2. 比调用多头注意力更轻量

    • torch.nn.MultiheadAttention 那种模块需要构造 query/key/value 三套映射,还要做多头拆分、缩放点积、再拼回去,对段级 pooling 来说就太重了。
    • 我们这里根本不需要学习复杂的"段与段之间的交互",只想学一个权重向量对每个段打分------所以用自定义的小网络更合适。
  3. 可控性强,易于理解和调试

    • 全部写在一个 Sequential 里,结构一目了然:
      1. Linear(d_model, d_model) + Tanh ------ 给 h_i 加上一层非线性变换,增强表达能力;
      2. Linear(d_model,1) ------ 把变换后的向量投射到一个分数;
      3. Softmax(dim=1) ------ 把所有段的分数正规化成概率分布。
    • 如果用库函数,调参、改结构(比如加层、换激活)都没这么直接。
  4. 恰好符合"Attention Pooling"需求

    • 很多文本分类、信息检索任务里,用到的 Attention 只是把多条向量加权求和,这就叫 Attention Pooling。
    • 既不是序列到序列(seq2seq)的编码‐解码注意力,也不是 Transformer 里的多头自注意力,二者都更复杂也不必要。

BiLSTM + Attention 聚合如何设计的?

  1. 动态聚焦

    • 不同文档、不同任务,重要段落不一样。Attention 为每个段分配可学习的权重,让模型"聚焦"到最能帮助分类的部分------
      • 比如新闻里最关键词的段落、评论中最情绪化的句子。
  2. 结构简单、参数少

    • Attention Pooling 只用两层小的全连接,就能实现加权求和,无需引入复杂的多头机制,既高效又易调。
  3. 顺序感与全局感兼顾

    • BiLSTM 捕获了段间时序;Attention 则从整体上衡量每段重要性。
    • 二者结合,既不丢掉段序信息,也能跳出平均池化的"所有段一视同仁"弊端。
  4. 输出固定维度

    • 不管原始文档切成多少段,最终 doc_repr 始终是 [batch_size, d_model],方便接下来的全连接分类层或其他下游网络。
相关推荐
WWZZ20254 小时前
快速上手大模型:深度学习2(实践:深度学习基础、线性回归)
人工智能·深度学习·算法·计算机视觉·机器人·大模型·slam
初级炼丹师(爱说实话版)4 小时前
算法面经常考题整理(1)机器学习
人工智能·算法·机器学习
被AI抢饭碗的人4 小时前
算法题(246):负环(bellman_ford算法)
算法
大数据张老师5 小时前
数据结构——折半查找
数据结构·算法·查找·折半查找
熬了夜的程序员6 小时前
【LeetCode】87. 扰乱字符串
算法·leetcode·职场和发展·排序算法
是码农一枚6 小时前
全域感知,主动预警:视频汇聚平台EasyCVR打造水库大坝智慧安防视频监控智能分析方案
算法
MicroTech20256 小时前
微算法科技(NASDAQ MLGO)探索自适应差分隐私机制(如AdaDP),根据任务复杂度动态调整噪声
人工智能·科技·算法
是码农一枚6 小时前
全域互联,统一管控:EasyCVR构建多区域视频监控“一网统管”新范式
算法
听情歌落俗6 小时前
c++通讯录管理系统
开发语言·c++·算法