tensorflow 零基础吃透:TensorFlow 稀疏张量(SparseTensor)的核心操作

零基础吃透:TensorFlow稀疏张量(SparseTensor)的核心操作

稀疏张量无法直接使用tf.math.add等密集张量的算术算子,必须通过tf.sparse包下的专用工具进行操作。本文拆解加法、矩阵乘法、拼接、切片、元素级运算五大核心操作,结合示例讲清原理、用法和版本兼容细节。

前置准备(必运行)

python 复制代码
import tensorflow as tf

# 复用之前的美观打印函数(调试必备)
def pprint_sparse_tensor(st):
  s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
  for (index, value) in zip(st.indices, st.values):
    s += f"\n  %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
  return s + "}>"

# 示例稀疏张量(后续操作会复用)
st2 = tf.sparse.from_dense([[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]])

一、稀疏张量加法(tf.sparse.add)

核心原理

仅对同形状稀疏张量的「相同坐标非零值」相加,不同坐标的非零值直接保留,最终输出仍为稀疏张量(仅存储非零结果)。

示例代码

python 复制代码
# 构造两个同形状的稀疏张量
st_a = tf.sparse.SparseTensor(
    indices=[[0, 2], [3, 4]],
    values=[31, 2], 
    dense_shape=[4, 10]  # 4行10列
)

st_b = tf.sparse.SparseTensor(
    indices=[[0, 2], [3, 0]],
    values=[56, 38],
    dense_shape=[4, 10]  # 必须与st_a形状一致
)

# 稀疏张量加法
st_sum = tf.sparse.add(st_a, st_b)
print("稀疏张量相加结果:")
print(pprint_sparse_tensor(st_sum))

输出解读

复制代码
<SparseTensor shape=[4, 10] 
 values={
  [0, 2]: 87  # st_a[0,2]=31 + st_b[0,2]=56
  [3, 0]: 38  # 仅st_b有该坐标,直接保留
  [3, 4]: 2   # 仅st_a有该坐标,直接保留
}>

关键注意事项

  • ❌ 形状不同会报错:必须保证dense_shape完全一致;
  • ✅ 结果仅保留非零值:若相加后某坐标值为0(如st_a[0,2]=-56 + st_b[0,2]=56),会被过滤出结果。

二、稀疏×密集矩阵乘法(tf.sparse.sparse_dense_matmul)

核心原理

稀疏张量作为矩阵(需满足矩阵乘法的形状规则),与密集矩阵相乘,无需转换为密集张量,大幅节省内存(超稀疏矩阵效率提升显著)。

示例代码

python 复制代码
# 构造2×2的稀疏矩阵(非零值:[0,1]=13,[1,0]=15,[1,1]=17)
st_c = tf.sparse.SparseTensor(
    indices=[[0, 1], [1, 0], [1, 1]],  # 注意:原代码的indices写法有误,修正为列表格式
    values=[13, 15, 17],
    dense_shape=(2, 2)
)

# 构造2×1的密集矩阵
mb = tf.constant([[4], [6]])

# 稀疏×密集矩阵乘法
product = tf.sparse.sparse_dense_matmul(st_c, mb)
print("\n稀疏×密集矩阵乘法结果:")
print(product)

计算逻辑(验证结果)

矩阵乘法规则:C × B = [ (0×4+13×6), (15×4+17×6) ]^T

  • 第一行:0×4 + 13×6 = 78
  • 第二行:15×4 + 17×6 = 60 + 102 = 162

输出解读

复制代码
tf.Tensor(
[[ 78]
 [162]], shape=(2, 1), dtype=int32)

关键注意事项

  • 形状规则:稀疏张量的列数 = 密集矩阵的行数(如2×2 × 2×1 合法);
  • 索引顺序:建议先通过tf.sparse.reorder排序稀疏张量索引,避免运算异常。

三、稀疏张量拼接(tf.sparse.concat)

核心原理

沿指定轴(如列轴axis=1)拼接多个稀疏张量,要求除拼接轴外的其他轴形状一致,最终输出合并后的稀疏张量。

示例代码

python 复制代码
# 构造3个待拼接的稀疏张量(行维度均为8,列维度不同)
sparse_pattern_A = tf.sparse.SparseTensor(
    indices = [[2,4], [3,3], [3,4], [4,3], [4,4], [5,4]],
    values = [1]*6,
    dense_shape = [8,5]  # 8行5列
)
sparse_pattern_B = tf.sparse.SparseTensor(
    indices = [[0,2], [1,1], [1,3], [2,0], [2,4], [2,5], [3,5], 
               [4,5], [5,0], [5,4], [5,5], [6,1], [6,3], [7,2]],
    values = [1]*14,
    dense_shape = [8,6]  # 8行6列
)
sparse_pattern_C = tf.sparse.SparseTensor(
    indices = [[3,0], [4,0]],
    values = [1]*2,
    dense_shape = [8,6]  # 8行6列
)

# 沿列轴(axis=1)拼接
sparse_pattern = tf.sparse.concat(
    axis=1,  # 列轴拼接(行轴保持8不变)
    sp_inputs=[sparse_pattern_A, sparse_pattern_B, sparse_pattern_C]
)

# 转换为密集张量查看拼接结果
print("\n拼接后的密集张量:")
print(tf.sparse.to_dense(sparse_pattern))

输出解读

拼接后形状为8×(5+6+6)=8×17,非零值按原位置分布在对应列区间:

  • A的非零值在列0~4;
  • B的非零值在列5~10;
  • C的非零值在列11~16。

关键注意事项

  • 拼接轴外的维度必须一致:如示例中所有张量的行维度均为8,仅列维度不同;
  • 拼接后非零值位置:原张量的列索引自动偏移(如B的列0→拼接后的列5)。

四、稀疏张量切片(tf.sparse.slice)

核心原理

沿指定轴截取稀疏张量的子区域,仅保留「切片范围内的非零值」,输出新的稀疏张量(形状为指定的size)。

函数参数

参数 作用
start 切片起始坐标(列表/张量),长度=张量秩(如[0,0]表示行0列0开始)
size 切片大小(列表/张量),长度=张量秩(如[8,5]表示截取8行5列)

示例代码

python 复制代码
# 对拼接后的张量切片(还原原张量)
sparse_slice_A = tf.sparse.slice(sparse_pattern_A, start = [0,0], size = [8,5])
sparse_slice_B = tf.sparse.slice(sparse_pattern_B, start = [0,5], size = [8,6])
sparse_slice_C = tf.sparse.slice(sparse_pattern_C, start = [0,10], size = [8,6])

# 打印切片结果(转密集张量)
print("\n切片A(8×5):")
print(tf.sparse.to_dense(sparse_slice_A))
print("\n切片B(8×1):")  # 原B的start=[0,5],size=[8,6]但仅列5有值,故输出8×1
print(tf.sparse.to_dense(sparse_slice_B))
print("\n切片C(8×0):")  # 无符合条件的非零值,输出空
print(tf.sparse.to_dense(sparse_slice_C))

输出解读

复制代码
切片A(8×5):
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 1]
 [0 0 0 1 1]
 [0 0 0 1 1]
 [0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]]

切片B(8×1):
[[0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]]

切片C(8×0):
[]

关键注意事项

  • 切片范围外的非零值会被过滤:如切片B仅截取列5,原B的其他列非零值被丢弃;
  • 空切片:无符合条件的非零值时,输出shape=(8,0)的空稀疏张量。

五、元素级运算(仅修改非零值)

场景:对稀疏张量的所有非零值做统一运算(如+5)

方式1:TF2.4+ 专用(tf.sparse.map_values)

tf.sparse.map_values专门对稀疏张量的values(非零值)做元素级运算,零值保持不变。

python 复制代码
# 对st2的非零值+5
st2_plus_5 = tf.sparse.map_values(tf.add, st2, 5)
print("\nTF2.4+ 非零值+5(密集张量):")
print(tf.sparse.to_dense(st2_plus_5))
方式2:TF2.4前 兼容方案

手动构造新的SparseTensor,仅修改values,保留indicesdense_shape

python 复制代码
# 老版本兼容写法:直接修改values
st2_plus_5_compat = tf.sparse.SparseTensor(
    st2.indices,          # 保留原坐标
    st2.values + 5,       # 非零值+5
    st2.dense_shape       # 保留原形状
)
print("\n老版本兼容 非零值+5(密集张量):")
print(tf.sparse.to_dense(st2_plus_5_compat))

输出解读(两种方式结果一致)

复制代码
[[ 6  0  0 13]
 [ 0  0  0  0]
 [ 0  0  8  0]]
  • 仅非零值被修改:原1→68→133→8
  • 零值保持不变:符合稀疏张量的设计初衷(仅操作有效数据)。

核心操作总结表

操作 函数 核心要求 适用场景
稀疏加法 tf.sparse.add 张量形状完全一致 同形状稀疏张量逐坐标相加
稀疏-密集矩阵乘法 tf.sparse.sparse_dense_matmul 稀疏列数=密集行数 超稀疏矩阵与密集矩阵相乘
稀疏拼接 tf.sparse.concat 非拼接轴形状一致 合并多个稀疏张量的列/行
稀疏切片 tf.sparse.slice start/size长度=张量秩 截取稀疏张量的子区域
元素级运算 tf.sparse.map_values TF2.4+,仅修改非零值 对非零值做统一算术运算(+/-/*//)

避坑关键

  1. 形状匹配:所有稀疏张量操作的核心是「形状兼容」,形状不匹配会直接报错;
  2. 索引顺序 :运算前建议用tf.sparse.reorder排序索引,避免算子异常;
  3. 版本兼容tf.sparse.map_values仅TF2.4+支持,老版本需手动修改values
  4. 零值处理:所有操作均仅处理非零值,零值始终保持隐式存储(不占用内存)。

掌握这些操作,就能高效处理NLP(TF-IDF)、计算机视觉(稀疏像素)等场景下的超稀疏数据,大幅降低内存占用和计算开销。

相关推荐
夏洛克信徒4 小时前
AI纪元2025终章:开源革命、监管铁幕与人类主体性的觉醒
人工智能·开源
心疼你的一切4 小时前
深度学习入门_神经网络基础
人工智能·深度学习·神经网络·机器学习
撬动未来的支点4 小时前
【AI邪修·破壁行动】神经网络基础—核心数据结构—张量
人工智能·深度学习·神经网络
玖日大大4 小时前
TensorFlow 深度解析:从基础到实战的全维度指南
人工智能·python·tensorflow
丝瓜蛋汤5 小时前
Conan-embedding整理
人工智能·embedding
云老大TG:@yunlaoda3605 小时前
腾讯云国际站代理商的OCR有什么优势呢?
人工智能·ocr·腾讯云
小追兵5 小时前
AI 照片修复神器:如何用 AI 恢复老照片高清细节
人工智能
智驱力人工智能5 小时前
山区搜救无人机人员检测算法 技术攻坚与生命救援的融合演进 城市高空无人机人群密度分析 多模态融合无人机识别系统
人工智能·深度学习·算法·架构·无人机·边缘计算
我很哇塞耶5 小时前
英伟达开源发布最新AI模型!引入突破性专家混合架构,推理性能超越Qwen3和GPT,百万token上下文,模型数据集全开源!
人工智能·ai·大模型