目录
[partial cross entropy loss](#partial cross entropy loss)
partial cross entropy loss
python
import torch
import torch.nn.functional as F
def partial_cross_entropy_loss(inputs, targets, ignore_index=-1):
"""
自定义部分交叉熵损失函数,忽略 ignore_index 指定的标签。
:param inputs: 模型的输出,形状应为 (N, C, H, W),其中 N 是批量大小,C 是类别数,H 和 W 是高度和宽度。
:param targets: 真实的标签,形状应为 (N, H, W)。
:param ignore_index: 要忽略的标签值,默认为 -1。
:return: 计算得到的损失。
"""
# 计算 log softmax
log_probs = F.log_softmax(inputs, dim=1)
# 将 log_probs 和 targets 转换为适合 gather 的形状
log_probs = log_probs.permute(0, 2, 3, 1) # (N, H, W, C)
log_probs = log_probs.reshape(-1, log_probs.shape[-1]) # (N*H*W, C)
targets = targets.view(-1) # (N*H*W)
# 掩码未标记的数据点
mask = targets != ignore_index
log_probs = log_probs[mask]
targets = targets[mask]
# 只计算有标签的数据点的损失
loss = F.nll_loss(log_probs, targets, reduction='mean')
return loss
python
# 假设模型的输出和真实标签
outputs = torch.randn(2, 3, 5, 5) # 随机生成模拟输出(2个样本,3个类别,5x5的图像)
targets = torch.tensor([[[-1, 1, -1, 0, -1],
[1, -1, 2, 2, 1],
[-1, -1, 1, -1, 0],
[2, 2, 2, -1, 1],
[-1, 0, -1, 0, 1]],
[[1, 0, -1, 1, -1],
[2, 2, -1, 0, 0],
[-1, 1, 1, 0, -1],
[0, 0, 2, -1, 1],
[2, -1, 0, -1, -1]]]) # 生成带有未标记区域的标签
# 计算损失
loss = partial_cross_entropy_loss(outputs, targets)
print(f"Loss: {loss.item()}")