Transformer 代码剖析6 - 位置编码 (pytorch实现)

一、位置编码的数学原理与设计思想

1.1 核心公式解析

位置编码采用正弦余弦交替编码方案:
P E ( p o s , 2 i ) = sin ⁡ ( p o s 1000 0 2 i / d m o d e l ) P E ( p o s , 2 i + 1 ) = cos ⁡ ( p o s 1000 0 2 i / d m o d e l ) PE_{(pos,2i)} = \sin\left(\frac{pos}{10000^{2i/d_{model}}}\right) \\ PE_{(pos,2i+1)} = \cos\left(\frac{pos}{10000^{2i/d_{model}}}\right) PE(pos,2i)=sin(100002i/dmodelpos)PE(pos,2i+1)=cos(100002i/dmodelpos)

式中:

  • p o s pos pos:当前词在序列中的绝对位置
  • i i i:特征维度的索引( 0 ≤ i < d m o d e l / 2 0 \leq i < d_{model}/2 0≤i<dmodel/2)
  • 1000 0 2 i / d m o d e l 10000^{2i/d_{model}} 100002i/dmodel:频率控制项,形成指数衰减的频率分布

1.2 设计优势分析

1. 绝对位置感知: 每个位置生成唯一编码模式
2. 相对位置建模: 通过三角函数加法公式可推导任意两个位置的关联度
3. 多频特征捕捉: 不同频率的正余弦波组合形成丰富的表征空间
4. 值域归一化: 所有编码值分布在[-1,1]区间,与词嵌入维度保持数值一致性

(图示:不同维度上的位置编码波形,高频维度对应快速变化,低频维度对应缓慢变化)

二、代码架构与执行流程

2.1 类结构设计

PositionalEncoding __init__构造函数 创建零矩阵 配置梯度策略 构建位置索引 生成维度索引 计算正弦编码 计算余弦编码 forward前向传播 获取输入尺寸 返回截断编码

2.2 核心代码模块

python 复制代码
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super().__init__()
        # 编码矩阵初始化(关键参数说明)
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 冻结梯度计算
        
        # 位置索引构建(维度变换演示)
        pos = torch.arange(0, max_len, device=device).float().unsqueeze(dim=1)
        
        # 维度索引生成(步长控制逻辑)
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        
        # 编码计算过程(数学实现)
        self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

三、逐行代码深度解析

3.1 构造函数解析

python 复制代码
def __init__(self, d_model, max_len, device):
    super(PositionalEncoding, self).__init__()
  • 功能说明:继承PyTorch模块基类,初始化可训练参数
  • 参数详解:
    • d_model:编码维度(需与词嵌入维度一致)
    • max_len:预计算的最大序列长度(如512对应BERT标准配置)
    • device:硬件加速配置(实现跨平台兼容)
python 复制代码
    self.encoding = torch.zeros(max_len, d_model, device=device)
    self.encoding.requires_grad = False
  • 设计意图:创建静态编码矩阵,避免反向传播计算
  • 内存优化:通过requires_grad=False节省显存占用
  • 维度说明:矩阵形状为[max_len, d_model],例如max_len=512时生成512x512矩阵
python 复制代码
    pos = torch.arange(0, max_len, device=device)
    pos = pos.float().unsqueeze(dim=1)
  • 位置索引构建:生成[0,1,...,max_len-1]的连续位置序列
  • 维度变换:通过unsqueeze将1D张量转换为2D(max_len,1),便于广播计算
python 复制代码
    _2i = torch.arange(0, d_model, step=2, device=device).float()
  • 步长控制:step=2确保交替访问奇偶索引
  • 数值范围:当d_model=512时,生成[0,2,4,...,510]的索引序列
python 复制代码
    self.encoding[:, 0::2] = torch.sin(pos / (10000  (_2i / d_model)))
    self.encoding[:, 1::2] = torch.cos(pos / (10000  (_2i / d_model)))
  • 分片赋值:通过0::21::2实现奇偶列交替填充
  • 频率控制:10000 (_2i/d_model)生成指数衰减的频率系数

3.2 前向传播解析

python 复制代码
def forward(self, x):
    batch_size, seq_len = x.size()
    return self.encoding[:seq_len, :]
  • 动态适配:根据实际输入序列长度截取编码
  • 广播机制:自动扩展编码矩阵到批次维度(无需显式复制)
  • 数值叠加:后续与词嵌入进行element-wise相加操作

四、张量运算可视化演示

4.1 示例参数配置

假设:

  • d_model = 4
  • max_len = 3
  • device = 'cpu'

4.2 计算过程推演

步骤1:生成位置索引

复制代码
pos = [[0],
       [1],
       [2]]  # shape (3,1)

步骤2:创建维度索引

复制代码
_2i = [0, 2]  # d_model=4时step=2生成

步骤3:计算频率项

复制代码
频率项 = 10000^( (0/4), (2/4) ) 
       = [1, 10000^0.5] 
       ≈ [1, 100]

步骤4:计算位置编码

复制代码
sin项:
pos / [1, 100] = [[0/1, 0/100],
                 [1/1, 1/100],
                 [2/1, 2/100]]
               = [[0, 0],
                  [1, 0.01],
                  [2, 0.02]]
sin值:
[[0, 0],
 [0.8415, 0.00999983],
 [0.9093, 0.01999867]]

cos项计算同理...

最终编码矩阵:

复制代码
PE = [
  [sin(0), cos(0), sin(0), cos(0)],      # 位置0
  [sin(1), cos(0.01), sin(1), cos(0.01)],# 位置1
  [sin(2), cos(0.02), sin(2), cos(0.02)] # 位置2
]

五、工程实践与优化策略

5.1 配置参数建议

  1. max_len设定:应大于训练数据最大序列长度20%
  2. 设备兼容性:通过device参数统一管理计算设备
  3. 混合精度训练:可将编码矩阵转为half精度

5.2 性能优化技巧

  1. 预计算缓存:提前生成编码矩阵避免运行时计算
  2. 内存映射:对超长序列使用内存映射文件
  3. 稀疏矩阵:对长文本场景采用分块加载策略

六、与其他模块的协同工作

6.1 与词嵌入的集成

python 复制代码
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, device, dropout):
        super().__init__()
        self.tok_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = PositionalEncoding(d_model, max_len, device)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        tok_emb = self.tok_emb(x)
        pos_emb = self.pos_emb(x)
        return self.dropout(tok_emb + pos_emb)
  • 加法融合:通过element-wise相加实现信息融合
  • 梯度隔离:位置编码不参与梯度更新
  • 维度验证:确保tok_embpos_emb维度严格一致

七、典型应用场景分析

7.1 文本生成任务

  • 长序列处理:通过位置编码捕获远距离依赖
  • 解码器优化:在自回归生成时动态调整位置编码

7.2 语音识别系统

  • 时序建模:精确捕捉语音信号的时序特征
  • 多尺度编码:结合不同频率分量处理语音信号

八、扩展研究方向

  1. 相对位置编码:改进绝对位置编码的局限性
  2. 动态频率调整:根据输入数据自动调节频率参数
  3. 混合编码方案:结合可学习参数与固定编码
  4. 量子化压缩:对编码矩阵进行低比特量化

原项目代码(附)

python 复制代码
"""
@author : Hyunwoong
@when : 2019-10-22
@homepage : https://github.com/gusdnd852
"""

import torch
from torch import nn

# 定义一个名为PositionalEncoding的类,它继承自nn.Module,用于计算正弦位置编码。
class PositionalEncoding(nn.Module):
    """
    计算正弦位置编码的类。
    """

    def __init__(self, d_model, max_len, device):
        """
        PositionalEncoding类的构造函数。

        :param d_model: 模型的维度(即嵌入向量的大小)。
        :param max_len: 序列的最大长度。
        :param device: 硬件设备设置(CPU或GPU)。
        """

        super(PositionalEncoding, self).__init__()  # 调用父类nn.Module的构造函数。

        # 初始化一个与输入矩阵大小相同的零矩阵,用于存储位置编码,以便后续与输入矩阵相加。
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.requires_grad = False  # 我们不需要计算位置编码的梯度。

        # 创建一个从0到max_len-1的一维张量,表示序列中的位置索引。
        pos = torch.arange(0, max_len, device=device)
        # 将位置索引张量转换为浮点数,并增加一个维度,从1D变为2D,以表示每个位置的索引。
        pos = pos.float().unsqueeze(dim=1)
        # 1D => 2D,增加维度以表示单词的位置。

        # 创建一个从0到d_model-1,步长为2的一维浮点数张量,用于计算正弦和余弦函数的指数部分。
        _2i = torch.arange(0, d_model, step=2, device=device).float()
        # 'i'表示d_model的索引(例如,嵌入大小=50时,'i'的范围为[0,50])。
        # "step=2"意味着'i'每次增加2(相当于2*i)。

        # 使用正弦函数计算位置编码的偶数索引位置的值。
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        # 使用余弦函数计算位置编码的奇数索引位置的值。
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))
        # 计算位置编码,以考虑单词的位置信息。

    def forward(self, x):
        # self.encoding是预先计算好的位置编码矩阵。
        # [max_len = 512, d_model = 512],表示最大长度为512,维度为512的位置编码。

        # 获取输入x的批次大小和序列长度。
        batch_size, seq_len = x.size()
        # [batch_size = 128, seq_len = 30],表示批次大小为128,序列长度为30。

        # 返回与输入序列长度相匹配的位置编码。
        return self.encoding[:seq_len, :]
        # [seq_len = 30, d_model = 512],返回的形状为序列长度乘以维度。
        # 它将与输入嵌入(tok_emb)相加,tok_emb的形状通常为[128, 30, 512]。
相关推荐
奔跑吧邓邓子几秒前
DeepSeek 赋能智能零售,解锁动态定价新范式
人工智能·动态定价·智能零售·deepseek
火兮明兮10 分钟前
Python训练第四十五天
开发语言·python
鼓掌MVP14 分钟前
边缘计算应用实践心得
人工智能·边缘计算
zdy126357468814 分钟前
python43天
python·深度学习·机器学习
QYR_1116 分钟前
宠物车载安全座椅市场报告:解读行业趋势与投资前景
大数据·人工智能
wswlqsss19 分钟前
第四十五天打卡
人工智能·深度学习
Likeadust23 分钟前
视频汇聚平台EasyCVR“明厨亮灶”方案筑牢旅游景区餐饮安全品质防线
网络·人工智能·音视频
天翼云开发者社区37 分钟前
总决赛定档!“天翼云息壤杯”高校AI大赛巅峰之战即将打响!
人工智能·ai大赛
亚马逊云开发者1 小时前
Amazon Bedrock 助力 SolveX.AI 构建智能解题 Agent,打造头部教育科技应用
人工智能
搏博1 小时前
将图形可视化工具的 Python 脚本打包为 Windows 应用程序
开发语言·windows·python·matplotlib·数据可视化