使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
咚咚王者几秒前
人工智能之数据分析 Pandas:第四章 常用函数
人工智能·数据分析·pandas
菩提树下的凡夫3 分钟前
Yolov11的空标注负样本技术在模型训练中的应用
人工智能·深度学习·yolo
夕小瑶4 分钟前
DeepSeek V3.2的隐藏更新,却意外暴露了MiniMax
人工智能
kebijuelun5 分钟前
Nemotron-Flash: Towards Latency-Optimal Hybrid Small Language Models
人工智能·语言模型·自然语言处理
三炭先生6 分钟前
计算机视觉算法--第一章:概述
人工智能·算法·计算机视觉
唯道行8 分钟前
计算机图形学·21 梁友栋-Barsky直线裁剪算法与三维直线裁剪
人工智能·算法·机器学习·计算机视觉·计算机图形学·opengl
阿杰学AI10 分钟前
AI核心知识32——大语言模型之多模态语音(简洁且通俗易懂版)
人工智能·ai·语言模型·自然语言处理·aigc·语音识别·多模态语音
九河云11 分钟前
智能家居生态数字化:设备联动场景化编程与用户习惯学习系统建设
人工智能·学习·智能家居
阿恩.77014 分钟前
国际会议:评职称、申博、考研的硬核加分项
人工智能·经验分享·笔记·计算机网络·能源
严文文-Chris14 分钟前
【机器学习三大范式对比总结】
人工智能·机器学习