【Pytorch】固定随机数种子

在对神经网络模型进行训练时,有时候会存在对训练过程进行复现的需求。然而,每次运行时 Pytorch、Numpy 中的随机性将使得该目的变得困难重重。在程序运行前固定所有随机数的种子有望解决这一问题。基于此,本文记录了 Pytorch 中的固定随机数种子的方法。

在使用 Pytorch 对模型进行训练时,通常涉及到随机数的模块包括:Python、Pytorch、Numpy、Cudnn。因此,在开始训练前,需要针对这些涉及随机数的模块进行随机数种子的固定。

1. Python

Python 本身涉及到的随机性主要是 Python 自带的 random 库随机化和 Hash 随机化问题,需要通过 os 库对其进行限制:

python 复制代码
import os, random
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
2. Numpy

在使用 Numpy 库取随机数时,需要对其随机数种子进行限制:

python 复制代码
import numpy as np
np.random.seed(seed)
3. Pytorch

当 Pytorch 使用 CPU 进行运算时,需要设定 CPU 支撑下的 Pytorch 随机数种子:

python 复制代码
import torch
torch.manual_seed(seed)

当 Pytorch 使用 GPU 进行运算时,需要设定 GPU 支撑下的 Pytorch 随机数种子:

python 复制代码
import torch
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # 使用多 GPU 时使用

需要特别注意的是:目前很多博客和知乎回答提出 torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed) 具有相同的作用。这个结论需要注意 Pytorch 版本。在笔者所用的 Pytorch 2.1 版本下,这两个函数的作用完全不同。参考官方文档:torch.cuda.manual_seedtorch.cuda.manual_seed_all(seed)

当 Pytorch 使用 Cudnn 进行加速运算时,还需要限制 Cudnn 在加速过程中涉及到的随机策略:

python 复制代码
import torch
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
总结

基于上述库的固定随机数方法总结为:

python 复制代码
def set_random_seed(seed: int) -> None:
	random.seed(seed)
	os.environ['PYTHONHASHSEED'] = str(seed)
	np.random.seed(seed)
	torch.manual_seed(seed)
	torch.cuda.manual_seed_all(seed)
	torch.backends.cudnn.benchmark = False
	torch.backends.cudnn.deterministic = True

seed = 114514
set_torch_seed(seed)

如果在实践中还调用了其他涉及随机性的第三方库,则需要根据上述思路对该固定随机数方法进行动态补充。

相关推荐
汽车仪器仪表相关领域几秒前
Debron OVM 1052 光学关门速度仪:汽车门盖检测的高精度便携工具 + 生产线适配 + 耐久性监测,整车制造与质量控制的黄金标准
人工智能·功能测试·单元测试·汽车·制造·可用性测试
网络工程小王几秒前
【LangGraph 状态持久化(Checkpoint)详解】学习笔记
jvm·人工智能·笔记·langchain
web守墓人1 分钟前
【神经网络】js版本的Pytorch,estorch重磅发布
前端·javascript·人工智能·pytorch·深度学习·神经网络
蔡俊锋1 分钟前
AI自动化不是接工具就行,得补缺点搭轨道
人工智能·ai 效率
DXM05214 分钟前
第11期:实战| ArcGIS Pro 遥感影像预处理
人工智能·arcgis·#arcpy·#arcgis 二次开发·#gis 自动化
Tutankaaa4 分钟前
交通安全知识竞赛:文明出行,安全相伴
大数据·人工智能·安全
knight_9___5 分钟前
大模型project面试2
人工智能
龙侠九重天6 分钟前
大型语言模型结构化输出:用 JSON Schema 约束大模型输出
人工智能·语言模型·自然语言处理·大模型·json
China_Yanhy6 分钟前
【云原生 AI 实战】EKS 搭建 GPU 超算集群:从零拉起节点到 PyTorchJob 分布式训练 (附 EFA 加速避坑指南)
人工智能·分布式·云原生
人工智能培训7 分钟前
知识图谱与检索增强的实战结合
人工智能·深度学习·神经网络·机器学习·生成对抗网络