TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式

这份文档核心讲解了 TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式 :一种是「极简直接法」(tf.ragged.constant),另一种是「灵活定制法」(工厂方法,从扁平值+行分区规则构造)。以下是通俗化拆解,全程贴合原文例子和逻辑:

一、核心前提

构造 RaggedTensor 的核心是「描述可变长度的嵌套结构」:要么直接给嵌套列表(让 TF 自动识别长度),要么先给扁平值+「行分区规则」(告诉 TF 如何把扁平值拆成可变长度的行)。

二、方法1:最简方式 → tf.ragged.constant

核心逻辑

直接把「嵌套的 Python 列表/NumPy 数组」传给 tf.ragged.constant,TF 会自动识别每一层的可变长度,生成对应的 RaggedTensor(无需手动定义长度规则)。

适用场景

已有现成的嵌套列表/数组,且能直接看出行/层的长度分布(比如手动整理的句子、段落数据)。

原文例子拆解
示例1:二维 RaggedTensor(句子,每行是单词列表)
python 复制代码
# 输入:2个句子(第一句6个单词,第二句5个单词,长度可变)
sentences = tf.ragged.constant([
    ["Let's", "build", "some", "ragged", "tensors", "!"],
    ["We", "can", "use", "tf.ragged.constant", "."]])
print(sentences)
# 输出:二维RaggedTensor,保留每行的可变长度
<tf.RaggedTensor [[b"Let's", b'build', b'some', b'ragged', b'tensors', b'!'],
 [b'We', b'can', b'use', b'tf.ragged.constant', b'.']]>
  • 关键:字符串会自动转成字节型(b前缀),不影响使用;每行长度可不同(6个 vs 5个)。
示例2:三维 RaggedTensor(段落,每段包含多个句子,句子长度可变)
python 复制代码
# 输入:2个段落 → 每个段落包含2个句子 → 句子长度可变
paragraphs = tf.ragged.constant([
    [['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']],  # 段落1:2个句子(各4个单词)
    [['Do', 'you', 'want', 'to', 'come', 'visit'], ["I'm", 'free', 'tomorrow']],  # 段落2:句子1(6词)、句子2(3词)
])
print(paragraphs)
# 输出:三维RaggedTensor,每层长度都可变
<tf.RaggedTensor [[[b'I', b'have', b'a', b'cat'], [b'His', b'name', b'is', b'Mat']],
 [[b'Do', b'you', b'want', b'to', b'come', b'visit'],
  [b"I'm", b'free', b'tomorrow']]]>

三、方法2:灵活定制 → 工厂方法(扁平值+行分区张量)

核心逻辑

当数据是「扁平的一维列表」,但知道「每个值属于哪一行/每行长度/每行起止索引」时,用工厂方法把扁平值拆成可变长度的 RaggedTensor。

适用场景

数据是扁平存储的(比如从文件/数据库读取的一维值列表),但有额外的"行分区规则"(比如记录了每个值所属的行号、每行的长度)。

原文3种工厂方法拆解(目标都是构造 [[3,1,4,1], [], [5,9], [2]]

所有方法的输入核心:values(扁平值列表) + 「行分区规则」,以下是一一对应解释:

工厂方法 行分区规则(参数) 原文例子拆解
from_value_rowids value_rowids(每个值的行号) values=[3,1,4,1,5,9,2] value_rowids=[0,0,0,0,2,2,3] → 3/1/4/1属于行0,5/9属于行2,2属于行3,行1无值→空行
from_row_lengths row_lengths(每行的长度) values=[3,1,4,1,5,9,2] row_lengths=[4,0,2,1] → 行0长度4(取前4个值),行1长度0(空),行2长度2(取5/9),行3长度1(取2)
from_row_splits row_splits(每行起止索引) values=[3,1,4,1,5,9,2] row_splits=[0,4,4,6,7] → 行0:04(3/1/4/1),行1:44(空),行2:46(5/9),行3:67(2)
原文代码验证(3种方法结果完全一致)
python 复制代码
# 方法1:from_value_rowids
print(tf.RaggedTensor.from_value_rowids(
    values=[3, 1, 4, 1, 5, 9, 2],
    value_rowids=[0, 0, 0, 0, 2, 2, 3]))
# 输出:<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>

# 方法2:from_row_lengths
print(tf.RaggedTensor.from_row_lengths(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_lengths=[4, 0, 2, 1]))
# 输出同上

# 方法3:from_row_splits
print(tf.RaggedTensor.from_row_splits(
    values=[3, 1, 4, 1, 5, 9, 2],
    row_splits=[0, 4, 4, 6, 7]))
# 输出同上

四、关键注意事项(原文补充)

  1. 校验机制:默认情况下,工厂方法会校验「行分区规则」和「values长度」是否匹配(比如 row_lengths 总和是否等于 values 长度),避免构造出非法的 RaggedTensor;

  2. 性能优化 :如果能100%保证输入的行分区规则合法(比如提前校验过),可以加 validate=False 跳过校验,提升构造速度:

    python 复制代码
    # 跳过校验,加速构造
    tf.RaggedTensor.from_row_lengths(values=[3,1,4,1,5,9,2], row_lengths=[4,0,2,1], validate=False)

核心总结

构造方法 优点 缺点 适用场景
tf.ragged.constant 写法极简,直观 依赖现成的嵌套列表 手动整理的嵌套数据(句子、段落)
工厂方法(from_*) 灵活,适配扁平数据源 需要定义行分区规则 扁平存储的数据(文件/数据库读取)

两种方法最终都能构造出 RaggedTensor,核心是根据「数据的原始形态」选择:有嵌套列表用 constant,有扁平值+行规则用工厂方法。

相关推荐
无代码专家18 小时前
低代码构建数据管理系统:选型逻辑与实践路径
人工智能·低代码
无代码专家18 小时前
低代码搭建项目管理平台:易用性导向的实践方案
人工智能·低代码
KKKlucifer18 小时前
AI赋能与全栈适配:安全运维新范式的演进与实践
人工智能·安全
许泽宇的技术分享18 小时前
当AI学会拍短剧:Huobao Drama全栈AI短剧生成平台深度解析
人工智能
爱喝可乐的老王18 小时前
机器学习监督学习模型--线性回归
人工智能·机器学习·线性回归
金融Tech趋势派18 小时前
2025企业微信私有化部署优秀服务商:微盛·企微管家方案解析
人工智能·企业微信·scrm
Gofarlic_oms118 小时前
跨国企业Cadence许可证全球统一管理方案
java·大数据·网络·人工智能·汽车
AAD5558889918 小时前
牛肝菌目标检测:基于YOLOv8-CFPT-P2345模型的创新实现与应用_1
人工智能·yolo·目标检测
幂链iPaaS18 小时前
制造业/零售电商ERP和MES系统集成指南
大数据·人工智能
gorgeous(๑>؂<๑)18 小时前
【中国科学院光电研究所-张建林组-AAAI26】追踪不稳定目标:基于外观引导的运动建模在无人机拍摄视频中实现稳健的多目标跟踪
人工智能·机器学习·计算机视觉·目标跟踪·无人机