最简单的线性回归神经网络

数据:

# 线性回归
import torch
import numpy as np
import matplotlib.pyplot as plt

# 随机种子,确保每次运行结果一致
torch.manual_seed(42)

# 生成训练数据
X = torch.randn(100, 3)  # 100 个样本,每个样本 3 个特征
true_w = torch.tensor([2.0, 3.0, 4.5] )  # 假设真实权重
true_b = 4.0  # 偏置项
Y = X @ true_w + true_b + torch.randn(100) * 0.2  # 加入一些噪声

# 打印部分数据
print(X[:5])
print(Y[:5])

模型:

import torch.nn as nn

# 定义线性回归模型
class LinearRegressionModel(nn.Module):
    def __init__(self):
        super(LinearRegressionModel, self).__init__()
        # 定义一个线性层,输入为2个特征,输出为1个预测值
        self.linear = nn.Linear(3, 1)  # 输入维度2,输出维度1
    
    def forward(self, x):
        return self.linear(x)  # 前向传播,返回预测结果

# 创建模型实例
model = LinearRegressionModel()

# 损失函数(均方误差)
criterion = nn.MSELoss()

# 优化器(使用 SGD 或 Adam)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)  # 学习率设置为0.01

训练:

# 训练模型
num_epochs = 1000  # 训练 1000 轮
for epoch in range(num_epochs):
    model.train()  # 设置模型为训练模式

    # 前向传播
    predictions = model(X)  # 模型输出预测值
    loss = criterion(predictions.squeeze(), Y)  # 计算损失(注意预测值需要压缩为1D)

    # 反向传播
    optimizer.zero_grad()  # 清空之前的梯度
    loss.backward()  # 计算梯度
    optimizer.step()  # 更新模型参数

    # 打印损失
    if (epoch + 1) % 100 == 0:
        print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item():.4f}')

# 查看训练后的权重和偏置
print(f'Predicted weight: {model.linear.weight.data.numpy()}')
print(f'Predicted bias: {model.linear.bias.data.numpy()}')

# 在新数据上做预测
with torch.no_grad():  # 评估时不需要计算梯度
    predictions = model(X)

# 可视化预测与实际值
plt.scatter(X[:, 0], Y, color='blue', label='True values')
plt.scatter(X[:, 0], predictions, color='red', label='Predictions')
plt.legend()
plt.show()
相关推荐
feifeikon19 分钟前
PyTorch DAY2: 搭建神经网络
人工智能·pytorch·神经网络
迪小莫学AI40 分钟前
高效解决 LeetCode 2270: 分割数组的方案数
算法·leetcode·职场和发展
egoist20231 小时前
数据结构之顺序结构二叉树(超详解)
c语言·开发语言·数据结构·学习·算法·二叉树·向上/下调整算法
MichaelIp1 小时前
Pytorch基础教程:从零实现手写数字分类
人工智能·pytorch·python·深度学习·神经网络·机器学习·分类
pzx_0012 小时前
【论文阅读】基于空间相关性与Stacking集成学习的风电功率预测方法
论文阅读·人工智能·算法·机器学习·bootstrap·集成学习
梅茜Mercy2 小时前
蓝桥杯备赛:顺序表和单链表相关算法题详解(上)
算法·职场和发展·蓝桥杯
廖显东-ShirDon 讲编程2 小时前
《零基础Go语言算法实战》【题目 4-3】请用 Go 语言编写一个验证栈序列是否为空的算法
算法·程序员·go语言·web编程·go web
墨绿色的摆渡人3 小时前
用 Python 从零开始创建神经网络(二十一):保存和加载模型及其参数
人工智能·python·深度学习·神经网络·机器学习
游王子3 小时前
机器学习(1):线性回归概念
人工智能·机器学习·线性回归
圆圆滚滚小企鹅。3 小时前
刷题记录 回溯算法-10:93. 复原 IP 地址
数据结构·python·算法·leetcode