刘二大人第5讲-pytorch实现线性回归-有代码

5. pytorch实现线性回归

用pytorch提供的工具实现一次线性模型

步骤如下:

  1. 准备dataset
  2. 设计模型:继承nn.Module类
  3. 定义Loss和optimizer优化器:用pytorch的api
  4. 训练:forward算Loss,backward算Gradient,用Gradient Descent算法update

5.1 步骤一:准备dataset

采用mini-batch风格

5.2 步骤二:定义模型

将模型定义成一个class,继承nn.Module

继承Module的对象会自动生成计算图实现backward,其中必须实现的方法有__init__()和forward()

  • 其中__init__()构造函数用于初始化对象
    • 调用父类构造函数必写
    • nn.Linear是一个继承自nn.Module的类
      • 构造函数(每个输入样本的维度,每个输出样本的维度,bias=True),其中bias若设为False,将不学bias
      • 有两个成员变量weight和bias
        • w的形状是out_features*in_features
        • b的形状是out_features
  • forward()定义前馈中执行的计算
    • 补充:python中对象()实际是执行他的__call()__方法,而Linear类中__call()__会去调用forward()

5.3 步骤三:实例化Loss和Optimizer

python 复制代码
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

注:model.parameters()能调用这个module中所有成员的parameters()方法拿到parameters,此处拿到module中linear的w和b

Loss这里采用MSELoss,同样是一个继承自nn.Module的类,具体文档见下

注:新版的torch中size_average and reduce 被弃用了,使用reduction='sum'求和,使用reduction='mean'求平均值

Optimizer这里采用SGD,不继承自nn.Module不会构建计算图,lr为learning rate,具体文档见下

5.4 步骤四:训练

python 复制代码
for epoch in range(1000):
    # 1. 求y_hat
    y_pred = model(x_data)
    # 2. 求loss
    loss = criterion(y_pred, y_data)
    # 3. 清空gradient
    optimizer.zero_grad()
    # 4. 反向传播求梯度
    loss.backward()
    # 5. 更新参数
    optimizer.step()
  1. 求 <math xmlns="http://www.w3.org/1998/Math/MathML"> y ^ \hat{y} </math>y^
  2. 求Loss
  3. 梯度归零
  4. backward求梯度
  5. step更新

5.5 完整代码

python 复制代码
import matplotlib
# 使用TkAgg后端以确保在不同环境下都能正常显示图形
matplotlib.use('TkAgg')
from matplotlib import pyplot as plt
import numpy as np
import torch
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D

# 设置中文字体以确保中文正常显示
plt.rcParams["font.family"] = ["SimHei"]
plt.rcParams['axes.unicode_minus'] = False  # 解决负号显示问题

# 1. 准备dataset,采用mini-batch风格
x_data = torch.Tensor([[1.0], [2.0], [3.0]])
y_data = torch.Tensor([[2.0], [4.0], [6.0]])

# 2. 定义模型
class LinearModel(torch.nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.linear = torch.nn.Linear(1, 1)

    def forward(self, x):
        y_pred = self.linear(x)
        return y_pred

# 3. 实例化模型、Loss、Optimizer
model = LinearModel()
criterion = torch.nn.MSELoss(reduction="sum")
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 4. 训练模型1000次
# 记录训练过程
w_list = []
b_list = []
l_list = []

for epoch in range(1000):
    # 1. 求y_hat
    y_pred = model(x_data)
    # 2. 求loss
    loss = criterion(y_pred, y_data)
    # 3. 清空gradient
    optimizer.zero_grad()
    # 4. 反向传播求梯度
    loss.backward()
    # 5. 更新参数
    optimizer.step()
    # 记录训练过程
    w_list.append(model.linear.weight.item())
    b_list.append(model.linear.bias.item())
    l_list.append(loss.item())

# 生成网格数据计算理论损失曲面
w = np.linspace(0, 4, 100)
b = np.linspace(-2, 2, 100)
W, B = np.meshgrid(w, b)
Z = np.zeros_like(W)

# 计算每个网格点的损失
for i in range(len(w)):
    for j in range(len(b)):
        # 计算预测值
        y_pred = W[i, j] * x_data.numpy() + B[i, j]
        # 计算MSE损失
        Z[i, j] = np.sum((y_pred - y_data.numpy())**2)

# 创建3D图形
fig = plt.figure(figsize=(12, 5))

# 绘制损失曲面和训练路径
ax1 = fig.add_subplot(121, projection='3d')
surf = ax1.plot_surface(W, B, Z, cmap=cm.coolwarm, alpha=0.8, linewidth=0)
ax1.plot(w_list, b_list, l_list, 'r-o', markersize=2, linewidth=1, label='训练路径')
ax1.set_xlabel('权重 (w)')
ax1.set_ylabel('偏置 (b)')
ax1.set_zlabel('损失 (Loss)')
ax1.set_title('损失曲面与训练路径')
ax1.view_init(elev=30, azim=45)
fig.colorbar(surf, ax=ax1, shrink=0.5, aspect=5)

# 绘制等高线图和训练路径投影
ax2 = fig.add_subplot(122)
contour = ax2.contour(W, B, Z, 50, cmap=cm.coolwarm)
ax2.plot(w_list, b_list, 'r-o', markersize=2, linewidth=1, label='训练路径')
ax2.set_xlabel('权重 (w)')
ax2.set_ylabel('偏置 (b)')
ax2.set_title('损失函数等高线图')
fig.colorbar(contour, ax=ax2, shrink=0.5, aspect=5)
ax2.grid(True)

plt.tight_layout()
plt.show()

图像如下:

w收敛于2.0,b收敛于0.0

相关推荐
程序员佳佳12 分钟前
2025年大模型终极横评:GPT-5.2、Banana Pro与DeepSeek V3.2实战硬核比拼(附统一接入方案)
服务器·数据库·人工智能·python·gpt·api
鲨莎分不晴21 分钟前
【前沿技术】Offline RL 全解:当强化学习失去“试错”的权利
人工智能·算法·机器学习
工业机器视觉设计和实现36 分钟前
lenet改vgg成功后,我们再改为最简单的resnet
人工智能
jiayong2342 分钟前
Spring AI Alibaba 深度解析(三):实战示例与最佳实践
java·人工智能·spring
北邮刘老师1 小时前
【智能体互联协议解析】需要“智能体名字系统”(ANS)吗?
网络·人工智能·大模型·智能体·智能体互联网
梁辰兴1 小时前
AI解码千年甲骨文,指尖触碰的文明觉醒!
人工智能·ai·ai+·文明·甲骨文·ai赋能·梁辰兴
阿里云大数据AI技术1 小时前
# Hologres Dynamic Table:高效增量刷新,构建实时统一数仓的核心利器
人工智能·数据分析
光羽隹衡1 小时前
机械学习逻辑回归——银行贷款案例
算法·机器学习·逻辑回归
JxWang052 小时前
pandas计算某列每行带有分隔符的数据中包含特定值的次数
人工智能
能源系统预测和优化研究2 小时前
创新点解读:基于非线性二次分解的Ridge-RF-XGBoost时间序列预测(附代码实现)
人工智能·深度学习·算法