Python训练营打卡Day42

知识点回顾

  1. 回调函数
  2. lambda函数
  3. hook函数的模块钩子和张量钩子
  4. Grad-CAM的示例

1. 回调函数(Callback Function)

回调函数是作为参数传递给另一个函数的函数,目的是在某个事件发生后执行。

python 复制代码
def fetch_data(callback):
    # 模拟数据获取
    data = {"name": "Alice", "age": 30}
    callback(data)

def process_data(data):
    print(f"处理数据: {data['name']}, {data['age']}岁")

# 使用回调函数
fetch_data(process_data)

2. Lambda 函数(匿名函数)

Lambda 函数是一种轻量级的匿名函数,适用于简单操作。

python 复制代码
# 常规函数
def add(a, b):
    return a + b

# 等效的lambda函数
add_lambda = lambda a, b: a + b

# 使用lambda函数
result = add_lambda(5, 3)
print(f"Lambda结果: {result}")

# 在高阶函数中使用lambda
numbers = [1, 2, 3, 4, 5]
squared = list(map(lambda x: x**2, numbers))
print(f"平方结果: {squared}")

3. Hook 函数

Hook 函数允许在不修改原始代码的情况下注入自定义逻辑,常见的有模块钩子和张量钩子。

python 复制代码
import torch
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(3, 16, kernel_size=3)
        self.relu = nn.ReLU()
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        return x

# 定义钩子函数
def activation_hook(module, input, output):
    print(f"{module.__class__.__name__}输出形状: {output.shape}")

model = MyModel()
# 注册钩子到ReLU模块
hook_handle = model.relu.register_forward_hook(activation_hook)

# 测试模型
x = torch.randn(1, 3, 32, 32)
output = model(x)

# 移除钩子
hook_handle.remove()
python 复制代码
import torch

# 创建张量并启用梯度
x = torch.tensor(2.0, requires_grad=True)
y = x**2

# 定义张量钩子
def print_grad(grad):
    print(f"梯度值: {grad}")

# 注册钩子
hook_handle = y.register_hook(print_grad)

# 反向传播
y.backward()

# 移除钩子
hook_handle.remove()

4. Grad-CAM 示例

Grad-CAM (Gradient-weighted Class Activation Mapping) 是一种可视化深度神经网络决策依据的技术。

python 复制代码
import torch
import torch.nn.functional as F
import numpy as np
import cv2
from torchvision import models, transforms
from PIL import Image
import matplotlib.pyplot as plt

# 加载预训练模型
model = models.resnet50(pretrained=True)
target_layer = model.layer4[-1]  # 最后一个卷积层

# 存储特征图和梯度
features = None
grads = None

# 特征钩子
def forward_hook(module, input, output):
    global features
    features = output.detach()

# 梯度钩子
def backward_hook(module, grad_in, grad_out):
    global grads
    grads = grad_out[0].detach()

# 注册钩子
hook_f = target_layer.register_forward_hook(forward_hook)
hook_b = target_layer.register_backward_hook(backward_hook)

# 预处理图像
def preprocess_image(img_path):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], 
                             std=[0.229, 0.224, 0.225])
    ])
    image = Image.open(img_path).convert('RGB')
    return transform(image).unsqueeze(0), image

# 加载图像
img_path = 'cat_dog.jpg'  # 替换为你的图像路径
input_tensor, orig_img = preprocess_image(img_path)

# 设置模型为评估模式
model.eval()

# 前向传播
output = model(input_tensor)
pred_class = output.argmax()

# 反向传播
model.zero_grad()
one_hot = torch.zeros_like(output)
one_hot[0, pred_class] = 1
output.backward(gradient=one_hot, retain_graph=True)

# 计算权重 (全局平均池化梯度)
weights = torch.mean(grads, dim=(2, 3), keepdim=True)

# 加权组合特征图
cam = torch.sum(weights * features, dim=1).squeeze()
cam = F.relu(cam)  # 应用ReLU去除负值

# 归一化
if torch.max(cam) > 0:
    cam = cam / torch.max(cam)

# 调整CAM尺寸与原图匹配
cam_np = cam.detach().cpu().numpy()
cam_resized = cv2.resize(cam_np, (orig_img.width, orig_img.height))

# 转换为热力图
heatmap = cv2.applyColorMap(np.uint8(255 * cam_resized), cv2.COLORMAP_JET)
heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)

# 叠加热力图到原图
superimposed_img = heatmap * 0.4 + np.array(orig_img)
superimposed_img = np.uint8(superimposed_img)

# 显示结果
plt.figure(figsize=(12, 4))
plt.subplot(131)
plt.imshow(orig_img)
plt.title('原始图像')
plt.axis('off')

plt.subplot(132)
plt.imshow(cam_resized, cmap='jet')
plt.title('激活映射')
plt.axis('off')

plt.subplot(133)
plt.imshow(superimposed_img)
plt.title('Grad-CAM结果')
plt.axis('off')

plt.tight_layout()
plt.show()

# 移除钩子
hook_f.remove()
hook_b.remove()

@浙大疏锦行

相关推荐
德育处主任Pro44 分钟前
『React』Fragment的用法及简写形式
前端·javascript·react.js
CodeBlossom1 小时前
javaweb -html -CSS
前端·javascript·html
CodeCraft Studio1 小时前
【案例分享】如何借助JS UI组件库DHTMLX Suite构建高效物联网IIoT平台
javascript·物联网·ui
朝新_1 小时前
【多线程初阶】阻塞队列 & 生产者消费者模型
java·开发语言·javaee
立莹Sir1 小时前
Calendar类日期设置进位问题
java·开发语言
打小就很皮...2 小时前
HBuilder 发行Android(apk包)全流程指南
前端·javascript·微信小程序
风逸hhh2 小时前
python打卡day46@浙大疏锦行
开发语言·python
火兮明兮3 小时前
Python训练第四十三天
开发语言·python
ascarl20103 小时前
准确--k8s cgroup问题排查
java·开发语言
dancing9994 小时前
cocos3.X的oops框架oops-plugin-excel-to-json改进兼容多表单导出功能
前端·javascript·typescript·游戏程序