关于tensorflow的数据类型的文章,网上有很多。本笔记只记录代码。
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
tf.__version__
#constant
a = tf.constant([1.5])
b = tf.constant([True, False])
c = tf.constant('Hello TensorFlow')
d = np.arange(4)
#检查是否是tensor类型
print(tf.is_tensor(a))
print(tf.is_tensor(b))
print(tf.is_tensor(c))
print(tf.is_tensor(d))
#返回数据类型
print(a.dtype)
print(b.dtype)
print(c.dtype)
print(d.dtype)
print(a.dtype == tf.float32)
print(b.dtype == tf.bool)
print(c.dtype == tf.string)
#数据类型转换
a = np.arange(5)
print(a)
#将numpy数组转换为tensor
tensor_a = tf.convert_to_tensor(a)
print(tensor_a)
tensor_a_int64 = tf.convert_to_tensor(a, dtype=tf.int64)
print(tensor_a_int64)
tensor_a_cast_float32 = tf.cast(tensor_a, dtype=tf.float32)
print(tensor_a_cast_float32)
#整型转换为bool
b = tf.constant([1,0,1,0])
b = tf.cast(b, dtype=tf.bool)
print(b)
#TensorFlow变量Variable, Variable的trainable属性为Ture,表示可以进行求导更新
a = tf.range(5)
b = tf.Variable(a)
print(b)
print(tf.is_tensor(b))
print(b.trainable)
#Tensor转回numpy
a = tf.range(4)
print(a)
a_numpy = a.numpy()
print(a_numpy)