cuda编程笔记(29)-- CUDA Graph

CUDA Graph 是目前 NVIDIA 官方推荐的训练加速技术之一 ,它能显著降低 CPU 启动开销 ,提高训练循环中 kernel 启动效率与吞吐量

CUDA Graph 的背景与动机

在普通的深度学习训练过程中,一个训练 step 通常包含如下操作:

复制代码
Forward -> Loss -> Backward -> Optimizer Step

每个操作都需要 CPU 向 GPU 发出一系列 kernel 启动命令,例如:

  • 矩阵乘法 kernel;

  • 激活函数 kernel;

  • 梯度计算 kernel;

  • 参数更新 kernel。

这些操作虽然在 GPU 上执行极快(几微秒级),但每次启动都需要 CPU 与 GPU 之间的同步与调度,这带来较大的启动延迟

当 batch size 很小,或模型较小时,CPU 启动开销反而成为瓶颈

CUDA Graph 的核心思想

CUDA Graph 的思路是:

"把一整段固定计算过程记录成图(Graph),然后多次执行(replay)。"

也就是说:

  1. 你先"记录"一次训练步骤中所有 GPU 操作;

  2. 然后每次训练迭代,只需调用一次执行命令,整个图就能在 GPU 内部直接执行,无需 CPU 再次参与。

这相当于预编译计算图,极大减少 CPU 调度与 CUDA kernel 启动延迟。

API介绍

CUDA Graph API 主要涉及以下核心类型:

类型 含义
cudaGraph_t 表示一个图(Graph),由若干节点(kernel/memcpy/memset 等)和依赖关系组成。
cudaGraphNode_t 表示图中的一个节点(Node)。
cudaGraphExec_t 表示图的"可执行版本"(Executable Graph),通过 cudaGraphInstantiate 创建,可直接执行。

1️⃣ cudaGraphCreate

cpp 复制代码
cudaError_t cudaGraphCreate(cudaGraph_t *pGraph, unsigned int flags);

功能:

创建一个空的 CUDA 图。

参数:

  • pGraph:输出参数,用于返回创建的图对象;

  • flags:目前必须为 0(保留字段)。

2️⃣ cudaGraphDestroy

cpp 复制代码
cudaError_t cudaGraphDestroy(cudaGraph_t graph);

功能:

销毁图对象,释放相关资源。

向图中添加节点

1️⃣ cudaGraphAddKernelNode

cpp 复制代码
cudaError_t cudaGraphAddKernelNode(
    cudaGraphNode_t *pGraphNode,
    cudaGraph_t graph,
    const cudaGraphNode_t *pDependencies,
    size_t numDependencies,
    const struct cudaKernelNodeParams *pNodeParams
);

功能:

向图中添加一个 Kernel 执行节点

参数解释:

参数 含义
pGraphNode 输出参数,返回创建的节点句柄
graph 所属图对象
pDependencies 当前节点依赖的节点数组(即必须先执行完这些节点)
numDependencies 依赖节点数量
pNodeParams 内核参数结构体(见下)

cudaKernelNodeParams 结构体

cpp 复制代码
typedef struct cudaKernelNodeParams {
    void *func;             // kernel 函数指针(必须使用 void* 转换)
    dim3 gridDim;           // grid 大小
    dim3 blockDim;          // block 大小
    unsigned int sharedMemBytes; // 动态共享内存大小(字节)
    void **kernelParams;    // 参数数组(void* 指针数组)
    void **extra;           // 备用字段(一般为 NULL)
} cudaKernelNodeParams;

示例:

cpp 复制代码
__global__ void myKernel(float *data) {
    int idx = blockIdx.x * blockDim.x + threadIdx.x;
    data[idx] *= 2.0f;
}

cudaGraph_t graph;
cudaGraphNode_t kernelNode;
cudaGraphCreate(&graph, 0);

float *d_data;
cudaMalloc(&d_data, 1024 * sizeof(float));

void *kernelArgs[] = { &d_data };
cudaKernelNodeParams params = {};
params.func = (void*)myKernel;
params.gridDim = dim3(32);
params.blockDim = dim3(32);
params.sharedMemBytes = 0;
params.kernelParams = kernelArgs;
params.extra = NULL;

cudaGraphAddKernelNode(&kernelNode, graph, NULL, 0, &params);

2️⃣ cudaGraphAddMemcpyNode

cpp 复制代码
cudaError_t cudaGraphAddMemcpyNode(
    cudaGraphNode_t *pGraphNode,
    cudaGraph_t graph,
    const cudaGraphNode_t *pDependencies,
    size_t numDependencies,
    const struct cudaMemcpy3DParms *pCopyParams
);

功能:

添加一个 内存拷贝节点(支持 host ↔ device、device ↔ device)。

结构体: cudaMemcpy3DParms

和普通的 cudaMemcpy3D() 参数一致,支持指定源地址、目标地址、大小、方向等。

3️⃣ cudaGraphAddMemsetNode

cpp 复制代码
cudaError_t cudaGraphAddMemsetNode(
    cudaGraphNode_t *pGraphNode,
    cudaGraph_t graph,
    const cudaGraphNode_t *pDependencies,
    size_t numDependencies,
    const struct cudaMemsetParams *pMemsetParams
);

用于在图中添加内存初始化(memset)节点。

建立依赖关系

cpp 复制代码
cudaError_t cudaGraphAddDependencies(
    cudaGraph_t graph,
    const cudaGraphNode_t *from,
    const cudaGraphNode_t *to,
    size_t numDependencies
);

功能:

在图中添加边(依赖关系),指定某些节点要在其他节点之后执行。

cpp 复制代码
cudaGraphAddDependencies(graph, &memcpyNode, &kernelNode, 1);

表示 kernelNode 依赖 memcpyNode。

实例化可执行图

cpp 复制代码
cudaError_t cudaGraphInstantiate(
    cudaGraphExec_t *pGraphExec,
    cudaGraph_t graph,
    cudaGraphNode_t *pErrorNode,
    char *pLogBuffer,
    size_t bufferSize
);

功能:

将一个静态图编译成 可执行图对象(GraphExec),后续可直接执行。

参数 含义
pGraphExec 输出参数,可执行图对象
graph 原始图
pErrorNode 若实例化失败,返回出错节点
pLogBuffer / bufferSize 存放错误信息的日志缓冲区

执行图

cpp 复制代码
cudaError_t cudaGraphLaunch(
    cudaGraphExec_t graphExec,
    cudaStream_t stream
);

功能:

在指定的 CUDA 流上执行整个图。

更新图(动态调整)

cpp 复制代码
cudaError_t cudaGraphExecUpdate(
    cudaGraphExec_t graphExec,
    cudaGraph_t newGraph,
    cudaGraphNode_t *pErrorNode,
    char *pLogBuffer,
    size_t bufferSize
);

这比重新实例化整个图更高效。

销毁图对象

cpp 复制代码
cudaGraphExecDestroy(graphExec);
cudaGraphDestroy(graph);

CUDA 流捕获(Stream Capture) 的 API

它是 CUDA Graph 高层接口里最常用的方式,尤其适合训练或者推理任务。流捕获可以自动记录流上的 kernel、memcpy、memset 等操作,然后生成一个可执行图,省去手动添加节点和依赖的步骤。

开始流捕获:cudaStreamBeginCapture

cpp 复制代码
cudaError_t cudaStreamBeginCapture(
    cudaStream_t stream,
    cudaStreamCaptureMode mode
);

功能:

在指定流上开始捕获操作,之后该流上执行的 kernel/memcpy/memset 等操作都会被记录到一个图中,而不是立即执行。

参数:

参数 说明
stream 需要捕获的 CUDA 流
mode 捕获模式,通常用 cudaStreamCaptureModeGlobalcudaStreamCaptureModeThreadLocal

模式说明:

  • cudaStreamCaptureModeGlobal:流捕获期间,其他流的操作可能被依赖。

  • cudaStreamCaptureModeThreadLocal:捕获只在当前线程流有效,不会影响其他线程。

结束流捕获:cudaStreamEndCapture

cpp 复制代码
cudaError_t cudaStreamEndCapture(
    cudaStream_t stream,
    cudaGraph_t *pGraph
);

功能:

结束捕获并生成一个 CUDA Graph 对象

参数:

参数 说明
stream 捕获的流
pGraph 输出参数,返回捕获到的 CUDA Graph 对象

捕获得到的图可以像普通图一样实例化并执行:

cpp 复制代码
cudaGraphExec_t graphExec;
cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);
cudaGraphLaunch(graphExec, stream);
cudaStreamSynchronize(stream);

如果捕获的图只是部分参数变化,可以用:

cpp 复制代码
cudaGraphExecUpdate(graphExec, newGraph, nullptr, nullptr, 0);

相比重新实例化,更新开销更小。

注意事项

  1. 捕获期间所有操作都是记录而不是立即执行

    • 如果捕获中调用 cudaDeviceSynchronize 或其他同步操作,会报错。
  2. 内存操作

    • 流捕获可以自动记录 cudaMemcpy / cudaMemset,无需手动添加节点。
  3. 嵌套捕获

    • CUDA 11+ 支持嵌套流捕获,但需要小心依赖关系。

示例代码

cpp 复制代码
cudaStream_t stream;
cudaStreamCreate(&stream);

cudaStreamBeginCapture(stream, cudaStreamCaptureModeGlobal);

// kernel 或 memcpy 操作
myKernel<<<grid, block, 0, stream>>>(d_data);
cudaMemcpyAsync(d_out, d_data, size, cudaMemcpyDeviceToDevice, stream);

cudaGraph_t graph;
cudaStreamEndCapture(stream, &graph);

cudaGraphExec_t graphExec;
cudaGraphInstantiate(&graphExec, graph, nullptr, nullptr, 0);
cudaGraphLaunch(graphExec, stream);
cudaStreamSynchronize(stream);

cudaGraphExecDestroy(graphExec);
cudaGraphDestroy(graph);
cudaStreamDestroy(stream);

这样就可以 自动把流上的操作打包成一个可执行图,执行效率比每次 kernel launch 更高,特别适合深度学习训练循环。

相关推荐
Larry_Yanan4 小时前
QML学习笔记(四十一)QML的ColorDialog和FontDialog
笔记·学习
润 下4 小时前
C语言——深入解析C语言指针:从基础到实践从入门到精通(四)
c语言·开发语言·人工智能·经验分享·笔记·程序人生·其他
koo3644 小时前
李宏毅机器学习笔记25
人工智能·笔记·机器学习
hzp6664 小时前
Magnus:面向大规模机器学习工作负载的综合数据管理方法
人工智能·深度学习·机器学习·大模型·llm·数据湖·大数据存储
m0_678693334 小时前
深度学习笔记39-CGAN|生成手势图像 | 可控制生成(Pytorch)
深度学习·学习·生成对抗网络
将车2444 小时前
C++实现二叉树搜索树
开发语言·数据结构·c++·笔记·学习
日更嵌入式的打工仔4 小时前
存储同步管理器SyncManager 归纳
笔记·单片机·嵌入式硬件
Larry_Yanan4 小时前
QML学习笔记(四十)QML的FileDialog和FolderDialog
笔记·qt·学习
还是大剑师兰特5 小时前
Transformer 面试题及详细答案120道(91-100)-- 理论与扩展
人工智能·深度学习·transformer·大剑师