# 一个Java老鸟的TensorFlow入门——从计算图到GradientTape

一个Java老鸟的TensorFlow入门------从计算图到GradientTape

写了20年Java,突然要学TensorFlow,第一反应是:这东西怎么这么绕?TF 1.x的计算图、Session、placeholder,跟Java的思维方式完全不一样。后来TF 2.x出了GradientTape,终于顺畅了。这篇记录我从零开始学TensorFlow的过程,不是教程,是一个老程序员的踩坑笔记。


一、TF 1.x:先建图,再跑图

第一个程序:常量加法

python 复制代码
import tensorflow.compat.v1 as tf
tf.disable_eager_execution()

a = tf.constant(3.0, name="node1")
b = tf.constant(4.0, name="node2")
c = tf.add(a, b)

with tf.Session() as sess:
    print(sess.run(c))  # 7.0

Java程序员的困惑:为什么不直接3.0 + 4.0

因为TF 1.x的设计思路是"先画蓝图,再施工"。tf.constant(3.0)不是在算3.0,是在图里画了一个节点。tf.add(a, b)也不是在算加法,是在图里画了一条从a、b到c的边。直到sess.run(c),施工队才开始干活。

这种"声明式"编程在Java里也有类似的东西------SQL。你写SELECT * FROM t WHERE id = 1,也不是在执行,是在描述你想要什么,数据库引擎去执行。TF 1.x的计算图也是这个意思。

第二个程序:变量与累加

python 复制代码
value1 = tf.Variable(0.0)
const1 = tf.constant(1.0)
sum1 = tf.Variable(0.0)

new_value1 = tf.add(value1, const1)
value1 = value1.assign(new_value1)
sum1 = tf.assign_add(sum1, value1)

sess = tf.Session()
init = tf.global_variables_initializer()
sess.run(init)

for i in range(10):
    result = sess.run([value1, sum1])
    print("第%d次, 累加:%d, 和:%d" % (i+1, result[0], result[1]))

这里有两点跟Java不一样:

  1. 变量要初始化 ------tf.global_variables_initializer(),不调这个,变量是空的。Java里int i = 0就完了,TF里要显式告诉Session"请初始化所有变量"。
  2. 赋值是操作不是语句 ------value1.assign(new_value1)返回的是一个操作节点,不是立刻赋值。得sess.run()才生效。

第三个程序:占位符(placeholder)

python 复制代码
a = tf.placeholder(tf.float32, name="a")
b = tf.placeholder(tf.float32, name="b")
c = tf.add(a, b)
d = tf.multiply(a, b)

with tf.Session() as sess:
    result = sess.run([c, d], feed_dict={a: [1.0, 2.0, 3.0], b: [4.0, 5.0, 6.0]})
    print(result[0])  # [5.0, 7.0, 9.0]
    print(result[1])  # [4.0, 10.0, 18.0]

placeholder就是方法的参数 。先在图里留个坑,运行的时候用feed_dict填数据。Java程序员可以理解为接口定义------你声明了参数类型,调用时传具体值。

还顺手把计算图写到了TensorBoard日志:

python 复制代码
writer = tf.summary.FileWriter("e:\\log", tf.get_default_graph())

打开TensorBoard可以看到可视化计算图------节点和边的拓扑结构。调试时很有用。


二、TF 1.x的痛点

学了三个例子之后,我感觉到几个不舒服的地方:

  1. 所有东西都得在图里 ------想打个中间变量的值?sess.run()。想看类型?图里没有运行时类型。
  2. 调试困难------图建好了,跑不了断点。出错了报错信息跟图节点名相关,不是Python代码行号。
  3. 代码啰嗦------建图、初始化、Session、feed_dict,干个加法要写一堆。

这不是TF的问题,是"声明式"编程的代价。SQL也有类似问题------复杂SQL调试起来也很难。


三、TF 2.x:终于像正常代码了

GradientTape做多项式回归

python 复制代码
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

np.random.seed(0)
X = np.linspace(-1, 1, 100)
Y = 0.5 * X**2 + 0.5 * X + 2 + np.random.normal(0, 0.05, (100,))
X_train, Y_train = X[:70], Y[:70]
X_test, Y_test = X[70:], Y[70:]

W1 = tf.Variable(np.random.randn())
W2 = tf.Variable(np.random.randn())
b = tf.Variable(np.random.randn())

def linear_regression(x):
    return W1 * x**2 + W2 * x + b

optimizer = tf.optimizers.SGD(learning_rate=0.01)

for step in range(100):
    with tf.GradientTape() as tape:
        pred = linear_regression(X_train)
        loss = tf.reduce_mean(tf.square(pred - Y_train))
    gradients = tape.gradient(loss, [W1, W2, b])
    optimizer.apply_gradients(zip(gradients, [W1, W2, b]))
    if (step + 1) % 20 == 0:
        print("Step: %i, loss: %f, W1: %f, W2: %f, b: %f"
              % (step+1, loss, W1.numpy(), W2.numpy(), b.numpy()))

对比TF 1.x,变化巨大:

  1. 不需要Session了------直接执行,像正常Python代码
  2. 不需要建图了 ------GradientTape自动记录前向计算过程
  3. 调试方便 ------W1.numpy()随时可以看值,不需要sess.run()
  4. 代码量少了一半

GradientTape的核心思想 :用with tf.GradientTape() as tape包住前向计算,TF自动记录所有操作。然后tape.gradient(loss, [参数])自动求导。不需要手写反向传播,不需要理解链式法则的推导过程。

Java程序员可以类比:TF 1.x像JDBC(手动管理连接、Statement、ResultSet),TF 2.x像MyBatis(框架帮你搞定底层,你只写业务逻辑)。


四、Keras:加载现成数据集

python 复制代码
from keras.api.datasets import mnist, imdb

(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
print(train_images.shape)  # (60000, 28, 28)

(train_datas, train_labels), (_, _) = imdb.load_data()
word_index = imdb.get_word_index()
reverse_word_index = dict([(value, key) for (key, value) in word_index.items()])
decode_view = ''.join(reverse_word_index.get(i-3, '?') for i in train_datas[3])
print(decode_view)

Keras内置了常用数据集,mnist.load_data()直接下载手写数字,imdb.load_data()直接下载电影评论。IMDB的数据已经转成了词索引,通过word_index反查可以还原原始文本。

这一步没什么技术含量,但省了很多数据准备的时间。学习阶段用现成数据集,项目阶段用自己的数据------这个节奏是对的。


五、总结:一个Java老兵的TF学习路径

阶段 我做了什么 关键收获
TF 1.x常量 建图、Session、run() 理解"声明式"编程
TF 1.x变量 Variable、assign、初始化 变量是图的一部分
TF 1.x占位符 placeholder、feed_dict 参数化计算图
TF 2.x GradientTape 自动求导、多项式回归 终于像正常代码了
Keras数据集 MNIST、IMDB加载 数据准备的起点

最大的体会 :如果你现在开始学TensorFlow,直接学TF 2.x。TF 1.x的计算图概念了解一下就行(很多老教程和老项目还在用),但写代码用2.x。GradientTape + Eager Execution,学习曲线平很多。

环境搭建我踩的坑

  • Python版本:用3.9-3.11,太新可能TF不支持
  • TensorFlow安装:pip install tensorflow,GPU版装tensorflow-gpu(需要CUDA和cuDNN,很折腾,学习阶段CPU够用)
  • 如果只是学基础,CPU版就行,MNIST和线性回归秒跑完

相关阅读:

相关推荐
itzixiao2 小时前
L1-055 谁是赢家(10 分)[java][python]
java·python·算法
IT利刃出鞘2 小时前
Java反射--PropertyDescriptor的使用
java·开发语言
所愿ღ2 小时前
SSM框架-Spring1
java·开发语言·笔记·spring
invicinble2 小时前
对于泛型的设计思路
java
A_aspectJ2 小时前
【Java基础开发】基于 Java Swing 开发的简易计算器 - 支持键盘
java·开发语言
2501_913061342 小时前
网络原理知识(7)
java·网络·面试
南境十里·墨染春水2 小时前
linux学习进程 线程同步——读写锁
java·jvm·学习
ZWZhangYu2 小时前
MCP 实战:从协议原理到 Java 自定义工具服务落地
java·开发语言·人工智能
Flittly2 小时前
【SpringSecurity新手村系列】(5)RBAC角色权限与账户状态校验
java·spring boot·笔记·安全·spring·ai