Pytorch笔记之回归

文章目录


前言

以线性回归为例,记录Pytorch的基本使用方法。


一、导入库

python 复制代码
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch.autograd import Variable # 定义求导变量
from torch import nn, optim # 定义网络模型和优化器

二、数据处理

将数据类型转为tensor,第一维度变为batch_size

python 复制代码
# 构建数据
x = np.random.rand(100)
noise = np.random.normal(0, 0.01, x.shape)
y = 0.1 * x + 0.2 + noise
# 数据处理
x_data = torch.FloatTensor(x.reshape(-1, 1))
y_data = torch.FloatTensor(y.reshape(-1, 1))
inputs = Variable(x_data)
target = Variable(y_data)

三、构建模型

1、继承nn.Module,定义一个线性回归模型。在__init__中定义连接层,定义前向传播的方法

2、实例化模型,定义损失函数与优化器

python 复制代码
# 继承模型
class LinearRegression(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = nn.Linear(1, 1)
    def forward(self, x):
        out = self.fc(x)
        return out
# 定义模型
print('模型参数')
model = LinearRegression()
mse_loss = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
for name, param in model.named_parameters():
    print('{}:{}'.format(name, param))

四、迭代训练

1、梯度清零:optimizer.zero_grad()

2、反向传播计算梯度值:loss.backward()

3、执行参数更新:optimizer.step()

循环迭代,定期输出损失值

python 复制代码
print('损失值')
for i in range(1001):
    out = model.forward(inputs)
    loss = mse_loss(out, target)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if i % 200 == 0:
        print(i, loss.item())

五、结果预测

绘制样本的散点图与预测值的折线图

python 复制代码
print('结果预测')
y_pred = model(x_data)
plt.plot(x, y, 'b.')
plt.plot(x, y_pred.data.numpy(), 'r-')
plt.show()

总结

使用Pytorch进行训练主要的三步:

(1)数据处理:将数据维度转换为(batch, *),数据类型转换为可训练的tensor;

(2)构建模型:继承nn.Module,定义连接层与运算方法,实例化,定义损失函数与优化器;

(3)迭代训练:循环迭代,依次执行梯度清零、梯度计算、参数更新。

相关推荐
冰暮流星几秒前
javascript之history对象介绍
前端·笔记
jialiguo1 小时前
博客摘录「 尚硅谷Vue3入门到实战,最新版Vue3+TypeScript前端开发教程」2024年8月7日
笔记
三无推导1 小时前
ComfyUI 安装部署教程:Windows 下快速搭建可视化 AI 绘图工作流,零基础也能跑通
人工智能·pytorch·windows·stable diffusion·aigc·ai绘画·持续部署
風清掦2 小时前
【STM32学习笔记-14】WDG看门狗 - 14.2 WWDG窗口看门狗
笔记·stm32·单片机·嵌入式硬件·学习·fpga开发
晓梦林3 小时前
bughush靶场学习笔记
笔记·学习
独隅3 小时前
PyTorch自动微分模块:从原理到实战一
人工智能·pytorch·python
sakiko_4 小时前
Swift学习笔记34-MVC架构,SwiftUI与UIkit混编练习
笔记·学习·swiftui·mvc·swift
Afans_fire4 小时前
多渠道广告归因:3种逻辑解决效果分配难题
笔记·内容运营·广告投放·广告营销·徐州巨量星河
泉飒4 小时前
qt软件无法打开编译
笔记·工业视觉
穗余5 小时前
2026 AI x Web3 School共学营笔记-Day10-Women Builders in AI × Web3
人工智能·笔记·web3