PyTorch神经网络打印存储所有权重+激活值(运行时中间值)

很多时候嵌入式或者新硬件需要纯净的权重模型和激活值(运行时中间值),本文提供一种最简洁的方法。

假设已经有模型model和pt文件了,在当前目录下新建weights文件夹,运行这段代码,就可以得到模型的权重(文本形式和二进制形式)

python 复制代码
model.load_state_dict(state_dict)

global_index = 0
for name, param in model.named_parameters():
    print(name, param.size())
    print(param.data.numpy(),file=open(f"weights/{global_index}-{name}.txt", "w"))
    param.data.numpy().tofile(f"weights/{global_index}-{name}.bin")
    global_index += 1

对于二进制形式的文件,可以通过od -t f4 <binary file name> 查看其对应的浮点数值。f4表示fp32.

打印forward的中间值:(这么复杂是必要的)

python3 复制代码
global_index = 0
def hook_fn(module, input, output):
    global global_index
    module_name = str(module)
    module_name=module_name.replace(" ", "")
    module_name=module_name.replace("\n", "")
    # print(name)
    intermediate_outputs = {}
    # input is a tuple, output is a tensor
    for i, inp in enumerate(input):
        intermediate_outputs[f"{global_index}-{module_name}-input-{i}"] = inp
    intermediate_outputs[f"{global_index}-{module_name}-output"] = output
    module_name = module_name[0:200]  # make sure full path <= 255
    print(intermediate_outputs)
    print(f"Size input:",end=" ")
    if(type(input) == tuple):
        for i, inp in enumerate(input):
            if type(inp) == torch.Tensor:
                print(f"{i}-th Size: {inp.size()}", end=", ")
                inp.numpy().tofile(f"activations/{global_index}-{module_name}-input-{i}.bin")
            else:
                print(f"{i}-th : {inp}", end=", ")
    elif type(input) == torch.Tensor:
        print(f"Size: {input.size()}")
        input.numpy().tofile(f"activations/{global_index}-{module_name}-input.bin")
    print(f"Size output: {output.size()}")
    global_index += 1
    output.numpy().tofile(f"activations/{global_index}-{module_name}-output.bin")

def register_hooks(model):
    for name, layer in model.named_children():
        # print(name, layer) # dump all layers, > layers.txt
        # Register the hook to the current layer
        layer.register_forward_hook(hook_fn)
        # Recursively apply the same to all submodules
        register_hooks(layer)

register_hooks(model)

其中regster_hooks和以下等价(不需要recursive了)

python3 复制代码
def register_hooks(model):
    for name, layer in model.named_modules():
        # print(name, layer) # dump all layers
        layer.register_forward_hook(hook_fn)

其中nn.sequential作为一个整体,目前没办法拆开来看其内部的中间值。

相关推荐
春末的南方城市26 分钟前
FLUX的ID保持项目也来了! 字节开源PuLID-FLUX-v0.9.0,开启一致性风格写真新纪元!
人工智能·计算机视觉·stable diffusion·aigc·图像生成
zmjia11128 分钟前
AI大语言模型进阶应用及模型优化、本地化部署、从0-1搭建、智能体构建技术
人工智能·语言模型·自然语言处理
jndingxin42 分钟前
OpenCV视频I/O(14)创建和写入视频文件的类:VideoWriter介绍
人工智能·opencv·音视频
_.Switch1 小时前
Python Web 应用中的 API 网关集成与优化
开发语言·前端·后端·python·架构·log4j
一个闪现必杀技1 小时前
Python入门--函数
开发语言·python·青少年编程·pycharm
AI完全体1 小时前
【AI知识点】偏差-方差权衡(Bias-Variance Tradeoff)
人工智能·深度学习·神经网络·机器学习·过拟合·模型复杂度·偏差-方差
GZ_TOGOGO1 小时前
【2024最新】华为HCIE认证考试流程
大数据·人工智能·网络协议·网络安全·华为
sp_fyf_20241 小时前
计算机前沿技术-人工智能算法-大语言模型-最新研究进展-2024-10-02
人工智能·神经网络·算法·计算机视觉·语言模型·自然语言处理·数据挖掘
新缸中之脑1 小时前
Ollama 运行视觉语言模型LLaVA
人工智能·语言模型·自然语言处理
小鹿( ﹡ˆoˆ﹡ )1 小时前
探索IP协议的神秘面纱:Python中的网络通信
python·tcp/ip·php