关于torch.backends.deterministic和torch.backends.cudnn.benchmark

TLDR:这是个关于torch.backends.cudnn设置的问题,不同组合的torch.backends.deterministic和torch.backends.cudnn.benchmark会产生不一样的结果,其中最快的组合(deterministic = False ,benchmark = True)比最慢的组合(deterministic = True ,benchmark = False)大约快了30倍。

现在先记录下方便以后有想法有能力了再总结回顾。

在跑BEAT的时候,有一处代码很好玩,other_tools.set_random_seed()

我加了点注释的代码如下:

python 复制代码
def set_random_seed(args):
    os.environ['PYTHONHASHSEED'] = str(args.random_seed)
    random.seed(args.random_seed)
    np.random.seed(args.random_seed)
    torch.manual_seed(args.random_seed)
    torch.cuda.manual_seed_all(args.random_seed)
    torch.cuda.manual_seed(args.random_seed)
    ## pay attention ,the below is training speed difference ,in camn:
    ## if set deterministic = True  ,benchmark = True ,it will cost almost 50-60 seconds for 10its
    ## if set deterministic = False ,benchmark = True ,it will cost almost 1-2 seconds for 10its
    ## if set deterministic = True  ,benchmark = False,it will cost 58 seconds for 10its
    ## if set deterministic = False  ,benchmark = False,it will cost almost 4-5 seconds for 10its
    torch.backends.cudnn.deterministic = args.deterministic #default: False
    torch.backends.cudnn.benchmark = args.benchmark         #default: False
    torch.backends.cudnn.enabled = args.cudnn_enabled       #default: True

只不过很有点意外的是当deterministic = True ,benchmark = True的时候居然这么慢,我起初以为设置好了benchmark=True后torch框架会自动选个最快的卷积算法,后续deterministic = True让这个卷积算法每次返回都是这个固定最快的。

上面是我以为的,下面根据结果(在注释的代码中)来分析

deterministic = True ,benchmark = True的情况,的确还是会选下卷积算法,比如把benchmark在比如设置为False的时候每次运行时间都是固定的,设置为True的时候还是会有点时间上的小波动,可见的确是选了下卷积的算法造成了结果的差异。当然具体怎么选的我暂且就不知道了,当然,选取最快的情况deterministic = False ,benchmark = True会有什么意向不到的结果我暂且也不清楚,网上很多说选取deterministic = True ,benchmark = False是为了保持结果的可复现性,我感觉这很扯就是,波动理应当极小极小(当然这是我目前的偏见)。

相关推荐
锐学AI12 分钟前
从零开始学LangChain(二):LangChain的核心组件 - Agents
人工智能·python
风送雨20 分钟前
多模态RAG工程开发教程(上)
python·langchain
棒棒的皮皮23 分钟前
【OpenCV】Python图像处理形态学之膨胀
图像处理·python·opencv·计算机视觉
小草cys26 分钟前
HarmonyOS Next调用高德api获取实时天气,api接口
开发语言·python·arkts·鸿蒙·harmony os
爬山算法26 分钟前
Netty(25)Netty的序列化和反序列化机制是什么?
开发语言·python
未知数Tel29 分钟前
Dify离线安装插件
python·阿里云·pip·dify
龘龍龙31 分钟前
Python基础学习(六)
开发语言·python·学习
热爱专研AI的学妹1 小时前
【搭建工作流教程】使用数眼智能 API 搭建 AI 智能体工作流教程(含可视化流程图)
大数据·数据库·人工智能·python·ai·语言模型·流程图
databook1 小时前
拒绝“凭感觉”:用回归分析看透数据背后的秘密
python·数据挖掘·数据分析
Psycho_MrZhang1 小时前
Flask 设计思想总结
后端·python·flask