TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式

这份文档核心讲解了 TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式 :一种是「极简直接法」(tf.ragged.constant),另一种是「灵活定制法」(工厂方法,从扁平值+行分区规则构造)。以下是通俗化拆解,全程贴合原文例子和逻辑:

一、核心前提

构造 RaggedTensor 的核心是「描述可变长度的嵌套结构」:要么直接给嵌套列表(让 TF 自动识别长度),要么先给扁平值+「行分区规则」(告诉 TF 如何把扁平值拆成可变长度的行)。

二、方法1:最简方式 → tf.ragged.constant

核心逻辑

直接把「嵌套的 Python 列表/NumPy 数组」传给 tf.ragged.constant,TF 会自动识别每一层的可变长度,生成对应的 RaggedTensor(无需手动定义长度规则)。

适用场景

已有现成的嵌套列表/数组,且能直接看出行/层的长度分布(比如手动整理的句子、段落数据)。

原文例子拆解
示例1:二维 RaggedTensor(句子,每行是单词列表)
python 复制代码
# 输入:2个句子(第一句6个单词,第二句5个单词,长度可变)
sentences = tf.ragged.constant([
    ["Let's", "build", "some", "ragged", "tensors", "!"],
    ["We", "can", "use", "tf.ragged.constant", "."]])
print(sentences)
# 输出:二维RaggedTensor,保留每行的可变长度
<tf.RaggedTensor [[b"Let's", b'build', b'some', b'ragged', b'tensors', b'!'],
 [b'We', b'can', b'use', b'tf.ragged.constant', b'.']]>
  • 关键:字符串会自动转成字节型(b前缀),不影响使用;每行长度可不同(6个 vs 5个)。
示例2:三维 RaggedTensor(段落,每段包含多个句子,句子长度可变)
python 复制代码
# 输入:2个段落 → 每个段落包含2个句子 → 句子长度可变
paragraphs = tf.ragged.constant([
    [['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']],  # 段落1:2个句子(各4个单词)
    [['Do', 'you', 'want', 'to', 'come', 'visit'], ["I'm", 'free', 'tomorrow']],  # 段落2:句子1(6词)、句子2(3词)
])
print(paragraphs)
# 输出:三维RaggedTensor,每层长度都可变
<tf.RaggedTensor [[[b'I', b'have', b'a', b'cat'], [b'His', b'name', b'is', b'Mat']],
 [[b'Do', b'you', b'want', b'to', b'come', b'visit'],
  [b"I'm", b'free', b'tomorrow']]]>

三、方法2:灵活定制 → 工厂方法(扁平值+行分区张量)

核心逻辑

当数据是「扁平的一维列表」,但知道「每个值属于哪一行/每行长度/每行起止索引」时,用工厂方法把扁平值拆成可变长度的 RaggedTensor。

适用场景

数据是扁平存储的(比如从文件/数据库读取的一维值列表),但有额外的"行分区规则"(比如记录了每个值所属的行号、每行的长度)。

原文3种工厂方法拆解(目标都是构造 [[3,1,4,1], [], [5,9], [2]]

所有方法的输入核心:values(扁平值列表) + 「行分区规则」,以下是一一对应解释:

工厂方法 行分区规则(参数) 原文例子拆解
from_value_rowids value_rowids(每个值的行号) values=[3,1,4,1,5,9,2] value_rowids=[0,0,0,0,2,2,3] → 3/1/4/1属于行0,5/9属于行2,2属于行3,行1无值→空行
from_row_lengths row_lengths(每行的长度) values=[3,1,4,1,5,9,2] row_lengths=[4,0,2,1] → 行0长度4(取前4个值),行1长度0(空),行2长度2(取5/9),行3长度1(取2)
from_row_splits row_splits(每行起止索引) values=[3,1,4,1,5,9,2] row_splits=[0,4,4,6,7] → 行0:04(3/1/4/1),行1:44(空),行2:46(5/9),行3:67(2)
原文代码验证(3种方法结果完全一致)
python 复制代码
# 方法1:from_value_rowids
print(tf.RaggedTensor.from_value_rowids(
    values=[3, 1, 4, 1, 5, 9, 2],
    value_rowids=[0, 0, 0, 0, 2, 2, 3]))
# 输出:<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

# 方法2:from_row_lengths
print(tf.RaggedTensor.from_row_lengths(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_lengths=[4, 0, 2, 1]))
# 输出同上

# 方法3:from_row_splits
print(tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_splits=[0, 4, 4, 6, 7]))
# 输出同上

四、关键注意事项(原文补充)

  1. 校验机制:默认情况下,工厂方法会校验「行分区规则」和「values长度」是否匹配(比如 row_lengths 总和是否等于 values 长度),避免构造出非法的 RaggedTensor;

  2. 性能优化 :如果能100%保证输入的行分区规则合法(比如提前校验过),可以加 validate=False 跳过校验,提升构造速度:

    python 复制代码
    # 跳过校验,加速构造
    tf.RaggedTensor.from_row_lengths(values=[3,1,4,1,5,9,2], row_lengths=[4,0,2,1], validate=False)

核心总结

构造方法 优点 缺点 适用场景
tf.ragged.constant 写法极简,直观 依赖现成的嵌套列表 手动整理的嵌套数据(句子、段落)
工厂方法(from_*) 灵活,适配扁平数据源 需要定义行分区规则 扁平存储的数据(文件/数据库读取)

两种方法最终都能构造出 RaggedTensor,核心是根据「数据的原始形态」选择:有嵌套列表用 constant,有扁平值+行规则用工厂方法。

相关推荐
霖大侠1 小时前
VISION TRANSFORMER ADAPTER FOR DENSE PREDICTIONS
人工智能·深度学习·transformer
青稞社区.1 小时前
VLA 的强化学习后训练框架π_RL详解
人工智能
CNRio1 小时前
数字经济健康发展的双维路径:技术伦理与产业价值的重构
大数据·人工智能·重构
Hello娃的1 小时前
【神经网络】反向传播BP算法
人工智能·神经网络·算法
IT_陈寒1 小时前
Java并发编程避坑指南:这5个隐藏陷阱让你的性能暴跌50%!
前端·人工智能·后端
桃子叔叔1 小时前
CoOp:Visual-Language Model从静态模板到动态学习新范式
人工智能·学习·语言模型
测试人社区—小叶子1 小时前
边缘计算与AI:下一代智能应用的核心架构
运维·网络·人工智能·python·架构·边缘计算
非著名架构师1 小时前
破解“AI幻觉”,锁定真实风险:专业气象模型如何为企业提供可信的极端天气决策依据?
人工智能·深度学习·机器学习·数据分析·风光功率预测·高精度气象数据·高精度天气预报数据
jinxinyuuuus1 小时前
快手在线去水印:短链解析、API逆向与视频流的元数据重构
前端·人工智能·算法·重构
忆~遂愿1 小时前
昇腾 Triton-Ascend 开源实战:架构解析、环境搭建与配置速查
人工智能·python·深度学习·机器学习·自然语言处理