正/余弦位置编码 Sinusoidal Encoding

1 公式

其中,pos 为词汇在句子中的位置索引;i 为特征维度索引,2i 代表偶数位置,2i + 1 代表奇数位置;d 为特征维度,在传统 Transformer 模型中默认为 512。

2 代码实现

"你", "今天", "好", "漂亮", "啊",句子 token 长度为 5,若特征维度为 512,则位置编码矩阵的形状为 (5, 512)。

python 复制代码
# 正/余弦位置编码
class SinusoidalEncoding(nn.Module):
    def __init__(self,
                 d_model: int = 512,
                 max_len: int = 5000,
                 p: float = 0.1) -> None:
        """
        PE(pos, 2i) = sin(pos / 10000 ** (2i / d_model))
        PE(pos, 2i + 1) = cos(pos / 10000 ** (2i / d_model))

        Args:
            d_model: 特征维度,默认为 512
            max_len: 最大句子长度,默认为 5000
            p: 丢弃率,默认为 0.1
        """
        super(SinusoidalEncoding, self).__init__()
        # 初始化位置向量
        pe = torch.zeros((max_len, d_model))
        position = torch.arange(0, max_len).unsqueeze(1)
        # 1 / 10000 ** (2 * k / d)
        div_term = torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000) / d_model))
        # 偶数位置
        pe[:, 0::2] = torch.sin(position * div_term)
        # 奇数位置
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        # 缓存
        self.register_buffer("pe", pe)
        # 丢弃层
        self.dropout = nn.Dropout(p)

    def forward(self, x: Tensor) -> Tensor:
        # 截取与词向量相同形状的位置向量,并相加,但位置向量不参与学习,(N, S, E) -> (N, S, E)
        x += Variable(self.pe[:, :x.size(1)], requires_grad=False)

        return self.dropout(x)
相关推荐
不爱吃糖の糖糖15 小时前
RAG 07:RAG 高级范式与幻觉防控
人工智能·embedding
MageGojo15 小时前
10 种主题随机诗词:一个 API 解决小程序的诗词内容源
python·小程序·古诗词·api 接入
cooldream200915 小时前
使用 uv 管理 Python 虚拟环境:现代 Python 开发的高效实践
python·uv·mcp
zhangfeng113315 小时前
国家超算中心 系统自带模型 和pytorch 和cuda版本
人工智能·pytorch·python
m0_7381207215 小时前
渗透测试基础——黑盒测试下的Web漏洞挖掘与利用解析(二)
服务器·前端·python·网络协议·安全·网络安全
huan19911015 小时前
从机器翻译到智驾:规则派的黄昏与数据革命的终局 (七)
人工智能·自然语言处理·机器翻译
玫幽倩15 小时前
2025FIC取证决赛wp(手机取证)
python·智能手机·手机·电子取证·计算机取证·手机取证·fic
多彩电脑15 小时前
Kivy如何自定义事件
开发语言·python
java_cj15 小时前
LangChain初入门 - 简化LLM开发难度的利器
开发语言·python·langchain
sleven fung15 小时前
llama-cpp-python 本地部署入门
开发语言·python·算法·llama