Pytorch分布式训练print()使用技巧

在分布式训练场景中,有时我们可能会需要使用print函数(虽然大部分情况下大多会用logging进行信息输出)在终端打印相关信息。但由于同时运行多个进程,如果不进行限制,每个进程都会打印信息,不但影响观感,而且可能会造成阻塞。

通常的解决方法是利用if条件语句进行限制,只在主进程中进行打印,如下:

python 复制代码
# 当前为主进程
if args.rank == 0:
    print('Train message')

但最近在学习目标检测模型DINO源码时,我发现作者采用重写内置print函数 的方式实现了相同的功能,即只在主进程中启用print函数,在其他进程中禁用print函数。

函数源码如下:

python 复制代码
def setup_for_distributed(is_master):
    """
    This function disables printing when not in master process
    """
    import builtins as __builtin__

    # 得到内置的print函数
    builtin_print = __builtin__.print

    
    # 重写print函数
    def print(*args, **kwargs):
        force = kwargs.pop('force', False)
        # 在主进程或者强制条件下才调用内置print输出
        if is_master or force:
            builtin_print(*args, **kwargs)

    # 用重写后的print函数替换内置的print函数
    __builtin__.print = print

该方法具体的调用位置是在初始化多进程组之后,示例如下:

python 复制代码
import torch

args.rank = int(os.environ["RANK"])
args.world_size = int(os.environ['WORLD_SIZE'])
args.dist_backend = 'nccl'
args.dist_url = 'env://'
torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                                         world_size=args.world_size, rank=args.rank)
# 只在主进程启用print
setup_for_distributed(args.rank == 0)

实测好用,且思路清奇,果然学习永无止境。在此做一个学习记录,也分享给需要的人。

相关推荐
2501_911067669 小时前
光能筑底,智联全城——叁仟智慧太阳能路灯杆重构城市基础设施新生态
大数据·人工智能·重构
OpenCSG9 小时前
AgenticOps x CSGHub:智能体时代的工程化革命,让企业 AI 落地可控可规模化
人工智能
hrrrrb10 小时前
【算法设计与分析】随机化算法
人工智能·python·算法
D___H10 小时前
Part10_编写自己的解释器
python
Zero_to_zero123410 小时前
Claude code系列(一):claude安装、入门及基础操作指令
人工智能·python
szcsun510 小时前
机器学习(二)-线性回归实战
人工智能·机器学习·线性回归
Yeats_Liao10 小时前
异步推理架构:CPU-NPU流水线设计与并发效率提升
python·深度学习·神经网络·架构·开源
普通网友10 小时前
Android16 adb投屏工具Scrcpy介绍。
人工智能
搬砖者(视觉算法工程师)10 小时前
语义分割:基于 TensorFlow 对 FCN 与迁移学习的探究
人工智能