Pytorch中expand()和repeat()函数使用详解和实战示例

在 PyTorch 中,expand()repeat() 都用于张量维度的扩展与复制,但它们的原理和内存使用方式不同,适用于不同场景。


1、 expand():广播扩展(不复制数据)

功能:

expand() 返回一个视图(view),通过 广播机制 将张量在指定维度"扩展",不复制内存

限制:

只能在原始维度为 1 的轴上扩展,不能创建新的维度,也不能在非 1 的维度扩展。


示例:

python 复制代码
import torch

x = torch.tensor([[1], [2], [3]])  # shape: (3, 1)

# 扩展第二维到4
y = x.expand(3, 4)

print("x:\n", x)
print("y:\n", y)

输出:

复制代码
x:
 tensor([[1],
         [2],
         [3]])
y:
 tensor([[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]])

注意:

虽然看起来 y 是复制的,但它没有真正复制数据,只是广播视图,不占用额外内存。


2、repeat():真实复制张量内容

功能:

repeat() 沿指定维度进行实际数据复制 ,得到一个新的张量。与 expand() 不同,它复制数据,占用更多内存,但灵活性更强。


示例:

python 复制代码
x = torch.tensor([[1], [2], [3]])  # shape: (3, 1)

# 在第0维重复1次,第1维重复4次
y = x.repeat(1, 4)  # shape: (3, 4)

print("x:\n", x)
print("y:\n", y)

输出:

复制代码
x:
 tensor([[1],
         [2],
         [3]])
y:
 tensor([[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]])

3、 对比总结

特性 expand() repeat()
是否复制内存 ❌ 否,仅创建视图 ✅ 是,真正复制数据
是否支持广播 ✅ 只能在维度为 1 的轴广播 ❌ 直接复制,不依赖维度是否为1
内存开销 小(共享内存) 大(复制数据)
灵活性 限制多,效率高 更灵活但效率低
常用于场景 batch 中广播参数、掩码构造等 构造 tile 模式张量、数据重复

4、 实战对比:

python 复制代码
x = torch.tensor([1, 2, 3])  # shape: (3,)

# reshape 成 (3,1) 才能 expand
x1 = x.view(3, 1).expand(3, 4)
x2 = x.view(3, 1).repeat(1, 4)

print("expand:\n", x1)
print("repeat:\n", x2)

输出:

复制代码
expand:
 tensor([[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]])
repeat:
 tensor([[1, 1, 1, 1],
         [2, 2, 2, 2],
         [3, 3, 3, 3]])

虽然结果相同,但 expand() 的效率更高(无数据复制)。

如果希望根据具体场景(如节省内存 vs 需要独立复制)选择操作,建议:

  • 广播 → expand()
  • 真正数据复制 → repeat()

5、实战示例

下面分别通过两个实际应用场景来展示 expand()repeat() 的用法:


1. BERT 中 attention mask 构造

BERT 中的 self-attention 通常需要一个 attention_mask,它的形状是 (batch_size, 1, 1, seq_len)(batch_size, 1, seq_len, seq_len),需要通过 expand() 扩展广播。


示例场景:

你有一个 batch 的文本序列,每个位置为 1 表示有效,0 表示 padding:

python 复制代码
import torch

# 假设一个 batch 的 attention mask (batch_size=2, seq_len=4)
mask = torch.tensor([[1, 1, 1, 0],
                     [1, 1, 0, 0]])  # shape: (2, 4)

我们需要将其变成 (2, 1, 1, 4) 用于 broadcasting:

python 复制代码
# 加维度:batch_size x 1 x 1 x seq_len
mask_expanded = mask.unsqueeze(1).unsqueeze(2)  # shape: (2,1,1,4)

# 假设 query length 也为 4,我们要 broadcast 到 (2, 1, 4, 4)
attn_mask = mask_expanded.expand(-1, 1, 4, -1)  # -1 表示保持原维度

print(attn_mask.shape)
print(attn_mask)

输出:
复制代码
torch.Size([2, 1, 4, 4])
tensor([[[[1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0],
          [1, 1, 1, 0]]],

        [[[1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0]]]])

这就是 BERT 中构造多头注意力掩码时典型使用 expand() 的方式。


2. 图像 tile:使用 repeat() 扩展图像张量

例如我们有一个灰度图像 1x1x28x28(batch_size=1, channel=1),我们希望将这个图像 横向复制 2 次、纵向复制 3 次,形成一个大的拼接图像。


示例代码:
python 复制代码
import torch

# 创建一个伪图像:1个通道,28x28
img = torch.arange(28*28).reshape(1, 1, 28, 28).float()

# 纵向重复3次,横向重复2次
# repeat参数: batch, channel, height_repeat, width_repeat
tiled_img = img.repeat(1, 1, 3, 2)  # 形状: (1, 1, 84, 56)

print(tiled_img.shape)

输出:
复制代码
torch.Size([1, 1, 84, 56])

此操作实际复制了图像数据,每个像素在目标 tensor 中占有真实空间,可以用于拼接生成大图或训练 tile-based 图像模型。


3、 总结对比

场景 操作函数 原因
BERT attention mask expand() 避免内存复制,适合广播掩码
图像 tile 复制 repeat() 必须真实复制图像内容

6、 多头注意力中 mask 对多头扩展的写法(补充资料)

多头注意力(Multi-head Attention) 中,为了让不同的头共享相同的 attention_mask,我们通常需要将 mask 的 shape 从 (batch_size, seq_len) 扩展成 (batch_size, num_heads, seq_len, seq_len)

这个扩展操作通常组合使用 unsqueeze()expand()repeat(),根据具体实现选择是否复制内存。


1、场景设定

我们有:

  • batch size = B
  • sequence length = L
  • number of heads = H

输入的原始 attention mask:

python 复制代码
# 原始 padding mask:B × L
mask = torch.tensor([
    [1, 1, 1, 0],
    [1, 1, 0, 0]
])  # shape: (2, 4)

2、 多头注意力中扩展 mask 的方法

方法 1:使用 unsqueeze + expand不复制数据,更高效)
python 复制代码
B, L, H = 2, 4, 8  # batch, seq_len, num_heads

# 1. 原始 mask shape: (B, L)
# 2. 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2)  # shape: (2, 1, 1, 4)

# 3. 扩展到多头:B × H × L × L
mask = mask.expand(B, H, L, L)

print(mask.shape)

输出:

复制代码
torch.Size([2, 8, 4, 4])

这个扩展方式非常适合 BERT 等 Transformer 模型中的 attention_mask 构造,不会额外占用内存。


方法 2:使用 repeat复制数据

如果你希望生成独立的副本(比如 mask 后面会被修改),可以用 repeat

python 复制代码
# 先加两个维度:B × 1 × 1 × L
mask = mask.unsqueeze(1).unsqueeze(2)  # (2,1,1,4)

# 重复到多头:B × H × L × L
mask = mask.repeat(1, H, L, 1)  # batch维不变,头重复H次,query和key维重复

3、用在 Attention 权重前:

python 复制代码
# 假设 attention logits shape: (B, H, L, L)
attn_logits = torch.randn(B, H, L, L)

# mask == 0 的地方,我们不希望注意力流动,可设为 -inf(或 -1e9)
attn_logits = attn_logits.masked_fill(mask == 0, float('-inf'))

然后将 attn_logits 送入 softmax。


4、 总结:选择 expand or repeat

场景 推荐操作 原因
构造 attention mask(只读) expand() 高效、无数据复制
构造后需要修改 repeat() 每个位置是独立内存