一、需求分析
小明创办了一家手机公司,他不知道如何估算手机产品的价格。为了解决这个问题,他收集了多家公司的手机销售数据。该数据为二手手机的各个性能的数据,最后根据这些性能得到4个价格区间,作为这些二手手机售出的价格区间。主要包括:

我们需要帮助小明找出手机的功能(例如:RAM等)与其售价之间的某种关系。我们可以使用机器学习的方法来解决这个问题,也可以构建一个全连接的网络。
需要注意的是: 在这个问题中,我们不需要预测实际价格,而是一个价格范围,它的范围使用 0、1、2、3 来表示,所以该问题也是一个分类问题。接下来我们还是按照四个步骤来完成这个任务:
- 准备训练集数据
- 构建要使用的模型
- 模型训练
- 模型预测评估
二、构建数据集
python
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torchsummary import summary
import torch.optim as optim
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import os
os.chdir(r'F:\Pycharm\works-space\神经网络')
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # ←←← 关键!放在最前面(解决报错)
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"] # 设置显示中文字体
mpl.rcParams["axes.unicode_minus"] = False # 设置正常显示符号
torch.manual_seed(66)
def create_dataset():
data_df = pd.read_csv(r'data/手机价格预测.csv')
# 目标值可用值: [0, 1, 2, 3]
print(f'data_df.shape = {data_df.shape}') # (2000, 21)
# print(data_df.head())
# battery_power blue clock_speed ... touch_screen wifi price_range
# 0 842 0 2.2 ... 0 1 1
# 1 1021 1 0.5 ... 1 0 2
# 2 563 1 0.5 ... 1 0 2
x, y = data_df.iloc[: , : -1], data_df.iloc[: , -1] # x: 所有特征列 y: 所有目标值
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)
# 需要把 DataFrame 转成张量
x_train = torch.tensor(data=x_train.to_numpy(dtype=np.float32), dtype=torch.float32)
x_test = torch.tensor(data=x_test.to_numpy(dtype=np.float32), dtype=torch.float32)
y_train = torch.tensor(data=y_train.to_numpy(dtype=np.float32), dtype=torch.long) # 标签, CrossEntropyLoss需要的是64位整数
y_test = torch.tensor(data=y_test.to_numpy(dtype=np.float32), dtype=torch.long) # 标签, CrossEntropyLoss需要的是64位整数
# 创建数据集: x_train 和 y_train 对应
# 创建数据集: x_test 和 y_test 对应
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
print(f'特征总数 = {x.shape[1]}') # (2000, 20) 所以 x.shape[1] = 20 = 特征总数
# print(len(y.value_counts())) # 输出 4 : 目标值的总类别数
print(f'总类别数 = {len(y.unique())}') # 输出 4 : 目标值的总类别数 【 unique去重: DataFrame 没有 unique(),只有 Series 有】
return train_dataset, test_dataset, x.shape[1], len(y.unique())
if __name__ == '__main__':
train_dataset, test_dataset, feature_count, target_category_count = create_dataset()
三、构建分类网络模型
自己写的:
python
class PhonePriceModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# 隐藏层1 + 批量归一正则化 + Relu + dropout
self.linear1 = nn.Linear(in_features=in_features, out_features=128)
self.bn1 = nn.BatchNorm1d(num_features=128, track_running_stats=True) # 核心原则:放在"线性变换之后,非线性激活之前"
self.relu1 = nn.ReLU() # 可以不加这个, 因为 ReLu 函数都是一样的, 但是为了更好展示数据流向关系, 还是加上
self.dropout1 = nn.Dropout(p=0.3) # Dropout 应用于"无界"或"稀疏"的激活输出之后,尤其是那些容易导致神经元强依赖的非线性层之后。ReLU 及其变体之后(强烈推荐)
# (注意:早期有人把 Dropout 放在 BN 前,但现在普遍认为放在 ReLU 后更合理.
# 因为 BN 输出已经是归一化的,再经 ReLU 产生稀疏激活,此时加 Dropout 能有效打破神经元依赖)
# 隐藏层1 + 批量归一正则化 + Relu + dropout
self.linear2 = nn.Linear(in_features=128, out_features=256)
self.bn2 = nn.BatchNorm1d(num_features=256, track_running_stats=True) # 核心原则:放在"线性变换之后,非线性激活之前"
self.relu2 = nn.ReLU() # 可以不加这个, 因为 ReLu 函数都是一样的, 但是为了更好展示数据流向关系, 还是加上
self.dropout2 = nn.Dropout(p=0.5) # 越深 dropout 比例越高"的经验原则
# 输出
self.output_linear = nn.Linear(in_features=256, out_features=out_features) # 【 output n.输出 】
def forward(self, x):
# 隐藏层1 + 批量归一正则化 + Relu + dropout
y = self.linear1(x)
x = self.bn1(y)
active = self.relu1(x)
x_drop = self.dropout1(active)
# 隐藏层2 + 批量归一正则化 + Relu + dropout
y = self.linear2(x_drop)
x = self.bn2(y)
active = self.relu2(x)
x_drop = self.dropout2(active)
# 隐藏层3 + 输出
output = self.output_linear(x_drop) # 由于多分类任务, 理应使用 softmax, 但 CrossEntropyLoss 自带 softmax, 所以输出时不用 softmax
return output
# 别用 'test_model' 这个以 'test' 开头的名字
# 这个错误是因为 PyCharm 使用 pytest 来运行代码,而 pytest 会将所有以 test_ 开头的函数识别为测试函数。
# 测试函数不应该有参数,除非这些参数是 pytest fixtures。
def show_model(train_dataset, test_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(feature_count, target_category_count)
summary(model=my_model, input_size=(feature_count, ))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Linear-1 [-1, 128] 2,688 计算: 128 * 20 + 128 = 2688
# BatchNorm1d-2 [-1, 128] 256
# ReLU-3 [-1, 128] 0
# Dropout-4 [-1, 128] 0
# Linear-5 [-1, 256] 33,024
# BatchNorm1d-6 [-1, 256] 512
# ReLU-7 [-1, 256] 0
# Dropout-8 [-1, 256] 0
# Linear-9 [-1, 4] 1,028
# ================================================================
# Total params: 37,508
# Trainable params: 37,508
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.01
# Params size (MB): 0.14
# Estimated Total Size (MB): 0.15
# ----------------------------------------------------------------
if __name__ == '__main__':
train_dataset, test_dataset, feature_count, target_category_count = create_dataset()
show_model(train_dataset, test_dataset, feature_count, target_category_count)
四、模型训练
1、nn.Module.state_dict()
.state_dict()是nn.Module类的一个方法。- 它返回一个 Python 字典(dict) ,里面包含了模型中所有可学习的参数(learnable parameters) 。
- 比如:每一层的权重(weight)、偏置(bias)
- 还包括 BatchNorm 层的 running_mean、running_var 等状态(如果你用了
track_running_stats=True)
📌 重点:
.state_dict()不包含模型结构!只包含参数值!
举个例子,假设模型有:
linear1.weight(形状 [128, 20])linear1.bias(形状 [128])bn1.weight,bn1.bias,bn1.running_mean,bn1.running_var- ......等等
那么 my_model.state_dict() 就是一个字典,键是这些参数的名字,值是对应的张量:
python
{
'linear1.weight': tensor([[...], [...], ...]),
'linear1.bias': tensor([...]),
'bn1.weight': tensor([...]),
'bn1.bias': tensor([...]),
'bn1.running_mean': tensor([...]),
...
}
✅ 所以 my_model.state_dict() 的作用是:提取模型当前的所有参数值,打包成一个字典。
2、torch.save(obj, filepath)
这是 PyTorch 提供的一个 保存对象到磁盘 的函数。
-
功能:把任意 Python 对象(比如张量、模型参数、字典等)以二进制形式保存到指定路径。
-
语法:
torch.save(obj, filepath)obj:要保存的对象(必须是可序列化的)filepath:保存的文件路径(字符串)
✅ 所以这行代码的意思是:把某个对象保存到
'model/phone.pth'这个文件里。
但这个"某个对象"到底是什么呢?------就是中间那个 my_model.state_dict()。
python
torch.save(my_model.state_dict(), r'model/phone.pth')
含义:
将训练好的模型
my_model的所有参数(权重、偏置等)保存到当前目录下的model/phone.pth文件中。
为什么这么做?
- 不保存整个模型对象(因为模型类定义可能变化,依赖环境复杂)
- 只保存参数(轻量、通用、安全)
- 后续加载时,先定义相同的模型结构,再用
.load_state_dict()把参数"灌"进去即可
✅ 保存文件的后缀名
torch.save(obj, filepath) 本身不限制文件后缀名 ,你可以用任意后缀(比如 .txt、.bin、.model),但 PyTorch 社区和官方推荐使用以下几种标准后缀,以表达语义和便于协作:
✅ 推荐的常用后缀名:
| 后缀名 | 含义说明 |
|---|---|
.pth |
最常见!是 "PyTorch " 的缩写,广泛用于保存模型参数(state_dict)或完整模型。 |
.pt |
官方文档常用,简洁,含义同 .pth。PyTorch 官方教程和 torchvision 模型多用此格式。 |
.ckpt |
常见于训练中间检查点(checkpoint),尤其在 PyTorch Lightning 中流行。 |
📌 举例:
pythontorch.save(model.state_dict(), 'model.pth') # 👍 推荐 torch.save(model.state_dict(), 'model.pt') # 👍 也推荐 torch.save(checkpoint, 'epoch_10.ckpt') # 👍 用于训练中断恢复
❌ 不推荐的做法:
- 使用
.h5(这是 TensorFlow/Keras 的格式) - 使用
.pkl(虽然torch.save底层用的是 pickle,但语义不明确) - 不加后缀(如
'model'),不利于识别文件类型
🔍 技术细节补充:
torch.save()默认使用 Python 的pickle协议序列化对象(也可以选其他后端,但默认是 pickle)。- 文件内容是二进制的,不是文本,所以不能直接打开看。
- 后缀名不影响功能 ,只影响可读性和约定 。
→ 即:model.pth和model.txt在torch.load()时都能正常加载,只要内容是torch.save生成的。
✅ 最佳实践建议:
如果你保存的是 模型参数(state_dict),推荐用:
python
torch.save(model.state_dict(), 'phone_model.pth')
如果你保存的是 完整模型(不推荐):
python
torch.save(model, 'phone_model_full.pt') # 包含结构 + 参数,但依赖类定义
如果保存的是 训练检查点(含优化器、epoch 等):
python
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'loss': loss,
}, 'checkpoint.ckpt')
总结:
虽然
torch.save()对后缀名没有强制要求,但强烈建议使用.pth或.pt来表示 PyTorch 模型/参数文件,这是社区通用规范,清晰、专业、不易混淆。
3、不同模式(cpu、gpu)下的 torch.save
✅ 核心原则(先记住)
无论当前是 CPU 还是 GPU 模式,保存模型时,都应把参数转为 CPU 张量再保存。
这样生成的
.pth文件才是通用、可移植、无设备依赖的。
🧾 情况 1:当前在 CPU 模式 下保存
✅ 推荐代码:
python
# 假设 model 当前就在 CPU 上
torch.save(model.state_dict(), 'model.pth')
🔍 说明:
- 因为
model本来就在 CPU,所以state_dict()中的张量都是 CPU 张量; - 直接保存即可,不需要额外操作;
- 生成的文件可以在任何设备(CPU/GPU)上加载。
✅ 安全、简洁、正确。
🧾 情况 2:当前在 GPU 模式 下保存
❌ 错误写法(不要这样!):
python
# 危险!保存的是 GPU 张量
torch.save(model.state_dict(), 'model.pth') # ⚠️ 不推荐
→ 这样保存的模型只能在有 GPU 的环境加载,否则会报错!
✅ 正确写法(推荐):
方法 A:临时转 CPU 保存(不改变原模型)
python
# 把 state_dict 中每个张量转成 CPU,再保存
torch.save({k: v.cpu() for k, v in model.state_dict().items()}, 'model.pth')
✅ 优点:
- 原
model仍然留在 GPU 上,可以继续训练; - 保存的是纯 CPU 参数,通用性强。
方法 B:用 .cpu() 保存(会移动原模型)
python
# 注意:这会把 model 本身移到 CPU!
torch.save(model.cpu().state_dict(), 'model.pth')
# 如果之后还要训练,记得移回 GPU
model.to(device) # device 是你原来的设备,如 'cuda'
✅ 优点:代码简短。
⚠️ 注意:model.cpu() 是 in-place 操作,会改变原模型的设备位置!
📌 最佳实践总结
| 当前设备 | 推荐保存代码 |
|---|---|
| CPU | torch.save(model.state_dict(), 'model.pth') |
| GPU | torch.save({k: v.cpu() for k, v in model.state_dict().items()}, 'model.pth') torch.save(model.cpu().state_dict(), 'model.pth') |
💡 这个 GPU 写法既安全又不会影响原模型,强烈推荐!
🔁 补充:统一写法(不管当前是什么设备)
如果你不想判断当前是 CPU 还是 GPU,可以用一个通用函数:
python
def save_model_cpu(model, path):
"""将模型参数转为 CPU 后保存,适用于任何设备"""
cpu_state_dict = {k: v.cpu() for k, v in model.state_dict().items()}
torch.save(cpu_state_dict, path)
# 使用
save_model_cpu(model, 'model.pth')
这样,无论 model 在 CPU 还是 GPU,都能安全保存为通用格式。
✅ 验证:加载时是否方便?
用上述方法保存的 model.pth,加载时极其简单:
python
# 创建模型结构(默认在 CPU)
model = MyModel()
# 直接加载(无需 map_location!)
model.load_state_dict(torch.load('model.pth'))
# 按需移到目标设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
✅ 完美兼容所有环境!
🎯 终极结论
| 场景 | 你应该写的代码 |
|---|---|
| CPU 下保存 | torch.save(model.state_dict(), 'xxx.pth') |
| GPU 下保存 | torch.save({k: v.cpu() for k, v in model.state_dict().items()}, 'xxx.pth') |
🌟 记住:保存的是"参数值",不是"设备"。把参数存成 CPU 格式,天下通吃!
这样你就再也不用担心"模型在哪训练的""能不能在别人电脑跑"这类问题了!
4、torch.load(f='model/phone.pth')
功能:
- 从磁盘加载一个用
torch.save()保存的对象。 - 在你这个场景中,之前保存的是
my_model.state_dict()(一个字典),所以这里加载回来的就是一个字典。
参数说明:
f='model/phone.pth':指定要加载的文件路径。f是 "file" 的缩写,也可以直接写成位置参数:torch.load('model/phone.pth')
- 返回值:当初保存的那个对象(这里是
state_dict字典)
✅ 所以这一步的结果是:
python
state_dict = {
'linear1.weight': tensor(...),
'linear1.bias': tensor(...),
'bn1.weight': tensor(...),
...
}
5、不同模式(cpu、gpu)下的 torch.load
torch.load在不同保存情况和不同运行环境下,应该怎么写?
✅ 前提:模型是怎么保存的?(决定加载方式)
PyTorch 模型保存时,参数张量会携带设备信息。所以加载时是否出错,取决于:
- 模型是在 CPU 还是 GPU 上保存的?
- 你现在是在 CPU 还是 GPU 环境下加载?
我们分四种组合来看(✅ 表示安全,❌ 表示可能报错):
| 保存设备 | 当前加载环境 | 直接 torch.load() 是否安全? |
|---|---|---|
| CPU | CPU | ✅ 安全 |
| CPU | GPU | ✅ 安全(PyTorch 自动转) |
| GPU | GPU | ✅ 安全(同设备) |
| GPU | CPU | ❌ 会报错! |
🔥 唯一危险的情况:在 CPU 环境加载 GPU 保存的模型,且没用
map_location。
📥 所以,torch.load 的写法分两类
✅ 类型一:你确定模型是 CPU 格式保存的(推荐做法)
比如你自己保存时用了:
python
torch.save({k: v.cpu() for k, v in model.state_dict().items()}, 'model.pth')
# 或
torch.save(model.cpu().state_dict(), 'model.pth')
➤ 加载代码(最简单):
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel() # 结构在 CPU
state_dict = torch.load('model.pth') # ✅ 直接加载,无需 map_location
model.load_state_dict(state_dict)
model.to(device) # 移到目标设备(CPU 或 GPU)
✅ 优点:代码简洁,兼容所有环境。
✅ 类型二:你不确定模型是 CPU 还是 GPU 保存的(通用安全写法)
这是最推荐的工业级写法,适用于加载任何来源的模型(自己训练的、别人给的、网上下载的)。
➤ 加载代码(万能模板):
python
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = MyModel().to(device) # 先把模型放到目标设备
# 关键:用 map_location 把参数加载到同一设备
state_dict = torch.load('model.pth', map_location=device)
model.load_state_dict(state_dict)
✅ 这样写:
- 如果模型是 CPU 保存的 → 自动转到
device;- 如果是 GPU 保存的 → 也转到
device(即使当前是 CPU 也不会崩);- 永远安全,永远不出错!
map_location是 PyTorch 中torch.load()函数的一个关键参数,它的作用是:指定在加载模型(或其他张量数据)时,将原本保存在某个设备(如 GPU)上的张量,"映射"到你当前希望使用的设备(如 CPU 或另一块 GPU)上。
🧠 为什么需要它?
PyTorch 在保存模型(比如
torch.save(model.state_dict(), 'model.pth'))时,会把每个张量所在的设备信息(device)也一起存进去。
- 如果你在 GPU 上训练并保存 ,那么
.pth文件里的张量就标记为cuda:0。- 如果你在 CPU 上训练并保存 ,张量就标记为
cpu。当你用
torch.load()加载时,PyTorch 默认会尝试把张量放回它原来所在的设备。👉 问题来了:
如果你现在没有 GPU(比如在普通笔记本上),却去加载一个 GPU 保存的模型,就会报错!
python# 报错示例(在 CPU 环境加载 GPU 模型) torch.load('gpu_model.pth') # RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False.
✅
map_location就是用来解决这个问题的!它告诉 PyTorch:
"别管这个模型原来在哪,统统加载到我指定的设备上!"
🔧 常见用法
- 加载到 CPU(最常用)
pythontorch.load('model.pth', map_location='cpu')
- 加载到当前可用设备(推荐写法)
pythondevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu') torch.load('model.pth', map_location=device)
- 从 GPU0 加载到 GPU1
pythontorch.load('model.pth', map_location={'cuda:0': 'cuda:1'})
- 强制全部转 CPU(即使有 GPU)
pythontorch.load('model.pth', map_location=torch.device('cpu'))
✅ 举个完整例子
pythonimport torch # 定义模型结构 model = MyModel() # 指定目标设备 device = torch.device('cpu') # 或 'cuda' # 安全加载:无论 model.pth 是 CPU 还是 GPU 保存的,都能正确加载到 device 上 state_dict = torch.load('model.pth', map_location=device) # 加载参数 model.load_state_dict(state_dict) # 把模型移到目标设备(其实参数已经在 device 上了,这步可省略,但习惯保留) model.to(device)
📌 总结一句话
map_location就是"加载时的设备重定向器"------它确保模型能从任何保存环境,安全加载到你当前的运行设备上,避免因 GPU/CPU 不匹配而崩溃。✅ 最佳实践:只要用
torch.load(),就加上map_location=device!
🧪 举个实际例子
场景:你在 Colab(GPU)训练,保存时忘了 .cpu()
python
# 错误保存(GPU 格式)
torch.save(model.state_dict(), 'bad_model.pth') # 参数在 cuda:0
现在你想在本地笔记本(只有 CPU)加载它:
❌ 错误加载(会崩溃):
python
model = MyModel()
params = torch.load('bad_model.pth') # RuntimeError!
✅ 正确加载(用 map_location):
python
device = torch.device('cpu')
model = MyModel().to(device)
params = torch.load('bad_model.pth', map_location=device) # ✅ 成功!
model.load_state_dict(params)
🛠 高级技巧:自动适配 + 安全加载(PyTorch ≥2.0)
python
def load_model_safe(model_class, path):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model_class().to(device)
# 安全模式:只加载张量,防恶意代码
try:
state_dict = torch.load(path, map_location=device, weights_only=True)
except TypeError: # 兼容旧版 PyTorch
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
return model
# 使用
model = load_model_safe(MyModel, 'any_model.pth')
📌 终极总结:怎么写 torch.load?
| 需求 | 推荐写法 |
|---|---|
| 自己保存的模型(已转 CPU) | torch.load('model.pth')(无需 map_location) |
| 加载任意来源的模型(最安全) | torch.load('model.pth', map_location=device) |
| 追求最高安全性(防恶意模型) | torch.load(..., map_location=device, weights_only=True) |
💡 黄金法则
只要加上
map_location=device,torch.load就永远不会因为设备问题而失败。所以,养成习惯:永远写
map_location=device!
这样代码就能:
- 在服务器 GPU 训练 → 本地 CPU 测试;
- 在手机端部署;
- 分享给任何人使用;
- 跨平台、跨设备、零兼容问题!
6、nn.Module.load_state_dict(...)
谁调用的?
- 是你新创建的模型对象
model调用的。 - 这个
model必须是和当初保存时结构完全一致 的PhonePriceModel实例。
功能:
- 将传入的
state_dict(参数字典)"注入"到当前模型的对应层中。 - PyTorch 会根据字典的 key(如
'linear1.weight')自动匹配模型内部的参数名,并赋值。
注意事项:
- 模型结构必须一致!如果现在模型少了一层,或多了一个参数,就会报错。
- 默认情况下,如果
state_dict中有模型没有的 key,或者模型有但state_dict没有,会报错。- 可通过
strict=False放宽要求(但一般不建议)。
- 可通过
✅ 所以这一步的作用是:把硬盘上存的参数,重新装回模型里。
7、保存模型的张量,再加载:完整流程回顾
完整流程回顾(保存 → 加载)
- 训练并保存(你之前的代码):
python
my_model = PhonePriceModel(feature_count, target_category_count)
# ...训练过程...
torch.save(my_model.state_dict(), 'model/phone.pth') # 只存参数
- 后续加载(比如在另一个脚本或重启后):
python
# 第一步:重建模型结构(必须和训练时一模一样!)
model = PhonePriceModel(feature_count, target_category_count)
# 第二步:加载参数
model.load_state_dict(torch.load('model/phone.pth'))
# 第三步:设为评估模式(关闭 dropout、batchnorm 的训练行为)
model.eval()
✅ 此时
model就和训练结束时的my_model完全一样了!
四、常见错误 & 建议
❌ 错误1:没创建模型就直接 load_state_dict
python
# 错!model 未定义
model.load_state_dict(torch.load('model/phone.pth'))
✅ 正确:先实例化模型。
❌ 错误2:模型结构变了(比如改了网络层数)
- 保存时是 3 层,加载时是 4 层 → key 不匹配 → 报错。
- 解决:确保
PhonePriceModel的代码没变,或者版本管理好。
✅ 建议:加上 map_location(尤其在 CPU/GPU 切换时)
如果你在 GPU 上训练,但在 CPU 上加载,需要指定设备:
python
model.load_state_dict(
torch.load('model/phone.pth', map_location=torch.device('cpu'))
)
否则可能报错:Attempting to deserialize object on a CUDA device...
五、总结(一句话)
model.load_state_dict(torch.load('model/phone.pth'))的意思是:
从phone.pth文件中读取模型参数,并把这些参数加载到已创建好的model对象中,使其恢复到保存时的状态。
这是 PyTorch 推荐的标准模型加载方式 ------结构 + 参数分离,灵活又安全!
8、代码 & 趋势 & 解释
python
def model_train(train_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(in_features=feature_count, out_features=target_category_count) # 创建模型
my_model.train() # 训练阶段
# 查看模型参数:
# print('模型参数: ')
# for name, parameter in my_model.named_parameters():
# print(name, parameter)
# 数据加载器
train_dataloader = DataLoader(dataset=train_dataset, batch_size=50, shuffle=True)
# 优化器
optimizer = optim.Adam(params=my_model.parameters(), lr=0.001, betas=(0.9, 0.99))
# 损失函数
criterion = nn.CrossEntropyLoss()
epochs = 50
loss_sum_list = [] # 记录每个 batch 里的总损失, 用于画图
loss_avg_list = [] # 记录每个 batch 里的平均损失, 用于画图
for epoch in range(epochs):
print(f'第 {epoch + 1} 次 epoch: ')
loss_sum = 0.0 # 统计当前 epoch 的总损失
batch_num = 0 # 统计用了多少个 batch, 用计算平均 损失
for x_train, y_true in train_dataloader:
optimizer.zero_grad() # 1. 清零梯度
y_predict = my_model(x_train) # 2. 前向传播
loss = criterion(y_predict, y_true) # 3. 计算损失值
loss.backward() # 4. 反向传播
optimizer.step() # 5. 梯度更新
loss_sum += loss.item() # 只计算数值, 用 item()
batch_num += 1
loss_sum_list.append(loss_sum) # 总损失 画折线图
avg_loss = loss_sum / batch_num
loss_avg_list.append(avg_loss) # 评价损失 画折线图
print(f'总损失值 = {loss_sum}') # loss 为什么会波动
# 第 1 次 epoch:
# 总损失值 = 24.427303969860077
# 第 2 次 epoch:
# 总损失值 = 14.080154180526733
# 第 3 次 epoch:
# 总损失值 = 11.578682631254196
# 第 4 次 epoch:
# 总损失值 = 11.567044869065285
# 第 5 次 epoch:
# 总损失值 = 11.530409216880798
# ...
# 第 48 次 epoch:
# 总损失值 = 7.979451455175877
# 第 49 次 epoch:
# 总损失值 = 9.96169776469469
# 第 50 次 epoch:
# 总损失值 = 9.308996006846428
# 总损失
plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.plot(range(1, epochs + 1), loss_sum_list)
plt.title('每个 epoch 总损失趋势')
plt.xlabel('epoch')
plt.ylabel('loss值')
plt.show()
# 平均损失: 平均损失 和 总损失 仅仅差个系数, 就是简单的多乘个 batch 而已, 和 总损失 的趋势是一模一样的
plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.plot(range(1, epochs + 1), loss_avg_list)
plt.title('每个 epoch 平均损失趋势')
plt.xlabel('epoch')
plt.ylabel('loss值')
plt.show()
all_params = my_model.state_dict()
torch.save(all_params, r'model/phone.pth') # 把模型里的可学习参数全都保存起来
if __name__ == '__main__':
train_dataset, test_dataset, feature_count, target_category_count = create_dataset()
show_model(train_dataset, test_dataset, feature_count, target_category_count)
# 模型参数保存好后, 就可以不调用这个函数重新训练了
model_train(train_dataset, feature_count, target_category_count)

"为什么 loss 会波动?"
从训练损失图来看,loss 在前几轮快速下降后,进入一个持续的、小幅上下震荡的过程(比如第10~50轮之间),甚至在第49轮出现明显反弹。
这其实是深度学习中非常常见且完全正常的现象。下面我们来系统解释:
✅ 一、根本原因:随机梯度下降(SGD)的本质是"近似"优化
你在用的是 Adam 优化器,它属于 基于 mini-batch 的随机梯度下降(SGD)变种。
🔹 每个 epoch 的 loss 是多个 batch 的平均值
python
for x_train, y_true in train_dataloader:
...
loss = criterion(y_predict, y_true)
loss_sum += loss.item()
👉 所以你记录的 loss_sum 实际上是:
当前 epoch 内所有 batch 的 loss 总和
而每个 batch 的数据是随机采样 的(因为 shuffle=True),所以:
- 不同 batch 的样本分布不同
- 模型对这些 batch 的预测误差也不同
- 因此每个 batch 的 loss 也会有差异
➡️ 即使模型已经收敛,batch-level 的 loss 仍然会有波动 ,导致最终的 epoch_loss 波动。
✅ 二、具体分析 loss 曲线
观察图表:
| 阶段 | 行为 | 原因 |
|---|---|---|
| Epoch 1~5 | 快速下降 | 模型刚开始学习,参数更新大,损失迅速降低 |
| Epoch 6~30 | 小幅震荡,缓慢下降 | 接近最优解区域,梯度变小,更新更精细 |
| Epoch 30~48 | 波动加剧,偶有回升 | 可能遇到局部困难样本或噪声 |
| Epoch 49 | 明显反弹 | 这是最典型的"过拟合迹象 "或"优化不稳定" |
✅ 三、为什么会出现"反弹"?可能的原因
📌 1. 学习率过高
- Adam 的默认学习率是 0.001,对你这个任务来说可能是偏高的。
- 当模型接近最优时,大的 step size 会导致"来回摆动",就像你在山顶附近走得太快会跌下去一样。
✅ 建议 :尝试把 lr=0.0001 或使用学习率调度器(如 ReduceLROnPlateau)
📌 2. Batch Size 较小(50)
- Batch Size 越小,每个 batch 的统计信息越不准确 → loss 波动越大
- 你用了
batch_size=50,对于 2000 样本的数据集,相当于每轮只有 40 个 batch,容易受个别难样本影响
✅ 建议:可以试试增大到 100 或 200,看是否平滑一些
📌 3. 数据本身存在噪声或类别不平衡
虽然你没提,但我们可以推测:
- 数据中可能存在某些手机特征与价格关系较弱
- 某些类别样本较少(比如 price_range=0 的样本很少)
- 导致模型在某些 batch 上表现差,从而拉高 loss
✅ 建议:检查各类别的样本数量分布
📌 4. Dropout 和 BN 的作用
- Dropout 在训练时是随机丢弃神经元的,每次 forward 都不一样
- BN 使用的是 batch 统计量(均值/方差),也会随 batch 改变
➡️ 这两个机制都会让模型在不同 batch 中行为略有差异 → loss 波动
✅ 四、如何判断这是"正常"还是"问题"?
| 判断标准 | 是否正常? | 说明 |
|---|---|---|
| ✅ Loss 整体趋势向下 | ✔️ 正常 | 即使有波动,只要长期下降就是好的 |
| ✅ 波动幅度不大(<10%) | ✔️ 正常 | 说明模型稳定 |
| ❌ 多次大幅反弹(如 +20%) | ⚠️ 警告 | 可能是学习率太高或过拟合 |
| ❌ 最终 loss 不降反升 | ⚠️ 警告 | 可能过拟合 |
👉 图中,整体是下降的 ,只是后期波动大,说明模型正在学习 ,但不够稳定。
✅ 五、改进方案(可选)
✅ 方案1:加学习率调度器
python
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=5, verbose=True)
# 在每个 epoch 后调用:
scheduler.step(avg_loss)
这样当 loss 停滞 5 轮时,自动将 lr 减半,避免震荡。
✅ 方案2:增加 batch size
python
train_dataloader = DataLoader(dataset=train_dataset, batch_size=100, shuffle=True)
减少 batch 数量,提高每个 batch 的代表性。
✅ 方案3:可视化 per-batch loss(进阶)
你可以打印每个 batch 的 loss,看看是不是某些特定 batch 导致了波动。
✅ 六、总结:loss 波动 ≠ 错误!
| 说法 | 是否正确 | 解释 |
|---|---|---|
| ✅ loss 波动是正常的 | ✔️ | 因为 SGD 是随机的 |
| ✅ 波动不代表模型不好 | ✔️ | 只要趋势向下即可 |
| ✅ 第49轮反弹可能是过拟合 | ⚠️ | 建议加入早停或降低学习率 |
| ✅ 应该关注"平均 loss"而非"单点" | ✔️ | 平滑后的曲线更能反映真实进展 |
🎯 最终结论:
** loss 波动是完全正常的!**
它反映了模型在 mini-batch 上的随机性,以及 Adam 优化器的动态特性。
只要:
- 整体趋势是下降的
- 最终测试准确率达到 95%+
- 没有严重发散
👉 就说明模型训练成功了!
现在看到的"波动"不是 bug,而是深度学习的"心跳" ❤️
五、模型评估
python
def estimate_model(test_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(in_features=feature_count, out_features=target_category_count)
all_params = torch.load(r'model/phone.pth') # 读取保存的模型参数
my_model.load_state_dict(all_params) # 把参数注入到模型里
my_model.eval() # 切换到测试模式
# 数据加载器
test_dataloader = DataLoader(dataset=test_dataset, batch_size=50, shuffle=False)
# epochs = 50 # 注意, 这是测试模式, 是没有 epochs 的
right_cnt = 0 # 预测正确的个数
for x_test, y_test in test_dataloader:
y_logits = my_model(x_test) # 得到的是 logits得分
print(f'原始 logits: \n{y_logits}')
# tensor([[-11.5542, -1.9179, 2.5488, -1.2692],
# [ 2.1304, 1.7906, -7.7733, -15.1395],
# ...
# [-18.7275, -12.8467, 1.5550, 4.5225]],
# grad_fn=<AddmmBackward0>)
# PhonePriceModel 模型最后的输出层没用 softmax, 所以这里要用 softmax
# 原始得分(logits)越高,经过 softmax 函数转换后对应的概率就越大, 所以这里也可以不用 softmax, 直接看最大的 logits得分
# 为了完善流程, 还是用 softmax
y_predict_probability = torch.softmax(y_logits, dim=1) # 得到概率, 每个样本对应的 4 个概率的和为 1
print(f'softmax处理过后的概率: \n{y_predict_probability}')
# tensor([[7.2588e-07, 1.1114e-02, 9.6762e-01, 2.1261e-02], ≈ 0.0000007 + 0.0111 + 0.9676 + 0.0213 ≈ 1.0000
# [5.8414e-01, 4.1583e-01, 2.9200e-05, 1.8462e-08], ≈ 0.58414 + 0.41583 + 0.0000292 + 极小值 ≈ 0.999999 ≈ 1
# ...
# [7.6012e-11, 2.7220e-08, 4.8917e-02, 9.5108e-01]], ≈ 极小值 + 极小值 + 0.048917 + 0.95108 ≈ 1.0000
# grad_fn=<SoftmaxBackward0>)
y_predict = y_predict_probability.argmax(dim=1) # 返回最大值的索引 【《PyTorch框架使用》下《5 张量运算函数》下《1.基础统计类函数》】
print(f'预测的类型: \n{y_predict}') # tensor([2, 0, 1, 3, 2, ...])
print(f'真实值: \n{y_test}') # tensor([2, 0, 1, 3, 2, ...])
print(f'预测值 == 真实值 :\n{y_predict == y_test}') # tensor([True, True, True, True, ...])
print(f'当前 batch 中预测正确的个数 = {(y_predict == y_test).sum().item()}') # 50
count = (y_predict == y_test).sum().item() # 统计预测正确的个数
right_cnt += count
simple_size = len(test_dataset) # 总共有多少个测试样本
print(f'测试集总个数 = {simple_size}') # 400
print(f'预测正确总个数 = {right_cnt}') # 381
print(f'预测正确率 = {right_cnt / simple_size}') # 0.9525
if __name__ == '__main__':
train_dataset, test_dataset, feature_count, target_category_count = create_dataset()
show_model(train_dataset, test_dataset, feature_count, target_category_count)
# 模型参数保存好后, 就可以不调用这个函数重新训练了
model_train(train_dataset, feature_count, target_category_count)
estimate_model(test_dataset, feature_count, target_category_count)
六、性能优化
虽然当前代码已非常完善,但若想进一步提升,可考虑:
| 方向 | 建议 |
|---|---|
| 早停(Early Stopping) | 监控验证损失,防止过拟合(你第 49 轮 loss 反弹) |
| 学习率调度 | ReduceLROnPlateau 在 loss 停滞时降 lr |
| 随机种子统一 | 补充 np.random.seed(66) 和 random.seed(66) |
| 设备兼容性 | 添加 .to(device) 支持 GPU(虽本任务 CPU 足够) |
但这些都属于"锦上添花",当前代码已能稳定复现高精度结果。
七、整体代码
python
import torch
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
import torch.nn as nn
from torchsummary import summary
import torch.optim as optim
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import time
import os
os.chdir(r'F:\Pycharm\works-space\神经网络')
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE" # ←←← 关键!放在最前面(解决报错)
from pylab import mpl
mpl.rcParams["font.sans-serif"] = ["SimHei"] # 设置显示中文字体
mpl.rcParams["axes.unicode_minus"] = False # 设置正常显示符号
torch.manual_seed(66)
def create_dataset():
data_df = pd.read_csv(r'data/手机价格预测.csv')
# 目标值可用值: [0, 1, 2, 3]
print(f'data_df.shape = {data_df.shape}') # (2000, 21)
# print(data_df.head())
# battery_power blue clock_speed ... touch_screen wifi price_range
# 0 842 0 2.2 ... 0 1 1
# 1 1021 1 0.5 ... 1 0 2
# 2 563 1 0.5 ... 1 0 2
x, y = data_df.iloc[: , : -1], data_df.iloc[: , -1] # x: 所有特征列 y: 所有目标值
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=88)
# 需要把 DataFrame 转成张量
x_train = torch.tensor(data=x_train.to_numpy(dtype=np.float32), dtype=torch.float32)
x_test = torch.tensor(data=x_test.to_numpy(dtype=np.float32), dtype=torch.float32)
y_train = torch.tensor(data=y_train.to_numpy(dtype=np.float32), dtype=torch.long) # 标签, CrossEntropyLoss需要的是64位整数
y_test = torch.tensor(data=y_test.to_numpy(dtype=np.float32), dtype=torch.long) # 标签, CrossEntropyLoss需要的是64位整数
# 创建数据集: x_train 和 y_train 对应
# 创建数据集: x_test 和 y_test 对应
train_dataset = TensorDataset(x_train, y_train)
test_dataset = TensorDataset(x_test, y_test)
print(f'特征总数 = {x.shape[1]}') # (2000, 20) 所以 x.shape[1] = 20 = 特征总数
# print(len(y.value_counts())) # 输出 4 : 目标值的总类别数
print(f'总类别数 = {len(y.unique())}') # 输出 4 : 目标值的总类别数 【 unique去重: DataFrame 没有 unique(),只有 Series 有】
return train_dataset, test_dataset, x.shape[1], len(y.unique())
class PhonePriceModel(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
# 隐藏层1 + 批量归一正则化 + Relu + dropout
self.linear1 = nn.Linear(in_features=in_features, out_features=128)
self.bn1 = nn.BatchNorm1d(num_features=128, track_running_stats=True) # 核心原则:放在"线性变换之后,非线性激活之前"
self.relu1 = nn.ReLU() # 可以不加这个, 因为 ReLu 函数都是一样的, 但是为了更好展示数据流向关系, 还是加上
self.dropout1 = nn.Dropout(p=0.3) # Dropout 应用于"无界"或"稀疏"的激活输出之后,尤其是那些容易导致神经元强依赖的非线性层之后。ReLU 及其变体之后(强烈推荐)
# (注意:早期有人把 Dropout 放在 BN 前,但现在普遍认为放在 ReLU 后更合理.
# 因为 BN 输出已经是归一化的,再经 ReLU 产生稀疏激活,此时加 Dropout 能有效打破神经元依赖)
# 隐藏层1 + 批量归一正则化 + Relu + dropout
self.linear2 = nn.Linear(in_features=128, out_features=256)
self.bn2 = nn.BatchNorm1d(num_features=256, track_running_stats=True) # 核心原则:放在"线性变换之后,非线性激活之前"
self.relu2 = nn.ReLU() # 可以不加这个, 因为 ReLu 函数都是一样的, 但是为了更好展示数据流向关系, 还是加上
self.dropout2 = nn.Dropout(p=0.5) # 越深 dropout 比例越高"的经验原则
# 输出
self.output_linear = nn.Linear(in_features=256, out_features=out_features) # 【 output n.输出 】
def forward(self, x):
# 隐藏层1 + 批量归一正则化 + Relu + dropout
y = self.linear1(x)
x = self.bn1(y)
active = self.relu1(x)
x_drop = self.dropout1(active)
# 隐藏层2 + 批量归一正则化 + Relu + dropout
y = self.linear2(x_drop)
x = self.bn2(y)
active = self.relu2(x)
x_drop = self.dropout2(active)
# 隐藏层3 + 输出
output = self.output_linear(x_drop) # 由于多分类任务, 理应使用 softmax, 但 CrossEntropyLoss 自带 softmax, 所以输出时不用 softmax
return output
# 别用 'test_model' 这个以 'test' 开头的名字
# 这个错误是因为 PyCharm 使用 pytest 来运行代码,而 pytest 会将所有以 test_ 开头的函数识别为测试函数。
# 测试函数不应该有参数,除非这些参数是 pytest fixtures。
def show_model(train_dataset, test_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(feature_count, target_category_count)
summary(model=my_model, input_size=(feature_count, ))
# ----------------------------------------------------------------
# Layer (type) Output Shape Param #
# ================================================================
# Linear-1 [-1, 128] 2,688 计算: 128 * 20 + 128 = 2688
# BatchNorm1d-2 [-1, 128] 256
# ReLU-3 [-1, 128] 0
# Dropout-4 [-1, 128] 0
# Linear-5 [-1, 256] 33,024
# BatchNorm1d-6 [-1, 256] 512
# ReLU-7 [-1, 256] 0
# Dropout-8 [-1, 256] 0
# Linear-9 [-1, 4] 1,028
# ================================================================
# Total params: 37,508
# Trainable params: 37,508
# Non-trainable params: 0
# ----------------------------------------------------------------
# Input size (MB): 0.00
# Forward/backward pass size (MB): 0.01
# Params size (MB): 0.14
# Estimated Total Size (MB): 0.15
# ----------------------------------------------------------------
def model_train(train_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(in_features=feature_count, out_features=target_category_count) # 创建模型
my_model.train() # 训练阶段
# 查看模型参数:
# print('模型参数: ')
# for name, parameter in my_model.named_parameters():
# print(name, parameter)
# 数据加载器
train_dataloader = DataLoader(dataset=train_dataset, batch_size=50, shuffle=True)
# 优化器
optimizer = optim.Adam(params=my_model.parameters(), lr=0.001, betas=(0.9, 0.99))
# 损失函数
criterion = nn.CrossEntropyLoss()
epochs = 50
loss_sum_list = [] # 记录每个 batch 里的总损失, 用于画图
loss_avg_list = [] # 记录每个 batch 里的平均损失, 用于画图
for epoch in range(epochs):
print(f'第 {epoch + 1} 次 epoch: ')
loss_sum = 0.0 # 统计当前 epoch 的总损失
batch_num = 0 # 统计用了多少个 batch, 用计算平均 损失
for x_train, y_true in train_dataloader:
optimizer.zero_grad() # 1. 清零梯度
y_predict = my_model(x_train) # 2. 前向传播
loss = criterion(y_predict, y_true) # 3. 计算损失值
loss.backward() # 4. 反向传播
optimizer.step() # 5. 梯度更新
loss_sum += loss.item() # 只计算数值, 用 item()
batch_num += 1
loss_sum_list.append(loss_sum) # 总损失 画折线图
avg_loss = loss_sum / batch_num
loss_avg_list.append(avg_loss) # 评价损失 画折线图
print(f'总损失值 = {loss_sum}') # loss 为什么会波动
# 第 1 次 epoch:
# 总损失值 = 24.427303969860077
# 第 2 次 epoch:
# 总损失值 = 14.080154180526733
# 第 3 次 epoch:
# 总损失值 = 11.578682631254196
# 第 4 次 epoch:
# 总损失值 = 11.567044869065285
# 第 5 次 epoch:
# 总损失值 = 11.530409216880798
# ...
# 第 48 次 epoch:
# 总损失值 = 7.979451455175877
# 第 49 次 epoch:
# 总损失值 = 9.96169776469469
# 第 50 次 epoch:
# 总损失值 = 9.308996006846428
# 总损失
plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.plot(range(1, epochs + 1), loss_sum_list)
plt.title('每个 epoch 总损失趋势')
plt.xlabel('epoch')
plt.ylabel('loss值')
plt.show()
# 平均损失: 平均损失 和 总损失 仅仅差个系数, 就是简单的多乘个 batch 而已, 和 总损失 的趋势是一模一样的
plt.style.use('fivethirtyeight')
plt.figure(figsize=(13, 10))
plt.plot(range(1, epochs + 1), loss_avg_list)
plt.title('每个 epoch 平均损失趋势')
plt.xlabel('epoch')
plt.ylabel('loss值')
plt.show()
all_params = my_model.state_dict()
torch.save(all_params, r'model/phone.pth') # 把模型里的可学习参数全都保存起来
def estimate_model(test_dataset, feature_count, target_category_count):
my_model = PhonePriceModel(in_features=feature_count, out_features=target_category_count)
all_params = torch.load(r'model/phone.pth') # 读取保存的模型参数
my_model.load_state_dict(all_params) # 把参数注入到模型里
my_model.eval() # 切换到测试模式
# 数据加载器
test_dataloader = DataLoader(dataset=test_dataset, batch_size=50, shuffle=False)
# epochs = 50 # 注意, 这是测试模式, 是没有 epochs 的
right_cnt = 0 # 预测正确的个数
for x_test, y_test in test_dataloader:
y_logits = my_model(x_test) # 得到的是 logits得分
print(f'原始 logits: \n{y_logits}')
# tensor([[-11.5542, -1.9179, 2.5488, -1.2692],
# [ 2.1304, 1.7906, -7.7733, -15.1395],
# ...
# [-18.7275, -12.8467, 1.5550, 4.5225]],
# grad_fn=<AddmmBackward0>)
# PhonePriceModel 模型最后的输出层没用 softmax, 所以这里要用 softmax
# 原始得分(logits)越高,经过 softmax 函数转换后对应的概率就越大, 所以这里也可以不用 softmax, 直接看最大的 logits得分
# 为了完善流程, 还是用 softmax
y_predict_probability = torch.softmax(y_logits, dim=1) # 得到概率, 每个样本对应的 4 个概率的和为 1
print(f'softmax处理过后的概率: \n{y_predict_probability}')
# tensor([[7.2588e-07, 1.1114e-02, 9.6762e-01, 2.1261e-02], ≈ 0.0000007 + 0.0111 + 0.9676 + 0.0213 ≈ 1.0000
# [5.8414e-01, 4.1583e-01, 2.9200e-05, 1.8462e-08], ≈ 0.58414 + 0.41583 + 0.0000292 + 极小值 ≈ 0.999999 ≈ 1
# ...
# [7.6012e-11, 2.7220e-08, 4.8917e-02, 9.5108e-01]], ≈ 极小值 + 极小值 + 0.048917 + 0.95108 ≈ 1.0000
# grad_fn=<SoftmaxBackward0>)
y_predict = y_predict_probability.argmax(dim=1) # 返回最大值的索引 【《PyTorch框架使用》下《5 张量运算函数》下《1.基础统计类函数》】
print(f'预测的类型: \n{y_predict}') # tensor([2, 0, 1, 3, 2, ...])
print(f'真实值: \n{y_test}') # tensor([2, 0, 1, 3, 2, ...])
print(f'预测值 == 真实值 :\n{y_predict == y_test}') # tensor([True, True, True, True, ...])
print(f'当前 batch 中预测正确的个数 = {(y_predict == y_test).sum().item()}') # 50
count = (y_predict == y_test).sum().item() # 统计预测正确的个数
right_cnt += count
simple_size = len(test_dataset) # 总共有多少个测试样本
print(f'测试集总个数 = {simple_size}') # 400
print(f'预测正确总个数 = {right_cnt}') # 381
print(f'预测正确率 = {right_cnt / simple_size}') # 0.9525
if __name__ == '__main__':
train_dataset, test_dataset, feature_count, target_category_count = create_dataset()
show_model(train_dataset, test_dataset, feature_count, target_category_count)
# 模型参数保存好后, 就可以不调用这个函数重新训练了
model_train(train_dataset, feature_count, target_category_count)
estimate_model(test_dataset, feature_count, target_category_count)