block_sparse_attn 安装

目录

[block_sparse_attn 安装](#block_sparse_attn 安装)

[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)

flash_attention


block_sparse_attn 安装

解决方法:

pip install xformers

或:

pip install flash-attn

bash 复制代码
pip install flash_attn

block_sparse_attn 安装成功

bash 复制代码
cd Block-Sparse-Attention

rm -rf build *.egg-info

export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"

pip install -e . --no-build-isolation

错误测试:

bash 复制代码
python -c "from block_sparse_attn import block_sparse_attn_func"

报错:

bash 复制代码
ImportError: libc10.so: cannot open shared object file: No such file or directory

解决方法:

bash 复制代码
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"

flash_attention

引用头文件:

python 复制代码
try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
    SAGE_ATTN_AVAILABLE = False

from block_sparse_attn import block_sparse_attn_func
python 复制代码
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
    if attention_mask is not None:
        seqlen = q.shape[1]
        seqlen_kv = k.shape[1]
        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
        streaming_info = None
        base_blockmask = attention_mask
        max_seqlen_q_ = seqlen
        max_seqlen_k_ = seqlen_kv
        p_dropout = 0.0
        x = block_sparse_attn_func(
            q, k, v,
            cu_seqlens_q, cu_seqlens_k,
            head_mask_type,
            streaming_info,
            base_blockmask,
            max_seqlen_q_, max_seqlen_k_,
            p_dropout,
            deterministic=False,
            softmax_scale=None,
            is_causal=False,
            exact_streaming=False,
            return_attn_probs=False,
        ).unsqueeze(0)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif compatibility_mode:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_3_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn_interface.flash_attn_func(q, k, v)
        if isinstance(x, tuple):
            x = x[0]
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_2_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn.flash_attn_func(q, k, v)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif SAGE_ATTN_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = sageattn(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    else:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    return x
相关推荐
weixin_468466851 天前
全局与局部注意力机制新手实战指南
人工智能·python·深度学习·算法·自然语言处理·transformer·注意力机制
小糖学代码1 天前
LLM系列:环境搭建:5.Python-dotenv 环境变量管理
人工智能·python·深度学习·神经网络
智慧物业老杨1 天前
智慧物业合同周期管理系统:从风险预警到智能交接的全流程数智化落地方案
java·人工智能·python
橙橙笔记1 天前
Python的学习第一部分
python·学习
voidmort1 天前
3. 微调(Fine-tuning)与强化学习(RL)的核心思想
python·深度学习·算法
biter down1 天前
基于 Pywinauto 的 QQ 音乐 GUI 自动化测试实践
python
人道领域1 天前
【LeetCode刷题日记】669.修剪二叉搜索树
开发语言·python·算法
EntyIU1 天前
mineru从安装部署到测试使用完整指南
python·ocr
安替-AnTi1 天前
厚朴 APK 搜索接口分析
python·apk·解析·taobao
山川湖海1 天前
AI时代快速学编程语言的陷阱(以Python为例)
大数据·人工智能·python