深度学习:transpose_qkv()与transpose_output()

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状

    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:

    batch_size 是批次大小。

    seq_len 是序列长度。

    num_hiddens 是隐藏层的维度。

  • 拆分多头

    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列

    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

csharp 复制代码
def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

csharp 复制代码
def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。

相关推荐
yvestine13 分钟前
自然语言处理——文本表示
人工智能·python·算法·自然语言处理·文本表示
zzc92121 分钟前
MATLAB仿真生成无线通信网络拓扑推理数据集
开发语言·网络·数据库·人工智能·python·深度学习·matlab
点赋科技22 分钟前
沙市区举办资本市场赋能培训会 点赋科技分享智能消费新实践
大数据·人工智能
HeteroCat29 分钟前
一周年工作总结:做了一年的AI工作我都干了什么?
人工智能
编程有点难36 分钟前
Python训练打卡Day43
开发语言·python·深度学习
2301_8050545643 分钟前
Python训练营打卡Day48(2025.6.8)
pytorch·python·深度学习
YSGZJJ43 分钟前
股指期货技术分析与短线操作方法介绍
大数据·人工智能
Guheyunyi1 小时前
监测预警系统重塑隧道安全新范式
大数据·运维·人工智能·科技·安全
码码哈哈爱分享1 小时前
[特殊字符] Whisper 模型介绍(OpenAI 语音识别系统)
人工智能·whisper·语音识别