TensorFlow 中不规则张量(RaggedTensor)

这份文档核心介绍了 TensorFlow 中不规则张量(RaggedTensor) 的定义、适用场景和核心用法,解决了普通张量"必须固定形状"的痛点,专门用于处理非均匀长度/嵌套可变结构的数据。以下是通俗化的核心概述:

一、什么是不规则张量(RaggedTensor)?

普通 TensorFlow 张量要求所有维度的长度固定(比如 [[1,2],[3,4]] 是合法的,[[1,2],[3]] 不合法),而 RaggedTensor 是 TensorFlow 对「嵌套可变长度列表」的原生支持,允许同一维度下的元素长度不一致(比如 [[3,1,4,1], [], [5,9,2], [6], []]),是处理非均匀形状数据的核心工具。

二、适用场景(解决哪些问题?)

专门应对普通张量无法存储的「可变长度/分层结构」数据,典型场景:

  1. 可变长度特征(如电影的演员名单:有的电影3个演员,有的5个);
  2. 批量可变长度序列(如句子:有的5个单词,有的8个;视频剪辑:有的10帧,有的20帧);
  3. 分层数据(如文本文档:文档→节→段落→句子→单词,每层长度都可变);
  4. 结构化数据的字段(如协议缓冲区中长度不固定的字段)。

三、核心功能:兼容大量 TensorFlow 运算

超过100种TF原生运算支持 RaggedTensor,无需额外转换,直接使用,包括:

运算类型 示例(文档原版) 效果说明
数学运算 tf.add(digits, 3)/tf.reduce_mean(digits, axis=1) 标量加法(所有元素+3)、按行求均值(空行返回nan)
数组运算 tf.concat([digits, [[5,3]]], axis=0)/tf.tile(digits, [1,2]) 纵向拼接(新增一行[5,3])、维度复制(每行元素重复2次)
字符串操作 tf.strings.substr(words, 0, 2) 对字符串型RaggedTensor截取前2个字符(如"So"→"So","thanks"→"th")
控制流/映射 tf.map_fn(tf.math.square, digits) 逐元素执行函数(如平方:3→9,1→1)

基础示例(文档原版)

python 复制代码
# 定义数字型、字符串型RaggedTensor
digits = tf.ragged.constant([[3, 1, 4, 1], [], [5, 9, 2], [6], []])
words = tf.ragged.constant([["So", "long"], ["thanks", "for", "all", "the", "fish"]])

# 数学运算:所有元素+3
print(tf.add(digits, 3))  # 输出:[[6,4,7,4], [], [8,12,5], [9], []]
# 数组运算:每行元素重复2次
print(tf.tile(digits, [1, 2]))  # 输出:[[3,1,4,1,3,1,4,1], [], [5,9,2,5,9,2], [6,6], []]

四、常用操作(和普通张量用法对齐,易上手)

1. Python 风格索引/切片

和普通列表/张量的索引逻辑一致,支持精准切片可变长度的行:

python 复制代码
print(digits[0])        # 取第一行 → [3,1,4,1](普通张量)
print(digits[:, :2])    # 取每行前2个元素 → [[3,1], [], [5,9], [6], []]
print(digits[:, -2:])   # 取每行最后2个元素 → [[4,1], [], [9,2], [6], []]
2. 重载算术/比较运算符

支持逐元素运算,可直接和标量、另一个 RaggedTensor 计算:

python 复制代码
# 和标量运算:所有元素+3
print(digits + 3)  # 同tf.add(digits, 3)
# 和另一个同结构RaggedTensor逐元素相加
print(digits + tf.ragged.constant([[1,2,3,4], [], [5,6,7], [8], []]))
# 输出:[[4,3,7,5], [], [10,15,9], [14], []]
3. 自定义逐元素转换(tf.ragged.map_flat_values)

对 RaggedTensor 的所有元素应用自定义函数,无需关心长度:

python 复制代码
# 定义函数:x*2+1
times_two_plus_one = lambda x: x * 2 + 1
# 逐元素应用函数
print(tf.ragged.map_flat_values(times_two_plus_one, digits))
# 输出:[[7,3,9,3], [], [11,19,5], [13], []]

五、格式转换(和Python/NumPy互通)

RaggedTensor 可便捷转换为常用格式,满足数据查看/导出需求:

  1. 转嵌套 Python list(最直观,保留可变长度结构):

    python 复制代码
    digits.to_list()  # 输出:[[3, 1, 4, 1], [], [5, 9, 2], [6], []]
  2. 转 NumPy array(返回object类型数组,每个元素是独立的小数组):

    python 复制代码
    digits.numpy()
    # 输出:array([array([3, 1, 4, 1]), array([]), array([5,9,2]), array([6]), array([])], dtype=object)

核心总结

RaggedTensor 是 TensorFlow 为「非均匀形状数据」量身打造的张量类型,核心优势:

  • 无需手动填充(如用0补全可变长度序列),保留原始数据结构;
  • 兼容绝大多数 TF 原生运算,用法和普通张量对齐,学习成本低;
  • 支持灵活的索引、自定义转换和格式互通,覆盖可变长度数据的存储、处理、导出全流程。
相关推荐
模型时代几秒前
Anthropic明确拒绝在Claude中加入广告功能
人工智能·microsoft
夕小瑶4 分钟前
OpenClaw、Moltbook爆火,算力如何48小时内扩到1900张卡
人工智能
一枕眠秋雨>o<6 分钟前
透视算力:cann-tools如何让AI性能调优从玄学走向科学
人工智能
那个村的李富贵19 分钟前
昇腾CANN跨行业实战:五大新领域AI落地案例深度解析
人工智能·aigc·cann
集简云-软件连接神器23 分钟前
技术实战:集简云语聚AI实现小红书私信接入AI大模型全流程解析
人工智能·小红书·ai客服
松☆23 分钟前
深入理解CANN:面向AI加速的异构计算架构
人工智能·架构
rainbow72424423 分钟前
无基础学AI的入门核心,从基础工具和理论开始学
人工智能
子榆.27 分钟前
CANN 与主流 AI 框架集成:从 PyTorch/TensorFlow 到高效推理的无缝迁移指南
人工智能·pytorch·tensorflow
七月稻草人29 分钟前
CANN生态ops-nn:AIGC的神经网络算子加速内核
人工智能·神经网络·aigc
2501_9248787329 分钟前
数据智能驱动进化:AdAgent 多触点归因与自我学习机制详解
人工智能·逻辑回归·动态规划