深度学习:transpose_qkv()与transpose_output()

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状

    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:

    batch_size 是批次大小。

    seq_len 是序列长度。

    num_hiddens 是隐藏层的维度。

  • 拆分多头

    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列

    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

csharp 复制代码
def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

csharp 复制代码
def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。

相关推荐
qq_416276429 分钟前
LOFAR物理频谱特征提取及实现
人工智能
Python图像识别38 分钟前
71_基于深度学习的布料瑕疵检测识别系统(yolo11、yolov8、yolov5+UI界面+Python项目源码+模型+标注好的数据集)
python·深度学习·yolo
余俊晖39 分钟前
如何构造一个文档解析的多模态大模型?MinerU2.5架构、数据、训练方法
人工智能·文档解析
Akamai中国2 小时前
Linebreak赋能实时化企业转型:专业系统集成商携手Akamai以实时智能革新企业运营
人工智能·云计算·云服务
LiJieNiub3 小时前
读懂目标检测:从基础概念到主流算法
人工智能·计算机视觉·目标跟踪
哥布林学者3 小时前
吴恩达深度学习课程一:神经网络和深度学习 第三周:浅层神经网络(二)
深度学习·ai
weixin_519535773 小时前
从ChatGPT到新质生产力:一份数据驱动的AI研究方向指南
人工智能·深度学习·机器学习·ai·chatgpt·数据分析·aigc
爱喝白开水a4 小时前
LangChain 基础系列之 Prompt 工程详解:从设计原理到实战模板_langchain prompt
开发语言·数据库·人工智能·python·langchain·prompt·知识图谱
takashi_void4 小时前
如何在本地部署大语言模型(Windows,Mac,Linux)三系统教程
linux·人工智能·windows·macos·语言模型·nlp
OpenCSG4 小时前
【活动预告】2025斗拱开发者大会,共探支付与AI未来
人工智能·ai·开源·大模型·支付安全