ADVANCE Day35

@浙大疏锦行

📘 Day 35 实战作业:最后一公里 ------ 可视化、保存与推理

1. 作业综述

核心目标

完成深度学习项目的闭环。

训练不是终点,应用 才是。我们需要学会查看模型结构,将训练好的模型保存为文件(.pth),并重新加载它来进行预测(推理)。

涉及知识点

  • 结构检视 : print(model)model.named_parameters()
  • 模型持久化 : torch.savetorch.load
  • 状态字典 : 理解 state_dict 的核心作用。
  • 推理模式 : model.eval()torch.no_grad() 的重要性。

场景类比

  • 训练: 像是读书上课,把知识(权重)装进脑子。
  • 保存: 像是把脑子里的知识写成一本"秘籍"(.pth文件)。
  • 加载: 别人拿到秘籍,修炼一下,也拥有了同样的功力。
  • 推理: 用这身功力去解决实际问题(考试/打架)。

步骤 1:模型解剖

场景描述

我们在代码里写了 nn.Linear,但模型内部到底有多少参数?

比如一个 4 -> 10 的全连接层,参数量是 10 × 4 10 \times 4 10×4 (权重) + 10 10 10 (偏置) = 50 50 50 个。

我们需要学会查看这些细节。

任务

  1. 定义并实例化之前的 MLP 模型。
  2. 直接打印模型对象(查看层结构)。
  3. 遍历 named_parameters(),打印每一层的参数形状。
py 复制代码
import torch
import torch.nn as nn

# --- 1. 快速复现模型 (复习) ---
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = MLP()

# --- 2. 宏观:看结构 ---
print("=== 模型结构图 ===")
print(model)

# --- 3. 微观:看参数 ---
print("\n=== 参数细节 ===")
total_params = 0
for name, param in model.named_parameters():
    print(f"层: {name} | 形状: {param.shape}")
    # 累加参数数量 (numel = number of elements)
    total_params += param.numel()

print(f"\n🔥 模型总参数量: {total_params}")
# 算一下:fc1 (4*10 + 10) + fc2 (10*3 + 3) = 50 + 33 = 83。对上了吗?
复制代码
=== 模型结构图 ===
MLP(
  (fc1): Linear(in_features=4, out_features=10, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=10, out_features=3, bias=True)
)

=== 参数细节 ===
层: fc1.weight | 形状: torch.Size([10, 4])
层: fc1.bias | 形状: torch.Size([10])
层: fc2.weight | 形状: torch.Size([3, 10])
层: fc2.bias | 形状: torch.Size([3])

🔥 模型总参数量: 83

步骤 2:模型的保存与加载

核心概念

PyTorch 推荐只保存参数 (权重和偏置),而不是保存整个模型对象。

这些参数存储在一个字典里,叫 state_dict

  • 保存torch.save(model.state_dict(), "best_model.pth")
  • 加载 :先实例化一个空模型,然后 model.load_state_dict(...)

任务

  1. 模拟训练(这里直接保存初始模型即可)。
  2. 将模型参数保存到 iris_model.pth
  3. 删除原模型,创建一个新模型,加载参数,验证是否复活。
py 复制代码
import os

# --- 1. 保存 (Save) ---
save_path = "iris_model.pth"
print(f"💾 正在保存模型参数到: {save_path} ...")

# 这里的 state_dict() 就是那本"秘籍"
torch.save(model.state_dict(), save_path)
print("✅ 保存成功!")

# 检查文件是否存在
print(f"文件存在性检查: {os.path.exists(save_path)}")


# --- 2. 加载 (Load) ---
print("\n🔄 正在模拟加载过程...")

# 假设我们在另一台电脑上,首先需要定义同样的模型结构(空壳)
new_model = MLP()

# 此时 new_model 的参数是随机初始化的
# 我们把保存的参数加载进去
# weights_only=True 是为了安全(防止pickle注入),新版本推荐加上
state_dict = torch.load(save_path, weights_only=True)
new_model.load_state_dict(state_dict)

print("✅ 模型加载完毕!")
print("新模型 fc1 偏置的前5个值:", new_model.fc1.bias[:5].detach().numpy())
复制代码
💾 正在保存模型参数到: iris_model.pth ...
✅ 保存成功!
文件存在性检查: True

🔄 正在模拟加载过程...
✅ 模型加载完毕!
新模型 fc1 偏置的前5个值: [ 0.01631749 -0.33050632 -0.18811053 -0.42050672 -0.4696836 ]

步骤 3:推理模式 (Inference)

场景描述

模型训练好并加载后,就可以上线使用了。

在推理(预测)阶段,有两个关键动作:

  1. model.eval(): 告诉模型"我要考试了",关闭 Dropout 和 BatchNorm 等训练专用的层。
  2. torch.no_grad(): 告诉 PyTorch "不需要算梯度",这样能省大量内存并加速。

任务

  1. 准备一条新的测试数据。
  2. 切换到推理模式。
  3. 预测这条数据属于哪一类鸢尾花。
py 复制代码
# --- 1. 准备一条假数据 ---
# 假设有4个特征:花萼长、宽,花瓣长、宽
# 注意:输入必须是 Tensor,且通常需要加一个 batch 维度 (1, 4)
sample_data = torch.tensor([[5.1, 3.5, 1.4, 0.2]]) 

print(f"输入数据形状: {sample_data.shape}")

# --- 2. 推理流程 (标准范式) ---
# A. 切换评估模式
new_model.eval()

# B. 关闭梯度计算上下文
with torch.no_grad():
    # 前向传播
    outputs = new_model(sample_data)
    
    # 获取预测结果
    # outputs 是 (1, 3) 的概率分布(Logits)
    print(f"模型原始输出 (Logits): {outputs}")
    
    # 转化为概率 (Softmax)
    probs = torch.softmax(outputs, dim=1)
    print(f"预测概率: {probs}")
    
    # 取概率最大的类别索引
    predicted_class = torch.argmax(probs, dim=1).item()

# --- 3. 结果解读 ---
class_names = ['Setosa', 'Versicolor', 'Virginica']
print(f"\n🔮 最终预测类别: {predicted_class} -> {class_names[predicted_class]}")
复制代码
输入数据形状: torch.Size([1, 4])
模型原始输出 (Logits): tensor([[ 0.0609, -0.5886, -0.3129]])
预测概率: tensor([[0.4524, 0.2363, 0.3113]])

🔮 最终预测类别: 0 -> Setosa

🎓 Day 35 总结:深度学习基础通关!

恭喜你!完成了从 Numpy 手搓感知机,到 PyTorch 搭建、训练、保存、推理的全过程。

回顾今天的重点

  1. 参数量 : 以后看到论文里的 "10B parameters" (100亿参数),你就知道那是 numel() 累加出来的。
  2. State Dict : 模型文件本质上就是一个 Python 字典,存着 {'fc1.weight': tensor(...), ...}
  3. Eval Mode : 预测时不加 model.eval()no_grad() 是新手最容易犯的错误,可能导致结果不准或显存爆炸。

Next Level (预告) :

从明天开始,我们将不再处理简单的表格数据。

我们将进入 计算机视觉 (Computer Vision) 的世界,去处理真正的 图像数据。卷积神经网络 (CNN)、ResNet、迁移学习......激动人心的旅程才刚刚开始!

准备好你的 GPU,我们明天见! 🚀

相关推荐
LYFlied1 分钟前
AI大时代下前端跨端解决方案的现状与演进路径
前端·人工智能
深蓝电商API3 分钟前
图片验证码识别:pytesseract+opencv入门
人工智能·opencv·计算机视觉·pytesseract
.Katherine௰3 分钟前
AI数字人模拟面试机器人
人工智能
光影少年4 分钟前
AI 前端 / 高级前端
前端·人工智能·状态模式
亓才孓7 分钟前
[Properties]写配置文件前,必须初始化Properties(引用变量没执行有效对象,调用方法会报空指针错误)
开发语言·python
zhangshuang-peta9 分钟前
OpenCode vs Claude Code vs OpenAI Codex:AI编程助手全面对比
人工智能·ai agent·mcp·peta
Bruk.Liu12 分钟前
(LangChain 实战14):基于 ChatMessageHistory 自定义实现对话记忆功能
人工智能·python·langchain·agent
代码改善世界13 分钟前
CANN中的AI算子开发:ops-nn仓库深度解读
人工智能
大江东去浪淘尽千古风流人物27 分钟前
【VLN】VLN(Vision-and-Language Navigation视觉语言导航)算法本质,范式难点及解决方向(1)
人工智能·python·算法
云飞云共享云桌面27 分钟前
高性能图形工作站的资源如何共享给10个SolidWorks研发设计用
linux·运维·服务器·前端·网络·数据库·人工智能