目录
-
- [1. 概述](#1. 概述)
-
- [1.1 为什么需要 KeyedJaggedTensor?](#1.1 为什么需要 KeyedJaggedTensor?)
- [1.2 数据形状理解](#1.2 数据形状理解)
- [2. 核心构造方法](#2. 核心构造方法)
-
- [2.1 from_offsets_sync 方法](#2.1 from_offsets_sync 方法)
- [2.2 from_jt_dict 方法](#2.2 from_jt_dict 方法)
- [2.3 直接构造方法](#2.3 直接构造方法)
- [3. 核心 API 和使用示例](#3. 核心 API 和使用示例)
-
- [3.1 基本属性访问](#3.1 基本属性访问)
- [3.2 特征访问](#3.2 特征访问)
- [3.3 与 EmbeddingBagCollection 集成](#3.3 与 EmbeddingBagCollection 集成)
- [4. DataLoader 中的集成](#4. DataLoader 中的集成)
- [5. 注意事项和最佳实践](#5. 注意事项和最佳实践)
-
- [5.1 内存效率考虑](#5.1 内存效率考虑)
- [5.2 数据验证](#5.2 数据验证)
- [5.3 性能优化](#5.3 性能优化)
- [5.4 分布式训练注意事项](#5.4 分布式训练注意事项)
- [6. 高级用法](#6. 高级用法)
-
- [6.1 特征转换和增强](#6.1 特征转换和增强)
- [6.2 与 TorchRec 其他组件的集成](#6.2 与 TorchRec 其他组件的集成)
- [7. 调试和可视化](#7. 调试和可视化)
-
- [7.1 打印和调试](#7.1 打印和调试)
- [7.2 可视化工具](#7.2 可视化工具)
- [8. 常见问题和解决方案](#8. 常见问题和解决方案)
-
- [8.1 常见错误](#8.1 常见错误)
- [8.2 性能问题](#8.2 性能问题)
- [9. 完整示例:端到端推荐系统](#9. 完整示例:端到端推荐系统)
- [10. 总结](#10. 总结)
1. 概述
KeyedJaggedTensor (KJT) 是 TorchRec v1.0 中用于表示稀疏特征的核心数据结构之一,专门设计用于高效管理多个变长特征序列。 在推荐系统中,它能够同时处理用户历史点击(click_ids)、搜索词(search_ids)等多种稀疏特征,是构建大规模推荐系统的基础组件。
1.1 为什么需要 KeyedJaggedTensor?
在推荐系统中,我们经常需要处理:
- 多个稀疏特征(用户ID、商品ID、类别ID等)
- 每个特征在不同样本中的长度不同(用户历史点击数不同)
- 高效的内存布局和计算性能
KeyedJaggedTensor 正是为了解决这些问题而设计的,它可以将多个 JaggedTensor 高效地组织在一起。
1.2 数据形状理解
KeyedJaggedTensor 的形状表示为 (F, B, L[f][i]),其中:
- F: 特征(键)的数量
- B: 批量大小
- L[f][i]: 特征 f 在批量索引 i 处的长度
2. 核心构造方法
KeyedJaggedTensor 提供了多种构造方法,适应不同的使用场景:
2.1 from_offsets_sync 方法
这是最常用的构造方法,通过指定偏移量来构建 KJT。
python
import torch
from torchrec import KeyedJaggedTensor
# 创建一个简单的 KeyedJaggedTensor
kjt = KeyedJaggedTensor.from_offsets_sync(
keys=["feature1", "feature2"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6]),
offsets=torch.tensor([0, 2, 4, 7]) # 偏移量定义
)
这里的 offsets 参数定义了每个特征在 values 张量中的起始位置。
2.2 from_jt_dict 方法
当已有多个 JaggedTensor 时,可以使用此方法将它们组合成一个 KeyedJaggedTensor。
python
from torchrec import JaggedTensor
# 创建单个特征的 JaggedTensor
feature1_jt = JaggedTensor(
values=torch.tensor([1, 2, 3]),
lengths=torch.tensor([2, 1]) # 两个样本,第一个样本2个值,第二个样本1个值
)
feature2_jt = JaggedTensor(
values=torch.tensor([4, 5]),
lengths=torch.tensor([1, 1])
)
# 从字典构造 KeyedJaggedTensor
kjt = KeyedJaggedTensor.from_jt_dict({
"feature1": feature1_jt,
"feature2": feature2_jt
})
这种方法在 DataLoader 的 collate 函数中特别有用,可以将按特征组织的数据聚合起来。
2.3 直接构造方法
也可以直接通过 values 和 lengths 参数构造:
python
kjt = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([1, 2, 3, 4, 5]),
lengths=torch.tensor([[2, 1], [1, 1]]) # 形状: [num_features, batch_size]
)
3. 核心 API 和使用示例
3.1 基本属性访问
python
# 获取所有键
print(kjt.keys()) # ['feature1', 'feature2']
# 获取所有值
print(kjt.values()) # tensor([1, 2, 3, 4, 5])
# 获取长度信息
print(kjt.lengths()) # tensor([[2, 1], [1, 1]])
# 获取偏移量
print(kjt.offsets()) # 如果未计算会自动计算
KeyedJaggedTensor 提供了 offsets() 方法,如果偏移量尚未计算,它会自动计算并返回。
3.2 特征访问
python
# 通过键名访问特定特征
feature1_data = kjt["feature1"]
print(feature1_data) # 返回对应的 JaggedTensor
# 检查特征是否存在
if "feature1" in kjt:
print("feature1 exists")
3.3 与 EmbeddingBagCollection 集成
KeyedJaggedTensor 的主要用途是与 EmbeddingBagCollection 配合使用,进行嵌入查找:
python
from torchrec import EmbeddingBagCollection, EmbeddingConfig
# 定义嵌入配置
ebc = EmbeddingBagCollection(
tables=[
EmbeddingConfig(
name="feature1_table",
embedding_dim=16,
num_embeddings=100,
feature_names=["feature1"]
),
EmbeddingConfig(
name="feature2_table",
embedding_dim=16,
num_embeddings=100,
feature_names=["feature2"]
)
]
)
# 前向传播
embeddings = ebc(kjt) # KJT 作为输入
下面,我们将使用 TorchRec 数据类型 KeyedJaggedTensor 和 EmbeddingBagCollection 进行简单的正向传播。
4. DataLoader 中的集成
在实际训练中,KeyedJaggedTensor 通常在 DataLoader 的 collate 函数中构造:
python
import torch
from torch.utils.data import DataLoader, Dataset
from torchrec import KeyedJaggedTensor, JaggedTensor
class RecDataset(Dataset):
def __init__(self, data):
self.data = data # 格式: [{'feature1': [1,2], 'feature2': [3]}, ...]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
def collate_fn(batch):
"""将 batch 转换为 KeyedJaggedTensor"""
# 收集所有特征名
all_keys = set()
for sample in batch:
all_keys.update(sample.keys())
# 为每个特征创建 JaggedTensor
jt_dict = {}
for key in all_keys:
values = []
lengths = []
for sample in batch:
if key in sample:
values.extend(sample[key])
lengths.append(len(sample[key]))
else:
lengths.append(0) # 处理缺失特征
jt_dict[key] = JaggedTensor(
values=torch.tensor(values, dtype=torch.int64),
lengths=torch.tensor(lengths, dtype=torch.int32)
)
# 从字典构造 KeyedJaggedTensor
return KeyedJaggedTensor.from_jt_dict(jt_dict)
# 使用示例
data = [
{'user_id': [1], 'item_history': [101, 102], 'category': [5]},
{'user_id': [2], 'item_history': [201], 'category': [6]},
{'user_id': [3], 'item_history': [301, 302, 303], 'category': [7]}
]
dataset = RecDataset(data)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)
for batch_kjt in dataloader:
print("Batch KJT:", batch_kjt)
# 这里可以传入模型进行训练
5. 注意事项和最佳实践
5.1 内存效率考虑
- 避免频繁构造: KeyedJaggedTensor 的构造涉及张量拼接,应在 DataLoader 中完成,而不是在模型内部
- 预分配内存: 对于固定结构的数据,考虑预分配内存以减少运行时开销
- 异步数据加载: 使用 DataLoader 的 num_workers 参数启用多进程数据加载
5.2 数据验证
- 检查特征一致性: 确保 batch 内所有样本的特征键一致,或正确处理缺失特征
- 验证长度合法性: 确保 lengths 和 values 的维度匹配
- 数据类型检查: values 通常为 torch.int64,lengths 为 torch.int32
python
def validate_kjt(kjt):
# 检查特征数量
num_features = len(kjt.keys())
# 检查长度维度
lengths = kjt.lengths()
assert lengths.dim() == 2, f"Lengths should be 2D, got {lengths.dim()}D"
assert lengths.shape[0] == num_features, "Lengths shape mismatch"
# 检查 values 长度
total_length = lengths.sum().item()
assert total_length == len(kjt.values()), "Values length mismatch"
5.3 性能优化
- 批处理大小: 选择合适的 batch_size 以平衡内存使用和计算效率
- 特征分组: 将相关特征分组处理,减少 KJT 的复杂度
- GPU 加速: 在 GPU 上构造 KJT 时,确保所有输入张量都在同一设备上
5.4 分布式训练注意事项
在分布式环境中使用 KeyedJaggedTensor 时:
- 分片策略 : 了解不同分片策略(如
RWShardingStrategy,DPShardingStrategy)对 KJT 的影响 - 通信开销: 大规模 KJT 在设备间传输时会产生通信开销,需要优化
- 负载均衡: 确保不同设备上的特征分布均衡,避免热点问题
6. 高级用法
6.1 特征转换和增强
python
def add_feature_interactions(kjt):
"""添加特征交互"""
# 从 KJT 获取原始特征
user_ids = kjt["user_id"].to_padded_dense(0)
item_ids = kjt["item_id"].to_padded_dense(0)
# 创建交互特征
interaction_ids = user_ids * 1000 + item_ids
# 转换回 JaggedTensor 格式
interaction_jt = JaggedTensor(
values=interaction_ids.flatten(),
lengths=torch.ones(len(interaction_ids), dtype=torch.int32)
)
# 添加到原始 KJT
return KeyedJaggedTensor.from_jt_dict({
**{key: kjt[key] for key in kjt.keys()},
"user_item_interaction": interaction_jt
})
6.2 与 TorchRec 其他组件的集成
KeyedJaggedTensor 可以与 TorchRec 的其他组件无缝集成:
python
from torchrec import DenseArch, InteractionArch, OverArch
# 完整的 DLRM 模型示例
class DLRMModel(torch.nn.Module):
def __init__(self, embedding_bag_collection, dense_arch, interaction_arch, over_arch):
super().__init__()
self.ebc = embedding_bag_collection
self.dense_arch = dense_arch
self.interaction_arch = interaction_arch
self.over_arch = over_arch
def forward(self, dense_features, sparse_features):
"""
dense_features: 稠密特征张量
sparse_features: KeyedJaggedTensor 格式的稀疏特征
"""
# 嵌入查找
sparse_embeddings = self.ebc(sparse_features)
# 稠密特征处理
dense_embeddings = self.dense_arch(dense_features)
# 特征交互
interaction_output = self.interaction_arch(dense_embeddings, sparse_embeddings)
# 最终预测
logits = self.over_arch(interaction_output)
return logits
7. 调试和可视化
7.1 打印和调试
python
def print_kjt_info(kjt):
print(f"KeyedJaggedTensor Info:")
print(f"Keys: {kjt.keys()}")
print(f"Values shape: {kjt.values().shape}")
print(f"Lengths shape: {kjt.lengths().shape}")
print(f"Batch size: {kjt.lengths().shape[1]}")
print("\nFeature details:")
for i, key in enumerate(kjt.keys()):
feature_lengths = kjt.lengths()[i]
total_values = feature_lengths.sum().item()
print(f" {key}:")
print(f" Lengths: {feature_lengths.tolist()}")
print(f" Total values: {total_values}")
# 使用示例
print_kjt_info(kjt)
7.2 可视化工具
可以将 KeyedJaggedTensor 转换为 Pandas DataFrame 以便可视化:
python
import pandas as pd
def kjt_to_dataframe(kjt):
"""将 KeyedJaggedTensor 转换为 DataFrame 便于可视化"""
batch_size = kjt.lengths().shape[1]
data = []
for batch_idx in range(batch_size):
row = {"batch_idx": batch_idx}
for key in kjt.keys():
feature_values = kjt[key].to_padded_dense(batch_idx)
row[key] = feature_values.tolist()
data.append(row)
return pd.DataFrame(data)
# 使用示例
df = kjt_to_dataframe(kjt)
print(df)
8. 常见问题和解决方案
8.1 常见错误
错误1: 特征键不一致
python
# 错误示例
batch = [
{'feature1': [1, 2]},
{'feature2': [3, 4]} # 缺少 feature1
]
# 解决方案: 在 collate_fn 中处理缺失特征
def robust_collate_fn(batch):
all_keys = set()
for sample in batch:
all_keys.update(sample.keys())
jt_dict = {}
for key in all_keys:
values = []
lengths = []
for sample in batch:
if key in sample:
values.extend(sample[key])
lengths.append(len(sample[key]))
else:
lengths.append(0) # 为缺失特征填充0长度
# 处理没有值的情况
if values:
jt_dict[key] = JaggedTensor(
values=torch.tensor(values),
lengths=torch.tensor(lengths)
)
else:
jt_dict[key] = JaggedTensor(
values=torch.tensor([], dtype=torch.int64),
lengths=torch.zeros(len(batch), dtype=torch.int32)
)
return KeyedJaggedTensor.from_jt_dict(jt_dict)
错误2: 数据类型不匹配
python
# 错误示例
kjt = KeyedJaggedTensor(
keys=["f1"],
values=torch.tensor([1.0, 2.0]), # float 类型,应该用 int
lengths=torch.tensor()
)
# 解决方案: 确保正确的数据类型
kjt = KeyedJaggedTensor(
keys=["f1"],
values=torch.tensor([1, 2], dtype=torch.int64), # 明确指定 int64
lengths=torch.tensor(, dtype=torch.int32) # 明确指定 int32
)
8.2 性能问题
问题: 大 batch 时内存不足
python
# 优化方案1: 减小 batch size
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)
# 优化方案2: 使用梯度累积
accumulation_steps = 4
for i, batch_kjt in enumerate(dataloader):
logits = model(batch_kjt)
loss = criterion(logits, labels)
loss = loss / accumulation_steps # 归一化损失
loss.backward()
if (i + 1) % accumulation_steps == 0:
optimizer.step()
optimizer.zero_grad()
9. 完整示例:端到端推荐系统
python
import torch
import torch.nn as nn
from torchrec import KeyedJaggedTensor, JaggedTensor, EmbeddingBagCollection, EmbeddingConfig, DenseArch, InteractionArch, OverArch
# 1. 数据准备
class MovieLensDataset(torch.utils.data.Dataset):
def __init__(self):
# 模拟数据
self.data = [
{'user_id': [1], 'movie_id': [101, 102], 'genre': [5, 6]},
{'user_id': [2], 'movie_id': [201], 'genre': [7]},
{'user_id': [3], 'movie_id': [301, 302, 303], 'genre': [8, 9, 10]},
{'user_id': [4], 'movie_id': [401], 'genre': [11]}
]
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
# 2. 自定义 collate 函数
def collate_fn(batch):
all_keys = set()
for sample in batch:
all_keys.update(sample.keys())
jt_dict = {}
for key in sorted(all_keys):
values = []
lengths = []
for sample in batch:
if key in sample:
values.extend(sample[key])
lengths.append(len(sample[key]))
else:
lengths.append(0)
jt_dict[key] = JaggedTensor(
values=torch.tensor(values, dtype=torch.int64),
lengths=torch.tensor(lengths, dtype=torch.int32)
)
return KeyedJaggedTensor.from_jt_dict(jt_dict)
# 3. 模型定义
class RecModel(nn.Module):
def __init__(self):
super().__init__()
# 嵌入表配置
self.ebc = EmbeddingBagCollection(
tables=[
EmbeddingConfig(
name="user_table",
embedding_dim=16,
num_embeddings=1000,
feature_names=["user_id"]
),
EmbeddingConfig(
name="movie_table",
embedding_dim=32,
num_embeddings=10000,
feature_names=["movie_id"]
),
EmbeddingConfig(
name="genre_table",
embedding_dim=8,
num_embeddings=50,
feature_names=["genre"]
)
]
)
# 稠密网络 (为简化,这里假设没有稠密特征)
self.dense_arch = nn.Sequential(
nn.Linear(1, 16), # 假设有一个虚拟的稠密特征
nn.ReLU()
)
# 交互层
self.interaction_arch = InteractionArch(
dense_feature_dim=16,
sparse_feature_names=["user_id", "movie_id", "genre"]
)
# 输出层
self.over_arch = OverArch(
in_features=16 + 32 + 8 + 16, # 假设的维度
layer_sizes=[64, 32, 1]
)
def forward(self, sparse_features, dense_features=None):
if dense_features is None:
dense_features = torch.ones(sparse_features.lengths().shape[1], 1)
# 嵌入查找
sparse_embeddings = self.ebc(sparse_features)
# 稠密特征处理
dense_embeddings = self.dense_arch(dense_features)
# 特征交互
interaction_output = self.interaction_arch(dense_embeddings, sparse_embeddings)
# 最终预测
logits = self.over_arch(interaction_output)
return logits
# 4. 训练循环
def train():
# 创建数据集和数据加载器
dataset = MovieLensDataset()
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=2,
collate_fn=collate_fn,
shuffle=True
)
# 创建模型
model = RecModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.BCEWithLogitsLoss()
# 训练
for epoch in range(3):
for batch_idx, sparse_features in enumerate(dataloader):
# 模拟标签
batch_size = sparse_features.lengths().shape[1]
labels = torch.randint(0, 2, (batch_size, 1), dtype=torch.float)
# 前向传播
logits = model(sparse_features)
# 计算损失
loss = criterion(logits, labels)
# 反向传播
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
# 打印 KJT 信息用于调试
print("\nKeyedJaggedTensor Info:")
print(f"Keys: {sparse_features.keys()}")
print(f"Batch size: {sparse_features.lengths().shape[1]}")
print(f"Values shape: {sparse_features.values().shape}")
if __name__ == "__main__":
train()
10. 总结
KeyedJaggedTensor 是 TorchRec v1.0 中处理稀疏特征的核心数据结构,它通过高效的内存布局和灵活的 API 设计,为大规模推荐系统提供了强大的支持。
核心要点回顾:
- 高效表示: 能够高效地表示多个稀疏特征,可以将其视为多个 JaggedTensor 的集合
- 多种构造方法 : 提供
from_offsets_sync、from_jt_dict等多种构造方法,适应不同场景 - 与 EmbeddingBagCollection 无缝集成: 可直接作为 EmbeddingBagCollection 的输入进行嵌入查找
- DataLoader 友好: 通过自定义 collate 函数,可以轻松地从原始数据构建 KJT
- 分布式支持: 在分布式训练环境中表现良好,支持各种分片策略
最佳实践建议:
- 在 DataLoader 中构建 KJT: 避免在模型内部频繁构造 KJT
- 验证数据一致性: 确保 batch 内特征键的一致性和数据类型的正确性
- 监控内存使用: 大规模 KJT 可能消耗大量内存,需要合理设置 batch size
- 利用 TorchRec 工具: 使用内置的验证和调试工具确保 KJT 的正确性
- 渐进式开发: 从小规模数据开始,逐步扩展到完整系统
通过掌握 KeyedJaggedTensor,你将能够构建高效、可扩展的推荐系统,充分利用 TorchRec v1.0 的强大功能。在实际应用中,建议结合官方文档和社区资源,持续优化和改进你的实现。