diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py index 0db84b45bf785cc0146160087e211c6673a02d6b..627bd769766863683d2191509a54f5e28660a623 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -71,6 +71,8 @@ class BaseAPIInfo: def analyze_builtin(self, arg): single_arg = {} + if self.is_save_data: + self.args_num += 1 if isinstance(arg, slice): single_arg.update({'type' : "slice"}) single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 707d6cbed9dbc19885470e916b4c49615356c657..b70720f58969d9f8c7e3a2bc86979ddeb9866e09 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -123,14 +123,11 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) api_full_name = api_full_name.replace("*", ".") for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '.forward.input', element) - if data_info.bench_out is not None: - UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) - UtAPIInfo(api_full_name + '.forward.output.npu', data_info.npu_out) - if data_info.grad_in is not None: - UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) - if data_info.bench_grad_out is not None: - UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) - UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out) + UtAPIInfo(api_full_name + '.forward.output.bench', data_info.bench_out) + UtAPIInfo(api_full_name + '.forward.output.npu', data_info.npu_out) + UtAPIInfo(api_full_name + '.backward.input', data_info.grad_in) + UtAPIInfo(api_full_name + '.backward.output.bench', data_info.bench_grad_out) + UtAPIInfo(api_full_name + '.backward.output.npu', data_info.npu_grad_out)