PyTorch 中用于 主机(CPU)与设备(GPU)同步 的函数 torch.cuda.synchronize()

PyTorch 中用于 主机(CPU)与设备(GPU)同步 的函数 torch.cuda.synchronize()

flyfish

完整代码在文末

GPU 是 异步执行 的

CPU 发送指令给 GPU(比如矩阵乘法、卷积)

CPU 不会等 GPU 算完,直接跑去执行下一行代码

GPU 在后台默默计算,计算完了再通知 CPU

torch.cuda.synchronize() 到底做什么?

强制 CPU 停下来,等待 GPU 把之前所有的计算任务全部执行完毕,再继续运行后面的代码。

无同步:CPU 发完指令就溜,GPU 后台干活

加同步:CPU 原地等待,直到 GPU 干完活

python 复制代码
torch.cuda.synchronize(device=None)

它会阻塞当前 CPU 线程 ,直到指定 CUDA 设备上所有 streams 中的所有 kernels(计算任务)全部完成为止。

device 参数可选(默认为当前设备 torch.cuda.current_device())。

调用底层 CUDA 的 cudaDeviceSynchronize()(或类似机制),强制等待 GPU 把之前下发的所有工作做完。

PyTorch(以及几乎所有现代 CUDA 编程)默认采用 异步执行 模型:

当在 Python 代码里写 tensor = model(input)torch.mm(a, b) 等 GPU 操作时:
CPU 只负责下发指令 (把 kernel 丢到 CUDA stream 里),函数几乎立刻返回。
GPU 在后台真正执行 计算。

CPU 和 GPU 是并行的,CPU 可以继续往下跑 Python 代码,而 GPU 还在算。避免 CPU 频繁等待 GPU。

如果直接用 time.time() 测量时间,会只测到下发指令的时间,而不是 GPU 真正计算的时间,导致计时严重偏小。

什么场景使用 torch.cuda.synchronize()

  1. 准确测量 GPU 执行时间(Benchmark / Profiling)

    python 复制代码
    torch.cuda.synchronize()          # 先清空之前的残留任务(可选但推荐)
    start = time.time()
    
    # 模型/操作
    output = model(input)
    
    torch.cuda.synchronize()          # 关键!等 GPU 全部算完
    end = time.time()
    print(end - start)

    更推荐的现代写法(使用 CUDA Event,精度更高,避免多余阻塞):

    python 复制代码
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    
    start_event.record()
    # ... 执行操作 ...
    end_event.record()
    
    torch.cuda.synchronize()          # 必须等 event 被记录
    print(start_event.elapsed_time(end_event))  # 单位:毫秒
  2. 调试时需要 GPU 计算真正完成后再看结果

    比如想 print(tensor) 或把结果转到 CPU(.cpu().item() 有时会隐式同步,但不总是可靠)。

    确认某个 kernel 是否真的执行完(排查异步 bug)。

  3. 使用多个 CUDA Stream 时

    默认 stream 里操作通常不需要手动 sync(PyTorch 会自动处理依赖)。

    但如果用了自定义 stream(torch.cuda.Stream),不同 stream 之间可能需要显式同步。

  4. 某些需要严格顺序的场景

    比如在训练循环中,每一步都想确保前一步完全结束(一般不推荐,会严重降低性能)。

什么场景不使用 torch.cuda.synchronize()

正常训练/推理时 时几乎永远不要在循环里每一步都加 torch.cuda.synchronize(),它会强制 CPU 等待 GPU,破坏异步并行,大幅降低整体吞吐量

性能测试时,也只在测量开始前和结束时各加一次,不要每 iteration 都加,除非故意想测包含同步开销的时间。

使用 synchronize() 测量

py 复制代码
import torch
import time

# 检查是否有 GPU
if not torch.cuda.is_available():
    print("CUDA 不可用,请检查环境")
    exit()

device = torch.device("cuda:0")
print(f"使用设备: {device}")

# 创建一些较大的数据用于测试
x = torch.randn(4096, 4096, device=device)
y = torch.randn(4096, 4096, device=device)

# ====================== 使用 synchronize() 测量 ======================
print("\n=== 使用 torch.cuda.synchronize() 测量 ===")

# 预热(非常重要!第一次运行会有额外开销)
for _ in range(10):
    _ = torch.mm(x, y)
torch.cuda.synchronize()

# 开始正式计时
torch.cuda.synchronize()                    # 确保之前所有操作完成
start_time = time.time()

# 要测量的操作(这里用矩阵乘法为例)
for i in range(100):
    z = torch.mm(x, y)                      # GPU 操作

torch.cuda.synchronize()                    # 关键!等待所有 GPU 操作完成
end_time = time.time()

elapsed = (end_time - start_time) * 1000    # 转换为毫秒
print(f"使用 synchronize() 测量 100 次矩阵乘法耗时: {elapsed:.2f} ms")
print(f"平均每次: {elapsed/100:.2f} ms")

使用 CUDA Event 测量

py 复制代码
import torch

# 检查 GPU
if not torch.cuda.is_available():
    print("CUDA 不可用")
    exit()

device = torch.device("cuda:0")
print(f"使用设备: {device}")

x = torch.randn(4096, 4096, device=device)
y = torch.randn(4096, 4096, device=device)

# ====================== 使用 CUDA Event 测量(推荐) ======================
print("\n=== 使用 CUDA Event 测量(推荐) ===")

# 预热
for _ in range(10):
    _ = torch.mm(x, y)
torch.cuda.synchronize()

# 创建 Event
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

# 开始计时
start_event.record()                        # 记录开始点

# 要测量的操作
for i in range(100):
    z = torch.mm(x, y)

end_event.record()                          # 记录结束点

# 必须调用 synchronize() 才能读取时间
torch.cuda.synchronize()                    # 等待两个 Event 都被记录完成

elapsed_ms = start_event.elapsed_time(end_event)   # 返回毫秒
print(f"使用 CUDA Event 测量 100 次矩阵乘法耗时: {elapsed_ms:.2f} ms")
print(f"平均每次: {elapsed_ms/100:.2f} ms")
相关推荐
qq_334563552 小时前
MySQL如何实现数据库审计日志记录_开启通用日志与插件审计
jvm·数据库·python
无风听海2 小时前
Python Union语法深度解析
python
阿里巴啦2 小时前
一个 Python 视频处理工具链实战:下载、转录、摘要、字幕、诊断全打通 (已开源)
人工智能·python·whisper·视频下载·视频处理工具
m0_640309302 小时前
如何大幅提升 Google Sheets 数据库更新脚本的执行效率
jvm·数据库·python
Greyson12 小时前
CSS如何实现单选按钮自定义样式_利用伪元素隐藏默认UI
jvm·数据库·python
2401_835956812 小时前
Go语言怎么防SQL注入_Go语言SQL注入防护教程【深入】
jvm·数据库·python
郝学胜-神的一滴2 小时前
Softmax 从入门到精通:多分类激活函数的优雅解法
人工智能·python·算法·机器学习·分类·数据挖掘
m0_514520572 小时前
宝塔面板怎样实现数据库的多地异地自动备份_结合阿里云OSS与定时任务插件
jvm·数据库·python
qq_334563552 小时前
golang如何优化磁盘IO性能_golang磁盘IO性能优化思路
jvm·数据库·python