使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
小程故事多_804 分钟前
LangChain1.0系列:中间件深度解析,让 AI智能体上下文控制不失控
人工智能·中间件·langchain
中国国际健康产业博览会28 分钟前
2026第35届中国国际健康产业博览会探索健康与科技的完美结合!
大数据·人工智能
数字化脑洞实验室37 分钟前
选择AI决策解决方案需要注意哪些安全和数据隐私问题?
人工智能·安全
Guheyunyi42 分钟前
安全风险监测系统核心技术
运维·网络·人工智能·安全
golang学习记1 小时前
再见了,claude code
人工智能
杀生丸学AI1 小时前
【动态高斯重建】论文集合:从4DGT到OMG4、4DSioMo
人工智能·3d·aigc·三维重建·视觉大模型·动态高斯
CareyWYR1 小时前
每周AI论文速递(251110-251114)
人工智能
mit6.8241 小时前
[AI tradingOS] 市场数据系统 | 多交易所交易接口 | 适配器模式
人工智能·区块链
ar01231 小时前
AR远程协助公司哪家好?国内外优秀AR技术公司解析
人工智能·ar
zhishidi1 小时前
大模型个性化推荐面试指南
人工智能·面试