拆解大模型三:你不会真以为 Attention 是大模型的主角吧

拆解大模型三:你以为 Attention 是主角,其实它只占了一半参数

学完 Attention,很多人会有个错觉:

Transformer = Attention,其他都是配件。

错了。

一个标准的 Transformer Block 里,FFN(前馈网络)占的参数量比 Attention 还多。GPT-3 的 1750 亿参数里,差不多三分之二都在 FFN 里,不是 Attention。

但几乎所有教程都在 Attention 上花 80% 的篇幅,FFN 带一句"两层全连接"就过去了。

这篇把一个完整的 Transformer Block 从头拆到尾------Embedding、Attention、FFN、残差连接、LayerNorm,每个组件是什么、为什么要这样设计,最后用代码把它拼起来。


一、鸟瞰:一个 Token 是怎么穿过 Transformer 的

先建立整体图景,再钻细节。

一个 Token 从输入到输出,经历这几站:

scss 复制代码
Token ID
    ↓
Embedding(把 ID 变成向量)
    ↓
[重复 N 次 Transformer Block]
    ↓ 每个 Block 内部:
    ├─ LayerNorm
    ├─ Multi-Head Attention(+ 残差连接)
    ├─ LayerNorm
    └─ FFN(+ 残差连接)
    ↓
最后一层输出
    ↓
Linear + Softmax(预测下一个 Token 的概率)

GPT-3 把这个 Block 叠了 96 层。每一层都在对 Token 的向量表示做一次"更新",一层层叠下来,最终的向量里就融入了丰富的上下文信息。

接下来逐个击破。


二、Embedding:把 Token 变成向量

模型不能直接处理 Token ID(比如整数 42),需要先把它映射成一个连续的向量。

这个映射就是 Embedding:一张查找表,每个 Token ID 对应一个 d_model 维的向量,这些向量是训练出来的参数。

py 复制代码
import torch
import torch.nn as nn

vocab_size = 50000   # 词表大小
d_model = 512        # 向量维度

embedding = nn.Embedding(vocab_size, d_model)

# Token ID 序列
token_ids = torch.tensor([42, 17, 8, 305])   # shape: (seq_len,)
x = embedding(token_ids)                      # shape: (seq_len, d_model)

但这里有个问题:Attention 是"无序"的------它只管谁跟谁相似,不知道谁在前谁在后。

所以还需要一个 位置编码(Positional Encoding) ,告诉模型每个 Token 在序列里的位置:

py 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)  # 偶数维用 sin
        pe[:, 1::2] = torch.cos(position * div_term)  # 奇数维用 cos
        self.register_buffer('pe', pe)

    def forward(self, x):
        # x: (seq_len, d_model)
        return x + self.pe[:x.size(0)]

原始 Transformer 用的是 sin/cos 函数生成位置编码,现代模型(比如 LLaMA)改用了可学习的 RoPE(旋转位置编码),但原理类似,都是把位置信息加进向量里。


三、残差连接:为什么深层网络不会"学坏"

在钻 Attention 和 FFN 之前,先说残差连接------因为不理解它,就很难理解为什么 Transformer 能叠这么深。

深层网络有个经典问题:梯度消失。网络叠得越深,反向传播时梯度乘了太多次小于 1 的数,到前面几层几乎变成 0,参数根本学不动。

残差连接的解法极其简单:

<math xmlns="http://www.w3.org/1998/Math/MathML"> output = f ( x ) + x \text{output} = f(x) + x </math>output=f(x)+x

把输入 <math xmlns="http://www.w3.org/1998/Math/MathML"> x x </math>x 直接加到这一层的输出上。

py 复制代码
# 有残差连接
output = layer(x) + x

# 没有残差连接
output = layer(x)

这有什么用?反向传播时,梯度可以沿着 + x 这条"高速公路"直接流到前面的层,不用经过 layer(x) 里的层层乘法,梯度消失的问题大幅缓解。

这个想法来自 2015 年的 ResNet(用于图像识别),Transformer 把它借过来了。正是有了残差连接,Transformer 才能叠几十甚至上百层。


四、LayerNorm:训练稳定性的保障

每次 Attention 或 FFN 之后,向量的数值范围可能变得很大或很小,导致训练不稳定。

LayerNorm 的作用是把向量归一化到均值为 0、方差为 1:

<math xmlns="http://www.w3.org/1998/Math/MathML"> x ^ = x − μ σ + ϵ \hat{x} = \frac{x - \mu}{\sigma + \epsilon} </math>x^=σ+ϵx−μ

然后再用两个可学习参数 <math xmlns="http://www.w3.org/1998/Math/MathML"> γ \gamma </math>γ、 <math xmlns="http://www.w3.org/1998/Math/MathML"> β \beta </math>β 做缩放和偏移:

<math xmlns="http://www.w3.org/1998/Math/MathML"> LayerNorm ( x ) = γ x ^ + β \text{LayerNorm}(x) = \gamma \hat{x} + \beta </math>LayerNorm(x)=γx^+β

py 复制代码
layer_norm = nn.LayerNorm(d_model)
x_normed = layer_norm(x)

注意 Transformer Block 里 LayerNorm 放在 Attention 和 FFN 之前(Pre-Norm),而原始论文是放在之后(Post-Norm)。现代大模型几乎都用 Pre-Norm,训练更稳定。


五、FFN:被严重低估的组件

终于到 FFN 了。

结构上非常简单------两层全连接,中间一个激活函数:

<math xmlns="http://www.w3.org/1998/Math/MathML"> FFN ( x ) = activation ( x W 1 + b 1 ) W 2 + b 2 \text{FFN}(x) = \text{activation}(x W_1 + b_1) W_2 + b_2 </math>FFN(x)=activation(xW1+b1)W2+b2

py 复制代码
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.linear2(self.activation(self.linear1(x)))

看起来人畜无害,但注意 d_ff 这个参数------通常设成 4 * d_model

GPT-3 的 d_model = 12288,那么 d_ff = 49152

两个矩阵 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 1 W_1 </math>W1 和 <math xmlns="http://www.w3.org/1998/Math/MathML"> W 2 W_2 </math>W2 的参数量:

<math xmlns="http://www.w3.org/1998/Math/MathML"> 12288 × 49152 + 49152 × 12288 ≈ 12 亿 12288 \times 49152 + 49152 \times 12288 \approx 12 \text{ 亿} </math>12288×49152+49152×12288≈12 亿

乘以 96 层:大约 1150 亿参数,全在 FFN 里。

那 FFN 在做什么?

研究者发现,FFN 里的神经元像是在存储"知识"。有研究(Geva et al. 2021)发现,FFN 的每一行权重对应一类模式------有的神经元专门响应"日本相关的词",有的专门响应"法律术语"。

Attention 负责"在序列里找关系",FFN 负责"根据当前 Token 的语义,调用存储在参数里的知识"。两者分工,缺一不可。


六、拼起来:完整的 Transformer Block

现在把所有组件拼成一个完整的 Block:

py 复制代码
import torch
import torch.nn as nn
import math

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.d_k = d_model // num_heads
        self.num_heads = num_heads

        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, T, C = x.shape   # batch, seq_len, d_model
        h = self.num_heads

        Q = self.W_Q(x).view(B, T, h, self.d_k).transpose(1, 2)  # (B, h, T, d_k)
        K = self.W_K(x).view(B, T, h, self.d_k).transpose(1, 2)
        V = self.W_V(x).view(B, T, h, self.d_k).transpose(1, 2)

        scores = Q @ K.transpose(-2, -1) / math.sqrt(self.d_k)    # (B, h, T, T)
        weights = torch.softmax(scores, dim=-1)
        out = weights @ V                                           # (B, h, T, d_k)

        out = out.transpose(1, 2).contiguous().view(B, T, C)       # (B, T, d_model)
        return self.W_O(out)


class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, d_ff),
            nn.GELU(),
            nn.Linear(d_ff, d_model),
        )

    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, d_model, num_heads, d_ff):
        super().__init__()
        self.norm1 = nn.LayerNorm(d_model)
        self.attn  = MultiHeadAttention(d_model, num_heads)
        self.norm2 = nn.LayerNorm(d_model)
        self.ffn   = FeedForward(d_model, d_ff)

    def forward(self, x):
        # Pre-Norm + 残差
        x = x + self.attn(self.norm1(x))   # Attention 子层
        x = x + self.ffn(self.norm2(x))    # FFN 子层
        return x


# 跑一下看看 shape 对不对
d_model, num_heads, d_ff = 512, 8, 2048
block = TransformerBlock(d_model, num_heads, d_ff)

x = torch.randn(2, 16, d_model)   # batch=2, seq_len=16
out = block(x)
print("输入 shape:", x.shape)
print("输出 shape:", out.shape)   # 应该和输入一样:(2, 16, 512)

输出:

css 复制代码
输入 shape: torch.Size([2, 16, 512])
输出 shape: torch.Size([2, 16, 512])

输入和输出的 shape 完全一样------这是有意设计的。Block 不改变形状,只更新向量里的内容。叠多少层,输入输出形状都不变,模型架构变得非常规整。


七、整个模型长什么样

把 Block 叠起来,加上头尾,就是一个完整的语言模型:

py 复制代码
class GPTModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, d_ff, num_layers, max_len):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = nn.Embedding(max_len, d_model)   # 可学习的位置编码
        self.blocks = nn.ModuleList([
            TransformerBlock(d_model, num_heads, d_ff)
            for _ in range(num_layers)
        ])
        self.norm = nn.LayerNorm(d_model)
        self.head = nn.Linear(d_model, vocab_size, bias=False)

    def forward(self, token_ids):
        B, T = token_ids.shape
        positions = torch.arange(T, device=token_ids.device)

        x = self.embedding(token_ids) + self.pos_embedding(positions)

        for block in self.blocks:
            x = block(x)

        x = self.norm(x)
        logits = self.head(x)   # (B, T, vocab_size)
        return logits

logits 的最后一个时间步,就是模型对下一个 Token 的预测分布。过 Softmax 之后,就是我们在第一篇里说的那张概率表。


小结

颠覆一下认知:FFN 才是 Transformer 里参数最多的组件,不是 Attention。

Attention 负责"在序列里找关系"------谁该关注谁。 FFN 负责"调用知识"------根据当前语义激活存储在参数里的模式。 残差连接让深层网络的梯度能流动,是叠深的基础。 LayerNorm 稳定训练,现代模型几乎都用 Pre-Norm。

把这些拼在一起,就是一个 Transformer Block。叠 N 层,就是大模型的主体。


下一篇,我们跳出结构,看训练过程:

大模型是怎么从一堆随机参数,通过训练变得"有用"的?预训练、SFT、RLHF 分别在做什么?

相关推荐
手机不死我是天子1 天前
拆解大模型二:Transformer 最核心的设计,其实你高中就学过
人工智能·llm
数据智能老司机2 天前
构建自然语言与大语言模型(LLM)流水线——将组件整合起来:面向不同使用场景的 Haystack Pipeline
llm·agent
数据智能老司机2 天前
构建自然语言与大语言模型(LLM)流水线——使用自定义组件进行 Haystack Pipeline 开发
llm·agent
gustt2 天前
探索MCP协议:构建高效的LLM工具集成系统
llm·agent·mcp
神秘的猪头2 天前
🚀 React 开发者进阶:RAG 核心——手把手带你玩转 Milvus 向量数据库
数据库·后端·llm
哈里谢顿3 天前
LangGraph 框架完全指南:构建生产级 AI 工作流
langchain·llm
UIUV3 天前
Splitter学习笔记(含RAG相关流程与代码实践)
后端·langchain·llm
mCell3 天前
分享一个常用的文生图提示词
人工智能·llm·数据可视化
gustt3 天前
使用 LangChain 构建 AI 代理:自动化创建 React TodoList 应用
人工智能·llm·agent