Flash Attention学习笔记

文章目录

一、寻找切入点

  • 背景
    对于attention的操作,有QK的转置相乘,走一个softmax,再走一个V。很多人觉得因为有访存的操作所以很耗费时间(比如Q和K乘完要写回显存里面,然后再读取做softmax,再写回去显存,再读取乘个V)
  • 出现原因
    flash attention就想解决这个多次访存的问题,只读取一次和只写回一次,就可以完成1次相乘+1次softmax+1次相乘
  • 源码
    llama.cpp里面有flash attention的源码

二、libtorch环境搭建

  • 下载链接

    pytorch官网

  • cmake找到torch库的位置,拷贝对应cmake路径

  • 编译代码展示

  • 插件安装

  • cmake

CMakeLists.txt 复制代码
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)
project(example-app LANGUAGES CXX CUDA)

# 设置调试模式
set(CMAKE_BUILD_TYPE Debug)

# 添加调试标志
set(CMAKE_CXX_FLAGS_DEBUG "-g -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "-g -G -O0")

set(CMAKE_PREFIX_PATH "/home/kuanglixiang/learn_flashtn/libtorch_w/libtorch/share/cmake/Torch;${CMAKE_PREFIX_PATH}")
find_package(Torch REQUIRED)
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
set(CMAKE_PREFIX_PATH "/home/kuanglixiang/learn_flashtn/libtorch_w/libtorch/share/cmake/Torch;${CMAKE_PREFIX_PATH}")
#add_executable(example-app example-app.cpp)
add_executable(flash_attn_main flash_atten_main.cpp flash.cu)
target_link_libraries(flash_attn_main "${TORCH_LIBRARIES}")
set_property(TARGET flash_attn_main PROPERTY CXX_STANDARD 17)
set_property(TARGET flash_attn_main PROPERTY CUDA_STANDARD 17)

# The following code block is suggested to be used on Windows.
# According to https://github.com/pytorch/pytorch/issues/25457,
# the DLLs need to be copied to avoid memory errors.
# if (MSVC)
#   file(GLOB TORCH_DLLS "${TORCH_INSTALL_PREFIX}/lib/*.dll")
#   add_custom_command(TARGET example-app
#                      POST_BUILD
#                      COMMAND ${CMAKE_COMMAND} -E copy_if_different
#                      ${TORCH_DLLS}
#                      $<TARGET_FILE_DIR:example-app>)
# endif (MSVC)

三、三个矩阵相乘合并

  • 流程讲解
    Q*K做完乘法没写回而是放到shared memory,再读取V做乘法最终再写回
  • 难点
    softmax得到结果才乘上V

四、naive&safe softmax

1)naive softmax

缺点:exp存在不稳定性,数值容易溢出,超过一定范围计算精度会下降

2)safe softmax

遍历了三遍数组,时间复杂度更高

  • 两者代码
c 复制代码
//naive softmax
#include <vector>
#include <cmath>
#include <iostream>
#include <algorithm>

std::vector<float> naive_softmax(const std::vector<float>& input) {
    std::vector<float> output;
    output.reserve(input.size());

    float sum_exp = 0.0f;

    // Calculate the sum of exponentials
    for (float val : input) {
        sum_exp += std::exp(val); // Subtract max_val for numerical stability
    }


    // Calculate the softmax values
    for (float val : input) {
        output.push_back(std::exp(val) / sum_exp);
    }

    return output;
}


std::vector<float> safe_softmax(const std::vector<float>& input) {
    std::vector<float> output;
    output.reserve(input.size());
    
    // Find the maximum value for numerical stability
    float max_val = *std::max_element(input.begin(), input.end());
    float sum_exp = 0.0f;
    
    // Calculate the sum of exponentials with numerical stability
    for (float val : input) {
        sum_exp += std::exp(val - max_val); // Subtract max_val for numerical stability
    }       
    // Calculate the softmax values
    for (float val : input) {
        output.push_back(std::exp(val - max_val) / sum_exp);
    }

    return output;
}


int main() {



}

五、online softmax(重置版本)

c 复制代码
std::vector<float> online_softmax(const std::vector<float>& input) {
    std::vector<float> output;
    output.reserve(input.size());

    float max_val = -999.0f; // Initialize to a very small value
    float prev_max_value = 0.0f;
    float sum_exp = 0.0f;

    for (size_t i = 0; i < input.size(); ++i) {
        max_val = std::max(max_val, input[i]);
        sum_exp = sum_exp*std::exp(prev_max_value-max_val) + std::exp(input[i] - max_val);
        prev_max_value = max_val;
    }

    for (size_t i = 0; i < input.size(); i++)
    {
        output.push_back(std::exp(input[i] - max_val) / sum_exp);
    }

    return output;
    
}
  • 三者输出结果对比
shell 复制代码
(base) k@Skynet:~/learn_flashtn/flash-attention-minimal/build$ ./naive_softmax 
9.58094e-05 0.000451403 3.56126e-09 0.00259765 0.996855 
9.58094e-05 0.000451403 3.56127e-09 0.00259765 0.996855 
9.58094e-05 0.000451403 3.56126e-09 0.00259765 0.996855 
  • 对比函数
c 复制代码
void print_vector(const std::vector<float>& vec) {
    for (const auto& val : vec) {
        std::cout << val << " ";
    }
    std::cout << std::endl;
}


int main() {
    std::vector<float> input = {1.0f, 2.55f, -9.2f, 4.3f,10.25f};
    std::vector<float> result = safe_softmax(input);
    print_vector(result);

    std::vector<float> result2 = naive_softmax(input);
    print_vector(result2);

    std::vector<float> result3 = online_softmax(input);
    print_vector(result3);

    return 0;
}

六、online softmax与value的点积优化

七、cuda算子解析

路径