从零手搓大语言模型:模型结构篇

本文为"从零手搓大语言模型"系列第 1 篇,基于 MiniMind 项目(64M 参数)的实践经验,系统梳理 Transformer Decoder-Only 架构的核心组件与设计原理。

一、整体架构概览

现代大语言模型(LLM)的核心架构为 Transformer Decoder-Only 结构。与 Encoder-Decoder 架构不同,Decoder-Only 仅执行单一任务:给定上文,预测下一个 token。GPT、Llama、Qwen、DeepSeek 等主流模型均采用此架构。

MiniMind 的模型结构如下:

yaml 复制代码
MiniMindForCausalLM(因果语言模型外壳)
├── embed_tokens:Token Embedding,将 token_id 映射为 768 维向量
├── 8 × MiniMindBlock(Transformer 层,逐层堆叠)
│   ├── RMSNorm → Attention(多头注意力) → 残差连接
│   └── RMSNorm → FeedForward(SwiGLU) → 残差连接
├── RMSNorm(最终归一化)
└── lm_head:线性投影层,将 768 维向量映射为 6400 维词表概率分布

二、RMSNorm:归一化层

2.1 为什么需要归一化

神经网络逐层计算时,矩阵乘法会不断放大数值尺度。若不加约束,经过多层叠加后数值将溢出或梯度消失,导致训练不稳定。归一化的作用是在每个计算模块入口将数值拉回标准范围。

2.2 RMSNorm 的实现

python 复制代码
class RMSNorm(nn.Module):
    def __init__(self, dim, eps=1e-5):
        self.weight = nn.Parameter(torch.ones(dim))  # 可学习的缩放因子

    def norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        return self.weight * self.norm(x.float())

RMSNorm 仅对均方根做缩放,不减去均值(相比 LayerNorm 少一步计算),在实践中效果相当但速度更快。

2.3 归一化不会丢失语义

归一化仅改变向量的"长度"(尺度),不改变向量的"方向"(维度间的比例关系)。语义信息主要编码在方向中,因此不会因归一化而丢失。此外,可学习的 weight 参数允许模型恢复必要的尺度信息。

2.4 三处归一化的职责

  • input_layernorm:Attention 计算前的尺度校准
  • post_attention_layernorm:FeedForward 计算前的尺度校准
  • final norm(MiniMindModel.norm):送入 lm_head 前的最终校准

三者功能相同,位置不同,确保每个计算模块接收到尺度正常的输入。

三、RoPE:旋转位置编码

3.1 问题背景

Transformer 的自注意力机制不具备天然的位置感知能力。"我打你"与"你打我"在不加位置信息时对模型而言是等价的。

3.2 核心思想

RoPE 将每个位置的 Q 和 K 向量按特定角度旋转。位置不同,旋转角度不同。两个词做点积时,结果仅与它们的相对距离有关。

3.3 多频率设计

768 维向量被拆分为 384 对,每对分配一个不同的旋转频率:

python 复制代码
freqs = 1.0 / (1000000 ** (torch.arange(0, dim, 2) / dim))
  • 高频维度:每步旋转角度大,适合捕捉近距离位置关系
  • 低频维度:每步旋转角度极小,适合区分远距离位置

384 个频率的组合确保了任意两个位置的编码都是唯一的,类似于钟表上秒针、分针、时针的组合能唯一标识时刻。

3.4 YaRN 长度外推

当推理序列超出训练长度时,高频维度的角度进入模型未见过的范围。YaRN 通过降低高频维度的频率(除以 factor),将远距离位置的角度"压缩"回模型熟悉的范围。

四、GQA:分组查询注意力

4.1 标准多头注意力(MHA)

Q、K、V 各有 8 个头,每个头独立计算注意力。推理时需要缓存所有 K、V 头的历史状态(KV Cache),显存开销大。

4.2 GQA 的优化

MiniMind 采用 Q 头数 = 8、KV 头数 = 4 的配置。每 2 个 Q 头共享 1 组 KV:

python 复制代码
self.q_proj = nn.Linear(768, 8 * 96)   # 8 个 Q 头
self.k_proj = nn.Linear(768, 4 * 96)   # 4 个 KV 头
self.v_proj = nn.Linear(768, 4 * 96)

计算前通过 repeat_kv 将 4 组 KV 复制扩展为 8 组以匹配 Q 头数。

4.3 收益

KV Cache 减半,推理速度和显存均改善,而模型效果几乎无损(Google 在 65B 规模实验中验证效果损失 < 0.5%)。

五、SwiGLU:门控前馈网络

5.1 FeedForward 的作用

Attention 处理词间关系,FeedForward 对每个位置独立做知识加工。其基本结构为:升维 → 激活 → 降维。

  • 升维(768 → 2432):将特征展开到更高维空间,便于分离不同模式
  • 激活:引入非线性,使模型能学到复杂规律
  • 降维(2432 → 768):将筛选后的信息压缩回原始维度

5.2 SwiGLU 的结构

python 复制代码
def forward(self, x):
    gate = F.silu(self.gate_proj(x))    # 门控信号:每个维度的"开关"
    info = self.up_proj(x)              # 信息流:实际内容
    return self.down_proj(gate * info)  # 门控 × 信息 → 降维

相比普通 ReLU 激活,SwiGLU 引入了门控机制,模型能精细控制"哪些信息通过、通过多少",而非 ReLU 的二值开关。

六、残差连接

每个 MiniMindBlock 包含两次残差连接:

python 复制代码
residual = hidden_states
hidden_states = self.self_attn(self.input_layernorm(hidden_states))
hidden_states += residual  # 第一次残差

hidden_states = hidden_states + self.mlp(self.post_attention_layernorm(hidden_states))  # 第二次残差

每个子模块(Attention、FeedForward)独立拥有一条残差通道,确保:

  1. 梯度能沿直通路径回传,避免梯度消失
  2. 即使某层学习效果不佳,原始信息仍可传递至下一层

七、因果遮罩与 KV Cache

7.1 因果遮罩

Decoder-Only 架构要求生成时只能看到"过去",不能看到"未来"。实现上通过上三角矩阵将未来位置的注意力分数设为负无穷:

python 复制代码
scores[:, :, :, -seq_len:] += torch.full((seq_len, seq_len), float("-inf")).triu(1)

7.2 KV Cache

推理时,已生成 token 的 K、V 不会改变。将其缓存后,每步只需对新 token 做计算,避免重复运算,大幅提升生成速度。

八、模型参数设计

MiniMind-3 的参数配置:

参数 说明
hidden_size 768 向量维度
num_hidden_layers 8 Transformer 层数
num_attention_heads 8 Q 头数
num_key_value_heads 4 KV 头数(GQA)
vocab_size 6400 词表大小
intermediate_size 2432 FFN 中间维度
max_position_embeddings 32768 最大位置数

总参数量约 64M,为 GPT-3 的 1/2700。

九、小结

当前主流 LLM 的结构已高度趋同:Decoder-Only + RoPE + GQA + SwiGLU + RMSNorm + Pre-Norm。模型的本质是一组参数矩阵,训练的过程就是将这些矩阵从随机值调整为能够准确预测下一个 token 的有意义数值。结构设计的核心在于选择合适的计算组件并合理连接,使得梯度传播顺畅、训练稳定、推理高效。

相关推荐
杊页1 小时前
系列二:MVVM 深度实战与项目重构 | 第6篇 DataBinding & ViewBinding 实战落地:告别 findViewById 的“刀耕火种”
架构·mvvm
风一直吹2 小时前
Web 端 PvP 实时对战从零实现:匹配、同步、伤害全链路拆解
架构
Sunia2 小时前
《Agentx专栏》06-记忆系统:用Redis+Milvus给AI配上短期+长期双层记忆
java·架构
AI科技星2 小时前
依托Gε₀ = e²/(4παmₚ²)核心方程:全新公式推导+原创理论提炼+全维度精算验证
人工智能·线性代数·架构·概率论·学习方法
用户938515635072 小时前
前端必会:从 Fetch 到 DeepSeek,一篇搞懂 HTTP 请求的方方面面
javascript·架构
小谢小哥2 小时前
68-持续集成详解
java·后端·架构
A-刘晨阳2 小时前
数据库挂了服务就瘫?我用PostgreSQL主从流复制搭了高可用架构,cpolar打通远程访问
数据库·postgresql·架构
candyTong2 小时前
为什么 Agent Skill 不是通过向量 RAG 召回的?
架构
踩着两条虫2 小时前
开源 AI 低代码平台 VTJ.PRO 双版本齐发:核心引擎 v0.17.1 与在线平台 v2.4.1 正式上线,强化团队协作与 AI 资产管理
前端·人工智能·低代码·架构·开源