SelfAttention和MultiHeadAttion实现demo

#encoding:utf-8

from math import sqrt

import torch

import torch.nn as nn

class Self_Attention(nn.Module):

def init(self, input_dim, dim_k, dim_v):

super(Self_Attention, self). init()

self.q = nn.Linear(input_dim, dim_k)

self.k = nn.Linear(input_dim, dim_k)

self.v = nn.Linear(input_dim, dim_v)

self.norm_fact = 1 / sqrt(dim_k)

def forward(self, x):

print("x.shape:", x.shape)

print("q.shape:", self.q.shape)

Q = self.q(x)

print("Q.shape:", Q.shape)

K = self.k(x)

print("K.shape:", K.shape)

V = self.v(x)

print("V.shape:", V.shape)

atten = nn.Softmax(dim=-1)(torch.bmm(Q,K.permute(0,2,1))) * self.norm_fact

output = torch.bmm(atten, V)

return output

print("\n")

print("self attention:")

x = torch.randn(4,3,1024)

print(x)

print("input size:", x.size())

self_attention = Self_Attention(1024,128,5)

res = self_attention(x)

print("\n")

print(res)

print("output size:", res.size())

print("\n")

class Self_Attention_Muti_Head(nn.Module):

def init(self, input_dim, dim_k, dim_v, nums_head):

super(Self_Attention_Muti_Head, self).init()

assert dim_k % nums_head == 0

assert dim_v % nums_head == 0

self.q = nn.Linear(input_dim, dim_k)

self.k = nn.Linear(input_dim, dim_k)

self.v = nn.Linear(input_dim, dim_v)

self.nums_head = nums_head

self.dim_k = dim_k

self.dim_v = dim_v

self._norm_fact = 1 / sqrt(dim_k)

def forward(self, x):

Q = self.q(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)

K = self.k(x).reshape(-1, x.shape[0], x.shape[1], self.dim_k//self.nums_head)

V = self.v(x).reshape(-1, x.shape[0], x.shape[1], self.dim_v//self.nums_head)

print("x.shape:", x.shape)

print("Q.shape", Q.size())

atten = nn.Softmax(dim=-1)(torch.matmul(Q, K.permute(0,1,3,2)))

output = torch.matmul(atten, V).reshape(x.shape[0], x.shape[1], -1)

return output

print("\n")

print("multi head attention:")

x = torch.randn(4,3,1024)

print(x)

print(x.size())

self_attention = Self_Attention_Muti_Head(1024,128,6,2)

res = self_attention(x)

print("\n")

print(res)

print(res.size())


有个问题:

根据文献:https://arxiv.org/pdf/1911.02150.pdf,感觉这里说的Multi Head Attenion和 Group Query Attention意思是一样的:

这下面这张经典的图中的的Grouped-query意思是一样的:

哪里没理解到位?

相关推荐
计算机毕业编程指导师5 分钟前
大数据可视化毕设:Hadoop+Spark交通分析系统从零到上线 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
大数据·hadoop·python·计算机·spark·毕业设计·城市交通
m0_6038887112 分钟前
FineInstructions Scaling Synthetic Instructions to Pre-Training Scale
人工智能·深度学习·机器学习·ai·论文速览
计算机毕业编程指导师13 分钟前
【计算机毕设选题】基于Spark的车辆排放分析:2026年热门大数据项目 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
大数据·hadoop·python·计算机·spark·毕业设计·车辆排放
浔川python社19 分钟前
浔川社团关于产品数据情况的官方通告
python
生活很暖很治愈20 分钟前
GUI自动化测试[3]——控件&数鼠标操作
windows·python·功能测试·测试工具
EmmaXLZHONG21 分钟前
Reinforce Learning Concept Flow Chart (强化学习概念流程图)
人工智能·深度学习·机器学习·流程图
薛定谔的猫198223 分钟前
十三.调用 BERT 中文文本情感分析交互式推理模型训练好的
人工智能·深度学习·bert
老蒋每日coding32 分钟前
Python3基础练习题详解,从入门到熟练的 50 个实例(一)
开发语言·python
小鸡吃米…37 分钟前
机器学习 - 高斯判别分析(Gaussian Discriminant Analysis)
人工智能·深度学习·机器学习
HAPPY酷38 分钟前
构建即自由:一份为创造者设计的 Windows C++ 自动化构建指南
开发语言·c++·ide·windows·python·策略模式·visual studio