自定义数据集使用框架的线性回归方法对其进行拟合

代码:

复制代码
# 导入必要的库
import torch
import numpy as np
import matplotlib.pyplot as plt

# 定义数据集:二维数据,其中第一列是特征 x,第二列是目标值 y
data = [[-0.5, 7.7],
        [1.8, 98.5],
        [0.9, 57.8],
        [0.4, 39.2],
        [-1.4, -15.7],
        [-1.4, -37.3],
        [-1.8, -49.1],
        [1.5, 75.6],
        [0.4, 34.0],
        [0.8, 62.3]]

# 将数据转换为 NumPy 数组
data = np.array(data)

# 提取特征(x_data)和目标值(y_data)
x_data = data[:, 0]  # 提取第一列作为特征
y_data = data[:, 1]  # 提取第二列作为目标

# 将 NumPy 数组转换为 PyTorch 张量(tensor),数据类型为浮点型
x_train = torch.tensor(x_data, dtype=torch.float32)  # 输入特征
y_train = torch.tensor(y_data, dtype=torch.float32)  # 目标值

# 导入损失函数模块,使用均方误差 (MSELoss)
import torch.nn as nn

criterion = nn.MSELoss()  # 均方误差损失


# 定义线性回归模型类(继承自 nn.Module)
class LinearModel(nn.Module):
    def __init__(self):
        super(LinearModel, self).__init__()
        # 定义一个线性层(一个输入节点,一个输出节点)
        self.layers = nn.Linear(1, 1)  # 输入特征数为1,输出为1

    # 前向传播方法:输入 x 返回经过线性层的输出
    def forward(self, x):
        x = self.layers(x)  # 线性层处理输入
        return x  # 返回结果


# 创建模型实例
model = LinearModel()

# 定义优化器:随机梯度下降(SGD),学习率为 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# 训练模型的 epoch 数量,设定为 500 次
epochs = 500

# 开始训练
for n in range(1, epochs + 1):
    # 预测:将 x_train 输入到模型中进行预测
    y_pred = model(x_train.unsqueeze(1))  # unsqueeze(1) 是将 x_train 从 [batch_size] 转换为 [batch_size, 1],以符合模型输入要求

    # 计算损失:损失是预测值和实际值之间的均方误差
    loss = criterion(y_pred.squeeze(1), y_train)  # squeeze(1) 是将 y_pred 从 [batch_size, 1] 转换为 [batch_size]

    # 优化器梯度清零
    optimizer.zero_grad()

    # 反向传播计算梯度
    loss.backward()

    # 使用优化器更新模型参数
    optimizer.step()

    # 每 10 个 epoch 或第 1 个 epoch 打印一次损失,并实时显示训练过程
    if n % 10 == 0 or n == 1:
        # 清除之前的图像
        plt.clf()

        # 绘制原始数据点,使用蓝色表示
        plt.scatter(x_data, y_data, color='blue')

        # 绘制当前预测的回归线,使用红色表示
        plt.plot(x_data, y_pred.detach().numpy(), color='red')  # detach() 是为了防止从计算图中分离,避免梯度计算

        # 设置图表的 x 和 y 轴标签
        plt.xlabel('X')
        plt.ylabel('Y')

        # 设置图表的标题,显示当前 epoch 数
        plt.title(f'Epoch {n}')

        # 暂停 0.1 秒,实时更新图表
        plt.pause(0.1)

# 训练完成后,显示最终的图表
plt.show()

结果:

相关推荐
草履虫建模10 小时前
力扣算法 1768. 交替合并字符串
java·开发语言·算法·leetcode·职场和发展·idea·基础
naruto_lnq12 小时前
分布式系统安全通信
开发语言·c++·算法
Jasmine_llq13 小时前
《P3157 [CQOI2011] 动态逆序对》
算法·cdq 分治·动态问题静态化+双向偏序统计·树状数组(高效统计元素大小关系·排序算法(预处理偏序和时间戳)·前缀和(合并单个贡献为总逆序对·动态问题静态化
爱吃rabbit的mq13 小时前
第09章:随机森林:集成学习的威力
算法·随机森林·集成学习
(❁´◡`❁)Jimmy(❁´◡`❁)14 小时前
Exgcd 学习笔记
笔记·学习·算法
YYuCChi14 小时前
代码随想录算法训练营第三十七天 | 52.携带研究材料(卡码网)、518.零钱兑换||、377.组合总和IV、57.爬楼梯(卡码网)
算法·动态规划
不能隔夜的咖喱15 小时前
牛客网刷题(2)
java·开发语言·算法
VT.馒头15 小时前
【力扣】2721. 并行执行异步函数
前端·javascript·算法·leetcode·typescript
进击的小头15 小时前
实战案例:51单片机低功耗场景下的简易滤波实现
c语言·单片机·算法·51单片机
咖丨喱17 小时前
IP校验和算法解析与实现
网络·tcp/ip·算法