深度学习:transpose_qkv()与transpose_output()

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状

    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:

    batch_size 是批次大小。

    seq_len 是序列长度。

    num_hiddens 是隐藏层的维度。

  • 拆分多头

    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列

    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

csharp 复制代码
def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

csharp 复制代码
def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。

相关推荐
sp_fyf_202421 分钟前
【大语言模型】ACL2024论文-13 透过分裂投票的镜头:探索法律案例结果分类中的分歧、难度和校准
人工智能·深度学习·机器学习·语言模型·自然语言处理·分类·数据挖掘
说私域1 小时前
营销手段的变革:开源 AI 智能名片与 S2B2C 商城小程序在新趋势下的机遇与挑战
人工智能·小程序
AI小白龙*3 小时前
AI工业大模型报告:体系架构、关键技术与典型应用
人工智能·语言模型·程序员·架构·llm·大语言模型·ai大模型
Roun34 小时前
区块链赋能Web3:数据透明与隐私保护的新纪元
人工智能·web3·去中心化·区块链
瞻邈5 小时前
PETR/PETRv2/StreamPETR论文阅读
论文阅读·深度学习
神洛华6 小时前
datawhale11月组队学习 模型压缩技术3:2:4结构稀疏化BERT模型
深度学习·算法·bert
Xlbb.6 小时前
安全见闻1-5
前端·网络·人工智能·安全·网络安全
sniper_fandc7 小时前
深度学习基础—Beam search集束搜索
人工智能·深度学习
阿万音玲羽7 小时前
李沐《动手学深度学习》kaggle树叶分类(ResNet18无预训练)python代码实现
人工智能·python·深度学习·机器学习·分类·cnn