自定义数据集使用框架的线性回归方法对其进行拟合

代码:

复制代码
# 导入必要的库
import torch
import numpy as np
import matplotlib.pyplot as plt

# 定义数据集:二维数据,其中第一列是特征 x,第二列是目标值 y
data = [[-0.5, 7.7],
        [1.8, 98.5],
        [0.9, 57.8],
        [0.4, 39.2],
        [-1.4, -15.7],
        [-1.4, -37.3],
        [-1.8, -49.1],
        [1.5, 75.6],
        [0.4, 34.0],
        [0.8, 62.3]]

# 将数据转换为 NumPy 数组
data = np.array(data)

# 提取特征(x_data)和目标值(y_data)
x_data = data[:, 0]  # 提取第一列作为特征
y_data = data[:, 1]  # 提取第二列作为目标

# 将 NumPy 数组转换为 PyTorch 张量(tensor),数据类型为浮点型
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标值

# 导入损失函数模块,使用均方误差 (MSELoss)
import torch.nn as nn

criterion = nn.MSELoss()  # 均方误差损失


# 定义线性回归模型类(继承自 nn.Module)
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 定义一个线性层(一个输入节点,一个输出节点)
        self.layers = nn.Linear(1, 1)  # 输入特征数为1,输出为1

    # 前向传播方法:输入 x 返回经过线性层的输出
    def forward(self, x):
        x = self.layers(x)  # 线性层处理输入
        return x  # 返回结果


# 创建模型实例
model = LinearModel()

# 定义优化器:随机梯度下降(SGD),学习率为 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型的 epoch 数量,设定为 500 次
epochs = 500

# 开始训练
for n in range(1, epochs + 1):
    # 预测:将 x_train 输入到模型中进行预测
    y_pred = model(x_train.unsqueeze(1))  # unsqueeze(1) 是将 x_train 从 [batch_size] 转换为 [batch_size, 1],以符合模型输入要求

    # 计算损失:损失是预测值和实际值之间的均方误差
    loss = criterion(y_pred.squeeze(1), y_train)  # squeeze(1) 是将 y_pred 从 [batch_size, 1] 转换为 [batch_size]

    # 优化器梯度清零
    optimizer.zero_grad()

    # 反向传播计算梯度
    loss.backward()

    # 使用优化器更新模型参数
    optimizer.step()

    # 每 10 个 epoch 或第 1 个 epoch 打印一次损失,并实时显示训练过程
    if n % 10 == 0 or n == 1:
        # 清除之前的图像
        plt.clf()

        # 绘制原始数据点,使用蓝色表示
        plt.scatter(x_data, y_data, color='blue')

        # 绘制当前预测的回归线,使用红色表示
        plt.plot(x_data, y_pred.detach().numpy(), color='red')  # detach() 是为了防止从计算图中分离,避免梯度计算

        # 设置图表的 x 和 y 轴标签
        plt.xlabel('X')
        plt.ylabel('Y')

        # 设置图表的标题,显示当前 epoch 数
        plt.title(f'Epoch {n}')

        # 暂停 0.1 秒,实时更新图表
        plt.pause(0.1)

# 训练完成后,显示最终的图表
plt.show()

结果:

相关推荐
勇闯逆流河21 分钟前
【数据结构】堆
c语言·数据结构·算法
pystraf1 小时前
LG P9844 [ICPC 2021 Nanjing R] Paimon Segment Tree Solution
数据结构·c++·算法·线段树·洛谷
飞川撸码2 小时前
【LeetCode 热题100】739:每日温度(详细解析)(Go语言版)
算法·leetcode·golang
yuhao__z2 小时前
代码随想录算法训练营第六十六天| 图论11—卡码网97. 小明逛公园,127. 骑士的攻击
算法
Echo``2 小时前
3:OpenCV—视频播放
图像处理·人工智能·opencv·算法·机器学习·视觉检测·音视频
Nobkins3 小时前
2021ICPC四川省赛个人补题ABDHKLM
开发语言·数据结构·c++·算法·图论
88号技师3 小时前
2025年6月一区SCI-不实野燕麦优化算法Animated Oat Optimization-附Matlab免费代码
开发语言·算法·matlab·优化算法
ysy16480672394 小时前
03算法学习_977、有序数组的平方
学习·算法
codists4 小时前
《算法导论(第4版)》阅读笔记:p83-p85
算法
Tiny番茄4 小时前
归一化函数 & 激活函数
人工智能·算法·机器学习