Prefix-Tuning源码解析

Prefix-Tuning源码解析

Prefix-Tuning在PEFT包中的源码实现

改写自Based on https://github.com/THUDM/P-tuning-v2/blob/main/model/prefix_encoder.py

python 复制代码
import torch
from transformers import PretrainedConfig


class PrefixEncoder(torch.nn.Module):
    r'''
    The torch.nn model to encode the prefix

    Input shape: (batch-size, prefix-length)

    Output shape: (batch-size, prefix-length, 2*layers*hidden)
    '''
    def __init__(self, config):
        super().__init__()
        self.prefix_projection = config.prefix_projection
        if self.prefix_projection:
            # Use a two-layer MLP to encode the prefix
            self.embedding = torch.nn.Embedding(config.prefix_length, config.hidden_size)
            self.trans = torch.nn.Sequential(
                torch.nn.Linear(config.hidden_size, config.encoder_hidden_size),
                torch.nn.Tanh(),
                torch.nn.Linear(config.encoder_hidden_size, config.num_hidden_layers * 2 * config.hidden_size)
            )
        else:
            self.embedding = torch.nn.Embedding(config.prefix_length, config.num_hidden_layers * 2 * config.hidden_size)

    def forward(self, prefix: torch.Tensor):
        if self.prefix_projection:
            prefix_tokens = self.embedding(prefix)
            past_key_values = self.trans(prefix_tokens)
        else:
            past_key_values = self.embedding(prefix)
        return past_key_values
    

if __name__ == "__main__":
    configs = {"prefix_length":20,
               "hidden_size":768,
               "encoder_hidden_size":768,
               "num_hidden_layers":12,
               "prefix_projection":False
               }
    

    prefix_encoder = PrefixEncoder(config=PretrainedConfig.from_dict(configs))
    print(prefix_encoder)

    batch_size = 8
    prefix = torch.arange(20).long().expand(batch_size, -1)
    print(prefix.shape)
    output = prefix_encoder(prefix)
    print(output.shape)

下面我们以T5-large模型为例子:

不考虑Use a two-layer MLP to encode the prefix的话,prefix tuning主要包括以下代码:

python 复制代码
class PrefixEncoder(torch.nn.Module):
    def __init__(self, config):
        super().__init__()
        ...
		self.embedding = torch.nn.Embedding(num_virtual_tokens, num_layers * 2 * token_dim) #num_virtual_tokens=20,token_dim=1024,num_layers=24
        
    def forward(self, prefix: torch.Tensor):
        past_key_values = self.embedding(prefix)
        return past_key_values

得到的PrefixEncoder被传入peft->peft_model.py->prompt_encoder

python 复制代码
PrefixEncoder(
  (embedding): Embedding(20, 49152) # 1024*2*24
)

self.prompt_tokens初始化为长度2*20的向量,因为T5有编码器和解码器,需要两次prefix:

python 复制代码
self.prompt_tokens[adapter_name] = torch.arange(
            config.num_virtual_tokens * config.num_transformer_submodules
        ).long() #20*2

# tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
#        18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
#        36, 37, 38, 39])
python 复制代码
prompt_tokens = (
            self.prompt_tokens[self.active_adapter]
            .unsqueeze(0)
            .expand(batch_size, -1)
            .to(prompt_encoder.embedding.weight.device)
        ) 
prompt_tokens = prompt_tokens[:, : peft_config.num_virtual_tokens]
# 此时prompt_tokens.shape = (batch_size=8, num_virtual_tokens=20)

past_key_values = prompt_encoder(prompt_tokens)
torch.Size([8, 20, 49152])

但目前的past_key_values还是所有层的集合,我们需要把past_key_values分解为每一层:

python 复制代码
past_key_values = past_key_values.view(
                batch_size, #8
                peft_config.num_virtual_tokens, #20
                peft_config.num_layers * 2, #24*2
                peft_config.num_attention_heads, #16
                peft_config.token_dim // peft_config.num_attention_heads, #1024/16
            )
# torch.Size([8, 20, 48, 16, 64])

因为有编码器和解码器,所以再复制一次

python 复制代码
past_key_values = torch.cat([past_key_values, past_key_values], dim=2)
# torch.Size([8, 20, 96, 16, 64])

# 重排:torch.Size([96, 8, 16, 20, 64])
# 然后split成一个长度为24的tuple,每个tuple的shape:torch.Size([4, 8, 16, 20, 64])
past_key_values = past_key_values.permute([2, 0, 3, 1, 4]).split(
                peft_config.num_transformer_submodules * 2
            )

也就是说past_key_values是24个层的Prefix embedding,形状为`(num_transformer_submodules * 2, batch_size, num_attention_heads, num_virtual_tokens, token_dim/num_attention_heads])

注意这里*2是因为key+value.

transformers->models->t5->modeling_t5.py->T5Attention类,这里的关键步骤是project函数中的hidden_states = torch.cat([past_key_value, hidden_states], dim=2),注意project函数仅仅用于key和value。

python 复制代码
def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
        # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
        batch_size, seq_length = hidden_states.shape[:2]

        real_seq_length = seq_length

        if past_key_value is not None:
            if len(past_key_value) != 2:
                raise ValueError(
                    f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
                )
            real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

        key_length = real_seq_length if key_value_states is None else key_value_states.shape[1]

        def shape(states):
            """projection"""
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            """reshape"""
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        def project(hidden_states, proj_layer, key_value_states, past_key_value):
            """projects hidden states correctly to key/query states"""
            if key_value_states is None:
                # self-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(hidden_states))
            elif past_key_value is None:
                # cross-attn
                # (batch_size, n_heads, seq_length, dim_per_head)
                hidden_states = shape(proj_layer(key_value_states))

            if past_key_value is not None:
                if key_value_states is None:
                    # self-attn
                    # (batch_size, n_heads, key_length, dim_per_head)
                    # 注意这里是重点:用串联方式
                    hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
                elif past_key_value.shape[2] != key_value_states.shape[1]:
                    # checking that the `	sequence_length` of the `past_key_value` is the same as
                    # the provided `key_value_states` to support prefix tuning
                    # cross-attn
                    # (batch_size, n_heads, seq_length, dim_per_head)
                    hidden_states = shape(proj_layer(key_value_states))
                else:
                    # cross-attn
                    hidden_states = past_key_value
            return hidden_states


real_seq_length += past_key_value[0].shape[2] if query_length is None else query_length

分别计算query_states、key_states、value_states,用query和key计算attention score,得到score形状为torch.Size([8, 16, 2, 22]),所以输入X可以attend to itself以及prefix。

python 复制代码
    # hidden_states shape: torch.Size([8, 2, 1024])   
    # get query states
        query_states = shape(self.q(hidden_states))  # (batch_size, n_heads, seq_length, dim_per_head) 
    # query_states shape: torch.Size([8, 16, 2, 64])

        # get key/value states
        key_states = project(
            hidden_states, self.k, key_value_states, past_key_value[0] if past_key_value is not None else None
        )
        # key_states shape: torch.Size([8, 16, 22, 64])
        value_states = project(
            hidden_states, self.v, key_value_states, past_key_value[1] if past_key_value is not None else None
        )
        # value_states shape: torch.Size([8, 16, 22, 64])
        
        # compute scores
        # torch.Size([8, 16, 2, 22])
        scores = torch.matmul(
            query_states, key_states.transpose(3, 2)
        )  # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9

接下来就是经典的attention操作了。用attn_weights ([8, 16, 2, 22]) 和value_states ([8, 16, 22, 64])相乘,把22消掉,就是每个输入X的输出了。

python 复制代码
# if key and values are already calculated
# we want only the last query position bias
# position_bias.shape: torch.Size([8, 16, 2, 22])

		scores += position_bias_masked
    	

		attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(
            scores
        )  # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )  # (batch_size, n_heads, seq_length, key_length)
		
        attn_output = unshape(torch.matmul(attn_weights, value_states))  # (batch_size, seq_length, dim) torch.Size([8, 2, 1024])
        attn_output = self.o(attn_output)

        present_key_value_state = (key_states, value_states) if (self.is_decoder and use_cache) else None
        outputs = (attn_output,) + (present_key_value_state,) + (position_bias,)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs

参考

https://huggingface.co/docs/peft/task_guides/seq2seq-prefix-tuning

相关推荐
Java后端的Ai之路9 分钟前
【神经网络基础】-神经网络学习全过程(大白话版)
人工智能·深度学习·神经网络·学习
庚昀◟23 分钟前
用AI来“造AI”!Nexent部署本地智能体的沉浸式体验
人工智能·ai·nlp·持续部署
喜欢吃豆36 分钟前
OpenAI Realtime API 深度技术架构与实现指南——如何实现AI实时通话
人工智能·语言模型·架构·大模型
数据分析能量站38 分钟前
AI如何重塑个人生产力、组织架构和经济模式
人工智能
wscats2 小时前
Markdown 编辑器技术调研
前端·人工智能·markdown
AI科技星2 小时前
张祥前统一场论宇宙大统一方程的求导验证
服务器·人工智能·科技·线性代数·算法·生活
GIS数据转换器2 小时前
基于知识图谱的个性化旅游规划平台
人工智能·3d·无人机·知识图谱·旅游
EnoYao2 小时前
Markdown 编辑器技术调研
前端·javascript·人工智能
TMT星球2 小时前
曹操出行上市后首次战略并购,进军万亿to B商旅市场
人工智能·汽车
Coder_Boy_2 小时前
Spring AI 源码大白话解析
java·人工智能·spring