pytest学习-pytorch单元测试

pytorch单元测试

希望测试pytorch各种算子、block、网络等在不同硬件平台,不同软件版本下的计算误差、耗时、内存占用等指标.

本文基于torch.testing._internal

一.公共模块[common.py]

python 复制代码
import torch
from torch import nn
import math
import torch.nn.functional as F
import time
import os
import socket
import sys
from datetime import datetime
import numpy as np
import collections
import math
import json
import copy
import traceback
import subprocess
import unittest
import torch
import inspect
from torch.testing._internal.common_utils import TestCase, run_tests,parametrize,instantiate_parametrized_tests
from torch.testing._internal.common_distributed import MultiProcessTestCase
import torch.distributed as dist

os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29500"
os.environ["RANDOM_SEED"] = "0" 

device="cpu"
device_type="cpu"
device_name="cpu"

try:
    if torch.cuda.is_available():     
        device_name=torch.cuda.get_device_name().replace(" ","")
        device="cuda:0"
        device_type="cuda"
        ccl_backend='nccl'
except:
    pass

host_name=socket.gethostname()    
sdk_version=os.getenv("SDK_VERSION","")   						 #从环境变量中获取sdk版本号
metric_data_root=os.getenv("TORCH_UT_METRICS_DATA","./ut_data")  #日志存放的目录
device_count=torch.cuda.device_count()

if not os.path.exists(metric_data_root):
    os.makedirs(metric_data_root)

def device_warmup(device):
    '''设备warmup,确保设备已经正常工作,排除设备初始化的耗时'''
    left = torch.rand([128,512], dtype = torch.float16).to(device)
    right = torch.rand([512,128], dtype = torch.float16).to(device)
    out=torch.matmul(left,right)
    torch.cuda.synchronize()

torch.manual_seed(1) 
np.random.seed(1)

def loop_decorator(loops,rank=0):
    '''循环装饰器,用于统计函数的执行时间,内存占用等'''
    def decorator(func):
        def wrapper(*args,**kwargs):
            latency=[]
            memory_allocated_t0=torch.cuda.memory_allocated(rank)
            for _ in range(loops):
                input_copy=[x.clone() for x in args]
                beg= datetime.now().timestamp() * 1e6
                pred= func(*input_copy)
                gt=kwargs["golden"]
                torch.cuda.synchronize()
                end=datetime.now().timestamp() * 1e6
                mse = torch.mean(torch.pow(pred.cpu().float()- gt.cpu().float(), 2)).item()
                latency.append(end-beg)
            memory_allocated_t1=torch.cuda.memory_allocated(rank)
            avg_latency=np.mean(latency[len(latency)//2:]).round(3)
            first_latency=latency[0]
            return { "first_latency":first_latency,"avg_latency":avg_latency,
                      "memory_allocated":memory_allocated_t1-memory_allocated_t0,
                      "mse":mse}
        return wrapper
    return decorator

class TorchUtMetrics:
    '''用于统计测试结果,比较之前的最小值'''
    def __init__(self,ut_name,thresold=0.2,rank=0):
        self.ut_name=f"{ut_name}_{rank}"
        self.thresold=thresold
        self.rank=rank
        self.data={"ut_name":self.ut_name,"metrics":[]}
        self.metrics_path=os.path.join(metric_data_root,f"{self.ut_name}_{self.rank}.jon")
        try:
            with open(self.metrics_path,"r") as f:
                self.data=json.loads(f.read())
        except:
            pass

    def __enter__(self):
        self.beg= datetime.now().timestamp() * 1e6
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):        
        self.report()
        self.save_data()

    def save_data(self):
        with open(self.metrics_path,"w") as f:
            f.write(json.dumps(self.data,indent=4))

    def set_metrics(self,metrics):
        self.end=datetime.now().timestamp() * 1e6
        item=collections.OrderedDict()
        item["time"]=datetime.now().strftime('%Y-%m-%d %H:%M:%S.%f')
        item["sdk_version"]=sdk_version
        item["device_name"]=device_name
        item["host_name"]=host_name
        item["metrics"]=metrics
        item["metrics"]["e2e_time"]=self.end-self.beg
        self.cur_item=item
        self.data["metrics"].append(self.cur_item)

    def get_metric_names(self):
        return self.data["metrics"][0]["metrics"].keys()

    def get_min_metric(self,metric_name,devicename=None):
        min_value=0
        min_value_index=-1
        for idx,item in enumerate(self.data["metrics"]):
            if devicename and (devicename!=item['device_name']):                
                continue            
            val=float(item["metrics"][metric_name])
            if min_value_index==-1 or val<min_value:
                min_value=val
                min_value_index=idx
        return min_value,min_value_index

    def get_metric_info(self,index):
        metrics=self.data["metrics"][index]
        return f'{metrics["device_name"]}@{metrics["sdk_version"]}'

    def report(self):
        assert len(self.data["metrics"])>0
        for metric_name in self.get_metric_names():
            min_value,min_value_index=self.get_min_metric(metric_name)
            min_value_same_dev,min_value_index_same_dev=self.get_min_metric(metric_name,device_name)
            cur_value=float(self.cur_item["metrics"][metric_name])
            print(f"-------------------------------{metric_name}-------------------------------")
            print(f"{cur_value}#{device_name}@{sdk_version}")
            if min_value_index_same_dev>=0:
                print(f"{min_value_same_dev}#{self.get_metric_info(min_value_index_same_dev)}")
            if min_value_index>=0:
                print(f"{min_value}#{self.get_metric_info(min_value_index)}")

二.普通算子测试[test_clone.py]

python 复制代码
from common import *
class TestCaseClone(TestCase):
    #如果不满足条件,则跳过这个测试
    @unittest.skipIf(device_count>1, "Not enough devices") 
    def test_todo(self):
        print(".TODO")

    #框架会自动遍历以下参数组合
    @parametrize("shape", [(10240,20480),(128,256)])
    @parametrize("dtype", [torch.float16,torch.float32])
    def test_clone(self,shape,dtype):
        
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5)
        def run(input_dev):
            output=input_dev.clone()
            return output
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2) as m:
            input_host=torch.ones(shape,dtype=dtype)*np.random.rand()
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=input_host.cpu())
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        
instantiate_parametrized_tests(TestCaseClone)

if __name__ == "__main__":
    run_tests()

三.集合通信测试[test_ccl.py]

python 复制代码
from common import *
class TestCCL(MultiProcessTestCase):
    '''CCL测试用例'''
    def _create_process_group_vccl(self, world_size, store):
        dist.init_process_group(
            ccl_backend, world_size=world_size, rank=self.rank, store=store
        )        
        pg = dist.distributed_c10d._get_default_group()
        return pg

    def setUp(self):
        super().setUp()
        self._spawn_processes()

    def tearDown(self):
        super().tearDown()
        try:
            os.remove(self.file_name)
        except OSError:
            pass

    @property
    def world_size(self):
        return 4
      
    #框架会自动遍历以下参数组合
    @unittest.skipIf(device_count<4, "Not enough devices") 
    @parametrize("op",[dist.ReduceOp.SUM])
    @parametrize("shape", [(1024,8192)])
    @parametrize("dtype", [torch.int64])
    def test_allreduce(self,op,shape,dtype):
        if self.rank >= self.world_size:
            return
        
        store = dist.FileStore(self.file_name, self.world_size)
        pg = self._create_process_group_vccl(self.world_size, store)
        if not torch.distributed.is_initialized():
            return
    
        torch.cuda.set_device(self.rank)
        device = torch.device(device_type,self.rank)
        device_warmup(device)
        #让这个函数循环执行loops次,统计第一次执行的耗时、后半段的平均时间、整个执行过程总的GPU内存使用量
        @loop_decorator(loops=5,rank=self.rank)
        def run(input_dev):
            dist.all_reduce(input_dev, op=op)
            return input_dev
        
        #记录整个测试的总耗时,保存统计量,输出摘要(self._testMethodName:测试方法,result:函数返回值,metrics:统计量)
        with TorchUtMetrics(ut_name=self._testMethodName,thresold=0.2,rank=self.rank) as m:
            input_host=torch.ones(shape,dtype=dtype)*(100+self.rank)
            gt=[torch.ones(shape,dtype=dtype)*(100+i) for i in range(self.world_size)]
            gt_=gt[0]
            for i in range(1,self.world_size):
                gt_=gt_+gt[i]
            input_dev=input_host.to(device)
            metrics=run(input_dev,golden=gt_)
            m.set_metrics(metrics)
            assert(metrics["mse"]==0)
        dist.destroy_process_group(pg)
    
instantiate_parametrized_tests(TestCCL)

if __name__ == "__main__":
    run_tests()

四.测试命令

bash 复制代码
# 运行所有的测试
pytest -v -s -p no:warnings --html=torch_report.html --self-contained-html --capture=sys ./

# 运行某一个测试
python3 test_clone.py -k "test_clone_shape_(128, 256)_float32"

五.测试报告

相关推荐
玄斎2 小时前
MySQL 单表操作通关指南:建库 / 建表 / 插入 / 增删改查
运维·服务器·数据库·学习·程序人生·mysql·oracle
im_AMBER4 小时前
Leetcode 78 识别数组中的最大异常值 | 镜像对之间最小绝对距离
笔记·学习·算法·leetcode
其美杰布-富贵-李5 小时前
HDF5文件学习笔记
数据结构·笔记·学习
d111111111d6 小时前
在STM32函数指针是什么,怎么使用还有典型应用场景。
笔记·stm32·单片机·嵌入式硬件·学习·算法
嗷嗷哦润橘_7 小时前
AI Agent学习:MetaGPT之我的工作
人工智能·学习·flask
知识分享小能手8 小时前
CentOS Stream 9入门学习教程,从入门到精通,Linux日志分析工具及应用 —语法详解与实战案例(17)
linux·学习·centos
2301_783360138 小时前
【学习笔记】关于RNA_seq和Ribo_seq技术的对比和BAM生成
笔记·学习
qq_397731518 小时前
Objective-C 学习笔记(第9章)
笔记·学习·objective-c
ujainu9 小时前
Python学习第一天:保留字和标识符
python·学习·标识符·保留字
sheji34169 小时前
【开题答辩全过程】以 基于Java的应急安全学习平台的设计与实现为例,包含答辩的问题和答案
java·开发语言·学习