探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(九)Transformer架构
Llama Transformer架构
现在将所有单独的组件堆叠到Transformer中:
python
class ModelArgs:
dim: int = 4096
n_layers: int = 32
n_heads: int = 32 # Number of heads for the queries
n_kv_heads: Optional[int] = None # Number of heads for the keys and values. If None, defaults to n_heads
vocab_size: int = -1 # This will be set when we load the tokenizer
multiple_of: int = 256
ffn_dim_multiplier: Optional[float] = None # If None, defaults to 4.0
norm_eps: float = 1e-6 # only the eps value has been modified int llama 3
# Needed for KV cache
max_batch_size: int = 32
max_seq_len: int = 2048
device: str = None
python
class Transformer(nn.Module):
def __init__(self, args: ModelArgs) -> None:
super().__init__()
assert args.vocab_size != -1, "Vocab size must be set"
self.args = args
self.vocab_size = args.vocab_size
self.n_layers = args.n_layers
self.tok_embeddings = nn.Embedding(self.vocab_size, args.dim)
self.layers = nn.ModuleList()
for _ in range(args.n_layers):
self.layers.append(EncoderBlock(args))
self.norm = RMSNorm(args.dim, eps=args.norm_eps)
self.output = nn.Linear(args.dim, self.vocab_size, bias=False)
# To precompute the frequencies of the Rotary Positional Encodings
self.freqs_complex = precomputed_theta_pos_frequencies(self.args.dim // self.args.n_heads, self.args.max_seq_len * 2, device=self.args.device)
def forward(self, tokens: torch.Tensor, start_pos: int):
# (B, seq_len)
batch_size, seq_len = tokens.shape
assert seq_len == 1, "Only one token at a time can be processed"
# (B, seq_len) -> (B, seq_len, dim)
h = self.tok_embeddings(tokens)
# Retrive the pairs (m, theta) corresponding to the positions [start_pos, start_pos + seq_len]
freqs_complex = self.freqs_complex[start_pos:start_pos + seq_len]
# Consecutively apply all the encoder layers
for layer in self.layers:
h = layer(h, start_pos, freqs_complex)
h = self.norm(h)
output = self.output(h).float()
return output
data:image/s3,"s3://crabby-images/8e8bc/8e8bc0ac2a44327f20bb75a20015934f66777752" alt=""
系列博客
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(一)Llama3 模型 架构
https://duanzhihua.blog.csdn.net/article/details/138208650
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(二)RoPE位置编码
https://duanzhihua.blog.csdn.net/article/details/138212328
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(三)KV缓存
https://duanzhihua.blog.csdn.net/article/details/138213306
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(四)分组多查询注意力
https://duanzhihua.blog.csdn.net/article/details/138216050
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(五)RMS 均方根归一化
https://duanzhihua.blog.csdn.net/article/details/138216630
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(六)SwiGLU 激活函数
https://duanzhihua.blog.csdn.net/article/details/138217261
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(七)前馈神经网络
https://duanzhihua.blog.csdn.net/article/details/138218095
探索和构建 LLaMA 3 架构:深入探讨组件、编码和推理技术(八)Transformer块
https://mp.csdn.net/mp_blog/manage/article?spm=1011.2423.3001.5298