diff --git a/debug/accuracy_tools/api_accuracy_checker/common/base_api.py b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py new file mode 100644 index 0000000000000000000000000000000000000000..3689a4c1d1bb6b61b14c6775049097d774730f37 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/common/base_api.py @@ -0,0 +1,128 @@ +import os +import torch +from api_accuracy_checker.common.utils import print_error_log, write_pt + + +class BaseAPIInfo: + def __init__(self, api_name, is_forward, is_save_data, save_path, forward_path, backward_path): + self.rank = os.getpid() + self.api_name = api_name + self.torch_object_key = {'device': self.analyze_device_in_kwargs, 'dtype': self.analyze_dtype_in_kwargs} + self.is_forward = is_forward + self.args_num = 0 + self.is_save_data = is_save_data + self.save_path = save_path + self.forward_path = forward_path + self.backward_path = backward_path + + def analyze_element(self, element): + if isinstance(element, (list, tuple)): + out = [] + for item in element: + out.append(self.analyze_element(item)) + elif isinstance(element, dict): + out = {} + for key, value in element.items(): + if key in self.torch_object_key.keys(): + fun = self.torch_object_key[key] + out[key] = fun(value) + else: + out[key] = self.analyze_element(value) + + elif isinstance(element, torch.Tensor): + out = self.analyze_tensor(element) + + elif self.is_builtin_class(element): + out = self.analyze_builtin(element) + else: + msg = f"Type {type(element)} is unsupported at analyze_element" + print_error_log(msg) + + raise NotImplementedError(msg) + return out + + + def analyze_tensor(self, arg): + single_arg = {} + if not self.is_save_data: + + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'dtype' : str(arg.dtype)}) + single_arg.update({'shape' : arg.shape}) + single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) + single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) + single_arg.update({'requires_grad': arg.requires_grad}) + + else: + api_args = self.api_name + '*' + str(self.args_num) + if self.is_forward: + forward_real_data_path = os.path.join(self.save_path, self.forward_path) + + file_path = os.path.join(forward_real_data_path, f'{api_args}.pt') + else: + backward_real_data_path = os.path.join(self.save_path, self.backward_path) + file_path = os.path.join(backward_real_data_path, f'{api_args}.pt') + self.args_num += 1 + pt_path = write_pt(file_path, arg.contiguous().cpu().detach()) + single_arg.update({'type' : 'torch.Tensor'}) + single_arg.update({'datapath' : pt_path}) + single_arg.update({'requires_grad': arg.requires_grad}) + return single_arg + + def analyze_builtin(self, arg): + single_arg = {} + if isinstance(arg, slice): + single_arg.update({'type' : "slice"}) + single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) + else: + single_arg.update({'type' : self.get_type_name(str(type(arg)))}) + single_arg.update({'value' : arg}) + return single_arg + + def transfer_types(self, data, dtype): + if 'int' in dtype or 'bool' in dtype: + return int(data) + else: + return float(data) + + def is_builtin_class(self, element): + if element is None or isinstance(element, (bool,int,float,str,slice)): + return True + return False + + def analyze_device_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.device'}) + if not isinstance(element, str): + + if hasattr(element, "index"): + device_value = element.type + ":" + str(element.index) + single_arg.update({'value' : device_value}) + else: + device_value = element.type + else: + single_arg.update({'value' : element}) + return single_arg + + def analyze_dtype_in_kwargs(self, element): + single_arg = {} + single_arg.update({'type' : 'torch.dtype'}) + single_arg.update({'value' : str(element)}) + return single_arg + + def get_tensor_extremum(self, data, operator): + if data.dtype is torch.bool: + if operator == 'max': + return True in data + elif operator == 'min': + return False not in data + if operator == 'max': + return torch._C._VariableFunctionsClass.max(data).item() + else: + return torch._C._VariableFunctionsClass.min(data).item() + + def get_type_name(self, name): + + left = name.index("'") + right = name.rindex("'") + return name[left + 1 : right] diff --git a/debug/accuracy_tools/api_accuracy_checker/common/config.py b/debug/accuracy_tools/api_accuracy_checker/common/config.py index c931c686318908fa8b64330f7f1e72102e9330c8..9fe21ccb3f3b7f1c8c52ba50f24f03fec093f545 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/config.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/config.py @@ -13,6 +13,7 @@ class Config: self.compare_algorithm = self.validate_compare_algorithm(config['compare_algorithm']) self.real_data = self.validate_real_data(config['real_data']) self.dump_step = self.validate_dump_step(config['dump_step']) + self.error_data_path = self.validate_error_data_path(config['error_data_path']) def validate_dump_path(self, dump_path): if not isinstance(dump_path, str): @@ -43,6 +44,11 @@ class Config: if not isinstance(dump_step, int): raise ValueError("dump_step mast be int type") return dump_step + + def validate_error_data_path(self, error_data_path): + if not isinstance(error_data_path, str): + raise ValueError("error_data_path mast be string type") + return error_data_path def __str__(self): diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 4d9db1aeccd6875472b14ec5e1fa4cd4a3488530..32ee339f236388316c36925c4088ea3bb8617fb7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -588,4 +588,19 @@ def cross_entropy_process(api_info_dict): if 'args' in api_info_dict and len(api_info_dict['args']) > 1 and 'Min' in api_info_dict['args'][1]: if api_info_dict['args'][1]['Min'] <= 0: api_info_dict['args'][1]['Min'] = 0 #The second argument in cross_entropy should be -100 or not less than 0. - return api_info_dict \ No newline at end of file + return api_info_dict + +def initialize_save_path(save_path, dir_name): + data_path = os.path.join(save_path, dir_name) + if os.path.exists(data_path): + raise ValueError(f"file {data_path} already exists, please remove it first") + else: + os.mkdir(data_path, mode = 0o750) + check_file_or_directory_path(data_path, True) + +def write_pt(file_path, tensor): + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + torch.save(tensor, file_path) + full_path = os.path.abspath(file_path) + return full_path \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 7a1c069e2eff91940d26a7bf4b74bfc54554a04e..e4055d62b47dbb8a1542ebf3be5fb302c1a8d6ba 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -111,6 +111,7 @@ class Comparator: self.test_result_cnt['forward_fail_num'] += 1 else: self.test_result_cnt['backward_fail_num'] += 1 + return is_fwd_success, is_bwd_success def _compare_core_wrapper(self, bench_out, npu_out): detailed_result_total = [] diff --git a/debug/accuracy_tools/api_accuracy_checker/config.yaml b/debug/accuracy_tools/api_accuracy_checker/config.yaml index 7e2cd46fe2477ba2bfb0a16ed972e1beab6d8da0..2b22a9d9f93bcd3f80eae803c7ffd4994c2da111 100644 --- a/debug/accuracy_tools/api_accuracy_checker/config.yaml +++ b/debug/accuracy_tools/api_accuracy_checker/config.yaml @@ -3,4 +3,5 @@ jit_compile: True compile_option: -O3 compare_algorithm: cosine_similarity real_data: False -dump_step: 1000 \ No newline at end of file +dump_step: 1000 +error_data_path: './' \ No newline at end of file diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index 5d7fb97e27620a700fe7fc47b34c25a28213face..763e33a2f3045ede76c1246c354a80ab656b5bd7 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -1,135 +1,14 @@ # 定义API INFO,保存基本信息,用于后续结构体的落盘,注意考虑random场景及真实数据场景 -import os import inspect -import torch -import torch_npu from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.common.utils import print_error_log -from api_accuracy_checker.common.config import msCheckerConfig -from api_accuracy_checker.dump.utils import write_npy - -class APIInfo: - def __init__(self, api_name, is_forward): - self.rank = os.getpid() - self.api_name = api_name - self.save_real_data = msCheckerConfig.real_data - self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} - self.is_forward = is_forward - self.args_num = 0 - - def analyze_element(self, element): - if isinstance(element, (list, tuple)): - out = [] - for item in element: - out.append(self.analyze_element(item)) - elif isinstance(element, dict): - out = {} - for key, value in element.items(): - if key in self.torch_object_key.keys(): - fun = self.torch_object_key[key] - out[key] = fun(value) - else: - out[key] = self.analyze_element(value) - - elif isinstance(element, torch.Tensor): - out = self.analyze_tensor(element, self.save_real_data) - - elif self.is_builtin_class(element): - out = self.analyze_builtin(element) - else: - msg = f"Type {type(element)} is unsupported at analyze_element" - print_error_log(msg) - - raise NotImplementedError(msg) - return out - - - def analyze_tensor(self, arg, save_real_data): - single_arg = {} - if not save_real_data: - - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'dtype' : str(arg.dtype)}) - single_arg.update({'shape' : arg.shape}) - single_arg.update({'Max' : self.transfer_types(self.get_tensor_extremum(arg,'max'), str(arg.dtype))}) - single_arg.update({'Min' : self.transfer_types(self.get_tensor_extremum(arg,'min'), str(arg.dtype))}) - single_arg.update({'requires_grad': arg.requires_grad}) - - else: - dump_path = msCheckerConfig.dump_path - api_args = self.api_name + '*' + str(self.args_num) - if self.is_forward: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') - - file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') - else: - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') - file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') - self.args_num += 1 - npy_path = write_npy(file_path, arg.contiguous().cpu().detach().numpy()) - single_arg.update({'type' : 'torch.Tensor'}) - single_arg.update({'datapath' : npy_path}) - single_arg.update({'requires_grad': arg.requires_grad}) - return single_arg - - def analyze_builtin(self, arg): - single_arg = {} - if isinstance(arg, slice): - single_arg.update({'type' : "slice"}) - single_arg.update({'value' : [arg.start, arg.stop, arg.step]}) - else: - single_arg.update({'type' : self.get_type_name(str(type(arg)))}) - single_arg.update({'value' : arg}) - return single_arg - - def transfer_types(self, data, dtype): - if 'int' in dtype or 'bool' in dtype: - return int(data) - else: - return float(data) - - def is_builtin_class(self, element): - if element is None or isinstance(element, (bool,int,float,str,slice)): - return True - return False - - def analyze_device_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type' : 'torch.device'}) - if not isinstance(element, str): - - if hasattr(element, "index"): - device_value = element.type + ":" + str(element.index) - single_arg.update({'value' : device_value}) - else: - device_value = element.type - else: - single_arg.update({'value' : element}) - return single_arg - - def analyze_dtype_in_kwargs(self, element): - single_arg = {} - single_arg.update({'type' : 'torch.dtype'}) - single_arg.update({'value' : str(element)}) - return single_arg - - def get_tensor_extremum(self, data, operator): - if data.dtype is torch.bool: - if operator == 'max': - return True in data - elif operator == 'min': - return False not in data - if operator == 'max': - return torch._C._VariableFunctionsClass.max(data).item() - else: - return torch._C._VariableFunctionsClass.min(data).item() - - def get_type_name(self, name): +from api_accuracy_checker.common.config import BaseAPIInfo - left = name.index("'") - right = name.rindex("'") - return name[left + 1 : right] +class APIInfo(BaseAPIInfo): + def __init__(self, api_name, is_forward, is_save_data=msCheckerConfig.real_data, + save_path=msCheckerConfig.dump_path, forward_path='forward_real_data', + backward_path='backward_real_data'): + super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) class ForwardAPIInfo(APIInfo): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 7790518399e3848d8856b479caae7d6ec3939801..437934cd71b166d84077bf3bc0357cd6435c56fb 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -5,7 +5,7 @@ import threading import numpy as np from .api_info import ForwardAPIInfo, BackwardAPIInfo -from ..common.utils import check_file_or_directory_path +from ..common.utils import check_file_or_directory_path, initialize_save_path from ..common.config import msCheckerConfig lock = threading.Lock() @@ -48,25 +48,15 @@ def write_json(file_path, data, indent=None): fcntl.flock(f, fcntl.LOCK_UN) lock.release() + def initialize_output_json(): dump_path = os.path.realpath(msCheckerConfig.dump_path) - check_file_or_directory_path(dump_path,True) + check_file_or_directory_path(dump_path, True) files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] if msCheckerConfig.real_data: - forward_real_data_path = os.path.join(dump_path, 'forward_real_data') - if os.path.exists(forward_real_data_path): - raise ValueError(f"file {forward_real_data_path} already exists, please remove it first") - else: - os.mkdir(forward_real_data_path, mode = 0o750) - check_file_or_directory_path(forward_real_data_path, True) - - backward_real_data_path = os.path.join(dump_path, 'backward_real_data') - if os.path.exists(backward_real_data_path): - raise ValueError(f"file {backward_real_data_path} already exists, please remove it first") - else: - os.mkdir(backward_real_data_path, mode = 0o750) - check_file_or_directory_path(backward_real_data_path, True) + initialize_save_path(dump_path, 'forward_real_data') + initialize_save_path(dump_path, 'backward_real_data') for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path): - raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") \ No newline at end of file + raise ValueError(f"file {file_path} already exists, please remove it first or use a new dump path") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index e5737b511068814d85ffb3d3eb5654288a46c93f..f4f1de8dbaa4785b49421ba44d05501b621cad4b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -7,11 +7,13 @@ import torch from tqdm import tqdm from api_accuracy_checker.run_ut.data_generate import gen_api_params, gen_args from api_accuracy_checker.common.utils import print_info_log, print_warn_log, get_json_contents, api_info_preprocess, \ - print_error_log + print_error_log, check_file_or_directory_path, initialize_save_path from api_accuracy_checker.compare.compare import Comparator from api_accuracy_checker.hook_module.wrap_tensor import TensorOPTemplate from api_accuracy_checker.hook_module.wrap_functional import FunctionalOPTemplate from api_accuracy_checker.hook_module.wrap_torch import TorchOPTemplate +from ut_api_info import UtAPIInfo +from api_accuracy_checker.common.config import msCheckerConfig NO_GRAD_APIS = ["hardtanh"] @@ -71,9 +73,12 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare = Comparator(out_path) for api_full_name, api_info_dict in tqdm(forward_content.items()): try: - grad_out, npu_grad_out, npu_out, out = run_torch_api(api_full_name, api_setting_dict, backward_content, - api_info_dict) - compare.compare_output(api_full_name, out, npu_out, grad_out, npu_grad_out) + data_info = run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict) + is_fwd_success, is_bwd_success = compare.compare_output(api_full_name, data_info.bench_out, + data_info.npu_out, data_info.bench_grad_out, + data_info.npu_grad_out) + if save_error_data: + do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success) except Exception as err: [_, api_name, _] = api_full_name.split("*") if "not implemented for 'Half'" in str(err): @@ -89,9 +94,27 @@ def run_ut(forward_file, backward_file, out_path, save_error_data): compare.write_compare_csv() +def do_save_error_data(api_full_name, data_info, is_fwd_success, is_bwd_success): + if not is_fwd_success or not is_bwd_success: + for element in data_info.in_fwd_data_list: + UtAPIInfo(api_full_name + '*forward*input', element) + if data_info.bench_out is not None: + UtAPIInfo(api_full_name + '*forward*output*bench', data_info.bench_out) + UtAPIInfo(api_full_name + '*forward*output*npu', data_info.npu_out) + if data_info.grad_in is not None: + UtAPIInfo(api_full_name + '*backward*input', data_info.grad_in) + if data_info.bench_grad_out is not None: + UtAPIInfo(api_full_name + '*backward*output*bench', data_info.bench_grad_out) + UtAPIInfo(api_full_name + '*backward*output*npu', data_info.npu_grad_out) + + + def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_dict): + in_fwd_data_list = [] [api_type, api_name, _] = api_full_name.split("*") args, inplace, kwargs, need_grad = get_api_info(api_info_dict, api_name) + in_fwd_data_list.append(args) + in_fwd_data_list.append(kwargs) need_backward = api_full_name in backward_content and api_name[-1] != "_" and inplace is not True need_backward = need_backward and need_grad if inplace or not need_grad: @@ -104,14 +127,16 @@ def run_torch_api(api_full_name, api_setting_dict, backward_content, api_info_di npu_out = exec_api(api_type, api_name, npu_args, npu_kwargs) grad_input_index = api_setting_dict.get(api_name) grad_index = None + grad = None if grad_input_index is not None: grad_index = grad_input_index.get('grad_index') if need_backward: - grad_out, npu_grad_out = run_backward(api_full_name, args, backward_content, grad_index, npu_args, npu_out, out) + grad_out, npu_grad_out, grad, npu_grad = run_backward(api_full_name, args, backward_content, grad_index, npu_args, + npu_out, out) if grad_index is not None: - return grad_out, npu_grad_out, npu_out[grad_index], out[grad_index] - return grad_out, npu_grad_out, npu_out, out + return UtDataInfo(grad_out, npu_grad_out, npu_out[grad_index], out[grad_index], grad, in_fwd_data_list) + return UtDataInfo(grad_out, npu_grad_out, npu_out, out, grad, in_fwd_data_list) def get_api_info(api_info_dict, api_name): @@ -150,7 +175,13 @@ def run_backward(api_full_name, args, backward_content, grad_index, npu_args, np if isinstance(arg, torch.Tensor): npu_args_grad.append(arg.grad) npu_grad_out = npu_args_grad - return grad_out, npu_grad_out + return grad_out, npu_grad_out, grad, npu_grad + + +def initialize_save_error_data(): + error_data_path = os.path.realpath(msCheckerConfig.error_data_path) + check_file_or_directory_path(error_data_path, True) + initialize_save_path(error_data_path, 'error_data') def _run_ut_parser(parser): @@ -190,9 +221,20 @@ def _run_ut(): raise ValueError("The forward_input_file and backward_input_file should be a json file!") out_path = os.path.realpath(args.out_path) if args.out_path else "./" save_error_data = args.save_error_data + if save_error_data: + initialize_save_error_data() run_ut(forward_file, backward_file, out_path, save_error_data) +class UtDataInfo: + def __init__(self, bench_grad_out, npu_grad_out, npu_out, bench_out, grad_in, in_fwd_data_list): + self.bench_grad_out = bench_grad_out + self.npu_grad_out = npu_grad_out + self.npu_out = npu_out + self.bench_out = bench_out + self.grad_in = grad_in + self.in_fwd_data_list = in_fwd_data_list + if __name__ == '__main__': _run_ut() print_info_log("UT task completed.") diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py new file mode 100644 index 0000000000000000000000000000000000000000..4d9d074fc6441972a34062c83f8c537fa6046654 --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/ut_api_info.py @@ -0,0 +1,9 @@ +from api_accuracy_checker.common.config import msCheckerConfig +from api_accuracy_checker.common.config import BaseAPIInfo + + +class UtAPIInfo(BaseAPIInfo): + def __init__(self, api_name, element, is_forward=True, is_save_data=True, save_path=msCheckerConfig.error_data_path, + forward_path='ut_error_data', backward_path='ut_error_data'): + super().__init__(api_name, is_forward, is_save_data, save_path, forward_path, backward_path) + self.analyze_element(element)