TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式

这份文档核心讲解了 TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式 :一种是「极简直接法」(tf.ragged.constant),另一种是「灵活定制法」(工厂方法,从扁平值+行分区规则构造)。以下是通俗化拆解,全程贴合原文例子和逻辑:

一、核心前提

构造 RaggedTensor 的核心是「描述可变长度的嵌套结构」:要么直接给嵌套列表(让 TF 自动识别长度),要么先给扁平值+「行分区规则」(告诉 TF 如何把扁平值拆成可变长度的行)。

二、方法1:最简方式 → tf.ragged.constant

核心逻辑

直接把「嵌套的 Python 列表/NumPy 数组」传给 tf.ragged.constant,TF 会自动识别每一层的可变长度,生成对应的 RaggedTensor(无需手动定义长度规则)。

适用场景

已有现成的嵌套列表/数组,且能直接看出行/层的长度分布(比如手动整理的句子、段落数据)。

原文例子拆解
示例1:二维 RaggedTensor(句子,每行是单词列表)
python 复制代码
# 输入:2个句子(第一句6个单词,第二句5个单词,长度可变)
sentences = tf.ragged.constant([
    ["Let's", "build", "some", "ragged", "tensors", "!"],
    ["We", "can", "use", "tf.ragged.constant", "."]])
print(sentences)
# 输出:二维RaggedTensor,保留每行的可变长度
<tf.RaggedTensor [[b"Let's", b'build', b'some', b'ragged', b'tensors', b'!'],
 [b'We', b'can', b'use', b'tf.ragged.constant', b'.']]>
  • 关键:字符串会自动转成字节型(b前缀),不影响使用;每行长度可不同(6个 vs 5个)。
示例2:三维 RaggedTensor(段落,每段包含多个句子,句子长度可变)
python 复制代码
# 输入:2个段落 → 每个段落包含2个句子 → 句子长度可变
paragraphs = tf.ragged.constant([
    [['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']],  # 段落1:2个句子(各4个单词)
    [['Do', 'you', 'want', 'to', 'come', 'visit'], ["I'm", 'free', 'tomorrow']],  # 段落2:句子1(6词)、句子2(3词)
])
print(paragraphs)
# 输出:三维RaggedTensor,每层长度都可变
<tf.RaggedTensor [[[b'I', b'have', b'a', b'cat'], [b'His', b'name', b'is', b'Mat']],
 [[b'Do', b'you', b'want', b'to', b'come', b'visit'],
  [b"I'm", b'free', b'tomorrow']]]>

三、方法2:灵活定制 → 工厂方法(扁平值+行分区张量)

核心逻辑

当数据是「扁平的一维列表」,但知道「每个值属于哪一行/每行长度/每行起止索引」时,用工厂方法把扁平值拆成可变长度的 RaggedTensor。

适用场景

数据是扁平存储的(比如从文件/数据库读取的一维值列表),但有额外的"行分区规则"(比如记录了每个值所属的行号、每行的长度)。

原文3种工厂方法拆解(目标都是构造 [[3,1,4,1], [], [5,9], [2]]

所有方法的输入核心:values(扁平值列表) + 「行分区规则」,以下是一一对应解释:

工厂方法 行分区规则(参数) 原文例子拆解
from_value_rowids value_rowids(每个值的行号) values=3,1,4,1,5,9,2 value_rowids=0,0,0,0,2,2,3 → 3/1/4/1属于行0,5/9属于行2,2属于行3,行1无值→空行
from_row_lengths row_lengths(每行的长度) values=3,1,4,1,5,9,2 row_lengths=4,0,2,1 → 行0长度4(取前4个值),行1长度0(空),行2长度2(取5/9),行3长度1(取2)
from_row_splits row_splits(每行起止索引) values=3,1,4,1,5,9,2 row_splits=0,4,4,6,7 → 行0:04(3/1/4/1),行1:44(空),行2:46(5/9),行3:67(2)
原文代码验证(3种方法结果完全一致)
python 复制代码
# 方法1:from_value_rowids
print(tf.RaggedTensor.from_value_rowids(
    values=[3, 1, 4, 1, 5, 9, 2],
    value_rowids=[0, 0, 0, 0, 2, 2, 3]))
# 输出:<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

# 方法2:from_row_lengths
print(tf.RaggedTensor.from_row_lengths(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_lengths=[4, 0, 2, 1]))
# 输出同上

# 方法3:from_row_splits
print(tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_splits=[0, 4, 4, 6, 7]))
# 输出同上

四、关键注意事项(原文补充)

  1. 校验机制:默认情况下,工厂方法会校验「行分区规则」和「values长度」是否匹配(比如 row_lengths 总和是否等于 values 长度),避免构造出非法的 RaggedTensor;

  2. 性能优化 :如果能100%保证输入的行分区规则合法(比如提前校验过),可以加 validate=False 跳过校验,提升构造速度:

    python 复制代码
    # 跳过校验,加速构造
    tf.RaggedTensor.from_row_lengths(values=[3,1,4,1,5,9,2], row_lengths=[4,0,2,1], validate=False)

核心总结

构造方法 优点 缺点 适用场景
tf.ragged.constant 写法极简,直观 依赖现成的嵌套列表 手动整理的嵌套数据(句子、段落)
工厂方法(from_*) 灵活,适配扁平数据源 需要定义行分区规则 扁平存储的数据(文件/数据库读取)

两种方法最终都能构造出 RaggedTensor,核心是根据「数据的原始形态」选择:有嵌套列表用 constant,有扁平值+行规则用工厂方法。

相关推荐
一次旅行4 小时前
HyperTool:突破传统工具调用限制,让Agent更高效执行复杂任务
人工智能
陈天伟教授4 小时前
图解人工智能(58)人工智能应用-围棋国手
人工智能·语音识别·机器翻译
闻道参看5 小时前
2026年AI优质企业培训系统综合测评:合规管控/数据量化
人工智能
老虾头5 小时前
科技贴近烟火:本地化 AI,赋能各行各业日常经营
人工智能
毒爪的小新5 小时前
Linux 环境极速部署 vLLM:从零搭建生产级大模型推理服务
linux·人工智能·ai·语言模型·vllm
老大白菜5 小时前
25美元,DIY开源可穿戴智能AI眼镜:Arduino+乐鑫ESP32+DeepSeek项目
人工智能
岁月宁静6 小时前
RAG 文档摄入全链路,从原理到生产落地
vue.js·人工智能·python
小和尚同志6 小时前
AI 自动化测试探索(一):Playwright MCP
前端·人工智能·aigc
硅谷秋水6 小时前
面向长上下文自动驾驶的规划对齐Token压缩
人工智能·深度学习·机器学习·计算机视觉·自动驾驶