diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index b6a06ede3a2684294ab27d99d3897eba384cfa7f..0fd0082a1109b470db9ea5c9b1ef1fd1bf437b0f 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -712,7 +712,7 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): parsed_item = item parsed_item['full_op_name'] = full_op_name item_list.append(parsed_item) - else: + elif 'type' in item: parsed_item = {} if item['type'] == 'slice': parsed_item['full_op_name'] = full_op_name @@ -736,12 +736,45 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): parsed_item['Norm'] = item['value'] parsed_item['data_name'] = '-1' item_list.append(parsed_item) + else: + resolve_api_special_parameters(item, full_op_name, item_list) else: for j in range(len(item)): op_item_parse(item[j], full_op_name, j, top_bool=False) return item_list +def resolve_api_special_parameters(data_dict, full_op_name, item_list): + """ + Function Description: + 解析下面格式的数据, 是api参数的一种特殊格式 + { + "last_hidden_state": { + "type": "torch.Tensor", + "dtype": "torch.bfloat16", + ... + }, + "loss": { + "type": "torch.Tensor", + "dtype": "torch.float32", + ... + } + } + Parameter: + data_dict: 字典格式的数据 + full_op_name: 参数的全名字符串 + item_list: 参数信息集合 + """ + for key, value in data_dict.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split(".") + parts.insert(-1, key) + full_op_name_new = ".".join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) + + def read_op(op_data, op_name): op_parsed_list = [] if 'forward' in op_name: diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index 434f95910dafd57587d93ad22cb0a0c825083aba..b2d3d798f622336a17c6336c62f5731353a8eeea 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -18,8 +18,21 @@ class ModuleProcesser: self.scope = None BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook) self.module_count = {} + @staticmethod + def filter_tensor_and_tuple(func): + @wraps(func) + def wrap_by_filter_tensor_and_tuple(*args, **kwargs): + # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入 + # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] + if not isinstance(args[1], (torch.Tensor, tuple)): + return args[1] + return func(*args, **kwargs) + + return wrap_by_filter_tensor_and_tuple + @staticmethod def clone_return_value(func): @wraps(func)