学习一下 transformers 库中,Llama 模型的代码,学习过程中写下这篇笔记,一来加深印象,二来可以多次回顾。
笔者小白,里面错误之处请不吝指出。
层归一化 LlamaRMSNorm
transformers 中对于 LlamaRMSNorm 类的定义如下:
python
class LlamaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
def forward(self, hidden_states):
input_dtype = hidden_states.dtype
hidden_states = hidden_states.to(torch.float32)
variance = hidden_states.pow(2).mean(-1, keepdim=True)
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
return self.weight * hidden_states.to(input_dtype)
这里采用了 RMS(Root Mean Square) 归一化,其中 RMS 计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R M S ( x ) = 1 n ∑ x i 2 RMS(x)=\sqrt{\frac{1}{n}\sum{x_i^2}} </math>RMS(x)=n1∑xi2
则 RMSNorm 归一化的计算公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R M S ( x ) = x R M S ( x ) + ϵ ∗ W RMS(x)=\frac{x}{\sqrt{RMS(x)+\epsilon}} * W </math>RMS(x)=RMS(x)+ϵ x∗W
加上一个小常数,确保分母不为零,保持数据稳定性。
旋转位置编码 LlamaRotaryEmbedding
- 绝对位置编码:计算高效,效果欠佳
- 相对位置编码:满足 NLP 领域在序列长度方向上具有平移不变性,计算效率低。
- 旋转位置编码:采用绝对位置编码达到相位置编码的效果
transformers 中对于 LlamaRotaryEmbedding 类的定义如下,它用于实现旋转位置嵌入:
python
class LlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
# Build here to make `torch.jit.trace` work.
self._set_cos_sin_cache(
seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
def forward(self, x, seq_len=None):
# x: [bs, num_attention_heads, seq_len, head_size]
if seq_len > self.max_seq_len_cached:
self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
return (
self.cos_cached[:seq_len].to(dtype=x.dtype),
self.sin_cached[:seq_len].to(dtype=x.dtype),
)
其中定义的变量意义如下:
- dim:表示模型输出维度
- max_position_embeddings:最大编码长度,默认为2048
- base:基数,默认为10000
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
实现公式为:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> i n v _ f r e q = 1 b a s e i / d i m inv\_freq=\frac{1}{base^{i/dim}} </math>inv_freq=basei/dim1
在上面代码中,t
的维度为[max_position_embeddings], inv_freq
的维度为[dim/2]。
经过 torch.einsum("i,j->ij", t, self.inv_freq)
之后维度为 [max_position_embeddings, dim/2]。
然后经过 emb = torch.cat((freqs, freqs), dim=-1)
操作,维度变为 [max_position_embeddings, dim]。
二维情况下旋转矩阵如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) = ( c o s k θ − s i n k θ s i n k θ c o s k θ ) R(k)=\begin{pmatrix} cosk\theta & -sink\theta \\ sink\theta & cosk\theta \\ \end{pmatrix} </math>R(k)=(coskθsinkθ−sinkθcoskθ)
旋转位置编码计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) x = ( c o s k θ 0 c o s k θ 0 c o s k θ 1 c o s k θ 1 . . . c o s k θ d / 2 − 1 c o s k θ d / 2 − 1 ) ∘ ( x 0 x 1 x 2 x 3 . . . x d − 1 x d ) + ( s i n k θ 0 s i n k θ 0 s i n k θ 1 s i n k θ 1 . . . s i n k θ d / 2 − 1 s i n k θ d / 2 − 1 ) ∘ ( − x 1 x 0 − x 3 x 2 . . . − x d x d − 1 ) R(k)x= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_0} \\ cos{k\theta_1} \\ cos{k\theta_1} \\ ... \\ cos{k\theta_{d/2-1}} \\ cos{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ ... \\ x_{d-1} \\ x_d \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_0} \\ sin{k\theta_1} \\ sin{k\theta_1} \\ ... \\ sin{k\theta_{d/2-1}} \\ sin{k\theta_{d/2-1}} \end{pmatrix} \circ \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ ... \\ -x_d \\ x_{d-1} \end{pmatrix} </math>R(k)x=⎝ ⎛coskθ0coskθ0coskθ1coskθ1...coskθd/2−1coskθd/2−1⎠ ⎞∘⎝ ⎛x0x1x2x3...xd−1xd⎠ ⎞+⎝ ⎛sinkθ0sinkθ0sinkθ1sinkθ1...sinkθd/2−1sinkθd/2−1⎠ ⎞∘⎝ ⎛−x1x0−x3x2...−xdxd−1⎠ ⎞
在使用 LLM 时,我们希望对上下文长度进行拓展,以便能进行多轮对话,由此有下面几种方法:
外推法:直接沿用当前公式计算计算更长位置的编码。
这种方法比较简单,但是存在相关性衰减问题,如果模型训练语料在 2k 长度左右,模型能够学习到 2k 长度左右的 token 之间相关性关系的规律。
如果直接将此规律沿用到 5k 上下文,可能导致在中间某个位置相关性衰减到零,从而无法捕捉两个 token 之间的相关性。
线性内插:线性内插会改变编码公式,将 token 之间的距离等比例缩小。
例如在 2k 上下文情况下,两个 token 之间距离为 16,那么在 32k 上下文下,这两个 token 之间距离缩短到 1。
对于短距离的衰减规律,可能造成非常大的变化,但是线性内插没有改变模型学习到的衰减规律的应用范围,不考虑微调的话,其效果一般好于直接外推方案。
transformers 中对于线性内插的实现如下:
python
class LlamaLinearScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
t = t / self.scaling_factor
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
可以看到,在 t = t / self.scaling_factor
这行代码中,除以一个缩放因子,从而达到线性缩放的效果。
动态 NTK 扩展:外推法对于长距离的 token 不能很好计算相关性,线性内插对于短距离 token 计算相关性会产生很大变化,因此可以综合两者进行扩展。
为了在短距离情况下具有外推特性,长距离情况下具有内插特性,可以设置一个与位置序号有关频率因子,动态调整。
transformers 中对于动态 NTK 扩展的实现如下:
python
class LlamaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
"""LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
self.scaling_factor = scaling_factor
super().__init__(dim, max_position_embeddings, base, device)
def _set_cos_sin_cache(self, seq_len, device, dtype):
self.max_seq_len_cached = seq_len
if seq_len > self.max_position_embeddings:
base = self.base * (
(self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
) ** (self.dim / (self.dim - 2))
inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1)
self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)
可以看到,如果长度超过 max_position_embeddings
,对于 base
做出了如下公式操作:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> b a s e = b a s e ∗ ( f a c t o r ∗ s e q _ l e n m a x _ l e n − ( f a c t o r − 1 ) ) d i m d i m − 2 base=base*(factor*\frac{seq\_len}{max\_len}-(factor-1))^{\frac{dim}{dim-2}} </math>base=base∗(factor∗max_lenseq_len−(factor−1))dim−2dim
如果 seq_len > max_position_embeddings,在 factor = 1 的情况下,base 变大。
显然 base > 1,则 inv_freq 值变小,这样将短距离的规律扩展到了长距离。
具体计算位置编码的代码如下:
python
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., : x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=-1)
# Copied from transformers.models.gpt_neox.modeling_gpt_neox.apply_rotary_pos_emb
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [seq_len, dim] -> [batch_size, 1, seq_len, head_dim]
sin = sin[position_ids].unsqueeze(1)
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
在 rotate_half()
中,将输入 x
沿着最后一维分隔成两部分,然后进行拼接。
Llama 中对 Q 的旋转位置编码按照如下方式计算:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> R ( k ) Q = ( c o s k θ 0 c o s k θ 2 . . . c o s k θ d − 2 c o s k θ 0 c o s k θ 2 . . . c o s k θ d − 2 ) ∘ ( q 0 q 1 . . . q d / 2 − 1 q d / 2 q d / 2 + 1 . . . q d − 1 ) + ( s i n k θ 0 s i n k θ 2 . . . s i n k θ d − 2 s i n k θ 0 s i n k θ 2 . . . s i n k θ d ) ∘ ( − q d / 2 − q d / 2 + 1 . . . − q d − 1 q 0 q 1 . . . q d − 1 ) R(k)Q= \begin{pmatrix} cos{k\theta_0} \\ cos{k\theta_2} \\ ... \\ cos{k\theta_{d-2}} \\ cos{k\theta_0} \\ cos{k\theta_2} \\ ... \\ cos{k\theta_{d-2}} \end{pmatrix} \circ \begin{pmatrix} q_0 \\ q_1 \\ ... \\ q_{d/2-1} \\ q_{d/2} \\ q_{d/2+1} \\ ... \\ q_{d-1} \end{pmatrix} + \begin{pmatrix} sin{k\theta_0} \\ sin{k\theta_2} \\ ... \\ sin{k\theta_{d-2}} \\ sin{k\theta_0} \\ sin{k\theta_2} \\ ... \\ sin{k\theta_{d}} \end{pmatrix} \circ \begin{pmatrix} -q_{d/2} \\ -q_{d/2+1} \\ ... \\ -q_{d-1} \\ q_0 \\ q_1 \\ ... \\ q_{d-1} \end{pmatrix} </math>R(k)Q=⎝ ⎛coskθ0coskθ2...coskθd−2coskθ0coskθ2...coskθd−2⎠ ⎞∘⎝ ⎛q0q1...qd/2−1qd/2qd/2+1...qd−1⎠ ⎞+⎝ ⎛sinkθ0sinkθ2...sinkθd−2sinkθ0sinkθ2...sinkθd⎠ ⎞∘⎝ ⎛−qd/2−qd/2+1...−qd−1q0q1...qd−1⎠ ⎞
这里只对 Q 和 K 加入位置编码信息。
前馈网络 LlamaMLP
transformers 中对于前馈网络的定义如下:
python
class LlamaMLP(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.intermediate_size = config.intermediate_size
self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
self.act_fn = ACT2FN[config.hidden_act]
def forward(self, x):
if self.config.pretraining_tp > 1:
slice = self.intermediate_size // self.config.pretraining_tp
gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
up_proj_slices = self.up_proj.weight.split(slice, dim=0)
down_proj_slices = self.down_proj.weight.split(slice, dim=1)
gate_proj = torch.cat(
[F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
)
up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
down_proj = [
F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
]
down_proj = sum(down_proj)
else:
down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
return down_proj
在 __init__()
函数中,定义了 hidden_size
和 intermediate_size
控制模型尺寸。
同时定义了三个全连接层:
gate_proj
:将 hidden_size 投影到 intermediate_sizeup_proj
:将 hidden_size 投影到 intermediate_sizedown_proj
:将 intermediate_size 投影到 hidden_size
这里会将输入通过 up_proj
先从 hidden_size 维度转换到 intermediate_size 维度,然后通过 down_proj
从 intermediate_size 维度转换到 hidden_size 维度。
同时里面采用 gate_proj
配合激活函数,实现了一个门控作用。
在 forward()
函数中会根据 config.pretraining_tp
选择不同的执行策略。这里是将三个全连接层切分为若干块,分别与输入 x
进行映射操作,得到多个子投影,然后将多个子投影拼接起来。
多头注意力 LlamaAttention
transformers 中对于多头注意力的定义如下:
python
class LlamaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: LlamaConfig):
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.num_key_value_heads = config.num_key_value_heads
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = config.rope_theta
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
self._init_rope()
def _init_rope(self):
if self.config.rope_scaling is None:
self.rotary_emb = LlamaRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
base=self.rope_theta,
)
else:
scaling_type = self.config.rope_scaling["type"]
scaling_factor = self.config.rope_scaling["factor"]
if scaling_type == "linear":
self.rotary_emb = LlamaLinearScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
elif scaling_type == "dynamic":
self.rotary_emb = LlamaDynamicNTKScalingRotaryEmbedding(
self.head_dim,
max_position_embeddings=self.max_position_embeddings,
scaling_factor=scaling_factor,
base=self.rope_theta,
)
else:
raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
这里主要定义了下面几种属性:
- hidden_size:隐藏层的大小
- num_heads:注意力头的数量
- head_dim:每个注意力头的维度,它通过 hidden_size // num_heads 得到
- num_key_value_heads:键值注意力头的数量
- num_key_value_groups:键值注意力头分组数量,它通过 num_heads // num_key_value_heads 得到
- rope_theta:即前面 RoPE 的 base
此外还定义了四个线性变换的全连接层,分别用于计算查询(Q)、键(K)、值(V)和输出(O)。
注意到键值 注意力头的数量与查询注意力头的数量不同。
键值注意力头数量可以是查询注意力头数量的几分之一,这样可以减少参数规模。
多头注意力的计算代码如下:
python
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if self.config.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
query_slices = self.q_proj.weight.split(
(self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
if self.config.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
多头注意力基本与《Attention Is All You Need》中一致,计算公式如下:
<math xmlns="http://www.w3.org/1998/Math/MathML" display="block"> A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt{d_k}})V </math>Attention(Q,K,V)=softmax(dk QKT)V
在 llama 中每进行一次注意力计算,都会对 Q 和 K 计算一次位置编码(RoPE)。
因为 K 和 V 注意力头数是 Q 的几分之一,所以在计算前首先进行 repeat 操作,对应代码如下:
python
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
计算注意力的代码如下:
python
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
attn_weights = attn_weights + attention_mask # 可选操作
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
最终 attn_output 经过 o_proj
的线性变换之后输出。
与前馈网络类似,如果 config 中设置 pretraining_tp
,会对输入进行切片后分块操作。
解码层 LlamaDecoderLayer
transfromers 中对解码层的定义如下:
python
class LlamaDecoderLayer(nn.Module):
def __init__(self, config: LlamaConfig):
super().__init__()
self.hidden_size = config.hidden_size
self.self_attn = (
LlamaAttention(config=config)
if not getattr(config, "_flash_attn_2_enabled", False)
else LlamaFlashAttention2(config=config)
)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
解码层由 AttentionLayer 、MLP 和两个 LayerNorm 组成。
前向计算代码如下:
python
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: Optional[bool] = False,
use_cache: Optional[bool] = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
# Self Attention
hidden_states, self_attn_weights, present_key_value = self.self_attn(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_value,
output_attentions=output_attentions,
use_cache=use_cache,
padding_mask=padding_mask,
)
hidden_states = residual + hidden_states
# Fully Connected
residual = hidden_states
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states = self.mlp(hidden_states)
hidden_states = residual + hidden_states
outputs = (hidden_states,)
if output_attentions:
outputs += (self_attn_weights,)
if use_cache:
outputs += (present_key_value,)
return outputs
在解码器层中,输入 hidden_states 依次经历如下计算:
- 经过 input_layernorm 进行层归一化。
- 计算一次自注意力。
- 做一次残差连接。
- 经过 post_attention_layernorm 进行层归一化。
- 经过 mlp,并将结果与步骤3结果做一次残差连接。
模型 LlamaModel
transformers 中对模型定义如下:
python
class LlamaModel(LlamaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def __init__(self, config: LlamaConfig):
super().__init__(config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([LlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
Llama 模型是由若干个解码层堆叠而成。
在前向传播时设置 gradient_checkpointing=True
可以节约显存。
但是这个参数不能和 use_cache=True
同时设置,这两个参数不兼容。
python
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
在前向传播中自定义了前向传播函数:
python
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, past_key_value, output_attentions, padding_mask=padding_mask)
return custom_forward
使用 torch.utils.checkpoint.checkpoint()
函数,它允许将前向传播的一部分分成小块以减小内存占用,并且可以在反向传播时实现显存优化。前提是设置 gradient_checkpointing=True
。
python
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(decoder_layer), hidden_states, attention_mask, position_ids
)
代码中的 decode_layer
为前文中提到的解码器层。
经过多层解码器层后,将输出经过 RMSNorm 层,得到最终结果。
语言模型 LlamaForCausalLM
transformers 中对语言模型的定义如下:
python
class LlamaForCausalLM(LlamaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]
def __init__(self, config):
super().__init__(config)
self.model = LlamaModel(config)
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
实质是在前文提到的 LlamaModel 基础上加入一个 llm_head
来生成结果。
前向传播核心计算代码如下:
python
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
hidden_states = outputs[0]
if self.config.pretraining_tp > 1:
lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
logits = torch.cat(logits, dim=-1)
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
如果输入 labels 会自动计算交叉熵损失。
分类模型 LlamaForSequenceClassification
分类模型也是由 LlamaModel 加上一个 score
的线性层构成。
在计算损失的时候,会根据不同的类型,选择不同的损失函数:
python
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(pooled_logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(pooled_logits, labels)
总结
以 LlamaModel 为例总结数据流向:
- 输入的如果是 input_ids,会首先计算 inputs_embeds,然后作为 hidden_states,经过若干个 LlamaDecoderLayer、LlamaRMSNorm 组合后输出。
- 在 LlamaDecoderLayer 中,经历如下步骤:
- 先记录原始输入,然后对于输入的 hidden_states 先做一次 LlamaRMSNorm。
- 对步骤1的结果做一次 LlamaAttention。
- 将步骤2的结果与原始输入做一次残差连接,并记录这次结果。
- 将步骤3的结果做一次 LlamaAttention。
- 将步骤4的结果做一次 LlamaMLP。
- 将步骤5的结果与步骤3的结果做一次残差连接,将结果输出。
- 在 LlamaAttention 中,经历如下步骤:
- 将输入的 hidden_states 做 Q、K、V 变换。
- 计算 Q、K 的旋转位置编码。
- 根据公式计算自注意力。
- 注意力经过线性变换后输出。
- 在 LlamaMLP 中,经历如下步骤:
- 原始输入经过线性变换,得到上投影。
- 原始输入经过门函数和激活函数,得到门控投影。
- 将步骤1的上投影和步骤2的门控投影对应元素相乘。
- 将步骤3的结果经过线性变换,得到下投影,输出这个结果。
学习参考
在学习过程中,参考了下面的文章或视频,感谢各位大佬分享自己的知识: