8-pytorch-损失函数与反向传播

b站小土堆pytorch教程学习笔记
根据loss更新模型参数

1.计算实际输出与目标之间的差距

2.为我们更新输出提供一定的依据(反向传播)

1 MSEloss

python 复制代码
import torch
from torch.nn import L1Loss
from torch import nn

inputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)

inputs=torch.reshape(inputs,(-1,1,1,3))
targets=torch.reshape(targets,(-1,1,1,3))

loss=L1Loss()
result=loss(inputs,targets)

loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)

print(result)
print(result_mse)

tensor(0.6667)
tensor(1.3333)

2 Cross EntropyLoss

python 复制代码
x=torch.tensor([0.1,0.2,0.3])#需要reshape为要求的(batch_size,class)
y=torch.tensor([1])#target已经为要求的batch_size无需reshape
x=torch.reshape(x,(-1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(result_cross)

tensor(1.1019)

3 在具体的神经网络中使用loss

python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=1)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
for data in dataloader:
    imgs,target=data
    output=han(imgs)
    # print(target)
    # print(output)
    result_loss=loss(output,target)
    print(result_loss)

*tensor([7])

tensor([[ 0.0057, -0.0201, -0.0796, 0.0556, -0.0625, 0.0125, -0.0413, -0.0056,
0.0624, -0.1072]], grad_fn=)...

tensor(2.2664, grad_fn=)...

4 反向传播 优化器

  1. 定义优化器
  2. 将待更新的每个参数梯度清零
  3. 调用损失函数的反向传播函数求出每个节点的梯度
  4. 使用step函数对模型的每个参数调优
python 复制代码
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

dataset=torchvision.datasets.CIFAR10('dataset',train=False,
                                     transform=torchvision.transforms.ToTensor(),
                                     download=True)
dataloader=DataLoader(dataset,batch_size=64)

class Han(nn.Module):
    def __init__(self):
        super(Han, self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32,64,5,padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024,64),
            Linear(64,10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x

loss=nn.CrossEntropyLoss()
han=Han()
optim=torch.optim.SGD(han.parameters(),lr=0.01)

for epoch in range(5):
    running_loss=0.0#一个epoch结束的loss和
    for data in dataloader:
        imgs,target=data
        output=han(imgs)

        result_loss=loss(output,target)#每次迭代的loss
        optim.zero_grad()#将网络中每个可调节参数对应的梯度调为0
        result_loss.backward()#优化器需要每个参数的梯度,使用反向传播获得
        optim.step()#对每个参数调优
        running_loss=running_loss+result_loss
    print(running_loss)

Files already downloaded and verified
tensor(361.0316, grad_fn=)
tensor(357.6938, grad_fn=)
tensor(343.0560, grad_fn=)
tensor(321.8132, grad_fn=)
tensor(313.3173, grad_fn=)

相关推荐
chushiyunen3 分钟前
python中的异常处理
开发语言·python
观书喜夜长7 分钟前
大模型应用开发学习-基于 LangChain 框架实现的交互式问答脚本
python·学习
章鱼丸-10 分钟前
DAY32 官方文档的阅读
python
于慨16 分钟前
docker
python
Kel24 分钟前
深入剖析 openai-node 源码:一个工业级 TypeScript SDK 的架构之美
javascript·人工智能·架构
GinoWi24 分钟前
Chapter 7 Python中的函数
python
m0_5180194828 分钟前
使用Seaborn绘制统计图形:更美更简单
jvm·数据库·python
Hommy8828 分钟前
【剪映小助手-客户端】构建与部署
python·aigc·剪映小助手
GinoWi30 分钟前
Chapter 6 Python中的字典
python
岛雨QA33 分钟前
Skill学习指南🧑‍💻
人工智能·agent·ai编程