diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 2c3086184c3609c0b9042cc54e7bd33eec6db6ed..66c489c4638267352c9e70517e2c25cfdbf1ea6e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -1,6 +1,7 @@ import os import torch -from api_accuracy_checker.common.utils import print_error_log, write_pt +from api_accuracy_checker.common.utils import print_error_log, write_pt, create_directory +from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_path_before_create class BaseAPIInfo: @@ -55,11 +56,14 @@ class BaseAPIInfo: else: api_args = self.api_name + '.' + str(self.args_num) if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path) - + forward_real_data_path = os.path.join(self.save_path, self.forward_path, "rank" + str(self.rank)) + check_path_before_create(forward_real_data_path) + create_directory(forward_real_data_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path) + backward_real_data_path = os.path.join(self.save_path, self.backward_path, "rank" + str(self.rank)) + check_path_before_create(backward_real_data_path) + create_directory(backward_real_data_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index b031b92a18afbe429d0a043760e427589c2865e1..98499c59985e40109333200c94143a795fffcdaa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -56,6 +56,9 @@ class Const: """ Class for const """ + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' MODEL_TYPE = ['.onnx', '.pb', '.om'] DIM_PATTERN = r"^(-?[0-9]+)(,-?[0-9]+)*" SEMICOLON = ";" @@ -606,7 +609,7 @@ def cross_entropy_process(api_info_dict): def initialize_save_path(save_path, dir_name): data_path = os.path.join(save_path, dir_name) if os.path.exists(data_path): - raise ValueError(f"file {data_path} already exists, please remove it first") + print_warn_log(f"{data_path} already exists, it will be overwritten") else: os.mkdir(data_path, mode=FileCheckConst.DATA_DIR_AUTHORITY) data_path_checker = FileChecker(data_path, FileCheckConst.DIR) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index f3e8a4cf494f8f3d8b36c189c7679314dc802733..e5e50979407537b325e84c4f34947a6dfc57b810 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -1,5 +1,6 @@ # 进行比对及结果展示 import os +import time from rich.table import Table from rich.console import Console from api_accuracy_checker.compare.algorithm import compare_core, cosine_sim, cosine_standard, get_max_rel_err, get_max_abs_err, \ @@ -9,8 +10,8 @@ from api_accuracy_checker.compare.compare_utils import CompareConst from api_accuracy_checker.common.config import msCheckerConfig class Comparator: - TEST_FILE_NAME = "accuracy_checking_result.csv" - DETAIL_TEST_FILE_NAME = "accuracy_checking_details.csv" + TEST_FILE_NAME = "accuracy_checking_result_" + time.strftime("%Y%m%d%H%M%S") + ".csv" + DETAIL_TEST_FILE_NAME = "accuracy_checking_details_" + time.strftime("%Y%m%d%H%M%S") + ".csv" # consts for result csv COLUMN_API_NAME = "API name" diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 536d607dd696b16e10ee5302aafbc64e975aa33d..bff827ac077ac809ed0667d4e8d542cb4d2c4fc2 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -2,6 +2,7 @@ import argparse import os import copy import sys +import time import torch_npu import yaml import torch @@ -213,7 +214,7 @@ def initialize_save_error_data(): error_data_path_checker = FileChecker(msCheckerConfig.error_data_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) error_data_path = error_data_path_checker.common_check() - initialize_save_path(error_data_path, 'ut_error_data') + initialize_save_path(error_data_path, 'ut_error_data' + time.strftime("%Y%m%d%H%M%S")) def _run_ut_parser(parser): diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py index ed764987d5a9287293f183c0bde1d86afd90ccae..a68057dfb41ca38ba79e1daa992a8f51ce4d64e4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/common/test_config.py @@ -17,5 +17,5 @@ class TestConfig(unittest.TestCase): def test_update_config(self): - self.config.update_config(dump_path='/new/path/to/dump', enable_dataloader=False) + self.config.update_config(dump_path='/new/path/to/dump') self.assertEqual(self.config.dump_path, '/new/path/to/dump') diff --git a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py index addba38e38446b177942a104b4194efe910b1f7c..b892a6077a3c26ae27343734aca8012e21d3fc2c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py +++ b/debug/accuracy_tools/api_accuracy_checker/test/ut/dump/test_dump_scopr.py @@ -10,12 +10,12 @@ class TestDumpScope(unittest.TestCase): wrapped_func = iter_tracer(dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, "Hello, World!") def another_dummy_func(): return 123 wrapped_func = iter_tracer(another_dummy_func) result = wrapped_func() - self.assertEqual(DumpUtil.dump_switch, "ON") + self.assertEqual(DumpUtil.dump_switch, "OFF") self.assertEqual(result, 123)