tensorflow 零基础吃透:TensorFlow 张量切片与数据插入(附目标检测 / NLP 实战场景)

零基础吃透:TensorFlow张量切片与数据插入(附目标检测/NLP实战场景)

张量切片(提取子部分)和数据插入是TensorFlow处理结构化数据的核心操作,广泛用于目标检测(特征路由、选框特征提取)NLP(单词遮盖、序列切片) 等场景。本文拆解「张量切片」「数据插入」两大核心模块,结合实战示例讲清tf.slice/tf.gather/tf.scatter_nd等API的用法、原理和场景适配。

前置准备(必运行)

python 复制代码
import tensorflow as tf
import numpy as np

# 消除GPU警告(不影响功能)
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

一、提取张量切片(核心场景:特征筛选/序列截取)

核心应用场景

  • 目标检测:从特征图中切片提取目标框对应的特征、按路由规则拆分样本特征;
  • NLP:截取句子的子序列、提取指定位置的单词(如遮盖任务选待遮盖单词)。

1. 基础切片:tf.slice 与 Python风格切片

tf.slice是TensorFlow原生切片API,Python风格切片([])更简洁,二者功能等价,支持一维/二维/高维张量。

1.1 一维张量切片
python 复制代码
# 基础一维张量
t1 = tf.constant([0, 1, 2, 3, 4, 5, 6, 7])

# 方式1:tf.slice(begin=起始索引,size=截取长度)
slice_result = tf.slice(t1, begin=[1], size=[3])
print("tf.slice结果:", slice_result.numpy())  # [1 2 3]

# 方式2:Python风格切片(start:stop,左闭右开)
python_slice1 = t1[1:4]  # 等价于begin=[1], size=[3]
python_slice2 = t1[-3:]  # 最后3个元素
print("Python切片[1:4]:", python_slice1.numpy())  # [1 2 3]
print("Python切片[-3:]:", python_slice2.numpy())  # [5 6 7]
1.2 二维张量切片(目标检测特征图切片)
python 复制代码
# 4行5列的二维张量(模拟目标检测特征图)
t2 = tf.constant([[0, 1, 2, 3, 4],
                  [5, 6, 7, 8, 9],
                  [10, 11, 12, 13, 14],
                  [15, 16, 17, 18, 19]])

# 切片:所有行除最后一行(:-1),列1到3(1:3)
slice_2d = t2[:-1, 1:3]
print("\n二维张量切片结果:")
print(slice_2d.numpy())

输出

复制代码
[[ 1  2]
 [ 6  7]
 [11 12]]
1.3 高维张量切片(3D特征图)
python 复制代码
# 2×2×4的三维张量(模拟批量特征图)
t3 = tf.constant([[[1, 3, 5, 7],
                   [9, 11, 13, 15]],
                  [[17, 19, 21, 23],
                   [25, 27, 29, 31]]])

# tf.slice:begin=[1,1,0](第1个批量、第1行、第0列),size=[1,1,2](截取1×1×2)
slice_3d = tf.slice(t3, begin=[1, 1, 0], size=[1, 1, 2])
print("\n三维张量切片结果:")
print(slice_3d.numpy())  # [[[25 27]]]

2. 跨步切片:tf.strided_slice(间隔截取)

tf.strided_slice支持「跨步」截取(类似NumPy的[start:stop:step]),适合按固定间隔提取元素(如NLP每隔k个单词采样)。

python 复制代码
# 一维张量跨步切片:起始0,结束8,步长3
strided_slice = tf.strided_slice(t1, begin=[0], end=[8], strides=[3])
print("\n跨步切片结果:", strided_slice.numpy())  # [0 3 6]

# 等价于Python风格:t1[::3]
print("Python跨步切片:", t1[::3].numpy())  # [0 3 6]

3. 单轴任意索引提取:tf.gather

tf.gather单个轴提取「非均匀索引」的元素(无需间隔),适合NLP提取指定位置的字符/单词、目标检测选指定样本。

python 复制代码
# 示例1:提取t1的0、3、6索引
gather_1d = tf.gather(t1, indices=[0, 3, 6])
print("\ntf.gather提取指定索引:", gather_1d.numpy())  # [0 3 6]

# 示例2:NLP提取任意字符(非均匀索引)
alphabet = tf.constant(list('abcdefghijklmnopqrstuvwxyz'))
gather_char = tf.gather(alphabet, indices=[2, 0, 19, 18])  # c、a、t、s
print("提取任意字符:", gather_char.numpy())  # [b'c' b'a' b't' b's']

4. 多轴任意索引提取:tf.gather_nd

tf.gather_nd多个轴提取元素(支持矩阵元素、高维张量任意位置),是目标检测「提取指定目标框特征」的核心API。

示例1:提取矩阵指定行
python 复制代码
# 5行2列矩阵
t4 = tf.constant([[0, 5], [1, 6], [2, 7], [3, 8], [4, 9]])
# 提取第2、3、0行(索引[[2], [3], [0]])
gather_2d_row = tf.gather_nd(t4, indices=[[2], [3], [0]])
print("\n提取矩阵指定行:")
print(gather_2d_row.numpy())

输出

复制代码
[[2 7]
 [3 8]
 [0 5]]
示例2:提取高维张量任意位置元素
python 复制代码
# 2×3×3的三维张量(模拟批量特征图)
t5 = np.reshape(np.arange(18), [2, 3, 3])

# 提取单个元素:[0,0,0]和[1,2,1]
gather_3d_elem = tf.gather_nd(t5, indices=[[0, 0, 0], [1, 2, 1]])
print("\n提取高维张量单个元素:", gather_3d_elem.numpy())  # [ 0 16]

# 提取多个子矩阵:[[0,0], [0,2]] 和 [[1,0], [1,2]]
gather_3d_mat = tf.gather_nd(t5, indices=[[[0, 0], [0, 2]], [[1, 0], [1, 2]]])
print("\n提取多个子矩阵:")
print(gather_3d_mat.numpy())

二、插入数据到张量(核心场景:遮盖/特征更新)

核心应用场景

  • NLP单词遮盖 :将选中的单词位置插入遮盖标记(<MASK>);
  • 目标检测:更新特征图中目标框的特征值、构造稀疏特征张量;
  • 矩阵操作:构造魔法方阵、单位矩阵。

1. 零初始化插入:tf.scatter_nd

tf.scatter_nd全零张量 的指定索引处插入数据,是模拟稀疏张量的核心方法(无需显式构造SparseTensor)。

python 复制代码
# 构造参数:indices=插入位置,updates=插入值,shape=目标张量形状
t6 = tf.constant([10])  # shape=[10]的全零张量
indices = tf.constant([[1], [3], [5], [7], [9]])  # 插入位置
data = tf.constant([2, 4, 6, 8, 10])  # 插入值

scatter_result = tf.scatter_nd(indices=indices, updates=data, shape=t6)
print("\ntf.scatter_nd插入结果:")
print(scatter_result.numpy())  # [ 0  2  0  4  0  6  0  8  0 10]
模拟稀疏张量(tf.gather_nd + tf.scatter_nd)
python 复制代码
# 步骤1:从t2提取指定位置的值
new_indices = tf.constant([[0, 2], [2, 1], [3, 3]])
t7 = tf.gather_nd(t2, indices=new_indices)  # 提取值:2、11、18

# 步骤2:将值插入全零张量(模拟稀疏张量)
t8 = tf.scatter_nd(indices=new_indices, updates=t7, shape=tf.constant([4, 5]))
print("\n模拟稀疏张量插入结果:")
print(t8.numpy())

# 等价于构造SparseTensor后转密集
t9 = tf.sparse.SparseTensor(
    indices=[[0, 2], [2, 1], [3, 3]],
    values=[2, 11, 18],
    dense_shape=[4, 5]
)
t10 = tf.sparse.to_dense(t9)
print("\nSparseTensor转密集结果(与上一致):")
print(t10.numpy())

输出

复制代码
[[ 0  0  2  0  0]
 [ 0  0  0  0  0]
 [ 0 11  0  0  0]
 [ 0  0  0 18  0]]

2. 已有张量插入(加减):tf.tensor_scatter_nd_add/sub

已有值的张量上,对指定索引执行「加法/减法」更新,适合动态修改特征值。

示例1:构造魔法方阵(加法更新)
python 复制代码
# 初始张量
t11 = tf.constant([[2, 7, 0],
                   [9, 0, 1],
                   [0, 3, 8]])

# 在指定位置加值,构造魔法方阵(每行/列和为15)
t12 = tf.tensor_scatter_nd_add(
    t11,
    indices=[[0, 2], [1, 1], [2, 0]],  # 插入位置
    updates=[6, 5, 4]  # 加6、加5、加4
)
print("\n魔法方阵(加法更新):")
print(t12.numpy())

输出

复制代码
[[2 7 6]
 [9 5 1]
 [4 3 8]]
示例2:构造单位矩阵(减法更新)
python 复制代码
# 从t11中减去指定值,得到单位矩阵
t13 = tf.tensor_scatter_nd_sub(
    t11,
    indices=[[0, 0], [0, 1], [1, 0], [1, 1], [1, 2], [2, 1], [2, 2]],
    updates=[1, 7, 9, -1, 1, 3, 7]
)
print("\n单位矩阵(减法更新):")
print(t13.numpy())

输出

复制代码
[[1 0 0]
 [0 1 0]
 [0 0 1]]

3. 极值插入:tf.tensor_scatter_nd_min/max

在指定索引处,将张量值更新为「当前值」与「插入值」的最小值/最大值,适合特征值的极值约束。

python 复制代码
# 初始张量
t14 = tf.constant([[-2, -7, 0],
                   [-9, 0, 1],
                   [0, -3, -8]])

# 最小值更新:指定位置取当前值和插入值的最小
t15 = tf.tensor_scatter_nd_min(
    t14,
    indices=[[0, 2], [1, 1], [2, 0]],
    updates=[-6, -5, -4]
)
print("\n最小值更新结果:")
print(t15.numpy())

# 最大值更新:指定位置取当前值和插入值的最大
t16 = tf.tensor_scatter_nd_max(
    t14,
    indices=[[0, 2], [1, 1], [2, 0]],
    updates=[6, 5, 4]
)
print("\n最大值更新结果:")
print(t16.numpy())

输出

复制代码
# 最小值更新
[[-2 -7 -6]
 [-9 -5  1]
 [-4 -3 -8]]

# 最大值更新
[[-2 -7  6]
 [-9  5  1]
 [ 4 -3 -8]]

三、实战场景落地(NLP单词遮盖示例)

python 复制代码
# 模拟NLP单词遮盖任务:
# 1. 构造句子张量(单词索引):["I", "like", "ragged", "tensors"] → [0,1,2,3]
sentence = tf.constant([0, 1, 2, 3], dtype=tf.int32)
# 2. 选择要遮盖的单词索引:2("ragged")
mask_indices = tf.constant([[2]])
# 3. 提取遮盖单词(用于标签)
mask_label = tf.gather_nd(sentence, mask_indices)
print("遮盖单词标签:", mask_label.numpy())  # [2]

# 4. 插入遮盖标记(<MASK>→索引99)
masked_sentence = tf.scatter_nd(
    indices=mask_indices,
    updates=[99],
    shape=tf.constant([4])  # 句子长度4
)
# 5. 合并:未遮盖位置保留原词,遮盖位置为99
final_sentence = tf.where(masked_sentence == 0, sentence, masked_sentence)
print("遮盖后的句子:", final_sentence.numpy())  # [ 0  1 99  3]

核心API总结表

操作类型 API 核心功能 典型场景
基础切片 tf.slice / Python[] 按范围截取张量子部分 特征图区域截取、序列子串提取
跨步切片 tf.strided_slice 按固定间隔截取元素 序列采样、特征降采样
单轴索引提取 tf.gather 单轴非均匀索引提取 NLP字符/单词提取、样本筛选
多轴索引提取 tf.gather_nd 多轴任意位置提取 目标框特征提取、高维张量取值
零初始化插入 tf.scatter_nd 全零张量指定位置插入数据 构造稀疏张量、NLP单词遮盖
已有张量加减 tf.tensor_scatter_nd_add/sub 已有张量指定位置加减值 魔法方阵、单位矩阵构造
极值插入 tf.tensor_scatter_nd_min/max 指定位置更新为极值 特征值极值约束、异常值修正

避坑关键

  1. 索引维度匹配tf.gather_nd的indices维度需与张量秩匹配(如二维张量的indices每行长度为2);
  2. 形状一致性tf.scatter_nd的shape需与插入后张量形状一致,否则报错;
  3. 原地更新?:所有插入操作均返回新张量(TensorFlow张量不可变),需接收返回值;
  4. 稀疏场景优选 :大量零值插入时,优先用tf.scatter_ndSparseTensor,避免密集张量浪费内存。

掌握这些操作,可灵活处理目标检测/NLP中的张量拆分、重组、更新需求,是TensorFlow进阶的核心基础。

相关推荐
MM_MS39 分钟前
Halcon变量控制类型、数据类型转换、字符串格式化、元组操作
开发语言·人工智能·深度学习·算法·目标检测·计算机视觉·视觉检测
余俊晖7 小时前
多页文档理解强化学习设计思路:DocR1奖励函数设计与数据构建思路
人工智能·语言模型·自然语言处理
AI大佬的小弟9 小时前
【小白第一课】大模型基础知识(1)---大模型到底是啥?
人工智能·自然语言处理·开源·大模型基础·大模型分类·什么是大模型·国内外主流大模型
柯南小海盗10 小时前
从“会聊天的AI”到“全能助手”:大语言模型科普
人工智能·语言模型·自然语言处理
ggaofeng10 小时前
运行调试大语言模型
人工智能·语言模型·自然语言处理
大模型任我行11 小时前
微软:小模型微调优化企业搜索
人工智能·语言模型·自然语言处理·论文笔记
莫非王土也非王臣11 小时前
TensorFlow中卷积神经网络相关函数
人工智能·cnn·tensorflow
2501_9361460413 小时前
【计算机视觉系列】:基于YOLOv8-RepHGNetV2的鱿鱼目标检测模型优化与实现
yolo·目标检测·计算机视觉
智算菩萨14 小时前
【Python自然语言处理】实战项目:词向量表示完整实现指南
开发语言·python·自然语言处理
开放知识图谱15 小时前
论文浅尝 | 图上生成:将大语言模型视为智能体与知识图谱以解决不完整知识图谱问答(EMNLP2024)
人工智能·语言模型·自然语言处理·知识图谱