半监督语义分割学习笔记

目录

[partial cross entropy loss](#partial cross entropy loss)


GitHub - LiheYoung/UniMatch: [CVPR 2023] Revisiting Weak-to-Strong Consistency in Semi-Supervised Semantic Segmentation

partial cross entropy loss

python 复制代码
import torch
import torch.nn.functional as F

def partial_cross_entropy_loss(inputs, targets, ignore_index=-1):
    """
    自定义部分交叉熵损失函数,忽略 ignore_index 指定的标签。
    
    :param inputs: 模型的输出,形状应为 (N, C, H, W),其中 N 是批量大小,C 是类别数,H 和 W 是高度和宽度。
    :param targets: 真实的标签,形状应为 (N, H, W)。
    :param ignore_index: 要忽略的标签值,默认为 -1。
    :return: 计算得到的损失。
    """
    # 计算 log softmax
    log_probs = F.log_softmax(inputs, dim=1)
    
    # 将 log_probs 和 targets 转换为适合 gather 的形状
    log_probs = log_probs.permute(0, 2, 3, 1)  # (N, H, W, C)
    log_probs = log_probs.reshape(-1, log_probs.shape[-1])  # (N*H*W, C)
    targets = targets.view(-1)  # (N*H*W)
    
    # 掩码未标记的数据点
    mask = targets != ignore_index
    log_probs = log_probs[mask]
    targets = targets[mask]
    
    # 只计算有标签的数据点的损失
    loss = F.nll_loss(log_probs, targets, reduction='mean')
    
    return loss
python 复制代码
# 假设模型的输出和真实标签
outputs = torch.randn(2, 3, 5, 5)  # 随机生成模拟输出(2个样本,3个类别,5x5的图像)
targets = torch.tensor([[[-1, 1, -1, 0, -1], 
                         [1, -1, 2, 2, 1], 
                         [-1, -1, 1, -1, 0], 
                         [2, 2, 2, -1, 1], 
                         [-1, 0, -1, 0, 1]], 
                        [[1, 0, -1, 1, -1], 
                         [2, 2, -1, 0, 0], 
                         [-1, 1, 1, 0, -1], 
                         [0, 0, 2, -1, 1], 
                         [2, -1, 0, -1, -1]]])  # 生成带有未标记区域的标签

# 计算损失
loss = partial_cross_entropy_loss(outputs, targets)
print(f"Loss: {loss.item()}")
相关推荐
寻丶幽风2 小时前
论文阅读笔记——双流网络
论文阅读·笔记·深度学习·视频理解·双流网络
令狐前生3 小时前
设计模式学习整理
学习·设计模式
湘-枫叶情缘4 小时前
解构认知边界:论万能方法的本体论批判与方法论重构——基于跨学科视阈的哲学-科学辩证
科技·学习·重构·生活·学习方法
inputA5 小时前
【LwIP源码学习6】UDP部分源码分析
c语言·stm32·单片机·嵌入式硬件·网络协议·学习·udp
海尔辛5 小时前
学习黑客5 分钟读懂Linux Permissions 101
linux·学习·安全
真的想上岸啊6 小时前
学习51单片机01(安装开发环境)
嵌入式硬件·学习·51单片机
sz66cm7 小时前
Linux基础 -- SSH 流式烧录与压缩传输笔记
linux·笔记·ssh
每次的天空7 小时前
Android学习总结之Glide自定义三级缓存(面试篇)
android·学习·glide
名誉寒冰8 小时前
# KVstorageBaseRaft-cpp 项目 RPC 模块源码学习
qt·学习·rpc
开发游戏的老王8 小时前
[虚幻官方教程学习笔记]深入理解实时渲染(An In-Depth Look at Real-Time Rendering)
笔记·学习·虚幻