TensorFlow 中定义模型和层

这段内容的核心是讲解 TensorFlow 中 tf.Module 的本质与用法 ------它是构建模型和层的「基础骨架」,核心优势是 自动收集变量和子模块 ,让模型的变量管理、组合、保存加载变得简单。同时它也是 Keras 层(tf.keras.layers.Layer)和模型(tf.keras.Model)的父类,学会 tf.Module 就等于掌握了 TensorFlow 构建模型的底层逻辑。

下面按「核心定位→逐行解析代码→核心特性→复杂模型示例」的顺序,把每个细节讲透(包括陌生的类结构、API、变量收集逻辑):

一、先明确 tf.Module 的核心定位

tf.Module 是一个「带状态管理的容器」,这里的「状态」就是模型的可训练参数(tf.Variable)。它的核心价值有 3 个:

  1. 自动收集变量 :把分配给实例属性的 tf.Variable 自动归类(可训练/所有变量),不用手动维护变量列表;
  2. 支持子模块组合 :一个 tf.Module 可以包含其他 tf.Module 实例(比如模型包含多个层),并递归收集所有子模块的变量;
  3. 适配模型生命周期:为后续模型训练(梯度更新可训练变量)、保存/加载(只存变量,不存代码)提供基础。

简单比喻:tf.Module 像一个「智能工具箱」------你把工具(变量)和小工具箱(子模块)放进去,它会自动分类整理,要用的时候直接拿,不用自己翻找。

二、逐行解析第一个示例:简单模块(SimpleModule)

这个示例是最基础的 tf.Module 用法,实现一个简单的线性函数 y = a*x + b,我们逐行拆解:

1. 定义 SimpleModule 类(继承 tf.Module)
python 复制代码
class SimpleModule(tf.Module):
  def __init__(self, name=None):
    # 调用父类 tf.Module 的构造函数,传入模块名称(可选,用于区分不同模块)
    super().__init__(name=name)
    
    # 定义可训练变量:trainable=True(默认),后续训练会更新这个变量
    self.a_variable = tf.Variable(5.0, name="train_me")
    
    # 定义不可训练变量:trainable=False,后续训练不会更新(比如固定的偏置)
    self.non_trainable_variable = tf.Variable(5.0, trainable=False, name="do_not_train_me")
  
  # __call__ 方法:让模块实例可以像函数一样被调用(核心!)
  def __call__(self, x):
    # 线性计算:a_variable(5.0)* x + non_trainable_variable(5.0)
    return self.a_variable * x + self.non_trainable_variable
  • 关键知识点
    • class SimpleModule(tf.Module):Python 类继承,让 SimpleModule 拥有 tf.Module 的所有特性(变量收集、子模块管理);
    • super().__init__(name=name):必须调用父类构造函数,否则 tf.Module 的核心功能(如变量收集)无法生效;
    • tf.Variabletrainable 参数:决定变量是否参与训练(是否会被梯度更新),比如微调模型时可以冻结部分变量;
    • __call__ 方法:Python 特殊方法,定义后,模块实例(如 simple_module)可以直接像函数一样调用(simple_module(x)),不用写 simple_module.call(x),这是深度学习中"调用模型/层"的标准写法。
2. 创建模块实例并调用
python 复制代码
# 创建 SimpleModule 实例,名称为 "simple"
simple_module = SimpleModule(name="simple")

# 调用模块(像函数一样传入张量 x=5.0)
simple_module(tf.constant(5.0))
# 输出:tf.Tensor(30.0, ...) → 计算逻辑:5.0 * 5.0 + 5.0 = 30.0
  • 这里的调用本质是执行 __call__ 方法,输入张量 x,返回计算结果,和之前学的函数调用逻辑一致,但模块内部维护了自己的"状态"(两个变量)。
3. 查看模块的变量(核心特性:自动收集)
python 复制代码
# 查看所有可训练变量(trainable=True 的变量)
print("trainable variables:", simple_module.trainable_variables)
# 输出:(<tf.Variable 'train_me:0' ... numpy=5.0>,) → 只有 a_variable

# 查看所有变量(包括 trainable=False 的变量)
print("all variables:", simple_module.variables)
# 输出:(a_variable, non_trainable_variable) → 两个变量都包含
  • 核心价值tf.Module 自动收集了所有分配给 selftf.Variable,并按"可训练性"分类,后续训练时可以直接获取 trainable_variables 来计算梯度,不用手动遍历变量。

三、第二个示例:构建两层全连接模型(组合子模块)

这个示例展示了 tf.Module 的另一个核心特性------递归收集子模块:一个父模块可以包含多个子模块(比如模型包含多个层),父模块会自动收集所有子模块的变量,方便统一管理。

1. 定义全连接层(Dense 层)

全连接层(也叫密集层)是神经网络的基础层,计算逻辑:y = ReLU(x × w + b),其中 w(权重)和 b(偏置)是可训练变量。

python 复制代码
class Dense(tf.Module):
  def __init__(self, in_features, out_features, name=None):
    super().__init__(name=name)
    
    # 权重 w:形状 [输入特征数, 输出特征数],用随机正态分布初始化
    self.w = tf.Variable(
      tf.random.normal([in_features, out_features]), name='w')
    
    # 偏置 b:形状 [输出特征数],用全零初始化(常见初始化方式)
    self.b = tf.Variable(tf.zeros([out_features]), name='b')
  
  def __call__(self, x):
    # 第一步:矩阵乘法 x × w(输入特征 × 权重) + 偏置 b → 线性变换
    y = tf.matmul(x, self.w) + self.b
    # 第二步:ReLU 激活函数(引入非线性,让模型能拟合复杂关系)
    return tf.nn.relu(y)
  • API 解析
    • tf.random.normal([in_features, out_features]):生成形状为 [in_features, out_features] 的随机张量(服从正态分布),用于初始化权重;
    • tf.zeros([out_features]):生成全零张量,初始化偏置(偏置初始化为 0 是常用实践);
    • tf.matmul(x, self.w):矩阵乘法,必须满足 x 的列数 = self.w 的行数(这里 x 的特征数是 in_features,和 self.w 的第一维一致);
    • tf.nn.relu(y):ReLU 激活函数(y > 0 时返回 yy ≤ 0 时返回 0),之前学过的激活函数。
2. 定义完整模型(SequentialModule)

这个模型包含两个 Dense 子模块,按顺序执行(第一层输出作为第二层输入):

python 复制代码
class SequentialModule(tf.Module):
  def __init__(self, name=None):
    super().__init__(name=name)
    
    # 第一层 Dense:输入特征数=3,输出特征数=3
    self.dense_1 = Dense(in_features=3, out_features=3)
    # 第二层 Dense:输入特征数=3(和第一层输出一致),输出特征数=2
    self.dense_2 = Dense(in_features=3, out_features=2)
  
  def __call__(self, x):
    # 第一步:输入 x 经过第一层 dense_1
    x = self.dense_1(x)
    # 第二步:第一层输出经过第二层 dense_2,返回最终结果
    return self.dense_2(x)
  • 这里的 self.dense_1self.dense_2 是 Dense 类的实例(也是 tf.Module 的实例),SequentialModule 作为父模块,会自动收集这两个子模块的所有变量。
3. 创建模型并调用
python 复制代码
# 创建模型实例
my_model = SequentialModule(name="the_model")

# 调用模型:输入 shape=(1, 3) 的张量(1 个样本,3 个特征)
print("Model results:", my_model(tf.constant([[2.0, 2.0, 2.0]])))
# 输出:tf.Tensor([[0. 0.]], ...) → 结果随机(因为权重 w 是随机初始化的)
  • 调用逻辑:输入 [[2.0,2.0,2.0]] → 经过 dense_1(线性变换+ReLU)→ 输出 3 个特征 → 经过 dense_2(线性变换+ReLU)→ 输出 2 个特征,最终结果因权重随机而不同。
4. 查看模型的子模块和变量(递归收集特性)
python 复制代码
# 查看模型的所有子模块(dense_1 和 dense_2)
print("Submodules:", my_model.submodules)
# 输出:(<__main__.Dense object ...>, <__main__.Dense object ...>)

# 查看模型的所有变量(两个子模块的 w 和 b,共 4 个变量)
for var in my_model.variables:
  print(var, "\n")
  • 输出的 4 个变量分别是:
    1. dense_1 的偏置 b(shape=(3,));
    2. dense_1 的权重 w(shape=(3,3));
    3. dense_2 的偏置 b(shape=(2,));
    4. dense_2 的权重 w(shape=(3,2));
  • 这就是 tf.Module 的「递归收集」:父模块不仅收集自己的变量,还会遍历所有子模块,收集子模块的变量,让你用一个模型实例就能管理所有参数。

四、关键补充说明(衔接 Keras)

文中的注释提到:tf.Moduletf.keras.layers.Layertf.keras.Model 的基类,意味着:

  1. Keras 的层和模型本质上也是 tf.Module,具备「自动收集变量、子模块」的特性;
  2. 实践中可以选择用 tf.Module 原生构建,或用 Keras(更高层API),但不要混合使用(避免变量收集异常);
  3. 查看变量的方法(trainable_variablesvariables)在两者中完全一致。

五、总结:这段内容的核心价值

  1. 掌握模型构建的底层逻辑tf.Module 是 TensorFlow 模型/层的基础,核心是「状态(变量)管理+子模块组合」;
  2. 理解关键特性
    • 自动收集变量(按可训练性分类),方便训练时获取参数;
    • 递归收集子模块,支持构建复杂模型(如多层、多分支);
    • __call__ 方法让模型/层可以像函数一样调用,符合深度学习的使用习惯;
  3. 衔接后续知识 :学会 tf.Module 后,再学 Keras 模型构建会非常轻松,且能理解模型保存、加载、训练的底层原理(本质是管理 tf.Module 收集的变量)。

简单说:tf.Module 就是 TensorFlow 给模型/层设计的「智能容器」,帮你自动打理参数和子模块,不用手动维护复杂的变量列表,让你能专注于模型的计算逻辑(比如层的组合、运算流程)。

相关推荐
CNRio1 小时前
执AI之笔,绘时代新篇——清醒洞察智能革命的机遇密码
人工智能
Rose sait1 小时前
Visual Studio中配置 ONNX Runtime、OpenCV 和 OpenVINO 项目
人工智能·openvino
亚里随笔1 小时前
DeepSeek-V3.2:开源大语言模型的新里程碑,在推理与智能体任务中突破性能边界
人工智能·语言模型·自然语言处理·llm·rlhf·agentic
摘星编程1 小时前
基于 DevUI 与 MateChat 构建企业级 AI 智能助手的实践与探索
人工智能·华为云·状态模式
IT_陈寒1 小时前
Redis性能翻倍的5个冷门技巧:从缓存穿透到集群优化实战指南
前端·人工智能·后端
豪越大豪1 小时前
Al+新型智慧消防一体化安全管控平台!办公 + 训练 + 安防一起管
人工智能·深度学习·安全
沫儿笙1 小时前
柯马弧焊机器人气流智能调节
人工智能·物联网·机器人
love530love1 小时前
【SD WebUI踩坑】启动报错 Expecting value: line 1 column 1 (char 0) 的终极解决方案
人工智能·windows·python·github·stablediffusion
木棉知行者1 小时前
【第5篇】InceptionNeXT(CVPR2024):融合 Inception 思想与现代 CNN 设计的高效特征提取架构
人工智能·深度学习·计算机视觉·cnn