「赤兔」Chitu 框架深度解读(十四):核心算子优化

「赤兔」Chitu 框架深度解读(十四):核心算子优化 (RoPE, RMSNorm, SiLU, Sampling)

除了 Attention 和 GEMM 等计算密集型算子外,Transformer 模型中还包含许多其他关键算子,如位置编码(RoPE)、归一化(RMSNorm)、激活函数(SiLU/SwiGLU)和采样(Sampling)。对这些算子进行优化同样能带来显著的性能提升。「赤兔」Chitu 框架在其 ops/ 目录下提供了多种优化实现。

旋转位置编码 (RoPE - ops/rotary.py)

RoPE 是目前主流大模型使用的位置编码方式。其核心操作是将输入向量的一部分维度进行旋转。

  • apply_rotary_pos_emb: 核心函数接口。
  • 后端实现 :
    • apply_rotary_pos_emb_cuda : 调用 CUDA C++ 扩展 (chitu_backend.apply_rotary_pos_emb) 实现。这是 NVIDIA GPU 上的主要优化路径。代码位于 csrc/cuda/rotary/rotary_pos_emb_llama.cu
    • apply_rotary_pos_emb_npu : 调用昇腾 NPU 的特定算子 (torch_npu.npu_rotary_mul)。
    • apply_rotary_pos_emb_cpu : 调用 CPU C++ 扩展 (chitu_cpu_backend.apply_rotary_pos_emb)。代码位于 csrc/cpuinfer/rotary.cpp,使用了 llamafile 中的优化代码。
    • apply_rotary_pos_emb_triton : 调用 chitu.ops.triton_ops.rotary.apply_rotary_pos_emb 实现的 Triton 内核。
    • apply_rotary_pos_emb_torch: PyTorch 原生实现,作为 fallback。
  • 动态调度 : 函数内部会根据当前设备类型 (is_cuda, is_npu, is_cpu) 和配置选择合适的后端实现。
  • BatchedFreqsCIS (batched_freqs_cis.py) : 用于预计算和缓存 RoPE 所需的 cossin 值,避免在每次 Attention 计算时重复生成,特别是在批处理和 Paged KV Cache 场景下优化 freqs 的构造。

RMSNorm (ops/norm.py)

RMSNorm 是一种轻量级的归一化方法。

  • rms_norm: 核心函数接口。
  • 后端实现 :
    • rms_norm_cuda : 调用 CUDA C++ 扩展 (chitu_backend.rms_norm)。代码位于 csrc/cuda/norm/rms_norm.cu
    • rms_norm_npu : 调用昇腾 NPU 的 torch_npu.npu_rms_norm 算子。
    • rms_norm_cpu : 调用 CPU C++ 扩展 (chitu_cpu_backend.rms_norm)。代码位于 csrc/cpuinfer/rmsnorm.cpp
    • rms_norm_triton : 调用 chitu.ops.triton_ops.norm.rms_norm 实现的 Triton 内核。
    • rms_norm_torch: PyTorch 原生实现,作为 fallback。
  • 动态调度: 同样根据设备类型和配置选择后端。

SiLU / SwiGLU 激活 (ops/activation.py)

SiLU (Sigmoid Linear Unit) 及其变种 SwiGLU 是 Llama 等模型 MLP 层常用的激活函数,通常涉及两个线性层的输出和一个逐元素操作 (SiLU(gate) * up)。

  • silu_and_mul : 计算 SiLU(x) * y
  • fused_gate_up_proj_and_silu: 尝试融合 Gate 和 Up 两个线性投影以及后续的 SiLU 激活和乘法操作。
  • 后端实现 :
    • CUDA C++ : chitu_backend.silu_and_mul (可能实现融合)。
    • NPU : torch_npu.npu_silu 和逐元素乘法。融合实现可能依赖特定 NPU 算子。
    • CPU C++ : chitu_cpu_backend.silu_and_mul (csrc/cpuinfer/silu_and_mul.cpp)。
    • Triton : chitu.ops.triton_ops.activation.silu_and_mul
    • Torch: PyTorch 原生实现。
  • 动态调度: 根据设备类型和可用性选择实现。

Logits 处理与采样 (ops/sampling.py)

生成下一个 Token 涉及对模型输出的 Logits 进行处理(如应用温度、Top-K/Top-P、重复惩罚等)和采样。

  • apply_logits_constraints : 应用各种约束(Repetition Penalty, Frequency/Presence Penalty, Temperature, Top-K, Top-P)。
    • apply_penalties_cuda : 调用 CUDA C++ 扩展 (chitu_backend.apply_penalties) 处理 Frequency 和 Presence Penalty。代码位于 csrc/cuda/frequency_penalty/frequency_penalty.cu。其他约束通常在 PyTorch 或 Triton 中实现。
    • apply_logits_constraints_triton : 调用 chitu.ops.triton_ops.sampling.apply_logits_constraints 实现的 Triton 内核,可能融合多种约束处理。
    • apply_logits_constraints_torch: PyTorch 原生实现。
  • sample : 从处理后的 Logits/Probabilities 中采样下一个 Token ID。
    • sample_cuda: 调用 CUDA C++ 扩展。
    • sample_triton: 调用 Triton 内核实现。
    • sample_torch : PyTorch 原生实现 (torch.multinomial)。
    • Greedy Search: 也提供了 Argmax 实现。

特点:

  • 融合优化: 尝试将多种 Logits 处理步骤(惩罚、温度、Top-K/P)融合到单个 CUDA 或 Triton 内核中,减少内存读写和 Kernel Launch 开销。
  • 后端多样性: 为关键的惩罚计算和采样提供了 CUDA/Triton/Torch 实现。

总结

「赤兔」对 RoPE, RMSNorm, SiLU/SwiGLU, Sampling 等核心算子进行了细致的优化。通过提供包括 CUDA C++ 扩展、NPU 特定算子、CPU C++ 扩展、Triton 内核以及 PyTorch 原生实现在内的多种后端,并根据运行环境动态选择最优路径。「赤兔」确保了这些看似"辅助"的算子不会成为推理流程中的性能瓶颈。特别是对 RoPE 的预计算缓存、激活函数的融合尝试以及 Sampling 约束处理的融合优化,都体现了其追求端到端高性能的设计理念。# 「赤兔」Chitu 框架深度解读(十四):核心算子优化 (RoPE, RMSNorm, SiLU, Sampling)

除了 Attention 和 GEMM 等计算密集型算子外,Transformer 模型中还包含许多其他关键算子,如位置编码(RoPE)、归一化(RMSNorm)、激活函数(SiLU/SwiGLU)和采样(Sampling)。对这些算子进行优化同样能带来显著的性能提升。「赤兔」Chitu 框架在其 ops/ 目录下提供了多种优化实现。

旋转位置编码 (RoPE - ops/rotary.py)

RoPE 是目前主流大模型使用的位置编码方式。其核心操作是将输入向量的一部分维度进行旋转。

  • apply_rotary_pos_emb: 核心函数接口。
  • 后端实现 :
    • apply_rotary_pos_emb_cuda : 调用 CUDA C++ 扩展 (chitu_backend.apply_rotary_pos_emb) 实现。这是 NVIDIA GPU 上的主要优化路径。代码位于 csrc/cuda/rotary/rotary_pos_emb_llama.cu
    • apply_rotary_pos_emb_npu : 调用昇腾 NPU 的特定算子 (torch_npu.npu_rotary_mul)。
    • apply_rotary_pos_emb_cpu : 调用 CPU C++ 扩展 (chitu_cpu_backend.apply_rotary_pos_emb)。代码位于 csrc/cpuinfer/rotary.cpp,使用了 llamafile 中的优化代码。
    • apply_rotary_pos_emb_triton : 调用 chitu.ops.triton_ops.rotary.apply_rotary_pos_emb 实现的 Triton 内核。
    • apply_rotary_pos_emb_torch: PyTorch 原生实现,作为 fallback。
  • 动态调度 : 函数内部会根据当前设备类型 (is_cuda, is_npu, is_cpu) 和配置选择合适的后端实现。
  • BatchedFreqsCIS (batched_freqs_cis.py) : 用于预计算和缓存 RoPE 所需的 cossin 值,避免在每次 Attention 计算时重复生成,特别是在批处理和 Paged KV Cache 场景下优化 freqs 的构造。

RMSNorm (ops/norm.py)

RMSNorm 是一种轻量级的归一化方法。

  • rms_norm: 核心函数接口。
  • 后端实现 :
    • rms_norm_cuda : 调用 CUDA C++ 扩展 (chitu_backend.rms_norm)。代码位于 csrc/cuda/norm/rms_norm.cu
    • rms_norm_npu : 调用昇腾 NPU 的 torch_npu.npu_rms_norm 算子。
    • rms_norm_cpu : 调用 CPU C++ 扩展 (chitu_cpu_backend.rms_norm)。代码位于 csrc/cpuinfer/rmsnorm.cpp
    • rms_norm_triton : 调用 chitu.ops.triton_ops.norm.rms_norm 实现的 Triton 内核。
    • rms_norm_torch: PyTorch 原生实现,作为 fallback。
  • 动态调度: 同样根据设备类型和配置选择后端。

SiLU / SwiGLU 激活 (ops/activation.py)

SiLU (Sigmoid Linear Unit) 及其变种 SwiGLU 是 Llama 等模型 MLP 层常用的激活函数,通常涉及两个线性层的输出和一个逐元素操作 (SiLU(gate) * up)。

  • silu_and_mul : 计算 SiLU(x) * y
  • fused_gate_up_proj_and_silu: 尝试融合 Gate 和 Up 两个线性投影以及后续的 SiLU 激活和乘法操作。
  • 后端实现 :
    • CUDA C++ : chitu_backend.silu_and_mul (可能实现融合)。
    • NPU : torch_npu.npu_silu 和逐元素乘法。融合实现可能依赖特定 NPU 算子。
    • CPU C++ : chitu_cpu_backend.silu_and_mul (csrc/cpuinfer/silu_and_mul.cpp)。
    • Triton : chitu.ops.triton_ops.activation.silu_and_mul
    • Torch: PyTorch 原生实现。
  • 动态调度: 根据设备类型和可用性选择实现。

Logits 处理与采样 (ops/sampling.py)

生成下一个 Token 涉及对模型输出的 Logits 进行处理(如应用温度、Top-K/Top-P、重复惩罚等)和采样。

  • apply_logits_constraints : 应用各种约束(Repetition Penalty, Frequency/Presence Penalty, Temperature, Top-K, Top-P)。
    • apply_penalties_cuda : 调用 CUDA C++ 扩展 (chitu_backend.apply_penalties) 处理 Frequency 和 Presence Penalty。代码位于 csrc/cuda/frequency_penalty/frequency_penalty.cu。其他约束通常在 PyTorch 或 Triton 中实现。
    • apply_logits_constraints_triton : 调用 chitu.ops.triton_ops.sampling.apply_logits_constraints 实现的 Triton 内核,可能融合多种约束处理。
    • apply_logits_constraints_torch: PyTorch 原生实现。
  • sample : 从处理后的 Logits/Probabilities 中采样下一个 Token ID。
    • sample_cuda: 调用 CUDA C++ 扩展。
    • sample_triton: 调用 Triton 内核实现。
    • sample_torch : PyTorch 原生实现 (torch.multinomial)。
    • Greedy Search: 也提供了 Argmax 实现。

特点:

  • 融合优化: 尝试将多种 Logits 处理步骤(惩罚、温度、Top-K/P)融合到单个 CUDA 或 Triton 内核中,减少内存读写和 Kernel Launch 开销。
  • 后端多样性: 为关键的惩罚计算和采样提供了 CUDA/Triton/Torch 实现。

总结

「赤兔」对 RoPE, RMSNorm, SiLU/SwiGLU, Sampling 等核心算子进行了细致的优化。通过提供包括 CUDA C++ 扩展、NPU 特定算子、CPU C++ 扩展、Triton 内核以及 PyTorch 原生实现在内的多种后端,并根据运行环境动态选择最优路径。「赤兔」确保了这些看似"辅助"的算子不会成为推理流程中的性能瓶颈。特别是对 RoPE 的预计算缓存、激活函数的融合尝试以及 Sampling 约束处理的融合优化,都体现了其追求端到端高性能的设计理念。

相关推荐
CoderYanger2 分钟前
贪心算法:8.买卖股票的最佳时机
java·算法·leetcode·贪心算法·1024程序员节
爱笑的眼睛114 分钟前
SQLAlchemy 核心 API 深度解析:超越 ORM 的数据库工具包
java·人工智能·python·ai
知白守黑V8 分钟前
OWASP 2025 LLM 应用十大安全风险深度解析
人工智能·安全·ai agent·ai智能体·ai应用·ai安全·大模型安全
zhaodiandiandian9 分钟前
生成式AI重构内容创作生态:人机协同成核心竞争力
大数据·人工智能·重构
努力毕业的小土博^_^15 分钟前
【AI课程领学】基于SmolVLM2与Qwen3的多模态模型拼接实践:从零构建视觉语言模型(一)
人工智能·深度学习·神经网络·机器学习·语言模型·自然语言处理
Lululaurel19 分钟前
AI编程提示词工程实战指南:从入门到精通
人工智能·python·机器学习·ai·ai编程
财经三剑客30 分钟前
东风集团股份:11月生产量达21.6万辆 销量19.6万辆
大数据·人工智能·汽车
老蒋新思维33 分钟前
创客匠人峰会新解:高势能 IP 打造 ——AI 时代知识变现的十倍增长密码
大数据·网络·人工智能·tcp/ip·创始人ip·创客匠人·知识变现
Dev7z35 分钟前
基于神经网络的风电机组齿轮箱故障诊断研究与设计
人工智能·深度学习·神经网络
老蒋新思维35 分钟前
创客匠人峰会洞察:AI 时代教育知识变现的重构 —— 从 “刷题记忆” 到 “成长赋能” 的革命
大数据·人工智能·网络协议·tcp/ip·重构·创始人ip·创客匠人