文章目录
GLU 变种:ReGLU 、 GEGLU 、 SwiGLU
-
在 GLU 的基础上,陆续提出了若干"激活 + GLU "的混合门控单元。它们共享同一套"双线形投影 + 逐元素门控"范式,差别仅在于把 GLU 中的 Sigmoid 门控替换为其他非线性函数,从而在参数量几乎不变的前提下带来不同的归纳偏差与性能收益。
-
参考论文:GLU Variants Improve Transformer
1. ReGLU(ReLU-GLU)
- 核心思想:把 Sigmoid 换成 ReLU,让门控也具备稀疏性,计算更便宜,且保留 GLU 的残差特性。
函数表达式
ReGLU ( x ) = ( x W + b ) ⊗ ReLU ( x V + c ) \text{ReGLU}(x) = (xW+b)\,\otimes\,\text{ReLU}(xV+c) ReGLU(x)=(xW+b)⊗ReLU(xV+c)
代码
-
代码
pythonimport torch from torch import nn class ReGLU(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.w_gate = nn.Linear(d_in, d_out, bias=False) self.w_up = nn.Linear(d_in, d_out, bias=False) self.w_down = nn.Linear(d_out, d_in, bias=False) def forward(self, x): gate = F.relu(self.w_gate(x)) up = self.w_up(x) return self.w_down(gate * up)
2. GEGLU(Gaussian Error GLU)
- 核心思想:用 GELU 取代 Sigmoid,兼顾稀疏与平滑,兼顾 ReLU 的低计算与 Swish 的高表达。
函数表达式
GEGLU ( x ) = ( x W + b ) ⊗ GELU ( x V + c ) \text{GEGLU}(x) = (xW+b)\,\otimes\,\text{GELU}(xV+c) GEGLU(x)=(xW+b)⊗GELU(xV+c)
代码
-
代码
pythonimport torch from torch import nn class GEGLU(nn.Module): def __init__(self, d_in, d_out): super().__init__() self.w_gate = nn.Linear(d_in, d_out, bias=False) self.w_up = nn.Linear(d_in, d_out, bias=False) self.w_down = nn.Linear(d_out, d_in, bias=False) def forward(self, x): gate = F.gelu(self.w_gate(x)) up = self.w_up(x) return self.w_down(gate * up)
3. SwiGLU(Swish-GLU)
- 核心思想:将 Swish 引入门控;Swish 本身具备 可学习/常数 β,在深层网络中表现优于 ReLU/GELU。
函数表达式
SwiGLU ( x ) = ( x W + b ) ⊗ Swish β ( x V + c ) Swish β ( z ) = z ⋅ σ ( β z ) \text{SwiGLU}(x) = (xW+b)\,\otimes\,\text{Swish}\beta(xV+c) \\ \text{Swish}\beta(z)=z\cdot\sigma(\beta z) SwiGLU(x)=(xW+b)⊗Swishβ(xV+c)Swishβ(z)=z⋅σ(βz)
代码
-
固定swish函数中的参数 β = 1 \beta = 1 β=1 (SiLU)
pythonimport troch from torch import nn class SwiGLU(nn.Module): def __init__(self, d_in, d_out, beta=1.0): super().__init__() self.beta = beta self.w_gate = nn.Linear(d_in, d_out, bias=False) self.w_up = nn.Linear(d_in, d_out, bias=False) self.w_down = nn.Linear(d_out, d_in, bias=False) def forward(self, x): gate = self.w_gate(x) gate = gate * torch.sigmoid(self.beta * gate) # Swish up = self.w_up(x) return self.w_down(gate * up)
合并代码
-
torch封装
pythonimport torch from torch import nn class GLUVariants(nn.Module): def __init__(self, d_in, d_out, variant="geglu"): super().__init__() self.variant = variant.lower() self.w_gate = nn.Linear(d_in, d_out, bias=False) self.w_up = nn.Linear(d_in, d_out, bias=False) self.w_down = nn.Linear(d_out, d_in, bias=False) def forward(self, x): gate = self.w_gate(x) up = self.w_up(x) if self.variant == "reglu": gate = F.relu(gate) elif self.variant == "geglu": gate = F.gelu(gate) elif self.variant == "swiglu": gate = gate * torch.sigmoid(gate) # β=1 else: gate = torch.sigmoid(gate) # fallback to GLU return self.w_down(gate * up)
输出
pythontorch.Size([8, 64, 512])
-
对比
特性 GLU ReGLU GEGLU SwiGLU 门控激活 Sigmoid ReLU GELU Swish 稀疏门控 否 是 部分 平滑稀疏 计算量 中 低 中 中 梯度平滑性 中 差 好 最好 实际效果(大模型) 基线 接近 GLU 略优于 GLU 最佳 是否需额外参数 否 否 否 可选 β