Triton-Lang在Transformer优化加速中的实践 | 得物技术

一、前言

众所周知,英伟达(Nvidia)自2006年推出CUDA以来,经过近20年的发展,尤其是经历了以卷积为代表的深度学习和近两年以Transformer为基础的LLM的推动,CUDA编程基本上成为了GPU编程的代名词。CUDA作为GPU的编程语言,不仅使用户能充分发挥Nvidia GPU的高性能的并行计算能力,也逐渐构筑了一个包括硬件、驱动、开发库和编程技巧的完备生态链,从而使CUDA成为了人工智能、高性能计算和云计算中的核心依赖。

(图片来源:Triton-lang documentation )

Triton是OpenAI 推出的以python为编程语言基础,专门为深度学习研发和高性能计算而设计的编程语言和编译器,旨在简化和优化GPU编程的复杂操作,降低高性能优化的门槛。

在大模型推理优化领域,已有很多优秀的工作开始应用Triton编写高效算子,例如近期被众多大模型推理框架集成的Attention算子FlashAttention、推理加速框架lightllm、训练加速框架的Unsloth等。

Triton的初期版本以CUDA为起点而开发,为没有CUDA基础的编程者提供快速编写高效CUDA kernel的方案,而随着迭代已逐渐支持其他芯片和编程工具,如AMD的ROCm,并在继续支持其他的芯片,如Intel的CPU。因而,除了简化高性能计算,同时Triton也在试图构建一个"CUDA-free"的更高层的kernel编写方案,打破"天下苦CUDA久矣"的局面,把复杂的对底层芯片的交互,交给其IR和底层的编译器。

综上,可以说Triton是起于CUDA,又不止于CUDA。几个词可以简单总结Triton的特点和发展方向:

  • 门槛低
  • 高效
  • 多平台

二、GPU基础

在学习Triton的编程设计前,还是需要了解GPU一些简单的基础架构知识和GPU编程的基础概念。

以下左图是引自NVIDIA经典Ampere架构的GA100(A100)的datasheet的整体架构示意图,展现其所有128个SMs(Streaming Multiprocessors)和各级缓存、HBM(高性能内存)和NvLink(Nvidia卡间互联)等;而右图是A100的单个SM(Streaming MultiProcessor, 多核流处理器) 的结构。

(图片来源:Nvidia-ampere-architecture-whitepaper )

从硬件的角度来讲,

  • SP (Streaming Processor 线程处理器) 是CUDA 编程模型的最基本单位。每个SP都有自己的registers (寄存器) 和 local memory (局部内存, L0 cache)。寄存器和局部内存只能被自己访问,不同的线程处理器之间彼此独立。
  • 由多个线程处理器 (SP) 和一块共享内存(shared memory, L1 cache)构成了一个SM。多核处理器里边的多个SP互相并行,且互不影响。每个SM内都有自己的共享内存,shared memory 可以被线程块内所有线程访问。

从软件的角度来讲,

  • thread(线程):一个CUDA程序被分成多个threads执行。
  • block 或 thread block (线程块):多个threads群组成一个block,同一个block中的threads可以同步,也可以通过shared memory 传递数据。
  • grid(网格):多个blocks会再构成grid。
  • warp:GPU执行程序时的调度单位。

对应关系:

  • 一个SP可以执行一个thread。
  • CUDA的device在执行任务时,会把任务分成一个个的block分配给SM执行, 而每个block又会以warp为单位执行(Nvidia把32个threads组成一个warp, warp即是SM调度和运行的基本单元,所有SP执行同一指令,但每个thread使用各自的data)。
  • 一个warp需要占用一个SM,多个warps则会轮流进入SM处理。

(图片来源:OpenAI official introduction )

将上述结构大致抽象成3个组成部分DRAM, SRAM和ALU, 其中DRAM即各个HBMs(即俗称的显存),SRAM指各级缓存,ALU即计算单元(GPU中的SM),而当用户优化CUDA代码时需要考虑:

  • DRAM读写时的内存合并:以保证充分利用GPU的内存带宽;
  • 数据必须手动分配至各级SRAM:以尽可能地避免共享内存冲突;
  • 计算流程必须在SM内部和外部谨慎合理地设计、分配和调度:以促进并行线程的计算效率。

而在编程设计时充分考虑以上,即使是对于富有经验的CUDA编程者也颇具挑战,因而Triton希望底层编译器对多数的调度细节能自动优化,而用户只需要考虑一些顶层的逻辑设计,即SMs层级的,例如矩阵分片,SM之间数据同步等问题。

其官网介绍给出了一个对比,

(表格来源:OpenAI official introduction)

通俗而言,相比于CUDA,使用Triton,你不必控制所有内容,因为有些事情可以留给工具自动优化;用Triton编写的模块可能不一定优于顶级的CUDA算子,但是性能通常能优于普通的CUDA kernel;而前者的门槛大大低于后者。

因而Triton的编程设计过程,其关键在于SM层级的并行处理过程的设计,即画好SM层级的网格图以表示算子的计算过程。

三、Triton 编程实例

向量求和

内核函数

向量求和对于Triton是一个"Hello World"式的示例。使用Pytorch,对于两个同长度的vector,直接相加,非常简单。

java 复制代码
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y

而对于Triton,需要编写一个内核函数(kernel)和一个调用函数(wrapper),调用时的并行网格图如下:

kernel 函数代码如下:

java 复制代码
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # 第一个输入向量的指针
               y_ptr,  # 第二个输入向量的指针
               output_ptr,  # 输出向量的指针
               n_elements,  # 向量长度
               BLOCK_SIZE: tl.constexpr,  # 每个线程块处理的元素数量
               ):
    # 有多个'程序'处理不同的数据, 用pid标识当前是哪个程序
    pid = tl.program_id(axis=0)  
    # 计算当前程序所需要的数据的偏置
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建一个掩码以防止内存操作超出范围
    mask = offsets < n_elements
    # 从 DRAM 加载 x 和 y
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # 将计算结果output写回 DRAM
    tl.store(output_ptr + offsets, output, mask=mask)
  • @triton.jit装饰器用于定义内核函数,在程序执行时即时编译并在GPU上执行。

  • x_ptr, y_ptr, output_ptr 分别是两个输入向量和一个输出向量的指针,n_elements表示向量长度,BLOCK_SIZE 的数据类型为 tl.constexpr,表示一个编译时的常量,定义了每个线程块处理数据时的数据长度。

  • 向量相加虽然简单,但是基本体现了内核函数通常的编写流程,定义维度 -> 计算偏置 -> 设置掩码 -> 读取数据 -> 计算过程 -> 写回数据。

    • 定义维度:当前程序(线程块)通过tl.program_id 获取自己的pid, 该程序id标识了当前程序的唯一性。tl.program_id和块大小(BLOCK_SIZE)也决定了并行处理时对整个数据块的划分,比如在这个向量数据的处理时,axis=0表示一维的划分,再比如矩阵乘法的操作,当我们用分块矩阵的思路设计内核时,则是在二维层面的操作。
    • 计算偏置:得到当前程序的id时,我们需要从整个数据块拿取当前程序所需的那块数据,所以需要通过id和块大小(BLOCK_SIZE)计算offsets。需要注意的是,这里的offsets是一个list,即是当前需要的数据的所有索引。
    • 设置掩码:因为数据的长度通常无法被我们预设的块大小整除,比如下图示例中的最后一块,所以需要设置mask,防止内存操作超出范围。
    • 读取数据:根据输入数据的指针、偏置和掩码,从DRAM(显存) 读取数据到当前程序所在的SRAM(缓存)。
    • 计算过程:在这里定义我们所需要的计算流程,例如将两段数据 x和y相加。
    • 写回数据:处理完数据后,同样根据输出数据的指针、偏置和掩码,把结果output从SRAM写回DRAM。

线程块在GPU的计算模型里又被称为CTA(Cooperative Thread Array),以上的计算过程相当于一个CTA处理单个block。

而当缓存受限时,我们也可以在单个CTA中处理多个blocks, 如下图和相应的写法:

java 复制代码
@triton.jit
def add_kernel(x_ptr, y_ptr, o_ptr, n_elements, num_blocks_per_CTA, BLOCK_SIZE: tl.constexpr,):
    pid = tl.program_id(axis=0)  
    program_offsets = pid * num_blocks_per_CTA * BLOCK_SIZE 
    offsets = program_offsets + tl.arange(0, BLOCK_SIZE)
    
    for i in range(num_blocks_per_CTA):
        mask = offsets < n_elements
        x = tl.load(x_ptr + offsets, mask=mask)
        y = tl.load(y_ptr + offsets, mask=mask)
        output = x + y
        tl.store(o_ptr + offsets, output, mask=mask)
        offsets += BLOCK_SIZE

接口函数

有了内核函数,我们需要再写一个wrapper,就可以调用内核(好比Pytorch的torch.Add api, 即加号"+")。

java 复制代码
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # SPMD启动网格,表示并行运行的内核实例数。
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # 调用内核函数
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。
    return output

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_triton = add(x, y)

这里需要注意两点:

  • 内核程序的运行需要启动一个网格,Triton以SPMD(单程序多数据,与SIMD类似)的方式执行程序。网格grid 与内核函数中一开始我们获取的程序标识(id)相对应,在向量处理这个示例中,它是一个一维的网格,数据格式可以是Callable(metaparameters) -> Tuple[int] ,如上面代码(triton.cdiv是Triton封装的除法,cdiv表示ceiling division),也可以直接是Tuple[int],如(n_elements + BLOCK_SIZE-1)//BLOCK_SIZE。
  • 我们看上述调用内核函数的格式,可以看到,内核函数可以被grid索引,每次索引可以得到一个GPU内核,启动一个程序。x,y,output这些张量作为参数传入内核函数的同时,被隐式地转化为指向各自张量第一个元素的指针。

性能测试

Triton自带性能测试函数,可以帮助衡量自己设计的算子与baseline之间的差距。装饰器@triton.testing.perf_report用于装饰benchmark函数,而triton.testing.Benchmark函数定义了plot折线图的属性,在benchmark函数里面我们可以定义指标来比较不同算子之间的性能,如Triton和Pytorch算子之间在不同size计算下的吞吐差距。

java 复制代码
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(12, 28, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['Triton', 'Torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='vector-add-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark.run(print_data=True, show_plots=True, save_path="./")

在左图,我们可以看到在向量维度较小时,torch的计算更快,而当维度较大时,Triton算子的计算更快一些。

相比向量的计算可能无法体现自定义算子的优势,右图展示了Triton 官方教程中定义的矩阵乘法算子的性能,可以看到其和cuBLAS编写的算子相比已能够达到持平的性能。

Vector Add Triton vs. Pytorch

Matrix multiplication Triton vs cuBLAS

调试

早期的Triton核函数开发不支持调试,目前已经支持pdb和各种python ide的断点调试,只需设置环境变量即可。

java 复制代码
os.environ["TRITON_INTERPRET"]=1

矩阵乘法

两个矩阵相乘,

在Pytorch中,两个矩阵相乘可以直接以torch.matmul(A, B)计算得到。而进一步对其稍作优化,我们立刻能想到的通常是分块矩阵。用Pytorch表示具体的流程:

java 复制代码
# Pytorch
import torch
from typing import Tuple

@torch.jit.script
def block_matrix_multiplication(A: torch.Tensor, B:torch.Tensor, 
                                M: int, N: int, K: int, 
                                BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int) -> torch.Tensor:
    C = torch.zeros((M, N), dtype=torch.float32)
    for m in range(0, M, BLOCK_SIZE_M):
        for n in range(0, N, BLOCK_SIZE_N):
            acc = torch.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=torch.float32)
            for k in range(0, K, BLOCK_SIZE_K):
                a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
                b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
                acc += torch.matmul(a, b)
            C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
    return C

# 用法示例
A = torch.rand(100, 100)
B = torch.rand(100, 100)
result = block_matrix_multiplication(A, B, 100, 100, 100, 16, 16, 16)
  • A, B表示两个输入矩阵,其二维尺寸分别为(M, K)和(K, N); BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 分别为分块时M, N, K三个维度的分块尺寸。
  • 主程序用3层循环计算,外边2层循环分别是依次遍历A和B两个矩阵的列和行;而最里边的循环,则是对于输出矩阵,每个对应的分块(M, N), 需要A和B相对应的列和行的分块依次相乘后累加。

从以上逻辑可以看出,这是一个行主序(Row Major)的分块矩阵乘法顺序。

我们可以画出其网格图,以下橙色的blocks 表示单个CTA的计算过程。

按照网格图,将上述的计算过程改写为Triton的内核函数:

java 复制代码
# Triton kernel
@triton.jit
det matmul_kernel(A, B, C, M, N, K, 
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    # 2d grid
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = offsets_m < M
    offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_n = offsets_n < N
    
    # 2d tile
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for start_k in range(0, K, BLOCK_K):
        offsets_k = start_k + tl.arange(0, BLOCK_K)
        mask_k = offsets_k < K
        
        a_ptrs = A + offsets_m[:, None]*K + offsets_k[None, :]
        mask_a = mask_m[:, None] & mask_k[None, :]
        b_ptrs = B + offsets_k[:, None]*N + offsets_n[None, :]
        mask_b = mask_k[:, None] & mask_n[None, :]
    
        a = tl.load(a_ptrs, mask=mask_a, other=0)
        b = tl.load(b_ptrs, mask=mask_b, other=0)
        acc += tl.dot(a, b)
    c_ptrs = C + offsets_m[:, None]*N + offsets_n
    mask_c = mask_m[:, None] & mask_n[None, :]
    tl.store(c_ptrs, acc, mask = mask_c) 
    
# grid = (tl.cdiv(M, BLOCK_M), tl.cdiv(N, BLOCK_N), 1)

以上代码,我们关注几点:

  • 网格:采用二维网格,作为程序标识,并计算其输入分块矩阵和输出分块矩阵的偏置,以及分块矩阵的掩码矩阵;
  • 累加:由上图可知,输出矩阵的每个分块,分别由其对应的K//BLOCK_K个矩阵相乘的结果累加得到。

行主序和列主序的代码和计算顺序如下,虽说CUDA是并行计算的程序,但是当我们将矩阵分为很多的程序执行时,如果我们的GPU并没有足够的SM来同时执行所有程序,因而这些程序是先后被加载入SM计算的。而CUDA默认是以列主序存储数据的,所以有时候列主序的程序性能要优于行主序。

  • 行主序:
java 复制代码
for m in range(0, m, BLOCK_M):
    for n in range(0, n, BLOCK_N): 
        CTA(m, n) ....
  • 列主序:
java 复制代码
for n in range(0, N, BLOCK_N):
    for m in range(0, M, BLOCK_M): 
        CTA(m, n) ....

而Triton的官网给出了一个基于L2 cache 优化的方案。其思路是以减少访存次数来提高cache的命中率。我们可以从下图比较其与通常的乘法算子的区别。

通常的列主序(column-major-ordering)

分组后的列主序 (grouped-column-major-ordering)

从上可以看出,左边计算4个CTA时,需要读取1列和4行,总共要进行5次读取;而右边的操作,只需要读取2列和2行,共4次读取。实际计算中,矩阵的行列维度数值都较大,分组后的计算在访存上会有一定的优化,而实际中在例如A100的硬件上这样的优化也能有10%的提升。

以下是官网优化示例给出的核心代码,相比于上述的二维索引,引入group之后采用一维索引,而代码的本质则是将这个一维索引pid转化为二维索引(pid_m, pid_n),而在这个变化中,我们重新定义了结果矩阵的计算顺序(即上图,右图中区别于左图的元素计算顺序)。

java 复制代码
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

我们用两个9x9的矩阵相乘来说明这个索引的过程:

  • 首先num_pid_m,num_pid_n分别计算了两个矩阵的M, N 维度各自块的个数;如下图中num_pid_m=9,num_pid_n=9。
  • GROUP_SIZE_M定义了group的维度;如下图中的GROUP_SIZE_M=3。
  • num_pid_in_group 为单个group中块的个数;如下图中 GROUP_SIZE_M * num_pid_n=27。
  • group_id 则计算得到 pid 是在哪个group中,即group的 id;如下图中 pid // num_pid_in_group=1。
  • first_pid_m 计算的是这个pid所在的group的第一个块的pid_m的值,以方便后续为算pid最终的pid_m, pid_n 提供偏置;如下图中first_pid_m=3 ,即图中第27块的pid_m。
  • group_size_m 则是计算了这个pid所在的group的行数,这是为了避免M无法为GROUP_SIZE_M所整除时,最后一个group的行数小于GROUP_SIZE_M;如下图的group的行数值为GROUP_SIZE_M,若当图中的行数为8时,可以想像最后一行的group_size_m为2。
  • 最后两行代码则是计算 pid 的真正坐标 (pid_m, pid_n)。例如下图的pid=33, 则pid_m=3+(33%3)=3, pid_n=(33%27)//3=2。

旋转位置编码

旋转位置编码(RoPE, Rotate Position Encoding)是Transformer 进入大模型应用时代后的重要算子,在Llama,ChatGLM 等主流的大模型中都有应用。关于旋转位置编码的原理和作用可以参考原论文和作者博客。其计算过程可以简要表示成以下的旋转变换,

以下是一个Huggingface中Llama RoPE的前向计算流程。

  • d 表示 embedding的维度,则位置编码的相位频率表示如下,m = [0, 2, 4, ..., d/2] , f = (1/10000)^m,!
  • n 表示token的个数,

以上是计算cos和sin两个旋转变换矩阵的过程,而矩阵q和k在做注意力乘法前先做简单处理。

java 复制代码
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

最后是旋转变换:

java 复制代码
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*sin)

我们来看Triton kernel对前向过程的实现:

给定矩阵Q,Cos,Sin, 它们的维度分别为

h 是注意力模块中的head数,为方便说明,可令b=1;需要实现的计算过程是!

为计算方便,现将Q reshape为(n,hd)。

考虑并行的维度时,按 token 并行是首先能想到的,n_rows = n,再考虑到有限的缓存,可以将另一个维度按group分,可以定义一个常量GROUP_SIZE=4(当然这里也可以设计成autotune模式,自动选择合适值),然后可以计算得到embedding维度的group数量n_groups ,并定义我们的网格 grid。

java 复制代码
div, mod = divmod(n, GROUP_SIZE)
n_groups = div + (mod != 0)
grid = (n_rows, n_groups)

下图阐释了整体和单元的计算过程。并行程序按两个维度计算,每个CTA的计算过程中有个次数为GROUP_SIZE的循环累加,各自累加计算得到q*cos和 rotate_half(q)*sin,再相加。

java 复制代码
def _rope_embedding(
    Q,     Q_row_stride,
    cos, cos_row_stride,
    sin, sin_row_stride,
    seqlen,
    head_dim      : tl.constexpr,
    n_heads       : tl.constexpr,
    BLOCK_SIZE    : tl.constexpr,
):
    """
        Calculates the RoPE Embedding quickly
        RoPE is Q * cos + rotate_half(Q) * sin
        See our blog post for more info
    """
    GROUP_SIZE = 4
    row_position  = tl.program_id(0)
    group_head_position = tl.program_id(1)
    col_offsets  = tl.arange(0, BLOCK_SIZE)
    half_head_dim = head_dim // 2
    mask = col_offsets < half_head_dim

    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)

    head_start = group_head_position * GROUP_SIZE
    head_end = min((head_start + GROUP_SIZE), n_heads)

    for k in range(head_start, head_end):
        offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
        offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim

        Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
        Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)

        tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
        tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)

四、总结

参考Triton的官方示例、以及其他社区的开源加速工具,我们还能看到其他许多算子,诸如rmsNorm, Softmax, Flash-Attention 等的具体加速方案和并行思路,以及他们的反向传导过程。而这些算子则组成了Transformer 的注意力模块的整体结构。通过对各个模块的并行优化,可以实现对注意力计算的推理和训练加速。

Transformer 的注意力结构及相关算子

在实际应用中,我们可以自己编写和优化自定义算子,也可以引用社区优秀的开源加速算子库,而且Pytorch在2.0+版本后已将Triton集成到其编译器,用torch.compile()可以直接对已加载模型编译,帮助自动优化可优化的算子。

参考:

Triton-lang 介绍 (openai.com/index/trito...)

Triton-tutorial(triton-lang.org/main/gettin...)

Nvidia GA100 datasheet (images.nvidia.cn/aem-dam/en-...)

CUDA programming(docs.nvidia.com/cuda/cuda-c...)

文 / xujiong

关注得物技术,每周更新技术干货

要是觉得文章对你有帮助的话,欢迎评论转发点赞~

未经得物技术许可严禁转载,否则依法追究法律责任。

相关推荐
自信的小螺丝钉14 分钟前
Leetcode 279. 完全平方数 动态规划 完全背包问题
算法·leetcode·动态规划
努力的泽泽17 分钟前
【动态规划-矩阵】5.下降路径最小和
算法·矩阵·动态规划
说私域22 分钟前
社群团购项目运营策略的深度剖析:融合链动2+1模式、AI智能名片与S2B2C商城小程序的综合应用
大数据·人工智能·小程序
IT古董43 分钟前
【机器学习】主动学习-增加标签的操作方法-流式选择性采样(Stream-based selective sampling)
人工智能·学习·机器学习
被制作时长两年半的个人练习生1 小时前
【AscendC】tiling方案设计不当引起的一个时隐时现的bug
人工智能·bug·算子开发·ascendc
KeyPan1 小时前
【机器学习:十九、反向传播】
人工智能·深度学习·机器学习
埃菲尔铁塔_CV算法3 小时前
双线性插值算法:原理、实现、优化及在图像处理和多领域中的广泛应用与发展趋势(二)
c++·人工智能·算法·机器学习·计算机视觉
程序猿阿伟3 小时前
《AI赋能鸿蒙Next,打造极致沉浸感游戏》
人工智能·游戏·harmonyos
叫我龙翔3 小时前
【算法日记】从零开始认识动态规划(一)
c++·算法·动态规划·代理模式
AC100AC3 小时前
[NOIP2007 提高组] 矩阵取数游戏
算法·游戏·矩阵