深度学习:transpose_qkv()与transpose_output()

transpose_qkv 函数的主要作用是将输入的张量重新排列,使其适合多头注意力的计算。具体来说,它将输入张量的形状从 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)

详细步骤

  • 输入形状

    假设输入的张量形状为 (batch_size, seq_len, num_hiddens),其中:

    batch_size 是批次大小。

    seq_len 是序列长度。

    num_hiddens 是隐藏层的维度。

  • 拆分多头

    多头注意力机制将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens // num_heads。

  • 重新排列

    通过重新排列张量的维度,将 (batch_size, seq_len, num_hiddens) 转换为 (batch_size * num_heads, seq_len, num_hiddens // num_heads)。

具体实现

假设 transpose_qkv 函数的实现如下:

csharp 复制代码
def transpose_qkv(X, num_heads):
    # X: (batch_size, seq_len, num_hiddens)
    batch_size, seq_len, num_hiddens = X.shape
    num_hiddens_per_head = num_hiddens // num_heads
    
    # 将 num_hiddens 维度拆分成 num_heads 个头
    X = X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 将 batch_size 和 num_heads 合并
    X = X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, seq_len, num_heads, num_hiddens_per_head):
      将 num_hiddens 维度拆分成 num_heads 个头,每个头的维度为 num_hiddens_per_head。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 num_heads 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size * num_heads, seq_len, num_hiddens_per_head):
      将 batch_size 和 num_heads 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size * num_heads, seq_len, num_hiddens_per_head)。

总结

transpose_qkv 函数通过以下步骤将输入张量重新排列,使其适合多头注意力的计算:

  • 将 num_hiddens 维度拆分成 num_heads 个头。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 batch_size 和 num_heads 维度,使得每个头的数据连续排列。

最终,transpose_qkv 函数返回形状为 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 的张量,以便进行多头注意力计算。

transpose_output 函数的主要作用是将多头注意力的输出重新排列,使其适合后续的处理。具体来说,它将输入张量的形状从 (batch_size * num_heads, seq_len, num_hiddens // num_heads) 转换为 (batch_size, seq_len, num_hiddens)

具体实现

假设 transpose_output 函数的实现如下:

csharp 复制代码
def transpose_output(X, num_heads):
    # X: (batch_size * num_heads, seq_len, num_hiddens_per_head)
    batch_size_times_num_heads, seq_len, num_hiddens_per_head = X.shape
    batch_size = batch_size_times_num_heads // num_heads
    
    # 将 batch_size 和 num_heads 拆分
    X = X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head)
    
    # 交换维度,使得每个头的数据连续排列
    X = X.permute(0, 2, 1, 3)  # (batch_size, seq_len, num_heads, num_hiddens_per_head)
    
    # 将 num_heads 和 num_hiddens_per_head 合并
    X = X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head)
    
    return X
  • 解释
    1. 拆分维度:
      X.reshape(batch_size, num_heads, seq_len, num_hiddens_per_head):
      将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。
      此时,X 的形状为 (batch_size, num_heads, seq_len, num_hiddens_per_head)。
    2. 交换维度:
      X.permute(0, 2, 1, 3):
      将 seq_len 维度移到第二个位置,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_heads, num_hiddens_per_head)。
    3. 合并维度:
      X.reshape(batch_size, seq_len, num_heads * num_hiddens_per_head):
      将 num_heads 和 num_hiddens_per_head 合并,使得每个头的数据连续排列。
      此时,X 的形状为 (batch_size, seq_len, num_hiddens)。

总结

transpose_output 函数通过以下步骤将多头注意力的输出重新排列,使其适合后续的处理:

  • 将 batch_size * num_heads 维度拆分成 batch_size 和 num_heads。

  • 交换维度,使得每个头的数据连续排列。

  • 合并 num_heads 和 num_hiddens_per_head 维度,使得每个头的数据连续排列。

最终,transpose_output 函数返回形状为 (batch_size, seq_len, num_hiddens) 的张量,以便进行后续的处理。

相关推荐
启途AI1 分钟前
2026免费好用的AIPPT工具榜:智能演示文稿制作新纪元
人工智能·powerpoint·ppt
TH_18 分钟前
35、AI自动化技术与职业变革探讨
运维·人工智能·自动化
楚来客19 分钟前
AI基础概念之八:Transformer算法通俗解析
人工智能·算法·transformer
风送雨19 分钟前
FastMCP 2.0 服务端开发教学文档(下)
服务器·前端·网络·人工智能·python·ai
效率客栈老秦35 分钟前
Python Trae提示词开发实战(8):数据采集与清洗一体化方案让效率提升10倍
人工智能·python·ai·提示词·trae
小和尚同志40 分钟前
虽然 V0 很强大,但是ScreenshotToCode 依旧有市场
人工智能·aigc
HyperAI超神经44 分钟前
【vLLM 学习】Rlhf
人工智能·深度学习·学习·机器学习·vllm
芯盾时代1 小时前
石油化工行业网络风险解决方案
网络·人工智能·信息安全
线束线缆组件品替网1 小时前
Weidmüller 工业以太网线缆技术与兼容策略解析
网络·人工智能·电脑·硬件工程·材料工程
lambo mercy1 小时前
深度学习3:新冠病毒感染人数预测
人工智能·深度学习