PyTorch小技巧:使用Hook可视化网络层激活(各层输出)

这篇文章将演示如何可视化PyTorch激活层。可视化激活,即模型内各层的输出,对于理解深度神经网络如何处理视觉信息至关重要,这有助于诊断模型行为并激发改进。

我们先安装必要的库:

复制代码
 pip install torch torchvision matplotlib

加载CIFAR-10数据集并可视化一些图像。这有助于理解模型处理的输入。

复制代码
 importtorchvision
 importtorchvision.transformsastransforms
 importmatplotlib.pyplotasplt
 
 # Transformations for the images
 transform=transforms.Compose([
     transforms.Resize(256),
     transforms.CenterCrop(224),
     transforms.ToTensor(),
     transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
 ])
 
 # Load CIFAR-10 dataset
 trainset=torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
 trainloader=torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True)
 
 # Function to show images
 defimshow(img):
     img=img.numpy().transpose((1, 2, 0))
     mean=np.array([0.485, 0.456, 0.406])
     std=np.array([0.229, 0.224, 0.225])
     img=std*img+mean  # unnormalize
     plt.imshow(img)
     plt.show()
 
 # Get some images
 dataiter=iter(trainloader)
 images, labels=next(dataiter)
 
 # Display images
 imshow(torchvision.utils.make_grid(images))

看着很模糊的原因是我们使用的CIFAR-10图像32x32的,很小 。因为对于小图像,处理速度很快,所以CIFAR-10称为研究的首选。

然后我们加载一个预训练的ResNet模型,并在特定的层上设置钩子函数,以在向前传递期间捕获激活。

复制代码
 import torch
 from torchvision.models import resnet18
 
 # Load pretrained ResNet18
 model = resnet18(pretrained=True)
 model.eval()  # Set the model to evaluation mode
 
 # Hook setup
 activations = {}
 def get_activation(name):
     def hook(model, input, output):
         activations[name] = output.detach()
     return hook
 
 # Register hooks
 model.layer1[0].conv1.register_forward_hook(get_activation('layer1_0_conv1'))
 model.layer4[0].conv1.register_forward_hook(get_activation('layer4_0_conv1'))

这样,在通过模型处理图像时就能捕获到激活。

复制代码
 # Run the model
 with torch.no_grad():
     output = model(images)

通过上面钩子函数我们获得了激活下面就可以进行可视化

复制代码
 # Visualization function for activations
 def plot_activations(layer, num_cols=4, num_activations=16):
     num_kernels = layer.shape[1]
     fig, axes = plt.subplots(nrows=(num_activations + num_cols - 1) // num_cols, ncols=num_cols, figsize=(12, 12))
     for i, ax in enumerate(axes.flat):
         if i < num_kernels:
             ax.imshow(layer[0, i].cpu().numpy(), cmap='twilight')
             ax.axis('off')
     plt.tight_layout()
     plt.show()
 # Display a subset of activations
 plot_activations(activations['layer1_0_conv1'], num_cols=4, num_activations=16)

结果如下:

复制代码
 plot_activations(activations['layer4_0_conv1'], num_cols=4, num_activations=16)

PyTorch的钩子函数(hooks)是一种非常有用的特性,它们允许你在训练的前向传播和反向传播过程中插入自定义操作。这对于调试、修改梯度或者理解网络的内部运作非常有帮助。

利用 PyTorch 钩子函数来可视化网络中的激活是一种很好的方式,尤其是想要理解不同层如何响应不同输入的情况下。在这个过程中,我们可以捕捉到网络各层的输出,并将其可视化以获得直观的理解。

可视化激活有助于理解卷积神经网络中的各个层如何响应输入图像中的不同特征。通过可视化不同的层,可以评估早期层是否捕获边缘和纹理等基本特征,而较深的层是否捕获更复杂的特征。这些知识对于诊断问题、调整层架构和改进整体模型性能是非常宝贵的。

https://avoid.overfit.cn/post/c63b9b1130fe425ea5b7d0bedf209b2e

相关推荐
Jay Kay1 分钟前
ReLU 新生:从死亡困境到强势回归
人工智能·数据挖掘·回归
Blossom.11811 分钟前
使用Python和Flask构建简单的机器学习API
人工智能·python·深度学习·目标检测·机器学习·数据挖掘·flask
无声旅者43 分钟前
AI 模型分类全解:特性与选择指南
人工智能·ai·ai大模型
Love__Tay1 小时前
【学习笔记】Python金融基础
开发语言·笔记·python·学习·金融
Grassto1 小时前
Cursor Rules 使用
人工智能
MYH5161 小时前
深度学习在非线性场景中的核心应用领域及向量/张量数据处理案例,结合工业、金融等领域的实际落地场景分析
人工智能·深度学习
Lilith的AI学习日记1 小时前
什么是预训练?深入解读大模型AI的“高考集训”
开发语言·人工智能·深度学习·神经网络·机器学习·ai编程
聚客AI2 小时前
PyTorch玩转CNN:卷积操作可视化+五大经典网络复现+分类项目
人工智能·pytorch·神经网络
程序员岳焱2 小时前
深度剖析:Spring AI 与 LangChain4j,谁才是 Java 程序员的 AI 开发利器?
java·人工智能·后端
有风南来2 小时前
算术图片验证码(四则运算)+selenium
自动化测试·python·selenium·算术图片验证码·四则运算验证码·加减乘除图片验证码