203_深度学习的第一步:线性回归模型与 SGD 优化算法实战

线性回归试图学到一个线性模型,以尽可能准确地预测输出。在 PyTorch 中,我们可以通过简单的几行代码,实现从数据生成、模型构建到自动训练的全过程。

1. 线性回归的核心要素

线性回归模型可以表示为:

  • 训练数据 :特征 和标签
  • 模型参数 :权重 和偏差
  • 损失函数:均方误差(MSE Loss),用于衡量预测值与真实值之间的平方差。
  • 优化算法:小批量随机梯度下降(SGD),通过不断沿着梯度的反方向更新参数来最小化损失。

2. 核心代码:从零开始与简洁实现

对比手动实现(Scratch)与使用 PyTorch 官方库(nn.Module)的两种方式。以下是利用官方库的简洁实现

Python

复制代码
import torch
from torch import nn
from torch.utils import data

# 1. 生成或准备数据 (假设已有 features 和 labels)
def load_array(data_arrays, batch_size, is_train=True):
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

batch_size = 10
data_iter = load_array((features, labels), batch_size)

# 2. 定义模型结构
# Linear(输入特征数, 输出特征数)
net = nn.Sequential(nn.Linear(2, 1))

# 3. 初始化模型参数
net[0].weight.data.normal_(0, 0.01) # 权重初始化为均值为0,方差为0.01的正态分布
net[0].bias.data.fill_(0)          # 偏差初始化为0

# 4. 定义损失函数与优化器
loss = nn.MSELoss() # 均方误差损失
trainer = torch.optim.SGD(net.parameters(), lr=0.03) # 学习率为0.03的SGD

3. 训练过程:循环迭代

训练过程是不断从数据迭代器中获取批量数据,并更新参数的过程。

Python

复制代码
num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        # 前向传播:计算预测值与损失
        l = loss(net(X), y)
        
        # 反向传播:三步走
        trainer.zero_grad() # 1. 梯度清零
        l.backward()        # 2. 计算梯度
        trainer.step()      # 3. 更新参数
    
    # 打印每一轮后的总损失
    train_l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {train_l:f}')

4. 关键细节解析

为什么需要 DataLoader

在深度学习中,我们通常不一次性处理所有数据(内存压力大),也不一张一张处理(效率低)。DataLoader 帮助我们将数据分成一个个 Minibatch,这能在保证计算效率的同时,为优化过程引入一定的随机性,帮助模型跳出局部最优解。

net.parameters() 的作用

在定义优化器时,我们需要传入 net.parameters()。这告诉优化器:"你需要负责更新这个网络中所有的权重和偏差"。


5. 总结:深度学习的标准化样板

通过线性回归的学习,我们其实已经掌握了所有深度学习模型的通用模版:

  1. 数据流:Dataset -> DataLoader。
  2. 模型流:nn.Linear -> nn.Sequential。
  3. 计算流:Forward -> Loss -> Backward -> Step。

💡 学习小结

线性回归虽然简单,但它包含了深度学习的绝大部分基因。一旦你理解了权重如何根据梯度更新,你就已经推开了通往卷积神经网络(CNN)和循环神经网络(RNN)的大门。

相关推荐
newsxun21 分钟前
第十六届北京国际电影节东郎分会场启幕
人工智能
大嘴皮猴儿22 分钟前
从零开始学商品图翻译:小白也能快速掌握的多语言文字处理与上架技巧
大数据·ide·人工智能·macos·新媒体运营·xcode·自动翻译
思绪无限24 分钟前
YOLOv5至YOLOv12升级:行人跌倒检测系统的设计与实现(完整代码+界面+数据集项目)
深度学习·yolo·目标检测·yolov12·yolo全家桶·行人跌倒检测系统
大黄说说25 分钟前
AI大模型对内容创作的颠覆:机遇、版权争议与行业新规则
人工智能
captain_AIouo34 分钟前
OZON航海引领者Captain AI指引运营新航向
大数据·人工智能·经验分享·aigc
AI医影跨模态组学44 分钟前
PLOS Medicine 中山大学肿瘤防治中心蔡木炎等团队:基于多视角深度学习的组织病理学分析用于II期结直肠癌的预后与治疗分层
人工智能·深度学习·论文·医学·医学影像
起个名字总是说已存在1 小时前
github开源AI技能:Awesome DESIGN.md让页面设计无限可能
人工智能·开源·github
Aray12341 小时前
大模型推理全栈技术解析:从Transformer到RoPE/YaRN的上下文优化
人工智能·深度学习·transformer
ShingingSky1 小时前
给 Claude Code 加上 Windows 提醒——一个小功能,少操十份心
人工智能·设计
思绪无限1 小时前
YOLOv5至YOLOv12升级:行人车辆检测与计数识别系统的设计与实现(完整代码+界面+数据集项目)
人工智能·深度学习·yolo·目标检测·yolov12·yolo全家桶·行人车辆检测与计数