cuda_graph_examples
"""
CUDA Graph 中的 Stream 和 Event 用法
CUDA Graph 将一系列 kernel 操作录制为图,然后可以重放,大幅减少 launch overhead。
Stream 和 Event 在 CUDA Graph 中的关键作用:
1. graph 必须在特定 stream 上捕获 (capture)
2. graph 必须在特定 stream 上重放 (replay)
3. Event 用于 graph 与外部 stream 之间的同步
主要 API:
- torch.cuda.CUDAGraph() 创建 graph
- graph.capture_begin(pool) 开始捕获 (需要 mempool)
- graph.capture_end() 结束捕获
- graph.replay() 重放 graph
- graph.reset() 重置 graph
- torch.cuda.graph_pool_handle() 获取当前设备的 mempool
- torch.cuda.graph(graph, stream) 在指定 stream 上捕获 (上下文管理器)
"""
import torch
import time
def example_basic_cuda_graph():
"""CUDA Graph 基础: 捕获与重放"""
print("=" * 60)
print("1. CUDA Graph 基础: 捕获与重放")
print("=" * 60)
N = 4096
# 静态输入 (graph 要求输入地址不变)
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
# 创建 graph
g = torch.cuda.CUDAGraph()
# 预热 (确保 kernel 已编译)
c = a @ b
torch.cuda.synchronize()
# === 捕获 graph ===
# 必须在非默认 stream 上捕获
capture_stream = torch.cuda.Stream()
with torch.cuda.stream(capture_stream):
g.capture_begin()
c = a @ b
g.capture_end()
# === 重放 graph ===
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
g.replay()
torch.cuda.synchronize()
graph_time = (time.perf_counter() - start) / 100 * 1000
# 对比: 直接调用
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(100):
c = a @ b
torch.cuda.synchronize()
eager_time = (time.perf_counter() - start) / 100 * 1000
print(f"Graph 重放平均耗时: {graph_time:.4f} ms")
print(f"Eager 执行平均耗时: {eager_time:.4f} ms")
print(f"加速比: {eager_time / graph_time:.2f}x")
print()
def example_graph_on_custom_stream():
"""在自定义 stream 上捕获和重放 graph"""
print("=" * 60)
print("2. 自定义 Stream 上的 CUDA Graph")
print("=" * 60)
N = 2048
s = torch.cuda.Stream()
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
# 预热
with torch.cuda.stream(s):
c = a @ b
s.synchronize()
# 在自定义 stream s 上捕获 graph
with torch.cuda.stream(s):
g.capture_begin()
c = a @ b
g.capture_end()
# 在自定义 stream s 上重放
with torch.cuda.stream(s):
g.replay()
s.synchronize()
print(f"Graph 在自定义 stream 上重放成功")
print(f"c shape: {c.shape}, device: {c.device}")
print()
def example_graph_with_event_sync():
"""使用 Event 在 graph 重放和普通操作之间同步"""
print("=" * 60)
print("3. Graph + Event 同步")
print("=" * 60)
N = 2048
graph_stream = torch.cuda.Stream()
compute_stream = torch.cuda.Stream()
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
d = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
# 预热
with torch.cuda.stream(graph_stream):
c = a @ b
graph_stream.synchronize()
# 在 graph_stream 上捕获
with torch.cuda.stream(graph_stream):
g.capture_begin()
c = a @ b
g.capture_end()
# 场景: graph 重放完成后, compute_stream 上做后续计算
with torch.cuda.stream(graph_stream):
g.replay()
# 用 event 标记 graph 重放完成
graph_done = torch.cuda.Event()
graph_done.record(graph_stream)
# compute_stream 等待 graph 完成
with torch.cuda.stream(compute_stream):
graph_done.wait(compute_stream)
d = c.relu() # 需要 c 已计算完毕
compute_stream.synchronize()
print(f"Graph 重放 + Event 同步完成")
print(f"c shape: {c.shape}, d shape: {d.shape}")
print(f"d 前 5 个元素: {d[0, :5]}")
print()
def example_multiple_graphs_parallel():
"""多个 graph 在不同 stream 上并行重放"""
print("=" * 60)
print("4. 多 Graph 并行重放")
print("=" * 60)
N = 2048
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
# 两组独立数据
a1 = torch.randn(N, N, device="cuda")
b1 = torch.randn(N, N, device="cuda")
c1 = torch.empty(N, N, device="cuda")
a2 = torch.randn(N, N, device="cuda")
b2 = torch.randn(N, N, device="cuda")
c2 = torch.empty(N, N, device="cuda")
g1 = torch.cuda.CUDAGraph()
g2 = torch.cuda.CUDAGraph()
# 预热
with torch.cuda.stream(s1):
c1 = a1 @ b1
with torch.cuda.stream(s2):
c2 = a2 @ b2
s1.synchronize()
s2.synchronize()
# 分别在各自的 stream 上捕获
with torch.cuda.stream(s1):
g1.capture_begin()
c1 = a1 @ b1
g1.capture_end()
with torch.cuda.stream(s2):
g2.capture_begin()
c2 = a2 @ b2
g2.capture_end()
# 并行重放
torch.cuda.synchronize()
start = time.perf_counter()
with torch.cuda.stream(s1):
g1.replay()
with torch.cuda.stream(s2):
g2.replay()
s1.synchronize()
s2.synchronize()
parallel_time = time.perf_counter() - start
# 串行重放对比
torch.cuda.synchronize()
start = time.perf_counter()
g1.replay()
g2.replay()
torch.cuda.synchronize()
serial_time = time.perf_counter() - start
print(f"并行重放时间: {parallel_time * 1000:.3f} ms")
print(f"串行重放时间: {serial_time * 1000:.3f} ms")
print()
def example_graph_mempool():
"""CUDA Graph 内存池 (mempool) 的使用"""
print("=" * 60)
print("5. Graph Memory Pool")
print("=" * 60)
N = 2048
# 获取当前设备的 graph mempool handle
pool = torch.cuda.graph_pool_handle()
print(f"Mempool handle: {pool}")
# 也可以为其他设备获取
if torch.cuda.device_count() > 1:
with torch.cuda.device(1):
pool1 = torch.cuda.graph_pool_handle()
print(f"Device 1 mempool: {pool1}")
# 使用 mempool 创建 graph
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
# 预热
c = a @ b
torch.cuda.synchronize()
# 使用 mempool 捕获 (必须在非默认 stream 上)
capture_stream = torch.cuda.Stream()
with torch.cuda.stream(capture_stream):
g.capture_begin(pool=pool)
c = a @ b
g.capture_end()
g.replay()
torch.cuda.synchronize()
print(f"使用 mempool 的 graph 重放成功")
print()
def example_graph_reset_and_reuse():
"""Graph 的 reset 和复用"""
print("=" * 60)
print("6. Graph Reset 和复用")
print("=" * 60)
N = 1024
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.Stream()
# 第一次捕获: matmul (必须在非默认 stream 上)
with torch.cuda.stream(capture_stream):
c = a @ b
capture_stream.synchronize()
with torch.cuda.stream(capture_stream):
g.capture_begin()
c = a @ b
g.capture_end()
g.replay()
torch.cuda.synchronize()
print(f"第一次重放 (matmul): c[0,0] = {c[0, 0].item():.4f}")
# reset 后可以重新捕获不同的操作
g.reset()
# 第二次捕获: matmul + relu (必须在非默认 stream 上)
with torch.cuda.stream(capture_stream):
g.capture_begin()
c = (a @ b).relu()
g.capture_end()
g.replay()
torch.cuda.synchronize()
print(f"第二次重放 (matmul+relu): c[0,0] = {c[0, 0].item():.4f}")
print()
def example_graph_with_stream_wait():
"""Graph 重放 stream 与外部 stream 的 wait_stream 同步"""
print("=" * 60)
print("7. Graph Stream + wait_stream 同步")
print("=" * 60)
N = 2048
graph_stream = torch.cuda.Stream()
data_stream = torch.cuda.Stream()
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
# 预热
with torch.cuda.stream(graph_stream):
c = a @ b
graph_stream.synchronize()
# 捕获
with torch.cuda.stream(graph_stream):
g.capture_begin()
c = a @ b
g.capture_end()
# 场景: data_stream 上先做一些预处理 (原地修改, graph 重放时能读到)
with torch.cuda.stream(data_stream):
a.mul_(2.0) # 原地修改 a, graph 重放时 c = a @ b 会用到新值
b.mul_(0.5) # 原地修改 b
# graph_stream 等待 data_stream 完成
with torch.cuda.stream(graph_stream):
graph_stream.wait_stream(data_stream)
g.replay()
graph_stream.synchronize()
print(f"wait_stream + graph 重放完成")
print(f"c[0,0] = {c[0, 0].item():.4f}")
print()
def example_graph_debug_capture():
"""调试: 检查 graph 是否正在捕获"""
print("=" * 60)
print("8. 检查 Graph 捕获状态")
print("=" * 60)
N = 512
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
c = torch.empty(N, N, device="cuda")
g = torch.cuda.CUDAGraph()
capture_stream = torch.cuda.Stream()
# 预热
with torch.cuda.stream(capture_stream):
c = a @ b
capture_stream.synchronize()
# 使用 torch.cuda.is_current_stream_capturing() 检查
print(f"捕获前: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")
with torch.cuda.stream(capture_stream):
g.capture_begin()
print(f"捕获中: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")
c = a @ b
g.capture_end()
print(f"捕获后: is_current_stream_capturing = {torch.cuda.is_current_stream_capturing()}")
g.replay()
torch.cuda.synchronize()
print(f"Graph 重放成功, c[0,0] = {c[0, 0].item():.4f}")
print()
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA 不可用,请在有 GPU 的环境中运行")
exit(1)
print(f"设备: {torch.cuda.get_device_name()}")
print(f"CUDA 版本: {torch.version.cuda}")
print()
example_basic_cuda_graph()
example_graph_on_custom_stream()
example_graph_with_event_sync()
example_multiple_graphs_parallel()
example_graph_mempool()
example_graph_reset_and_reuse()
example_graph_with_stream_wait()
example_graph_debug_capture()
event_basics
"""
torch.cuda.Event 基础用法示例
CUDA Event 用于:
1. 记录 stream 中某个时间点
2. 测量两个 event 之间的时间间隔 (GPU 端计时)
3. 在不同 stream 之间做精确同步
4. 阻塞 CPU 直到 GPU 到达某个 event 点
主要 API:
- torch.cuda.Event(enable_timing, blocking, interprocess) 创建 event
- event.record(stream) 在指定 stream 中记录 event
- event.wait(stream) 让指定 stream 等待 event 完成
- event.synchronize() 阻塞 CPU 直到 event 完成
- event.query() 非阻塞检查 event 是否完成
- torch.cuda.Event.elapsed_time 两个 event 之间的时间 (ms)
- event.ipc_handle() 获取用于跨进程共享的 IPC handle
"""
import torch
import time
def example_event_basic():
"""Event 基础: record / synchronize / query"""
print("=" * 60)
print("1. Event 基础: record / synchronize / query")
print("=" * 60)
a = torch.randn(2000, 2000, device="cuda")
# 创建 event (默认 enable_timing=True, blocking=False)
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
# record: 在当前 stream 中记录 event
start_event.record()
_ = a @ a
end_event.record()
# synchronize: 阻塞 CPU 直到 event 完成
end_event.synchronize()
# elapsed_time: 获取两个 event 之间的时间 (ms)
elapsed = start_event.elapsed_time(end_event)
print(f"矩阵乘法耗时: {elapsed:.3f} ms")
# query: 非阻塞检查 event 是否完成
print(f"end_event 已完成: {end_event.query()}")
print()
def example_event_stream_sync():
"""使用 Event 在不同 stream 之间同步"""
print("=" * 60)
print("2. Event 跨 Stream 同步")
print("=" * 60)
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
a = torch.randn(2000, 2000, device="cuda")
# Stream s1 上做计算
with torch.cuda.stream(s1):
b = a @ a
# 在 s1 上记录 event
event = torch.cuda.Event()
event.record(s1)
# Stream s2 等待 event (即等待 s1 完成)
with torch.cuda.stream(s2):
event.wait(s2) # s2 阻塞直到 s1 到达 event 点
c = b @ b # 此时 b 已就绪
s2.synchronize()
print(f"b shape: {b.shape}, c shape: {c.shape}")
print("Event 跨 stream 同步成功!")
print()
def example_event_timing():
"""使用 Event 精确测量 GPU 操作时间"""
print("=" * 60)
print("3. Event 精确计时")
print("=" * 60)
sizes = [512, 1024, 2048, 4096]
for N in sizes:
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
# 预热
_ = a @ b
torch.cuda.synchronize()
# 使用 event 计时
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
_ = a @ b
end.record()
end.synchronize()
elapsed = start.elapsed_time(end)
print(f" [{N}x{N}] 矩阵乘法: {elapsed:.3f} ms")
print()
def example_event_blocking():
"""blocking event: 在 synchronize 时立即阻塞 CPU"""
print("=" * 60)
print("4. Blocking Event")
print("=" * 60)
a = torch.randn(1000, 1000, device="cuda")
# blocking=True: synchronize() 会阻塞调用线程
blocking_event = torch.cuda.Event(blocking=True)
non_blocking_event = torch.cuda.Event(blocking=False)
print(f"blocking event: {blocking_event}")
print(f"non-blocking event: {non_blocking_event}")
blocking_event.record()
_ = a @ a
blocking_event.record()
# blocking event 的 synchronize 会真正阻塞 CPU 线程
blocking_event.synchronize()
print("Blocking event 同步完成")
print()
def example_event_interprocess():
"""跨进程 Event 共享 (IPC)"""
print("=" * 60)
print("5. Interprocess Event (IPC)")
print("=" * 60)
# 创建可用于跨进程共享的 event
ipc_event = torch.cuda.Event(
enable_timing=False,
blocking=True,
interprocess=True,
)
ipc_event.record()
ipc_event.synchronize()
# 获取 IPC handle (可以传给其他进程)
handle = ipc_event.ipc_handle()
print(f"IPC handle type: {type(handle)}")
print(f"IPC handle bytes: {handle.hex()[:40]}...")
# 在另一个进程中可以通过 torch.cuda.ipc_collect() 来使用
print("(另一个进程可用 torch.cuda.ipc_collect() 获取此 event)")
print()
def example_multiple_events_pipeline():
"""使用多个 Event 构建流水线同步"""
print("=" * 60)
print("6. 多 Event 流水线同步")
print("=" * 60)
# 三个 stream 形成流水线: s1 -> s2 -> s3
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
s3 = torch.cuda.Stream()
a = torch.randn(2000, 2000, device="cuda")
# Event 用于流水线各阶段同步
e1_done = torch.cuda.Event()
e2_done = torch.cuda.Event()
# Stage 1: s1 上做计算
with torch.cuda.stream(s1):
b = a @ a
e1_done.record(s1)
# Stage 2: s2 等待 s1 完成
with torch.cuda.stream(s2):
e1_done.wait(s2)
c = b @ b
e2_done.record(s2)
# Stage 3: s3 等待 s2 完成
with torch.cuda.stream(s3):
e2_done.wait(s3)
d = c.relu()
s3.synchronize()
print(f"流水线完成: a -> b -> c -> d")
print(f" a: {a.shape}, b: {b.shape}, c: {c.shape}, d: {d.shape}")
print()
def example_event_vs_stream_sync():
"""对比 Event 同步 vs Stream 同步的精度"""
print("=" * 60)
print("7. Event 同步 vs Stream 同步对比")
print("=" * 60)
N = 4096
a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")
# 方式 1: Event 精确计时
torch.cuda.synchronize()
e_start = torch.cuda.Event(enable_timing=True)
e_end = torch.cuda.Event(enable_timing=True)
e_start.record()
_ = a @ b
e_end.record()
e_end.synchronize()
event_time = e_start.elapsed_time(e_end)
# 方式 2: CPU 端计时 (不精确, 包含 launch overhead)
torch.cuda.synchronize()
cpu_start = time.perf_counter()
_ = a @ b
torch.cuda.synchronize()
cpu_time = (time.perf_counter() - cpu_start) * 1000
print(f"Event 计时 (GPU 端): {event_time:.3f} ms")
print(f"CPU 计时 (含 launch): {cpu_time:.3f} ms")
print(f"差异 (launch overhead): {cpu_time - event_time:.3f} ms")
print()
def example_event_record_without_stream():
"""Event record 不指定 stream 的行为"""
print("=" * 60)
print("8. Event record 默认行为")
print("=" * 60)
s = torch.cuda.Stream()
a = torch.randn(1000, 1000, device="cuda")
# record() 不指定 stream: 使用当前 stream
# 默认当前 stream 是 default stream
default_cur = torch.cuda.current_stream()
e1 = torch.cuda.Event()
e1.record() # 在 default stream 上记录
_ = a @ a
# 在自定义 stream 上记录
with torch.cuda.stream(s):
e2 = torch.cuda.Event()
e2.record(s) # 显式指定在 s 上记录
_ = a @ a
e1.synchronize()
e2.synchronize()
print(f"e1 记录在: default stream")
print(f"e2 记录在: {s}")
print()
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA 不可用,请在有 GPU 的环境中运行")
exit(1)
print(f"设备: {torch.cuda.get_device_name()}")
print(f"CUDA 版本: {torch.version.cuda}")
print()
example_event_basic()
example_event_stream_sync()
example_event_timing()
example_event_blocking()
example_event_interprocess()
example_multiple_events_pipeline()
example_event_vs_stream_sync()
example_event_record_without_stream()
stream_basics
"""
torch.cuda.Stream 基础用法示例
CUDA Stream 表示一个独立的命令队列,同一 stream 内的操作按序执行,
不同 stream 之间的操作可以并行执行。
主要 API:
- torch.cuda.Stream(device, priority) 创建 stream
- stream.wait_stream(other_stream) 让当前 stream 等待另一个 stream
- torch.cuda.current_stream() 获取当前默认 stream
- torch.cuda.default_stream() 获取默认 stream
- stream.synchronize() 阻塞 CPU 直到 stream 中所有操作完成
- torch.cuda.synchronize() 同步所有设备上的所有 stream
"""
import torch
import time
def example_create_and_use_stream():
"""创建 stream 并在其上执行操作"""
print("=" * 60)
print("1. 创建 Stream 并执行操作")
print("=" * 60)
s = torch.cuda.Stream()
a = torch.randn(1000, 1000, device="cuda")
b = torch.randn(1000, 1000, device="cuda")
# 在 stream s 上执行矩阵乘法
with torch.cuda.stream(s):
c = a @ b
d = c.relu()
# 等待 stream s 完成
s.synchronize()
print(f"结果 shape: {d.shape}, 设备: {d.device}")
print()
def example_priority_stream():
"""创建不同优先级的 stream"""
print("=" * 60)
print("2. 优先级 Stream")
print("=" * 60)
# 优先级范围: 最低优先级到最高优先级
low_priority = torch.cuda.Stream(priority=-1)
high_priority = torch.cuda.Stream(priority=0)
print(f"低优先级 stream: {low_priority}")
print(f"高优先级 stream: {high_priority}")
print(f"优先级范围: {torch.cuda.Stream.priority_range()}")
print()
def example_stream_wait_stream():
"""stream 间同步: wait_stream"""
print("=" * 60)
print("3. Stream 间同步 (wait_stream)")
print("=" * 60)
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
a = torch.randn(2000, 2000, device="cuda")
# Stream s1 上做计算
with torch.cuda.stream(s1):
b = a @ a # s1 上执行
# Stream s2 等待 s1 完成后再执行
with torch.cuda.stream(s2):
s2.wait_stream(s1) # s2 等待 s1 完成
c = b @ b # 需要 b 已计算完毕
s2.synchronize()
print(f"b shape: {b.shape}, c shape: {c.shape}")
print()
def example_multiple_streams_parallel():
"""多 stream 并行执行,展示重叠计算"""
print("=" * 60)
print("4. 多 Stream 并行执行")
print("=" * 60)
N = 4096
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
a1 = torch.randn(N, N, device="cuda")
a2 = torch.randn(N, N, device="cuda")
# 预热 GPU
torch.cuda.synchronize()
# 在两个 stream 上同时发起计算
start = time.perf_counter()
with torch.cuda.stream(s1):
r1 = a1 @ a1
with torch.cuda.stream(s2):
r2 = a2 @ a2
# 等待两个 stream 都完成
s1.synchronize()
s2.synchronize()
parallel_time = time.perf_counter() - start
# 对比:默认 stream 串行执行
torch.cuda.synchronize()
start = time.perf_counter()
r1 = a1 @ a1
r2 = a2 @ a2
torch.cuda.synchronize()
serial_time = time.perf_counter() - start
print(f"并行执行时间: {parallel_time * 1000:.2f} ms")
print(f"串行执行时间: {serial_time * 1000:.2f} ms")
print()
def example_current_and_default_stream():
"""获取当前 stream 和默认 stream"""
print("=" * 60)
print("5. current_stream / default_stream")
print("=" * 60)
s = torch.cuda.Stream()
# 默认 stream
default = torch.cuda.default_stream()
print(f"默认 stream: {default}")
# 在 s 内部,current_stream 就是 s
with torch.cuda.stream(s):
cur = torch.cuda.current_stream()
print(f"stream context 内的 current_stream: {cur}")
print(f"与 s 相同: {cur == s}")
# 退出 context 后,current_stream 恢复为默认
cur = torch.cuda.current_stream()
print(f"退出 context 后的 current_stream: {cur}")
print(f"与默认 stream 相同: {cur == default}")
print()
def example_stream_query():
"""查询 stream 是否已完成"""
print("=" * 60)
print("6. Stream query (非阻塞检查)")
print("=" * 60)
s = torch.cuda.Stream()
a = torch.randn(2000, 2000, device="cuda")
with torch.cuda.stream(s):
_ = a @ a
# query() 返回 True 表示 stream 中所有操作已完成
# 返回 False 表示还有操作在运行
while not s.query():
print(" Stream 仍在运行...")
time.sleep(0.01)
print(" Stream 已完成!")
print()
def example_stream_synchronize_device():
"""同步整个设备 vs 单个 stream"""
print("=" * 60)
print("7. 设备级同步 vs Stream 级同步")
print("=" * 60)
s1 = torch.cuda.Stream()
s2 = torch.cuda.Stream()
a = torch.randn(1000, 1000, device="cuda")
with torch.cuda.stream(s1):
_ = a @ a
with torch.cuda.stream(s2):
_ = a @ a
# 只等待 s1
s1.synchronize()
print("s1 已同步 (s2 可能还在运行)")
# 等待所有设备上的所有 stream
torch.cuda.synchronize()
print("设备级同步完成 (所有 stream 都已完成)")
print()
def example_external_stream():
"""使用外部 CUDA stream (cudaStream_t)
场景: 你有一个 C++ 扩展或第三方库 (如 TensorRT, cuBLAS handle)
传给你一个原始的 cudaStream_t 指针,你需要让 PyTorch 在这个
stream 上执行操作。ExternalStream 就是做这个桥接的。
"""
print("=" * 60)
print("8. 外部 Stream (ExternalStream)")
print("=" * 60)
# 模拟: 从外部 C 代码拿到一个 cudaStream_t 指针
# 实际场景中这个指针来自 cublasGetStream / cudaStreamCreate 等
s = torch.cuda.Stream()
raw_ptr = s.cuda_stream # 模拟外部给的原始指针
print(f"外部给的 cudaStream_t 指针: {raw_ptr}")
# 用 ExternalStream 包成 PyTorch 可用的 Stream
ext_s = torch.cuda.ExternalStream(raw_ptr)
print(f"包装后的 ExternalStream: {ext_s}")
# 现在可以在外部 stream 上跑 PyTorch 操作了
a = torch.randn(500, 500, device="cuda")
with torch.cuda.stream(ext_s):
b = a @ a
ext_s.synchronize()
print(f"在外部 stream 上完成计算, b shape: {b.shape}")
print()
# 真实场景举例:
# handle = cublasCreate()
# cublasGetStream(handle, &raw_stream) # 拿到 cuBLAS 的 stream
# s = torch.cuda.ExternalStream(raw_stream)
# with torch.cuda.stream(s):
# ... PyTorch 操作和 cuBLAS 操作在同一个 stream 上 ...
print("典型用途: 让 PyTorch 和 cuBLAS/TensorRT 等外部库共享同一个 stream")
print()
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA 不可用,请在有 GPU 的环境中运行")
exit(1)
print(f"设备: {torch.cuda.get_device_name()}")
print(f"CUDA 版本: {torch.version.cuda}")
print()
example_create_and_use_stream()
example_priority_stream()
example_stream_wait_stream()
example_multiple_streams_parallel()
example_current_and_default_stream()
example_stream_query()
example_stream_synchronize_device()
example_external_stream()
stream_event_combined
"""
Stream + Event 联合使用: 多流并行与流水线
展示如何结合 Stream 和 Event 实现:
1. 计算与数据传输重叠 (copy-compute overlap)
2. 多 stream 流水线 (pipeline parallelism)
3. 生产者-消费者模式
4. 精确的多流性能测量
"""
import torch
import time
def example_copy_compute_overlap():
"""计算与 H2D/D2H 数据传输重叠"""
print("=" * 60)
print("1. 计算与数据传输重叠")
print("=" * 60)
N = 4096
# CPU 上的数据
cpu_a = torch.randn(N, N)
cpu_b = torch.randn(N, N)
# === 串行方式 (默认 stream) ===
torch.cuda.synchronize()
start = time.perf_counter()
gpu_a = cpu_a.to("cuda", non_blocking=False)
gpu_b = cpu_b.to("cuda", non_blocking=False)
gpu_c = gpu_a @ gpu_b
cpu_c = gpu_c.cpu()
torch.cuda.synchronize()
serial_time = time.perf_counter() - start
# === 并行方式: 使用多 stream 重叠 ===
torch.cuda.synchronize()
start = time.perf_counter()
copy_stream = torch.cuda.Stream()
compute_stream = torch.cuda.Stream()
# 先把完整的 b 拷到 GPU (只需一次)
gpu_b = cpu_b.to("cuda", non_blocking=True)
# 对 a 分块, 每块和完整的 b 做乘法
chunk_size = N // 4
results = []
for i in range(4):
start_idx = i * chunk_size
end_idx = (i + 1) * chunk_size
chunk_a = cpu_a[start_idx:end_idx].contiguous() # (1024, 4096)
# 在 copy_stream 上异步拷贝 chunk_a 到 GPU
with torch.cuda.stream(copy_stream):
gpu_chunk_a = chunk_a.to("cuda", non_blocking=True)
# 用 event 标记拷贝完成
copy_done = torch.cuda.Event()
copy_done.record(copy_stream)
# compute_stream 等待拷贝完成, 然后做计算
with torch.cuda.stream(compute_stream):
copy_done.wait(compute_stream)
# (1024, 4096) @ (4096, 4096) = (1024, 4096)
gpu_chunk_c = gpu_chunk_a @ gpu_b
results.append(gpu_chunk_c)
# 等待所有 stream 完成
compute_stream.synchronize()
copy_stream.synchronize()
parallel_time = time.perf_counter() - start
print(f"串行时间: {serial_time * 1000:.2f} ms")
print(f"并行时间: {parallel_time * 1000:.2f} ms")
print(f"加速比: {serial_time / parallel_time:.2f}x")
print()
def example_pipeline_parallelism():
"""多阶段流水线: 使用 Event 实现 stage 间同步"""
print("=" * 60)
print("2. 多阶段流水线 (Pipeline Parallelism)")
print("=" * 60)
N = 2048
num_stages = 3
num_microbatches = 5
# 每个 stage 一个 stream
streams = [torch.cuda.Stream() for _ in range(num_stages)]
# 模拟数据
inputs = [torch.randn(N, N, device="cuda") for _ in range(num_microbatches)]
# 每个 microbatch 在每个 stage 之间的 event
# events[i][j]: microbatch i 在 stage j 完成
events = [[torch.cuda.Event() for _ in range(num_stages)]
for _ in range(num_microbatches)]
torch.cuda.synchronize()
start = time.perf_counter()
# 流水线执行: microbatch i 在 stage j 上执行
for j in range(num_stages):
with torch.cuda.stream(streams[j]):
for i in range(num_microbatches):
# 等待前一个 stage 完成 (如果有)
if j > 0:
events[i][j - 1].wait(streams[j])
# 模拟 stage 计算
x = inputs[i]
for _ in range(10):
x = x @ x[:N, :N] * 0.01
inputs[i] = x # 更新数据
# 记录当前 stage 完成
events[i][j].record(streams[j])
# 等待所有 stream 完成
for s in streams:
s.synchronize()
pipeline_time = time.perf_counter() - start
# 对比串行
torch.cuda.synchronize()
start = time.perf_counter()
for i in range(num_microbatches):
x = inputs[i]
for _ in range(num_stages * 10):
x = x @ x[:N, :N] * 0.01
torch.cuda.synchronize()
serial_time = time.perf_counter() - start
print(f"流水线时间: {pipeline_time * 1000:.2f} ms")
print(f"串行时间: {serial_time * 1000:.2f} ms")
print()
def example_producer_consumer():
"""生产者-消费者模式: 一个 stream 生产数据, 另一个消费"""
print("=" * 60)
print("3. 生产者-消费者模式")
print("=" * 60)
N = 2048
num_items = 4
producer_stream = torch.cuda.Stream()
consumer_stream = torch.cuda.Stream()
# 生产者: 生成数据
produced_events = []
data_buffers = []
a = torch.randn(N, N, device="cuda")
for i in range(num_items):
with torch.cuda.stream(producer_stream):
# 生产数据 (例如: 做矩阵乘法)
data = a @ a if i == 0 else data_buffers[-1] @ a
data_buffers.append(data)
# 标记数据已生产完毕
ready = torch.cuda.Event()
ready.record(producer_stream)
produced_events.append(ready)
# 消费者: 消费数据
results = []
for i in range(num_items):
with torch.cuda.stream(consumer_stream):
# 等待数据生产完毕
produced_events[i].wait(consumer_stream)
# 消费数据 (例如: 做激活函数)
result = data_buffers[i].relu()
results.append(result)
consumer_stream.synchronize()
print(f"生产了 {num_items} 个数据, 消费了 {len(results)} 个结果")
print(f"每个结果 shape: {results[0].shape}")
print()
def example_precise_multistream_timing():
"""使用 Event 精确测量多 stream 中各阶段耗时"""
print("=" * 60)
print("4. 多 Stream 精确计时")
print("=" * 60)
N = 4096
num_streams = 3
streams = [torch.cuda.Stream() for _ in range(num_streams)]
# 每个 stream 的起止 event
start_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_streams)]
end_events = [torch.cuda.Event(enable_timing=True) for _ in range(num_streams)]
data = [torch.randn(N, N, device="cuda") for _ in range(num_streams)]
# 在各 stream 上发起计算并记录 event
for i in range(num_streams):
with torch.cuda.stream(streams[i]):
start_events[i].record(streams[i])
_ = data[i] @ data[i]
end_events[i].record(streams[i])
# 等待所有完成并报告各 stream 耗时
for i in range(num_streams):
end_events[i].synchronize()
elapsed = start_events[i].elapsed_time(end_events[i])
print(f" Stream {i}: {elapsed:.3f} ms")
# 总体时间 (第一个 start 到最后一个 end)
total = start_events[0].elapsed_time(end_events[-1])
print(f" 总耗时 (overlap): {total:.3f} ms")
print()
def example_data_dependency_graph():
"""复杂数据依赖图: 多个 stream 间的依赖管理"""
print("=" * 60)
print("5. 复杂数据依赖图")
print("=" * 60)
# 场景:
# A_stream: 计算 A
# B_stream: 计算 B
# C_stream: 计算 C = A + B (依赖 A 和 B)
# D_stream: 计算 D = C * 2 (依赖 C)
A_stream = torch.cuda.Stream()
B_stream = torch.cuda.Stream()
C_stream = torch.cuda.Stream()
D_stream = torch.cuda.Stream()
N = 2048
x = torch.randn(N, N, device="cuda")
y = torch.randn(N, N, device="cuda")
# A 和 B 可以并行
with torch.cuda.stream(A_stream):
A = x @ x
A_done = torch.cuda.Event()
A_done.record(A_stream)
with torch.cuda.stream(B_stream):
B = y @ y
B_done = torch.cuda.Event()
B_done.record(B_stream)
# C 依赖 A 和 B
with torch.cuda.stream(C_stream):
A_done.wait(C_stream)
B_done.wait(C_stream)
C = A + B
C_done = torch.cuda.Event()
C_done.record(C_stream)
# D 依赖 C
with torch.cuda.stream(D_stream):
C_done.wait(D_stream)
D = C * 2
D_stream.synchronize()
print(f"A: {A.shape}, B: {B.shape}, C: {C.shape}, D: {D.shape}")
print(f"D 前几个元素: {D[0, :5]}")
print()
def example_ring_allreduce_simulation():
"""模拟 Ring AllReduce 中的 stream/event 同步模式
Ring AllReduce 的核心思想:
- N 个 rank 排成一个环: 0→1→2→3→0
- 每个 rank 把自己的数据发给下一个 rank (send)
- 同时从上一个 rank 接收数据并累加 (recv)
- 转 N-1 圈后,每个 rank 都拥有所有 rank 数据的和
这里用 stream 和 event 模拟一轮通信:
- send_stream: 负责把数据发给 next rank
- recv_stream: 负责从 prev rank 接收数据
- event: 确保 recv 在 send 完成之后才读取数据
"""
print("=" * 60)
print("6. 模拟 Ring AllReduce 同步模式")
print("=" * 60)
num_ranks = 4
N = 512
# 每个 rank 有一份自己的数据
data = [torch.ones(N, N, device="cuda") * (i + 1) for i in range(num_ranks)]
# recv_buf[i]: rank i 用于接收 rank (i-1) 发来数据的缓冲区
recv_buf = [torch.empty(N, N, device="cuda") for _ in range(num_ranks)]
# 每个 rank 有独立的 send stream 和 recv stream
send_streams = [torch.cuda.Stream() for _ in range(num_ranks)]
recv_streams = [torch.cuda.Stream() for _ in range(num_ranks)]
# 一轮 ring 通信: 每个 rank 同时 send 和 recv
# send_done 和 recv_done 预先创建好
send_done = [torch.cuda.Event() for _ in range(num_ranks)]
recv_done = [torch.cuda.Event() for _ in range(num_ranks)]
for rank in range(num_ranks):
next_rank = (rank + 1) % num_ranks # 发给下一个
prev_rank = (rank - 1) % num_ranks # 从上一个收
# send: 把自己的数据拷到 next rank 的 recv_buf
with torch.cuda.stream(send_streams[rank]):
recv_buf[next_rank].copy_(data[rank])
send_done[rank].record(send_streams[rank])
# recv: 等 prev rank 发完, 然后累加
# 注意: send_done[prev_rank] 可能还没 record
# 但 CUDA 的 wait 不要求 event 已经 record --- 它只是排队等,
# 等 prev rank 的 send 完成后自动放行
with torch.cuda.stream(recv_streams[rank]):
send_done[prev_rank].wait(recv_streams[rank])
data[rank].add_(recv_buf[rank])
recv_done[rank].record(recv_streams[rank])
# 等待所有完成
for s in send_streams + recv_streams:
s.synchronize()
print()
print("=== recv + reduce 完成 (每个 rank 累加了 prev rank 的数据) ===")
for i in range(num_ranks):
prev_i = (i - 1) % num_ranks
expected = (i + 1) + (prev_i + 1)
print(f" rank {i}: data[0,0] = {data[i][0, 0].item():.0f} (期望: 自己的{i+1} + prev的{prev_i+1} = {expected})")
print()
print("关键点:")
print(" - send 和 recv 在不同 stream 上,可以并行执行")
print(" - 用 event 确保 recv 在对应 send 完成后才读取数据")
print(" - 真实 Ring AllReduce 要转 N-1 圈,这里只展示一轮")
print()
if __name__ == "__main__":
if not torch.cuda.is_available():
print("CUDA 不可用,请在有 GPU 的环境中运行")
exit(1)
print(f"设备: {torch.cuda.get_device_name()}")
print(f"CUDA 版本: {torch.version.cuda}")
print()
example_copy_compute_overlap()
example_pipeline_parallelism()
example_producer_consumer()
example_precise_multistream_timing()
example_data_dependency_graph()
example_ring_allreduce_simulation()
event record到底啥意思
event.record(s1) 的意思是:在 stream s1 的命令队列中插入一个标记点(marker)。
具体来说:
- 它不阻塞任何东西 ---
record()调用本身是异步的,只是往 stream 里塞了一个"路标",CPU 立即返回。 - GPU 按序执行 --- 当 stream
s1中record()之前的所有操作都在 GPU 上执行完毕后,这个 event 就被标记为"已完成"。 - 之后用它做同步 --- 你可以用这个 event 来:
-
event.wait(s2)--- 让另一个 streams2停在原地,直到s1到达这个标记点 -
event.synchronize()--- 阻塞 CPU 线程,直到s1到达这个标记点 -
event.query()--- 非阻塞检查 GPU 是否已经执行到这个标记点Stream s1 上做计算
with torch.cuda.stream(s1):
b = a @ a在 s1 上记录 event
event = torch.cuda.Event()
event.record(s1)Stream s2 等待 event (即等待 s1 完成)
with torch.cuda.stream(s2):
event.wait(s2) # s2 阻塞直到 s1 到达 event 点
c = b @ b # 此时 b 已就绪
-
这里的执行顺序是:
时间线 →
s1: [ a @ a ] ──→ [ event 标记 ]
│
s2: └──→ event.wait(s2) 阻塞在此 ──→ [ b @ b ]
↑
等 s1 跑到 event 点后才放行
所以 event.record(s1) 本质上就是一个"栅栏标记"------在 stream 里画一条线,之后你可以让其他 stream 或 CPU 等这条线。