From 28875fa133fea57ef4fe37220d472f654e118b9e Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 15:59:51 +0800 Subject: [PATCH 1/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/config.py | 6 +++ .../api_accuracy_checker/compare/compare.py | 1 + .../api_accuracy_checker/config.yaml | 3 +- .../api_accuracy_checker/dump/api_info.py | 25 +++++---- .../api_accuracy_checker/dump/info_dump.py | 33 ++++++------ .../api_accuracy_checker/run_ut/run_ut.py | 51 ++++++++++++++++--- 6 files changed, 83 insertions(+), 36 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index c931c68631..9fe21ccb3f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -13,6 +13,7 @@ class Config: self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) self.real_data = self.validate_real_data(config['real_data']) self.dump_step = self.validate_dump_step(config['dump_step']) + self.error_data_path = self.validate_error_data_path(config['error_data_path']) def validate_dump_path(self, dump_path): if not isinstance(dump_path, str): @@ -43,6 +44,11 @@ class Config: if not isinstance(dump_step, int): raise ValueError("dump_step mast be int type") return dump_step + + def validate_error_data_path(self, error_data_path): + if not isinstance(error_data_path, str): + raise ValueError("error_data_path mast be string type") + return error_data_path def __str__(self): diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 7a1c069e2e..e4055d62b4 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -111,6 +111,7 @@ class Comparator: self.test_result_cnt['forward_fail_num'] += 1 else: self.test_result_cnt['backward_fail_num'] += 1 + return is_fwd_success, is_bwd_success def _compare_core_wrapper(self, bench_out, npu_out): detailed_result_total = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 7e2cd46fe2..2b22a9d9f9 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,4 +3,5 @@ jit_compile: True compile_option: -O3 compare_algorithm: cosine_similarity real_data: False -dump_step: 1000 \ No newline at end of file +dump_step: 1000 +error_data_path: './' \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 5d7fb97e27..81aa297627 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -13,15 +13,15 @@ class APIInfo: self.rank = os.getpid() self.api_name = api_name self.save_real_data = msCheckerConfig.real_data - self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} + self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} self.is_forward = is_forward self.args_num = 0 - def analyze_element(self, element): + def analyze_element(self, element, is_save_data, save_path, forward_path='forward_real_data', backward_path='backward_real_data'): if isinstance(element, (list, tuple)): out = [] for item in element: - out.append(self.analyze_element(item)) + out.append(self.analyze_element(item, is_save_data, save_path, forward_path, backward_path)) elif isinstance(element, dict): out = {} for key, value in element.items(): @@ -29,10 +29,10 @@ class APIInfo: fun = self.torch_object_key[key] out[key] = fun(value) else: - out[key] = self.analyze_element(value) + out[key] = self.analyze_element(value, is_save_data, save_path, forward_path, backward_path) elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element, self.save_real_data) + out = self.analyze_tensor(element, is_save_data, save_path, forward_path, backward_path) elif self.is_builtin_class(element): out = self.analyze_builtin(element) @@ -44,9 +44,9 @@ class APIInfo: return out - def analyze_tensor(self, arg, save_real_data): + def analyze_tensor(self, arg, is_save_data, save_path, forward_path, backward_path): single_arg = {} - if not save_real_data: + if not is_save_data: single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'dtype' : str(arg.dtype)}) @@ -56,14 +56,13 @@ class APIInfo: single_arg.update({'requires_grad': arg.requires_grad}) else: - dump_path = msCheckerConfig.dump_path api_args = self.api_name + '*' + str(self.args_num) if self.is_forward: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') + forward_real_data_path = os.path.join(save_path, forward_path) file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') else: - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') + backward_real_data_path = os.path.join(save_path, backward_path) file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') self.args_num += 1 npy_path = write_npy(file_path, arg.contiguous().cpu().detach().numpy()) @@ -139,8 +138,8 @@ class ForwardAPIInfo(APIInfo): self.analyze_api_call_stack() def analyze_api_input(self, args, kwargs): - args_info_list = self.analyze_element(args) - kwargs_info_dict = self.analyze_element(kwargs) + args_info_list = self.analyze_element(args, self.save_real_data, msCheckerConfig.dump_path) + kwargs_info_dict = self.analyze_element(kwargs, self.save_real_data, msCheckerConfig.dump_path) self.api_info_struct = {self.api_name: {"args":args_info_list, "kwargs":kwargs_info_dict}} def analyze_api_call_stack(self): @@ -160,5 +159,5 @@ class BackwardAPIInfo(APIInfo): self.analyze_api_input(grads) def analyze_api_input(self, grads): - grads_info_list = self.analyze_element(grads) + grads_info_list = self.analyze_element(grads, self.save_real_data, msCheckerConfig.dump_path) self.grad_info_struct = {self.api_name:grads_info_list} diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 7790518399..0c8f029644 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -48,25 +48,28 @@ def write_json(file_path, data, indent=None): fcntl.flock(f, fcntl.LOCK_UN) lock.release() +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + raise ValueError(f"file {data_path} already exists, please remove it first") + else: + os.mkdir(data_path, mode = 0o750) + check_file_or_directory_path(data_path, True) + def initialize_output_json(): dump_path = os.path.realpath(msCheckerConfig.dump_path) - check_file_or_directory_path(dump_path,True) + check_file_or_directory_path(dump_path, True) files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] if msCheckerConfig.real_data: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') - if os.path.exists(forward_real_data_path): - raise ValueError(f"file {forward_real_data_path} already exists, please remove it first") - else: - os.mkdir(forward_real_data_path, mode = 0o750) - check_file_or_directory_path(forward_real_data_path, True) - - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') - if os.path.exists(backward_real_data_path): - raise ValueError(f"file {backward_real_data_path} already exists, please remove it first") - else: - os.mkdir(backward_real_data_path, mode = 0o750) - check_file_or_directory_path(backward_real_data_path, True) + initialize_save_path(dump_path, 'forward_real_data') + initialize_save_path(dump_path, 'backward_real_data') for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path): - raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") \ No newline at end of file + raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") + +def initialize_save_error_input_data(): + error_data_path = os.path.realpath(msCheckerConfig.error_data_path) + check_file_or_directory_path(error_data_path, True) + initialize_save_path(error_data_path, 'forward_error_data') + initialize_save_path(error_data_path, 'backward_error_data') \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e5737b5110..9e81d2fd96 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -12,6 +12,9 @@ from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from api_accuracy_checker.dump.api_info import APIInfo +from api_accuracy_checker.dump.info_dump import initialize_save_error_input_data +from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -71,9 +74,13 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, - api_info_dict) - compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list = run_torch_api(api_full_name, + api_setting_dict, + backward_content, + api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + if save_error_data: + do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): @@ -89,14 +96,39 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() +def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success): + if not is_fwd_success and len(in_fwd_data_list) > 0: + bench_api_info = APIInfo(api_full_name + '*bench', True) + npu_api_info = APIInfo(api_full_name + '*npu', True) + for i in range(len(in_fwd_data_list)): + if i < 2: + bench_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, + forward_path='forward_error_data') + else: + npu_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, + forward_path='forward_error_data') + if not is_bwd_success and len(in_bwd_data_list) > 0: + bench_api_info = APIInfo(api_full_name + '*bench', False) + npu_api_info = APIInfo(api_full_name + '*npu', False) + bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, + forward_path='backward_error_data') + npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, + forward_path='backward_error_data') + + def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list, in_bwd_data_list = [], [] [api_type, api_name, _] = api_full_name.split("*") args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad if inplace or not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) + in_fwd_data_list.append(npu_args) + in_fwd_data_list.append(npu_kwargs) grad_out, npu_grad_out = None, None if kwargs.get("device"): del kwargs["device"] @@ -108,10 +140,13 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di grad_index = grad_input_index.get('grad_index') if need_backward: - grad_out, npu_grad_out = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) + grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, + npu_out, out) + in_bwd_data_list.append(grad) + in_bwd_data_list.append(npu_grad) if grad_index is not None: - return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index] - return grad_out, npu_grad_out, npu_out, out + return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], in_fwd_data_list, in_bwd_data_list + return grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list def get_api_info(api_info_dict, api_name): @@ -150,7 +185,7 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np if isinstance(arg, torch.Tensor): npu_args_grad.append(arg.grad) npu_grad_out = npu_args_grad - return grad_out, npu_grad_out + return grad_out, npu_grad_out, grad, npu_grad def _run_ut_parser(parser): @@ -190,6 +225,8 @@ def _run_ut(): raise ValueError("The forward_input_file and backward_input_file should be a json file!") out_path = os.path.realpath(args.out_path) if args.out_path else "./" save_error_data = args.save_error_data + if save_error_data: + initialize_save_error_input_data() run_ut(forward_file, backward_file, out_path, save_error_data) -- Gitee From d27b57036cf4f0020869582fa49355def027a9f6 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 16:35:12 +0800 Subject: [PATCH 2/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 9e81d2fd96..d5c902921f 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -111,9 +111,9 @@ def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd bench_api_info = APIInfo(api_full_name + '*bench', False) npu_api_info = APIInfo(api_full_name + '*npu', False) bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, - forward_path='backward_error_data') + backward_path='backward_error_data') npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, - forward_path='backward_error_data') + backward_path='backward_error_data') def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): -- Gitee From a8efe7ab25e6aa85e85b00277ba778ff50ad052b Mon Sep 17 00:00:00 2001 From: l30044004 Date: Sat, 19 Aug 2023 16:42:05 +0800 Subject: [PATCH 3/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index d5c902921f..cb6de537bc 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -110,9 +110,9 @@ def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd if not is_bwd_success and len(in_bwd_data_list) > 0: bench_api_info = APIInfo(api_full_name + '*bench', False) npu_api_info = APIInfo(api_full_name + '*npu', False) - bench_api_info.analyze_element(in_fwd_data_list[0], True, msCheckerConfig.error_data_path, + bench_api_info.analyze_element(in_bwd_data_list[0], True, msCheckerConfig.error_data_path, backward_path='backward_error_data') - npu_api_info.analyze_element(in_fwd_data_list[1], True, msCheckerConfig.error_data_path, + npu_api_info.analyze_element(in_bwd_data_list[1], True, msCheckerConfig.error_data_path, backward_path='backward_error_data') -- Gitee From ad45258015b1614b100551e4fd6fdb7dbd9f15ab Mon Sep 17 00:00:00 2001 From: l30044004 Date: Mon, 21 Aug 2023 19:07:55 +0800 Subject: [PATCH 4/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/dump/api_info.py | 17 +++++++++++++ .../api_accuracy_checker/run_ut/run_ut.py | 25 ++++++------------- 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 81aa297627..a6c40ad629 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -161,3 +161,20 @@ class BackwardAPIInfo(APIInfo): def analyze_api_input(self, grads): grads_info_list = self.analyze_element(grads, self.save_real_data, msCheckerConfig.dump_path) self.grad_info_struct = {self.api_name:grads_info_list} + + +class ForwardErrorAPIInfo(APIInfo): + def __init__(self, name, args, kwargs): + super().__init__(name, is_forward=True) + self.analyze_api_input(args, kwargs) + + def analyze_api_input(self, args, kwargs): + self.analyze_element(args, True, msCheckerConfig.error_data_path, forward_path='forward_error_data') + self.analyze_element(kwargs, True, msCheckerConfig.error_data_path, forward_path='forward_error_data') + + +class BackwardErrorAPIInfo(APIInfo): + def __init__(self, name, grads): + super().__init__(name, is_forward=False) + self.analyze_element(grads, True, msCheckerConfig.error_data_path, backward_path='backward_error_data') + diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index cb6de537bc..d3d31ed74e 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -12,7 +12,7 @@ from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.dump.api_info import APIInfo +from api_accuracy_checker.dump.api_info import ForwardErrorAPIInfo, BackwardErrorAPIInfo from api_accuracy_checker.dump.info_dump import initialize_save_error_input_data from api_accuracy_checker.common.config import msCheckerConfig @@ -97,23 +97,12 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success): - if not is_fwd_success and len(in_fwd_data_list) > 0: - bench_api_info = APIInfo(api_full_name + '*bench', True) - npu_api_info = APIInfo(api_full_name + '*npu', True) - for i in range(len(in_fwd_data_list)): - if i < 2: - bench_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, - forward_path='forward_error_data') - else: - npu_api_info.analyze_element(in_fwd_data_list[i], True, msCheckerConfig.error_data_path, - forward_path='forward_error_data') - if not is_bwd_success and len(in_bwd_data_list) > 0: - bench_api_info = APIInfo(api_full_name + '*bench', False) - npu_api_info = APIInfo(api_full_name + '*npu', False) - bench_api_info.analyze_element(in_bwd_data_list[0], True, msCheckerConfig.error_data_path, - backward_path='backward_error_data') - npu_api_info.analyze_element(in_bwd_data_list[1], True, msCheckerConfig.error_data_path, - backward_path='backward_error_data') + if not is_fwd_success and len(in_fwd_data_list) == 4: + ForwardErrorAPIInfo(api_full_name + '*bench', in_fwd_data_list[0], in_fwd_data_list[1]) + ForwardErrorAPIInfo(api_full_name + '*npu', in_fwd_data_list[2], in_fwd_data_list[3]) + if not is_bwd_success and len(in_bwd_data_list) == 2: + BackwardErrorAPIInfo(api_full_name + '*bench', in_bwd_data_list[0]) + BackwardErrorAPIInfo(api_full_name + '*npu', in_bwd_data_list[1]) def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): -- Gitee From 6de3817fbd5f2ce55edf0f3c2c33a9d83ab0afc2 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 23 Aug 2023 10:30:25 +0800 Subject: [PATCH 5/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5/=E8=BE=93=E5=87=BA=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/dump/api_info.py | 63 +++++++++---------- .../api_accuracy_checker/dump/info_dump.py | 5 +- .../api_accuracy_checker/dump/utils.py | 8 +++ .../api_accuracy_checker/run_ut/run_ut.py | 57 ++++++++++------- 4 files changed, 75 insertions(+), 58 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index a6c40ad629..fbb354b846 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -6,22 +6,25 @@ import torch_npu from api_accuracy_checker.common.config import msCheckerConfig from api_accuracy_checker.common.utils import print_error_log from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.utils import write_npy +from api_accuracy_checker.dump.utils import write_pt class APIInfo: - def __init__(self, api_name, is_forward): + def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): self.rank = os.getpid() self.api_name = api_name - self.save_real_data = msCheckerConfig.real_data self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} self.is_forward = is_forward self.args_num = 0 + self.is_save_data = is_save_data + self.save_path = save_path + self.forward_path = forward_path + self.backward_path = backward_path - def analyze_element(self, element, is_save_data, save_path, forward_path='forward_real_data', backward_path='backward_real_data'): + def analyze_element(self, element): if isinstance(element, (list, tuple)): out = [] for item in element: - out.append(self.analyze_element(item, is_save_data, save_path, forward_path, backward_path)) + out.append(self.analyze_element(item)) elif isinstance(element, dict): out = {} for key, value in element.items(): @@ -29,10 +32,10 @@ class APIInfo: fun = self.torch_object_key[key] out[key] = fun(value) else: - out[key] = self.analyze_element(value, is_save_data, save_path, forward_path, backward_path) + out[key] = self.analyze_element(value) elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element, is_save_data, save_path, forward_path, backward_path) + out = self.analyze_tensor(element) elif self.is_builtin_class(element): out = self.analyze_builtin(element) @@ -44,9 +47,9 @@ class APIInfo: return out - def analyze_tensor(self, arg, is_save_data, save_path, forward_path, backward_path): + def analyze_tensor(self, arg): single_arg = {} - if not is_save_data: + if not self.is_save_data: single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'dtype' : str(arg.dtype)}) @@ -58,16 +61,16 @@ class APIInfo: else: api_args = self.api_name + '*' + str(self.args_num) if self.is_forward: - forward_real_data_path = os.path.join(save_path, forward_path) + forward_real_data_path = os.path.join(self.save_path, self.forward_path) - file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') + file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') else: - backward_real_data_path = os.path.join(save_path, backward_path) - file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') + backward_real_data_path = os.path.join(self.save_path, self.backward_path) + file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') self.args_num += 1 - npy_path = write_npy(file_path, arg.contiguous().cpu().detach().numpy()) + pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'datapath' : npy_path}) + single_arg.update({'datapath' : pt_path}) single_arg.update({'requires_grad': arg.requires_grad}) return single_arg @@ -133,13 +136,14 @@ class APIInfo: class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): - super().__init__(name, is_forward=True) + super().__init__(name, True, msCheckerConfig.real_data, msCheckerConfig.dump_path, 'forward_real_data', + 'backward_real_data') self.analyze_api_input(args, kwargs) self.analyze_api_call_stack() def analyze_api_input(self, args, kwargs): - args_info_list = self.analyze_element(args, self.save_real_data, msCheckerConfig.dump_path) - kwargs_info_dict = self.analyze_element(kwargs, self.save_real_data, msCheckerConfig.dump_path) + args_info_list = self.analyze_element(args) + kwargs_info_dict = self.analyze_element(kwargs) self.api_info_struct = {self.api_name: {"args":args_info_list, "kwargs":kwargs_info_dict}} def analyze_api_call_stack(self): @@ -155,26 +159,17 @@ class ForwardAPIInfo(APIInfo): class BackwardAPIInfo(APIInfo): def __init__(self, name, grads): - super().__init__(name, is_forward=False) + super().__init__(name, False, msCheckerConfig.real_data, msCheckerConfig.dump_path, 'forward_real_data', + 'backward_real_data') self.analyze_api_input(grads) def analyze_api_input(self, grads): - grads_info_list = self.analyze_element(grads, self.save_real_data, msCheckerConfig.dump_path) + grads_info_list = self.analyze_element(grads) self.grad_info_struct = {self.api_name:grads_info_list} -class ForwardErrorAPIInfo(APIInfo): - def __init__(self, name, args, kwargs): - super().__init__(name, is_forward=True) - self.analyze_api_input(args, kwargs) - - def analyze_api_input(self, args, kwargs): - self.analyze_element(args, True, msCheckerConfig.error_data_path, forward_path='forward_error_data') - self.analyze_element(kwargs, True, msCheckerConfig.error_data_path, forward_path='forward_error_data') - - -class BackwardErrorAPIInfo(APIInfo): - def __init__(self, name, grads): - super().__init__(name, is_forward=False) - self.analyze_element(grads, True, msCheckerConfig.error_data_path, backward_path='backward_error_data') +class ErrorAPIInfo(APIInfo): + def __init__(self, name, element): + super().__init__(name, True, True, msCheckerConfig.error_data_path, 'error_data', '') + self.analyze_element(element) diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 0c8f029644..42056a9d7a 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -68,8 +68,7 @@ def initialize_output_json(): if os.path.exists(file_path): raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") -def initialize_save_error_input_data(): +def initialize_save_error_data(): error_data_path = os.path.realpath(msCheckerConfig.error_data_path) check_file_or_directory_path(error_data_path, True) - initialize_save_path(error_data_path, 'forward_error_data') - initialize_save_path(error_data_path, 'backward_error_data') \ No newline at end of file + initialize_save_path(error_data_path, 'error_data') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 93af6f0981..707e5f6433 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -1,5 +1,6 @@ import os import numpy as np +import torch def create_folder(path): @@ -13,3 +14,10 @@ def write_npy(file_path, tensor): np.save(file_path, tensor) full_path = os.path.abspath(file_path) return full_path + +def write_pt(file_path, tensor): + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + torch.save(tensor, file_path) + full_path = os.path.abspath(file_path) + return full_path diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index d3d31ed74e..95189fa563 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -12,8 +12,8 @@ from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.dump.api_info import ForwardErrorAPIInfo, BackwardErrorAPIInfo -from api_accuracy_checker.dump.info_dump import initialize_save_error_input_data +from api_accuracy_checker.dump.api_info import ErrorAPIInfo +from api_accuracy_checker.dump.info_dump import initialize_save_error_data from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -74,13 +74,13 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list = run_torch_api(api_full_name, - api_setting_dict, - backward_content, - api_info_dict) + grad_out, npu_grad_out, npu_out, out, error_data_info = run_torch_api(api_full_name, + api_setting_dict, + backward_content, + api_info_dict) is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) if save_error_data: - do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success) + do_save_error_data(api_full_name, error_data_info, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): @@ -96,17 +96,23 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() -def do_save_error_data(api_full_name, in_fwd_data_list, in_bwd_data_list, is_fwd_success, is_bwd_success): - if not is_fwd_success and len(in_fwd_data_list) == 4: - ForwardErrorAPIInfo(api_full_name + '*bench', in_fwd_data_list[0], in_fwd_data_list[1]) - ForwardErrorAPIInfo(api_full_name + '*npu', in_fwd_data_list[2], in_fwd_data_list[3]) - if not is_bwd_success and len(in_bwd_data_list) == 2: - BackwardErrorAPIInfo(api_full_name + '*bench', in_bwd_data_list[0]) - BackwardErrorAPIInfo(api_full_name + '*npu', in_bwd_data_list[1]) +def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): + if not is_fwd_success or not is_bwd_success: + for element in data_info.in_fwd_data_list: + ErrorAPIInfo(api_full_name + '*forward*input', element) + if len(data_info.out_fwd_data_list) == 2: + ErrorAPIInfo(api_full_name + '*forward*output*bench', data_info.out_fwd_data_list[0]) + ErrorAPIInfo(api_full_name + '*forward*output*npu', data_info.out_fwd_data_list[1]) + if len(data_info.in_bwd_data_list) == 1: + ErrorAPIInfo(api_full_name + '*backward*input', data_info.in_bwd_data_list[0]) + if len(data_info.out_bwd_data_list) == 2: + ErrorAPIInfo(api_full_name + '*backward*output*bench', data_info.out_bwd_data_list[0]) + ErrorAPIInfo(api_full_name + '*backward*output*npu', data_info.out_bwd_data_list[1]) + def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): - in_fwd_data_list, in_bwd_data_list = [], [] + in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list = [], [], [], [] [api_type, api_name, _] = api_full_name.split("*") args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) @@ -116,13 +122,13 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di if inplace or not need_grad: print_warn_log("%s involves in-place operations, skip backward" % api_full_name) npu_args, npu_kwargs = generate_npu_params(args, kwargs, need_backward) - in_fwd_data_list.append(npu_args) - in_fwd_data_list.append(npu_kwargs) grad_out, npu_grad_out = None, None if kwargs.get("device"): del kwargs["device"] out = exec_api(api_type, api_name, args, kwargs) npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) + out_fwd_data_list.append(out) + out_fwd_data_list.append(npu_out) grad_input_index = api_setting_dict.get(api_name) grad_index = None if grad_input_index is not None: @@ -132,10 +138,12 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) in_bwd_data_list.append(grad) - in_bwd_data_list.append(npu_grad) + out_bwd_data_list.append(grad_out) + out_bwd_data_list.append(npu_grad_out) + error_data_info = ErrorDataInfo(in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list) if grad_index is not None: - return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], in_fwd_data_list, in_bwd_data_list - return grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list + return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], error_data_info + return grad_out, npu_grad_out, npu_out, out, error_data_info def get_api_info(api_info_dict, api_name): @@ -215,10 +223,17 @@ def _run_ut(): out_path = os.path.realpath(args.out_path) if args.out_path else "./" save_error_data = args.save_error_data if save_error_data: - initialize_save_error_input_data() + initialize_save_error_data() run_ut(forward_file, backward_file, out_path, save_error_data) +class ErrorDataInfo: + def __init__(self, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list): + self.in_fwd_data_list = in_fwd_data_list + self.in_bwd_data_list = in_bwd_data_list + self.out_fwd_data_list = out_fwd_data_list + self.out_bwd_data_list = out_bwd_data_list + if __name__ == '__main__': _run_ut() print_info_log("UT task completed.") -- Gitee From 97f5c5ee034ccf03eb3d2312d67e956836a490fe Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 23 Aug 2023 19:57:12 +0800 Subject: [PATCH 6/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5/=E8=BE=93=E5=87=BA=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/common/base_api.py | 128 +++++++++++++++ .../api_accuracy_checker/common/utils.py | 17 +- .../api_accuracy_checker/dump/api_info.py | 148 +----------------- .../api_accuracy_checker/dump/info_dump.py | 14 +- .../api_accuracy_checker/dump/utils.py | 8 - .../api_accuracy_checker/run_ut/run_ut.py | 37 +++-- .../run_ut/ut_api_info.py | 14 ++ 7 files changed, 190 insertions(+), 176 deletions(-) create mode 100644 debug/accuracy_tools/api_accuracy_checker/common/base_api.py create mode 100644 debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py new file mode 100644 index 0000000000..3689a4c1d1 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -0,0 +1,128 @@ +import os +import torch +from api_accuracy_checker.common.utils import print_error_log, write_pt + + +class BaseAPIInfo: + def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): + self.rank = os.getpid() + self.api_name = api_name + self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} + self.is_forward = is_forward + self.args_num = 0 + self.is_save_data = is_save_data + self.save_path = save_path + self.forward_path = forward_path + self.backward_path = backward_path + + def analyze_element(self, element): + if isinstance(element, (list, tuple)): + out = [] + for item in element: + out.append(self.analyze_element(item)) + elif isinstance(element, dict): + out = {} + for key, value in element.items(): + if key in self.torch_object_key.keys(): + fun = self.torch_object_key[key] + out[key] = fun(value) + else: + out[key] = self.analyze_element(value) + + elif isinstance(element, torch.Tensor): + out = self.analyze_tensor(element) + + elif self.is_builtin_class(element): + out = self.analyze_builtin(element) + else: + msg = f"Type {type(element)} is unsupported at analyze_element" + print_error_log(msg) + + raise NotImplementedError(msg) + return out + + + def analyze_tensor(self, arg): + single_arg = {} + if not self.is_save_data: + + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'dtype' : str(arg.dtype)}) + single_arg.update({'shape' : arg.shape}) + single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) + single_arg.update({'requires_grad': arg.requires_grad}) + + else: + api_args = self.api_name + '*' + str(self.args_num) + if self.is_forward: + forward_real_data_path = os.path.join(self.save_path, self.forward_path) + + file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') + else: + backward_real_data_path = os.path.join(self.save_path, self.backward_path) + file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') + self.args_num += 1 + pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'datapath' : pt_path}) + single_arg.update({'requires_grad': arg.requires_grad}) + return single_arg + + def analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({'type' : "slice"}) + single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({'type' : self.get_type_name(str(type(arg)))}) + single_arg.update({'value' : arg}) + return single_arg + + def transfer_types(self, data, dtype): + if 'int' in dtype or 'bool' in dtype: + return int(data) + else: + return float(data) + + def is_builtin_class(self, element): + if element is None or isinstance(element, (bool,int,float,str,slice)): + return True + return False + + def analyze_device_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.device'}) + if not isinstance(element, str): + + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + single_arg.update({'value' : device_value}) + else: + device_value = element.type + else: + single_arg.update({'value' : element}) + return single_arg + + def analyze_dtype_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.dtype'}) + single_arg.update({'value' : str(element)}) + return single_arg + + def get_tensor_extremum(self, data, operator): + if data.dtype is torch.bool: + if operator == 'max': + return True in data + elif operator == 'min': + return False not in data + if operator == 'max': + return torch._C._VariableFunctionsClass.max(data).item() + else: + return torch._C._VariableFunctionsClass.min(data).item() + + def get_type_name(self, name): + + left = name.index("'") + right = name.rindex("'") + return name[left + 1 : right] diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 4d9db1aecc..32ee339f23 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -588,4 +588,19 @@ def cross_entropy_process(api_info_dict): if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: if api_info_dict['args'][1]['Min'] <= 0: api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. - return api_info_dict \ No newline at end of file + return api_info_dict + +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + raise ValueError(f"file {data_path} already exists, please remove it first") + else: + os.mkdir(data_path, mode = 0o750) + check_file_or_directory_path(data_path, True) + +def write_pt(file_path, tensor): + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + torch.save(tensor, file_path) + full_path = os.path.abspath(file_path) + return full_path \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index fbb354b846..763e33a2f3 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,143 +1,19 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 -import os import inspect -import torch -import torch_npu from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.common.utils import print_error_log -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.utils import write_pt - -class APIInfo: - def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): - self.rank = os.getpid() - self.api_name = api_name - self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} - self.is_forward = is_forward - self.args_num = 0 - self.is_save_data = is_save_data - self.save_path = save_path - self.forward_path = forward_path - self.backward_path = backward_path - - def analyze_element(self, element): - if isinstance(element, (list, tuple)): - out = [] - for item in element: - out.append(self.analyze_element(item)) - elif isinstance(element, dict): - out = {} - for key, value in element.items(): - if key in self.torch_object_key.keys(): - fun = self.torch_object_key[key] - out[key] = fun(value) - else: - out[key] = self.analyze_element(value) - - elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element) - - elif self.is_builtin_class(element): - out = self.analyze_builtin(element) - else: - msg = f"Type {type(element)} is unsupported at analyze_element" - print_error_log(msg) - - raise NotImplementedError(msg) - return out +from api_accuracy_checker.common.config import BaseAPIInfo - def analyze_tensor(self, arg): - single_arg = {} - if not self.is_save_data: - - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'dtype' : str(arg.dtype)}) - single_arg.update({'shape' : arg.shape}) - single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) - single_arg.update({'requires_grad': arg.requires_grad}) - - else: - api_args = self.api_name + '*' + str(self.args_num) - if self.is_forward: - forward_real_data_path = os.path.join(self.save_path, self.forward_path) - - file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') - else: - backward_real_data_path = os.path.join(self.save_path, self.backward_path) - file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') - self.args_num += 1 - pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'datapath' : pt_path}) - single_arg.update({'requires_grad': arg.requires_grad}) - return single_arg - - def analyze_builtin(self, arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({'type' : "slice"}) - single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) - else: - single_arg.update({'type' : self.get_type_name(str(type(arg)))}) - single_arg.update({'value' : arg}) - return single_arg - - def transfer_types(self, data, dtype): - if 'int' in dtype or 'bool' in dtype: - return int(data) - else: - return float(data) - - def is_builtin_class(self, element): - if element is None or isinstance(element, (bool,int,float,str,slice)): - return True - return False - - def analyze_device_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type' : 'torch.device'}) - if not isinstance(element, str): - - if hasattr(element, "index"): - device_value = element.type + ":" + str(element.index) - single_arg.update({'value' : device_value}) - else: - device_value = element.type - else: - single_arg.update({'value' : element}) - return single_arg - - def analyze_dtype_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type' : 'torch.dtype'}) - single_arg.update({'value' : str(element)}) - return single_arg - - def get_tensor_extremum(self, data, operator): - if data.dtype is torch.bool: - if operator == 'max': - return True in data - elif operator == 'min': - return False not in data - if operator == 'max': - return torch._C._VariableFunctionsClass.max(data).item() - else: - return torch._C._VariableFunctionsClass.min(data).item() - - def get_type_name(self, name): - - left = name.index("'") - right = name.rindex("'") - return name[left + 1 : right] - +class APIInfo(BaseAPIInfo): + def __init__(self, api_name, is_forward, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path, forward_path='forward_real_data', + backward_path='backward_real_data'): + super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): - super().__init__(name, True, msCheckerConfig.real_data, msCheckerConfig.dump_path, 'forward_real_data', - 'backward_real_data') + super().__init__(name, is_forward=True) self.analyze_api_input(args, kwargs) self.analyze_api_call_stack() @@ -159,17 +35,9 @@ class ForwardAPIInfo(APIInfo): class BackwardAPIInfo(APIInfo): def __init__(self, name, grads): - super().__init__(name, False, msCheckerConfig.real_data, msCheckerConfig.dump_path, 'forward_real_data', - 'backward_real_data') + super().__init__(name, is_forward=False) self.analyze_api_input(grads) def analyze_api_input(self, grads): grads_info_list = self.analyze_element(grads) self.grad_info_struct = {self.api_name:grads_info_list} - - -class ErrorAPIInfo(APIInfo): - def __init__(self, name, element): - super().__init__(name, True, True, msCheckerConfig.error_data_path, 'error_data', '') - self.analyze_element(element) - diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 42056a9d7a..437934cd71 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -5,7 +5,7 @@ import threading import numpy as np from .api_info import ForwardAPIInfo, BackwardAPIInfo -from ..common.utils import check_file_or_directory_path +from ..common.utils import check_file_or_directory_path, initialize_save_path from ..common.config import msCheckerConfig lock = threading.Lock() @@ -48,13 +48,6 @@ def write_json(file_path, data, indent=None): fcntl.flock(f, fcntl.LOCK_UN) lock.release() -def initialize_save_path(save_path, dir_name): - data_path = os.path.join(save_path, dir_name) - if os.path.exists(data_path): - raise ValueError(f"file {data_path} already exists, please remove it first") - else: - os.mkdir(data_path, mode = 0o750) - check_file_or_directory_path(data_path, True) def initialize_output_json(): dump_path = os.path.realpath(msCheckerConfig.dump_path) @@ -67,8 +60,3 @@ def initialize_output_json(): file_path = os.path.join(dump_path, file) if os.path.exists(file_path): raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") - -def initialize_save_error_data(): - error_data_path = os.path.realpath(msCheckerConfig.error_data_path) - check_file_or_directory_path(error_data_path, True) - initialize_save_path(error_data_path, 'error_data') diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py index 707e5f6433..93af6f0981 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/utils.py @@ -1,6 +1,5 @@ import os import numpy as np -import torch def create_folder(path): @@ -14,10 +13,3 @@ def write_npy(file_path, tensor): np.save(file_path, tensor) full_path = os.path.abspath(file_path) return full_path - -def write_pt(file_path, tensor): - if os.path.exists(file_path): - raise ValueError(f"File {file_path} already exists") - torch.save(tensor, file_path) - full_path = os.path.abspath(file_path) - return full_path diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 95189fa563..fde43cfc05 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -7,13 +7,12 @@ import torch from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log + print_error_log, check_file_or_directory_path, initialize_save_path from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from api_accuracy_checker.dump.api_info import ErrorAPIInfo -from api_accuracy_checker.dump.info_dump import initialize_save_error_data +from ut_api_info import ErrorAPIInfo from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -74,13 +73,11 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - grad_out, npu_grad_out, npu_out, out, error_data_info = run_torch_api(api_full_name, - api_setting_dict, - backward_content, - api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.out, data_info.npu_out, + data_info.grad_out, data_info.npu_grad_out) if save_error_data: - do_save_error_data(api_full_name, error_data_info, is_fwd_success, is_bwd_success) + do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): @@ -140,10 +137,11 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di in_bwd_data_list.append(grad) out_bwd_data_list.append(grad_out) out_bwd_data_list.append(npu_grad_out) - error_data_info = ErrorDataInfo(in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list) if grad_index is not None: - return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], error_data_info - return grad_out, npu_grad_out, npu_out, out, error_data_info + return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], in_fwd_data_list, + in_bwd_data_list, out_fwd_data_list, out_bwd_data_list) + return UtDataInfo(grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, + out_bwd_data_list) def get_api_info(api_info_dict, api_name): @@ -185,6 +183,12 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np return grad_out, npu_grad_out, grad, npu_grad +def initialize_save_error_data(): + error_data_path = os.path.realpath(msCheckerConfig.error_data_path) + check_file_or_directory_path(error_data_path, True) + initialize_save_path(error_data_path, 'error_data') + + def _run_ut_parser(parser): parser.add_argument("-forward", "--forward_input_file", dest="forward_input_file", default="", help=" The api param tool forward result file: generate from api param tool, " @@ -227,8 +231,13 @@ def _run_ut(): run_ut(forward_file, backward_file, out_path, save_error_data) -class ErrorDataInfo: - def __init__(self, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list): +class UtDataInfo: + def __init__(self, grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, + out_bwd_data_list): + self.grad_out = grad_out + self.npu_grad_out = npu_grad_out + self.npu_out = npu_out + self.out = out self.in_fwd_data_list = in_fwd_data_list self.in_bwd_data_list = in_bwd_data_list self.out_fwd_data_list = out_fwd_data_list diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py new file mode 100644 index 0000000000..a6bbc8a0f6 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py @@ -0,0 +1,14 @@ +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.common.config import BaseAPIInfo + + +class UtAPIInfo(BaseAPIInfo): + def __init__(self, api_name, is_forward, is_save_data=True, save_path=msCheckerConfig.error_data_path, + forward_path='ut_error_data', backward_path='ut_error_data'): + super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) + + +class ErrorAPIInfo(UtAPIInfo): + def __init__(self, api_name, element, is_forward=True): + super().__init__(api_name, is_forward) + self.analyze_element(element) -- Gitee From 903a58f47f40ced4f07f199f4f7a5e3d3e675b1e Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 23 Aug 2023 20:31:23 +0800 Subject: [PATCH 7/8] =?UTF-8?q?=E9=A2=84=E6=A3=80=E5=B7=A5=E5=85=B7?= =?UTF-8?q?=E9=92=88=E5=AF=B9=E6=9C=AA=E8=BE=BE=E5=88=B0=E9=A2=84=E8=AE=BE?= =?UTF-8?q?=E6=A0=87=E5=87=86=E7=9A=84api=E4=BF=9D=E5=AD=98=E8=BE=93?= =?UTF-8?q?=E5=85=A5/=E8=BE=93=E5=87=BA=E6=95=B0=E6=8D=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../api_accuracy_checker/run_ut/run_ut.py | 14 +++++++------- .../api_accuracy_checker/run_ut/ut_api_info.py | 7 +------ 2 files changed, 8 insertions(+), 13 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fde43cfc05..cad59309ef 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -12,7 +12,7 @@ from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate -from ut_api_info import ErrorAPIInfo +from ut_api_info import UtAPIInfo from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -96,15 +96,15 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): if not is_fwd_success or not is_bwd_success: for element in data_info.in_fwd_data_list: - ErrorAPIInfo(api_full_name + '*forward*input', element) + UtAPIInfo(api_full_name + '*forward*input', element) if len(data_info.out_fwd_data_list) == 2: - ErrorAPIInfo(api_full_name + '*forward*output*bench', data_info.out_fwd_data_list[0]) - ErrorAPIInfo(api_full_name + '*forward*output*npu', data_info.out_fwd_data_list[1]) + UtAPIInfo(api_full_name + '*forward*output*bench', data_info.out_fwd_data_list[0]) + UtAPIInfo(api_full_name + '*forward*output*npu', data_info.out_fwd_data_list[1]) if len(data_info.in_bwd_data_list) == 1: - ErrorAPIInfo(api_full_name + '*backward*input', data_info.in_bwd_data_list[0]) + UtAPIInfo(api_full_name + '*backward*input', data_info.in_bwd_data_list[0]) if len(data_info.out_bwd_data_list) == 2: - ErrorAPIInfo(api_full_name + '*backward*output*bench', data_info.out_bwd_data_list[0]) - ErrorAPIInfo(api_full_name + '*backward*output*npu', data_info.out_bwd_data_list[1]) + UtAPIInfo(api_full_name + '*backward*output*bench', data_info.out_bwd_data_list[0]) + UtAPIInfo(api_full_name + '*backward*output*npu', data_info.out_bwd_data_list[1]) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py index a6bbc8a0f6..4d9d074fc6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py @@ -3,12 +3,7 @@ from api_accuracy_checker.common.config import BaseAPIInfo class UtAPIInfo(BaseAPIInfo): - def __init__(self, api_name, is_forward, is_save_data=True, save_path=msCheckerConfig.error_data_path, + def __init__(self, api_name, element, is_forward=True, is_save_data=True, save_path=msCheckerConfig.error_data_path, forward_path='ut_error_data', backward_path='ut_error_data'): super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) - - -class ErrorAPIInfo(UtAPIInfo): - def __init__(self, api_name, element, is_forward=True): - super().__init__(api_name, is_forward) self.analyze_element(element) -- Gitee From aeac9794ead8ec8155052343e52d394d67902bad Mon Sep 17 00:00:00 2001 From: louyujing Date: Thu, 24 Aug 2023 02:21:04 +0000 Subject: [PATCH 8/8] update debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py. Signed-off-by: louyujing --- .../api_accuracy_checker/run_ut/run_ut.py | 46 ++++++++----------- 1 file changed, 19 insertions(+), 27 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index cad59309ef..f4f1de8dba 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -74,8 +74,9 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): for api_full_name, api_info_dict in tqdm(forward_content.items()): try: data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) - is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.out, data_info.npu_out, - data_info.grad_out, data_info.npu_grad_out) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.bench_out, + data_info.npu_out, data_info.bench_grad_out, + data_info.npu_grad_out) if save_error_data: do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: @@ -97,19 +98,19 @@ def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) if not is_fwd_success or not is_bwd_success: for element in data_info.in_fwd_data_list: UtAPIInfo(api_full_name + '*forward*input', element) - if len(data_info.out_fwd_data_list) == 2: - UtAPIInfo(api_full_name + '*forward*output*bench', data_info.out_fwd_data_list[0]) - UtAPIInfo(api_full_name + '*forward*output*npu', data_info.out_fwd_data_list[1]) - if len(data_info.in_bwd_data_list) == 1: - UtAPIInfo(api_full_name + '*backward*input', data_info.in_bwd_data_list[0]) - if len(data_info.out_bwd_data_list) == 2: - UtAPIInfo(api_full_name + '*backward*output*bench', data_info.out_bwd_data_list[0]) - UtAPIInfo(api_full_name + '*backward*output*npu', data_info.out_bwd_data_list[1]) + if data_info.bench_out is not None: + UtAPIInfo(api_full_name + '*forward*output*bench', data_info.bench_out) + UtAPIInfo(api_full_name + '*forward*output*npu', data_info.npu_out) + if data_info.grad_in is not None: + UtAPIInfo(api_full_name + '*backward*input', data_info.grad_in) + if data_info.bench_grad_out is not None: + UtAPIInfo(api_full_name + '*backward*output*bench', data_info.bench_grad_out) + UtAPIInfo(api_full_name + '*backward*output*npu', data_info.npu_grad_out) def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): - in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, out_bwd_data_list = [], [], [], [] + in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) in_fwd_data_list.append(args) @@ -124,24 +125,18 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di del kwargs["device"] out = exec_api(api_type, api_name, args, kwargs) npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) - out_fwd_data_list.append(out) - out_fwd_data_list.append(npu_out) grad_input_index = api_setting_dict.get(api_name) grad_index = None + grad = None if grad_input_index is not None: grad_index = grad_input_index.get('grad_index') if need_backward: grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) - in_bwd_data_list.append(grad) - out_bwd_data_list.append(grad_out) - out_bwd_data_list.append(npu_grad_out) if grad_index is not None: - return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], in_fwd_data_list, - in_bwd_data_list, out_fwd_data_list, out_bwd_data_list) - return UtDataInfo(grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, - out_bwd_data_list) + return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) + return UtDataInfo(grad_out, npu_grad_out, npu_out, out, grad, in_fwd_data_list) def get_api_info(api_info_dict, api_name): @@ -232,16 +227,13 @@ def _run_ut(): class UtDataInfo: - def __init__(self, grad_out, npu_grad_out, npu_out, out, in_fwd_data_list, in_bwd_data_list, out_fwd_data_list, - out_bwd_data_list): - self.grad_out = grad_out + def __init__(self, bench_grad_out, npu_grad_out, npu_out, bench_out, grad_in, in_fwd_data_list): + self.bench_grad_out = bench_grad_out self.npu_grad_out = npu_grad_out self.npu_out = npu_out - self.out = out + self.bench_out = bench_out + self.grad_in = grad_in self.in_fwd_data_list = in_fwd_data_list - self.in_bwd_data_list = in_bwd_data_list - self.out_fwd_data_list = out_fwd_data_list - self.out_bwd_data_list = out_bwd_data_list if __name__ == '__main__': _run_ut() -- Gitee