文章目录
-
- 一、寻找切入点
- 二、libtorch环境搭建
- 三、三个矩阵相乘合并
- [四、naive&safe softmax](#四、naive&safe softmax)
-
- [1)naive softmax](#1)naive softmax)
- [2)safe softmax](#2)safe softmax)
- [五、online softmax(重置版本)](#五、online softmax(重置版本))
- [六、online softmax与value的点积优化](#六、online softmax与value的点积优化)
- 七、cuda算子解析
一、寻找切入点
- 背景
对于attention的操作,有QK的转置相乘,走一个softmax,再走一个V。很多人觉得因为有访存的操作所以很耗费时间(比如Q和K乘完要写回显存里面,然后再读取做softmax,再写回去显存,再读取乘个V) - 出现原因
flash attention就想解决这个多次访存的问题,只读取一次和只写回一次,就可以完成1次相乘+1次softmax+1次相乘 - 源码
llama.cpp里面有flash attention的源码
二、libtorch环境搭建
-
下载链接
pytorch官网
-
cmake找到torch库的位置,拷贝对应cmake路径
-
编译代码展示
-
插件安装
-
cmake
CMakeLists.txt
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(example-app LANGUAGES CXX CUDA)
# 设置调试模式
set(CMAKE_BUILD_TYPE Debug)
# 添加调试标志
set(CMAKE_CXX_FLAGS_DEBUG "-g -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "-g -G -O0")
set(CMAKE_PREFIX_PATH "/home/kuanglixiang/learn_flashtn/libtorch_w/libtorch/share/cmake/Torch;${CMAKE_PREFIX_PATH}")
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
set(CMAKE_PREFIX_PATH "/home/kuanglixiang/learn_flashtn/libtorch_w/libtorch/share/cmake/Torch;${CMAKE_PREFIX_PATH}")
#add_executable(example-app example-app.cpp)
add_executable(flash_attn_main flash_atten_main.cpp flash.cu)
target_link_libraries(flash_attn_main "${TORCH_LIBRARIES}")
set_property(TARGET flash_attn_main PROPERTY CXX_STANDARD 17)
set_property(TARGET flash_attn_main PROPERTY CUDA_STANDARD 17)
# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
# if (MSVC)
# file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
# add_custom_command(TARGET example-app
# POST_BUILD
# COMMAND ${CMAKE_COMMAND} -E copy_if_different
# ${TORCH_DLLS}
# $<TARGET_FILE_DIR:example-app>)
# endif (MSVC)
三、三个矩阵相乘合并

- 流程讲解
Q*K做完乘法没写回而是放到shared memory,再读取V做乘法最终再写回 - 难点
softmax得到结果才乘上V
四、naive&safe softmax
1)naive softmax

缺点:exp存在不稳定性,数值容易溢出,超过一定范围计算精度会下降
2)safe softmax

遍历了三遍数组,时间复杂度更高
- 两者代码
c
//naive softmax
#include <vector>
#include <cmath>
#include <iostream>
#include <algorithm>
std::vector<float> naive_softmax(const std::vector<float>& input) {
std::vector<float> output;
output.reserve(input.size());
float sum_exp = 0.0f;
// Calculate the sum of exponentials
for (float val : input) {
sum_exp += std::exp(val); // Subtract max_val for numerical stability
}
// Calculate the softmax values
for (float val : input) {
output.push_back(std::exp(val) / sum_exp);
}
return output;
}
std::vector<float> safe_softmax(const std::vector<float>& input) {
std::vector<float> output;
output.reserve(input.size());
// Find the maximum value for numerical stability
float max_val = *std::max_element(input.begin(), input.end());
float sum_exp = 0.0f;
// Calculate the sum of exponentials with numerical stability
for (float val : input) {
sum_exp += std::exp(val - max_val); // Subtract max_val for numerical stability
}
// Calculate the softmax values
for (float val : input) {
output.push_back(std::exp(val - max_val) / sum_exp);
}
return output;
}
int main() {
}
五、online softmax(重置版本)


c
std::vector<float> online_softmax(const std::vector<float>& input) {
std::vector<float> output;
output.reserve(input.size());
float max_val = -999.0f; // Initialize to a very small value
float prev_max_value = 0.0f;
float sum_exp = 0.0f;
for (size_t i = 0; i < input.size(); ++i) {
max_val = std::max(max_val, input[i]);
sum_exp = sum_exp*std::exp(prev_max_value-max_val) + std::exp(input[i] - max_val);
prev_max_value = max_val;
}
for (size_t i = 0; i < input.size(); i++)
{
output.push_back(std::exp(input[i] - max_val) / sum_exp);
}
return output;
}
- 三者输出结果对比
shell
(base) k@Skynet:~/learn_flashtn/flash-attention-minimal/build$ ./naive_softmax
9.58094e-05 0.000451403 3.56126e-09 0.00259765 0.996855
9.58094e-05 0.000451403 3.56127e-09 0.00259765 0.996855
9.58094e-05 0.000451403 3.56126e-09 0.00259765 0.996855
- 对比函数
c
void print_vector(const std::vector<float>& vec) {
for (const auto& val : vec) {
std::cout << val << " ";
}
std::cout << std::endl;
}
int main() {
std::vector<float> input = {1.0f, 2.55f, -9.2f, 4.3f,10.25f};
std::vector<float> result = safe_softmax(input);
print_vector(result);
std::vector<float> result2 = naive_softmax(input);
print_vector(result2);
std::vector<float> result3 = online_softmax(input);
print_vector(result3);
return 0;
}
六、online softmax与value的点积优化

