DeepSeek开源周首日:发布大模型加速核心技术可变长度高效FlashMLA 加持H800算力解码性能狂飙升至3000GB/s

FlashMLA的核心技术特性包括对BF16精度的全面支持,以及采用块大小为64的页式键值缓存(Paged KV Cache)系统,实现更精确的内存管理。在性能表现方面,基于CUDA12.6平台,FlashMLA在H800SXM5GPU上创下了显著成绩:在内存受限场景下达到3000GB/s的处理速度,在计算受限场景下则实现580TFLOPS的算力水平。

1. 核心功能与特性

  • 性能提升

    FlashMLA在H800 SXM5 GPU(CUDA 12.6)上表现亮眼:

    • 内存受限场景下带宽达3000 GB/s
    • 计算受限场景下算力峰值达580 TFLOPS(BF16精度)
  • 关键技术优化

    • 变长序列处理:针对自然语言处理中的动态序列长度优化,提升长文本推理效率。
    • 分页KV缓存:块大小为64的分页机制,减少显存碎片化,提升内存利用率。
    • BF16支持:通过低精度计算降低内存占用,同时保持模型性能。
  • MLA架构创新

    相比传统注意力机制,MLA通过低秩压缩技术 将每次查询的KV缓存量减少93.3%,显著降低推理时的显存需求,尤其适合长上下文场景。


2. 技术背景与意义

  • 解决行业痛点

    Transformer模型在长序列推理时面临KV缓存膨胀 问题,导致显存占用高、硬件成本攀升。FlashMLA通过MLA架构和并行解码设计,将推理成本降低约80-90%,同时支持更高吞吐量

  • 开源生态价值

    FlashMLA开源代码库(GitHub链接)整合了FlashAttention-2/3和CUTLASS的技术实现,为开发者提供可复现的优化方案,加速AGI技术迭代。


3. 应用场景与部署

  • 适用场景

    • 大语言模型(LLM)推理加速,如对话AI、实时翻译、长文本生成等。
    • 需要低延迟、高吞吐的工业级NLP任务。
  • 部署要求

    • 硬件:Hopper架构GPU(如H800/H100)
    • 软件:CUDA 12.3+、PyTorch 2.0+

4. 对行业的影响

  • 成本革命

    DeepSeek通过MLA技术将模型训练和推理成本压缩至行业标杆水平。例如,其V3模型的训练成本仅600万美元(未含研发投入),而MLA的推理优化进一步降低商业化门槛。

  • 算力效率提升

    结合MoE(混合专家模型)架构和多Token预测技术,DeepSeek在单位算力下实现更高性能,推动行业从"堆算力"向"优化算法"转型。

  • 开源竞争格局

    此次开源被视为对Meta Llama、Mistral等项目的直接挑战,可能加速闭源与开源模型的性能差距缩小。


FlashMLA的发布标志着DeepSeek在高效计算领域的技术领先地位,其开源策略或将重塑大模型开发范式,推动更多低成本、高性能AI应用的涌现。

5.快速开始

安装

可以使用以下命令进行安装:

bash 复制代码
python setup.py install
基准测试

运行以下命令进行基准测试:

bash 复制代码
python tests/test_flash_mla.py
使用示例

在Python中可以这样使用:

python 复制代码
from flash_mla import get_mla_metadata, flash_mla_with_kvcache

tile_scheduler_metadata, num_splits = get_mla_metadata(cache_seqlens, s_q * h_q // h_kv, h_kv)

for i in range(num_layers):
    ...
    o_i, lse_i = flash_mla_with_kvcache(
        q_i, kvcache_i, block_table, cache_seqlens, dv,
        tile_scheduler_metadata, num_splits, causal=True,
    )
    ...

6.核心代码的详细解释

以下是对 FlashMLA/flash_mla/flash_mla_interface.py 文件中:

get_mla_metadata 函数

python 复制代码
def get_mla_metadata(
    cache_seqlens: torch.Tensor,
    num_heads_per_head_k: int,
    num_heads_k: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        cache_seqlens: (batch_size), dtype torch.int32.
        num_heads_per_head_k: Equals to seq_len_q * num_heads_q // num_heads_k.
        num_heads_k: num_heads_k.

    Return:
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), dtype torch.int32.
        num_splits: (batch_size + 1), dtype torch.int32.
    """
    return flash_mla_cuda.get_mla_metadata(cache_seqlens, num_heads_per_head_k, num_heads_k)
  • 功能:该函数用于获取MLA(Multi-Head Attention)的元数据。
  • 参数
    • cache_seqlens:一个形状为 (batch_size)torch.Tensor,数据类型为 torch.int32,表示缓存的序列长度。
    • num_heads_per_head_k:整数类型,其值等于 seq_len_q * num_heads_q // num_heads_k
    • num_heads_k:整数类型,表示 num_heads_k 的值。
  • 返回值
    • tile_scheduler_metadata:形状为 (num_sm_parts, TileSchedulerMetaDataSize)torch.Tensor,数据类型为 torch.int32
    • num_splits:形状为 (batch_size + 1)torch.Tensor,数据类型为 torch.int32
  • 实现细节 :该函数直接调用 flash_mla_cuda 模块中的 get_mla_metadata 函数,并将输入参数传递给它,然后返回该函数的结果。

flash_mla_with_kvcache 函数

python 复制代码
def flash_mla_with_kvcache(
    q: torch.Tensor,
    k_cache: torch.Tensor,
    block_table: torch.Tensor,
    cache_seqlens: torch.Tensor,
    head_dim_v: int,
    tile_scheduler_metadata: torch.Tensor,
    num_splits: torch.Tensor,
    softmax_scale: Optional[float] = None,
    causal: bool = False,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Arguments:
        q: (batch_size, seq_len_q, num_heads_q, head_dim).
        k_cache: (num_blocks, page_block_size, num_heads_k, head_dim).
        block_table: (batch_size, max_num_blocks_per_seq), torch.int32.
        cache_seqlens: (batch_size), torch.int32.
        head_dim_v: Head_dim of v.
        tile_scheduler_metadata: (num_sm_parts, TileSchedulerMetaDataSize), torch.int32, return by get_mla_metadata.
        num_splits: (batch_size + 1), torch.int32, return by get_mla_metadata.
        softmax_scale: float. The scaling of QK^T before applying softmax. Default to 1 / sqrt(head_dim).
        causal: bool. Whether to apply causal attention mask.

    Return:
        out: (batch_size, seq_len_q, num_heads_q, head_dim_v).
        softmax_lse: (batch_size, num_heads_q, seq_len_q), torch.float32.
    """
    if softmax_scale is None:
        softmax_scale = q.shape[-1] ** (-0.5)
    out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
        q,
        k_cache,
        None,
        head_dim_v,
        cache_seqlens,
        block_table,
        softmax_scale,
        causal,
        tile_scheduler_metadata,
        num_splits,
    )
    return out, softmax_lse
  • 功能:该函数用于执行带有键值缓存(KVCache)的MLA操作。
  • 参数
    • q:形状为 (batch_size, seq_len_q, num_heads_q, head_dim)torch.Tensor,表示查询张量。
    • k_cache:形状为 (num_blocks, page_block_size, num_heads_k, head_dim)torch.Tensor,表示键缓存张量。
    • block_table:形状为 (batch_size, max_num_blocks_per_seq)torch.Tensor,数据类型为 torch.int32,表示块表。
    • cache_seqlens:形状为 (batch_size)torch.Tensor,数据类型为 torch.int32,表示缓存的序列长度。
    • head_dim_v:整数类型,表示 v 的头维度。
    • tile_scheduler_metadata:形状为 (num_sm_parts, TileSchedulerMetaDataSize)torch.Tensor,数据类型为 torch.int32,由 get_mla_metadata 函数返回。
    • num_splits:形状为 (batch_size + 1)torch.Tensor,数据类型为 torch.int32,由 get_mla_metadata 函数返回。
    • softmax_scale:可选的浮点数,表示在应用softmax之前对 QK^T 进行缩放的比例,默认为 1 / sqrt(head_dim)
    • causal:布尔类型,表示是否应用因果注意力掩码,默认为 False
  • 返回值
    • out:形状为 (batch_size, seq_len_q, num_heads_q, head_dim_v)torch.Tensor,表示输出张量。
    • softmax_lse:形状为 (batch_size, num_heads_q, seq_len_q)torch.Tensor,数据类型为 torch.float32,表示softmax的对数和指数(LogSumExp)。
  • 实现细节
    • 如果 softmax_scale 未提供,则将其设置为 q 张量最后一个维度的平方根的倒数。
    • 调用 flash_mla_cuda 模块中的 fwd_kvcache_mla 函数,传递相应的参数,并将返回的结果赋值给 outsoftmax_lse
    • 最后返回 outsoftmax_lse

这些函数主要是作为Python接口,调用底层的CUDA实现(flash_mla_cuda 模块)来完成MLA操作和元数据的获取。

相关推荐
邹霍梁@开源软件GoodERP18 分钟前
【AI+智造】DeepSeek价值重构:当采购与物控遇上数字化转型的化学反应
运维·人工智能·制造
zhulu5061 小时前
PyTorch 源码学习:Dispatch & Autograd & Operators
人工智能·pytorch·学习
山海青风2 小时前
从零开始玩转TensorFlow:小明的机器学习故事 5
人工智能·机器学习·tensorflow
小森( ﹡ˆoˆ﹡ )2 小时前
DeepSeek 全面分析报告
人工智能·自然语言处理·nlp
刘大猫262 小时前
十、MyBatis的缓存
大数据·数据结构·人工智能
deephub3 小时前
用PyTorch从零构建 DeepSeek R1:模型架构和分步训练详解
人工智能·pytorch·python·深度学习·deepseek
阿正的梦工坊3 小时前
详解 @符号在 PyTorch 中的矩阵乘法规则
人工智能·pytorch·矩阵
人类群星闪耀时3 小时前
大数据平台上的机器学习模型部署:从理论到实
大数据·人工智能·机器学习
合方圆~小文4 小时前
跨境宠物摄像头是一种专为宠物主人设计的智能设备
java·数据库·人工智能·扩展屏应用开发