RaggedTensor 处理可变长度文本序列的核心实战场景

这份示例展示了 RaggedTensor 处理可变长度文本序列的核心实战场景:为一批长度不同的查询句子,构造「一元词嵌入(单个词的向量)+ 二元词嵌入(相邻词对的向量)」,最终计算每个句子的平均嵌入向量。全程利用 RaggedTensor 适配"句子长度可变"的特点,无需手动填充(如补0),保证计算简洁高效。

前置准备(补全原文缺失的导入)

运行示例需先导入依赖,原文未写但必须补充:

python 复制代码
import tensorflow as tf
import math  # 用于嵌入表初始化的sqrt计算

核心目标

给 3 个长度不同的句子(4词、1词、5词),生成每个句子的平均嵌入向量(融合单个词和相邻词对的语义),最终输出 3 个 4 维向量(对应 3 个句子)。

逐行解析(贴合原文标号①-⑥)

步骤1:定义输入(可变长度句子的 RaggedTensor)
python 复制代码
queries = tf.ragged.constant([['Who', 'is', 'Dan', 'Smith'],
                              ['Pause'],
                              ['Will', 'it', 'rain', 'later', 'today']])
  • 结构:3 行(3 个句子),每行长度分别为 4、1、5(可变长度),RaggedTensor 原生支持这种结构,无需填充;
  • 作用:存储一批可变长度的查询句子,作为后续处理的原始输入。
步骤2:创建嵌入表(Embedding Table)
python 复制代码
num_buckets = 1024  # 哈希桶数量(把字符串单词映射到0~1023的整数)
embedding_size = 4   # 每个词的嵌入向量维度(4维)
embedding_table = tf.Variable(
    tf.random.truncated_normal([num_buckets, embedding_size],
                       stddev=1.0 / math.sqrt(embedding_size)))
  • 形状:[1024, 4] → 1024 个"哈希桶",每个桶对应一个 4 维的嵌入向量;
  • 初始化:截断正态分布(避免数值过大),保证嵌入向量初始值稳定;
  • 作用:作为"单词→向量"的映射表,后续通过哈希桶 ID 查向量。
步骤3:计算一元词嵌入(单个词的向量,①)
python 复制代码
# 步骤3.1:字符串单词→哈希桶ID(把字符串转成整数,才能查嵌入表)
word_buckets = tf.strings.to_hash_bucket_fast(queries, num_buckets)
# 步骤3.2:根据桶ID查嵌入表,得到每个单词的4维向量
word_embeddings = tf.nn.embedding_lookup(embedding_table, word_buckets)  # ①
  • word_buckets 结构:和 queries 同形状的 RaggedTensor(3 行,长度 4/1/5),每个元素是 0~1023 的整数;
  • word_embeddings 结构:[3, (4/1/5), 4] → 3 个句子,每个句子对应长度的单词,每个单词 4 维向量;
  • 核心:RaggedTensor 直接支持 tf.strings.to_hash_bucket_fasttf.nn.embedding_lookup,无需转换为普通张量。
步骤4:给每个句子加首尾标记(②)
python 复制代码
# 创建形状[3,1]的标记张量(每个句子1个'#')
marker = tf.fill([queries.nrows(), 1], '#')
# 横向拼接:开头+原句子+结尾 → 每个句子首尾加'#'
padded = tf.concat([marker, queries, marker], axis=1)  # ②
  • queries.nrows():获取 RaggedTensor 的行数(3 行),避免手动写死行数;
  • marker 形状:[3,1] → 3 行 1 列,每个元素是 #
  • padded 结构:RaggedTensor,每行长度 = 原长度 + 2(首尾各加 1 个 #):
    • 句子1:['#', 'Who', 'is', 'Dan', 'Smith', '#'](长度 6);
    • 句子2:['#', 'Pause', '#'](长度 3);
    • 句子3:['#', 'Will', 'it', 'rain', 'later', 'today', '#'](长度 7);
  • 作用:为后续构造"首尾边界的二元词"(如 #+WhoSmith+#)做准备。
步骤5:构造二元词(相邻词对,③)
python 复制代码
# 取"除最后一个元素外的所有元素"和"除第一个元素外的所有元素",拼接成相邻词对
bigrams = tf.strings.join([padded[:, :-1], padded[:, 1:]], separator='+')  # ③
  • padded[:, :-1]:每行去掉最后一个元素(比如句子1:['#', 'Who', 'is', 'Dan', 'Smith']);
  • padded[:, 1:]:每行去掉第一个元素(比如句子1:['Who', 'is', 'Dan', 'Smith', '#']);
  • tf.strings.join(..., separator='+'):把相邻元素拼接成二元词(比如 # + Who → #+WhoWho+is,直到 Smith+#);
  • bigrams 结构:RaggedTensor,每行长度 = 拼接后长度 -1(句子1:5 个二元词,句子2:2 个,句子3:6 个);
  • 作用:生成每个句子的相邻词对,捕捉"词之间的上下文关系"。
步骤6:计算二元词嵌入(词对的向量,④)
python 复制代码
# 步骤6.1:二元词→哈希桶ID(字符串转整数)
bigram_buckets = tf.strings.to_hash_bucket_fast(bigrams, num_buckets)
# 步骤6.2:根据桶ID查嵌入表,得到每个二元词的4维向量
bigram_embeddings = tf.nn.embedding_lookup(embedding_table, bigram_buckets)  # ④
  • 逻辑和一元词嵌入完全一致;
  • bigram_embeddings 结构:[3, (5/2/6), 4] → 3 个句子,每个句子对应数量的二元词,每个二元词 4 维向量。
步骤7:合并一元+二元嵌入(⑤)
python 复制代码
all_embeddings = tf.concat([word_embeddings, bigram_embeddings], axis=1)  # ⑤
  • axis=1:横向拼接(句子内的元素维度),把每个句子的"一元词向量"和"二元词向量"合并;
  • all_embeddings 结构:
    • 句子1:4 个一元词 + 5 个二元词 = 9 个 4 维向量;
    • 句子2:1 个一元词 + 2 个二元词 = 3 个 4 维向量;
    • 句子3:5 个一元词 + 6 个二元词 = 11 个 4 维向量;
  • 核心:RaggedTensor 的 concat(axis=1) 原生支持可变长度行的横向拼接,普通张量(需固定列数)无法做到这一点。
步骤8:计算每个句子的平均嵌入(⑥)
python 复制代码
avg_embedding = tf.reduce_mean(all_embeddings, axis=1)  # ⑥
print(avg_embedding)
  • axis=1:对每个句子的所有嵌入向量(一元+二元)求均值(按句子内的元素维度求平均);
  • 输出形状:[3, 4] → 普通张量(不再是 RaggedTensor),3 个句子各 1 个 4 维平均向量;
  • 结果:和原文一致,每个句子的平均向量融合了"单个词"和"相邻词对"的语义,可用于后续的文本分类、检索等任务。

核心总结(RaggedTensor 的价值)

这个示例的关键是 全程用 RaggedTensor 处理可变长度句子,相比普通张量的优势:

  1. 无需填充:不用把所有句子补0到最长长度(如补到7词),避免冗余计算;
  2. 运算兼容:所有核心运算(concat/strings.join/embedding_lookup/reduce_mean)都原生支持 RaggedTensor,流程简洁;
  3. 结构保真:始终保留句子的原始长度,计算均值时只针对"有效元素",不会被填充值干扰。

最终输出的平均嵌入向量,既捕捉了单个词的语义,又捕捉了词之间的上下文关系,是文本处理中常见的特征工程手段,而 RaggedTensor 是处理可变长度文本的核心工具。

相关推荐
小鸡吃米…5 天前
基于 TensorFlow 的图像识别
人工智能·python·tensorflow
小鸡吃米…5 天前
TensorFlow - 构建计算图
人工智能·python·tensorflow
A懿轩A5 天前
【2026 最新】TensorFlow 安装配置详细指南 同时讲解安装CPU和GPU版本 小白也能轻松上手!逐步带图超详细展示(Windows 版)
人工智能·windows·python·深度学习·tensorflow
小鸡吃米…6 天前
TensorFlow 实现异或(XOR)运算
人工智能·python·tensorflow·neo4j
小鸡吃米…6 天前
TensorFlow 实现梯度下降优化
人工智能·python·tensorflow·neo4j
甄心爱学习6 天前
【LR逻辑回归】原理以及tensorflow实现
算法·tensorflow·逻辑回归
小鸡吃米…7 天前
TensorFlow 实现多层感知机学习
人工智能·python·tensorflow
小鸡吃米…7 天前
TensorFlow 优化器
人工智能·python·tensorflow
小鸡吃米…8 天前
TensorFlow 模型导出
python·tensorflow·neo4j
Jonathan Star9 天前
Ant Design (antd) Form 组件中必填项的星号(*)从标签左侧移到右侧
人工智能·python·tensorflow