使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
AI前沿资讯10 分钟前
2026 AI 3D工具推荐:V2Fun如何重新定义“一站式角色创作”
人工智能·3d
水上冰石11 分钟前
Vibe Coding即氛围编程,直觉编程概念介绍
人工智能
Xxtaoaooo24 分钟前
用 JiuwenSwarm 搭建论文写作 Agent 团队:文献检索、大纲生成、语法润色与引用格式避坑
人工智能·论文写作·智能体·jiuwenswarm·agent 团队
云边云科技_云网融合31 分钟前
企业出海的 “数字丝绸之路“:SD-WAN 如何重构全球网络竞争力
大数据·运维·网络·人工智能
超级架构师1 小时前
Huiwen Han — Preprints Public Inventory v10.15
人工智能
技术小黑1 小时前
CNN算法实战系列03 | DenseNet121算法实战与解析
pytorch·深度学习·算法·cnn
189228048612 小时前
NV243美光MT29F32T08GWLBHD6-24QJES:B
大数据·服务器·人工智能·科技·缓存
z小猫不吃鱼2 小时前
02 Transformer 基础:Self-Attention 原理详解
人工智能·深度学习·transformer
是Dream呀2 小时前
vLLM适配昇腾NPU:DeepSeek-V3 PD分离部署完整流程
人工智能
Java后端的Ai之路2 小时前
CodeBuddy-Rules配置
人工智能·python·ai编程