零基础吃透:RaggedTensor的底层编码原理
RaggedTensor的核心设计是**"扁平化存储+行分区描述"** ------ 不直接存储嵌套列表(低效),而是将所有有效元素扁平存储在values张量中,再通过row_partition(行分区)描述"如何将扁平值拆分为可变长度的行"。以下从「核心编码结构」「四种行分区编码」「多不规则维度」「不规则秩与扁平值」「均匀维度编码」五大模块,拆解底层原理和用法。
一、RaggedTensor的核心编码结构
核心公式
RaggedTensor = values(扁平张量) + row_partition(行分区规则)
values:所有有效元素按顺序拼接成的一维/多维扁平张量(无嵌套,无空值);row_partition:描述"如何将values拆分为可变长度行"的规则,支持4种编码方式(下文详解)。
基础示例(row_splits编码)
python
import tensorflow as tf
# 构造RaggedTensor:values+row_splits
rt = tf.RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2], # 所有有效元素的扁平列表
row_splits=[0, 4, 4, 6, 7] # 行拆分点:[起始, 行0结束, 行1结束, 行2结束, 行3结束]
)
print("构造的RaggedTensor:", rt)
输出 :<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
拆分点逻辑(关键!)
row_splits的每个值是values的索引,定义每行的元素范围:
| 行索引 | 拆分点范围 | values切片 | 行内容 |
|---|---|---|---|
| 0 | 0 → 4 | values[0:4] → [3,1,4,1] | [3,1,4,1] |
| 1 | 4 → 4 | values[4:4] → [] | [](空行) |
| 2 | 4 → 6 | values[4:6] → [5,9] | [5,9] |
| 3 | 6 → 7 | values[6:7] → [2] | [2] |
二、四种row_partition编码方式(行分区规则)
TF内部管理行分区的编码方式,不同编码适配不同场景(效率/兼容性),以下是4种核心编码的原理、示例和优缺点:
1. row_splits(拆分点编码)
定义
一维整型向量,每个值表示values中"行的结束索引",长度=行数+1(首元素必为0,末元素必为values长度)。
示例(复用上文)
python
rt = tf.RaggedTensor.from_row_splits(values=[3,1,4,1,5,9,2], row_splits=[0,4,4,6,7])
print("row_splits:", rt.row_splits.numpy()) # 直接访问拆分点
输出 :[0 4 4 6 7]
核心优缺点
✅ 优点:恒定时间索引/切片 (直接通过拆分点定位行),适合频繁索引的场景;
❌ 缺点:空行仍需占用拆分点位置,存储大量空行时效率低。
2. value_rowids(值的行索引编码)
定义
一维整型向量,长度=values长度,每个值表示"对应values元素所属的行索引"。
示例
python
# 构造:values=[3,1,4,1,5,9,2],对应行索引[0,0,0,0,2,2,3]
rt = tf.RaggedTensor.from_value_rowids(
values=[3,1,4,1,5,9,2],
value_rowids=[0,0,0,0,2,2,3],
nrows=4 # 总行数(包含空行1)
)
print("value_rowids构造的RaggedTensor:", rt)
print("value_rowids:", rt.value_rowids.numpy())
输出:
value_rowids构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
value_rowids: [0 0 0 0 2 2 3]
核心优缺点
✅ 优点:
- 存储大量空行时高效(仅存储有效元素的行索引,空行无开销);
- 兼容
tf.segment_sum等分段运算(输入格式匹配);
❌ 缺点:索引单行需遍历value_rowids,效率低于row_splits。
3. row_lengths(行长度编码)
定义
一维整型向量,长度=行数,每个值表示"对应行的元素长度"(空行长度为0)。
示例
python
# 构造:行长度[4,0,2,1] → 对应行0:4个元素,行1:0个,行2:2个,行3:1个
rt = tf.RaggedTensor.from_row_lengths(
values=[3,1,4,1,5,9,2],
row_lengths=[4,0,2,1]
)
print("row_lengths构造的RaggedTensor:", rt)
print("row_lengths:", rt.row_lengths.numpy())
输出:
row_lengths构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
row_lengths: [4 0 2 1]
核心优缺点
✅ 优点:拼接(concat)效率高 (拼接时仅需合并row_lengths,无需修改values);
❌ 缺点:空行仍需存储长度0,大量空行时效率低于value_rowids。
4. uniform_row_length(均匀行长度编码)
定义
整型标量,表示"所有行的长度相同"(仅用于"非内层维度为均匀"的场景)。
示例(见下文"均匀非内层维度")
核心优缺点
✅ 优点:存储效率极高(仅需一个标量,无需向量);
❌ 缺点:仅适用于所有行长度相同的场景,通用性差。
四种编码方式对比表
| 编码方式 | 存储形式 | 核心优势 | 适用场景 |
|---|---|---|---|
| row_splits | 拆分点向量 | 索引/切片高效 | 频繁单行查询、切片操作 |
| value_rowids | 行索引向量 | 大量空行存储高效、兼容segment运算 | 含大量空行的数据集、分段求和/均值 |
| row_lengths | 行长度向量 | 拼接/拆分高效 | 频繁拼接多个RaggedTensor |
| uniform_row_length | 长度标量 | 存储效率最高 | 非内层维度所有行长度相同的场景 |
三、多个不规则维度的编码
核心原理
多不规则维度的RaggedTensor通过嵌套RaggedTensor 编码:外层RaggedTensor的values是内层RaggedTensor,每一层嵌套对应一个不规则维度(ragged_rank+1)。
方式1:嵌套from_row_splits构造
python
# 外层RaggedTensor:values是内层RaggedTensor,row_splits=[0,1,1,5]
# 内层RaggedTensor:values=[10-19],row_splits=[0,3,3,5,9,10]
rt = tf.RaggedTensor.from_row_splits(
values=tf.RaggedTensor.from_row_splits(
values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
row_splits=[0, 3, 3, 5, 9, 10]
),
row_splits=[0, 1, 1, 5]
)
print("多不规则维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩(ragged_rank):", rt.ragged_rank)
输出:
多不规则维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12]], [], [[], [13, 14], [15, 16, 17, 18], [19]]]>
形状: (3, None, None)
不规则秩(ragged_rank): 2
方式2:from_nested_row_splits(更简洁)
直接传入"嵌套拆分点列表",无需手动嵌套构造:
python
rt = tf.RaggedTensor.from_nested_row_splits(
flat_values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19], # 最内层扁平值
nested_row_splits=([0, 1, 1, 5], [0, 3, 3, 5, 9, 10]) # 外层+内层拆分点
)
print("from_nested_row_splits构造:", rt)
输出:与嵌套构造完全一致。
四、不规则秩(ragged_rank)与扁平值(flat_values)
1. 不规则秩(ragged_rank)
定义:RaggedTensor的嵌套深度 (即values被分区的次数),等于nested_row_splits的长度。
- ragged_rank=1:一维不规则(如
[[1,2], [3]]); - ragged_rank=2:二维不规则(如
[[[1], []], [[2,3]]])。
2. 扁平值(flat_values)
定义:最内层的非嵌套张量(所有有效元素的扁平存储),是RaggedTensor的"数据核心"。
示例(ragged_rank=3)
python
# 4维结构:[batch, (paragraph), (sentence), (word)]
conversations = tf.ragged.constant(
[[[["I", "like", "ragged", "tensors."]],
[["Oh", "yeah?"], ["What", "can", "you", "use", "them", "for?"]],
[["Processing", "variable", "length", "data!"]]],
[[["I", "like", "cheese."], ["Do", "you?"]],
[["Yes."], ["I", "do."]]]])
print("形状:", conversations.shape)
print("不规则秩:", conversations.ragged_rank)
print("flat_values(前10个元素):", conversations.flat_values.numpy()[:10])
输出:
形状: (2, None, None, None)
不规则秩: 3
flat_values(前10个元素): [b'I' b'like' b'ragged' b'tensors.' b'Oh' b'yeah?' b'What' b'can' b'you' b'use']
关键解读
- ragged_rank=3:因为有3层不规则维度(paragraph、sentence、word);
- flat_values:所有单词的一维张量(共24个元素),是整个RaggedTensor的底层数据存储。
五、均匀维度的编码
RaggedTensor允许部分维度为"均匀"(长度固定),分为「均匀内层维度」和「均匀非内层维度」,编码方式不同。
1. 均匀内层维度
定义
最内层维度(flat_values)是多维密集张量(长度固定),外层为不规则维度。
示例
python
rt = tf.RaggedTensor.from_row_splits(
values=[[1, 3], [0, 0], [1, 3], [5, 3], [3, 3], [1, 2]], # 内层是2列的密集张量
row_splits=[0, 3, 4, 6] # 外层不规则:行0=3个元素,行1=1个,行2=2个
)
print("均匀内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)
print("flat_values形状:", rt.flat_values.shape)
print("flat_values:\n", rt.flat_values)
输出:
均匀内层维度RaggedTensor: <tf.RaggedTensor [[[1, 3], [0, 0], [1, 3]], [[5, 3]], [[3, 3], [1, 2]]]>
形状: (3, None, 2)
不规则秩: 1
flat_values形状: (6, 2)
flat_values:
[[1 3]
[0 0]
[1 3]
[5 3]
[3 3]
[1 2]]
关键解读
- 形状
(3, None, 2):3行(均匀)、每行元素数可变(None)、每个元素是2列(均匀); - flat_values是
(6,2)的密集张量(内层维度固定为2),外层通过row_splits拆分为可变长度行。
2. 均匀非内层维度
定义
非内层维度为均匀(所有行长度相同),通过uniform_row_length编码行分区。
示例
python
rt = tf.RaggedTensor.from_uniform_row_length(
values=tf.RaggedTensor.from_row_splits(
values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
row_splits=[0, 3, 5, 9, 10] # 内层不规则
),
uniform_row_length=2 # 外层均匀:每行固定2个元素
)
print("均匀非内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)
输出:
均匀非内层维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12], [13, 14]], [[15, 16, 17, 18], [19]]]>
形状: (2, 2, None)
不规则秩: 2
关键解读
- 形状
(2, 2, None):2行(均匀)、每行固定2个元素(均匀)、每个元素长度可变(None); - 外层通过
uniform_row_length=2编码(仅存一个标量),内层通过row_splits编码不规则维度。
核心总结
1. 编码核心逻辑
RaggedTensor的底层是"扁平存储+分区规则",避免嵌套列表的低效存储,四种行分区编码适配不同场景;
2. 多维度扩展
多不规则维度通过嵌套RaggedTensor实现,ragged_rank表示嵌套深度,flat_values是最内层扁平数据;
3. 均匀维度兼容
支持部分维度均匀(内层/非内层),分别通过"多维flat_values"和"uniform_row_length"编码;
4. 性能优化
- 频繁索引 → 选row_splits;
- 大量空行 → 选value_rowids;
- 频繁拼接 → 选row_lengths;
- 行长度均匀 → 选uniform_row_length。
理解RaggedTensor的编码原理,能帮助你在高性能场景(如大规模可变长度数据处理)中选择最优的构造/操作方式,避免性能瓶颈。