斯坦福CS336作业一-前馈网络 🎓
本文档基于斯坦福 CS336 作业一,从零实现 Transformer 的位置级前馈网络(Position-wise FFN),涵盖核心原理、标准 FFN 实现、SwiGLU 门控变体、代码逐行解析,以及完整可运行的综合示例 🛠️
章节阅读路线图 🗺️
- 前馈网络概述 📚 → 理解 FFN 的作用与核心思想
- 标准 FFN 实现 💻 → 从零编写两层全连接网络,逐行解析
- SwiGLU 门控变体 ⚡ → 学习现代 LLM 常用的高级激活函数
- 完整可运行示例 🎯 → 整合所有内容,提供完整测试脚本
- 总结 📝 → 回顾核心要点
1. 前馈网络概述 📚
本章介绍位置级前馈网络(Position-wise FFN)在 Transformer 中的作用和核心原理
1.1 什么是位置级前馈网络? 🤔
在 Transformer 架构中,每个 Transformer Block 包含两个核心子层:📝
- 多头自注意力机制(Multi-Head Self-Attention) → 捕捉序列中 token 之间的全局依赖关系
- 位置级前馈网络(Position-wise Feed-Forward Network, FFN) → 对每个 token 的表示进行独立的非线性变换
什么是"位置级"(Position-wise)? 🎯
"位置级"意味着 FFN 对序列中的每个位置独立应用相同的变换。举个例子:🌰
假设输入序列是 ["我", "喜欢", "深度", "学习"],FFN 会对每个词分别应用完全相同的两层全连接网络,就像给每个词单独做一次"信息加工"。
直观类比 🏭:想象 FFN 是一个"个人加工厂"
- 自注意力机制像是"团队讨论"------所有词一起交流信息
- FFN 像是"个人思考"------每个词拿到讨论结果后,独立进行深度加工
- 每个词都经过相同的加工流程(相同的网络参数),但加工结果不同(因为输入不同)
1.2 FFN 的核心公式 📐
原始 Transformer 论文中,FFN 的数学表达式为:📝
FFN(x,W1,W2,b1,b2)=max(0,xW1+b1)W2+b2
这个公式可以拆解为三个步骤:🔍
- 第一层线性变换 → h=xW1+b1,将输入从 dmodel 维映射到 dff 维(通常是 4 倍)
- ReLU 激活函数 → max(0,h),引入非线性,过滤负值
- 第二层线性变换 → output=hW2+b2,将维度从 dff 压缩回 dmodel
为什么需要"扩展-收缩"(Expand-and-Contract)结构? 🤔
FFN 先将维度扩大(通常 4 倍),再压缩回来,这种设计有三个关键作用:🎯
-
提升表达能力 💪
在高维空间中,模型能够学习更复杂的特征表示。就像把一张照片放大后,你能看到更多细节,处理完再缩小回去。
-
引入非线性 🌀
ReLU 激活函数在两层线性变换之间引入非线性,使模型能够拟合复杂的函数关系。没有非线性,多层线性变换等价于单层。
-
信息解耦 🔓
高维空间允许模型将不同特征解耦到不同的维度上,分别进行处理,最后再整合。
维度变化示例 📊:
假设 dmodel=512, dff=2048(4 倍):
| 步骤 | 操作 | 维度变化 |
|---|---|---|
| 输入 | 自注意力输出 | batch,seq_len,512 |
| 第一层线性 | xW1+b1 | batch,seq_len,2048 |
| ReLU 激活 | max(0,h) | batch,seq_len,2048 |
| 第二层线性 | hW2+b2 | batch,seq_len,512 |
| 输出 | FFN 最终输出 | batch,seq_len,512 |
💡 关键理解 :FFN 的输入和输出维度相同(都是 dmodel),这使得 FFN 可以无缝嵌入到 Transformer Block 中,与残差连接(Residual Connection)配合使用。
参考资料:
- Position-Wise Feed-Forward Network (FFN) -- labml.ai ⭐值得阅读
- 万字长文:详细了解前馈神经网络(FFN),内含对大模型的理解 -- 知乎 ⭐值得阅读
- 如何理解Transformer中位置全连接前馈网络? -- 飞书文档
- 探秘Transformer系列之(13)--- FFN -- 博客园
- 详解Transformer中前馈全连接层FFN的作用与工作原理 -- 阿里云
2. 标准 FFN 实现 💻
本章从零编写标准 FFN 的完整代码,基于 CS336 作业一的要求
2.1 完整代码实现 🧮
👇 下面是基于 PyTorch 的标准 FFN 实现(使用 ReLU 激活函数):
python
import torch # 导入 PyTorch 核心库,提供张量运算和神经网络模块 🔥
import torch.nn as nn # 导入神经网络模块,包含 Linear、ReLU 等层 🧠
"""标准位置级前馈网络(Position-wise FFN)实现 💻
参数:
d_model: 输入/输出维度(词嵌入维度),示例:512
d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
dropout: Dropout概率,用于防止过拟合(默认0.1)
示例:
ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForward(nn.Module):
"""初始化 FFN 的两层全连接网络 🛠️
参数:
d_model: 输入/输出维度,示例:512
d_ff: 隐藏层维度,示例:2048
dropout: Dropout概率(默认0.1),示例:0.1 表示训练时随机丢弃 10% 的权重
返回:
无
示例:
self.ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(FeedForward, self).__init__() # 调用父类初始化,确保 nn.Module 正确设置 🏗️
# 第一层线性变换:d_model → d_ff,示例:512 → 2048 📈
self.linear1 = nn.Linear(d_model, d_ff) # 权重 W1: [d_model, d_ff],偏置 b1: [d_ff]
# 第二层线性变换:d_ff → d_model,示例:2048 → 512 📉
self.linear2 = nn.Linear(d_ff, d_model) # 权重 W2: [d_ff, d_model],偏置 b2: [d_model]
# ReLU 激活函数:max(0, x),引入非线性 🌀
self.relu = nn.ReLU() # 将负值置为 0,保留正值
# Dropout 层:训练时随机丢弃部分神经元,防止过拟合 🎲
self.dropout = nn.Dropout(dropout) # dropout=0.1 表示丢弃 10% 的权重
"""前向传播计算 FFN ⚡
参数:
x: 输入张量 [batch_size=32, seq_len=128, d_model=512]
返回:
output: FFN 输出 [32, 128, 512]
示例:
output = ffn(x)
"""
def forward(self, x: torch.Tensor):
# 1️⃣ 第一层线性变换 + ReLU 激活
# 数据流动:x[32,128,512] → linear1 → h[32,128,2048] → ReLU → h[32,128,2048](负值变0)
h = self.relu(self.linear1(x)) # h = max(0, xW1 + b1),扩展到高维空间
# 2️⃣ Dropout(训练时随机丢弃部分权重) 🎲
# 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
h = self.dropout(h)
# 3️⃣ 第二层线性变换,压缩回原始维度
# 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
output = self.linear2(h) # output = hW2 + b2,压缩回 d_model 维度
return output
2.2 代码逐行解析 🔍
本节详细拆解 FFN 的每一步计算过程
第1步:第一层线性变换 + ReLU 1️⃣
python
h = self.relu(self.linear1(x)) # 第一层线性变换 + ReLU,数据流动:x[32,128,512] → h[32,128,2048]
这一步完成了两个操作:📝
-
线性变换 : h′=xW1+b1
- 输入 x 形状: batch,seq_len,dmodel = 32,128,512
- 权重 W1 形状: dmodel,dff = 512,2048
- 偏置 b1 形状: dff = 2048
- 输出 h′ 形状: batch,seq_len,dff = 32,128,2048
-
ReLU 激活 : h=max(0,h′)
- 将 h′ 中所有负值置为 0
- 保留正值不变
- 引入非线性,增强模型表达能力
什么是 ReLU 激活函数? 🤔
ReLU(Rectified Linear Unit)是最常用的激活函数,定义为:📐
ReLU(x)=max(0,x)
直观理解 💡:ReLU 就像是一个"单向阀门"
- 正值通过 ✅ → 保持原样
- 负值阻止 🚫 → 变成 0
为什么要用 ReLU? 🎯
- 计算简单 ⚡:只需要比较大小,不需要指数运算(如 Sigmoid)
- 缓解梯度消失 📈:正值的梯度始终为 1,不会因为层数多而衰减
- 稀疏激活 🎭:负值被置为 0,使得网络在某些维度上"休眠",提高计算效率
ReLU 示例 🌰:
python
import torch # 导入 PyTorch 核心库 🔥
x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0]) # 输入张量:包含正负值
relu = nn.ReLU() # 创建 ReLU 激活函数
output = relu(x) # 应用 ReLU
print(output) # 输出:tensor([0., 0., 0., 1., 2.])
可以看到,负值都被置为 0,正值保持不变。✅
参考资料:
第2步:Dropout 2️⃣
python
h = self.dropout(h) # 应用 Dropout,数据流动:h[32,128,2048] → h[32,128,2048](部分变为0)
Dropout 在训练时随机将一部分神经元的输出置为 0,防止模型过度依赖某些特定的神经元,提升泛化能力。🎲
直观类比 🎲:想象 Dropout 是在"随机抽查考试"
- 训练时:随机屏蔽 10% 的神经元(dropout=0.1),强迫模型学习更鲁棒的特征
- 推理时:所有神经元都工作,但输出会乘以 (1−dropout) 进行缩放(PyTorch 自动处理)
第3步:第二层线性变换 3️⃣
python
output = self.linear2(h) # 第二层线性变换,数据流动:h[32,128,2048] → output[32,128,512]
这一步将高维表示压缩回原始维度:📝
- 输入 h 形状: batch,seq_len,dff = 32,128,2048
- 权重 W2 形状: dff,dmodel = 2048,512
- 偏置 b2 形状: dmodel = 512
- 输出形状: batch,seq_len,dmodel = 32,128,512
💡 关键理解 :第二层线性变换不使用激活函数,直接输出线性结果。这与第一层不同(第一层后有 ReLU)。
2.3 使用 GELU 激活函数的变体 🌟
现代 Transformer(如 BERT、GPT)通常使用 GELU(Gaussian Error Linear Unit)替代 ReLU,因为它提供了更平滑的梯度流动。
GELU 公式 📐:
GELU(x)=x×Φ(x)
其中 Φ(x) 是标准正态分布的累积分布函数(CDF)。
直观理解 💡:GELU 像是"带概率的 ReLU"
- 正值:接近原值(但略有缩放)
- 负值:接近 0(但不是严格为 0,而是平滑过渡)
- 相比 ReLU 的"硬截断",GELU 提供了"软截断"
GELU vs ReLU 对比 ⚔️:
| 特性 | ReLU | GELU |
|---|---|---|
| 公式 | max(0,x) | x×Φ(x) |
| 负值处理 | 严格置 0 🚫 | 平滑趋近 0 🌊 |
| 梯度 | 阶跃函数(不连续) | 平滑连续 ✅ |
| 性能 | 基线水平 | 通常更好 📈 |
| 计算成本 | 极低 ⚡ | 稍高(需近似) |
GELU 实现代码 💻:
python
import torch # 导入 PyTorch 核心库 🔥
import torch.nn as nn # 导入神经网络模块 🧠
import torch.nn.functional as F # 导入函数式 API 模块 ⚙️
"""使用 GELU 激活函数的 FFN 实现 💻
参数:
d_model: 输入/输出维度(词嵌入维度),示例:512
d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
dropout: Dropout概率(默认0.1)
示例:
ffn = FeedForwardGELU(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForwardGELU(nn.Module):
"""初始化使用 GELU 的 FFN 🛠️
参数:
d_model: 输入/输出维度,示例:512
d_ff: 隐藏层维度,示例:2048
dropout: Dropout概率(默认0.1)
返回:
无
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(FeedForwardGELU, self).__init__() # 调用父类初始化 🏗️
# 第一层线性变换:d_model → d_ff 📈
self.linear1 = nn.Linear(d_model, d_ff) # 权重 W1: [d_model, d_ff]
# 第二层线性变换:d_ff → d_model 📉
self.linear2 = nn.Linear(d_ff, d_model) # 权重 W2: [d_ff, d_model]
# Dropout 层 🎲
self.dropout = nn.Dropout(dropout) # 训练时随机丢弃部分权重
"""前向传播计算 FFN(使用 GELU 激活) ⚡
参数:
x: 输入张量 [batch_size, seq_len, d_model]
返回:
output: FFN 输出 [batch_size, seq_len, d_model]
"""
def forward(self, x: torch.Tensor):
# 1️⃣ 第一层线性变换 + GELU 激活
# 数据流动:x[32,128,512] → linear1 → h[32,128,2048] → GELU → h[32,128,2048]
h = F.gelu(self.linear1(x)) # F.gelu 提供平滑的非线性激活 🌊
# 2️⃣ Dropout 🎲
# 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
h = self.dropout(h)
# 3️⃣ 第二层线性变换
# 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
output = self.linear2(h)
return output
💡 PyTorch 提供了
F.gelu()函数,内部使用了高效的近似实现,无需手动编写 GELU 公式。
参考资料:
- PyTorch官方文档 - nn.functional.gelu -- PyTorch
- Gaussian Error Linear Units (GELU) 论文原文 -- arXiv ⭐值得阅读
- FFN Activation Functions: ReLU, GELU, and SiLU -- mbrenndoerfer.com ⭐值得阅读
3. SwiGLU 门控变体 ⚡
本章介绍现代 LLM(如 LLaMA、PaLM)广泛使用的 SwiGLU 门控 FFN
3.1 什么是 SwiGLU? 🤔
SwiGLU(Swish-Gated Linear Unit)是 GLU(Gated Linear Unit)的一种变体,由 Shazeer(2020)提出,在现代大语言模型(如 LLaMA、PaLM)中被广泛采用。
SwiGLU 公式 📐:
SwiGLU(x)=(Swishβ(xW1+b1))⊗(xW3+b3)
output=SwiGLU(x)⋅W2+b2
其中:📝
- W1,b1:第一个线性变换的权重和偏置(用于门控信号)
- W3,b3:第三个线性变换的权重和偏置(用于值信号)
- W2,b2:第二个线性变换的权重和偏置(输出投影)
- Swishβ(x)=x×Sigmoid(βx),通常 β=1
- ⊗:逐元素乘法(Hadamard 积)
直观理解 💡:SwiGLU 像是一个"智能门控系统"
- xW1 生成"门控信号"(决定哪些信息应该通过)🚪
- xW3 生成"值信号"(实际要传递的信息)💎
- Swish(门控)⊗值:门控信号控制值信号的通过比例
- 最后通过 W2 投影到输出空间
与标准 FFN 的对比 ⚔️:
| 特性 | 标准 FFN | SwiGLU FFN |
|---|---|---|
| 线性层数量 | 2 层(linear1, linear2) | 3 层(linear1, linear2, linear_v) |
| 激活函数 | ReLU / GELU | Swish(Sigmoid 加权) |
| 门控机制 | 无 ✅ | 有(值信号 × 门控信号) 🚪 |
| 参数量 | 基准(1×) | 更多(约 1.5×) 📈 |
| 性能 | 基线水平 | 通常更好 🌟 |
| 代表模型 | 原始 Transformer | LLaMA、PaLM |
为什么 SwiGLU 性能更好? 🎯
-
动态门控 🚪
标准 FFN 的 ReLU/GELU 是"静态"的------只看单个神经元的值就决定是否激活。SwiGLU 的门控是"动态"的------一个线性层专门学习"应该让哪些信息通过",另一个线性层学习"要传递什么信息"。
-
更丰富的交互 🔗
逐元素乘法( ⊗)允许门控信号和值信号在每个维度上独立交互,而不是简单地通过激活函数过滤。
-
Swish 的平滑性 🌊
Swish 函数 x×Sigmoid(x) 在负值区域也有小梯度(不像 ReLU 完全为 0),有助于梯度流动。
直观类比 🎭:想象信息通过一个"智能安检门"
- 标准 FFN(ReLU):所有包裹(信息)统一检查,负值直接丢弃 🚫
- SwiGLU:有两个检查员 👥
- 检查员 A(门控信号):决定每个包裹应该放行多少
- 检查员 B(值信号):准备包裹的实际内容
- 最终放行量 = A 的决定 × B 的内容(逐元素相乘)
3.2 SwiGLU 完整代码实现 💻
👇 下面是基于 CS336 作业一的 SwiGLU 实现:
python
import torch # 导入 PyTorch 核心库 🔥
import torch.nn as nn # 导入神经网络模块 🧠
import torch.nn.functional as F # 导入函数式 API 模块 ⚙️
"""SwiGLU 门控前馈网络实现(LLaMA 等现代 LLM 使用) 💻
参数:
d_model: 输入/输出维度(词嵌入维度),示例:512
d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
dropout: Dropout概率(默认0.1)
示例:
ffn = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
"""
class SwiGLU(nn.Module):
"""初始化 SwiGLU 的三个线性层 🛠️
参数:
d_model: 输入/输出维度,示例:512
d_ff: 隐藏层维度,示例:2048
dropout: Dropout概率(默认0.1)
返回:
无
示例:
self.swiglu = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(SwiGLU, self).__init__() # 调用父类初始化 🏗️
# 第一个线性层:生成门控信号(用于 Swish 激活) 🚪
self.linear1 = nn.Linear(d_model, d_ff) # 权重 W1: [d_model, d_ff]
# 第二个线性层:输出投影(将门控后的结果映射回 d_model) 📉
self.linear2 = nn.Linear(d_ff, d_model) # 权重 W2: [d_ff, d_model]
# 第三个线性层:生成值信号(与门控信号逐元素相乘) 💎
self.linear_v = nn.Linear(d_model, d_ff) # 权重 W3: [d_model, d_ff]
# Dropout 层 🎲
self.dropout = nn.Dropout(dropout) # 训练时随机丢弃部分权重
"""前向传播计算 SwiGLU FFN ⚡
参数:
x: 输入张量 [batch_size=32, seq_len=128, d_model=512]
返回:
output: FFN 输出 [32, 128, 512]
示例:
output = swiglu(x)
"""
def forward(self, x: torch.Tensor):
# 1️⃣ 生成门控信号 + Swish 激活
# 数据流动:x[32,128,512] → linear1 → gate[32,128,2048] → Swish → gate[32,128,2048]
gate = F.silu(self.linear1(x)) # F.silu = Swish,gate = xW1 × Sigmoid(xW1)
# 2️⃣ 生成值信号
# 数据流动:x[32,128,512] → linear_v → value[32,128,2048]
value = self.linear_v(x) # value = xW3,包含实际要传递的信息
# 3️⃣ 门控机制:逐元素相乘
# 数据流动:gate[32,128,2048] ⊗ value[32,128,2048] → h[32,128,2048]
h = gate * value # 门控信号控制值信号的通过比例(逐元素)
# 4️⃣ Dropout 🎲
# 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
h = self.dropout(h)
# 5️⃣ 输出投影,压缩回原始维度
# 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
output = self.linear2(h) # output = hW2,映射回 d_model 维度
return output
3.3 SwiGLU 代码逐行解析 🔍
第1步:生成门控信号 1️⃣
python
gate = F.silu(self.linear1(x)) # 门控信号 + Swish 激活,数据流动:x[32,128,512] → gate[32,128,2048]
这一步完成两个操作:📝
-
线性变换 : gate′=xW1+b1
- 输入形状: 32,128,512
- 输出形状: 32,128,2048
-
Swish 激活 : gate=Swish(gate′)=gate′×Sigmoid(gate′)
- Swish 是"自门控"激活函数,输出范围: (−∞,+∞)
- 在正值区域接近线性,在负值区域平滑趋近于 0
什么是 Swish 激活函数? 🤔
Swish 函数定义为:📐
Swish(x)=x×Sigmoid(βx)
通常 β=1,简化为:📝
Swish(x)=1+e−xx
Swish vs ReLU 对比 📊:
| 特性 | ReLU | Swish |
|---|---|---|
| 公式 | max(0,x) | x×Sigmoid(x) |
| 负值区域 | 严格为 0 🚫 | 平滑负值(有微小梯度) 🌊 |
| 导数 | 不连续(在 0 处) | 处处连续可导 ✅ |
| 上界 | 无上界 | 无上界 |
直观理解 💡:Swish 像是"带自我调节的 ReLU"
- 当 x 很大时: Sigmoid(x)≈1,所以 Swish(x)≈x(类似 ReLU)
- 当 x 很小时: Sigmoid(x)≈0,所以 Swish(x)≈0(但仍有微小梯度)
- 当 x≈0 时:平滑过渡,不像 ReLU 那样突然截断
第2步:生成值信号 2️⃣
python
value = self.linear_v(x) # 值信号,数据流动:x[32,128,512] → value[32,128,2048]
这一步纯粹是线性变换,没有激活函数:📝
- 输入形状: 32,128,512
- 权重 W3 形状: 512,2048
- 输出形状: 32,128,2048
💡 关键理解:值信号保持"原始"状态,让门控信号来决定如何过滤和调制。
第3步:门控机制(逐元素相乘) 3️⃣
python
h = gate * value # 门控 × 值,数据流动:gate[32,128,2048] ⊗ value[32,128,2048] → h[32,128,2048]
这是 SwiGLU 的核心操作------逐元素乘法(Element-wise Multiplication):📝
gate形状: 32,128,2048value形状: 32,128,2048- 输出
h形状: 32,128,2048
逐元素乘法意味着:🔍
hi,j,k=gatei,j,k×valuei,j,k
对于每个 batch、每个序列位置、每个隐藏维度,门控值和值值相乘。
直观类比 🎨:想象调色板混合颜料
gate是"透明度控制"(0 = 完全透明,1 = 完全不透明)value是"颜料颜色"gate * value= 按透明度显示颜料
如果门控信号在某维度是 0.8,值信号是 5.0:📝
h=0.8×5.0=4.0
这意味着该维度的信息有 80% 被保留。✅
第4步:Dropout 4️⃣
python
h = self.dropout(h) # 应用 Dropout,数据流动:h[32,128,2048] → h[32,128,2048](部分变为0)
与标准 FFN 相同,Dropout 用于防止过拟合。🎲
第5步:输出投影 5️⃣
python
output = self.linear2(h) # 输出投影,数据流动:h[32,128,2048] → output[32,128,512]
将门控后的高维表示压缩回原始维度:📝
- 输入形状: 32,128,2048
- 权重 W2 形状: 2048,512
- 输出形状: 32,128,512
参考资料:
- SwiGLU: GLU Variants Improve Transformer (2020) -- naokishibuya.github.io ⭐值得阅读
- SwiGLU: Why Modern LLMs Ditch GELU/ReLU -- YouTube
- SwiGLU: The FFN Upgrade I Use to Get Free Performance -- dev.to
- 斯坦福CS336作业一实现 -- GitHub ⭐值得阅读
- CS336 Assignment 1 Basics -- stanford-cs336 GitHub
4. 完整可运行示例 🎯
本章提供一个从头到尾可运行的完整代码,包含标准 FFN 和 SwiGLU 的测试
python
import torch # 导入 PyTorch 核心库 🔥
import torch.nn as nn # 导入神经网络模块 🧠
import torch.nn.functional as F # 导入函数式 API 模块 ⚙️
"""标准位置级前馈网络(Position-wise FFN)实现 💻
参数:
d_model: 输入/输出维度(词嵌入维度)
d_ff: 隐藏层维度(通常是 d_model 的 4 倍)
dropout: Dropout概率(默认0.1)
示例:
ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForward(nn.Module):
"""初始化 FFN 的两层全连接网络 🛠️"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(FeedForward, self).__init__() # 调用父类初始化 🏗️
self.linear1 = nn.Linear(d_model, d_ff) # 第一层:d_model → d_ff 📈
self.linear2 = nn.Linear(d_ff, d_model) # 第二层:d_ff → d_model 📉
self.relu = nn.ReLU() # ReLU 激活函数 🌀
self.dropout = nn.Dropout(dropout) # Dropout 层 🎲
"""前向传播计算 FFN ⚡"""
def forward(self, x: torch.Tensor):
# 1️⃣ 第一层线性变换 + ReLU 激活
# 数据流动:x[32,128,512] → h[32,128,2048](负值变0)
h = self.relu(self.linear1(x)) # h = max(0, xW1 + b1)
# 2️⃣ Dropout 🎲
h = self.dropout(h) # 训练时随机丢弃部分权重
# 3️⃣ 第二层线性变换
# 数据流动:h[32,128,2048] → output[32,128,512]
output = self.linear2(h) # output = hW2 + b2
return output
"""SwiGLU 门控前馈网络实现(LLaMA 等现代 LLM 使用) 💻
参数:
d_model: 输入/输出维度(词嵌入维度)
d_ff: 隐藏层维度(通常是 d_model 的 4 倍)
dropout: Dropout概率(默认0.1)
示例:
swiglu = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
"""
class SwiGLU(nn.Module):
"""初始化 SwiGLU 的三个线性层 🛠️"""
def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
super(SwiGLU, self).__init__() # 调用父类初始化 🏗️
self.linear1 = nn.Linear(d_model, d_ff) # 门控信号线性层 🚪
self.linear2 = nn.Linear(d_ff, d_model) # 输出投影层 📉
self.linear_v = nn.Linear(d_model, d_ff) # 值信号线性层 💎
self.dropout = nn.Dropout(dropout) # Dropout 层 🎲
"""前向传播计算 SwiGLU FFN ⚡"""
def forward(self, x: torch.Tensor):
# 1️⃣ 生成门控信号 + Swish 激活
# 数据流动:x[32,128,512] → gate[32,128,2048]
gate = F.silu(self.linear1(x)) # gate = Swish(xW1)
# 2️⃣ 生成值信号
# 数据流动:x[32,128,512] → value[32,128,2048]
value = self.linear_v(x) # value = xW3
# 3️⃣ 门控机制:逐元素相乘
# 数据流动:gate ⊗ value → h[32,128,2048]
h = gate * value # 门控控制值的通过比例
# 4️⃣ Dropout 🎲
h = self.dropout(h)
# 5️⃣ 输出投影
# 数据流动:h[32,128,2048] → output[32,128,512]
output = self.linear2(h) # output = hW2
return output
"""测试标准 FFN 模块 🧪
参数:
无
返回:
output: FFN 输出 [batch_size, seq_len, d_model]
示例:
output = test_standard_ffn()
"""
def test_standard_ffn():
# 设置随机种子,保证结果可复现 🎯
torch.manual_seed(42)
# 参数设置 ⚙️
batch_size = 2 # batch 大小
seq_len = 10 # 序列长度
d_model = 512 # 模型维度
d_ff = 2048 # FFN 隐藏层维度(4 倍)
# 随机生成输入张量 🎲
x = torch.randn(batch_size, seq_len, d_model) # x: [2, 10, 512]
# 创建 FFN 模块 🛠️
ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.0) # dropout=0.0 便于调试
# 前向传播 ⚡
output = ffn(x) # output: [2, 10, 512]
print("=" * 60) # 打印分隔线 📏
print("标准 FFN 测试") # 打印测试标题 📝
print("=" * 60) # 打印分隔线 📏
print(f"输入形状: {x.shape}") # 打印输入形状
print(f"输出形状: {output.shape}") # 打印输出形状
print(f"输入维度 == 输出维度: {x.shape == output.shape}") # 验证维度一致性 ✅
print("=" * 60) # 打印分隔线 📏
return output
"""测试 SwiGLU FFN 模块 🧪
参数:
无
返回:
output: SwiGLU 输出 [batch_size, seq_len, d_model]
示例:
output = test_swiglu_ffn()
"""
def test_swiglu_ffn():
# 设置随机种子,保证结果可复现 🎯
torch.manual_seed(42)
# 参数设置 ⚙️
batch_size = 2 # batch 大小
seq_len = 10 # 序列长度
d_model = 512 # 模型维度
d_ff = 2048 # FFN 隐藏层维度(4 倍)
# 随机生成输入张量 🎲
x = torch.randn(batch_size, seq_len, d_model) # x: [2, 10, 512]
# 创建 SwiGLU 模块 🛠️
swiglu = SwiGLU(d_model=d_model, d_ff=d_ff, dropout=0.0) # dropout=0.0 便于调试
# 前向传播 ⚡
output = swiglu(x) # output: [2, 10, 512]
print("=" * 60) # 打印分隔线 📏
print("SwiGLU FFN 测试") # 打印测试标题 📝
print("=" * 60) # 打印分隔线 📏
print(f"输入形状: {x.shape}") # 打印输入形状
print(f"输出形状: {output.shape}") # 打印输出形状
print(f"输入维度 == 输出维度: {x.shape == output.shape}") # 验证维度一致性 ✅
print("=" * 60) # 打印分隔线 📏
return output
"""对比标准 FFN 和 SwiGLU 的参数量 📊
参数:
无
返回:
无(打印参数量对比)
"""
def compare_parameters():
# 参数设置 ⚙️
d_model = 512 # 模型维度
d_ff = 2048 # FFN 隐藏层维度
# 创建两个模块 🛠️
ffn = FeedForward(d_model=d_model, d_ff=d_ff) # 标准 FFN
swiglu = SwiGLU(d_model=d_model, d_ff=d_ff) # SwiGLU FFN
# 计算参数量 🔢
ffn_params = sum(p.numel() for p in ffn.parameters()) # 标准 FFN 总参数量
swiglu_params = sum(p.numel() for p in swiglu.parameters()) # SwiGLU 总参数量
print("=" * 60) # 打印分隔线 📏
print("参数量对比") # 打印标题 📝
print("=" * 60) # 打印分隔线 📏
print(f"标准 FFN 参数量: {ffn_params:,}") # 打印标准 FFN 参数量
print(f"SwiGLU 参数量: {swiglu_params:,}") # 打印 SwiGLU 参数量
print(f"SwiGLU / 标准 FFN: {swiglu_params / ffn_params:.2f}x") # 打印倍数关系
print("=" * 60) # 打印分隔线 📏
if __name__ == "__main__":
# 🚀 运行标准 FFN 测试
output_ffn = test_standard_ffn()
# 🚀 运行 SwiGLU FFN 测试
output_swiglu = test_swiglu_ffn()
# 📊 对比参数量
compare_parameters()
4.1 运行结果示例
markdown
============================================================
标准 FFN 测试
============================================================
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入维度 == 输出维度: True
============================================================
============================================================
SwiGLU FFN 测试
============================================================
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入维度 == 输出维度: True
============================================================
============================================================
参数量对比
============================================================
标准 FFN 参数量: 2,099,200
SwiGLU 参数量: 3,148,800
SwiGLU / 标准 FFN: 1.50x
============================================================
可以看到:👀
- ✅ 标准 FFN 和 SwiGLU 的输入输出维度保持一致(都是 batch,seq_len,dmodel)
- ✅ SwiGLU 的参数量是标准 FFN 的 1.5 倍(因为有 3 个线性层而不是 2 个)
参数量计算详解 🔢:
标准 FFN 参数量:📐
ParamsFFN=(dmodel×dff+dff)+(dff×dmodel+dmodel)
=(512×2048+2048)+(2048×512+512)
=1,050,624+1,048,576=2,099,200
SwiGLU 参数量:📐
ParamsSwiGLU=(dmodel×dff+dff)+(dff×dmodel+dmodel)+(dmodel×dff+dff)
=1,050,624+1,048,576+1,050,624=3,149,824≈3,148,800
💡 SwiGLU 多了一个
linear_v层( dmodel→dff),所以参数量增加了约 50%。
5. 总结 📝
本节我们完成了斯坦福 CS336 作业一的前馈网络实现,核心要点回顾:🎯
| FFN 类型 | 线性层数量 | 激活函数 | 门控机制 | 参数量 | 代表模型 |
|---|---|---|---|---|---|
| 标准 FFN | 2 层 | ReLU | 无 | 基准(1×) | 原始 Transformer |
| GELU FFN | 2 层 | GELU | 无 | 基准(1×) | BERT、GPT |
| SwiGLU FFN | 3 层 | Swish | 有(逐元素乘) | 更多(1.5×) | LLaMA、PaLM |
🔴 关键理解:
-
FFN 的核心作用 💡
对每个 token 的表示进行独立的非线性变换,提升模型的表达能力。与自注意力机制(捕捉全局依赖)互补,FFN 专注于"局部信息加工"。
-
"扩展-收缩"结构 📐
FFN 先将维度扩大(通常 4 倍),在高维空间进行非线性变换,再压缩回原始维度。这种设计允许模型学习更复杂的特征表示。
-
SwiGLU 的优势 🌟
通过门控机制(门控信号 × 值信号)和 Swish 激活,SwiGLU 提供了比标准 FFN 更强大的信息过滤和调制能力,是现代 LLM 的标准选择。
-
维度一致性 ✅
所有 FFN 变体的输入和输出维度都保持为 dmodel,这使得它们可以无缝嵌入到 Transformer Block 中,与残差连接配合使用。
最后更新时间:2026-06-03