使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
sanduo1128 分钟前
AI 原生(AI-Native)&架构极简主义
人工智能·架构·ai-native
zhengfei61133 分钟前
【开源渗透工具】——一个开源的多模态大型语言模型红队框架OpenRT
人工智能·语言模型·开源
WJSKad123536 分钟前
工业零件识别与分类:基于lad_r50-paa-r101_fpn_2xb8_coco_1x模型实现
人工智能·分类·数据挖掘
千汇数据的老司机44 分钟前
靠资源拿项目VS靠技术拿项目,二者的深刻区分。
大数据·人工智能·谈单
聚城云-GeecityCloud1 小时前
物业行业:在矛盾与转型中回归服务本质
人工智能·数据挖掘·回归
a3158238061 小时前
基于大语言模型的新闻判断技术
人工智能·语言模型·自然语言处理
亚里随笔1 小时前
超越LoRA:参数高效强化学习方法的全面评估与突破
人工智能·深度学习·机器学习·lora·rl
computersciencer2 小时前
机器学习入门:什么是机器学习
人工智能·机器学习
Java后端的Ai之路2 小时前
【机器学习】- CatBoost模型参数详细说明
人工智能·机器学习·catboost·模型参数
java1234_小锋2 小时前
AI蒸馏技术:让AI更智能、更高效
人工智能·ai·ai蒸馏