目录
[1. 什么是 Reshape?](#1. 什么是 Reshape?)
[2. Reshape 的核心作用](#2. Reshape 的核心作用)
[① 数据适配](#① 数据适配)
[② 维度对齐](#② 维度对齐)
[③ 特征重组](#③ 特征重组)
[3. Reshape 的数学表示](#3. Reshape 的数学表示)
[7. Reshape 的注意事项](#7. Reshape 的注意事项)
[① 数据连续性](#① 数据连续性)
[② 维度顺序](#② 维度顺序)
[③ 性能优化](#③ 性能优化)
[二、sequenze 和 unequenze](#二、sequenze 和 unequenze)
[1. Sequence(序列)是什么?](#1. Sequence(序列)是什么?)
[2. Unsequence(非序列)是什么?](#2. Unsequence(非序列)是什么?)
[1. 填充(Padding)](#1. 填充(Padding))
[三、concat、stack、expand 和 flatten](#三、concat、stack、expand 和 flatten)
[1. concat(拼接)](#1. concat(拼接))
[2. stack(堆叠)](#2. stack(堆叠))
[3. expand(扩展)](#3. expand(扩展))
[4. flatten(展平)](#4. flatten(展平))
[5. 对比总结](#5. 对比总结)
一、reshape
返回一个具有与输入相同的数据和元素数量,但具有指定形状的张量。如果可能的话,返回的张量将是输入的视图。否则,它将是一个副本。连续的输入和具有兼容步幅的输入可以进行重塑而无需复制,但您不应依赖于复制与视图行为。
1. 什么是 Reshape?
- 核心功能:改变张量的维度(形状),但不改变其元素内容和存储顺序。
- 数学本质:通过重新排列索引,将原张量映射到新的形状空间。
2. Reshape 的核心作用
① 数据适配
- 将数据转换为模型输入要求的形状(如
[batch_size, channels, height, width]
)。 - 示例 :将
[100, 784]
(MNIST 图像展平)转换为[100, 1, 28, 28]
。
② 维度对齐
- 在矩阵乘法、卷积等操作中,确保输入张量的维度匹配。
- 示例 :将
[3, 5, 5]
转换为[3, 1, 5, 5]
以适配卷积层。
③ 特征重组
- 提取特定维度的特征(如将
[batch, height, width, channels]
转换为[batch, channels, height, width]
)。
3. Reshape 的数学表示
- 输入形状 :
(N, C_in, H_in, W_in)
- 输出形状 :
(N, C_out, H_out, W_out)
- 关键约束:N×Cin×Hin×Win=N×Cout×Hout×Wout即总元素数量必须保持不变。
4.代码示例
python
# 输入:[2, 3, 5, 5]
x = torch.randn(2, 3, 5, 5)
# Reshape to [2, 15, 5]
y = x.reshape(2, -1, 5)
print(y.shape) # torch.Size([2, 15, 5])
5.permute()方法
- 功能:重新排列张量的轴顺序(不改变元素值)。
python
# 输入:[batch=2, channels=3, height=5, width=5]
x = torch.randn(2, 3, 5, 5)
# 将 channels 和 height 交换
y = x.permute(0, 2, 1, 3) # 输出形状:[2, 5, 3, 5]
print(y.shape)
6.view()方法
- 功能:返回一个与原张量共享内存的新视图(需数据连续)。
python
import torch
# 输入:[batch=2, channels=3, height=5, width=5]
x = torch.randn(2, 3, 5, 5)
# Reshape to [2, 15, 5](3 * 5=15)
y = x.view(2, -1, 5) # -1 表示自动计算剩余维度
print(y.shape) # torch.Size([2, 15, 5])
7. Reshape 的注意事项
① 数据连续性
- view():要求原张量数据连续,否则会报错。
- reshape():允许非连续数据,但会复制内存,可能影响性能。
② 维度顺序
- 使用
permute()
时需明确指定轴顺序,避免逻辑错误。
③ 性能优化
- 尽量使用
view()
而非reshape()
以复用内存。
二、sequenze 和 unequenze
1. Sequence(序列)是什么?
- 定义:按顺序排列的数据,每个元素之间存在时间或逻辑上的依赖关系。
- 常见场景 :
- 自然语言处理(NLP):句子、单词序列。
- 时间序列分析:股票价格、传感器数据。
- 语音识别:音频信号帧序列。
- 数学形式:X=[x1,x2,...,xT],其中 T 是序列长度。
2. Unsequence(非序列)是什么?
- 定义:无顺序依赖的数据,元素之间是独立或空间相关的。
- 常见场景 :
- 图像分类:二维像素矩阵。
- 无监督聚类:客户分群、文档分类。
- 图神经网络(GNN):节点间无固定顺序的图结构。
3. 序列 vs 非序列的核心差异
维度 | 序列 | 非序列 |
---|---|---|
数据依赖 | 时间/逻辑顺序敏感 | 无顺序依赖 |
典型任务 | 文本生成、语音识别、时间序列预测 | 图像分类、目标检测、聚类 |
常用模型 | RNN、LSTM、Transformer | CNN、GCN、全连接层 |
输入形状 | [batch, T, ...] (T为序列长度) |
[batch, C, H, W] (C为通道数) |
4. 序列数据的处理方法
1. 填充(Padding)
- 目的:将不同长度的序列统一到相同长度。
python
import torch.nn.utils.rnn as rnn_utils
# 输入序列:batch=2, 最大长度=5
sequences = [
torch.randn(3), # 序列1(长度3)
torch.randn(5) # 序列2(长度5)
]
# 填充到长度5,用0填充
padded = rnn_utils.pad_sequence(sequences, batch_first=True)
print(padded.shape) # torch.Size([2, 5, ...])
2.打包(Packing)
- 目的:仅保留有效数据,忽略填充部分,提升计算效率。
python
# 输入序列和长度掩码
lengths = [3, 5]
packed = rnn_utils.pack_padded_sequence(sequences, lengths, batch_first=True)
# 解包输出
output, output_lengths = rnn_utils.unpack_packed_sequence(packed)
三、concat、stack、expand 和 flatten
1. concat(拼接)
功能
沿指定维度将多个张量连接成一个更大的张量,不改变原有维度。
import torch
# 定义两个二维张量
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 沿第0维(行方向)拼接
concatenated = torch.cat([a, b], dim=0)
print(concatenated)
# 输出:
# tensor([[1, 2],
# [3, 4],
# [5, 6],
# [7, 8]])
关键点
- 输入张量的其他维度必须一致。
- 结果形状:
(N+M, ...)
, 其中N
和M
是拼接张量的大小。
2. stack(堆叠)
功能
沿新维度将多个张量堆叠成更高维度的张量,新增一个维度。
python
import torch
a = torch.tensor([[1, 2], [3, 4]])
b = torch.tensor([[5, 6], [7, 8]])
# 沿新维度(第1维)堆叠
stacked = torch.stack([a, b], dim=1)
print(stacked)
# 输出:
# tensor([[[1, 2],
# [3, 4]],
# [[5, 6],
# [7, 8]]])
关键点
- 所有输入张量的形状必须完全相同。
- 结果形状:
(K, ...,)
,其中K
是堆叠的张量数量。
3. expand(扩展)
功能
通过广播机制,将张量在指定维度上重复元素,不复制数据(仅创建视图)。
python
import torch
# 原始张量:[1, 2]
x = torch.tensor([1, 2])
# 在第0维扩展2倍,得到 [1, 2, 1, 2]
expanded = x.expand(2, -1)
print(expanded) # tensor([1, 2, 1, 2])
# 在第1维扩展3倍,得到 [[1,1,1], [2,2,2]]
expanded_2d = x.unsqueeze(1).expand(-1, 3)
print(expanded_2d)
# tensor([[1, 1, 1],
# [2, 2, 2]])
关键点
expand
的参数需满足:new_dim_size >= original_dim_size
。- 需先通过
unsqueeze
创建新维度才能扩展。
4. flatten(展平)
功能
将多维张量压缩为一维或指定维度的连续数组,忽略其他维度。
python
import torch
# 原始张量:[2, 3, 4]
x = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
# 展平为一维
flattened = x.flatten()
print(flattened)
# 输出:
# tensor([1, 2, 3, 4, 5, 6, 7, 8])
# 展平到指定维度(保留第0维,合并后两维)
flattened_2d = x.flatten(start_dim=1)
print(flattened_2d)
# 输出:
# tensor([[1, 2, 3, 4],
# [5, 6, 7, 8]])
关键点
start_dim
指定从哪个维度开始展平,默认为0
。- 展平后张量的总元素数不变。
5. 对比总结
操作 | 核心功能 | 是否改变维度 | 内存消耗 | 典型场景 |
---|---|---|---|---|
concat | 沿指定维度拼接张量 | 否(保持原有维度) | 低(共享数据) | 数据合并(如特征拼接) |
stack | 新增维度堆叠张量 | 是(维度+1) | 中(复制数据) | 多模型输出堆叠(如图像分割) |
expand | 广播机制扩展元素 | 可能改变维度 | 极低(仅视图) | 扩展特征图尺寸(如上采样) |
flatten | 压缩多维张量为低维 | 是(降维) | 低(共享数据) | 全连接层输入适配 |
四、pointwise
Tensor 中逐元素进行的操作,也叫element wise 操作,大部分的activation 算子以及 add、sub、mul、div、sqrt 等都属于pointwise 类别。操作和numpy数组差不多
五、split和slice
将张量分割成多个块。每个块都是原始张量的视图。
python
import torch
# 创建一个示例张量
tensor = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
# 对张量进行切片
slice_tensor = tensor[2:7] # 从索引2到索引6(不包含7)
print(slice_tensor) # 输出: tensor([3, 4, 5, 6, 7])
# 使用步长对张量进行切片
step_slice_tensor = tensor[1:9:2] # 从索引1到索引8(不包含9),步长为2
print(step_slice_tensor) # 输出: tensor([2, 4, 6, 8])
# 省略起始索引和结束索引来选择整个张量
full_tensor = tensor[:]
print(full_tensor) # 输出: tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])