tensorflow 零基础吃透:RaggedTensor 在 Keras 和 tf.Example 中的实战用法

零基础吃透:RaggedTensor在Keras和tf.Example中的实战用法

这份内容会拆解 RaggedTensor 两大核心实战场景------Keras 深度学习模型输入tf.Example 可变长度特征解析,全程用通俗语言+逐行代码解释,帮你理解"为什么用RaggedTensor""怎么用""核心API原理"。

一、场景1:Keras中使用RaggedTensor(训练LSTM判断句子是否是问题)

核心目标

长度不同的句子 (比如"What makes you think she is a witch?"有8个词,"A newt?"只有2个词)训练LSTM模型,判断每个句子是否是疑问句。

✅ 关键优势:用RaggedTensor直接输入可变长度句子,无需补0,避免冗余计算,模型原生支持处理。

完整代码+逐行拆解(带原理)

python 复制代码
import tensorflow as tf

# ===================== 步骤1:定义任务数据 =====================
# 输入:4个长度不同的句子(可变长度文本)
sentences = tf.constant(
    ['What makes you think she is a witch?',  # 8个词
     'She turned me into a newt.',           # 6个词
     'A newt?',                              # 2个词
     'Well, I got better.'])                 # 5个词
# 标签:每个句子是否是疑问句(True=是,False=否)
is_question = tf.constant([True, False, True, False])

# ===================== 步骤2:预处理(字符串→RaggedTensor) =====================
# 超参数:哈希桶数量(把单词转成0~999的整数)
hash_buckets = 1000
# 步骤2.1:按空格切分句子→得到RaggedTensor(每个句子的单词列表,长度可变)
words = tf.strings.split(sentences, ' ')
# 步骤2.2:单词→哈希编号(解决字符串无法输入模型的问题)→ 仍为RaggedTensor
hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)
# 查看预处理结果(验证是RaggedTensor)
print("预处理后的单词编号(RaggedTensor):")
print(hashed_words)
关键解释:
  • tf.strings.split(sentences, ' '):把每个句子按空格切分成单词列表,返回RaggedTensor(比如A newt?[b'A', b'newt?']);
  • tf.strings.to_hash_bucket_fast:把字符串单词转成0~999的整数,保留RaggedTensor结构(长度不变)。
python 复制代码
# ===================== 步骤3:构建Keras模型(核心:支持RaggedTensor) =====================
keras_model = tf.keras.Sequential([
    # 输入层:关键!设置ragged=True,接收RaggedTensor输入
    tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),
    # 嵌入层:把单词编号→16维向量(原生支持RaggedTensor)
    tf.keras.layers.Embedding(hash_buckets, 16),
    # LSTM层:处理可变长度序列(无需补0,自动按实际长度计算)
    tf.keras.layers.LSTM(32, use_bias=False),
    # 全连接层+激活函数:提取特征
    tf.keras.layers.Dense(32),
    tf.keras.layers.Activation(tf.nn.relu),
    # 输出层:预测是否是疑问句(1维输出)
    tf.keras.layers.Dense(1)
])

# ===================== 步骤4:编译+训练+预测 =====================
# 编译模型:二分类任务用binary_crossentropy损失,优化器选rmsprop
keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
# 训练模型:直接传入RaggedTensor(hashed_words)和标签,无需转密集张量
keras_model.fit(hashed_words, is_question, epochs=5)
# 预测:输入RaggedTensor,输出每个句子的预测值(越接近1越可能是疑问句)
print("\n模型预测结果:")
print(keras_model.predict(hashed_words))

核心API解析(为什么能支持RaggedTensor?)

API/参数 作用原理
Input(ragged=True) 声明输入是RaggedTensor,允许输入维度为[None](可变长度),Keras层会适配处理
Embedding 原生支持RaggedTensor输入,按"实际单词数"生成嵌入向量,不生成冗余的补0向量
LSTM 处理RaggedTensor时,自动按每个句子的实际长度计算序列特征,忽略补0(这里根本没补)

运行结果解读

复制代码
Epoch 1/5 → loss:2.5281;Epoch 5/5 → loss:1.6017(损失下降,模型在学习)
预测结果:[[0.0526], [0.0006], [0.0392], [0.0021]]
  • 预测值越接近1,模型认为是疑问句的概率越高;
  • 第一句(疑问句)预测值0.0526,第三句(疑问句)0.0392,比第二/四句高,符合标签规律(模型初步学到了特征)。

关键优势(对比补0的密集张量)

  1. 无需补0:不用把所有句子补到最长长度(8个词),节省内存和计算;
  2. 逻辑简洁:预处理和模型输入全程保留原始句子长度,避免填充值干扰模型学习;
  3. 原生兼容:Keras核心层(Embedding/LSTM/Dense)都支持RaggedTensor,无需额外转换。

二、场景2:tf.Example中解析可变长度特征为RaggedTensor

核心背景

tf.Example 是TensorFlow官方的protobuf数据格式 (一种高效的序列化格式),常用于存储训练数据,尤其适合存储「可变长度特征」(比如有的样本有2个颜色,有的有1个;有的样本长度特征为空)。

✅ 核心需求:把tf.Example中存储的可变长度特征,直接解析为RaggedTensor(不用手动处理空值/补0)。

完整代码+逐行拆解

python 复制代码
import tensorflow as tf
# 导入protobuf文本解析工具(把文本格式的protobuf转成Example对象)
import google.protobuf.text_format as pbtext

# ===================== 步骤1:定义函数,构建tf.Example =====================
def build_tf_example(s):
  # 步骤:把文本格式的protobuf字符串→tf.train.Example对象→序列化(转成字节串)
  return pbtext.Merge(s, tf.train.Example()).SerializeToString()

# 构建4个tf.Example样本(每个样本的colors/lengths特征长度不同)
example_batch = [
  # 样本1:colors=["red","blue"](2个),lengths=[7](1个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["red", "blue"]} } }
      feature {key: "lengths" value {int64_list {value: [7]} } } }'''),
  # 样本2:colors=["orange"](1个),lengths=[](空)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["orange"]} } }
      feature {key: "lengths" value {int64_list {value: []} } } }'''),
  # 样本3:colors=["black","yellow"](2个),lengths=[1,3](2个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["black", "yellow"]} } }
      feature {key: "lengths" value {int64_list {value: [1, 3]} } } }'''),
  # 样本4:colors=["green"](1个),lengths=[3,5,2](3个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["green"]} } }
      feature {key: "lengths" value {int64_list {value: [3, 5, 2]} } } }''')]
python 复制代码
# ===================== 步骤2:定义特征规范(关键:用RaggedFeature) =====================
feature_specification = {
    # 声明colors特征是字符串型RaggedTensor
    'colors': tf.io.RaggedFeature(tf.string),
    # 声明lengths特征是int64型RaggedTensor
    'lengths': tf.io.RaggedFeature(tf.int64),
}

# ===================== 步骤3:解析tf.Example→RaggedTensor =====================
# 解析序列化的example_batch,按特征规范返回RaggedTensor
feature_tensors = tf.io.parse_example(example_batch, feature_specification)

# 打印解析结果
print("\n解析后的可变长度特征(RaggedTensor):")
for name, value in feature_tensors.items():
  print("{}={}".format(name, value))

核心API解析

API 作用原理
tf.train.Example TensorFlow的protobuf数据格式,支持存储可变长度的列表特征(bytes_list/int64_list)
pbtext.Merge 把文本格式的protobuf字符串,转成tf.train.Example对象
SerializeToString() 把Example对象序列化成字节串(方便存储/传输)
tf.io.RaggedFeature 声明特征是可变长度的,解析后返回RaggedTensor(而非补0的密集张量)
tf.io.parse_example 批量解析序列化的Example字节串,按特征规范返回对应张量(这里是RaggedTensor)

运行结果解读

复制代码
colors=<tf.RaggedTensor [[b'red', b'blue'], [b'orange'], [b'black', b'yellow'], [b'green']]>
lengths=<tf.RaggedTensor [[7], [], [1, 3], [3, 5, 2]]>
  • colors:4个样本的颜色特征,长度分别为2、1、2、1,直接用RaggedTensor存储,无空值;
  • lengths:4个样本的长度特征,长度分别为1、0、2、3,空样本(第二个)直接存为空列表,无需补0。

关键优势(对比普通解析)

如果不用tf.io.RaggedFeature,解析可变长度特征会返回补0的密集张量 (比如lengths会被解析成[[7,0,0], [0,0,0], [1,3,0], [3,5,2]]),而RaggedTensor:

  1. 保留原始长度:空特征就是空列表,无需补0;
  2. 无冗余:只存储有效元素,节省内存;
  3. 后续兼容:解析后的RaggedTensor可直接传入Keras模型/TF运算,无需额外转换。

三、核心总结(RaggedTensor在Keras/tf.Example中的价值)

场景 核心用法 关键优势
Keras模型 Input层设置ragged=True,直接输入RaggedTensor 处理可变长度序列(文本),无需补0,模型学习更高效
tf.Example解析 用tf.io.RaggedFeature声明可变长度特征 解析可变长度特征时保留原始结构,无冗余补0

通用关键结论

  1. RaggedTensor是TensorFlow处理「可变长度数据」的"一站式解决方案":从数据解析(tf.Example)→模型输入(Keras)→模型运算(LSTM/Embedding)全程兼容;
  2. 相比"补0+密集张量",RaggedTensor既节省内存,又避免填充值干扰模型学习,是处理非均匀长度数据的最优选择;
  3. 核心API记忆:
    • Keras:Input(ragged=True)
    • tf.Example:tf.io.RaggedFeature
相关推荐
NAGNIP9 小时前
一文搞懂深度学习中的通用逼近定理!
人工智能·算法·面试
冬奇Lab10 小时前
一天一个开源项目(第36篇):EverMemOS - 跨 LLM 与平台的长时记忆 OS,让 Agent 会记忆更会推理
人工智能·开源·资讯
冬奇Lab10 小时前
OpenClaw 源码深度解析(一):Gateway——为什么需要一个"中枢"
人工智能·开源·源码阅读
AngelPP14 小时前
OpenClaw 架构深度解析:如何把 AI 助手搬到你的个人设备上
人工智能
宅小年14 小时前
Claude Code 换成了Kimi K2.5后,我再也回不去了
人工智能·ai编程·claude
九狼14 小时前
Flutter URL Scheme 跨平台跳转
人工智能·flutter·github
ZFSS15 小时前
Kimi Chat Completion API 申请及使用
前端·人工智能
天翼云开发者社区16 小时前
春节复工福利就位!天翼云息壤2500万Tokens免费送,全品类大模型一键畅玩!
人工智能·算力服务·息壤
知识浅谈16 小时前
教你如何用 Gemini 将课本图片一键转为精美 PPT
人工智能
Ray Liang16 小时前
被低估的量化版模型,小身材也能干大事
人工智能·ai·ai助手·mindx