文章目录
-
- MultiHeadAttentionFormal的实现
- 操作详解
-
- [1. 🔍 attention_mask](#1. 🔍 attention_mask)
- [2. 🔍 matmul](#2. 🔍 matmul)
-
- [✅ 其他实现方式](#✅ 其他实现方式)
-
- [1. 使用 `@` 运算符(推荐简洁写法)](#1. 使用
@
运算符(推荐简洁写法)) - [2. 使用 `torch.einsum()`(爱因斯坦求和约定)](#2. 使用
torch.einsum()
(爱因斯坦求和约定)) - [3. 使用 `torch.bmm()`(批量矩阵乘法)](#3. 使用
torch.bmm()
(批量矩阵乘法)) - [4. 使用 `unsqueeze` + `squeeze` 控制维度(兼容高维)](#4. 使用
unsqueeze
+squeeze
控制维度(兼容高维)) - [5. 使用 `F.linear()` 实现投影(不常用)](#5. 使用
F.linear()
实现投影(不常用))
- [1. 使用 `@` 运算符(推荐简洁写法)](#1. 使用
- [📌 对比总结表](#📌 对比总结表)
- [💡 示例对比(均等效)](#💡 示例对比(均等效))
- [3. 🔍 transpose](#3. 🔍 transpose)
-
- [📌 定义](#📌 定义)
- [🧠 在多头注意力中的典型应用场景](#🧠 在多头注意力中的典型应用场景)
- [✅ 其他实现方式](#✅ 其他实现方式)
-
- [1. 使用 `permute(*dims)` ------ 更灵活的维度重排](#1. 使用
permute(*dims)
—— 更灵活的维度重排) - [2. 使用 `swapaxes(dim0, dim1)` ------ 与 transpose 等效](#2. 使用
swapaxes(dim0, dim1)
—— 与 transpose 等效)
- [1. 使用 `permute(*dims)` ------ 更灵活的维度重排](#1. 使用
- [📌 总结对比表](#📌 总结对比表)
- [💡 示例说明](#💡 示例说明)
- [🛠 实际应用建议](#🛠 实际应用建议)
- [4. 🔍 view()](#4. 🔍 view())
-
- [🔄 其他等效实现方式](#🔄 其他等效实现方式)
-
- [1. `torch.reshape(tensor, shape)`](#1.
torch.reshape(tensor, shape)
) - [2. 使用 `flatten(start_dim, end_dim)` 合并维度](#2. 使用
flatten(start_dim, end_dim)
合并维度) - [3. 使用 `einops.rearrange`(推荐用于可读性)](#3. 使用
einops.rearrange
(推荐用于可读性))
- [1. `torch.reshape(tensor, shape)`](#1.
- [✅ 总结对比](#✅ 总结对比)
- [💡 实际应用建议](#💡 实际应用建议)
- [5. 🔍 masked_fill()](#5. 🔍 masked_fill())
-
- [🧠 函数定义](#🧠 函数定义)
- 示例解析
- [✅ 实际案例演示](#✅ 实际案例演示)
- [⚠️ 注意事项](#⚠️ 注意事项)
- [💡 应用场景](#💡 应用场景)
- [✅ 总结](#✅ 总结)
- [📌 最佳实践建议](#📌 最佳实践建议)
- 参考材料
MultiHeadAttentionFormal的实现
python
import torch
import torch.nn as nn
import math
class MultiHeadAttentionFormal(nn.Module):
def __init__(self, hidden_dim, head_num, attention_dropout=0.1):
super().__init__()
self.hidden_dim = hidden_dim
self.head_num = head_num
self.head_dim = hidden_dim // head_num # head_num * head_dim = hidden_dim
self.q_proj = nn.Linear(hidden_dim, hidden_dim) # (hidden_dim, head_dim * head_num)
self.k_proj = nn.Linear(hidden_dim, hidden_dim)
self.v_proj = nn.Linear(hidden_dim, hidden_dim)
self.output = nn.Linear(hidden_dim, hidden_dim)
self.attention_dropout = nn.Dropout(attention_dropout)
def forward(self, x, attention_mask=None):
# X (batch_size, seq_len, hidden_dim)
batch_size, seq_len, _ = x.shape
# Q/K/V的shape: (batch_size, seq_len, hidden_dim)
Q = self.q_proj(x)
K = self.k_proj(x)
V = self.v_proj(x)
# (batch_size, seq_len, hidden_dim),其中 hidden_dim = head_num * head_dim
# -> (batch_size, seq_len, head_num, head_dim)
# -> (batch_size, head_num, seq_len, head_dim)
q_state = Q.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
k_state = K.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
v_state = V.view(batch_size, seq_len, self.head_num, self.head_dim).transpose(1, 2)
# k_state的转置
# (batch_size, head_num, seq_len, head_dim)
# -> (batch_size, head_num, head_dim, seq_len)
# 相乘的结果,shape为(batch_size, head_num, seq_len, seq_len)
atten_weight = torch.matmul(q_state, k_state.transpose(-2, -1)) / math.sqrt(self.head_dim)
print("stage1, atten_weight.shape: ", atten_weight.shape)
if attention_mask is not None:
atten_weight = atten_weight.masked_fill(attention_mask==0, float("-inf"))
print("stage2, atten_weight.shape: ", atten_weight.shape)
atten_weight = torch.softmax(atten_weight, dim=-1)
print("stage3, atten_weight.shape: ", atten_weight.shape)
atten_weight = self.attention_dropout(atten_weight)
print("stage4, atten_weight.shape: ", atten_weight.shape)
# atten_weight: (batch_size, head_num, seq_len, seq_len)
# v_state: (batch_size, head_num, seq_len, head_dim)
# => (batch_size, head_num, seq_len, head_dim)
output_mid = torch.matmul(atten_weight, v_state)
print("stage1, output_mid.shape: ", output_mid.shape, "v_state.shape: ", v_state.shape)
# transpose后,张量的内存可能变得不连续,所以需要用contiguous把内存连续化;view()、reshape()、flatten()、torch.nn.Linear、torch.matmul 等操作对输入张量有连续性的要求。
output_mid = output_mid.transpose(1, 2).contiguous()
print("stage2, output_mid.shape: ", output_mid.shape)
output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)
print("stage3, output_mid.shape: ", output_mid.shape)
output = self.output(output_mid)
return output
attention_mask = torch.tensor(
[
[1,1],
[1,0],
[1,0]
]
).unsqueeze(1).unsqueeze(2).expand(3, 8, 2, 2)
# batch_size, seq_len, hidden_dim
X = torch.rand(3, 2, 128)
net = MultiHeadAttentionFormal(128, 8) # hidden_dim = 128, head_num = 8 -> head_dim = 16
net(X, attention_mask)
操作详解
1. 🔍 attention_mask
首先是创建一个随机张量,shape为(batch_size, seq_len)
python
attention_mask = torch.tensor([
[1, 1],
[1, 0],
[1, 0]
])
这是一个形状为 (3, 2) 的张量。
每一行表示一个样本(batch)的 attention mask:
1 表示该位置是有效的;
0 表示该位置是 padding,需要被屏蔽
然后增加维度
python
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
---------------------
tensor([[[[1, 1]]],
[[[1, 0]]],
[[[1, 0]]]])
第一次 unsqueeze(1):增加第 1 维度(head_num),形状变为 (3, 1, 2)
第二次 unsqueeze(2):增加第 2 维度(seq_len),形状变为 (3, 1, 1, 2)
此时维度含义为:(batch_size, 1, 1, seq_len)
注意:此时还没有考虑 head_num,只是准备好了 mask 的基本结构
现在扩展到head_num
py
attention_mask = attention_mask.expand(3, 8, 2, 2)
-------------
tensor([[[[1, 1], [1, 1]], # 头1
[[1, 1], [1, 1]], # 头2
[[1, 1], [1, 1]], # 头3
[[1, 1], [1, 1]], # 头4
[[1, 1], [1, 1]], # 头5
[[1, 1], [1, 1]], # 头6
[[1, 1], [1, 1]], # 头7
[[1, 1], [1, 1]], # 头8
]
[[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]]
],
[[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]],
[[1, 0], [1, 0]]
]
])
expand() 是 PyTorch 中用于广播张量的方法,不会复制数据,而是共享内存。
将 (3, 1, 1, 2) 扩展为 (3, 8, 2, 2):
3:batch size
8:head_num,每个 head 都使用相同的 mask
2:query 的序列长度(seq_len)
2:key/value 的序列长度(seq_len)
attention_mask.shape为(batch_size, head_num, seq_len, seq_len)
2. 🔍 matmul
以这行代码为例:
python
output_mid = torch.matmul(atten_weight, v_state)
其中:
atten_weight.shape = (batch_size, head_num, seq_len, seq_len)
,即注意力权重矩阵(通常是 softmax 后的结果)v_state.shape = (batch_size, head_num, seq_len, head_dim)
,即 value 的状态
这个操作本质上是将 attention weight 与 value 进行矩阵乘法,得到加权后的输出。
✅ 其他实现方式
1. 使用 @
运算符(推荐简洁写法)
python
output_mid = atten_weight @ v_state
- 等价于
torch.matmul
- 更加 Pythonic,代码更简洁
- 支持广播机制
2. 使用 torch.einsum()
(爱因斯坦求和约定)
python
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)
- 非常灵活,适用于多头注意力、交叉注意力等复杂结构
- 显式控制每个维度的运算规则,可读性略差但表达能力更强
- 在调试或构建复杂模型时非常有用
3. 使用 torch.bmm()
(批量矩阵乘法)
python
# 将 batch 和 head 合并成一个大 batch 维度
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len) # shape: (B*H, T, T)
v_state_flat = v_state.view(-1, seq_len, head_dim) # shape: (B*H, T, D)
output_flat = torch.bmm(atten_weight_flat, v_state_flat) # shape: (B*H, T, D)
output_mid = output_flat.view(batch_size, head_num, seq_len, head_dim)
- 只支持 3D 张量,不支持自动广播
- 性能接近
matmul
,但需要手动处理维度变形
4. 使用 unsqueeze
+ squeeze
控制维度(兼容高维)
python
output_mid = torch.matmul(
atten_weight.unsqueeze(-2), v_state.unsqueeze(-1)
).squeeze(-1)
- 通过添加/删除维度来精确控制 matmul 操作维度
- 适合在图像、视频等 attention 中使用
5. 使用 F.linear()
实现投影(不常用)
虽然不是标准做法,但如果 atten_weight
是某种投影权重矩阵,也可以用线性层模拟。但在 attention 中通常不适用。
📌 对比总结表
方法 | 输入要求 | 是否支持 batch | 是否支持 broadcasting | 推荐用于 Attention |
---|---|---|---|---|
torch.matmul |
任意维度 | ✅ | ✅ | ✅✅✅ |
@ |
任意维度 | ✅ | ✅ | ✅✅✅(简洁) |
torch.einsum |
需要指定索引 | ✅ | ✅ | ✅✅✅(多头) |
torch.bmm |
必须为 3D | ✅ | ❌ | ✅(简单 attention) |
unsqueeze + matmul |
手动控制维度 | ✅ | ✅ | ✅(特殊场景) |
💡 示例对比(均等效)
python
# 原始写法
output_mid = torch.matmul(atten_weight, v_state)
# 使用 @ 符号
output_mid = atten_weight @ v_state
# 使用 einsum
output_mid = torch.einsum('bhij,bhjd->bhid', atten_weight, v_state)
# 使用 bmm(需 flatten + reshape)
batch_size, head_num, seq_len, _ = atten_weight.shape
atten_weight_flat = atten_weight.view(-1, seq_len, seq_len)
v_state_flat = v_state.view(-1, seq_len, head_dim)
output_flat = torch.bmm(atten_weight_flat, v_state_flat)
output_mid = output_flat.view(batch_size, head_num, seq_len, -1)
3. 🔍 transpose
python
output_mid = output_mid.transpose(1, 2)
这行代码的作用是交换张量的第 1
维和第 2
维。用于处理多头注意力(Multi-Head Attention)中张量形状的调整。
📌 定义
python
torch.Tensor.transpose(dim0, dim1) -> Tensor
- 功能:返回一个新的张量,其中指定的两个维度被交换。
- 参数 :
dim0
: 第一个维度dim1
: 第二个维度
⚠️ 注意:这个操作不会复制数据,而是返回原始张量的一个视图(view)。如果后续需要使用 view()
或 reshape()
,可能需要调用 .contiguous()
来确保内存连续。
🧠 在多头注意力中的典型应用场景
python
# 假设 input shape: (batch_size, head_num, seq_len, head_dim)
output_mid = output_mid.transpose(1, 2)
原始形状:
python
output_mid.shape = (batch_size, head_num, seq_len, head_dim)
转置后形状:
python
output_mid.shape = (batch_size, seq_len, head_num, head_dim)
然后一般会进行 view()
操作来合并 head_num
和 head_dim
,得到最终输出:
python
output_mid = output_mid.contiguous().view(batch_size, seq_len, -1)
# 最终 shape: (batch_size, seq_len, hidden_dim)
这是将多头注意力结果重新拼接回原始隐藏层大小的关键步骤。
✅ 其他实现方式
除了使用 transpose()
,还有以下几种方法可以实现类似效果:
1. 使用 permute(*dims)
------ 更灵活的维度重排
python
output_mid = output_mid.permute(0, 2, 1, 3)
-
permute()
可以一次重排多个维度 -
示例前后的 shape 对应关系:
python# 原 shape: (batch_size, head_num, seq_len, head_dim) # 新 shape: (batch_size, seq_len, head_num, head_dim)
✅ 推荐用于更复杂的维度变换场景
2. 使用 swapaxes(dim0, dim1)
------ 与 transpose 等效
python
output_mid = output_mid.swapaxes(1, 2)
- 与
transpose()
功能相同 - 更语义化,适合阅读时强调"交换"而非"转置"
📌 总结对比表
方法 | 支持任意维 | 是否返回 view | 是否支持链式操作 | 推荐用途 |
---|---|---|---|---|
transpose() |
❌ 仅限两个维度 | ✅ | ✅ | 简单交换两个维度 |
permute() |
✅ 多维支持 | ✅ | ✅ | 高阶张量维度重排(推荐) |
swapaxes() |
✅ | ✅ | ✅ | 强调"交换",语义更强 |
💡 示例说明
假设输入为:
python
output_mid.shape = (3, 8, 2, 16) # batch_size=3, head_num=8, seq_len=2, head_dim=16
使用 transpose(1, 2)
:
python
output_mid = output_mid.transpose(1, 2)
# output_mid.shape
(batch_size, head_num, seq_len, head_dim)
=> (batch_size, seq_len, head_num, head_dim)
(3, 8, 2, 16)
=> (3, 2, 8, 16)
使用 permute(0, 2, 1, 3)
:
python
output_mid = output_mid.permute(0, 2, 1, 3)
# output_mid.shape => (3, 2, 8, 16)
两者等价,但 permute()
更具通用性。
🛠 实际应用建议
- 如果只是交换两个维度 →
transpose()
- 如果涉及多维重排 →
permute()
- 如果要合并/拆分某些维度 →
permute()
+contiguous()
+view()
4. 🔍 view()
在 PyTorch 中,view()
是一个用于 改变张量形状(reshape) 的函数。它不会修改张量的数据,只是重新解释其形状。
语法:
python
tensor.view(shape)
示例代码:
python
output_mid = output_mid.view(batch_size, seq_len, self.hidden_dim)
前提条件:
output_mid
当前的 shape 是(batch_size, seq_len, head_num, head_dim)
head_num * head_dim == hidden_dim
- 所以 view 后变为
(batch_size, seq_len, hidden_dim)
作用:
将多头注意力中每个 head 的输出拼接起来,恢复成原始的 hidden_dim
维度。
比如:
python
# 假设 batch_size=3, seq_len=2, head_num=8, head_dim=16
output_mid.shape = (3, 8, 2, 16) # transpose + contiguous 后
output_mid = output_mid.view(3, 2, 128) # 8*16 = 128
⚠️ 注意:使用
view()
前必须保证张量是连续的(contiguous),否则会报错。所以前面通常有.contiguous()
调用。
🔄 其他等效实现方式
除了 view()
,还有以下几种方式可以实现类似功能:
1. torch.reshape(tensor, shape)
与 view()
类似,但更灵活,可以在非连续内存上运行。
python
output_mid = output_mid.reshape(batch_size, seq_len, self.hidden_dim)
✅ 推荐使用这个替代 view()
,因为不需要关心是否是连续内存。
2. 使用 flatten(start_dim, end_dim)
合并维度
python
output_mid = output_mid.transpose(1, 2).flatten(start_dim=2, end_dim=3)
这相当于把第 2 和第 3 维合并,效果等同于 reshape 或 view。
3. 使用 einops.rearrange
(推荐用于可读性)
来自 einops
库(einop库安装及介绍),提供更直观的维度操作方式:
python
from einops import rearrange
output_mid = rearrange(output_mid, 'b h s d -> b s (h d)')
优点:
- 更易读
- 不需要关心是否连续
- 可扩展性强(支持更多复杂变换)
✅ 总结对比
方法 | 是否要求连续 | 易读性 | 灵活性 | 推荐场景 |
---|---|---|---|---|
view() |
❌ 必须连续 | ⬇️ 差 | ⬇️ 一般 | 小规模调试 |
reshape() |
✅ 不要求 | ⬆️ 好 | ⬆️ 强 | 通用替换 view |
flatten() |
✅ 不要求 | ⬆️ 好 | ⬆️ 强 | 多维合并 |
einops.rearrange() |
✅ 不要求 | ⬆️ 很好 | ⬆️ 非常强 | 工程项目 |
💡 实际应用建议
如果你在写正式项目或模型工程化,推荐使用:
python
from einops import rearrange
output_mid = rearrange(output_mid, 'b h s d -> b s (h d)')
或者安全版本(不依赖连续内存):
python
output_mid = output_mid.transpose(1, 2)
output_mid = output_mid.flatten(2) # (b, s, h*d)
这样不仅代码清晰,也避免了对 .contiguous()
的依赖问题。
5. 🔍 masked_fill()
在 PyTorch 中,masked_fill()
是一个非常常用的函数,用于 根据布尔掩码(mask)对张量的某些位置进行填充。它常用于 NLP 任务中,比如 Transformer 模型中的 attention mask 处理。
🧠 函数定义
python
torch.Tensor.masked_fill(mask, value)
参数说明:
mask
: 一个布尔类型的张量(True/False),形状必须与原张量相同。value
: 要填充的值,可以是标量或广播兼容的张量。
行为:
- 对于
mask
中为True
的位置,将原张量对应位置的值替换为value
。 False
的位置保持不变。
示例解析
python
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))
解释:
attention_mask == 0
:- 这是一个布尔操作,生成一个和
attention_mask
形状相同的布尔张量。 - 所有等于 0 的位置变成
True
,表示这些位置是 pad 或无效 token,不应该参与 attention 计算。
- 这是一个布尔操作,生成一个和
float("-inf")
:- 将这些被 mask 的位置填入负无穷大。
- 在后续 softmax 中,
exp(-inf)
会变成 0,从而实现"忽略这些位置"的效果。
✅ 实际案例演示
输入示例:
python
import torch
# 原始 attention 权重 (模拟)
atten_weight = torch.tensor([
[0.1, 0.2, 0.3, 0.4],
[0.5, 0.6, 0.7, 0.8]
])
# attention mask (pad 位置为 0)
attention_mask = torch.tensor([
[1, 1, 0, 0],
[1, 0, 0, 0]
])
# 应用 masked_fill
atten_weight = atten_weight.masked_fill(attention_mask == 0, float("-inf"))
print(atten_weight)
输出结果:
text
tensor([[ 0.1000, 0.2000, -inf, -inf],
[ 0.5000, -inf, -inf, -inf]])
后续 softmax 结果:
python
import torch.nn.functional as F
F.softmax(atten_weight, dim=-1)
输出:
text
tensor([[0.4621, 0.5379, 0.0000, 0.0000],
[1.0000, 0.0000, 0.0000, 0.0000]])
可以看到,mask 为 0 的位置在 softmax 后变成了 0,不会影响最终注意力分布。
⚠️ 注意事项
-
mask 张量的 shape 必须与目标张量一致:
-
如果你有一个
(batch_size, seq_len)
的 mask,而atten_weight
是(batch_size, head_num, seq_len, seq_len)
,你需要通过unsqueeze
和expand
调整 mask 的维度。 -
示例:
pythonattention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # -> (batch_size, 1, 1, seq_len) attention_mask = attention_mask.expand(batch_size, num_heads, seq_len, seq_len)
-
-
不能直接使用 int 类型的 mask:
masked_fill
只接受布尔类型作为 mask,所以要确保使用了比较操作如==
,!=
等。
💡 应用场景
场景 | 描述 |
---|---|
padding mask | 防止模型关注到 padding 的 token |
look-ahead mask | 防止 decoder 在预测时看到未来 token |
自定义屏蔽机制 | 如屏蔽某些特定词、句子结构等 |
✅ 总结
方法 | 作用 | 推荐指数 |
---|---|---|
masked_fill(mask == 0, -inf) |
屏蔽不需要关注的位置 | ⭐⭐⭐⭐⭐ |
F.softmax(..., dim=-1) |
使屏蔽位置变为 0 | ⭐⭐⭐⭐ |
mask 维度适配 | 使用 unsqueeze + expand 调整 mask 到与 attn weight 相同 |
⭐⭐⭐⭐⭐ |
📌 最佳实践建议
python
# 假设 attention_mask: (batch_size, seq_len)
# attn_weights: (batch_size, num_heads, seq_len_q, seq_len_k)
# Step 1: 添加两个维度,使其匹配 attn_weights 的 shape
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) # -> (B, 1, 1, S)
# Step 2: 扩展 mask 使得其与 attn_weights 形状完全一致
attention_mask = attention_mask.expand_as(attn_weights) # -> same shape as attn_weights
# Step 3: 应用 mask,填入 -inf
attn_weights = attn_weights.masked_fill(attention_mask == 0, float('-inf'))
这样就能保证每个 head 和 query 的位置都能正确屏蔽掉 pad 或无效 token。