block_sparse_attn 安装

目录

[block_sparse_attn 安装](#block_sparse_attn 安装)

[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)

flash_attention


block_sparse_attn 安装

解决方法:

pip install xformers

或:

pip install flash-attn

bash 复制代码
pip install flash_attn

block_sparse_attn 安装成功

bash 复制代码
cd Block-Sparse-Attention

rm -rf build *.egg-info

export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"

pip install -e . --no-build-isolation

错误测试:

bash 复制代码
python -c "from block_sparse_attn import block_sparse_attn_func"

报错:

bash 复制代码
ImportError: libc10.so: cannot open shared object file: No such file or directory

解决方法:

bash 复制代码
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"

flash_attention

引用头文件:

python 复制代码
try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
    SAGE_ATTN_AVAILABLE = False

from block_sparse_attn import block_sparse_attn_func
python 复制代码
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
    if attention_mask is not None:
        seqlen = q.shape[1]
        seqlen_kv = k.shape[1]
        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
        streaming_info = None
        base_blockmask = attention_mask
        max_seqlen_q_ = seqlen
        max_seqlen_k_ = seqlen_kv
        p_dropout = 0.0
        x = block_sparse_attn_func(
            q, k, v,
            cu_seqlens_q, cu_seqlens_k,
            head_mask_type,
            streaming_info,
            base_blockmask,
            max_seqlen_q_, max_seqlen_k_,
            p_dropout,
            deterministic=False,
            softmax_scale=None,
            is_causal=False,
            exact_streaming=False,
            return_attn_probs=False,
        ).unsqueeze(0)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif compatibility_mode:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_3_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn_interface.flash_attn_func(q, k, v)
        if isinstance(x, tuple):
            x = x[0]
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_2_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn.flash_attn_func(q, k, v)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif SAGE_ATTN_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = sageattn(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    else:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    return x
相关推荐
曲幽39 分钟前
别再用网页翻译看源码了!你的私人翻译神器LibreTranslate,部署避坑指南来了
python·docker·web·pot·translate·libretranslate·arogstranslate
用户556918817532 小时前
#从脚本到独立程序:Python + Playwright 批量抓取的完整踩坑记录
python·自动化运维
兵慌码乱16 小时前
基于 MediaPipe 与 PySide2 的手势交互音乐控制系统实现:轻量化视觉交互全流程解析
python·opencv·计算机视觉·人机交互·手势识别·mediapipe·pyside2
luckdewei19 小时前
FastAPI 资产管理系统实战:复杂 ORM 关联、Alembic 迁移与 N+1 查询优化
python
aqi001 天前
15天学会AI应用开发(八)使用向量数据库实现RAG功能
人工智能·python·大模型·ai编程·ai应用
Csvn1 天前
`functools.lru_cache` —— 一行代码搞定缓存加速
后端·python
金銀銅鐵2 天前
[Python] 从《千字文》中随机挑选汉字
后端·python
cup112 天前
[技术复盘] Windows Python 打包实战:Nuitka 环境踩坑总结与 CI 自动化构建全指南
python·ai·环境变量·ci·nuitka·skill