使用pytorch,冻结resnet50前几层进行迁移学习

在PyTorch中,冻结ResNet50模型的前几层可以通过以下步骤进行:

python 复制代码
import torch
import torchvision.models as models

# 加载预训练的ResNet50模型
model = models.resnet50(pretrained=True)

# 冻结需要保持不变的层,通常是前几个卷积层
for name, param in model.named_parameters():
    if 'conv1' in name or 'bn1' in name or 'layer1' in name or 'layer2' in name:
        param.requires_grad = False

# 修改最后一层进行微调
num_classes = 10  # 假设输出类别数为10
model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

# 将模型移到GPU上(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# 编译和训练模型
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# 训练模型
for epoch in range(num_epochs):
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs = inputs.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()

        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        loss.backward()
        optimizer.step()

        running_loss += loss.item()

    # 打印每个epoch的损失值
    print(f"Epoch {epoch+1} Loss: {running_loss/len(train_loader)}")

在这个例子中,我们加载了预训练的ResNet50模型,并将指定的层参数设置为不需要梯度更新。具体来说,我们冻结了conv1bn1layer1layer2这些层的参数。然后,通过修改最后一层(全连接层)来适应自己的数据集。接下来,将模型移动到GPU上(如果可用),定义损失函数和优化器,并进行模型训练。

请根据你自己的数据集和任务适当调整代码。

相关推荐
AI完全体4 分钟前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
GZ_TOGOGO16 分钟前
【2024最新】华为HCIE认证考试流程
大数据·人工智能·网络协议·网络安全·华为
sp_fyf_202416 分钟前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
新缸中之脑18 分钟前
Ollama 运行视觉语言模型LLaVA
人工智能·语言模型·自然语言处理
胡耀超1 小时前
知识图谱入门——3:工具分类与对比(知识建模工具:Protégé、 知识抽取工具:DeepDive、知识存储工具:Neo4j)
人工智能·知识图谱
陈苏同学1 小时前
4. 将pycharm本地项目同步到(Linux)服务器上——深度学习·科研实践·从0到1
linux·服务器·ide·人工智能·python·深度学习·pycharm
吾名招财1 小时前
yolov5-7.0模型DNN加载函数及参数详解(重要)
c++·人工智能·yolo·dnn
羊小猪~~2 小时前
深度学习项目----用LSTM模型预测股价(包含LSTM网络简介,代码数据均可下载)
pytorch·python·rnn·深度学习·机器学习·数据分析·lstm
鼠鼠龙年发大财2 小时前
【鼠鼠学AI代码合集#7】概率
人工智能
龙的爹23332 小时前
论文 | Model-tuning Via Prompts Makes NLP Models Adversarially Robust
人工智能·gpt·深度学习·语言模型·自然语言处理·prompt