[pytorch] 8.损失函数和反向传播

损失函数

torch提供了很多损失函数,可查看官方文档Loss Functions部分

  • 作用:
    1. 计算实际输出和目标输出之间的差距
    2. 为更新输出提供一定的依据(反向传播),grad

损失函数用法差不多,这里以L1Loss和MSEloss为例

  • L1Loss
    注意传入的数据要为float类型,不然会报错,所以inputs和outputs处要加上类型转换
    L1Loss的参数reduction,设置了计算loss值的方式,默认为差距绝对值的均值,也可以设置为'sum',这是输出就为2
  • MSELoss 平方差损失函数
    先看要求的输入输出

    也是batch_size的那种形式
python 复制代码
import torch
from torch.nn import L1Loss
from torch.nn import MSELoss

inputs = torch.tensor([1,2,3],dtype = torch.float32)
outputs = torch.tensor([1,2,5],dtype = torch.float32)

inputs = torch.reshape(inputs, (1,1,1,3))
outputs = torch.reshape(outputs, (1,1,1,3))

# L1Loss()
loss = L1Loss()
result = loss(inputs, outputs)
print(result)

# MSELoss()
loss_mse = MSELoss()
result_mse = loss_mse(inputs, outputs)
print(result_mse)

反向传播

python 复制代码
from torch import nn
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset=dataset, batch_size=1)


class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2), # 计算同上
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten() ,
            nn.Linear(1024, 64),
            nn.Linear(64, 10),
        )
    
    def forward(self, x):
        x = self.model1(x)
        return x
        
net = Test()
loss = nn.CrossEntropyLoss()
for data in dataloader:
    imgs, targets = data
    output = net(imgs)
    resulr_loss = loss(output, targets)
    print(resulr_loss)

加上反向传播后:

python 复制代码
from torch import nn
import torch
from torch.utils.tensorboard import SummaryWriter
import torchvision
from torch.utils.data import DataLoader

dataset = torchvision.datasets.CIFAR10('./dataset', train=False, transform=torchvision.transforms.ToTensor())
dataloader = DataLoader(dataset=dataset, batch_size=1)


class Test(nn.Module):
    def __init__(self):
        super(Test,self).__init__()
        self.model1 = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=5, padding=2), # 计算同上
            nn.MaxPool2d(kernel_size=2),
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=5, padding=2),
            nn.MaxPool2d(kernel_size=2),
            nn.Flatten() ,
            nn.Linear(1024, 64),
            nn.Linear(64, 10),
        )
    
    def forward(self, x):
        x = self.model1(x)
        return x
        # 这就不需要像之前那种一样一个一个调用了
    
    # 这样网络就写完了

net = Test()
loss = nn.CrossEntropyLoss()
for data in dataloader:
    imgs, targets = data
    output = net(imgs)
    result_loss = loss(output, targets)
    result_loss.backward()  # 注意不是loss.backward(),而是result_loss.backward()
    print('ok')

backward行打断点,进入调试界面可以查看网络内部的参数

weighr里面有grad

运行到backward之前,grad里是none

运行完之后,计算出梯度

后面可以使用优化器,利用计算出来的梯度,对神经网络进行更新

相关推荐
Ma04071310 分钟前
【机器学习】监督学习、无监督学习、半监督学习、自监督学习、弱监督学习、强化学习
人工智能·学习·机器学习
cooldream200911 分钟前
LlamaIndex 存储体系深度解析
人工智能·rag·llamaindex
CoovallyAIHub22 分钟前
如何在手机上轻松识别多种鸟类?我们发现了更简单的秘密……
深度学习·算法·计算机视觉
Elastic 中国社区官方博客29 分钟前
使用 A2A 协议和 MCP 在 Elasticsearch 中创建一个 LLM agent 新闻室:第二部分
大数据·数据库·人工智能·elasticsearch·搜索引擎·ai·全文检索
知识浅谈31 分钟前
我用Gemini3pro 造了个手控全息太阳系
人工智能
孤廖33 分钟前
终极薅羊毛指南:CLI工具免费调用MiniMax-M2/GLM-4.6/Kimi-K2-Thinking全流程
人工智能·经验分享·chatgpt·ai作画·云计算·无人机·文心一言
aneasystone本尊34 分钟前
学习 LiteLLM 的日志系统
人工智能
秋邱39 分钟前
价值升维!公益赋能 + 绿色技术 + 终身学习,构建可持续教育 AI 生态
网络·数据库·人工智能·redis·python·学习·docker
Mintopia41 分钟前
🎭 小众语言 AIGC:当 Web 端的低资源语言遇上“穷得只剩文化”的生成挑战
人工智能·aigc·全栈
安达发公司43 分钟前
安达发|告别手工排产!车间排产软件成为中央厨房的“最强大脑”
大数据·人工智能·aps高级排程·aps排程软件·安达发aps·车间排产软件