人工智能之语言领域
第十九章 深度学习框架
文章目录
- 人工智能之语言领域
- [前言 深度学习框架](#前言 深度学习框架)
- [19.1 TensorFlow 与 Keras](#19.1 TensorFlow 与 Keras)
- [19.1.1 核心架构与 NLP 模型搭建流程](#19.1.1 核心架构与 NLP 模型搭建流程)
- [TensorFlow 2.x + Keras 架构](#TensorFlow 2.x + Keras 架构)
- [NLP 模型搭建步骤(以文本分类为例)](#NLP 模型搭建步骤(以文本分类为例))
- [19.1.2 Keras 高层 API 的快速开发优势](#19.1.2 Keras 高层 API 的快速开发优势)
- [代码示例:使用 Keras + Transformers 微调 BERT](#代码示例:使用 Keras + Transformers 微调 BERT)
- [19.2 PyTorch](#19.2 PyTorch)
- [19.3 框架选型与项目实战](#19.3 框架选型与项目实战)
- 框架互操作性
- 小结
- 资料
前言 深度学习框架
深度学习框架是构建自然语言处理(NLP)模型的基础设施 。目前,TensorFlow/Keras 与 PyTorch 是两大主流选择,它们在设计理念、开发体验和部署生态上各有千秋。本章将深入对比这两大框架的核心特性,讲解其在 NLP 任务中的典型用法,并通过 BERT 微调实战 展示 PyTorch 的完整开发流程,帮助你做出合理的技术选型。
19.1 TensorFlow 与 Keras
19.1.1 核心架构与 NLP 模型搭建流程
TensorFlow 由 Google 开发,早期以静态计算图(Static Graph) 为核心(TF 1.x),2019 年发布的 TensorFlow 2.x 默认启用 Eager Execution(动态图),大幅提升了开发体验。
TensorFlow 2.x + Keras 架构
- Keras:高层 API,集成于 TF 2.x,提供简洁接口
- tf.data:高效数据管道
- tf.keras.Model:模型定义基类
- SavedModel:统一部署格式
原始文本数据
tf.data.Dataset
(分批/缓存/预取)
Tokenizer
(e.g., BERT tokenizer)
Keras Model
(Sequential/Functional/Subclassing)
model.fit()
自动训练循环
SavedModel / HDF5
NLP 模型搭建步骤(以文本分类为例)
- 数据预处理 →
tf.data.Dataset - 文本向量化 →
tf.keras.layers.TextVectorization或 Hugging Face Tokenizer - 构建模型 →
tf.keras.Sequential - 编译与训练 →
model.compile(),model.fit()
19.1.2 Keras 高层 API 的快速开发优势
Keras 以"用户友好、模块化、可扩展"著称,特别适合快速原型开发。
代码示例:使用 Keras + Transformers 微调 BERT
python
import tensorflow as tf
from transformers import TFAutoModel, AutoTokenizer
# Step 1: 加载分词器与预训练模型
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
bert_model = TFAutoModel.from_pretrained("bert-base-chinese")
# Step 2: 构建分类模型(函数式 API)
input_ids = tf.keras.Input(shape=(128,), dtype=tf.int32, name="input_ids")
attention_mask = tf.keras.Input(shape=(128,), dtype=tf.int32, name="attention_mask")
# BERT 编码
outputs = bert_model(input_ids, attention_mask=attention_mask)
cls_output = outputs.last_hidden_state[:, 0, :] # [CLS] 向量
# 分类头
dropout = tf.keras.layers.Dropout(0.1)(cls_output)
logits = tf.keras.layers.Dense(2, activation="softmax")(dropout)
model = tf.keras.Model(inputs=[input_ids, attention_mask], outputs=logits)
# Step 3: 编译与训练
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=2e-5),
loss="sparse_categorical_crossentropy",
metrics=["accuracy"]
)
# 假设已准备 dataset (input_ids, attention_mask, labels)
# model.fit(dataset, epochs=3)
✅ Keras 优势:
代码简洁,接近伪代码
内置回调(EarlyStopping, ModelCheckpoint)
无缝集成 TensorBoard 可视化
❌ 局限:动态控制流支持弱于 PyTorch
社区研究代码多基于 PyTorch
19.2 PyTorch
19.2.1 动态图机制与 NLP 序列建模适配
PyTorch 由 Facebook(现 Meta)开发,原生采用动态计算图(Dynamic Computation Graph) ,即"Define-by-Run":
- 每次前向传播都构建新计算图
- 支持任意 Python 控制流(if/for/while)
- 调试如同普通 Python 程序(可打断点、print)
这对 NLP 序列建模(如 RNN、自定义注意力)极为友好。
动态图 vs 静态图
| 特性 | PyTorch(动态图) | TensorFlow 1.x(静态图) |
|---|---|---|
| 开发体验 | 直观、易调试 | 需先定义图,再执行 |
| 控制流 | 原生支持 | 需用 tf.cond/tf.while_loop |
| 灵活性 | 高(适合研究) | 低 |
💡 通俗比喻:
- PyTorch 像"现场演奏":边写谱子边演奏
- TF 1.x 像"录音棚":先写完整乐谱,再录制
19.2.2 PyTorch Lightning:轻量化训练框架
PyTorch Lightning 是 PyTorch 的轻量级封装,将训练逻辑解耦为清晰模块,避免样板代码。
核心思想
- Trainer:负责训练循环、GPU/TPU、日志等工程细节
- LightningModule:专注模型定义、损失、优化器
LightningModule
-
init
-
training_step
-
configure_optimizers
Trainer -
fit()
-
自动处理:
• GPU分配
• 梯度累积
• 日志记录
代码结构对比
| 原生 PyTorch | PyTorch Lightning |
|---|---|
| 手写 for epoch, for batch | trainer.fit(model) |
手动 .to(device) |
自动分配设备 |
| 手动梯度清零/反向/更新 | 自动处理 |
19.3 框架选型与项目实战
19.3.1 不同场景下的框架选择依据
| 场景 | 推荐框架 | 理由 |
|---|---|---|
| 学术研究 / 快速实验 | PyTorch | 动态图灵活,社区论文复现多 |
| 工业部署 / 移动端 | TensorFlow | TFLite、TF.js 生态成熟 |
| 大模型微调 | PyTorch + Hugging Face | 社区支持最好(PEFT/Accelerate) |
| Kaggle / 快速原型 | Keras (TF) | 代码简洁,内置功能全 |
| 生产服务(高吞吐) | TensorFlow Serving / TorchServe | 两者均支持,TF 更早普及 |
📊 2024 年趋势:
- 研究领域:PyTorch 占据绝对主导(>90% 论文)
- 工业部署:两者并存,PyTorch 通过 TorchScript/TorchServe 追赶
19.3.2 基于 PyTorch 的 BERT 模型微调实战
我们将使用 PyTorch + Hugging Face Transformers + PyTorch Lightning 完成中文文本分类任务。
完整代码实现
python
# Step 1: 安装依赖
# pip install torch transformers pytorch-lightning datasets
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
import pytorch_lightning as pl
from datasets import load_dataset
# Step 2: 定义 LightningModule
class BertClassifier(pl.LightningModule):
def __init__(self, model_name="bert-base-chinese", num_labels=2, lr=2e-5):
super().__init__()
self.save_hyperparameters()
self.model = BertForSequenceClassification.from_pretrained(
model_name, num_labels=num_labels
)
self.lr = lr
def forward(self, input_ids, attention_mask, labels=None):
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels
)
return outputs
def training_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs.loss
self.log("train_loss", loss, prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
outputs = self(**batch)
loss = outputs.loss
preds = torch.argmax(outputs.logits, dim=1)
acc = (preds == batch["labels"]).float().mean()
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", acc, prog_bar=True)
def configure_optimizers(self):
return torch.optim.AdamW(self.parameters(), lr=self.lr)
# Step 3: 数据准备
class TextDataModule(pl.LightningDataModule):
def __init__(self, model_name="bert-base-chinese", batch_size=16, max_length=128):
super().__init__()
self.tokenizer = BertTokenizer.from_pretrained(model_name)
self.batch_size = batch_size
self.max_length = max_length
def setup(self, stage=None):
# 加载数据集(以 IMDb 中文版为例)
dataset = load_dataset("imdb") # 实际中替换为中文数据集
self.train_data = dataset["train"].shuffle(seed=42).select(range(1000))
self.val_data = dataset["test"].shuffle(seed=42).select(range(200))
def tokenize(self, examples):
return self.tokenizer(
examples["text"],
truncation=True,
padding="max_length",
max_length=self.max_length
)
def train_dataloader(self):
ds = self.train_data.map(self.tokenize, batched=True)
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
return DataLoader(ds, batch_size=self.batch_size, shuffle=True)
def val_dataloader(self):
ds = self.val_data.map(self.tokenize, batched=True)
ds.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
return DataLoader(ds, batch_size=self.batch_size)
# Step 4: 训练
if __name__ == "__main__":
# 初始化数据模块与模型
datamodule = TextDataModule()
model = BertClassifier()
# 初始化 Trainer
trainer = pl.Trainer(
max_epochs=3,
accelerator="auto", # 自动选择 CPU/GPU/TPU
devices="auto",
log_every_n_steps=10,
callbacks=[
pl.callbacks.EarlyStopping(monitor="val_loss", patience=2),
pl.callbacks.ModelCheckpoint(monitor="val_acc", mode="max")
]
)
# 开始训练
trainer.fit(model, datamodule)
关键优势解析
- 自动设备管理 :
accelerator="auto"自动使用 GPU - 无需手动训练循环 :
trainer.fit()封装所有细节 - 内置回调:早停、模型保存、日志记录
- 与 Hugging Face 无缝集成 :直接加载
BertForSequenceClassification
框架互操作性
-
ONNX(Open Neural Network Exchange) :可在 TF 与 PyTorch 间转换模型
python# PyTorch 导出 ONNX torch.onnx.export(model, dummy_input, "model.onnx") -
Hugging Face Transformers :同时支持 TF 与 PyTorch 模型(
TFAutoModel/AutoModel)
小结
- TensorFlow/Keras:适合工业部署、快速原型,静态图优化成熟
- PyTorch:研究首选,动态图灵活,生态活跃(尤其大模型)
- PyTorch Lightning:大幅提升 PyTorch 工程化能力,减少样板代码
在实际项目中,研究阶段用 PyTorch,部署阶段按需转 TF 或使用 TorchServe 是常见策略。掌握两大框架的核心思想,能让你在 NLP 开发中游刃有余。
资料
咚咚王
《Python 编程:从入门到实践》
《利用 Python 进行数据分析》
《算法导论中文第三版》
《概率论与数理统计(第四版) (盛骤) 》
《程序员的数学》
《线性代数应该这样学第 3 版》
《微积分和数学分析引论》
《(西瓜书)周志华-机器学习》
《TensorFlow 机器学习实战指南》
《Sklearn 与 TensorFlow 机器学习实用指南》
《模式识别(第四版)》
《深度学习 deep learning》伊恩·古德费洛著 花书
《Python 深度学习第二版(中文版)【纯文本】 (登封大数据 (Francois Choliet)) (Z-Library)》
《深入浅出神经网络与深度学习 +(迈克尔·尼尔森(Michael+Nielsen)》
《自然语言处理综论 第 2 版》
《Natural-Language-Processing-with-PyTorch》
《计算机视觉-算法与应用(中文版)》
《Learning OpenCV 4》
《AIGC:智能创作时代》杜雨 +&+ 张孜铭
《AIGC 原理与实践:零基础学大语言模型、扩散模型和多模态模型》
《从零构建大语言模型(中文版)》
《实战 AI 大模型》
《AI 3.0》