LLM-pruner源码解析

1.超参数

模型剪枝的超参数

模型

模型检查点和日志的保存地址

剪枝比例,这里默认0.5

剪枝类型,这里模型L2

模型生成时的超参数

温度

top_p

最大序列长度

逐通道,逐块,逐层,这个逐层我不记得在论文里面提过啊

layer:保留前n层

注意力模块和线性层的开始层和结束层

剪枝的迭代次数

组的计算策略:这里采用的求和

是否使用全局剪枝

泰勒:包括在向量维度上的,在元素维度上的一阶、二阶(混合参数不知道指的是啥)

向量维度:

更细的元素维度:

这几个生成参数也没啥好说的,加载设备

确定torch的版本

官方给的bloom的配置是7B1的模型

我这里用的3B后面还要根据分析的结果改一下

3B有30层,

bloom7b有24层 这个4,20代表从0开始的排序是20还是从1开始的排序是20呀

按照配置获取的超参数

2.程序运行逻辑

第一步先固定下随机种子

设置日志

获得tokenizer

获得模型,这个类是llm-pruner自己写的,

有个问题:为啥要自己重新写一个加载类呢

自定义加载类

这段代码可以不用看直接看对q,k,v重新排序,这里自己写的加载类和transformers自带的没有区别,猜测应该是大佬防止模块名不一致,自己又重新写了一遍

下面套的类比较多这里为了区别,运行到哪一类提前说一下是继承顺序

**BloomForCausallm:**BloomForCausallm类先继承BloomPreTrainedModel类

复制代码
from transformers.models.bloom.configuration_bloom import BloomConfig

**BloomForCausallm继承BloomPreTrainedModel:**BloomPreTrainedModel类继承PreTrainedModel类

复制代码
from transformers.modeling_utils import PreTrainedModel
from transformers.models.bloom.configuration_bloom import BloomConfig

**BloomForCausallm调用BloomModel:**实例化BloomModel类,获取模型结构并初始化权重

BloomForCausallm调用BloomModel,BloomModel继承BloomPreTrainedModel: 这个BloomModel类也是继承BloomPreTrainedModel类

BloomForCausallm调用BloomModel,BloomModel调用BloomBlock **:**BloomModel类调用BloomBlock类

BloomForCausallm调用BloomModel,BloomModel调用BloomAttention **:**BloomAttention

**BloomForCausallm调用BloomModel,BloomModel调用BloomMLP:**BloomMLP

**BloomForCausallm调用BloomModel,BloomModel调用BloomGelu:**BloomGelu

**BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法:**BloomModel中的post_init,这个方法是PreTrainedModel中的方法

BloomForCausallm调用BloomModel,BloomModel继承PreTrainedModel中的self.post_init()实例方法,self.post_init()调用PreTrainedModel中的self.init_weight()实例方法和self._back_compatibity_gradient_ckeckpointing()实例方法

init_weights:如果需要,修剪并可能初始化权重。如果使用自定义的 `PreTrainedModel`,你需要在 `_init_weights` 中实现任何初始化逻辑。

**BloomForCausallm:**对于 self.transformer,post_init方法中的判定都是否,所以还是保持原来的参数不变,并没有对self.transformer包含的模块进行重新初始化

**BloomForCausallm:**self.lm_head

**BloomForCausallm:**self.post_init()判断都是否,没有对网络进行重新初始化

BloomForCausallm继承BloomPreTrainedModel: 我还是没明白什么时候调用的_init_weight方法,只对self.lm_head进行标准化

q,k,v重排

从这里开始需要看,将q,k,v的顺序进行重排

torch.view 函数的追踪在某些情况下比较复杂,因此,查询、键、值的索引映射有时会遇到问题。为了避免这些问题,函数通过分离查询、键、值的方式来重新组织权重和偏置。


将模型转化为fp16

复制代码
from LLMPruner.templates.prompts import prompts

model.generate也是自己写的,先不看这里直接看pruner

pruner

选择pruner类型,将模块参数全都转化为要求梯度,组依赖关系的计算方式选择求和,给定提示prompt

选择taylor方式,一阶,求和,实例化TaylorImportance类

复制代码
from LLMPruner.pruner import hf_llama_pruner as Pruner

Pruner这个类是自己定义的,如果模型不一样对应的类也不一样,具体怎么根据自己的模型改还得继续向下看

在baichuan中

组依赖关系求和,不进行归一化,一阶泰勒

import LLMPruner.torch_pruning as tp

逐模块计算

获取参数

从开始层到结束层 将q,k,v的pruner比例平分

"ch_sparsity_dict": { model.transformer.h[i].self_attention.query_key_value: args.pruning_ratio / 3 for i in range(args.block_attention_layer_start, args.block_attention_layer_end) },

复制代码
"root_instances": [model.transformer.h[i].mlp.dense_h_to_4h for i in range(args.block_mlp_layer_start, args.block_mlp_layer_end)] +
                  [model.transformer.h[i].self_attention.query_key_value for i in range(args.block_attention_layer_start, args.block_attention_layer_end)],

开始pruner

复制代码
import LLMPruner.torch_pruning as tp 

获取MetaPruner类的实例属性

对DependencyGraph类实例化

复制代码
from ... import ops, dependency
复制代码
from . import _helpers, utils, ops

这几个函数都是调用的ops中的类,CUSTOMIZED定制为None

已经注册的编辑器,更新编辑器。定制pruner,忽略的层

from .pruner import function

刚开始的编辑器

更新之后的编辑器

按照编辑器中的值获取pruner输入通道的函数,定制里面是空值

按照编辑器中的值获取pruner输出通道的函数,定制里面是空值

每个pruner类型获取pruner输入/出通道的函数,这里只是举例子

_op_id为0,记录pruner的历史,现在为空

调用build_dependency方法,输入参数包括模型,前面设置的提示词,刚才参数中的前向函数model(example_inputs),剩下的三个output_transform,unwrapped_parameters,customized_pruners都是None或空值

build_dependency 的实例属性包括:获取模型,命名的模块

customized_pruners,CUSTOMIZED_PRUNERS之前已经知道了两个字典都是空

self._module2name有命名的模块有337个

检测没有包装的参数

新建已包装模块的列表,获取注册表中的pruner

获得每个模块的类型,如果操作类型在pruner模块类型中且不是逐元素(后面不用看后面是空),把参数加入到已包装参数的类中

已包装的参数有336个

新建unwrapped_detected和_param_to_name

遍历看看哪个是未包装的,加入unwrapped_detected

最后的运行结果,没有不被包装的,同样_param_to_name里面也是空的

从 unwrapped_detected 列表中移除所有出现在 unwrapped_parameters 列表中的元素,最终的结果存回 unwrapped_detected 变量,这个unwrapped_parameters是类属性,为空列表

如果 unwrapped_detected不为空的话相关的处理手段,对 unwrapped_detected 列表中的每个元素进行处理,找出其 最后一个大于1的维度,并将该元素与对应的维度信息存入 unwrapped_parameters 列表。最后,将 unwrapped_parameters 保存到 self.unwrapped_parameters

开始追踪计算图了,输入有模型,提示输入,前向,output_transform为None

model.eval()

初始化gradfn2module和visited

获得之前pruner编辑器中的内容

还是之前的那18个

如果模型模块不在忽略的层中且在注册表内,在模型的每一层上注册一个前向钩子(forward_hook),这些钩子将在模型进行前向传播时触发并调用 _record_grad_fn 函数

visited 是一个字典,键是模块对象,值是该模块被调用的次数。每当某个模块被前向传播执行时,visited[module] 增加 1

当前模块是否是 nn.Linear 层,并且该层的输出张量维度是否为 3(例如,(batch_size, seq_len, hidden_dim))。如果条件成立,设置 self._2d_4d = False,表示模型输出的维度不再是 2D 或 4D,而是 3D。

有些层(如 LSTM, GRU)的输出是一个元组,通常包括输出张量和一些附加信息(如隐藏状态)。这个条件会检查输出是否是一个元组,如果是,则提取元组的第一个元素作为最终的输出

PackedSequence 是 PyTorch 中用于表示 RNN 变长序列的输出格式。此条件用于检查 outputs 是否是一个 PackedSequence 对象,如果是,则将其 .data 提取出来。.data 是实际的张量数据

outputs.grad_fn 是 outputs 张量的梯度计算函数(grad_fn),它记录了张量如何计算出来。这里的 gradfn2module 是一个字典,将梯度计算函数 (grad_fn) 映射到对应的模块 module

之前自定义了前向函数 forward_fn,同时会调用之前注册的hook函数

前向完成后移除掉hook函数

在前向的过程中填充记录模块调用次数的visited 字典和记录计算梯度函数的gradfn2module列表

针对递归模型或层,找到被调用多次的模块记录到reused列表中

这里没有被条用多次的,是空的列表

这里output_transform是None,如果有的话对模型的输出结果进行转换

from . import _helpers, utils, ops

对于utils.flatten_as_list()

如果 obj 是一个张量,则将其包装成一个列表并返回

检查 obj 是否是一个列表(list)或元组(tuple)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。然后,递归地调用 flatten_as_list 来展开列表或元组中的每个元素(sub_obj)。使用 extend 方法将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

检查 obj 是否是一个字典(dict)。如果是,创建一个空的 flattened_list 用于存储展开后的元素。

然后,递归地调用 flatten_as_list 来展开字典中每个键对应的值(sub_obj)。使用 extend 将每次递归得到的展开结果合并到 flattened_list 中。最终返回这个完全展开的列表

如果 obj 既不是 torch.Tensor、列表、元组或字典,那么直接返回 obj 本身。这部分适用于基本类型(如整数、浮点数、字符串等)。

调用是实力属性_trace_computational_graph追踪计算图

输入包括module2node字典,模块梯度计算函数,记录模块计算函数的gradfn2module字典,和纪律被多次调用的模块的字典

非递归计算图构建

processing_stack: 用于存储待处理的梯度函数(grad_fn)节点,类似栈(stack)的数据结构。

visited: 用于跟踪已经处理过的 grad_fn,避免重复处理。

visited_as_output_node: 用于追踪作为输出节点的计算图节点。

在每次循环中,弹出栈顶的梯度函数(grad_fn)并开始处理。

如果当前 grad_fn 已经处理过,则跳过(防止重复计算)

调用create_node_if_not_exists

如果 module 已经存在,并且该 module 已经在 module2node 字典中关联了一个节点(即已存在对应的计算图节点),并且该 module 不在 reused 中(表示该节点没有被标记为"已重用"),那么直接返回现有的节点 module2node[module

如果 module 为空(表示这是一个新模块,之前没有创建过),则会根据 grad_fn 创建一个新的模块并与其关联

如果 grad_fn 没有 name 属性(说明它是一个不常见的或自定义的操作),则将其视为一个 逐元素操作(如加法、减法等),并使用 ops._ElementWiseOp 创建一个新的操作模块 module,并给这个模块分配一个唯一的 op_id。self._op_id 会在每次创建模块后自增。

如果 verbose 为 True,则发出警告,提示遇到了一个未知操作,默认将其视为逐元素操作。

如果 grad_fn.name() 包含特定的字符串(如 "catbackward"、"split"、"view" 等),则根据操作类型创建对应的模块(例如 ops._ConcatOp 表示拼接操作,ops._SplitOp 表示拆分操作,ops._ReshapeOp 表示形状变化操作等)。

如果没有匹配到特定类型的操作,则默认将其视为 逐元素操作。

创建好模块后,会将 grad_fn 与新创建的模块存储到 gradfn2module 字典中,以便以后查找。

如果 module 还没有在 module2node 字典中找到对应的节点,则创建一个新的节点 Node 对象。

该节点包含以下信息:

module: 关联的操作模块

grad_fn: 关联的梯度计算函数

name: 从 _module2name 字典中获取模块的名称,如果没有,则为 None。

如果该模块是自定义的修剪器(CUSTOMIZED_PRUNERS),则将节点类型设置为 CUSTOMIZED。

将新节点添加到 module2node 字典中,以便后续访问。

如果 module 已经有对应的节点,则直接使用已存在的节点

hasattr() 是 Python 内置的一个函数,用来检查一个对象是否具有指定的属性。

检查当前的 grad_fn(计算图中的节点)是否有 next_functions 属性。

grad_fn.next_functions 是一个可迭代对象,每个元素表示当前梯度函数(操作)依赖的输入(上游节点)。遍历 next_functions 列表中的每个元素,来处理每个输入。

如果 f[0] 为 None,表示该输入没有有效的梯度函数,因此跳过这个输入

这行代码检查 f[0](即当前输入的 grad_fn)是否有 name 属性,并且其名称是否包含 "accumulategrad"(表示该输入是一个叶子变量)。这种叶子变量通常对应于模型参数(如权重或偏置),它们不是由其他操作计算得到的,而是计算图中的输入

如果 f[0] 是叶子变量,进一步检查它是否属于未包装的参数(即 unwrapped_parameters)。

如果找到了匹配的参数,gradfn2module[f[0]] = p 将 grad_fn 映射到该参数(p)。同时,使用 self._module2name 为该参数生成一个名称 "UnwrappedParameter_j (shape)",并将其赋值为 grad_fn 的名称。

如果没有找到匹配的参数,跳过当前输入

调用 create_node_if_not_exists(f[0]) 为输入 f[0] 创建一个节点

node.add_input(input_node, allow_dumplicated=False) 将当前的 input_node作为输入添加到 node中。allow_dumplicated=False 表示不允许重复连接相同的输入。

input_node.add_output(node, allow_dumplicated=False) 将 ndoe\作为输出添加到 input_node中。

f[0] 被添加到 processing_stack 中,表示该输入已经被处理

visited.add(grad_fn) 将当前的 grad_fn 标记为已访问,表示该节点已经被处理过。

visited_as_output_node.add(node) 将当前的 node 标记为已访问的输出节点,防止后续重复处理

对于没有包装的节点

最后返回模块和节点之间的关系

打个节点,下次再看

相关推荐
文心快码 Baidu Comate3 分钟前
吉利汽车x文心快码:AI最佳实践案例
人工智能·汽车·编程·ai编程·文心快码·智能编程助手
一尘之中25 分钟前
如何做好一份技术文档?
人工智能·学习
掘金安东尼40 分钟前
RAG BM25 算法和重排,微调以外的手段
人工智能·llm
Matlab仿真实验室1 小时前
基于Matlab实现车牌识别系统(源码+图像)
开发语言·网络·人工智能·算法·计算机视觉·matlab·车牌识别系统
Mr.谢尔比2 小时前
李宏毅机器学习课程知识点摘要(6-13集)
人工智能·pytorch·深度学习·神经网络·机器学习·计算机视觉
Elastic 中国社区官方博客2 小时前
使用 Jina Embeddings v2 在 Elasticsearch 中进行后期分块
大数据·人工智能·elasticsearch·搜索引擎·ai·全文检索·jina
深度学习lover2 小时前
<项目代码>YOLOv8 停车场空位识别<目标检测>
人工智能·python·yolo·目标检测·计算机视觉·停车场空位识别
Eric.Lee20212 小时前
数据集-目标检测系列- 安全背心 检测数据集 safety_vests >> DataBall
人工智能·python·yolo·目标检测·计算机视觉·安全背心检测
sp_fyf_20242 小时前
【大语言模型】ACL2024论文-16 基于地图制图的罗马尼亚自然语言推理语料库的新型课程学习方法
人工智能·深度学习·机器学习·语言模型·数据挖掘·学习方法
W Y2 小时前
【智能制造-46】人机工程(工厂自动化)
人工智能·自动化·制造·人机工程学·人机工程