使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
Anycall.Q几秒前
RE-IMAGEN(ICLR 2023)
人工智能·计算机视觉·imagen
CS创新实验室12 分钟前
AI 领域的 Harness Engineering:概念、实践与前景综述
人工智能·机器学习·aigc·harness
Gary jie26 分钟前
OpenClaw4月更新的梦境记忆巩固系统
人工智能·深度学习·opencv·目标检测·机器学习·长短时记忆网络
beyond阿亮26 分钟前
Claude Code零基础入门安装使用指南
人工智能·ai·claude code
赵侃侃爱分享26 分钟前
AI怎么定义网络安全
人工智能·安全·web安全
ZhiqianXia27 分钟前
Pytorch 学习笔记(8): PyTorch FX
pytorch·笔记·学习
key_3_feng30 分钟前
MCP协议:解锁AI模型与外部世界的高效协作
大数据·人工智能·mcp
Linux猿30 分钟前
高通量藻类细胞检测数据集,YOLO目标检测|附数据集下载
人工智能·yolo·目标检测·目标跟踪·yolo目标检测·yolo目标检测数据集·高通量藻类细胞检测数据集
薛定猫AI32 分钟前
【技术干货】用 design.md 驯服 AI 生成前端:从 Awesome Design 到工程化落地实践
前端·人工智能
枫叶林FYL34 分钟前
第1章 具身智能的本质与哲学基础
人工智能·机器学习