使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
小锋学长生活大爆炸1 分钟前
【软件】AI Agent:无需电脑的手机自动化助手AutoGLM
运维·人工智能·智能手机·自动化·手机·agent·autoglm
ar01232 分钟前
AR巡检私有化本地化部署:企业数字化转型的关键一步
人工智能·ar
Hcoco_me5 分钟前
大模型面试题39:KV Cache 完全指南
人工智能·深度学习·自然语言处理·transformer·word2vec
小途软件5 分钟前
基于计算机视觉的课堂行为编码研究
人工智能·python·深度学习·计算机视觉·语言模型·自然语言处理·django
盼小辉丶5 分钟前
PyTorch实战——pix2pix详解与实现
pytorch·深度学习·生成模型
小途软件6 分钟前
基于计算机视觉的桥梁索力测试方法
人工智能·python·语言模型·自然语言处理·django
拓端研究室6 分钟前
2025医疗人工智能报告:AI应用、IVD市场、健康科技|附240+份报告PDF、数据、可视化模板汇总下载
大数据·人工智能·物联网
咚咚王者7 分钟前
人工智能之核心基础 机器学习 第七章 监督学习总结
人工智能·学习·机器学习
2501_941507947 分钟前
【人工智能】基于YOLO11-C3k2-LFE模型的LED灯目标检测与识别系统研究
人工智能·目标检测·计算机视觉
不爱学英文的码字机器8 分钟前
用 openJiuwen 构建 AI Agent:从 Hello World 到毒舌编辑器
人工智能·redis·编辑器