这段内容的核心是 衔接 tf.Module 讲解 Keras 的核心设计 ------tf.keras.layers.Layer(Keras层)和 tf.keras.Model(Keras模型)均继承自 tf.Module,因此完全具备 tf.Module 的变量收集、子模块管理等特性,同时新增了「标准化生命周期」「训练/评估工具」「灵活模型定义方式」等高层功能,让模型构建、训练、部署更简洁。
下面按「Keras层(从tf.Module迁移+build步骤)→ Keras模型(子类化+函数式API)」的逻辑,逐模块拆解代码和核心特性,对比你已学的 tf.Module,突出Keras的优势:
一、核心前提:Keras 与 tf.Module 的关系
| 组件 | 继承关系 | 核心特性(继承自tf.Module)+ 新增功能 |
|---|---|---|
| tf.keras.layers.Layer | 直接继承 tf.Module | 自动收集变量/子模块 + 标准化生命周期(build/call)、训练参数支持、配置保存 |
| tf.keras.Model | 继承自 tf.keras.layers.Layer | 所有Layer特性 + 训练/评估循环、模型保存/加载、多输入多输出支持 |
简单说:Keras 是 tf.Module 的「增强版工具包」,保留了底层核心能力,同时解决了 tf.Module 需手动管理变量创建、训练逻辑等问题。
二、第一部分:Keras层(tf.keras.layers.Layer)
Keras层是构建模型的基础单元,我们从「简单层→灵活层」逐步解析,重点看和 tf.Module 的差异及 build 步骤的作用。
1. 基础Keras层:从tf.Module迁移
之前我们用 tf.Module 写了 Dense 层,现在把它改成Keras层,只需2个关键修改:
python
class MyDense(tf.keras.layers.Layer): # 1. 父类从tf.Module改为tf.keras.layers.Layer
# 2. 加**kwargs:支持Keras层的额外参数(如name、trainable、dtype等)
def __init__(self, in_features, out_features, **kwargs):
super().__init__(**kwargs) # 传递额外参数给父类
# 变量创建(暂时还在__init__,后面会迁移到build)
self.w = tf.Variable(
tf.random.normal([in_features, out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
# 3. __call__ 改成 call():Keras的标准计算接口
def call(self, x):
y = tf.matmul(x, self.w) + self.b
return tf.nn.relu(y)
# 创建层实例(支持传name等Keras参数)
simple_layer = MyDense(name="simple", in_features=3, out_features=3)
关键差异解释:
- 父类变更:
tf.keras.layers.Layer继承了tf.Module,所以simple_layer.variables「自动收集变量」的特性不变; **kwargs作用:Keras层支持很多内置参数(如trainable=False冻结层、dtype=tf.float64指定数据类型),**kwargs能兼容这些参数,避免报错;call()方法:Keras的__call__方法有内置逻辑(比如处理training参数、应用正则化、计算损失),用户只需在call()中写核心计算逻辑,调用层时仍用simple_layer(x)(底层会自动调用call())。
调用效果(和tf.Module一致):
python
simple_layer([[2.0, 2.0, 2.0]])
# 输出:tf.Tensor([[3.358..., 11.478..., 0.602...]], ...)
2. 灵活Keras层:用 build 步骤延迟创建变量
之前我们用 is_built 标记实现"延迟变量创建",Keras层内置了 build 生命周期方法,更简洁、标准化:
python
class FlexibleDense(tf.keras.layers.Layer):
def __init__(self, out_features, **kwargs):
super().__init__(**kwargs)
self.out_features = out_features # 只指定输出特征数,不指定输入
# 核心:build() 方法------第一次调用层时自动执行,传入输入形状
def build(self, input_shape): # input_shape:输入张量的形状(不含batch维度)
# 从input_shape提取输入特征数(input_shape[-1]是最后一维,即特征数)
self.w = tf.Variable(
tf.random.normal([input_shape[-1], self.out_features]), name='w')
self.b = tf.Variable(tf.zeros([out_features]), name='b')
# 注:build执行后,Keras会自动标记层为"已构建",无需手动管理is_built
def call(self, inputs): # 计算逻辑和之前一致
return tf.matmul(inputs, self.w) + self.b
核心逻辑拆解:
-
初始化时无变量:刚创建层时,
build未执行,变量为空;pythonflexible_dense = FlexibleDense(out_features=3) flexible_dense.variables # 输出:[](无变量) -
第一次调用触发
build:python# 输入shape=(2,3)(2个样本,3个特征)→ input_shape=(3,)(不含batch维度) flexible_dense(tf.constant([[2.0,2.0,2.0], [3.0,3.0,3.0]])) # build自动执行:input_shape[-1]=3 → w形状=(3,3),b形状=(3) -
变量已创建,后续复用:
pythonflexible_dense.variables # 输出:[w, b](2个变量) -
输入形状不兼容会报错(Keras自动检查):
python# 输入shape=(1,4)(4个特征),而w形状是(3,3)→矩阵乘法不兼容 flexible_dense(tf.constant([[2.0,2.0,2.0,2.0]])) # 报错:Matrix size-incompatible(输入列数4≠w的行数3)
build 步骤的优势(对比tf.Module的is_built):
- 标准化:不用手动管理
is_built标记,Keras自动处理; - 输入校验:自动检查输入形状与变量的兼容性,避免运行时错误;
- 扩展性:
build中可添加其他初始化逻辑(如正则化、约束),Keras层会自动兼容。
3. Keras层的额外实用功能(tf.Module没有)
文中提到Keras层有很多增强功能,核心常用的有:
-
training参数 :区分训练/推断模式(比如Dropout层训练时随机失活,推断时不生效);pythondef call(self, inputs, training=None): x = tf.matmul(inputs, self.w) + self.b # 训练时应用Dropout,推断时不应用 x = tf.keras.layers.Dropout(0.2)(x, training=training) return tf.nn.relu(x) -
get_config()/from_config():保存/恢复层的配置(无需重新写代码);python# 保存配置 config = flexible_dense.get_config() # 从配置重建层(变量会重新初始化,需加载权重) new_layer = FlexibleDense.from_config(config) -
可选损失/指标:支持在层内定义损失(如正则化损失),训练时自动汇总。
三、第二部分:Keras模型(tf.keras.Model)
Keras模型是「层的容器」,继承自 tf.keras.layers.Layer,因此支持嵌套层、自动收集变量,同时新增了训练循环、模型可视化、保存加载等核心功能。Keras模型有两种定义方式:子类化Model 和 函数式API,我们分别解析。
1. 子类化Keras模型(和tf.Module的SequentialModule类似)
直接继承 tf.keras.Model,逻辑和之前的 SequentialModule 几乎一致,只需修改父类和 __call__ 为 call():
python
class MySequentialModel(tf.keras.Model):
def __init__(self, name=None, **kwargs):
super().__init__(name=name, **kwargs)
# 嵌套Keras层(FlexibleDense是Keras层)
self.dense_1 = FlexibleDense(out_features=3)
self.dense_2 = FlexibleDense(out_features=2)
def call(self, x): # 计算流程:x→dense_1→dense_2
x = self.dense_1(x)
return self.dense_2(x)
# 创建模型实例
my_sequential_model = MySequentialModel(name="the_model")
核心特性(和tf.Module对比):
-
变量/子模块收集不变:
pythonmy_sequential_model.variables # 输出:dense_1和dense_2的w、b(共4个变量) my_sequential_model.submodules # 输出:(dense_1, dense_2) -
新增训练/评估功能:支持
model.fit()(训练)、model.evaluate()(评估)、model.predict()(预测),无需手动写梯度下降循环; -
保存/加载更方便:支持
model.save()(直接保存为SavedModel)、tf.keras.models.load_model()(无需原始类)。
调用效果:
python
my_sequential_model(tf.constant([[2.0,2.0,2.0]]))
# 输出:tf.Tensor([[-7.720..., -11.065...]], ...)
2. 函数式API(Keras特色,快速构建模型)
对于"层按顺序执行"或"多输入多输出"的模型,函数式API无需子类化,直接用「层的链式调用」构建,更简洁且支持可视化。
代码解析:
python
# 1. 定义输入层:指定输入形状(不含batch维度,None表示任意batch大小)
inputs = tf.keras.Input(shape=[3,]) # 输入:(None, 3) → None=任意样本数,3=特征数
# 2. 链式调用层:输入→dense_1→dense_2
x = FlexibleDense(3)(inputs) # 第一层:输出3个特征
x = FlexibleDense(2)(x) # 第二层:输出2个特征
# 3. 包装成模型:指定输入和输出
my_functional_model = tf.keras.Model(inputs=inputs, outputs=x)
核心优势1:自动可视化模型结构(summary())
python
my_functional_model.summary()
输出结果解析:
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 3)] 0 # 输入层,无参数
flexible_dense_3 (FlexibleD (None, 3) 12 # 3*3(w)+3(b)=12
ense)
flexible_dense_4 (FlexibleD (None, 2) 8 # 3*2(w)+2(b)=8
ense)
=================================================================
Total params: 20 # 总参数:12+8=20
Trainable params: 20 # 可训练参数(默认所有变量可训练)
Non-trainable params: 0 # 无不可训练参数
_________________________________________________________________
- 价值:快速查看模型的层顺序、输出形状、参数数量,方便调试(比如参数过多可能过拟合,过少可能欠拟合)。
核心优势2:无需手动管理输入形状
函数式API通过 Input(shape=[3,]) 预先指定输入特征数,第一次创建层时就会触发 build,无需等待调用模型:
python
my_functional_model(tf.constant([[2.0,2.0,2.0]]))
# 输出:tf.Tensor([[17.506..., -3.895...]], ...)
关键注意:
Input(shape=[3,])的shape不含batch维度(batch维度用None表示,支持任意样本数);- 函数式API适合「静态结构」的模型(层的连接关系固定),子类化Model适合「动态结构」的模型(比如根据输入条件分支)。
四、关键注意事项(避免踩坑)
- Keras层/模型中不要混合tf.Module :嵌套在Keras层中的原始
tf.Module不会被Keras收集变量(影响训练和保存),应统一使用Keras层; - 子类化Model无需指定InputLayer :如果子类化
tf.keras.Model时写了InputLayer,会被忽略,输入形状由第一次调用模型时的输入决定; - 函数式API的输入形状可留空 :比如
Input(shape=[None,])表示输入特征数不固定(适合NLP中变长文本的词向量)。
五、总结:Keras的核心价值
Keras 是基于 tf.Module 的高层API,核心优势是「标准化、简化模型开发全流程」:
- 层的标准化 :用
build/call替代手动管理变量创建,支持训练参数、配置保存; - 模型的灵活定义:子类化Model适合动态逻辑,函数式API适合静态结构+可视化;
- 内置训练/评估工具 :
model.fit()封装梯度下降、批量处理,无需手动写训练循环; - 无缝兼容tf.Module特性:自动收集变量、子模块,支持SavedModel保存/加载,跨环境部署。
简单说:tf.Module 是底层骨架,Keras是装修好的房子------保留了骨架的稳定性,同时提供了家具(训练工具)、装修(可视化)、售后(配置保存),让你不用关注底层细节,专注于模型逻辑。