使用householder反射推广ROPE相对位置编码

推导如下:

对向量a和b,分别使用u和v进行householder反射,得到

则a'和b'的向量内积为

有明确的几何意义,为在u和v所张成的二维平面上旋转,旋转角度为u和v的夹角。

为了保证的值只与相对位置有关,每个反射向量必按照某个二维平面均匀分布。假设该二维平面的单位正交基为m和n。则。其中m和n是可学习参数。

代码实现如下:

python 复制代码
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Tuple, Optional

class HouseholderRotaryEmbedding(nn.Module):
    """
    Householder推广的Rotary Position Embedding (RoPE)实现。
    接口设计与标准RoPE保持一致,方便集成到现有Transformer中。
    """
    def __init__(self, dim: int, base: float = 10000.0):
        """
        初始化Householder RoPE。
        
        参数:
            dim: 每个注意力头的维度
            base: 用于计算频率的基础值,默认10000
        """
        super().__init__()
        assert dim % 2 == 0, f"维度必须为偶数,当前维度: {dim}"
        
        self.dim = dim
        self.base = base
        
        # 初始化可学习的正交基向量
        self.m = nn.Parameter(torch.randn(dim))
        self.n = nn.Parameter(torch.randn(dim))
        
        # 预计算频率(与原始RoPE一致)
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        
    def _get_cos_sin(self, 
                     seq_len: int, 
                     device: torch.device, 
                     dtype: torch.dtype) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        获取用于位置编码的cos和sin值。
        
        返回:
            cos: [seq_len, dim] 余弦值
            sin: [seq_len, dim] 正弦值
        """
        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        
        # 扩展为完整维度
        freqs = freqs.repeat_interleave(2, dim=-1)  # [seq_len, dim]
        
        cos = torch.cos(freqs).to(dtype)  # [seq_len, dim]
        sin = torch.sin(freqs).to(dtype)  # [seq_len, dim]
        
        return cos, sin
    
    def _orthogonalize_basis(self) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        对可学习的基向量进行施密特正交化并归一化。
        
        返回:
            m_unit: 正交归一化的m向量
            n_unit: 正交归一化的n向量
        """
        # 施密特正交化:使n与m正交
        m_norm_sq = torch.dot(self.m, self.m).clamp(min=1e-10)
        proj_coeff = torch.dot(self.n, self.m) / m_norm_sq
        n_ortho = self.n - proj_coeff * self.m
        
        # 归一化
        m_unit = F.normalize(self.m, p=2, dim=0)
        n_unit = F.normalize(n_ortho, p=2, dim=0)
        
        return m_unit, n_unit
    
    def _get_reflection_vectors(self, 
                               seq_len: int, 
                               device: torch.device, 
                               dtype: torch.dtype) -> torch.Tensor:
        """
        生成Householder反射向量。
        
        返回:
            u: [seq_len, dim] 反射向量
        """
        # 获取正交基和三角函数值
        m_unit, n_unit = self._orthogonalize_basis()
        cos, sin = self._get_cos_sin(seq_len, device, dtype)
        
        # 生成反射向量:u_i = cos(iθ)·m + sin(iθ)·n
        u = cos * m_unit + sin * n_unit  # [seq_len, dim]
        
        # 归一化反射向量以确保数值稳定性
        u = F.normalize(u, p=2, dim=-1, eps=1e-6)
        
        return u
    
    def forward(self, 
                x: torch.Tensor, 
                seq_len: Optional[int] = None) -> torch.Tensor:
        """
        对输入张量应用Householder RoPE变换。
        
        参数:
            x: 输入张量,形状为 [batch_size, seq_len, n_head, head_dim]
                或 [batch_size, n_head, seq_len, head_dim]
            seq_len: 序列长度,如果不提供则从x的形状推断
        
        返回:
            变换后的张量,形状与输入相同
        """
        batch_size, seq_len_x, n_head, head_dim = x.shape
        seq_len = seq_len or seq_len_x
        
        if head_dim != self.dim:
            raise ValueError(f"输入维度{head_dim}与初始化维度{self.dim}不匹配")
        
        # 生成反射向量
        u = self._get_reflection_vectors(seq_len, x.device, x.dtype)  # [seq_len, dim]
        
        # 重塑输入以应用变换
        # 转换为 [batch_size * n_head, seq_len, head_dim]
        x_reshaped = x.reshape(-1, seq_len, head_dim)
        
        # 应用Householder变换: H(x) = x - 2(x·u)u
        # 计算点积: [batch_size * n_head, seq_len]
        dot_product = torch.einsum('bsd,sd->bs', x_reshaped, u)
        
        # 应用变换
        x_transformed = x_reshaped - 2.0 * dot_product.unsqueeze(-1) * u
        
        # 重塑回原始形状
        x_transformed = x_transformed.reshape(batch_size, seq_len_x, n_head, head_dim)
        
        return x_transformed
    
    def apply_rotary_pos_emb(self, 
                            q: torch.Tensor, 
                            k: torch.Tensor, 
                            seq_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        同时应用于查询和键的标准RoPE接口。
        
        参数:
            q: 查询张量,形状为 [batch_size, seq_len, n_head, head_dim]
            k: 键张量,形状为 [batch_size, seq_len, n_head, head_dim]
            seq_len: 序列长度
        
        返回:
            q_rotated: 旋转后的查询
            k_rotated: 旋转后的键
        """
        return self.forward(q, seq_len), self.forward(k, seq_len)
    
    def apply_rotary_pos_emb_single(self, 
                                   x: torch.Tensor, 
                                   seq_len: Optional[int] = None) -> torch.Tensor:
        """
        应用于单个张量的标准RoPE接口。
        
        参数:
            x: 输入张量
            seq_len: 序列长度
        
        返回:
            旋转后的张量
        """
        return self.forward(x, seq_len)


# ==================== 兼容性包装器 ====================

class HouseholderRotary(nn.Module):
    """
    完全兼容标准RoPE接口的包装器。
    """
    def __init__(self, dim: int, base: float = 10000.0):
        super().__init__()
        self.rope = HouseholderRotaryEmbedding(dim, base)
        
    def forward(self, 
                q: torch.Tensor, 
                k: torch.Tensor, 
                seq_dim: int = -2) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        标准RoPE前向传播。
        
        参数:
            q: 查询张量
            k: 键张量
            seq_dim: 序列维度的索引
        
        返回:
            旋转后的查询和键
        """
        # 确保输入形状正确 [batch, seq_len, heads, dim]
        if q.dim() == 4:
            q_rotated = self.rope(q)
            k_rotated = self.rope(k)
            return q_rotated, k_rotated
        else:
            # 处理其他形状(如[batch, heads, seq_len, dim])
            raise NotImplementedError("目前仅支持 [batch, seq_len, heads, dim] 形状")


# ==================== 使用示例 ====================

if __name__ == "__main__":
    # 1. 基本使用示例
    print("=== 基本使用示例 ===")
    
    dim = 128
    seq_len = 50
    batch_size = 2
    n_heads = 4
    
    # 创建位置编码层
    rope = HouseholderRotaryEmbedding(dim)
    
    # 创建模拟的查询和键
    q = torch.randn(batch_size, seq_len, n_heads, dim)
    k = torch.randn(batch_size, seq_len, n_heads, dim)
    
    # 应用位置编码
    q_rotated, k_rotated = rope.apply_rotary_pos_emb(q, k)
    
    print(f"原始查询形状: {q.shape}")
    print(f"旋转后查询形状: {q_rotated.shape}")
    print(f"旋转后键形状: {k_rotated.shape}")
    
    # 2. 与标准注意力模块集成示例
    print("\n=== 与注意力模块集成示例 ===")
    
    class MultiHeadAttentionWithHouseholderRoPE(nn.Module):
        """集成Householder RoPE的多头注意力示例"""
        def __init__(self, embed_dim, num_heads, dropout=0.0):
            super().__init__()
            self.embed_dim = embed_dim
            self.num_heads = num_heads
            self.head_dim = embed_dim // num_heads
            
            self.q_proj = nn.Linear(embed_dim, embed_dim)
            self.k_proj = nn.Linear(embed_dim, embed_dim)
            self.v_proj = nn.Linear(embed_dim, embed_dim)
            self.out_proj = nn.Linear(embed_dim, embed_dim)
            
            self.dropout = nn.Dropout(dropout)
            
            # Householder RoPE
            self.rotary_emb = HouseholderRotaryEmbedding(self.head_dim)
            
        def forward(self, x, attention_mask=None):
            batch_size, seq_len, _ = x.shape
            
            # 投影查询、键、值
            q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
            k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
            v = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
            
            # 应用Householder RoPE
            q, k = self.rotary_emb.apply_rotary_pos_emb(q, k)
            
            # 转置以进行注意力计算 [batch, heads, seq_len, head_dim]
            q = q.transpose(1, 2)
            k = k.transpose(1, 2)
            v = v.transpose(1, 2)
            
            # 缩放点积注意力
            attn_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.head_dim)
            
            if attention_mask is not None:
                attn_scores = attn_scores.masked_fill(attention_mask == 0, float('-inf'))
            
            attn_weights = F.softmax(attn_scores, dim=-1)
            attn_weights = self.dropout(attn_weights)
            
            # 注意力输出
            attn_output = torch.matmul(attn_weights, v)
            
            # 转置回 [batch, seq_len, heads, head_dim] 并重塑
            attn_output = attn_output.transpose(1, 2).contiguous()
            attn_output = attn_output.view(batch_size, seq_len, self.embed_dim)
            
            # 最终投影
            output = self.out_proj(attn_output)
            
            return output, attn_weights
    
    # 创建并测试注意力模块
    embed_dim = 256
    num_heads = 8
    
    attn = MultiHeadAttentionWithHouseholderRoPE(embed_dim, num_heads)
    
    # 测试输入
    test_input = torch.randn(batch_size, seq_len, embed_dim)
    output, attn_weights = attn(test_input)
    
    print(f"输入形状: {test_input.shape}")
    print(f"输出形状: {output.shape}")
    print(f"注意力权重形状: {attn_weights.shape}")
    
    # 3. 检查正交性
    print("\n=== 正交性检查 ===")
    with torch.no_grad():
        m_unit, n_unit = rope._orthogonalize_basis()
        dot_product = torch.dot(m_unit, n_unit).item()
        m_norm = torch.norm(m_unit).item()
        n_norm = torch.norm(n_unit).item()
        
        print(f"m 范数: {m_norm:.6f} (应接近1.0)")
        print(f"n 范数: {n_norm:.6f} (应接近1.0)")
        print(f"m·n 点积: {dot_product:.6f} (应接近0.0)")
相关推荐
FserSuN1 小时前
Agent开发总结学习
人工智能·学习
LCG米1 小时前
从训练到部署:基于PyTorch与TensorFlow Lite的端侧AI花卉分类系统完整指南
人工智能·pytorch·tensorflow
冴羽1 小时前
太好看了!3 个动漫变真人 Nano Banana Pro 提示词
前端·人工智能·aigc
资深低代码开发平台专家1 小时前
通用编程时代正在向专用化分层演进
java·大数据·c语言·c++·python
悟纤1 小时前
Suno 创作《亲爱的你》歌词模式全流程制作 | 从零开始用Suno Ai | 第4篇
人工智能·suno·suno ai
TL滕1 小时前
从0开始学算法——第六天(进阶排序算法练习)
笔记·python·学习·算法·排序算法
mqiqe1 小时前
【AI】Weaviate向量数据库详细部署安装应用
数据库·人工智能
AI生成未来1 小时前
ICCV 2025 | 北大王选所推出AnyPortal:像素级操控视频背景,前景细节100%保留!
人工智能·扩散模型·视频编辑·视频生成
jixunwulian1 小时前
边缘计算网关在空压机数据采集与远程运维中的解决方案
运维·人工智能·边缘计算