Pytorch基础:torch.expand() 和 torch.repeat()

在torch中,如果要改变某一个tensor的维度,可以利用viewexpandrepeattransposepermute等方法,这里对这些方法的一些容易混淆的地方做个总结。

expand和repeat函数是pytorch中常用于进行张量数据复制维度扩展的函数,但其工作机制差别很大,本文对这两个函数进行对比。

1. torch.expand()

  • 作用: expand()函数可以将张量广播到新的形状。
  • 注意: 只能对维度值为1的维度进行扩展无需扩展的维度,维度值不变,对应位置可写上原始维度大小或直接写作-1;且扩展的Tensor不会分配新的内存,只是原来的基础上创建新的视图并返回,返回的张量内存不连续的。类似于numpy中的broadcast_to函数的作用。如果希望张量内存连续,可以调用contiguous函数。

expand函数用于将张量中单数维的数据扩展到指定的size。

首先解释下什么叫单数维(singleton dimensions),张量在某个维度上的size为1,则称为单数维。比如zeros(2,3,4)不存在单数维,而zeros(2,1,4)在第二个维度(即维度1)上为单数维。expand函数仅仅能作用于这些单数维的维度上

参数*sizes用于逐个指定各个维度扩展后的大小(也可以理解为拓展的次数),对于不需要或者无法(即非单数维)进行扩展的维度,对应位置可写上原始维度大小或直接写作-1

expand函数可能导致原始张量的升维,其作用在张量前面的维度上(在tensor的低维增加更多维度),因此通过expand函数可将张量数据复制多份(可理解为沿着第一个batch的维度上)。

py 复制代码
import torch
 
a = torch.tensor([1, 0, 2])     # a -> torch.Size([3])
b1 = a.expand(2, -1)            # 第一个维度为升维,第二个维度保持原样
'''
b1为 -> torch.Size([3, 2])
tensor([[1, 0, 2],
        [1, 0, 2]])
'''
 
a = torch.tensor([[1], [0], [2]])   # a -> torch.Size([3, 1])
b2 = a.expand(-1, 2)                 # 保持第一个维度,第二个维度只有一个元素,可扩展
'''
b2 -> torch.Size([3, 2])
b2为  tensor([[1, 1],
             [0, 0],
             [2, 2]])
'''
 
a = torch.Tensor([[1, 2, 3]])   # a -> torch.Size([1, 3])
b3 = a.expand(4, 3)              # 也可写为a.expand(4, -1)  对于某一个维度上的值为1的维度,
                                # 可以在该维度上进行tensor的复制,若大于1则不行
'''
b3 -> torch.Size([4, 3])
tensor(
	[[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.],
	[1.,2.,3.]]
)
'''
 
a = torch.Tensor([[1, 2, 3], [4, 5, 6]])  # a -> torch.Size([2, 3])
b4 = a.expand(4, 6)  # 最高几个维度的参数必须和原始shape保持一致,否则报错
'''
RuntimeError: The expanded size of the tensor (6) must match 
the existing size (3) at non-singleton dimension 1.
'''
 
b5 = a.expand(1, 2, 3)  # 可以在tensor的低维增加更多维度
'''
b5 -> torch.Size([1,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
b6 = a.expand(2, 2, 3)  # 可以在tensor的低维增加更多维度,同时在新增加的低维度上进行tensor的复制
'''
b5 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
 
b7 = a.expand(2, 3, 2)  # 不可在更高维增加维度,否则报错
'''
RuntimeError: The expanded size of the tensor (2) must match the 
existing size (3) at non-singleton dimension 2.
'''
 
b8 = a.expand(2, -1, -1)  # 最高几个维度的参数可以用-1,表示和原始维度一致
'''
b8 -> torch.Size([2,2, 3])
tensor(
	[[[1.,2.,3.],
	 [4.,5.,6.]],
	 [[1.,2.,3.],
	 [4.,5.,6.]]]
)
'''
 
# expand返回的张量与原版张量具有相同内存地址
print(b8.storage())  # 存储区的数据,说明expand后的a,aa,aaa,aaaa是共享storage的,
# 只是tensor的头信息区设置了不同的数据展示格式,从而使得a,aa,aaa,aaaa呈现不同的tensor形式
'''
1.0
2.0
3.0
4.0
5.0
6.0
'''

1.1 expand_as

可视为expand的另一种表达,其size通过函数传递的目标张量的size来定义。

shell 复制代码
import torch
a = torch.tensor([1, 0, 2])
b = torch.zeros(2, 3)
c = a.expand_as(b)  # a照着b的维度大小进行拓展
# c为 tensor([[1, 0, 2],
#        [1, 0, 2]])

2 tensor.repeat()

沿着特定维度扩展张量,并返回扩展后的张量

  • 作用:和expand()作用类似,均是将tensor广播到新的形状。
  • 注意:不允许使用维度-1,1即为不变
py 复制代码
import torch
 
if __name__ == '__main__':
    x = torch.rand(2, 3)
    y1 = x.repeat(4, 2)
    print(y1.shape)  # torch.Size([8, 6])

3. 两者内存占用的区别

  • torch.expand 不会占用额外空间,只是在存在的张量上创建一个新的视图

  • torch.repeat 和 torch.expand 不同,它是拷贝了数据,会占用额外的空间

示例如下:

py 复制代码
import torch
 
if __name__ == '__main__':
    x = torch.rand(1, 3)
    y1 = x.expand(4, 3)
    y2 = x.repeat(2, 3)
    print(x.storage().data_ptr(), y1.storage().data_ptr())  # 52364352 52364352
    print(x.storage().data_ptr(), y2.storage().data_ptr())  # 52364352 8852096
相关推荐
XianxinMao3 分钟前
2024大模型双向突破:MoE架构创新与小模型崛起
人工智能·架构
Francek Chen15 分钟前
【深度学习基础】多层感知机 | 模型选择、欠拟合和过拟合
人工智能·pytorch·深度学习·神经网络·多层感知机·过拟合
Channing Lewis25 分钟前
python生成随机字符串
服务器·开发语言·python
pchmi1 小时前
C# OpenCV机器视觉:红外体温检测
人工智能·数码相机·opencv·计算机视觉·c#·机器视觉·opencvsharp
资深设备全生命周期管理1 小时前
以Python 做服务器,N Robot 做客户端,小小UI,拿捏
服务器·python·ui
洪小帅1 小时前
Django 的 `Meta` 类和外键的使用
数据库·python·django·sqlite
认知作战壳吉桔1 小时前
中国认知作战研究中心:从认知战角度分析2007年iPhone发布
大数据·人工智能·新质生产力·认知战·认知战研究中心
夏沫mds1 小时前
web3py+flask+ganache的智能合约教育平台
python·flask·web3·智能合约
去往火星1 小时前
opencv在图片上添加中文汉字(c++以及python)
开发语言·c++·python
Bran_Liu1 小时前
【LeetCode 刷题】栈与队列-队列的应用
数据结构·python·算法·leetcode