torch训练出的模型的组成以及模型训练后的使用和分析办法

下面这篇会偏「工程实践向」,假定你已经会基本的 Python、会跑简单的 nn.Module + DataLoader 的训练脚本,但对「模型本身到底是个啥」「训练完之后能怎么用、怎么分析」还不太有体系认识。


一、Torch 训练出来的模型,本质上是什么?

从工程角度说:

一个 torch 训练出来的模型 = 一段可执行的计算图代码(nn.Module) + 一组训练好的参数(state_dict

1.1 nn.Module:模型的结构 & 前向逻辑

你通常是这样定义一个模型的:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F

class MLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        self.fc1 = nn.Linear(in_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

这里的 MLP

  • 继承自 nn.Module
  • 内部包含若干子模块 / 层(nn.Linear 等);
  • forward 方法描述了前向传播的计算流程(计算图的定义)。

训练过程中发生了什么?

  1. 每次前向:调用 model(x),实际上执行 model.forward(x)
  2. PyTorch 在前向时构建动态计算图(autograd graph),记录算子、张量之间的依赖关系;
  3. 反向时:loss.backward(),PyTorch 根据计算图自动求梯度;
  4. 优化器 optimizer.step() 用梯度更新每一层的参数。

重要点

「模型结构」是 Python 代码(nn.Module 的子类),不是数据;

「模型参数」是数据(若干 torch.Tensor),由训练得到。


1.2 state_dict:模型的权重参数快照

训练完成后,你通常会这样保存:

python 复制代码
torch.save(model.state_dict(), "model.pt")

state_dict 是一个 OrderedDict,key 是模块路径,value 是对应的权重张量:

python 复制代码
state_dict = model.state_dict()
print(state_dict.keys())
# 例如:
# odict_keys(['fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias'])

所以,可以把训练好的模型拆成两部分理解:

  1. 结构定义class MLP(nn.Module): ...
  2. 参数快照model.pt / model.pth 中保存的 state_dict

恢复模型的典型写法:

python 复制代码
model = MLP(in_dim, hidden_dim, out_dim)
state = torch.load("model.pt", map_location="cpu")  # or "cuda"
model.load_state_dict(state)
model.eval()  # 切 eval 模式

1.3 train() / eval():模型的「行为模式」

Torch 中有些层在训练和推理时行为不同:

  • nn.Dropout:训练时随机置零;推理时不随机。
  • nn.BatchNorm*:训练时用 batch 统计量,更新 running mean/var;推理时用 running mean/var。

所以:

python 复制代码
model.train()  # 启用训练模式
model.eval()   # 启用推理模式

保存出来的参数在任何模式下都一样,但运行时的行为会因此变得不同。


1.4 引入设备概念:CPU / GPU / 多 GPU

模型的参数其实就是一堆张量,因此也要放到设备上:

python 复制代码
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

x = x.to(device)
y = model(x)

训练出的模型「特征」之一就是:它跟设备绑定不是绝对的,权重本身是独立的,加载后可以 .to(cpu/cuda) 切换设备。

只是某些部署格式(比如导出到某些 C++ 推理引擎)会在导出时就固定计算后端。


二、训练好的模型能做什么?(应用层面)

训练好模型后,最基本的场景有:

  1. 直接推理(inference):给输入,输出预测;
  2. 集成进项目 / 服务:API 服务、批处理、工具脚本等;
  3. 迁移学习 / 微调:在新数据集上继续训练;
  4. 导出为其他框架 / 推理引擎格式:ONNX、TorchScript、TensorRT 等;
  5. 可视化与分析:特征可视化、梯度分析、模型解释等。

下面分别展开,同时结合相关工具。


三、在 Python 中直接使用模型推理

3.1 最简单的推理流程

典型模板:

python 复制代码
model = MLP(in_dim, hidden_dim, out_dim)
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.eval()

with torch.no_grad():
    x = torch.randn(2, in_dim)  # batch_size=2
    y = model(x)
    # 做 argmax、阈值判断等后处理

关键点:

  • model.eval() 关闭 dropout / 切换 BN 行为;
  • torch.no_grad() 关闭 autograd,减少显存和计算开销;
  • 数据类型要匹配:float32 / int64 等。

3.2 将模型封装为一个函数或类接口

在实际工程中,一般会封装推理接口:

python 复制代码
class Predictor:
    def __init__(self, ckpt_path, device="cpu"):
        self.device = torch.device(device)
        self.model = MLP(in_dim, hidden_dim, out_dim).to(self.device)
        self.model.load_state_dict(torch.load(ckpt_path, map_location=self.device))
        self.model.eval()

    @torch.no_grad()
    def __call__(self, x: torch.Tensor):
        x = x.to(self.device)
        logits = self.model(x)
        prob = torch.softmax(logits, dim=-1)
        pred = torch.argmax(prob, dim=-1)
        return {"logits": logits, "prob": prob, "pred": pred}

这样在 web 服务、脚本里调用都比较统一。


四、把模型集成进实际项目:部署 & 服务化

4.1 在 Flask / FastAPI 中部署

假设你写了一个分类器,可以用 FastAPI:

python 复制代码
from fastapi import FastAPI
import torch
from pydantic import BaseModel

app = FastAPI()
predictor = Predictor("model.pt", device="cuda:0")

class Item(BaseModel):
    features: list[float]  # 简单起见

@app.post("/predict")
def predict(item: Item):
    x = torch.tensor(item.features, dtype=torch.float32).unsqueeze(0)
    out = predictor(x)
    return {
        "pred": int(out["pred"][0].item()),
        "prob": out["prob"][0].tolist()
    }

这样模型就变成一个 HTTP API 服务了,可以被前端、其他服务调用。

相关工具:

  • Web 框架:FastAPI、Flask、Django 等;
  • 容器化:Docker + gunicorn/uvicorn;
  • 负载均衡与监控:Kubernetes、Prometheus 等(偏 DevOps)。

4.2 批量离线推理

做离线分析或大规模数据预测:

python 复制代码
from torch.utils.data import DataLoader, Dataset

dataset = ...
loader = DataLoader(dataset, batch_size=256)

model.eval()
all_preds = []
with torch.no_grad():
    for x in loader:
        x = x.to(device)
        logits = model(x)
        pred = logits.argmax(dim=-1)
        all_preds.append(pred.cpu())

all_preds = torch.cat(all_preds)

五、迁移学习 / 微调现有模型

很多时候你不会从头训练模型,而是:

  1. 加载预训练模型(ImageNet、BERT 等);
  2. 替换最后几层;
  3. 在自己的数据集上 fine-tune。

PyTorch 生态支持丰富的预训练模型:

  • torchvision.models:ResNet、ViT、YOLOvX(部分)等图像模型;
  • torchaudio.models:Wav2Vec2 等;
  • torchtext / transformers(Hugging Face):BERT / GPT / LLama 等 NLP/多模态模型。

示例:使用 torchvision 的预训练 ResNet 做分类:

python 复制代码
import torch
import torch.nn as nn
import torchvision.models as models

num_classes = 10

model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
# 替换最后的全连接层
in_features = model.fc.in_features
model.fc = nn.Linear(in_features, num_classes)

迁移学习/微调流程的关键:

  • 冻结部分层的梯度:for p in model.layer1.parameters(): p.requires_grad = False
  • 使用较小学习率;
  • 考虑不同层不同学习率(参数组)。

六、将模型导出到其他格式:TorchScript / ONNX / TensorRT 等

单纯 Python 下推理,依赖解释器和动态计算图,不利于生产环境高效部署;因此 PyTorch 提供了多种导出方式。

6.1 TorchScript:静态图 / C++ 端部署

TorchScript 是 PyTorch 原生的「可序列化、可独立运行」模型格式。

两种方式:

  1. Tracing(跟踪):给定一个样例输入,让 PyTorch 跟踪前向执行过程,导出图;
  2. Scripting(脚本):分析 Python 源码中用到的一部分语法,转换为 TorchScript IR。

简例(Tracing):

python 复制代码
model = MLP(in_dim, hidden_dim, out_dim)
model.load_state_dict(torch.load("model.pt"))
model.eval()

example = torch.randn(1, in_dim)
traced = torch.jit.trace(model, example)
traced.save("model_traced.pt")

在 C++/LibTorch 中可以直接加载 model_traced.pt 进行推理。

适用途径:

  • 部署在 C++ 服务;
  • 较少的 Python 依赖;
  • 可以做一些 graph-level 的优化。

局限:Tracing 对控制流(if/loop)不太鲁棒;Scripting 有一定语法限制。


6.2 ONNX:跨框架 / 推理引擎通用格式

ONNX(Open Neural Network Exchange)是一个开放的模型交换格式,很多推理引擎支持(ONNX Runtime、TensorRT、OpenVINO 等)。

导出示例:

python 复制代码
dummy_input = torch.randn(1, in_dim)
torch.onnx.export(
    model,
    dummy_input,
    "model.onnx",
    input_names=["input"],
    output_names=["output"],
    opset_version=17,  # 一般用较新的版本
    dynamic_axes={"input": {0: "batch"}, "output": {0: "batch"}}
)

然后可以使用:

  • onnx + onnxruntime 在 Python/C++ 中高性能推理;
  • trtexec / TensorRT 将 ONNX 转为 TensorRT 引擎,在 NVIDIA GPU 上加速。

ONNX 的价值:

  • 跨框架:从 PyTorch 导出后,可以在非 PyTorch 环境跑;
  • 适合端侧或高性能服务部署。

6.3 其他:TensorRT、TFLite、CoreML、OpenVINO 等

一般流程:PyTorch → ONNX → 各种后端工具

  • TensorRT:适合 NVIDIA GPU 高吞吐量低延迟服务;
  • TFLite:移动端 / 嵌入式;
  • CoreML:Apple 生态(iOS/macOS);
  • OpenVINO:Intel 硬件加速。

对你来说关键是理解:

PyTorch 训练好的模型,本质是一个「高层次描述」,可以被编译/转换成各种底层推理引擎格式。


七、分析与理解模型:从黑盒到「半透明盒」

下面进入你问的另外一个重点:「有哪些工具可以对训练出来的模型进行应用和分析?」

大致分几类:

  1. 结构 & 参数层级的分析:看模型长啥样、参数多少;
  2. 训练过程可视化 / 调试:loss、metric、梯度等;
  3. 特征与中间表示可视化:看模型中间层提取了什么特征;
  4. 解释性 / 可解释 AI 工具:Grad-CAM、SHAP、LIME 等;
  5. 性能分析:耗时、显存、运算复杂度、算子 profiling。

7.1 模型结构和参数信息

(1)print(model) / model.named_parameters()

最原始的方式:

python 复制代码
print(model)

for name, param in model.named_parameters():
    print(name, param.shape, param.requires_grad)
(2)torchsummary / torchinfo

用第三方工具给出总结:

bash 复制代码
pip install torchinfo
python 复制代码
from torchinfo import summary

model = MLP(in_dim, hidden_dim, out_dim)
summary(model, input_size=(32, in_dim))  # batch_size=32

输出典型信息:

  • 每层的输出尺寸;
  • 参数数量;
  • 总参数量;
  • 内存占用估算。

7.2 TensorBoard / torch.utils.tensorboard

PyTorch 官方兼容 TensorBoard,可以监控训练过程:

python 复制代码
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter(log_dir="runs/exp1")

for step in range(num_steps):
    loss = ...
    writer.add_scalar("train/loss", loss.item(), step)
    # 也可以 add_histogram, add_image 等
writer.close()

启动:

bash 复制代码
tensorboard --logdir runs

看到:

  • Loss / metric 曲线;
  • 权重/梯度分布变化;
  • 中间层 feature map 可视化(对 CNN 很有帮助)。

7.3 使用 Hook 观察中间层特征

PyTorch 的 hook 机制允许你在前向/后向过程中插入回调,捕获中间 tensor。

示例(捕获某层输出):

python 复制代码
features = {}

def hook_fn(module, input, output):
    features['fc1_out'] = output.detach()

handle = model.fc1.register_forward_hook(hook_fn)

with torch.no_grad():
    x = torch.randn(1, in_dim)
    y = model(x)

handle.remove()

print(features['fc1_out'].shape)

用途:

  • 观察中间层的激活值(看是否梯度消失/爆炸、是否过饱和);
  • 做特征可视化(比如 CNN 的 feature maps);
  • 提取某层编码,用于下游任务(如特征检索)。

7.4 模型解释:Grad-CAM / SHAP / Captum 等

(1)Grad-CAM(图像任务)

Grad-CAM 可以显示 CNN 在预测时「看到了图像的哪些区域」。

PyTorch 生态有很多现成实现,比如 pytorch-grad-cam

bash 复制代码
pip install grad-cam

使用示例(简化):

python 复制代码
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

target_layers = [model.layer4[-1]]  # 最后一个卷积层
cam = GradCAM(model=model, target_layers=target_layers, use_cuda=True)

targets = [ClassifierOutputTarget(class_idx)]
grayscale_cam = cam(input_tensor=input_image, targets=targets)[0, :]
# 然后叠加到原图做可视化

(2)Captum(官方出品的模型可解释性库)

bash 复制代码
pip install captum

支持:

  • Integrated Gradients
  • DeepLift
  • Saliency
  • Layer-wise Relevance Propagation 等

例:对一个分类模型做特征重要性评估(tabular / NLP 都可以用):

python 复制代码
from captum.attr import IntegratedGradients

ig = IntegratedGradients(model)

attr = ig.attribute(inputs=x, target=class_idx)  # 得到每个输入维度的贡献

(3)SHAP / LIME(更偏统计学习、模型无关)

  • SHAP:计算每个特征在预测中的「边际贡献」;
  • LIME:通过局部线性拟合近似模型的行为。

对 PyTorch 模型可以用 shap.DeepExplainer(对深度学习优化过)。

更多适合 tabular/NLP 等应用场景。


7.5 性能与资源分析:Profiler、CUDA 工具

(1)PyTorch Profiler

官方推荐用法:

python 复制代码
import torch
from torch.profiler import profile, record_function, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=torch.profiler.schedule(wait=1, warmup=1, active=3, repeat=2),
    on_trace_ready=torch.profiler.tensorboard_trace_handler("./log"),
    record_shapes=True,
    profile_memory=True
) as prof:
    for step, (x, y) in enumerate(dataloader):
        if step >= 10:
            break
        with record_function("train_step"):
            loss = train_step(x, y)
        prof.step()

然后用 TensorBoard 查看:

  • 每个算子耗时;
  • CPU/GPU 时间;
  • 内存/显存占用;
  • 算子调用栈。

(2)第三方:torchinfo, ptflops

  • ptflops:估算 FLOPs / MACs;
  • torch.utils.benchmark:性能对比基准。
bash 复制代码
pip install ptflops
python 复制代码
from ptflops import get_model_complexity_info

with torch.cuda.device(0):
    macs, params = get_model_complexity_info(model, (3, 224, 224), as_strings=True,
                                             print_per_layer_stat=True)
    print('MACs:', macs, 'Params:', params)

八、生态工具:高层封装与大模型库

除了 PyTorch 原生,还有很多围绕它构建的工具,专门解决「训练 --> 应用/分析」的工程问题。

8.1 PyTorch Lightning / Lightning AI

  • 把训练循环(for epoch in range...)等样板代码封装起来;
  • 内置日志、checkpoint、早停、分布式训练等;
  • 更适合大型项目,结构清晰。

训练出的模型依旧是 nn.Module + checkpoint,只是代码组织更规范,更容易与日志、分析工具集成。


8.2 Hugging Face 生态(Transformers / Datasets / Accelerate)

如果你做 NLP / 多模态 / 大模型:

  • Transformers
    • 提供大量预训练模型(BERT、GPT、CLIP、LLM 等);
    • 模型本身底层还是 PyTorch(或可选 TensorFlow/JAX),但推理接口统一且强大;
    • 内置很多分析/可视化工具(如 Attention visualization)。
  • Accelerate
    • 简化分布式训练与推理;
  • Optimum
    • 专注模型优化与部署(ONNX Runtime、TensorRT、OpenVINO 等)。

它们都把「模型文件 + 配置 + tokenizer 等」打包成可重用资源,用起来比自己从头组更方便。


九、一个简单的「路线图」:从新手到能分析模型

最后给你一个可以按步骤实践的路线:

  1. 理解模型 = 结构 (nn.Module) + 参数 (state_dict)

    • 写一个小 MLP/CNN,自行保存 / 加载;
    • 熟悉 model.eval(), with torch.no_grad()
  2. 写一个干净的推理脚本

    • 将模型封装成 Predictor 类;
    • 支持命令行参数:ckpt 路径、设备、输入文件等。
  3. torchinfo.summary 看一下自己模型的结构

    • 理解每一层输出尺寸和参数量。
  4. 接入 TensorBoard 监控训练过程

    • 记录 loss / acc;
    • 尝试记录某些层的权重直方图和梯度分布。
  5. 写一个简单的 Hook,抓取中间层输出

    • 对 CNN,画几张中间 feature map(把 feature map normalize 到 [0,1]后用 matplotlib.imshow)。
  6. 导出为 ONNX 并用 onnxruntime 做一次推理

    • 比较 PyTorch 和 ONNX Runtime 的输出差异(应在 1e-5 内);
    • 感受从「训练」到「部署格式」的转换。
  7. 尝试一次 Grad-CAM 或 Captum 的 Integrated Gradients

    • 如果做图像:Grad-CAM;
    • 如果做 tabular:用 Captum 对特征做 attribution。
  8. 用 PyTorch Profiler 跑一次 profile

    • 找出时间耗费最大的几个算子层;
    • 调整 batch size、layer 结构,看性能变化。

完成这些,你就对「训练出来的模型是什么」「如何使用」「如何分析」会有一套完整的、可操作的认知。

相关推荐
QuiteCoder2 小时前
深度学习的范式演进、架构前沿与通用人工智能之路
人工智能·深度学习
周名彥2 小时前
### 天脑体系V∞·13824D完全体终极架构与全域落地研究报告 (生物计算与隐私计算融合版)
人工智能·神经网络·去中心化·量子计算·agi
MoonBit月兔2 小时前
年终 Meetup:走进腾讯|AI 原生编程与 Code Agent 实战交流会
大数据·开发语言·人工智能·腾讯云·moonbit
大模型任我行3 小时前
人大:熵引导的LLM有限数据训练
人工智能·语言模型·自然语言处理·论文笔记
weixin_468466853 小时前
YOLOv13结合代码原理详细解析及模型安装与使用
人工智能·深度学习·yolo·计算机视觉·图像识别·目标识别·yolov13
蹦蹦跳跳真可爱5893 小时前
Python----大模型(GPT-2模型训练加速,训练策略)
人工智能·pytorch·python·gpt·embedding
xwill*3 小时前
π∗0.6: a VLA That Learns From Experience
人工智能·pytorch·python
jiayong233 小时前
知识库概念与核心价值01
java·人工智能·spring·知识库
雨轩剑3 小时前
做 AI 功能不难,难的是把 App 发布上架
人工智能·开源软件