tensorflow 零基础吃透:tf.sparse.SparseTensor 与核心 TensorFlow API 的协同使用

零基础吃透:tf.sparse.SparseTensor与核心TensorFlow API的协同使用

稀疏张量(tf.sparse.SparseTensor)可与TensorFlow绝大多数核心API透明兼容 (无需额外转换),包括tf.kerastf.datatf.functiontf.train.Example等,大幅降低稀疏数据在深度学习流水线中的使用成本。以下按API分类拆解用法、原理和关键注意事项,结合示例讲清实战细节。

前置准备(必运行)

python 复制代码
import tensorflow as tf

# 复用美观打印函数
def pprint_sparse_tensor(st):
  s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
  for (index, value) in zip(st.indices, st.values):
    s += f"\n  %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
  return s + "}>"

# 核心示例稀疏张量(后续复用)
sparse_data = tf.sparse.SparseTensor(
    indices = [(0,0),(0,1),(0,2), (4,3),(5,0),(5,1)],
    values = [1]*6,
    dense_shape = (6,4)  # 6行4列
)

一、与tf.keras的协同使用

核心原理

tf.keras支持将稀疏张量作为模型输入/中间传递/输出,仅需在输入层指定sparse=True;需注意:tf.keras.layers.Dense等全连接层会将稀疏输入转换为密集张量输出(因全连接层需计算所有维度)。

1. 构建支持稀疏输入的Keras模型

python 复制代码
# 1. 定义稀疏输入层(shape=(4,),sparse=True)
x = tf.keras.Input(shape=(4,), sparse=True)
# 2. 全连接层(自动将稀疏输入转密集,输出密集张量)
y = tf.keras.layers.Dense(4)(x)
# 3. 构建模型
model = tf.keras.Model(inputs=x, outputs=y)

# 4. 传入稀疏张量做前向计算
forward_result = model(sparse_data)
print("模型前向计算结果(形状):", forward_result.shape)

# 5. 用predict预测(自动处理稀疏输入)
predict_result = model.predict(sparse_data, verbose=0)
print("\n模型预测结果(前3行):")
print(predict_result[:3])

输出解读

复制代码
模型前向计算结果(形状): (6, 4)

模型预测结果(前3行):
[[ 0.01870704  0.7702533   0.22425324 -1.9139588 ]
 [ 0.          0.          0.          0.        ]
 [ 0.          0.          0.          0.        ]]
  • 输入是6行4列的稀疏张量,输出是6行4列的密集张量;
  • 输入中全零的行(如第1/2/3行),输出也全为0(因无有效特征参与计算)。

关键注意事项

  • ✅ 输入层必须设sparse=True:否则会报错(无法将稀疏张量传入密集输入层);
  • ❌ Dense层输出必为密集张量:若需全程稀疏,需使用支持稀疏输出的自定义层;
  • ✅ 其他兼容层:tf.keras.layers.Embeddingtf.keras.layers.Conv2D(部分场景)也支持稀疏输入。

二、与tf.data的协同使用

tf.data是TensorFlow的输入流水线核心,稀疏张量可无缝集成,且保留稀疏性(无需转换为密集张量),大幅提升流水线效率。

1. 从稀疏张量构建Dataset

使用tf.data.Dataset.from_tensor_slices(与密集张量用法一致),按维度切片并保留稀疏性:

python 复制代码
# 构建数据集:按行切片(6个元素,每个元素是4列的稀疏张量)
dataset = tf.data.Dataset.from_tensor_slices(sparse_data)

# 遍历数据集元素
print("稀疏张量构建的Dataset元素:")
for idx, element in enumerate(dataset):
  print(f"元素{idx}:")
  print(pprint_sparse_tensor(element))
  print("-"*20)

输出解读

复制代码
元素0:
<SparseTensor shape=[4] 
 values={
  [0]: 1
  [1]: 1
  [2]: 1}>
--------------------
元素1:
<SparseTensor shape=[4] 
 values={}>
--------------------
...(元素2/3均为空,元素4/5有非零值)
  • 切片后每个元素是一维稀疏张量(shape=[4]);
  • 空行(如元素1)保留为空稀疏张量(无values)。

2. 批处理(batch)与解批(unbatch)

批处理:合并连续元素为批量稀疏张量
python 复制代码
# 按2个元素为一批,构建批量稀疏张量
batched_dataset = dataset.batch(2)
print("\n批处理(batch=2)后的Dataset:")
for idx, batch in enumerate(batched_dataset):
  print(f"批次{idx}:")
  print(pprint_sparse_tensor(batch))
  print("-"*20)
解批:还原为单个稀疏张量
python 复制代码
# 解批:批量张量→单个张量
unbatched_dataset = batched_dataset.unbatch()
print("\n解批后的Dataset(与原数据集一致):")
for idx, element in enumerate(unbatched_dataset):
  if idx < 3:  # 仅打印前3个
    print(f"元素{idx}:")
    print(pprint_sparse_tensor(element))

3. 数据集变换(map)

使用Dataset.map对稀疏张量做元素级变换(仅修改非零值,保留稀疏性):

python 复制代码
# 变换:非零值×2
transform_dataset = dataset.map(lambda x: x * 2)
print("\n变换后(非零值×2)的Dataset:")
for idx, element in enumerate(transform_dataset):
  if idx in [0,4,5]:  # 仅打印有非零值的元素
    print(f"元素{idx}:")
    print(pprint_sparse_tensor(element))

输出解读

复制代码
元素0:
<SparseTensor shape=[4] 
 values={
  [0]: 2
  [1]: 2
  [2]: 2}>
元素4:
<SparseTensor shape=[4] 
 values={
  [3]: 2}>
元素5:
<SparseTensor shape=[4] 
 values={
  [0]: 2
  [1]: 2}>
  • 仅非零值被×2,空元素仍为空;
  • 变换后仍为稀疏张量,无额外内存开销。

4. 变长形状批处理(dense_to_sparse_batch)

针对形状可变 的稀疏张量,使用tf.data.experimental.dense_to_sparse_batch批处理为统一形状的稀疏张量(替代普通batch):

python 复制代码
# 构造变长稀疏张量数据集(元素shape分别为[2], [3], [1])
var_len_sparse = [
  tf.sparse.SparseTensor([[0],[1]], [1,1], [2]),
  tf.sparse.SparseTensor([[0],[1],[2]], [1,1,1], [3]),
  tf.sparse.SparseTensor([[0]], [1], [1])
]
var_len_dataset = tf.data.Dataset.from_tensor_slices(var_len_sparse)

# 变长批处理:batch=2,统一shape=[3]
sparse_batched = var_len_dataset.apply(
    tf.data.experimental.dense_to_sparse_batch(batch_size=2, row_shape=[3])
)
print("\n变长批处理结果:")
for batch in sparse_batched:
  print(pprint_sparse_tensor(batch))

三、与tf.train.Example的协同使用

tf.train.Example是TensorFlow数据的标准protobuf编码格式,支持读取稀疏数据为SparseTensor

1. tf.io.VarLenFeature(读取变长稀疏数据)

适用于一维变长数据 (如文本序列),但官方推荐优先使用tf.io.RaggedFeature(更灵活):

python 复制代码
# 定义解析规则:读取变长int特征为稀疏张量
feature_description = {
    "sparse_feat": tf.io.VarLenFeature(dtype=tf.int32)
}

# 模拟tf.train.Example数据(省略构造过程)
# parsed_example = tf.io.parse_single_example(example_proto, feature_description)
# sparse_tensor = parsed_example["sparse_feat"]  # 输出SparseTensor

2. tf.io.SparseFeature(读取任意维度稀疏数据)

通过3个独立特征键存储indices/values/dense_shape,支持任意维度稀疏张量:

python 复制代码
# 定义解析规则:指定三个特征键对应稀疏张量的三个组件
feature_description = {
    "indices": tf.io.FixedLenFeature([], dtype=tf.string),
    "values": tf.io.FixedLenFeature([], dtype=tf.string),
    "dense_shape": tf.io.FixedLenFeature([], dtype=tf.string)
}

# 解析为稀疏张量(需反序列化)
# parsed = tf.io.parse_single_example(example_proto, feature_description)
# indices = tf.io.parse_tensor(parsed["indices"], tf.int64)
# values = tf.io.parse_tensor(parsed["values"], tf.int32)
# dense_shape = tf.io.parse_tensor(parsed["dense_shape"], tf.int64)
# sparse_tensor = tf.sparse.SparseTensor(indices, values, dense_shape)

四、与tf.function的协同使用

tf.function将Python函数编译为TensorFlow图,大幅提升性能,稀疏张量可透明兼容:

示例:稀疏-密集矩阵乘法(编译为图)

python 复制代码
# 装饰器编译为图函数
@tf.function
def sparse_matmul(x, y):
  return tf.sparse.sparse_dense_matmul(x, y)

# 构造输入
a = tf.sparse.SparseTensor(
    indices=[[0, 3], [2, 4]],
    values=[15, 25],
    dense_shape=[3, 10]
)
b = tf.sparse.to_dense(tf.sparse.transpose(a))  # 转置后转密集

# 调用编译后的函数
c = sparse_matmul(a, b)
print("\ntf.function编译后的稀疏矩阵乘法结果:")
print(c.numpy())

输出解读

复制代码
[[225   0   0]
 [  0   0   0]
 [  0   0 625]]
  • 第一次调用会编译图(稍慢),后续调用直接执行图(极快);
  • 稀疏张量的所有操作均在图中执行,无额外转换开销。

五、其他兼容API(简要说明)

除上述核心API外,稀疏张量还兼容以下高频操作:

API 作用 示例
tf.cast 转换稀疏张量数据类型 tf.cast(sparse_data, tf.float32)
tf.print 打印稀疏张量(含indices/values) tf.print(sparse_data)
tf.math.abs 非零值取绝对值 tf.math.abs(sparse_data)
tf.saved_model 保存/加载含稀疏张量的模型 model.save("sparse_model")
tf.io.serialize_sparse 序列化稀疏张量为字节流 tf.io.serialize_sparse(sparse_data)

核心避坑总结

1. 稀疏→密集的隐式转换

  • Dense层、matmul(稀疏×密集)等操作会隐式转换为密集张量,超稀疏场景可能导致OOM;
  • 解决方案:优先使用稀疏专用算子(如tf.sparse.sparse_dense_matmul)。

2. 形状兼容性

  • tf.data批处理时,非批轴的形状必须一致(如示例中所有元素shape=[4]);
  • 变长形状需用dense_to_sparse_batch,而非普通batch

3. tf.train.Example的选型

  • 一维变长数据:优先用tf.io.RaggedFeature(替代VarLenFeature);
  • 高维稀疏数据:用tf.io.SparseFeature存储三个组件。

4. tf.function的静态形状

  • tf.function编译时会固化稀疏张量的dense_shape,动态形状需用tf.TensorShape(None)兼容。

实战价值总结

稀疏张量与核心API的无缝兼容,使得:

  • 数据预处理:用tf.data高效处理超稀疏数据(如TF-IDF、高维特征);
  • 模型训练:用tf.keras直接以稀疏张量为输入,避免密集转换的内存浪费;
  • 性能优化:用tf.function编译稀疏操作,提升运行效率;
  • 数据存储:用tf.train.Example/tf.saved_model序列化稀疏数据,节省存储成本。

这是处理大规模稀疏数据(如推荐系统、NLP、计算机视觉)的核心能力,能大幅降低内存占用和计算开销。

相关推荐
SamtecChina20233 小时前
Electronica现场演示 | Samtec前面板解决方案
大数据·人工智能·算法·计算机外设
2401_841495643 小时前
【自然语言处理】字符编码与字频统计:中文信息处理的底层逻辑与实践维度
人工智能·自然语言处理·中文信息处理·西文字符编码的奠基·中文编码的演进·字符编码的实践价值·字频统计的作用与方法
雍凉明月夜3 小时前
视觉opencv学习笔记Ⅴ-数据增强(2)
人工智能·python·opencv·计算机视觉
JoannaJuanCV3 小时前
自动驾驶—CARLA仿真(24)sensor_synchronization demo
网络·人工智能·自动驾驶·carla
JoannaJuanCV3 小时前
自动驾驶—CARLA仿真(14)draw_skeleton demo
人工智能·机器学习·自动驾驶
测试人社区-千羽3 小时前
飞机自动驾驶系统测试:安全关键系统的全面验证框架
人工智能·安全·面试·职场和发展·自动化·自动驾驶·测试用例
Abona3 小时前
广义端到端(GE2E)自动驾驶技术综述:范式演进、核心挑战与破局路径
人工智能·机器学习·自动驾驶
CSDN官方博客3 小时前
CSDN社区镜像创作活动
大数据·运维·人工智能
棒棒的皮皮3 小时前
【OpenCV】Python图像处理几何变换之缩放
图像处理·python·opencv·计算机视觉