Pytorch 学习笔记(3) : torch.cuda

一、设备管理

API 功能说明
torch.cuda.is_available() 检查 CUDA 是否可用
torch.cuda.is_initialized() 检查 CUDA 是否已初始化
torch.cuda.init() 手动初始化 CUDA 状态
torch.cuda.device_count() 获取可用 GPU 数量
torch.cuda.current_device() 获取当前设备索引
torch.cuda.set_device(device) 设置当前设备
torch.cuda.get_device_name(device) 获取 GPU 名称
torch.cuda.get_device_capability(device) 获取 CUDA 计算能力
torch.cuda.get_device_properties(device) 获取设备详细属性

上下文管理器:

python 复制代码
with torch.cuda.device(0):      # 切换设备
    # 代码块

with torch.cuda.device_of(tensor):  # 切换到 tensor 所在设备
    # 代码块

二、流管理 (Stream)

API 功能说明
torch.cuda.current_stream(device) 获取当前流
torch.cuda.default_stream(device) 获取默认流
torch.cuda.set_stream(stream) 设置当前流
torch.cuda.synchronize(device) 等待所有流完成
torch.cuda.Stream() 创建 CUDA 流
torch.cuda.ExternalStream() 包装外部 CUDA 流

上下文管理器:

python 复制代码
with torch.cuda.stream(stream):  # 选择指定流
    # 代码块

三、内存管理 ⭐

API 功能说明
torch.cuda.memory_allocated(device) 当前张量占用内存(字节)
torch.cuda.max_memory_allocated(device) 最大内存占用
torch.cuda.memory_reserved(device) 缓存分配器管理的内存
torch.cuda.max_memory_reserved(device) 最大预留内存
torch.cuda.empty_cache() 释放未占用缓存内存
torch.cuda.memory_summary(device) 内存统计可读摘要
torch.cuda.memory_stats(device) 详细内存统计(字典)
torch.cuda.memory_snapshot() 内存分配器状态快照
torch.cuda.mem_get_info(device) 空闲/总显存(cudaMemGetInfo)
torch.cuda.reset_peak_memory_stats(device) 重置峰值统计
torch.cuda.set_per_process_memory_fraction(fraction, device) 设置进程内存使用比例

高级功能:

  • torch.cuda.MemPool - 内存池管理
  • torch.cuda.use_mem_pool(pool, device) - 上下文管理器,路由分配到指定池

四、随机数生成

API 功能说明
torch.cuda.manual_seed(seed) 为当前 GPU 设置种子
torch.cuda.manual_seed_all(seed) 为所有 GPU 设置种子
torch.cuda.seed() 随机种子(当前 GPU)
torch.cuda.seed_all() 随机种子(所有 GPU)
torch.cuda.initial_seed() 获取当前随机种子
torch.cuda.get_rng_state(device) 获取 RNG 状态
torch.cuda.set_rng_state(state, device) 设置 RNG 状态
torch.cuda.get_rng_state_all() 获取所有设备的 RNG 状态
torch.cuda.set_rng_state_all(states) 设置所有设备的 RNG 状态

五、多 GPU 通信 (comm)

API 功能说明
torch.cuda.comm.broadcast(tensor, devices) 广播张量到指定 GPU
torch.cuda.comm.broadcast_coalesced(tensors, devices) 广播张量序列
torch.cuda.comm.reduce_add(tensors, destination) 多 GPU 张量求和
torch.cuda.comm.scatter(tensor, devices, chunk_sizes, dim) 分散张量到多 GPU
torch.cuda.comm.gather(tensors, dim, destination) 从多 GPU 收集张量

六、CUDA Graphs(加速推理)

API 功能说明
torch.cuda.CUDAGraph() CUDA 图包装器
torch.cuda.graph(g) 捕获 CUDA 工作的上下文管理器
torch.cuda.make_graphed_callables(callables) 返回图化版本的可调用对象
torch.cuda.is_current_stream_capturing() 检查是否正在捕获
torch.cuda.graph_pool_handle() 获取图内存池标识

七、性能监控与调试

API 功能说明
torch.cuda.utilization(device) GPU 利用率(nvidia-smi)
torch.cuda.memory_usage(device) 显存读写时间百分比
torch.cuda.temperature(device) GPU 温度(摄氏度)
torch.cuda.power_draw(device) 功耗(毫瓦)
torch.cuda.clock_rate(device) SM 时钟频率(MHz)
torch.cuda.get_sync_debug_mode() 获取同步调试模式
torch.cuda.set_sync_debug_mode(mode) 设置同步调试模式

八、NVTX 性能分析标记

python 复制代码
torch.cuda.nvtx.mark(msg)           # 标记瞬时事件
torch.cuda.nvtx.range_push(msg)     # 压入范围
torch.cuda.nvtx.range_pop()         # 弹出范围

with torch.cuda.nvtx.range(msg):    # 上下文管理器
    # 代码块

九、异常类

异常 说明
torch.cuda.OutOfMemoryError 显存不足异常
torch.cuda.AcceleratorError 设备执行异常

十、特性检测

API 功能说明
torch.cuda.is_bf16_supported() 是否支持 bfloat16
torch.cuda.is_tf32_supported() 是否支持 tf32

📌 常用代码片段

python 复制代码
import torch

# 基础检查
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    
# 显存清理(训练后常用)
torch.cuda.empty_cache()

# 监控显存使用
print(torch.cuda.memory_summary())

# 设置随机种子(可复现性)
torch.cuda.manual_seed_all(42)

# 多流并行
stream1 = torch.cuda.Stream()
stream2 = torch.cuda.Stream()

with torch.cuda.stream(stream1):
    result1 = model1(input1)

with torch.cuda.stream(stream2):
    result2 = model2(input2)

torch.cuda.synchronize()  # 等待完成

📚 参考 :PyTorch 2.11 官方文档 - torch.cuda

相关推荐
嵌入式小企鹅5 小时前
蓝牙学习系列(七):BLE GATT 数据模型详解
学习·蓝牙·ble·蓝牙协议栈·蓝牙开发·gatt
arvin_xiaoting6 小时前
OpenClaw学习总结_III_自动化系统_3:CronJobs详解
数据库·学习·自动化
少许极端6 小时前
算法奇妙屋(四十一)-贪心算法学习之路 8
学习·算法·贪心算法
练习时长一年7 小时前
我的开发笔记
笔记
arvin_xiaoting7 小时前
OpenClaw学习总结_III_自动化系统_2:Webhooks详解
运维·学习·自动化
生瓜硬劈..7 小时前
SQL 调优全解:从 20 s 到 200 ms 的 6 步实战笔记
java·笔记·sql
不早睡不改名@7 小时前
Netty源码分析---Reactor线程模型深度解析(一)
java·笔记·学习·netty
祁白_7 小时前
Bugku:备份是一个好习惯
笔记·学习·web安全·ctf
Kang.Charles7 小时前
【新手入门】UE第一人称示例工程学习
学习
野指针YZZ8 小时前
XV6操作系统:内存学习笔记
笔记·学习