TransnormerLLM 中 FlashLinearAttention 的纯pytorch实现

Github 仓库:https://github.com/One-sixth/flash-linear-attention-pytorch

flash-linear-attention-pytorch

纯 Pytorch 实现 TransnormerLLM 中快速线性注意力算子。

用于学习目的。

如果你希望用于训练模型,你可能要修改为 CUDA 或 Triton 的实现,不然会很慢。

注意

这个算子有精度问题,误差较大,是正常的。

这是因为注意力矩阵没有激活函数,导致注意力矩阵的值很大。

在使用 float16 类型时需要特别小心。

这是一个简单的缓解方法:限制 q 和 k 的值,从而减少float16溢出的可能性。

python 复制代码
q = q / q.norm(-1, keepdim=True)
k = k / k.norm(-1, keepdim=True)
o = linear_attention(q, k, v, m)

使用方法

python 复制代码
import torch
from flash_linear_attention_ops import flash_linear_attention, normal_linear_attention


batch_size = 16
seq_len = 1024
dim = 64
n_head = 12
device = 'cuda'
dtype = torch.float32


Q = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
K = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
V = torch.randn(batch_size, n_head, seq_len, dim, requires_grad=True, dtype=dtype, device=device)
M = torch.randint(0, 2, (1, 1, seq_len, seq_len), device=device, dtype=dtype)

O_flash = flash_linear_attention(Q, K, V, M)
O_normal = normal_linear_attention(Q, K, V, M)

print('O_flash.shape', O_flash.shape)
print('O_normal.shape', O_normal.shape)

print('O diff', (O_flash - O_normal).abs().max().item())

参考引用

https://github.com/OpenNLPLab/TransnormerLLM

https://github.com/shreyansh26/FlashAttention-PyTorch

相关推荐
机器学习之心3 分钟前
MATLAB基于GWO优化Transformer多输入多输出回归预测与改进NSGA III的多目标优化
transformer·gwo-transformer·多输入多输出回归预测·改进nsgaiii的多目标优化
彼岸花开了吗6 分钟前
构建AI智能体:七十八、参数的艺术:如何在有限算力下实现高质量的AI诗歌创作
人工智能·python·llm
guoketg19 分钟前
Vision Transformer(ViT)的讲解和面试题目讲解
人工智能·python·深度学习·vit
小oo呆26 分钟前
【学习心得】Python的Pydantic(简介)
前端·javascript·python
岚天start27 分钟前
【日志监控方案】Python脚本获取关键字日志信息并推送钉钉告警
python·钉钉·日志监控
叫我:松哥29 分钟前
基于 Flask 框架开发的在线学习平台,集成人工智能技术,提供分类练习、随机练习、智能推荐等多种学习模式
人工智能·后端·python·学习·信息可视化·flask·推荐算法
rgeshfgreh29 分钟前
Python环境管理:uv极速对决Conda全能
python
幻云201030 分钟前
Python机器学习:从入门到精通
python
热爱专研AI的学妹38 分钟前
2026世界杯观赛工具自制指南:实时比分推送机器人搭建思路
开发语言·人工智能·python·业界资讯
热心不起来的市民小周41 分钟前
测测你的牌:基于 MobileNetV2 的车牌内容检测
python·深度学习·计算机视觉