TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式

这份文档核心讲解了 TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式 :一种是「极简直接法」(tf.ragged.constant),另一种是「灵活定制法」(工厂方法,从扁平值+行分区规则构造)。以下是通俗化拆解,全程贴合原文例子和逻辑:

一、核心前提

构造 RaggedTensor 的核心是「描述可变长度的嵌套结构」:要么直接给嵌套列表(让 TF 自动识别长度),要么先给扁平值+「行分区规则」(告诉 TF 如何把扁平值拆成可变长度的行)。

二、方法1:最简方式 → tf.ragged.constant

核心逻辑

直接把「嵌套的 Python 列表/NumPy 数组」传给 tf.ragged.constant,TF 会自动识别每一层的可变长度,生成对应的 RaggedTensor(无需手动定义长度规则)。

适用场景

已有现成的嵌套列表/数组,且能直接看出行/层的长度分布(比如手动整理的句子、段落数据)。

原文例子拆解
示例1:二维 RaggedTensor(句子,每行是单词列表)
python 复制代码
# 输入:2个句子(第一句6个单词,第二句5个单词,长度可变)
sentences = tf.ragged.constant([
    ["Let's", "build", "some", "ragged", "tensors", "!"],
    ["We", "can", "use", "tf.ragged.constant", "."]])
print(sentences)
# 输出:二维RaggedTensor,保留每行的可变长度
<tf.RaggedTensor [[b"Let's", b'build', b'some', b'ragged', b'tensors', b'!'],
 [b'We', b'can', b'use', b'tf.ragged.constant', b'.']]>
  • 关键:字符串会自动转成字节型(b前缀),不影响使用;每行长度可不同(6个 vs 5个)。
示例2:三维 RaggedTensor(段落,每段包含多个句子,句子长度可变)
python 复制代码
# 输入:2个段落 → 每个段落包含2个句子 → 句子长度可变
paragraphs = tf.ragged.constant([
    [['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']],  # 段落1:2个句子(各4个单词)
    [['Do', 'you', 'want', 'to', 'come', 'visit'], ["I'm", 'free', 'tomorrow']],  # 段落2:句子1(6词)、句子2(3词)
])
print(paragraphs)
# 输出:三维RaggedTensor,每层长度都可变
<tf.RaggedTensor [[[b'I', b'have', b'a', b'cat'], [b'His', b'name', b'is', b'Mat']],
 [[b'Do', b'you', b'want', b'to', b'come', b'visit'],
  [b"I'm", b'free', b'tomorrow']]]>

三、方法2:灵活定制 → 工厂方法(扁平值+行分区张量)

核心逻辑

当数据是「扁平的一维列表」,但知道「每个值属于哪一行/每行长度/每行起止索引」时,用工厂方法把扁平值拆成可变长度的 RaggedTensor。

适用场景

数据是扁平存储的(比如从文件/数据库读取的一维值列表),但有额外的"行分区规则"(比如记录了每个值所属的行号、每行的长度)。

原文3种工厂方法拆解(目标都是构造 [[3,1,4,1], [], [5,9], [2]]

所有方法的输入核心:values(扁平值列表) + 「行分区规则」,以下是一一对应解释:

工厂方法 行分区规则(参数) 原文例子拆解
from_value_rowids value_rowids(每个值的行号) values=[3,1,4,1,5,9,2] value_rowids=[0,0,0,0,2,2,3] → 3/1/4/1属于行0,5/9属于行2,2属于行3,行1无值→空行
from_row_lengths row_lengths(每行的长度) values=[3,1,4,1,5,9,2] row_lengths=[4,0,2,1] → 行0长度4(取前4个值),行1长度0(空),行2长度2(取5/9),行3长度1(取2)
from_row_splits row_splits(每行起止索引) values=[3,1,4,1,5,9,2] row_splits=[0,4,4,6,7] → 行0:04(3/1/4/1),行1:44(空),行2:46(5/9),行3:67(2)
原文代码验证(3种方法结果完全一致)
python 复制代码
# 方法1:from_value_rowids
print(tf.RaggedTensor.from_value_rowids(
    values=[3, 1, 4, 1, 5, 9, 2],
    value_rowids=[0, 0, 0, 0, 2, 2, 3]))
# 输出:<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

# 方法2:from_row_lengths
print(tf.RaggedTensor.from_row_lengths(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_lengths=[4, 0, 2, 1]))
# 输出同上

# 方法3:from_row_splits
print(tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_splits=[0, 4, 4, 6, 7]))
# 输出同上

四、关键注意事项(原文补充)

  1. 校验机制:默认情况下,工厂方法会校验「行分区规则」和「values长度」是否匹配(比如 row_lengths 总和是否等于 values 长度),避免构造出非法的 RaggedTensor;

  2. 性能优化 :如果能100%保证输入的行分区规则合法(比如提前校验过),可以加 validate=False 跳过校验,提升构造速度:

    python 复制代码
    # 跳过校验,加速构造
    tf.RaggedTensor.from_row_lengths(values=[3,1,4,1,5,9,2], row_lengths=[4,0,2,1], validate=False)

核心总结

构造方法 优点 缺点 适用场景
tf.ragged.constant 写法极简,直观 依赖现成的嵌套列表 手动整理的嵌套数据(句子、段落)
工厂方法(from_*) 灵活,适配扁平数据源 需要定义行分区规则 扁平存储的数据(文件/数据库读取)

两种方法最终都能构造出 RaggedTensor,核心是根据「数据的原始形态」选择:有嵌套列表用 constant,有扁平值+行规则用工厂方法。

相关推荐
Java后端的Ai之路14 分钟前
【神经网络基础】-神经网络学习全过程(大白话版)
人工智能·深度学习·神经网络·学习
庚昀◟28 分钟前
用AI来“造AI”!Nexent部署本地智能体的沉浸式体验
人工智能·ai·nlp·持续部署
喜欢吃豆41 分钟前
OpenAI Realtime API 深度技术架构与实现指南——如何实现AI实时通话
人工智能·语言模型·架构·大模型
数据分析能量站43 分钟前
AI如何重塑个人生产力、组织架构和经济模式
人工智能
wscats2 小时前
Markdown 编辑器技术调研
前端·人工智能·markdown
AI科技星2 小时前
张祥前统一场论宇宙大统一方程的求导验证
服务器·人工智能·科技·线性代数·算法·生活
GIS数据转换器2 小时前
基于知识图谱的个性化旅游规划平台
人工智能·3d·无人机·知识图谱·旅游
EnoYao2 小时前
Markdown 编辑器技术调研
前端·javascript·人工智能
TMT星球2 小时前
曹操出行上市后首次战略并购,进军万亿to B商旅市场
人工智能·汽车
Coder_Boy_2 小时前
Spring AI 源码大白话解析
java·人工智能·spring