零基础吃透:创建tf.sparse.SparseTensor的核心方法
创建tf.sparse.SparseTensor是使用稀疏张量的基础,TensorFlow提供了直接构造 和从密集张量转换两种核心方式,同时可通过自定义函数美化打印结果(便于调试),也能轻松转回密集张量。以下结合示例拆解每个步骤的原理、用法和注意事项。
一、环境警告说明(先避坑)
代码中出现的GPU相关警告(如cuFFT/cuDNN/cuBLAS factory)是因为本地环境的GPU库重复注册/缺失,不影响稀疏张量的核心功能(CPU环境下可正常运行),无需处理即可继续。
二、方式1:直接构造SparseTensor(核心参数)
2.1 构造原理
直接通过tf.sparse.SparseTensor构造,需指定三个核心参数(COO格式):
| 参数名 | 要求 | 示例 |
|---|---|---|
indices |
二维张量(dtype=int64),每行是一个非零值的坐标,形状[N, rank] |
[[0,3], [2,4]](2个非零值,二维坐标) |
values |
一维张量,长度=N(与indices行数一致),存储非零值 | [10,20] |
dense_shape |
一维张量(dtype=int64),指定稀疏张量对应的密集形状,长度=rank | [3,10](3行10列的二维张量) |
2.2 示例代码
python
import tensorflow as tf
# 直接构造稀疏张量
st1 = tf.sparse.SparseTensor(
indices=[[0, 3], [2, 4]], # 非零值坐标:(0,3)=10,(2,4)=20
values=[10, 20], # 非零值列表
dense_shape=[3, 10] # 对应密集张量形状:3行10列
)
# 原生打印(显示三个核心组件)
print("原生打印SparseTensor:")
print(st1)
2.3 输出解读
SparseTensor(indices=tf.Tensor(
[[0 3]
[2 4]], shape=(2, 2), dtype=int64),
values=tf.Tensor([10 20], shape=(2,), dtype=int32),
dense_shape=tf.Tensor([ 3 10], shape=(2,), dtype=int64))
indices:二维int64张量,2行2列(2个非零值,二维坐标);values:一维int32张量,存储2个非零值;dense_shape:一维int64张量,指定密集形状为[3,10]。
三、美观打印SparseTensor(调试必备)
原生打印的格式不直观,可自定义函数将"坐标-值"一一对应打印,便于快速理解稀疏张量的内容。
3.1 自定义打印函数原理
遍历indices和values,逐个拼接"坐标: 值"的格式,最终输出结构化的字符串。
3.2 示例代码
python
def pprint_sparse_tensor(st):
# 初始化字符串,先打印密集形状
s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
# 遍历每个非零值的坐标和值
for (index, value) in zip(st.indices, st.values):
# 拼接坐标(列表格式)和值
s += f"\n %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
return s + "}>"
# 美观打印st1
print("\n美观打印SparseTensor:")
print(pprint_sparse_tensor(st1))
3.3 输出解读
<SparseTensor shape=[3, 10]
values={
[0, 3]: 10
[2, 4]: 20}>
直观看到:
- 稀疏张量对应密集形状是
3行10列; - 非零值位置:
(0,3)为10,(2,4)为20; - 其余位置均为隐式零值。
四、方式2:从密集张量转换为SparseTensor
4.1 核心函数:tf.sparse.from_dense
自动提取密集张量中的非零值及其坐标,生成对应的SparseTensor(无需手动指定indices/values)。
4.2 示例代码
python
# 从密集张量创建稀疏张量
dense_tensor = [[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]]
st2 = tf.sparse.from_dense(dense_tensor)
# 美观打印转换后的稀疏张量
print("从密集张量转换的SparseTensor:")
print(pprint_sparse_tensor(st2))
4.3 输出解读
<SparseTensor shape=[3, 4]
values={
[0, 0]: 1
[0, 3]: 8
[2, 2]: 3}>
- 密集张量的非零值:
(0,0)=1、(0,3)=8、(2,2)=3,其余为0; tf.sparse.from_dense自动过滤零值,仅保留非零值的坐标和值。
五、稀疏张量转回密集张量
5.1 核心函数:tf.sparse.to_dense
根据SparseTensor的indices/values/dense_shape,填充非零值,其余位置补0,生成密集张量。
5.2 示例代码
python
# 稀疏张量转回密集张量
st3 = tf.sparse.to_dense(st2)
print("\n稀疏张量转回的密集张量:")
print(st3)
5.3 输出解读
tf.Tensor(
[[1 0 0 8]
[0 0 0 0]
[0 0 3 0]], shape=(3, 4), dtype=int32)
与原始密集张量完全一致,验证了转换的可逆性。
六、关键注意事项(避坑核心)
1. 数据类型要求
indices和dense_shape的dtype必须是int64(TensorFlow强制要求,手动指定时若用int32会报错);values的dtype可自定义(int32/float32等),但需与业务场景匹配。
2. 索引格式要求
indices的每行长度必须等于dense_shape的长度(即张量的秩):- 二维张量的索引是
[行, 列](长度2); - 三维张量的索引是
[深度, 行, 列](长度3)。
- 二维张量的索引是
3. 显式零值的处理
-
tf.sparse.from_dense会自动过滤隐式零值 (未存储的零),但如果密集张量中主动存储0(显式零值),会被保留:python# 含显式零值的密集张量 dense_with_zero = [[1, 0, 0], [0, 0, 0], [2, 0, 3]] st_with_zero = tf.sparse.from_dense(dense_with_zero) print(pprint_sparse_tensor(st_with_zero)) # 仅保留1、2、3,过滤0
4. 空张量处理
-
若密集张量全为0,
tf.sparse.from_dense生成的SparseTensor的indices和values为空:pythondense_all_zero = [[0,0], [0,0]] st_all_zero = tf.sparse.from_dense(dense_all_zero) print(st_all_zero.indices.numpy()) # 空数组 [] print(st_all_zero.values.numpy()) # 空数组 []
七、核心总结
| 操作 | 函数/方法 | 核心用途 |
|---|---|---|
| 直接构造稀疏张量 | tf.sparse.SparseTensor |
手动指定非零值坐标和值(精准控制) |
| 密集→稀疏 | tf.sparse.from_dense |
自动提取非零值,快速生成稀疏张量 |
| 稀疏→密集 | tf.sparse.to_dense |
还原为密集张量,适配不支持稀疏的算子 |
| 美观打印稀疏张量 | 自定义pprint_sparse_tensor函数 |
调试时直观查看非零值的坐标和值 |
掌握这三种核心操作,就能灵活创建和转换稀疏张量,满足NLP/计算机视觉等场景下的稀疏数据处理需求。