python
def precompute_pos_cis(dim: int, end: int = int(32*1024), theta: float =1e6):
"""位置编码预处理
Args:
dim (int): 输入的维度
end (int, optional): 最大输出Token数. Defaults to int(32*1024).
theta (float, optional): 控制频率衰减的参数. Defaults to 1e6.
"""
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) # freq=1/e^freq,频率的多尺度和dim有关
t = torch.arange(end ,device = freqs.device) # (end,) 表示输出序列最大长度
freqs = torch.outer(t, freqs).float() # (end, dim//2),外积,表示每一个位置带有不同频率的旋转
pos_cis = torch.polar(torch.ones_like(freqs), freqs) # polar(单位阵, 频率阵), 生成复数矩阵
return pos_cis
解释:

python
def apply_rotary_emb(xq,xk, pos_cis):
"""RoPE位置编码
Args:
xq (_type_): Q矩阵 (batchsize, seqlen, heads, dim//2)
xk (_type_): K矩阵 (batchsize, seqlen, heads, dim//2)
pos_cis (_type_): 预处理后的旋转频率矩阵 (end, dim//2)
"""
def unite_shape(pos_cis, x):
"""对预处理后的旋转频率矩阵进行广播
Args:
pos_cis (_type_): 旋转频率矩阵
x (_type_): 输入
"""
ndim = x.ndim # 维度,4 (batchsize, seqlen, heads, dim//2)
assert 0 <= 1 < ndim
assert pos_cis.shape == (x.shape[1], x.shape[-1]) # (x.shape[1], x.shape[-1]) = (seqlen, dim//2)
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] # reshape x = (1, seqlen, 1, dim//2)
return pos_cis.view(*shape) # reshape pos_cis (1, 1, end. dim//2)
xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1 ,2))
pos_cis = unite_shape(*(pos_cis, xq_))
xq_out = torch.view_as_real(xq_ * pos_cis).flatten(3)
xk_out = torch.view_as_real(xk_ * pos_cis).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
解释:
