正/余弦位置编码 Sinusoidal Encoding

1 公式

其中,pos 为词汇在句子中的位置索引;i 为特征维度索引,2i 代表偶数位置,2i + 1 代表奇数位置;d 为特征维度,在传统 Transformer 模型中默认为 512。

2 代码实现

"你", "今天", "好", "漂亮", "啊"\],句子 token 长度为 5,若特征维度为 512,则位置编码矩阵的形状为 (5, 512)。

python 复制代码
# 正/余弦位置编码
class SinusoidalEncoding(nn.Module):
    def __init__(self,
                 d_model: int = 512,
                 max_len: int = 5000,
                 p: float = 0.1) -> None:
        """
        PE(pos, 2i) = sin(pos / 10000 ** (2i / d_model))
        PE(pos, 2i + 1) = cos(pos / 10000 ** (2i / d_model))

        Args:
            d_model: 特征维度,默认为 512
            max_len: 最大句子长度,默认为 5000
            p: 丢弃率,默认为 0.1
        """
        super(SinusoidalEncoding, self).__init__()
        # 初始化位置向量
        pe = torch.zeros((max_len, d_model))
        position = torch.arange(0, max_len).unsqueeze(1)
        # 1 / 10000 ** (2 * k / d)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
        # 偶数位置
        pe[:, 0::2] = torch.sin(position * div_term)
        # 奇数位置
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # 缓存
        self.register_buffer("pe", pe)
        # 丢弃层
        self.dropout = nn.Dropout(p)

    def forward(self, x: Tensor) -> Tensor:
        # 截取与词向量相同形状的位置向量,并相加,但位置向量不参与学习,(N, S, E) -> (N, S, E)
        x += Variable(self.pe[:, :x.size(1)], requires_grad=False)

        return self.dropout(x)
相关推荐
曲幽4 小时前
数据库实战:FastAPI + SQLAlchemy 2.0 + Alembic 从零搭建,踩坑实录
python·fastapi·web·sqlalchemy·db·asyncio·alembic
用户8356290780519 小时前
Python 实现 PowerPoint 形状动画设置
后端·python
ponponon10 小时前
时代的眼泪,nameko 和 eventlet 停止维护后的项目自救,升级和替代之路
python
Flittly10 小时前
【从零手写 ClaudeCode:learn-claude-code 项目实战笔记】(5)Skills (技能加载)
python·agent
敏编程10 小时前
一天一个Python库:pyarrow - 大规模数据处理的利器
python
Flittly12 小时前
【从零手写 ClaudeCode:learn-claude-code 项目实战笔记】(4)Subagents (子智能体)
python·agent
明月_清风19 小时前
Python 装饰器前传:如果不懂“闭包”,你只是在复刻代码
后端·python
明月_清风19 小时前
打破“死亡环联”:深挖 Python 分代回收与垃圾回收(GC)机制
后端·python
ZhengEnCi1 天前
08c. 检索算法与策略-混合检索
后端·python·算法
明月_清风2 天前
Python 内存手术刀:sys.getrefcount 与引用计数的生死时速
后端·python