tensorflow 零基础吃透:RaggedTensor 的底层编码原理

零基础吃透:RaggedTensor的底层编码原理

RaggedTensor的核心设计是**"扁平化存储+行分区描述"** ------ 不直接存储嵌套列表(低效),而是将所有有效元素扁平存储在values张量中,再通过row_partition(行分区)描述"如何将扁平值拆分为可变长度的行"。以下从「核心编码结构」「四种行分区编码」「多不规则维度」「不规则秩与扁平值」「均匀维度编码」五大模块,拆解底层原理和用法。

一、RaggedTensor的核心编码结构

核心公式

复制代码
RaggedTensor = values(扁平张量) + row_partition(行分区规则)
  • values:所有有效元素按顺序拼接成的一维/多维扁平张量(无嵌套,无空值);
  • row_partition:描述"如何将values拆分为可变长度行"的规则,支持4种编码方式(下文详解)。

基础示例(row_splits编码)

python 复制代码
import tensorflow as tf

# 构造RaggedTensor:values+row_splits
rt = tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],  # 所有有效元素的扁平列表
    row_splits=[0, 4, 4, 6, 7]      # 行拆分点:[起始, 行0结束, 行1结束, 行2结束, 行3结束]
)
print("构造的RaggedTensor:", rt)

输出<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

拆分点逻辑(关键!)

row_splits的每个值是values的索引,定义每行的元素范围:

行索引 拆分点范围 values切片 行内容
0 0 → 4 values[0:4] → [3,1,4,1] [3,1,4,1]
1 4 → 4 values[4:4] → [] [](空行)
2 4 → 6 values[4:6] → [5,9] [5,9]
3 6 → 7 values[6:7] → [2] [2]

二、四种row_partition编码方式(行分区规则)

TF内部管理行分区的编码方式,不同编码适配不同场景(效率/兼容性),以下是4种核心编码的原理、示例和优缺点:

1. row_splits(拆分点编码)

定义

一维整型向量,每个值表示values中"行的结束索引",长度=行数+1(首元素必为0,末元素必为values长度)。

示例(复用上文)
python 复制代码
rt = tf.RaggedTensor.from_row_splits(values=[3,1,4,1,5,9,2], row_splits=[0,4,4,6,7])
print("row_splits:", rt.row_splits.numpy())  # 直接访问拆分点

输出[0 4 4 6 7]

核心优缺点

✅ 优点:恒定时间索引/切片 (直接通过拆分点定位行),适合频繁索引的场景;

❌ 缺点:空行仍需占用拆分点位置,存储大量空行时效率低。

2. value_rowids(值的行索引编码)

定义

一维整型向量,长度=values长度,每个值表示"对应values元素所属的行索引"。

示例
python 复制代码
# 构造:values=[3,1,4,1,5,9,2],对应行索引[0,0,0,0,2,2,3]
rt = tf.RaggedTensor.from_value_rowids(
    values=[3,1,4,1,5,9,2],
    value_rowids=[0,0,0,0,2,2,3],
    nrows=4  # 总行数(包含空行1)
)
print("value_rowids构造的RaggedTensor:", rt)
print("value_rowids:", rt.value_rowids.numpy())

输出

复制代码
value_rowids构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
value_rowids: [0 0 0 0 2 2 3]
核心优缺点

✅ 优点:

  • 存储大量空行时高效(仅存储有效元素的行索引,空行无开销);
  • 兼容tf.segment_sum等分段运算(输入格式匹配);
    ❌ 缺点:索引单行需遍历value_rowids,效率低于row_splits。

3. row_lengths(行长度编码)

定义

一维整型向量,长度=行数,每个值表示"对应行的元素长度"(空行长度为0)。

示例
python 复制代码
# 构造:行长度[4,0,2,1] → 对应行0:4个元素,行1:0个,行2:2个,行3:1个
rt = tf.RaggedTensor.from_row_lengths(
    values=[3,1,4,1,5,9,2],
    row_lengths=[4,0,2,1]
)
print("row_lengths构造的RaggedTensor:", rt)
print("row_lengths:", rt.row_lengths.numpy())

输出

复制代码
row_lengths构造的RaggedTensor: <tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
row_lengths: [4 0 2 1]
核心优缺点

✅ 优点:拼接(concat)效率高 (拼接时仅需合并row_lengths,无需修改values);

❌ 缺点:空行仍需存储长度0,大量空行时效率低于value_rowids。

4. uniform_row_length(均匀行长度编码)

定义

整型标量,表示"所有行的长度相同"(仅用于"非内层维度为均匀"的场景)。

示例(见下文"均匀非内层维度")
核心优缺点

✅ 优点:存储效率极高(仅需一个标量,无需向量);

❌ 缺点:仅适用于所有行长度相同的场景,通用性差。

四种编码方式对比表

编码方式 存储形式 核心优势 适用场景
row_splits 拆分点向量 索引/切片高效 频繁单行查询、切片操作
value_rowids 行索引向量 大量空行存储高效、兼容segment运算 含大量空行的数据集、分段求和/均值
row_lengths 行长度向量 拼接/拆分高效 频繁拼接多个RaggedTensor
uniform_row_length 长度标量 存储效率最高 非内层维度所有行长度相同的场景

三、多个不规则维度的编码

核心原理

多不规则维度的RaggedTensor通过嵌套RaggedTensor 编码:外层RaggedTensor的values是内层RaggedTensor,每一层嵌套对应一个不规则维度(ragged_rank+1)。

方式1:嵌套from_row_splits构造

python 复制代码
# 外层RaggedTensor:values是内层RaggedTensor,row_splits=[0,1,1,5]
# 内层RaggedTensor:values=[10-19],row_splits=[0,3,3,5,9,10]
rt = tf.RaggedTensor.from_row_splits(
    values=tf.RaggedTensor.from_row_splits(
        values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        row_splits=[0, 3, 3, 5, 9, 10]
    ),
    row_splits=[0, 1, 1, 5]
)
print("多不规则维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩(ragged_rank):", rt.ragged_rank)

输出

复制代码
多不规则维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12]], [], [[], [13, 14], [15, 16, 17, 18], [19]]]>
形状: (3, None, None)
不规则秩(ragged_rank): 2

方式2:from_nested_row_splits(更简洁)

直接传入"嵌套拆分点列表",无需手动嵌套构造:

python 复制代码
rt = tf.RaggedTensor.from_nested_row_splits(
    flat_values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],  # 最内层扁平值
    nested_row_splits=([0, 1, 1, 5], [0, 3, 3, 5, 9, 10])  # 外层+内层拆分点
)
print("from_nested_row_splits构造:", rt)

输出:与嵌套构造完全一致。

四、不规则秩(ragged_rank)与扁平值(flat_values)

1. 不规则秩(ragged_rank)

定义:RaggedTensor的嵌套深度 (即values被分区的次数),等于nested_row_splits的长度。

  • ragged_rank=1:一维不规则(如[[1,2], [3]]);
  • ragged_rank=2:二维不规则(如[[[1], []], [[2,3]]])。

2. 扁平值(flat_values)

定义:最内层的非嵌套张量(所有有效元素的扁平存储),是RaggedTensor的"数据核心"。

示例(ragged_rank=3)

python 复制代码
# 4维结构:[batch, (paragraph), (sentence), (word)]
conversations = tf.ragged.constant(
    [[[["I", "like", "ragged", "tensors."]],
      [["Oh", "yeah?"], ["What", "can", "you", "use", "them", "for?"]],
      [["Processing", "variable", "length", "data!"]]],
     [[["I", "like", "cheese."], ["Do", "you?"]],
      [["Yes."], ["I", "do."]]]])

print("形状:", conversations.shape)
print("不规则秩:", conversations.ragged_rank)
print("flat_values(前10个元素):", conversations.flat_values.numpy()[:10])

输出

复制代码
形状: (2, None, None, None)
不规则秩: 3
flat_values(前10个元素): [b'I' b'like' b'ragged' b'tensors.' b'Oh' b'yeah?' b'What' b'can' b'you' b'use']
关键解读
  • ragged_rank=3:因为有3层不规则维度(paragraph、sentence、word);
  • flat_values:所有单词的一维张量(共24个元素),是整个RaggedTensor的底层数据存储。

五、均匀维度的编码

RaggedTensor允许部分维度为"均匀"(长度固定),分为「均匀内层维度」和「均匀非内层维度」,编码方式不同。

1. 均匀内层维度

定义

最内层维度(flat_values)是多维密集张量(长度固定),外层为不规则维度。

示例
python 复制代码
rt = tf.RaggedTensor.from_row_splits(
    values=[[1, 3], [0, 0], [1, 3], [5, 3], [3, 3], [1, 2]],  # 内层是2列的密集张量
    row_splits=[0, 3, 4, 6]  # 外层不规则:行0=3个元素,行1=1个,行2=2个
)
print("均匀内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)
print("flat_values形状:", rt.flat_values.shape)
print("flat_values:\n", rt.flat_values)

输出

复制代码
均匀内层维度RaggedTensor: <tf.RaggedTensor [[[1, 3], [0, 0], [1, 3]], [[5, 3]], [[3, 3], [1, 2]]]>
形状: (3, None, 2)
不规则秩: 1
flat_values形状: (6, 2)
flat_values:
 [[1 3]
 [0 0]
 [1 3]
 [5 3]
 [3 3]
 [1 2]]
关键解读
  • 形状(3, None, 2):3行(均匀)、每行元素数可变(None)、每个元素是2列(均匀);
  • flat_values是(6,2)的密集张量(内层维度固定为2),外层通过row_splits拆分为可变长度行。

2. 均匀非内层维度

定义

非内层维度为均匀(所有行长度相同),通过uniform_row_length编码行分区。

示例
python 复制代码
rt = tf.RaggedTensor.from_uniform_row_length(
    values=tf.RaggedTensor.from_row_splits(
        values=[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        row_splits=[0, 3, 5, 9, 10]  # 内层不规则
    ),
    uniform_row_length=2  # 外层均匀:每行固定2个元素
)
print("均匀非内层维度RaggedTensor:", rt)
print("形状:", rt.shape)
print("不规则秩:", rt.ragged_rank)

输出

复制代码
均匀非内层维度RaggedTensor: <tf.RaggedTensor [[[10, 11, 12], [13, 14]], [[15, 16, 17, 18], [19]]]>
形状: (2, 2, None)
不规则秩: 2
关键解读
  • 形状(2, 2, None):2行(均匀)、每行固定2个元素(均匀)、每个元素长度可变(None);
  • 外层通过uniform_row_length=2编码(仅存一个标量),内层通过row_splits编码不规则维度。

核心总结

1. 编码核心逻辑

RaggedTensor的底层是"扁平存储+分区规则",避免嵌套列表的低效存储,四种行分区编码适配不同场景;

2. 多维度扩展

多不规则维度通过嵌套RaggedTensor实现,ragged_rank表示嵌套深度,flat_values是最内层扁平数据;

3. 均匀维度兼容

支持部分维度均匀(内层/非内层),分别通过"多维flat_values"和"uniform_row_length"编码;

4. 性能优化

  • 频繁索引 → 选row_splits;
  • 大量空行 → 选value_rowids;
  • 频繁拼接 → 选row_lengths;
  • 行长度均匀 → 选uniform_row_length。

理解RaggedTensor的编码原理,能帮助你在高性能场景(如大规模可变长度数据处理)中选择最优的构造/操作方式,避免性能瓶颈。

相关推荐
大佐不会说日语~6 小时前
Spring AI Alibaba 对话记忆丢失问题:Redis 缓存过期后如何恢复 AI 上下文
java·人工智能·spring boot·redis·spring·缓存
渡我白衣7 小时前
计算机组成原理(6):进位计数制
c++·人工智能·深度学习·神经网络·机器学习·硬件工程
古城小栈7 小时前
Spring AI 1.1:快速接入主流 LLM,实现智能问答与文本生成
java·人工智能·spring boot·spring
tap.AI7 小时前
图片转文字技术(二)AI翻译的核心技术解析-从神经网络到多模态融合
人工智能·深度学习·神经网络
东坡肘子7 小时前
周日小插曲 -- 肘子的 Swift 周报 #115
人工智能·swiftui·swift
jifengzhiling7 小时前
卡尔曼增益:动态权重,最优估计
人工智能·算法·机器学习
emfuture7 小时前
传统劳动密集型加工厂,面对日益普及的自动化技术,应如何实现转型升级?
大数据·人工智能·智能制造·工业互联网
Zzz 小生7 小时前
Github-Lobe Chat:下一代开源AI聊天框架,重新定义人机交互体验
人工智能·开源·github·人机交互
说私域7 小时前
新零售第一阶段传统零售商的困境突破与二次增长路径——基于定制开发AI智能名片S2B2C商城小程序的实践研究
人工智能·小程序·零售