tensorflow数据操作----学习笔记(二)

广播机制

在上一篇笔记中《tensorflow安装及数据操作----学习笔记(一)》,知道了如何在相同形状上的两个张量执行按元素操作,如果是不同形状的两个张量,我们可以通过调用广播机制来执行按元素操作。

python 复制代码
a = tf.reshape(tf.range(3), (3, 1))
b = tf.reshape(tf.range(2), (1, 2))
print(a)
print(b)
"""
tf.Tensor(
[[0]
 [1]
 [2]], shape=(3, 1), dtype=int32)
tf.Tensor([[0 1]], shape=(1, 2), dtype=int32)
"""

# a和b是两个不同形状的张量,分别为3x1和1x2的矩阵,将两个矩阵广播为一个更大的3x2的矩阵,a将复制列,b将复制行,然后再按元素相加
print(a + b)
"""
tf.Tensor(
[[0 1]
 [1 2]
 [2 3]], shape=(3, 2), dtype=int32)
"""

索引和切片

与Python数组一样,张量中的元素可以通过索引访问。第一个元素索引是0,最后一个元素索引是-1,可以指定范围访问。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
print(X)		# 打印张量
print(X[-1])	# 打印最后一个元素
print(X[0:2])	# 打印第一个到第三个之间的元素元素,不包含第三个
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
tf.Tensor([ 8.  9. 10. 11.], shape=(4,), dtype=float32)
tf.Tensor(
[[0. 1. 2. 3.]
 [4. 5. 6. 7.]], shape=(2, 4), dtype=float32)
 """

TensorFlow中的Tensors(张量)是不可变的,也不能被赋值。TensorFlow中的Variables是支持赋值的可变容器。TensorFlow中的梯度不会通过Variable反向传播。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[2, 3].assign(1)	# 将张量中第三行第四列指定为1
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[ 0.,  1.,  2.,  3.],
       [ 4.,  5.,  6.,  7.],
       [ 8.,  9., 10.,  1.]], dtype=float32)>
"""

如果我们想为多个元素赋值相同的值,我们只需要索引所有元素,然后为它们赋值。 例如,[0:2, :]访问第1行和第2行,其中":"代表沿轴1(列)的所有元素。 虽然我们讨论的是矩阵的索引,但这也适用于向量和超过2个维度的张量。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
X_var = tf.Variable(X)
X_var[0:2, :].assign(tf.ones(X_var[0:2,:].shape, dtype = tf.float32) * 12)
print(X)
print(X_var)
"""
tf.Tensor(
[[ 0.  1.  2.  3.]
 [ 4.  5.  6.  7.]
 [ 8.  9. 10. 11.]], shape=(3, 4), dtype=float32)
<tf.Variable 'Variable:0' shape=(3, 4) dtype=float32, numpy=
array([[12., 12., 12., 12.],
       [12., 12., 12., 12.],
       [ 8.,  9., 10., 11.]], dtype=float32)>
"""

节省内存

运行一些操作可能会导致为新结果分配内存。 例如,如果我们用Y = X + Y,我们将取消引用Y指向的张量,而是指向新分配的内存处的张量。

在下面的例子中,我们用Python的id()函数演示了这一点, 它给我们提供了内存中引用对象的确切地址。 运行Y = Y + X后,我们会发现id(Y)指向另一个位置。 这是因为Python首先计算Y + X,为结果分配新的内存,然后使Y指向内存中的这个新位置。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
print("befor_id: ", id(Y))
Y = Y + X
print("after_id: ", id(Y))
"""
befor_id: 140225696089184
after_id: 140225696221664
"""

但是这样是不可取的,这样会不必要的分配内存。更新时,可能会无意中还是引用的旧的参数。所以我们可以通过assign将一个操作的结果分配给一个Variable。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
Y = tf.constant([[2.0, 1, 4, 3], [1, 2, 3, 4], [4, 3, 2, 1]])
Z = tf.Variable(tf.zeros_like(Y))
print('id(Z):', id(Z))
Z.assign(X + Y)
print('id(Z):', id(Z))
"""
id(Z): 140657476148192
id(Z): 140657476148192
"""

即使你将状态持久存储在Variable中, 你也可能希望避免为不是模型参数的张量过度分配内存,从而进一步减少内存使用量。

由于TensorFlow的Tensors是不可变的,而且梯度不会通过Variable流动, 因此TensorFlow没有提供一种明确的方式来原地运行单个操作。

但是,TensorFlow提供了tf.function修饰符, 将计算封装在TensorFlow图中,该图在运行前经过编译和优化。 这允许TensorFlow删除未使用的值,并复用先前分配的且不再需要的值。 这样可以最大限度地减少TensorFlow计算的内存开销。

python 复制代码
@tf.function
def computation(X, Y):
    Z = tf.zeros_like(Y)  # 这个未使用的值将被删除
    A = X + Y  # 当不再需要时,分配将被复用
    B = A + Y
    C = B + Y
    return C + Y

computation(X, Y)

转换为其他Python对象

TensorFlow的张量和NumPy张量互相转换很容易,转换后的结果不共享内存。

python 复制代码
X = tf.reshape(tf.range(12, dtype=tf.float32), (3, 4))
A = X.numpy()
B = tf.constant(A)
print(type(A))
print(id(A))
print(type(B))
print(id(B))
"""
<class 'numpy.ndarray'>
140005241484944
<class 'tensorflow.python.framework.ops.EagerTensor'>
140005241373792
"""

# 要将size为1的张量转换为Python标量,可以调用item函数或者Python内置函数。
a = tf.constant([3.5]).numpy()
print(a)
print(a.item())
print(float(a))
print(int(a))
"""
[3.5]
3.5
3.5
3
"""
相关推荐
楼田莉子1 小时前
C++算法题目分享:二叉搜索树相关的习题
数据结构·c++·学习·算法·leetcode·面试
十一10241 小时前
FX10/20 (CYUSB401X)开发笔记5 固件架构
笔记
FakeOccupational2 小时前
【电路笔记 通信】AXI4-Lite协议 FPGA实现 & Valid-Ready Handshake 握手协议
笔记·fpga开发
奶黄小甜包2 小时前
C语言零基础第18讲:自定义类型—结构体
c语言·数据结构·笔记·学习
rannn_1115 小时前
【MySQL学习|黑马笔记|Day7】触发器和锁(全局锁、表级锁、行级锁、)
笔记·后端·学习·mysql
喜欢吃燃面5 小时前
C++算法竞赛:位运算
开发语言·c++·学习·算法
传奇开心果编程5 小时前
【传奇开心果系列】Flet框架实现的家庭记账本示例自定义模板
python·学习·ui·前端框架·自动化
草莓熊Lotso5 小时前
《详解 C++ Date 类的设计与实现:从运算符重载到功能测试》
开发语言·c++·经验分享·笔记·其他
_Kayo_12 小时前
node.js 学习笔记3 HTTP
笔记·学习
CCCC131016315 小时前
嵌入式学习(day 28)线程
jvm·学习