一个Java老鸟的TensorFlow入门------从计算图到GradientTape
写了20年Java,突然要学TensorFlow,第一反应是:这东西怎么这么绕?TF 1.x的计算图、Session、placeholder,跟Java的思维方式完全不一样。后来TF 2.x出了GradientTape,终于顺畅了。这篇记录我从零开始学TensorFlow的过程,不是教程,是一个老程序员的踩坑笔记。
一、TF 1.x:先建图,再跑图
第一个程序:常量加法
python
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()
a = tf.constant(3.0, name="node1")
b = tf.constant(4.0, name="node2")
c = tf.add(a, b)
with tf.Session() as sess:
print(sess.run(c)) # 7.0
Java程序员的困惑:为什么不直接3.0 + 4.0?
因为TF 1.x的设计思路是"先画蓝图,再施工"。tf.constant(3.0)不是在算3.0,是在图里画了一个节点。tf.add(a, b)也不是在算加法,是在图里画了一条从a、b到c的边。直到sess.run(c),施工队才开始干活。
这种"声明式"编程在Java里也有类似的东西------SQL。你写SELECT * FROM t WHERE id = 1,也不是在执行,是在描述你想要什么,数据库引擎去执行。TF 1.x的计算图也是这个意思。
第二个程序:变量与累加
python
value1 = tf.Variable(0.0)
const1 = tf.constant(1.0)
sum1 = tf.Variable(0.0)
new_value1 = tf.add(value1, const1)
value1 = value1.assign(new_value1)
sum1 = tf.assign_add(sum1, value1)
sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)
for i in range(10):
result = sess.run([value1, sum1])
print("第%d次, 累加:%d, 和:%d" % (i+1, result[0], result[1]))
这里有两点跟Java不一样:
- 变量要初始化 ------
tf.global_variables_initializer(),不调这个,变量是空的。Java里int i = 0就完了,TF里要显式告诉Session"请初始化所有变量"。 - 赋值是操作不是语句 ------
value1.assign(new_value1)返回的是一个操作节点,不是立刻赋值。得sess.run()才生效。
第三个程序:占位符(placeholder)
python
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
c = tf.add(a, b)
d = tf.multiply(a, b)
with tf.Session() as sess:
result = sess.run([c, d], feed_dict={a: [1.0, 2.0, 3.0], b: [4.0, 5.0, 6.0]})
print(result[0]) # [5.0, 7.0, 9.0]
print(result[1]) # [4.0, 10.0, 18.0]
placeholder就是方法的参数 。先在图里留个坑,运行的时候用feed_dict填数据。Java程序员可以理解为接口定义------你声明了参数类型,调用时传具体值。
还顺手把计算图写到了TensorBoard日志:
python
writer = tf.summary.FileWriter("e:\\log", tf.get_default_graph())
打开TensorBoard可以看到可视化计算图------节点和边的拓扑结构。调试时很有用。
二、TF 1.x的痛点
学了三个例子之后,我感觉到几个不舒服的地方:
- 所有东西都得在图里 ------想打个中间变量的值?
sess.run()。想看类型?图里没有运行时类型。 - 调试困难------图建好了,跑不了断点。出错了报错信息跟图节点名相关,不是Python代码行号。
- 代码啰嗦------建图、初始化、Session、feed_dict,干个加法要写一堆。
这不是TF的问题,是"声明式"编程的代价。SQL也有类似问题------复杂SQL调试起来也很难。
三、TF 2.x:终于像正常代码了
GradientTape做多项式回归
python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
X = np.linspace(-1, 1, 100)
Y = 0.5 * X**2 + 0.5 * X + 2 + np.random.normal(0, 0.05, (100,))
X_train, Y_train = X[:70], Y[:70]
X_test, Y_test = X[70:], Y[70:]
W1 = tf.Variable(np.random.randn())
W2 = tf.Variable(np.random.randn())
b = tf.Variable(np.random.randn())
def linear_regression(x):
return W1 * x**2 + W2 * x + b
optimizer = tf.optimizers.SGD(learning_rate=0.01)
for step in range(100):
with tf.GradientTape() as tape:
pred = linear_regression(X_train)
loss = tf.reduce_mean(tf.square(pred - Y_train))
gradients = tape.gradient(loss, [W1, W2, b])
optimizer.apply_gradients(zip(gradients, [W1, W2, b]))
if (step + 1) % 20 == 0:
print("Step: %i, loss: %f, W1: %f, W2: %f, b: %f"
% (step+1, loss, W1.numpy(), W2.numpy(), b.numpy()))
对比TF 1.x,变化巨大:
- 不需要Session了------直接执行,像正常Python代码
- 不需要建图了 ------
GradientTape自动记录前向计算过程 - 调试方便 ------
W1.numpy()随时可以看值,不需要sess.run() - 代码量少了一半
GradientTape的核心思想 :用with tf.GradientTape() as tape包住前向计算,TF自动记录所有操作。然后tape.gradient(loss, [参数])自动求导。不需要手写反向传播,不需要理解链式法则的推导过程。
Java程序员可以类比:TF 1.x像JDBC(手动管理连接、Statement、ResultSet),TF 2.x像MyBatis(框架帮你搞定底层,你只写业务逻辑)。
四、Keras:加载现成数据集
python
from keras.api.datasets import mnist, imdb
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape) # (60000, 28, 28)
(train_datas, train_labels), (_, _) = imdb.load_data()
word_index = imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decode_view = ''.join(reverse_word_index.get(i-3, '?') for i in train_datas[3])
print(decode_view)
Keras内置了常用数据集,mnist.load_data()直接下载手写数字,imdb.load_data()直接下载电影评论。IMDB的数据已经转成了词索引,通过word_index反查可以还原原始文本。
这一步没什么技术含量,但省了很多数据准备的时间。学习阶段用现成数据集,项目阶段用自己的数据------这个节奏是对的。
五、总结:一个Java老兵的TF学习路径
| 阶段 | 我做了什么 | 关键收获 |
|---|---|---|
| TF 1.x常量 | 建图、Session、run() | 理解"声明式"编程 |
| TF 1.x变量 | Variable、assign、初始化 | 变量是图的一部分 |
| TF 1.x占位符 | placeholder、feed_dict | 参数化计算图 |
| TF 2.x GradientTape | 自动求导、多项式回归 | 终于像正常代码了 |
| Keras数据集 | MNIST、IMDB加载 | 数据准备的起点 |
最大的体会 :如果你现在开始学TensorFlow,直接学TF 2.x。TF 1.x的计算图概念了解一下就行(很多老教程和老项目还在用),但写代码用2.x。GradientTape + Eager Execution,学习曲线平很多。
环境搭建我踩的坑:
- Python版本:用3.9-3.11,太新可能TF不支持
- TensorFlow安装:
pip install tensorflow,GPU版装tensorflow-gpu(需要CUDA和cuDNN,很折腾,学习阶段CPU够用) - 如果只是学基础,CPU版就行,MNIST和线性回归秒跑完
相关阅读: