PyTorch核心概念:从梯度、计算图到连续性的全面解析(三)

文章目录

  • [Contiguous vs Non-Contiguous Tensor](#Contiguous vs Non-Contiguous Tensor)
    • [Tensor and View](#Tensor and View)
    • Strides
    • [非连续数据结构:Transpose( )](#非连续数据结构:Transpose( ))
    • [在 PyTorch 中检查Contiguous and Non-Contiguous](#在 PyTorch 中检查Contiguous and Non-Contiguous)
  • 参考文献

Contiguous vs Non-Contiguous Tensor

Tensor and View

View使用与原始张量相同的数据块,只是"view"其维度的方式不同
视图只不过是解释原始张量维度的另一种方法,而无需在内存中进行物理复制。例如,我们有一个 1x12 张量,即 [1,2,3,4,5,6,7,8,9,10,11,12],然后使用 .view(4,3) 来改变形状将张量转换为 4x3 结构

python 复制代码
x = torch.arange(1,13)
print(x)
>> tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

x = torch.arange(1,13)
y = x.view(4,3)
print(y)
>>
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

如果更改原始张量 x 中的数据,它也会反映在视图张量 y 中,因为视图张量 y 不是创建原始张量 x 的另一个副本,而是从与原始张量相同的内存地址读取数据X。反之亦然,视图张量中的值的更改将同时更改原始张量中的值,因为视图张量及其原始张量共享同一块内存块

python 复制代码
x = torch.arange(1,13)
y = x.view(4,3)
x[0] = 100
print(y)
>> 
tensor([[100,   2,   3],
        [  4,   5,   6],
        [  7,   8,   9],
        [ 10,  11,  12]])
        
x = torch.arange(1,13)
y = x.view(4,3)
y[-1,-1] = 1000
print(x)
>> tensor([   1,    2,    3,    4,    5,    6,    7,    8,    9,   10,   11, 1000])

可以以连续的方式查看不同维度的数据序列

一维张量A中的元素数量为T,经过view()处理之后的张量B,shape为(K,M,N),则需满足 K × M × N = T K\times M\times N=T K×M×N=T

Strides

python 复制代码
# x is a contiguous data. Recall that view() doesn't change data arrangement in the original 1D tensor
x = torch.arange(1,13).view(6,2)
x
>>
tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12]])
        
# Check stride
x.stride()
>> (2, 1)

步长 (2, 1) 告诉我们:我们需要跨过 1 个(维度 0)数字才能到达沿轴 0 的下一个数字,并且需要跨过 2 个(维度 1)数字才能到达沿轴 1 的下一个数字

python 复制代码
y = torch.arange(0,11).view(2,2,3)
y
>>
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],
        [[ 6,  7,  8],
         [ 9, 10, 11]]])
         
# Check stride
y.stride()
>> (6, 3, 1)

检索一维张量中 (A, B, C) 位置的公式如下: A × 6 + B × 3 + C × 1 A \times 6 + B \times 3 + C \times 1 A×6+B×3+C×1

非连续数据结构:Transpose( )

首先,Transpose(axis1, axis2) 只是"swapping the way axis1 and axis2 strides"

python 复制代码
# Initiate a contiguous tensor
x = torch.arange(0,12).view(2,2,3)
x
>>
tensor([[[ 0,  1,  2],
         [ 3,  4,  5]],
        [[ 6,  7,  8],
         [ 9, 10, 11]]])
         
x.stride()
>> (6,3,1)

# Now let's transpose axis 0 and 1, and see how the strides swap
y = x.transpose(0,2)
y
>>
tensor([[[ 0,  6],
         [ 3,  9]],
        [[ 1,  7],
         [ 4, 10]],
        [[ 2,  8],
         [ 5, 11]]])
         
y.stride()
>> (1,3,6)

y 是 x.transpose(0,2),它交换 x 张量在轴 0 和轴 2 上的stride,因此 y 的stride是 (1,3,6)。这意味着我们需要跳转 6 个数字才能获取第 0 轴的下一个数字,跳转 3 个数字才能获取第 1 轴的下一个数字,跳转 1 个数字才能获取第 2 轴的下一个数字(stride公式: A × 1 + B × 3 + C × 6 A \times 1+ B \times 3+C \times 6 A×1+B×3+C×6)

transpose的不同之处在于:现在数据序列不再遵循连续的顺序 。它不会从最内层维度逐一填充顺序数据,填满后跳转到下一个维度。现在它在最里面的维度跳跃了6个数字,所以它不是连续的
transpose( ) 具有不连续的数据结构,但仍然是视图而不是副本 ⇒ \Rightarrow ⇒它是一个不连续的"视图",改变了原始数据的stride方式

python 复制代码
# Change the value in a transpose tensor y
x = torch.arange(0,12).view(2,6)
y = x.transpose(0,1)
y[0,0] = 100
y
>>
tensor([[100,   6],
        [  1,   7],
        [  2,   8],
        [  3,   9],
        [  4,  10],
        [  5,  11]])
# Check the original tensor x
x
>>
tensor([[100,   1,   2,   3,   4,   5],
        [  6,   7,   8,   9,  10,  11]])

在 PyTorch 中检查Contiguous and Non-Contiguous

使用PyTorch中的 .is_contigious() 检查张量是否连续

python 复制代码
x = torch.arange(0,12).view(2,6)
x.is_contiguous()
>> True

y = x.transpose(0,1)
y.is_contiguous()
>> False

将不连续张量(或视图)转换为连续张量

使用PyTorch中的 .contigious() 将不连续的张量转换成连续的张量

python 复制代码
z = y.contiguous()
z.is_contiguous()
>> TRUE

** .contigious() 复制原始的"non-contiguous"张量,然后按照连续顺序将其保存到新的内存块中**

python 复制代码
# This is contiguous
x = torch.arange(1,13).view(2,3,2)
x.stride()
>> (6, 2, 1)

# This is non-contiguous
y = x.transpose(0,1)
y.stride()
>> (2, 6, 1)

# This is a converted contiguous tensor with new stride
z = y.contiguous()
z.stride()
>> (4, 2, 1)

print(z.shape)
>> (3, 2, 2)

# The stride across the first dimension is 2*2
# The stride across the second dimension is 2*1
# The stride across the third dimension is 1
(4, 2, 1)=>(2*2, 2*1, 1)

用来区分张量/视图是否连续的一种方法是观察stride中的 ( A , B , C ) (A, B, C) (A,B,C) 是否满足 A > B > C A > B > C A>B>C。如果不满足,则意味着至少有一个维度正在跳过的距离比其上方的维度更长,这使得它不连续

我们还可以观察转换后的连续张量 z 如何以新的顺序存储数据

python 复制代码
# y is a non-contiguous 'view' (remember view uses the original chunk of data in memory, but its strides implies 'non-contiguous', (2,6,1).
y.storage()
>>
 1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 
# Z is a 'contiguous' tensor (not a view, but a new copy of the original data. Notice the order of the data is different). It strides implies 'contiguous', (4,2,1)
z.storage()
>>
 1
 2
 7
 8
 3
 4
 9
 10
 5
 6
 11
 12

view() 和 reshape() 之间的区别

虽然这两个函数都可以改变张量的维度,但两者之间的主要区别是:

  1. view():不复制原始张量,使用与原始张量相同的数据块,仅适用于连续数据
  2. reshape():当数据连续时,尽可能返回视图;当数据不连续时,则将数据复制到连续的数据块中,作为副本,它会占用内存空间,而且新张量的变化不会影响原始张量中的原始数值
python 复制代码
# When data is contiguous
x = torch.arange(1,13)
x
>> tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12])

# Reshape returns a view with the new dimension
y = x.reshape(4,3)
y
>>
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
        
# How do we know it's a view? Because the element change in new tensor y would affect the value in x, and vice versa
y[0,0] = 100
y
>>
tensor([[100,   2,   3],
        [  4,   5,   6],
        [  7,   8,   9],
        [ 10,  11,  12]])
        
print(x)
>>
tensor([100,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12])

接下来,让我们看看 reshape() 如何处理非连续数据:

python 复制代码
# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
        
# Reshape() works fine on a non-contiguous data
y = x.reshape(4,3)
y
>>
tensor([[ 1,  3,  5],
        [ 7,  9, 11],
        [ 2,  4,  6],
        [ 8, 10, 12]])
        
# Change an element in y
y[0,0] = 100
y
>>
tensor([[100,   3,   5],
        [  7,   9,  11],
        [  2,   4,   6],
        [  8,  10,  12]])
        
# Check the original tensor, and nothing was changed
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])

最后,让我们看看 view() 是否可以处理非连续数据。No, it can't!

python 复制代码
# After transpose(), the data is non-contiguous
x = torch.arange(1,13).view(6,2).transpose(0,1)
x
>>
tensor([[ 1,  3,  5,  7,  9, 11],
        [ 2,  4,  6,  8, 10, 12]])
        
# Try to use view on the non-contiguous data
y = x.view(4,3)
y
>>
-------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
----> 1 y = x.view(4,3)
      2 y
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

总结

  1. view"使用与原始张量相同的内存块,因此该内存块中的任何更改都会影响所有视图以及与其关联的原始张量
  2. 视图可以是连续的或不连续的。一个不连续的张量视图可以转换为连续的张量视图,并且会复制不连续的视图张量到新的内存空间中,因此数据将不再与原始数据块关联
  3. stride位置公式:给定一个stride ( A , B , C ) (A,B,C) (A,B,C),索引 ( j , k , v ) (j, k, v) (j,k,v) 在 1D 数据数组中的位置为 ( A × j + B × k + C × v ) (A \times j + B \times k + C \times v) (A×j+B×k+C×v)
  4. view()reshape() 之间的区别:view() 不能应用于 '非连续的张量/视图,它返回一个视图;reshape() 可以应用于"连续"和"非连续"张量/视图

《PyTorch核心概念:从梯度、计算图到连续性的全面解析(一)》
《PyTorch核心概念:从梯度、计算图到连续性的全面解析(二)》

参考文献

1、Contiguous vs Non-Contiguous Tensor / View --- Understanding view(), reshape(), transpose()

相关推荐
一念之坤1 小时前
零基础学Python之数据结构 -- 01篇
数据结构·python
wxl7812272 小时前
如何使用本地大模型做数据分析
python·数据挖掘·数据分析·代码解释器
NoneCoder2 小时前
Python入门(12)--数据处理
开发语言·python
LKID体2 小时前
Python操作neo4j库py2neo使用(一)
python·oracle·neo4j
小尤笔记3 小时前
利用Python编写简单登录系统
开发语言·python·数据分析·python基础
FreedomLeo13 小时前
Python数据分析NumPy和pandas(四十、Python 中的建模库statsmodels 和 scikit-learn)
python·机器学习·数据分析·scikit-learn·statsmodels·numpy和pandas
007php0073 小时前
GoZero 上传文件File到阿里云 OSS 报错及优化方案
服务器·开发语言·数据库·python·阿里云·架构·golang
Tech Synapse3 小时前
Python网络爬虫实践案例:爬取猫眼电影Top100
开发语言·爬虫·python
一行玩python4 小时前
SQLAlchemy,ORM的Python标杆!
开发语言·数据库·python·oracle
数据小爬虫@4 小时前
利用Python爬虫获取淘宝店铺详情
开发语言·爬虫·python