BERTSplitLSTM算法
1. 定义
bert_overlap_split_bilstm 是一种面向长文档级文本分类的混合模型,其核心思想是:
- 用 BERT(Transformer)对文本段落进行上下文编码;
- 对长文本进行 Overlap Split(重叠切分),保证跨段上下文连续;
- 将各段 BERT 输出的向量序列输入 BiLSTM,进一步建模段间时序依赖;
- 最后通过 Attention + 全连接层 汇总文档表示并分类。
这种结构兼具 Transformer 的全局语义理解能力与 LSTM 的序列依赖建模优势,适合长度超出 BERT 单次最大输入的长文本。
2. 原理
-
BERT 编码
- 对每个文本片段(segment)分别调用预训练 BERT,输出其 pooler_output(即
[CLS]向量),获得固定维度的段级表示。
- 对每个文本片段(segment)分别调用预训练 BERT,输出其 pooler_output(即
-
重叠切分(Overlap Split)
- 将原始文档的 token 序列以
segment_len为段长、overlap为重叠步长滑窗切分。 - 保证前后相邻两个段有一定 token 重合,缓解截断带来的信息丢失。
- 将原始文档的 token 序列以
-
BiLSTM 建模
- 将所有段的 BERT 表示按文档顺序拼接成形状
[batch, num_seg, hid_dim]的张量。 - 输入双向 LSTM,捕捉段与段之间的依赖关系。
- 将所有段的 BERT 表示按文档顺序拼接成形状
-
段级 Attention
- 对 LSTM 输出序列计算注意力权重,动态聚焦对分类贡献最大的段。
- 权重加权后汇总得到文档级向量
doc_repr。
-
全连接分类
- 对
doc_repr经过 Dropout 和 Linear 层,输出各类别的 logits。
- 对
3. 模型结构图示
原始文档 (tokens)
│
Overlap Split
↓
┌───────────────────────┐ ... ┌───────────────────────┐
│ segment 1 (≤ seg_len) │ │ segment k (≤ seg_len) │
└───────────────────────┘ └───────────────────────┘
│ │
[BERT 编码] [BERT 编码]
↓ ↓
pooled_output ←--------------- ... ---------------→ pooled_output
│ │
└──────────→ 拼接为 [batch, num_seg, hid] ←──────────┘
│
BiLSTM
↓
LSTM 输出序列
│
Attention
↓
文档级向量 doc_repr
│
FC 分类
↓
输出 logits → Softmax
4. 关键代码解读
4.1 Overlap Split
python
def split_segments(self, input_ids, type_ids, attn_mask):
doc_len = input_ids.size(1)
step = self.seg_len - self.overlap
starts = list(range(0, doc_len, step))
segments = []
for s in starts:
end = min(s + self.seg_len, doc_len)
segments.append({
'input_ids': input_ids[:, s:end],
'token_type_ids': type_ids[:, s:end],
'attention_mask': attn_mask[:, s:end]
})
return segments
seg_len:每段最大长度(如 200)overlap:相邻段重叠长度(如 50)step = seg_len - overlap作为滑窗步长
4.2 前向计算
python
# 拆分后对每段调用 BERT
reprs = []
for seg in segs:
outputs = self.bert(**seg, return_dict=True)
pooled = outputs.pooler_output # [batch, hid]
reprs.append(pooled.unsqueeze(1)) # 增加段维度
reprs = torch.cat(reprs, dim=1) # [batch, num_seg, hid]
# BiLSTM 编码
lstm_out, _ = self.lstm(reprs) # [batch, num_seg, hid]
# 段级 Attention
weights = self.attn(lstm_out).transpose(1, 2) # [batch, 1, num_seg]
doc_repr = torch.bmm(weights, lstm_out).squeeze(1) # [batch, hid]
# 分类
logits = self.fc(doc_repr)
5. 训练过程
-
数据准备
- TSV 格式:
label \t content - 调用
prepare_data.read_data读取,BERT Tokenizer 分批编码(最大长度doc_maxlen,此处可设置大于单段长度以保证切分后不丢) - 构建
TensorDataset与DataLoader。
- TSV 格式:
-
模型初始化
pythonmodel = BertLSTMWithOverlap( pretrained_dir=bert_model_dir, num_classes=num_classes, segment_len=segment_len, overlap=overlap ) model.to(device) -
冻结 BERT(可选)
pythonif feature_extract: for p in model.bert.parameters(): p.requires_grad = False -
Optimizer & Scheduler
pythonoptimizer = torch.optim.Adam( filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=1e-4 ) scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lr_lambda=lambda ep: 0.1 if ep > EPOCHS*0.8 else 1 ) -
训练循环
- 每个
epoch:- 遍历训练集 batch,调用
train_one_step计算 loss、反向传播、更新参数 - 遍历验证集 batch,调用
eval_one_step评估 - 记录并保存最佳模型
- 遍历训练集 batch,调用
- 每个
-
结果可视化
- 使用 Matplotlib 绘制
loss与acc曲线,保存到指定目录。
- 使用 Matplotlib 绘制
6. 模型作用与应用场景
- 长文本分类:如新闻分级、法律文档分类、学术论文主题判别
- 文档检索候选排序:将长文档编码为紧凑表示后用于相似度计算
- 多标签分类:调整输出层即能扩展到多标签场景
- 信息抽取:可替换最后的分类头为序列标注模块,用于实体识别等
7. 后续优化方向
- 动态切分:根据句边界或段落边界智能切分,减少人为超参数依赖。
- Transformer 级联:用轻量级 Transformer 层替代 BiLSTM,提高并行效率。
- 多尺度注意力:在段内和段间分别建模注意力,捕捉细粒度与全局信息。
- 蒸馏与剪枝:在推理阶段用小型 BERT 或量化模型,降低部署成本。
代码分析
配置代码
1. 路径相关
-
project_dir- 值:
'/root/autodl-tmp/NLP/text_class/' - 作用:项目根目录,后续所有相对路径都基于这个基础目录拼接而成。
- 值:
-
data_base_dir- 值:
project_dir + 'data/thucnews/' - 作用:数据存放目录,
cnews.train.txt、cnews.val.txt、cnews.test.txt等文件都在这里。
- 值:
-
bert_model_dir- 值:
'/root/autodl-tmp/NLP/text_class/pretrained_models/dienstag/chinese-roberta-wwm-ext-large/' - 作用:预训练 BERT/RoBERTa 模型所在目录,用于
BertTokenizerFast.from_pretrained和BertModel.from_pretrained。
- 值:
-
save_dir- 值:
'./save/20250718/' - 作用:训练过程中保存 checkpoint(
.pt文件)的目录,并以日期命名,便于版本管理。
- 值:
-
imgs_dir- 值:
'./imgs/20250718/' - 作用:绘制的训练曲线(loss、acc)等图像保存目录,同样以日期区分。
- 值:
2. 训练策略
-
feature_extract- 类型:
bool - 值:
True - 含义:
True时,冻结 BERT(不更新其参数),只用作"特征提取器";False时,对 BERT 进行微调,BERT 的权重也参与梯度更新。
- 何时切换:
- 数据量少或想加快训练时,可先
True,后续微调再设为False。
- 数据量少或想加快训练时,可先
- 类型:
-
train_from_scrach- 类型:
bool - 值:
True - 含义:
True时,从头开始训练,不加载已有 checkpoint;False时,会检查save_dir是否有last_new_checkpoint文件,并尝试加载继续训练。
- 类型:
-
last_new_checkpoint- 类型:
str - 值:
'best_epoch004_acc0.9327.pt' - 含义:
- 当
train_from_scrach=False且save_dir下存在此文件,就会自动加载这份权重继续训练。
- 当
- 类型:
3. 标签映射
-
labels- 值:
['体育', '娱乐', '家居', '房产', '教育', '时尚', '时政', '游戏', '科技', '财经'] - 含义:本次分类任务的 10 个类别名称列表,对应 THUCNews 中选择的 10 类。
- 值:
-
label2id- 构造方式:
{l:i for i,l in enumerate(labels)} - 作用:将中文标签映射为整数 ID,例如
{ '体育':0, '娱乐':1, ... },方便模型输出与损失计算。
- 构造方式:
-
id2label- 构造方式:
{i:l for i,l in enumerate(labels)} - 作用:将模型预测的整数 ID 转回中文标签,便于结果展示与分析。
- 构造方式:
-
num_classes- 值:
len(labels) - 作用:类别数,作为最后一层全连接层的输出维度。
- 值:
4. 优化器与训练超参
-
LR- 值:
5e-4 - 含义:Adam 优化器的初始学习率。
- 值:
-
EPOCHS- 值:
4 - 含义:训练轮数,总共遍历训练集 4 次。
- 值:
-
batch_size- 值:
1500 - 含义:每个
DataLoader中的批量大小。 - 注意:如果显存不足,可适当调小。
- 值:
5. 文档切分参数
-
doc_maxlen- 值:
1000 - 含义:整体文档在调用 tokenizer 时的最大输入长度。
- 建议:该值应≥
segment_len,以保证切分后不丢失尾部。
- 值:
-
segment_len- 值:
150 - 含义:每个切分段的最大 token 长度(即 BERT 单次能处理的长度)。
- 值:
-
overlap- 值:
50 - 含义:相邻两段之间的重叠 token 数量,用于保留上下文连续性
- 计算:每段起始位置步长为
segment_len -- overlap = 100。
- 值:
预处理代码
python
import time
import pandas as pd
import torch
from transformers import BertTokenizerFast
# 配置文件导入
from config import bert_model_dir, doc_maxlen, batch_size, data_base_dir, label2id
# 初始化 tokenizer
tokenizer = BertTokenizerFast.from_pretrained(bert_model_dir)
-
依赖导入
time:用于测量编码耗时。pandas as pd:读取 TSV 文本数据并做基本清洗。torch:张量操作、构造TensorDataset、创建DataLoader。BertTokenizerFast:HuggingFace 的快速 tokenizer 实现,支持批量编码。- 从
config.py导入五个关键配置:bert_model_dir:预训练模型路径;doc_maxlen:编码时最大 token 长度;batch_size:DataLoader 的 batch 大小;data_base_dir:数据目录(这里暂未使用,但通常和filepath配合);label2id:标签到整数 ID 的映射字典。
-
Tokenizer 实例化
pythontokenizer = BertTokenizerFast.from_pretrained(bert_model_dir)- 加载指定目录下的预训练 BERT 分词器(包含词表、特殊 token 定义等)。
read_data(filepath: str)
python
def read_data(filepath: str):
"""
读取 tsv 数据,返回文本列表和标签张量
filepath: 文件路径,格式为 label \t content
"""
# 使用 pandas 读取 tsv 文件
df = pd.read_csv(filepath, sep='\t', names=['label', 'content'], encoding='utf-8')
df = df.dropna()
texts = df['content'].tolist()
labels = [label2id[label] for label in df['label']]
labels = torch.tensor(labels, dtype=torch.long)
print(f"加载数据: 样本数={len(texts)}, 样本示例长度={len(texts[0])} 字符, 标签样本={labels.shape}")
return texts, labels
-
输入
filepath:TSV 文件路径,文件每行形如标签\t文本内容。
-
步骤解析
-
pd.read_csv(..., sep='\t', names=['label','content'])- 用 pandas 读取带分隔符的文本文件,并手动指定两列名称。
-
df.dropna()- 丢弃任何含有空值的行,避免后续分词或映射报错。
-
texts = df['content'].tolist()- 提取文本列为 Python 列表,每个元素都是一个字符串。
-
labels = [label2id[label] for label in df['label']]- 通过配置中的
label2id将中文标签转换成整数 ID。
- 通过配置中的
-
labels = torch.tensor(..., dtype=torch.long)- 转为 PyTorch 的长整型张量,便于放入模型计算损失。
-
打印日志:样本总数、第一条文本长度、标签张量形状,帮助排查问题。
-
-
输出
texts:List[str],所有文本;labels:torch.LongTensor,形状(样本数,)。
bert_encode(texts: list[str])
python
def bert_encode(texts: list[str]):
"""
使用 tokenizer 对文本进行分批编码,返回字典形式的 tensor 对象
"""
start = time.time()
print("开始编码...")
# 调用 tokenizer,支持批量
inputs = tokenizer(
texts,
add_special_tokens=True,
max_length=doc_maxlen,
padding='longest',
truncation=True,
return_tensors='pt'
)
end = time.time()
elapsed = end - start
print(f"编码完成,耗时 {elapsed//60:.0f} 分 {elapsed%60:.2f} 秒")
return inputs
-
功能
批量将纯文本列表转成 BERT 可接受的张量格式,包含:
input_ids:token 索引;token_type_ids:句子 A/B 区分,单文档通常全 0;attention_mask:标记有效 token(1)与填充(0)。
-
关键参数
add_special_tokens=True:添加[CLS]、[SEP]等;max_length=doc_maxlen:截断或填充到固定长度;padding='longest':批内最长序列长度填充,而非固定最大,以节省计算;truncation=True:超长截断;return_tensors='pt':直接返回 PyTorch 张量。
-
性能监控
- 用
time.time()记录整个批量编码耗时,并打印分钟+秒级别耗时。
- 用
-
返回
inputs:字典,包含上述三种 tensor,形状均为(batch, seq_len)。
load_data(filepath: str, shuffle: bool=False)
python
def load_data(filepath: str, shuffle: bool = False) -> torch.utils.data.DataLoader:
"""
读取文件并封装为 DataLoader
filepath: 数据文件路径
shuffle: 是否打乱顺序
返回: DataLoader({input_ids, token_type_ids, attention_mask}, labels)
"""
texts, labels = read_data(filepath)
inputs = bert_encode(texts)
# 构造 TensorDataset,需要所有字段为 tensor
dataset = torch.utils.data.TensorDataset(
inputs['input_ids'],
inputs['token_type_ids'],
inputs['attention_mask'],
labels
)
# 创建 DataLoader
loader = torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=2,
pin_memory=True
)
return loader
-
流程
- 调用
read_data拿到texts与labels; - 调用
bert_encode拿到三大输入张量; - 用
TensorDataset将所有 tensor 合并,一个样本就是(input_ids, token_type_ids, attention_mask, label)四元组; - 用
DataLoader封装并行加载:batch_size:每批样本数;shuffle:训练集一般为True,验证/测试集设False;num_workers=2:两个子进程并行读取,加快 IO;pin_memory=True:提高 GPU 端数据传输效率。
- 调用
-
返回
- 一个
DataLoader实例,每次迭代返回一个 batch,形式为(input_ids, token_type_ids, attention_mask, labels),即可直接送入模型训练或评估函数。
- 一个
模型结构
1.整体结构
python
class BertLSTMWithOverlap(torch.nn.Module):
"""BERT Overlap Split + 双向 LSTM + Attention + 全连接分类"""
...
- 继承自
torch.nn.Module,是一整个可训练/推理的模型。 - 文档字符串简要说明了它集成了四部分:
- BERT(Transformer 编码器)
- Overlap Split(重叠切分长文档)
- 双向 LSTM(段间序列建模)
- Attention + 全连接分类
2.初始化方法 __init__
python
def __init__(
self, pretrained_dir: str, num_classes: int,
segment_len: int = 200, overlap: int = 50, dropout_p: float = 0.5
):
super().__init__()
self.seg_len = segment_len
self.overlap = overlap
-
参数说明
pretrained_dir:BERT 预训练模型所在目录;num_classes:要分类的类别总数;segment_len:每段的最大 token 数;overlap:段与段之间的重叠 token 数;dropout_p:分类层 Dropout 比例。
-
成员变量
self.seg_len、self.overlap:后面切分时会用到。
2.1 加载 BERT
python
# 加载 BERT 配置与模型
config = BertConfig.from_json_file(os.path.join(pretrained_dir, 'config.json'))
self.bert = BertModel.from_pretrained(pretrained_dir, config=config)
- 先从目录里的
config.json读取模型配置,确保与本地文件一一对应; - 再用
from_pretrained加载模型权重。
2.2 冻结 BERT 参数(可选)
python
# 特征提取模式:冻结 BERT 参数
if feature_extract:
for p in self.bert.parameters():
p.requires_grad = False
feature_extract来自全局配置(config.py),若为True,则冻结所有 BERT 参数,BERT 只用来提取特征,不参与梯度更新。- 冻结可以加快训练、减少显存占用,尤其在数据量有限时常用。
2.3 构建上层网络
python
d_model = config.hidden_size
# 双向 LSTM
self.lstm = torch.nn.LSTM(
input_size=d_model,
hidden_size=d_model // 2,
bidirectional=True,
batch_first=True
)
d_model:BERT 的隐藏层维度(例如 1024);- 双向 LSTM :
- 输入维度 =
d_model; - 隐藏维度 =
d_model // 2,因为双向拼接后又恢复到d_model; batch_first=True,输入输出的 shape 都是(batch, seq, hid)。
- 输入维度 =
python
# 段级注意力
self.attn = torch.nn.Sequential(
torch.nn.Linear(d_model, d_model),
torch.nn.Tanh(),
torch.nn.Linear(d_model, 1, bias=False),
torch.nn.Softmax(dim=1)
)
- 段级 Attention :
- 先把每个 LSTM 输出向量映射到同维度,再做
Tanh; - 用一维线性层算出"得分",不加 bias;
- 用
Softmax(dim=1)把所有段的得分归一化,得到权重矩阵。
- 先把每个 LSTM 输出向量映射到同维度,再做
python
# 分类层
self.fc = torch.nn.Sequential(
torch.nn.Dropout(p=dropout_p),
torch.nn.Linear(d_model, num_classes)
)
- 分类头 :先 Dropout,再把
d_model维度映射到类别数上,得到最终 logits。
3.文档切分方法 split_segments
python
def split_segments(self, input_ids, type_ids, attn_mask):
"""根据 segment_len 与 overlap 拆分文档为多个片段"""
doc_len = input_ids.size(1)
step = self.seg_len - self.overlap
starts = list(range(0, doc_len, step))
segments = []
for s in starts:
end = min(s + self.seg_len, doc_len)
segments.append({
'input_ids': input_ids[:, s:end],
'token_type_ids': type_ids[:, s:end],
'attention_mask': attn_mask[:, s:end]
})
return segments
doc_len:输入序列的总长度(包括所有 padding)。step:滑窗步长 = 段长 − 重叠starts:所有窗口起始位置列表;- 遍历切分:每个窗口切出一段字典,包含三种 BERT 输入张量。
- 返回 :一个段列表,后续
forward会遍历每段分别送入 BERT。
4.前向计算 forward
python
def forward(self, input_ids, type_ids, attn_mask):
# 1. 切分
segs = self.split_segments(input_ids, type_ids, attn_mask)
# 2. 每段 BERT 编码并收集 pooled_output
reprs = []
for seg in segs:
outputs = self.bert(**seg, return_dict=True)
pooled = outputs.pooler_output # [batch, d_model]
reprs.append(pooled.unsqueeze(1)) # 扩增段维度
reprs = torch.cat(reprs, dim=1) # [batch, num_seg, d_model]
# 3. BiLSTM 建模段序列
lstm_out, _ = self.lstm(reprs) # [batch, num_seg, d_model]
# 4. Attention 汇聚
weights = self.attn(lstm_out).transpose(1, 2) # [batch, 1, num_seg]
doc_repr = torch.bmm(weights, lstm_out).squeeze(1) # [batch, d_model]
# 5. 分类
logits = self.fc(doc_repr)
return logits
- 切分 :得到
segs,列表长度 = 窗口数 (num_seg); - BERT 编码 :对每个段独立调用
self.bert(**seg),取出pooler_output(即[CLS]向量),并在第二维增加一维,以便后面拼接; - 拼接 :
reprs形状变为(batch_size, num_seg, d_model); - LSTM :建模段与段之间的时序关系,输出同样形状的
lstm_out; - Attention :对每个段计算权重,
weights形状(batch,1,num_seg),然后用bmm(批量矩阵乘)把权重和lstm_out相乘,得到文档级别表示doc_repr; - 分类 :把
doc_repr送入self.fc,输出(batch_size, num_classes)的 logits。
训练过程
1. 单步训练:train_one_step
python
def train_one_step(model, batch, optimizer):
model.train() # ① 切换模型到训练模式(启用 Dropout、BatchNorm 等)
optimizer.zero_grad() # ② 清空上一轮累积的梯度
inputs = [t.to(device) for t in batch[:3]] # ③ 将 input_ids、token_type_ids、attention_mask 移到 GPU/CPU
labels = batch[3].to(device) # ④ 将标签也移到同一设备
logits = model(*inputs) # ⑤ 前向计算:得到本 batch 的 logits
loss = loss_fn(logits, labels) # ⑥ 计算交叉熵损失
preds = torch.argmax(logits, dim=1).cpu().numpy() # ⑦ 取最大概率下标作为预测,送回 CPU 并转 numpy
acc = metric_fn(preds, labels.cpu().numpy()) # ⑧ 计算准确率指标
loss.backward() # ⑨ 反向传播,计算梯度
optimizer.step() # ⑩ 优化器更新参数
return loss.item(), acc # ⑪ 返回该 batch 的损失值和准确率
-
①
model.train()告诉 PyTorch 这是训练阶段,启用 Dropout 和 BatchNorm 的"训练模式"。
-
②
optimizer.zero_grad()PyTorch 默认累积梯度,需要在每次更新前先清零。
-
③--④ 数据搬运
把输入和标签搬到同一个设备(GPU 或 CPU),确保后续计算不报错。
-
⑤ 前向计算
model(*inputs)等同于model(input_ids, token_type_ids, attention_mask),执行你在forward里定义的重叠切分+BERT+LSTM+Attention+FC 整个流程。 -
⑥ 损失函数
交叉熵(
CrossEntropyLoss)常用于多分类。 -
⑦--⑧ 预测与指标
torch.argmax得到每条样本最可能的类别;.cpu().numpy()把结果搬回 CPU,并转为 NumPy;metric_fn(这里是accuracy_score)计算准确率。
-
⑨--⑩ 反向传播 & 更新
loss.backward()计算所有可学习参数的梯度;optimizer.step()根据梯度更新参数。
-
⑪ 返回
用于在主循环里汇总整个训练集的平均 loss 和 acc。
2. 单步验证:eval_one_step
python
def eval_one_step(model, batch):
model.eval() # ① 切换到评估模式(禁用 Dropout、BatchNorm 固定)
with torch.no_grad(): # ② 关闭梯度计算,节省显存和计算
inputs = [t.to(device) for t in batch[:3]] # ③ 搬运输入
labels = batch[3].to(device) # ④ 搬运标签
logits = model(*inputs) # ⑤ 前向计算
loss = loss_fn(logits, labels) # ⑥ 计算损失
preds = torch.argmax(logits, dim=1).cpu().numpy()
acc = metric_fn(preds, labels.cpu().numpy())
return loss.item(), acc # ⑦ 返回损失与准确率
-
与训练不同的点:
model.eval():关闭训练专用行为(如 Dropout);torch.no_grad():不计算梯度,仅做前向,减小显存占用并加速。
-
同样返回损失和准确率,用于验证集统计。
3. 整体训练流程:train_model
python
def train_model(model, train_loader, val_loader, optimizer, scheduler=None):
print(f"训练集 loader 总 batch 数: {len(train_loader)}")
print(f"验证集 loader 总 batch 数: {len(val_loader)}")
start = time.time() # ① 记录训练开始时间
best_acc = 0.0 # ② 用于追踪验证集上最优准确率
for epoch in range(1, EPOCHS+1):
print_time_bar() # ③ 打印时间分隔线
# ------ 训练阶段 ------
train_losses, train_accs = [], []
for batch in train_loader:
l, a = train_one_step(model, batch, optimizer)
train_losses.append(l)
train_accs.append(a)
# ------ 验证阶段 ------
val_losses, val_accs = [], []
for batch in val_loader:
l, a = eval_one_step(model, batch)
val_losses.append(l)
val_accs.append(a)
# 学习率调度
if scheduler:
scheduler.step() # 根据策略(如后期降 LR)更新学习率
# ------ 指标汇总与日志 ------
mean_train_loss = sum(train_losses) / len(train_losses)
mean_train_acc = sum(train_accs) / len(train_accs)
mean_val_loss = sum(val_losses) / len(val_losses)
mean_val_acc = sum(val_accs) / len(val_accs)
history.loc[epoch] = [ # 记录到 DataFrame
epoch,
mean_train_loss, mean_train_acc,
mean_val_loss, mean_val_acc
]
print(f"Epoch {epoch}: loss={mean_train_loss:.4f}, acc={mean_train_acc:.4f}, \
val_loss={mean_val_loss:.4f}, val_acc={mean_val_acc:.4f}")
# ------ 保存最佳模型 ------
if mean_val_acc > best_acc:
best_acc = mean_val_acc
path = os.path.join(save_dir, f"best_epoch{epoch:03d}_acc{best_acc:.4f}.pt")
state = {
'epoch': epoch,
'model': copy.deepcopy(model.state_dict()),
'optimizer': optimizer.state_dict()
}
torch.save(state, path) # 只保留在验证集上最优的 checkpoint
elapsed = time.time() - start # ④ 计算总耗时
print_time_bar()
print(f"训练完成,共耗时 {elapsed//3600:.0f}h {(elapsed%3600)//60:.0f}m {elapsed%60:.2f}s,最佳 val_acc={best_acc:.4f}")
return history # 返回包含整个训练历史的 DataFrame
-
批次数打印
直观展示训练/验证用的数据大小。
-
best_acc用于条件式保存最优模型。
-
Epoch 循环
- 训练 :遍历所有
train_loader,累积各 batch 的 loss/acc; - 验证 :遍历
val_loader,同样累积; - 调度器:在每个 epoch 末尾更新学习率(如线性衰减、分段调度等);
- 汇总与记录 :算出平均指标,打印并存入
history; - Checkpoint:当验证集准确率更好时保存模型权重和优化器状态。
- 训练 :遍历所有
-
总耗时统计
训练结束后报告总用时,帮助评估资源消耗。
补充问题
bert_encode输出结果是什么?
当你把一批原始文本传给 bert_encode(texts),它会帮你把这些"人能看懂的句子"转成"机器能处理的数字张量",并打包成一个字典,主要包含三个部分:
-
input_ids(输入 ID)- 通俗说,就是把每个词(或子词)都换成一个数字编号。
- 比如句子"我爱北京"会被分成"[CLS] 我 爱 北 京 [SEP]",然后对应编号
[101, 2769, 4263, 1266, 1872, 102]。 - 最终形状是
(batch_size, seq_len),每行就是一句话的词 ID 列表。
-
token_type_ids(句子类型 ID)- 如果你一次只传一句话,所有值都填 0;
- 在问答或句子对任务里,第一句标 0,第二句标 1,方便模型区分"第一句话"和"第二句话"。
- 形状也是
(batch_size, seq_len),对应每个 token 的类型。
-
attention_mask(注意力掩码)- 通俗点就是告诉模型"哪些位置是真正的词,哪些是补齐的空白"。
- 真正的词位置标 1,padding(为了对齐批次长度塞进去的空位置)标 0。
- 这样 BERT 在计算时就只关注"真实内容"部分,不去浪费计算在空白上。
- 形状同样是
(batch_size, seq_len)。
pooler_output作用?
pooler_output 是 HuggingFace 实现的 BertModel 在前向传播中给出的"句子级"向量,具体作用和特性如下:
-
什么是
pooler_output?- 它对应于 Transformer 最后一层的第一个 token(即
[CLS])在经过一个额外的全连接层(dense)+tanh激活后的输出。 - 设计初衷是把
[CLS]位置的上下文信息"浓缩"成一个固定维度的向量,常用于分类、回归或下游任务的输入。
- 它对应于 Transformer 最后一层的第一个 token(即
-
有没有降维?
- 默认情况下,不降维 :
dense层的输入维度和输出维度都等于hidden_size,也就是 BERT 模型本身的隐藏层维度(如 768、1024 等)。- 因此
pooler_output的 shape 总是[batch_size, hidden_size],与last_hidden_state[:,0](未经过额外变换的[CLS]向量)在维度上保持一致。
- 默认情况下,不降维 :
-
有没有做表示转换?
- 是的,会做一次仿射 + 激活的转换:
- 仿射变换 :
output = inputs @ W + b,其中W和b是BertPooler模块里的可学习参数,W形状为(hidden_size, hidden_size)。 tanh激活 :给这个仿射结果加上非线性,丰富表达能力,并且将输出压缩到(-1,1)区间。
- 仿射变换 :
- 是的,会做一次仿射 + 激活的转换:
-
参数如何定义维度?
-
全都跟着
BertConfig.hidden_size:- 在初始化
BertModel时,会读取config.hidden_size,并用它来构造:- Transformer 每层的维度
BertPooler.dense层的权重矩阵大小。
- 在初始化
-
如果你想改维度(比如降至 512 或升至 2048),需要手动在上游 重写这一层,或者在
forward后接一个额外的降维/升维层。
-
self.attn为什么选择使用这种结构?
这里的 self.attn 并不是在做 Transformer 那种"多头自注意力"(Multi‑Head Self‑Attention),而是一个段级的"加性注意力"(又称 Bahdanau Attention 或者 Attention Pooling),它的作用只是给每个段一个可学习的权重,然后在段维度上做加权平均,从而得到整个文档的表示。
python
self.attn = torch.nn.Sequential(
torch.nn.Linear(d_model, d_model),
torch.nn.Tanh(),
torch.nn.Linear(d_model, 1, bias=False),
torch.nn.Softmax(dim=1)
)
逐步来看它为什么这么写、有什么好处:
-
简单高效
- 我们只需要对 BiLSTM 输出的每个段向量
h_i ∈ ℝ^{d_model}计算一个标量分数,然后做 softmax 得到权重。 - 这种"先映射 + Tanh,再线性投一维、最后 softmax"正是经典的加性注意力(additive attention)。
- 它参数少、计算快,完全能满足"选出对分类最重要的那几段"这个需求。
- 我们只需要对 BiLSTM 输出的每个段向量
-
比调用多头注意力更轻量
torch.nn.MultiheadAttention那种模块需要构造 query/key/value 三套映射,还要做多头拆分、缩放点积、再拼回去,对段级 pooling 来说就太重了。- 我们这里根本不需要学习复杂的"段与段之间的交互",只想学一个权重向量对每个段打分------所以用自定义的小网络更合适。
-
可控性强,易于理解和调试
- 全部写在一个
Sequential里,结构一目了然:Linear(d_model, d_model)+Tanh------ 给h_i加上一层非线性变换,增强表达能力;Linear(d_model,1)------ 把变换后的向量投射到一个分数;Softmax(dim=1)------ 把所有段的分数正规化成概率分布。
- 如果用库函数,调参、改结构(比如加层、换激活)都没这么直接。
- 全部写在一个
-
恰好符合"Attention Pooling"需求
- 很多文本分类、信息检索任务里,用到的 Attention 只是把多条向量加权求和,这就叫 Attention Pooling。
- 既不是序列到序列(seq2seq)的编码‐解码注意力,也不是 Transformer 里的多头自注意力,二者都更复杂也不必要。
BiLSTM + Attention 聚合如何设计的?
-
动态聚焦
- 不同文档、不同任务,重要段落不一样。Attention 为每个段分配可学习的权重,让模型"聚焦"到最能帮助分类的部分------
- 比如新闻里最关键词的段落、评论中最情绪化的句子。
- 不同文档、不同任务,重要段落不一样。Attention 为每个段分配可学习的权重,让模型"聚焦"到最能帮助分类的部分------
-
结构简单、参数少
- Attention Pooling 只用两层小的全连接,就能实现加权求和,无需引入复杂的多头机制,既高效又易调。
-
顺序感与全局感兼顾
- BiLSTM 捕获了段间时序;Attention 则从整体上衡量每段重要性。
- 二者结合,既不丢掉段序信息,也能跳出平均池化的"所有段一视同仁"弊端。
-
输出固定维度
- 不管原始文档切成多少段,最终
doc_repr始终是[batch_size, d_model],方便接下来的全连接分类层或其他下游网络。
- 不管原始文档切成多少段,最终