Pytorch中的损失函数及其用法

来自 千问

在 PyTorch 中,损失函数(Loss Function)是连接模型预测值与真实标签之间的"尺子",它衡量了模型犯错的程度。模型训练的目标就是通过优化算法(如 SGD、Adam)不断调整参数,使得这个"尺子"量出来的差距(损失值)越来越小。

损失函数主要位于 torch.nn 模块中。根据任务类型的不同(回归或分类),我们选择的损失函数也不同。

以下是 PyTorch 中最常用的几种损失函数及其用法详解:

1. 均方误差损失 (MSELoss)

这是回归任务中最常用的损失函数。

  • 公式:计算预测值与真实值之差的平方的平均值。
  • 特点:对异常值(Outliers)比较敏感,因为误差被平方放大了。
  • 适用场景:预测房价、温度、销量等连续数值。
python 复制代码
import torch
import torch.nn as nn

# 定义损失函数
loss_fn = nn.MSELoss()

# 模拟数据:假设我们预测了3个样本的值
pred = torch.tensor([2.5, 0.8, 1.2])  # 模型预测值
target = torch.tensor([3.0, 1.0, 1.0])  # 真实值

# 计算损失
loss = loss_fn(pred, target)
print(f"MSE Loss: {loss.item()}") 
# 手动计算:((2.5-3)^2 + (0.8-1)^2 + (1.2-1)^2) / 3 = 0.11

2. 交叉熵损失 (CrossEntropyLoss)

这是多分类任务(如图像分类)的首选损失函数。

  • 内部机制 :它结合了 LogSoftmaxNLLLoss(负对数似然损失)。
  • 特点
    • 输入是模型原始输出的 logits(不需要经过 Softmax 激活)。
    • 目标标签是类别索引(整数),而不是 one-hot 编码。
  • 适用场景:手写数字识别、猫狗分类等。
python 复制代码
# 定义损失函数
loss_fn = nn.CrossEntropyLoss()

# 模拟数据:3个样本,4个类别
logits = torch.tensor([
    [2.0, 1.0, 0.1, 0.5],  # 样本1的预测分数
    [0.5, 2.0, 0.3, 0.1],  # 样本2
    [0.1, 0.2, 3.0, 0.4]   # 样本3
])

# 真实标签:样本1是第0类,样本2是第1类,样本3是第2类
labels = torch.tensor([0, 1, 2])

# 计算损失
loss = loss_fn(logits, labels)
print(f"CrossEntropy Loss: {loss.item()}")

3. 二元交叉熵损失 (BCELoss / BCEWithLogitsLoss)

用于二分类多标签分类任务。

  • 区别
    • BCELoss:需要输入经过 Sigmoid 激活后的概率值(0~1之间)。
    • BCEWithLogitsLoss推荐使用。它将 Sigmoid 层内置到了损失函数中,数值计算更稳定。
  • 适用场景:判断图片中是否有猫(是/否),或者一张图片同时有多个标签(如既有"猫"又有"狗")。
python 复制代码
# 推荐使用 BCEWithLogitsLoss
loss_fn = nn.BCEWithLogitsLoss()

# 模拟数据:预测是否包含某个特征
pred = torch.tensor([1.0, -2.0, 0.5])  # 模型输出的 logits
target = torch.tensor([1.0, 0.0, 1.0]) # 真实标签(0或1)

loss = loss_fn(pred, target)
print(f"BCEWithLogits Loss: {loss.item()}")

4. L1 损失 (L1Loss)

也称为平均绝对误差(MAE)。

  • 公式:计算预测值与真实值之差的绝对值的平均。
  • 特点:对异常值比 MSE 更鲁棒(不敏感),因为没有平方操作。
  • 适用场景:数据中存在较多异常值的回归任务。
python 复制代码
loss_fn = nn.L1Loss()
pred = torch.tensor([2.5, 0.8, 1.2])
target = torch.tensor([3.0, 1.0, 1.0])

loss = loss_fn(pred, target)
print(f"L1 Loss: {loss.item()}") 
# 计算:(0.5 + 0.2 + 0.2) / 3 = 0.3

总结与选择建议

损失函数 PyTorch 类名 适用任务 关键特点
均方误差 nn.MSELoss() 回归任务 对大误差惩罚重,对异常值敏感
交叉熵 nn.CrossEntropyLoss() 多分类任务 输入为 logits,标签为整数索引
二元交叉熵 nn.BCEWithLogitsLoss() 二分类/多标签 内置 Sigmoid,数值稳定
L1 损失 nn.L1Loss() 回归任务 对异常值更鲁棒

在实际使用中,你通常会先实例化损失函数,然后在训练循环中计算损失并进行反向传播:

python 复制代码
# 伪代码:训练循环中的典型用法
criterion = nn.CrossEntropyLoss() # 1. 定义
optimizer = torch.optim.Adam(model.parameters())

for data, target in dataloader:
    optimizer.zero_grad()          # 梯度清零
    output = model(data)           # 2. 前向传播
    loss = criterion(output, target) # 3. 计算损失
    loss.backward()                # 4. 反向传播
    optimizer.step()               # 5. 更新参数
相关推荐
胡摩西2 分钟前
毫米级精准定位如何实现机器人自动回充:技术原理与工程实现
人工智能·机器学习·机器人·slam·室内定位·agv·roomaps
高洁017 分钟前
学习基于数字孪生的质量预测与控制
人工智能·python·深度学习·数据挖掘·transformer
上海蓝色星球9 分钟前
造价机器人CER V2.0正式上线!
大数据·人工智能·智慧城市·运维开发
CeshirenTester10 分钟前
2026春招规则彻底变了,应届生必须看懂这4个信号
人工智能
无心水10 分钟前
【OpenClaw:进阶开发】12、掌控每一个像素:OpenClaw + CDP 打造无界浏览器自动化
人工智能·cdp·openclaw·ai前沿·养龙虾·无界浏览器
xier_ran13 分钟前
【第一周】关键词解释:倒数排名融合(Reciprocal Rank Fusion, RRF)算法
开发语言·python·算法
HelloWorld__来都来了13 分钟前
如何用python爬取上市公司信息
开发语言·python
开朗觉觉17 分钟前
将json字符串转换为json对象
linux·服务器·python
飞升不如收破烂~19 分钟前
Transformer 架构:用「工厂流水线」讲透(无代码、纯人话)
人工智能·深度学习·transformer
2501_9481142421 分钟前
星链4SAPI + OpenClaw实战:给GPT-5.4与Claude 4.6装上“职业传送门”
python·gpt·架构