
1. 项目概述
- 目标:在国产信创硬件上训练长文本分类模型,并部署 API 提供推理服务
- 任务类型:多类别/二分类 NLP 问题
- 输入数据:长文本(如 2000+ token)
- 输出:文本类别预测
- 硬件环境 :
- 2 × Ascend 910B2 NPU
- 鲲鹏 ARM64 CPU
- 昆仑信创操作系统(如 openEuler / 麒麟)
- 软件环境 :
-
Python >= 3.9
-
PyTorch 2.2.1(Ascend 镜像):
bashpip install torch==2.2.1 -f https://ascend-pytorch-mirror.huawei.com/whl/torch/ -
Transformers
-
NumPy, pandas, scikit-learn
-
2. 数据处理
2.1 文本切分
- 长文本超过 BERT 最大长度(如 512)时,使用 BERT Split :
- 将文本按句子或固定长度切分为多个片段
- 每个片段通过 BERT 编码
- 拼接或平均片段的 hidden states 作为文本表示
- 可选:文本重叠切分,保证上下文连续性
2.2 数据集示例
python
import pandas as pd
from sklearn.model_selection import train_test_split
df = pd.read_csv('long_text_dataset.csv') # columns: text, label
train_texts, val_texts, train_labels, val_labels = train_test_split(
df['text'].tolist(), df['label'].tolist(), test_size=0.1, random_state=42
)
2.3 Tokenizer
python
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")
def encode_texts(texts, max_len=512):
encoded_list = []
for text in texts:
# 分段处理
segments = [text[i:i+max_len] for i in range(0, len(text), max_len)]
encoded_segments = [tokenizer(s, padding='max_length', truncation=True, return_tensors='pt') for s in segments]
encoded_list.append(encoded_segments)
return encoded_list
3. 模型设计:BERTSplitLSTM
3.1 结构说明
-
BERT Encoder:
- 每个文本片段使用 BERT 编码
- 输出
[CLS]或最后隐藏层
-
片段合并:
- 将片段向量按顺序拼接或送入 LSTM
-
LSTM:
- 捕捉跨片段的长文本上下文
- 双向 LSTM 可选
-
分类层:
- 全连接 + softmax
- 输出文本类别
3.2 PyTorch 示例
python
import torch
import torch.nn as nn
from transformers import BertModel
class BERTSplitLSTM(nn.Module):
def __init__(self, bert_model_name='bert-base-chinese', lstm_hidden=256, num_classes=10):
super().__init__()
self.bert = BertModel.from_pretrained(bert_model_name)
self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
hidden_size=lstm_hidden,
num_layers=1,
batch_first=True,
bidirectional=True)
self.fc = nn.Linear(2*lstm_hidden, num_classes)
def forward(self, segments_batch):
# segments_batch: list of segments tensors, shape [batch, seg_len, hidden_size]
segment_outputs = []
for segments in segments_batch:
seg_embs = []
for seg in segments:
output = self.bert(**seg).last_hidden_state[:,0,:] # CLS token
seg_embs.append(output)
seg_embs = torch.stack(seg_embs, dim=1) # [batch, n_segments, hidden_size]
lstm_out, _ = self.lstm(seg_embs)
final_output = lstm_out[:,-1,:]
segment_outputs.append(final_output)
return self.fc(torch.cat(segment_outputs, dim=0))
4. 训练配置
-
损失函数 :
CrossEntropyLoss -
优化器 :
AdamW(带权重衰减) -
学习率策略:线性 warmup + decay
-
批大小:根据显存,双卡 910B2 可尝试 batch=4~8
-
梯度累积:长文本可使用梯度累积降低显存占用
-
混合精度训练:
pythonscaler = torch.cuda.amp.GradScaler()
4.1 训练示例
python
from torch.utils.data import DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
for epoch in range(epochs):
for batch in train_loader:
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(batch['segments'])
loss = criterion(outputs, batch['labels'])
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
5. 模型部署
5.1 模型保存
python
torch.save(model.state_dict(), "bert_split_lstm_finetune.pt")
5.2 转换 OM(Ascend)
bash
# 导出 ONNX
python export_to_onnx.py --model_path bert_split_lstm_finetune.pt --output bert_split_lstm.onnx
# ONNX → OM
atc --model=bert_split_lstm.onnx --framework=5 --output=bert_split_lstm.om --soc_version=Ascend910B2 --input_shape="input_ids:1,512"
5.3 API 部署
- 方法 :
- 使用 FastAPI
- 支持多进程 + 多线程 + 批量请求
python
from fastapi import FastAPI
import torch
app = FastAPI()
model = load_om_model("bert_split_lstm.om", device='ascend', card_ids=[0,1])
@app.post("/predict")
async def predict(text: str):
segments = encode_texts([text])
pred = model(segments)
return {"label": pred.argmax(dim=-1).item()}
6. 性能优化
- 多卡并行:910B2 ×2 NPU
- 批量推理:增加吞吐
- 多线程/异步:利用 CPU 做数据预处理
- 量化/半精度训练:降低显存,提升速度
- 预热模型:推理前跑几次 batch
7. 验证与上线
- 小规模文本测试模型准确性
- 大批量文本测试吞吐和延迟
- 监控 NPU 显存、CPU、推理延迟