diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py index 17243a7415881198c1841afa119cc242292bfdc1..73178f463fa9dbaaa70da46b31bd4607a900ef98 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/algorithm.py @@ -45,8 +45,8 @@ def cosine_standard(compare_result): def cosine_sim(cpu_output, npu_output): - n_value = npu_output.cpu().detach().numpy().reshape(-1) - b_value = cpu_output.detach().numpy().reshape(-1) + n_value = npu_output.cpu().detach().numpy().flatten() + b_value = cpu_output.detach().numpy().flatten() cos = CompareConst.NA np.seterr(divide="ignore", invalid="ignore") if len(n_value) == 1: @@ -54,15 +54,12 @@ def cosine_sim(cpu_output, npu_output): return get_max_rel_err(n_value, b_value) if n_value.dtype == np.uint8: return compare_uint8_data(n_value, b_value) - n_max = np.max(np.abs(n_value)) - b_max = np.max(np.abs(b_value)) + n_max, b_max = np.max(np.abs(n_value)), np.max(np.abs(b_value)) if n_max <= np.finfo(float).eps and b_max <= np.finfo(float).eps: return cos, True - elif n_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in npu dump data. Compare by relative error.") + elif n_max <= np.finfo(float).eps or b_max <= np.finfo(float).eps: + print_warn_log("All the data is Zero in either npu dump data or bench dump data. Compare by relative error.") return get_max_rel_err(n_value, b_value) - elif b_max <= np.finfo(float).eps: - print_warn_log("All the data is Zero in bench dump data. Compare by relative error.") else: n_value = n_value.astype(float) / n_max b_value = b_value.astype(float) / b_max diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py index 51dbd75d9c87ff42d6730c51af017cf6b6e03fe8..db37c4e4ce034c2b2752a9bdbfdada7d45357e35 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/dump_scope.py @@ -1 +1,6 @@ -# dump范围控制 ———— 李天 \ No newline at end of file +# dump范围控制 ———— 李天 +import torch +from api_accuracy_checker.dump.utils import iter_tracer +from torch.utils.data import Dataset, DataLoader +from torch.utils.data.dataloader import _BaseDataLoaderIter +_BaseDataLoaderIter.__next__ = iter_tracer(torch.utils.data.dataloader._BaseDataLoaderIter.__next__) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 93af6f0981aa0bd3d3c42606a50f04b79bc1c37b..db4417f088af760ab85d3349b32a7e998006f356 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -13,3 +13,37 @@ def write_npy(file_path, tensor): np.save(file_path, tensor) full_path = os.path.abspath(file_path) return full_path + +def set_dump_switch(switch): + DumpUtil.set_dump_switch(switch) + +class DumpUtil(object): + dump_switch = None + target_iter_range = 1 + call_num = 0 + + @staticmethod + def set_dump_switch(switch): + DumpUtil.dump_switch = switch + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + @staticmethod + def incr_iter_num_maybe_exit(): + if DumpUtil.call_num == DumpUtil.target_iter_range : + DumpUtil.dump_switch = "ON" + elif DumpUtil.call_num > DumpUtil.target_iter_range: + raise Exception("Model pretest: exit after iteration {}".format(DumpUtil.target_iter_range)) + else: + DumpUtil.dump_switch = "OFF" + DumpUtil.call_num += 1 + +def iter_tracer(func): + def func_wrapper(*args, **kwargs): + DumpUtil.dump_switch = "OFF " + result = func(*args, **kwargs) + DumpUtil.incr_iter_num_maybe_exit() + return result + return func_wrapper