第一步:先搞懂「计算图」到底是什么?(关键铺垫)
计算图(Computation Graph)是 TensorFlow 的核心概念,本质是:把一系列 TensorFlow 运算(比如 tf.matmul、+),用"节点"和"边"的形式提前定义好的「运算流程图」。
用一个生活比喻理解:
- 你要做一道菜"番茄炒蛋",步骤是:切番茄 → 打鸡蛋 → 翻炒番茄 → 倒入鸡蛋 → 加盐 → 出锅。
- 这整个「步骤流程」就是「计算图」:
- 节点:每个操作(切、打、炒、加盐);
- 边:操作之间的依赖(切番茄之后才能炒,炒番茄之后才能加鸡蛋);
- 食材/成品:张量(比如"切好的番茄"是一个张量,"炒好的番茄炒蛋"是最终张量)。
再对比两种做事方式(对应 TensorFlow 的两种执行模式):
| 执行模式 | 通俗理解(做饭例子) | 对应代码场景 |
|---|---|---|
| 即时执行(Eager) | 边想边做:切一个番茄 → 立刻炒 → 再打鸡蛋 → 再炒 | 普通 Python 函数(无 tf.function),一步一步执行运算,执行时才知道下一步做什么 |
| 图执行(Graph) | 先画好"做菜流程图" → 按图批量做(甚至让别人按图做) | tf.function 包装后的函数,先把运算流程(计算图)定义好,再执行(可重复用图) |
「计算图」的核心价值:
- 提速:图是"预编译"的,TensorFlow 会优化流程(比如合并重复运算、GPU/TPU 批量执行),比"边想边做"快得多(尤其模型大、调用次数多时);
- 可部署:计算图是独立于 Python 的"运算描述文件",可以保存下来,部署到手机、服务器等没有 Python 环境的设备上;
- 简洁:不用手动管理运算依赖,TensorFlow 自动处理节点顺序。
第二步:结合代码,理解 tf.function 到底在做什么?
现在回头看你给的代码,核心逻辑是:用 tf.function 把"即时执行的 Python 函数",转换成"基于计算图执行的 TensorFlow Function"------函数的「输入输出、计算逻辑完全不变」,但「执行方式从"一步一步算"变成了"按预定义图算"」。
第一块代码解析(基础用法)
python
# 1. 定义一个普通 Python 函数(即时执行模式)
def a_regular_function(x, y, b):
x = tf.matmul(x, y) # 运算1:矩阵乘法
x = x + b # 运算2:加法
return x
- 这个函数的执行方式:每次调用时,会「逐行即时执行」------先算
tf.matmul,得到结果再算x + b,像做饭时"边切边炒"; - 没有提前规划流程,每次调用都要重新"解析步骤"。
python
# 2. 用 tf.function 包装普通函数,得到 TensorFlow 的 Function 对象
a_function_that_uses_a_graph = tf.function(a_regular_function)
- 这一步是关键:
tf.function会"分析"a_regular_function里的所有 TensorFlow 运算(tf.matmul、+),自动构建出对应的「计算图」; - 这个图的结构是:
x → [tf.matmul] → 中间结果 → [+ b] → 最终 x(节点是两个运算,边是张量流转); - 包装后的
a_function_that_uses_a_graph不再是普通 Python 函数,而是"会用图执行的调用对象"。
python
# 3. 创建张量,调用两种函数
x1 = tf.constant([[1.0, 2.0]])
y1 = tf.constant([[2.0], [3.0]])
b1 = tf.constant(4.0)
orig_value = a_regular_function(x1, y1, b1).numpy() # 普通函数:即时执行
tf_function_value = a_function_that_uses_a_graph(x1, y1, b1).numpy() # Function:图执行
assert(orig_value == tf_function_value) # 结果完全一致
- 为什么结果一致?因为「计算逻辑没变」------都是先矩阵乘法再加法,只是执行方式不同;
- 区别在哪?第一次调用
a_function_that_uses_a_graph时,会先"编译图"(这一步稍慢),之后再调用时,直接用已编译好的图执行(速度极快);而普通函数每次调用都要重新执行步骤。
第二块代码解析(装饰器用法 + 嵌套函数)
python
def inner_function(x, y, b):
x = tf.matmul(x, y)
x = x + b
return x
@tf.function # 装饰器写法,等价于 outer_function = tf.function(outer_function)
def outer_function(x):
y = tf.constant([[2.0], [3.0]])
b = tf.constant(4.0)
return inner_function(x, y, b) # 调用内部函数
- 核心点:
tf.function会「递归捕获所有嵌套调用的 TensorFlow 运算」------不仅outer_function里的tf.constant,连inner_function里的tf.matmul、+都会被纳入同一个计算图; - 不用给
inner_function单独加@tf.function,外层包装后,整个调用链的运算都会被图执行; - 执行结果
array([[12.]])和普通函数一致,还是因为计算逻辑没变。
最后一句的意义:"不用定义 Placeholder 或 tf.Session"
这是对比 TensorFlow 1.x 的简化:
- TensorFlow 1.x 时,要手动定义计算图、用
Placeholder占位、用tf.Session()启动图执行,代码繁琐; - 现在有了
tf.function,你写的还是普通 Python 函数,只要用tf.function一包装,自动生成计算图,不用管占位符和会话------既保留了 Python 的简洁,又得到了计算图的优势。
第三步:总结核心要点(帮你"串起来")
- 「计算图」是 TensorFlow 对"运算流程"的预定义,目的是「提速、可部署」;
tf.function的作用:把"即时执行的 Python 函数",自动转换成"基于计算图执行的 Function";- 关键特性:
- 计算逻辑不变(结果和普通函数一致);
- 会递归处理嵌套函数,不用单独包装;
- 第一次调用编译图(稍慢),后续调用复用图(极快);
- 你之前学的
GradientTape、变量、梯度计算,都能和tf.function兼容------计算图会记录运算依赖,梯度带照样能追踪梯度(而且图执行下梯度计算更高效)。
现在再回头看这段内容,核心就是:TensorFlow 提供了 tf.function 这个工具,让你用写普通 Python 函数的方式,轻松获得计算图的优势,不用关心底层图的构建细节。