Triton-Lang在Transformer优化加速中的实践 | 得物技术

一、前言

众所周知,英伟达(Nvidia)自2006年推出CUDA以来,经过近20年的发展,尤其是经历了以卷积为代表的深度学习和近两年以Transformer为基础的LLM的推动,CUDA编程基本上成为了GPU编程的代名词。CUDA作为GPU的编程语言,不仅使用户能充分发挥Nvidia GPU的高性能的并行计算能力,也逐渐构筑了一个包括硬件、驱动、开发库和编程技巧的完备生态链,从而使CUDA成为了人工智能、高性能计算和云计算中的核心依赖。

(图片来源:Triton-lang documentation )

Triton是OpenAI 推出的以python为编程语言基础,专门为深度学习研发和高性能计算而设计的编程语言和编译器,旨在简化和优化GPU编程的复杂操作,降低高性能优化的门槛。

在大模型推理优化领域,已有很多优秀的工作开始应用Triton编写高效算子,例如近期被众多大模型推理框架集成的Attention算子FlashAttention、推理加速框架lightllm、训练加速框架的Unsloth等。

Triton的初期版本以CUDA为起点而开发,为没有CUDA基础的编程者提供快速编写高效CUDA kernel的方案,而随着迭代已逐渐支持其他芯片和编程工具,如AMD的ROCm,并在继续支持其他的芯片,如Intel的CPU。因而,除了简化高性能计算,同时Triton也在试图构建一个"CUDA-free"的更高层的kernel编写方案,打破"天下苦CUDA久矣"的局面,把复杂的对底层芯片的交互,交给其IR和底层的编译器。

综上,可以说Triton是起于CUDA,又不止于CUDA。几个词可以简单总结Triton的特点和发展方向:

  • 门槛低
  • 高效
  • 多平台

二、GPU基础

在学习Triton的编程设计前,还是需要了解GPU一些简单的基础架构知识和GPU编程的基础概念。

以下左图是引自NVIDIA经典Ampere架构的GA100(A100)的datasheet的整体架构示意图,展现其所有128个SMs(Streaming Multiprocessors)和各级缓存、HBM(高性能内存)和NvLink(Nvidia卡间互联)等;而右图是A100的单个SM(Streaming MultiProcessor, 多核流处理器) 的结构。

(图片来源:Nvidia-ampere-architecture-whitepaper )

从硬件的角度来讲,

  • SP (Streaming Processor 线程处理器) 是CUDA 编程模型的最基本单位。每个SP都有自己的registers (寄存器) 和 local memory (局部内存, L0 cache)。寄存器和局部内存只能被自己访问,不同的线程处理器之间彼此独立。
  • 由多个线程处理器 (SP) 和一块共享内存(shared memory, L1 cache)构成了一个SM。多核处理器里边的多个SP互相并行,且互不影响。每个SM内都有自己的共享内存,shared memory 可以被线程块内所有线程访问。

从软件的角度来讲,

  • thread(线程):一个CUDA程序被分成多个threads执行。
  • block 或 thread block (线程块):多个threads群组成一个block,同一个block中的threads可以同步,也可以通过shared memory 传递数据。
  • grid(网格):多个blocks会再构成grid。
  • warp:GPU执行程序时的调度单位。

对应关系:

  • 一个SP可以执行一个thread。
  • CUDA的device在执行任务时,会把任务分成一个个的block分配给SM执行, 而每个block又会以warp为单位执行(Nvidia把32个threads组成一个warp, warp即是SM调度和运行的基本单元,所有SP执行同一指令,但每个thread使用各自的data)。
  • 一个warp需要占用一个SM,多个warps则会轮流进入SM处理。

(图片来源:OpenAI official introduction )

将上述结构大致抽象成3个组成部分DRAM, SRAM和ALU, 其中DRAM即各个HBMs(即俗称的显存),SRAM指各级缓存,ALU即计算单元(GPU中的SM),而当用户优化CUDA代码时需要考虑:

  • DRAM读写时的内存合并:以保证充分利用GPU的内存带宽;
  • 数据必须手动分配至各级SRAM:以尽可能地避免共享内存冲突;
  • 计算流程必须在SM内部和外部谨慎合理地设计、分配和调度:以促进并行线程的计算效率。

而在编程设计时充分考虑以上,即使是对于富有经验的CUDA编程者也颇具挑战,因而Triton希望底层编译器对多数的调度细节能自动优化,而用户只需要考虑一些顶层的逻辑设计,即SMs层级的,例如矩阵分片,SM之间数据同步等问题。

其官网介绍给出了一个对比,

(表格来源:OpenAI official introduction)

通俗而言,相比于CUDA,使用Triton,你不必控制所有内容,因为有些事情可以留给工具自动优化;用Triton编写的模块可能不一定优于顶级的CUDA算子,但是性能通常能优于普通的CUDA kernel;而前者的门槛大大低于后者。

因而Triton的编程设计过程,其关键在于SM层级的并行处理过程的设计,即画好SM层级的网格图以表示算子的计算过程。

三、Triton 编程实例

向量求和

内核函数

向量求和对于Triton是一个"Hello World"式的示例。使用Pytorch,对于两个同长度的vector,直接相加,非常简单。

java 复制代码
size = 1024
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_torch = x + y

而对于Triton,需要编写一个内核函数(kernel)和一个调用函数(wrapper),调用时的并行网格图如下:

kernel 函数代码如下:

java 复制代码
import triton.language as tl

@triton.jit
def add_kernel(x_ptr,  # 第一个输入向量的指针
               y_ptr,  # 第二个输入向量的指针
               output_ptr,  # 输出向量的指针
               n_elements,  # 向量长度
               BLOCK_SIZE: tl.constexpr,  # 每个线程块处理的元素数量
               ):
    # 有多个'程序'处理不同的数据, 用pid标识当前是哪个程序
    pid = tl.program_id(axis=0)  
    # 计算当前程序所需要的数据的偏置
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    # 创建一个掩码以防止内存操作超出范围
    mask = offsets < n_elements
    # 从 DRAM 加载 x 和 y
    x = tl.load(x_ptr + offsets, mask=mask)
    y = tl.load(y_ptr + offsets, mask=mask)
    output = x + y
    # 将计算结果output写回 DRAM
    tl.store(output_ptr + offsets, output, mask=mask)
  • @triton.jit装饰器用于定义内核函数,在程序执行时即时编译并在GPU上执行。

  • x_ptr, y_ptr, output_ptr 分别是两个输入向量和一个输出向量的指针,n_elements表示向量长度,BLOCK_SIZE 的数据类型为 tl.constexpr,表示一个编译时的常量,定义了每个线程块处理数据时的数据长度。

  • 向量相加虽然简单,但是基本体现了内核函数通常的编写流程,定义维度 -> 计算偏置 -> 设置掩码 -> 读取数据 -> 计算过程 -> 写回数据。

    • 定义维度:当前程序(线程块)通过tl.program_id 获取自己的pid, 该程序id标识了当前程序的唯一性。tl.program_id和块大小(BLOCK_SIZE)也决定了并行处理时对整个数据块的划分,比如在这个向量数据的处理时,axis=0表示一维的划分,再比如矩阵乘法的操作,当我们用分块矩阵的思路设计内核时,则是在二维层面的操作。
    • 计算偏置:得到当前程序的id时,我们需要从整个数据块拿取当前程序所需的那块数据,所以需要通过id和块大小(BLOCK_SIZE)计算offsets。需要注意的是,这里的offsets是一个list,即是当前需要的数据的所有索引。
    • 设置掩码:因为数据的长度通常无法被我们预设的块大小整除,比如下图示例中的最后一块,所以需要设置mask,防止内存操作超出范围。
    • 读取数据:根据输入数据的指针、偏置和掩码,从DRAM(显存) 读取数据到当前程序所在的SRAM(缓存)。
    • 计算过程:在这里定义我们所需要的计算流程,例如将两段数据 x和y相加。
    • 写回数据:处理完数据后,同样根据输出数据的指针、偏置和掩码,把结果output从SRAM写回DRAM。

线程块在GPU的计算模型里又被称为CTA(Cooperative Thread Array),以上的计算过程相当于一个CTA处理单个block。

而当缓存受限时,我们也可以在单个CTA中处理多个blocks, 如下图和相应的写法:

java 复制代码
@triton.jit
def add_kernel(x_ptr, y_ptr, o_ptr, n_elements, num_blocks_per_CTA, BLOCK_SIZE: tl.constexpr,):
    pid = tl.program_id(axis=0)  
    program_offsets = pid * num_blocks_per_CTA * BLOCK_SIZE 
    offsets = program_offsets + tl.arange(0, BLOCK_SIZE)
    
    for i in range(num_blocks_per_CTA):
        mask = offsets < n_elements
        x = tl.load(x_ptr + offsets, mask=mask)
        y = tl.load(y_ptr + offsets, mask=mask)
        output = x + y
        tl.store(o_ptr + offsets, output, mask=mask)
        offsets += BLOCK_SIZE

接口函数

有了内核函数,我们需要再写一个wrapper,就可以调用内核(好比Pytorch的torch.Add api, 即加号"+")。

java 复制代码
def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    output = torch.empty_like(x)
    assert x.is_cuda and y.is_cuda and output.is_cuda
    n_elements = output.numel()
    # SPMD启动网格,表示并行运行的内核实例数。
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    # 调用内核函数
    add_kernel[grid](x, y, output, n_elements, BLOCK_SIZE=1024)
    # 我们返回一个指向z的句柄,但是,由于`torch.cuda.synchronize()`尚未被调用,内核此时仍在异步运行。
    return output

torch.manual_seed(0)
size = 98432
x = torch.rand(size, device='cuda')
y = torch.rand(size, device='cuda')
output_triton = add(x, y)

这里需要注意两点:

  • 内核程序的运行需要启动一个网格,Triton以SPMD(单程序多数据,与SIMD类似)的方式执行程序。网格grid 与内核函数中一开始我们获取的程序标识(id)相对应,在向量处理这个示例中,它是一个一维的网格,数据格式可以是Callable(metaparameters) -> Tuple[int] ,如上面代码(triton.cdiv是Triton封装的除法,cdiv表示ceiling division),也可以直接是Tuple[int],如(n_elements + BLOCK_SIZE-1)//BLOCK_SIZE。
  • 我们看上述调用内核函数的格式,可以看到,内核函数可以被grid索引,每次索引可以得到一个GPU内核,启动一个程序。x,y,output这些张量作为参数传入内核函数的同时,被隐式地转化为指向各自张量第一个元素的指针。

性能测试

Triton自带性能测试函数,可以帮助衡量自己设计的算子与baseline之间的差距。装饰器@triton.testing.perf_report用于装饰benchmark函数,而triton.testing.Benchmark函数定义了plot折线图的属性,在benchmark函数里面我们可以定义指标来比较不同算子之间的性能,如Triton和Pytorch算子之间在不同size计算下的吞吐差距。

java 复制代码
@triton.testing.perf_report(
    triton.testing.Benchmark(
        x_names=['size'],  # Argument names to use as an x-axis for the plot.
        x_vals=[2**i for i in range(12, 28, 1)],  # Different possible values for `x_name`.
        x_log=True,  # x axis is logarithmic.
        line_arg='provider',  # Argument name whose value corresponds to a different line in the plot.
        line_vals=['triton', 'torch'],  # Possible values for `line_arg`.
        line_names=['Triton', 'Torch'],  # Label name for the lines.
        styles=[('blue', '-'), ('green', '-')],  # Line styles.
        ylabel='GB/s',  # Label name for the y-axis.
        plot_name='vector-add-performance',  # Name for the plot. Used also as a file name for saving the plot.
        args={},  # Values for function arguments not in `x_names` and `y_name`.
    ))
def benchmark(size, provider):
    x = torch.rand(size, device='cuda', dtype=torch.float32)
    y = torch.rand(size, device='cuda', dtype=torch.float32)
    quantiles = [0.5, 0.2, 0.8]
    if provider == 'torch':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: x + y, quantiles=quantiles)
    if provider == 'triton':
        ms, min_ms, max_ms = triton.testing.do_bench(lambda: add(x, y), quantiles=quantiles)
    gbps = lambda ms: 3 * x.numel() * x.element_size() * 1e-9 / (ms * 1e-3)
    return gbps(ms), gbps(max_ms), gbps(min_ms)

benchmark.run(print_data=True, show_plots=True, save_path="./")

在左图,我们可以看到在向量维度较小时,torch的计算更快,而当维度较大时,Triton算子的计算更快一些。

相比向量的计算可能无法体现自定义算子的优势,右图展示了Triton 官方教程中定义的矩阵乘法算子的性能,可以看到其和cuBLAS编写的算子相比已能够达到持平的性能。

Vector Add Triton vs. Pytorch

Matrix multiplication Triton vs cuBLAS

调试

早期的Triton核函数开发不支持调试,目前已经支持pdb和各种python ide的断点调试,只需设置环境变量即可。

java 复制代码
os.environ["TRITON_INTERPRET"]=1

矩阵乘法

两个矩阵相乘,

在Pytorch中,两个矩阵相乘可以直接以torch.matmul(A, B)计算得到。而进一步对其稍作优化,我们立刻能想到的通常是分块矩阵。用Pytorch表示具体的流程:

java 复制代码
# Pytorch
import torch
from typing import Tuple

@torch.jit.script
def block_matrix_multiplication(A: torch.Tensor, B:torch.Tensor, 
                                M: int, N: int, K: int, 
                                BLOCK_SIZE_M: int, BLOCK_SIZE_N: int, BLOCK_SIZE_K: int) -> torch.Tensor:
    C = torch.zeros((M, N), dtype=torch.float32)
    for m in range(0, M, BLOCK_SIZE_M):
        for n in range(0, N, BLOCK_SIZE_N):
            acc = torch.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=torch.float32)
            for k in range(0, K, BLOCK_SIZE_K):
                a = A[m : m+BLOCK_SIZE_M, k : k+BLOCK_SIZE_K]
                b = B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]
                acc += torch.matmul(a, b)
            C[m : m+BLOCK_SIZE_M, n : n+BLOCK_SIZE_N] = acc
    return C

# 用法示例
A = torch.rand(100, 100)
B = torch.rand(100, 100)
result = block_matrix_multiplication(A, B, 100, 100, 100, 16, 16, 16)
  • A, B表示两个输入矩阵,其二维尺寸分别为(M, K)和(K, N); BLOCK_SIZE_M, BLOCK_SIZE_N, BLOCK_SIZE_K 分别为分块时M, N, K三个维度的分块尺寸。
  • 主程序用3层循环计算,外边2层循环分别是依次遍历A和B两个矩阵的列和行;而最里边的循环,则是对于输出矩阵,每个对应的分块(M, N), 需要A和B相对应的列和行的分块依次相乘后累加。

从以上逻辑可以看出,这是一个行主序(Row Major)的分块矩阵乘法顺序。

我们可以画出其网格图,以下橙色的blocks 表示单个CTA的计算过程。

按照网格图,将上述的计算过程改写为Triton的内核函数:

java 复制代码
# Triton kernel
@triton.jit
det matmul_kernel(A, B, C, M, N, K, 
    BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr):
    # 2d grid
    pid_m = tl.program_id(0)
    pid_n = tl.program_id(1)
    offsets_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
    mask_m = offsets_m < M
    offsets_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
    mask_n = offsets_n < N
    
    # 2d tile
    acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
    for start_k in range(0, K, BLOCK_K):
        offsets_k = start_k + tl.arange(0, BLOCK_K)
        mask_k = offsets_k < K
        
        a_ptrs = A + offsets_m[:, None]*K + offsets_k[None, :]
        mask_a = mask_m[:, None] & mask_k[None, :]
        b_ptrs = B + offsets_k[:, None]*N + offsets_n[None, :]
        mask_b = mask_k[:, None] & mask_n[None, :]
    
        a = tl.load(a_ptrs, mask=mask_a, other=0)
        b = tl.load(b_ptrs, mask=mask_b, other=0)
        acc += tl.dot(a, b)
    c_ptrs = C + offsets_m[:, None]*N + offsets_n
    mask_c = mask_m[:, None] & mask_n[None, :]
    tl.store(c_ptrs, acc, mask = mask_c) 
    
# grid = (tl.cdiv(M, BLOCK_M), tl.cdiv(N, BLOCK_N), 1)

以上代码,我们关注几点:

  • 网格:采用二维网格,作为程序标识,并计算其输入分块矩阵和输出分块矩阵的偏置,以及分块矩阵的掩码矩阵;
  • 累加:由上图可知,输出矩阵的每个分块,分别由其对应的K//BLOCK_K个矩阵相乘的结果累加得到。

行主序和列主序的代码和计算顺序如下,虽说CUDA是并行计算的程序,但是当我们将矩阵分为很多的程序执行时,如果我们的GPU并没有足够的SM来同时执行所有程序,因而这些程序是先后被加载入SM计算的。而CUDA默认是以列主序存储数据的,所以有时候列主序的程序性能要优于行主序。

  • 行主序:
java 复制代码
for m in range(0, m, BLOCK_M):
    for n in range(0, n, BLOCK_N): 
        CTA(m, n) ....
  • 列主序:
java 复制代码
for n in range(0, N, BLOCK_N):
    for m in range(0, M, BLOCK_M): 
        CTA(m, n) ....

而Triton的官网给出了一个基于L2 cache 优化的方案。其思路是以减少访存次数来提高cache的命中率。我们可以从下图比较其与通常的乘法算子的区别。

通常的列主序(column-major-ordering)

分组后的列主序 (grouped-column-major-ordering)

从上可以看出,左边计算4个CTA时,需要读取1列和4行,总共要进行5次读取;而右边的操作,只需要读取2列和2行,共4次读取。实际计算中,矩阵的行列维度数值都较大,分组后的计算在访存上会有一定的优化,而实际中在例如A100的硬件上这样的优化也能有10%的提升。

以下是官网优化示例给出的核心代码,相比于上述的二维索引,引入group之后采用一维索引,而代码的本质则是将这个一维索引pid转化为二维索引(pid_m, pid_n),而在这个变化中,我们重新定义了结果矩阵的计算顺序(即上图,右图中区别于左图的元素计算顺序)。

java 复制代码
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

我们用两个9x9的矩阵相乘来说明这个索引的过程:

  • 首先num_pid_m,num_pid_n分别计算了两个矩阵的M, N 维度各自块的个数;如下图中num_pid_m=9,num_pid_n=9。
  • GROUP_SIZE_M定义了group的维度;如下图中的GROUP_SIZE_M=3。
  • num_pid_in_group 为单个group中块的个数;如下图中 GROUP_SIZE_M * num_pid_n=27。
  • group_id 则计算得到 pid 是在哪个group中,即group的 id;如下图中 pid // num_pid_in_group=1。
  • first_pid_m 计算的是这个pid所在的group的第一个块的pid_m的值,以方便后续为算pid最终的pid_m, pid_n 提供偏置;如下图中first_pid_m=3 ,即图中第27块的pid_m。
  • group_size_m 则是计算了这个pid所在的group的行数,这是为了避免M无法为GROUP_SIZE_M所整除时,最后一个group的行数小于GROUP_SIZE_M;如下图的group的行数值为GROUP_SIZE_M,若当图中的行数为8时,可以想像最后一行的group_size_m为2。
  • 最后两行代码则是计算 pid 的真正坐标 (pid_m, pid_n)。例如下图的pid=33, 则pid_m=3+(33%3)=3, pid_n=(33%27)//3=2。

旋转位置编码

旋转位置编码(RoPE, Rotate Position Encoding)是Transformer 进入大模型应用时代后的重要算子,在Llama,ChatGLM 等主流的大模型中都有应用。关于旋转位置编码的原理和作用可以参考原论文和作者博客。其计算过程可以简要表示成以下的旋转变换,

以下是一个Huggingface中Llama RoPE的前向计算流程。

  • d 表示 embedding的维度,则位置编码的相位频率表示如下,m = [0, 2, 4, ..., d/2] , f = (1/10000)^m,!
  • n 表示token的个数,

以上是计算cos和sin两个旋转变换矩阵的过程,而矩阵q和k在做注意力乘法前先做简单处理。

java 复制代码
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

最后是旋转变换:

java 复制代码
q_embed = (q*cos) + (rotate_half(q)*sin)
k_embed = (k*cos) + (rotate_half(k)*sin)

我们来看Triton kernel对前向过程的实现:

给定矩阵Q,Cos,Sin, 它们的维度分别为

h 是注意力模块中的head数,为方便说明,可令b=1;需要实现的计算过程是!

为计算方便,现将Q reshape为(n,hd)。

考虑并行的维度时,按 token 并行是首先能想到的,n_rows = n,再考虑到有限的缓存,可以将另一个维度按group分,可以定义一个常量GROUP_SIZE=4(当然这里也可以设计成autotune模式,自动选择合适值),然后可以计算得到embedding维度的group数量n_groups ,并定义我们的网格 grid。

java 复制代码
div, mod = divmod(n, GROUP_SIZE)
n_groups = div + (mod != 0)
grid = (n_rows, n_groups)

下图阐释了整体和单元的计算过程。并行程序按两个维度计算,每个CTA的计算过程中有个次数为GROUP_SIZE的循环累加,各自累加计算得到q*cos和 rotate_half(q)*sin,再相加。

java 复制代码
def _rope_embedding(
    Q,     Q_row_stride,
    cos, cos_row_stride,
    sin, sin_row_stride,
    seqlen,
    head_dim      : tl.constexpr,
    n_heads       : tl.constexpr,
    BLOCK_SIZE    : tl.constexpr,
):
    """
        Calculates the RoPE Embedding quickly
        RoPE is Q * cos + rotate_half(Q) * sin
        See our blog post for more info
    """
    GROUP_SIZE = 4
    row_position  = tl.program_id(0)
    group_head_position = tl.program_id(1)
    col_offsets  = tl.arange(0, BLOCK_SIZE)
    half_head_dim = head_dim // 2
    mask = col_offsets < half_head_dim

    sin1 = tl.load(sin + (row_position % seqlen)*sin_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)
    cos1 = tl.load(cos + (row_position % seqlen)*cos_row_stride + \
                   half_head_dim*0 + col_offsets, mask = mask, other = 0)

    head_start = group_head_position * GROUP_SIZE
    head_end = min((head_start + GROUP_SIZE), n_heads)

    for k in range(head_start, head_end):
        offs_q1 = row_position * Q_row_stride + k * head_dim + col_offsets
        offs_q2 = row_position * Q_row_stride + k * head_dim + col_offsets + half_head_dim

        Q1 = tl.load(Q + offs_q1, mask = mask, other = 0).to(sin1.dtype)
        Q2 = tl.load(Q + offs_q2, mask = mask, other = 0).to(sin1.dtype)

        tl.store(Q + offs_q1, Q1*cos1 - Q2*sin1, mask = mask)
        tl.store(Q + offs_q2, Q2*cos1 + Q1*sin1, mask = mask)

四、总结

参考Triton的官方示例、以及其他社区的开源加速工具,我们还能看到其他许多算子,诸如rmsNorm, Softmax, Flash-Attention 等的具体加速方案和并行思路,以及他们的反向传导过程。而这些算子则组成了Transformer 的注意力模块的整体结构。通过对各个模块的并行优化,可以实现对注意力计算的推理和训练加速。

Transformer 的注意力结构及相关算子

在实际应用中,我们可以自己编写和优化自定义算子,也可以引用社区优秀的开源加速算子库,而且Pytorch在2.0+版本后已将Triton集成到其编译器,用torch.compile()可以直接对已加载模型编译,帮助自动优化可优化的算子。

参考:

Triton-lang 介绍 (openai.com/index/trito...)

Triton-tutorial(triton-lang.org/main/gettin...)

Nvidia GA100 datasheet (images.nvidia.cn/aem-dam/en-...)

CUDA programming(docs.nvidia.com/cuda/cuda-c...)

文 / xujiong

关注得物技术,每周更新技术干货

要是觉得文章对你有帮助的话,欢迎评论转发点赞~

未经得物技术许可严禁转载,否则依法追究法律责任。

相关推荐
·云扬·41 分钟前
【Leetcode hot 100】101.对称二叉树
算法·leetcode·职场和发展
代码AI弗森2 小时前
从 IDE 到 CLI:AI 编程代理工具全景与落地指南(附对比矩阵与脚本化示例)
ide·人工智能·矩阵
007tg5 小时前
从ChatGPT家长控制功能看AI合规与技术应对策略
人工智能·chatgpt·企业数据安全
Memene摸鱼日报5 小时前
「Memene 摸鱼日报 2025.9.11」腾讯推出命令行编程工具 CodeBuddy Code, ChatGPT 开发者模式迎来 MCP 全面支持
人工智能·chatgpt·agi
linjoe995 小时前
【Deep Learning】Ubuntu配置深度学习环境
人工智能·深度学习·ubuntu
Greedy Alg6 小时前
LeetCode 142. 环形链表 II
算法
睡不醒的kun6 小时前
leetcode算法刷题的第三十二天
数据结构·c++·算法·leetcode·职场和发展·贪心算法·动态规划
先做个垃圾出来………6 小时前
残差连接的概念与作用
人工智能·算法·机器学习·语言模型·自然语言处理
AI小书房7 小时前
【人工智能通识专栏】第十三讲:图像处理
人工智能
fanstuck7 小时前
基于大模型的个性化推荐系统实现探索与应用
大数据·人工智能·语言模型·数据挖掘