(深度学习记录)第TR4周:Pytorch复现Transformer

🏡我的环境:

  • 语言环境:Python3.11.4
  • 编译器:Jupyter Notebook
  • torcch版本:2.0.1
python 复制代码
import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super().__init__()
 
        self.hid_dim = hid_dim
        self.n_heads = n_heads
 
        # hid_dim必须整除
        assert hid_dim % n_heads == 0
        # 定义wq
        self.w_q = nn.Linear(hid_dim, hid_dim)
        # 定义wk
        self.w_k = nn.Linear(hid_dim, hid_dim)
        # 定义wv
        self.w_v = nn.Linear(hid_dim, hid_dim)
 
        self.fc = nn.Linear(hid_dim, hid_dim)
        self.do = nn.Dropout(dropout)
 
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))
 
    def forward(self, query, key, value, mask=None):
        # Q与KV在句子长度这一个维度上数值可以不一样
        bsz = query.shape[0]
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
 
        # 将QKV拆成多组,方案是将向量直接拆开了
        # (64, 12, 300) -> (64, 12, 6, 50) -> (64, 6, 12, 50)
        # (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)
        # (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
 
        # 第1步,Q x K / scale
        # (64, 6, 12, 50) x (64, 6, 50, 10) -> (64, 6, 12, 10)
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
 
        # 需要mask掉的地方,attention设置的很小很小
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
 
        # 第2步,做softmax 再dropout得到attention
        attention = self.do(torch.softmax(attention, dim=-1))
 
 
        # 第3步,attention结果与k相乘,得到多头注意力的结果
        # (64, 6, 12, 10) x (64, 6, 10, 50) -> (64, 6, 12, 50)
        x = torch.matmul(attention, V)
 
        # 把结果转回去
        # (64, 6, 12, 50) -> (64, 12, 6, 50)
        x = x.permute(0, 2, 1, 3).contiguous()
 
        # 把结果合并
        # (64, 12, 6, 50) -> (64, 12, 300)
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        x = self.fc(x)
        return x
 
query = torch.rand(64, 12, 300)
key = torch.rand(64, 10, 300)
value = torch.rand(64, 10, 300)
attention = MultiHeadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
print(output.shape)

多头注意力机制拓展了模型关注不同位置的能力,赋予Attention层多个"子表示空间"。

相关推荐
β添砖java1 小时前
深度学习(13)PyTorch神经网络基础
人工智能·深度学习
victory04312 小时前
论文设计和撰写1
人工智能·深度学习·机器学习
沪漂阿龙4 小时前
OpenAI Agents SDK 深度解析(三):执行层——Agent 的“幕后指挥部”
人工智能·深度学习
还是奇怪4 小时前
AI 提示词工程入门:用好的语言与模型高效对话
大数据·人工智能·语言模型·自然语言处理·transformer
数智工坊4 小时前
【SAM-DETR论文阅读】:基于语义对齐匹配的DETR极速收敛检测框架
网络·论文阅读·人工智能·深度学习·transformer
童园管理札记4 小时前
【续】数字时代:学前教育的新改革
经验分享·深度学习·职场和发展·微信公众平台
西西弗Sisyphus5 小时前
Transformer 编码器堆叠的 Encoder 层之间,和多头注意力模块内部各独立单注意力头之间,在 QKV 上处理的区别
transformer
AI医影跨模态组学6 小时前
如何将纵向CT影像组学特征与局部晚期胃癌化疗时空异质性及耐药演化建立关联,并进一步解释其与化疗响应、淋巴结转移及生存预后的机制联系
人工智能·深度学习·论文·医学·医学影像·影像组学
硅谷秋水8 小时前
ClawVM:有状态工具LLM智体的Harness管理型虚拟内存
人工智能·深度学习·语言模型
春风有信8 小时前
【DM】DDPM与DDIM的数学原理
人工智能·深度学习·机器学习