Pytorch学习16_损失函数与反向传播

官方网址

torch.nn --- PyTorch 2.1 documentation

MSELoss

创建张量

复制代码
inputs=torch.tensor([1,2,3],dtype=torch.float32)
targets=torch.tensor([1,2,5],dtype=torch.float32)

创建inputs和targets张量,其数值分别为1, 2, 31,2,5,数据类型为float32

形状变换

复制代码
inputs=torch.reshape(inputs,(1,1,1,3))
targets=torch.reshape(targets,(1,1,1,3))
复制代码
对inputs和targets进行形状变换,将原始形状(3,)变为(1, 1, 1, 3)
这样的变换通常在深度学习中用于处理需要4D输入的模型,如卷积神经网络 (CNN) 的输入格式
此处 的(1,1,1,3)
其中,1表示批次大小(batch size),1表示通道数,1表示高度(height),3表示宽度(width)

例如:
形状变换前的数据结构:
inputs: [1.0, 2.0, 3.0],形状 (3,)


形状变换后的数据结构:
inputs: [[[[1.0, 2.0, 3.0]]]],形状 (1, 1, 1, 3)
复制代码
loss=L1Loss(reduction='sum')
result=loss(inputs,targets)

使用 PyTorch 中的 L1 损失函数(平均绝对误差损失),并通过 reduction='sum' 参数指定计算总和损失。然后,将 inputstargets 传递给损失函数,计算它们之间的损失值。

复制代码
loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)

使用 PyTorch 中的均方误差损失函数(nn.MSELoss)计算了两个张量 inputstargets 之间的均方误差(Mean Squared Error,MSE)。

解释具体步骤:

  1. loss_mse=nn.MSELoss():创建了一个均方误差损失函数的实例,该实例被存储在变量 loss_mse 中。

  2. result_mse=loss_mse(inputs,targets):使用创建的均方误差损失函数计算了 inputstargets 之间的均方误差,并将结果存储在变量 result_mse 中。

均方误差是回归问题中常用的损失函数,它计算了预测值与真实值之间的差异的平方的均值。在这里,inputs 可能是模型的输出,而 targets 则是真实的标签或目标值。result_mse 中的数值表示了两个张量之间的均方误差,数值越小表示模型的预测越接近真实值。

输出结果

交叉熵

复制代码
x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(f"result_cross:{result_cross}")

这段代码使用 PyTorch 中的交叉熵损失函数(nn.CrossEntropyLoss)计算了两个张量 xy 之间的交叉熵损失。

解释具体步骤:

  1. x=torch.tensor([0.1,0.2,0.3]):定义了一个包含三个元素的张量 x,这可能是模型的输出。

  2. y=torch.tensor([1]):定义了一个包含一个元素的张量 y,这可能是真实的类别标签。

  3. x=torch.reshape(x,(1,3)):将张量 x 的形状调整为 (1, 3),这是为了与交叉熵损失函数的要求相符。

  4. loss_cross=nn.CrossEntropyLoss():创建了一个交叉熵损失函数的实例,该实例被存储在变量 loss_cross 中。

  5. result_cross=loss_cross(x, y):使用创建的交叉熵损失函数计算了 xy 之间的交叉熵损失,并将结果存储在变量 result_cross 中。

在交叉熵损失中,x 通常是模型的输出,表示各个类别的得分,而 y 是真实的类别标签。result_cross 中的数值表示了两个张量之间的交叉熵损失,数值越小表示模型的预测越接近真实类别。

复制代码
import torch
from torch import nn
from torch.nn import L1Loss

inputs=torch.tensor([1,2,3],dtype=torch.float32)#张量,其数值为[1, 2, 3],数据类型为float32
targets=torch.tensor([1,2,5],dtype=torch.float32)

inputs=torch.reshape(inputs,(1,1,1,3))
targets=torch.reshape(targets,(1,1,1,3))

loss=L1Loss(reduction='sum')
result=loss(inputs,targets)

loss_mse=nn.MSELoss()
result_mse=loss_mse(inputs,targets)

print(f"result:{result}")
print(f"result_mse:{result_mse}")


x=torch.tensor([0.1,0.2,0.3])
y=torch.tensor([1])
x=torch.reshape(x,(1,3))
loss_cross=nn.CrossEntropyLoss()
result_cross=loss_cross(x,y)
print(f"result_cross:{result_cross}")

输出结果


复制代码
dataset=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),download=True)

dataloader=DataLoader(batch_size=64)

加载 CIFAR-10 数据集并创建一个 DataLoader 对象,其中包含每个批次(batch)包含 64 个样本。

代码全文:

复制代码
import torchvision
from torch import nn
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import DataLoader

dataset=torchvision.datasets.CIFAR10("./dataset",train=False,transform=torchvision.transforms.ToTensor(),
                                     download=True)

dataloader=DataLoader(dataset,batch_size=64)

class Xuex(nn.Module):
    def __init__(self):
        super(Xuex,self).__init__()
        self.model1=Sequential(
            Conv2d(3,32,5,padding=2),
            MaxPool2d(2),
            Conv2d(32, 32, 5, padding=2),
            MaxPool2d(2),
            Conv2d(32, 64, 5, padding=2),
            MaxPool2d(2),
            Flatten(),
            Linear(1024, 64),
            Linear(64, 10)
        )

    def forward(self,x):
        x=self.model1(x)
        return x
xuexu=Xuex()
for data in dataloader:
    imgs,targets=data
    outputs=xuexu(imgs)
    print(f"output:{outputs}")
    print(f"targets:{targets}")
    break

输出结果

~

参考

【PyTorch深度学习快速入门教程(绝对通俗易懂!)【小土堆】】 https://www.bilibili.com/video/BV1hE411t7RN/?p=23\&share_source=copy_web\&vd_source=be33b1553b08cc7b94afdd6c8a50dc5a

相关推荐
朱大喜几秒前
matplotlib/Plotly/ECharts 可视化看板设计:从图表选型到交互体验的工程化实践
人工智能
HappyAcmen3 分钟前
5.通义向量模型调用
python
AOwhisky13 分钟前
Redis 学习笔记(第一期):概述、安装配置与核心理论
运维·数据库·redis·笔记·学习·云计算
云烟成雨TD18 分钟前
Agent Scope Java 2.x 系列【3】从零构建 ReActAgent
java·人工智能·agent
❀抽抽22 分钟前
证件照制作API接入指南:700+规格一键生成
大数据·网络·人工智能
Promise微笑24 分钟前
绝缘油介损(油介损)测试仪的深层机理、技术演进与精准诊断策略
大数据·网络·人工智能
开发者小布28 分钟前
Claude Code 国内配置完整指南:通过中转 API 实现稳定访问(macOS / Linux / Windows)
人工智能
大C聊AI34 分钟前
通用大模型纷纷收费,垂直场景AI工具的价值正在被重估
大数据·人工智能·机器学习·办公效率·ai 工具·智标领航·ai 辅助办公
苏州邦恩精密39 分钟前
2026江苏GOM三维扫描仪定制厂家找哪家?企业数字化转型视角
人工智能·机器学习·3d·自动化·制造