目录
[block_sparse_attn 安装](#block_sparse_attn 安装)
[block_sparse_attn 安装成功](#block_sparse_attn 安装成功)
block_sparse_attn 安装
解决方法:
pip install xformers
或:
pip install flash-attn
bash
pip install flash_attn
block_sparse_attn 安装成功
bash
cd Block-Sparse-Attention
rm -rf build *.egg-info
export TORCH_CUDA_ARCH_LIST="8.0;8.6;8.9"
pip install -e . --no-build-isolation
错误测试:
bash
python -c "from block_sparse_attn import block_sparse_attn_func"
报错:
bash
ImportError: libc10.so: cannot open shared object file: No such file or directory
解决方法:
bash
python -c "import torch;from block_sparse_attn import block_sparse_attn_func"
flash_attention
引用头文件:
python
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
from sageattention import sageattn
SAGE_ATTN_AVAILABLE = True
except ModuleNotFoundError:
SAGE_ATTN_AVAILABLE = False
from block_sparse_attn import block_sparse_attn_func
python
def flash_attention(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, num_heads: int, compatibility_mode=False, attention_mask=None, return_KV=False):
if attention_mask is not None:
seqlen = q.shape[1]
seqlen_kv = k.shape[1]
q = rearrange(q, "b s (n d) -> (b s) n d", n=num_heads)
k = rearrange(k, "b s (n d) -> (b s) n d", n=num_heads)
v = rearrange(v, "b s (n d) -> (b s) n d", n=num_heads)
cu_seqlens_q = torch.tensor([0, seqlen], device=q.device, dtype=torch.int32)
cu_seqlens_k = torch.tensor([0, seqlen_kv], device=q.device, dtype=torch.int32)
head_mask_type = torch.tensor([1]*num_heads, device=q.device, dtype=torch.int32)
streaming_info = None
base_blockmask = attention_mask
max_seqlen_q_ = seqlen
max_seqlen_k_ = seqlen_kv
p_dropout = 0.0
x = block_sparse_attn_func(
q, k, v,
cu_seqlens_q, cu_seqlens_k,
head_mask_type,
streaming_info,
base_blockmask,
max_seqlen_q_, max_seqlen_k_,
p_dropout,
deterministic=False,
softmax_scale=None,
is_causal=False,
exact_streaming=False,
return_attn_probs=False,
).unsqueeze(0)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif compatibility_mode:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_3_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn_interface.flash_attn_func(q, k, v)
if isinstance(x, tuple):
x = x[0]
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif FLASH_ATTN_2_AVAILABLE:
q = rearrange(q, "b s (n d) -> b s n d", n=num_heads)
k = rearrange(k, "b s (n d) -> b s n d", n=num_heads)
v = rearrange(v, "b s (n d) -> b s n d", n=num_heads)
x = flash_attn.flash_attn_func(q, k, v)
x = rearrange(x, "b s n d -> b s (n d)", n=num_heads)
elif SAGE_ATTN_AVAILABLE:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = sageattn(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
else:
q = rearrange(q, "b s (n d) -> b n s d", n=num_heads)
k = rearrange(k, "b s (n d) -> b n s d", n=num_heads)
v = rearrange(v, "b s (n d) -> b n s d", n=num_heads)
x = F.scaled_dot_product_attention(q, k, v)
x = rearrange(x, "b n s d -> b s (n d)", n=num_heads)
return x