From b7c84a59899d4659b67297f4a50d14d62a8a5e40 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 03:51:10 +0000 Subject: [PATCH 01/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/compare/algorithm.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index a79125d832b..429dcac6d53 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -47,25 +47,29 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): n_value = npu_output.cpu().detach().numpy().reshape(-1) b_value = cpu_output.detach().numpy().reshape(-1) - cos = CompareConst.NA - np.seterr(divide="ignore", invalid="ignore") + cos = CompareConst.NA + np.seterr(divide='ignore', invalid='ignore') if len(n_value) == 1: - print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") - return get_max_rel_err(n_value, b_value) + print_warn_log('All the data in npu dump data is scalar. Compare by relative error.') + return get_max_rel_err_scalar(n_value, b_value) + n_value = n_value / np.max(np.abs(n_value)) + b_value = b_value / np.max(np.abs(b_value)) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) if a_norm <= np.finfo(float).eps and b_norm <= np.finfo(float).eps: - return cos, True - elif a_norm <= np.finfo(float).eps: - print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") - return get_max_rel_err(n_value, b_value) + return cos, True + elif a_norm <= np.finfo(float).eps: + print_warn_log('All the data is Zero in npu dump data. Compare by relative error') + return get_max_rel_err_scalar(n_value, b_value) elif b_norm <= np.finfo(float).eps: - print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") - else: + print_warn_log('All the data is Zero in bench dump data. Compare by relative error') + return get_max_rel_err_scalar(n_value, b_value) + else: cos = num / (a_norm * b_norm) + print(cos) if np.isnan(cos): - print_warn_log("Dump data has NaN when comparing with Cosine Similarity.") + print_warn_log('Cannot compare by Cosine Similarity, the dump data has NaN.') return cos, cos > 0.99 -- Gitee From fd47ccf77c4af0ec40b9745389471628f5b2ead0 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Fri, 4 Aug 2023 12:03:37 +0800 Subject: [PATCH 02/13] dataloader --- .../api_accuracy_checker/dump/dump_scope.py | 7 +++++- .../api_accuracy_checker/dump/utils.py | 24 +++++++++++++++++-- 2 files changed, 28 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 51dbd75d9c8..db37c4e4ce0 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1 +1,6 @@ -# dump范围控制 ———— 李天 \ No newline at end of file +# dump范围控制 ———— 李天 +import torch +from api_accuracy_checker.dump.utils import iter_tracer +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 4a19785b61d..6f2eb4f22cf 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -27,7 +27,9 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) class DumpUtil(object): - dump_switch = None + dump_switch = "None" + target_iter_range = 1 + call_num = 0 @staticmethod def set_dump_switch(switch): @@ -35,4 +37,22 @@ class DumpUtil(object): @staticmethod def get_dump_switch(): - return DumpUtil.dump_switch == "ON" \ No newline at end of file + return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num == DumpUtil.target_iter_range : + DumpUtil.dump_switch = "ON" + elif DumpUtil.call_num > DumpUtil.target_iter_range: + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter_range)) + else: + DumpUtil.dump_switch = "OFF" + DumpUtil.call_num += 1 + +def iter_tracer(func): + def func_wrapper(*args, **kwargs): + DumpUtil.dump_switch = "OFF " + result = func(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + return result + return func_wrapper \ No newline at end of file -- Gitee From 73aefeff97751963420167c1383c617cf97acff4 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Fri, 4 Aug 2023 12:07:01 +0800 Subject: [PATCH 03/13] dataloader --- .../api_accuracy_checker/common/ut.py | 104 ++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 debug/accuracy_tools/api_accuracy_checker/common/ut.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/ut.py b/debug/accuracy_tools/api_accuracy_checker/common/ut.py new file mode 100644 index 00000000000..95ca98af2c6 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/common/ut.py @@ -0,0 +1,104 @@ +import unittest + +import yaml +import os +from ..common.utils import check_file_or_directory_path + +class Config: + def __init__(self, yaml_file): + check_file_or_directory_path(yaml_file, False) + with open(yaml_file, 'r') as file: + config = yaml.safe_load(file) + self.dump_path = self.validate_dump_path(config['dump_path']) + self.jit_compile = self.validate_jit_compile(config['jit_compile']) + self.compile_option = self.validate_compile_option(config['compile_option']) + self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) + self.real_data = self.validate_real_data(config['real_data']) + self.dump_step = self.validate_dump_step(config['dump_step']) + + def validate_dump_path(self, dump_path): + if not isinstance(dump_path, str): + raise ValueError("dump_path mast be string type") + return dump_path + + def validate_jit_compile(self, jit_compile): + if not isinstance(jit_compile, bool): + raise ValueError("jit_compile mast be bool type") + return jit_compile + + def validate_compile_option(self, compile_option): + if not isinstance(compile_option, str): + raise ValueError("compile_option mast be string type") + return compile_option + + def validate_compare_algorithm(self, compare_algorithm): + if not isinstance(compare_algorithm, str): + raise ValueError("compare_algorithm mast be string type") + return compare_algorithm + + def validate_real_data(self, real_data): + if not isinstance(real_data, bool): + raise ValueError("real_data mast be bool type") + return real_data + + def validate_dump_step(self, dump_step): + if not isinstance(dump_step, int): + raise ValueError("dump_step mast be int type") + return dump_step + + + def __str__(self): + return ( + f"dump_path={self.dump_path}\n" + f"jit_compile={self.jit_compile}\n" + f"compile_option={self.compile_option}\n" + f"compare_algorithm={self.compare_algorithm}\n" + f"real_data={self.real_data}\n" + f"dump_step={self.dump_step}\n" + ) + + def update_config(self, **kwargs): + for key, value in kwargs.items(): + if hasattr(self, key): + if key == 'dump_path': + self.validate_dump_path(value) + elif key == 'jit_compile': + self.validate_jit_compile(value) + elif key == 'compile_option': + self.validate_compile_option(value) + elif key == 'compare_algorithm': + self.validate_compare_algorithm(value) + elif key == 'real_data': + self.validate_real_data(value) + elif key == 'dump_step': + self.validate_dump_step(value) + setattr(self, key, value) + else: + raise ValueError(f"Invalid key '{key}'") + + + +cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) +yaml_path = os.path.join(cur_path, "config.yaml") +msCheckerConfig = Config(yaml_path) + +class ConfigTestCase(unittest.TestCase): + def setUp(self): + self.config = Config('./config.yaml') + # print(self.config) + + def test_config_loading(self): + self.assertEqual(self.config.dump_path, './api_info') + self.assertEqual(self.config.jit_compile, True) + + def test_config_update(self): + self.config.update_config(dump_path='./dump2') + self.assertEqual(self.config.dump_path, './dump2') + # self.config.update_config(dump='./dump2') + +if __name__ == '__main__': + # unittest.main() + config = Config('./config.yaml') + config.update_config(dump_pat='./dump2') + print(config) + \ No newline at end of file -- Gitee From 0003b5262201adde13741607becc981664dd815d Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 04:07:52 +0000 Subject: [PATCH 04/13] Revert "dataloader" This reverts commit 73aefeff97751963420167c1383c617cf97acff4. --- .../api_accuracy_checker/common/ut.py | 104 ------------------ 1 file changed, 104 deletions(-) delete mode 100644 debug/accuracy_tools/api_accuracy_checker/common/ut.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/ut.py b/debug/accuracy_tools/api_accuracy_checker/common/ut.py deleted file mode 100644 index 95ca98af2c6..00000000000 --- a/debug/accuracy_tools/api_accuracy_checker/common/ut.py +++ /dev/null @@ -1,104 +0,0 @@ -import unittest - -import yaml -import os -from ..common.utils import check_file_or_directory_path - -class Config: - def __init__(self, yaml_file): - check_file_or_directory_path(yaml_file, False) - with open(yaml_file, 'r') as file: - config = yaml.safe_load(file) - self.dump_path = self.validate_dump_path(config['dump_path']) - self.jit_compile = self.validate_jit_compile(config['jit_compile']) - self.compile_option = self.validate_compile_option(config['compile_option']) - self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) - self.real_data = self.validate_real_data(config['real_data']) - self.dump_step = self.validate_dump_step(config['dump_step']) - - def validate_dump_path(self, dump_path): - if not isinstance(dump_path, str): - raise ValueError("dump_path mast be string type") - return dump_path - - def validate_jit_compile(self, jit_compile): - if not isinstance(jit_compile, bool): - raise ValueError("jit_compile mast be bool type") - return jit_compile - - def validate_compile_option(self, compile_option): - if not isinstance(compile_option, str): - raise ValueError("compile_option mast be string type") - return compile_option - - def validate_compare_algorithm(self, compare_algorithm): - if not isinstance(compare_algorithm, str): - raise ValueError("compare_algorithm mast be string type") - return compare_algorithm - - def validate_real_data(self, real_data): - if not isinstance(real_data, bool): - raise ValueError("real_data mast be bool type") - return real_data - - def validate_dump_step(self, dump_step): - if not isinstance(dump_step, int): - raise ValueError("dump_step mast be int type") - return dump_step - - - def __str__(self): - return ( - f"dump_path={self.dump_path}\n" - f"jit_compile={self.jit_compile}\n" - f"compile_option={self.compile_option}\n" - f"compare_algorithm={self.compare_algorithm}\n" - f"real_data={self.real_data}\n" - f"dump_step={self.dump_step}\n" - ) - - def update_config(self, **kwargs): - for key, value in kwargs.items(): - if hasattr(self, key): - if key == 'dump_path': - self.validate_dump_path(value) - elif key == 'jit_compile': - self.validate_jit_compile(value) - elif key == 'compile_option': - self.validate_compile_option(value) - elif key == 'compare_algorithm': - self.validate_compare_algorithm(value) - elif key == 'real_data': - self.validate_real_data(value) - elif key == 'dump_step': - self.validate_dump_step(value) - setattr(self, key, value) - else: - raise ValueError(f"Invalid key '{key}'") - - - -cur_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) -yaml_path = os.path.join(cur_path, "config.yaml") -msCheckerConfig = Config(yaml_path) - -class ConfigTestCase(unittest.TestCase): - def setUp(self): - self.config = Config('./config.yaml') - # print(self.config) - - def test_config_loading(self): - self.assertEqual(self.config.dump_path, './api_info') - self.assertEqual(self.config.jit_compile, True) - - def test_config_update(self): - self.config.update_config(dump_path='./dump2') - self.assertEqual(self.config.dump_path, './dump2') - # self.config.update_config(dump='./dump2') - -if __name__ == '__main__': - # unittest.main() - config = Config('./config.yaml') - config.update_config(dump_pat='./dump2') - print(config) - \ No newline at end of file -- Gitee From c410a93203c1e24d06aa0927cf685606302ffd8e Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 04:14:25 +0000 Subject: [PATCH 05/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- .../api_accuracy_checker/compare/algorithm.py | 21 +++++++++---------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 429dcac6d53..e66958fd9ea 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -47,29 +47,28 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): n_value = npu_output.cpu().detach().numpy().reshape(-1) b_value = cpu_output.detach().numpy().reshape(-1) - cos = CompareConst.NA - np.seterr(divide='ignore', invalid='ignore') + cos = CompareConst.NA + np.seterr(divide="ignore", invalid="ignore") if len(n_value) == 1: - print_warn_log('All the data in npu dump data is scalar. Compare by relative error.') - return get_max_rel_err_scalar(n_value, b_value) + print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") + return get_max_rel_err(n_value, b_value) n_value = n_value / np.max(np.abs(n_value)) b_value = b_value / np.max(np.abs(b_value)) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) if a_norm <= np.finfo(float).eps and b_norm <= np.finfo(float).eps: - return cos, True - elif a_norm <= np.finfo(float).eps: - print_warn_log('All the data is Zero in npu dump data. Compare by relative error') - return get_max_rel_err_scalar(n_value, b_value) + return cos, True + elif a_norm <= np.finfo(float).eps: + print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") + return get_max_rel_err(n_value, b_value) elif b_norm <= np.finfo(float).eps: - print_warn_log('All the data is Zero in bench dump data. Compare by relative error') - return get_max_rel_err_scalar(n_value, b_value) + print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") else: cos = num / (a_norm * b_norm) print(cos) if np.isnan(cos): - print_warn_log('Cannot compare by Cosine Similarity, the dump data has NaN.') + print_warn_log("Dump data has NaN when comparing with Cosine Similarity.") return cos, cos > 0.99 -- Gitee From de4d91de090b9b9a0fb3a38a54400326b9e3c878 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 04:15:06 +0000 Subject: [PATCH 06/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index e66958fd9ea..5feb2156031 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -64,9 +64,8 @@ def cosine_sim(cpu_output, npu_output): return get_max_rel_err(n_value, b_value) elif b_norm <= np.finfo(float).eps: print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") - else: + else: cos = num / (a_norm * b_norm) - print(cos) if np.isnan(cos): print_warn_log("Dump data has NaN when comparing with Cosine Similarity.") return cos, cos > 0.99 -- Gitee From 64e84e51ffb986d659edfb0d9b24702e7b051f8c Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 04:15:58 +0000 Subject: [PATCH 07/13] update debug/accuracy_tools/api_accuracy_checker/dump/utils.py. Signed-off-by: sunyiming --- debug/accuracy_tools/api_accuracy_checker/dump/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 6f2eb4f22cf..d7845d73434 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -27,7 +27,7 @@ def set_dump_switch(switch): DumpUtil.set_dump_switch(switch) class DumpUtil(object): - dump_switch = "None" + dump_switch = None target_iter_range = 1 call_num = 0 -- Gitee From eed4d2b18ebc939b5c8ae8de5d53f132abf80370 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 07:23:34 +0000 Subject: [PATCH 08/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 5feb2156031..9e45e9a4144 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -52,8 +52,8 @@ def cosine_sim(cpu_output, npu_output): if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) - n_value = n_value / np.max(np.abs(n_value)) - b_value = b_value / np.max(np.abs(b_value)) + n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(float).eps) + b_value = b_value / (np.max(np.abs(b_value)) + np.finfo(float).eps) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) -- Gitee From b454e56416e4d03c1f7409e74a42021301f7e79d Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 09:18:57 +0000 Subject: [PATCH 09/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9e45e9a4144..1248e965f79 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -52,8 +52,8 @@ def cosine_sim(cpu_output, npu_output): if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) - n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(float).eps) - b_value = b_value / (np.max(np.abs(b_value)) + np.finfo(float).eps) + n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(n_value.dtype).eps) + b_value = b_value / (np.max(np.abs(b_value)) + np.finfo(b_value.dtype).eps) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) -- Gitee From eccf09da274c241a37d733bec538bd08e1020802 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Fri, 4 Aug 2023 09:21:04 +0000 Subject: [PATCH 10/13] update debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py. Signed-off-by: sunyiming --- .../accuracy_tools/api_accuracy_checker/compare/algorithm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 9d902eaa2e1..2d4cfea38bd 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -52,10 +52,10 @@ def cosine_sim(cpu_output, npu_output): if len(n_value) == 1: print_warn_log("All the data in npu dump data is scalar. Compare by relative error.") return get_max_rel_err(n_value, b_value) - n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(n_value.dtype).eps) - b_value = b_value / (np.max(np.abs(b_value)) + np.finfo(b_value.dtype).eps) if n_value.dtype == np.uint8: return compare_uint8_data(n_value, b_value) + n_value = n_value / (np.max(np.abs(n_value)) + np.finfo(n_value.dtype).eps) + b_value = b_value / (np.max(np.abs(b_value)) + np.finfo(b_value.dtype).eps) num = n_value.dot(b_value) a_norm = np.linalg.norm(n_value) b_norm = np.linalg.norm(b_value) -- Gitee From 24e8e489a950720ff2987cd3033ed919ba7526c8 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Sat, 5 Aug 2023 10:59:52 +0800 Subject: [PATCH 11/13] update --- .../api_accuracy_checker/run_ut/run_ut.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 0c18da7d0e8..3c301066714 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -30,28 +30,27 @@ def exec_api(api_type, api_name, args, kwargs): return out +def arg_to_npu_recursive(arg_in): + if isinstance(arg_in, list) or isinstance(arg_in, tuple): + return [arg_to_npu_recursive(item) for item in arg_in] + elif isinstance(arg_in, torch.Tensor): + return arg_to_npu(arg_in) + else: + return arg_in + def generate_npu_params(cpu_args, cpu_kwargs, need_backward): npu_args = [] npu_kwargs = {} if need_backward: - for arg_in in cpu_args: - arg_in = arg_to_npu(arg_in) - npu_args.append(arg_in) - for key, value in cpu_kwargs.items(): - value = arg_to_npu(value) - npu_kwargs[key] = value + npu_args = [arg_to_npu_recursive(arg_in) for arg_in in cpu_args] + npu_kwargs = {key: arg_to_npu_recursive(value) for key, value in cpu_kwargs.items()} else: - for arg_in in cpu_args: - if isinstance(arg_in, torch.Tensor): - arg_in = arg_in.clone().detach().to("npu") - npu_args.append(arg_in) - for key, value in cpu_kwargs.items(): - if isinstance(value, torch.Tensor): - value = value.clone().detach().to("npu") - npu_kwargs[key] = value + npu_args = [arg_in.clone().detach().to("npu") if isinstance(arg_in, torch.Tensor) else arg_in for arg_in in cpu_args] + npu_kwargs = {key: value.clone().detach().to("npu") if isinstance(value, torch.Tensor) else value for key, value in cpu_kwargs.items()} return npu_args, npu_kwargs + def arg_to_npu(arg_in): if isinstance(arg_in, torch.Tensor) and arg_in.dtype in [torch.float, torch.float16, torch.float64] and arg_in.requires_grad: -- Gitee From 9cc2fc02ea355daaf07836c9a04d019eaba81785 Mon Sep 17 00:00:00 2001 From: sunyiming Date: Sat, 5 Aug 2023 03:30:33 +0000 Subject: [PATCH 12/13] Revert "update" This reverts commit 24e8e489a950720ff2987cd3033ed919ba7526c8. --- .../api_accuracy_checker/run_ut/run_ut.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 3c301066714..0c18da7d0e8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -30,27 +30,28 @@ def exec_api(api_type, api_name, args, kwargs): return out -def arg_to_npu_recursive(arg_in): - if isinstance(arg_in, list) or isinstance(arg_in, tuple): - return [arg_to_npu_recursive(item) for item in arg_in] - elif isinstance(arg_in, torch.Tensor): - return arg_to_npu(arg_in) - else: - return arg_in - def generate_npu_params(cpu_args, cpu_kwargs, need_backward): npu_args = [] npu_kwargs = {} if need_backward: - npu_args = [arg_to_npu_recursive(arg_in) for arg_in in cpu_args] - npu_kwargs = {key: arg_to_npu_recursive(value) for key, value in cpu_kwargs.items()} + for arg_in in cpu_args: + arg_in = arg_to_npu(arg_in) + npu_args.append(arg_in) + for key, value in cpu_kwargs.items(): + value = arg_to_npu(value) + npu_kwargs[key] = value else: - npu_args = [arg_in.clone().detach().to("npu") if isinstance(arg_in, torch.Tensor) else arg_in for arg_in in cpu_args] - npu_kwargs = {key: value.clone().detach().to("npu") if isinstance(value, torch.Tensor) else value for key, value in cpu_kwargs.items()} + for arg_in in cpu_args: + if isinstance(arg_in, torch.Tensor): + arg_in = arg_in.clone().detach().to("npu") + npu_args.append(arg_in) + for key, value in cpu_kwargs.items(): + if isinstance(value, torch.Tensor): + value = value.clone().detach().to("npu") + npu_kwargs[key] = value return npu_args, npu_kwargs - def arg_to_npu(arg_in): if isinstance(arg_in, torch.Tensor) and arg_in.dtype in [torch.float, torch.float16, torch.float64] and arg_in.requires_grad: -- Gitee From 2256f4de42493e4c3f9424944900c1b6317416b6 Mon Sep 17 00:00:00 2001 From: s30048155 Date: Tue, 8 Aug 2023 11:26:25 +0800 Subject: [PATCH 13/13] update cosine_sim --- .../api_accuracy_checker/compare/algorithm.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 17243a74158..73178f463fa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -45,8 +45,8 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): - n_value = npu_output.cpu().detach().numpy().reshape(-1) - b_value = cpu_output.detach().numpy().reshape(-1) + n_value = npu_output.cpu().detach().numpy().flatten() + b_value = cpu_output.detach().numpy().flatten() cos = CompareConst.NA np.seterr(divide="ignore", invalid="ignore") if len(n_value) == 1: @@ -54,15 +54,12 @@ def cosine_sim(cpu_output, npu_output): return get_max_rel_err(n_value, b_value) if n_value.dtype == np.uint8: return compare_uint8_data(n_value, b_value) - n_max = np.max(np.abs(n_value)) - b_max = np.max(np.abs(b_value)) + n_max, b_max = np.max(np.abs(n_value)), np.max(np.abs(b_value)) if n_max <= np.finfo(float).eps and b_max <= np.finfo(float).eps: return cos, True - elif n_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") + elif n_max <= np.finfo(float).eps or b_max <= np.finfo(float).eps: + print_warn_log("All the data is Zero in either npu dump data or bench dump data. Compare by relative error.") return get_max_rel_err(n_value, b_value) - elif b_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") else: n_value = n_value.astype(float) / n_max b_value = b_value.astype(float) / b_max -- Gitee