diff --git a/debug/accuracy_tools/atat/pytorch/dump/dump.py b/debug/accuracy_tools/atat/pytorch/dump/dump.py new file mode 100644 index 0000000000000000000000000000000000000000..f890360000b7f6d6d7fb7804d779b4a24199a356 --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/dump/dump.py @@ -0,0 +1,455 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +""" +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import inspect +import json +import os +import threading +from pathlib import Path + +import numpy as np +import torch + +try: + import torch_npu +except ImportError: + is_gpu = True +else: + is_gpu = False + +from atat.core.utils import (print_warn_log, Const, print_info_log, modify_dump_path, check_inplace_op, CompareConst, + print_error_log) +from atat.core.file_check_util import FileOpen, change_mode, FileCheckConst +from atat.pytorch.common.utils import get_md5_for_tensor +from ..dump.utils import check_writable +from .utils import (DumpUtil, check_if_in_api_list, make_dump_data_dir, get_tensor_rank, create_dirs_if_not_exist, + CompareException, check_single_rank_folder) + + +forward_init_status = False +backward_init_status = False + +thread_lock = threading.Lock() +pkl_name = "" +rank = os.getpid() + 100000 +multi_output_apis = ["_sort_", "npu_flash_attention"] +module_count = {} + + +class APIList(list): + threshold = 1000 + + def __init__(self, *args): + self.dump_count = 0 + self.pkl_mode_changed = False + super().__init__(*args) + + def flush(self): + pkl_path = get_pkl_file_path() + if len(self) == 0 or pkl_path == "": + return + with FileOpen(pkl_path, 'a') as f: + try: + f.write('\n'.join(json.dumps(item) for item in self)) + f.write('\n') + except IOError as ex: + raise Exception("write to disk failed") from ex + self.dump_count += 1 + print_info_log(f"write {len(self)} items to {pkl_path} the {self.dump_count} time") + if not self.pkl_mode_changed: + change_mode(pkl_path, FileCheckConst.DATA_FILE_AUTHORITY) + self.pkl_mode_changed = True + self.clear() + + def append(self, data): + list.append(self, data) + if len(self) >= APIList.threshold: + self.flush() + + +api_list = APIList() + + +class DataInfo(object): + def __init__(self, save_data, summary_data, dtype, shape, md5=None): + if md5 is None: + md5 = [] + self.save_data = save_data + self.summary_data = summary_data + self.dtype = dtype + self.shape = shape + self.md5 = md5 + + +def get_not_float_tensor_info(data): + if DumpUtil.summary_mode == "md5": + return DataInfo([], [], str(data.dtype), tuple(data.shape), get_md5_for_tensor(data)) + if data.numel() == 0 or data.dtype == torch.bool: + tensor_max = [] + tensor_min = [] + tensor_mean = [] + elif len(data.shape) == 0: + item = data.float().item() + tensor_max = item + tensor_min = item + tensor_mean = item + else: + tensor_max = torch._C._VariableFunctionsClass.max(data).float().item() + tensor_min = torch._C._VariableFunctionsClass.min(data).float().item() + tensor_mean = torch._C._VariableFunctionsClass.mean(data.float()).float().item() + return get_tensor_data_info(data, tensor_max, tensor_min, tensor_mean, CompareConst.NAN) + + +def get_scalar_data_info(data): + summary_data = [data, data, data, data] + return DataInfo(data, summary_data, str(type(data)), str([])) + + +def get_float_tensor_info(data): + if DumpUtil.summary_mode == "md5": + return DataInfo([], [], str(data.dtype), tuple(data.shape), get_md5_for_tensor(data)) + tensor_max = torch._C._VariableFunctionsClass.max(data).float().item() + tensor_min = torch._C._VariableFunctionsClass.min(data).float().item() + tensor_mean = torch._C._VariableFunctionsClass.mean(data).float().item() + tensor_norm = torch._C._VariableFunctionsClass.norm(data).float().item() + return get_tensor_data_info(data, tensor_max, tensor_min, tensor_mean, tensor_norm) + + +def get_tensor_data_info(data, *tensor_args): + summary_data = [] + summary_data.extend([*tensor_args]) + if DumpUtil.summary_mode == "all": + saved_tensor = data.contiguous().cpu().detach() + if data.dtype == torch.bfloat16: + saved_numpy = saved_tensor.to(torch.float32).numpy() + else: + saved_numpy = saved_tensor.numpy() + return DataInfo(saved_numpy, summary_data, str(data.dtype), tuple(data.shape)) + return DataInfo([], summary_data, str(data.dtype), tuple(data.shape)) + + +def dump_tensor(x, prefix, dump_step): + if isinstance(x, (tuple, list)) and x: + for i, item in enumerate(x): + dump_tensor(item, "{}.{}".format(prefix, i), dump_step) + return + elif isinstance(x, torch.Tensor): + if x.is_meta: + print_info_log(f"Meta tensor {prefix} is skipped.") + return + x_clone = x.clone().detach() + if x_clone.numel() == 0 or len(x_clone.shape) == 0 or not x_clone.is_floating_point(): + if DumpUtil.dump_filter_switch == Const.OFF: + data_info = get_not_float_tensor_info(x_clone) + dump_data_by_rank_count(dump_step, prefix, data_info) + else: + return + else: + data_info = get_float_tensor_info(x_clone) + dump_data_by_rank_count(dump_step, prefix, data_info) + + elif DumpUtil.dump_filter_switch == Const.OFF: + if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): + data_info = get_scalar_data_info(x) + dump_data_by_rank_count(dump_step, prefix, data_info) + + +def append_pkl_data(dump_step, prefix, data_info): + global api_list + thread_lock.acquire() + api_list.append([prefix, dump_step, data_info.md5, data_info.dtype, data_info.shape, data_info.summary_data]) + thread_lock.release() + + +def dump_data(prefix, data_info): + if DumpUtil.summary_mode != "all": + return + output_path = os.path.join(DumpUtil.dump_data_dir, f'{prefix}.npy') + try: + np.save(output_path, data_info.save_data) + change_mode(output_path, FileCheckConst.DATA_FILE_AUTHORITY) + except Exception as e: + print_warn_log("Dump data failed, error: {}".format(e)) + + +def thread_dump_data(prefix, data_info): + DumpUtil.dump_thread_pool.submit(dump_data, prefix, data_info) + + +def dump_data_by_rank_count(dump_step, prefix, data_info): + print_info_log(f"ptdbg is analyzing rank{rank} api: {prefix}" + " " * 10, end='\r') + if DumpUtil.is_single_rank and DumpUtil.dump_thread_pool: + thread_dump_data(prefix, data_info) + else: + dump_data(prefix, data_info) + append_pkl_data(dump_step, prefix, data_info) + + +def dump_stack_info(name_template): + if check_inplace_op(name_template) and Const.PRE_FORWARD in name_template: + return + + stack_str = [] + try: + for (_, path, line, func, code, _) in inspect.stack()[4:]: + if code: + stack_line = [path, str(line), func, code[0].strip() if code else code] + else: + stack_line = [path, str(line), func, code] + stack_str.append(stack_line) + except Exception as e: + print_warn_log("Dump stack info failed, error: {}".format(e)) + stack_str.append('') + + prefix = name_template.format("stack_info") + if DumpUtil.dump_switch_mode in Const.DUMP_MODE: + complement_set = set(['forward', 'backward', 'input', 'output']) - set(DumpUtil.dump_mode) + if not any(mode in prefix for mode in complement_set): + api_list.append([prefix, stack_str]) + else: + api_list.append([prefix, stack_str]) + + +def dump_api_tensor(dump_step, in_feat, name_template, out_feat): + if check_inplace_op(name_template): + if Const.PRE_FORWARD in name_template: + name_template = name_template.replace(Const.PRE_FORWARD, Const.FORWARD) + else: + if Const.BACKWARD in name_template and Const.BACKWARD in DumpUtil.dump_mode: + return + elif Const.BACKWARD not in name_template and Const.FORWARD in DumpUtil.dump_mode: + if "output" in DumpUtil.dump_mode: + dump_tensor(in_feat, name_template.format("output"), dump_step) + if "input" in DumpUtil.dump_mode: + return + + if Const.BACKWARD in name_template and Const.BACKWARD in DumpUtil.dump_mode: + if 'input' in DumpUtil.dump_mode: + dump_tensor(out_feat, name_template.format("input"), dump_step) + if 'output' in DumpUtil.dump_mode: + dump_tensor(in_feat, name_template.format("output"), dump_step) + elif Const.BACKWARD not in name_template and Const.FORWARD in DumpUtil.dump_mode: + if 'input' in DumpUtil.dump_mode: + dump_tensor(in_feat, name_template.format("input"), dump_step) + if 'output' in DumpUtil.dump_mode: + dump_tensor(out_feat, name_template.format("output"), dump_step) + + +def rename_(): + global rank + global pkl_name + if rank is not None and pkl_name is not None: + dir_name = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(os.getpid() + 100000)) + new_name = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(rank)) + if not os.path.exists(new_name) and os.path.exists(dir_name): + _, file_name = os.path.split(pkl_name) + os.rename(dir_name, new_name) + pkl_name = os.path.join(new_name, file_name) + + +def dump_acc_cmp(name, in_feat, out_feat, dump_step, module): + if not DumpUtil.get_dump_switch(): + return + if DumpUtil.dump_switch_mode == Const.API_LIST and not check_if_in_api_list(name): + return + if DumpUtil.dump_switch_mode in [Const.LIST, Const.ACL, Const.RANGE, Const.STACK] and not DumpUtil.check_switch_scope(name): + return + dump_file = DumpUtil.get_dump_path() + dump_file = modify_dump_path(dump_file, DumpUtil.dump_switch_mode) + global rank + dump_dir, dump_filename = os.path.split(dump_file) + dump_dir = os.path.join(dump_dir, "step{}".format(DumpUtil.iter_num)) + if not os.path.exists(dump_dir): + Path(dump_dir).mkdir(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + dump_file = os.path.join(dump_dir, dump_filename) + rank_this = get_tensor_rank(in_feat, out_feat) + DumpUtil.dump_root = os.path.dirname(DumpUtil.dump_path) + if rank_this is not None and rank != rank_this: + rank = rank_this + rename_() + if not DumpUtil.dump_init_enable: + if '.pkl' in dump_filename: + npy_dir = dump_filename[:-4] + else: + npy_dir = dump_filename + DumpUtil.dump_data_dir = os.path.join(DumpUtil.dump_root, "step{}".format(DumpUtil.iter_num), "rank{}".format(rank), npy_dir) + if DumpUtil.target_rank is not None: + if rank != DumpUtil.target_rank: + return + dump_file = create_dirs_if_not_exist(rank, dump_file) + global pkl_name + pkl_name = dump_file + if DumpUtil.dump_init_enable: + DumpUtil.dump_init_enable = False + DumpUtil.dump_data_dir = make_dump_data_dir(dump_file) \ + if DumpUtil.dump_switch_mode not in [Const.STACK, Const.ACL] and DumpUtil.summary_mode == "all" else "" + if os.path.exists(dump_file) and not os.path.isdir(dump_file): + check_writable(dump_file) + try: + os.remove(dump_file) + except FileNotFoundError as e: + print_warn_log("The file does not exist, error: {}".format(e)) + + name_prefix = name + name_template = f"{name_prefix}" + "_{}" + if DumpUtil.is_single_rank is None: + DumpUtil.is_single_rank = check_single_rank_folder(dump_dir) + if DumpUtil.dump_switch_mode in [Const.ALL, Const.API_LIST]: + dump_api_tensor(dump_step, in_feat, name_template, out_feat) + elif DumpUtil.dump_switch_mode == Const.API_STACK: + dump_api_tensor(dump_step, in_feat, name_template, out_feat) + dump_stack_info(name_template) + else: + if DumpUtil.dump_switch_mode == Const.ACL: + acl_dump(module, name, name_prefix) + elif DumpUtil.dump_switch_mode != Const.STACK: + dump_api_tensor(dump_step, in_feat, name_template, out_feat) + dump_stack_info(name_template) + + +def acl_dump(module, module_name, name_prefix): + if name_prefix in DumpUtil.backward_input: + dump_mode_backward_acl_dump(module, module_name, DumpUtil.backward_input.get(name_prefix)) + else: + forward_acl_dump(module, module_name) + + +def Op_Need_Trigger(module_name): + if 'Tensor___getitem___' in module_name: + return True + return False + + +def forward_acl_dump(module, module_name): + global forward_init_status + global backward_init_status + if not forward_init_status and not backward_init_status: + forward_init_status = True + torch_npu.npu.synchronize() + torch_npu.npu.init_dump() + torch_npu.npu.set_dump(DumpUtil.dump_config) + torch_npu.npu.synchronize() + if Op_Need_Trigger(module_name): + module.forward(*module.input_args, **module.input_kwargs).cpu() + else: + module.forward(*module.input_args, **module.input_kwargs) + torch_npu.npu.synchronize() + torch_npu.npu.finalize_dump() + torch_npu.npu.synchronize() + del module.input_args + del module.input_kwargs + forward_init_status = False + print_info_log("Dump %s op file." % module_name) + + +def acl_backward_dump_status(output, grad, module_name): + if isinstance(output, torch.Tensor): + output.backward(grad, retain_graph=True) + return True + + for api_name in multi_output_apis: + if api_name in module_name: + output[0].backward(grad, retain_graph=True) + return True + return False + + +def dump_mode_backward_acl_dump(module, module_name, grad_path): + global forward_init_status + global backward_init_status + module_name = module_name.replace(Const.FORWARD, Const.BACKWARD) + if not forward_init_status and not backward_init_status: + forward_init_status = True + module.input_args = list(module.input_args) + for i, data in enumerate(module.input_args): + if isinstance(data, torch.Tensor) and data.grad_fn: + module.input_args[i] = data.detach().requires_grad_() + output = module.forward(*module.input_args, **module.input_kwargs) + grad = torch.tensor(np.load(grad_path)).to("npu").requires_grad_() + torch_npu.npu.init_dump() + torch_npu.npu.set_dump(DumpUtil.dump_config) + torch_npu.npu.synchronize() + if not acl_backward_dump_status(output, grad, module_name): + print_warn_log("The output of {} is not of tensor type and cannot be automatically derived. " + "you can manually construct a single API backward case for ACL dump.".format(module_name)) + torch_npu.npu.synchronize() + torch_npu.npu.finalize_dump() + del module.input_args + del module.input_kwargs + forward_init_status = False + print_info_log("Dump %s op file." % module_name) + + +def module_count_func(name, name_template): + module_name = name.split("_")[-3] + if Const.FORWARD in name_template: + if module_name not in module_count: + module_count[module_name] = [0, [0]] + else: + if module_count[module_name][-1] and \ + module_count[module_name][0] != module_count[module_name][-1][-1]: + module_count[module_name][-1].pop() + module_count[module_name][0] += 1 + module_count[module_name][-1].append(module_count[module_name][0]) + index = module_count[module_name][0] + else: + backward_stack = module_count[module_name][-1] if module_name in module_count else [] + if not backward_stack: + print_warn_log("The backward stack of {} is empty.".format(module_name)) + index = "abnormal" + else: + index = backward_stack.pop() + return index + + +def acc_cmp_dump(name, **kwargs): + dump_step = kwargs.get('dump_step', 1) + pid = kwargs.get('pid') + name_template = name + if not pid: + return RuntimeError("Not get the specified process pid.") + + def acc_cmp_hook(module, in_feat, out_feat=None): + nonlocal name, name_template + if "_{}_" in name_template: + try: + index = module_count_func(name, name_template) + except IndexError as e: + print_error_log(f"Get module {name_template} index failed.") + raise CompareException(CompareException.INDEX_OUT_OF_BOUNDS_ERROR) from e + name = name_template.format(index) + if pid == os.getpid(): + dump_acc_cmp(name, in_feat, out_feat, dump_step, module) + if hasattr(module, "input_args"): + del module.input_args + if hasattr(module, "input_kwargs"): + del module.input_kwargs + + return acc_cmp_hook + + +def write_to_disk(): + api_list.flush() + + +def get_pkl_file_path(): + return pkl_name + + +def reset_module_count(): + global module_count + module_count = {} diff --git a/debug/accuracy_tools/atat/pytorch/dump/dump_module.py b/debug/accuracy_tools/atat/pytorch/dump/dump_module.py new file mode 100644 index 0000000000000000000000000000000000000000..79a8259f1b4fd2588d4cff0d3e63757e8185a78f --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/dump/dump_module.py @@ -0,0 +1,29 @@ +import os +import torch.nn as nn +from atat.core.utils import print_error_log, DumpException +from .dump import acc_cmp_dump +from ..hook_module.api_registry import api_register + +module_count = {} + + +def module_dump(module, dump_name): + if not isinstance(module, nn.Module): + print_error_log("The parameter:module in module_dump is not a Module subclass.") + raise DumpException(DumpException.INVALID_PARAM_ERROR) + if not isinstance(dump_name, str): + print_error_log("The parameter:dump_name in module_dump is not a str type.") + raise DumpException(DumpException.INVALID_PARAM_ERROR) + pid = os.getpid() + api_register.api_originality() + if dump_name not in module_count: + module_count[dump_name] = 0 + else: + module_count[dump_name] += 1 + dump_name = dump_name + '_' + str(module_count.get(dump_name)) + "_" + module.register_forward_hook(acc_cmp_dump(dump_name + "forward", pid=pid)) + module.register_backward_hook(acc_cmp_dump(dump_name + "backward", pid=pid)) + + +def module_dump_end(): + api_register.api_modularity() diff --git a/debug/accuracy_tools/atat/pytorch/dump/utils.py b/debug/accuracy_tools/atat/pytorch/dump/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8e58f35606a4a4f9cf9e7ae732beeedb7777cdef --- /dev/null +++ b/debug/accuracy_tools/atat/pytorch/dump/utils.py @@ -0,0 +1,357 @@ +import os +import re +import shutil +from pathlib import Path +import torch +import torch.distributed as dist + +from atat.core.utils import print_error_log, CompareException, DumpException, Const, get_time, print_info_log, \ + check_mode_valid, check_switch_valid, check_dump_mode_valid, check_summary_only_valid, generate_compare_script, \ + check_file_valid, make_dump_path_if_not_exists, check_path_before_create, check_summary_mode_valid +from atat.core.file_check_util import FileChecker, FileCheckConst, check_path_length, check_path_pattern_vaild +from atat.pytorch.common.utils import check_is_npu + +from ..dump import dump + +dump_count = 0 +range_begin_flag, range_end_flag = False, False + + +def check_list_or_acl_mode(name_prefix): + global dump_count + for item in DumpUtil.dump_switch_scope: + if name_prefix.startswith(item): + dump_count = dump_count + 1 + return True + return False + + +def check_range_mode(name_prefix): + global range_begin_flag + global range_end_flag + if name_prefix.startswith(DumpUtil.dump_switch_scope[0]): + range_begin_flag = True + return True + if name_prefix.startswith(DumpUtil.dump_switch_scope[1]): + range_end_flag = True + return True + if range_begin_flag and not range_end_flag: + return True + return False + + +def check_stack_mode(name_prefix): + if len(DumpUtil.dump_switch_scope) == 0: + return True + elif len(DumpUtil.dump_switch_scope) == 1: + return name_prefix.startswith(DumpUtil.dump_switch_scope[0]) + elif len(DumpUtil.dump_switch_scope) == 2: + return check_range_mode(name_prefix) + else: + print_error_log("dump scope is invalid, Please set the scope mode in" + " set_dump_switch with 'all', 'list', 'range', 'stack', 'acl', 'api_list'!") + return False + + +class DumpConfig: + def __init__(self, mode=None, scope=None, api_list=None, filter_switch=None, dump_mode=None, summary_only=False, summary_mode="all"): + self.mode = mode + self.scope = scope + self.api_list = api_list + self.filter_switch = filter_switch + self.dump_mode = dump_mode + self.summary_only = summary_only + self.summary_mode = summary_mode + + +class DumpUtil(object): + dump_root = None + dump_data_dir = None + dump_path = None + dump_switch = None + dump_switch_mode = Const.ALL # all, api_stack, list, stack... + dump_switch_scope = [] + dump_init_enable = False + dump_api_list = [] + dump_filter_switch = None + dump_mode = ['forward', 'backward', 'input', 'output'] + backward_input = {} + dump_dir_tag = 'ptdbg_dump' + dump_config = None + dataloader_iter = 0 + target_iter = None + iter_num = 0 + target_rank = None + summary_only = False + need_replicate = False + summary_mode = "all" + is_single_rank = None + dump_thread_pool = None + + + @staticmethod + def set_dump_path(save_path): + DumpUtil.dump_path = save_path + DumpUtil.dump_init_enable = True + + @staticmethod + def set_acl_config(acl_config): + if not acl_config: + raise ValueError("acl_config must be configured when mode is 'acl'") + acl_config_checker = FileChecker(acl_config, FileCheckConst.FILE, FileCheckConst.READ_ABLE, + FileCheckConst.JSON_SUFFIX) + acl_config = acl_config_checker.common_check() + DumpUtil.dump_config = acl_config + + @staticmethod + def set_dump_switch(switch, dump_config): + DumpUtil.dump_switch = switch + if dump_config.mode is not None: + DumpUtil.dump_switch_mode = dump_config.mode + DumpUtil.dump_init_enable = True + if dump_config.scope is not None: + DumpUtil.dump_switch_scope = dump_config.scope + if dump_config.api_list is not None: + DumpUtil.dump_api_list = [api.lower() for api in dump_config.api_list] + if dump_config.filter_switch is not None: + DumpUtil.dump_filter_switch = dump_config.filter_switch + if dump_config.dump_mode is not None: + DumpUtil.dump_mode = dump_config.dump_mode if isinstance(dump_config.dump_mode, list) else [dump_config.dump_mode] + + if dump_config.mode == Const.ACL: + DumpUtil.dump_switch_scope = [api_name.replace("backward", "forward") for api_name in dump_config.scope] + + DumpUtil.summary_only = dump_config.summary_only + DumpUtil.summary_mode = dump_config.summary_mode + + check_mapper = { + Const.LIST: check_list_or_acl_mode, + Const.ACL: check_list_or_acl_mode, + Const.RANGE: check_range_mode, + Const.STACK: check_stack_mode + } + + @staticmethod + def check_switch_scope(name_prefix): + if DumpUtil.dump_switch_mode in DumpUtil.check_mapper: + check_func = DumpUtil.check_mapper[DumpUtil.dump_switch_mode] + return check_func(name_prefix) + return False + + @staticmethod + def get_dump_path(): + if DumpUtil.dump_path: + return DumpUtil.dump_path + + if DumpUtil.dump_switch_mode == Const.ALL: + raise RuntimeError("get_dump_path: the file path is empty," + " you must use set_dump_path to set a valid dump path!!!") + else: + dir_path = os.path.realpath("./") + dump_file_name = "scope_dump_{}_{}_{}.pkl".format( + DumpUtil.dump_switch_mode, DumpUtil.dump_switch_scope[0], get_time()) + DumpUtil.dump_path = os.path.join(dir_path, dump_file_name) + return DumpUtil.dump_path + + @staticmethod + def get_dump_switch(): + return DumpUtil.dump_switch == "ON" + + +def set_dump_path(fpath=None, dump_tag='ptdbg_dump'): + fpath = load_env_dump_path(fpath) + check_file_valid(fpath) + if not re.match(Const.FILE_PATTERN, dump_tag): + print_error_log('The file path {} contains special characters.'.format(dump_tag)) + raise CompareException(CompareException.INVALID_PATH_ERROR) + real_path = os.path.realpath(fpath) + make_dump_path_if_not_exists(real_path) + fpath_checker = FileChecker(real_path, FileCheckConst.DIR, FileCheckConst.WRITE_ABLE) + fpath_checker.common_check() + DumpUtil.set_dump_path(real_path) + DumpUtil.dump_dir_tag = dump_tag + + +def get_tensor_rank(in_feat, out_feat): + if dist.is_initialized(): + return dist.get_rank() + + def get_tensor_rank_single(x): + if isinstance(x, (list, tuple)): + if len(x) > 0: + return get_tensor_rank_single(x[0]) + return None + elif isinstance(x, torch.Tensor): + device = x.device + if device.type == 'cpu': + return None + else: + return device.index + return None + in_rank = get_tensor_rank_single(in_feat) + if in_rank is None: + out_rank = get_tensor_rank_single(out_feat) + if out_rank is None: + return None + return out_rank + return in_rank + + +def create_dirs_if_not_exist(rank, dump_file): + dump_path, file_name = os.path.split(dump_file) + rank_dir = os.path.join(dump_path, f"rank{rank}") + dump_file = os.path.join(rank_dir, file_name) + if not os.path.isdir(rank_dir): + check_path_pattern_vaild(dump_file) + check_path_length(dump_file, name_length=200) + Path(rank_dir).mkdir(mode=FileCheckConst.DATA_DIR_AUTHORITY, exist_ok=True) + return dump_file + + +def generate_dump_path_str(): + if DumpUtil.dump_switch_mode == 'acl': + if DumpUtil.dump_config == '': + print_error_log("Please provide dump config for register hook before turning on dump switch!") + raise DumpException(DumpException.NONE_ERROR) + dump_path = f"according to dump config {DumpUtil.dump_config}" + else: + dump_dir, dump_file = os.path.split(DumpUtil.dump_path) + if not dump_file.endswith(".pkl"): + dump_dir = DumpUtil.dump_path + dump_path = f"to {dump_dir}" + return dump_path + + +def set_dump_switch(switch, mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, + summary_only=False): + if scope is None: + scope = [] + if api_list is None: + api_list = [] + if dump_mode is None: + dump_mode = [Const.ALL] + check_switch_valid(switch) + if not DumpUtil.dump_path: + set_dump_path() + dump_config = DumpConfig(summary_only=summary_only) + DumpUtil.set_dump_switch(switch, dump_config) + dump_path_str = generate_dump_path_str() + if switch == "OFF": + dump.write_to_disk() + if check_is_npu() and DumpUtil.dump_switch_mode in [Const.ALL, Const.API_STACK, Const.LIST, Const.RANGE, Const.API_LIST]: + generate_compare_script(DumpUtil.dump_data_dir, dump.get_pkl_file_path(), DumpUtil.dump_switch_mode) + set_dump_switch_print_info(switch, mode, dump_path_str) + set_dump_switch_config(mode=mode, scope=scope, api_list=api_list, filter_switch=filter_switch, dump_mode=dump_mode, + summary_only=summary_only) + + +def set_dump_switch_config(mode=Const.ALL, scope=None, api_list=None, filter_switch=Const.OFF, dump_mode=None, + summary_only=False, summary_mode="all"): + if scope is None: + scope = [] + if api_list is None: + api_list = [] + if dump_mode is None: + dump_mode = [Const.ALL] + try: + check_summary_mode_valid(summary_mode) + check_mode_valid(mode, scope, api_list) + check_switch_valid(filter_switch) + dump_mode = check_dump_mode_valid(dump_mode) + summary_only = check_summary_only_valid(summary_only) + except (CompareException, AssertionError) as err: + print_error_log(str(err)) + raise CompareException(CompareException.INVALID_PARAM_ERROR) from err + switch = DumpUtil.dump_switch + dump_config = DumpConfig(mode, scope, api_list, filter_switch, dump_mode, summary_only, summary_mode) + DumpUtil.set_dump_switch("OFF", dump_config) + DumpUtil.dump_switch = switch + + +def set_dump_switch_print_info(switch, mode, dump_path_str): + global dump_count + if switch == "ON": + print_info_log(f"Dump switch is turned on. Dump data will be saved {dump_path_str}. ") + if mode == Const.LIST: + dump_count = 0 + else: + print_info_log(f"Dump switch is turned off. ") + if mode == Const.LIST: + print_info_log("The number of matched dump is {}".format(dump_count)) + + +def check_if_in_api_list(name): + if not DumpUtil.dump_api_list: + return False + for api in DumpUtil.dump_api_list: + if api.lower() in name.lower(): + return True + return False + + +def set_backward_input(backward_input): + for index, api_name in enumerate(DumpUtil.dump_switch_scope): + DumpUtil.backward_input[api_name] = backward_input[index] + + +def make_dump_data_dir(dump_file_name): + dump_path, file_name = os.path.split(os.path.realpath(dump_file_name)) + name_body, name_extension = os.path.splitext(file_name) + output_dir = os.path.join(dump_path, f"{name_body}") + check_path_before_create(output_dir) + if not os.path.exists(output_dir): + Path(output_dir).mkdir(mode=0o750, exist_ok=True) + else: + shutil.rmtree(output_dir, ignore_errors=True) + Path(output_dir).mkdir(mode=0o750, exist_ok=True) + return output_dir + + +def make_dump_dirs(): + dump_file_name, dump_file_name_body = "dump.pkl", "dump" + dump_root_dir = load_env_dump_path(DumpUtil.dump_path) + tag_dir = os.path.join(dump_root_dir, DumpUtil.dump_dir_tag) + check_path_length(tag_dir) + check_path_pattern_vaild(tag_dir) + Path(tag_dir).mkdir(mode=0o750, parents=True, exist_ok=True) + DumpUtil.dump_dir = tag_dir + dump_file_path = os.path.join(tag_dir, dump_file_name) + DumpUtil.set_dump_path(dump_file_path) + + +def check_writable(dump_file): + if not os.access(dump_file, os.W_OK): + print_error_log( + 'The path {} does not have permission to write. Please check the path permission'.format( + dump_file)) + raise DumpException(DumpException.INVALID_PATH_ERROR) + + +def load_env_dump_path(dump_path): + if not dump_path: + dump_path = os.getenv(Const.ASCEND_WORK_PATH) + if dump_path: + try: + dump_path = os.path.join(str(dump_path), Const.DUMP_DIR) + except TypeError as err: + print_error_log("Generating dump path from environment variables ASCEND_WORK_PATH failed.") + raise DumpException(DumpException.INVALID_PATH_ERROR) from err + else: + print_error_log("Dump path is None, you can configure it in the following ways:\n" + "1. Configure set_dump_path function.\n" + "2. Configure the dump_path parameter of PrecisionDebugger.\n" + "3. Set environment variables ASCEND_WORK_PATH.") + raise DumpException(DumpException.INVALID_PATH_ERROR) + return dump_path + + +def check_single_rank_folder(dump_path): + rank_folder_pattern = re.compile(r'^rank\d+$') + rank_folder_count = 0 + for item in os.listdir(dump_path): + full_path = os.path.join(dump_path, item) + if os.path.isdir(full_path) and rank_folder_pattern.match(item): + rank_folder_count += 1 + if rank_folder_count > 1: + return False + return rank_folder_count == 1