Transformer 代码剖析6 - 位置编码 (pytorch实现)

一、位置编码的数学原理与设计思想

1.1 核心公式解析

位置编码采用正弦余弦交替编码方案:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

式中:

  • p o s pos pos:当前词在序列中的绝对位置
  • i i i:特征维度的索引( 0 ≤ i < d m o d e l / 2 0 \leq i < d_{model}/2 0≤i<dmodel/2)
  • 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel:频率控制项,形成指数衰减的频率分布

1.2 设计优势分析

1. 绝对位置感知: 每个位置生成唯一编码模式
2. 相对位置建模: 通过三角函数加法公式可推导任意两个位置的关联度
3. 多频特征捕捉: 不同频率的正余弦波组合形成丰富的表征空间
4. 值域归一化: 所有编码值分布在[-1,1]区间,与词嵌入维度保持数值一致性

(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)

二、代码架构与执行流程

2.1 类结构设计

PositionalEncoding __init__构造函数 创建零矩阵 配置梯度策略 构建位置索引 生成维度索引 计算正弦编码 计算余弦编码 forward前向传播 获取输入尺寸 返回截断编码

2.2 核心代码模块

python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super().__init__()
        # 编码矩阵初始化(关键参数说明)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 冻结梯度计算
        
        # 位置索引构建(维度变换演示)
        pos = torch.arange(0, max_len, device=device).float().unsqueeze(dim=1)
        
        # 维度索引生成(步长控制逻辑)
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        
        # 编码计算过程(数学实现)
        self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

三、逐行代码深度解析

3.1 构造函数解析

python 复制代码
def __init__(self, d_model, max_len, device):
    super(PositionalEncoding, self).__init__()
  • 功能说明:继承PyTorch模块基类,初始化可训练参数
  • 参数详解:
    • d_model:编码维度(需与词嵌入维度一致)
    • max_len:预计算的最大序列长度(如512对应BERT标准配置)
    • device:硬件加速配置(实现跨平台兼容)
python 复制代码
    self.encoding = torch.zeros(max_len, d_model, device=device)
    self.encoding.requires_grad = False
  • 设计意图:创建静态编码矩阵,避免反向传播计算
  • 内存优化:通过requires_grad=False节省显存占用
  • 维度说明:矩阵形状为[max_len, d_model],例如max_len=512时生成512x512矩阵
python 复制代码
    pos = torch.arange(0, max_len, device=device)
    pos = pos.float().unsqueeze(dim=1)
  • 位置索引构建:生成[0,1,...,max_len-1]的连续位置序列
  • 维度变换:通过unsqueeze将1D张量转换为2D(max_len,1),便于广播计算
python 复制代码
    _2i = torch.arange(0, d_model, step=2, device=device).float()
  • 步长控制:step=2确保交替访问奇偶索引
  • 数值范围:当d_model=512时,生成[0,2,4,...,510]的索引序列
python 复制代码
    self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
    self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))
  • 分片赋值:通过0::21::2实现奇偶列交替填充
  • 频率控制:10000 (_2i/d_model)生成指数衰减的频率系数

3.2 前向传播解析

python 复制代码
def forward(self, x):
    batch_size, seq_len = x.size()
    return self.encoding[:seq_len, :]
  • 动态适配:根据实际输入序列长度截取编码
  • 广播机制:自动扩展编码矩阵到批次维度(无需显式复制)
  • 数值叠加:后续与词嵌入进行element-wise相加操作

四、张量运算可视化演示

4.1 示例参数配置

假设:

  • d_model = 4
  • max_len = 3
  • device = 'cpu'

4.2 计算过程推演

步骤1:生成位置索引

复制代码
pos = [[0],
       [1],
       [2]]  # shape (3,1)

步骤2:创建维度索引

复制代码
_2i = [0, 2]  # d_model=4时step=2生成

步骤3:计算频率项

复制代码
频率项 = 10000^( (0/4), (2/4) ) 
       = [1, 10000^0.5] 
       ≈ [1, 100]

步骤4:计算位置编码

复制代码
sin项:
pos / [1, 100] = [[0/1, 0/100],
                 [1/1, 1/100],
                 [2/1, 2/100]]
               = [[0, 0],
                  [1, 0.01],
                  [2, 0.02]]
sin值:
[[0, 0],
 [0.8415, 0.00999983],
 [0.9093, 0.01999867]]

cos项计算同理...

最终编码矩阵:

复制代码
PE = [
  [sin(0), cos(0), sin(0), cos(0)],      # 位置0
  [sin(1), cos(0.01), sin(1), cos(0.01)],# 位置1
  [sin(2), cos(0.02), sin(2), cos(0.02)] # 位置2
]

五、工程实践与优化策略

5.1 配置参数建议

  1. max_len设定:应大于训练数据最大序列长度20%
  2. 设备兼容性:通过device参数统一管理计算设备
  3. 混合精度训练:可将编码矩阵转为half精度

5.2 性能优化技巧

  1. 预计算缓存:提前生成编码矩阵避免运行时计算
  2. 内存映射:对超长序列使用内存映射文件
  3. 稀疏矩阵:对长文本场景采用分块加载策略

六、与其他模块的协同工作

6.1 与词嵌入的集成

python 复制代码
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, device, dropout):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.dropout(tok_emb + pos_emb)
  • 加法融合:通过element-wise相加实现信息融合
  • 梯度隔离:位置编码不参与梯度更新
  • 维度验证:确保tok_embpos_emb维度严格一致

七、典型应用场景分析

7.1 文本生成任务

  • 长序列处理:通过位置编码捕获远距离依赖
  • 解码器优化:在自回归生成时动态调整位置编码

7.2 语音识别系统

  • 时序建模:精确捕捉语音信号的时序特征
  • 多尺度编码:结合不同频率分量处理语音信号

八、扩展研究方向

  1. 相对位置编码:改进绝对位置编码的局限性
  2. 动态频率调整:根据输入数据自动调节频率参数
  3. 混合编码方案:结合可学习参数与固定编码
  4. 量子化压缩:对编码矩阵进行低比特量化

原项目代码(附)

python 复制代码
"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""

import torch
from torch import nn

# 定义一个名为PositionalEncoding的类,它继承自nn.Module,用于计算正弦位置编码。
class PositionalEncoding(nn.Module):
    """
    计算正弦位置编码的类。
    """

    def __init__(self, d_model, max_len, device):
        """
        PositionalEncoding类的构造函数。

        :param d_model: 模型的维度(即嵌入向量的大小)。
        :param max_len: 序列的最大长度。
        :param device: 硬件设备设置(CPU或GPU)。
        """

        super(PositionalEncoding, self).__init__()  # 调用父类nn.Module的构造函数。

        # 初始化一个与输入矩阵大小相同的零矩阵,用于存储位置编码,以便后续与输入矩阵相加。
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 我们不需要计算位置编码的梯度。

        # 创建一个从0到max_len-1的一维张量,表示序列中的位置索引。
        pos = torch.arange(0, max_len, device=device)
        # 将位置索引张量转换为浮点数,并增加一个维度,从1D变为2D,以表示每个位置的索引。
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D,增加维度以表示单词的位置。

        # 创建一个从0到d_model-1,步长为2的一维浮点数张量,用于计算正弦和余弦函数的指数部分。
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i'表示d_model的索引(例如,嵌入大小=50时,'i'的范围为[0,50])。
        # "step=2"意味着'i'每次增加2(相当于2*i)。

        # 使用正弦函数计算位置编码的偶数索引位置的值。
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        # 使用余弦函数计算位置编码的奇数索引位置的值。
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # 计算位置编码,以考虑单词的位置信息。

    def forward(self, x):
        # self.encoding是预先计算好的位置编码矩阵。
        # [max_len = 512, d_model = 512],表示最大长度为512,维度为512的位置编码。

        # 获取输入x的批次大小和序列长度。
        batch_size, seq_len = x.size()
        # [batch_size = 128, seq_len = 30],表示批次大小为128,序列长度为30。

        # 返回与输入序列长度相匹配的位置编码。
        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512],返回的形状为序列长度乘以维度。
        # 它将与输入嵌入(tok_emb)相加,tok_emb的形状通常为[128, 30, 512]。
相关推荐
犬余7 分钟前
模型上下文协议(MCP):AI的“万能插座”
人工智能·mcp
忧陌60615 分钟前
Day22打卡-复习
python
芯盾时代44 分钟前
数据出境的安全合规思考
大数据·人工智能·安全·网络安全·信息与通信
Sylvan Ding1 小时前
PyTorch Lightning实战 - 训练 MNIST 数据集
人工智能·pytorch·python·lightning
大白技术控1 小时前
浙江大学 deepseek 公开课 第三季 第3期 - 陈喜群 教授 (附PPT下载) by 突破信息差
人工智能·互联网·deepseek·deepseek公开课·浙大deepseek公开课课件·deepseek公开课ppt·人工智能大模型
Silence4Allen1 小时前
大模型微调指南之 LLaMA-Factory 篇:一键启动LLaMA系列模型高效微调
人工智能·大模型·微调·llama-factory
江鸟19981 小时前
AI日报 · 2025年05月11日|传闻 OpenAI 考虑推出 ChatGPT “永久”订阅模式
人工智能·gpt·ai·chatgpt·github
weifont1 小时前
Ai大模型训练从零到1第一节(共81节)
人工智能
sbc-study1 小时前
大规模预训练范式(Large-scale Pre-training)
gpt·学习·transformer
kyle~1 小时前
C++匿名函数
开发语言·c++·人工智能