使用pytorch搭建ResNet并基于迁移学习训练

这里的迁移学习方法是载入预训练权重的方法

python 复制代码
    net = resnet34()
    # load pretrain weights
    # download url: https://download.pytorch.org/models/resnet34-333f7ec4.pth
    model_weight_path = "./resnet34-pre.pth"
    assert os.path.exists(model_weight_path), "file {} does not exist.".format(model_weight_path)
    net.load_state_dict(torch.load(model_weight_path, map_location='cpu'))
    # for param in net.parameters():
    #     param.requires_grad = False

    # change fc layer structure
    in_channel = net.fc.in_features
    net.fc = nn.Linear(in_channel, 5)

这里的迁移学习方法是载入预训练权重的方法net = resnet34():注意这里没有传入参数num_classes 因为后面才载入所有的参数,会覆盖我们设定的classes

change fc layer structure

in_channel = net.fc.in_features # fc 为全连接层 in_features为特征矩阵的深度

net.fc = nn.Linear(in_channel, 5)

如果不想使用迁移学习的方法,则注释阴影部分,在net = resnet34()中传入num_classes参数

相关推荐
灵犀物润9 分钟前
机器宠物建模的第一步:基础形体搭建(Blocking)
人工智能·机器人·宠物
人机与认知实验室10 分钟前
触摸大语言模型的边界
人工智能·深度学习·机器学习·语言模型·自然语言处理
神的孩子都在歌唱25 分钟前
PostgreSQL 向量检索方式(pgvector)
数据库·人工智能·postgresql
ARM+FPGA+AI工业主板定制专家1 小时前
基于Jetson+GMSL AI相机的工业高动态视觉感知方案
人工智能·机器学习·fpga开发·自动驾驶
新智元1 小时前
刚刚,谷歌深夜上新 Veo 3.1!网友狂刷 2.75 亿条,Sora 2 要小心了
人工智能·openai
yuzhuanhei1 小时前
Segment Anything(SAM)
人工智能
做科研的周师兄1 小时前
【机器学习入门】7.4 随机森林:一文吃透随机森林——从原理到核心特点
人工智能·学习·算法·随机森林·机器学习·支持向量机·数据挖掘
lll上1 小时前
三步对接gpt-5-pro!地表强AI模型实测
人工智能·gpt
喜欢吃豆1 小时前
一份关于语言模型对齐的技术论述:从基于PPO的RLHF到直接偏好优化
人工智能·语言模型·自然语言处理·大模型·强化学习
超龄超能程序猿2 小时前
Spring AI Alibaba 与 Ollama对话历史的持久化
java·人工智能·spring