【动手学深度学习】笔记1:简单的线性回归

根据我们之前的对话,我为你整理了一份线性回归从零实现的学习笔记。这份笔记涵盖了数据生成、小批量迭代器、模型定义、损失函数、SGD优化器以及完整训练流程。你可以把它保存下来,经常复习。


线性回归从零实现 · 学习笔记

一、生成合成数据

python 复制代码
def synthetic_data(w, b, num_examples):
    """生成 y = X w + b + 噪声"""
    X = torch.normal(0, 1, (num_examples, len(w)))   # 特征矩阵
    y = torch.matmul(X, w) + b                       # 线性部分
    y += torch.normal(0, 0.01, y.shape)              # 加噪声
    return X, y.reshape((-1, 1))                     # y 转为列向量
  • 形状理解X 形状 (样本数, 特征数)w 形状 (特征数,)y 形状 (样本数,)。通过 reshape((-1,1)) 变成列向量 (样本数, 1)
  • 为什么要加噪声:模拟真实数据的不完美,让模型学会忽略小扰动。
  • 为什么要 reshape:便于后续矩阵运算,避免广播歧义。

二、手动实现小批量数据迭代器

python 复制代码
def data_iter(batch_size, features, labels):
    num_examples = len(features)
    indices = list(range(num_examples))
    random.shuffle(indices)                # 每个 epoch 打乱顺序
    for i in range(0, num_examples, batch_size):
        batch_indices = torch.tensor(indices[i: min(i+batch_size, num_examples)])
        yield features[batch_indices], labels[batch_indices]
  • indices :原始样本的索引列表 [0,1,2,...],打乱后实现随机抽取。
  • batch_indices:当前批次对应的索引(张量形式)。
  • features[batch_indices] :按索引抽取子集,形状 (batch_size, 特征数)
  • yield:生成器,每次返回一个批次,并保存状态,下次继续。节省内存。

三、初始化模型参数

python 复制代码
w = torch.normal(0, 0.01, size=(2, 1), requires_grad=True)
b = torch.zeros(1, requires_grad=True)
  • w 形状 (特征数, 1),便于矩阵乘法 X @ w 得到 (batch_size, 1)
  • requires_grad=True:告诉 PyTorch 需要计算梯度。

四、定义线性回归模型

python 复制代码
def linreg(X, w, b):
    return torch.matmul(X, w) + b
  • 等价于 X @ w + b
  • 广播机制:b 会自动加到每个样本的预测值上。

五、定义损失函数(均方误差的一半)

python 复制代码
def squared_loss(y_hat, y):
    return (y_hat - y.reshape(y_hat.shape)) ** 2 / 2
  • 返回向量 (batch_size, 1),每个元素是一个样本的半平方误差
  • 除以 2 是为了求导后系数为 1(导数 = y_hat - y)。
  • 训练时会对这个向量 .sum() 再反向传播。

六、定义优化算法(小批量 SGD)

python 复制代码
def sgd(params, lr, batch_size):
    with torch.no_grad():
        for param in params:
            param -= lr * param.grad / batch_size
            param.grad.zero_()
  • with torch.no_grad():禁用梯度追踪,手动更新参数时不要构建计算图。
  • param.grad 是累积的梯度(因为之前 l.sum().backward() 求和后反向,梯度是总损失的导数)。
  • 除以 batch_size:将总梯度转换为平均梯度,使更新步长与批量大小无关。
  • param.grad.zero_():梯度清零,否则下一批会累加。

七、训练循环

python 复制代码
lr = 0.03
num_epochs = 3
net = linreg
loss = squared_loss

for epoch in range(num_epochs):
    for X, y in data_iter(batch_size, features, labels):
        l = loss(net(X, w, b), y)        # 小批量损失向量
        l.sum().backward()               # 反向传播计算梯度
        sgd([w, b], lr, batch_size)      # 更新参数
    with torch.no_grad():
        train_l = loss(net(features, w, b), labels)
        print(f'epoch {epoch+1}, loss: {float(train_l.mean()):f}')
  • 内层循环:每个 epoch 内遍历所有小批量。
  • l.sum().backward():因为 l 是向量,需先求和成标量再反向传播。梯度会累加到 w.gradb.grad
  • 外层循环结束时:打印整个数据集上的平均损失,观察训练进展。

八、关键概念总结

概念 说明
特征矩阵 X 形状 (num_examples, num_features),每行一个样本
权重 w 形状 (num_features, 1),每个特征对应一个权重
偏置 b 标量或 (1,),所有样本共享
预测值 X @ w + b,形状 (num_examples, 1)
损失函数 向量输出(每个样本一个损失),求和后反向传播
梯度下降 param -= lr * (param.grad / batch_size),然后清零梯度
生成器 (yield) 按需产生数据,节省内存

九、常见问题自查

  1. 为什么 y 要 reshape 成列向量?

    为了与 y_hat 形状一致,便于矩阵运算,避免广播错误。

  2. 为什么要除以 batch_size

    因为 l.sum().backward() 产生的梯度是总损失 的梯度,而我们需要平均梯度,所以除以批量大小。

  3. 为什么损失函数要除以 2?

    为了让求导后的梯度表达式为 (y_hat - y),没有系数 2,纯粹是为了数学简洁。

  4. with torch.no_grad() 有什么用?

    告诉 PyTorch 不要跟踪这部分操作,避免构建不需要的计算图,节省内存和计算。

  5. 为什么要手动 param.grad.zero_()

    梯度默认会累加,不清零的话下一次反向传播会将新旧梯度相加,导致错误。


十、扩展学习路径

  • 使用 PyTorch 高层 API:nn.Linear, MSELoss, optim.SGD
  • 增加验证集,绘制训练曲线
  • 尝试不同的学习率和批量大小,观察收敛情况
  • 扩展到多项式回归、逻辑回归

笔记到此结束。 你可以根据自己的理解,在空白处补充例子或疑问。需要我进一步解释某个部分,随时可以继续提问。

相关推荐
RainCity15 小时前
Java Swing 自定义组件库分享(十二)
java·笔记·后端
饼干哥哥4 天前
开源Skills|搭建亚马逊动态关键词库系统,每天抓SSS级机会词
人工智能·深度学习·数据分析
武子康5 天前
调查研究-191 SenseVoice 不只是 ASR:把语音从“转文字“升级成“理解状态“
人工智能·深度学习·openai
武子康7 天前
调查研究-189 Kronos 调研:金融 K 线基础模型,是真突破,还是量化圈的新玩具?
人工智能·深度学习·openai
LinXunFeng8 天前
Obsidian - 使用 Share Note 分享笔记并自部署
前端·笔记·github
xiao5kou4chang6kai412 天前
MATLAB机器学习、深度学习--从数据预处理到模型训练
深度学习·机器学习·matlab·数据预处理
renhongxia112 天前
世界模型作为AGI落地底层底座的作用
人工智能·深度学习·生成对抗网络·自然语言处理·知识图谱·agi
计算机科研狗@OUC12 天前
(cvpr26) AIMDepth: Asymmetric Image-Event Mamba for Monocular Depth Estimation
人工智能·深度学习·计算机视觉
闪闪发亮的小星星12 天前
高斯光以及高斯光公式解释
笔记
cqbzcsq12 天前
CellFlow虚拟细胞论文阅读
论文阅读·人工智能·笔记·学习·生物信息