使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
猿人谷14 小时前
不只是 CPU 阈值:STAR 如何用 GAT + Transformer 做容器级自动扩缩容?
人工智能·算法
说了很好15 小时前
PyTorch从零搭建DDPM:时间嵌入+UNet网络+扩散调度完整复现
人工智能
Bigfish_coding15 小时前
前端转agent-【python】-06 长期记忆(向量数据库 + 嵌入)
人工智能
小林ixn15 小时前
别再手写Prompt了!用AI Loop实现自动化自我迭代,效率提升10倍
人工智能·自动化运维
说了很好15 小时前
逐行注释DDPM源码:正向加噪、逆向去噪、MSE损失全流程复现
人工智能
Dilee15 小时前
Spring AI 1.1.7 接入 MCP:Filesystem Server 最小 Demo
人工智能·后端
Token炼金师15 小时前
大模型推理超参数原理详解
人工智能
Token炼金师15 小时前
大模型训练超参数:从Loss曲面到收敛策略的底层逻辑
人工智能
后端小肥肠15 小时前
Skill 囤了一堆却用不起来?我用 Codex 写了个整理神器
人工智能·agent
魏祖潇15 小时前
从"会聊天"到"能干活":用 OpenCode 给自己找个 AI 搭子
人工智能