使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
Suryxin.20 小时前
从0开始复现nano-vllm「ModelRunner.capture_cudagraph()」
人工智能·pytorch·深度学习·vllm
武汉唯众智创20 小时前
云边端协同落地:唯众AI实训平台技术架构实操解析
人工智能·人工智能实训·ai 实训平台·职教 ai 实训·职教院校实训方案·高校职校实训方案
大猫子的技术日记21 小时前
Playwright 自动化测试入门指南:Python 开发者的端到端实战
开发语言·人工智能·python
数琨创享TQMS质量数智化21 小时前
数琨创享:以数智化质量目标管理闭环赋能可量化、可追溯、可驱动的质量运营
大数据·人工智能·qms质量管理系统
laplace012321 小时前
Kv cache
人工智能·agent·claude·rag·skills
Maynor99621 小时前
OpenClaw 中转站配置完全指南
linux·运维·服务器·人工智能·飞书
Eric22321 小时前
CLI-Agent-Manager:面向 Vibe Coding 的多 Agent 统一管理面板
人工智能·后端·开源
如若12321 小时前
SoftGroup训练FORinstance森林点云数据集——从零到AP=0.506完整复现
人工智能·python·深度学习·神经网络·计算机视觉
InternLM21 小时前
LMDeploy重磅更新:从支撑模型到被模型反哺,推理引擎迈入协同进化时代!
人工智能·大模型·多模态大模型·大模型推理·书生大模型
AI周红伟21 小时前
周红伟老师《企业级 RAG+Agent+Skills+OpenClaw 智能体实战内训》的完整课程大纲(5 天 / 40 小时)
人工智能