【信创】华为昇腾NLP算法训练

1. 项目概述

  • 目标:在国产信创硬件上训练长文本分类模型,并部署 API 提供推理服务
  • 任务类型:多类别/二分类 NLP 问题
  • 输入数据:长文本(如 2000+ token)
  • 输出:文本类别预测
  • 硬件环境
    • 2 × Ascend 910B2 NPU
    • 鲲鹏 ARM64 CPU
    • 昆仑信创操作系统(如 openEuler / 麒麟)
  • 软件环境
    • Python >= 3.9

    • PyTorch 2.2.1(Ascend 镜像):

      bash 复制代码
      pip install torch==2.2.1 -f https://ascend-pytorch-mirror.huawei.com/whl/torch/
    • Transformers

    • NumPy, pandas, scikit-learn

2. 数据处理

2.1 文本切分

  • 长文本超过 BERT 最大长度(如 512)时,使用 BERT Split
    • 将文本按句子或固定长度切分为多个片段
    • 每个片段通过 BERT 编码
    • 拼接或平均片段的 hidden states 作为文本表示
  • 可选:文本重叠切分,保证上下文连续性

2.2 数据集示例

python 复制代码
import pandas as pd
from sklearn.model_selection import train_test_split

df = pd.read_csv('long_text_dataset.csv')  # columns: text, label
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df['text'].tolist(), df['label'].tolist(), test_size=0.1, random_state=42
)

2.3 Tokenizer

python 复制代码
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("bert-base-chinese")

def encode_texts(texts, max_len=512):
    encoded_list = []
    for text in texts:
        # 分段处理
        segments = [text[i:i+max_len] for i in range(0, len(text), max_len)]
        encoded_segments = [tokenizer(s, padding='max_length', truncation=True, return_tensors='pt') for s in segments]
        encoded_list.append(encoded_segments)
    return encoded_list

3. 模型设计:BERTSplitLSTM

3.1 结构说明

  1. BERT Encoder

    • 每个文本片段使用 BERT 编码
    • 输出 [CLS] 或最后隐藏层
  2. 片段合并

    • 将片段向量按顺序拼接或送入 LSTM
  3. LSTM

    • 捕捉跨片段的长文本上下文
    • 双向 LSTM 可选
  4. 分类层

    • 全连接 + softmax
    • 输出文本类别

3.2 PyTorch 示例

python 复制代码
import torch
import torch.nn as nn
from transformers import BertModel

class BERTSplitLSTM(nn.Module):
    def __init__(self, bert_model_name='bert-base-chinese', lstm_hidden=256, num_classes=10):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.lstm = nn.LSTM(input_size=self.bert.config.hidden_size,
                            hidden_size=lstm_hidden,
                            num_layers=1,
                            batch_first=True,
                            bidirectional=True)
        self.fc = nn.Linear(2*lstm_hidden, num_classes)

    def forward(self, segments_batch):
        # segments_batch: list of segments tensors, shape [batch, seg_len, hidden_size]
        segment_outputs = []
        for segments in segments_batch:
            seg_embs = []
            for seg in segments:
                output = self.bert(**seg).last_hidden_state[:,0,:]  # CLS token
                seg_embs.append(output)
            seg_embs = torch.stack(seg_embs, dim=1)  # [batch, n_segments, hidden_size]
            lstm_out, _ = self.lstm(seg_embs)
            final_output = lstm_out[:,-1,:]
            segment_outputs.append(final_output)
        return self.fc(torch.cat(segment_outputs, dim=0))

4. 训练配置

  • 损失函数CrossEntropyLoss

  • 优化器AdamW(带权重衰减)

  • 学习率策略:线性 warmup + decay

  • 批大小:根据显存,双卡 910B2 可尝试 batch=4~8

  • 梯度累积:长文本可使用梯度累积降低显存占用

  • 混合精度训练

    python 复制代码
    scaler = torch.cuda.amp.GradScaler()

4.1 训练示例

python 复制代码
from torch.utils.data import DataLoader

train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

for epoch in range(epochs):
    for batch in train_loader:
        optimizer.zero_grad()
        with torch.cuda.amp.autocast():
            outputs = model(batch['segments'])
            loss = criterion(outputs, batch['labels'])
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

5. 模型部署

5.1 模型保存

python 复制代码
torch.save(model.state_dict(), "bert_split_lstm_finetune.pt")

5.2 转换 OM(Ascend)

bash 复制代码
# 导出 ONNX
python export_to_onnx.py --model_path bert_split_lstm_finetune.pt --output bert_split_lstm.onnx

# ONNX → OM
atc --model=bert_split_lstm.onnx --framework=5 --output=bert_split_lstm.om --soc_version=Ascend910B2 --input_shape="input_ids:1,512"

5.3 API 部署

  • 方法
    • 使用 FastAPI
    • 支持多进程 + 多线程 + 批量请求
python 复制代码
from fastapi import FastAPI
import torch

app = FastAPI()
model = load_om_model("bert_split_lstm.om", device='ascend', card_ids=[0,1])

@app.post("/predict")
async def predict(text: str):
    segments = encode_texts([text])
    pred = model(segments)
    return {"label": pred.argmax(dim=-1).item()}

6. 性能优化

  • 多卡并行:910B2 ×2 NPU
  • 批量推理:增加吞吐
  • 多线程/异步:利用 CPU 做数据预处理
  • 量化/半精度训练:降低显存,提升速度
  • 预热模型:推理前跑几次 batch

7. 验证与上线

  • 小规模文本测试模型准确性
  • 大批量文本测试吞吐和延迟
  • 监控 NPU 显存、CPU、推理延迟
相关推荐
杨_晨2 小时前
大模型微调训练FAQ - Batch Size与参数配置
人工智能·机器学习·ai·语言模型·batch
测试_AI_一辰2 小时前
Agent & RAG 测试工程 02:RAG 从最小闭环到可信
开发语言·前端·人工智能·github·ai编程
tudficdew2 小时前
C++中的策略模式实战
开发语言·c++·算法
查无此人byebye2 小时前
手写Multi-Head Attention多头注意力机制,Pytorch实现与原理详解
人工智能·pytorch·python·深度学习·transformer
Gavin在路上2 小时前
SpringAIAlibaba之深度剖析序列化过程中LinkedHashMap类型转换异常(十)
人工智能
naruto_lnq2 小时前
实时语音处理库
开发语言·c++·算法
wfeqhfxz25887822 小时前
击剑运动员姿态识别与关键部位检测_YOLOv26模型应用与优化
人工智能·yolo·目标跟踪
OpenCSG2 小时前
OpenCSG(开放传神)开源数据贡献解析:3大标杆数据集,筑牢中文AI基建
人工智能·开源
独自破碎E2 小时前
【数组】分糖果问题
java·开发语言·算法