
在分布式训练场景中,有时我们可能会需要使用print函数(虽然大部分情况下大多会用logging进行信息输出)在终端打印相关信息。但由于同时运行多个进程,如果不进行限制,每个进程都会打印信息,不但影响观感,而且可能会造成阻塞。
通常的解决方法是利用if条件语句进行限制,只在主进程中进行打印,如下:
python
# 当前为主进程
if args.rank == 0:
print('Train message')
但最近在学习目标检测模型DINO源码时,我发现作者采用重写内置print函数 的方式实现了相同的功能,即只在主进程中启用print函数,在其他进程中禁用print函数。
函数源码如下:
python
def setup_for_distributed(is_master):
"""
This function disables printing when not in master process
"""
import builtins as __builtin__
# 得到内置的print函数
builtin_print = __builtin__.print
# 重写print函数
def print(*args, **kwargs):
force = kwargs.pop('force', False)
# 在主进程或者强制条件下才调用内置print输出
if is_master or force:
builtin_print(*args, **kwargs)
# 用重写后的print函数替换内置的print函数
__builtin__.print = print
该方法具体的调用位置是在初始化多进程组之后,示例如下:
python
import torch
args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.dist_backend = 'nccl'
args.dist_url = 'env://'
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
# 只在主进程启用print
setup_for_distributed(args.rank == 0)
实测好用,且思路清奇,果然学习永无止境。在此做一个学习记录,也分享给需要的人。