使用 PyTorch 实现线性回归:从零开始的完整指南

在机器学习中,线性回归是最基础且广泛使用的算法之一。它通过拟合数据点之间的线性关系,帮助我们理解和预测变量之间的关系。本文将通过一个简单的例子,展示如何使用 PyTorch 框架实现线性回归,并对自定义数据集进行拟合。

1. 线性回归简介

线性回归的目标是找到一个线性方程 y=wx+b,其中 w 是斜率,b 是截距,使得该方程能够尽可能地拟合给定的数据点。在实际应用中,我们通常使用最小二乘法来最小化预测值与真实值之间的误差。

2. 准备数据

首先,我们需要准备一个简单的数据集。在这个例子中,我们将使用一个包含 10 个数据点的自定义数据集:

python 复制代码
data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

这些数据点表示输入特征 x 和目标变量 y 之间的关系。我们将使用 PyTorch 的张量(Tensor)来存储和处理这些数据。

3. 构建线性回归模型

接下来,我们需要定义一个线性回归模型。在 PyTorch 中,可以通过继承 nn.Module 来定义一个自定义模型。我们将使用一个简单的线性层来实现这个模型:

python 复制代码
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

这个模型包含一个线性层,其输入维度为 1,输出维度也为 1,正好符合我们的问题需求。

4. 定义损失函数和优化器

为了训练模型,我们需要定义一个损失函数和一个优化器。在这里,我们使用均方误差(MSE)作为损失函数,使用随机梯度下降(SGD)作为优化器:

python 复制代码
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

5. 训练模型

现在,我们可以开始训练模型了。我们将数据集输入模型,计算损失,并通过反向传播更新模型参数。以下是完整的训练代码:

python 复制代码
epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

在每个 epoch 中,我们计算模型的预测值,计算损失,并通过 loss.backward() 计算梯度,最后通过 optimizer.step() 更新模型参数。

6. 可视化结果

训练完成后,我们可以通过绘制原始数据点和拟合的直线来直观地展示模型的效果。以下是完整的可视化代码:

python 复制代码
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号


# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()

运行上述代码后,你将看到如下图像:

从图中可以看出,拟合的直线能够较好地反映数据点之间的线性关系。

7. 总结

通过本文的介绍,你已经学会了如何使用 PyTorch 实现线性回归,并对自定义数据集进行拟合。线性回归虽然简单,但在许多实际问题中都非常有效。希望这篇文章能够帮助你更好地理解和应用线性回归模型。


代码完整版

以下是完整的代码,供你参考和使用:

python 复制代码
import torch
import torch.nn as nn
import numpy as np
from matplotlib import pyplot as plt

# 设置 matplotlib 支持中文显示
plt.rcParams['font.sans-serif'] = ['SimHei']  # 指定中文字体为黑体
plt.rcParams['axes.unicode_minus'] = False  # 正确显示负号

# 定义输入数据
data = [
    [-0.5, 7.7],
    [1.8, 98.5],
    [0.9, 57.8],
    [0.4, 39.2],
    [-1.4, -15.7],
    [-1.4, -37.3],
    [-1.8, -49.1],
    [1.5, 75.6],
    [0.4, 34.0],
    [0.8, 62.3]
]

# 转换为 NumPy 数组
data = np.array(data)
# 提取 x_data 和 y_data
x_data = data[:, 0]
y_data = data[:, 1]

# 将 x_data 和 y_data 转化成 tensor
x_train = torch.tensor(x_data, dtype=torch.float32)
y_train = torch.tensor(y_data, dtype=torch.float32)

# 定义损失函数
criterion = nn.MSELoss()

# 定义线性回归模型
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        self.layers = nn.ModuleList([nn.Linear(1, 1)])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

model = LinearModel()

# 定义优化器
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型
epochs = 500
for n in range(1, epochs + 1):
    y_pred = model(x_train.unsqueeze(1))
    loss = criterion(y_pred.squeeze(1), y_train)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if n % 10 == 0 or n == 1:
        print(f"Epoch: {n}, Loss: {loss.item():.4f}")

# 绘制图像
# 绘制原始数据点
plt.scatter(x_data, y_data, color='blue', label='原始数据')

# 绘制拟合的直线
slope = model.layers[0].weight.item()
intercept = model.layers[0].bias.item()
x_fit = np.linspace(x_data.min(), x_data.max(), 100)
y_fit = slope * x_fit + intercept
plt.plot(x_fit, y_fit, color='red', label='拟合直线')

# 添加图例和标签
plt.xlabel('X')
plt.ylabel('Y')
plt.legend()
plt.title('线性回归拟合结果')
plt.show()
相关推荐
子燕若水1 小时前
Unreal Engine 5中的AI知识
人工智能
极限实验室2 小时前
Coco AI 实战(一):Coco Server Linux 平台部署
人工智能
杨过过儿2 小时前
【学习笔记】4.1 什么是 LLM
人工智能
巴伦是只猫2 小时前
【机器学习笔记Ⅰ】13 正则化代价函数
人工智能·笔记·机器学习
大千AI助手2 小时前
DTW模版匹配:弹性对齐的时间序列相似度度量算法
人工智能·算法·机器学习·数据挖掘·模版匹配·dtw模版匹配
AI生存日记3 小时前
百度文心大模型 4.5 系列全面开源 英特尔同步支持端侧部署
人工智能·百度·开源·open ai大模型
LCG元3 小时前
自动驾驶感知模块的多模态数据融合:时序同步与空间对齐的框架解析
人工智能·机器学习·自动驾驶
why技术3 小时前
Stack Overflow,轰然倒下!
前端·人工智能·后端
超龄超能程序猿4 小时前
(三)PS识别:基于噪声分析PS识别的技术实现
图像处理·人工智能·计算机视觉