Pytorch 学习笔记(3) : torch.cuda

一、设备管理

API 功能说明
torch.cuda.is_available() 检查 CUDA 是否可用
torch.cuda.is_initialized() 检查 CUDA 是否已初始化
torch.cuda.init() 手动初始化 CUDA 状态
torch.cuda.device_count() 获取可用 GPU 数量
torch.cuda.current_device() 获取当前设备索引
torch.cuda.set_device(device) 设置当前设备
torch.cuda.get_device_name(device) 获取 GPU 名称
torch.cuda.get_device_capability(device) 获取 CUDA 计算能力
torch.cuda.get_device_properties(device) 获取设备详细属性

上下文管理器:

python 复制代码
with torch.cuda.device(0):      # 切换设备
    # 代码块

with torch.cuda.device_of(tensor):  # 切换到 tensor 所在设备
    # 代码块

二、流管理 (Stream)

API 功能说明
torch.cuda.current_stream(device) 获取当前流
torch.cuda.default_stream(device) 获取默认流
torch.cuda.set_stream(stream) 设置当前流
torch.cuda.synchronize(device) 等待所有流完成
torch.cuda.Stream() 创建 CUDA 流
torch.cuda.ExternalStream() 包装外部 CUDA 流

上下文管理器:

python 复制代码
with torch.cuda.stream(stream):  # 选择指定流
    # 代码块

三、内存管理 ⭐

API 功能说明
torch.cuda.memory_allocated(device) 当前张量占用内存(字节)
torch.cuda.max_memory_allocated(device) 最大内存占用
torch.cuda.memory_reserved(device) 缓存分配器管理的内存
torch.cuda.max_memory_reserved(device) 最大预留内存
torch.cuda.empty_cache() 释放未占用缓存内存
torch.cuda.memory_summary(device) 内存统计可读摘要
torch.cuda.memory_stats(device) 详细内存统计(字典)
torch.cuda.memory_snapshot() 内存分配器状态快照
torch.cuda.mem_get_info(device) 空闲/总显存(cudaMemGetInfo)
torch.cuda.reset_peak_memory_stats(device) 重置峰值统计
torch.cuda.set_per_process_memory_fraction(fraction, device) 设置进程内存使用比例

高级功能:

  • torch.cuda.MemPool - 内存池管理
  • torch.cuda.use_mem_pool(pool, device) - 上下文管理器,路由分配到指定池

四、随机数生成

API 功能说明
torch.cuda.manual_seed(seed) 为当前 GPU 设置种子
torch.cuda.manual_seed_all(seed) 为所有 GPU 设置种子
torch.cuda.seed() 随机种子(当前 GPU)
torch.cuda.seed_all() 随机种子(所有 GPU)
torch.cuda.initial_seed() 获取当前随机种子
torch.cuda.get_rng_state(device) 获取 RNG 状态
torch.cuda.set_rng_state(state, device) 设置 RNG 状态
torch.cuda.get_rng_state_all() 获取所有设备的 RNG 状态
torch.cuda.set_rng_state_all(states) 设置所有设备的 RNG 状态

五、多 GPU 通信 (comm)

API 功能说明
torch.cuda.comm.broadcast(tensor, devices) 广播张量到指定 GPU
torch.cuda.comm.broadcast_coalesced(tensors, devices) 广播张量序列
torch.cuda.comm.reduce_add(tensors, destination) 多 GPU 张量求和
torch.cuda.comm.scatter(tensor, devices, chunk_sizes, dim) 分散张量到多 GPU
torch.cuda.comm.gather(tensors, dim, destination) 从多 GPU 收集张量

六、CUDA Graphs(加速推理)

API 功能说明
torch.cuda.CUDAGraph() CUDA 图包装器
torch.cuda.graph(g) 捕获 CUDA 工作的上下文管理器
torch.cuda.make_graphed_callables(callables) 返回图化版本的可调用对象
torch.cuda.is_current_stream_capturing() 检查是否正在捕获
torch.cuda.graph_pool_handle() 获取图内存池标识

七、性能监控与调试

API 功能说明
torch.cuda.utilization(device) GPU 利用率(nvidia-smi)
torch.cuda.memory_usage(device) 显存读写时间百分比
torch.cuda.temperature(device) GPU 温度(摄氏度)
torch.cuda.power_draw(device) 功耗(毫瓦)
torch.cuda.clock_rate(device) SM 时钟频率(MHz)
torch.cuda.get_sync_debug_mode() 获取同步调试模式
torch.cuda.set_sync_debug_mode(mode) 设置同步调试模式

八、NVTX 性能分析标记

python 复制代码
torch.cuda.nvtx.mark(msg)           # 标记瞬时事件
torch.cuda.nvtx.range_push(msg)     # 压入范围
torch.cuda.nvtx.range_pop()         # 弹出范围

with torch.cuda.nvtx.range(msg):    # 上下文管理器
    # 代码块

九、异常类

异常 说明
torch.cuda.OutOfMemoryError 显存不足异常
torch.cuda.AcceleratorError 设备执行异常

十、特性检测

API 功能说明
torch.cuda.is_bf16_supported() 是否支持 bfloat16
torch.cuda.is_tf32_supported() 是否支持 tf32

📌 常用代码片段

python 复制代码
import torch

# 基础检查
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
# 显存清理(训练后常用)
torch.cuda.empty_cache()

# 监控显存使用
print(torch.cuda.memory_summary())

# 设置随机种子(可复现性)
torch.cuda.manual_seed_all(42)

# 多流并行
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()

with torch.cuda.stream(stream1):
    result1 = model1(input1)

with torch.cuda.stream(stream2):
    result2 = model2(input2)

torch.cuda.synchronize()  # 等待完成

📚 参考 :PyTorch 2.11 官方文档 - torch.cuda

相关推荐
weiwei228442 天前
神经网络模型导出及开放标准格式ONNX
pytorch·onnx
LinXunFeng6 天前
Obsidian - 使用 Share Note 分享笔记并自部署
前端·笔记·github
通信小呆呆11 天前
当算法有了“五感”:多模态数据融合如何向人体感官协同学习?
人工智能·学习·算法·机器学习·机器人
程序猿追11 天前
那个右下角的小数字怎么“卡”住我打字——我用 HarmonyOS 自己写了一个字数限制输入框
pytorch·华为·harmonyos
H__Rick11 天前
自动对焦学习-3
人工智能·学习·计算机视觉
Daisy Lee11 天前
量化学习-第1章-什么是量化金融
学习·金融·datawhale
Alsn8611 天前
等待学习-学习目录:Docker 容器安全攻防
学习·安全·docker
YM52e11 天前
买菜计算器小应用 - HarmonyOS ArkUI 开发实战-PC版本
学习·华为·harmonyos·鸿蒙·鸿蒙系统
小雨下雨的雨11 天前
HarmonyOS ArkUI训练营入门-组件掌握系列-Animation 动画效果实现-PC版本
学习·华为·harmonyos·鸿蒙
闪闪发亮的小星星11 天前
高斯光以及高斯光公式解释
笔记