如何在 PyTorch 中冻结模型权重以进行迁移学习:分步教程

一、说明

迁移学习是一种机器学习技术,其中预先训练的模型适用于新的但类似的问题。迁移学习的关键步骤之一是能够冻结预训练模型的层,以便在训练期间仅更新网络的某些部分。当您想要保留预训练模型已经学习的特征时,冻结至关重要。在本教程中,我们将使用一个简单的示例来演示在 PyTorch 中冻结权重以进行迁移学习的过程。

二、先决条件

如果您没有安装 torch 和 torchvision 库,我们可以在终端中执行以下操作:

aba 复制代码
pip install torch torchvision 

三、导入库

让我们从 Python 代码开始。首先,我们导入本教程的库:

aba 复制代码
import torch
import torch.nn as nn
import torchvision.models as models

四、加载预训练模型

我们将在此示例中使用预训练的 ResNet-18 模型:

aba 复制代码
# Load the pre-trained model
resnet18 = models.resnet18(pretrained=True)

五、冻结层

要冻结图层,我们将requires_grad属性设置为False。这可以防止 PyTorch 在反向传播期间计算这些层的梯度。

aba 复制代码
# Freeze all layers
for param in resnet18.parameters():
    param.requires_grad = False

六、解冻一些层

通常,为了获得最佳结果,我们会对网络中的后续层进行一些微调。我们可以这样做:

aba 复制代码
# Unfreeze last layer
for param in resnet18.fc.parameters():
    param.requires_grad = True

七、修改网络架构

我们将替换最后一个全连接层,以使模型适应具有不同数量的输出类(假设有 10 个类)的新问题。此外,这使我们能够将这个预训练网络用于分类以外的其他应用,例如分割。对于分割,我们用卷积层替换最后一层。对于此示例,我们继续执行包含 10 个类别的分类任务。

aba 复制代码
# Replace last layer
num_ftrs = resnet18.fc.in_features
resnet18.fc = nn.Linear(num_ftrs, 10)

八、训练修改后的模型

让我们定义一个简单的训练循环。出于演示目的,我们将使用随机数据。

aba 复制代码
# Create random data
inputs = torch.randn(5, 3, 224, 224)
labels = torch.randint(0, 10, (5,))

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(resnet18.fc.parameters(), lr=0.001, momentum=0.9)

# Training loop
for epoch in range(5):
    optimizer.zero_grad()
    outputs = resnet18(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    print(f'Epoch {epoch+1}/5, Loss: {loss.item()}')  

在此示例中,训练期间仅更新最后一层的权重。

九、结论

在 PyTorch 中冻结层非常简单明了。通过将该requires_grad属性设置为False,您可以防止在训练期间更新特定层,从而使您能够有效地利用预训练模型的强大功能。

了解如何在 PyTorch 中冻结和解冻层对于有效的迁移学习至关重要,因为它允许您利用预训练的模型来执行类似但不同的任务。通过这种简单而强大的技术,您可以在训练深度神经网络时节省时间和计算资源。

参考资料:请访问此处、GithubLinkedIn礼萨·卡兰塔尔

相关推荐
2601_95578198几秒前
HTML5 静态网站搭建 依托 OpenClaw 完成设计与部署
人工智能·教程分享·open claw部署·open claw本地部署
王莎莎-MinerU2 分钟前
从 OCR 到 Context Engineering:用 MinerU 搭一个可复现文档解析评测
人工智能·深度学习·机器学习·pdf·ocr·个人开发
叫我:松哥2 分钟前
基于卷积神经网络的静态手势语识别算法,在测试集上的识别准确率达到97.5%
人工智能·python·深度学习·神经网络·算法·cnn
ZHW_AI课题组3 分钟前
基于KNN的帕尔默企鹅种类预测分类
人工智能·机器学习·分类·数据挖掘
财迅通Ai4 分钟前
探路者:锚定端侧AI压缩黄金赛道,硬核科技开启成长新周期
人工智能·科技·探路者
小马哥crazymxm4 分钟前
Arxiv论文周选 (2026-W23)
论文阅读·人工智能·科技
虎妞05004 分钟前
PyTorch 2.0 生产级部署与性能优化指南
pytorch·深度学习·ai·模型部署·cuda
独自归家的兔5 分钟前
Claude Fable 5 与 Claude Mythos 5 全面解析及定价策略分析
人工智能·深度学习
BD好产品8 分钟前
2026年度AI接口聚合方案实测复盘:从多模型混战看企业级工程选型
人工智能
YOLO数据集集合9 分钟前
智能道路病害识别 公路巡检深度学习数据集实战 | 路面缺陷检测 无人机视觉 道路养护AI方案10299期
人工智能·深度学习·目标检测·无人机