文章目录
- [1 view()函数](#1 view()函数)
-
- [1.1 基本用法](#1.1 基本用法)
- [2 view_as()函数](#2 view_as()函数)
- [3 reshape()函数](#3 reshape()函数)
- [4 permute()函数](#4 permute()函数)
- [5 transpose() 函数](#5 transpose() 函数)
- [6 squeeze()函数 和 unsqueeze()函数](#6 squeeze()函数 和 unsqueeze()函数)
1 view()函数
1.1 基本用法
view是将一个张量改变形状
函数原型
torch.Tensor.view(*shape) → Tensor
其中参数shape 可以是一个整数元组,或者是一个 系列整数
示例:两种不同参数比较
python
import torch
# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
# 参数用整数元组,变形为2x6的张量
y = x.view((2, 6))
print(y)
# Output:
# tensor([[ 1, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12]])
# 参数用系列整数值,将其变形为1x12的张量
z = x.view(1, 12)
print(z)
# Output:
# tensor([[ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])
要求 变化后元素总的个数和变化前相同,对应上面的例子,变化前后都得是12个元素~(3,4)(2,6)(1,12)
注意上面的x,y,z都是共享底层内存的,怎么理解呢?x,y,z本质还是一个东西,y,z并不是x的副本
就是只要改变x,y,z中的其中一个,其他的张量都会受到影响改变
比如如下
python
# 修改视图中的元素,原始张量也会受到影响
y[0, 0] = 99
print(x)
# Output:
# tensor([[99, 2, 3, 4],
# [ 5, 6, 7, 8],
# [ 9, 10, 11, 12]])
print(y)
# Output:
# tensor([[99, 2, 3, 4, 5, 6],
# [ 7, 8, 9, 10, 11, 12]])
print(z)
# Output:
# tensor([[99, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]])
还有一种常见的写法 -1,意味着这个值不是固定的,取决于其他维度,保证乘积不变即可
比如如下
python
import torch
# 创建一个3x4的张量
x = torch.tensor([[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12]])
y=x.view(-1,4)
print(y.shape)
z=x.view(6,-1)
print(z.shape)
输出
torch.Size([3, 4])
torch.Size([6, 2])
2 view_as()函数
view 函数,不需要指定形状,只需要指定要保持的那个对应的张量即可
view_as 可以将一个张量的形状改为与另外一个张量相同
python
import torch
# 创建一个形状为(4, 2)的张量
x = torch.randn(4, 2)
print(x.shape) # 输出:torch.Size([4, 2])
# 创建一个形状为(2, 2, 2)的张量
y = torch.randn(2, 2, 2)
print(y.shape) # 输出:torch.Size([2, 2, 2])
# 使用view_as方法将x的形状改变为与y相同的形状
x = x.view_as(y)
print(x.shape) # 输出:torch.Size([2, 2, 2])
注意区别
view
方法需要你明确指定新的形状。例如,如果你有一个形状为(4, 2)的张量,你可以使用view(2, 4)
来将其形状改变为(2, 4)
而view_as
方法则需要一个目标张量,它会将原始张量的形状改变为与目标张量相同的形状。
相同点是生成的 新的张量和原来张量都是共享底层内存的
3 reshape()函数
reshape使用整体和view差不多
reshape和view,大概率情况下会共享底层内存,但是在不连续的张量情况下(不连续发生在切片或者转置的时候),这时候会建立新的副本,这时候必须用reshape
例子
python
import torch
# 创建一个不连续的张量
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]])
strided_tensor = tensor[:, ::2] # 通过切片操作创建不连续的张量
# 使用 reshape 函数
reshaped_tensor =strided_tensor.reshape(4,1)
reshaped_tensor[0, 0] = 0
print("Original Tensor:", tensor)
print("Strided Tensor:", strided_tensor)
print("Reshaped Tensor:", reshaped_tensor)
输出
Original Tensor: tensor([[1, 2, 3],
[4, 5, 6]])
Strided Tensor: tensor([[1, 3],
[4, 6]])
Reshaped Tensor: tensor([[0],
[3],
[4],
[6]])
如果这时候继续用view
会报如下错
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
4 permute()函数
这段代码是在使用PyTorch处理张量时对张量的维度进行重新排序的操作。
X.permute(1,0,2)
是将X
的维度进行重新排序。permute
方法接收一组维度索引,然后按照这个索引的顺序重新排列张量的维度。
例如,如果X
是一个三维张量,其维度是(batch_size, seq_length, feature_size)
,那么X.permute(1,0,2)
将返回一个张量(也是共享底层内存)其维度是(seq_length, batch_size, feature_size)
。
这种操作在处理序列数据时非常常见,因为某些模型(如RNN、LSTM、GRU等)在输入数据时,需要序列长度(seq_length)在前,批量大小(batch_size)在后。所以,我们通常会使用permute
或者transpose
方法来调整维度的顺序。
示例如下
python
import torch
# 创建一个3维张量
x = torch.randn(2, 3, 4)
# 使用permute进行维度的置换
y = x.permute(1, 0, 2)
print(y.size()) # Output: torch.Size([3, 2, 4])
# 可以使用负数表示从最后一个维度开始的相对索引
z = x.permute(2, 1, 0)
print(z.size()) # Output: torch.Size([4, 3, 2])
5 transpose() 函数
在PyTorch中,transpose()
函数用于交换张量(tensor)的维度。该函数返回一个新的张量,其维度顺序是原始张量维度的重新排列。
函数签名如下:
python
torch.transpose(input, dim0, dim1) -> Tensor
input
: 输入的张量。dim0
: 第一个维度的索引。dim1
: 第二个维度的索引。
这个函数将input
张量的dim0
和dim1
两个维度进行交换。例如,如果input
张量的形状是(a, b, c)
,并且你使用transpose(input, 0, 1)
,则返回的张量的形状将是(b, a, c)
,即交换了第一个和第二个维度。
以下是一个简单的示例:
python
import torch
# 创建一个3x4的张量
x = torch.rand((3, 4))
# 使用transpose函数交换维度
y = torch.transpose(x, 0, 1)
print("原始张量:", x)
print("交换维度后的张量:", y)
请注意,transpose()
函数生成的张量并不会和原始张量共享内存,并不会修改原始张量,而是返回一个新的张量副本。如果你希望在原地操作(修改原始张量),可以使用transpose_()
方法:
python
# 在原地操作,修改原始张量
x.transpose_(0, 1)
print("原地操作后的张量:", x)
这里的下划线表示原地操作。
6 squeeze()函数 和 unsqueeze()函数
在PyTorch中,squeeze
和unsqueeze
是用于操作张量形状的函数,用于增加或减少维度。
-
squeeze
函数:torch.squeeze(input, dim=None, out=None)
函数用于删除张量中维度为1的轴。如果指定了dim
参数,则只会在指定轴上删除大小为1的维度,否则会删除所有大小为1的维度。- 参数:
input
: 输入的张量。dim
(可选): 要挤压的维度,如果指定,则只删除指定维度上的大小为1的轴。out
(可选): 输出张量,如果指定,则将结果存储在此张量中。
- 返回值:挤压后的张量。
示例:
pythonimport torch x = torch.randn(1, 3, 1, 4) y = torch.squeeze(x) # 在所有大小为1的维度上进行挤压 z = torch.squeeze(x, dim=2) # 只在维度2上挤压大小为1的轴 print(x.shape) # 输出: torch.Size([1, 3, 1, 4]) print(y.shape) # 输出: torch.Size([3, 4]) print(z.shape) # 输出: torch.Size([1, 3, 4])
-
unsqueeze
函数:torch.unsqueeze(input, dim)
函数用于在张量的指定位置插入维度为1的轴。- 参数:
input
: 输入的张量。dim
: 插入维度为1的轴的位置。
- 返回值:插入维度为1的轴后的张量。
示例:
pythonimport torch x = torch.randn(3, 4) y = torch.unsqueeze(x, dim=1) # 在维度1上插入大小为1的轴 print(x.shape) # 输出: torch.Size([3, 4]) print(y.shape) # 输出: torch.Size([3, 1, 4])
增加维度还可以通过None的方式增加
python
import torch
# 二维张量
x = torch.tensor([[1, 2, 3], [4, 5, 6]])
# 在第一个轴上增加维度
y = x[:, None, :]
# 或者使用 torch.unsqueeze
# y = torch.unsqueeze(x, dim=1)
print(x.shape) # 输出: torch.Size([2, 3])
print(y.shape) # 输出: torch.Size([2, 1, 3])