pytorch实现线性回归

pytorch实现线性回归

代码

python 复制代码
import torch
import numpy as np
from torch.nn import init
from torch.utils import data
from torch import nn

# 数据集
num_inputs = 2
num_examples = 1000
true_w = [2, -3.4]
true_b = 4.2
features = torch.from_numpy(np.random.normal(0, 1, (num_examples, num_inputs))).type(torch.float32)  # 1000*2
labels = true_w[0] * features[:, 0] + true_w[1] * features[:, 1] + true_b
labels += torch.from_numpy(np.random.normal(0, 0.01, size=labels.size()))  # 噪声

batch_size = 10
# 将训练数据的特征和标签组合
dataset = data.TensorDataset(features, labels)
# 随机读取⼩批量
data_iter = data.DataLoader(dataset, batch_size, shuffle=True)

# 使用框架预定义好的层
net = nn.Sequential(nn.Linear(2, 1))  # 输入是二维,输出是一维

# 初始化模型参数
# net[0].weight.data.normal_(0, 0.01)
# net[0].bias.data.fill_(0)
init.normal_(net[0].weight, mean=0, std=0.01)
init.constant_(net[0].bias, val=0)

# 计算均方误差使用的是MELoss类,也称为L_2范数
loss = nn.MSELoss()
# 实例化SGD(随机梯度下降)实例
optimizer = torch.optim.SGD(net.parameters(), lr=0.03)

# 训练
num_epochs = 3
l = 0
for epoch in range(1, num_epochs + 1):
    for X, y in data_iter:
        output = net(X)
        l = loss(output, y.view(-1, 1))
        optimizer.zero_grad() # 梯度清零,等价于net.zero_grad()
        l.backward()
        optimizer.step()
    print('epoch %d, loss: %f' % (epoch, l.item()))

结果

相关推荐
埃菲尔铁塔_CV算法4 分钟前
人工智能图像算法:开启视觉新时代的钥匙
人工智能·算法
EasyCVR5 分钟前
EHOME视频平台EasyCVR视频融合平台使用OBS进行RTMP推流,WebRTC播放出现抖动、卡顿如何解决?
人工智能·算法·ffmpeg·音视频·webrtc·监控视频接入
打羽毛球吗️11 分钟前
机器学习中的两种主要思路:数据驱动与模型驱动
人工智能·机器学习
好喜欢吃红柚子28 分钟前
万字长文解读空间、通道注意力机制机制和超详细代码逐行分析(SE,CBAM,SGE,CA,ECA,TA)
人工智能·pytorch·python·计算机视觉·cnn
小馒头学python32 分钟前
机器学习是什么?AIGC又是什么?机器学习与AIGC未来科技的双引擎
人工智能·python·机器学习
神奇夜光杯41 分钟前
Python酷库之旅-第三方库Pandas(202)
开发语言·人工智能·python·excel·pandas·标准库及第三方库·学习与成长
正义的彬彬侠44 分钟前
《XGBoost算法的原理推导》12-14决策树复杂度的正则化项 公式解析
人工智能·决策树·机器学习·集成学习·boosting·xgboost
Debroon1 小时前
RuleAlign 规则对齐框架:将医生的诊断规则形式化并注入模型,无需额外人工标注的自动对齐方法
人工智能
羊小猪~~1 小时前
神经网络基础--什么是正向传播??什么是方向传播??
人工智能·pytorch·python·深度学习·神经网络·算法·机器学习
AI小杨1 小时前
【车道线检测】一、传统车道线检测:基于霍夫变换的车道线检测史诗级详细教程
人工智能·opencv·计算机视觉·霍夫变换·车道线检测