使用PyTorch实现线性回归的完整流程

代码逻辑:

复制代码
"""
线性回归模型的训练流程:
第一步:准备训练的数据,Prepare Dataset
第二步:构建数据加载器并进行模型构建 Design model using class
第三步:设置损失函数和优化器 Construct loss and optimizer
第四步:训练模型和预测 Training cycle
"""
#导入工具包
import torch
from torch.utils.data import TensorDataset,DataLoader
from torch.nn import Linear,MSELoss
from torch.optim import SGD
from sklearn.datasets import make_regression
import matplotlib.pyplot as plt

#创建训练的数据
def create_dataset():
    x,y,coef=make_regression(n_samples=100, #样本
                    n_features=1,  #特征向量维度
                    noise=10,  #噪声
                    coef=True, #权重
                    bias=1.5, #偏置
                    random_state=0  #随机值
                    )
    x = torch.tensor(x)
    y = torch.tensor(y)
    return x,y,coef
x,y,coef = create_dataset()
# plt.scatter(x,y)
# #构建拟合线性回归直线
# x1 = torch.linspace(x.min(),x.max(),1000)
# y1 = torch.tensor([float(v)*coef+1.5 for v in x1]) #强制转换数据类型
# plt.plot(x1,y1)
# plt.grid()
# plt.show()
#构建数据加载器并进行模型构建
dataset = TensorDataset(x,y)
dataloader = DataLoader(dataset,batch_size=16,shuffle=True)
model = Linear(in_features=1,out_features=1)
#构建损失函数(平方损失函数)和优化器(梯度下降)
loss = MSELoss()
#构造优化器
optinizer = SGD(params=model.parameters(),lr=0.01) #设置所有参数,学习率(步长)设置,0.001-0.01之间
#模型训练
"""
1、设置循环的轮次,epochs=100次
2、定义损失函数的变化存放在一个列表中 loss_epoch = []
3、每个轮次的损失,都默认值添加到 total_loss = 0.0 中
4、记录下训练的样本个数
5、遍历每个轮次,根据每个轮次遍历dataloader把数据送入到模型中进行计算损失,梯度清零,反向传播(自动微分),更新参数
6、获取每次轮次的损失,添加到loss_epoch列表中
7、绘制损失函数的变化曲线
8、绘制拟合直线
"""
#定义训练的轮次
epochs = 100
loss_epoch = []
total_loss = 0.0
train_sample = 0.0
for _ in range(epochs):
    for train_x,train_y in dataloader:
        #把batch的训练数据送入模型中
        y_pred = model(train_x.type(torch.float32))
        #计算损失
        loss_value = loss(y_pred,train_y.reshape(-1,1).type(torch.float32))
        #损失累加
        total_loss += loss_value.item()
        #样本累加
        train_sample += len(train_y)
        #梯度清零
        optinizer.zero_grad()
        #自动微分(反向传播)
        loss_value.backward()
        #更新参数
        optinizer.step()
    #获取每个batch的损失
    loss_epoch.append(total_loss/train_sample)

#绘制损失变化曲线
plt.plot(range(epochs),loss_epoch)
plt.show()
#绘制拟合直线
plt.scatter(x,y)
#构建拟合线性回归直线
x1 = torch.linspace(x.min(),x.max(),1000)
y0 = torch.tensor([float(v)*model.weight+ model.bias for v in x1])
y1 = torch.tensor([float(v)*coef+1.5 for v in x1]) #强制转换数据类型
#预测值
plt.plot(x1,y0,label = "预测值")
#真实值
plt.plot(x1,y1,label = "真实值")
plt.grid()
plt.show()

线性回归代码逻辑总结

线性回归是机器学习中最基础的模型之一,用于预测连续值。以下代码使用PyTorch实现了一个完整的线性回归训练流程,分为数据准备、模型构建、训练优化和结果展示四个部分。

数据准备

使用sklearn.datasets.make_regression生成100个带噪声的线性数据样本,特征维度为1。将数据转换为PyTorch张量并创建DataLoader实现批量加载(batch_size=16)。数据可视化部分被注释,但保留了散点图和真实回归线的绘制逻辑。

复制代码
# 生成数据
x, y, coef = make_regression(n_samples=100, n_features=1, noise=10, bias=1.5)
x = torch.tensor(x)
y = torch.tensor(y)

# 创建DataLoader
dataset = TensorDataset(x, y)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)
模型构建

使用PyTorch的Linear层定义单层线性模型,输入输出维度均为1,对应公式: $$ y = wx + b $$

复制代码
model = Linear(in_features=1, out_features=1)
训练配置

选择均方误差(MSE)作为损失函数,随机梯度下降(SGD)作为优化器,学习率设为0.01。关键对象初始化如下:

复制代码
loss_fn = MSELoss()
optimizer = SGD(model.parameters(), lr=0.01)
训练循环

执行100轮训练(epochs),每轮遍历所有批次数据。核心步骤包括:

  • 前向传播计算预测值

  • 计算损失并累加

  • 反向传播更新参数

  • 记录每轮平均损失

    for epoch in range(100):
    for batch_x, batch_y in dataloader:
    pred = model(batch_x.float())
    loss = loss_fn(pred, batch_y.float().view(-1,1))

    复制代码
          optimizer.zero_grad()
          loss.backward()
          optimizer.step()
结果可视化
  1. 损失曲线:展示训练过程中损失值的下降趋势

  2. 拟合对比:将模型预测的直线(蓝色)与真实回归线(橙色)叠加在数据散点图上

    预测值计算示例

    x_range = torch.linspace(x.min(), x.max(), 100)
    y_pred = model(x_range.unsqueeze(1)).detach()

深度学习核心概念通俗解释

模型本质

线性回归可视为最简单的神经网络------只有输入层和输出层。模型通过调整权重w和偏置b来拟合数据中的线性关系。

关键组件
  • 损失函数:衡量预测值与真实值的差距(如MSE)
  • 优化器:通过梯度下降调整参数,逐步减少损失
  • 反向传播:链式法则计算梯度,从输出层回溯到输入层
扩展思考
  1. 增加网络层数和激活函数可升级为深度神经网络
  2. 批量训练(batch)能平衡计算效率和梯度稳定性
  3. 学习率决定参数更新步长,需谨慎选择

该示例虽简单,但包含了深度学习的所有核心要素:数据流、参数优化、损失计算和模型评估。理解这个基础模板后,可以进一步探索更复杂的网络结构。

相关推荐
代码游侠3 分钟前
C语言核心概念复习(二)
c语言·开发语言·数据结构·笔记·学习·算法
XX風15 分钟前
2.1_binary_search_tree
算法·计算机视觉
不想写bug呀27 分钟前
买卖股票问题
算法·买卖股票问题
-Try hard-27 分钟前
完全二叉树、非完全二叉树、哈希表的创建与遍历
开发语言·算法·vim·散列表
茉莉玫瑰花茶1 小时前
C++ 17 详细特性解析(4)
开发语言·c++·算法
long3161 小时前
K‘ 未排序数组中的最小/最大元素 |期望线性时间
java·算法·排序算法·springboot·sorting algorithm
进击的小头1 小时前
FIR滤波器实战:音频信号降噪
c语言·python·算法·音视频
xqqxqxxq1 小时前
洛谷算法1-1 模拟与高精度(NOIP经典真题解析)java(持续更新)
java·开发语言·算法
razelan1 小时前
初级算法技巧 4
算法
砍树+c+v1 小时前
3a 感知机训练过程示例(手算拆解,代码实现)
人工智能·算法·机器学习