使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
数智工坊几秒前
周志华《Machine Learning》学习笔记--第五章--神经网络
人工智能·笔记·神经网络·学习·机器学习
虹科网络安全4 分钟前
艾体宝产品|从知识孤岛到智能知识中心:Arango 如何重塑企业知识图谱
人工智能·知识图谱·arango
189228048618 分钟前
NV041固态MT29F16T08GSLCEM9-QBES:C
人工智能·算法·microsoft·缓存·性能优化
STRUGGLE_xlf10 分钟前
Agent 基础
人工智能·agent
博览鸿蒙11 分钟前
[特殊字符]AI+FPGA 全栈学习大纲【就业版】定位
人工智能·学习·fpga开发
极客侃科技11 分钟前
哪款AI的同声传译好用?天禧AI 4.0多语种同传表现出众
人工智能
雪隐12 分钟前
AI股票小助手05-用 Flask 把 MiniQMT 变成 REST API
人工智能·后端
霸道流氓气质18 分钟前
Spring AI Ollama 连接超时问题排查与解决:OkHttp 读超时配置全指南
人工智能·spring·okhttp
道友可好23 分钟前
Spec Kit:GitHub 官方出品,规范即代码
前端·人工智能·后端
weixin_5051544631 分钟前
打通工业安全治理“最后一公分”:Bowell 发布 Runtime 治理平台
大数据·人工智能·安全·3d·数字孪生·数据可视化