张量的形状操作函数概括
张量的形状变换操作函数
reshape()
squeeze()
unsqueeze()
transpose()
permute()
view()
contiguous()
需要掌握的函数
reshape()、unsqueeze()、permute()、view()
reshape()
在不改变内容的前提下,对其形状做改变。
注意:转换后元素总的个数不能变
python
torch.random.manual_seed(10)
t1 = torch.randint(0,11,[2,3])
print(f"t1 = {t1}")
print(f"t1.shape = {t1.shape}")
t2 = t1.reshape(3,2)
print(f"t2 = {t2}")
print(f"t2.shape = {t2.shape}")

unsqueeze()
在指定的轴上增加一个(1)维度
python
t1 = torch.randint(0, 11, [2, 3])
t2 = t1.unsqueeze(0)
print(f"t2 = {t2}")
print(f"t2.shape = {t2.shape}")
t3 = t1.unsqueeze(1)
print(f"t3 = {t3}")
print(f"t3.shape = {t3.shape}")
t4 = t1.unsqueeze(2)
print(f"t4 = {t4}")
print(f"t4.shape = {t4.shape}")

squeeze()
删除所有为1的维度,等价于降维
python
t1 = torch.randint(0, 11, [2,1,3,1,1])
print(f"t1 = {t1}")
print(f"t1.shape = {t1.shape}")
t2 = t1.squeeze()
print(f"t2 = {t2}")
print(f"t2.shape = {t2.shape}")

transpose()和permute()
transpose() 一次只能交换2个维度
permute() 一次可以同时交换多个维度
python
t1 = torch.randint(0, 11, [2,3,4])
print(f"t1.shape = {t1.shape}")
t2 = t1.transpose(0,1)
print(f"t2.shape = {t2.shape}")
t3 = t1.permute(2,0,1)
print(f"t3.shape = {t3.shape}")

view()和contiguous()
view只修改连续的张量的形状(连续指的是内存的连续)
view可以改变原来的张量比如t1.view(),t1的形状也发生了改变
is_contiguous() 判断张量是否连续
contiguous() 将不连续的张量变成连续的
python
t1 = torch.randint(0, 11, [2,3])
t2 = t1.view(3,2)
print(f"t2.shape = {t2.shape}")
#通过transpose将张量变为不连续的
t1 = t1.transpose(1,0)
# print(f"t1.is_contiguous() = {t1.is_contiguous()}")
# t3 = t1.view(2,3)
# print(f"t3.shape = {t3.shape}")
#通过contiguous()变为连续的然后再转换
t1 = t1.contiguous()
print(f"t1.shape = {t1.shape}")
t4 = t1.view(2,3)
print(f"t4.shape = {t4.shape}")

张量的拼接
cat() 不改变维度数拼接张量,除了拼接的那个维度外其它的维度必须保持一致
stack() 会改变维度,拼接张量,所有的维度都必须保持一致
拼接张量可以是新维度,但是无论新旧维度,所有维度都必须保持一致
cat()
python
t1 = torch.randint(0,5,[3,4])
t2 = torch.randint(0,5,[2,4])
t3 = torch.cat([t1,t2],dim=0)
print(f"t3.shape = {t3.shape}")

stack()
python
t1 = torch.randint(0,5,[2,3])
t2 = torch.randint(0,5,[2,3])
t3 = torch.stack([t1,t2],dim=0)
print(f"t3.shape:{t3.shape}")
t4 = torch.stack([t1,t2],dim=1)
print(f"t4.shape:{t4.shape}")
t5 = torch.stack([t1,t2],dim=2)
print(f"t5.shape:{t5.shape}")
