[Triton笔记4]低内存 Dropout

DropOut介绍

Dropout(随机失活)是深度学习中一种强大的正则化技术。为了理解其数学原理,我们可以将其拆解为训练阶段推理(评估)阶段两个部分来讨论。

1. 核心数学定义

假设我们有一个输入向量 x ∈ R n x \in \mathbb{R}^n x∈Rn。在应用 Dropout 时,我们引入一个随机向量 r ∈ { 0 , 1 } n r \in \{0, 1\}^n r∈{0,1}n,其中每个元素 r i r_i ri 服从伯努利分布(Bernoulli distribution):

r i ∼ Bernoulli ( 1 − p ) r_i \sim \text{Bernoulli}(1-p) ri∼Bernoulli(1−p)

这里, p p p 是失活概率(即置为 0 的概率)。

训练阶段 (Training)

在训练过程中,输出向量 y y y 的计算方式如下:

y = x ⊙ r y = x \odot r y=x⊙r

其中 ⊙ \odot ⊙ 表示逐元素乘法(Hadamard product)。这意味着对于每个神经元,以概率 p p p 将其置为 0,以概率 1 − p 1-p 1−p 将其保留。

推理阶段 (Inference)

在测试或部署时,我们希望利用网络学习到的全部信息。此时,我们不再进行随机失活。如果不做任何处理,直接使用训练好的权重,推理时的期望输出会比训练时大。

为了保证训练和推理阶段在期望值(Expectation)上的一致性,我们需要进行缩放。

2. 为什么需要缩放?

我们可以从期望值的角度来推导。

对于输入 x x x 的任意一个元素 x i x_i xi,在训练阶段,经过 Dropout 后的输出值 y i y_i yi 是一个随机变量:

  • y i = x i y_i = x_i yi=xi (概率为 1 − p 1-p 1−p)

  • y i = 0 y_i = 0 yi=0 (概率为 p p p)

其数学期望为:

E [ y i ] = x i ⋅ ( 1 − p ) + 0 ⋅ p = x i ( 1 − p ) E[y_i] = x_i \cdot (1-p) + 0 \cdot p = x_i(1-p) E[yi]=xi⋅(1−p)+0⋅p=xi(1−p)

这意味着在训练时,神经元的输出平均只有原始值的 ( 1 − p ) (1-p) (1−p) 倍。为了使推理阶段的输出与训练阶段的期望输出在量级上保持一致,我们在推理阶段有两种处理方案:

方案 A:倒置 Dropout (Inverted Dropout) ------ PyTorch 采用的方法

在训练时,我们将保留下来的神经元放大 1 1 − p \frac{1}{1-p} 1−p1 倍:

y t r a i n = 1 1 − p ( x ⊙ r ) y_{train} = \frac{1}{1-p} (x \odot r) ytrain=1−p1(x⊙r)

这样,训练时的期望值就变回了:

E [ y t r a i n ] = 1 1 − p ⋅ x i ( 1 − p ) = x i E[y_{train}] = \frac{1}{1-p} \cdot x_i(1-p) = x_i E[ytrain]=1−p1⋅xi(1−p)=xi

优点: 在推理阶段,我们什么都不用做,直接使用网络即可。这极大简化了部署和推理的代码逻辑。

方案 B:推理缩放

如果在训练时不进行上述放大,那么在推理阶段,我们需要手动将权重或输出乘上 ( 1 − p ) (1-p) (1−p),以匹配训练时的缩放水平。这种方式现在较少使用。

3. 总结:一致性原则

Dropout 的数学本质是通过稀疏化输入来强制模型学习鲁棒特征(防止共同适应,Co-adaptation)。

  • 正则化效果: 通过每次训练只激活一部分神经元,模型无法依赖特定的特征组合,这相当于训练了一个由多个子网络组成的"集成模型(Ensemble)"。

  • 范数保持(Norm Consistency): 通过引入倒置缩放因子 1 1 − p \frac{1}{1-p} 1−p1,确保了无论丢弃多少神经元,特征层的激活值在训练和推理时的分布尺度保持一致,从而避免了激活值数值漂移对后续层(如 Softmax 或 Batch Normalization)的影响。

Baseline

首先看一下 baseline 的实现。

python 复制代码
import tabulate
import torch


import triton
import triton.language as tl


@triton.jit
def _dropout(
    x_ptr,      # 输入指针
    x_keep_ptr, # pointer to a mask of 0s and 1s 由 0 和 1 组成的掩码的指针
    output_ptr, # pointer to the output 输出指针
    n_elements, # number of elements in the `x` tensor `x` 张量的元素数量
    p,          # probability that an element of `x` is changed to zero 元素 `x` 被设置为 0 的概率
    BLOCK_SIZE: tl.constexpr,
):
    pid = tl.program_id(axis=0)
    block_start = pid * BLOCK_SIZE
    offsets = block_start + tl.arange(0, BLOCK_SIZE)
    mask = offsets < n_elements
    # Load data
    # 加载数据
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)
    # The line below is the crucial part, described in the paragraph above!
    # 下一行是上段描述的关键部分
    output = tl.where(x_keep, x / (1 - p), 0.0)
    # Write-back output
    # 写回输出
    tl.store(output_ptr + offsets, output, mask=mask)


def dropout(x, x_keep, p):
    output = torch.empty_like(x)
    assert x.is_contiguous()
    n_elements = x.numel()
    grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']), )
    _dropout[grid](x, x_keep, output, n_elements, p, BLOCK_SIZE=1024)
    return output


# Input tensor
# 输入张量
x = torch.randn(size=(10, )).cuda()
# Dropout mask
# Dropout 掩码
p = 0.5
x_keep = (torch.rand(size=(10, )) > p).to(torch.int32).cuda()
#
output = dropout(x, x_keep=x_keep, p=p)
print(tabulate.tabulate([
    ["input"] + x.tolist(),
    ["keep mask"] + x_keep.tolist(),
    ["output"] + output.tolist(),
]))

运行结果

bash 复制代码
---------  ---------  --------  --------  --------  --------  --------  --------  --------  --------  --------
input      -0.285416  0.932678  -1.73046  0.353095  -0.62449  0.469328  0.489296  -1.01694  0.310098  -1.38205
keep mask   0         1          1        1          0        1         0          1        0          1
output      0         1.86536   -3.46091  0.70619    0        0.938655  0         -2.03389  0         -2.76411
---------  ---------  --------  --------  --------  --------  --------  --------  --------  --------  --------

这段代码展示了在 Triton 中实现 Dropout 的一种直接(内存密集型)方式。它的核心逻辑是将"掩码生成"与"数据计算"分离开来。

我们可以从以下几个关键维度来理解这段 Baseline 代码的原理:

1. 处理流程的"显式分离"

在这个 Baseline 中,Dropout 的过程被分为两步:

  • 掩码准备(外部生成): 在调用 dropout 函数之前,你已经在 Python 端通过 torch.rand 生成了一个与输入 x 同形状的 x_keep 张量(位掩码)。这通常会消耗额外的显存,并且需要将这个掩码从内存(或显存)中读取出来。

  • Triton 计算(核函数执行): Triton 的 _dropout 函数只负责"消费"这个已经存在的掩码。

2. Triton 核函数逻辑剖析

_dropout 函数中,Triton 采用了分块处理(Tiling)的方法:

  • 并行化分块: 通过 pid = tl.program_id(axis=0),代码将输入张量切分成多个 BLOCK_SIZE 大小的块。每个 GPU 线程块(Block)独立处理一部分数据,这保证了极高的并行效率。

  • 内存加载:

    python 复制代码
    x = tl.load(x_ptr + offsets, mask=mask)
    x_keep = tl.load(x_keep_ptr + offsets, mask=mask)

    这里使用了 tl.load 从显存中加载数据。注意,为了处理边界问题(即数组长度不是 BLOCK_SIZE 的整数倍),使用了 mask=mask 参数。

  • 核心逻辑运算:

python 复制代码
    output = tl.where(x_keep, x / (1 - p), 0.0)

这是该函数最核心的部分,即倒置 Dropout (Inverted Dropout) 的实现:

  • tl.where(condition, x, y):类似于 C 语言的 condition ? x : y
  • 如果 x_keep 为 1(True),则执行 x / (1 - p)。这正如你之前问到的数学原理,通过除以 ( 1 − p ) (1-p) (1−p) 来缩放保留的神经元。
  • 如果 x_keep 为 0(False),则直接填入 0.0,实现"失活"。

3. 该 Baseline 的"局限性"(为什么要优化?)

虽然这段代码逻辑清晰,但它存在一个明显的性能瓶颈,即题目中提到的**"低内存"改进的动机**:

  1. 显存带宽浪费: 你必须预先分配一个和 x 一样大的 x_keep 张量。这意味着在执行时,不仅要读输入 x,还要读一个巨大的 x_keep。对于大规模模型,这种额外的显存读写操作会严重拖慢运行速度。
  2. 显存占用: x_keep 占用了额外的空间。如果输入张量非常巨大,这可能会导致显存溢出。
  3. 调度开销: 外部生成 x_keep 涉及额外的 CPU-GPU 交互和启动 kernel 的开销。

总结

这个 Baseline 的原理是基于查找表的执行模式

它把随机性来源(随机掩码)当成普通的输入数据来处理。这种方法是"预计算"的,因此逻辑简单且易于调试。

接下来的进阶方向:

真正的"低内存"实现会舍弃 x_keep_ptr 参数,改为在 kernel 内部使用一个 seed,利用 tl.rand(seed, offsets) 直接在 GPU 寄存器中即时生成随机数,从而彻底省去存储和读取 x_keep 的开销。

种子化 Dropout

在之前的 Baseline 中,x_keep 是一个显存中的张量,这导致了额外的显存读写(Read/Write)。

在这里,我们描述一种替代实现,它具有以下优点:

  1. 更小的内存占用。
  2. 较少的数据移动。
  3. 简化了在多次调用内核函数时持久化随机性的管理。

生成 Triton 中的伪随机数很简单!在本教程中,我们将使用 triton.language.rand 函数,该函数基于给定的种子和一组 int32 偏移量生成一个块的均匀分布的 float32 值,范围在 (0, 1) 内。

注意 Triton 的 PRNG 实现基于 Philox 算法(详见 [SALMON2011])。

1. 核心变化:从"存"到"算"

在之前的 Baseline 中,x_keep 是一个显存中的张量,这导致了额外的显存读写(Read/Write)。而在 _seeded_dropout 中:

python 复制代码
random = tl.rand(seed, offsets)
x_keep = random > p
  • 即时计算(Just-in-Time Generation): 每一对线程都使用 tl.rand 动态生成随机值。因为它是基于 seedoffsets(位置索引)计算出来的,只要这两个值确定,生成的随机数序列就是确定的。

  • 内存零占用: 你不再需要预先分配一个形状为 (10,)x_keep 张量,程序运行中只有输入 x 和输出 output 占据显存。

2. 为什么这是"低内存"的?

  • 带宽开销: 传统 Dropout 需要从显存加载 mask 数据,这会消耗巨大的显存带宽。对于显存受限的 GPU 计算(如大型 Transformer 层),带宽往往是瓶颈。

  • Philox 算法: Triton 使用的是 Philox 算法,这是一种专门针对并行计算设计的计数器驱动(Counter-based)PRNG(伪随机数生成器)。它的特点是:给定一个输入 (seed, offset),它就能通过一系列高效的位运算直接输出对应的随机数,完全不需要保存随机状态。这使得它可以在 GPU 的寄存器(Register)中直接完成计算,无需访问慢速的显存。

triton.language.rand

python 复制代码
@jit
def rand(seed, offset, n_rounds: tl.constexpr = N_ROUNDS_DEFAULT):

1. 参数含义 (Parameters)

  • seed (Scalar):

    • 含义 : 这是一个整数(通常为 int32int64)。

    • 作用 : 它定义了随机数序列的起始状态 。在并行计算中,seed 确保了只要该值固定,无论何时何地运行,生成的"随机"序列都是完全一致的(这对于调试、模型保存和分布式训练中的一致性至关重要)。

  • offset (Block):

    • 含义 : 一个与数据索引对应的块(例如在当前 _seeded_dropout 中,就是 block_start + tl.arange(0, BLOCK_SIZE))。

    • 作用 : 它充当了 Philox 算法中的"计数器(Counter)"。由于每个线程(或线程块)处理的数据位置不同,通过 offset 保证了即使共享同一个 seed,每个位置上的元素也会生成不同且相互独立的随机数。

  • n_rounds (constexpr, 默认值):

    • 含义: 一个编译期常量,决定了 Philox 算法内部迭代轮数。

    • 作用: 轮数越多,随机分布的统计质量越好(越接近均匀分布),但计算开销也越大。通常使用默认值即可在性能和统计特性之间取得平衡。

2. 返回值 (Return Value)

  • 返回值 : 一个与 offset 形状相同的块(Block),其中的元素类型为 float32

  • 分布特性 : 结果遵循标准均匀分布 ,即每个值都在区间 [ 0 , 1 ) [0, 1) [0,1) 内。

相关推荐
凌波粒6 小时前
深度学习入门(鱼书)第2章笔记——感知机
人工智能·笔记·深度学习
RainCityLucky7 小时前
Java Swing 自定义组件库分享(七)
java·笔记·后端
清平乐的技术专栏7 小时前
【Kafka笔记】(一)认识 Kafka
笔记·分布式·kafka
Fuyo_11197 小时前
C++中的活字印刷术——模板·初阶
开发语言·c++·笔记
大明者省7 小时前
Ubuntu22.04 宝塔面板与 XFCE 远程桌面端口兼容性分析
运维·服务器·数据库·笔记
哆哆啦ss8 小时前
使用 Obsidian + GitHub Actions + GitHub Pages 搭建内容发布流
笔记
清平乐的技术专栏8 小时前
【Kafka笔记】(四)Kafka 三种消费模式
笔记·分布式·kafka
LuminousCPP8 小时前
数据结构 - 线性表第三篇:基于顺序表实现 C 语言通讯录(基础功能篇)
c语言·数据结构·经验分享·笔记·算法
Szime8 小时前
深智微华润微代理端整理:FS32K144国产化替代三年BCM选型验证避坑笔记
笔记