view
PyTorch 的view() 是张量「重塑(Reshape)」函数,用于改变张量的维度形状但不改变数据本身
在多头注意力中,view()的核心作用是将总隐藏维度拆分为「注意力头数 × 单头维度」,实现多头并行计算
核心规则
python
tensor.view(*shape)
作用:将张量重塑为指定的shape,要求「新形状的元素总数 = 原张量的元素总数」(否则报错)
核心特性:
不改变张量的底层数据,仅改变维度的 "视图"(轻量操作,无数据拷贝)
重塑后的张量与原张量共享内存(修改一个,另一个也会变)
支持用-1自动推导某一维度的大小(仅能有一个-1)
python
import torch
# 一维张量重塑为二维
x = torch.arange(12) # shape=(12,),元素总数=12
x_view1 = x.view(3, 4) # shape=(3,4),3×4=12
x_view2 = x.view(4, -1) # -1自动推导为3,shape=(4,3)
print(x_view1.shape) # torch.Size([3,4])
print(x_view2.shape) # torch.Size([4,3])
# 三维张量重塑(核心:元素总数不变)
y = torch.randn(2, 6, 768) # 2×6×768=9216
y_view = y.view(2, 6, 12, 64) # 2×6×12×64=9216
print(y_view.shape) # torch.Size([2,6,12,64])
- 关键注意事项
报错场景:新形状元素总数≠原总数 → x.view(3,5)(12≠15)会报错
-1的用法:仅能指定一个-1,用于自动计算维度(如view(2,-1,64))
内存连续性:若张量内存不连续(如经过transpose/permute),需先调用contiguous()再view,否则报错
多头注意力中view的核心作用
将总隐藏维度d_model拆分为num_heads × d_k(单头维度),view()是实现这一拆分的关键
完整流程如下:
步骤 1:先明确核心参数(以 BERT-base 为例)
batch_size=2(批次)、seq_len=6(序列长度);
d_model=768(总隐藏维度)、num_heads=12(注意力头数)、d_k=64(单头维度,768=12×64);
输入querys:shape=(2,6,768)(经W_query线性变换后的输出)。
步骤 2:用view()拆分注意力头
python
# 1. 原始querys:[batch, seq_len, d_model] = [2,6,768]
querys = torch.randn(2, 6, 768)
# 2. 拆分为多头:[2,6,12,64](batch, seq_len, num_heads, d_k)
querys_heads = querys.view(2, 6, 12, 64)
print(querys_heads.shape) # torch.Size([2,6,12,64])
# 3. 转置调整维度(为后续批量矩阵乘法):[2,12,6,64]
# 注:transpose后内存不连续,需contiguous()才能再view
querys_heads = querys_heads.transpose(1, 2).contiguous()
print(querys_heads.shape) # torch.Size([2,12,6,64])
步骤 3:注意力计算后,用view()合并多头
python
# 假设注意力计算后的输出:[2,12,6,64](batch, num_heads, seq_len, d_k)
attn_output = torch.randn(2, 12, 6, 64)
# 1. 先转置回原维度:[2,6,12,64]
attn_output = attn_output.transpose(1, 2).contiguous()
# 2. 合并多头:[2,6,768](还原为总隐藏维度)
attn_output_merged = attn_output.view(2, 6, 768)
print(attn_output_merged.shape) # torch.Size([2,6,768])
多头注意力完整实战代码
python
import torch
import torch.nn as nn
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=768, num_heads=12):
super().__init__()
self.d_model = d_model
self.num_heads = num_heads
self.d_k = d_model // num_heads # 64,用//保证整除
# 定义Q/K/V线性层
self.W_query = nn.Linear(d_model, d_model)
self.W_key = nn.Linear(d_model, d_model)
self.W_value = nn.Linear(d_model, d_model)
def forward(self, x):
# x: [batch_size, seq_len, d_model] = [2,6,768]
batch_size, seq_len = x.shape[0], x.shape[1]
# 1. 线性变换:Q/K/V均为[2,6,768]
Q = self.W_query(x)
K = self.W_key(x)
V = self.W_value(x)
# 2. 拆分为多头:[2,6,12,64]
Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k)
K = K.view(batch_size, seq_len, self.num_heads, self.d_k)
V = V.view(batch_size, seq_len, self.num_heads, self.d_k)
# 3. 转置:[2,12,6,64](batch, num_heads, seq_len, d_k)
# 必须contiguous(),否则后续view会报错
Q = Q.transpose(1, 2).contiguous()
K = K.transpose(1, 2).contiguous()
V = V.transpose(1, 2).contiguous()
# 4. 计算注意力分数:Q @ K^T → [2,12,6,6]
K_T = K.transpose(2, 3) # [2,12,64,6]
attn_scores = Q @ K_T # [2,12,6,6]
# 5. softmax归一化(省略,核心看view)
attn_weights = torch.softmax(attn_scores, dim=-1)
# 6. 加权求和:[2,12,6,64]
attn_output = attn_weights @ V
# 7. 转置+合并多头:[2,6,768]
attn_output = attn_output.transpose(1, 2).contiguous()
attn_output = attn_output.view(batch_size, seq_len, self.d_model)
return attn_output
测试代码
py
mha = MultiHeadAttention(d_model=768, num_heads=12)
x = torch.randn(2, 6, 768)
output = mha(x)
print(output.shape) # 输出:torch.Size([2,6,768])
view vs reshape
新手常混淆view和reshape,二者均用于重塑张量,但核心差异如下:
特性 view() reshape()
内存共享 与原张量共享内存(无拷贝) 优先共享内存,不连续则拷贝新内存
内存连续性 要求张量内存连续(否则报错) 自动处理内存不连续,无需contiguous()
适用场景 内存连续的张量(如线性层输出) 内存不连续的张量(如 transpose 后)
大模型开发建议
若确定张量内存连续(如线性层输出、原始输入),用view()(更高效)
若张量经过transpose/permute(如多头注意力中的转置),用reshape()(无需手动contiguous())
permute:重排,置换
示例:
python
# 替代:transpose后直接reshape,更简洁
Q = Q.transpose(1, 2).reshape(batch_size, self.num_heads, seq_len, self.d_k)
总结
view()核心作用:改变张量维度形状,不改变数据,要求元素总数不变,支持-1自动推导维度
多头注意力中view()的核心用法
拆分:将[batch, seq_len, d_model]拆为[batch, seq_len, num_heads, d_k]
合并:注意力计算后,将[batch, seq_len, num_heads, d_k]合并回[batch, seq_len, d_model]
关键注意:transpose/permute后需contiguous()才能用view(),或直接用reshape()更便捷
contiguous
contiguous()用于将「内存不连续」的张量转换为「内存连续」的张量,保证张量的元素在内存中按维度顺序紧密排列
是view()等操作的前置必要条件
张量在计算机内存中是一维线性存储的
"连续" 指的是:张量的元素在内存中的排列顺序,和按「维度顺序(如从 0 维到最后一维)」遍历张量得到的顺序完全一致
直观示例(二维张量)
假设有张量x = torch.tensor([[1,2,3], [4,5,6]])(shape=(2,3)):
连续内存布局:内存中存储顺序是 1 → 2 → 3 → 4 → 5 → 6(按 "行优先" 遍历,先遍历 0 维,再遍历 1 维)
若对x做转置x.T,得到[[1,4], [2,5], [3,6]]
转置后的张量逻辑上是 3 行 2 列,但内存中仍存储为1→2→3→4→5→6(PyTorch 的transpose/permute仅修改 "维度视图",不拷贝数据)
此时按转置后的维度遍历(行优先),期望顺序是1→4→2→5→3→6,但内存实际顺序不符 → 转置后的张量是内存不连续的
为什么张量会变得 "不连续"?
PyTorch 中以下操作会导致张量内存不连续(核心是 "只改视图,不改内存"):
- 维度变换类:transpose()、permute()(最常见,如多头注意力中的维度交换)
- 索引 / 切片类:非连续切片(如x[:, ::2])、高级索引
- 其他操作:narrow()、expand()(部分场景)
这些操作的设计初衷是 "轻量"------ 避免不必要的数据拷贝,提升效率,但代价是破坏了内存连续性
contiguous()的核心作用
contiguous()会创建一个新的内存连续的张量:
- 新张量与原张量数据相同,但内存排列会按照 "当前维度顺序" 重新整理
- 新张量与原张量不再共享内存(是数据拷贝操作)
- 只有内存连续的张量,才能调用view()(view()要求张量元素在内存中是连续的,否则无法正确重塑维度)
实战示例(结合多头注意力的经典场景)
python
import torch
# 1. 创建连续张量
x = torch.randn(2, 6, 768) # shape=(2,6,768),内存连续
print(x.is_contiguous()) # 输出:True
# 2. 转置后内存不连续
x_trans = x.transpose(1, 2) # 交换1、2维,shape=(2,768,6)
print(x_trans.is_contiguous()) # 输出:False
# 3. 直接调用view()会报错
try:
x_trans.view(2, 768, 12, 5) # 768×6=12×60?不,768×6=4608=12×384,这里故意错,核心看报错
except Exception as e:
print("报错:", e) # 报错:view size is not compatible with input tensor's size and stride...
# 4. 先contiguous()再view(),正常运行
x_contig = x_trans.contiguous()
print(x_contig.is_contiguous()) # 输出:True
x_view = x_contig.view(2, 768, 12, 64) # 2×768×12×64=2×768×768=1179648,和2×768×6=9216?哦,修正:x_trans.shape=(2,768,6),总元素=2×768×6=9216;view为2,768,12,0.5?重新来:
x_view = x_contig.view(2, 768, 12, 0.5) # 故意错,实际应保证总元素一致:
x_view = x_contig.view(2, 768, 12, 0.5) → 正确示例:
x_contig = x_trans.contiguous()
x_view = x_contig.view(2, 12, 64, 6) # 2×12×64×6=2×768×6=9216,总元素一致
print(x_view.shape) # 输出:torch.Size([2, 12, 64, 6])
大模型开发中的核心应用场景(必掌握)
contiguous()几乎只在「transpose()/permute() + view()」的组合中使用,尤其是多头注意力层:
python
# 多头注意力中拆分注意力头的标准流程
Q = torch.randn(2, 6, 768) # [batch, seq_len, d_model]
Q = Q.view(2, 6, 12, 64) # 拆分为多头:[2,6,12,64]
Q = Q.transpose(1, 2) # 交换维度:[2,12,6,64] → 内存不连续
Q = Q.contiguous() # 转为连续内存
后续可安全调用view()(若需要)
py
Q = Q.view(2, 12, 6, 64) # 正常运行
关键注意事项
contiguous()是数据拷贝操作:会消耗内存和时间,非必要时不要调用(比如仅做矩阵乘法,无需连续内存)
替代方案:reshape()会自动处理内存连续性(优先共享内存,不连续则自动拷贝),因此:
若只需重塑维度,用reshape()替代contiguous()+view()更简洁
示例:Q.transpose(1,2).reshape(2,12,6,64)(无需手动contiguous())
判断是否连续:用tensor.is_contiguous()快速检查,返回True则为连续
总结
contiguous()的核心:将内存不连续的张量转为连续,保证view()等操作能正常执行
触发场景:张量经过transpose()/permute()后,若要调用view(),必须先contiguous()
大模型实战建议:优先用reshape()替代contiguous()+view(),减少代码量且更安全