transformers阅读——Llama模型

学习一下 transformers 库中,Llama 模型的代码,学习过程中写下这篇笔记,一来加深印象,二来可以多次回顾。

笔者小白,里面错误之处请不吝指出。

层归一化 LlamaRMSNorm

transformers 中对于 LlamaRMSNorm 类的定义如下:

python 复制代码
class LlamaRMSNorm(nn.Module):  
    def __init__(self, hidden_size, eps=1e-6):  
    """  
    LlamaRMSNorm is equivalent to T5LayerNorm  
    """  
    super().__init__()  
    self.weight = nn.Parameter(torch.ones(hidden_size))  
    self.variance_epsilon = eps  
  
def forward(self, hidden_states):  
    input_dtype = hidden_states.dtype  
    hidden_states = hidden_states.to(torch.float32)  
    variance = hidden_states.pow(2).mean(-1, keepdim=True)  
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  
    return self.weight * hidden_states.to(input_dtype)

这里采用了 RMS(Root Mean Square) 归一化,其中 RMS 计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R M S ( x ) = 1 n ∑ x i 2 RMS(x)=\sqrt{\frac{1}{n}\sum{x_i^2}} </math>RMS(x)=n1∑xi2

则 RMSNorm 归一化的计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R M S ( x ) = x R M S ( x ) + ϵ ∗ W RMS(x)=\frac{x}{\sqrt{RMS(x)+\epsilon}} * W </math>RMS(x)=RMS(x)+ϵ x∗W

加上一个小常数,确保分母不为零,保持数据稳定性。

旋转位置编码 LlamaRotaryEmbedding

  • 绝对位置编码:计算高效,效果欠佳
  • 相对位置编码:满足 NLP 领域在序列长度方向上具有平移不变性,计算效率低。
  • 旋转位置编码:采用绝对位置编码达到相位置编码的效果

transformers 中对于 LlamaRotaryEmbedding 类的定义如下,它用于实现旋转位置嵌入:

python 复制代码
class LlamaRotaryEmbedding(nn.Module):  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):  
        super().__init__()  

        self.dim = dim  
        self.max_position_embeddings = max_position_embeddings  
        self.base = base  
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
        self.register_buffer("inv_freq", inv_freq, persistent=False)  

        # Build here to make `torch.jit.trace` work.  
        self._set_cos_sin_cache(  
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()  
        )  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)  

    def forward(self, x, seq_len=None):  
        # x: [bs, num_attention_heads, seq_len, head_size]  
        if seq_len > self.max_seq_len_cached:  
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)  

        return (  
            self.cos_cached[:seq_len].to(dtype=x.dtype),  
            self.sin_cached[:seq_len].to(dtype=x.dtype),  
        )

其中定义的变量意义如下:

  • dim:表示模型输出维度
  • max_position_embeddings:最大编码长度,默认为2048
  • base:基数,默认为10000

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 实现公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i n v _ f r e q = 1 b a s e i / d i m inv\_freq=\frac{1}{base^{i/dim}} </math>inv_freq=basei/dim1

在上面代码中,t 的维度为[max_position_embeddings], inv_freq 的维度为[dim/2]。

经过 torch.einsum("i,j->ij", t, self.inv_freq) 之后维度为 [max_position_embeddings, dim/2]。

然后经过 emb = torch.cat((freqs, freqs), dim=-1) 操作,维度变为 [max_position_embeddings, dim]。

二维情况下旋转矩阵如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) = ( c o s k θ − s i n k θ s i n k θ c o s k θ ) R(k)=\begin{pmatrix} cosk\theta & -sink\theta \\ sink\theta & cosk\theta \\ \end{pmatrix} </math>R(k)=(coskθsinkθ−sinkθcoskθ)

旋转位置编码计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) x = ( c o s k θ 0 c o s k θ 0 c o s k θ 1 c o s k θ 1 . . . c o s k θ d / 2 − 1 c o s k θ d / 2 − 1 ) ∘ ( x 0 x 1 x 2 x 3 . . . x d − 1 x d ) + ( s i n k θ 0 s i n k θ 0 s i n k θ 1 s i n k θ 1 . . . s i n k θ d / 2 − 1 s i n k θ d / 2 − 1 ) ∘ ( − x 1 x 0 − x 3 x 2 . . . − x d x d − 1 ) R(k)x= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ cos{k\theta_1} \\ ... \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ ... \\ x_{d-1} \\ x_d \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ sin{k\theta_1} \\ ... \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ ... \\ -x_d \\ x_{d-1} \end{pmatrix} </math>R(k)x=⎝ ⎛coskθ0coskθ0coskθ1coskθ1...coskθd/2−1coskθd/2−1⎠ ⎞∘⎝ ⎛x0x1x2x3...xd−1xd⎠ ⎞+⎝ ⎛sinkθ0sinkθ0sinkθ1sinkθ1...sinkθd/2−1sinkθd/2−1⎠ ⎞∘⎝ ⎛−x1x0−x3x2...−xdxd−1⎠ ⎞

在使用 LLM 时,我们希望对上下文长度进行拓展,以便能进行多轮对话,由此有下面几种方法:

外推法:直接沿用当前公式计算计算更长位置的编码。

这种方法比较简单,但是存在相关性衰减问题,如果模型训练语料在 2k 长度左右,模型能够学习到 2k 长度左右的 token 之间相关性关系的规律。

如果直接将此规律沿用到 5k 上下文,可能导致在中间某个位置相关性衰减到零,从而无法捕捉两个 token 之间的相关性。

线性内插:线性内插会改变编码公式,将 token 之间的距离等比例缩小。

例如在 2k 上下文情况下,两个 token 之间距离为 16,那么在 32k 上下文下,这两个 token 之间距离缩短到 1。

对于短距离的衰减规律,可能造成非常大的变化,但是线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。

transformers 中对于线性内插的实现如下:

python 复制代码
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
        t = t / self.scaling_factor  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,在 t = t / self.scaling_factor 这行代码中,除以一个缩放因子,从而达到线性缩放的效果。

动态 NTK 扩展:外推法对于长距离的 token 不能很好计算相关性,线性内插对于短距离 token 计算相关性会产生很大变化,因此可以综合两者进行扩展。

为了在短距离情况下具有外推特性,长距离情况下具有内插特性,可以设置一个与位置序号有关频率因子,动态调整。

transformers 中对于动态 NTK 扩展的实现如下:

python 复制代码
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  

        if seq_len > self.max_position_embeddings:  
            base = self.base * (  
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)  
            ) ** (self.dim / (self.dim - 2))  
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
            self.register_buffer("inv_freq", inv_freq, persistent=False)  

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,如果长度超过 max_position_embeddings,对于 base 做出了如下公式操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> b a s e = b a s e ∗ ( f a c t o r ∗ s e q _ l e n m a x _ l e n − ( f a c t o r − 1 ) ) d i m d i m − 2 base=base*(factor*\frac{seq\_len}{max\_len}-(factor-1))^{\frac{dim}{dim-2}} </math>base=base∗(factor∗max_lenseq_len−(factor−1))dim−2dim

如果 seq_len > max_position_embeddings,在 factor = 1 的情况下,base 变大。

显然 base > 1,则 inv_freq 值变小,这样将短距离的规律扩展到了长距离。

具体计算位置编码的代码如下:

python 复制代码
def rotate_half(x):  
    """Rotates half the hidden dims of the input."""  
    x1 = x[..., : x.shape[-1] // 2]  
    x2 = x[..., x.shape[-1] // 2 :]  
    return torch.cat((-x2, x1), dim=-1)  
  
  
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb  
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):  
    cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]  
    sin = sin[position_ids].unsqueeze(1)  
    q_embed = (q * cos) + (rotate_half(q) * sin)  
    k_embed = (k * cos) + (rotate_half(k) * sin)  
    return q_embed, k_embed

rotate_half() 中,将输入 x 沿着最后一维分隔成两部分,然后进行拼接。

Llama 中对 Q 的旋转位置编码按照如下方式计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) Q = ( c o s k θ 0 c o s k θ 2 . . . c o s k θ d − 2 c o s k θ 0 c o s k θ 2 . . . c o s k θ d − 2 ) ∘ ( q 0 q 1 . . . q d / 2 − 1 q d / 2 q d / 2 + 1 . . . q d − 1 ) + ( s i n k θ 0 s i n k θ 2 . . . s i n k θ d − 2 s i n k θ 0 s i n k θ 2 . . . s i n k θ d ) ∘ ( − q d / 2 − q d / 2 + 1 . . . − q d − 1 q 0 q 1 . . . q d − 1 ) R(k)Q= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_2} \\ ... \\ cos{k\theta_{d-2}} \\ cos{k\theta_0} \\ cos{k\theta_2} \\ ... \\ cos{k\theta_{d-2}} \end{pmatrix} \circ \begin{pmatrix} q_0 \\ q_1 \\ ... \\ q_{d/2-1} \\ q_{d/2} \\ q_{d/2+1} \\ ... \\ q_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_2} \\ ... \\ sin{k\theta_{d-2}} \\ sin{k\theta_0} \\ sin{k\theta_2} \\ ... \\ sin{k\theta_{d}} \end{pmatrix} \circ \begin{pmatrix} -q_{d/2} \\ -q_{d/2+1} \\ ... \\ -q_{d-1} \\ q_0 \\ q_1 \\ ... \\ q_{d-1} \end{pmatrix} </math>R(k)Q=⎝ ⎛coskθ0coskθ2...coskθd−2coskθ0coskθ2...coskθd−2⎠ ⎞∘⎝ ⎛q0q1...qd/2−1qd/2qd/2+1...qd−1⎠ ⎞+⎝ ⎛sinkθ0sinkθ2...sinkθd−2sinkθ0sinkθ2...sinkθd⎠ ⎞∘⎝ ⎛−qd/2−qd/2+1...−qd−1q0q1...qd−1⎠ ⎞

这里只对 Q 和 K 加入位置编码信息。

前馈网络 LlamaMLP

transformers 中对于前馈网络的定义如下:

python 复制代码
class LlamaMLP(nn.Module):  
    def __init__(self, config):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.intermediate_size = config.intermediate_size  
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  
        self.act_fn = ACT2FN[config.hidden_act]  
  
    def forward(self, x):  
        if self.config.pretraining_tp > 1:  
            slice = self.intermediate_size // self.config.pretraining_tp  
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)  
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)  

            gate_proj = torch.cat(  
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1  
            )  
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)  
            down_proj = [  
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)  
            ]  
            down_proj = sum(down_proj)  
        else:  
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))  

        return down_proj

__init__() 函数中,定义了 hidden_sizeintermediate_size 控制模型尺寸。

同时定义了三个全连接层:

  • gate_proj:将 hidden_size 投影到 intermediate_size
  • up_proj:将 hidden_size 投影到 intermediate_size
  • down_proj:将 intermediate_size 投影到 hidden_size

这里会将输入通过 up_proj 先从 hidden_size 维度转换到 intermediate_size 维度,然后通过 down_proj 从 intermediate_size 维度转换到 hidden_size 维度。

同时里面采用 gate_proj 配合激活函数,实现了一个门控作用。

forward() 函数中会根据 config.pretraining_tp 选择不同的执行策略。这里是将三个全连接层切分为若干块,分别与输入 x 进行映射操作,得到多个子投影,然后将多个子投影拼接起来。

多头注意力 LlamaAttention

transformers 中对于多头注意力的定义如下:

python 复制代码
class LlamaAttention(nn.Module):  
    """Multi-headed attention from 'Attention Is All You Need' paper"""  

    def __init__(self, config: LlamaConfig):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.num_heads = config.num_attention_heads  
        self.head_dim = self.hidden_size // self.num_heads  
        self.num_key_value_heads = config.num_key_value_heads  
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  
        self.max_position_embeddings = config.max_position_embeddings  
        self.rope_theta = config.rope_theta  

        if (self.head_dim * self.num_heads) != self.hidden_size:  
        raise ValueError(  
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"  
            f" and `num_heads`: {self.num_heads})."  
        )  
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)  
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)  
        self._init_rope()  

    def _init_rope(self):  
        if self.config.rope_scaling is None:  
            self.rotary_emb = LlamaRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                base=self.rope_theta,  
            )  
        else:  
            scaling_type = self.config.rope_scaling["type"]  
            scaling_factor = self.config.rope_scaling["factor"]  
        if scaling_type == "linear":  
            self.rotary_emb = LlamaLinearScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        elif scaling_type == "dynamic":  
            self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        else:  
            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

这里主要定义了下面几种属性:

  • hidden_size:隐藏层的大小
  • num_heads:注意力头的数量
  • head_dim:每个注意力头的维度,它通过 hidden_size // num_heads 得到
  • num_key_value_heads:键值注意力头的数量
  • num_key_value_groups:键值注意力头分组数量,它通过 num_heads // num_key_value_heads 得到
  • rope_theta:即前面 RoPE 的 base

此外还定义了四个线性变换的全连接层,分别用于计算查询(Q)、键(K)、值(V)和输出(O)。

注意到键值 注意力头的数量与查询注意力头的数量不同。

键值注意力头数量可以是查询注意力头数量的几分之一,这样可以减少参数规模。

多头注意力的计算代码如下:

python 复制代码
def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: bool = False,  
    use_cache: bool = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  
    bsz, q_len, _ = hidden_states.size()  

    if self.config.pretraining_tp > 1:  
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  
        query_slices = self.q_proj.weight.split(  
            (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0  
        )  
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]  
        query_states = torch.cat(query_states, dim=-1)  

        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]  
        key_states = torch.cat(key_states, dim=-1)  

        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]  
        value_states = torch.cat(value_states, dim=-1)  

    else:  
        query_states = self.q_proj(hidden_states)  
        key_states = self.k_proj(hidden_states)  
        value_states = self.v_proj(hidden_states)  

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)  
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  

        kv_seq_len = key_states.shape[-2]  
    if past_key_value is not None:  
        kv_seq_len += past_key_value[0].shape[-2]  
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)  
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)  

    if past_key_value is not None:  
        # reuse k, v, self_attention  
        key_states = torch.cat([past_key_value[0], key_states], dim=2)  
        value_states = torch.cat([past_key_value[1], value_states], dim=2)  

        past_key_value = (key_states, value_states) if use_cache else None  

        key_states = repeat_kv(key_states, self.num_key_value_groups)  
        value_states = repeat_kv(value_states, self.num_key_value_groups)  

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)  

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):  
        raise ValueError(  
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"  
            f" {attn_weights.size()}"  
        )  

    if attention_mask is not None:  
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):  
            raise ValueError(  
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"  
            )  
        attn_weights = attn_weights + attention_mask  

    # upcast attention to fp32  
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
    attn_output = torch.matmul(attn_weights, value_states)  

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):  
    raise ValueError(  
    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"  
    f" {attn_output.size()}"  
    )  

    attn_output = attn_output.transpose(1, 2).contiguous()  

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)  

    if self.config.pretraining_tp > 1:  
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  
    else:  
        attn_output = self.o_proj(attn_output)  

    if not output_attentions:  
        attn_weights = None  

    return attn_output, attn_weights, past_key_value

多头注意力基本与《Attention Is All You Need》中一致,计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V </math>Attention(Q,K,V)=softmax(dk QKT)V

在 llama 中每进行一次注意力计算,都会对 Q 和 K 计算一次位置编码(RoPE)。

因为 K 和 V 注意力头数是 Q 的几分之一,所以在计算前首先进行 repeat 操作,对应代码如下:

python 复制代码
key_states = repeat_kv(key_states, self.num_key_value_groups)  
value_states = repeat_kv(value_states, self.num_key_value_groups)

计算注意力的代码如下:

python 复制代码
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

attn_weights = attn_weights + attention_mask # 可选操作

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
attn_output = torch.matmul(attn_weights, value_states)

最终 attn_output 经过 o_proj 的线性变换之后输出。

与前馈网络类似,如果 config 中设置 pretraining_tp,会对输入进行切片后分块操作。

解码层 LlamaDecoderLayer

transfromers 中对解码层的定义如下:

python 复制代码
class LlamaDecoderLayer(nn.Module):  
    def __init__(self, config: LlamaConfig):  
    super().__init__()  
    self.hidden_size = config.hidden_size  
    self.self_attn = (  
        LlamaAttention(config=config)  
        if not getattr(config, "_flash_attn_2_enabled", False)  
        else LlamaFlashAttention2(config=config)  
    )  
    self.mlp = LlamaMLP(config)  
    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
    self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

解码层由 AttentionLayerMLP 和两个 LayerNorm 组成。

前向计算代码如下:

python 复制代码
def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: Optional[bool] = False,  
    use_cache: Optional[bool] = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:  
    """  
    Args:  
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`  
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size  
        `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.  
        output_attentions (`bool`, *optional*):  
        Whether or not to return the attentions tensors of all attention layers. See `attentions` under  
        returned tensors for more detail.  
        use_cache (`bool`, *optional*):  
        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding  
        (see `past_key_values`).  
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states  
    """  

    residual = hidden_states  

    hidden_states = self.input_layernorm(hidden_states)  

    # Self Attention  
    hidden_states, self_attn_weights, present_key_value = self.self_attn(  
        hidden_states=hidden_states,  
        attention_mask=attention_mask,  
        position_ids=position_ids,  
        past_key_value=past_key_value,  
        output_attentions=output_attentions,  
        use_cache=use_cache,  
        padding_mask=padding_mask,  
    )  
    hidden_states = residual + hidden_states  

    # Fully Connected  
    residual = hidden_states  
    hidden_states = self.post_attention_layernorm(hidden_states)  
    hidden_states = self.mlp(hidden_states)  
    hidden_states = residual + hidden_states  

    outputs = (hidden_states,)  

    if output_attentions:  
        outputs += (self_attn_weights,)  

    if use_cache:  
        outputs += (present_key_value,)  

    return outputs

在解码器层中,输入 hidden_states 依次经历如下计算:

  1. 经过 input_layernorm 进行层归一化。
  2. 计算一次自注意力。
  3. 做一次残差连接。
  4. 经过 post_attention_layernorm 进行层归一化。
  5. 经过 mlp,并将结果与步骤3结果做一次残差连接。

模型 LlamaModel

transformers 中对模型定义如下:

python 复制代码
class LlamaModel(LlamaPreTrainedModel):  
    """  
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]  

    Args:  
        config: LlamaConfig  
    """  

    def __init__(self, config: LlamaConfig):  
        super().__init__(config)  
        self.padding_idx = config.pad_token_id  
        self.vocab_size = config.vocab_size  

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)  
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])  
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  

        self.gradient_checkpointing = False  
        # Initialize weights and apply final processing  
        self.post_init()

Llama 模型是由若干个解码层堆叠而成。

在前向传播时设置 gradient_checkpointing=True 可以节约显存。

但是这个参数不能和 use_cache=True 同时设置,这两个参数不兼容。

python 复制代码
if self.gradient_checkpointing and self.training:  
    if use_cache:  
        logger.warning_once(  
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."  
        )  
    use_cache = False

在前向传播中自定义了前向传播函数:

python 复制代码
def create_custom_forward(module):  
    def custom_forward(*inputs):  
        # None for past_key_value  
        return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)  
  
    return custom_forward

使用 torch.utils.checkpoint.checkpoint() 函数,它允许将前向传播的一部分分成小块以减小内存占用,并且可以在反向传播时实现显存优化。前提是设置 gradient_checkpointing=True

python 复制代码
layer_outputs = torch.utils.checkpoint.checkpoint(  
    create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids  
)

代码中的 decode_layer 为前文中提到的解码器层。

经过多层解码器层后,将输出经过 RMSNorm 层,得到最终结果。

语言模型 LlamaForCausalLM

transformers 中对语言模型的定义如下:

python 复制代码
class LlamaForCausalLM(LlamaPreTrainedModel):  
    _tied_weights_keys = ["lm_head.weight"]  

    def __init__(self, config):  
        super().__init__(config)  
        self.model = LlamaModel(config)  
        self.vocab_size = config.vocab_size  
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  

        # Initialize weights and apply final processing  
        self.post_init()

实质是在前文提到的 LlamaModel 基础上加入一个 llm_head 来生成结果。

前向传播核心计算代码如下:

python 复制代码
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)  
outputs = self.model(  
    input_ids=input_ids,  
    attention_mask=attention_mask,  
    position_ids=position_ids,  
    past_key_values=past_key_values,  
    inputs_embeds=inputs_embeds,  
    use_cache=use_cache,  
    output_attentions=output_attentions,  
    output_hidden_states=output_hidden_states,  
    return_dict=return_dict,  
)  
  
hidden_states = outputs[0]  
if self.config.pretraining_tp > 1:  
    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  
    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  
    logits = torch.cat(logits, dim=-1)  
else:  
    logits = self.lm_head(hidden_states)  
logits = logits.float()

如果输入 labels 会自动计算交叉熵损失。

分类模型 LlamaForSequenceClassification

分类模型也是由 LlamaModel 加上一个 score 的线性层构成。

在计算损失的时候,会根据不同的类型,选择不同的损失函数:

python 复制代码
if self.config.problem_type == "regression":  
    loss_fct = MSELoss()  
    if self.num_labels == 1:  
        loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())  
    else:  
        loss = loss_fct(pooled_logits, labels)  
elif self.config.problem_type == "single_label_classification":  
    loss_fct = CrossEntropyLoss()  
    loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))  
elif self.config.problem_type == "multi_label_classification":  
    loss_fct = BCEWithLogitsLoss()  
    loss = loss_fct(pooled_logits, labels)

总结

以 LlamaModel 为例总结数据流向:

  • 输入的如果是 input_ids,会首先计算 inputs_embeds,然后作为 hidden_states,经过若干个 LlamaDecoderLayer、LlamaRMSNorm 组合后输出。
  • 在 LlamaDecoderLayer 中,经历如下步骤:
    1. 先记录原始输入,然后对于输入的 hidden_states 先做一次 LlamaRMSNorm。
    2. 对步骤1的结果做一次 LlamaAttention。
    3. 将步骤2的结果与原始输入做一次残差连接,并记录这次结果。
    4. 将步骤3的结果做一次 LlamaAttention。
    5. 将步骤4的结果做一次 LlamaMLP。
    6. 将步骤5的结果与步骤3的结果做一次残差连接,将结果输出。
  • 在 LlamaAttention 中,经历如下步骤:
    1. 将输入的 hidden_states 做 Q、K、V 变换。
    2. 计算 Q、K 的旋转位置编码。
    3. 根据公式计算自注意力。
    4. 注意力经过线性变换后输出。
  • 在 LlamaMLP 中,经历如下步骤:
    1. 原始输入经过线性变换,得到上投影。
    2. 原始输入经过门函数和激活函数,得到门控投影。
    3. 将步骤1的上投影和步骤2的门控投影对应元素相乘。
    4. 将步骤3的结果经过线性变换,得到下投影,输出这个结果。

学习参考

在学习过程中,参考了下面的文章或视频,感谢各位大佬分享自己的知识:

相关推荐
云卓SKYDROID1 分钟前
除草机器人算法以及技术详解!
算法·机器人·科普·高科技·云卓科技·算法技术
半盏茶香25 分钟前
【C语言】分支和循环详解(下)猜数字游戏
c语言·开发语言·c++·算法·游戏
徐子童29 分钟前
双指针算法习题解答
算法
想要打 Acm 的小周同学呀38 分钟前
LRU缓存算法
java·算法·缓存
劲夫学编程2 小时前
leetcode:杨辉三角
算法·leetcode·职场和发展
毕竟秋山澪2 小时前
孤岛的总面积(Dfs C#
算法·深度优先
浮生如梦_4 小时前
Halcon基于laws纹理特征的SVM分类
图像处理·人工智能·算法·支持向量机·计算机视觉·分类·视觉检测
励志成为嵌入式工程师6 小时前
c语言简单编程练习9
c语言·开发语言·算法·vim
捕鲸叉6 小时前
创建线程时传递参数给线程
开发语言·c++·算法
A charmer6 小时前
【C++】vector 类深度解析:探索动态数组的奥秘
开发语言·c++·算法