TensorFlow 中不规则张量(RaggedTensor)

这份文档核心介绍了 TensorFlow 中不规则张量(RaggedTensor) 的定义、适用场景和核心用法,解决了普通张量"必须固定形状"的痛点,专门用于处理非均匀长度/嵌套可变结构的数据。以下是通俗化的核心概述:

一、什么是不规则张量(RaggedTensor)?

普通 TensorFlow 张量要求所有维度的长度固定(比如 [[1,2],[3,4]] 是合法的,[[1,2],[3]] 不合法),而 RaggedTensor 是 TensorFlow 对「嵌套可变长度列表」的原生支持,允许同一维度下的元素长度不一致(比如 [[3,1,4,1], [], [5,9,2], [6], []]),是处理非均匀形状数据的核心工具。

二、适用场景(解决哪些问题?)

专门应对普通张量无法存储的「可变长度/分层结构」数据,典型场景:

  1. 可变长度特征(如电影的演员名单:有的电影3个演员,有的5个);
  2. 批量可变长度序列(如句子:有的5个单词,有的8个;视频剪辑:有的10帧,有的20帧);
  3. 分层数据(如文本文档:文档→节→段落→句子→单词,每层长度都可变);
  4. 结构化数据的字段(如协议缓冲区中长度不固定的字段)。

三、核心功能:兼容大量 TensorFlow 运算

超过100种TF原生运算支持 RaggedTensor,无需额外转换,直接使用,包括:

运算类型 示例(文档原版) 效果说明
数学运算 tf.add(digits, 3)/tf.reduce_mean(digits, axis=1) 标量加法(所有元素+3)、按行求均值(空行返回nan)
数组运算 tf.concat([digits, [[5,3]]], axis=0)/tf.tile(digits, [1,2]) 纵向拼接(新增一行[5,3])、维度复制(每行元素重复2次)
字符串操作 tf.strings.substr(words, 0, 2) 对字符串型RaggedTensor截取前2个字符(如"So"→"So","thanks"→"th")
控制流/映射 tf.map_fn(tf.math.square, digits) 逐元素执行函数(如平方:3→9,1→1)

基础示例(文档原版)

python 复制代码
# 定义数字型、字符串型RaggedTensor
digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
words = tf.ragged.constant([["So", "long"], ["thanks", "for", "all", "the", "fish"]])

# 数学运算:所有元素+3
print(tf.add(digits, 3))  # 输出:[[6,4,7,4], [], [8,12,5], [9], []]
# 数组运算:每行元素重复2次
print(tf.tile(digits, [1, 2]))  # 输出:[[3,1,4,1,3,1,4,1], [], [5,9,2,5,9,2], [6,6], []]

四、常用操作(和普通张量用法对齐,易上手)

1. Python 风格索引/切片

和普通列表/张量的索引逻辑一致,支持精准切片可变长度的行:

python 复制代码
print(digits[0])        # 取第一行 → [3,1,4,1](普通张量)
print(digits[:, :2])    # 取每行前2个元素 → [[3,1], [], [5,9], [6], []]
print(digits[:, -2:])   # 取每行最后2个元素 → [[4,1], [], [9,2], [6], []]
2. 重载算术/比较运算符

支持逐元素运算,可直接和标量、另一个 RaggedTensor 计算:

python 复制代码
# 和标量运算:所有元素+3
print(digits + 3)  # 同tf.add(digits, 3)
# 和另一个同结构RaggedTensor逐元素相加
print(digits + tf.ragged.constant([[1,2,3,4], [], [5,6,7], [8], []]))
# 输出:[[4,3,7,5], [], [10,15,9], [14], []]
3. 自定义逐元素转换(tf.ragged.map_flat_values)

对 RaggedTensor 的所有元素应用自定义函数,无需关心长度:

python 复制代码
# 定义函数:x*2+1
times_two_plus_one = lambda x: x * 2 + 1
# 逐元素应用函数
print(tf.ragged.map_flat_values(times_two_plus_one, digits))
# 输出:[[7,3,9,3], [], [11,19,5], [13], []]

五、格式转换(和Python/NumPy互通)

RaggedTensor 可便捷转换为常用格式,满足数据查看/导出需求:

  1. 转嵌套 Python list(最直观,保留可变长度结构):

    python 复制代码
    digits.to_list()  # 输出:[[3, 1, 4, 1], [], [5, 9, 2], [6], []]
  2. 转 NumPy array(返回object类型数组,每个元素是独立的小数组):

    python 复制代码
    digits.numpy()
    # 输出:array([array([3, 1, 4, 1]), array([]), array([5,9,2]), array([6]), array([])], dtype=object)

核心总结

RaggedTensor 是 TensorFlow 为「非均匀形状数据」量身打造的张量类型,核心优势:

  • 无需手动填充(如用0补全可变长度序列),保留原始数据结构;
  • 兼容绝大多数 TF 原生运算,用法和普通张量对齐,学习成本低;
  • 支持灵活的索引、自定义转换和格式互通,覆盖可变长度数据的存储、处理、导出全流程。
相关推荐
丝斯201116 小时前
AI学习笔记整理(67)——大模型的Benchmark(基准测试)
人工智能·笔记·学习
咚咚王者16 小时前
人工智能之核心技术 深度学习 第七章 扩散模型(Diffusion Models)
人工智能·深度学习
github.com/starRTC16 小时前
Claude Code中英文系列教程25:非交互式运行 Claude Code
人工智能·ai编程
逄逄不是胖胖16 小时前
《动手学深度学习》-60translate实现
人工智能·python·深度学习
loui robot16 小时前
规划与控制之局部路径规划算法local_planner
人工智能·算法·自动驾驶
玄同76516 小时前
Llama.cpp 全实战指南:跨平台部署本地大模型的零门槛方案
人工智能·语言模型·自然语言处理·langchain·交互·llama·ollama
格林威16 小时前
Baumer相机金属焊缝缺陷识别:提升焊接质量检测可靠性的 7 个关键技术,附 OpenCV+Halcon 实战代码!
人工智能·数码相机·opencv·算法·计算机视觉·视觉检测·堡盟相机
独处东汉16 小时前
freertos开发空气检测仪之按键输入事件管理系统设计与实现
人工智能·stm32·单片机·嵌入式硬件·unity
你大爷的,这都没注册了16 小时前
AI提示词,zero-shot,few-shot 概念
人工智能
AC赳赳老秦16 小时前
DeepSeek 辅助科研项目申报:可行性报告与经费预算框架的智能化撰写指南
数据库·人工智能·科技·mongodb·ui·rabbitmq·deepseek