diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c931c686318908fa8b64330f7f1e72102e9330c8 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -0,0 +1,81 @@ +import yaml +import os +from api_accuracy_checker.common.utils import check_file_or_directory_path + +class Config: + def __init__(self, yaml_file): + check_file_or_directory_path(yaml_file, False) + with open(yaml_file, 'r') as file: + config = yaml.safe_load(file) + self.dump_path = self.validate_dump_path(config['dump_path']) + self.jit_compile = self.validate_jit_compile(config['jit_compile']) + self.compile_option = self.validate_compile_option(config['compile_option']) + self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) + self.real_data = self.validate_real_data(config['real_data']) + self.dump_step = self.validate_dump_step(config['dump_step']) + + def validate_dump_path(self, dump_path): + if not isinstance(dump_path, str): + raise ValueError("dump_path mast be string type") + return dump_path + + def validate_jit_compile(self, jit_compile): + if not isinstance(jit_compile, bool): + raise ValueError("jit_compile mast be bool type") + return jit_compile + + def validate_compile_option(self, compile_option): + if not isinstance(compile_option, str): + raise ValueError("compile_option mast be string type") + return compile_option + + def validate_compare_algorithm(self, compare_algorithm): + if not isinstance(compare_algorithm, str): + raise ValueError("compare_algorithm mast be string type") + return compare_algorithm + + def validate_real_data(self, real_data): + if not isinstance(real_data, bool): + raise ValueError("real_data mast be bool type") + return real_data + + def validate_dump_step(self, dump_step): + if not isinstance(dump_step, int): + raise ValueError("dump_step mast be int type") + return dump_step + + + def __str__(self): + return ( + f"dump_path={self.dump_path}\n" + f"jit_compile={self.jit_compile}\n" + f"compile_option={self.compile_option}\n" + f"compare_algorithm={self.compare_algorithm}\n" + f"real_data={self.real_data}\n" + f"dump_step={self.dump_step}\n" + ) + + def update_config(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + if key == 'dump_path': + self.validate_dump_path(value) + elif key == 'jit_compile': + self.validate_jit_compile(value) + elif key == 'compile_option': + self.validate_compile_option(value) + elif key == 'compare_algorithm': + self.validate_compare_algorithm(value) + elif key == 'real_data': + self.validate_real_data(value) + elif key == 'dump_step': + self.validate_dump_step(value) + setattr(self, key, value) + else: + raise ValueError(f"Invalid key '{key}'") + + + +cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +yaml_path = os.path.join(cur_path, "config.yaml") +msCheckerConfig = Config(yaml_path) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..38a1a3c47b1c857a9cea34b40ab56f32eed596c0 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -0,0 +1,6 @@ +dump_path: './api_info' +jit_compile: True +compile_option: -O3 +compare_algorithm: cosine_similarity +real_data: False +dump_step: 1000 \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py index ea27681ec2cf441506bd997539da8755673c978d..d2c12a0e52d3a93c9f0c1b9ff87a6d1feac5e033 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/__init__.py @@ -7,4 +7,4 @@ from api_accuracy_checker.dump.utils import set_dump_switch initialize_hook(pretest_hook) initialize_output_json() -__all__ = ['set_dump_switch'] \ No newline at end of file +__all__ = ['set_dump_switch', 'msCheckerConfig'] \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index cff12e6b84abd03489706880fb6cfd89f5334906..e8c085a0b885fc2e8b71d1f88e7ae6a68faebd6b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -11,7 +11,7 @@ class APIInfo: def __init__(self, api_name): self.rank = torch_npu.npu.current_device() self.api_name = api_name - self.save_real_data = DumpUtil.save_real_data + self.save_real_data = msCheckerConfig.real_data def analyze_element(self, element): if isinstance(element, (list, tuple)): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 3a4ba69399528c5ca61b0aedf8212dab6f55e516..1914b27d8d7e029081e1e0515d1220d51b4e0683 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -5,13 +5,13 @@ import threading import numpy as np from .api_info import ForwardAPIInfo, BackwardAPIInfo -from .utils import DumpUtil from ..common.utils import check_file_or_directory_path +from ..common.config import msCheckerConfig lock = threading.Lock() def write_api_info_json(api_info): - dump_path = DumpUtil.dump_path + dump_path = msCheckerConfig.dump_path rank = api_info.rank if isinstance(api_info, ForwardAPIInfo): file_path = os.path.join(dump_path, f'forward_info_{rank}.json') @@ -26,13 +26,14 @@ def write_api_info_json(api_info): raise ValueError(f"Invalid api_info type {type(api_info)}") def write_json(file_path, data, indent=None): - check_file_or_directory_path(file_path,False) - with open(file_path, 'w') as f: - f.write("{\n}") - try: - lock.acquire() - with open(file_path, 'a+') as f: - fcntl.flock(f, fcntl.LOCK_EX) + check_file_or_directory_path(os.path.dirname(file_path),True) + if not os.path.exists(file_path): + with open(file_path, 'w') as f: + f.write("{\n}") + lock.acquire() + with open(file_path, 'a+') as f: + fcntl.flock(f, fcntl.LOCK_EX) + try: f.seek(0, os.SEEK_END) f.seek(f.tell() - 1, os.SEEK_SET) f.truncate() @@ -41,14 +42,14 @@ def write_json(file_path, data, indent=None): f.truncate() f.write(',\n') f.write(json.dumps(data, indent=indent)[1:-1] + '\n}') - except Exception as e: - raise ValueError(f"Json save failed:{e}") - finally: - fcntl.flock(f, fcntl.LOCK_UN) - lock.release() + except Exception as e: + raise ValueError(f"Json save failed:{e}") + finally: + fcntl.flock(f, fcntl.LOCK_UN) + lock.release() def initialize_output_json(): - dump_path = DumpUtil.dump_path + dump_path = os.path.realpath(msCheckerConfig.dump_path) check_file_or_directory_path(dump_path,True) files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] if msCheckerConfig.real_data: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 0dd469f47d477a79a125f00be4875a8435a492ea..4a19785b61dd32a7e11f2ea313fab52f8f8cd7d0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -27,24 +27,12 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) class DumpUtil(object): - save_real_data = False - dump_path = './api_info' dump_switch = None - @staticmethod - def set_dump_path(save_path): - DumpUtil.dump_path = save_path - DumpUtil.dump_init_enable = True - @staticmethod def set_dump_switch(switch): DumpUtil.dump_switch = switch - @staticmethod - def get_dump_path(): - if DumpUtil.dump_path: - return DumpUtil.dump_path - @staticmethod def get_dump_switch(): return DumpUtil.dump_switch == "ON" \ No newline at end of file