零基础吃透:RaggedTensor与其他张量类型的转换
RaggedTensor 提供了原生方法,可与 TensorFlow 另外两种核心张量类型(密集张量 tf.Tensor、稀疏张量 tf.SparseTensor)双向转换,覆盖"补0/去填充""稀疏存储/可变长度处理"等核心场景。以下按「转换方向」拆解每个方法的用法、原理和适用场景。
前置准备(可运行代码)
python
import tensorflow as tf
# 基础RaggedTensor示例(文本序列)
ragged_sentences = tf.ragged.constant([
['Hi'], # 行0:1个元素
['Welcome', 'to', 'the', 'fair'], # 行1:4个元素
['Have', 'fun']]) # 行2:2个元素
转换1:RaggedTensor → 密集张量(tf.Tensor)
核心方法:ragged_tensor.to_tensor(default_value, shape=None)
将不规则张量补全为固定形状的密集张量,不足长度的位置填充指定默认值。
| 参数 | 作用 |
|---|---|
default_value |
必选:填充值(如空字符串''、0、-1等),需与张量元素类型匹配 |
shape |
可选:目标密集张量的形状(如[None, 10]表示行数可变、列数固定10);不指定则补到「最大行长度」 |
示例代码
python
# 转换为密集张量:补空字符串,列数固定为10
dense_sentences = ragged_sentences.to_tensor(default_value='', shape=[None, 10])
print("Ragged → 密集张量(shape=[None,10]):")
print(dense_sentences)
# 不指定shape:自动补到最大行长度(4列)
dense_sentences_auto = ragged_sentences.to_tensor(default_value='')
print("\nRagged → 密集张量(自动补最大长度):")
print(dense_sentences_auto)
运行结果
Ragged → 密集张量(shape=[None,10]):
tf.Tensor(
[[b'Hi' b'' b'' b'' b'' b'' b'' b'' b'' b'']
[b'Welcome' b'to' b'the' b'fair' b'' b'' b'' b'' b'' b'']
[b'Have' b'fun' b'' b'' b'' b'' b'' b'' b'' b'']], shape=(3, 10), dtype=string)
Ragged → 密集张量(自动补最大长度):
tf.Tensor(
[[b'Hi' b'' b'' b'']
[b'Welcome' b'to' b'the' b'fair']
[b'Have' b'fun' b'' b'']], shape=(3, 4), dtype=string)
核心解读
- 补全逻辑:每行按目标长度填充
default_value,比如行0仅1个元素,补9个空字符串到10列; - 适用场景:需将RaggedTensor传入不支持不规则张量的层/函数(如老版本LSTM、部分第三方库)。
转换2:密集张量(tf.Tensor)→ RaggedTensor
核心方法:tf.RaggedTensor.from_tensor(tensor, padding)
从含填充值的密集张量中剔除填充值,恢复为可变长度的RaggedTensor(仅保留有效元素)。
| 参数 | 作用 |
|---|---|
tensor |
必选:含填充值的密集张量(如补0/-1的序列) |
padding |
必选:填充值(如-1、0),张量中该值会被剔除,仅保留非填充值 |
示例代码
python
# 含填充值的密集张量(-1为填充)
dense_x = tf.constant([[1, 3, -1, -1], [2, -1, -1, -1], [4, 5, 8, 9]])
# 转换为RaggedTensor:剔除-1
ragged_x = tf.RaggedTensor.from_tensor(dense_x, padding=-1)
print("密集张量 → Ragged(剔除填充值-1):")
print(ragged_x)
运行结果
<tf.RaggedTensor [[1, 3], [2], [4, 5, 8, 9]]>
核心解读
- 去填充逻辑:遍历每行,删除所有
padding指定的值,仅保留有效元素; - 适用场景:从"补0/补-1的密集序列"恢复原始可变长度结构(如文本预处理后去填充)。
转换3:RaggedTensor → 稀疏张量(tf.SparseTensor)
核心方法:ragged_tensor.to_sparse()
将不规则张量转换为稀疏张量,仅存储非空元素的坐标+值,节省内存(无填充值存储)。
示例代码
python
# Ragged → SparseTensor
sparse_sentences = ragged_sentences.to_sparse()
print("Ragged → SparseTensor:")
print("indices(非空元素坐标):", sparse_sentences.indices)
print("values(非空元素值):", sparse_sentences.values)
print("dense_shape(对应密集张量形状):", sparse_sentences.dense_shape)
运行结果
indices(非空元素坐标): tf.Tensor(
[[0 0]
[1 0]
[1 1]
[1 2]
[1 3]
[2 0]
[2 1]], shape=(7, 2), dtype=int64)
values(非空元素值): tf.Tensor([b'Hi' b'Welcome' b'to' b'the' b'fair' b'Have' b'fun'], shape=(7,), dtype=string)
dense_shape(对应密集张量形状): tf.Tensor([3 4], shape=(2,), dtype=int64)
核心解读
- 转换逻辑:
indices:所有非空元素的「行+列」坐标(如[0,0]对应行0列0的Hi);values:按坐标顺序排列的非空元素值;dense_shape:对应密集张量的形状(行数=3,最大列数=4);
- 适用场景:存储/传输大规模可变长度数据(SparseTensor仅存有效元素,比密集张量省内存)。
转换4:稀疏张量(tf.SparseTensor)→ RaggedTensor
核心方法:tf.RaggedTensor.from_sparse(sparse_tensor)
将稀疏张量按行整理为可变长度的RaggedTensor,空行保留为空列表。
示例代码
python
# 定义SparseTensor(3行3列,非空元素:[0,0]='a'、[2,0]='b'、[2,1]='c')
st = tf.SparseTensor(
indices=[[0, 0], [2, 0], [2, 1]], # 非空元素坐标
values=['a', 'b', 'c'], # 非空元素值
dense_shape=[3, 3] # 对应密集张量形状
)
# Sparse → RaggedTensor
ragged_from_sparse = tf.RaggedTensor.from_sparse(st)
print("Sparse → RaggedTensor:")
print(ragged_from_sparse)
运行结果
<tf.RaggedTensor [[b'a'], [], [b'b', b'c']]>
核心解读
- 转换逻辑:
- 按行分组SparseTensor的非空元素,每行的元素按列索引升序排列;
- 无元素的行(如行1)保留为空列表;
- 适用场景:将稀疏存储的可变长度数据转换为RaggedTensor,利用其更友好的API(如切片、运算符、Keras适配)。
核心注意事项(避坑)
1. 类型匹配
- 转换时
default_value/padding的类型必须与张量元素类型一致(如字符串张量不能用0填充); - SparseTensor转RaggedTensor时,
indices需按行/列索引有序(TF会自动排序,无需手动处理)。
2. 形状约束
to_tensor(shape=...)指定的形状不能小于RaggedTensor的最大长度(如最大列数4,不能指定shape=[None,3],会报错);from_tensor仅支持"末尾维度有填充值"的密集张量(如二维张量仅列维度补填充,行维度需固定)。
3. 空值处理
- RaggedTensor的空行(
[])转换为密集张量时,会填充default_value; - SparseTensor的空行转换为RaggedTensor时,保留为空列表(无填充)。
转换场景总结
| 转换方向 | 核心方法 | 适用场景 |
|---|---|---|
| Ragged → Tensor | to_tensor(default_value, shape) |
适配不支持Ragged的层/函数(如老版LSTM) |
| Tensor → Ragged | from_tensor(tensor, padding) |
去除密集张量的填充值,恢复可变长度 |
| Ragged → Sparse | to_sparse() |
大规模数据存储/传输(节省内存) |
| Sparse → Ragged | from_sparse(sparse_tensor) |
利用Ragged的友好API(切片、运算符、Keras) |
这四类转换覆盖了RaggedTensor在数据预处理、模型输入、存储传输等全流程的适配需求,是处理可变长度数据的核心工具。