transformers 阅读:Llama 模型

正文

学习一下 transformers 库中,Llama 模型的代码,学习过程中写下这篇笔记,一来加深印象,二来可以多次回顾。

笔者小白,里面错误之处请不吝指出。

层归一化 LlamaRMSNorm

transformers 中对于 LlamaRMSNorm 类的定义如下:

class LlamaRMSNorm(nn.Module):  
    def __init__(self, hidden_size, eps=1e-6):  
    """  
    LlamaRMSNorm is equivalent to T5LayerNorm  
    """  
    super().__init__()  
    self.weight = nn.Parameter(torch.ones(hidden_size))  
    self.variance_epsilon = eps  
  
def forward(self, hidden_states):  
    input_dtype = hidden_states.dtype  
    hidden_states = hidden_states.to(torch.float32)  
    variance = hidden_states.pow(2).mean(-1, keepdim=True)  
    hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)  
    return self.weight * hidden_states.to(input_dtype)

这里采用了 RMS(Root Mean Square) 归一化,其中 RMS 计算公式为:

RMS(x)=1n∑xi2RMS(x)=\sqrt{\frac{1}{n}\sum{x_i^2}}RMS(x)=n1​∑xi2​​

则 RMSNorm 归一化的计算公式为:

RMS(x)=xRMS(x)+ϵ∗WRMS(x)=\frac{x}{\sqrt{RMS(x)+\epsilon}} * WRMS(x)=RMS(x)+ϵ​x​∗W

加上一个小常数,确保分母不为零,保持数据稳定性。

旋转位置编码 LlamaRotaryEmbedding

  • 绝对位置编码:计算高效,效果欠佳
  • 相对位置编码:满足 NLP 领域在序列长度方向上具有平移不变性,计算效率低。
  • 旋转位置编码:采用绝对位置编码达到相位置编码的效果

transformers 中对于 LlamaRotaryEmbedding 类的定义如下,它用于实现旋转位置嵌入:

class LlamaRotaryEmbedding(nn.Module):  
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):  
        super().__init__()  

        self.dim = dim  
        self.max_position_embeddings = max_position_embeddings  
        self.base = base  
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
        self.register_buffer("inv_freq", inv_freq, persistent=False)  

        # Build here to make `torch.jit.trace` work.  
        self._set_cos_sin_cache(  
            seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()  
        )  
  
    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)  

    def forward(self, x, seq_len=None):  
        # x: [bs, num_attention_heads, seq_len, head_size]  
        if seq_len > self.max_seq_len_cached:  
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)  

        return (  
            self.cos_cached[:seq_len].to(dtype=x.dtype),  
            self.sin_cached[:seq_len].to(dtype=x.dtype),  
        )

其中定义的变量意义如下:

  • dim:表示模型输出维度
  • max_position_embeddings:最大编码长度,默认为2048
  • base:基数,默认为10000

inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim)) 实现公式为:

inv_freq=1base2i/diminv\_freq=\frac{1}{base^{2i/dim}}inv_freq=base2i/dim1​

在上面代码中,t 的维度为[max_position_embeddings], inv_freq 的维度为[dim/2]。

经过 torch.einsum("i,j->ij", t, self.inv_freq) 之后维度为 [max_position_embeddings, dim/2]。

然后经过 emb = torch.cat((freqs, freqs), dim=-1) 操作,维度变为 [max_position_embeddings, dim]。

二维情况下旋转矩阵如下:

R(k)=(coskθ−sinkθsinkθcoskθ)R(k)=\begin{pmatrix} cosk\theta & -sink\theta \\ sink\theta & cosk\theta \\ \end{pmatrix}R(k)=(coskθsinkθ​−sinkθcoskθ​)

旋转位置编码计算公式如下:

R(k)x=(coskθ0coskθ0coskθ1coskθ1...coskθd/2−1coskθd/2−1)∘(x0x1x2x3...xd−2xd−1)+(sinkθ0sinkθ0sinkθ1sinkθ1...sinkθd/2−1sinkθd/2−1)∘(−x1x0−x3x2...−xd−1xd−2)R(k)x= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ cos{k\theta_1} \\ ... \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ ... \\ x_{d-2} \\ x_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ sin{k\theta_1} \\ ... \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ ... \\ -x_{d-1} \\ x_{d-2} \end{pmatrix} R(k)x=⎝⎛​coskθ0​coskθ0​coskθ1​coskθ1​...coskθd/2−1​coskθd/2−1​​⎠⎞​∘⎝⎛​x0​x1​x2​x3​...xd−2​xd−1​​⎠⎞​+⎝⎛​sinkθ0​sinkθ0​sinkθ1​sinkθ1​...sinkθd/2−1​sinkθd/2−1​​⎠⎞​∘⎝⎛​−x1​x0​−x3​x2​...−xd−1​xd−2​​⎠⎞​

在使用 LLM 时,我们希望对上下文长度进行拓展,以便能进行多轮对话,由此有下面几种方法:

外推法:直接沿用当前公式计算计算更长位置的编码。

这种方法比较简单,但是存在相关性衰减问题,如果模型训练语料在 2k 长度左右,模型能够学习到 2k 长度左右的 token 之间相关性关系的规律。

如果直接将此规律沿用到 5k 上下文,可能导致在中间某个位置相关性衰减到零,从而无法捕捉两个 token 之间的相关性。

线性内插:线性内插会改变编码公式,将 token 之间的距离等比例缩小。

例如在 2k 上下文情况下,两个 token 之间距离为 16,那么在 32k 上下文下,这两个 token 之间距离缩短到 1。

对于短距离的衰减规律,可能造成非常大的变化,但是线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。

transformers 中对于线性内插的实现如下:

class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  
        t = t / self.scaling_factor  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,在 t = t / self.scaling_factor 这行代码中,除以一个缩放因子,从而达到线性缩放的效果。

动态 NTK 扩展:外推法对于长距离的 token 不能很好计算相关性,线性内插对于短距离 token 计算相关性会产生很大变化,因此可以综合两者进行扩展。

为了在短距离情况下具有外推特性,长距离情况下具有内插特性,可以设置一个与位置序号有关频率因子,动态调整。

transformers 中对于动态 NTK 扩展的实现如下:

class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):  
    """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""  

    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):  
        self.scaling_factor = scaling_factor  
        super().__init__(dim, max_position_embeddings, base, device)  

    def _set_cos_sin_cache(self, seq_len, device, dtype):  
        self.max_seq_len_cached = seq_len  

        if seq_len > self.max_position_embeddings:  
            base = self.base * (  
                (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)  
            ) ** (self.dim / (self.dim - 2))  
            inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))  
            self.register_buffer("inv_freq", inv_freq, persistent=False)  

        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)  

        freqs = torch.einsum("i,j->ij", t, self.inv_freq)  
        # Different from paper, but it uses a different permutation in order to obtain the same calculation  
        emb = torch.cat((freqs, freqs), dim=-1)  
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)  
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

可以看到,如果长度超过 max_position_embeddings,对于 base 做出了如下公式操作:

base=base∗(factor∗seq_lenmax_len−(factor−1))dimdim−2base=base*(factor*\frac{seq\_len}{max\_len}-(factor-1))^{\frac{dim}{dim-2}}base=base∗(factor∗max_lenseq_len​−(factor−1))dim−2dim​

如果 seq_len > max_position_embeddings,在 factor = 1 的情况下,base 变大。

显然 base > 1,则 inv_freq 值变小,这样将短距离的规律扩展到了长距离。

具体计算位置编码的代码如下:

def rotate_half(x):  
    """Rotates half the hidden dims of the input."""  
    x1 = x[..., : x.shape[-1] // 2]  
    x2 = x[..., x.shape[-1] // 2 :]  
    return torch.cat((-x2, x1), dim=-1)  
  
  
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb  
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):  
    cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]  
    sin = sin[position_ids].unsqueeze(1)  
    q_embed = (q * cos) + (rotate_half(q) * sin)  
    k_embed = (k * cos) + (rotate_half(k) * sin)  
    return q_embed, k_embed

rotate_half() 中,将输入 x 沿着最后一维分隔成两部分,然后进行拼接。

Llama 中对 Q 的旋转位置编码按照如下方式计算:

R(k)Q=(coskθ0coskθ1...coskθd/2−1coskθ0coskθ1...coskθd/2−1)∘(q0q1...qd/2−1qd/2qd/2+1...qd−1)+(sinkθ0sinkθ1...sinkθd/2−1sinkθ0sinkθ1...sinkθd/2−1)∘(−qd/2−qd/2+1...−qd−1q0q1...qd−1)R(k)Q= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_1} \\ ... \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ ... \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} q_0 \\ q_1 \\ ... \\ q_{d/2-1} \\ q_{d/2} \\ q_{d/2+1} \\ ... \\ q_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_1} \\ ... \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ ... \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -q_{d/2} \\ -q_{d/2+1} \\ ... \\ -q_{d-1} \\ q_0 \\ q_1 \\ ... \\ q_{d-1} \end{pmatrix} R(k)Q=⎝⎛​coskθ0​coskθ1​...coskθd/2−1​coskθ0​coskθ1​...coskθd/2−1​​⎠⎞​∘⎝⎛​q0​q1​...qd/2−1​qd/2​qd/2+1​...qd−1​​⎠⎞​+⎝⎛​sinkθ0​sinkθ1​...sinkθd/2−1​sinkθ0​sinkθ1​...sinkθd/2−1​​⎠⎞​∘⎝⎛​−qd/2​−qd/2+1​...−qd−1​q0​q1​...qd−1​​⎠⎞​

这里只对 Q 和 K 加入位置编码信息。

前馈网络 LlamaMLP

transformers 中对于前馈网络的定义如下:

class LlamaMLP(nn.Module):  
    def __init__(self, config):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.intermediate_size = config.intermediate_size  
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)  
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)  
        self.act_fn = ACT2FN[config.hidden_act]  
  
    def forward(self, x):  
        if self.config.pretraining_tp > 1:  
            slice = self.intermediate_size // self.config.pretraining_tp  
            gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)  
            up_proj_slices = self.up_proj.weight.split(slice, dim=0)  
            down_proj_slices = self.down_proj.weight.split(slice, dim=1)  

            gate_proj = torch.cat(  
                [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1  
            )  
            up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)  

            intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)  
            down_proj = [  
                F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)  
            ]  
            down_proj = sum(down_proj)  
        else:  
            down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))  

        return down_proj

__init__() 函数中,定义了 hidden_sizeintermediate_size 控制模型尺寸。

同时定义了三个全连接层:

  • gate_proj:将 hidden_size 投影到 intermediate_size
  • up_proj:将 hidden_size 投影到 intermediate_size
  • down_proj:将 intermediate_size 投影到 hidden_size

这里会将输入通过 up_proj 先从 hidden_size 维度转换到 intermediate_size 维度,然后通过 down_proj 从 intermediate_size 维度转换到 hidden_size 维度。

同时里面采用 gate_proj 配合激活函数,实现了一个门控作用。

forward() 函数中会根据 config.pretraining_tp 选择不同的执行策略。这里是将三个全连接层切分为若干块,分别与输入 x 进行映射操作,得到多个子投影,然后将多个子投影拼接起来。

多头注意力 LlamaAttention

transformers 中对于多头注意力的定义如下:

class LlamaAttention(nn.Module):  
    """Multi-headed attention from 'Attention Is All You Need' paper"""  

    def __init__(self, config: LlamaConfig):  
        super().__init__()  
        self.config = config  
        self.hidden_size = config.hidden_size  
        self.num_heads = config.num_attention_heads  
        self.head_dim = self.hidden_size // self.num_heads  
        self.num_key_value_heads = config.num_key_value_heads  
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads  
        self.max_position_embeddings = config.max_position_embeddings  
        self.rope_theta = config.rope_theta  

        if (self.head_dim * self.num_heads) != self.hidden_size:  
        raise ValueError(  
            f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"  
            f" and `num_heads`: {self.num_heads})."  
        )  
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)  
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)  
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)  
        self._init_rope()  

    def _init_rope(self):  
        if self.config.rope_scaling is None:  
            self.rotary_emb = LlamaRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                base=self.rope_theta,  
            )  
        else:  
            scaling_type = self.config.rope_scaling["type"]  
            scaling_factor = self.config.rope_scaling["factor"]  
        if scaling_type == "linear":  
            self.rotary_emb = LlamaLinearScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        elif scaling_type == "dynamic":  
            self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(  
                self.head_dim,  
                max_position_embeddings=self.max_position_embeddings,  
                scaling_factor=scaling_factor,  
                base=self.rope_theta,  
            )  
        else:  
            raise ValueError(f"Unknown RoPE scaling type {scaling_type}")

这里主要定义了下面几种属性:

  • hidden_size:隐藏层的大小
  • num_heads:注意力头的数量
  • head_dim:每个注意力头的维度,它通过 hidden_size // num_heads 得到
  • num_key_value_heads:键值注意力头的数量
  • num_key_value_groups:键值注意力头分组数量,它通过 num_heads // num_key_value_heads 得到
  • rope_theta:即前面 RoPE 的 base

此外还定义了四个线性变换的全连接层,分别用于计算查询(Q)、键(K)、值(V)和输出(O)。

注意到键值 注意力头的数量与查询注意力头的数量不同。

键值注意力头数量可以是查询注意力头数量的几分之一,这样可以减少参数规模。

多头注意力的计算代码如下:

def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: bool = False,  
    use_cache: bool = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:  
    bsz, q_len, _ = hidden_states.size()  

    if self.config.pretraining_tp > 1:  
        key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp  
        query_slices = self.q_proj.weight.split(  
            (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0  
        )  
        key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)  
        value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)  

        query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]  
        query_states = torch.cat(query_states, dim=-1)  

        key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]  
        key_states = torch.cat(key_states, dim=-1)  

        value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]  
        value_states = torch.cat(value_states, dim=-1)  

    else:  
        query_states = self.q_proj(hidden_states)  
        key_states = self.k_proj(hidden_states)  
        value_states = self.v_proj(hidden_states)  

        query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)  
        key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)  

        kv_seq_len = key_states.shape[-2]  
    if past_key_value is not None:  
        kv_seq_len += past_key_value[0].shape[-2]  
        cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)  
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)  

    if past_key_value is not None:  
        # reuse k, v, self_attention  
        key_states = torch.cat([past_key_value[0], key_states], dim=2)  
        value_states = torch.cat([past_key_value[1], value_states], dim=2)  

        past_key_value = (key_states, value_states) if use_cache else None  

        key_states = repeat_kv(key_states, self.num_key_value_groups)  
        value_states = repeat_kv(value_states, self.num_key_value_groups)  

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)  

    if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):  
        raise ValueError(  
            f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"  
            f" {attn_weights.size()}"  
        )  

    if attention_mask is not None:  
        if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):  
            raise ValueError(  
                f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"  
            )  
        attn_weights = attn_weights + attention_mask  

    # upcast attention to fp32  
    attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
    attn_output = torch.matmul(attn_weights, value_states)  

    if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):  
    raise ValueError(  
    f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"  
    f" {attn_output.size()}"  
    )  

    attn_output = attn_output.transpose(1, 2).contiguous()  

    attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)  

    if self.config.pretraining_tp > 1:  
        attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)  
        o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)  
        attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])  
    else:  
        attn_output = self.o_proj(attn_output)  

    if not output_attentions:  
        attn_weights = None  

    return attn_output, attn_weights, past_key_value

多头注意力基本与《Attention Is All You Need》中一致,计算公式如下:

Attention(Q,K,V)=softmax(QKTdk)VAttention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})VAttention(Q,K,V)=softmax(dk​​QKT​)V

在 llama 中每进行一次注意力计算,都会对 Q 和 K 计算一次位置编码(RoPE)。

因为 K 和 V 注意力头数是 Q 的几分之一,所以在计算前首先进行 repeat 操作,对应代码如下:

key_states = repeat_kv(key_states, self.num_key_value_groups)  
value_states = repeat_kv(value_states, self.num_key_value_groups)

计算注意力的代码如下:

attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

attn_weights = attn_weights + attention_mask # 可选操作

attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)  
attn_output = torch.matmul(attn_weights, value_states)

最终 attn_output 经过 o_proj 的线性变换之后输出。

与前馈网络类似,如果 config 中设置 pretraining_tp,会对输入进行切片后分块操作。

解码层 LlamaDecoderLayer

transfromers 中对解码层的定义如下:

class LlamaDecoderLayer(nn.Module):  
    def __init__(self, config: LlamaConfig):  
    super().__init__()  
    self.hidden_size = config.hidden_size  
    self.self_attn = (  
        LlamaAttention(config=config)  
        if not getattr(config, "_flash_attn_2_enabled", False)  
        else LlamaFlashAttention2(config=config)  
    )  
    self.mlp = LlamaMLP(config)  
    self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  
    self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

解码层由 AttentionLayerMLP 和两个 LayerNorm 组成。

前向计算代码如下:

def forward(  
    self,  
    hidden_states: torch.Tensor,  
    attention_mask: Optional[torch.Tensor] = None,  
    position_ids: Optional[torch.LongTensor] = None,  
    past_key_value: Optional[Tuple[torch.Tensor]] = None,  
    output_attentions: Optional[bool] = False,  
    use_cache: Optional[bool] = False,  
    padding_mask: Optional[torch.LongTensor] = None,  
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:  
    """  
    Args:  
        hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`  
        attention_mask (`torch.FloatTensor`, *optional*): attention mask of size  
        `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.  
        output_attentions (`bool`, *optional*):  
        Whether or not to return the attentions tensors of all attention layers. See `attentions` under  
        returned tensors for more detail.  
        use_cache (`bool`, *optional*):  
        If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding  
        (see `past_key_values`).  
        past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states  
    """  

    residual = hidden_states  

    hidden_states = self.input_layernorm(hidden_states)  

    # Self Attention  
    hidden_states, self_attn_weights, present_key_value = self.self_attn(  
        hidden_states=hidden_states,  
        attention_mask=attention_mask,  
        position_ids=position_ids,  
        past_key_value=past_key_value,  
        output_attentions=output_attentions,  
        use_cache=use_cache,  
        padding_mask=padding_mask,  
    )  
    hidden_states = residual + hidden_states  

    # Fully Connected  
    residual = hidden_states  
    hidden_states = self.post_attention_layernorm(hidden_states)  
    hidden_states = self.mlp(hidden_states)  
    hidden_states = residual + hidden_states  

    outputs = (hidden_states,)  

    if output_attentions:  
        outputs += (self_attn_weights,)  

    if use_cache:  
        outputs += (present_key_value,)  

    return outputs

在解码器层中,输入 hidden_states 依次经历如下计算:

  1. 经过 input_layernorm 进行层归一化。
  2. 计算一次自注意力。
  3. 做一次残差连接。
  4. 经过 post_attention_layernorm 进行层归一化。
  5. 经过 mlp,并将结果与步骤3结果做一次残差连接。

模型 LlamaModel

transformers 中对模型定义如下:

class LlamaModel(LlamaPreTrainedModel):  
    """  
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]  

    Args:  
        config: LlamaConfig  
    """  

    def __init__(self, config: LlamaConfig):  
        super().__init__(config)  
        self.padding_idx = config.pad_token_id  
        self.vocab_size = config.vocab_size  

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)  
        self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])  
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)  

        self.gradient_checkpointing = False  
        # Initialize weights and apply final processing  
        self.post_init()

Llama 模型是由若干个解码层堆叠而成。

在前向传播时设置 gradient_checkpointing=True 可以节约显存。

但是这个参数不能和 use_cache=True 同时设置,这两个参数不兼容。

if self.gradient_checkpointing and self.training:  
    if use_cache:  
        logger.warning_once(  
            "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."  
        )  
    use_cache = False

在前向传播中自定义了前向传播函数:

def create_custom_forward(module):  
    def custom_forward(*inputs):  
        # None for past_key_value  
        return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)  
  
    return custom_forward

使用 torch.utils.checkpoint.checkpoint() 函数,它允许将前向传播的一部分分成小块以减小内存占用,并且可以在反向传播时实现显存优化。前提是设置 gradient_checkpointing=True

layer_outputs = torch.utils.checkpoint.checkpoint(  
    create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids  
)

代码中的 decode_layer 为前文中提到的解码器层。

经过多层解码器层后,将输出经过 RMSNorm 层,得到最终结果。

语言模型 LlamaForCausalLM

transformers 中对语言模型的定义如下:

class LlamaForCausalLM(LlamaPreTrainedModel):  
    _tied_weights_keys = ["lm_head.weight"]  

    def __init__(self, config):  
        super().__init__(config)  
        self.model = LlamaModel(config)  
        self.vocab_size = config.vocab_size  
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)  

        # Initialize weights and apply final processing  
        self.post_init()

实质是在前文提到的 LlamaModel 基础上加入一个 llm_head 来生成结果。

前向传播核心计算代码如下:

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)  
outputs = self.model(  
    input_ids=input_ids,  
    attention_mask=attention_mask,  
    position_ids=position_ids,  
    past_key_values=past_key_values,  
    inputs_embeds=inputs_embeds,  
    use_cache=use_cache,  
    output_attentions=output_attentions,  
    output_hidden_states=output_hidden_states,  
    return_dict=return_dict,  
)  
  
hidden_states = outputs[0]  
if self.config.pretraining_tp > 1:  
    lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)  
    logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]  
    logits = torch.cat(logits, dim=-1)  
else:  
    logits = self.lm_head(hidden_states)  
logits = logits.float()

如果输入 labels 会自动计算交叉熵损失。

分类模型 LlamaForSequenceClassification

分类模型也是由 LlamaModel 加上一个 score 的线性层构成。

在计算损失的时候,会根据不同的类型,选择不同的损失函数:

if self.config.problem_type == "regression":  
    loss_fct = MSELoss()  
    if self.num_labels == 1:  
        loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())  
    else:  
        loss = loss_fct(pooled_logits, labels)  
elif self.config.problem_type == "single_label_classification":  
    loss_fct = CrossEntropyLoss()  
    loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))  
elif self.config.problem_type == "multi_label_classification":  
    loss_fct = BCEWithLogitsLoss()  
    loss = loss_fct(pooled_logits, labels)

总结

以 LlamaModel 为例总结数据流向:

  • 输入的如果是 input_ids,会首先计算 inputs_embeds,然后作为 hidden_states,经过若干个 LlamaDecoderLayer、LlamaRMSNorm 组合后输出。
  • 在 LlamaDecoderLayer 中,经历如下步骤:
    1. 先记录原始输入,然后对于输入的 hidden_states 先做一次 LlamaRMSNorm。
    2. 对步骤1的结果做一次 LlamaAttention。
    3. 将步骤2的结果与原始输入做一次残差连接,并记录这次结果。
    4. 将步骤3的结果做一次 LlamaRMSNorm。
    5. 将步骤4的结果做一次 LlamaMLP。
    6. 将步骤5的结果与步骤3的结果做一次残差连接,将结果输出。
  • 在 LlamaAttention 中,经历如下步骤:
    1. 将输入的 hidden_states 做 Q、K、V 变换。
    2. 计算 Q、K 的旋转位置编码。
    3. 根据公式计算自注意力。
    4. 注意力经过线性变换后输出。
  • 在 LlamaMLP 中,经历如下步骤:
    1. 原始输入经过线性变换,得到上投影。
    2. 原始输入经过门函数和激活函数,得到门控投影。
    3. 将步骤1的上投影和步骤2的门控投影对应元素相乘。
    4. 将步骤3的结果经过线性变换,得到下投影,输出这个结果。

那么,我们该如何学习大模型?

作为一名热心肠的互联网老兵,我决定把宝贵的AI知识分享给大家。 至于能学习到多少就看你的学习毅力和能力了 。我已将重要的AI大模型资料包括AI大模型入门学习思维导图、精品AI大模型学习书籍手册、视频教程、实战学习等录播视频免费分享出来。

一、大模型全套的学习路线

学习大型人工智能模型,如GPT-3、BERT或任何其他先进的神经网络模型,需要系统的方法和持续的努力。既然要系统的学习大模型,那么学习路线是必不可少的,下面的这份路线能帮助你快速梳理知识,形成自己的体系。

L1级别:AI大模型时代的华丽登场

L2级别:AI大模型API应用开发工程

L3级别:大模型应用架构进阶实践

L4级别:大模型微调与私有化部署

一般掌握到第四个级别,市场上大多数岗位都是可以胜任,但要还不是天花板,天花板级别要求更加严格,对于算法和实战是非常苛刻的。建议普通人掌握到L4级别即可。

以上的AI大模型学习路线,不知道为什么发出来就有点糊 ,高清版可以微信扫描下方CSDN官方认证二维码免费领取【保证100%免费

二、640套AI大模型报告合集

这套包含640份报告的合集,涵盖了AI大模型的理论研究、技术实现、行业应用等多个方面。无论您是科研人员、工程师,还是对AI大模型感兴趣的爱好者,这套报告合集都将为您提供宝贵的信息和启示。

三、大模型经典PDF籍

随着人工智能技术的飞速发展,AI大模型已经成为了当今科技领域的一大热点。这些大型预训练模型,如GPT-3、BERT、XLNet等,以其强大的语言理解和生成能力,正在改变我们对人工智能的认识。 那以下这些PDF籍就是非常不错的学习资源。

四、AI大模型商业化落地方案

作为普通人,入局大模型时代需要持续学习和实践,不断提高自己的技能和认知水平,同时也需要有责任感和伦理意识,为人工智能的健康发展贡献力量。

相关推荐
中关村科金6 分钟前
中关村科金智能客服机器人如何解决客户个性化需求与标准化服务之间的矛盾?
人工智能·机器人·在线客服·智能客服机器人·中关村科金
逸_9 分钟前
Product Hunt 今日热榜 | 2024-12-25
人工智能
Luke Ewin15 分钟前
基于3D-Speaker进行区分说话人项目搭建过程报错记录 | 通话录音说话人区分以及语音识别 | 声纹识别以及语音识别 | pyannote-audio
人工智能·语音识别·声纹识别·通话录音区分说话人
DashVector29 分钟前
如何通过HTTP API检索Doc
数据库·人工智能·http·阿里云·数据库开发·向量检索
说私域33 分钟前
无人零售及开源 AI 智能名片 S2B2C 商城小程序的深度剖析
人工智能·小程序·零售
Calvin88082841 分钟前
Android Studio 的革命性更新:Project Quartz 和 Gemini,开启 AI 开发新时代!
android·人工智能·android studio
Jamence1 小时前
【深度学习数学知识】-贝叶斯公式
人工智能·深度学习·概率论
feifeikon1 小时前
机器学习DAY4续:梯度提升与 XGBoost (完)
人工智能·深度学习·机器学习
凡人的AI工具箱2 小时前
每天40分玩转Django:实操多语言博客
人工智能·后端·python·django·sqlite
Jackilina_Stone2 小时前
【自动驾驶】3 激光雷达③
人工智能·自动驾驶