Llama的框架图如图:
源码中含有大量分布式训练相关的代码,读起来比较晦涩难懂,所以我们对llama自顶向下进行了解析及复现,我们对其划分成三层,分别是顶层、中层、和底层,如下:
Llama的整体组成
由上图可知,Llama整体是由1个embedding层,n个transformer层,和1个RMSNorm层组成的,所以顶层代码如下:
顶层
bash
class Llama(torch.nn.Module):
def __init__(self, config: ModelArgs):
super().__init__()
self.config = config
# embedding层
self.tok_embeddings = torch.nn.Embedding(self.config.vocab_size, self.config.dim)
# RMSNorm
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
# n层Transformer
self.layers = torch.nn.ModuleList()
for i in range(self.config.n_layers):
self.layers.append(TransformerBlock(config))
def forward(self, tokens):
# 进行token的嵌入编码
h = self.tok_embeddings(tokens)
# decoder架构需要生成一个mask
seqlen = h.shape[1]
mask = torch.full((seqlen, seqlen), float('-inf'), device=tokens.device)
mask = torch.triu(mask, diagonal=1)
# 进行n层Transformer
for i in range(self.config.n_layers):
h = self.layers[i](h, mask)
# 进行RMSNorm
token_embeddings = self.norm(h)
return token_embeddings
中层
我们首先进行RMSNorm的复现
bash
class RMSNorm(torch.nn.Module):
def __init__(self, dim, eps):
super().__init__()
self.eps = eps
self.weight = torch.nn.Parameter(torch.ones(dim))
def _norm(self, tensor):
return tensor * torch.rsqrt(tensor.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, tensor):
output = self._norm(tensor)
return output * self.weight
然后对Transformer进行复现,在Transformer中,Transformer包括两个RMSNorm层,一个多头attention层,一个全连接层。
bash
class TransformerBlock(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
# 多头注意力层
self.attention = Attention(config)
# Norm层
self.attention_normal = RMSNorm(config.dim, config.norm_eps)
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
# 全连接层
self.ffn = FeedForwad(self.config.dim, self.config.dim * 4)
def forward(self, embeddings, mask):
# norm
h = self.attention_normal(embeddings)
# attention
h = self.attention(h, mask)
# add & norm
h = self.ffn_norm(h + embeddings)
# fnn
f = self.ffn(h)
# add
return f + h
底层
在多头attention中,首先需要对token的嵌入进行空间映射,多头拆分,旋转位置编码,分数计算等操作
bash
class Attention(torch.nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.n_head = config.n_heads
self.dim = config.dim // self.n_head
self.k = torch.nn.Linear(config.dim, config.dim)
self.q = torch.nn.Linear(config.dim, config.dim)
self.v = torch.nn.Linear(config.dim, config.dim)
def forward(self, embeddings, mask):
bsz, seq_len, dim = embeddings.shape
k_embeddings = self.k(embeddings)
q_embeddings = self.q(embeddings)
v_embeddings = self.v(embeddings)
n_q_embeddings = q_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
n_k_embeddings = k_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
n_v_embeddings = v_embeddings.reshape(bsz, -1, self.n_head, self.dim).permute(0, 2, 1, 3)
rotated_n_q_embeddings = compute_rotated_embedding(n_q_embeddings, self.dim, seq_len, self.config.rope_theta)
rotated_n_k_embeddings = compute_rotated_embedding(n_k_embeddings, self.dim, seq_len, self.config.rope_theta)
scores = torch.nn.functional.softmax(mask + rotated_n_q_embeddings @ rotated_n_k_embeddings.transpose(-1, -2)
/ math.sqrt(self.dim), dim=-1)
n_embeddings = scores @ n_v_embeddings
embeddings = n_embeddings.permute(0, 2, 1, 3).reshape(bsz, -1, self.config.dim)
return embeddings
bash
class FeedForwad(torch.nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.linear1 = torch.nn.Linear(dim, hidden_dim)
self.linear2 = torch.nn.Linear(dim, hidden_dim)
self.linear3 = torch.nn.Linear(hidden_dim, dim)
def forward(self, embeddings):
gate = torch.nn.functional.silu(self.linear1(embeddings))
up_proj = self.linear2(embeddings) * gate
return self.linear3(up_proj)
最后,我们复现旋转位置编码,至此我们捋清了llama的所有结构!
bash
def compute_rotated_embedding(embedding, dim, m, base):
# 计算所有嵌入位置的旋转角度
all_theta = compute_all_theta(dim, m, base)
# 旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
# 1、将嵌入投影到复数平面
embedding_real_pair = embedding.reshape(*embedding.shape[:-1], -1, 2)
embedding_complex_pair = torch.view_as_complex(embedding_real_pair)
# 2、将旋转角度投影到复数平面
all_theta = all_theta[: embedding.shape[-2]]
theta_complex_pair = torch.polar(torch.ones_like(all_theta), all_theta)
# 3、旋转后嵌入位置 = 复数平面上初始位置 * 复数平面上角度坐标
rotated_complex_embedding = embedding_complex_pair * theta_complex_pair
# 4、将复数平面的嵌入投影到实数平面
rotated_real_embedding = torch.view_as_real(rotated_complex_embedding)
rotated_real_embedding = rotated_real_embedding.reshape(*embedding.shape[:-1], -1)
return rotated_real_embedding
def compute_all_theta(dim, m, base):
theta = 1 / (base ** (torch.arange(0, dim / 2).float() / (dim / 2)))
m = torch.arange(0, m)
all_theta = torch.outer(m, theta)
return all_theta
附录:llama的config参数
bash
@dataclass
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32
n_kv_heads: Optional[int] = None
vocab_size: int = -1
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
ffn_dim_multiplier: Optional[float] = None
norm_eps: float = 1e-5
rope_theta: float = 500000
max_batch_size: int = 32
max_seq_len: int = 2048
use_scaled_rope: bool = True