通过os.dup sys.stdout.fileno捕获标准输出,判断pytorch算子是否fallback到了cpu
某种设备在运行pytorch算子时,如果不支持会自动fallback到cpu,输出的tensor.device却不是cpu,我希望能获取到这个状态。本文通过捕获标准输出,根据终端是否输出fallback字符串,判断是否触发了fallback
一.代码
python
import threading
import sys
import os
class CheckFallback:
def __init__(self,enable=True):
self.is_fallback=False
self.enable=enable
if self.enable:
self.stdout_fileno_origin = sys.stdout.fileno()
self.stdout_fileno_dup = os.dup(self.stdout_fileno_origin)
self.stdout_pipe = os.pipe()
os.dup2(self.stdout_pipe[1], self.stdout_fileno_origin)
os.close(self.stdout_pipe[1])
self.stdout_messages = ''
self.running=True
self.task = threading.Thread(target=self.read_pipe)
self.task.start()
def read_pipe(self):
while self.running:
msg = os.read(self.stdout_pipe[0], 8192)
if msg:
self.stdout_messages+=msg.decode('utf-8')
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enable:
self.running=False
os.close(self.stdout_fileno_origin)
self.task.join()
os.close(self.stdout_pipe[0])
os.dup2(self.stdout_fileno_dup, self.stdout_fileno_origin)
os.close(self.stdout_fileno_dup)
#检查终端是否有fallback信息输出
if self.stdout_messages.find("fallback")>=0:
self.is_fallback=True
import torch
A=torch.ones((512,65024),dtype=torch.float16).to("your_device")
with CheckFallback() as f:
C=torch.ops.aten.gelu.default(A)
print(f.is_fallback)
print(C.shape,C.device)
with CheckFallback() as f:
A=torch.ones((1,32),dtype=torch.float16).to("your_device")
C=torch.ops.aten.pow(A,A)
print(f.is_fallback)
print(C.shape,C.device)