深度学习算子CUDA优化实战:从GEMM到Transformer—Week4学习总结

深度学习算子CUDA优化实战:从GEMM到Transformer

副标题:系统掌握DL算子优化技术,构建高性能Transformer

经过前三周的CUDA基础学习,这周我们终于要进入深度学习领域的核心战场了。说实话,当我第一次看到Transformer的GEMM优化能提升10倍性能时,那种震撼感至今难忘。这周我们会深入三个最关键的深度学习算子:GEMM、Softmax和LayerNorm,最后把它们组装成一个完整的Transformer Layer。

一、为什么要优化这些算子?

在开始写代码之前,我想先聊聊为什么要专门优化这几个算子。

整个训练过程中可能有接近85%的时间都耗在矩阵乘法(GEMM)上。剩下的15%里,Softmax和LayerNorm又占了大头。换句话说,如果你能把这三个算子优化到极致,基本上就优化了整个模型

这也是为什么NVIDIA、AMD这些GPU厂商会在硬件层面专门为矩阵乘法设计Tensor Core这样的加速单元。不夸张地说,GEMM就是深度学习的基石:

  • 全连接层:直接就是矩阵乘法
  • 卷积层:可以转化成im2col + GEMM
  • Attention机制:QKV三个矩阵乘法,外加attention score的计算

所以这周的学习,我是抱着"啃硬骨头"的心态来的。

二、GEMM优化:与cuBLAS的性能对决

2.1 从Naive实现开始

最开始,我写了一个最简单的GEMM实现,每个线程计算输出矩阵的一个元素:

cpp 复制代码
__global__ void gemm_naive(float* A, float* B, float* C, 
                           int M, int N, int K) {
    int row = blockIdx.y * blockDim.y + threadIdx.y;
    int col = blockIdx.x * blockDim.x + threadIdx.x;
    
    if (row < M && col < N) {
        float sum = 0.0f;
        for (int k = 0; k < K; k++) {
            sum += A[row * K + k] * B[k * N + col];
        }
        C[row * N + col] = sum;
    }
}

这个版本的问题显而易见:每个元素都要从Global Memory读取K次,完全没有数据复用。在1024x1024的矩阵上跑,耗时大概100ms,算下来只有20 GFLOPS,连cuBLAS性能的10%都不到。

2.2 Shared Memory优化:性能提升6倍

第一个优化思路是使用Shared Memory做Tiling(分块)。核心想法是:让一个tile内的所有线程共享读取的数据,从而实现数据复用

cpp 复制代码
#define TILE_SIZE 32

__global__ void gemm_shared(float* A, float* B, float* C,
                           int M, int N, int K) {
    __shared__ float As[TILE_SIZE][TILE_SIZE];
    __shared__ float Bs[TILE_SIZE][TILE_SIZE];
    
    int row = blockIdx.y * TILE_SIZE + threadIdx.y;
    int col = blockIdx.x * TILE_SIZE + threadIdx.x;
    
    float sum = 0.0f;
    
    // 分块计算
    for (int t = 0; t < (K + TILE_SIZE - 1) / TILE_SIZE; t++) {
        // 协作加载tile到shared memory
        if (row < M && t * TILE_SIZE + threadIdx.x < K)
            As[threadIdx.y][threadIdx.x] = A[row * K + t * TILE_SIZE + threadIdx.x];
        else
            As[threadIdx.y][threadIdx.x] = 0.0f;
            
        if (col < N && t * TILE_SIZE + threadIdx.y < K)
            Bs[threadIdx.y][threadIdx.x] = B[(t * TILE_SIZE + threadIdx.y) * N + col];
        else
            Bs[threadIdx.y][threadIdx.x] = 0.0f;
            
        __syncthreads();
        
        // 使用shared memory计算
        for (int k = 0; k < TILE_SIZE; k++) {
            sum += As[threadIdx.y][k] * Bs[k][threadIdx.x];
        }
        
        __syncthreads();
    }
    
    if (row < M && col < N) {
        C[row * N + col] = sum;
    }
}

这个版本把时间降到了15ms左右,达到了130 GFLOPS,大概是cuBLAS的65%。提升主要来自两点:

  1. 数据复用:每个元素被TILE_SIZE个线程复用,Global Memory访问次数从O(MNK)降到O(MNK/TILE_SIZE)
  2. Shared Memory带宽:访问shared memory比global memory快10倍以上

2.3 寄存器分块:逼近cuBLAS性能

但是这里还有优化空间。仔细观察会发现,每个线程在内层循环中反复访问shared memory。我们可以进一步让每个线程计算输出的一个小块(比如4x4),这样可以更好地利用寄存器:

cpp 复制代码
#define BM 128
#define BN 128
#define BK 8
#define TM 8
#define TN 8

__global__ void gemm_optimized(float* A, float* B, float* C,
                               int M, int N, int K) {
    __shared__ float As[BM][BK];
    __shared__ float Bs[BK][BN];
    
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int bx = blockIdx.x;
    int by = blockIdx.y;
    
    // 每个线程计算TM x TN个输出元素
    float regC[TM][TN] = {0.0f};
    float regA[TM];
    float regB[TN];
    
    // 外层循环:遍历K维度
    for (int ko = 0; ko < K; ko += BK) {
        // 协作加载tile (省略边界检查)
        // ...
        
        __syncthreads();
        
        // 内层循环:使用寄存器分块计算
        for (int ki = 0; ki < BK; ki++) {
            // 加载到寄存器
            for (int i = 0; i < TM; i++) {
                regA[i] = As[ty * TM + i][ki];
            }
            for (int j = 0; j < TN; j++) {
                regB[j] = Bs[ki][tx * TN + j];
            }
            
            // 寄存器级别的矩阵乘法
            for (int i = 0; i < TM; i++) {
                for (int j = 0; j < TN; j++) {
                    regC[i][j] += regA[i] * regB[j];
                }
            }
        }
        
        __syncthreads();
    }
    
    // 写回结果
    // ...
}

这个版本耗时降到了10ms,达到200 GFLOPS,终于和cuBLAS打平了!关键优化点在于:

  • 减少shared memory访问:从每次计算都访问shared memory,变成批量加载到寄存器
  • 指令级并行:寄存器分块让编译器可以更好地调度指令流水线

2.4 性能对比总结

实现版本 耗时 性能(GFLOPS) vs cuBLAS
Naive 100ms 20 10%
Shared Memory 15ms 130 65%
Register Tiling 10ms 200 100%

这个优化过程让我深刻理解了GPU的内存层次结构。每一层优化都是在减少对慢速内存的访问,增加对快速内存的复用。

三、Softmax优化:数值稳定性与性能的平衡

Softmax看起来是个简单的算子,但实际优化起来坑不少。

3.1 数值稳定性问题

最初我写了个naive版本:

cpp 复制代码
__global__ void softmax_naive(float* input, float* output, int N) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    if (idx >= N) return;
    
    float sum = 0.0f;
    for (int i = 0; i < N; i++) {
        sum += expf(input[i]);
    }
    output[idx] = expf(input[idx]) / sum;
}

结果在测试时遇到了infnan。问题出在哪?当输入值较大时,exp(x)会溢出:

cpp 复制代码
expf(1000.0f);  // 结果是inf!

解决方法是利用Softmax的性质:减去最大值不改变结果

复制代码
softmax(x) = softmax(x - max(x))

这样可以保证指数运算的输入都是非正数,避免溢出:

cpp 复制代码
__global__ void softmax_stable(float* input, float* output, int N) {
    // 第一遍:找最大值
    __shared__ float max_val;
    if (threadIdx.x == 0) {
        max_val = input[0];
        for (int i = 1; i < N; i++) {
            max_val = fmaxf(max_val, input[i]);
        }
    }
    __syncthreads();
    
    // 第二遍:计算exp和sum
    __shared__ float sum_val;
    if (threadIdx.x == 0) {
        sum_val = 0.0f;
        for (int i = 0; i < N; i++) {
            sum_val += expf(input[i] - max_val);
        }
    }
    __syncthreads();
    
    // 第三遍:归一化
    int idx = threadIdx.x;
    if (idx < N) {
        output[idx] = expf(input[idx] - max_val) / sum_val;
    }
}

但这个版本需要三次遍历,效率不高。

3.2 Online Softmax:一次遍历搞定

Online Softmax算法可以在一次遍历中同时维护最大值和指数和,这是我这周学到的最巧妙的算法之一:

cpp 复制代码
__global__ void softmax_online(float* input, float* output, int N) {
    int idx = threadIdx.x;
    
    __shared__ float s_max;
    __shared__ float s_sum;
    
    if (idx == 0) {
        float max_val = input[0];
        float sum = 0.0f;
        
        for (int i = 0; i < N; i++) {
            float x = input[i];
            float old_max = max_val;
            max_val = fmaxf(max_val, x);
            
            // 关键:更新sum时要考虑max的变化
            sum = sum * expf(old_max - max_val) + expf(x - max_val);
        }
        
        s_max = max_val;
        s_sum = sum;
    }
    __syncthreads();
    
    if (idx < N) {
        output[idx] = expf(input[idx] - s_max) / s_sum;
    }
}

核心思想是:当发现新的最大值时,要同步调整之前累积的sum。公式推导如下:

复制代码
设old_max为旧最大值,new_max为新最大值
旧的sum = Σ exp(xi - old_max)
新的sum = Σ exp(xi - new_max)
        = Σ exp(xi - old_max) * exp(old_max - new_max)
        = old_sum * exp(old_max - new_max) + exp(new_x - new_max)

这个优化让性能提升了约2倍,从三次遍历降到一次遍历。

四、LayerNorm优化:Welford算法的威力

LayerNorm在Transformer中无处不在,每个sub-layer后面都要来一次。其基本公式是:

复制代码
output = gamma * (x - mean) / sqrt(variance + epsilon) + beta

4.1 传统实现的问题

naive版本需要两次遍历:

cpp 复制代码
// 第一遍:计算均值
float mean = 0.0f;
for (int i = 0; i < N; i++) {
    mean += x[i];
}
mean /= N;

// 第二遍:计算方差
float variance = 0.0f;
for (int i = 0; i < N; i++) {
    variance += (x[i] - mean) * (x[i] - mean);
}
variance /= N;

这在数值稳定性上也有问题:当均值很大时,(x[i] - mean)可能导致精度损失。

4.2 Welford在线算法

Welford算法可以在一次遍历中同时计算均值和方差,且数值稳定:

cpp 复制代码
__device__ void welford_update(float x, int count, 
                               float& mean, float& M2) {
    float delta = x - mean;
    mean += delta / count;
    float delta2 = x - mean;
    M2 += delta * delta2;
}

__global__ void layernorm_welford(float* input, float* output,
                                  float* gamma, float* beta,
                                  int N, float epsilon) {
    int idx = threadIdx.x;
    
    __shared__ float s_mean;
    __shared__ float s_variance;
    
    // 单线程使用Welford算法
    if (idx == 0) {
        float mean = 0.0f;
        float M2 = 0.0f;
        
        for (int i = 0; i < N; i++) {
            welford_update(input[i], i + 1, mean, M2);
        }
        
        s_mean = mean;
        s_variance = M2 / N;
    }
    __syncthreads();
    
    // 并行归一化
    if (idx < N) {
        float normalized = (input[idx] - s_mean) / sqrtf(s_variance + epsilon);
        output[idx] = gamma[idx] * normalized + beta[idx];
    }
}

Welford算法的数学原理很优雅:

复制代码
mean_new = mean + (x - mean) / count
M2_new = M2 + (x - mean) * (x - mean_new)
variance = M2 / count

这里M2是"平方差之和"的累积量,通过增量更新避免了存储所有数据。

4.3 Kernel Fusion进一步优化

实际应用中,我们通常会把均值/方差计算和归一化融合到一个kernel里,避免多次kernel启动的overhead:

cpp 复制代码
__global__ void layernorm_fused(float* input, float* output,
                                float* gamma, float* beta,
                                int N, float epsilon) {
    extern __shared__ float shared[];
    float* s_data = shared;
    
    int tid = threadIdx.x;
    int idx = blockIdx.x * N + tid;
    
    // 加载到shared memory
    if (tid < N) {
        s_data[tid] = input[idx];
    }
    __syncthreads();
    
    // 并行计算局部统计量
    float local_sum = 0.0f;
    float local_sq_sum = 0.0f;
    for (int i = tid; i < N; i += blockDim.x) {
        float val = s_data[i];
        local_sum += val;
        local_sq_sum += val * val;
    }
    
    // Warp级别规约
    for (int offset = 16; offset > 0; offset /= 2) {
        local_sum += __shfl_down_sync(0xffffffff, local_sum, offset);
        local_sq_sum += __shfl_down_sync(0xffffffff, local_sq_sum, offset);
    }
    
    __shared__ float s_mean;
    __shared__ float s_variance;
    
    if (tid % 32 == 0) {
        atomicAdd(&s_mean, local_sum);
        atomicAdd(&s_variance, local_sq_sum);
    }
    __syncthreads();
    
    if (tid == 0) {
        s_mean /= N;
        s_variance = s_variance / N - s_mean * s_mean;
    }
    __syncthreads();
    
    // 归一化并写回
    if (tid < N) {
        float normalized = (s_data[tid] - s_mean) / sqrtf(s_variance + epsilon);
        output[idx] = gamma[tid] * normalized + beta[tid];
    }
}

这个fusion版本比分离的实现快2-3倍,主要收益来自:

  1. 减少kernel启动overhead
  2. 数据在shared memory中复用,减少global memory访问
  3. 更好的指令流水线利用

五、组装Transformer Layer:从零到一

有了前面三个优化算子,我们终于可以组装一个完整的Transformer Layer了。

5.1 Multi-Head Attention的实现

Attention的计算公式大家都很熟悉:

复制代码
Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V

在Multi-Head Attention中,我们把输入投影到多个子空间:

cpp 复制代码
class MultiHeadAttention {
private:
    int d_model;      // 512
    int num_heads;    // 8
    int d_k;          // 64 = d_model / num_heads
    
    float* W_q;       // Query权重
    float* W_k;       // Key权重  
    float* W_v;       // Value权重
    float* W_o;       // Output权重
    
public:
    void forward(float* input, float* output, int batch, int seq_len) {
        // 1. 线性投影得到Q, K, V
        float *Q, *K, *V;
        cudaMalloc(&Q, batch * seq_len * d_model * sizeof(float));
        cudaMalloc(&K, batch * seq_len * d_model * sizeof(float));
        cudaMalloc(&V, batch * seq_len * d_model * sizeof(float));
        
        gemm_optimized<<<...>>>(input, W_q, Q, ...);  // Q = XW_q
        gemm_optimized<<<...>>>(input, W_k, K, ...);  // K = XW_k
        gemm_optimized<<<...>>>(input, W_v, V, ...);  // V = XW_v
        
        // 2. 分割成多个head
        // Reshape: [batch, seq_len, d_model] -> [batch, num_heads, seq_len, d_k]
        
        // 3. 计算attention score
        float* scores;
        cudaMalloc(&scores, batch * num_heads * seq_len * seq_len * sizeof(float));
        
        // scores = QK^T / sqrt(d_k)
        gemm_optimized<<<...>>>(Q, K, scores, ...);
        scale_kernel<<<...>>>(scores, 1.0f / sqrtf(d_k), ...);
        
        // 4. 应用Softmax
        for (int b = 0; b < batch; b++) {
            for (int h = 0; h < num_heads; h++) {
                for (int i = 0; i < seq_len; i++) {
                    float* score_row = scores + (b*num_heads*seq_len*seq_len + 
                                                 h*seq_len*seq_len + i*seq_len);
                    softmax_online<<<1, 256>>>(score_row, score_row, seq_len);
                }
            }
        }
        
        // 5. 加权求和
        float* context;
        cudaMalloc(&context, batch * num_heads * seq_len * d_k * sizeof(float));
        gemm_optimized<<<...>>>(scores, V, context, ...);
        
        // 6. 合并多个head
        // Reshape: [batch, num_heads, seq_len, d_k] -> [batch, seq_len, d_model]
        
        // 7. 输出投影
        gemm_optimized<<<...>>>(context, W_o, output, ...);
        
        // 清理
        cudaFree(Q); cudaFree(K); cudaFree(V);
        cudaFree(scores); cudaFree(context);
    }
};

5.2 完整的Transformer Layer

一个完整的Transformer Layer包括:

  1. Multi-Head Attention
  2. Add & Norm (残差连接 + LayerNorm)
  3. Feed-Forward Network (两层全连接)
  4. Add & Norm
cpp 复制代码
class TransformerLayer {
private:
    MultiHeadAttention attention;
    FeedForward ffn;
    LayerNorm ln1, ln2;
    
public:
    void forward(float* input, float* output, int batch, int seq_len) {
        float *attn_out, *norm1_out, *ffn_out;
        int size = batch * seq_len * d_model;
        
        cudaMalloc(&attn_out, size * sizeof(float));
        cudaMalloc(&norm1_out, size * sizeof(float));
        cudaMalloc(&ffn_out, size * sizeof(float));
        
        // 1. Multi-Head Attention
        attention.forward(input, attn_out, batch, seq_len);
        
        // 2. Add & Norm
        residual_add<<<...>>>(input, attn_out, attn_out, size);
        ln1.forward(attn_out, norm1_out, batch * seq_len);
        
        // 3. Feed-Forward Network
        ffn.forward(norm1_out, ffn_out, batch * seq_len);
        
        // 4. Add & Norm
        residual_add<<<...>>>(norm1_out, ffn_out, ffn_out, size);
        ln2.forward(ffn_out, output, batch * seq_len);
        
        cudaFree(attn_out);
        cudaFree(norm1_out);
        cudaFree(ffn_out);
    }
};

5.3 性能优化要点

在实际运行中,我发现几个关键的优化点:

  1. Kernel Fusion:把Add和Norm融合,避免额外的内存读写
  2. 内存预分配:避免在forward中频繁malloc/free
  3. Stream并行:不同的attention head可以并行计算
  4. 混合精度:使用FP16存储,FP32计算

经过这些优化,单层Transformer的forward时间从初始的50ms降到了15ms左右。

六、Tensor Core:硬件加速的终极武器

前面的优化都是在软件层面做的。对于Volta架构及以后的GPU,我们还可以利用Tensor Core硬件单元。

6.1 Tensor Core简介

Tensor Core是NVIDIA专门为深度学习设计的矩阵计算单元,特点是:

  • 超高吞吐:一个时钟周期完成4x4x4的矩阵乘加
  • 支持混合精度:FP16输入,FP32累加
  • 10倍以上加速:相比传统CUDA Core

在Ampere架构(如A100)上,Tensor Core的理论性能可达312 TFLOPS (FP16),而普通CUDA Core只有19.5 TFLOPS (FP32)。

6.2 使用WMMA API

CUDA提供了WMMA (Warp-level Matrix Multiply Accumulate) API来使用Tensor Core:

cpp 复制代码
#include <mma.h>
using namespace nvcuda;

#define WMMA_M 16
#define WMMA_N 16
#define WMMA_K 16

__global__ void gemm_wmma(half* A, half* B, float* C,
                         int M, int N, int K) {
    // 声明fragment(寄存器中的矩阵片段)
    wmma::fragment<wmma::matrix_a, WMMA_M, WMMA_N, WMMA_K, 
                   half, wmma::row_major> a_frag;
    wmma::fragment<wmma::matrix_b, WMMA_M, WMMA_N, WMMA_K,
                   half, wmma::col_major> b_frag;
    wmma::fragment<wmma::accumulator, WMMA_M, WMMA_N, WMMA_K,
                   float> c_frag;
    
    int warpM = (blockIdx.x * blockDim.x + threadIdx.x) / 32;
    int warpN = blockIdx.y * blockDim.y + threadIdx.y;
    
    // 初始化累加器
    wmma::fill_fragment(c_frag, 0.0f);
    
    // K维度分块
    for (int i = 0; i < K; i += WMMA_K) {
        int aRow = warpM * WMMA_M;
        int aCol = i;
        int bRow = i;
        int bCol = warpN * WMMA_N;
        
        // 加载数据到fragment
        wmma::load_matrix_sync(a_frag, A + aRow * K + aCol, K);
        wmma::load_matrix_sync(b_frag, B + bRow * N + bCol, N);
        
        // 执行矩阵乘加
        wmma::mma_sync(c_frag, a_frag, b_frag, c_frag);
    }
    
    // 写回结果
    int cRow = warpM * WMMA_M;
    int cCol = warpN * WMMA_N;
    wmma::store_matrix_sync(C + cRow * N + cCol, c_frag, N, 
                           wmma::mem_row_major);
}

七、性能分析方法论

经过这周的学习,我总结了一套实用的CUDA性能优化流程:

7.1 Profile驱动的优化(PGO)

千万不要盲目优化!先用Nsight Compute找瓶颈:

bash 复制代码
ncu --set full -o profile_output ./my_program

重点看这几个指标:

  • SM Occupancy:低于50%说明线程块配置有问题
  • Memory Throughput:接近理论带宽说明memory bound
  • Compute Throughput:接近理论算力说明compute bound
  • Warp Stall Reasons:找到线程阻塞的原因

7.2 优化技术选择

根据瓶颈类型选择优化手段:

Memory Bound (内存受限)

  • Tiling/Shared Memory:减少global memory访问
  • Kernel Fusion:合并多个kernel
  • Memory Coalescing:优化访问模式

Compute Bound (计算受限)

  • Warp优化:避免分支divergence
  • Tensor Core:利用专用硬件
  • 混合精度:FP16计算,FP32存储

Latency Bound (延迟受限)

  • 提高Occupancy:增加并发warp数
  • ILP优化:展开循环,增加指令级并行
  • 异步操作:隐藏访存延迟

7.3 验证与迭代

每次优化后都要验证:

  1. 正确性检查:对比CPU参考实现
  2. 性能测试:多次运行取平均
  3. 分析改进:用Nsight看指标变化
  4. 持续迭代:找下一个瓶颈

我的经验是,不要指望一次优化就到位,要一步步来。每次抓主要矛盾,解决后再profile,找下一个瓶颈。

代码仓库:我的CUDA学习笔记 - Week 4

相关推荐
工程师老罗2 小时前
Pytorch如何验证模型?
人工智能·pytorch·深度学习
2301_765703142 小时前
C++中的职责链模式实战
开发语言·c++·算法
Hi_kenyon2 小时前
Skills精选
人工智能
顾西爵霞2 小时前
个人学习主页搭建指南:从毛坯房到精装户型
学习·html
StandbyTime2 小时前
《算法笔记》学习记录-第一章
c++·算法·算法笔记
hhhjhl2 小时前
flutter_for_openharmony逆向思维训练app实战+学习日历实现
学习·flutter
沈浩(种子思维作者)2 小时前
铁的居里点(770度就不被磁铁吸了)道理是什么?能不能精确计算出来?
人工智能·python·flask·量子计算
沛沛老爹2 小时前
Web开发者转型AI:多模态Agent视频分析技能开发实战
前端·人工智能·音视频
张小凡vip2 小时前
数据挖掘(九) --Anaconda 全面了解与安装指南
人工智能·数据挖掘