【推荐系统】深度学习训练框架(十七):TorchRec之KeyedJaggedTensor

目录

    • [1. 概述](#1. 概述)
      • [1.1 为什么需要 KeyedJaggedTensor?](#1.1 为什么需要 KeyedJaggedTensor?)
      • [1.2 数据形状理解](#1.2 数据形状理解)
    • [2. 核心构造方法](#2. 核心构造方法)
      • [2.1 from_offsets_sync 方法](#2.1 from_offsets_sync 方法)
      • [2.2 from_jt_dict 方法](#2.2 from_jt_dict 方法)
      • [2.3 直接构造方法](#2.3 直接构造方法)
    • [3. 核心 API 和使用示例](#3. 核心 API 和使用示例)
      • [3.1 基本属性访问](#3.1 基本属性访问)
      • [3.2 特征访问](#3.2 特征访问)
      • [3.3 与 EmbeddingBagCollection 集成](#3.3 与 EmbeddingBagCollection 集成)
    • [4. DataLoader 中的集成](#4. DataLoader 中的集成)
    • [5. 注意事项和最佳实践](#5. 注意事项和最佳实践)
      • [5.1 内存效率考虑](#5.1 内存效率考虑)
      • [5.2 数据验证](#5.2 数据验证)
      • [5.3 性能优化](#5.3 性能优化)
      • [5.4 分布式训练注意事项](#5.4 分布式训练注意事项)
    • [6. 高级用法](#6. 高级用法)
      • [6.1 特征转换和增强](#6.1 特征转换和增强)
      • [6.2 与 TorchRec 其他组件的集成](#6.2 与 TorchRec 其他组件的集成)
    • [7. 调试和可视化](#7. 调试和可视化)
      • [7.1 打印和调试](#7.1 打印和调试)
      • [7.2 可视化工具](#7.2 可视化工具)
    • [8. 常见问题和解决方案](#8. 常见问题和解决方案)
      • [8.1 常见错误](#8.1 常见错误)
      • [8.2 性能问题](#8.2 性能问题)
    • [9. 完整示例:端到端推荐系统](#9. 完整示例:端到端推荐系统)
    • [10. 总结](#10. 总结)

1. 概述

KeyedJaggedTensor (KJT) 是 TorchRec v1.0 中用于表示稀疏特征的核心数据结构之一,专门设计用于高效管理多个变长特征序列。 在推荐系统中,它能够同时处理用户历史点击(click_ids)、搜索词(search_ids)等多种稀疏特征,是构建大规模推荐系统的基础组件。

1.1 为什么需要 KeyedJaggedTensor?

在推荐系统中,我们经常需要处理:

  • 多个稀疏特征(用户ID、商品ID、类别ID等)
  • 每个特征在不同样本中的长度不同(用户历史点击数不同)
  • 高效的内存布局和计算性能

KeyedJaggedTensor 正是为了解决这些问题而设计的,它可以将多个 JaggedTensor 高效地组织在一起。

1.2 数据形状理解

KeyedJaggedTensor 的形状表示为 (F, B, L[f][i]),其中:

  • F: 特征(键)的数量
  • B: 批量大小
  • L[f][i]: 特征 f 在批量索引 i 处的长度

2. 核心构造方法

KeyedJaggedTensor 提供了多种构造方法,适应不同的使用场景:

2.1 from_offsets_sync 方法

这是最常用的构造方法,通过指定偏移量来构建 KJT。

python 复制代码
import torch
from torchrec import KeyedJaggedTensor

# 创建一个简单的 KeyedJaggedTensor
kjt = KeyedJaggedTensor.from_offsets_sync(
    keys=["feature1", "feature2"],
    values=torch.tensor([0, 1, 2, 3, 4, 5, 6]),
    offsets=torch.tensor([0, 2, 4, 7])  # 偏移量定义
)

这里的 offsets 参数定义了每个特征在 values 张量中的起始位置。

2.2 from_jt_dict 方法

当已有多个 JaggedTensor 时,可以使用此方法将它们组合成一个 KeyedJaggedTensor。

python 复制代码
from torchrec import JaggedTensor

# 创建单个特征的 JaggedTensor
feature1_jt = JaggedTensor(
    values=torch.tensor([1, 2, 3]),
    lengths=torch.tensor([2, 1])  # 两个样本,第一个样本2个值,第二个样本1个值
)

feature2_jt = JaggedTensor(
    values=torch.tensor([4, 5]),
    lengths=torch.tensor([1, 1])
)

# 从字典构造 KeyedJaggedTensor
kjt = KeyedJaggedTensor.from_jt_dict({
    "feature1": feature1_jt,
    "feature2": feature2_jt
})

这种方法在 DataLoader 的 collate 函数中特别有用,可以将按特征组织的数据聚合起来。

2.3 直接构造方法

也可以直接通过 values 和 lengths 参数构造:

python 复制代码
kjt = KeyedJaggedTensor(
    keys=["f1", "f2"],
    values=torch.tensor([1, 2, 3, 4, 5]),
    lengths=torch.tensor([[2, 1], [1, 1]])  # 形状: [num_features, batch_size]
)

3. 核心 API 和使用示例

3.1 基本属性访问

python 复制代码
# 获取所有键
print(kjt.keys())  # ['feature1', 'feature2']

# 获取所有值
print(kjt.values())  # tensor([1, 2, 3, 4, 5])

# 获取长度信息
print(kjt.lengths())  # tensor([[2, 1], [1, 1]])

# 获取偏移量
print(kjt.offsets())  # 如果未计算会自动计算

KeyedJaggedTensor 提供了 offsets() 方法,如果偏移量尚未计算,它会自动计算并返回。

3.2 特征访问

python 复制代码
# 通过键名访问特定特征
feature1_data = kjt["feature1"]
print(feature1_data)  # 返回对应的 JaggedTensor

# 检查特征是否存在
if "feature1" in kjt:
    print("feature1 exists")

3.3 与 EmbeddingBagCollection 集成

KeyedJaggedTensor 的主要用途是与 EmbeddingBagCollection 配合使用,进行嵌入查找:

python 复制代码
from torchrec import EmbeddingBagCollection, EmbeddingConfig

# 定义嵌入配置
ebc = EmbeddingBagCollection(
    tables=[
        EmbeddingConfig(
            name="feature1_table",
            embedding_dim=16,
            num_embeddings=100,
            feature_names=["feature1"]
        ),
        EmbeddingConfig(
            name="feature2_table",
            embedding_dim=16,
            num_embeddings=100,
            feature_names=["feature2"]
        )
    ]
)

# 前向传播
embeddings = ebc(kjt)  # KJT 作为输入

下面,我们将使用 TorchRec 数据类型 KeyedJaggedTensor 和 EmbeddingBagCollection 进行简单的正向传播。

4. DataLoader 中的集成

在实际训练中,KeyedJaggedTensor 通常在 DataLoader 的 collate 函数中构造:

python 复制代码
import torch
from torch.utils.data import DataLoader, Dataset
from torchrec import KeyedJaggedTensor, JaggedTensor

class RecDataset(Dataset):
    def __init__(self, data):
        self.data = data  # 格式: [{'feature1': [1,2], 'feature2': [3]}, ...]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    """将 batch 转换为 KeyedJaggedTensor"""
    # 收集所有特征名
    all_keys = set()
    for sample in batch:
        all_keys.update(sample.keys())
    
    # 为每个特征创建 JaggedTensor
    jt_dict = {}
    for key in all_keys:
        values = []
        lengths = []
        for sample in batch:
            if key in sample:
                values.extend(sample[key])
                lengths.append(len(sample[key]))
            else:
                lengths.append(0)  # 处理缺失特征
        
        jt_dict[key] = JaggedTensor(
            values=torch.tensor(values, dtype=torch.int64),
            lengths=torch.tensor(lengths, dtype=torch.int32)
        )
    
    # 从字典构造 KeyedJaggedTensor
    return KeyedJaggedTensor.from_jt_dict(jt_dict)

# 使用示例
data = [
    {'user_id': [1], 'item_history': [101, 102], 'category': [5]},
    {'user_id': [2], 'item_history': [201], 'category': [6]},
    {'user_id': [3], 'item_history': [301, 302, 303], 'category': [7]}
]

dataset = RecDataset(data)
dataloader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn)

for batch_kjt in dataloader:
    print("Batch KJT:", batch_kjt)
    # 这里可以传入模型进行训练

5. 注意事项和最佳实践

5.1 内存效率考虑

  • 避免频繁构造: KeyedJaggedTensor 的构造涉及张量拼接,应在 DataLoader 中完成,而不是在模型内部
  • 预分配内存: 对于固定结构的数据,考虑预分配内存以减少运行时开销
  • 异步数据加载: 使用 DataLoader 的 num_workers 参数启用多进程数据加载

5.2 数据验证

  • 检查特征一致性: 确保 batch 内所有样本的特征键一致,或正确处理缺失特征
  • 验证长度合法性: 确保 lengths 和 values 的维度匹配
  • 数据类型检查: values 通常为 torch.int64,lengths 为 torch.int32
python 复制代码
def validate_kjt(kjt):
    # 检查特征数量
    num_features = len(kjt.keys())
    
    # 检查长度维度
    lengths = kjt.lengths()
    assert lengths.dim() == 2, f"Lengths should be 2D, got {lengths.dim()}D"
    assert lengths.shape[0] == num_features, "Lengths shape mismatch"
    
    # 检查 values 长度
    total_length = lengths.sum().item()
    assert total_length == len(kjt.values()), "Values length mismatch"

5.3 性能优化

  • 批处理大小: 选择合适的 batch_size 以平衡内存使用和计算效率
  • 特征分组: 将相关特征分组处理,减少 KJT 的复杂度
  • GPU 加速: 在 GPU 上构造 KJT 时,确保所有输入张量都在同一设备上

5.4 分布式训练注意事项

在分布式环境中使用 KeyedJaggedTensor 时:

  • 分片策略 : 了解不同分片策略(如 RWShardingStrategy, DPShardingStrategy)对 KJT 的影响
  • 通信开销: 大规模 KJT 在设备间传输时会产生通信开销,需要优化
  • 负载均衡: 确保不同设备上的特征分布均衡,避免热点问题

6. 高级用法

6.1 特征转换和增强

python 复制代码
def add_feature_interactions(kjt):
    """添加特征交互"""
    # 从 KJT 获取原始特征
    user_ids = kjt["user_id"].to_padded_dense(0)
    item_ids = kjt["item_id"].to_padded_dense(0)
    
    # 创建交互特征
    interaction_ids = user_ids * 1000 + item_ids
    
    # 转换回 JaggedTensor 格式
    interaction_jt = JaggedTensor(
        values=interaction_ids.flatten(),
        lengths=torch.ones(len(interaction_ids), dtype=torch.int32)
    )
    
    # 添加到原始 KJT
    return KeyedJaggedTensor.from_jt_dict({
        **{key: kjt[key] for key in kjt.keys()},
        "user_item_interaction": interaction_jt
    })

6.2 与 TorchRec 其他组件的集成

KeyedJaggedTensor 可以与 TorchRec 的其他组件无缝集成:

python 复制代码
from torchrec import DenseArch, InteractionArch, OverArch

# 完整的 DLRM 模型示例
class DLRMModel(torch.nn.Module):
    def __init__(self, embedding_bag_collection, dense_arch, interaction_arch, over_arch):
        super().__init__()
        self.ebc = embedding_bag_collection
        self.dense_arch = dense_arch
        self.interaction_arch = interaction_arch
        self.over_arch = over_arch
    
    def forward(self, dense_features, sparse_features):
        """
        dense_features: 稠密特征张量
        sparse_features: KeyedJaggedTensor 格式的稀疏特征
        """
        # 嵌入查找
        sparse_embeddings = self.ebc(sparse_features)
        
        # 稠密特征处理
        dense_embeddings = self.dense_arch(dense_features)
        
        # 特征交互
        interaction_output = self.interaction_arch(dense_embeddings, sparse_embeddings)
        
        # 最终预测
        logits = self.over_arch(interaction_output)
        return logits

7. 调试和可视化

7.1 打印和调试

python 复制代码
def print_kjt_info(kjt):
    print(f"KeyedJaggedTensor Info:")
    print(f"Keys: {kjt.keys()}")
    print(f"Values shape: {kjt.values().shape}")
    print(f"Lengths shape: {kjt.lengths().shape}")
    print(f"Batch size: {kjt.lengths().shape[1]}")
    
    print("\nFeature details:")
    for i, key in enumerate(kjt.keys()):
        feature_lengths = kjt.lengths()[i]
        total_values = feature_lengths.sum().item()
        print(f"  {key}:")
        print(f"    Lengths: {feature_lengths.tolist()}")
        print(f"    Total values: {total_values}")

# 使用示例
print_kjt_info(kjt)

7.2 可视化工具

可以将 KeyedJaggedTensor 转换为 Pandas DataFrame 以便可视化:

python 复制代码
import pandas as pd

def kjt_to_dataframe(kjt):
    """将 KeyedJaggedTensor 转换为 DataFrame 便于可视化"""
    batch_size = kjt.lengths().shape[1]
    data = []
    
    for batch_idx in range(batch_size):
        row = {"batch_idx": batch_idx}
        for key in kjt.keys():
            feature_values = kjt[key].to_padded_dense(batch_idx)
            row[key] = feature_values.tolist()
        data.append(row)
    
    return pd.DataFrame(data)

# 使用示例
df = kjt_to_dataframe(kjt)
print(df)

8. 常见问题和解决方案

8.1 常见错误

错误1: 特征键不一致

python 复制代码
# 错误示例
batch = [
    {'feature1': [1, 2]},
    {'feature2': [3, 4]}  # 缺少 feature1
]

# 解决方案: 在 collate_fn 中处理缺失特征
def robust_collate_fn(batch):
    all_keys = set()
    for sample in batch:
        all_keys.update(sample.keys())
    
    jt_dict = {}
    for key in all_keys:
        values = []
        lengths = []
        for sample in batch:
            if key in sample:
                values.extend(sample[key])
                lengths.append(len(sample[key]))
            else:
                lengths.append(0)  # 为缺失特征填充0长度
        # 处理没有值的情况
        if values:
            jt_dict[key] = JaggedTensor(
                values=torch.tensor(values),
                lengths=torch.tensor(lengths)
            )
        else:
            jt_dict[key] = JaggedTensor(
                values=torch.tensor([], dtype=torch.int64),
                lengths=torch.zeros(len(batch), dtype=torch.int32)
            )
    
    return KeyedJaggedTensor.from_jt_dict(jt_dict)

错误2: 数据类型不匹配

python 复制代码
# 错误示例
kjt = KeyedJaggedTensor(
    keys=["f1"],
    values=torch.tensor([1.0, 2.0]),  # float 类型,应该用 int
    lengths=torch.tensor()
)

# 解决方案: 确保正确的数据类型
kjt = KeyedJaggedTensor(
    keys=["f1"],
    values=torch.tensor([1, 2], dtype=torch.int64),  # 明确指定 int64
    lengths=torch.tensor(, dtype=torch.int32)   # 明确指定 int32
)

8.2 性能问题

问题: 大 batch 时内存不足

python 复制代码
# 优化方案1: 减小 batch size
dataloader = DataLoader(dataset, batch_size=32, collate_fn=collate_fn)

# 优化方案2: 使用梯度累积
accumulation_steps = 4
for i, batch_kjt in enumerate(dataloader):
    logits = model(batch_kjt)
    loss = criterion(logits, labels)
    loss = loss / accumulation_steps  # 归一化损失
    loss.backward()
    
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

9. 完整示例:端到端推荐系统

python 复制代码
import torch
import torch.nn as nn
from torchrec import KeyedJaggedTensor, JaggedTensor, EmbeddingBagCollection, EmbeddingConfig, DenseArch, InteractionArch, OverArch

# 1. 数据准备
class MovieLensDataset(torch.utils.data.Dataset):
    def __init__(self):
        # 模拟数据
        self.data = [
            {'user_id': [1], 'movie_id': [101, 102], 'genre': [5, 6]},
            {'user_id': [2], 'movie_id': [201], 'genre': [7]},
            {'user_id': [3], 'movie_id': [301, 302, 303], 'genre': [8, 9, 10]},
            {'user_id': [4], 'movie_id': [401], 'genre': [11]}
        ]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]

# 2. 自定义 collate 函数
def collate_fn(batch):
    all_keys = set()
    for sample in batch:
        all_keys.update(sample.keys())
    
    jt_dict = {}
    for key in sorted(all_keys):
        values = []
        lengths = []
        for sample in batch:
            if key in sample:
                values.extend(sample[key])
                lengths.append(len(sample[key]))
            else:
                lengths.append(0)
        
        jt_dict[key] = JaggedTensor(
            values=torch.tensor(values, dtype=torch.int64),
            lengths=torch.tensor(lengths, dtype=torch.int32)
        )
    
    return KeyedJaggedTensor.from_jt_dict(jt_dict)

# 3. 模型定义
class RecModel(nn.Module):
    def __init__(self):
        super().__init__()
        
        # 嵌入表配置
        self.ebc = EmbeddingBagCollection(
            tables=[
                EmbeddingConfig(
                    name="user_table",
                    embedding_dim=16,
                    num_embeddings=1000,
                    feature_names=["user_id"]
                ),
                EmbeddingConfig(
                    name="movie_table",
                    embedding_dim=32,
                    num_embeddings=10000,
                    feature_names=["movie_id"]
                ),
                EmbeddingConfig(
                    name="genre_table",
                    embedding_dim=8,
                    num_embeddings=50,
                    feature_names=["genre"]
                )
            ]
        )
        
        # 稠密网络 (为简化,这里假设没有稠密特征)
        self.dense_arch = nn.Sequential(
            nn.Linear(1, 16),  # 假设有一个虚拟的稠密特征
            nn.ReLU()
        )
        
        # 交互层
        self.interaction_arch = InteractionArch(
            dense_feature_dim=16,
            sparse_feature_names=["user_id", "movie_id", "genre"]
        )
        
        # 输出层
        self.over_arch = OverArch(
            in_features=16 + 32 + 8 + 16,  # 假设的维度
            layer_sizes=[64, 32, 1]
        )
    
    def forward(self, sparse_features, dense_features=None):
        if dense_features is None:
            dense_features = torch.ones(sparse_features.lengths().shape[1], 1)
        
        # 嵌入查找
        sparse_embeddings = self.ebc(sparse_features)
        
        # 稠密特征处理
        dense_embeddings = self.dense_arch(dense_features)
        
        # 特征交互
        interaction_output = self.interaction_arch(dense_embeddings, sparse_embeddings)
        
        # 最终预测
        logits = self.over_arch(interaction_output)
        return logits

# 4. 训练循环
def train():
    # 创建数据集和数据加载器
    dataset = MovieLensDataset()
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=2,
        collate_fn=collate_fn,
        shuffle=True
    )
    
    # 创建模型
    model = RecModel()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = nn.BCEWithLogitsLoss()
    
    # 训练
    for epoch in range(3):
        for batch_idx, sparse_features in enumerate(dataloader):
            # 模拟标签
            batch_size = sparse_features.lengths().shape[1]
            labels = torch.randint(0, 2, (batch_size, 1), dtype=torch.float)
            
            # 前向传播
            logits = model(sparse_features)
            
            # 计算损失
            loss = criterion(logits, labels)
            
            # 反向传播
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            print(f"Epoch {epoch}, Batch {batch_idx}, Loss: {loss.item():.4f}")
            
            # 打印 KJT 信息用于调试
            print("\nKeyedJaggedTensor Info:")
            print(f"Keys: {sparse_features.keys()}")
            print(f"Batch size: {sparse_features.lengths().shape[1]}")
            print(f"Values shape: {sparse_features.values().shape}")

if __name__ == "__main__":
    train()

10. 总结

KeyedJaggedTensor 是 TorchRec v1.0 中处理稀疏特征的核心数据结构,它通过高效的内存布局和灵活的 API 设计,为大规模推荐系统提供了强大的支持。

核心要点回顾:

  • 高效表示: 能够高效地表示多个稀疏特征,可以将其视为多个 JaggedTensor 的集合
  • 多种构造方法 : 提供 from_offsets_syncfrom_jt_dict 等多种构造方法,适应不同场景
  • 与 EmbeddingBagCollection 无缝集成: 可直接作为 EmbeddingBagCollection 的输入进行嵌入查找
  • DataLoader 友好: 通过自定义 collate 函数,可以轻松地从原始数据构建 KJT
  • 分布式支持: 在分布式训练环境中表现良好,支持各种分片策略

最佳实践建议:

  1. 在 DataLoader 中构建 KJT: 避免在模型内部频繁构造 KJT
  2. 验证数据一致性: 确保 batch 内特征键的一致性和数据类型的正确性
  3. 监控内存使用: 大规模 KJT 可能消耗大量内存,需要合理设置 batch size
  4. 利用 TorchRec 工具: 使用内置的验证和调试工具确保 KJT 的正确性
  5. 渐进式开发: 从小规模数据开始,逐步扩展到完整系统

通过掌握 KeyedJaggedTensor,你将能够构建高效、可扩展的推荐系统,充分利用 TorchRec v1.0 的强大功能。在实际应用中,建议结合官方文档和社区资源,持续优化和改进你的实现。

相关推荐
imooos2 小时前
使用小程序AI推理能力识别车牌号
人工智能·小程序
神州数码云基地2 小时前
首次开发陌生技术?用 AI 赋能前端提速开发!
前端·人工智能·开源·ai开发
weixin_446260852 小时前
用于构建和部署AI智能代理工作流的开源平台
人工智能
CoovallyAIHub2 小时前
从电影特效到体育科学,运动追踪只能靠“人眼”吗?
深度学习·算法·计算机视觉
一招定胜负2 小时前
支持向量机
人工智能·机器学习·支持向量机
paopao_wu2 小时前
深度学习3:理解神经网络
人工智能·深度学习·神经网络
梦帮科技2 小时前
量子计算+AI:下一代智能的终极形态?(第二部分)
人工智能·机器学习·ai编程·量子计算
周杰伦_Jay2 小时前
【深度拆解智能体技术底层逻辑】从架构到实现的完整解析
人工智能·机器学习·架构·开源·论文·peai2026
EXtreme352 小时前
【DL】从零构建智能:神经网络前向传播、反向传播与激活函数深度解密
人工智能·深度学习·神经网络·梯度下降·反向传播·链式法则