正/余弦位置编码 Sinusoidal Encoding

1 公式

其中,pos 为词汇在句子中的位置索引;i 为特征维度索引,2i 代表偶数位置,2i + 1 代表奇数位置;d 为特征维度,在传统 Transformer 模型中默认为 512。

2 代码实现

"你", "今天", "好", "漂亮", "啊"\],句子 token 长度为 5,若特征维度为 512,则位置编码矩阵的形状为 (5, 512)。

python 复制代码
# 正/余弦位置编码
class SinusoidalEncoding(nn.Module):
    def __init__(self,
                 d_model: int = 512,
                 max_len: int = 5000,
                 p: float = 0.1) -> None:
        """
        PE(pos, 2i) = sin(pos / 10000 ** (2i / d_model))
        PE(pos, 2i + 1) = cos(pos / 10000 ** (2i / d_model))

        Args:
            d_model: 特征维度,默认为 512
            max_len: 最大句子长度,默认为 5000
            p: 丢弃率,默认为 0.1
        """
        super(SinusoidalEncoding, self).__init__()
        # 初始化位置向量
        pe = torch.zeros((max_len, d_model))
        position = torch.arange(0, max_len).unsqueeze(1)
        # 1 / 10000 ** (2 * k / d)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
        # 偶数位置
        pe[:, 0::2] = torch.sin(position * div_term)
        # 奇数位置
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # 缓存
        self.register_buffer("pe", pe)
        # 丢弃层
        self.dropout = nn.Dropout(p)

    def forward(self, x: Tensor) -> Tensor:
        # 截取与词向量相同形状的位置向量,并相加,但位置向量不参与学习,(N, S, E) -> (N, S, E)
        x += Variable(self.pe[:, :x.size(1)], requires_grad=False)

        return self.dropout(x)
相关推荐
深蓝海拓10 小时前
使用@property将类方法包装为属性
开发语言·python
gorgeous(๑>؂<๑)10 小时前
【CVPR26-索尼】EW-DETR:通过增量低秩检测Transformer实现动态世界目标检测
人工智能·深度学习·目标检测·计算机视觉·transformer
福运常在11 小时前
股票数据API(19)次新股池数据
java·python·maven
多看书少吃饭12 小时前
Vue3 + Java + Python 打造企业级大模型知识库(含 SSE 流式对话完整源码)
java·python·状态模式
Z.风止12 小时前
Large Model-learning(2)
开发语言·笔记·python·leetcode
蓝天守卫者联盟112 小时前
玩具喷涂废气治理厂家:行业现状、技术路径与选型指南
大数据·运维·人工智能·python
m0_7381207212 小时前
我的创作纪念日0328
java·网络·windows·python·web安全·php
red1giant_star12 小时前
浅析文件类漏洞原理与分类——含payload合集与检测与防护思路
python·安全
tryCbest12 小时前
Python之Flask开发框架(第一篇) — 从安装到第一个应用
开发语言·python·flask
zhangzeyuaaa12 小时前
Python getter/setter 正确用法详解
开发语言·python