diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py index f2361a90b76acead7c432271dbbcc22cd073c812..6e63b7cf884869f6b69bbcc4b130613436b2e60e 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/common/utils.py @@ -72,6 +72,7 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + PRE_FORWARD = "pre_forward" # dump mode ALL = "all" @@ -96,13 +97,15 @@ class Const: FILE_PATTERN = r'^[a-zA-Z0-9_./-]+$' FILE_NAME_LENGTH = 255 DIRECTORY_LENGTH = 4096 - + DISTRIBUTED_PREFIX_LENGTH = 60 # env dump path ASCEND_WORK_PATH = "ASCEND_WORK_PATH" DUMP_DIR = "dump_data" MAX_SEED_VALUE = 2**32 - 1 + INPLACE_LIST = ["broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter"] + class CompareConst: """ @@ -669,3 +672,11 @@ def check_path_before_create(path): if not re.match(Const.FILE_PATTERN, os.path.realpath(path)): print_error_log('The file path {} contains special characters.'.format(path)) raise CompareException(CompareException.INVALID_PATH_ERROR) + + +def check_inplace_op(prefix): + if len(prefix) > Const.DISTRIBUTED_PREFIX_LENGTH: + return False + match_op = re.findall(r"Distributed_(.+?)_\d", prefix) + op_name = match_op[0] if match_op else None + return op_name in Const.INPLACE_LIST diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py index 5238754cce48d36e940c600d2d14d402a1cbbe43..2a5ad3658c3e621471fef1ba186b0ca56d057d12 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/dump/dump.py @@ -32,7 +32,7 @@ else: is_gpu = False from .utils import DumpUtil, check_if_in_api_list, make_dump_data_dir, get_tensor_rank, create_dirs_if_not_exist -from ..common.utils import print_warn_log, Const, print_info_log, modify_dump_path +from ..common.utils import print_warn_log, Const, print_info_log, modify_dump_path, check_inplace_op from ..dump.utils import check_writable from ..common.file_check_util import FileOpen, change_mode, FileCheckConst, check_path_pattern_vaild, check_path_length @@ -104,11 +104,11 @@ def json_dump_condition(prefix): return (Const.BACKWARD in prefix and backward_threading_id == cur_threading_id) or 'forward' in prefix -def dump_tensor(x, prefix, dump_step, dump_file_name): +def dump_tensor(x, prefix, dump_step): global data_info if isinstance(x, (tuple, list)) and x: for i, item in enumerate(x): - dump_tensor(item, "{}.{}".format(prefix, i), dump_step, dump_file_name) + dump_tensor(item, "{}.{}".format(prefix, i), dump_step) return elif isinstance(x, torch.Tensor): if x.is_meta: @@ -117,20 +117,20 @@ def dump_tensor(x, prefix, dump_step, dump_file_name): if x.numel() == 0 or len(x.shape) == 0 or not x.is_floating_point(): if DumpUtil.dump_filter_switch == Const.OFF: data_info = get_not_float_tensor_info(x) - dump_data(dump_file_name, dump_step, prefix, data_info) + dump_data(dump_step, prefix, data_info) else: return else: data_info = get_float_tensor_info(x) - dump_data(dump_file_name, dump_step, prefix, data_info) + dump_data(dump_step, prefix, data_info) elif DumpUtil.dump_filter_switch == Const.OFF: if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): data_info = get_scalar_data_info(x) - dump_data(dump_file_name, dump_step, prefix, data_info) + dump_data(dump_step, prefix, data_info) -def dump_data(dump_file_name, dump_step, prefix, data_info): +def dump_data(dump_step, prefix, data_info): global api_list thread_lock.acquire() try: @@ -149,7 +149,10 @@ def dump_data(dump_file_name, dump_step, prefix, data_info): thread_lock.release() -def dump_stack_info(name_template, dump_file): +def dump_stack_info(name_template): + if check_inplace_op(name_template) and Const.PRE_FORWARD in name_template: + return + stack_str = [] try: for (_, path, line, func, code, _) in inspect.stack()[4:]: @@ -172,17 +175,24 @@ def dump_stack_info(name_template, dump_file): api_list.append([prefix, stack_str]) -def dump_api_tensor(dump_step, in_feat, name_template, out_feat, dump_file): +def dump_api_tensor(dump_step, in_feat, name_template, out_feat): + if check_inplace_op(name_template): + if Const.PRE_FORWARD in name_template: + name_template = name_template.replace(Const.PRE_FORWARD, Const.FORWARD) + else: + dump_tensor(in_feat, name_template.format("output"), dump_step) + return + if Const.BACKWARD in name_template and Const.BACKWARD in DumpUtil.dump_mode: if 'input' in DumpUtil.dump_mode: - dump_tensor(out_feat, name_template.format("input"), dump_step, dump_file) + dump_tensor(out_feat, name_template.format("input"), dump_step) if 'output' in DumpUtil.dump_mode: - dump_tensor(in_feat, name_template.format("output"), dump_step, dump_file) + dump_tensor(in_feat, name_template.format("output"), dump_step) elif Const.BACKWARD not in name_template and Const.FORWARD in DumpUtil.dump_mode: if 'input' in DumpUtil.dump_mode: - dump_tensor(in_feat, name_template.format("input"), dump_step, dump_file) + dump_tensor(in_feat, name_template.format("input"), dump_step) if 'output' in DumpUtil.dump_mode: - dump_tensor(out_feat, name_template.format("output"), dump_step, dump_file) + dump_tensor(out_feat, name_template.format("output"), dump_step) def rename_(): @@ -247,16 +257,16 @@ def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): name_prefix = name name_template = f"{name_prefix}" + "_{}" if DumpUtil.dump_switch_mode in [Const.ALL, Const.API_LIST]: - dump_api_tensor(dump_step, in_feat, name_template, out_feat, dump_file) + dump_api_tensor(dump_step, in_feat, name_template, out_feat) elif DumpUtil.dump_switch_mode == Const.API_STACK: - dump_api_tensor(dump_step, in_feat, name_template, out_feat, dump_file) - dump_stack_info(name_template, dump_file) + dump_api_tensor(dump_step, in_feat, name_template, out_feat) + dump_stack_info(name_template) elif DumpUtil.check_switch_scope(name_prefix): if DumpUtil.dump_switch_mode == Const.ACL: acl_dump(module, name, name_prefix) elif DumpUtil.dump_switch_mode != Const.STACK: - dump_api_tensor(dump_step, in_feat, name_template, out_feat, dump_file) - dump_stack_info(name_template, dump_file) + dump_api_tensor(dump_step, in_feat, name_template, out_feat) + dump_stack_info(name_template) def acl_dump(module, module_name, name_prefix): @@ -336,7 +346,7 @@ def acc_cmp_dump(name, **kwargs): if not pid: return RuntimeError("Not get the specified process pid.") - def acc_cmp_hook(module, in_feat, out_feat): + def acc_cmp_hook(module, in_feat, out_feat=None): if pid == os.getpid(): dump_acc_cmp(name, in_feat, out_feat, dump_step, module) if hasattr(module, "input_args"): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py index 83f7dcacef0cf06f107763528df84ff291e35ccd..55b97208a6d4981889074d8d70a095f7d56b1d23 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/hook_module.py @@ -33,21 +33,21 @@ class HOOKModule(nn.Module): self.has_overflow = False self.input_args = tuple() self.input_kwargs = dict() + self.prefix = "" if not g_stop_hook: - prefix = "" if hasattr(self, "prefix_op_name_"): - prefix = self.prefix_op_name_ + self.prefix = self.prefix_op_name_ - if prefix not in HOOKModule.module_count: - HOOKModule.module_count[prefix] = 1 - prefix += '0_' + if self.prefix not in HOOKModule.module_count: + HOOKModule.module_count[self.prefix] = 1 + self.prefix += '0_' else: - HOOKModule.module_count[prefix] += 1 - prefix = prefix + str(HOOKModule.module_count[prefix] - 1) + '_' + HOOKModule.module_count[self.prefix] += 1 + self.prefix = self.prefix + str(HOOKModule.module_count[self.prefix] - 1) + '_' - self.register_forward_hook(hook(prefix + "forward")) - self.register_backward_hook(hook(prefix + "backward")) + self.register_forward_hook(hook(self.prefix + "forward")) + self.register_backward_hook(hook(self.prefix + "backward")) def __call__(self, *input, **kwargs): changed = False diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py index 7aa21770cefe569f040e60cbad25604af826b627..ed28dbe5d6712a76e5d0fc249a24911cf3e50b63 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/hook_module/wrap_distributed.py @@ -21,7 +21,7 @@ import torch.distributed as dist import yaml from .hook_module import HOOKModule -from ..common.utils import torch_device_guard +from ..common.utils import torch_device_guard, Const from ..common.file_check_util import FileOpen @@ -51,6 +51,8 @@ class DistributedOPTemplate(HOOKModule): self.op_name_ = op_name self.prefix_op_name_ = "Distributed_" + str(op_name) + "_" super().__init__(hook) + if self.op_name_ in Const.INPLACE_LIST: + self.register_forward_pre_hook(hook(self.prefix + Const.PRE_FORWARD)) @torch_device_guard def forward(self, *args, **kwargs): diff --git a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py index 1a26c5ac04e770e76fd7a3b01bb88b34d4bcaa20..6af477753d09e45ab5bf30e993e9112d7db7ca10 100644 --- a/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py +++ b/debug/accuracy_tools/ptdbg_ascend/src/python/ptdbg_ascend/overflow_check/utils.py @@ -1,6 +1,6 @@ import torch -from ..common.utils import Const, check_switch_valid +from ..common.utils import Const, check_switch_valid, check_inplace_op from ..dump.dump import dump_stack_info, get_scalar_data_info, dump_data, \ get_not_float_tensor_info, get_float_tensor_info from ..dump.utils import DumpUtil, make_dump_data_dir @@ -44,30 +44,37 @@ def set_overflow_check_switch(switch, filter_switch=Const.OFF): def dump_overflow(module_name, in_feat, out_feat, dump_file): name_template = f"{module_name}" + "_{}" DumpUtil.dump_data_dir = make_dump_data_dir(dump_file) - dump_stack_info(name_template, dump_file) + dump_stack_info(name_template) + if check_inplace_op(name_template): + if Const.PRE_FORWARD in name_template: + name_template = name_template.replace(Const.PRE_FORWARD, Const.FORWARD) + else: + _dump_tensor_completely(in_feat, name_template.format("output")) + return + if "forward" in name_template: - _dump_tensor_completely(in_feat, name_template.format("input"), dump_file) - _dump_tensor_completely(out_feat, name_template.format("output"), dump_file) + _dump_tensor_completely(in_feat, name_template.format("input")) + _dump_tensor_completely(out_feat, name_template.format("output")) else: - _dump_tensor_completely(in_feat, name_template.format("output"), dump_file) - _dump_tensor_completely(out_feat, name_template.format("input"), dump_file) + _dump_tensor_completely(in_feat, name_template.format("output")) + _dump_tensor_completely(out_feat, name_template.format("input")) -def _dump_tensor_completely(x, prefix, dump_file_name): +def _dump_tensor_completely(x, prefix): dump_flag = Const.DUMP_RATIO_MAX + 1 if isinstance(x, (tuple, list)) and x: for i, item in enumerate(x): - _dump_tensor_completely(item, "{}.{}".format(prefix, i), dump_file_name) + _dump_tensor_completely(item, "{}.{}".format(prefix, i)) elif isinstance(x, torch.Tensor): if x.numel() == 0 or len(x.shape) == 0 or not x.is_floating_point(): if OverFlowUtil.overflow_filter_switch == Const.OFF: data_info = get_not_float_tensor_info(x) - dump_data(dump_file_name, dump_flag, prefix, data_info) + dump_data(dump_flag, prefix, data_info) else: data_info = get_float_tensor_info(x) - dump_data(dump_file_name, dump_flag, prefix, data_info) + dump_data(dump_flag, prefix, data_info) elif OverFlowUtil.overflow_filter_switch == Const.OFF: if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): data_info = get_scalar_data_info(x) - dump_data(dump_file_name, dump_flag, prefix, data_info) + dump_data(dump_flag, prefix, data_info)