From 28875fa133fea57ef4fe37220d472f654e118b9e Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 15:59:51 +0800 Subject: [PATCH 1/3] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/config.py | 6 +++ .../api_accuracy_checker/compare/compare.py | 1 + .../api_accuracy_checker/config.yaml | 3 +- .../api_accuracy_checker/dump/api_info.py | 25 +++++---- .../api_accuracy_checker/dump/info_dump.py | 33 ++++++------ .../api_accuracy_checker/run_ut/run_ut.py | 51 ++++++++++++++++--- 6 files changed, 83 insertions(+), 36 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index c931c686318..9fe21ccb3f3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -13,6 +13,7 @@ class Config: self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) self.real_data = self.validate_real_data(config['real_data']) self.dump_step = self.validate_dump_step(config['dump_step']) + self.error_data_path = self.validate_error_data_path(config['error_data_path']) def validate_dump_path(self, dump_path): if not isinstance(dump_path, str): @@ -43,6 +44,11 @@ class Config: if not isinstance(dump_step, int): raise ValueError("dump_step mast be int type") return dump_step + + def validate_error_data_path(self, error_data_path): + if not isinstance(error_data_path, str): + raise ValueError("error_data_path mast be string type") + return error_data_path def __str__(self): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 7a1c069e2ef..e4055d62b47 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -111,6 +111,7 @@ class Comparator: self.test_result_cnt['forward_fail_num'] += 1 else: self.test_result_cnt['backward_fail_num'] += 1 + return is_fwd_success, is_bwd_success def _compare_core_wrapper(self, bench_out, npu_out): detailed_result_total = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 7e2cd46fe24..2b22a9d9f93 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,4 +3,5 @@ jit_compile: True compile_option: -O3 compare_algorithm: cosine_similarity real_data: False -dump_step: 1000 \ No newline at end of file +dump_step: 1000 +error_data_path: './' \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 5d7fb97e276..81aa2976275 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -13,15 +13,15 @@ class APIInfo: self.rank = os.getpid() self.api_name = api_name self.save_real_data = msCheckerConfig.real_data - self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} + self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} self.is_forward = is_forward self.args_num = 0 - def analyze_element(self, element): + def analyze_element(self, element, is_save_data, save_path, forward_path='forward_real_data', backward_path='backward_real_data'): if isinstance(element, (list, tuple)): out = [] for item in element: - out.append(self.analyze_element(item)) + out.append(self.analyze_element(item, is_save_data, save_path, forward_path, backward_path)) elif isinstance(element, dict): out = {} for key, value in element.items(): @@ -29,10 +29,10 @@ class APIInfo: fun = self.torch_object_key[key] out[key] = fun(value) else: - out[key] = self.analyze_element(value) + out[key] = self.analyze_element(value, is_save_data, save_path, forward_path, backward_path) elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element, self.save_real_data) + out = self.analyze_tensor(element, is_save_data, save_path, forward_path, backward_path) elif self.is_builtin_class(element): out = self.analyze_builtin(element) @@ -44,9 +44,9 @@ class APIInfo: return out - def analyze_tensor(self, arg, save_real_data): + def analyze_tensor(self, arg, is_save_data, save_path, forward_path, backward_path): single_arg = {} - if not save_real_data: + if not is_save_data: single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'dtype' : str(arg.dtype)}) @@ -56,14 +56,13 @@ class APIInfo: single_arg.update({'requires_grad': arg.requires_grad}) else: - dump_path = msCheckerConfig.dump_path api_args = self.api_name + '*' + str(self.args_num) if self.is_forward: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') + forward_real_data_path = os.path.join(save_path, forward_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') else: - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') + backward_real_data_path = os.path.join(save_path, backward_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') self.args_num += 1 npy_path = write_npy(file_path, arg.contiguous().cpu().detach().numpy()) @@ -139,8 +138,8 @@ class ForwardAPIInfo(APIInfo): self.analyze_api_call_stack() def analyze_api_input(self, args, kwargs): - args_info_list = self.analyze_element(args) - kwargs_info_dict = self.analyze_element(kwargs) + args_info_list = self.analyze_element(args, self.save_real_data, msCheckerConfig.dump_path) + kwargs_info_dict = self.analyze_element(kwargs, self.save_real_data, msCheckerConfig.dump_path) self.api_info_struct = {self.api_name: {"args":args_info_list, "kwargs":kwargs_info_dict}} def analyze_api_call_stack(self): @@ -160,5 +159,5 @@ class BackwardAPIInfo(APIInfo): self.analyze_api_input(grads) def analyze_api_input(self, grads): - grads_info_list = self.analyze_element(grads) + grads_info_list = self.analyze_element(grads, self.save_real_data, msCheckerConfig.dump_path) self.grad_info_struct = {self.api_name:grads_info_list} diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 7790518399e..0c8f0296449 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -48,25 +48,28 @@ def write_json(file_path, data, indent=None): fcntl.flock(f, fcntl.LOCK_UN) lock.release() +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + raise ValueError(f"file {data_path} already exists, please remove it first") + else: + os.mkdir(data_path, mode = 0o750) + check_file_or_directory_path(data_path, True) + def initialize_output_json(): dump_path = os.path.realpath(msCheckerConfig.dump_path) - check_file_or_directory_path(dump_path,True) + check_file_or_directory_path(dump_path, True) files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] if msCheckerConfig.real_data: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') - if os.path.exists(forward_real_data_path): - raise ValueError(f"file {forward_real_data_path} already exists, please remove it first") - else: - os.mkdir(forward_real_data_path, mode = 0o750) - check_file_or_directory_path(forward_real_data_path, True) - - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') - if os.path.exists(backward_real_data_path): - raise ValueError(f"file {backward_real_data_path} already exists, please remove it first") - else: - os.mkdir(backward_real_data_path, mode = 0o750) - check_file_or_directory_path(backward_real_data_path, True) + initialize_save_path(dump_path, 'forward_real_data') + initialize_save_path(dump_path, 'backward_real_data') for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path): - raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") \ No newline at end of file + raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") + +def initialize_save_error_input_data(): + error_data_path = os.path.realpath(msCheckerConfig.error_data_path) + check_file_or_directory_path(error_data_path, True) + initialize_save_path(error_data_path, 'forward_error_data') + initialize_save_path(error_data_path, 'backward_error_data') \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e5737b51106..9e81d2fd965 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -12,6 +12,9 @@ from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from api_accuracy_checker.dump.api_info import APIInfo +from api_accuracy_checker.dump.info_dump import initialize_save_error_input_data +from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -71,9 +74,13 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, - api_info_dict) - compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list = run_torch_api(api_full_name, + api_setting_dict, + backward_content, + api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + if save_error_data: + do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): @@ -89,14 +96,39 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() +def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success): + if not is_fwd_success and len(in_fwd_data_list) > 0: + bench_api_info = APIInfo(api_full_name + '*bench', True) + npu_api_info = APIInfo(api_full_name + '*npu', True) + for i in range(len(in_fwd_data_list)): + if i < 2: + bench_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, + forward_path='forward_error_data') + else: + npu_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, + forward_path='forward_error_data') + if not is_bwd_success and len(in_bwd_data_list) > 0: + bench_api_info = APIInfo(api_full_name + '*bench', False) + npu_api_info = APIInfo(api_full_name + '*npu', False) + bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, + forward_path='backward_error_data') + npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, + forward_path='backward_error_data') + + def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list, in_bwd_data_list = [], [] [api_type, api_name, _] = api_full_name.split("*") args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad if inplace or not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) + in_fwd_data_list.append(npu_args) + in_fwd_data_list.append(npu_kwargs) grad_out, npu_grad_out = None, None if kwargs.get("device"): del kwargs["device"] @@ -108,10 +140,13 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di grad_index = grad_input_index.get('grad_index') if need_backward: - grad_out, npu_grad_out = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) + grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, + npu_out, out) + in_bwd_data_list.append(grad) + in_bwd_data_list.append(npu_grad) if grad_index is not None: - return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index] - return grad_out, npu_grad_out, npu_out, out + return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], in_fwd_data_list, in_bwd_data_list + return grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list def get_api_info(api_info_dict, api_name): @@ -150,7 +185,7 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np if isinstance(arg, torch.Tensor): npu_args_grad.append(arg.grad) npu_grad_out = npu_args_grad - return grad_out, npu_grad_out + return grad_out, npu_grad_out, grad, npu_grad def _run_ut_parser(parser): @@ -190,6 +225,8 @@ def _run_ut(): raise ValueError("The forward_input_file and backward_input_file should be a json file!") out_path = os.path.realpath(args.out_path) if args.out_path else "./" save_error_data = args.save_error_data + if save_error_data: + initialize_save_error_input_data() run_ut(forward_file, backward_file, out_path, save_error_data) -- Gitee From d27b57036cf4f0020869582fa49355def027a9f6 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 16:35:12 +0800 Subject: [PATCH 2/3] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 9e81d2fd965..d5c902921f4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -111,9 +111,9 @@ def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd bench_api_info = APIInfo(api_full_name + '*bench', False) npu_api_info = APIInfo(api_full_name + '*npu', False) bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, - forward_path='backward_error_data') + backward_path='backward_error_data') npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, - forward_path='backward_error_data') + backward_path='backward_error_data') def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): -- Gitee From a8efe7ab25e6aa85e85b00277ba778ff50ad052b Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 16:42:05 +0800 Subject: [PATCH 3/3] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index d5c902921f4..cb6de537bc8 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -110,9 +110,9 @@ def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd if not is_bwd_success and len(in_bwd_data_list) > 0: bench_api_info = APIInfo(api_full_name + '*bench', False) npu_api_info = APIInfo(api_full_name + '*npu', False) - bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, + bench_api_info.analyze_element(in_bwd_data_list[0], True, msCheckerConfig.error_data_path, backward_path='backward_error_data') - npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, + npu_api_info.analyze_element(in_bwd_data_list[1], True, msCheckerConfig.error_data_path, backward_path='backward_error_data') -- Gitee