下面这篇会偏「工程实践向」,假定你已经会基本的 Python、会跑简单的 nn.Module + DataLoader 的训练脚本,但对「模型本身到底是个啥」「训练完之后能怎么用、怎么分析」还不太有体系认识。
一、Torch 训练出来的模型,本质上是什么?
从工程角度说:
一个 torch 训练出来的模型 = 一段可执行的计算图代码(
nn.Module) + 一组训练好的参数(state_dict)
1.1 nn.Module:模型的结构 & 前向逻辑
你通常是这样定义一个模型的:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_dim, hidden_dim, out_dim):
super().__init__()
self.fc1 = nn.Linear(in_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, out_dim)
def forward(self, x):
x = F.relu(self.fc1(x))
x = self.fc2(x)
return x
这里的 MLP:
- 继承自
nn.Module; - 内部包含若干子模块 / 层(
nn.Linear等); forward方法描述了前向传播的计算流程(计算图的定义)。
训练过程中发生了什么?
- 每次前向:调用
model(x),实际上执行model.forward(x); - PyTorch 在前向时构建动态计算图(autograd graph),记录算子、张量之间的依赖关系;
- 反向时:
loss.backward(),PyTorch 根据计算图自动求梯度; - 优化器
optimizer.step()用梯度更新每一层的参数。
重要点 :
「模型结构」是 Python 代码(nn.Module 的子类),不是数据;
「模型参数」是数据(若干 torch.Tensor),由训练得到。
1.2 state_dict:模型的权重参数快照
训练完成后,你通常会这样保存:
python
torch.save(model.state_dict(), "model.pt")
state_dict 是一个 OrderedDict,key 是模块路径,value 是对应的权重张量:
python
state_dict = model.state_dict()
print(state_dict.keys())
# 例如:
# odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])
所以,可以把训练好的模型拆成两部分理解:
- 结构定义 :
class MLP(nn.Module): ... - 参数快照 :
model.pt/model.pth中保存的state_dict
恢复模型的典型写法:
python
model = MLP(in_dim, hidden_dim, out_dim)
state = torch.load("model.pt", map_location="cpu") # or "cuda"
model.load_state_dict(state)
model.eval() # 切 eval 模式
1.3 train() / eval():模型的「行为模式」
Torch 中有些层在训练和推理时行为不同:
nn.Dropout:训练时随机置零;推理时不随机。nn.BatchNorm*:训练时用 batch 统计量,更新 running mean/var;推理时用 running mean/var。
所以:
python
model.train() # 启用训练模式
model.eval() # 启用推理模式
保存出来的参数在任何模式下都一样,但运行时的行为会因此变得不同。
1.4 引入设备概念:CPU / GPU / 多 GPU
模型的参数其实就是一堆张量,因此也要放到设备上:
python
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
x = x.to(device)
y = model(x)
训练出的模型「特征」之一就是:它跟设备绑定不是绝对的,权重本身是独立的,加载后可以
.to(cpu/cuda)切换设备。只是某些部署格式(比如导出到某些 C++ 推理引擎)会在导出时就固定计算后端。
二、训练好的模型能做什么?(应用层面)
训练好模型后,最基本的场景有:
- 直接推理(inference):给输入,输出预测;
- 集成进项目 / 服务:API 服务、批处理、工具脚本等;
- 迁移学习 / 微调:在新数据集上继续训练;
- 导出为其他框架 / 推理引擎格式:ONNX、TorchScript、TensorRT 等;
- 可视化与分析:特征可视化、梯度分析、模型解释等。
下面分别展开,同时结合相关工具。
三、在 Python 中直接使用模型推理
3.1 最简单的推理流程
典型模板:
python
model = MLP(in_dim, hidden_dim, out_dim)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()
with torch.no_grad():
x = torch.randn(2, in_dim) # batch_size=2
y = model(x)
# 做 argmax、阈值判断等后处理
关键点:
model.eval()关闭 dropout / 切换 BN 行为;torch.no_grad()关闭 autograd,减少显存和计算开销;- 数据类型要匹配:
float32/int64等。
3.2 将模型封装为一个函数或类接口
在实际工程中,一般会封装推理接口:
python
class Predictor:
def __init__(self, ckpt_path, device="cpu"):
self.device = torch.device(device)
self.model = MLP(in_dim, hidden_dim, out_dim).to(self.device)
self.model.load_state_dict(torch.load(ckpt_path, map_location=self.device))
self.model.eval()
@torch.no_grad()
def __call__(self, x: torch.Tensor):
x = x.to(self.device)
logits = self.model(x)
prob = torch.softmax(logits, dim=-1)
pred = torch.argmax(prob, dim=-1)
return {"logits": logits, "prob": prob, "pred": pred}
这样在 web 服务、脚本里调用都比较统一。
四、把模型集成进实际项目:部署 & 服务化
4.1 在 Flask / FastAPI 中部署
假设你写了一个分类器,可以用 FastAPI:
python
from fastapi import FastAPI
import torch
from pydantic import BaseModel
app = FastAPI()
predictor = Predictor("model.pt", device="cuda:0")
class Item(BaseModel):
features: list[float] # 简单起见
@app.post("/predict")
def predict(item: Item):
x = torch.tensor(item.features, dtype=torch.float32).unsqueeze(0)
out = predictor(x)
return {
"pred": int(out["pred"][0].item()),
"prob": out["prob"][0].tolist()
}
这样模型就变成一个 HTTP API 服务了,可以被前端、其他服务调用。
相关工具:
- Web 框架:FastAPI、Flask、Django 等;
- 容器化:Docker + gunicorn/uvicorn;
- 负载均衡与监控:Kubernetes、Prometheus 等(偏 DevOps)。
4.2 批量离线推理
做离线分析或大规模数据预测:
python
from torch.utils.data import DataLoader, Dataset
dataset = ...
loader = DataLoader(dataset, batch_size=256)
model.eval()
all_preds = []
with torch.no_grad():
for x in loader:
x = x.to(device)
logits = model(x)
pred = logits.argmax(dim=-1)
all_preds.append(pred.cpu())
all_preds = torch.cat(all_preds)
五、迁移学习 / 微调现有模型
很多时候你不会从头训练模型,而是:
- 加载预训练模型(ImageNet、BERT 等);
- 替换最后几层;
- 在自己的数据集上 fine-tune。
PyTorch 生态支持丰富的预训练模型:
torchvision.models:ResNet、ViT、YOLOvX(部分)等图像模型;torchaudio.models:Wav2Vec2 等;torchtext/transformers(Hugging Face):BERT / GPT / LLama 等 NLP/多模态模型。
示例:使用 torchvision 的预训练 ResNet 做分类:
python
import torch
import torch.nn as nn
import torchvision.models as models
num_classes = 10
model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# 替换最后的全连接层
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)
迁移学习/微调流程的关键:
- 冻结部分层的梯度:
for p in model.layer1.parameters(): p.requires_grad = False - 使用较小学习率;
- 考虑不同层不同学习率(参数组)。
六、将模型导出到其他格式:TorchScript / ONNX / TensorRT 等
单纯 Python 下推理,依赖解释器和动态计算图,不利于生产环境高效部署;因此 PyTorch 提供了多种导出方式。
6.1 TorchScript:静态图 / C++ 端部署
TorchScript 是 PyTorch 原生的「可序列化、可独立运行」模型格式。
两种方式:
- Tracing(跟踪):给定一个样例输入,让 PyTorch 跟踪前向执行过程,导出图;
- Scripting(脚本):分析 Python 源码中用到的一部分语法,转换为 TorchScript IR。
简例(Tracing):
python
model = MLP(in_dim, hidden_dim, out_dim)
model.load_state_dict(torch.load("model.pt"))
model.eval()
example = torch.randn(1, in_dim)
traced = torch.jit.trace(model, example)
traced.save("model_traced.pt")
在 C++/LibTorch 中可以直接加载 model_traced.pt 进行推理。
适用途径:
- 部署在 C++ 服务;
- 较少的 Python 依赖;
- 可以做一些 graph-level 的优化。
局限:Tracing 对控制流(if/loop)不太鲁棒;Scripting 有一定语法限制。
6.2 ONNX:跨框架 / 推理引擎通用格式
ONNX(Open Neural Network Exchange)是一个开放的模型交换格式,很多推理引擎支持(ONNX Runtime、TensorRT、OpenVINO 等)。
导出示例:
python
dummy_input = torch.randn(1, in_dim)
torch.onnx.export(
model,
dummy_input,
"model.onnx",
input_names=["input"],
output_names=["output"],
opset_version=17, # 一般用较新的版本
dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
)
然后可以使用:
onnx+onnxruntime在 Python/C++ 中高性能推理;trtexec/ TensorRT 将 ONNX 转为 TensorRT 引擎,在 NVIDIA GPU 上加速。
ONNX 的价值:
- 跨框架:从 PyTorch 导出后,可以在非 PyTorch 环境跑;
- 适合端侧或高性能服务部署。
6.3 其他:TensorRT、TFLite、CoreML、OpenVINO 等
一般流程:PyTorch → ONNX → 各种后端工具
- TensorRT:适合 NVIDIA GPU 高吞吐量低延迟服务;
- TFLite:移动端 / 嵌入式;
- CoreML:Apple 生态(iOS/macOS);
- OpenVINO:Intel 硬件加速。
对你来说关键是理解:
PyTorch 训练好的模型,本质是一个「高层次描述」,可以被编译/转换成各种底层推理引擎格式。
七、分析与理解模型:从黑盒到「半透明盒」
下面进入你问的另外一个重点:「有哪些工具可以对训练出来的模型进行应用和分析?」
大致分几类:
- 结构 & 参数层级的分析:看模型长啥样、参数多少;
- 训练过程可视化 / 调试:loss、metric、梯度等;
- 特征与中间表示可视化:看模型中间层提取了什么特征;
- 解释性 / 可解释 AI 工具:Grad-CAM、SHAP、LIME 等;
- 性能分析:耗时、显存、运算复杂度、算子 profiling。
7.1 模型结构和参数信息
(1)print(model) / model.named_parameters()
最原始的方式:
python
print(model)
for name, param in model.named_parameters():
print(name, param.shape, param.requires_grad)
(2)torchsummary / torchinfo
用第三方工具给出总结:
bash
pip install torchinfo
python
from torchinfo import summary
model = MLP(in_dim, hidden_dim, out_dim)
summary(model, input_size=(32, in_dim)) # batch_size=32
输出典型信息:
- 每层的输出尺寸;
- 参数数量;
- 总参数量;
- 内存占用估算。
7.2 TensorBoard / torch.utils.tensorboard
PyTorch 官方兼容 TensorBoard,可以监控训练过程:
python
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(log_dir="runs/exp1")
for step in range(num_steps):
loss = ...
writer.add_scalar("train/loss", loss.item(), step)
# 也可以 add_histogram, add_image 等
writer.close()
启动:
bash
tensorboard --logdir runs
看到:
- Loss / metric 曲线;
- 权重/梯度分布变化;
- 中间层 feature map 可视化(对 CNN 很有帮助)。
7.3 使用 Hook 观察中间层特征
PyTorch 的 hook 机制允许你在前向/后向过程中插入回调,捕获中间 tensor。
示例(捕获某层输出):
python
features = {}
def hook_fn(module, input, output):
features['fc1_out'] = output.detach()
handle = model.fc1.register_forward_hook(hook_fn)
with torch.no_grad():
x = torch.randn(1, in_dim)
y = model(x)
handle.remove()
print(features['fc1_out'].shape)
用途:
- 观察中间层的激活值(看是否梯度消失/爆炸、是否过饱和);
- 做特征可视化(比如 CNN 的 feature maps);
- 提取某层编码,用于下游任务(如特征检索)。
7.4 模型解释:Grad-CAM / SHAP / Captum 等
(1)Grad-CAM(图像任务)
Grad-CAM 可以显示 CNN 在预测时「看到了图像的哪些区域」。
PyTorch 生态有很多现成实现,比如 pytorch-grad-cam:
bash
pip install grad-cam
使用示例(简化):
python
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
target_layers = [model.layer4[-1]] # 最后一个卷积层
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)
targets = [ClassifierOutputTarget(class_idx)]
grayscale_cam = cam(input_tensor=input_image, targets=targets)[0, :]
# 然后叠加到原图做可视化
(2)Captum(官方出品的模型可解释性库)
bash
pip install captum
支持:
- Integrated Gradients
- DeepLift
- Saliency
- Layer-wise Relevance Propagation 等
例:对一个分类模型做特征重要性评估(tabular / NLP 都可以用):
python
from captum.attr import IntegratedGradients
ig = IntegratedGradients(model)
attr = ig.attribute(inputs=x, target=class_idx) # 得到每个输入维度的贡献
(3)SHAP / LIME(更偏统计学习、模型无关)
- SHAP:计算每个特征在预测中的「边际贡献」;
- LIME:通过局部线性拟合近似模型的行为。
对 PyTorch 模型可以用 shap.DeepExplainer(对深度学习优化过)。
更多适合 tabular/NLP 等应用场景。
7.5 性能与资源分析:Profiler、CUDA 工具
(1)PyTorch Profiler
官方推荐用法:
python
import torch
from torch.profiler import profile, record_function, ProfilerActivity
with profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
record_shapes=True,
profile_memory=True
) as prof:
for step, (x, y) in enumerate(dataloader):
if step >= 10:
break
with record_function("train_step"):
loss = train_step(x, y)
prof.step()
然后用 TensorBoard 查看:
- 每个算子耗时;
- CPU/GPU 时间;
- 内存/显存占用;
- 算子调用栈。
(2)第三方:torchinfo, ptflops
ptflops:估算 FLOPs / MACs;torch.utils.benchmark:性能对比基准。
bash
pip install ptflops
python
from ptflops import get_model_complexity_info
with torch.cuda.device(0):
macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
print_per_layer_stat=True)
print('MACs:', macs, 'Params:', params)
八、生态工具:高层封装与大模型库
除了 PyTorch 原生,还有很多围绕它构建的工具,专门解决「训练 --> 应用/分析」的工程问题。
8.1 PyTorch Lightning / Lightning AI
- 把训练循环(
for epoch in range...)等样板代码封装起来; - 内置日志、checkpoint、早停、分布式训练等;
- 更适合大型项目,结构清晰。
训练出的模型依旧是 nn.Module + checkpoint,只是代码组织更规范,更容易与日志、分析工具集成。
8.2 Hugging Face 生态(Transformers / Datasets / Accelerate)
如果你做 NLP / 多模态 / 大模型:
- Transformers :
- 提供大量预训练模型(BERT、GPT、CLIP、LLM 等);
- 模型本身底层还是 PyTorch(或可选 TensorFlow/JAX),但推理接口统一且强大;
- 内置很多分析/可视化工具(如 Attention visualization)。
- Accelerate :
- 简化分布式训练与推理;
- Optimum :
- 专注模型优化与部署(ONNX Runtime、TensorRT、OpenVINO 等)。
它们都把「模型文件 + 配置 + tokenizer 等」打包成可重用资源,用起来比自己从头组更方便。
九、一个简单的「路线图」:从新手到能分析模型
最后给你一个可以按步骤实践的路线:
-
理解模型 = 结构 (
nn.Module) + 参数 (state_dict)- 写一个小 MLP/CNN,自行保存 / 加载;
- 熟悉
model.eval(),with torch.no_grad()。
-
写一个干净的推理脚本
- 将模型封装成
Predictor类; - 支持命令行参数:ckpt 路径、设备、输入文件等。
- 将模型封装成
-
用
torchinfo.summary看一下自己模型的结构- 理解每一层输出尺寸和参数量。
-
接入 TensorBoard 监控训练过程
- 记录 loss / acc;
- 尝试记录某些层的权重直方图和梯度分布。
-
写一个简单的 Hook,抓取中间层输出
- 对 CNN,画几张中间 feature map(把 feature map normalize 到 [0,1]后用
matplotlib.imshow)。
- 对 CNN,画几张中间 feature map(把 feature map normalize 到 [0,1]后用
-
导出为 ONNX 并用 onnxruntime 做一次推理
- 比较 PyTorch 和 ONNX Runtime 的输出差异(应在 1e-5 内);
- 感受从「训练」到「部署格式」的转换。
-
尝试一次 Grad-CAM 或 Captum 的 Integrated Gradients
- 如果做图像:Grad-CAM;
- 如果做 tabular:用 Captum 对特征做 attribution。
-
用 PyTorch Profiler 跑一次 profile
- 找出时间耗费最大的几个算子层;
- 调整 batch size、layer 结构,看性能变化。
完成这些,你就对「训练出来的模型是什么」「如何使用」「如何分析」会有一套完整的、可操作的认知。