pytorch学习(五)、损失函数,反向传播和优化器的使用

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录


一、损失函数

  • 损失函数:就是在中学我们学到的回归问题中 y 真 − y 测的绝对值 y真 - y测的绝对值 y真−y测的绝对值
  • 损失函数loss一定是越小越好,也就是预测的数据越接近真实值越好
  • 在下面代码里就是 ∣ y − x ∣ / l e n ( x ) |y - x| / len(x) ∣y−x∣/len(x)
python 复制代码
import torch
from torch.nn import L1Loss
#创建一个数据x
inputs = torch.tensor([1,2,3],dtype=torch.float32)
# 创建一个目标y
targets = torch.tensor([1,2,5],dtype=torch.float32)
# 创建一个损失函数
loss = L1Loss()
# 计算损失
result = loss(inputs,targets)

print(result)
'''
tensor(0.6667)
'''

二、反向传播

  • forward:前向传播
  • backward:反向传播(代码第四十一行)

反向传播就是为更新输出提供一定的依据。

python 复制代码
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader
# 下载数据集
dataset = torchvision.datasets.CIFAR10(root='./dataset/CIFAR10',transform=torchvision.transforms.ToTensor(),download=True)
# 加载数据集
dataloader = DataLoader(dataset,batch_size=1)
# 构建网络
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x
# 实例化网络
model = Model()
# 创建损失函数
loss = nn.CrossEntropyLoss()
# 读取数据
for data in dataloader:
    imgs,targets = data
    # 使用网络
    output = model(imgs)
    # 计算损失
    result_loss = loss(output,targets)
    # print(result_loss)
    # 进行反向传播(也就是梯度从未到有)
    result_loss.backward()

三、优化器

  • 也可以说是梯度下降的,将得到的反向传播的数据进行更新优化,使损失函数逐渐减小来提高模型准确率。
  • 例如:optim = torch.optim.SGD(model.parameters(),lr=0.01) 随机梯度下降
  • 每种优化器前面两个参数几乎都一样,第一个参数放自己模型的训练参数,第二个参数是学习率,学习率不能太大会造成模型的不稳定,也不能太小会降低模型的运行速度。
python 复制代码
import torch
from torch import nn
import torchvision
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10(root='./dataset/CIFAR10',train=False,transform=torchvision.transforms.ToTensor(),download=True)
dataloader = DataLoader(dataset,batch_size=64)

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(3, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 32, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 5, padding=2),
            nn.MaxPool2d(2),
            nn.Flatten(),
            nn.Linear(1024, 64),
            nn.Linear(64, 10)
        )
    def forward(self,x):
        x = self.model1(x)
        return x

model = Model()
# 创建优化器
optim = torch.optim.SGD(model.parameters(),lr=0.01)
# 创建损失函数
loss = nn.CrossEntropyLoss()
# 迭代次数
for epoch in range(20):
	# 总损失
    running_loss = 0.0
    for data in dataloader:
        imgs,targets = data
        output = model(imgs)
        result_loss = loss(output,targets)
        optim.zero_grad()
        # 进行反向传播
        result_loss.backward()
        # 进行优化梯度随机下降
        optim.step()
        running_loss = running_loss + result_loss
    print(running_loss)
'''
tensor(360.6862, grad_fn=<AddBackward0>)
tensor(357.4413, grad_fn=<AddBackward0>)
tensor(344.1332, grad_fn=<AddBackward0>)
tensor(323.9127, grad_fn=<AddBackward0>)
tensor(315.2773, grad_fn=<AddBackward0>)
tensor(307.1144, grad_fn=<AddBackward0>)
tensor(297.0642, grad_fn=<AddBackward0>)
'''
相关推荐
小A1594 分钟前
STM32完全学习——SPI接口的FLASH(DMA模式)
stm32·嵌入式硬件·学习
桃花键神22 分钟前
AI可信论坛亮点:合合信息分享视觉内容安全技术前沿
人工智能
岁岁岁平安26 分钟前
spring学习(spring-DI(字符串或对象引用注入、集合注入)(XML配置))
java·学习·spring·依赖注入·集合注入·基本数据类型注入·引用数据类型注入
武昌库里写JAVA29 分钟前
Java成长之路(一)--SpringBoot基础学习--SpringBoot代码测试
java·开发语言·spring boot·学习·课程设计
qq_5895681034 分钟前
数据可视化echarts学习笔记
学习·信息可视化·echarts
野蛮的大西瓜43 分钟前
开源呼叫中心中,如何将ASR与IVR菜单结合,实现动态的IVR交互
人工智能·机器人·自动化·音视频·信息与通信
CountingStars6191 小时前
目标检测常用评估指标(metrics)
人工智能·目标检测·目标跟踪
tangjunjun-owen1 小时前
第四节:GLM-4v-9b模型的tokenizer源码解读
人工智能·glm-4v-9b·多模态大模型教程
冰蓝蓝1 小时前
深度学习中的注意力机制:解锁智能模型的新视角
人工智能·深度学习
兔C1 小时前
微信小程序的轮播图学习报告
学习·微信小程序·小程序