tensorflow 零基础吃透:RaggedTensor 的底层编码原理

零基础吃透:RaggedTensor的底层编码原理

RaggedTensor的核心设计是**"扁平化存储+行分区描述"** ------ 不直接存储嵌套列表(低效),而是将所有有效元素扁平存储在values张量中,再通过row_partition(行分区)描述"如何将扁平值拆分为可变长度的行"。以下从「核心编码结构」「四种行分区编码」「多不规则维度」「不规则秩与扁平值」「均匀维度编码」五大模块,拆解底层原理和用法。

一、RaggedTensor的核心编码结构

核心公式

复制代码
RaggedTensor = values(扁平张量) + row_partition(行分区规则)
  • values:所有有效元素按顺序拼接成的一维/多维扁平张量(无嵌套,无空值);
  • row_partition:描述"如何将values拆分为可变长度行"的规则,支持4种编码方式(下文详解)。

基础示例(row_splits编码)

python 复制代码
import tensorflow as tf

# 构造RaggedTensor:values+row_splits
rt = tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],  # 所有有效元素的扁平列表
    row_splits=[0, 4, 4, 6, 7]      # 行拆分点:[起始, 行0结束, 行1结束, 行2结束, 行3结束]
)
print("构造的RaggedTensor:", rt)

输出<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

拆分点逻辑(关键!)

row_splits的每个值是values的索引,定义每行的元素范围:

行索引 拆分点范围 values切片 行内容
0 0 → 4 values[0:4] → [3,1,4,1] [3,1,4,1]
1 4 → 4 values[4:4] → [] [](空行)
2 4 → 6 values[4:6] → [5,9] [5,9]
3 6 → 7 values[6:7] → [2] [2]

二、四种row_partition编码方式(行分区规则)

TF内部管理行分区的编码方式,不同编码适配不同场景(效率/兼容性),以下是4种核心编码的原理、示例和优缺点:

1. row_splits(拆分点编码)

定义

一维整型向量,每个值表示values中"行的结束索引",长度=行数+1(首元素必为0,末元素必为values长度)。

示例(复用上文)
python 复制代码
rt = tf.RaggedTensor.from_row_splits(values=[3,1,4,1,5,9,2], row_splits=[0,4,4,6,7])
print("row_splits:", rt.row_splits.numpy())  # 直接访问拆分点

输出[0 4 4 6 7]

核心优缺点

✅ 优点:恒定时间索引/切片 (直接通过拆分点定位行),适合频繁索引的场景;

❌ 缺点:空行仍需占用拆分点位置,存储大量空行时效率低。

2. value_rowids(值的行索引编码)

定义

一维整型向量,长度=values长度,每个值表示"对应values元素所属的行索引"。

示例
python 复制代码
# 构造:values=[3,1,4,1,5,9,2],对应行索引[0,0,0,0,2,2,3]
rt = tf.RaggedTensor.from_value_rowids(
    values=[3,1,4,1,5,9,2],
    value_rowids=[0,0,0,0,2,2,3],
    nrows=4  # 总行数(包含空行1)
)
print("value_rowids构造的RaggedTensor:", rt)
print("value_rowids:", rt.value_rowids.numpy())

输出

复制代码
value_rowids构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
value_rowids: [0 0 0 0 2 2 3]
核心优缺点

✅ 优点:

  • 存储大量空行时高效(仅存储有效元素的行索引,空行无开销);
  • 兼容tf.segment_sum等分段运算(输入格式匹配);
    ❌ 缺点:索引单行需遍历value_rowids,效率低于row_splits。

3. row_lengths(行长度编码)

定义

一维整型向量,长度=行数,每个值表示"对应行的元素长度"(空行长度为0)。

示例
python 复制代码
# 构造:行长度[4,0,2,1] → 对应行0:4个元素,行1:0个,行2:2个,行3:1个
rt = tf.RaggedTensor.from_row_lengths(
    values=[3,1,4,1,5,9,2],
    row_lengths=[4,0,2,1]
)
print("row_lengths构造的RaggedTensor:", rt)
print("row_lengths:", rt.row_lengths.numpy())

输出

复制代码
row_lengths构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
row_lengths: [4 0 2 1]
核心优缺点

✅ 优点:拼接(concat)效率高 (拼接时仅需合并row_lengths,无需修改values);

❌ 缺点:空行仍需存储长度0,大量空行时效率低于value_rowids。

4. uniform_row_length(均匀行长度编码)

定义

整型标量,表示"所有行的长度相同"(仅用于"非内层维度为均匀"的场景)。

示例(见下文"均匀非内层维度")
核心优缺点

✅ 优点:存储效率极高(仅需一个标量,无需向量);

❌ 缺点:仅适用于所有行长度相同的场景,通用性差。

四种编码方式对比表

编码方式 存储形式 核心优势 适用场景
row_splits 拆分点向量 索引/切片高效 频繁单行查询、切片操作
value_rowids 行索引向量 大量空行存储高效、兼容segment运算 含大量空行的数据集、分段求和/均值
row_lengths 行长度向量 拼接/拆分高效 频繁拼接多个RaggedTensor
uniform_row_length 长度标量 存储效率最高 非内层维度所有行长度相同的场景

三、多个不规则维度的编码

核心原理

多不规则维度的RaggedTensor通过嵌套RaggedTensor 编码:外层RaggedTensor的values是内层RaggedTensor,每一层嵌套对应一个不规则维度(ragged_rank+1)。

方式1:嵌套from_row_splits构造

python 复制代码
# 外层RaggedTensor:values是内层RaggedTensor,row_splits=[0,1,1,5]
# 内层RaggedTensor:values=[10-19],row_splits=[0,3,3,5,9,10]
rt = tf.RaggedTensor.from_row_splits(
    values=tf.RaggedTensor.from_row_splits(
        values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        row_splits=[0, 3, 3, 5, 9, 10]
    ),
    row_splits=[0, 1, 1, 5]
)
print("多不规则维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩(ragged_rank):", rt.ragged_rank)

输出

复制代码
多不规则维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12]], [], [[], [13, 14], [15, 16, 17, 18], [19]]]>
形状: (3, None, None)
不规则秩(ragged_rank): 2

方式2:from_nested_row_splits(更简洁)

直接传入"嵌套拆分点列表",无需手动嵌套构造:

python 复制代码
rt = tf.RaggedTensor.from_nested_row_splits(
    flat_values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  # 最内层扁平值
    nested_row_splits=([0, 1, 1, 5], [0, 3, 3, 5, 9, 10])  # 外层+内层拆分点
)
print("from_nested_row_splits构造:", rt)

输出:与嵌套构造完全一致。

四、不规则秩(ragged_rank)与扁平值(flat_values)

1. 不规则秩(ragged_rank)

定义:RaggedTensor的嵌套深度 (即values被分区的次数),等于nested_row_splits的长度。

  • ragged_rank=1:一维不规则(如[[1,2], [3]]);
  • ragged_rank=2:二维不规则(如[[[1], []], [[2,3]]])。

2. 扁平值(flat_values)

定义:最内层的非嵌套张量(所有有效元素的扁平存储),是RaggedTensor的"数据核心"。

示例(ragged_rank=3)

python 复制代码
# 4维结构:[batch, (paragraph), (sentence), (word)]
conversations = tf.ragged.constant(
    [[[["I", "like", "ragged", "tensors."]],
      [["Oh", "yeah?"], ["What", "can", "you", "use", "them", "for?"]],
      [["Processing", "variable", "length", "data!"]]],
     [[["I", "like", "cheese."], ["Do", "you?"]],
      [["Yes."], ["I", "do."]]]])

print("形状:", conversations.shape)
print("不规则秩:", conversations.ragged_rank)
print("flat_values(前10个元素):", conversations.flat_values.numpy()[:10])

输出

复制代码
形状: (2, None, None, None)
不规则秩: 3
flat_values(前10个元素): [b'I' b'like' b'ragged' b'tensors.' b'Oh' b'yeah?' b'What' b'can' b'you' b'use']
关键解读
  • ragged_rank=3:因为有3层不规则维度(paragraph、sentence、word);
  • flat_values:所有单词的一维张量(共24个元素),是整个RaggedTensor的底层数据存储。

五、均匀维度的编码

RaggedTensor允许部分维度为"均匀"(长度固定),分为「均匀内层维度」和「均匀非内层维度」,编码方式不同。

1. 均匀内层维度

定义

最内层维度(flat_values)是多维密集张量(长度固定),外层为不规则维度。

示例
python 复制代码
rt = tf.RaggedTensor.from_row_splits(
    values=[[1, 3], [0, 0], [1, 3], [5, 3], [3, 3], [1, 2]],  # 内层是2列的密集张量
    row_splits=[0, 3, 4, 6]  # 外层不规则:行0=3个元素,行1=1个,行2=2个
)
print("均匀内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)
print("flat_values形状:", rt.flat_values.shape)
print("flat_values:\n", rt.flat_values)

输出

复制代码
均匀内层维度RaggedTensor: <tf.RaggedTensor [[[1, 3], [0, 0], [1, 3]], [[5, 3]], [[3, 3], [1, 2]]]>
形状: (3, None, 2)
不规则秩: 1
flat_values形状: (6, 2)
flat_values:
 [[1 3]
 [0 0]
 [1 3]
 [5 3]
 [3 3]
 [1 2]]
关键解读
  • 形状(3, None, 2):3行(均匀)、每行元素数可变(None)、每个元素是2列(均匀);
  • flat_values是(6,2)的密集张量(内层维度固定为2),外层通过row_splits拆分为可变长度行。

2. 均匀非内层维度

定义

非内层维度为均匀(所有行长度相同),通过uniform_row_length编码行分区。

示例
python 复制代码
rt = tf.RaggedTensor.from_uniform_row_length(
    values=tf.RaggedTensor.from_row_splits(
        values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        row_splits=[0, 3, 5, 9, 10]  # 内层不规则
    ),
    uniform_row_length=2  # 外层均匀:每行固定2个元素
)
print("均匀非内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)

输出

复制代码
均匀非内层维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12], [13, 14]], [[15, 16, 17, 18], [19]]]>
形状: (2, 2, None)
不规则秩: 2
关键解读
  • 形状(2, 2, None):2行(均匀)、每行固定2个元素(均匀)、每个元素长度可变(None);
  • 外层通过uniform_row_length=2编码(仅存一个标量),内层通过row_splits编码不规则维度。

核心总结

1. 编码核心逻辑

RaggedTensor的底层是"扁平存储+分区规则",避免嵌套列表的低效存储,四种行分区编码适配不同场景;

2. 多维度扩展

多不规则维度通过嵌套RaggedTensor实现,ragged_rank表示嵌套深度,flat_values是最内层扁平数据;

3. 均匀维度兼容

支持部分维度均匀(内层/非内层),分别通过"多维flat_values"和"uniform_row_length"编码;

4. 性能优化

  • 频繁索引 → 选row_splits;
  • 大量空行 → 选value_rowids;
  • 频繁拼接 → 选row_lengths;
  • 行长度均匀 → 选uniform_row_length。

理解RaggedTensor的编码原理,能帮助你在高性能场景(如大规模可变长度数据处理)中选择最优的构造/操作方式,避免性能瓶颈。

相关推荐
NAGNIP2 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab3 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab3 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP7 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年7 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼7 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS8 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区9 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈9 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang9 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx