PyTorch基本使用-自动微分模块

学习目的:掌握自动微分模块的使用

训练神经网络时,最常用的算法就是反向传播。在该算法中,参数(模型权重)会根据损失函数关于对应参数的梯度进行调整。为了计算这些梯度,PyTorch 内置了名为 torch.autograd的微分引擎。它支持任意计算图的自动梯度计算:

接下来我们使用这个结构进行自动微分模块的介绍。我们使用 backward 方法、grad 属性来实现梯度的计算和访问。

  • 当X为标量时梯度的计算

    python 复制代码
    import torch
    # 1. 当X为标量时梯度的计算
    def test01():
        x = torch.tensor(5)
        # 目标值
        y = torch.tensor(0.)
        # 设置要更新的权重和偏置的初始值
        w = torch.tensor(1.0,requires_grad=True,dtype=torch.float32)
        b = torch.tensor(3.0,requires_grad=True,dtype=torch.float32)
        #设置网络的输出值
        z = x*w + b #矩阵乘法
        # 设置损失函数,并进行损失计算
        loss = torch.nn.MSELoss()
        loss = loss(z,y)
        # 自动微分
        loss.backward()
        # 打印w,b变量的梯度
        # backward 函数计算的梯度值会存储在张量的grad 变量中
        print('W的梯度:',w.grad)
        print('B的梯度:',b.grad)
    
    test01()

    输出结果:

    tex 复制代码
    W的梯度: tensor(80.)
    B的梯度: tensor(16.)
  • 当X为多维张量时梯度计算

    python 复制代码
    import torch
    def test02():
        # 输入张量 2*5
        x = torch.ones(2,5)
        # 目标张量 2*3
        y = torch.zeros(2,3)
        # 设置要更新的权重和偏置的初始值
        w = torch.randn(5,3,requires_grad=True)
        b = torch.randn(3,requires_grad=True)
        #设置网络的输出值
        z = torch.matmul(x,w)+ b #矩阵乘法
        # 设置损失函数,并进行损失计算
        loss = torch.nn.MSELoss()
        loss = loss(z,y)
        # 自动微分
        loss.backward()
        # 打印w,b变量的梯度
        # backward 函数计算的梯度值会存储在张量的grad 变量中
        print('W的梯度:',w.grad)
        print('B的梯度:',b.grad)
    
    test02()

    输出结果:

    tex 复制代码
    W的梯度: tensor([[-1.7502,  0.8537,  0.6175],
            [-1.7502,  0.8537,  0.6175],
            [-1.7502,  0.8537,  0.6175],
            [-1.7502,  0.8537,  0.6175],
            [-1.7502,  0.8537,  0.6175]])
    B的梯度: tensor([-1.7502,  0.8537,  0.6175])
相关推荐
F20226974868 分钟前
使用 Python 爬取某网站简历模板(bs4/lxml+协程)
开发语言·python
cdg==吃蛋糕12 分钟前
pdf读取函数,可以读取本地pdf和url的在线pdf转换为文字
python·pdf
前程的前程也迷茫16 分钟前
flask程序线程问题
python·flask
睡觉狂魔er22 分钟前
自动驾驶控制与规划——Project 1: 车辆纵向控制
人工智能·机器学习·自动驾驶
goomind23 分钟前
YOLOv8实战bdd100k自动驾驶目标识别
人工智能·深度学习·yolo·计算机视觉·目标跟踪·自动驾驶·bdd100k
GIS 数据栈25 分钟前
自动驾驶领域常用的软件与工具
人工智能·机器学习·自动驾驶
博雅智信27 分钟前
人工智能-自动驾驶领域
人工智能·python·深度学习·yolo·机器学习·计算机视觉·自动驾驶
Jackilina_Stone28 分钟前
【自动驾驶】1 自动驾驶概述
人工智能·自动驾驶
数据龙傲天34 分钟前
大数据时代下的电商API接口创新应用
爬虫·python·数据分析·api
whaosoft-1431 小时前
51c深度学习~合集9
人工智能