TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。

用于学习目的。

如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。

这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。

在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

python 复制代码
q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

python 复制代码
import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attention


batch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32


Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)

O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)

print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)

print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM

https://github.com/shreyansh26/FlashAttention-PyTorch

相关推荐
AI模块工坊5 分钟前
CVPR 即插即用 | 当RetNet遇见ViT:一场来自曼哈顿的注意力革命,中科院刷新SOTA性能榜!
人工智能·深度学习·计算机视觉·transformer
*才华有限公司*23 分钟前
基于BERT的文本分类模型训练全流程:从环境搭建到显存优化实战
python
Lxinccode1 小时前
python(59) : 多线程调用大模型ocr提取图片文本
开发语言·python·图片提取文字·批量提取文件·多线程ocr
梁辰兴1 小时前
PyCharm使用了Conda的虚拟环境创建的的Python项目,下载库(包)到该项目的虚拟环境中
python·pycharm·conda·错误·异常·异常报错
自由日记1 小时前
python简单线性回归
开发语言·python·线性回归
Halo_tjn2 小时前
Set集合专项实验
java·开发语言·前端·python
vvoennvv3 小时前
【Python TensorFlow】 BiTCN-LSTM双向时间序列卷积长短期记忆神经网络时序预测算法(附代码)
python·神经网络·tensorflow·lstm·tcn
q***42053 小时前
python的sql解析库-sqlparse
数据库·python·sql
大数据追光猿4 小时前
LangChain / LangGraph / AutoGPT / CrewAI / AutoGen 五大框架对比
经验分享·笔记·python·langchain·agent
wang_yb4 小时前
别急着转投 Polars!Pandas 3.0 带着“黑科技”杀回来了
python·databook