SelfAttenion自注意力机制

my-attention

考虑别的token对当前token的语义影响

第一种情况, 维度缩减

输入x= 4x6

dk=3

wq = 6x3

wk = 6x3

wv = 6x3

q = x@ wq = 4x6*6x3 = 4x3

k = x@ wk = 4x6*6x3 = 4x3

r=q@k.T = 4x3*3x4=4x4

缩放

r = r/sqrt(dk)=4x4

a = softmax®=4x4

v = x@ wv = 4x6*6x3 = 4x3

out = a@v = 4x4 * 4x3 = 4x3

第二种情况, 维度不缩减

输入x= 4x6

输出维度为6

dk=6

随机生成qkv

wq = 6x6

wk = 6x6

wv = 6x6

q = x@ wq = 4x6*6x6 = 4x6

k = x@ wk = 4x6*6x6 = 4x6

r=q@k.T = 4x6*6x4=4x4

缩放

r = r/sqrt(dk)=4x4

归一化

a = softmax®=4x4

原始值增加权重

v = x@ wv = 4x6*6x6 = 4x6

out = a@v = 4x4 * 4x6 = 4x6

保证输出结果的维度和要求要一致

下面是用代码实现了一下自注意力机制

复制代码
import math

import torch
from torch import nn

x = torch.randn(16, 64, 512)

d_model = 512
h_num = 8


class Self_Attention(nn.Module):
    def __init__(self, d_model, h_num):
        # 调用父类构造函数
        super(Self_Attention, self).__init__()

        self.d_model = d_model
        self.h_num = h_num
        self.softmax = nn.Softmax(dim=-1)
        self.w_q = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_o = nn.Linear(d_model, d_model)

    def forward(self, x):
        B, L, D = x.shape

        h_d = self.d_model // self.h_num

        q, k, v = self.w_q(x), self.w_k(x), self.w_v(x)
        q = q.view(B, L, self.h_num, h_d).transpose(1, 2)
        k = k.view(B, L, self.h_num, h_d).transpose(1, 2)
        v = v.view(B, L, self.h_num, h_d).transpose(1, 2)

        r = q @ k.transpose(2, 3) / math.sqrt(h_d)

        mask = torch.tril(torch.ones(L, L, dtype=bool))

        r = r.masked_fill(~mask, -10000) 

        a = self.softmax(r)

        o = a @ v
        o = o.transpose(1, 2).contiguous().view(B, L, self.d_model)

        return self.w_o(o)


attention = Self_Attention(d_model, h_num)
y = attention(x)
print(y.shape)
print(y)
相关推荐
DogDaoDao4 小时前
【第 05 篇】Python的字典与集合
开发语言·python·集合·字典
王小王-1234 小时前
基于深度学习的个性化音乐推荐系统的设计与开发
人工智能·深度学习·mysql·vue·推荐算法·个性化音乐推荐系统·音乐预测
现代野蛮人4 小时前
【深度学习】 —— 几种优化器对比实验
人工智能·深度学习·分类·tensorflow
涛声依旧-底层原理研究所5 小时前
混合检索 + 重排:让 AI Agent 拥有「既全又准」的认知骨架
人工智能·python
努力写A题的小菜鸡5 小时前
01-PyTorch加载数据初认识(dataset运用)
人工智能·pytorch·python
abcy0712135 小时前
python fastapi celery hdfs 异步上传
python·hdfs·fastapi
Dxy12393102165 小时前
Python多线程如何操作全局变量:从踩坑到最佳实践
python
SilentSamsara5 小时前
RAG 系统入门:LangChain/LlamaIndex + Chroma 向量数据库的检索增强实战
数据库·人工智能·python·青少年编程·langchain
YOLO视觉与编程5 小时前
jetson orin nano烧录jetpack7.2系统
人工智能·深度学习·yolo·目标检测·机器学习
码云骑士5 小时前
06-Python装饰器从入门到源码(上)-闭包与自由变量
开发语言·python