零基础吃透:SavedModel与RaggedTensor的结合使用
核心背景(先理清)
SavedModel 是 TensorFlow 官方的模型序列化格式 ,能完整保存模型的「权重+计算图+签名」,支持跨平台部署(如TensorFlow Serving、TFLite)、离线复用;
RaggedTensor 与 SavedModel 兼容的核心规则:
- TF 2.3+:通过「具体函数(Concrete Function)」原生支持 RaggedTensor,无需额外处理;
- TF 2.3 前:需将 RaggedTensor 拆解为「values(元素值)+ row_splits(行分割点)」两个张量,再保存/加载。
下面结合你提供的两个核心示例(Keras模型、自定义tf.Module模型),逐行解析代码逻辑、原理和关键注意事项。
前置准备(确保代码可运行)
python
import tensorflow as tf
import tempfile # 用于创建临时目录存储SavedModel
示例1:保存/加载支持RaggedTensor的Keras模型
步骤1:先重建之前的Keras模型(衔接上下文)
python
# 1. 定义数据(复用之前的句子分类任务)
sentences = tf.constant(['What makes you think she is a witch?', 'She turned me into a newt.', 'A newt?', 'Well, I got better.'])
is_question = tf.constant([True, False, True, False])
# 2. 预处理:字符串→RaggedTensor(单词哈希编码)
hash_buckets = 1000
words = tf.strings.split(sentences, ' ')
hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets) # RaggedTensor
# 3. 构建适配RaggedTensor的Keras模型(修正后版本,避免LSTM报错)
keras_model = tf.keras.Sequential([
tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True), # 声明Ragged输入
tf.keras.layers.Embedding(hash_buckets, 16),
tf.keras.layers.LSTM(32, use_bias=False),
tf.keras.layers.Dense(32),
tf.keras.layers.Activation(tf.nn.relu),
tf.keras.layers.Dense(1)
])
keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
keras_model.fit(hashed_words.to_tensor(default_value=0), is_question, epochs=1) # 补0后训练
步骤2:保存Keras模型(支持RaggedTensor输入)
python
# 创建临时目录(避免手动管理路径)
keras_module_path = tempfile.mkdtemp()
# 保存模型:自动将Keras模型转换为SavedModel格式,包含RaggedTensor的处理逻辑
tf.saved_model.save(keras_model, keras_module_path)
print(f"Keras模型已保存到:{keras_module_path}")
步骤3:加载SavedModel并调用(传入RaggedTensor)
python
# 加载保存的模型
imported_model = tf.saved_model.load(keras_module_path)
# 直接传入RaggedTensor(hashed_words),模型透明处理
result = imported_model(hashed_words)
print("\n加载后模型的预测结果:")
print(result)
运行结果+关键解读
WARNING:absl:Function `_wrapped_model` contains input name(s) args_0 with unsupported characters which will be renamed to args_0_1 in the SavedModel.
INFO:tensorflow:Assets written to: /tmp/xxx/assets
<tf.Tensor: shape=(4, 1), dtype=float32, numpy=
array([[0.05265009],
[0.000567 ],
[0.03915224],
[0.0021234 ]], dtype=float32)>
- 警告含义(非错误) :
SavedModel对输入名称的字符有规范,自动将不兼容的名称(如args_0)重命名为args_0_1,不影响模型功能; - 核心兼容逻辑 :
Keras模型的Input层声明了ragged=True,保存为SavedModel时,TF会自动生成适配RaggedTensor的具体函数(Concrete Function),加载后可直接传入RaggedTensor; - 结果一致性 :
加载后模型的预测结果与原Keras模型完全一致,说明RaggedTensor的处理逻辑被完整保存。
示例2:保存/加载支持RaggedTensor的自定义tf.Module模型
自定义tf.Module是TF原生的模型封装方式(比Keras更灵活),但需手动构建具体函数(指定RaggedTensor的输入签名),否则SavedModel无法正确处理RaggedTensor。
步骤1:定义自定义tf.Module(含RaggedTensor运算)
python
class CustomModule(tf.Module):
def __init__(self, variable_value):
super(CustomModule, self).__init__()
# 定义可训练变量(会被SavedModel保存)
self.v = tf.Variable(variable_value, dtype=tf.float32)
# 用@tf.function装饰:编译为计算图,支持RaggedTensor
@tf.function
def grow(self, x):
# 核心逻辑:RaggedTensor的每个元素 × 变量v
return x * self.v
# 实例化模块(变量v=100.0)
module = CustomModule(100.0)
步骤2:预构建具体函数(关键!适配RaggedTensor)
python
# 必须先构建具体函数:指定输入为RaggedTensorSpec(二维、float32、可变长度)
# 作用:让SavedModel记录RaggedTensor的输入签名,避免加载后调用报错
concrete_func = module.grow.get_concrete_function(
tf.RaggedTensorSpec(shape=[None, None], dtype=tf.float32)
)
print("预构建的具体函数:", concrete_func)
步骤3:保存自定义模块
python
# 创建临时目录
custom_module_path = tempfile.mkdtemp()
# 保存模块:包含变量v、grow函数的计算图、RaggedTensor的输入签名
tf.saved_model.save(module, custom_module_path)
print(f"\n自定义模型已保存到:{custom_module_path}")
步骤4:加载并调用(传入RaggedTensor)
python
# 加载保存的模块
imported_module = tf.saved_model.load(custom_module_path)
# 传入RaggedTensor调用grow函数
ragged_input = tf.ragged.constant([[1.0, 4.0, 3.0], [2.0]], dtype=tf.float32)
result = imported_module.grow(ragged_input)
print("\n自定义模型调用结果:")
print(result)
运行结果+关键解读
INFO:tensorflow:Assets written to: /tmp/yyy/assets
<tf.RaggedTensor [[100.0, 400.0, 300.0], [200.0]]>
- 为什么必须预构建具体函数?
自定义tf.Module的@tf.function函数默认是"动态跟踪"的,未指定输入签名时,SavedModel无法确定输入类型(密集张量/RaggedTensor);
用get_concrete_function+tf.RaggedTensorSpec预构建后,SavedModel会固化RaggedTensor的处理逻辑,加载后可直接调用。 - 运算逻辑验证 :
RaggedTensor的每个元素 × 变量v(100.0),结果保留RaggedTensor结构:[1.0,4.0,3.0] × 100 → [100.0,400.0,300.0];[2.0] × 100 → [200.0]。
关键注意事项(避坑核心)
1. 版本要求(重中之重)
-
TF 2.3+:具体函数原生支持RaggedTensor,无需额外处理;
-
TF 2.3 前:SavedModel的签名不支持RaggedTensor,需手动拆解RaggedTensor为两个张量:
python# 低版本兼容:拆解RaggedTensor rt = tf.ragged.constant([[1.0,2.0], [3.0]]) rt_values = rt.values # 所有元素值:[1.0,2.0,3.0] rt_row_splits = rt.row_splits # 行分割点:[0,2,3] # 保存/加载时传递values+row_splits,加载后用tf.RaggedTensor.from_row_splits重构
2. 具体函数(Concrete Function)的必要性
- Keras模型:TF自动为
ragged=True的Input层生成具体函数,无需手动构建; - 自定义tf.Module:必须用
get_concrete_function(tf.RaggedTensorSpec)预构建,否则加载后调用RaggedTensor会报错。
3. tf.RaggedTensorSpec的作用
- 定义RaggedTensor的输入签名(形状、 dtype),让SavedModel明确输入约束;
- 示例中
shape=[None, None]表示二维RaggedTensor,两个维度长度均可变; - 若需固定某一维(如固定批次大小为32),可写
shape=[32, None]。
4. SavedModel的核心组成(了解即可)
保存后的SavedModel目录包含:
assets/:静态资源(如词汇表);variables/:模型权重(如CustomModule的变量v);saved_model.pb:计算图+签名(包含RaggedTensor的处理逻辑)。
核心总结(SavedModel + RaggedTensor)
| 模型类型 | 保存关键步骤 | 加载后调用方式 |
|---|---|---|
| Keras模型 | Input层设置ragged=True,直接tf.saved_model.save |
直接传入RaggedTensor |
| 自定义tf.Module | 用get_concrete_function(tf.RaggedTensorSpec)预构建具体函数,再保存 |
传入符合RaggedTensorSpec的RaggedTensor |
核心原则
- SavedModel对RaggedTensor的支持依赖「具体函数(Concrete Function)」,需确保保存前生成适配RaggedTensor的具体函数;
- TF 2.3+是兼容的最低版本,低版本需拆解RaggedTensor为分量张量;
- Keras模型的兼容性更"傻瓜化"(自动处理),自定义tf.Module需手动指定输入签名。
这套方案是生产环境中部署"处理可变长度数据(文本、序列)"模型的标准流程,既保留RaggedTensor无冗余的优势,又能利用SavedModel实现模型的序列化和部署。