Pytorch--Hooks For Module

文章目录


1.register_module_forward_pre_hook

在 PyTorch 中,register_module_forward_pre_hook 是一个方法,用于向模型的模块注册前向传播预钩子(forward pre-hook)。预钩子是在模块的前向传播之前被调用的函数,允许在模块接收输入之前对输入进行修改或记录。

c 复制代码
import torch
import torch.nn as nn

# 定义一个前向传播预钩子函数
def forward_pre_hook(module, input):
    print("Forward pre-hook called for module:", module)
    print("Input shape:", input[0].shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册前向传播预钩子
model.register_module_forward_pre_hook(forward_pre_hook)

# 输入数据
input_data = torch.randn(1, 10)

# 前向传播
output = model(input_data)
python 复制代码
Forward pre-hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])

2.register_module_forward_hook

在 PyTorch 中,register_module_forward_hook 是一个方法,用于向模型的模块注册前向传播钩子(forward hook)。钩子是在模块的前向传播过程中被调用的函数,可以用于获取中间特征、对特征进行修改或记录等操作。

python 复制代码
import torch
import torch.nn as nn

# 定义一个前向传播钩子函数
def forward_hook(module, input, output):
    print("Forward hook called for module:", module)
    print("Input shape:", input[0].shape)
    print("Output shape:", output.shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册前向传播钩子
model.register_forward_hook(forward_hook)

# 输入数据
input_data = torch.randn(1, 10)

# 前向传播
output = model(input_data)
python 复制代码
Forward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 10])

3.register_module_backward_hook

在 PyTorch 中,register_module_backward_hook 是一个方法,用于向模型的模块注册反向传播钩子(backward hook)。钩子是在模块的反向传播过程中被调用的函数,可以用于获取梯度、对梯度进行修改或记录等操作。

python 复制代码
import torch
import torch.nn as nn

# 定义一个反向传播钩子函数
def backward_hook(module, grad_input, grad_output):
    print("Backward hook called for module:", module)
    print("Grad input shape:", grad_input[0].shape)
    print("Grad output shape:", grad_output[0].shape)

# 创建一个模型类
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()
        self.linear = nn.Linear(10, 10)

    def forward(self, x):
        return self.linear(x)

# 创建模型实例
model = MyModel()

# 注册反向传播钩子
model.register_backward_hook(backward_hook)

# 输入数据
input_data = torch.randn(1, 10)
target = torch.randn(1, 10)

# 前向传播和反向传播
output = model(input_data)
loss = nn.MSELoss()(output, target)
loss.backward()
python 复制代码
Backward hook called for module: Linear(in_features=10, out_features=10, bias=True)
Grad input shape: torch.Size([1, 10])
Grad output shape: torch.Size([1, 10])

相关推荐
文火冰糖的硅基工坊4 分钟前
[嵌入式系统-100]:常见的IoT(物联网)开发板
人工智能·物联网·架构
刘晓倩32 分钟前
实战任务二:用扣子空间通过任务提示词制作精美PPT
人工智能
shut up35 分钟前
LangChain - 如何使用阿里云百炼平台的Qwen-plus模型构建一个桌面文件查询AI助手 - 超详细
人工智能·python·langchain·智能体
Hy行者勇哥37 分钟前
公司全场景运营中 PPT 的类型、功能与作用详解
大数据·人工智能
宝贝儿好1 小时前
【python】第五章:python-GUI编程
python·pyqt
FIN66681 小时前
昂瑞微:实现精准突破,攻坚射频“卡脖子”难题
前端·人工智能·安全·前端框架·信息与通信
FIN66681 小时前
昂瑞微冲刺科创板:硬科技与资本市场的双向奔赴
前端·人工智能·科技·前端框架·智能
m0_677034351 小时前
机器学习-推荐系统(下)
人工智能·机器学习
XIAO·宝1 小时前
深度学习------专题《神经网络完成手写数字识别》
人工智能·深度学习·神经网络
流年染指悲伤、1 小时前
2024年最新技术趋势分析:AI、前端与后端开发新动向
人工智能·前端开发·后端开发·2024·技术趋势