From 08395b95917b9794be2c347f4658daf91bbe48a4 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Sep 2023 11:10:37 +0800 Subject: [PATCH 1/6] update --- .../api_accuracy_checker/common/config.py | 88 +++++-------------- .../api_accuracy_checker/config.yaml | 4 +- .../api_accuracy_checker/dump/dump.py | 3 +- .../api_accuracy_checker/dump/dump_scope.py | 5 +- 4 files changed, 30 insertions(+), 70 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 9fe21ccb3f..c8b1a46070 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -7,81 +7,37 @@ class Config: check_file_or_directory_path(yaml_file, False) with open(yaml_file, 'r') as file: config = yaml.safe_load(file) - self.dump_path = self.validate_dump_path(config['dump_path']) - self.jit_compile = self.validate_jit_compile(config['jit_compile']) - self.compile_option = self.validate_compile_option(config['compile_option']) - self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) - self.real_data = self.validate_real_data(config['real_data']) - self.dump_step = self.validate_dump_step(config['dump_step']) - self.error_data_path = self.validate_error_data_path(config['error_data_path']) - - def validate_dump_path(self, dump_path): - if not isinstance(dump_path, str): - raise ValueError("dump_path mast be string type") - return dump_path - - def validate_jit_compile(self, jit_compile): - if not isinstance(jit_compile, bool): - raise ValueError("jit_compile mast be bool type") - return jit_compile - - def validate_compile_option(self, compile_option): - if not isinstance(compile_option, str): - raise ValueError("compile_option mast be string type") - return compile_option - - def validate_compare_algorithm(self, compare_algorithm): - if not isinstance(compare_algorithm, str): - raise ValueError("compare_algorithm mast be string type") - return compare_algorithm - - def validate_real_data(self, real_data): - if not isinstance(real_data, bool): - raise ValueError("real_data mast be bool type") - return real_data - - def validate_dump_step(self, dump_step): - if not isinstance(dump_step, int): - raise ValueError("dump_step mast be int type") - return dump_step - - def validate_error_data_path(self, error_data_path): - if not isinstance(error_data_path, str): - raise ValueError("error_data_path mast be string type") - return error_data_path - + self.config = {key: self.validate(key, value) for key, value in config.items()} + + def validate(self, key, value): + validators = { + 'dump_path': str, + 'jit_compile': bool, + 'compile_option': str, + 'compare_algorithm': str, + 'real_data': bool, + 'dump_step': int, + 'error_data_path': str, + 'dataloader': bool, + 'target_iter': int + } + if not isinstance(value, validators[key]): + raise ValueError(f"{key} must be {validators[key].__name__} type") + if isinstance(value, int) and value <= 0: + raise ValueError(f"{key} must be greater than 0") + return value def __str__(self): - return ( - f"dump_path={self.dump_path}\n" - f"jit_compile={self.jit_compile}\n" - f"compile_option={self.compile_option}\n" - f"compare_algorithm={self.compare_algorithm}\n" - f"real_data={self.real_data}\n" - f"dump_step={self.dump_step}\n" - ) + return '\n'.join(f"{key}={value}" for key, value in self.config.items()) def update_config(self, **kwargs): for key, value in kwargs.items(): - if hasattr(self, key): - if key == 'dump_path': - self.validate_dump_path(value) - elif key == 'jit_compile': - self.validate_jit_compile(value) - elif key == 'compile_option': - self.validate_compile_option(value) - elif key == 'compare_algorithm': - self.validate_compare_algorithm(value) - elif key == 'real_data': - self.validate_real_data(value) - elif key == 'dump_step': - self.validate_dump_step(value) - setattr(self, key, value) + if key in self.config: + self.config[key] = self.validate(key, value) else: raise ValueError(f"Invalid key '{key}'") - cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) yaml_path = os.path.join(cur_path, "config.yaml") msCheckerConfig = Config(yaml_path) \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 2b22a9d9f9..582443ba3e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -4,4 +4,6 @@ compile_option: -O3 compare_algorithm: cosine_similarity real_data: False dump_step: 1000 -error_data_path: './' \ No newline at end of file +error_data_path: './' +dataloader: True +target_iter: 1 \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index ce7895a442..45548bb44c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -19,6 +19,7 @@ from api_accuracy_checker.dump.api_info import ForwardAPIInfo, BackwardAPIInfo from api_accuracy_checker.dump.info_dump import write_api_info_json, initialize_output_json from api_accuracy_checker.common.utils import print_error_log from api_accuracy_checker.hook_module.register_hook import initialize_hook +from api_accuracy_checker.common.config import msCheckerConfig def set_dump_switch(switch): @@ -29,7 +30,7 @@ def set_dump_switch(switch): class DumpUtil(object): dump_switch = None - target_iter = 1 + target_iter = msCheckerConfig.target_iter call_num = 0 @staticmethod diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 497ad22d5a..5bb6dcaca3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -2,7 +2,7 @@ import torch from torch.utils.data.dataloader import _BaseDataLoaderIter from api_accuracy_checker.dump.dump import DumpUtil - +from api_accuracy_checker.common.config import msCheckerConfig def iter_tracer(func): def func_wrapper(*args, **kwargs): @@ -12,4 +12,5 @@ def iter_tracer(func): return result return func_wrapper -_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file +if msCheckerConfig.dataloader: + _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file -- Gitee From 00481d4b08c7256f8f14a8e4cd7a4cb1dd020e05 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Mon, 11 Sep 2023 12:04:15 +0800 Subject: [PATCH 2/6] update --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index c8b1a46070..548c8f3422 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -23,8 +23,8 @@ class Config: } if not isinstance(value, validators[key]): raise ValueError(f"{key} must be {validators[key].__name__} type") - if isinstance(value, int) and value <= 0: - raise ValueError(f"{key} must be greater than 0") + if key == 'target_iter' and value < 0: + raise ValueError("target_iter must be greater than 0") return value def __str__(self): -- Gitee From 7bb96278c868632fc04e7c34c38488b71aabb4e1 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Sep 2023 09:20:30 +0800 Subject: [PATCH 3/6] update --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 2 +- debug/accuracy_tools/api_accuracy_checker/config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index 548c8f3422..b7e1a150b9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -18,7 +18,7 @@ class Config: 'real_data': bool, 'dump_step': int, 'error_data_path': str, - 'dataloader': bool, + 'enable_dataloader': bool, 'target_iter': int } if not isinstance(value, validators[key]): diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 582443ba3e..46f0ed8d41 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -5,5 +5,5 @@ compare_algorithm: cosine_similarity real_data: False dump_step: 1000 error_data_path: './' -dataloader: True +enable_dataloader: True target_iter: 1 \ No newline at end of file -- Gitee From 167cf8db3ad39aa545a3eb25bc33a32e7e8de243 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Sep 2023 11:13:51 +0800 Subject: [PATCH 4/6] bug fix --- debug/accuracy_tools/api_accuracy_checker/common/config.py | 3 +++ debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index b7e1a150b9..07dd4e6bfc 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -27,6 +27,9 @@ class Config: raise ValueError("target_iter must be greater than 0") return value + def __getattr__(self, item): + return self.config[item] + def __str__(self): return '\n'.join(f"{key}={value}" for key, value in self.config.items()) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 5bb6dcaca3..9f8dc8325c 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -12,5 +12,5 @@ def iter_tracer(func): return result return func_wrapper -if msCheckerConfig.dataloader: +if msCheckerConfig.enable_dataloader: _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file -- Gitee From dadbfa5b5d12f9361d8fe806327f95000a05b4d6 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Sep 2023 11:30:56 +0800 Subject: [PATCH 5/6] bugfix --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 45548bb44c..5f9ce453bc 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -30,7 +30,6 @@ def set_dump_switch(switch): class DumpUtil(object): dump_switch = None - target_iter = msCheckerConfig.target_iter call_num = 0 @staticmethod @@ -43,10 +42,10 @@ class DumpUtil(object): @staticmethod def incr_iter_num_maybe_exit(): - if DumpUtil.call_num == DumpUtil.target_iter: + if DumpUtil.call_num == msCheckerConfig.target_iter: set_dump_switch("ON") - elif DumpUtil.call_num > DumpUtil.target_iter: - raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter)) + elif DumpUtil.call_num > msCheckerConfig.target_iter: + raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) else: set_dump_switch("OFF") DumpUtil.call_num += 1 -- Gitee From 4ae4a28207fa9cb9d467f17ac2e8d02b4f5bf31b Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 12 Sep 2023 16:16:12 +0800 Subject: [PATCH 6/6] bug fix --- debug/accuracy_tools/api_accuracy_checker/dump/dump.py | 4 ++-- debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py | 3 +-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py index 5f9ce453bc..08385882d3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump.py @@ -42,9 +42,9 @@ class DumpUtil(object): @staticmethod def incr_iter_num_maybe_exit(): - if DumpUtil.call_num == msCheckerConfig.target_iter: + if DumpUtil.call_num == msCheckerConfig.target_iter or not msCheckerConfig.enable_dataloader: set_dump_switch("ON") - elif DumpUtil.call_num > msCheckerConfig.target_iter: + elif DumpUtil.call_num > msCheckerConfig.target_iter and msCheckerConfig.enable_dataloader: raise Exception("Model pretest: exit after iteration {}".format(msCheckerConfig.target_iter)) else: set_dump_switch("OFF") diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 9f8dc8325c..17f94da193 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -12,5 +12,4 @@ def iter_tracer(func): return result return func_wrapper -if msCheckerConfig.enable_dataloader: - _BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) \ No newline at end of file -- Gitee