09ba-斯坦福CS336作业一-前馈网络

斯坦福CS336作业一-前馈网络 🎓

本文档基于斯坦福 CS336 作业一,从零实现 Transformer 的位置级前馈网络(Position-wise FFN),涵盖核心原理、标准 FFN 实现、SwiGLU 门控变体、代码逐行解析,以及完整可运行的综合示例 🛠️

章节阅读路线图 🗺️

  1. 前馈网络概述 📚 → 理解 FFN 的作用与核心思想
  2. 标准 FFN 实现 💻 → 从零编写两层全连接网络,逐行解析
  3. SwiGLU 门控变体 ⚡ → 学习现代 LLM 常用的高级激活函数
  4. 完整可运行示例 🎯 → 整合所有内容,提供完整测试脚本
  5. 总结 📝 → 回顾核心要点

1. 前馈网络概述 📚

本章介绍位置级前馈网络(Position-wise FFN)在 Transformer 中的作用和核心原理

1.1 什么是位置级前馈网络? 🤔

在 Transformer 架构中,每个 Transformer Block 包含两个核心子层:📝

  1. 多头自注意力机制(Multi-Head Self-Attention) → 捕捉序列中 token 之间的全局依赖关系
  2. 位置级前馈网络(Position-wise Feed-Forward Network, FFN) → 对每个 token 的表示进行独立的非线性变换

什么是"位置级"(Position-wise)? 🎯

"位置级"意味着 FFN 对序列中的每个位置独立应用相同的变换。举个例子:🌰

假设输入序列是 ["我", "喜欢", "深度", "学习"],FFN 会对每个词分别应用完全相同的两层全连接网络,就像给每个词单独做一次"信息加工"。

直观类比 🏭:想象 FFN 是一个"个人加工厂"

  • 自注意力机制像是"团队讨论"------所有词一起交流信息
  • FFN 像是"个人思考"------每个词拿到讨论结果后,独立进行深度加工
  • 每个词都经过相同的加工流程(相同的网络参数),但加工结果不同(因为输入不同)

1.2 FFN 的核心公式 📐

原始 Transformer 论文中,FFN 的数学表达式为:📝
FFN(x,W1,W2,b1,b2)=max⁡(0,xW1+b1)W2+b2 \text{FFN}(x, W_1, W_2, b_1, b_2) = \max(0, xW_1 + b_1)W_2 + b_2 FFN(x,W1,W2,b1,b2)=max(0,xW1+b1)W2+b2

这个公式可以拆解为三个步骤:🔍

  1. 第一层线性变换 h=xW1+b1 h = xW_1 + b_1 h=xW1+b1,将输入从 dmodel d_{\text{model}} dmodel 维映射到 dff d_{ff} dff 维(通常是 4 倍)
  2. ReLU 激活函数 max⁡(0,h)\max(0, h) max(0,h),引入非线性,过滤负值
  3. 第二层线性变换 output=hW2+b2output = hW_2 + b_2 output=hW2+b2,将维度从 dff d_{ff} dff 压缩回 dmodel d_{\text{model}} dmodel

为什么需要"扩展-收缩"(Expand-and-Contract)结构? 🤔

FFN 先将维度扩大(通常 4 倍),再压缩回来,这种设计有三个关键作用:🎯

  1. 提升表达能力 💪

    在高维空间中,模型能够学习更复杂的特征表示。就像把一张照片放大后,你能看到更多细节,处理完再缩小回去。

  2. 引入非线性 🌀

    ReLU 激活函数在两层线性变换之间引入非线性,使模型能够拟合复杂的函数关系。没有非线性,多层线性变换等价于单层。

  3. 信息解耦 🔓

    高维空间允许模型将不同特征解耦到不同的维度上,分别进行处理,最后再整合。

维度变化示例 📊:

假设 dmodel=512 d_{\text{model}} = 512 dmodel=512, dff =2048 d_{ff} = 2048 dff=2048(4 倍):

步骤 操作 维度变化
输入 自注意力输出 batch,seq_len,512\\text{batch}, \\text{seq\\_len}, 512 batch,seq_len,512
第一层线性 xW1+b1 xW_1 + b_1 xW1+b1 batch,seq_len,2048\\text{batch}, \\text{seq\\_len}, 2048 batch,seq_len,2048
ReLU 激活 max⁡(0,h)\max(0, h) max(0,h) batch,seq_len,2048\\text{batch}, \\text{seq\\_len}, 2048 batch,seq_len,2048
第二层线性 hW2+b2 hW_2 + b_2 hW2+b2 batch,seq_len,512\\text{batch}, \\text{seq\\_len}, 512 batch,seq_len,512
输出 FFN 最终输出 batch,seq_len,512\\text{batch}, \\text{seq\\_len}, 512 batch,seq_len,512

💡 关键理解 :FFN 的输入和输出维度相同(都是 dmodel d_{\text{model}} dmodel),这使得 FFN 可以无缝嵌入到 Transformer Block 中,与残差连接(Residual Connection)配合使用。


参考资料:


2. 标准 FFN 实现 💻

本章从零编写标准 FFN 的完整代码,基于 CS336 作业一的要求

2.1 完整代码实现 🧮

👇 下面是基于 PyTorch 的标准 FFN 实现(使用 ReLU 激活函数):

python 复制代码
import torch                                              # 导入 PyTorch 核心库,提供张量运算和神经网络模块 🔥
import torch.nn as nn                                     # 导入神经网络模块,包含 Linear、ReLU 等层 🧠


"""标准位置级前馈网络(Position-wise FFN)实现 💻

参数:
    d_model: 输入/输出维度(词嵌入维度),示例:512
    d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
    dropout: Dropout概率,用于防止过拟合(默认0.1)
    
示例:
    ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForward(nn.Module):
    """初始化 FFN 的两层全连接网络 🛠️
    
    参数:
        d_model: 输入/输出维度,示例:512
        d_ff: 隐藏层维度,示例:2048
        dropout: Dropout概率(默认0.1),示例:0.1 表示训练时随机丢弃 10% 的权重
        
    返回:
        无
        
    示例:
        self.ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(FeedForward, self).__init__()               # 调用父类初始化,确保 nn.Module 正确设置 🏗️
        
        # 第一层线性变换:d_model → d_ff,示例:512 → 2048 📈
        self.linear1 = nn.Linear(d_model, d_ff)           # 权重 W1: [d_model, d_ff],偏置 b1: [d_ff]
        
        # 第二层线性变换:d_ff → d_model,示例:2048 → 512 📉
        self.linear2 = nn.Linear(d_ff, d_model)           # 权重 W2: [d_ff, d_model],偏置 b2: [d_model]
        
        # ReLU 激活函数:max(0, x),引入非线性 🌀
        self.relu = nn.ReLU()                             # 将负值置为 0,保留正值
        
        # Dropout 层:训练时随机丢弃部分神经元,防止过拟合 🎲
        self.dropout = nn.Dropout(dropout)                # dropout=0.1 表示丢弃 10% 的权重
    
    """前向传播计算 FFN ⚡
    
    参数:
        x: 输入张量 [batch_size=32, seq_len=128, d_model=512]
        
    返回:
        output: FFN 输出 [32, 128, 512]
        
    示例:
        output = ffn(x)
    """
    def forward(self, x: torch.Tensor):
        # 1️⃣ 第一层线性变换 + ReLU 激活
        # 数据流动:x[32,128,512] → linear1 → h[32,128,2048] → ReLU → h[32,128,2048](负值变0)
        h = self.relu(self.linear1(x))                    # h = max(0, xW1 + b1),扩展到高维空间
        
        # 2️⃣ Dropout(训练时随机丢弃部分权重) 🎲
        # 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
        h = self.dropout(h)
        
        # 3️⃣ 第二层线性变换,压缩回原始维度
        # 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
        output = self.linear2(h)                          # output = hW2 + b2,压缩回 d_model 维度
        
        return output

2.2 代码逐行解析 🔍

本节详细拆解 FFN 的每一步计算过程

第1步:第一层线性变换 + ReLU 1️⃣

python 复制代码
h = self.relu(self.linear1(x))                            # 第一层线性变换 + ReLU,数据流动:x[32,128,512] → h[32,128,2048]

这一步完成了两个操作:📝

  1. 线性变换 h′=xW1+b1 h' = xW_1 + b_1 h′=xW1+b1

    • 输入 xx x 形状: batch,seq_len,dmodel\\text{batch}, \\text{seq\\_len}, d_{\\text{model}} batch,seq_len,dmodel = 32,128,51232, 128, 512 32,128,512
    • 权重 W1 W_1 W1 形状: dmodel, dff d_{\\text{model}}, d_{ff} dmodel,dff = 512,2048512, 2048 512,2048
    • 偏置 b1 b_1 b1 形状: dff d_{ff} dff = 20482048 2048
    • 输出 h′h' h′ 形状: batch,seq_len, dff \\text{batch}, \\text{seq\\_len}, d_{ff} batch,seq_len,dff = 32,128,204832, 128, 2048 32,128,2048
  2. ReLU 激活 h=max⁡(0,h′)h = \max(0, h') h=max(0,h′)

    • h′h' h′ 中所有负值置为 0
    • 保留正值不变
    • 引入非线性,增强模型表达能力

什么是 ReLU 激活函数? 🤔

ReLU(Rectified Linear Unit)是最常用的激活函数,定义为:📐
ReLU(x)=max⁡(0,x)\text{ReLU}(x) = \max(0, x) ReLU(x)=max(0,x)

直观理解 💡:ReLU 就像是一个"单向阀门"

  • 正值通过 ✅ → 保持原样
  • 负值阻止 🚫 → 变成 0

为什么要用 ReLU? 🎯

  1. 计算简单 ⚡:只需要比较大小,不需要指数运算(如 Sigmoid)
  2. 缓解梯度消失 📈:正值的梯度始终为 1,不会因为层数多而衰减
  3. 稀疏激活 🎭:负值被置为 0,使得网络在某些维度上"休眠",提高计算效率

ReLU 示例 🌰:

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥

x = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0])            # 输入张量:包含正负值
relu = nn.ReLU()                                          # 创建 ReLU 激活函数
output = relu(x)                                          # 应用 ReLU

print(output)                                             # 输出:tensor([0., 0., 0., 1., 2.])

可以看到,负值都被置为 0,正值保持不变。✅


参考资料:

第2步:Dropout 2️⃣

python 复制代码
h = self.dropout(h)                                       # 应用 Dropout,数据流动:h[32,128,2048] → h[32,128,2048](部分变为0)

Dropout 在训练时随机将一部分神经元的输出置为 0,防止模型过度依赖某些特定的神经元,提升泛化能力。🎲

直观类比 🎲:想象 Dropout 是在"随机抽查考试"

  • 训练时:随机屏蔽 10% 的神经元(dropout=0.1),强迫模型学习更鲁棒的特征
  • 推理时:所有神经元都工作,但输出会乘以 (1−dropout)(1 - \text{dropout}) (1−dropout) 进行缩放(PyTorch 自动处理)

第3步:第二层线性变换 3️⃣

python 复制代码
output = self.linear2(h)                                  # 第二层线性变换,数据流动:h[32,128,2048] → output[32,128,512]

这一步将高维表示压缩回原始维度:📝

  • 输入 hh h 形状: batch,seq_len, dff \\text{batch}, \\text{seq\\_len}, d_{ff} batch,seq_len,dff = 32,128,204832, 128, 2048 32,128,2048
  • 权重 W2 W_2 W2 形状: dff ,dmodel d_{ff}, d_{\\text{model}} dff,dmodel = 2048,5122048, 512 2048,512
  • 偏置 b2 b_2 b2 形状: dmodel d_{\\text{model}} dmodel = 512512 512
  • 输出形状: batch,seq_len,dmodel\\text{batch}, \\text{seq\\_len}, d_{\\text{model}} batch,seq_len,dmodel = 32,128,51232, 128, 512 32,128,512

💡 关键理解 :第二层线性变换不使用激活函数,直接输出线性结果。这与第一层不同(第一层后有 ReLU)。


2.3 使用 GELU 激活函数的变体 🌟

现代 Transformer(如 BERT、GPT)通常使用 GELU(Gaussian Error Linear Unit)替代 ReLU,因为它提供了更平滑的梯度流动。

GELU 公式 📐:
GELU(x)=x×Φ(x)\text{GELU}(x) = x \times \Phi(x) GELU(x)=x×Φ(x)

其中 Φ(x)\Phi(x) Φ(x) 是标准正态分布的累积分布函数(CDF)。

直观理解 💡:GELU 像是"带概率的 ReLU"

  • 正值:接近原值(但略有缩放)
  • 负值:接近 0(但不是严格为 0,而是平滑过渡)
  • 相比 ReLU 的"硬截断",GELU 提供了"软截断"

GELU vs ReLU 对比 ⚔️:

特性 ReLU GELU
公式 max⁡(0,x)\max(0, x) max(0,x) x×Φ(x)x \times \Phi(x) x×Φ(x)
负值处理 严格置 0 🚫 平滑趋近 0 🌊
梯度 阶跃函数(不连续) 平滑连续 ✅
性能 基线水平 通常更好 📈
计算成本 极低 ⚡ 稍高(需近似)

GELU 实现代码 💻:

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥
import torch.nn as nn                                     # 导入神经网络模块 🧠
import torch.nn.functional as F                           # 导入函数式 API 模块 ⚙️


"""使用 GELU 激活函数的 FFN 实现 💻

参数:
    d_model: 输入/输出维度(词嵌入维度),示例:512
    d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
    dropout: Dropout概率(默认0.1)
    
示例:
    ffn = FeedForwardGELU(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForwardGELU(nn.Module):
    """初始化使用 GELU 的 FFN 🛠️
    
    参数:
        d_model: 输入/输出维度,示例:512
        d_ff: 隐藏层维度,示例:2048
        dropout: Dropout概率(默认0.1)
        
    返回:
        无
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(FeedForwardGELU, self).__init__()           # 调用父类初始化 🏗️
        
        # 第一层线性变换:d_model → d_ff 📈
        self.linear1 = nn.Linear(d_model, d_ff)           # 权重 W1: [d_model, d_ff]
        
        # 第二层线性变换:d_ff → d_model 📉
        self.linear2 = nn.Linear(d_ff, d_model)           # 权重 W2: [d_ff, d_model]
        
        # Dropout 层 🎲
        self.dropout = nn.Dropout(dropout)                # 训练时随机丢弃部分权重
    
    """前向传播计算 FFN(使用 GELU 激活) ⚡
    
    参数:
        x: 输入张量 [batch_size, seq_len, d_model]
        
    返回:
        output: FFN 输出 [batch_size, seq_len, d_model]
    """
    def forward(self, x: torch.Tensor):
        # 1️⃣ 第一层线性变换 + GELU 激活
        # 数据流动:x[32,128,512] → linear1 → h[32,128,2048] → GELU → h[32,128,2048]
        h = F.gelu(self.linear1(x))                       # F.gelu 提供平滑的非线性激活 🌊
        
        # 2️⃣ Dropout 🎲
        # 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
        h = self.dropout(h)
        
        # 3️⃣ 第二层线性变换
        # 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
        output = self.linear2(h)
        
        return output

💡 PyTorch 提供了 F.gelu() 函数,内部使用了高效的近似实现,无需手动编写 GELU 公式。


参考资料:


3. SwiGLU 门控变体 ⚡

本章介绍现代 LLM(如 LLaMA、PaLM)广泛使用的 SwiGLU 门控 FFN

3.1 什么是 SwiGLU? 🤔

SwiGLU(Swish-Gated Linear Unit)是 GLU(Gated Linear Unit)的一种变体,由 Shazeer(2020)提出,在现代大语言模型(如 LLaMA、PaLM)中被广泛采用。

SwiGLU 公式 📐:
SwiGLU(x)= (Swishβ(xW1+b1)) ⊗(xW3+b3)\text{SwiGLU}(x) = \left(\text{Swish}_\beta(xW_1 + b_1)\right) \otimes (xW_3 + b_3) SwiGLU(x)=(Swishβ(xW1+b1))⊗(xW3+b3)
output=SwiGLU(x)⋅W2+b2\text{output} = \text{SwiGLU}(x) \cdot W_2 + b_2 output=SwiGLU(x)⋅W2+b2

其中:📝

  • W1,b1 W_1, b_1 W1,b1:第一个线性变换的权重和偏置(用于门控信号)
  • W3,b3 W_3, b_3 W3,b3:第三个线性变换的权重和偏置(用于值信号)
  • W2,b2 W_2, b_2 W2,b2:第二个线性变换的权重和偏置(输出投影)
  • Swishβ(x)=x×Sigmoid(βx) \text{Swish}_\beta(x) = x \times \text{Sigmoid}(\beta x) Swishβ(x)=x×Sigmoid(βx),通常 β=1\beta = 1 β=1
  • ⊗\otimes ⊗:逐元素乘法(Hadamard 积)

直观理解 💡:SwiGLU 像是一个"智能门控系统"

  • xW1 xW_1 xW1 生成"门控信号"(决定哪些信息应该通过)🚪
  • xW3 xW_3 xW3 生成"值信号"(实际要传递的信息)💎
  • Swish(门控)⊗值\text{Swish}(\text{门控}) \otimes \text{值} Swish(门控)⊗值:门控信号控制值信号的通过比例
  • 最后通过 W2 W_2 W2 投影到输出空间

与标准 FFN 的对比 ⚔️:

特性 标准 FFN SwiGLU FFN
线性层数量 2 层(linear1, linear2) 3 层(linear1, linear2, linear_v)
激活函数 ReLU / GELU Swish(Sigmoid 加权)
门控机制 无 ✅ 有(值信号 × 门控信号) 🚪
参数量 基准(1×) 更多(约 1.5×) 📈
性能 基线水平 通常更好 🌟
代表模型 原始 Transformer LLaMA、PaLM

为什么 SwiGLU 性能更好? 🎯

  1. 动态门控 🚪

    标准 FFN 的 ReLU/GELU 是"静态"的------只看单个神经元的值就决定是否激活。SwiGLU 的门控是"动态"的------一个线性层专门学习"应该让哪些信息通过",另一个线性层学习"要传递什么信息"。

  2. 更丰富的交互 🔗

    逐元素乘法( ⊗\otimes ⊗)允许门控信号和值信号在每个维度上独立交互,而不是简单地通过激活函数过滤。

  3. Swish 的平滑性 🌊

    Swish 函数 x×Sigmoid(x)x \times \text{Sigmoid}(x) x×Sigmoid(x) 在负值区域也有小梯度(不像 ReLU 完全为 0),有助于梯度流动。

直观类比 🎭:想象信息通过一个"智能安检门"

  • 标准 FFN(ReLU):所有包裹(信息)统一检查,负值直接丢弃 🚫
  • SwiGLU:有两个检查员 👥
    • 检查员 A(门控信号):决定每个包裹应该放行多少
    • 检查员 B(值信号):准备包裹的实际内容
    • 最终放行量 = A 的决定 × B 的内容(逐元素相乘)

3.2 SwiGLU 完整代码实现 💻

👇 下面是基于 CS336 作业一的 SwiGLU 实现:

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥
import torch.nn as nn                                     # 导入神经网络模块 🧠
import torch.nn.functional as F                           # 导入函数式 API 模块 ⚙️


"""SwiGLU 门控前馈网络实现(LLaMA 等现代 LLM 使用) 💻

参数:
    d_model: 输入/输出维度(词嵌入维度),示例:512
    d_ff: 隐藏层维度(通常是 d_model 的 4 倍),示例:2048
    dropout: Dropout概率(默认0.1)
    
示例:
    ffn = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
"""
class SwiGLU(nn.Module):
    """初始化 SwiGLU 的三个线性层 🛠️
    
    参数:
        d_model: 输入/输出维度,示例:512
        d_ff: 隐藏层维度,示例:2048
        dropout: Dropout概率(默认0.1)
        
    返回:
        无
        
    示例:
        self.swiglu = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
    """
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(SwiGLU, self).__init__()                    # 调用父类初始化 🏗️
        
        # 第一个线性层:生成门控信号(用于 Swish 激活) 🚪
        self.linear1 = nn.Linear(d_model, d_ff)           # 权重 W1: [d_model, d_ff]
        
        # 第二个线性层:输出投影(将门控后的结果映射回 d_model) 📉
        self.linear2 = nn.Linear(d_ff, d_model)           # 权重 W2: [d_ff, d_model]
        
        # 第三个线性层:生成值信号(与门控信号逐元素相乘) 💎
        self.linear_v = nn.Linear(d_model, d_ff)          # 权重 W3: [d_model, d_ff]
        
        # Dropout 层 🎲
        self.dropout = nn.Dropout(dropout)                # 训练时随机丢弃部分权重
    
    """前向传播计算 SwiGLU FFN ⚡
    
    参数:
        x: 输入张量 [batch_size=32, seq_len=128, d_model=512]
        
    返回:
        output: FFN 输出 [32, 128, 512]
        
    示例:
        output = swiglu(x)
    """
    def forward(self, x: torch.Tensor):
        # 1️⃣ 生成门控信号 + Swish 激活
        # 数据流动:x[32,128,512] → linear1 → gate[32,128,2048] → Swish → gate[32,128,2048]
        gate = F.silu(self.linear1(x))                    # F.silu = Swish,gate = xW1 × Sigmoid(xW1)
        
        # 2️⃣ 生成值信号
        # 数据流动:x[32,128,512] → linear_v → value[32,128,2048]
        value = self.linear_v(x)                          # value = xW3,包含实际要传递的信息
        
        # 3️⃣ 门控机制:逐元素相乘
        # 数据流动:gate[32,128,2048] ⊗ value[32,128,2048] → h[32,128,2048]
        h = gate * value                                  # 门控信号控制值信号的通过比例(逐元素)
        
        # 4️⃣ Dropout 🎲
        # 数据流动:h[32,128,2048] → Dropout → h[32,128,2048](部分变为0)
        h = self.dropout(h)
        
        # 5️⃣ 输出投影,压缩回原始维度
        # 数据流动:h[32,128,2048] → linear2 → output[32,128,512]
        output = self.linear2(h)                          # output = hW2,映射回 d_model 维度
        
        return output

3.3 SwiGLU 代码逐行解析 🔍

第1步:生成门控信号 1️⃣

python 复制代码
gate = F.silu(self.linear1(x))                            # 门控信号 + Swish 激活,数据流动:x[32,128,512] → gate[32,128,2048]

这一步完成两个操作:📝

  1. 线性变换 gate′=xW1+b1gate' = xW_1 + b_1 gate′=xW1+b1

    • 输入形状: 32,128,51232, 128, 512 32,128,512
    • 输出形状: 32,128,204832, 128, 2048 32,128,2048
  2. Swish 激活 gate=Swish(gate′)=gate′×Sigmoid(gate′)gate = \text{Swish}(gate') = gate' \times \text{Sigmoid}(gate') gate=Swish(gate′)=gate′×Sigmoid(gate′)

    • Swish 是"自门控"激活函数,输出范围: (−∞,+∞)(-\infty, +\infty) (−∞,+∞)
    • 在正值区域接近线性,在负值区域平滑趋近于 0

什么是 Swish 激活函数? 🤔

Swish 函数定义为:📐
Swish(x)=x×Sigmoid(βx)\text{Swish}(x) = x \times \text{Sigmoid}(\beta x) Swish(x)=x×Sigmoid(βx)

通常 β=1\beta = 1 β=1,简化为:📝
Swish(x)= x1+e−x \text{Swish}(x) = \frac{x}{1 + e^{-x}} Swish(x)=1+e−xx

Swish vs ReLU 对比 📊:

特性 ReLU Swish
公式 max⁡(0,x)\max(0, x) max(0,x) x×Sigmoid(x)x \times \text{Sigmoid}(x) x×Sigmoid(x)
负值区域 严格为 0 🚫 平滑负值(有微小梯度) 🌊
导数 不连续(在 0 处) 处处连续可导 ✅
上界 无上界 无上界

直观理解 💡:Swish 像是"带自我调节的 ReLU"

  • xx x 很大时: Sigmoid(x)≈1\text{Sigmoid}(x) \approx 1 Sigmoid(x)≈1,所以 Swish(x)≈x\text{Swish}(x) \approx x Swish(x)≈x(类似 ReLU)
  • xx x 很小时: Sigmoid(x)≈0\text{Sigmoid}(x) \approx 0 Sigmoid(x)≈0,所以 Swish(x)≈0\text{Swish}(x) \approx 0 Swish(x)≈0(但仍有微小梯度)
  • x≈0x \approx 0 x≈0 时:平滑过渡,不像 ReLU 那样突然截断

第2步:生成值信号 2️⃣

python 复制代码
value = self.linear_v(x)                                  # 值信号,数据流动:x[32,128,512] → value[32,128,2048]

这一步纯粹是线性变换,没有激活函数:📝

  • 输入形状: 32,128,51232, 128, 512 32,128,512
  • 权重 W3 W_3 W3 形状: 512,2048512, 2048 512,2048
  • 输出形状: 32,128,204832, 128, 2048 32,128,2048

💡 关键理解:值信号保持"原始"状态,让门控信号来决定如何过滤和调制。

第3步:门控机制(逐元素相乘) 3️⃣

python 复制代码
h = gate * value                                          # 门控 × 值,数据流动:gate[32,128,2048] ⊗ value[32,128,2048] → h[32,128,2048]

这是 SwiGLU 的核心操作------逐元素乘法(Element-wise Multiplication):📝

  • gate 形状: 32,128,204832, 128, 2048 32,128,2048
  • value 形状: 32,128,204832, 128, 2048 32,128,2048
  • 输出 h 形状: 32,128,204832, 128, 2048 32,128,2048

逐元素乘法意味着:🔍
hi,j,k=gatei,j,k×valuei,j,khi, j, k = gatei, j, k \times valuei, j, k hi,j,k=gatei,j,k×valuei,j,k

对于每个 batch、每个序列位置、每个隐藏维度,门控值和值值相乘。

直观类比 🎨:想象调色板混合颜料

  • gate 是"透明度控制"(0 = 完全透明,1 = 完全不透明)
  • value 是"颜料颜色"
  • gate * value = 按透明度显示颜料

如果门控信号在某维度是 0.8,值信号是 5.0:📝
h=0.8×5.0=4.0h = 0.8 \times 5.0 = 4.0 h=0.8×5.0=4.0

这意味着该维度的信息有 80% 被保留。✅

第4步:Dropout 4️⃣

python 复制代码
h = self.dropout(h)                                       # 应用 Dropout,数据流动:h[32,128,2048] → h[32,128,2048](部分变为0)

与标准 FFN 相同,Dropout 用于防止过拟合。🎲

第5步:输出投影 5️⃣

python 复制代码
output = self.linear2(h)                                  # 输出投影,数据流动:h[32,128,2048] → output[32,128,512]

将门控后的高维表示压缩回原始维度:📝

  • 输入形状: 32,128,204832, 128, 2048 32,128,2048
  • 权重 W2 W_2 W2 形状: 2048,5122048, 512 2048,512
  • 输出形状: 32,128,51232, 128, 512 32,128,512

参考资料:


4. 完整可运行示例 🎯

本章提供一个从头到尾可运行的完整代码,包含标准 FFN 和 SwiGLU 的测试

python 复制代码
import torch                                              # 导入 PyTorch 核心库 🔥
import torch.nn as nn                                     # 导入神经网络模块 🧠
import torch.nn.functional as F                           # 导入函数式 API 模块 ⚙️


"""标准位置级前馈网络(Position-wise FFN)实现 💻

参数:
    d_model: 输入/输出维度(词嵌入维度)
    d_ff: 隐藏层维度(通常是 d_model 的 4 倍)
    dropout: Dropout概率(默认0.1)
    
示例:
    ffn = FeedForward(d_model=512, d_ff=2048, dropout=0.1)
"""
class FeedForward(nn.Module):
    """初始化 FFN 的两层全连接网络 🛠️"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(FeedForward, self).__init__()               # 调用父类初始化 🏗️
        self.linear1 = nn.Linear(d_model, d_ff)           # 第一层:d_model → d_ff 📈
        self.linear2 = nn.Linear(d_ff, d_model)           # 第二层:d_ff → d_model 📉
        self.relu = nn.ReLU()                             # ReLU 激活函数 🌀
        self.dropout = nn.Dropout(dropout)                # Dropout 层 🎲
    
    """前向传播计算 FFN ⚡"""
    def forward(self, x: torch.Tensor):
        # 1️⃣ 第一层线性变换 + ReLU 激活
        # 数据流动:x[32,128,512] → h[32,128,2048](负值变0)
        h = self.relu(self.linear1(x))                    # h = max(0, xW1 + b1)
        
        # 2️⃣ Dropout 🎲
        h = self.dropout(h)                               # 训练时随机丢弃部分权重
        
        # 3️⃣ 第二层线性变换
        # 数据流动:h[32,128,2048] → output[32,128,512]
        output = self.linear2(h)                          # output = hW2 + b2
        
        return output


"""SwiGLU 门控前馈网络实现(LLaMA 等现代 LLM 使用) 💻

参数:
    d_model: 输入/输出维度(词嵌入维度)
    d_ff: 隐藏层维度(通常是 d_model 的 4 倍)
    dropout: Dropout概率(默认0.1)
    
示例:
    swiglu = SwiGLU(d_model=512, d_ff=2048, dropout=0.1)
"""
class SwiGLU(nn.Module):
    """初始化 SwiGLU 的三个线性层 🛠️"""
    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
        super(SwiGLU, self).__init__()                    # 调用父类初始化 🏗️
        self.linear1 = nn.Linear(d_model, d_ff)           # 门控信号线性层 🚪
        self.linear2 = nn.Linear(d_ff, d_model)           # 输出投影层 📉
        self.linear_v = nn.Linear(d_model, d_ff)          # 值信号线性层 💎
        self.dropout = nn.Dropout(dropout)                # Dropout 层 🎲
    
    """前向传播计算 SwiGLU FFN ⚡"""
    def forward(self, x: torch.Tensor):
        # 1️⃣ 生成门控信号 + Swish 激活
        # 数据流动:x[32,128,512] → gate[32,128,2048]
        gate = F.silu(self.linear1(x))                    # gate = Swish(xW1)
        
        # 2️⃣ 生成值信号
        # 数据流动:x[32,128,512] → value[32,128,2048]
        value = self.linear_v(x)                          # value = xW3
        
        # 3️⃣ 门控机制:逐元素相乘
        # 数据流动:gate ⊗ value → h[32,128,2048]
        h = gate * value                                  # 门控控制值的通过比例
        
        # 4️⃣ Dropout 🎲
        h = self.dropout(h)
        
        # 5️⃣ 输出投影
        # 数据流动:h[32,128,2048] → output[32,128,512]
        output = self.linear2(h)                          # output = hW2
        
        return output


"""测试标准 FFN 模块 🧪

参数:
    无
    
返回:
    output: FFN 输出 [batch_size, seq_len, d_model]
    
示例:
    output = test_standard_ffn()
"""
def test_standard_ffn():
    # 设置随机种子,保证结果可复现 🎯
    torch.manual_seed(42)
    
    # 参数设置 ⚙️
    batch_size = 2                                        # batch 大小
    seq_len = 10                                          # 序列长度
    d_model = 512                                         # 模型维度
    d_ff = 2048                                           # FFN 隐藏层维度(4 倍)
    
    # 随机生成输入张量 🎲
    x = torch.randn(batch_size, seq_len, d_model)         # x: [2, 10, 512]
    
    # 创建 FFN 模块 🛠️
    ffn = FeedForward(d_model=d_model, d_ff=d_ff, dropout=0.0)  # dropout=0.0 便于调试
    
    # 前向传播 ⚡
    output = ffn(x)                                       # output: [2, 10, 512]
    
    print("=" * 60)                                       # 打印分隔线 📏
    print("标准 FFN 测试")                                 # 打印测试标题 📝
    print("=" * 60)                                       # 打印分隔线 📏
    print(f"输入形状: {x.shape}")                         # 打印输入形状
    print(f"输出形状: {output.shape}")                    # 打印输出形状
    print(f"输入维度 == 输出维度: {x.shape == output.shape}")  # 验证维度一致性 ✅
    print("=" * 60)                                       # 打印分隔线 📏
    
    return output


"""测试 SwiGLU FFN 模块 🧪

参数:
    无
    
返回:
    output: SwiGLU 输出 [batch_size, seq_len, d_model]
    
示例:
    output = test_swiglu_ffn()
"""
def test_swiglu_ffn():
    # 设置随机种子,保证结果可复现 🎯
    torch.manual_seed(42)
    
    # 参数设置 ⚙️
    batch_size = 2                                        # batch 大小
    seq_len = 10                                          # 序列长度
    d_model = 512                                         # 模型维度
    d_ff = 2048                                           # FFN 隐藏层维度(4 倍)
    
    # 随机生成输入张量 🎲
    x = torch.randn(batch_size, seq_len, d_model)         # x: [2, 10, 512]
    
    # 创建 SwiGLU 模块 🛠️
    swiglu = SwiGLU(d_model=d_model, d_ff=d_ff, dropout=0.0)  # dropout=0.0 便于调试
    
    # 前向传播 ⚡
    output = swiglu(x)                                    # output: [2, 10, 512]
    
    print("=" * 60)                                       # 打印分隔线 📏
    print("SwiGLU FFN 测试")                              # 打印测试标题 📝
    print("=" * 60)                                       # 打印分隔线 📏
    print(f"输入形状: {x.shape}")                         # 打印输入形状
    print(f"输出形状: {output.shape}")                    # 打印输出形状
    print(f"输入维度 == 输出维度: {x.shape == output.shape}")  # 验证维度一致性 ✅
    print("=" * 60)                                       # 打印分隔线 📏
    
    return output


"""对比标准 FFN 和 SwiGLU 的参数量 📊

参数:
    无
    
返回:
    无(打印参数量对比)
"""
def compare_parameters():
    # 参数设置 ⚙️
    d_model = 512                                         # 模型维度
    d_ff = 2048                                           # FFN 隐藏层维度
    
    # 创建两个模块 🛠️
    ffn = FeedForward(d_model=d_model, d_ff=d_ff)         # 标准 FFN
    swiglu = SwiGLU(d_model=d_model, d_ff=d_ff)           # SwiGLU FFN
    
    # 计算参数量 🔢
    ffn_params = sum(p.numel() for p in ffn.parameters())     # 标准 FFN 总参数量
    swiglu_params = sum(p.numel() for p in swiglu.parameters())  # SwiGLU 总参数量
    
    print("=" * 60)                                       # 打印分隔线 📏
    print("参数量对比")                                    # 打印标题 📝
    print("=" * 60)                                       # 打印分隔线 📏
    print(f"标准 FFN 参数量: {ffn_params:,}")            # 打印标准 FFN 参数量
    print(f"SwiGLU 参数量: {swiglu_params:,}")            # 打印 SwiGLU 参数量
    print(f"SwiGLU / 标准 FFN: {swiglu_params / ffn_params:.2f}x")  # 打印倍数关系
    print("=" * 60)                                       # 打印分隔线 📏


if __name__ == "__main__":
    # 🚀 运行标准 FFN 测试
    output_ffn = test_standard_ffn()
    
    # 🚀 运行 SwiGLU FFN 测试
    output_swiglu = test_swiglu_ffn()
    
    # 📊 对比参数量
    compare_parameters()

4.1 运行结果示例

markdown 复制代码
============================================================
标准 FFN 测试
============================================================
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入维度 == 输出维度: True
============================================================
============================================================
SwiGLU FFN 测试
============================================================
输入形状: torch.Size([2, 10, 512])
输出形状: torch.Size([2, 10, 512])
输入维度 == 输出维度: True
============================================================
============================================================
参数量对比
============================================================
标准 FFN 参数量: 2,099,200
SwiGLU 参数量: 3,148,800
SwiGLU / 标准 FFN: 1.50x
============================================================

可以看到:👀

  • ✅ 标准 FFN 和 SwiGLU 的输入输出维度保持一致(都是 batch,seq_len,dmodel\\text{batch}, \\text{seq\\_len}, d_{\\text{model}} batch,seq_len,dmodel
  • ✅ SwiGLU 的参数量是标准 FFN 的 1.5 倍(因为有 3 个线性层而不是 2 个)

参数量计算详解 🔢:

标准 FFN 参数量:📐
ParamsFFN=(dmodel× dff + dff )+( dff ×dmodel+dmodel) \text{Params}{\text{FFN}} = (d{\text{model}} \times d_{ff} + d_{ff}) + (d_{ff} \times d_{\text{model}} + d_{\text{model}}) ParamsFFN=(dmodel×dff+dff)+(dff×dmodel+dmodel)
=(512×2048+2048)+(2048×512+512)= (512 \times 2048 + 2048) + (2048 \times 512 + 512) =(512×2048+2048)+(2048×512+512)
=1,050,624+1,048,576=2,099,200= 1,050,624 + 1,048,576 = 2,099,200 =1,050,624+1,048,576=2,099,200

SwiGLU 参数量:📐
ParamsSwiGLU=(dmodel× dff + dff )+( dff ×dmodel+dmodel)+(dmodel× dff + dff ) \text{Params}{\text{SwiGLU}} = (d{\text{model}} \times d_{ff} + d_{ff}) + (d_{ff} \times d_{\text{model}} + d_{\text{model}}) + (d_{\text{model}} \times d_{ff} + d_{ff}) ParamsSwiGLU=(dmodel×dff+dff)+(dff×dmodel+dmodel)+(dmodel×dff+dff)
=1,050,624+1,048,576+1,050,624=3,149,824≈3,148,800= 1,050,624 + 1,048,576 + 1,050,624 = 3,149,824 \approx 3,148,800 =1,050,624+1,048,576+1,050,624=3,149,824≈3,148,800

💡 SwiGLU 多了一个 linear_v 层( dmodel→ dff d_{\text{model}} \rightarrow d_{ff} dmodel→dff),所以参数量增加了约 50%。


5. 总结 📝

本节我们完成了斯坦福 CS336 作业一的前馈网络实现,核心要点回顾:🎯

FFN 类型 线性层数量 激活函数 门控机制 参数量 代表模型
标准 FFN 2 层 ReLU 基准(1×) 原始 Transformer
GELU FFN 2 层 GELU 基准(1×) BERT、GPT
SwiGLU FFN 3 层 Swish 有(逐元素乘) 更多(1.5×) LLaMA、PaLM

🔴 关键理解

  1. FFN 的核心作用 💡

    对每个 token 的表示进行独立的非线性变换,提升模型的表达能力。与自注意力机制(捕捉全局依赖)互补,FFN 专注于"局部信息加工"。

  2. "扩展-收缩"结构 📐

    FFN 先将维度扩大(通常 4 倍),在高维空间进行非线性变换,再压缩回原始维度。这种设计允许模型学习更复杂的特征表示。

  3. SwiGLU 的优势 🌟

    通过门控机制(门控信号 × 值信号)和 Swish 激活,SwiGLU 提供了比标准 FFN 更强大的信息过滤和调制能力,是现代 LLM 的标准选择。

  4. 维度一致性

    所有 FFN 变体的输入和输出维度都保持为 dmodel d_{\text{model}} dmodel,这使得它们可以无缝嵌入到 Transformer Block 中,与残差连接配合使用。


最后更新时间:2026-06-03

相关推荐
武子康6 小时前
调查研究-175 Supermemory:AI 时代的 Memory API,不只是另一个向量数据库
人工智能·openai
寒山李白7 小时前
人工智能训练师报考指南
人工智能·ai·证书·职称·训练师
努力努力再努力FFF7 小时前
大学四年AI能力规划:从入门学习到简历表达
人工智能·学习
Litluecat7 小时前
配合多角色提示语3,学习AI漫剧(刚开始学)
人工智能·学习·ai·提示词·短剧·漫剧
xixingzhe27 小时前
AI开发工具-大需求
人工智能
沪漂阿龙7 小时前
create_agent:LangChain 新版 Agent 的核心入口
人工智能·架构·langchain
茉莉玫瑰花茶7 小时前
综合案例 - AI 智能租房助手 [ 5 ]
服务器·数据库·人工智能·python·ai
文艺倾年7 小时前
【强化学习】强化学习基本概念,20W字总结(一)
人工智能·python·语言模型·自然语言处理·面试·职场和发展·大模型
FserSuN7 小时前
压缩在智能中的作用
人工智能