这份文档核心讲解了 TensorFlow 不规则张量(RaggedTensor)的两种核心构造方式 :一种是「极简直接法」(tf.ragged.constant),另一种是「灵活定制法」(工厂方法,从扁平值+行分区规则构造)。以下是通俗化拆解,全程贴合原文例子和逻辑:
一、核心前提
构造 RaggedTensor 的核心是「描述可变长度的嵌套结构」:要么直接给嵌套列表(让 TF 自动识别长度),要么先给扁平值+「行分区规则」(告诉 TF 如何把扁平值拆成可变长度的行)。
二、方法1:最简方式 → tf.ragged.constant
核心逻辑
直接把「嵌套的 Python 列表/NumPy 数组」传给 tf.ragged.constant,TF 会自动识别每一层的可变长度,生成对应的 RaggedTensor(无需手动定义长度规则)。
适用场景
已有现成的嵌套列表/数组,且能直接看出行/层的长度分布(比如手动整理的句子、段落数据)。
原文例子拆解
示例1:二维 RaggedTensor(句子,每行是单词列表)
python
# 输入:2个句子(第一句6个单词,第二句5个单词,长度可变)
sentences = tf.ragged.constant([
["Let's", "build", "some", "ragged", "tensors", "!"],
["We", "can", "use", "tf.ragged.constant", "."]])
print(sentences)
# 输出:二维RaggedTensor,保留每行的可变长度
<tf.RaggedTensor [[b"Let's", b'build', b'some', b'ragged', b'tensors', b'!'],
[b'We', b'can', b'use', b'tf.ragged.constant', b'.']]>
- 关键:字符串会自动转成字节型(b前缀),不影响使用;每行长度可不同(6个 vs 5个)。
示例2:三维 RaggedTensor(段落,每段包含多个句子,句子长度可变)
python
# 输入:2个段落 → 每个段落包含2个句子 → 句子长度可变
paragraphs = tf.ragged.constant([
[['I', 'have', 'a', 'cat'], ['His', 'name', 'is', 'Mat']], # 段落1:2个句子(各4个单词)
[['Do', 'you', 'want', 'to', 'come', 'visit'], ["I'm", 'free', 'tomorrow']], # 段落2:句子1(6词)、句子2(3词)
])
print(paragraphs)
# 输出:三维RaggedTensor,每层长度都可变
<tf.RaggedTensor [[[b'I', b'have', b'a', b'cat'], [b'His', b'name', b'is', b'Mat']],
[[b'Do', b'you', b'want', b'to', b'come', b'visit'],
[b"I'm", b'free', b'tomorrow']]]>
三、方法2:灵活定制 → 工厂方法(扁平值+行分区张量)
核心逻辑
当数据是「扁平的一维列表」,但知道「每个值属于哪一行/每行长度/每行起止索引」时,用工厂方法把扁平值拆成可变长度的 RaggedTensor。
适用场景
数据是扁平存储的(比如从文件/数据库读取的一维值列表),但有额外的"行分区规则"(比如记录了每个值所属的行号、每行的长度)。
原文3种工厂方法拆解(目标都是构造 [[3,1,4,1], [], [5,9], [2]])
所有方法的输入核心:values(扁平值列表) + 「行分区规则」,以下是一一对应解释:
| 工厂方法 | 行分区规则(参数) | 原文例子拆解 |
|---|---|---|
| from_value_rowids | value_rowids(每个值的行号) | values=[3,1,4,1,5,9,2] value_rowids=[0,0,0,0,2,2,3] → 3/1/4/1属于行0,5/9属于行2,2属于行3,行1无值→空行 |
| from_row_lengths | row_lengths(每行的长度) | values=[3,1,4,1,5,9,2] row_lengths=[4,0,2,1] → 行0长度4(取前4个值),行1长度0(空),行2长度2(取5/9),行3长度1(取2) |
| from_row_splits | row_splits(每行起止索引) | values=[3,1,4,1,5,9,2] row_splits=[0,4,4,6,7] → 行0:04(3/1/4/1),行1:44(空),行2:46(5/9),行3:67(2) |
原文代码验证(3种方法结果完全一致)
python
# 方法1:from_value_rowids
print(tf.RaggedTensor.from_value_rowids(
values=[3, 1, 4, 1, 5, 9, 2],
value_rowids=[0, 0, 0, 0, 2, 2, 3]))
# 输出:<tf.RaggedTensor [[3, 1, 4, 1], [], [5, 9], [2]]>
# 方法2:from_row_lengths
print(tf.RaggedTensor.from_row_lengths(
values=[3, 1, 4, 1, 5, 9, 2],
row_lengths=[4, 0, 2, 1]))
# 输出同上
# 方法3:from_row_splits
print(tf.RaggedTensor.from_row_splits(
values=[3, 1, 4, 1, 5, 9, 2],
row_splits=[0, 4, 4, 6, 7]))
# 输出同上
四、关键注意事项(原文补充)
-
校验机制:默认情况下,工厂方法会校验「行分区规则」和「values长度」是否匹配(比如 row_lengths 总和是否等于 values 长度),避免构造出非法的 RaggedTensor;
-
性能优化 :如果能100%保证输入的行分区规则合法(比如提前校验过),可以加
validate=False跳过校验,提升构造速度:python# 跳过校验,加速构造 tf.RaggedTensor.from_row_lengths(values=[3,1,4,1,5,9,2], row_lengths=[4,0,2,1], validate=False)
核心总结
| 构造方法 | 优点 | 缺点 | 适用场景 |
|---|---|---|---|
| tf.ragged.constant | 写法极简,直观 | 依赖现成的嵌套列表 | 手动整理的嵌套数据(句子、段落) |
| 工厂方法(from_*) | 灵活,适配扁平数据源 | 需要定义行分区规则 | 扁平存储的数据(文件/数据库读取) |
两种方法最终都能构造出 RaggedTensor,核心是根据「数据的原始形态」选择:有嵌套列表用 constant,有扁平值+行规则用工厂方法。