这段内容的核心是讲解 TensorFlow 中 tf.Module 的本质与用法 ------它是构建模型和层的「基础骨架」,核心优势是 自动收集变量和子模块 ,让模型的变量管理、组合、保存加载变得简单。同时它也是 Keras 层(tf.keras.layers.Layer)和模型(tf.keras.Model)的父类,学会 tf.Module 就等于掌握了 TensorFlow 构建模型的底层逻辑。
下面按「核心定位→逐行解析代码→核心特性→复杂模型示例」的顺序,把每个细节讲透(包括陌生的类结构、API、变量收集逻辑):
一、先明确 tf.Module 的核心定位
tf.Module 是一个「带状态管理的容器」,这里的「状态」就是模型的可训练参数(tf.Variable)。它的核心价值有 3 个:
- 自动收集变量 :把分配给实例属性的
tf.Variable自动归类(可训练/所有变量),不用手动维护变量列表; - 支持子模块组合 :一个
tf.Module可以包含其他tf.Module实例(比如模型包含多个层),并递归收集所有子模块的变量; - 适配模型生命周期:为后续模型训练(梯度更新可训练变量)、保存/加载(只存变量,不存代码)提供基础。
简单比喻:tf.Module 像一个「智能工具箱」------你把工具(变量)和小工具箱(子模块)放进去,它会自动分类整理,要用的时候直接拿,不用自己翻找。
二、逐行解析第一个示例:简单模块(SimpleModule)
这个示例是最基础的 tf.Module 用法,实现一个简单的线性函数 y = a*x + b,我们逐行拆解:
1. 定义 SimpleModule 类(继承 tf.Module)
python
class SimpleModule(tf.Module):
def __init__(self, name=None):
# 调用父类 tf.Module 的构造函数,传入模块名称(可选,用于区分不同模块)
super().__init__(name=name)
# 定义可训练变量:trainable=True(默认),后续训练会更新这个变量
self.a_variable = tf.Variable(5.0, name="train_me")
# 定义不可训练变量:trainable=False,后续训练不会更新(比如固定的偏置)
self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me")
# __call__ 方法:让模块实例可以像函数一样被调用(核心!)
def __call__(self, x):
# 线性计算:a_variable(5.0)* x + non_trainable_variable(5.0)
return self.a_variable * x + self.non_trainable_variable
- 关键知识点 :
class SimpleModule(tf.Module):Python 类继承,让 SimpleModule 拥有tf.Module的所有特性(变量收集、子模块管理);super().__init__(name=name):必须调用父类构造函数,否则tf.Module的核心功能(如变量收集)无法生效;tf.Variable的trainable参数:决定变量是否参与训练(是否会被梯度更新),比如微调模型时可以冻结部分变量;__call__方法:Python 特殊方法,定义后,模块实例(如simple_module)可以直接像函数一样调用(simple_module(x)),不用写simple_module.call(x),这是深度学习中"调用模型/层"的标准写法。
2. 创建模块实例并调用
python
# 创建 SimpleModule 实例,名称为 "simple"
simple_module = SimpleModule(name="simple")
# 调用模块(像函数一样传入张量 x=5.0)
simple_module(tf.constant(5.0))
# 输出:tf.Tensor(30.0, ...) → 计算逻辑:5.0 * 5.0 + 5.0 = 30.0
- 这里的调用本质是执行
__call__方法,输入张量x,返回计算结果,和之前学的函数调用逻辑一致,但模块内部维护了自己的"状态"(两个变量)。
3. 查看模块的变量(核心特性:自动收集)
python
# 查看所有可训练变量(trainable=True 的变量)
print("trainable variables:", simple_module.trainable_variables)
# 输出:(<tf.Variable 'train_me:0' ... numpy=5.0>,) → 只有 a_variable
# 查看所有变量(包括 trainable=False 的变量)
print("all variables:", simple_module.variables)
# 输出:(a_variable, non_trainable_variable) → 两个变量都包含
- 核心价值 :
tf.Module自动收集了所有分配给self的tf.Variable,并按"可训练性"分类,后续训练时可以直接获取trainable_variables来计算梯度,不用手动遍历变量。
三、第二个示例:构建两层全连接模型(组合子模块)
这个示例展示了 tf.Module 的另一个核心特性------递归收集子模块:一个父模块可以包含多个子模块(比如模型包含多个层),父模块会自动收集所有子模块的变量,方便统一管理。
1. 定义全连接层(Dense 层)
全连接层(也叫密集层)是神经网络的基础层,计算逻辑:y = ReLU(x × w + b),其中 w(权重)和 b(偏置)是可训练变量。
python
class Dense(tf.Module):
def __init__(self, in_features, out_features, name=None):
super().__init__(name=name)
# 权重 w:形状 [输入特征数, 输出特征数],用随机正态分布初始化
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
# 偏置 b:形状 [输出特征数],用全零初始化(常见初始化方式)
self.b = tf.Variable(tf.zeros([out_features]), name='b')
def __call__(self, x):
# 第一步:矩阵乘法 x × w(输入特征 × 权重) + 偏置 b → 线性变换
y = tf.matmul(x, self.w) + self.b
# 第二步:ReLU 激活函数(引入非线性,让模型能拟合复杂关系)
return tf.nn.relu(y)
- API 解析 :
tf.random.normal([in_features, out_features]):生成形状为[in_features, out_features]的随机张量(服从正态分布),用于初始化权重;tf.zeros([out_features]):生成全零张量,初始化偏置(偏置初始化为 0 是常用实践);tf.matmul(x, self.w):矩阵乘法,必须满足x的列数 =self.w的行数(这里x的特征数是in_features,和self.w的第一维一致);tf.nn.relu(y):ReLU 激活函数(y > 0时返回y,y ≤ 0时返回 0),之前学过的激活函数。
2. 定义完整模型(SequentialModule)
这个模型包含两个 Dense 子模块,按顺序执行(第一层输出作为第二层输入):
python
class SequentialModule(tf.Module):
def __init__(self, name=None):
super().__init__(name=name)
# 第一层 Dense:输入特征数=3,输出特征数=3
self.dense_1 = Dense(in_features=3, out_features=3)
# 第二层 Dense:输入特征数=3(和第一层输出一致),输出特征数=2
self.dense_2 = Dense(in_features=3, out_features=2)
def __call__(self, x):
# 第一步:输入 x 经过第一层 dense_1
x = self.dense_1(x)
# 第二步:第一层输出经过第二层 dense_2,返回最终结果
return self.dense_2(x)
- 这里的
self.dense_1和self.dense_2是 Dense 类的实例(也是tf.Module的实例),SequentialModule作为父模块,会自动收集这两个子模块的所有变量。
3. 创建模型并调用
python
# 创建模型实例
my_model = SequentialModule(name="the_model")
# 调用模型:输入 shape=(1, 3) 的张量(1 个样本,3 个特征)
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))
# 输出:tf.Tensor([[0. 0.]], ...) → 结果随机(因为权重 w 是随机初始化的)
- 调用逻辑:输入
[[2.0,2.0,2.0]]→ 经过 dense_1(线性变换+ReLU)→ 输出 3 个特征 → 经过 dense_2(线性变换+ReLU)→ 输出 2 个特征,最终结果因权重随机而不同。
4. 查看模型的子模块和变量(递归收集特性)
python
# 查看模型的所有子模块(dense_1 和 dense_2)
print("Submodules:", my_model.submodules)
# 输出:(<__main__.Dense object ...>, <__main__.Dense object ...>)
# 查看模型的所有变量(两个子模块的 w 和 b,共 4 个变量)
for var in my_model.variables:
print(var, "\n")
- 输出的 4 个变量分别是:
- dense_1 的偏置 b(shape=(3,));
- dense_1 的权重 w(shape=(3,3));
- dense_2 的偏置 b(shape=(2,));
- dense_2 的权重 w(shape=(3,2));
- 这就是
tf.Module的「递归收集」:父模块不仅收集自己的变量,还会遍历所有子模块,收集子模块的变量,让你用一个模型实例就能管理所有参数。
四、关键补充说明(衔接 Keras)
文中的注释提到:tf.Module 是 tf.keras.layers.Layer 和 tf.keras.Model 的基类,意味着:
- Keras 的层和模型本质上也是
tf.Module,具备「自动收集变量、子模块」的特性; - 实践中可以选择用
tf.Module原生构建,或用 Keras(更高层API),但不要混合使用(避免变量收集异常); - 查看变量的方法(
trainable_variables、variables)在两者中完全一致。
五、总结:这段内容的核心价值
- 掌握模型构建的底层逻辑 :
tf.Module是 TensorFlow 模型/层的基础,核心是「状态(变量)管理+子模块组合」; - 理解关键特性 :
- 自动收集变量(按可训练性分类),方便训练时获取参数;
- 递归收集子模块,支持构建复杂模型(如多层、多分支);
__call__方法让模型/层可以像函数一样调用,符合深度学习的使用习惯;
- 衔接后续知识 :学会
tf.Module后,再学 Keras 模型构建会非常轻松,且能理解模型保存、加载、训练的底层原理(本质是管理tf.Module收集的变量)。
简单说:tf.Module 就是 TensorFlow 给模型/层设计的「智能容器」,帮你自动打理参数和子模块,不用手动维护复杂的变量列表,让你能专注于模型的计算逻辑(比如层的组合、运算流程)。