tensorflow数据操作----学习笔记(二)

广播机制

在上一篇笔记中《tensorflow安装及数据操作----学习笔记(一)》,知道了如何在相同形状上的两个张量执行按元素操作,如果是不同形状的两个张量,我们可以通过调用广播机制来执行按元素操作。

python 复制代码
a = tf.reshape(tf.range(3), (3, 1))
b = tf.reshape(tf.range(2), (1, 2))
print(a)
print(b)
"""
tf.Tensor(
[[0]
 [1]
 [2]], shape=(3, 1), dtype=int32)
tf.Tensor([[0 1]], shape=(1, 2), dtype=int32)
"""

# a和b是两个不同形状的张量,分别为3x1和1x2的矩阵,将两个矩阵广播为一个更大的3x2的矩阵,a将复制列,b将复制行,然后再按元素相加
print(a + b)
"""
tf.Tensor(
[[0 1]
 [1 2]
 [2 3]], shape=(3, 2), dtype=int32)
"""

索引和切片

与Python数组一样,张量中的元素可以通过索引访问。第一个元素索引是0,最后一个元素索引是-1,可以指定范围访问。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
print(X)		# 打印张量
print(X[-1])	# 打印最后一个元素
print(X[0:2])	# 打印第一个到第三个之间的元素元素,不包含第三个
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
tf.Tensor([ 8.  9. 10. 11.], shape=(4,), dtype=float32)
tf.Tensor(
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]], shape=(2, 4), dtype=float32)
 """

TensorFlow中的Tensors(张量)是不可变的,也不能被赋值。TensorFlow中的Variables是支持赋值的可变容器。TensorFlow中的梯度不会通过Variable反向传播。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[2, 3].assign(1)	# 将张量中第三行第四列指定为1
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10.,  1.]], dtype=float32)>
"""

如果我们想为多个元素赋值相同的值,我们只需要索引所有元素,然后为它们赋值。 例如,[0:2, :]访问第1行和第2行,其中":"代表沿轴1(列)的所有元素。 虽然我们讨论的是矩阵的索引,但这也适用于向量和超过2个维度的张量。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[0:2, :].assign(tf.ones(X_var[0:2,:].shape, dtype = tf.float32) * 12)
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[12., 12., 12., 12.],
       [12., 12., 12., 12.],
       [ 8.,  9., 10., 11.]], dtype=float32)>
"""

节省内存

运行一些操作可能会导致为新结果分配内存。 例如,如果我们用Y = X + Y,我们将取消引用Y指向的张量,而是指向新分配的内存处的张量。

在下面的例子中,我们用Python的id()函数演示了这一点, 它给我们提供了内存中引用对象的确切地址。 运行Y = Y + X后,我们会发现id(Y)指向另一个位置。 这是因为Python首先计算Y + X,为结果分配新的内存,然后使Y指向内存中的这个新位置。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
print("befor_id: ", id(Y))
Y = Y + X
print("after_id: ", id(Y))
"""
befor_id: 140225696089184
after_id: 140225696221664
"""

但是这样是不可取的,这样会不必要的分配内存。更新时,可能会无意中还是引用的旧的参数。所以我们可以通过assign将一个操作的结果分配给一个Variable。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
Z = tf.Variable(tf.zeros_like(Y))
print('id(Z):', id(Z))
Z.assign(X + Y)
print('id(Z):', id(Z))
"""
id(Z): 140657476148192
id(Z): 140657476148192
"""

即使你将状态持久存储在Variable中, 你也可能希望避免为不是模型参数的张量过度分配内存,从而进一步减少内存使用量。

由于TensorFlow的Tensors是不可变的,而且梯度不会通过Variable流动, 因此TensorFlow没有提供一种明确的方式来原地运行单个操作。

但是,TensorFlow提供了tf.function修饰符, 将计算封装在TensorFlow图中,该图在运行前经过编译和优化。 这允许TensorFlow删除未使用的值,并复用先前分配的且不再需要的值。 这样可以最大限度地减少TensorFlow计算的内存开销。

python 复制代码
@tf.function
def computation(X, Y):
    Z = tf.zeros_like(Y)  # 这个未使用的值将被删除
    A = X + Y  # 当不再需要时,分配将被复用
    B = A + Y
    C = B + Y
    return C + Y

computation(X, Y)

转换为其他Python对象

TensorFlow的张量和NumPy张量互相转换很容易,转换后的结果不共享内存。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
A = X.numpy()
B = tf.constant(A)
print(type(A))
print(id(A))
print(type(B))
print(id(B))
"""
<class 'numpy.ndarray'>
140005241484944
<class 'tensorflow.python.framework.ops.EagerTensor'>
140005241373792
"""

# 要将size为1的张量转换为Python标量,可以调用item函数或者Python内置函数。
a = tf.constant([3.5]).numpy()
print(a)
print(a.item())
print(float(a))
print(int(a))
"""
[3.5]
3.5
3.5
3
"""
相关推荐
weiabc38 分钟前
学习electron
javascript·学习·electron
fengbizhe1 小时前
笔试-笔记2
c++·笔记
余为民同志1 小时前
mini-lsm通关笔记Week2Day4
笔记
墨染风华不染尘1 小时前
python之开发笔记
开发语言·笔记·python
徐霞客3201 小时前
Qt入门1——认识Qt的几个常用头文件和常用函数
开发语言·c++·笔记·qt
HackKong1 小时前
小白怎样入门网络安全?
网络·学习·安全·web安全·网络安全·黑客
澜世2 小时前
2024小迪安全基础入门第三课
网络·笔记·安全·网络安全
Bald Baby2 小时前
JWT的使用
java·笔记·学习·servlet
心怀梦想的咸鱼2 小时前
UE5 第一人称射击项目学习(四)
学习·ue5
AI完全体2 小时前
【AI日记】24.11.22 学习谷歌数据分析初级课程-第2/3课
学习·数据分析