论文学习30:LViT: Language Meets Vision Transformerin Medical Image Segmentation

代码来源

https://github.com/HUANGLIZI/LViT

模块作用

深度学习已广泛应用于医学图像分割等领域。然而,由于数据标注成本过高,获取充足的高质量标注数据成为限制现有医学图像分割模型性能的关键。为了缓解这一限制,本文提出了一种新的文本增强医学图像分割模型 LViT(Language Meets Vision Transformer)。在这个 LViT 模型中,引入了医学文本标注来弥补图像数据的质量缺陷。此外,文本信息可以指导在半监督学习中生成更高质量的伪标签。文中还提出了一种指数伪标签迭代机制 (EPI),以帮助像素级注意力模块 (PLAM) 在半监督 LViT 环境中保留局部图像特征。

模块结构

LViT 采用 双 U 形结构,结合 U 形 CNN 分支(提取局部特征)和 U 形 ViT 分支(融合全局及跨模态信息),克服纯 CNN 全局上下文不足和纯 Transformer 计算成本高、忽略局部细节的局限。文本通过简单的嵌入层(BERT-Embed)处理,而非使用完整文本编码器,减少参数量。

模块代码

复制代码
class LViT(nn.Module):
    def __init__(self, config, n_channels=3, n_classes=1, img_size=224, vis=False):
        super().__init__()
        self.vis = vis
        self.n_channels = n_channels
        self.n_classes = n_classes
        in_channels = config.base_channel
        self.inc = ConvBatchNorm(n_channels, in_channels)
        self.downVit = VisionTransformer(config, vis, img_size=224, channel_num=64, patch_size=16, embed_dim=64)
        self.downVit1 = VisionTransformer(config, vis, img_size=112, channel_num=128, patch_size=8, embed_dim=128)
        self.downVit2 = VisionTransformer(config, vis, img_size=56, channel_num=256, patch_size=4, embed_dim=256)
        self.downVit3 = VisionTransformer(config, vis, img_size=28, channel_num=512, patch_size=2, embed_dim=512)
        self.upVit = VisionTransformer(config, vis, img_size=224, channel_num=64, patch_size=16, embed_dim=64)
        self.upVit1 = VisionTransformer(config, vis, img_size=112, channel_num=128, patch_size=8, embed_dim=128)
        self.upVit2 = VisionTransformer(config, vis, img_size=56, channel_num=256, patch_size=4, embed_dim=256)
        self.upVit3 = VisionTransformer(config, vis, img_size=28, channel_num=512, patch_size=2, embed_dim=512)
        self.down1 = DownBlock(in_channels, in_channels * 2, nb_Conv=2)
        self.down2 = DownBlock(in_channels * 2, in_channels * 4, nb_Conv=2)
        self.down3 = DownBlock(in_channels * 4, in_channels * 8, nb_Conv=2)
        self.down4 = DownBlock(in_channels * 8, in_channels * 8, nb_Conv=2)
        self.up4 = UpblockAttention(in_channels * 16, in_channels * 4, nb_Conv=2)
        self.up3 = UpblockAttention(in_channels * 8, in_channels * 2, nb_Conv=2)
        self.up2 = UpblockAttention(in_channels * 4, in_channels, nb_Conv=2)
        self.up1 = UpblockAttention(in_channels * 2, in_channels, nb_Conv=2)
        self.outc = nn.Conv2d(in_channels, n_classes, kernel_size=(1, 1), stride=(1, 1))
        self.last_activation = nn.Sigmoid()  # if using BCELoss
        self.multi_activation = nn.Softmax()
        self.reconstruct1 = Reconstruct(in_channels=64, out_channels=64, kernel_size=1, scale_factor=(16, 16))
        self.reconstruct2 = Reconstruct(in_channels=128, out_channels=128, kernel_size=1, scale_factor=(8, 8))
        self.reconstruct3 = Reconstruct(in_channels=256, out_channels=256, kernel_size=1, scale_factor=(4, 4))
        self.reconstruct4 = Reconstruct(in_channels=512, out_channels=512, kernel_size=1, scale_factor=(2, 2))
        self.pix_module1 = PixLevelModule(64)
        self.pix_module2 = PixLevelModule(128)
        self.pix_module3 = PixLevelModule(256)
        self.pix_module4 = PixLevelModule(512)
        self.text_module4 = nn.Conv1d(in_channels=768, out_channels=512, kernel_size=3, padding=1)
        self.text_module3 = nn.Conv1d(in_channels=512, out_channels=256, kernel_size=3, padding=1)
        self.text_module2 = nn.Conv1d(in_channels=256, out_channels=128, kernel_size=3, padding=1)
        self.text_module1 = nn.Conv1d(in_channels=128, out_channels=64, kernel_size=3, padding=1)

    def forward(self, x, text):
        x = x.float()  # x [4,3,224,224]
        x1 = self.inc(x)  # x1 [4, 64, 224, 224]
        text4 = self.text_module4(text.transpose(1, 2)).transpose(1, 2) 
        text3 = self.text_module3(text4.transpose(1, 2)).transpose(1, 2)
        text2 = self.text_module2(text3.transpose(1, 2)).transpose(1, 2)
        text1 = self.text_module1(text2.transpose(1, 2)).transpose(1, 2)
        y1 = self.downVit(x1, x1, text1)
        x2 = self.down1(x1)
        y2 = self.downVit1(x2, y1, text2)
        x3 = self.down2(x2)
        y3 = self.downVit2(x3, y2, text3)
        x4 = self.down3(x3)
        y4 = self.downVit3(x4, y3, text4)
        x5 = self.down4(x4)
        y4 = self.upVit3(y4, y4, text4, True)
        y3 = self.upVit2(y3, y4, text3, True)
        y2 = self.upVit1(y2, y3, text2, True)
        y1 = self.upVit(y1, y2, text1, True)
        x1 = self.reconstruct1(y1) + x1
        x2 = self.reconstruct2(y2) + x2
        x3 = self.reconstruct3(y3) + x3
        x4 = self.reconstruct4(y4) + x4
        x = self.up4(x5, x4)
        x = self.up3(x, x3)
        x = self.up2(x, x2)
        x = self.up1(x, x1)
        if self.n_classes == 1:
            logits = self.last_activation(self.outc(x))
        else:
            logits = self.outc(x)  # if not using BCEWithLogitsLoss or class>1
        return logits

总结

本文提出了一种新的视觉语言医学图像分割模型LViT,该模型利用医学文本标注来弥补图像数据的质量缺陷,并在半监督学习中指导生成更高质量的伪标签。为了评估LViT的性能,研究人员构建了多模态医学分割数据集(图像+文本),实验结果表明,这个模型在全监督和半监督环境下均具有卓越的分割性能。此外,研究人员还提供了一个关于早期食管癌诊疗的示例应用,以展示文本标注如何在实际场景中发挥作用。目前,这个模型是一个二维分割模型。在未来的工作中,研究人员将把模型扩展到三维,并在更多医学数据上进行实验,以进一步验证其通用性。此外,在当前版本的LViT 模型中,需要在推理阶段提供文本输入。因此,未来的工作可以是根据提供的图像信息自动生成文本注释。由于文本注释是结构化的,可以在 LViT 的未来版本中将文本注释生成问题转化为分类问题。这能够支持有或无文本输入的推理,从而增强模型的可用性。

相关推荐
海天一色y7 小时前
动手学深度学习
人工智能·深度学习
大熊背7 小时前
(3dnr)多帧视频图像去噪 (二)
计算机视觉·音视频·isp·3dnr
爱隐身的官人8 小时前
爬虫基础学习 - Beautifulsoup
爬虫·学习·beautifulsoup
diablobaal8 小时前
云计算学习100天-第44天-部署邮件服务器
学习·云计算
dbdr09018 小时前
Linux 入门到精通,真的不用背命令!零基础小白靠「场景化学习法」,3 个月拿下运维 offer,第二十四天
linux·运维·c语言·python·学习
diablobaal9 小时前
云计算学习100天-第34天 -zabbix监控2
学习·云计算·zabbix
艾伦~耶格尔9 小时前
【MySQL体系结构详解:一条SQL查询的旅程】
数据库·学习·mysql
fanged10 小时前
HarvardX TinyML小笔记2(番外3:数据工程)
深度学习
小王爱学人工智能10 小时前
OpenCV的轮廓检测
人工智能·opencv·计算机视觉