block_sparse_attn 安装

目录

[block_sparse_attn 安装](#block_sparse_attn 安装)

[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)

flash_attention


block_sparse_attn 安装

解决方法:

pip install xformers

或:

pip install flash-attn

bash 复制代码
pip install flash_attn

block_sparse_attn 安装成功

bash 复制代码
cd Block-Sparse-Attention

rm -rf build *.egg-info

export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"

pip install -e . --no-build-isolation

错误测试:

bash 复制代码
python -c "from block_sparse_attn import block_sparse_attn_func"

报错:

bash 复制代码
ImportError: libc10.so: cannot open shared object file: No such file or directory

解决方法:

bash 复制代码
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"

flash_attention

引用头文件:

python 复制代码
try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
    SAGE_ATTN_AVAILABLE = False

from block_sparse_attn import block_sparse_attn_func
python 复制代码
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
    if attention_mask is not None:
        seqlen = q.shape[1]
        seqlen_kv = k.shape[1]
        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
        streaming_info = None
        base_blockmask = attention_mask
        max_seqlen_q_ = seqlen
        max_seqlen_k_ = seqlen_kv
        p_dropout = 0.0
        x = block_sparse_attn_func(
            q, k, v,
            cu_seqlens_q, cu_seqlens_k,
            head_mask_type,
            streaming_info,
            base_blockmask,
            max_seqlen_q_, max_seqlen_k_,
            p_dropout,
            deterministic=False,
            softmax_scale=None,
            is_causal=False,
            exact_streaming=False,
            return_attn_probs=False,
        ).unsqueeze(0)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif compatibility_mode:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_3_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn_interface.flash_attn_func(q, k, v)
        if isinstance(x, tuple):
            x = x[0]
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_2_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn.flash_attn_func(q, k, v)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif SAGE_ATTN_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = sageattn(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    else:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    return x
相关推荐
花酒锄作田1 小时前
企业微信机器人与 DeepAgents 集成实践
python·mcp·deepagents
likerhood3 小时前
java中`==`和`.equals()`区别
java·开发语言·python
qq_283720054 小时前
Python Celery + FastAPI + Vue 全栈异步任务实战
vue.js·python·fastapi
2401_885885044 小时前
营销推广短信接口集成:结合营销策略实现的API接口动态变量填充方案
前端·python
telllong5 小时前
Python异步编程从入门到不懵:asyncio实战踩坑7连发
开发语言·python
lulu12165440787 小时前
Claude Code Harness架构技术深度解析:生产级AI Agent工程化实践
java·人工智能·python·ai编程
7年前端辞职转AI9 小时前
Python 文件操作
python·编程语言
龙文浩_9 小时前
AI梯度下降与PyTorch张量操作技术指南
人工智能·pytorch·python·深度学习·神经网络·机器学习·自然语言处理
呱牛do it9 小时前
企业级绩效考核系统设计与实现:基于FastAPI + Vue3的全栈解决方案
python·fastapi
7年前端辞职转AI9 小时前
Python 容器数据类型
python·编程语言