tensorflow数据操作----学习笔记(二)

广播机制

在上一篇笔记中《tensorflow安装及数据操作----学习笔记(一)》,知道了如何在相同形状上的两个张量执行按元素操作,如果是不同形状的两个张量,我们可以通过调用广播机制来执行按元素操作。

python 复制代码
a = tf.reshape(tf.range(3), (3, 1))
b = tf.reshape(tf.range(2), (1, 2))
print(a)
print(b)
"""
tf.Tensor(
[[0]
 [1]
 [2]], shape=(3, 1), dtype=int32)
tf.Tensor([[0 1]], shape=(1, 2), dtype=int32)
"""

# a和b是两个不同形状的张量,分别为3x1和1x2的矩阵,将两个矩阵广播为一个更大的3x2的矩阵,a将复制列,b将复制行,然后再按元素相加
print(a + b)
"""
tf.Tensor(
[[0 1]
 [1 2]
 [2 3]], shape=(3, 2), dtype=int32)
"""

索引和切片

与Python数组一样,张量中的元素可以通过索引访问。第一个元素索引是0,最后一个元素索引是-1,可以指定范围访问。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
print(X)		# 打印张量
print(X[-1])	# 打印最后一个元素
print(X[0:2])	# 打印第一个到第三个之间的元素元素,不包含第三个
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
tf.Tensor([ 8.  9. 10. 11.], shape=(4,), dtype=float32)
tf.Tensor(
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]], shape=(2, 4), dtype=float32)
 """

TensorFlow中的Tensors(张量)是不可变的,也不能被赋值。TensorFlow中的Variables是支持赋值的可变容器。TensorFlow中的梯度不会通过Variable反向传播。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[2, 3].assign(1)	# 将张量中第三行第四列指定为1
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10.,  1.]], dtype=float32)>
"""

如果我们想为多个元素赋值相同的值,我们只需要索引所有元素,然后为它们赋值。 例如,[0:2, :]访问第1行和第2行,其中":"代表沿轴1(列)的所有元素。 虽然我们讨论的是矩阵的索引,但这也适用于向量和超过2个维度的张量。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[0:2, :].assign(tf.ones(X_var[0:2,:].shape, dtype = tf.float32) * 12)
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[12., 12., 12., 12.],
       [12., 12., 12., 12.],
       [ 8.,  9., 10., 11.]], dtype=float32)>
"""

节省内存

运行一些操作可能会导致为新结果分配内存。 例如,如果我们用Y = X + Y,我们将取消引用Y指向的张量,而是指向新分配的内存处的张量。

在下面的例子中,我们用Python的id()函数演示了这一点, 它给我们提供了内存中引用对象的确切地址。 运行Y = Y + X后,我们会发现id(Y)指向另一个位置。 这是因为Python首先计算Y + X,为结果分配新的内存,然后使Y指向内存中的这个新位置。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
print("befor_id: ", id(Y))
Y = Y + X
print("after_id: ", id(Y))
"""
befor_id: 140225696089184
after_id: 140225696221664
"""

但是这样是不可取的,这样会不必要的分配内存。更新时,可能会无意中还是引用的旧的参数。所以我们可以通过assign将一个操作的结果分配给一个Variable。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
Z = tf.Variable(tf.zeros_like(Y))
print('id(Z):', id(Z))
Z.assign(X + Y)
print('id(Z):', id(Z))
"""
id(Z): 140657476148192
id(Z): 140657476148192
"""

即使你将状态持久存储在Variable中, 你也可能希望避免为不是模型参数的张量过度分配内存,从而进一步减少内存使用量。

由于TensorFlow的Tensors是不可变的,而且梯度不会通过Variable流动, 因此TensorFlow没有提供一种明确的方式来原地运行单个操作。

但是,TensorFlow提供了tf.function修饰符, 将计算封装在TensorFlow图中,该图在运行前经过编译和优化。 这允许TensorFlow删除未使用的值,并复用先前分配的且不再需要的值。 这样可以最大限度地减少TensorFlow计算的内存开销。

python 复制代码
@tf.function
def computation(X, Y):
    Z = tf.zeros_like(Y)  # 这个未使用的值将被删除
    A = X + Y  # 当不再需要时,分配将被复用
    B = A + Y
    C = B + Y
    return C + Y

computation(X, Y)

转换为其他Python对象

TensorFlow的张量和NumPy张量互相转换很容易,转换后的结果不共享内存。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
A = X.numpy()
B = tf.constant(A)
print(type(A))
print(id(A))
print(type(B))
print(id(B))
"""
<class 'numpy.ndarray'>
140005241484944
<class 'tensorflow.python.framework.ops.EagerTensor'>
140005241373792
"""

# 要将size为1的张量转换为Python标量,可以调用item函数或者Python内置函数。
a = tf.constant([3.5]).numpy()
print(a)
print(a.item())
print(float(a))
print(int(a))
"""
[3.5]
3.5
3.5
3
"""
相关推荐
时光追逐者18 分钟前
MongoDB从入门到实战之MongoDB快速入门(附带学习路线图)
数据库·学习·mongodb
一弓虽23 分钟前
SpringBoot 学习
java·spring boot·后端·学习
晓数1 小时前
【硬核干货】JetBrains AI Assistant 干货笔记
人工智能·笔记·jetbrains·ai assistant
我的golang之路果然有问题2 小时前
速成GO访问sql,个人笔记
经验分享·笔记·后端·sql·golang·go·database
genggeng不会代码2 小时前
用于协同显著目标检测的小组协作学习 2021 GCoNet(总结)
学习
lwewan2 小时前
26考研——存储系统(3)
c语言·笔记·考研
搞机小能手2 小时前
六个能够白嫖学习资料的网站
笔记·学习·分类
nongcunqq3 小时前
爬虫练习 js 逆向
笔记·爬虫
汐汐咯3 小时前
终端运行java出现???
笔记
The_cute_cat5 小时前
25.4.22学习总结
学习