通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu

通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu

某种设备在运行pytorch算子时,如果不支持会自动fallback到cpu,输出的tensor.device却不是cpu,我希望能获取到这个状态。本文通过捕获标准输出,根据终端是否输出fallback字符串,判断是否触发了fallback

一.代码

python 复制代码
import threading
import sys
import os

class CheckFallback:
    def __init__(self,enable=True):        
        self.is_fallback=False
        self.enable=enable
        if self.enable:
            self.stdout_fileno_origin = sys.stdout.fileno()
            self.stdout_fileno_dup = os.dup(self.stdout_fileno_origin)
            self.stdout_pipe = os.pipe()
            os.dup2(self.stdout_pipe[1], self.stdout_fileno_origin)
            os.close(self.stdout_pipe[1])
            self.stdout_messages = ''
            self.running=True
            self.task = threading.Thread(target=self.read_pipe)
            self.task.start()

    def read_pipe(self):
        while self.running:
            msg = os.read(self.stdout_pipe[0], 8192)
            if msg:
                self.stdout_messages+=msg.decode('utf-8')
    
    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        if self.enable:
            self.running=False
            os.close(self.stdout_fileno_origin)
            self.task.join()
            os.close(self.stdout_pipe[0])
            os.dup2(self.stdout_fileno_dup, self.stdout_fileno_origin)
            os.close(self.stdout_fileno_dup)
            #检查终端是否有fallback信息输出
            if self.stdout_messages.find("fallback")>=0:
                self.is_fallback=True

import torch
A=torch.ones((512,65024),dtype=torch.float16).to("your_device")
with CheckFallback() as f:
    C=torch.ops.aten.gelu.default(A)    
print(f.is_fallback)
print(C.shape,C.device)

with CheckFallback() as f:
    A=torch.ones((1,32),dtype=torch.float16).to("your_device")
    C=torch.ops.aten.pow(A,A)
print(f.is_fallback)
print(C.shape,C.device)
相关推荐
阿_旭2 分钟前
基于YOLO11深度学习的运动品牌LOGO检测与识别系统【python源码+Pyqt5界面+数据集+训练代码】
人工智能·python·深度学习·毕业设计·logo检测
SomeB1oody3 分钟前
【Python机器学习】1.9. 逻辑回归实战(进阶):建立二阶边界模型
人工智能·python·机器学习·ai·逻辑回归
东临碣石827 分钟前
【AI论文】GEN3C: 基于3D信息的全球一致视频生成技术,实现精确相机控制
人工智能·数码相机·3d
go54631584659 分钟前
简单的 Python 示例,用于生成电影解说视频的第一人称独白解说文案
开发语言·python
YueiL12 分钟前
OpenCV 颜色空间:原理与操作指南
python·opencv
源码姑娘14 分钟前
基于OpenCV的车牌识别系统(源码+论文+部署教程)
人工智能·毕业设计
我码玄黄16 分钟前
大模型时代,为什么模型都是多少B?
人工智能·llm
机器学习小小白17 分钟前
【深入解析Inception网络:从V1到V3的理论演进与对比,包含pytorch实现Inception模块的代码】
pytorch·深度学习·神经网络·inception
Dmatteratall20 分钟前
目标检测热力图的生成代码(基于GridCam)生成的
人工智能·目标检测·计算机视觉
没学上了27 分钟前
逻辑回归机器学习
人工智能·深度学习·逻辑回归