Pytorch分布式训练print()使用技巧

在分布式训练场景中,有时我们可能会需要使用print函数(虽然大部分情况下大多会用logging进行信息输出)在终端打印相关信息。但由于同时运行多个进程,如果不进行限制,每个进程都会打印信息,不但影响观感,而且可能会造成阻塞。

通常的解决方法是利用if条件语句进行限制,只在主进程中进行打印,如下:

python 复制代码
# 当前为主进程
if args.rank == 0:
    print('Train message')

但最近在学习目标检测模型DINO源码时,我发现作者采用重写内置print函数 的方式实现了相同的功能,即只在主进程中启用print函数,在其他进程中禁用print函数。

函数源码如下:

python 复制代码
def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    # 得到内置的print函数
    builtin_print = __builtin__.print

    
    # 重写print函数
    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        # 在主进程或者强制条件下才调用内置print输出
        if is_master or force:
            builtin_print(*args, **kwargs)

    # 用重写后的print函数替换内置的print函数
    __builtin__.print = print

该方法具体的调用位置是在初始化多进程组之后,示例如下:

python 复制代码
import torch

args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.dist_backend = 'nccl'
args.dist_url = 'env://'
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
# 只在主进程启用print
setup_for_distributed(args.rank == 0)

实测好用,且思路清奇,果然学习永无止境。在此做一个学习记录,也分享给需要的人。

相关推荐
KimLiu15 分钟前
LCODER之Python:使用Django搭建服务端
后端·python·django
胡耀超18 分钟前
3.Python高级数据结构与文本处理
服务器·数据结构·人工智能·windows·python·大模型
索迪迈科技26 分钟前
GPS汽车限速器有哪些功能?主要运用在哪里?
人工智能·行车记录仪·车辆安全·监控管理·gps定位
1373i35 分钟前
【Python】pytorch安装(使用conda)
pytorch·python·conda
keyinf038 分钟前
python网络爬取个人学习指南-(五)
python
Niuguangshuo1 小时前
深度学习基本模块:Conv2D 二维卷积层
人工智能·深度学习
b***25111 小时前
深圳比斯特|多维度分选:圆柱电池品质管控的自动化解决方案
大数据·人工智能
kida_yuan1 小时前
【从零开始】12. 一切回归原点
python·架构·nlp