使用 PyTorch 实现线性回归:从零开始的完整指南

在机器学习中,线性回归是最基础且广泛使用的算法之一。它通过拟合数据点之间的线性关系,帮助我们理解和预测变量之间的关系。本文将通过一个简单的例子,展示如何使用 PyTorch 框架实现线性回归,并对自定义数据集进行拟合。

1. 线性回归简介

线性回归的目标是找到一个线性方程 y=wx+b,其中 w 是斜率,b 是截距,使得该方程能够尽可能地拟合给定的数据点。在实际应用中,我们通常使用最小二乘法来最小化预测值与真实值之间的误差。

2. 准备数据

首先,我们需要准备一个简单的数据集。在这个例子中,我们将使用一个包含 10 个数据点的自定义数据集:

python 复制代码
data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

这些数据点表示输入特征 x 和目标变量 y 之间的关系。我们将使用 PyTorch 的张量(Tensor)来存储和处理这些数据。

3. 构建线性回归模型

接下来,我们需要定义一个线性回归模型。在 PyTorch 中,可以通过继承 nn.Module 来定义一个自定义模型。我们将使用一个简单的线性层来实现这个模型:

python 复制代码
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

这个模型包含一个线性层,其输入维度为 1,输出维度也为 1,正好符合我们的问题需求。

4. 定义损失函数和优化器

为了训练模型,我们需要定义一个损失函数和一个优化器。在这里,我们使用均方误差(MSE)作为损失函数,使用随机梯度下降(SGD)作为优化器:

python 复制代码
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

5. 训练模型

现在,我们可以开始训练模型了。我们将数据集输入模型,计算损失,并通过反向传播更新模型参数。以下是完整的训练代码:

python 复制代码
epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

在每个 epoch 中,我们计算模型的预测值,计算损失,并通过 loss.backward() 计算梯度,最后通过 optimizer.step() 更新模型参数。

6. 可视化结果

训练完成后,我们可以通过绘制原始数据点和拟合的直线来直观地展示模型的效果。以下是完整的可视化代码:

python 复制代码
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号


# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

运行上述代码后,你将看到如下图像:

从图中可以看出,拟合的直线能够较好地反映数据点之间的线性关系。

7. 总结

通过本文的介绍,你已经学会了如何使用 PyTorch 实现线性回归,并对自定义数据集进行拟合。线性回归虽然简单,但在许多实际问题中都非常有效。希望这篇文章能够帮助你更好地理解和应用线性回归模型。


代码完整版

以下是完整的代码,供你参考和使用:

python 复制代码
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

# 设置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号

# 定义输入数据
data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]

# 将 x_data 和 y_data 转化成 tensor
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)

# 定义损失函数
criterion = nn.MSELoss()

# 定义线性回归模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = LinearModel()

# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

# 绘制图像
# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()
相关推荐
袁庭新6 分钟前
2025年07月总结
人工智能·aigc·编程语言
2501_9248792615 分钟前
强反光干扰下漏检率↓79%!陌讯多模态融合算法在油罐车识别的边缘计算优化
人工智能·算法·计算机视觉·视觉检测·边缘计算
下页、再停留1 小时前
【PHP】接入百度AI开放平台人脸识别API,实现人脸对比
人工智能·百度·php
松果财经1 小时前
外卖“0元购”退场后,即时零售大战才刚开始
大数据·人工智能
说私域1 小时前
基于开源链动2+1模式AI智能名片S2B2C商城小程序的私域流量拉新策略研究
人工智能·小程序·开源
永洪科技1 小时前
永洪科技华西地区客户交流活动成功举办!以AI之力锚定增长确定性
大数据·人工智能·科技·数据分析·数据可视化
京东零售技术1 小时前
京东零售在智能供应链领域的前沿探索与技术实践
人工智能·百度·零售
小小小小小鹿2 小时前
Ai入门-结合rag搭建一个专属的ai学习助手
人工智能·llm
华东数交2 小时前
数本归源——数据资产化的需求
人工智能