推导如下:
对向量a和b,分别使用u和v进行householder反射,得到
则a'和b'的向量内积为
有明确的几何意义,为在u和v所张成的二维平面上旋转,旋转角度为u和v的夹角。
为了保证的值只与相对位置有关,每个反射向量必按照某个二维平面均匀分布。假设该二维平面的单位正交基为m和n。则
。其中m和n是可学习参数。
代码实现如下:
python
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional
class HouseholderRotaryEmbedding(nn.Module):
"""
Householder推广的Rotary Position Embedding (RoPE)实现。
接口设计与标准RoPE保持一致,方便集成到现有Transformer中。
"""
def __init__(self, dim: int, base: float = 10000.0):
"""
初始化Householder RoPE。
参数:
dim: 每个注意力头的维度
base: 用于计算频率的基础值,默认10000
"""
super().__init__()
assert dim % 2 == 0, f"维度必须为偶数,当前维度: {dim}"
self.dim = dim
self.base = base
# 初始化可学习的正交基向量
self.m = nn.Parameter(torch.randn(dim))
self.n = nn.Parameter(torch.randn(dim))
# 预计算频率(与原始RoPE一致)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq, persistent=False)
def _get_cos_sin(self,
seq_len: int,
device: torch.device,
dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
"""
获取用于位置编码的cos和sin值。
返回:
cos: [seq_len, dim] 余弦值
sin: [seq_len, dim] 正弦值
"""
t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
freqs = torch.outer(t, self.inv_freq)
# 扩展为完整维度
freqs = freqs.repeat_interleave(2, dim=-1) # [seq_len, dim]
cos = torch.cos(freqs).to(dtype) # [seq_len, dim]
sin = torch.sin(freqs).to(dtype) # [seq_len, dim]
return cos, sin
def _orthogonalize_basis(self) -> Tuple[torch.Tensor, torch.Tensor]:
"""
对可学习的基向量进行施密特正交化并归一化。
返回:
m_unit: 正交归一化的m向量
n_unit: 正交归一化的n向量
"""
# 施密特正交化:使n与m正交
m_norm_sq = torch.dot(self.m, self.m).clamp(min=1e-10)
proj_coeff = torch.dot(self.n, self.m) / m_norm_sq
n_ortho = self.n - proj_coeff * self.m
# 归一化
m_unit = F.normalize(self.m, p=2, dim=0)
n_unit = F.normalize(n_ortho, p=2, dim=0)
return m_unit, n_unit
def _get_reflection_vectors(self,
seq_len: int,
device: torch.device,
dtype: torch.dtype) -> torch.Tensor:
"""
生成Householder反射向量。
返回:
u: [seq_len, dim] 反射向量
"""
# 获取正交基和三角函数值
m_unit, n_unit = self._orthogonalize_basis()
cos, sin = self._get_cos_sin(seq_len, device, dtype)
# 生成反射向量:u_i = cos(iθ)·m + sin(iθ)·n
u = cos * m_unit + sin * n_unit # [seq_len, dim]
# 归一化反射向量以确保数值稳定性
u = F.normalize(u, p=2, dim=-1, eps=1e-6)
return u
def forward(self,
x: torch.Tensor,
seq_len: Optional[int] = None) -> torch.Tensor:
"""
对输入张量应用Householder RoPE变换。
参数:
x: 输入张量,形状为 [batch_size, seq_len, n_head, head_dim]
或 [batch_size, n_head, seq_len, head_dim]
seq_len: 序列长度,如果不提供则从x的形状推断
返回:
变换后的张量,形状与输入相同
"""
batch_size, seq_len_x, n_head, head_dim = x.shape
seq_len = seq_len or seq_len_x
if head_dim != self.dim:
raise ValueError(f"输入维度{head_dim}与初始化维度{self.dim}不匹配")
# 生成反射向量
u = self._get_reflection_vectors(seq_len, x.device, x.dtype) # [seq_len, dim]
# 重塑输入以应用变换
# 转换为 [batch_size * n_head, seq_len, head_dim]
x_reshaped = x.reshape(-1, seq_len, head_dim)
# 应用Householder变换: H(x) = x - 2(x·u)u
# 计算点积: [batch_size * n_head, seq_len]
dot_product = torch.einsum('bsd,sd->bs', x_reshaped, u)
# 应用变换
x_transformed = x_reshaped - 2.0 * dot_product.unsqueeze(-1) * u
# 重塑回原始形状
x_transformed = x_transformed.reshape(batch_size, seq_len_x, n_head, head_dim)
return x_transformed
def apply_rotary_pos_emb(self,
q: torch.Tensor,
k: torch.Tensor,
seq_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
"""
同时应用于查询和键的标准RoPE接口。
参数:
q: 查询张量,形状为 [batch_size, seq_len, n_head, head_dim]
k: 键张量,形状为 [batch_size, seq_len, n_head, head_dim]
seq_len: 序列长度
返回:
q_rotated: 旋转后的查询
k_rotated: 旋转后的键
"""
return self.forward(q, seq_len), self.forward(k, seq_len)
def apply_rotary_pos_emb_single(self,
x: torch.Tensor,
seq_len: Optional[int] = None) -> torch.Tensor:
"""
应用于单个张量的标准RoPE接口。
参数:
x: 输入张量
seq_len: 序列长度
返回:
旋转后的张量
"""
return self.forward(x, seq_len)
# ==================== 兼容性包装器 ====================
class HouseholderRotary(nn.Module):
"""
完全兼容标准RoPE接口的包装器。
"""
def __init__(self, dim: int, base: float = 10000.0):
super().__init__()
self.rope = HouseholderRotaryEmbedding(dim, base)
def forward(self,
q: torch.Tensor,
k: torch.Tensor,
seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
"""
标准RoPE前向传播。
参数:
q: 查询张量
k: 键张量
seq_dim: 序列维度的索引
返回:
旋转后的查询和键
"""
# 确保输入形状正确 [batch, seq_len, heads, dim]
if q.dim() == 4:
q_rotated = self.rope(q)
k_rotated = self.rope(k)
return q_rotated, k_rotated
else:
# 处理其他形状(如[batch, heads, seq_len, dim])
raise NotImplementedError("目前仅支持 [batch, seq_len, heads, dim] 形状")
# ==================== 使用示例 ====================
if __name__ == "__main__":
# 1. 基本使用示例
print("=== 基本使用示例 ===")
dim = 128
seq_len = 50
batch_size = 2
n_heads = 4
# 创建位置编码层
rope = HouseholderRotaryEmbedding(dim)
# 创建模拟的查询和键
q = torch.randn(batch_size, seq_len, n_heads, dim)
k = torch.randn(batch_size, seq_len, n_heads, dim)
# 应用位置编码
q_rotated, k_rotated = rope.apply_rotary_pos_emb(q, k)
print(f"原始查询形状: {q.shape}")
print(f"旋转后查询形状: {q_rotated.shape}")
print(f"旋转后键形状: {k_rotated.shape}")
# 2. 与标准注意力模块集成示例
print("\n=== 与注意力模块集成示例 ===")
class MultiHeadAttentionWithHouseholderRoPE(nn.Module):
"""集成Householder RoPE的多头注意力示例"""
def __init__(self, embed_dim, num_heads, dropout=0.0):
super().__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.head_dim = embed_dim // num_heads
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
# Householder RoPE
self.rotary_emb = HouseholderRotaryEmbedding(self.head_dim)
def forward(self, x, attention_mask=None):
batch_size, seq_len, _ = x.shape
# 投影查询、键、值
q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
# 应用Householder RoPE
q, k = self.rotary_emb.apply_rotary_pos_emb(q, k)
# 转置以进行注意力计算 [batch, heads, seq_len, head_dim]
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
# 缩放点积注意力
attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
if attention_mask is not None:
attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
attn_weights = F.softmax(attn_scores, dim=-1)
attn_weights = self.dropout(attn_weights)
# 注意力输出
attn_output = torch.matmul(attn_weights, v)
# 转置回 [batch, seq_len, heads, head_dim] 并重塑
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
# 最终投影
output = self.out_proj(attn_output)
return output, attn_weights
# 创建并测试注意力模块
embed_dim = 256
num_heads = 8
attn = MultiHeadAttentionWithHouseholderRoPE(embed_dim, num_heads)
# 测试输入
test_input = torch.randn(batch_size, seq_len, embed_dim)
output, attn_weights = attn(test_input)
print(f"输入形状: {test_input.shape}")
print(f"输出形状: {output.shape}")
print(f"注意力权重形状: {attn_weights.shape}")
# 3. 检查正交性
print("\n=== 正交性检查 ===")
with torch.no_grad():
m_unit, n_unit = rope._orthogonalize_basis()
dot_product = torch.dot(m_unit, n_unit).item()
m_norm = torch.norm(m_unit).item()
n_norm = torch.norm(n_unit).item()
print(f"m 范数: {m_norm:.6f} (应接近1.0)")
print(f"n 范数: {n_norm:.6f} (应接近1.0)")
print(f"m·n 点积: {dot_product:.6f} (应接近0.0)")