TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。

用于学习目的。

如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。

这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。

在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

python 复制代码
q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

python 复制代码
import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attention


batch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32


Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)

O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)

print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)

print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM

https://github.com/shreyansh26/FlashAttention-PyTorch

相关推荐
知乎的哥廷根数学学派4 小时前
面向可信机械故障诊断的自适应置信度惩罚深度校准算法(Pytorch)
人工智能·pytorch·python·深度学习·算法·机器学习·矩阵
且去填词4 小时前
DeepSeek :基于 Schema 推理与自愈机制的智能 ETL
数据仓库·人工智能·python·语言模型·etl·schema·deepseek
人工干智能4 小时前
OpenAI Assistants API 中 client.beta.threads.messages.create方法,兼谈一星*和两星**解包
python·llm
databook5 小时前
当条形图遇上极坐标:径向与圆形条形图的视觉革命
python·数据分析·数据可视化
Hcoco_me5 小时前
大模型面试题61:Flash Attention中online softmax(在线softmax)的实现方式
人工智能·深度学习·自然语言处理·transformer·vllm
阿部多瑞 ABU5 小时前
`chenmo` —— 可编程元叙事引擎 V2.3+
linux·人工智能·python·ai写作
acanab5 小时前
VScode python插件
ide·vscode·python
知乎的哥廷根数学学派6 小时前
基于生成对抗U-Net混合架构的隧道衬砌缺陷地质雷达数据智能反演与成像方法(以模拟信号为例,Pytorch)
开发语言·人工智能·pytorch·python·深度学习·机器学习
WangYaolove13146 小时前
Python基于大数据的电影市场预测分析(源码+文档)
python·django·毕业设计·源码
知乎的哥廷根数学学派7 小时前
基于自适应多尺度小波核编码与注意力增强的脉冲神经网络机械故障诊断(Pytorch)
人工智能·pytorch·python·深度学习·神经网络·机器学习