cuda stream, cuda event, cuda graph用例

cuda_graph_examples

复制代码
"""
CUDA Graph 中的 Stream 和 Event 用法

CUDA Graph 将一系列 kernel 操作录制为图,然后可以重放,大幅减少 launch overhead。
Stream 和 Event 在 CUDA Graph 中的关键作用:
1. graph 必须在特定 stream 上捕获 (capture)
2. graph 必须在特定 stream 上重放 (replay)
3. Event 用于 graph 与外部 stream 之间的同步

主要 API:
- torch.cuda.CUDAGraph()                    创建 graph
- graph.capture_begin(pool)                  开始捕获 (需要 mempool)
- graph.capture_end()                        结束捕获
- graph.replay()                             重放 graph
- graph.reset()                              重置 graph
- torch.cuda.graph_pool_handle()             获取当前设备的 mempool
- torch.cuda.graph(graph, stream)            在指定 stream 上捕获 (上下文管理器)
"""

import torch
import time


def example_basic_cuda_graph():
    """CUDA Graph 基础: 捕获与重放"""
    print("=" * 60)
    print("1. CUDA Graph 基础: 捕获与重放")
    print("=" * 60)

    N = 4096
    # 静态输入 (graph 要求输入地址不变)
    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    # 创建 graph
    g = torch.cuda.CUDAGraph()

    # 预热 (确保 kernel 已编译)
    c = a @ b
    torch.cuda.synchronize()

    # === 捕获 graph ===
    # 必须在非默认 stream 上捕获
    capture_stream = torch.cuda.Stream()
    with torch.cuda.stream(capture_stream):
        g.capture_begin()
        c = a @ b
        g.capture_end()

    # === 重放 graph ===
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(100):
        g.replay()
    torch.cuda.synchronize()
    graph_time = (time.perf_counter() - start) / 100 * 1000

    # 对比: 直接调用
    torch.cuda.synchronize()
    start = time.perf_counter()
    for _ in range(100):
        c = a @ b
    torch.cuda.synchronize()
    eager_time = (time.perf_counter() - start) / 100 * 1000

    print(f"Graph 重放平均耗时: {graph_time:.4f} ms")
    print(f"Eager 执行平均耗时: {eager_time:.4f} ms")
    print(f"加速比: {eager_time / graph_time:.2f}x")
    print()


def example_graph_on_custom_stream():
    """在自定义 stream 上捕获和重放 graph"""
    print("=" * 60)
    print("2. 自定义 Stream 上的 CUDA Graph")
    print("=" * 60)

    N = 2048
    s = torch.cuda.Stream()

    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()

    # 预热
    with torch.cuda.stream(s):
        c = a @ b
    s.synchronize()

    # 在自定义 stream s 上捕获 graph
    with torch.cuda.stream(s):
        g.capture_begin()
        c = a @ b
        g.capture_end()

    # 在自定义 stream s 上重放
    with torch.cuda.stream(s):
        g.replay()
    s.synchronize()

    print(f"Graph 在自定义 stream 上重放成功")
    print(f"c shape: {c.shape}, device: {c.device}")
    print()


def example_graph_with_event_sync():
    """使用 Event 在 graph 重放和普通操作之间同步"""
    print("=" * 60)
    print("3. Graph + Event 同步")
    print("=" * 60)

    N = 2048

    graph_stream = torch.cuda.Stream()
    compute_stream = torch.cuda.Stream()

    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")
    d = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()

    # 预热
    with torch.cuda.stream(graph_stream):
        c = a @ b
    graph_stream.synchronize()

    # 在 graph_stream 上捕获
    with torch.cuda.stream(graph_stream):
        g.capture_begin()
        c = a @ b
        g.capture_end()

    # 场景: graph 重放完成后, compute_stream 上做后续计算
    with torch.cuda.stream(graph_stream):
        g.replay()

    # 用 event 标记 graph 重放完成
    graph_done = torch.cuda.Event()
    graph_done.record(graph_stream)

    # compute_stream 等待 graph 完成
    with torch.cuda.stream(compute_stream):
        graph_done.wait(compute_stream)
        d = c.relu()  # 需要 c 已计算完毕

    compute_stream.synchronize()
    print(f"Graph 重放 + Event 同步完成")
    print(f"c shape: {c.shape}, d shape: {d.shape}")
    print(f"d 前 5 个元素: {d[0, :5]}")
    print()


def example_multiple_graphs_parallel():
    """多个 graph 在不同 stream 上并行重放"""
    print("=" * 60)
    print("4. 多 Graph 并行重放")
    print("=" * 60)

    N = 2048

    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    # 两组独立数据
    a1 = torch.randn(N, N, device="cuda")
    b1 = torch.randn(N, N, device="cuda")
    c1 = torch.empty(N, N, device="cuda")

    a2 = torch.randn(N, N, device="cuda")
    b2 = torch.randn(N, N, device="cuda")
    c2 = torch.empty(N, N, device="cuda")

    g1 = torch.cuda.CUDAGraph()
    g2 = torch.cuda.CUDAGraph()

    # 预热
    with torch.cuda.stream(s1):
        c1 = a1 @ b1
    with torch.cuda.stream(s2):
        c2 = a2 @ b2
    s1.synchronize()
    s2.synchronize()

    # 分别在各自的 stream 上捕获
    with torch.cuda.stream(s1):
        g1.capture_begin()
        c1 = a1 @ b1
        g1.capture_end()

    with torch.cuda.stream(s2):
        g2.capture_begin()
        c2 = a2 @ b2
        g2.capture_end()

    # 并行重放
    torch.cuda.synchronize()
    start = time.perf_counter()
    with torch.cuda.stream(s1):
        g1.replay()
    with torch.cuda.stream(s2):
        g2.replay()
    s1.synchronize()
    s2.synchronize()
    parallel_time = time.perf_counter() - start

    # 串行重放对比
    torch.cuda.synchronize()
    start = time.perf_counter()
    g1.replay()
    g2.replay()
    torch.cuda.synchronize()
    serial_time = time.perf_counter() - start

    print(f"并行重放时间: {parallel_time * 1000:.3f} ms")
    print(f"串行重放时间: {serial_time * 1000:.3f} ms")
    print()


def example_graph_mempool():
    """CUDA Graph 内存池 (mempool) 的使用"""
    print("=" * 60)
    print("5. Graph Memory Pool")
    print("=" * 60)

    N = 2048

    # 获取当前设备的 graph mempool handle
    pool = torch.cuda.graph_pool_handle()
    print(f"Mempool handle: {pool}")

    # 也可以为其他设备获取
    if torch.cuda.device_count() > 1:
        with torch.cuda.device(1):
            pool1 = torch.cuda.graph_pool_handle()
            print(f"Device 1 mempool: {pool1}")

    # 使用 mempool 创建 graph
    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()

    # 预热
    c = a @ b
    torch.cuda.synchronize()

    # 使用 mempool 捕获 (必须在非默认 stream 上)
    capture_stream = torch.cuda.Stream()
    with torch.cuda.stream(capture_stream):
        g.capture_begin(pool=pool)
        c = a @ b
        g.capture_end()

    g.replay()
    torch.cuda.synchronize()
    print(f"使用 mempool 的 graph 重放成功")
    print()


def example_graph_reset_and_reuse():
    """Graph 的 reset 和复用"""
    print("=" * 60)
    print("6. Graph Reset 和复用")
    print("=" * 60)

    N = 1024

    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()
    capture_stream = torch.cuda.Stream()

    # 第一次捕获: matmul (必须在非默认 stream 上)
    with torch.cuda.stream(capture_stream):
        c = a @ b
    capture_stream.synchronize()

    with torch.cuda.stream(capture_stream):
        g.capture_begin()
        c = a @ b
        g.capture_end()
    g.replay()
    torch.cuda.synchronize()
    print(f"第一次重放 (matmul): c[0,0] = {c[0, 0].item():.4f}")

    # reset 后可以重新捕获不同的操作
    g.reset()

    # 第二次捕获: matmul + relu (必须在非默认 stream 上)
    with torch.cuda.stream(capture_stream):
        g.capture_begin()
        c = (a @ b).relu()
        g.capture_end()
    g.replay()
    torch.cuda.synchronize()
    print(f"第二次重放 (matmul+relu): c[0,0] = {c[0, 0].item():.4f}")
    print()


def example_graph_with_stream_wait():
    """Graph 重放 stream 与外部 stream 的 wait_stream 同步"""
    print("=" * 60)
    print("7. Graph Stream + wait_stream 同步")
    print("=" * 60)

    N = 2048

    graph_stream = torch.cuda.Stream()
    data_stream = torch.cuda.Stream()

    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()

    # 预热
    with torch.cuda.stream(graph_stream):
        c = a @ b
    graph_stream.synchronize()

    # 捕获
    with torch.cuda.stream(graph_stream):
        g.capture_begin()
        c = a @ b
        g.capture_end()

    # 场景: data_stream 上先做一些预处理 (原地修改, graph 重放时能读到)
    with torch.cuda.stream(data_stream):
        a.mul_(2.0)   # 原地修改 a, graph 重放时 c = a @ b 会用到新值
        b.mul_(0.5)   # 原地修改 b

    # graph_stream 等待 data_stream 完成
    with torch.cuda.stream(graph_stream):
        graph_stream.wait_stream(data_stream)
        g.replay()

    graph_stream.synchronize()
    print(f"wait_stream + graph 重放完成")
    print(f"c[0,0] = {c[0, 0].item():.4f}")
    print()


def example_graph_debug_capture():
    """调试: 检查 graph 是否正在捕获"""
    print("=" * 60)
    print("8. 检查 Graph 捕获状态")
    print("=" * 60)

    N = 512
    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")
    c = torch.empty(N, N, device="cuda")

    g = torch.cuda.CUDAGraph()
    capture_stream = torch.cuda.Stream()

    # 预热
    with torch.cuda.stream(capture_stream):
        c = a @ b
    capture_stream.synchronize()

    # 使用 torch.cuda.is_current_stream_capturing() 检查
    print(f"捕获前: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")

    with torch.cuda.stream(capture_stream):
        g.capture_begin()
        print(f"捕获中: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")
        c = a @ b
        g.capture_end()

    print(f"捕获后: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")

    g.replay()
    torch.cuda.synchronize()
    print(f"Graph 重放成功, c[0,0] = {c[0, 0].item():.4f}")
    print()


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("CUDA 不可用,请在有 GPU 的环境中运行")
        exit(1)

    print(f"设备: {torch.cuda.get_device_name()}")
    print(f"CUDA 版本: {torch.version.cuda}")
    print()

    example_basic_cuda_graph()
    example_graph_on_custom_stream()
    example_graph_with_event_sync()
    example_multiple_graphs_parallel()
    example_graph_mempool()
    example_graph_reset_and_reuse()
    example_graph_with_stream_wait()
    example_graph_debug_capture()

event_basics

复制代码
"""
torch.cuda.Event 基础用法示例

CUDA Event 用于:
1. 记录 stream 中某个时间点
2. 测量两个 event 之间的时间间隔 (GPU 端计时)
3. 在不同 stream 之间做精确同步
4. 阻塞 CPU 直到 GPU 到达某个 event 点

主要 API:
- torch.cuda.Event(enable_timing, blocking, interprocess)  创建 event
- event.record(stream)           在指定 stream 中记录 event
- event.wait(stream)             让指定 stream 等待 event 完成
- event.synchronize()            阻塞 CPU 直到 event 完成
- event.query()                  非阻塞检查 event 是否完成
- torch.cuda.Event.elapsed_time  两个 event 之间的时间 (ms)
- event.ipc_handle()             获取用于跨进程共享的 IPC handle
"""

import torch
import time


def example_event_basic():
    """Event 基础: record / synchronize / query"""
    print("=" * 60)
    print("1. Event 基础: record / synchronize / query")
    print("=" * 60)

    a = torch.randn(2000, 2000, device="cuda")

    # 创建 event (默认 enable_timing=True, blocking=False)
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    # record: 在当前 stream 中记录 event
    start_event.record()
    _ = a @ a
    end_event.record()

    # synchronize: 阻塞 CPU 直到 event 完成
    end_event.synchronize()

    # elapsed_time: 获取两个 event 之间的时间 (ms)
    elapsed = start_event.elapsed_time(end_event)
    print(f"矩阵乘法耗时: {elapsed:.3f} ms")

    # query: 非阻塞检查 event 是否完成
    print(f"end_event 已完成: {end_event.query()}")
    print()


def example_event_stream_sync():
    """使用 Event 在不同 stream 之间同步"""
    print("=" * 60)
    print("2. Event 跨 Stream 同步")
    print("=" * 60)

    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    a = torch.randn(2000, 2000, device="cuda")

    # Stream s1 上做计算
    with torch.cuda.stream(s1):
        b = a @ a

    # 在 s1 上记录 event
    event = torch.cuda.Event()
    event.record(s1)

    # Stream s2 等待 event (即等待 s1 完成)
    with torch.cuda.stream(s2):
        event.wait(s2)  # s2 阻塞直到 s1 到达 event 点
        c = b @ b       # 此时 b 已就绪

    s2.synchronize()
    print(f"b shape: {b.shape}, c shape: {c.shape}")
    print("Event 跨 stream 同步成功!")
    print()


def example_event_timing():
    """使用 Event 精确测量 GPU 操作时间"""
    print("=" * 60)
    print("3. Event 精确计时")
    print("=" * 60)

    sizes = [512, 1024, 2048, 4096]

    for N in sizes:
        a = torch.randn(N, N, device="cuda")
        b = torch.randn(N, N, device="cuda")

        # 预热
        _ = a @ b
        torch.cuda.synchronize()

        # 使用 event 计时
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        _ = a @ b
        end.record()
        end.synchronize()

        elapsed = start.elapsed_time(end)
        print(f"  [{N}x{N}] 矩阵乘法: {elapsed:.3f} ms")

    print()


def example_event_blocking():
    """blocking event: 在 synchronize 时立即阻塞 CPU"""
    print("=" * 60)
    print("4. Blocking Event")
    print("=" * 60)

    a = torch.randn(1000, 1000, device="cuda")

    # blocking=True: synchronize() 会阻塞调用线程
    blocking_event = torch.cuda.Event(blocking=True)
    non_blocking_event = torch.cuda.Event(blocking=False)

    print(f"blocking event: {blocking_event}")
    print(f"non-blocking event: {non_blocking_event}")

    blocking_event.record()
    _ = a @ a
    blocking_event.record()

    # blocking event 的 synchronize 会真正阻塞 CPU 线程
    blocking_event.synchronize()
    print("Blocking event 同步完成")
    print()


def example_event_interprocess():
    """跨进程 Event 共享 (IPC)"""
    print("=" * 60)
    print("5. Interprocess Event (IPC)")
    print("=" * 60)

    # 创建可用于跨进程共享的 event
    ipc_event = torch.cuda.Event(
        enable_timing=False,
        blocking=True,
        interprocess=True,
    )

    ipc_event.record()
    ipc_event.synchronize()

    # 获取 IPC handle (可以传给其他进程)
    handle = ipc_event.ipc_handle()
    print(f"IPC handle type: {type(handle)}")
    print(f"IPC handle bytes: {handle.hex()[:40]}...")

    # 在另一个进程中可以通过 torch.cuda.ipc_collect() 来使用
    print("(另一个进程可用 torch.cuda.ipc_collect() 获取此 event)")
    print()


def example_multiple_events_pipeline():
    """使用多个 Event 构建流水线同步"""
    print("=" * 60)
    print("6. 多 Event 流水线同步")
    print("=" * 60)

    # 三个 stream 形成流水线: s1 -> s2 -> s3
    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()
    s3 = torch.cuda.Stream()

    a = torch.randn(2000, 2000, device="cuda")

    # Event 用于流水线各阶段同步
    e1_done = torch.cuda.Event()
    e2_done = torch.cuda.Event()

    # Stage 1: s1 上做计算
    with torch.cuda.stream(s1):
        b = a @ a
    e1_done.record(s1)

    # Stage 2: s2 等待 s1 完成
    with torch.cuda.stream(s2):
        e1_done.wait(s2)
        c = b @ b
    e2_done.record(s2)

    # Stage 3: s3 等待 s2 完成
    with torch.cuda.stream(s3):
        e2_done.wait(s3)
        d = c.relu()

    s3.synchronize()
    print(f"流水线完成: a -> b -> c -> d")
    print(f"  a: {a.shape}, b: {b.shape}, c: {c.shape}, d: {d.shape}")
    print()


def example_event_vs_stream_sync():
    """对比 Event 同步 vs Stream 同步的精度"""
    print("=" * 60)
    print("7. Event 同步 vs Stream 同步对比")
    print("=" * 60)

    N = 4096
    a = torch.randn(N, N, device="cuda")
    b = torch.randn(N, N, device="cuda")

    # 方式 1: Event 精确计时
    torch.cuda.synchronize()
    e_start = torch.cuda.Event(enable_timing=True)
    e_end = torch.cuda.Event(enable_timing=True)
    e_start.record()
    _ = a @ b
    e_end.record()
    e_end.synchronize()
    event_time = e_start.elapsed_time(e_end)

    # 方式 2: CPU 端计时 (不精确, 包含 launch overhead)
    torch.cuda.synchronize()
    cpu_start = time.perf_counter()
    _ = a @ b
    torch.cuda.synchronize()
    cpu_time = (time.perf_counter() - cpu_start) * 1000

    print(f"Event 计时 (GPU 端): {event_time:.3f} ms")
    print(f"CPU  计时 (含 launch): {cpu_time:.3f} ms")
    print(f"差异 (launch overhead): {cpu_time - event_time:.3f} ms")
    print()


def example_event_record_without_stream():
    """Event record 不指定 stream 的行为"""
    print("=" * 60)
    print("8. Event record 默认行为")
    print("=" * 60)

    s = torch.cuda.Stream()
    a = torch.randn(1000, 1000, device="cuda")

    # record() 不指定 stream: 使用当前 stream
    # 默认当前 stream 是 default stream
    default_cur = torch.cuda.current_stream()

    e1 = torch.cuda.Event()
    e1.record()  # 在 default stream 上记录
    _ = a @ a

    # 在自定义 stream 上记录
    with torch.cuda.stream(s):
        e2 = torch.cuda.Event()
        e2.record(s)  # 显式指定在 s 上记录
        _ = a @ a

    e1.synchronize()
    e2.synchronize()
    print(f"e1 记录在: default stream")
    print(f"e2 记录在: {s}")
    print()


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("CUDA 不可用,请在有 GPU 的环境中运行")
        exit(1)

    print(f"设备: {torch.cuda.get_device_name()}")
    print(f"CUDA 版本: {torch.version.cuda}")
    print()

    example_event_basic()
    example_event_stream_sync()
    example_event_timing()
    example_event_blocking()
    example_event_interprocess()
    example_multiple_events_pipeline()
    example_event_vs_stream_sync()
    example_event_record_without_stream()

stream_basics

复制代码
"""
torch.cuda.Stream 基础用法示例

CUDA Stream 表示一个独立的命令队列,同一 stream 内的操作按序执行,
不同 stream 之间的操作可以并行执行。

主要 API:
- torch.cuda.Stream(device, priority)  创建 stream
- stream.wait_stream(other_stream)     让当前 stream 等待另一个 stream
- torch.cuda.current_stream()          获取当前默认 stream
- torch.cuda.default_stream()          获取默认 stream
- stream.synchronize()                 阻塞 CPU 直到 stream 中所有操作完成
- torch.cuda.synchronize()             同步所有设备上的所有 stream
"""

import torch
import time


def example_create_and_use_stream():
    """创建 stream 并在其上执行操作"""
    print("=" * 60)
    print("1. 创建 Stream 并执行操作")
    print("=" * 60)

    s = torch.cuda.Stream()
    a = torch.randn(1000, 1000, device="cuda")
    b = torch.randn(1000, 1000, device="cuda")

    # 在 stream s 上执行矩阵乘法
    with torch.cuda.stream(s):
        c = a @ b
        d = c.relu()

    # 等待 stream s 完成
    s.synchronize()
    print(f"结果 shape: {d.shape}, 设备: {d.device}")
    print()


def example_priority_stream():
    """创建不同优先级的 stream"""
    print("=" * 60)
    print("2. 优先级 Stream")
    print("=" * 60)

    # 优先级范围: 最低优先级到最高优先级
    low_priority = torch.cuda.Stream(priority=-1)
    high_priority = torch.cuda.Stream(priority=0)

    print(f"低优先级 stream: {low_priority}")
    print(f"高优先级 stream: {high_priority}")
    print(f"优先级范围: {torch.cuda.Stream.priority_range()}")
    print()


def example_stream_wait_stream():
    """stream 间同步: wait_stream"""
    print("=" * 60)
    print("3. Stream 间同步 (wait_stream)")
    print("=" * 60)

    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    a = torch.randn(2000, 2000, device="cuda")

    # Stream s1 上做计算
    with torch.cuda.stream(s1):
        b = a @ a  # s1 上执行

    # Stream s2 等待 s1 完成后再执行
    with torch.cuda.stream(s2):
        s2.wait_stream(s1)  # s2 等待 s1 完成
        c = b @ b  # 需要 b 已计算完毕

    s2.synchronize()
    print(f"b shape: {b.shape}, c shape: {c.shape}")
    print()


def example_multiple_streams_parallel():
    """多 stream 并行执行,展示重叠计算"""
    print("=" * 60)
    print("4. 多 Stream 并行执行")
    print("=" * 60)

    N = 4096
    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()

    a1 = torch.randn(N, N, device="cuda")
    a2 = torch.randn(N, N, device="cuda")

    # 预热 GPU
    torch.cuda.synchronize()

    # 在两个 stream 上同时发起计算
    start = time.perf_counter()
    with torch.cuda.stream(s1):
        r1 = a1 @ a1
    with torch.cuda.stream(s2):
        r2 = a2 @ a2

    # 等待两个 stream 都完成
    s1.synchronize()
    s2.synchronize()
    parallel_time = time.perf_counter() - start

    # 对比:默认 stream 串行执行
    torch.cuda.synchronize()
    start = time.perf_counter()
    r1 = a1 @ a1
    r2 = a2 @ a2
    torch.cuda.synchronize()
    serial_time = time.perf_counter() - start

    print(f"并行执行时间: {parallel_time * 1000:.2f} ms")
    print(f"串行执行时间: {serial_time * 1000:.2f} ms")
    print()


def example_current_and_default_stream():
    """获取当前 stream 和默认 stream"""
    print("=" * 60)
    print("5. current_stream / default_stream")
    print("=" * 60)

    s = torch.cuda.Stream()

    # 默认 stream
    default = torch.cuda.default_stream()
    print(f"默认 stream: {default}")

    # 在 s 内部,current_stream 就是 s
    with torch.cuda.stream(s):
        cur = torch.cuda.current_stream()
        print(f"stream context 内的 current_stream: {cur}")
        print(f"与 s 相同: {cur == s}")

    # 退出 context 后,current_stream 恢复为默认
    cur = torch.cuda.current_stream()
    print(f"退出 context 后的 current_stream: {cur}")
    print(f"与默认 stream 相同: {cur == default}")
    print()


def example_stream_query():
    """查询 stream 是否已完成"""
    print("=" * 60)
    print("6. Stream query (非阻塞检查)")
    print("=" * 60)

    s = torch.cuda.Stream()
    a = torch.randn(2000, 2000, device="cuda")

    with torch.cuda.stream(s):
        _ = a @ a

    # query() 返回 True 表示 stream 中所有操作已完成
    # 返回 False 表示还有操作在运行
    while not s.query():
        print("  Stream 仍在运行...")
        time.sleep(0.01)

    print("  Stream 已完成!")
    print()


def example_stream_synchronize_device():
    """同步整个设备 vs 单个 stream"""
    print("=" * 60)
    print("7. 设备级同步 vs Stream 级同步")
    print("=" * 60)

    s1 = torch.cuda.Stream()
    s2 = torch.cuda.Stream()
    a = torch.randn(1000, 1000, device="cuda")

    with torch.cuda.stream(s1):
        _ = a @ a
    with torch.cuda.stream(s2):
        _ = a @ a

    # 只等待 s1
    s1.synchronize()
    print("s1 已同步 (s2 可能还在运行)")

    # 等待所有设备上的所有 stream
    torch.cuda.synchronize()
    print("设备级同步完成 (所有 stream 都已完成)")
    print()


def example_external_stream():
    """使用外部 CUDA stream (cudaStream_t)

    场景: 你有一个 C++ 扩展或第三方库 (如 TensorRT, cuBLAS handle)
    传给你一个原始的 cudaStream_t 指针,你需要让 PyTorch 在这个
    stream 上执行操作。ExternalStream 就是做这个桥接的。
    """
    print("=" * 60)
    print("8. 外部 Stream (ExternalStream)")
    print("=" * 60)

    # 模拟: 从外部 C 代码拿到一个 cudaStream_t 指针
    # 实际场景中这个指针来自 cublasGetStream / cudaStreamCreate 等
    s = torch.cuda.Stream()
    raw_ptr = s.cuda_stream  # 模拟外部给的原始指针
    print(f"外部给的 cudaStream_t 指针: {raw_ptr}")

    # 用 ExternalStream 包成 PyTorch 可用的 Stream
    ext_s = torch.cuda.ExternalStream(raw_ptr)
    print(f"包装后的 ExternalStream: {ext_s}")

    # 现在可以在外部 stream 上跑 PyTorch 操作了
    a = torch.randn(500, 500, device="cuda")
    with torch.cuda.stream(ext_s):
        b = a @ a
    ext_s.synchronize()
    print(f"在外部 stream 上完成计算, b shape: {b.shape}")
    print()

    # 真实场景举例:
    #   handle = cublasCreate()
    #   cublasGetStream(handle, &raw_stream)  # 拿到 cuBLAS 的 stream
    #   s = torch.cuda.ExternalStream(raw_stream)
    #   with torch.cuda.stream(s):
    #       ... PyTorch 操作和 cuBLAS 操作在同一个 stream 上 ...
    print("典型用途: 让 PyTorch 和 cuBLAS/TensorRT 等外部库共享同一个 stream")
    print()


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("CUDA 不可用,请在有 GPU 的环境中运行")
        exit(1)

    print(f"设备: {torch.cuda.get_device_name()}")
    print(f"CUDA 版本: {torch.version.cuda}")
    print()

    example_create_and_use_stream()
    example_priority_stream()
    example_stream_wait_stream()
    example_multiple_streams_parallel()
    example_current_and_default_stream()
    example_stream_query()
    example_stream_synchronize_device()
    example_external_stream()

stream_event_combined

复制代码
"""
Stream + Event 联合使用: 多流并行与流水线

展示如何结合 Stream 和 Event 实现:
1. 计算与数据传输重叠 (copy-compute overlap)
2. 多 stream 流水线 (pipeline parallelism)
3. 生产者-消费者模式
4. 精确的多流性能测量
"""

import torch
import time


def example_copy_compute_overlap():
    """计算与 H2D/D2H 数据传输重叠"""
    print("=" * 60)
    print("1. 计算与数据传输重叠")
    print("=" * 60)

    N = 4096
    # CPU 上的数据
    cpu_a = torch.randn(N, N)
    cpu_b = torch.randn(N, N)

    # === 串行方式 (默认 stream) ===
    torch.cuda.synchronize()
    start = time.perf_counter()

    gpu_a = cpu_a.to("cuda", non_blocking=False)
    gpu_b = cpu_b.to("cuda", non_blocking=False)
    gpu_c = gpu_a @ gpu_b
    cpu_c = gpu_c.cpu()

    torch.cuda.synchronize()
    serial_time = time.perf_counter() - start

    # === 并行方式: 使用多 stream 重叠 ===
    torch.cuda.synchronize()
    start = time.perf_counter()

    copy_stream = torch.cuda.Stream()
    compute_stream = torch.cuda.Stream()

    # 先把完整的 b 拷到 GPU (只需一次)
    gpu_b = cpu_b.to("cuda", non_blocking=True)

    # 对 a 分块, 每块和完整的 b 做乘法
    chunk_size = N // 4
    results = []

    for i in range(4):
        start_idx = i * chunk_size
        end_idx = (i + 1) * chunk_size

        chunk_a = cpu_a[start_idx:end_idx].contiguous()  # (1024, 4096)

        # 在 copy_stream 上异步拷贝 chunk_a 到 GPU
        with torch.cuda.stream(copy_stream):
            gpu_chunk_a = chunk_a.to("cuda", non_blocking=True)

        # 用 event 标记拷贝完成
        copy_done = torch.cuda.Event()
        copy_done.record(copy_stream)

        # compute_stream 等待拷贝完成, 然后做计算
        with torch.cuda.stream(compute_stream):
            copy_done.wait(compute_stream)
            # (1024, 4096) @ (4096, 4096) = (1024, 4096)
            gpu_chunk_c = gpu_chunk_a @ gpu_b
            results.append(gpu_chunk_c)

    # 等待所有 stream 完成
    compute_stream.synchronize()
    copy_stream.synchronize()
    parallel_time = time.perf_counter() - start

    print(f"串行时间: {serial_time * 1000:.2f} ms")
    print(f"并行时间: {parallel_time * 1000:.2f} ms")
    print(f"加速比: {serial_time / parallel_time:.2f}x")
    print()


def example_pipeline_parallelism():
    """多阶段流水线: 使用 Event 实现 stage 间同步"""
    print("=" * 60)
    print("2. 多阶段流水线 (Pipeline Parallelism)")
    print("=" * 60)

    N = 2048
    num_stages = 3
    num_microbatches = 5

    # 每个 stage 一个 stream
    streams = [torch.cuda.Stream() for _ in range(num_stages)]

    # 模拟数据
    inputs = [torch.randn(N, N, device="cuda") for _ in range(num_microbatches)]

    # 每个 microbatch 在每个 stage 之间的 event
    # events[i][j]: microbatch i 在 stage j 完成
    events = [[torch.cuda.Event() for _ in range(num_stages)]
              for _ in range(num_microbatches)]

    torch.cuda.synchronize()
    start = time.perf_counter()

    # 流水线执行: microbatch i 在 stage j 上执行
    for j in range(num_stages):
        with torch.cuda.stream(streams[j]):
            for i in range(num_microbatches):
                # 等待前一个 stage 完成 (如果有)
                if j > 0:
                    events[i][j - 1].wait(streams[j])

                # 模拟 stage 计算
                x = inputs[i]
                for _ in range(10):
                    x = x @ x[:N, :N] * 0.01

                inputs[i] = x  # 更新数据

                # 记录当前 stage 完成
                events[i][j].record(streams[j])

    # 等待所有 stream 完成
    for s in streams:
        s.synchronize()

    pipeline_time = time.perf_counter() - start

    # 对比串行
    torch.cuda.synchronize()
    start = time.perf_counter()
    for i in range(num_microbatches):
        x = inputs[i]
        for _ in range(num_stages * 10):
            x = x @ x[:N, :N] * 0.01
    torch.cuda.synchronize()
    serial_time = time.perf_counter() - start

    print(f"流水线时间: {pipeline_time * 1000:.2f} ms")
    print(f"串行时间:   {serial_time * 1000:.2f} ms")
    print()


def example_producer_consumer():
    """生产者-消费者模式: 一个 stream 生产数据, 另一个消费"""
    print("=" * 60)
    print("3. 生产者-消费者模式")
    print("=" * 60)

    N = 2048
    num_items = 4

    producer_stream = torch.cuda.Stream()
    consumer_stream = torch.cuda.Stream()

    # 生产者: 生成数据
    produced_events = []
    data_buffers = []

    a = torch.randn(N, N, device="cuda")

    for i in range(num_items):
        with torch.cuda.stream(producer_stream):
            # 生产数据 (例如: 做矩阵乘法)
            data = a @ a if i == 0 else data_buffers[-1] @ a
            data_buffers.append(data)

        # 标记数据已生产完毕
        ready = torch.cuda.Event()
        ready.record(producer_stream)
        produced_events.append(ready)

    # 消费者: 消费数据
    results = []
    for i in range(num_items):
        with torch.cuda.stream(consumer_stream):
            # 等待数据生产完毕
            produced_events[i].wait(consumer_stream)
            # 消费数据 (例如: 做激活函数)
            result = data_buffers[i].relu()
            results.append(result)

    consumer_stream.synchronize()
    print(f"生产了 {num_items} 个数据, 消费了 {len(results)} 个结果")
    print(f"每个结果 shape: {results[0].shape}")
    print()


def example_precise_multistream_timing():
    """使用 Event 精确测量多 stream 中各阶段耗时"""
    print("=" * 60)
    print("4. 多 Stream 精确计时")
    print("=" * 60)

    N = 4096
    num_streams = 3
    streams = [torch.cuda.Stream() for _ in range(num_streams)]

    # 每个 stream 的起止 event
    start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_streams)]
    end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_streams)]

    data = [torch.randn(N, N, device="cuda") for _ in range(num_streams)]

    # 在各 stream 上发起计算并记录 event
    for i in range(num_streams):
        with torch.cuda.stream(streams[i]):
            start_events[i].record(streams[i])
            _ = data[i] @ data[i]
            end_events[i].record(streams[i])

    # 等待所有完成并报告各 stream 耗时
    for i in range(num_streams):
        end_events[i].synchronize()
        elapsed = start_events[i].elapsed_time(end_events[i])
        print(f"  Stream {i}: {elapsed:.3f} ms")

    # 总体时间 (第一个 start 到最后一个 end)
    total = start_events[0].elapsed_time(end_events[-1])
    print(f"  总耗时 (overlap): {total:.3f} ms")
    print()


def example_data_dependency_graph():
    """复杂数据依赖图: 多个 stream 间的依赖管理"""
    print("=" * 60)
    print("5. 复杂数据依赖图")
    print("=" * 60)

    # 场景:
    #   A_stream: 计算 A
    #   B_stream: 计算 B
    #   C_stream: 计算 C = A + B (依赖 A 和 B)
    #   D_stream: 计算 D = C * 2 (依赖 C)

    A_stream = torch.cuda.Stream()
    B_stream = torch.cuda.Stream()
    C_stream = torch.cuda.Stream()
    D_stream = torch.cuda.Stream()

    N = 2048
    x = torch.randn(N, N, device="cuda")
    y = torch.randn(N, N, device="cuda")

    # A 和 B 可以并行
    with torch.cuda.stream(A_stream):
        A = x @ x
    A_done = torch.cuda.Event()
    A_done.record(A_stream)

    with torch.cuda.stream(B_stream):
        B = y @ y
    B_done = torch.cuda.Event()
    B_done.record(B_stream)

    # C 依赖 A 和 B
    with torch.cuda.stream(C_stream):
        A_done.wait(C_stream)
        B_done.wait(C_stream)
        C = A + B
    C_done = torch.cuda.Event()
    C_done.record(C_stream)

    # D 依赖 C
    with torch.cuda.stream(D_stream):
        C_done.wait(D_stream)
        D = C * 2

    D_stream.synchronize()
    print(f"A: {A.shape}, B: {B.shape}, C: {C.shape}, D: {D.shape}")
    print(f"D 前几个元素: {D[0, :5]}")
    print()


def example_ring_allreduce_simulation():
    """模拟 Ring AllReduce 中的 stream/event 同步模式

    Ring AllReduce 的核心思想:
    - N 个 rank 排成一个环: 0→1→2→3→0
    - 每个 rank 把自己的数据发给下一个 rank (send)
    - 同时从上一个 rank 接收数据并累加 (recv)
    - 转 N-1 圈后,每个 rank 都拥有所有 rank 数据的和

    这里用 stream 和 event 模拟一轮通信:
    - send_stream: 负责把数据发给 next rank
    - recv_stream: 负责从 prev rank 接收数据
    - event: 确保 recv 在 send 完成之后才读取数据
    """
    print("=" * 60)
    print("6. 模拟 Ring AllReduce 同步模式")
    print("=" * 60)

    num_ranks = 4
    N = 512

    # 每个 rank 有一份自己的数据
    data = [torch.ones(N, N, device="cuda") * (i + 1) for i in range(num_ranks)]
    # recv_buf[i]: rank i 用于接收 rank (i-1) 发来数据的缓冲区
    recv_buf = [torch.empty(N, N, device="cuda") for _ in range(num_ranks)]

    # 每个 rank 有独立的 send stream 和 recv stream
    send_streams = [torch.cuda.Stream() for _ in range(num_ranks)]
    recv_streams = [torch.cuda.Stream() for _ in range(num_ranks)]

    # 一轮 ring 通信: 每个 rank 同时 send 和 recv
    # send_done 和 recv_done 预先创建好
    send_done = [torch.cuda.Event() for _ in range(num_ranks)]
    recv_done = [torch.cuda.Event() for _ in range(num_ranks)]

    for rank in range(num_ranks):
        next_rank = (rank + 1) % num_ranks   # 发给下一个
        prev_rank = (rank - 1) % num_ranks   # 从上一个收

        # send: 把自己的数据拷到 next rank 的 recv_buf
        with torch.cuda.stream(send_streams[rank]):
            recv_buf[next_rank].copy_(data[rank])
        send_done[rank].record(send_streams[rank])

        # recv: 等 prev rank 发完, 然后累加
        # 注意: send_done[prev_rank] 可能还没 record
        # 但 CUDA 的 wait 不要求 event 已经 record --- 它只是排队等,
        # 等 prev rank 的 send 完成后自动放行
        with torch.cuda.stream(recv_streams[rank]):
            send_done[prev_rank].wait(recv_streams[rank])
            data[rank].add_(recv_buf[rank])
        recv_done[rank].record(recv_streams[rank])

    # 等待所有完成
    for s in send_streams + recv_streams:
        s.synchronize()
    print()
    print("=== recv + reduce 完成 (每个 rank 累加了 prev rank 的数据) ===")
    for i in range(num_ranks):
        prev_i = (i - 1) % num_ranks
        expected = (i + 1) + (prev_i + 1)
        print(f"  rank {i}: data[0,0] = {data[i][0, 0].item():.0f} (期望: 自己的{i+1} + prev的{prev_i+1} = {expected})")

    print()
    print("关键点:")
    print("  - send 和 recv 在不同 stream 上,可以并行执行")
    print("  - 用 event 确保 recv 在对应 send 完成后才读取数据")
    print("  - 真实 Ring AllReduce 要转 N-1 圈,这里只展示一轮")
    print()


if __name__ == "__main__":
    if not torch.cuda.is_available():
        print("CUDA 不可用,请在有 GPU 的环境中运行")
        exit(1)

    print(f"设备: {torch.cuda.get_device_name()}")
    print(f"CUDA 版本: {torch.version.cuda}")
    print()

    example_copy_compute_overlap()
    example_pipeline_parallelism()
    example_producer_consumer()
    example_precise_multistream_timing()
    example_data_dependency_graph()
    example_ring_allreduce_simulation()

event record到底啥意思

event.record(s1) 的意思是:在 stream s1 的命令队列中插入一个标记点(marker)

具体来说:

  1. 它不阻塞任何东西 --- record() 调用本身是异步的,只是往 stream 里塞了一个"路标",CPU 立即返回。
  2. GPU 按序执行 --- 当 stream s1record() 之前的所有操作都在 GPU 上执行完毕后,这个 event 就被标记为"已完成"。
  3. 之后用它做同步 --- 你可以用这个 event 来:
    • event.wait(s2) --- 让另一个 stream s2 停在原地,直到 s1 到达这个标记点

    • event.synchronize() --- 阻塞 CPU 线程,直到 s1 到达这个标记点

    • event.query() --- 非阻塞检查 GPU 是否已经执行到这个标记点

      Stream s1 上做计算

      with torch.cuda.stream(s1):
      b = a @ a

      在 s1 上记录 event

      event = torch.cuda.Event()
      event.record(s1)

      Stream s2 等待 event (即等待 s1 完成)

      with torch.cuda.stream(s2):
      event.wait(s2) # s2 阻塞直到 s1 到达 event 点
      c = b @ b # 此时 b 已就绪

这里的执行顺序是:

复制代码
时间线 →

s1:  [ a @ a ] ──→ [ event 标记 ] 
                        │
s2:                     └──→ event.wait(s2) 阻塞在此 ──→ [ b @ b ]
                                   ↑
                          等 s1 跑到 event 点后才放行

所以 event.record(s1) 本质上就是一个"栅栏标记"------在 stream 里画一条线,之后你可以让其他 stream 或 CPU 等这条线。