block_sparse_attn 安装

目录

[block_sparse_attn 安装](#block_sparse_attn 安装)

[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)

flash_attention


block_sparse_attn 安装

解决方法:

pip install xformers

或:

pip install flash-attn

bash 复制代码
pip install flash_attn

block_sparse_attn 安装成功

bash 复制代码
cd Block-Sparse-Attention

rm -rf build *.egg-info

export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"

pip install -e . --no-build-isolation

错误测试:

bash 复制代码
python -c "from block_sparse_attn import block_sparse_attn_func"

报错:

bash 复制代码
ImportError: libc10.so: cannot open shared object file: No such file or directory

解决方法:

bash 复制代码
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"

flash_attention

引用头文件:

python 复制代码
try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
    SAGE_ATTN_AVAILABLE = False

from block_sparse_attn import block_sparse_attn_func
python 复制代码
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
    if attention_mask is not None:
        seqlen = q.shape[1]
        seqlen_kv = k.shape[1]
        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
        streaming_info = None
        base_blockmask = attention_mask
        max_seqlen_q_ = seqlen
        max_seqlen_k_ = seqlen_kv
        p_dropout = 0.0
        x = block_sparse_attn_func(
            q, k, v,
            cu_seqlens_q, cu_seqlens_k,
            head_mask_type,
            streaming_info,
            base_blockmask,
            max_seqlen_q_, max_seqlen_k_,
            p_dropout,
            deterministic=False,
            softmax_scale=None,
            is_causal=False,
            exact_streaming=False,
            return_attn_probs=False,
        ).unsqueeze(0)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif compatibility_mode:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_3_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn_interface.flash_attn_func(q, k, v)
        if isinstance(x, tuple):
            x = x[0]
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_2_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn.flash_attn_func(q, k, v)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif SAGE_ATTN_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = sageattn(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    else:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    return x
相关推荐
墨白曦煜1 小时前
RocketMQ 实战:揭秘 @RocketMQMessageListener 的反序列化魔法与“万能”消费策略
开发语言·python·rocketmq
ZTLJQ10 小时前
序列化的艺术:Python JSON处理完全解析
开发语言·python·json
H5css�海秀10 小时前
今天是自学大模型的第一天(sanjose)
后端·python·node.js·php
阿贵---10 小时前
使用XGBoost赢得Kaggle比赛
jvm·数据库·python
无敌昊哥战神10 小时前
【LeetCode 257】二叉树的所有路径(回溯法/深度优先遍历)- Python/C/C++详细题解
c语言·c++·python·leetcode·深度优先
李昊哲小课12 小时前
第1章-PySide6 基础认知与环境配置
python·pyqt·pyside
2401_8942419213 小时前
用Pygame开发你的第一个小游戏
jvm·数据库·python
Zzzz_my14 小时前
正则表达式(RE)
pytorch·python·正则表达式
天天鸭14 小时前
前端仔写了个 AI Agent,才发现大模型只干了 10% 的活
前端·python·ai编程