零基础吃透:tf.sparse.SparseTensor与核心TensorFlow API的协同使用
稀疏张量(tf.sparse.SparseTensor)可与TensorFlow绝大多数核心API透明兼容 (无需额外转换),包括tf.keras、tf.data、tf.function、tf.train.Example等,大幅降低稀疏数据在深度学习流水线中的使用成本。以下按API分类拆解用法、原理和关键注意事项,结合示例讲清实战细节。
前置准备(必运行)
python
import tensorflow as tf
# 复用美观打印函数
def pprint_sparse_tensor(st):
s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
for (index, value) in zip(st.indices, st.values):
s += f"\n %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
return s + "}>"
# 核心示例稀疏张量(后续复用)
sparse_data = tf.sparse.SparseTensor(
indices = [(0,0),(0,1),(0,2), (4,3),(5,0),(5,1)],
values = [1]*6,
dense_shape = (6,4) # 6行4列
)
一、与tf.keras的协同使用
核心原理
tf.keras支持将稀疏张量作为模型输入/中间传递/输出,仅需在输入层指定sparse=True;需注意:tf.keras.layers.Dense等全连接层会将稀疏输入转换为密集张量输出(因全连接层需计算所有维度)。
1. 构建支持稀疏输入的Keras模型
python
# 1. 定义稀疏输入层(shape=(4,),sparse=True)
x = tf.keras.Input(shape=(4,), sparse=True)
# 2. 全连接层(自动将稀疏输入转密集,输出密集张量)
y = tf.keras.layers.Dense(4)(x)
# 3. 构建模型
model = tf.keras.Model(inputs=x, outputs=y)
# 4. 传入稀疏张量做前向计算
forward_result = model(sparse_data)
print("模型前向计算结果(形状):", forward_result.shape)
# 5. 用predict预测(自动处理稀疏输入)
predict_result = model.predict(sparse_data, verbose=0)
print("\n模型预测结果(前3行):")
print(predict_result[:3])
输出解读
模型前向计算结果(形状): (6, 4)
模型预测结果(前3行):
[[ 0.01870704 0.7702533 0.22425324 -1.9139588 ]
[ 0. 0. 0. 0. ]
[ 0. 0. 0. 0. ]]
- 输入是6行4列的稀疏张量,输出是6行4列的密集张量;
- 输入中全零的行(如第1/2/3行),输出也全为0(因无有效特征参与计算)。
关键注意事项
- ✅ 输入层必须设
sparse=True:否则会报错(无法将稀疏张量传入密集输入层); - ❌ Dense层输出必为密集张量:若需全程稀疏,需使用支持稀疏输出的自定义层;
- ✅ 其他兼容层:
tf.keras.layers.Embedding、tf.keras.layers.Conv2D(部分场景)也支持稀疏输入。
二、与tf.data的协同使用
tf.data是TensorFlow的输入流水线核心,稀疏张量可无缝集成,且保留稀疏性(无需转换为密集张量),大幅提升流水线效率。
1. 从稀疏张量构建Dataset
使用tf.data.Dataset.from_tensor_slices(与密集张量用法一致),按维度切片并保留稀疏性:
python
# 构建数据集:按行切片(6个元素,每个元素是4列的稀疏张量)
dataset = tf.data.Dataset.from_tensor_slices(sparse_data)
# 遍历数据集元素
print("稀疏张量构建的Dataset元素:")
for idx, element in enumerate(dataset):
print(f"元素{idx}:")
print(pprint_sparse_tensor(element))
print("-"*20)
输出解读
元素0:
<SparseTensor shape=[4]
values={
[0]: 1
[1]: 1
[2]: 1}>
--------------------
元素1:
<SparseTensor shape=[4]
values={}>
--------------------
...(元素2/3均为空,元素4/5有非零值)
- 切片后每个元素是一维稀疏张量(shape=[4]);
- 空行(如元素1)保留为空稀疏张量(无values)。
2. 批处理(batch)与解批(unbatch)
批处理:合并连续元素为批量稀疏张量
python
# 按2个元素为一批,构建批量稀疏张量
batched_dataset = dataset.batch(2)
print("\n批处理(batch=2)后的Dataset:")
for idx, batch in enumerate(batched_dataset):
print(f"批次{idx}:")
print(pprint_sparse_tensor(batch))
print("-"*20)
解批:还原为单个稀疏张量
python
# 解批:批量张量→单个张量
unbatched_dataset = batched_dataset.unbatch()
print("\n解批后的Dataset(与原数据集一致):")
for idx, element in enumerate(unbatched_dataset):
if idx < 3: # 仅打印前3个
print(f"元素{idx}:")
print(pprint_sparse_tensor(element))
3. 数据集变换(map)
使用Dataset.map对稀疏张量做元素级变换(仅修改非零值,保留稀疏性):
python
# 变换:非零值×2
transform_dataset = dataset.map(lambda x: x * 2)
print("\n变换后(非零值×2)的Dataset:")
for idx, element in enumerate(transform_dataset):
if idx in [0,4,5]: # 仅打印有非零值的元素
print(f"元素{idx}:")
print(pprint_sparse_tensor(element))
输出解读
元素0:
<SparseTensor shape=[4]
values={
[0]: 2
[1]: 2
[2]: 2}>
元素4:
<SparseTensor shape=[4]
values={
[3]: 2}>
元素5:
<SparseTensor shape=[4]
values={
[0]: 2
[1]: 2}>
- 仅非零值被×2,空元素仍为空;
- 变换后仍为稀疏张量,无额外内存开销。
4. 变长形状批处理(dense_to_sparse_batch)
针对形状可变 的稀疏张量,使用tf.data.experimental.dense_to_sparse_batch批处理为统一形状的稀疏张量(替代普通batch):
python
# 构造变长稀疏张量数据集(元素shape分别为[2], [3], [1])
var_len_sparse = [
tf.sparse.SparseTensor([[0],[1]], [1,1], [2]),
tf.sparse.SparseTensor([[0],[1],[2]], [1,1,1], [3]),
tf.sparse.SparseTensor([[0]], [1], [1])
]
var_len_dataset = tf.data.Dataset.from_tensor_slices(var_len_sparse)
# 变长批处理:batch=2,统一shape=[3]
sparse_batched = var_len_dataset.apply(
tf.data.experimental.dense_to_sparse_batch(batch_size=2, row_shape=[3])
)
print("\n变长批处理结果:")
for batch in sparse_batched:
print(pprint_sparse_tensor(batch))
三、与tf.train.Example的协同使用
tf.train.Example是TensorFlow数据的标准protobuf编码格式,支持读取稀疏数据为SparseTensor:
1. tf.io.VarLenFeature(读取变长稀疏数据)
适用于一维变长数据 (如文本序列),但官方推荐优先使用tf.io.RaggedFeature(更灵活):
python
# 定义解析规则:读取变长int特征为稀疏张量
feature_description = {
"sparse_feat": tf.io.VarLenFeature(dtype=tf.int32)
}
# 模拟tf.train.Example数据(省略构造过程)
# parsed_example = tf.io.parse_single_example(example_proto, feature_description)
# sparse_tensor = parsed_example["sparse_feat"] # 输出SparseTensor
2. tf.io.SparseFeature(读取任意维度稀疏数据)
通过3个独立特征键存储indices/values/dense_shape,支持任意维度稀疏张量:
python
# 定义解析规则:指定三个特征键对应稀疏张量的三个组件
feature_description = {
"indices": tf.io.FixedLenFeature([], dtype=tf.string),
"values": tf.io.FixedLenFeature([], dtype=tf.string),
"dense_shape": tf.io.FixedLenFeature([], dtype=tf.string)
}
# 解析为稀疏张量(需反序列化)
# parsed = tf.io.parse_single_example(example_proto, feature_description)
# indices = tf.io.parse_tensor(parsed["indices"], tf.int64)
# values = tf.io.parse_tensor(parsed["values"], tf.int32)
# dense_shape = tf.io.parse_tensor(parsed["dense_shape"], tf.int64)
# sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape)
四、与tf.function的协同使用
tf.function将Python函数编译为TensorFlow图,大幅提升性能,稀疏张量可透明兼容:
示例:稀疏-密集矩阵乘法(编译为图)
python
# 装饰器编译为图函数
@tf.function
def sparse_matmul(x, y):
return tf.sparse.sparse_dense_matmul(x, y)
# 构造输入
a = tf.sparse.SparseTensor(
indices=[[0, 3], [2, 4]],
values=[15, 25],
dense_shape=[3, 10]
)
b = tf.sparse.to_dense(tf.sparse.transpose(a)) # 转置后转密集
# 调用编译后的函数
c = sparse_matmul(a, b)
print("\ntf.function编译后的稀疏矩阵乘法结果:")
print(c.numpy())
输出解读
[[225 0 0]
[ 0 0 0]
[ 0 0 625]]
- 第一次调用会编译图(稍慢),后续调用直接执行图(极快);
- 稀疏张量的所有操作均在图中执行,无额外转换开销。
五、其他兼容API(简要说明)
除上述核心API外,稀疏张量还兼容以下高频操作:
| API | 作用 | 示例 |
|---|---|---|
tf.cast |
转换稀疏张量数据类型 | tf.cast(sparse_data, tf.float32) |
tf.print |
打印稀疏张量(含indices/values) | tf.print(sparse_data) |
tf.math.abs |
非零值取绝对值 | tf.math.abs(sparse_data) |
tf.saved_model |
保存/加载含稀疏张量的模型 | model.save("sparse_model") |
tf.io.serialize_sparse |
序列化稀疏张量为字节流 | tf.io.serialize_sparse(sparse_data) |
核心避坑总结
1. 稀疏→密集的隐式转换
Dense层、matmul(稀疏×密集)等操作会隐式转换为密集张量,超稀疏场景可能导致OOM;- 解决方案:优先使用稀疏专用算子(如
tf.sparse.sparse_dense_matmul)。
2. 形状兼容性
tf.data批处理时,非批轴的形状必须一致(如示例中所有元素shape=[4]);- 变长形状需用
dense_to_sparse_batch,而非普通batch。
3. tf.train.Example的选型
- 一维变长数据:优先用
tf.io.RaggedFeature(替代VarLenFeature); - 高维稀疏数据:用
tf.io.SparseFeature存储三个组件。
4. tf.function的静态形状
tf.function编译时会固化稀疏张量的dense_shape,动态形状需用tf.TensorShape(None)兼容。
实战价值总结
稀疏张量与核心API的无缝兼容,使得:
- 数据预处理:用
tf.data高效处理超稀疏数据(如TF-IDF、高维特征); - 模型训练:用
tf.keras直接以稀疏张量为输入,避免密集转换的内存浪费; - 性能优化:用
tf.function编译稀疏操作,提升运行效率; - 数据存储:用
tf.train.Example/tf.saved_model序列化稀疏数据,节省存储成本。
这是处理大规模稀疏数据(如推荐系统、NLP、计算机视觉)的核心能力,能大幅降低内存占用和计算开销。