了解并实现一个Transformer Block

文章目录

  • [1. 前言](#1. 前言)
  • [2. Transformer Block](#2. Transformer Block)
  • [3. 代码实现](#3. 代码实现)
  • [4. 参考](#4. 参考)

1. 前言

什么是 Transformer?如果希望深入理解可以参考:
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

本文主要介绍常常听到的 Transformer Block 的概念,以及如何实现一个 Transformer Block。

2. Transformer Block

回顾一下 Transformer 的完整模型:

我们常说的 Transformer Block 对应图中解码器的上部分。为了具体展示流程,我们假设有一句话:"Every effort moves you" 作为输入,经过蓝色框中的 Transformer Block 之后输出,如下图:

图中蓝色的部分就是所谓的 Transformer Block。

3. 代码实现

BERT 源码已经实现了 Transformer 的细节,完整源码参考 Pytorch Bert,这里把 Transformer Block 实现的框架贴出来

python 复制代码
import torch.nn as nn

from .attention import MultiHeadedAttention
from .utils import SublayerConnection, PositionwiseFeedForward


class TransformerBlock(nn.Module):
    """
    Bidirectional Encoder = Transformer (self-attention)
    Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
    """

    def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
        """
        :param hidden: hidden size of transformer
        :param attn_heads: head sizes of multi-head attention
        :param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
        :param dropout: dropout rate
        """

        super().__init__()
        self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
        self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
        self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, mask):
        x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask))
        x = self.output_sublayer(x, self.feed_forward)
        return self.dropout(x)

4. 参考

《NLP深入学习:大模型背后的Transformer模型究竟是什么?(一)》
《NLP深入学习:大模型背后的Transformer模型究竟是什么?(二)》

欢迎关注本人,我是喜欢搞事的程序猿; 一起进步,一起学习;

欢迎关注知乎/CSDN:SmallerFL;

也欢迎关注我的wx公众号(精选高质量文章):一个比特定乾坤

相关推荐
格林威44 分钟前
常规线扫描镜头有哪些类型?能做什么?
人工智能·深度学习·数码相机·算法·计算机视觉·视觉检测·工业镜头
lyx33136967591 小时前
#深度学习基础:神经网络基础与PyTorch
pytorch·深度学习·神经网络·参数初始化
倔强青铜三1 小时前
苦练Python第63天:零基础玩转TOML配置读写,tomllib模块实战
人工智能·python·面试
B站计算机毕业设计之家2 小时前
智慧交通项目:Python+YOLOv8 实时交通标志系统 深度学习实战(TT100K+PySide6 源码+文档)✅
人工智能·python·深度学习·yolo·计算机视觉·智慧交通·交通标志
高工智能汽车2 小时前
棱镜观察|极氪销量遇阻?千里智驾左手服务吉利、右手对标华为
人工智能·华为
txwtech2 小时前
第6篇 OpenCV RotatedRect如何判断矩形的角度
人工智能·opencv·计算机视觉
正牌强哥2 小时前
Futures_ML——机器学习在期货量化交易中的应用与实践
人工智能·python·机器学习·ai·交易·akshare
倔强青铜三2 小时前
苦练Python第62天:零基础玩转CSV文件读写,csv模块实战
人工智能·python·面试
大模型真好玩3 小时前
低代码Agent开发框架使用指南(二)—Coze平台核心功能概览
人工智能·coze·deepseek
jerryinwuhan3 小时前
最短路径问题总结
开发语言·人工智能·python