摘要:想知道如何让你的神经网络模型性能从"还行"飙升到"优秀"吗?本文以一个经典的"手机价格区间预测"案例为切入点,采用模块化的代码讲解方式,带你从零构建一个基础的 PyTorch 分类网络,并一步步实施数据标准化、网络加深、优化器更换等关键优化策略,亲眼见证模型准确率从 64% 提升至 91% 的全过程。
🚀 一、前言:从能用到好用的蜕变之旅
大家好!在深度学习的道路上,我们常常能快速搭建出一个可以运行的模型,但它的性能却不尽如人意。真正的挑战在于如何优化它,让它从"能用"变成"好用"。
今天,我们将通过一个非常接地气的项目------手机价格区间预测,来一场精彩的模型性能提升之旅。本文将分为两大篇章:
-
基础篇:手把手带你用 PyTorch 构建一个基础的全连接神经网络,并看看它的初始表现。
-
进阶优化篇:揭示模型性能不佳的症结,并祭出四大优化"法宝",让模型的准确率实现质的飞跃。
本文最大的特点是代码分模块展示和讲解,让你清晰地了解每一部分代码的作用,轻松上手。准备好了吗?Let's Go!
📊 二、案例背景:手机价格区间预测
🎯 我们的任务 :帮助虚拟的"小明手机店"解决定价难题。我们需要根据手机的20项硬件特征(如RAM、电池、摄像头像素等)自动预测其所属的价格区间(共4个等级:0低价, 1中价, 2高价, 3旗舰价)。

这是一个典型的多分类问题,非常适合用神经网络来解决。
🔨 三、基础篇:从零构建分类网络 (Baseline Model)
我们先搭建一个基础版本,把它作为我们的基线模型 (Baseline),看看它能做到什么程度。
💻 3.1 环境准备 & 导包
首先,导入所有必需的库。
python
# 核心框架
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
# 数据处理与划分
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
# 辅助工具
import time
📊 3.2 数据加载与预处理
用 pandas 读取数据,并用 scikit-learn 将其划分为训练集和测试集。
python
def create_dataset_base():
"""基础版数据加载函数"""
data = pd.read_csv('./data/手机价格预测.csv')
x, y = data.iloc[:, :-1], data.iloc[:, -1]
x = x.astype(np.float32)
x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=3, stratify=y)
train_dataset = TensorDataset(torch.tensor(x_train.values), torch.tensor(y_train.values))
test_dataset = TensorDataset(torch.tensor(x_test.values), torch.tensor(y_test.values))
return train_dataset, test_dataset, x_train.shape[1], len(np.unique(y))
🧠 3.3 构建神经网络模型
定义一个简单的三层全连接网络。
python
class PhonePriceModelBase(nn.Module):
"""基础版神经网络模型"""
def __init__(self, input_dim, output_dim):
super().__init__()
self.linear1 = nn.Linear(input_dim, 128)
self.linear2 = nn.Linear(128, 256)
self.output = nn.Linear(256, output_dim)
def forward(self, x):
x = torch.relu(self.linear1(x))
x = torch.relu(self.linear2(x))
x = self.output(x) # 输出层不加激活,CrossEntropyLoss内部包含了Softmax
return x
⚙️ 3.4 编写训练与评估函数
训练函数负责模型的迭代学习,评估函数则检验模型在未知数据上的表现。
python
# 训练函数
def train_base(train_dataset, input_dim, output_dim):
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
model = PhonePriceModelBase(input_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001) # 使用SGD优化器
epochs = 50
print("--- 🚀 开始基础模型训练 ---")
for epoch in range(epochs):
model.train()
for x, y in train_loader:
y_pred = model(x)
loss = criterion(y_pred, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.save(model.state_dict(), './model/phone_base.pth')
print("--- ✅ 基础模型训练完成 ---")
# 评估函数
def evaluate_base(test_dataset, input_dim, output_dim):
model = PhonePriceModelBase(input_dim, output_dim)
model.load_state_dict(torch.load('./model/phone_base.pth'))
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)
correct = 0
model.eval()
with torch.no_grad():
for x, y in test_loader:
y_pred = torch.argmax(model(x), dim=1)
correct += (y_pred == y).sum().item()
accuracy = correct / len(test_dataset)
return accuracy
▶️ 3.5 运行并查看结果
💡 结果分析
-
基础模型在测试集上的准确率为 64.25%。
-
这个结果比随机猜测(25%)好得多,证明模型学到了一些规律。但对于一个商业应用来说,这个准确率显然是不够的。
📈 四、进阶优化篇:让性能一飞冲天!
针对基线模型的不足,我们将从四个方面进行一次"豪华升级"。
📏 4.1 优化一:数据标准化 (Standard Scaling)
❓ 为什么? 特征间的数值范围差异巨大(如电池容量4000 vs 核心数8),会干扰模型的学习效率。标准化能将所有特征"拉"到同一水平线上,让训练更平稳。
python
from sklearn.preprocessing import StandardScaler
def create_dataset_optimized():
data = pd.read_csv('./data/手机价格预测.csv')
x, y = data.iloc[:, :-1], data.iloc[:, -1]
x = x.astype(np.float32)
y = y.astype(np.int64) # CrossEntropyLoss需要LongTensor类型的标签
x_train, x_valid, y_train, y_valid = train_test_split(x, y, train_size=0.8, random_state=88, stratify=y)
# === ✨ 优化点 ①:数据标准化 ===
transfer = StandardScaler()
x_train = transfer.fit_transform(x_train)
x_valid = transfer.transform(x_valid)
train_dataset = TensorDataset(torch.from_numpy(x_train), torch.tensor(y_train.values))
valid_dataset = TensorDataset(torch.from_numpy(x_valid), torch.tensor(y_valid.values))
return train_dataset, valid_dataset, x_train.shape[1], len(np.unique(y))
🧱 4.2 优化二:深化网络结构 (Deeper Network)
❓ 为什么? 更深、更宽的网络拥有更强的学习能力,能捕捉数据中更复杂、更抽象的模式,就像一个更聪明的大脑。
python
class PhonePriceModelOptimized(nn.Module):
def __init__(self, input_dim, output_dim):
super().__init__()
# === ✨ 优化点 ②:使用nn.Sequential构建更深的网络 ===
self.layers = nn.Sequential(
nn.Linear(input_dim, 128), nn.ReLU(),
nn.Linear(128, 256), nn.ReLU(),
nn.Linear(256, 512), nn.ReLU(),
nn.Linear(512, 128), nn.ReLU(),
nn.Linear(128, output_dim)
)
def forward(self, x):
return self.layers(x)
🚀 4.3 优化三 & 四:更换优化器与调整学习率 (Adam & Learning Rate)
❓ 为什么? Adam 优化器是一种更智能的"导航系统",它能自适应地调整学习率,通常比传统的 SGD 收敛更快、效果更好。配合 Adam,一个更小的学习率能让模型在接近最优解时"慢下来",进行精细微调。
python
def train_optimized(train_dataset, input_dim, class_num):
dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)
model = PhonePriceModelOptimized(input_dim, class_num)
criterion = nn.CrossEntropyLoss()
# === ✨ 优化点 ③ & ④:使用Adam和更小的学习率 ===
optimizer = optim.Adam(model.parameters(), lr=1e-4)
# ... (训练循环与之前类似) ...
torch.save(model.state_dict(), './model/phone-price-model2.pth')
# 评估函数逻辑不变,只需加载新模型即可
def test_optimized(valid_dataset, input_dim, class_num):
# ... (加载PhonePriceModelOptimized模型并评估) ...
# 此处省略重复代码,与 evaluate_base 类似
🏆 4.4 见证奇迹的时刻
💡 优化后结果分析
-
经过四大优化策略的加持,新模型在测试集上的准确率达到了惊人的 91.25%!
-
这意味着,对于100台新手机,我们的模型大约能正确预测其中91台的价格区间,这已经具备了很高的商业应用价值。
🎉 五、成果对比与总结
让我们将两个模型的结果放在一起,感受优化的力量!
| 模型版本 | 核心配置 | 准确率 | 提升幅度 |
|---|---|---|---|
| 基础版 | 3层网络 / SGD | 64.25% | - |
| 🚀 优化版 | 5层网络 / 数据标准化 / Adam | 91.25% | +27% |
📝 六、总结与回顾
本次实战,我们不仅完整地走了一遍从数据到模型的全流程,更重要的是,我们学会了如何像一位"模型调优师"一样去思考和实践。
✅ 核心 takeaways:
-
数据优先:永远不要低估数据预处理(如标准化)的重要性,它是模型成功的基石。
-
结构决定上限:一个好的网络结构(如适当加深)为模型提供了更高的学习潜力。
-
优化器是关键:选择合适的优化器(如Adam)和学习率,能让训练事半功倍。
-
迭代思维:深度学习是一个不断实验、分析、优化的循环过程。
🛠️ 七、源码与数据集获取
光说不练假把式!为了方便大家亲手复现本文的实验,我已经将完整的项目源码和 手机价格预测.csv 数据集打包好啦!
获取方式非常简单:
👉 第一步:请先**【关注】**我的 CSDN 账号。您的关注是我持续创作高质量内容的最佳动力!
📨 第二步:关注后,请给我发送私信,内容为关键词: `手机价格预测`
收到私信后,我会尽快将源码和数据集的下载方式发给你哦!期待你的实践与反馈!
🙏 课程推荐与参考 (Reference)
再次强调,本文核心知识点均提炼自 B 站黑马程序员 的精品课程,老师讲得非常详细,大家一定要去支持!
- 课程名称: AI大模型《神经网络与深度学习》全套视频课程,涵盖Pytorch深度学习框架、BP神经网络、CNN图像分类算法及RNN文本生成算法
现在,你也可以尝试调整网络结构、训练轮数,看看能否冲击更高的准确率!如果觉得这篇文章对你有帮助,请不要吝啬你的 👍 点赞、⭐ 收藏和 💬 评论!