【Python/Pytorch - 网络模型】-- TV Loss损失函数

文章目录

文章目录

  • [00 写在前面](#00 写在前面)
  • [01 基于Pytorch版本的TV Loss代码](#01 基于Pytorch版本的TV Loss代码)
  • [02 论文下载](#02 论文下载)

00 写在前面

在医学图像重建过程中,经常在代价方程中加入TV 正则项,该正则项作为去噪项,对于重建可以起到很大帮助作用。但是对于一些纹理细节要求较高的任务,加入TV 正则项,在一定程度上可能会降低纹理细节。

对于连续函数,其表达式为:

对于图片而言,即为离散的数值,求每一个像素和横向下一个像素的差的平方,加上纵向下一个像素的差的平方,再开β/2次根:

01 基于Pytorch版本的TV Loss代码

python 复制代码
import torch
from torch.autograd import Variable


class TVLoss(torch.nn.Module):
    """
    TV loss
    """

    def __init__(self, weight=1):
        super(TVLoss, self).__init__()
        self.weight = weight

    def forward(self, x):
        batch_size = x.size()[0]
        h_x = x.size()[2]
        w_x = x.size()[3]
        count_h = self._tensor_size(x[:, :, 1:, :])
        count_w = self._tensor_size(x[:, :, :, 1:])
        h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
        w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
        return self.weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size

    def _tensor_size(self, t):
        return t.size()[1] * t.size()[2] * t.size()[3]


if __name__ == "__main__":
    x = Variable(
        torch.FloatTensor([[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[1, 2, 3], [4, 5, 6], [7, 8, 9]]]).view(1, 2, 3, 3),
        requires_grad=True)
    tv = TVLoss()
    result = tv(x)
    print(result)

02 论文下载

Understanding Deep Image Representations by Inverting Them

相关推荐
刘简爱学习3 小时前
用于病理图像多类分割的弱监督状态空间模型PathMamba
人工智能·深度学习·计算机视觉
心勤则明3 小时前
使用 Spring AI Alibaba MCP 结合 Nacos 实现企业级智能体应用
java·人工智能·spring
70asunflower3 小时前
AI Infra 架构全景介绍
人工智能·架构
wggmrlee3 小时前
模型训练流程
人工智能
2601_949816164 小时前
使用python进行PostgreSQL 数据库连接
数据库·python·postgresql
l1t4 小时前
在aarch64 Linux环境编译安装CinderX
linux·python
逆境不可逃4 小时前
【用AI学Agent】Agent入门前置:大模型基础(开发向)
人工智能·深度学习·机器学习
热爱生活的猴子4 小时前
PyTorch导出ONNX报错(ShapeInferenceError)问题笔记(含dynamo=False作用解析)
人工智能·pytorch·笔记
新缸中之脑4 小时前
用Kreuzberg提取文档结构
人工智能
Gauss松鼠会4 小时前
【GaussDB】GaussDB技术解读之AI大模型在智能运维场景的应用
运维·人工智能·gaussdb