tensorflow 零基础吃透:RaggedTensor 在 Keras 和 tf.Example 中的实战用法

零基础吃透:RaggedTensor在Keras和tf.Example中的实战用法

这份内容会拆解 RaggedTensor 两大核心实战场景------Keras 深度学习模型输入tf.Example 可变长度特征解析,全程用通俗语言+逐行代码解释,帮你理解"为什么用RaggedTensor""怎么用""核心API原理"。

一、场景1:Keras中使用RaggedTensor(训练LSTM判断句子是否是问题)

核心目标

长度不同的句子 (比如"What makes you think she is a witch?"有8个词,"A newt?"只有2个词)训练LSTM模型,判断每个句子是否是疑问句。

✅ 关键优势:用RaggedTensor直接输入可变长度句子,无需补0,避免冗余计算,模型原生支持处理。

完整代码+逐行拆解(带原理)

python 复制代码
import tensorflow as tf

# ===================== 步骤1:定义任务数据 =====================
# 输入:4个长度不同的句子(可变长度文本)
sentences = tf.constant(
    ['What makes you think she is a witch?',  # 8个词
     'She turned me into a newt.',           # 6个词
     'A newt?',                              # 2个词
     'Well, I got better.'])                 # 5个词
# 标签:每个句子是否是疑问句(True=是,False=否)
is_question = tf.constant([True, False, True, False])

# ===================== 步骤2:预处理(字符串→RaggedTensor) =====================
# 超参数:哈希桶数量(把单词转成0~999的整数)
hash_buckets = 1000
# 步骤2.1:按空格切分句子→得到RaggedTensor(每个句子的单词列表,长度可变)
words = tf.strings.split(sentences, ' ')
# 步骤2.2:单词→哈希编号(解决字符串无法输入模型的问题)→ 仍为RaggedTensor
hashed_words = tf.strings.to_hash_bucket_fast(words, hash_buckets)
# 查看预处理结果(验证是RaggedTensor)
print("预处理后的单词编号(RaggedTensor):")
print(hashed_words)
关键解释:
  • tf.strings.split(sentences, ' '):把每个句子按空格切分成单词列表,返回RaggedTensor(比如A newt?[b'A', b'newt?']);
  • tf.strings.to_hash_bucket_fast:把字符串单词转成0~999的整数,保留RaggedTensor结构(长度不变)。
python 复制代码
# ===================== 步骤3:构建Keras模型(核心:支持RaggedTensor) =====================
keras_model = tf.keras.Sequential([
    # 输入层:关键!设置ragged=True,接收RaggedTensor输入
    tf.keras.layers.Input(shape=[None], dtype=tf.int64, ragged=True),
    # 嵌入层:把单词编号→16维向量(原生支持RaggedTensor)
    tf.keras.layers.Embedding(hash_buckets, 16),
    # LSTM层:处理可变长度序列(无需补0,自动按实际长度计算)
    tf.keras.layers.LSTM(32, use_bias=False),
    # 全连接层+激活函数:提取特征
    tf.keras.layers.Dense(32),
    tf.keras.layers.Activation(tf.nn.relu),
    # 输出层:预测是否是疑问句(1维输出)
    tf.keras.layers.Dense(1)
])

# ===================== 步骤4:编译+训练+预测 =====================
# 编译模型:二分类任务用binary_crossentropy损失,优化器选rmsprop
keras_model.compile(loss='binary_crossentropy', optimizer='rmsprop')
# 训练模型:直接传入RaggedTensor(hashed_words)和标签,无需转密集张量
keras_model.fit(hashed_words, is_question, epochs=5)
# 预测:输入RaggedTensor,输出每个句子的预测值(越接近1越可能是疑问句)
print("\n模型预测结果:")
print(keras_model.predict(hashed_words))

核心API解析(为什么能支持RaggedTensor?)

API/参数 作用原理
Input(ragged=True) 声明输入是RaggedTensor,允许输入维度为[None](可变长度),Keras层会适配处理
Embedding 原生支持RaggedTensor输入,按"实际单词数"生成嵌入向量,不生成冗余的补0向量
LSTM 处理RaggedTensor时,自动按每个句子的实际长度计算序列特征,忽略补0(这里根本没补)

运行结果解读

复制代码
Epoch 1/5 → loss:2.5281;Epoch 5/5 → loss:1.6017(损失下降,模型在学习)
预测结果:[[0.0526], [0.0006], [0.0392], [0.0021]]
  • 预测值越接近1,模型认为是疑问句的概率越高;
  • 第一句(疑问句)预测值0.0526,第三句(疑问句)0.0392,比第二/四句高,符合标签规律(模型初步学到了特征)。

关键优势(对比补0的密集张量)

  1. 无需补0:不用把所有句子补到最长长度(8个词),节省内存和计算;
  2. 逻辑简洁:预处理和模型输入全程保留原始句子长度,避免填充值干扰模型学习;
  3. 原生兼容:Keras核心层(Embedding/LSTM/Dense)都支持RaggedTensor,无需额外转换。

二、场景2:tf.Example中解析可变长度特征为RaggedTensor

核心背景

tf.Example 是TensorFlow官方的protobuf数据格式 (一种高效的序列化格式),常用于存储训练数据,尤其适合存储「可变长度特征」(比如有的样本有2个颜色,有的有1个;有的样本长度特征为空)。

✅ 核心需求:把tf.Example中存储的可变长度特征,直接解析为RaggedTensor(不用手动处理空值/补0)。

完整代码+逐行拆解

python 复制代码
import tensorflow as tf
# 导入protobuf文本解析工具(把文本格式的protobuf转成Example对象)
import google.protobuf.text_format as pbtext

# ===================== 步骤1:定义函数,构建tf.Example =====================
def build_tf_example(s):
  # 步骤:把文本格式的protobuf字符串→tf.train.Example对象→序列化(转成字节串)
  return pbtext.Merge(s, tf.train.Example()).SerializeToString()

# 构建4个tf.Example样本(每个样本的colors/lengths特征长度不同)
example_batch = [
  # 样本1:colors=["red","blue"](2个),lengths=[7](1个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["red", "blue"]} } }
      feature {key: "lengths" value {int64_list {value: [7]} } } }'''),
  # 样本2:colors=["orange"](1个),lengths=[](空)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["orange"]} } }
      feature {key: "lengths" value {int64_list {value: []} } } }'''),
  # 样本3:colors=["black","yellow"](2个),lengths=[1,3](2个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["black", "yellow"]} } }
      feature {key: "lengths" value {int64_list {value: [1, 3]} } } }'''),
  # 样本4:colors=["green"](1个),lengths=[3,5,2](3个)
  build_tf_example(r'''
    features {
      feature {key: "colors" value {bytes_list {value: ["green"]} } }
      feature {key: "lengths" value {int64_list {value: [3, 5, 2]} } } }''')]
python 复制代码
# ===================== 步骤2:定义特征规范(关键:用RaggedFeature) =====================
feature_specification = {
    # 声明colors特征是字符串型RaggedTensor
    'colors': tf.io.RaggedFeature(tf.string),
    # 声明lengths特征是int64型RaggedTensor
    'lengths': tf.io.RaggedFeature(tf.int64),
}

# ===================== 步骤3:解析tf.Example→RaggedTensor =====================
# 解析序列化的example_batch,按特征规范返回RaggedTensor
feature_tensors = tf.io.parse_example(example_batch, feature_specification)

# 打印解析结果
print("\n解析后的可变长度特征(RaggedTensor):")
for name, value in feature_tensors.items():
  print("{}={}".format(name, value))

核心API解析

API 作用原理
tf.train.Example TensorFlow的protobuf数据格式,支持存储可变长度的列表特征(bytes_list/int64_list)
pbtext.Merge 把文本格式的protobuf字符串,转成tf.train.Example对象
SerializeToString() 把Example对象序列化成字节串(方便存储/传输)
tf.io.RaggedFeature 声明特征是可变长度的,解析后返回RaggedTensor(而非补0的密集张量)
tf.io.parse_example 批量解析序列化的Example字节串,按特征规范返回对应张量(这里是RaggedTensor)

运行结果解读

复制代码
colors=<tf.RaggedTensor [[b'red', b'blue'], [b'orange'], [b'black', b'yellow'], [b'green']]>
lengths=<tf.RaggedTensor [[7], [], [1, 3], [3, 5, 2]]>
  • colors:4个样本的颜色特征,长度分别为2、1、2、1,直接用RaggedTensor存储,无空值;
  • lengths:4个样本的长度特征,长度分别为1、0、2、3,空样本(第二个)直接存为空列表,无需补0。

关键优势(对比普通解析)

如果不用tf.io.RaggedFeature,解析可变长度特征会返回补0的密集张量 (比如lengths会被解析成[[7,0,0], [0,0,0], [1,3,0], [3,5,2]]),而RaggedTensor:

  1. 保留原始长度:空特征就是空列表,无需补0;
  2. 无冗余:只存储有效元素,节省内存;
  3. 后续兼容:解析后的RaggedTensor可直接传入Keras模型/TF运算,无需额外转换。

三、核心总结(RaggedTensor在Keras/tf.Example中的价值)

场景 核心用法 关键优势
Keras模型 Input层设置ragged=True,直接输入RaggedTensor 处理可变长度序列(文本),无需补0,模型学习更高效
tf.Example解析 用tf.io.RaggedFeature声明可变长度特征 解析可变长度特征时保留原始结构,无冗余补0

通用关键结论

  1. RaggedTensor是TensorFlow处理「可变长度数据」的"一站式解决方案":从数据解析(tf.Example)→模型输入(Keras)→模型运算(LSTM/Embedding)全程兼容;
  2. 相比"补0+密集张量",RaggedTensor既节省内存,又避免填充值干扰模型学习,是处理非均匀长度数据的最优选择;
  3. 核心API记忆:
    • Keras:Input(ragged=True)
    • tf.Example:tf.io.RaggedFeature
相关推荐
杭州泽沃电子科技有限公司2 小时前
汽轮机在线监测:老牌火电的“智慧心脏”如何打赢“双碳”攻坚战?
运维·人工智能·智能监测·发电
陈奕昆2 小时前
n8n实战营Day3课时3:库存物流联动·全流程测试与异常调试
人工智能·python·n8n
珂朵莉MM2 小时前
第七届全球校园人工智能算法精英大赛-算法巅峰赛产业命题赛第3赛季优化题--碳中和
人工智能·算法
jinxinyuuuus2 小时前
AI 硬件助手:LLM的比较推理与自动化决策理由生成
人工智能·自动化
智界前沿2 小时前
AI数字人公司推荐,集之互动如何在医疗、政务、汽车等关键领域打造“标杆案例”
人工智能·汽车·政务
水如烟2 小时前
孤能子视角:“人本关系线“耦合––焦耳、功率、效率、个人学习
人工智能
FIT2CLOUD飞致云2 小时前
重要发布丨新增支持工作流知识库和数据源工具,MaxKB开源企业级智能体平台v2.4.0版本发布
人工智能·ai·开源·1panel·maxkb
The Straggling Crow2 小时前
RAGFlow
人工智能
沃达德软件2 小时前
智慧警务实战模型与算法
大数据·人工智能·算法·数据挖掘·数据分析