TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。

用于学习目的。

如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。

这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。

在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

python 复制代码
q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

python 复制代码
import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attention


batch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32


Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)

O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)

print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)

print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM

https://github.com/shreyansh26/FlashAttention-PyTorch

相关推荐
火云洞红孩儿6 小时前
告别界面孤岛:PyMe如何用一站式流程重塑Python GUI开发?
开发语言·python
抓个马尾女孩6 小时前
为什么self-attention除以根号dk而不是其他值
人工智能·深度学习·机器学习·transformer
攻城狮7号6 小时前
不懂代码也能造?TRAE+GLM-4.6 手把手教你搭心理咨询智能客服小程序
python·小程序·uni-app·vue·trae·glm我的编程搭子·glm-4.6
叫我辉哥e16 小时前
新手进阶Python:办公看板集成ERP跨系统同步+自动备份+AI异常复盘
开发语言·人工智能·python
布局呆星6 小时前
闭包与装饰器
开发语言·python
全栈测试笔记7 小时前
异步函数与异步生成器
linux·服务器·前端·数据库·python
木头左7 小时前
基于Backtrader框架的指数期权备兑策略实现与分析
python
素心如月桠7 小时前
cmd 输入 python --version 输出为空(windows11系统安装python后执行python --version没反应)
python
飞Link8 小时前
深度解析 HyperLPR:高性能中文车牌识别框架从入门到实战
python
QQ588501988 小时前
Python_uniapp-心理健康测评服务微信小程序的设计与实现
python·微信小程序·uni-app