(深度学习记录)第TR4周:Pytorch复现Transformer

🏡我的环境:

  • 语言环境:Python3.11.4
  • 编译器:Jupyter Notebook
  • torcch版本:2.0.1
python 复制代码
import torch
import torch.nn as nn
 
class MultiHeadAttention(nn.Module):
    def __init__(self, hid_dim, n_heads, dropout):
        super().__init__()
 
        self.hid_dim = hid_dim
        self.n_heads = n_heads
 
        # hid_dim必须整除
        assert hid_dim % n_heads == 0
        # 定义wq
        self.w_q = nn.Linear(hid_dim, hid_dim)
        # 定义wk
        self.w_k = nn.Linear(hid_dim, hid_dim)
        # 定义wv
        self.w_v = nn.Linear(hid_dim, hid_dim)
 
        self.fc = nn.Linear(hid_dim, hid_dim)
        self.do = nn.Dropout(dropout)
 
        self.scale = torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))
 
    def forward(self, query, key, value, mask=None):
        # Q与KV在句子长度这一个维度上数值可以不一样
        bsz = query.shape[0]
        Q = self.w_q(query)
        K = self.w_k(key)
        V = self.w_v(value)
 
        # 将QKV拆成多组,方案是将向量直接拆开了
        # (64, 12, 300) -> (64, 12, 6, 50) -> (64, 6, 12, 50)
        # (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)
        # (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)
        Q = Q.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        K = K.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
        V = V.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)
 
        # 第1步,Q x K / scale
        # (64, 6, 12, 50) x (64, 6, 50, 10) -> (64, 6, 12, 10)
        attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale
 
        # 需要mask掉的地方,attention设置的很小很小
        if mask is not None:
            attention = attention.masked_fill(mask == 0, -1e10)
 
        # 第2步,做softmax 再dropout得到attention
        attention = self.do(torch.softmax(attention, dim=-1))
 
 
        # 第3步,attention结果与k相乘,得到多头注意力的结果
        # (64, 6, 12, 10) x (64, 6, 10, 50) -> (64, 6, 12, 50)
        x = torch.matmul(attention, V)
 
        # 把结果转回去
        # (64, 6, 12, 50) -> (64, 12, 6, 50)
        x = x.permute(0, 2, 1, 3).contiguous()
 
        # 把结果合并
        # (64, 12, 6, 50) -> (64, 12, 300)
        x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))
        x = self.fc(x)
        return x
 
query = torch.rand(64, 12, 300)
key = torch.rand(64, 10, 300)
value = torch.rand(64, 10, 300)
attention = MultiHeadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
print(output.shape)

多头注意力机制拓展了模型关注不同位置的能力,赋予Attention层多个"子表示空间"。

相关推荐
写代码的二次猿2 小时前
安装openfold(顺利解决版)
开发语言·python·深度学习
SkyXZ3 小时前
人脸伪造判别分类网络CNN&Transformer
深度学习
飞Link4 小时前
深度解析 LSTM 神经网络架构与实战指南
人工智能·深度学习·神经网络·lstm
love530love5 小时前
Windows 11 源码编译 vLLM 0.16 完全指南(RTX 3090 / CUDA 12.8 / PyTorch 2.7.1)
人工智能·pytorch·windows·python·深度学习·vllm·vs 2022
放下华子我只抽RuiKe55 小时前
机器学习全景指南-基石篇——预测连续值的线性回归
人工智能·深度学习·神经网络·算法·机器学习·自然语言处理·线性回归
带娃的IT创业者5 小时前
专栏系列3.3《时序关联学习:r=0.733 背后的记忆形成》
人工智能·深度学习·神经网络·时序学习·nct·神经调质
快乐非自愿5 小时前
NIO核心原理深度解析:非阻塞I/O的块式设计与高并发实现逻辑
人工智能·深度学习·nio
八月瓜科技5 小时前
擎策·知海全球专利数据库 技术赋能检索 让科技创新少走弯路
大数据·数据库·人工智能·科技·深度学习·娱乐
兴通扫码设备5 小时前
ocr工业场景适配升级:深圳市兴通物联XTC8501智能相机接口与环境适应性技术解析
数据库·人工智能·深度学习·数码相机·计算机视觉
小陈phd5 小时前
多模态大模型学习笔记(十六)——Transformer 学习之 Decoder Only
人工智能·笔记·深度学习·学习·自然语言处理·transformer