From 596dfa1b01cc7f89a2daf7cb3a342fc630e5bd8a Mon Sep 17 00:00:00 2001 From: l30044004 Date: Thu, 13 Jun 2024 20:34:25 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E5=A4=84=E7=90=86=E5=8E=9F=E7=94=9F?= =?UTF-8?q?=E5=8F=8D=E5=90=91hook=E5=A4=84=E7=90=86=E9=9D=9Etensor?= =?UTF-8?q?=E6=95=B0=E6=8D=AE=E6=8A=A5=E9=94=99=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../atat/pytorch/module_processer.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index 434f95910d..c728bf2b68 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -3,6 +3,7 @@ import torch from torch.utils.hooks import BackwardHook from .functional.scope import ModuleRangeScope from .common.utils import Const +from ..core.log import print_warn_log class ModuleProcesser: @@ -18,8 +19,23 @@ class ModuleProcesser: self.scope = None BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) + BackwardHook.setup_output_hook = ModuleProcesser.wrapper_setup_output_hook(BackwardHook.setup_output_hook) self.module_count = {} + @staticmethod + def wrapper_setup_output_hook(func): + @wraps(func) + def decorated(*args, **kwargs): + # BackwardHook中的setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] + if not isinstance(args[1], (torch.Tensor, tuple)): + print_warn_log("For backward hooks to be called, " + f"module output should be a Tensor or a tuple of Tensors but received {type(args[1])}, " + "therefore skipping dump this data.") + return args[1] + return func(*args, **kwargs) + + return decorated + @staticmethod def clone_return_value(func): @wraps(func) -- Gitee From 9a924530410c9fac3edc533beb8eca8e99f2da93 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Mon, 17 Jun 2024 19:02:12 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E9=80=82=E9=85=8D=E6=AF=94=E5=AF=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../atat/pytorch/compare/acc_compare.py | 11 ++++++++++- debug/accuracy_tools/atat/pytorch/module_processer.py | 7 ++----- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index b6a06ede3a..b0e868f7e7 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -712,7 +712,7 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): parsed_item = item parsed_item['full_op_name'] = full_op_name item_list.append(parsed_item) - else: + elif 'type' in item: parsed_item = {} if item['type'] == 'slice': parsed_item['full_op_name'] = full_op_name @@ -736,6 +736,15 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): parsed_item['Norm'] = item['value'] parsed_item['data_name'] = '-1' item_list.append(parsed_item) + else: + for key, value in item.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split('.') + parts.insert(-1, key) + full_op_name_new = '.'.join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) else: for j in range(len(item)): op_item_parse(item[j], full_op_name, j, top_bool=False) diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index c728bf2b68..9a68622ea0 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -3,7 +3,6 @@ import torch from torch.utils.hooks import BackwardHook from .functional.scope import ModuleRangeScope from .common.utils import Const -from ..core.log import print_warn_log class ModuleProcesser: @@ -26,11 +25,9 @@ class ModuleProcesser: def wrapper_setup_output_hook(func): @wraps(func) def decorated(*args, **kwargs): - # BackwardHook中的setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] + # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入 + # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] if not isinstance(args[1], (torch.Tensor, tuple)): - print_warn_log("For backward hooks to be called, " - f"module output should be a Tensor or a tuple of Tensors but received {type(args[1])}, " - "therefore skipping dump this data.") return args[1] return func(*args, **kwargs) -- Gitee From e66ca4fa894d91041ef822e60a2a34551fab082e Mon Sep 17 00:00:00 2001 From: l30044004 Date: Tue, 18 Jun 2024 15:39:15 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=8F=90=E5=8F=96=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E5=A2=9E=E5=8A=A0=E8=AF=B4=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../atat/pytorch/compare/acc_compare.py | 40 +++++++++++++++---- 1 file changed, 32 insertions(+), 8 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py index b0e868f7e7..0fd0082a11 100644 --- a/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py +++ b/debug/accuracy_tools/atat/pytorch/compare/acc_compare.py @@ -737,20 +737,44 @@ def op_item_parse(item, op_name, index, item_list=[], top_bool=True): parsed_item['data_name'] = '-1' item_list.append(parsed_item) else: - for key, value in item.items(): - if isinstance(value, dict): - parsed_item = value - parts = full_op_name.split('.') - parts.insert(-1, key) - full_op_name_new = '.'.join(parts) - parsed_item['full_op_name'] = full_op_name_new - item_list.append(parsed_item) + resolve_api_special_parameters(item, full_op_name, item_list) else: for j in range(len(item)): op_item_parse(item[j], full_op_name, j, top_bool=False) return item_list +def resolve_api_special_parameters(data_dict, full_op_name, item_list): + """ + Function Description: + 解析下面格式的数据, 是api参数的一种特殊格式 + { + "last_hidden_state": { + "type": "torch.Tensor", + "dtype": "torch.bfloat16", + ... + }, + "loss": { + "type": "torch.Tensor", + "dtype": "torch.float32", + ... + } + } + Parameter: + data_dict: 字典格式的数据 + full_op_name: 参数的全名字符串 + item_list: 参数信息集合 + """ + for key, value in data_dict.items(): + if isinstance(value, dict): + parsed_item = value + parts = full_op_name.split(".") + parts.insert(-1, key) + full_op_name_new = ".".join(parts) + parsed_item['full_op_name'] = full_op_name_new + item_list.append(parsed_item) + + def read_op(op_data, op_name): op_parsed_list = [] if 'forward' in op_name: -- Gitee From 9b2fb528c3764b9a529661798304481b6af9a9e4 Mon Sep 17 00:00:00 2001 From: l30044004 Date: Wed, 19 Jun 2024 14:17:03 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=96=B9=E6=B3=95?= =?UTF-8?q?=E5=91=BD=E5=90=8D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- debug/accuracy_tools/atat/pytorch/module_processer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/atat/pytorch/module_processer.py b/debug/accuracy_tools/atat/pytorch/module_processer.py index 9a68622ea0..b2d3d798f6 100644 --- a/debug/accuracy_tools/atat/pytorch/module_processer.py +++ b/debug/accuracy_tools/atat/pytorch/module_processer.py @@ -18,20 +18,20 @@ class ModuleProcesser: self.scope = None BackwardHook.setup_input_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_input_hook) BackwardHook.setup_output_hook = ModuleProcesser.clone_return_value(BackwardHook.setup_output_hook) - BackwardHook.setup_output_hook = ModuleProcesser.wrapper_setup_output_hook(BackwardHook.setup_output_hook) + BackwardHook.setup_output_hook = ModuleProcesser.filter_tensor_and_tuple(BackwardHook.setup_output_hook) self.module_count = {} @staticmethod - def wrapper_setup_output_hook(func): + def filter_tensor_and_tuple(func): @wraps(func) - def decorated(*args, **kwargs): + def wrap_by_filter_tensor_and_tuple(*args, **kwargs): # setup_output_hook传入非tensor数据,工具后续dump会报错,处理方式是非tensor数据不传入 # setup_output_hook定义为setup_output_hook(self, args),因此处理第二个位置参数,即*args[1] if not isinstance(args[1], (torch.Tensor, tuple)): return args[1] return func(*args, **kwargs) - return decorated + return wrap_by_filter_tensor_and_tuple @staticmethod def clone_return_value(func): -- Gitee