pytorch修改ConvNeXt-T网络

使用迁移学习,修改ConvNeXt-T网络,对特征进行融合

python 复制代码
import torch
import torch.nn as nn
import torchvision.models as models


class CustomConvNeXtT(nn.Module):
    def __init__(self, in_channels=3, num_classes=2, chunk=2, csv_shape=107, CSV=True):
        super(CustomConvNeXtT, self).__init__()
        self.chunk = chunk
        self.num_classes = num_classes
        self.CSV = CSV

        # 加载预训练的ConvNeXt-Tiny模型
        convnext = models.convnext_tiny(pretrained=True)

        # 冻结预训练模型的所有参数
        for name, param in convnext.named_parameters():
            param.requires_grad = False

        # 将修改后的模型赋值给自定义的ConvNeXt-T网络
        self.model = convnext

        # 修改第一个卷积层的输入通道数
        self.model.features[0][0] = nn.Conv2d(in_channels, 96, kernel_size=4, stride=4)

        # 获取特征提取器的输出特征维度
        num_ftrs = self.model.classifier[2].in_features

        # 修改分类头部
        self.model.classifier = nn.Sequential(
            nn.LayerNorm(num_ftrs * self.chunk + (csv_shape if CSV else 0), eps=1e-6, elementwise_affine=True),
            nn.Linear(num_ftrs * self.chunk + (csv_shape if CSV else 0), num_classes)
        )

    def extract_features(self, x):
        x = self.model.features(x)
        x = self.model.avgpool(x)
        x = torch.flatten(x, 1)
        return x

    def forward(self, data_DCE, data_T2, csv):
        data_DCE = self.extract_features(data_DCE)
        data_T2 = self.extract_features(data_T2)

        if not self.CSV:
            csv = torch.ones_like(csv)

        x = torch.cat((data_DCE, data_T2, csv), dim=1)
        print(f"Feature size after concatenation: {x.size()}")  # 打印特征拼接后的尺寸

        output = self.model.classifier(x)
        return output


if __name__ == '__main__':
    net = CustomConvNeXtT(in_channels=3, num_classes=2, chunk=2, csv_shape=107, CSV=True)
    for name, param in net.named_parameters():
        print(name, ":", param.requires_grad)

    data_DCE = torch.randn(64, 3, 224, 224)
    data_T2 = torch.randn(64, 3, 224, 224)
    csv = torch.randn(64, 107)

    output = net(data_DCE, data_T2, csv)
    print("输出特征尺寸:", output.size())
相关推荐
TechWayfarer10 分钟前
IP归属地API实战指南:用IP数据云解析日志挖掘用户地域分布
大数据·开发语言·网络·python·tcp/ip
Cloud_Shy61815 分钟前
Python 数据分析基础入门:《Excel Python:飞速搞定数据分析与处理》学习笔记系列(第十一章 Python 包跟踪器 中篇)
数据库·python·sql·数据分析·excel·web
端平入洛28 分钟前
Python 可变对象与引用穿透:为什么改了"里面的东西"外面也变了?
python
文歌子38 分钟前
TorchGeo 入门:用 PyTorch 处理遥感数据,从零搭建卫星图像分类模型
深度学习
woon38 分钟前
从“涂掉红色”到“删除 PDF 对象”:一次 PDF 去印章脚本改造实践
python
老纪1 小时前
c++怎么利用std--variant处理多种二进制子协议包的自动分支解析【进阶】
jvm·数据库·python
茗创科技1 小时前
Nat Hum Behav | 特征选择会导致基于脑影像的机器学习生物标志物产生迥异的神经生物学解释
python·深度学习·机器学习·matlab·脑网络
IT策士1 小时前
Django 从 0 到 1 打造完整电商平台:Django 模型进阶与数据迁移
python·django·sqlite
OsDepK1 小时前
AudioSplit音频多轨免费分离工具即将发布
ide·git·python·音视频·集成学习
dr_yingli2 小时前
MedGemma皮肤肿瘤6分类LLM fineturn流程
人工智能·深度学习