TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。

用于学习目的。

如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。

这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。

在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

python 复制代码
q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

python 复制代码
import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attention


batch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32


Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)

O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)

print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)

print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM

https://github.com/shreyansh26/FlashAttention-PyTorch

相关推荐
铁皮哥8 分钟前
【后端/Agent 开发】给你的项目配置一套 .claude/ 工作流:别再裸用 Claude Code 了!
java·windows·python·spring·github·maven·生活
m0_6315298225 分钟前
CSS如何利用CSS变量进行渐变色管理_提升渐变配置的灵活性
jvm·数据库·python
2301_8180084439 分钟前
数据库模型设计实战:如何正向工程从模型建表_规范化项目开发流程
jvm·数据库·python
科研前沿43 分钟前
多视角相机驱动的室内人员空间定位技术白皮书
大数据·人工智能·python·科技·数码相机·音视频
覆东流1 小时前
第10天:python元组
开发语言·后端·python
万事大吉CC1 小时前
【5】Django 的模板语言:页面架构设计
后端·python·django
罗西的思考1 小时前
【GUI-Agent】阿里通义MAI-UI 代码阅读(1)— 总体
人工智能·机器学习·ui·transformer
码界奇点2 小时前
基于Python的微信公众号爬虫系统设计与实现
开发语言·爬虫·python·毕业设计·web·源代码管理
2401_846339562 小时前
Vue 3 中集成 Three.js 场景的完整实现指南
jvm·数据库·python
落雪寒窗-2 小时前
Python开发个人日常记录
开发语言·python