TensorFlow 中不规则张量(RaggedTensor)

这份文档核心介绍了 TensorFlow 中不规则张量(RaggedTensor) 的定义、适用场景和核心用法,解决了普通张量"必须固定形状"的痛点,专门用于处理非均匀长度/嵌套可变结构的数据。以下是通俗化的核心概述:

一、什么是不规则张量(RaggedTensor)?

普通 TensorFlow 张量要求所有维度的长度固定(比如 [[1,2],[3,4]] 是合法的,[[1,2],[3]] 不合法),而 RaggedTensor 是 TensorFlow 对「嵌套可变长度列表」的原生支持,允许同一维度下的元素长度不一致(比如 [[3,1,4,1], [], [5,9,2], [6], []]),是处理非均匀形状数据的核心工具。

二、适用场景(解决哪些问题?)

专门应对普通张量无法存储的「可变长度/分层结构」数据,典型场景:

  1. 可变长度特征(如电影的演员名单:有的电影3个演员,有的5个);
  2. 批量可变长度序列(如句子:有的5个单词,有的8个;视频剪辑:有的10帧,有的20帧);
  3. 分层数据(如文本文档:文档→节→段落→句子→单词,每层长度都可变);
  4. 结构化数据的字段(如协议缓冲区中长度不固定的字段)。

三、核心功能:兼容大量 TensorFlow 运算

超过100种TF原生运算支持 RaggedTensor,无需额外转换,直接使用,包括:

运算类型 示例(文档原版) 效果说明
数学运算 tf.add(digits, 3)/tf.reduce_mean(digits, axis=1) 标量加法(所有元素+3)、按行求均值(空行返回nan)
数组运算 tf.concat([digits, [[5,3]]], axis=0)/tf.tile(digits, [1,2]) 纵向拼接(新增一行[5,3])、维度复制(每行元素重复2次)
字符串操作 tf.strings.substr(words, 0, 2) 对字符串型RaggedTensor截取前2个字符(如"So"→"So","thanks"→"th")
控制流/映射 tf.map_fn(tf.math.square, digits) 逐元素执行函数(如平方:3→9,1→1)

基础示例(文档原版)

python 复制代码
# 定义数字型、字符串型RaggedTensor
digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
words = tf.ragged.constant([["So", "long"], ["thanks", "for", "all", "the", "fish"]])

# 数学运算:所有元素+3
print(tf.add(digits, 3))  # 输出:[[6,4,7,4], [], [8,12,5], [9], []]
# 数组运算:每行元素重复2次
print(tf.tile(digits, [1, 2]))  # 输出:[[3,1,4,1,3,1,4,1], [], [5,9,2,5,9,2], [6,6], []]

四、常用操作(和普通张量用法对齐,易上手)

1. Python 风格索引/切片

和普通列表/张量的索引逻辑一致,支持精准切片可变长度的行:

python 复制代码
print(digits[0])        # 取第一行 → [3,1,4,1](普通张量)
print(digits[:, :2])    # 取每行前2个元素 → [[3,1], [], [5,9], [6], []]
print(digits[:, -2:])   # 取每行最后2个元素 → [[4,1], [], [9,2], [6], []]
2. 重载算术/比较运算符

支持逐元素运算,可直接和标量、另一个 RaggedTensor 计算:

python 复制代码
# 和标量运算:所有元素+3
print(digits + 3)  # 同tf.add(digits, 3)
# 和另一个同结构RaggedTensor逐元素相加
print(digits + tf.ragged.constant([[1,2,3,4], [], [5,6,7], [8], []]))
# 输出:[[4,3,7,5], [], [10,15,9], [14], []]
3. 自定义逐元素转换(tf.ragged.map_flat_values)

对 RaggedTensor 的所有元素应用自定义函数,无需关心长度:

python 复制代码
# 定义函数:x*2+1
times_two_plus_one = lambda x: x * 2 + 1
# 逐元素应用函数
print(tf.ragged.map_flat_values(times_two_plus_one, digits))
# 输出:[[7,3,9,3], [], [11,19,5], [13], []]

五、格式转换(和Python/NumPy互通)

RaggedTensor 可便捷转换为常用格式,满足数据查看/导出需求:

  1. 转嵌套 Python list(最直观,保留可变长度结构):

    python 复制代码
    digits.to_list()  # 输出:[[3, 1, 4, 1], [], [5, 9, 2], [6], []]
  2. 转 NumPy array(返回object类型数组,每个元素是独立的小数组):

    python 复制代码
    digits.numpy()
    # 输出:array([array([3, 1, 4, 1]), array([]), array([5,9,2]), array([6]), array([])], dtype=object)

核心总结

RaggedTensor 是 TensorFlow 为「非均匀形状数据」量身打造的张量类型,核心优势:

  • 无需手动填充(如用0补全可变长度序列),保留原始数据结构;
  • 兼容绝大多数 TF 原生运算,用法和普通张量对齐,学习成本低;
  • 支持灵活的索引、自定义转换和格式互通,覆盖可变长度数据的存储、处理、导出全流程。
相关推荐
东坡肘子1 分钟前
AT 的人生未必比 MT 更好 -- 肘子的 Swift 周报 #118
人工智能·swiftui·swift
雅欣鱼子酱3 小时前
USB Type-C PD取电(诱骗,诱电,SINK),筋膜枪专用取电芯片
网络·人工智能·芯片·电子元器件
kisshuan123968 小时前
【深度学习】使用RetinaNet+X101-32x4d_FPN_GHM模型实现茶芽检测与识别_1
人工智能·深度学习
Learn Beyond Limits9 小时前
解构语义:从词向量到神经分类|Decoding Semantics: Word Vectors and Neural Classification
人工智能·算法·机器学习·ai·分类·数据挖掘·nlp
崔庆才丨静觅9 小时前
0代码生成4K高清图!ACE Data Platform × SeeDream 专属方案:小白/商家闭眼冲
人工智能·api
qq_3564483710 小时前
机器学习基本概念与梯度下降
人工智能
水如烟10 小时前
孤能子视角:关系性学习,“喂饭“的小孩认知
人工智能
徐_长卿10 小时前
2025保姆级微信AI群聊机器人教程:教你如何本地打造私人和群聊机器人
人工智能·机器人
XyX——10 小时前
【福利教程】一键解锁 ChatGPT / Gemini / Spotify 教育权益!TG 机器人全自动验证攻略
人工智能·chatgpt·机器人