核心解析:tf.RaggedTensorSpec 作用与参数说明
tf.RaggedTensorSpec 是 TensorFlow 中用于描述不规则张量(RaggedTensor)的"规格/签名" 的类,常用来定义输入签名(如 tf.function、SavedModel、Keras 输入等场景),告诉 TensorFlow 待处理的 RaggedTensor 应满足的形状、数据类型、不规则维度等约束。
逐参数拆解:spec = tf.RaggedTensorSpec(shape=[2, None, None], dtype=tf.int32, ragged_rank=2)
-
shape=[2, None, None]- 定义 RaggedTensor 的"整体形状框架":
- 第一维固定为
2(表示最外层维度有且仅有 2 个元素); - 第二、三维为
None(表示这两个维度的长度是动态可变的,无固定值); - 结合
ragged_rank=2,最终张量的"固定维度"是第 0 维(长度 2),第 1、2 维为不规则维度。
- 第一维固定为
- 定义 RaggedTensor 的"整体形状框架":
-
dtype=tf.int32- 指定该 RaggedTensor 中存储的数据类型为 32 位整型(如
1、5、100等)。
- 指定该 RaggedTensor 中存储的数据类型为 32 位整型(如
-
ragged_rank=2-
核心参数:表示 RaggedTensor 的"不规则等级"(即有多少个连续的不规则维度);
-
此处
ragged_rank=2意味着:从第 1 维开始,连续 2 个维度(第 1、2 维)是不规则的(各元素的子维度长度可不同); -
示例符合该 spec 的 RaggedTensor 结构:
python# 外层固定2个元素,第1、2维长度可变 rt = tf.ragged.constant([ [[1, 2], [3]], # 第0个元素:第1维长度2,第2维分别为2、1 [[4], [5, 6, 7]] # 第1个元素:第1维长度2,第2维分别为1、3 ])
-
核心用途
该 spec 可用于:
- 定义
tf.function的输入签名,约束传入的 RaggedTensor 必须匹配此规格; - 定义 Keras 模型的输入层(适配不规则长度的张量,如变长文本、变长序列);
- 保存/加载 SavedModel 时,明确输入输出的张量类型约束。