diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py index d3254ae71f9a8fccb8608088462c4733c166814d..7cdb87e44889d83e2680a634cff082b9bd04d9e6 100644 --- a/debug/accuracy_tools/grad_tool/common/base_comparator.py +++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py @@ -8,6 +8,9 @@ import matplotlib.pyplot as plt from grad_tool.common.constant import GradConst from grad_tool.common.utils import write_csv, check_file_or_directory_path, print_info_log, create_directory +from api_accuracy_checker.common.utils import print_error_log +from ptdbg_ascend.src.python.ptdbg_ascend.common import file_check_util +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst class BaseComparator(ABC): @@ -85,8 +88,17 @@ class BaseComparator(ABC): picture_dir = os.path.join(output_dir, "similarities_picture") if not os.path.isdir(picture_dir): create_directory(picture_dir) - plt.savefig(os.path.join(picture_dir, f"{key}_similarities.png")) - plt.close() + file_path= os.path.join(picture_dir, f"{key}_similarities.png") + if os.path.exists(file_path): + raise ValueError(f"File {file_path} already exists") + try: + plt.savefig(file_path) + plt.close() + except Exception as e: + error_message = "An unexpected error occurred: %s when savfig to %s" % (str(e), file_path) + print_error_log(error_message) + full_path = os.path.realpath(file_path) + file_check_util.change_mode(full_path, FileCheckConst.DATA_FILE_AUTHORITY) head_tuple = tuple(['step'] + [str(step) for step in steps]) write_csv(os.path.join(output_dir, "similarities.csv"), [[key] + value], head_tuple) diff --git a/debug/accuracy_tools/grad_tool/common/utils.py b/debug/accuracy_tools/grad_tool/common/utils.py index f40f8688c2458fa17a5dc2db1ac999c9dc9ab878..43b63676e35c175ccabcd490c218dda02f221062 100644 --- a/debug/accuracy_tools/grad_tool/common/utils.py +++ b/debug/accuracy_tools/grad_tool/common/utils.py @@ -7,6 +7,7 @@ import yaml import pandas as pd from grad_tool.common.constant import GradConst +from msprobe.core.common.file_check import FileOpen def _print_log(level, msg, end='\n'): @@ -114,7 +115,7 @@ class ListCache(list): def get_config(filepath): - with open(filepath, 'r') as file: + with FileOpen(filepath, 'r') as file: config = yaml.safe_load(file) return config diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py index 895b8f2ae68e94fc5e3228566e466a5a597e5ee9..966b8e64108681d3edabfeb6da0f79e74e0b7ea3 100644 --- a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -16,6 +16,7 @@ from grad_tool.common.utils import ListCache, print_warn_log from grad_tool.common.utils import create_directory, check_file_or_directory_path, write_csv from grad_tool.grad_ms.global_context import grad_context from grad_tool.grad_ms.global_context import GlobalContext +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker def get_rank_id(): @@ -170,6 +171,8 @@ class CSVGenerator(Process): max_try = 10 while max_try: try: + file_path_checker = FileChecker(file_path, FileCheckConst.DIR) + file_path = file_path_checker.common_check() stat_data = np.load(file_path) return stat_data except Exception as err: @@ -177,7 +180,7 @@ class CSVGenerator(Process): max_try -= 1 time.sleep(0.1) return stat_data - + def gen_csv_line(self, file_path: str, stat_data) -> None: shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX]) file_name = os.path.basename(file_path) diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index daeda8898798559eac970e4163664c08953fc341..18bfad542dde055bf65e65495a0bf6b2c236ddb6 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -67,7 +67,7 @@ class Service: if not self.switch: return if self.data_collector: - module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_input, grad_output=grad_output) + module_input_output = ModuleBackwardInputsOutputs(grad_input=grad_output, grad_output=grad_input) self.data_collector.backward_data_collect(api_or_module_name, module, pid, module_input_output) pid = os.getpid()