多头注意力中的张量重塑

view

PyTorch 的view() 是张量「重塑(Reshape)」函数,用于改变张量的维度形状但不改变数据本身

在多头注意力中,view()的核心作用是将总隐藏维度拆分为「注意力头数 × 单头维度」,实现多头并行计算

核心规则

python 复制代码
tensor.view(*shape)

作用:将张量重塑为指定的shape,要求「新形状的元素总数 = 原张量的元素总数」(否则报错)

核心特性:

不改变张量的底层数据,仅改变维度的 "视图"(轻量操作,无数据拷贝)

重塑后的张量与原张量共享内存(修改一个,另一个也会变)

支持用-1自动推导某一维度的大小(仅能有一个-1)

python 复制代码
import torch

# 一维张量重塑为二维
x = torch.arange(12)  # shape=(12,),元素总数=12
x_view1 = x.view(3, 4)  # shape=(3,4),3×4=12
x_view2 = x.view(4, -1)  # -1自动推导为3,shape=(4,3)
print(x_view1.shape)  # torch.Size([3,4])
print(x_view2.shape)  # torch.Size([4,3])

# 三维张量重塑(核心:元素总数不变)
y = torch.randn(2, 6, 768)  # 2×6×768=9216
y_view = y.view(2, 6, 12, 64)  # 2×6×12×64=9216
print(y_view.shape)  # torch.Size([2,6,12,64])
  1. 关键注意事项
    报错场景:新形状元素总数≠原总数 → x.view(3,5)(12≠15)会报错
    -1的用法:仅能指定一个-1,用于自动计算维度(如view(2,-1,64))
    内存连续性:若张量内存不连续(如经过transpose/permute),需先调用contiguous()再view,否则报错

多头注意力中view的核心作用

将总隐藏维度d_model拆分为num_heads × d_k(单头维度),view()是实现这一拆分的关键

完整流程如下:

步骤 1:先明确核心参数(以 BERT-base 为例)

batch_size=2(批次)、seq_len=6(序列长度);

d_model=768(总隐藏维度)、num_heads=12(注意力头数)、d_k=64(单头维度,768=12×64);

输入querys:shape=(2,6,768)(经W_query线性变换后的输出)。

步骤 2:用view()拆分注意力头

python 复制代码
# 1. 原始querys:[batch, seq_len, d_model] = [2,6,768]
querys = torch.randn(2, 6, 768)

# 2. 拆分为多头:[2,6,12,64](batch, seq_len, num_heads, d_k)
querys_heads = querys.view(2, 6, 12, 64)
print(querys_heads.shape)  # torch.Size([2,6,12,64])

# 3. 转置调整维度(为后续批量矩阵乘法):[2,12,6,64]
# 注:transpose后内存不连续,需contiguous()才能再view
querys_heads = querys_heads.transpose(1, 2).contiguous()
print(querys_heads.shape)  # torch.Size([2,12,6,64])

步骤 3:注意力计算后,用view()合并多头

python 复制代码
# 假设注意力计算后的输出:[2,12,6,64](batch, num_heads, seq_len, d_k)
attn_output = torch.randn(2, 12, 6, 64)

# 1. 先转置回原维度:[2,6,12,64]
attn_output = attn_output.transpose(1, 2).contiguous()

# 2. 合并多头:[2,6,768](还原为总隐藏维度)
attn_output_merged = attn_output.view(2, 6, 768)
print(attn_output_merged.shape)  # torch.Size([2,6,768])

多头注意力完整实战代码

python 复制代码
import torch
import torch.nn as nn

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=768, num_heads=12):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads  # 64,用//保证整除
        
        # 定义Q/K/V线性层
        self.W_query = nn.Linear(d_model, d_model)
        self.W_key = nn.Linear(d_model, d_model)
        self.W_value = nn.Linear(d_model, d_model)
        
    def forward(self, x):
        # x: [batch_size, seq_len, d_model] = [2,6,768]
        batch_size, seq_len = x.shape[0], x.shape[1]
        
        # 1. 线性变换:Q/K/V均为[2,6,768]
        Q = self.W_query(x)
        K = self.W_key(x)
        V = self.W_value(x)
        
        # 2. 拆分为多头:[2,6,12,64]
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k)
        
        # 3. 转置:[2,12,6,64](batch, num_heads, seq_len, d_k)
        # 必须contiguous(),否则后续view会报错
        Q = Q.transpose(1, 2).contiguous()
        K = K.transpose(1, 2).contiguous()
        V = V.transpose(1, 2).contiguous()
        
        # 4. 计算注意力分数:Q @ K^T → [2,12,6,6]
        K_T = K.transpose(2, 3)  # [2,12,64,6]
        attn_scores = Q @ K_T  # [2,12,6,6]
        
        # 5. softmax归一化(省略,核心看view)
        attn_weights = torch.softmax(attn_scores, dim=-1)
        
        # 6. 加权求和:[2,12,6,64]
        attn_output = attn_weights @ V
        
        # 7. 转置+合并多头:[2,6,768]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, seq_len, self.d_model)
        
        return attn_output

测试代码

py 复制代码
mha = MultiHeadAttention(d_model=768, num_heads=12)
x = torch.randn(2, 6, 768)
output = mha(x)
print(output.shape)  # 输出:torch.Size([2,6,768])

view vs reshape

新手常混淆view和reshape,二者均用于重塑张量,但核心差异如下:

特性 view() reshape()

内存共享 与原张量共享内存(无拷贝) 优先共享内存,不连续则拷贝新内存

内存连续性 要求张量内存连续(否则报错) 自动处理内存不连续,无需contiguous()

适用场景 内存连续的张量(如线性层输出) 内存不连续的张量(如 transpose 后)

大模型开发建议

若确定张量内存连续(如线性层输出、原始输入),用view()(更高效)

若张量经过transpose/permute(如多头注意力中的转置),用reshape()(无需手动contiguous())

permute:重排,置换

示例:

python 复制代码
# 替代:transpose后直接reshape,更简洁
Q = Q.transpose(1, 2).reshape(batch_size, self.num_heads, seq_len, self.d_k)

总结

view()核心作用:改变张量维度形状,不改变数据,要求元素总数不变,支持-1自动推导维度

多头注意力中view()的核心用法

拆分:将[batch, seq_len, d_model]拆为[batch, seq_len, num_heads, d_k]

合并:注意力计算后,将[batch, seq_len, num_heads, d_k]合并回[batch, seq_len, d_model]

关键注意:transpose/permute后需contiguous()才能用view(),或直接用reshape()更便捷

contiguous

contiguous()用于将「内存不连续」的张量转换为「内存连续」的张量,保证张量的元素在内存中按维度顺序紧密排列

是view()等操作的前置必要条件

张量在计算机内存中是一维线性存储的

"连续" 指的是:张量的元素在内存中的排列顺序,和按「维度顺序(如从 0 维到最后一维)」遍历张量得到的顺序完全一致

直观示例(二维张量)

假设有张量x = torch.tensor([[1,2,3], [4,5,6]])(shape=(2,3)):

连续内存布局:内存中存储顺序是 1 → 2 → 3 → 4 → 5 → 6(按 "行优先" 遍历,先遍历 0 维,再遍历 1 维)

若对x做转置x.T,得到[[1,4], [2,5], [3,6]]

转置后的张量逻辑上是 3 行 2 列,但内存中仍存储为1→2→3→4→5→6(PyTorch 的transpose/permute仅修改 "维度视图",不拷贝数据)

此时按转置后的维度遍历(行优先),期望顺序是1→4→2→5→3→6,但内存实际顺序不符 → 转置后的张量是内存不连续的

为什么张量会变得 "不连续"?

PyTorch 中以下操作会导致张量内存不连续(核心是 "只改视图,不改内存"):

  • 维度变换类:transpose()、permute()(最常见,如多头注意力中的维度交换)
  • 索引 / 切片类:非连续切片(如x[:, ::2])、高级索引
  • 其他操作:narrow()、expand()(部分场景)
    这些操作的设计初衷是 "轻量"------ 避免不必要的数据拷贝,提升效率,但代价是破坏了内存连续性

contiguous()的核心作用

contiguous()会创建一个新的内存连续的张量:

  • 新张量与原张量数据相同,但内存排列会按照 "当前维度顺序" 重新整理
  • 新张量与原张量不再共享内存(是数据拷贝操作)
  • 只有内存连续的张量,才能调用view()(view()要求张量元素在内存中是连续的,否则无法正确重塑维度)

实战示例(结合多头注意力的经典场景)

python 复制代码
import torch

# 1. 创建连续张量
x = torch.randn(2, 6, 768)  # shape=(2,6,768),内存连续
print(x.is_contiguous())  # 输出:True

# 2. 转置后内存不连续
x_trans = x.transpose(1, 2)  # 交换1、2维,shape=(2,768,6)
print(x_trans.is_contiguous())  # 输出:False

# 3. 直接调用view()会报错
try:
    x_trans.view(2, 768, 12, 5)  # 768×6=12×60?不,768×6=4608=12×384,这里故意错,核心看报错
except Exception as e:
    print("报错:", e)  # 报错:view size is not compatible with input tensor's size and stride...

# 4. 先contiguous()再view(),正常运行
x_contig = x_trans.contiguous()
print(x_contig.is_contiguous())  # 输出:True
x_view = x_contig.view(2, 768, 12, 64)  # 2×768×12×64=2×768×768=1179648,和2×768×6=9216?哦,修正:x_trans.shape=(2,768,6),总元素=2×768×6=9216;view为2,768,12,0.5?重新来:
x_view = x_contig.view(2, 768, 12, 0.5)  # 故意错,实际应保证总元素一致:
x_view = x_contig.view(2, 768, 12, 0.5) → 正确示例:
x_contig = x_trans.contiguous()
x_view = x_contig.view(2, 12, 64, 6)  # 2×12×64×6=2×768×6=9216,总元素一致
print(x_view.shape)  # 输出:torch.Size([2, 12, 64, 6])

大模型开发中的核心应用场景(必掌握)

contiguous()几乎只在「transpose()/permute() + view()」的组合中使用,尤其是多头注意力层:

python 复制代码
# 多头注意力中拆分注意力头的标准流程
Q = torch.randn(2, 6, 768)  # [batch, seq_len, d_model]
Q = Q.view(2, 6, 12, 64)    # 拆分为多头:[2,6,12,64]
Q = Q.transpose(1, 2)       # 交换维度:[2,12,6,64] → 内存不连续
Q = Q.contiguous()          # 转为连续内存

后续可安全调用view()(若需要)

py 复制代码
Q = Q.view(2, 12, 6, 64)    # 正常运行

关键注意事项

contiguous()是数据拷贝操作:会消耗内存和时间,非必要时不要调用(比如仅做矩阵乘法,无需连续内存)

替代方案:reshape()会自动处理内存连续性(优先共享内存,不连续则自动拷贝),因此:

若只需重塑维度,用reshape()替代contiguous()+view()更简洁

示例:Q.transpose(1,2).reshape(2,12,6,64)(无需手动contiguous())

判断是否连续:用tensor.is_contiguous()快速检查,返回True则为连续

总结

contiguous()的核心:将内存不连续的张量转为连续,保证view()等操作能正常执行

触发场景:张量经过transpose()/permute()后,若要调用view(),必须先contiguous()

大模型实战建议:优先用reshape()替代contiguous()+view(),减少代码量且更安全

相关推荐
寻星探路6 小时前
【深度长文】万字攻克网络原理:从 HTTP 报文解构到 HTTPS 终极加密逻辑
java·开发语言·网络·python·http·ai·https
聆风吟º8 小时前
CANN runtime 全链路拆解:AI 异构计算运行时的任务管理与功能适配技术路径
人工智能·深度学习·神经网络·cann
User_芊芊君子8 小时前
CANN大模型推理加速引擎ascend-transformer-boost深度解析:毫秒级响应的Transformer优化方案
人工智能·深度学习·transformer
ValhallaCoder8 小时前
hot100-二叉树I
数据结构·python·算法·二叉树
智驱力人工智能9 小时前
小区高空抛物AI实时预警方案 筑牢社区头顶安全的实践 高空抛物检测 高空抛物监控安装教程 高空抛物误报率优化方案 高空抛物监控案例分享
人工智能·深度学习·opencv·算法·安全·yolo·边缘计算
人工不智能5779 小时前
拆解 BERT:Output 中的 Hidden States 到底藏了什么秘密?
人工智能·深度学习·bert
猫头虎9 小时前
如何排查并解决项目启动时报错Error encountered while processing: java.io.IOException: closed 的问题
java·开发语言·jvm·spring boot·python·开源·maven
h64648564h9 小时前
CANN 性能剖析与调优全指南:从 Profiling 到 Kernel 级优化
人工智能·深度学习
心疼你的一切9 小时前
解密CANN仓库:AIGC的算力底座、关键应用与API实战解析
数据仓库·深度学习·aigc·cann
八零后琐话10 小时前
干货:程序员必备性能分析工具——Arthas火焰图
开发语言·python