block_sparse_attn 安装

目录

[block_sparse_attn 安装](#block_sparse_attn 安装)

[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)

flash_attention


block_sparse_attn 安装

解决方法:

pip install xformers

或:

pip install flash-attn

bash 复制代码
pip install flash_attn

block_sparse_attn 安装成功

bash 复制代码
cd Block-Sparse-Attention

rm -rf build *.egg-info

export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"

pip install -e . --no-build-isolation

错误测试:

bash 复制代码
python -c "from block_sparse_attn import block_sparse_attn_func"

报错:

bash 复制代码
ImportError: libc10.so: cannot open shared object file: No such file or directory

解决方法:

bash 复制代码
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"

flash_attention

引用头文件:

python 复制代码
try:
    import flash_attn_interface
    FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_3_AVAILABLE = False

try:
    import flash_attn
    FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
    FLASH_ATTN_2_AVAILABLE = False

try:
    from sageattention import sageattn
    SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
    SAGE_ATTN_AVAILABLE = False

from block_sparse_attn import block_sparse_attn_func
python 复制代码
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
    if attention_mask is not None:
        seqlen = q.shape[1]
        seqlen_kv = k.shape[1]
        q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
        cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
        cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
        head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
        streaming_info = None
        base_blockmask = attention_mask
        max_seqlen_q_ = seqlen
        max_seqlen_k_ = seqlen_kv
        p_dropout = 0.0
        x = block_sparse_attn_func(
            q, k, v,
            cu_seqlens_q, cu_seqlens_k,
            head_mask_type,
            streaming_info,
            base_blockmask,
            max_seqlen_q_, max_seqlen_k_,
            p_dropout,
            deterministic=False,
            softmax_scale=None,
            is_causal=False,
            exact_streaming=False,
            return_attn_probs=False,
        ).unsqueeze(0)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif compatibility_mode:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_3_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn_interface.flash_attn_func(q, k, v)
        if isinstance(x, tuple):
            x = x[0]
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif FLASH_ATTN_2_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
        x = flash_attn.flash_attn_func(q, k, v)
        x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
    elif SAGE_ATTN_AVAILABLE:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = sageattn(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    else:
        q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
        k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
        v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
        x = F.scaled_dot_product_attention(q, k, v)
        x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
    return x
相关推荐
AI砖家4 分钟前
Claude Code Superpowers 安装使用指南:让 AI 编程从“业余”走向“工程化”
前端·人工智能·python·ai编程·代码规范
计算机毕业编程指导师21 分钟前
【计算机毕设推荐】Python+Spark卵巢癌风险数据可视化系统完整实现 毕业设计 选题推荐 毕设选题 数据分析 机器学习 数据挖掘
hadoop·python·计算机·数据挖掘·spark·毕业设计·卵巢癌
玩转单片机与嵌入式23 分钟前
学习嵌入式AI(TInyML),只需掌握这点python基础即可!
人工智能·python·学习
少年执笔26 分钟前
ollama搭建本地模型框架
python·ai
极光代码工作室29 分钟前
基于大数据的校园消费行为分析系统
大数据·hadoop·python·数据分析·spark
A__tao1 小时前
JSON 转 Java 实体类工具(支持嵌套与注释解析)
java·python·json
zhouwy1131 小时前
Python 基础语法笔记:从入门到进阶的系统学习
python
高洁012 小时前
工程科研中的AI应用:结构力学分析技巧
python·深度学习·机器学习·数据挖掘·知识图谱
大邳草民2 小时前
Python 爬虫:从 HTTP 请求到接口分析
笔记·爬虫·python