ADVANCE Day35

@浙大疏锦行

📘 Day 35 实战作业:最后一公里 ------ 可视化、保存与推理

1. 作业综述

核心目标

完成深度学习项目的闭环。

训练不是终点,应用 才是。我们需要学会查看模型结构,将训练好的模型保存为文件(.pth),并重新加载它来进行预测(推理)。

涉及知识点

  • 结构检视 : print(model)model.named_parameters()
  • 模型持久化 : torch.savetorch.load
  • 状态字典 : 理解 state_dict 的核心作用。
  • 推理模式 : model.eval()torch.no_grad() 的重要性。

场景类比

  • 训练: 像是读书上课,把知识(权重)装进脑子。
  • 保存: 像是把脑子里的知识写成一本"秘籍"(.pth文件)。
  • 加载: 别人拿到秘籍,修炼一下,也拥有了同样的功力。
  • 推理: 用这身功力去解决实际问题(考试/打架)。

步骤 1:模型解剖

场景描述

我们在代码里写了 nn.Linear,但模型内部到底有多少参数?

比如一个 4 -> 10 的全连接层,参数量是 10 × 4 10 \times 4 10×4 (权重) + 10 10 10 (偏置) = 50 50 50 个。

我们需要学会查看这些细节。

任务

  1. 定义并实例化之前的 MLP 模型。
  2. 直接打印模型对象(查看层结构)。
  3. 遍历 named_parameters(),打印每一层的参数形状。
py 复制代码
import torch
import torch.nn as nn

# --- 1. 快速复现模型 (复习) ---
class MLP(nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.fc1 = nn.Linear(4, 10)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(10, 3)

    def forward(self, x):
        out = self.fc1(x)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = MLP()

# --- 2. 宏观:看结构 ---
print("=== 模型结构图 ===")
print(model)

# --- 3. 微观:看参数 ---
print("\n=== 参数细节 ===")
total_params = 0
for name, param in model.named_parameters():
    print(f"层: {name} | 形状: {param.shape}")
    # 累加参数数量 (numel = number of elements)
    total_params += param.numel()

print(f"\n🔥 模型总参数量: {total_params}")
# 算一下:fc1 (4*10 + 10) + fc2 (10*3 + 3) = 50 + 33 = 83。对上了吗?
复制代码
=== 模型结构图 ===
MLP(
  (fc1): Linear(in_features=4, out_features=10, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=10, out_features=3, bias=True)
)

=== 参数细节 ===
层: fc1.weight | 形状: torch.Size([10, 4])
层: fc1.bias | 形状: torch.Size([10])
层: fc2.weight | 形状: torch.Size([3, 10])
层: fc2.bias | 形状: torch.Size([3])

🔥 模型总参数量: 83

步骤 2:模型的保存与加载

核心概念

PyTorch 推荐只保存参数 (权重和偏置),而不是保存整个模型对象。

这些参数存储在一个字典里,叫 state_dict

  • 保存torch.save(model.state_dict(), "best_model.pth")
  • 加载 :先实例化一个空模型,然后 model.load_state_dict(...)

任务

  1. 模拟训练(这里直接保存初始模型即可)。
  2. 将模型参数保存到 iris_model.pth
  3. 删除原模型,创建一个新模型,加载参数,验证是否复活。
py 复制代码
import os

# --- 1. 保存 (Save) ---
save_path = "iris_model.pth"
print(f"💾 正在保存模型参数到: {save_path} ...")

# 这里的 state_dict() 就是那本"秘籍"
torch.save(model.state_dict(), save_path)
print("✅ 保存成功!")

# 检查文件是否存在
print(f"文件存在性检查: {os.path.exists(save_path)}")


# --- 2. 加载 (Load) ---
print("\n🔄 正在模拟加载过程...")

# 假设我们在另一台电脑上,首先需要定义同样的模型结构(空壳)
new_model = MLP()

# 此时 new_model 的参数是随机初始化的
# 我们把保存的参数加载进去
# weights_only=True 是为了安全(防止pickle注入),新版本推荐加上
state_dict = torch.load(save_path, weights_only=True)
new_model.load_state_dict(state_dict)

print("✅ 模型加载完毕!")
print("新模型 fc1 偏置的前5个值:", new_model.fc1.bias[:5].detach().numpy())
复制代码
💾 正在保存模型参数到: iris_model.pth ...
✅ 保存成功!
文件存在性检查: True

🔄 正在模拟加载过程...
✅ 模型加载完毕!
新模型 fc1 偏置的前5个值: [ 0.01631749 -0.33050632 -0.18811053 -0.42050672 -0.4696836 ]

步骤 3:推理模式 (Inference)

场景描述

模型训练好并加载后,就可以上线使用了。

在推理(预测)阶段,有两个关键动作:

  1. model.eval(): 告诉模型"我要考试了",关闭 Dropout 和 BatchNorm 等训练专用的层。
  2. torch.no_grad(): 告诉 PyTorch "不需要算梯度",这样能省大量内存并加速。

任务

  1. 准备一条新的测试数据。
  2. 切换到推理模式。
  3. 预测这条数据属于哪一类鸢尾花。
py 复制代码
# --- 1. 准备一条假数据 ---
# 假设有4个特征:花萼长、宽,花瓣长、宽
# 注意:输入必须是 Tensor,且通常需要加一个 batch 维度 (1, 4)
sample_data = torch.tensor([[5.1, 3.5, 1.4, 0.2]]) 

print(f"输入数据形状: {sample_data.shape}")

# --- 2. 推理流程 (标准范式) ---
# A. 切换评估模式
new_model.eval()

# B. 关闭梯度计算上下文
with torch.no_grad():
    # 前向传播
    outputs = new_model(sample_data)
    
    # 获取预测结果
    # outputs 是 (1, 3) 的概率分布(Logits)
    print(f"模型原始输出 (Logits): {outputs}")
    
    # 转化为概率 (Softmax)
    probs = torch.softmax(outputs, dim=1)
    print(f"预测概率: {probs}")
    
    # 取概率最大的类别索引
    predicted_class = torch.argmax(probs, dim=1).item()

# --- 3. 结果解读 ---
class_names = ['Setosa', 'Versicolor', 'Virginica']
print(f"\n🔮 最终预测类别: {predicted_class} -> {class_names[predicted_class]}")
复制代码
输入数据形状: torch.Size([1, 4])
模型原始输出 (Logits): tensor([[ 0.0609, -0.5886, -0.3129]])
预测概率: tensor([[0.4524, 0.2363, 0.3113]])

🔮 最终预测类别: 0 -> Setosa

🎓 Day 35 总结:深度学习基础通关!

恭喜你!完成了从 Numpy 手搓感知机,到 PyTorch 搭建、训练、保存、推理的全过程。

回顾今天的重点

  1. 参数量 : 以后看到论文里的 "10B parameters" (100亿参数),你就知道那是 numel() 累加出来的。
  2. State Dict : 模型文件本质上就是一个 Python 字典,存着 {'fc1.weight': tensor(...), ...}
  3. Eval Mode : 预测时不加 model.eval()no_grad() 是新手最容易犯的错误,可能导致结果不准或显存爆炸。

Next Level (预告) :

从明天开始,我们将不再处理简单的表格数据。

我们将进入 计算机视觉 (Computer Vision) 的世界,去处理真正的 图像数据。卷积神经网络 (CNN)、ResNet、迁移学习......激动人心的旅程才刚刚开始!

准备好你的 GPU,我们明天见! 🚀

相关推荐
来两个炸鸡腿21 小时前
【Datawhale组队学习202601】Base-NLP task04 参数高效微调
人工智能·学习·自然语言处理
YH12312359h21 小时前
YOLO11-LSKNet钢材表面缺陷检测与分类系统详解
人工智能·分类·数据挖掘
理智.62921 小时前
Cursor 中使用 Conda 虚拟环境常见问题与彻底解决方案(踩坑实录)
python·conda
aiguangyuan1 天前
中文分词与文本分析实战指南
人工智能·python·nlp
小二·1 天前
Python Web 开发进阶实战:量子机器学习实验平台 —— 在 Flask + Vue 中集成 Qiskit 构建混合量子-经典 AI 应用
前端·人工智能·python
AC赳赳老秦1 天前
Confluence + DeepSeek:构建自动化、智能化的企业知识库文档生成与维护体系
大数据·运维·人工智能·自动化·jenkins·数据库架构·deepseek
DS随心转小程序1 天前
ChatGPT和Gemini公式
人工智能·chatgpt·aigc·word·豆包·deepseek·ds随心转
one____dream1 天前
【网安】pwn-ret2shellcode
python·安全·网络安全·网络攻击模型
王然-HUDDM1 天前
技术领跑:HUDDM-7D系统L4级功能安全预研验证
人工智能·嵌入式硬件·安全·车载系统·汽车
Light601 天前
亚马逊“沃尔玛化”战略大转身:一场生鲜零售的自我革命与中国启示
人工智能·零售·数字化转型·亚马逊·新零售·沃尔玛·生鲜供应链