diff --git a/debug/accuracy_tools/api_accuracy_checker/common/utils.py b/debug/accuracy_tools/api_accuracy_checker/common/utils.py index 988bcaa8f040502ce04804f828fc7a43f4678fd3..167437da9d2a950d3e0ebdc834f8062902e18b80 100644 --- a/debug/accuracy_tools/api_accuracy_checker/common/utils.py +++ b/debug/accuracy_tools/api_accuracy_checker/common/utils.py @@ -83,7 +83,7 @@ class Const: } CONVERT_API = { - "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate", "group_norm"] + "fp16_to_fp32": ["conv2d", "batch_norm", "relu", "max_pool2d", "interpolate", "group_norm", "layer_norm", "bmm", "tanh", "cross_entropy", "linear", "numel"] } class CompareConst: diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py index d33ed21d8134afc0a8e90453754dd53da38899eb..5d7fb97e27620a700fe7fc47b34c25a28213face 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/api_info.py @@ -9,11 +9,13 @@ from api_accuracy_checker.common.config import msCheckerConfig from api_accuracy_checker.dump.utils import write_npy class APIInfo: - def __init__(self, api_name): + def __init__(self, api_name, is_forward): self.rank = os.getpid() self.api_name = api_name self.save_real_data = msCheckerConfig.real_data self.torch_object_key = {'device' : self.analyze_device_in_kwargs, 'dtype' : self.analyze_dtype_in_kwargs} + self.is_forward = is_forward + self.args_num = 0 def analyze_element(self, element): if isinstance(element, (list, tuple)): @@ -55,8 +57,15 @@ class APIInfo: else: dump_path = msCheckerConfig.dump_path - real_data_path = os.path.join(dump_path, 'real_data') - file_path = os.path.join(real_data_path, self.api_name) + api_args = self.api_name + '*' + str(self.args_num) + if self.is_forward: + forward_real_data_path = os.path.join(dump_path, 'forward_real_data') + + file_path = os.path.join(forward_real_data_path, f'{api_args}.npy') + else: + backward_real_data_path = os.path.join(dump_path, 'backward_real_data') + file_path = os.path.join(backward_real_data_path, f'{api_args}.npy') + self.args_num += 1 npy_path = write_npy(file_path, arg.contiguous().cpu().detach().numpy()) single_arg.update({'type' : 'torch.Tensor'}) single_arg.update({'datapath' : npy_path}) @@ -125,7 +134,7 @@ class APIInfo: class ForwardAPIInfo(APIInfo): def __init__(self, name, args, kwargs): - super().__init__(name) + super().__init__(name, is_forward=True) self.analyze_api_input(args, kwargs) self.analyze_api_call_stack() @@ -147,7 +156,7 @@ class ForwardAPIInfo(APIInfo): class BackwardAPIInfo(APIInfo): def __init__(self, name, grads): - super().__init__(name) + super().__init__(name, is_forward=False) self.analyze_api_input(grads) def analyze_api_input(self, grads): diff --git a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py index 1914b27d8d7e029081e1e0515d1220d51b4e0683..7790518399e3848d8856b479caae7d6ec3939801 100644 --- a/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py +++ b/debug/accuracy_tools/api_accuracy_checker/dump/info_dump.py @@ -53,8 +53,19 @@ def initialize_output_json(): check_file_or_directory_path(dump_path,True) files = ['forward_info.json', 'backward_info.json', 'stack_info.json'] if msCheckerConfig.real_data: - real_data_path = os.path.join(dump_path, 'real_data') - check_file_or_directory_path(real_data_path, True) + forward_real_data_path = os.path.join(dump_path, 'forward_real_data') + if os.path.exists(forward_real_data_path): + raise ValueError(f"file {forward_real_data_path} already exists, please remove it first") + else: + os.mkdir(forward_real_data_path, mode = 0o750) + check_file_or_directory_path(forward_real_data_path, True) + + backward_real_data_path = os.path.join(dump_path, 'backward_real_data') + if os.path.exists(backward_real_data_path): + raise ValueError(f"file {backward_real_data_path} already exists, please remove it first") + else: + os.mkdir(backward_real_data_path, mode = 0o750) + check_file_or_directory_path(backward_real_data_path, True) for file in files: file_path = os.path.join(dump_path, file) if os.path.exists(file_path):