diff --git a/debug/accuracy_tools/grad_tool/common/__init__.py b/debug/accuracy_tools/grad_tool/common/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/grad_tool/common/base_comparator.py b/debug/accuracy_tools/grad_tool/common/base_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..36ecc320c63e8b034ba60a3f0e6d1f49c65999c2 --- /dev/null +++ b/debug/accuracy_tools/grad_tool/common/base_comparator.py @@ -0,0 +1,131 @@ +import os +from typing import List +from abc import ABC, abstractmethod + +from tqdm import tqdm +import matplotlib.pyplot as plt + +from grad_tool.common.constant import GradConst +from grad_tool.common.utils import write_csv, check_file_or_directory_path, print_info_log, create_directory + + +class BaseComparator(ABC): + + @classmethod + def compare_distributed(cls, path1: str, path2: str, output_dir: str): + ranks = cls._get_matched_dirs(path1, path2, "rank") + print_info_log(f"the following ranks will be compared: {ranks}") + if not ranks: + raise RuntimeError("no matched ranks for comparison, please dump data in same configuration") + if not os.path.isdir(output_dir): + create_directory(output_dir) + for rank in tqdm(ranks, desc="rank"): + print_info_log(f"now comparing rank {rank}:") + cls.compare(os.path.join(path1, f"rank_{rank}"), + os.path.join(path2, f"rank_{rank}"), + os.path.join(output_dir, f"rank_{rank}")) + + @classmethod + def compare(cls, path1: str, path2: str, output_dir: str): + steps = cls._get_matched_dirs(path1, path2, "step") + if not steps: + raise RuntimeError("no matched steps for comparison, please dump data in same configuration") + similarities = cls._calculate_separated_similarities(path1, path2, steps) + if not os.path.isdir(output_dir): + create_directory(output_dir) + cls._save_similarities(similarities, steps, output_dir) + + @classmethod + def _get_matched_dirs(cls, path1: str, path2: str, dir_prefix): + check_file_or_directory_path(path1, file_type=GradConst.DIR) + check_file_or_directory_path(path2, file_type=GradConst.DIR) + dirs = [] + for dirname in os.listdir(path1): + splits = dirname.split('_') + if not splits or splits[0] != dir_prefix or not splits[1].isdigit(): + continue + + folder2 = os.path.join(path2, dirname) + if not os.path.isdir(folder2): + continue + dirs.append(int(splits[1])) + dirs = sorted(dirs) + return dirs + + @classmethod + def _save_similarities(cls, similarities: List[float], steps: List[int], output_dir: str): + if not similarities: + raise ValueError(f"length of similarities is 0") + for key, value in tqdm(similarities.items(), desc="save similarities (by param)"): + if len(value) != len(steps): + raise RuntimeError(f"similarities length of {key}:{len(value)} not equal steps:{len(steps)}") + plt.plot(steps, value) + plt.xlabel('steps') + plt.ylabel('similarities') + plt.title(f'{key}_similarities') + picture_dir = os.path.join(output_dir, "similarities_picture") + if not os.path.isdir(picture_dir): + create_directory(picture_dir) + plt.savefig(os.path.join(picture_dir, f"{key}_similarities.png")) + plt.close() + head_tuple = tuple(['step'] + [str(step) for step in steps]) + write_csv(os.path.join(output_dir, "similarities.csv"), [[key] + value], head_tuple) + + @classmethod + def _calculate_separated_similarities(cls, path1, path2, steps): + similarities = {} + print_info_log(f"{len(steps)} steps will be compared") + for step in tqdm(steps, desc="culculate similarities (by step)"): + grad_files = cls._get_matched_grad_files(path1, path2, step) + same_count_summary = 0 + total_count_summary = 0 + for grad_file in grad_files: + grad1 = os.path.join(path1, f"step_{step}", grad_file) + grad2 = os.path.join(path2, f"step_{step}", grad_file) + same_count, total_count = cls._calculate_similarity(grad1, grad2) + same_count_summary += same_count + total_count_summary += total_count + idx = grad_file.rfind(".") + param_name = grad_file[:idx] + if param_name not in similarities: + similarities[param_name] = [] + if total_count == 0: + similarities[param_name].append(0) + else: + similarities[param_name].append(same_count / total_count) + if GradConst.SUMMARY not in similarities: + similarities[GradConst.SUMMARY] = [] + if total_count_summary == 0: + similarities[GradConst.SUMMARY].append(0) + else: + similarities[GradConst.SUMMARY].append(same_count_summary / total_count_summary) + return similarities + + @classmethod + def _get_matched_grad_files(cls, path1: str, path2: str, step: int): + path1 = os.path.join(path1, f"step_{step}") + path2 = os.path.join(path2, f"step_{step}") + check_file_or_directory_path(path1, file_type=GradConst.DIR) + check_file_or_directory_path(path2, file_type=GradConst.DIR) + grad_files = [] + for grad_file in os.listdir(path1): + splits = grad_file.split('.') + if len(splits) < 1 or splits[-1] not in GradConst.GRAD_FILE_SUFFIX: + continue + folder2 = os.path.join(path2, grad_file) + if not os.path.exists(folder2): + continue + grad_files.append(grad_file) + return sorted(grad_files) + + @classmethod + def _calculate_similarity(cls, grad_file1: str, grad_file2: str): + npy1, npy2 = cls._load_grad_files(grad_file1, grad_file2) + same_count = (npy1 == npy2).sum() + total_count = npy1.size + return same_count, total_count + + @classmethod + @abstractmethod + def _load_grad_files(cls, grad_file1: str, grad_file2: str): + raise NotImplementedError("_load_grad_files is not implemented.") diff --git a/debug/accuracy_tools/grad_tool/common/base_monitor.py b/debug/accuracy_tools/grad_tool/common/base_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..1faeba79a3b99655be1053127cd82d73849e4a8b --- /dev/null +++ b/debug/accuracy_tools/grad_tool/common/base_monitor.py @@ -0,0 +1,13 @@ +from abc import ABC, abstractmethod + +from grad_tool.common.utils import get_config + + +class BaseMonitor(ABC): + + def __init__(self, config_file): + self.config = get_config(config_file) + + @abstractmethod + def monitor(self, module): + raise NotImplementedError("monitor is not implemented.") diff --git a/debug/accuracy_tools/grad_tool/common/constant.py b/debug/accuracy_tools/grad_tool/common/constant.py new file mode 100644 index 0000000000000000000000000000000000000000..902f54f5e65d4503c3197930491ea2882b2d42b8 --- /dev/null +++ b/debug/accuracy_tools/grad_tool/common/constant.py @@ -0,0 +1,48 @@ + +class GradConst: + + FRAMEWORKS = {"PyTorch", "MindSpore"} + PYTORCH = "PyTorch" + MindSpore = "MindSpore" + + GRAD_FILE_SUFFIX = {"npy", "pt"} + NPY_SUFFIX = "npy" + PT_SUFFIX = "pt" + + # for callback + CURRENT_STEP = "current_step" + + PARAM_LIST = "param_list" + RANK = "rank" + STEP = "step" + BOUNDS = "bounds" + OUTPUT_PATH = "output_path" + + # level const + LEVEL = "level" + LEVEL0 = "L0" + LEVEL1 = "L1" + LEVEL2 = "L2" + LEVEL3 = "L3" + SUPPORTED_LEVEL = {"L0", "L1", "L2", "L3"} + + # numpy coding + STEP_IDX = 0 + SHAPE_DIM_IDX = 4 + MAX_SIZE = 10 * 1024 * 1024 * 1024 + + # direction suffix + DIR_SUFFIX = "dir.npy" + + # file safty + DATA_DIR_AUTHORITY = 0o750 + DATA_FILE_AUTHORITY = 0o640 + DIRECTORY_LENGTH = 4096 + FILE_NAME_LENGTH = 255 + FILE_VALID_PATTERN = r"^[a-zA-Z0-9_.:/-]+$" + DIR = "dir" + FILE = "file" + + STEP_FINISH = "step_finish" + + SUMMARY = "summary" diff --git a/debug/accuracy_tools/grad_tool/common/utils.py b/debug/accuracy_tools/grad_tool/common/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b63e95578a13b54a7de5f9ca3886d10b72418e0a --- /dev/null +++ b/debug/accuracy_tools/grad_tool/common/utils.py @@ -0,0 +1,213 @@ +import os +import re +import sys +import time +import yaml + +import pandas as pd + +from grad_tool.common.constant import GradConst + + +def _print_log(level, msg, end='\n'): + current_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(int(time.time()))) + pid = os.getgid() + print(current_time + "(" + str(pid) + ")-[" + level + "]" + msg, end=end) + sys.stdout.flush() + + +def print_info_log(info_msg, end='\n'): + """ + Function Description: + print info log. + Parameter: + info_msg: the info message. + """ + _print_log("INFO", info_msg, end=end) + + +def print_error_log(error_msg): + """ + Function Description: + print error log. + Parameter: + error_msg: the error message. + """ + _print_log("ERROR", error_msg) + + +def print_warn_log(warn_msg): + """ + Function Description: + print warn log. + Parameter: + warn_msg: the warning message. + """ + _print_log("WARNING", warn_msg) + + +def write_csv(filepath, content_list, header): + if not os.path.exists(filepath): + make_file_safety(filepath) + data_frame = pd.DataFrame(columns=header) + data_frame.to_csv(filepath, index=False) + + check_file_or_directory_path(filepath) + new_data = pd.DataFrame(list(content for content in content_list)) + new_data.to_csv(filepath, mode='a+', header=False, index=False) + + +def make_file_safety(file_path: str, permission=0o640): + if os.path.islink(file_path): + raise RuntimeError(f"Invalid soft link path: {file_path}") + file_real_path = os.path.realpath(file_path) + if os.path.exists(file_real_path): + return + parent_path = os.path.dirname(file_real_path) + if not os.path.exists(parent_path): + os.makedirs(parent_path, mode=GradConst.DATA_DIR_AUTHORITY, exist_ok=True) + if not os.access(parent_path, os.W_OK): + raise PermissionError(f"The path {parent_path} is not writable!") + try: + os.close(os.open(file_real_path, os.O_WRONLY | os.O_CREAT, permission)) + except OSError as e: + raise RuntimeError("Can't create file: " + file_real_path) from e + os.chmod(file_real_path, permission) + + +def data_in_list_target(data, lst): + return not lst or len(lst) == 0 or data in lst + + +def check_numeral_list_ascend(lst): + if any(not isinstance(item, (int, float)) for item in lst): + raise Exception("The input list should only contain numbers") + if lst != sorted(lst): + raise Exception("The input list should be ascending") + + +class ListCache(list): + threshold = 1000 + + def __init__(self, *args): + super().__init__(*args) + + def __del__(self): + self.flush() + + def flush(self): + if len(self) == 0: + return + if not self._output_file: + print_warn_log("dumpfile path is not setted") + write_csv(self._output_file, self, []) + print_info_log(f"write {len(self)} items to {self._output_file}.") + self.clear() + + def append(self, data): + list.append(self, data) + if len(self) >= ListCache.threshold: + self.flush() + + def set_output_file(self, output_file): + self._output_file = output_file + + +def get_config(filepath): + with open(filepath, 'r') as file: + config = yaml.safe_load(file) + return config + + +def check_link(path): + abs_path = os.path.abspath(path) + if os.path.islink(abs_path): + raise RuntimeError("The path is a soft link.") + + +def check_path_length(path, name_length=None): + file_max_name_length = name_length if name_length else GradConst.FILE_NAME_LENGTH + if len(path) > GradConst.DIRECTORY_LENGTH or \ + len(os.path.basename(path)) > file_max_name_length: + raise RuntimeError("The file path length exceeds limit.") + + +def check_path_pattern_vaild(path): + if not re.match(GradConst.FILE_VALID_PATTERN, path): + raise RuntimeError("The file path contains special characters.") + + +def check_path_readability(path): + if not os.access(path, os.R_OK): + raise RuntimeError("The file path is not readable.") + + +def check_path_writability(path): + if not os.access(path, os.W_OK): + raise RuntimeError("The file path is not writable.") + + +def check_path_owner_consistent(path): + file_owner = os.stat(path).st_uid + if file_owner != os.getuid(): + raise RuntimeError("The file path may be insecure because is does not belong to you.") + + +def _user_interactive_confirm(message): + while True: + check_message = input(message + " Enter 'c' to continue or enter 'e' to exit: ") + if check_message == "c": + break + elif check_message == "e": + print_warn_log("User canceled.") + raise RuntimeError("User canceled.") + else: + print("Input is error, please enter 'c' or 'e'.") + + +def check_file_size(file_path, max_size=GradConst.MAX_SIZE): + file_size = os.path.getsize(file_path) + if file_size >= max_size: + _user_interactive_confirm(f'The size of file path {file_path} exceeds {max_size} bytes.' + f'Do you want to continue?') + + +def check_path_type(file_path, file_type): + if file_type == GradConst.FILE: + if not os.path.isfile(file_path): + raise RuntimeError("The path should be a file!") + if file_type == GradConst.DIR: + if not os.path.isdir(file_path): + raise RuntimeError("The path should be a dictionary!") + + +def check_path_exists(path): + if not os.path.exists(path): + raise RuntimeError("The file path does not exist.") + + +def path_valid_check(path): + check_path_length(path) + check_path_pattern_vaild(path) + + +def check_file_or_directory_path(path, file_type=GradConst.FILE): + check_link(path) + check_path_exists(path) + check_path_length(path) + check_path_pattern_vaild(path) + check_path_owner_consistent(path) + check_path_type(path, file_type) + if file_type == GradConst.FILE: + check_path_readability(path) + check_file_size(path) + else: + check_path_writability(path) + + +def create_directory(dir_path): + dir_path = os.path.realpath(dir_path) + try: + os.makedirs(dir_path, mode=GradConst.DATA_DIR_AUTHORITY, exist_ok=True) + except OSError as ex: + raise RuntimeError("Failed to create directory. Please check the path permission or disk space.") from ex diff --git a/debug/accuracy_tools/grad_tool/grad_comparator.py b/debug/accuracy_tools/grad_tool/grad_comparator.py index f2daf988cc4b45c48373f9f00a03d826da802513..868c03c255bedeae7c6daf9e8ad27f00f2614033 100644 --- a/debug/accuracy_tools/grad_tool/grad_comparator.py +++ b/debug/accuracy_tools/grad_tool/grad_comparator.py @@ -1,127 +1,28 @@ import os import torch -from tqdm import tqdm -import matplotlib.pyplot as plt -from grad_tool.utils import write_csv, path_check, print_info_log, create_directory +import numpy as np +from grad_tool.common.constant import GradConst -class GradComparator: - @staticmethod - def compare_distributed(path1: str, path2: str, output_dir): - ranks = GradComparator._get_matched_dirs(path1, path2, "rank") - print_info_log(f"the following ranks will be compared: {ranks}") - if not ranks: - raise Exception("no matched ranks for comparison, please dump data in same configuration") - if not os.path.isdir(output_dir): - create_directory(output_dir) - for rank in tqdm(ranks, desc="rank"): - print_info_log(f"now comparing rank {rank}:") - GradComparator.compare(os.path.join(path1, f"rank_{rank}"), - os.path.join(path2, f"rank_{rank}"), - os.path.join(output_dir, f"rank_{rank}")) - - @staticmethod - def compare(path1: str, path2: str, output_dir): - steps = GradComparator._get_matched_dirs(path1, path2, "step") - if not steps: - raise Exception("no matched steps for comparison, please dump data in same configuration") - similarities = GradComparator._calculate_separated_similarities(path1, path2, steps) - if not os.path.isdir(output_dir): - create_directory(output_dir) - GradComparator._save_similarities(similarities, steps, output_dir) - - @staticmethod - def _calculate_separated_similarities(path1, path2, steps): - similarities = {} - print_info_log(f"{len(steps)} steps will be compared") - for step in tqdm(steps, desc="culculate similarities (by step)"): - pt_files = GradComparator._get_matched_pt_files(path1, path2, step) - same_count_summary = 0 - total_count_summary = 0 - for pt_file in pt_files: - pt1 = os.path.join(path1, f"step_{step}", pt_file) - pt2 = os.path.join(path2, f"step_{step}", pt_file) - same_count, total_count = GradComparator._calculate_similarity(pt1, pt2) - same_count_summary += same_count - total_count_summary += total_count - param_name = pt_file[:-3] - if param_name not in similarities: - similarities[param_name] = [] - if total_count == 0: - similarities[param_name].append(0) - else: - similarities[param_name].append(same_count / total_count) - if "summary" not in similarities: - similarities["summary"] = [] - if total_count_summary == 0: - similarities["summary"].append(0) - else: - similarities["summary"].append(same_count_summary / total_count_summary) - return similarities - @staticmethod - def _get_matched_dirs(path1: str, path2: str, dir_prefix): - path_check(path1, isdir=True) - path_check(path2, isdir=True) - dirs = [] - for dirname in os.listdir(path1): - splits = dirname.split('_') - if not splits or splits[0] != dir_prefix or not splits[1].isdigit(): - continue - - folder2 = os.path.join(path2, dirname) - if not os.path.isdir(folder2): - continue - dirs.append(int(splits[1])) - dirs = sorted(dirs) - return dirs - - @staticmethod - def _get_matched_pt_files(path1: str, path2: str, step: int): - path1 = os.path.join(path1, f"step_{step}") - path2 = os.path.join(path2, f"step_{step}") - path_check(path1, isdir=True) - path_check(path2, isdir=True) - pt_files = [] - for dirname in os.listdir(path1): - splits = dirname.split('.') - if len(splits) < 1 or splits[-1] != 'pt': - continue - folder2 = os.path.join(path2, dirname) - if not os.path.exists(folder2): - continue - pt_files.append(dirname) - return sorted(pt_files) +class GradComparator: @staticmethod - def _save_similarities(similarities: [float], steps: [int], output_dir: str): - if not similarities: - raise Exception(f"length of similarities is 0") - for key, value in tqdm(similarities.items(), desc="save similarities (by param)"): - if len(value) != len(steps): - raise Exception(f"similarities length of {key}:{len(value)} not equal steps:{len(steps)}") - plt.plot(steps, value) - plt.xlabel('steps') - plt.ylabel('similarities') - plt.title(f'{key}_similarities') - picture_dir = os.path.join(output_dir, "similarities_picture") - if not os.path.isdir(picture_dir): - create_directory(picture_dir) - plt.savefig(os.path.join(picture_dir, f"{key}_similarities.png")) - plt.close() - head_tuple = tuple(['step'] + [str(step) for step in steps]) - write_csv(os.path.join(output_dir, "similarities.csv"), [[key] + value], head_tuple) + def compare(path1: str, path2: str, output_dir: str, framework="PyTorch"): + if framework not in GradConst.FRAMEWORKS: + raise RuntimeError(f"{framework} is not supported! Choose from {GradConst.FRAMEWORKS}.") + if framework == GradConst.PYTORCH: + from grad_tool.grad_pt.grad_comparator import PtGradComparator as grad_comparator + else: + from grad_tool.grad_ms.grad_comparator import MsGradComparator as grad_comparator + grad_comparator.compare(path1, path2, output_dir) @staticmethod - def _calculate_similarity(pt_file1: str, pt_file2: str): - tensor1 = torch.load(pt_file1, map_location=torch.device("cpu")) - tensor2 = torch.load(pt_file2, map_location=torch.device("cpu")) - if tensor1.shape != tensor2.shape: - raise Exception(f"tensor shape is not equal: {pt_file1}, {pt_file2}") - if tensor1.dtype != torch.bool: - raise Exception(f"tensor type is not bool: {pt_file1}") - if tensor2.dtype != torch.bool: - raise Exception(f"tensor type is not bool: {pt_file2}") - same_count = (tensor1 == tensor2).sum().item() - total_count = tensor1.numel() - return same_count, total_count + def compare_distributed(path1: str, path2: str, output_dir: str, framework="PyTorch"): + if framework not in GradConst.FRAMEWORKS: + raise RuntimeError(f"{framework} is not supported! Choose from {GradConst.FRAMEWORKS}.") + if framework == GradConst.PYTORCH: + from grad_tool.grad_pt.grad_comparator import PtGradComparator as grad_comparator + else: + from grad_tool.grad_ms.grad_comparator import MsGradComparator as grad_comparator + grad_comparator.compare_distributed(path1, path2, output_dir) diff --git a/debug/accuracy_tools/grad_tool/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_monitor.py index 9324d60932e80d2aa2c2063f15c41ee26b6b2529..8b6c6bb4a3c70ede7b0a58b776a70be265544a5e 100644 --- a/debug/accuracy_tools/grad_tool/grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/grad_monitor.py @@ -1,68 +1,18 @@ -import os -from collections import defaultdict -import torch -from torch.optim.optimizer import register_optimizer_step_pre_hook -from grad_tool.level_adapter import Level, LevelAdapter -from grad_tool.grad_stat_csv import GradStatCsv -from grad_tool.utils import get_config, check_numeral_list_ascend, data_in_list_target,\ - write_csv, get_rank_id, print_info_log, create_directory, print_warn_log, print_rank_0 +from grad_tool.common.constant import GradConst +from grad_tool.common.utils import print_warn_log class GradientMonitor: - default_bounds = [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10] - def __init__(self, config_filepath): - config = get_config(config_filepath) - self._level_adp: Level = LevelAdapter.level_adapter(config.get("level")) - self._param_list = config.get('param_list') - self._target_ranks = config.get("rank") - print_info_log(f"target rank {self._target_ranks}") - self._target_step = config.get("step") - print_info_log(f"target step {self._target_step}") - self._bounds = config.get("bounds") - if not self._bounds or len(self._bounds) == 0: - self._bounds = GradientMonitor.default_bounds - check_numeral_list_ascend(self._bounds) - self._output_path = config.get("output_path") - if not os.path.isdir(self._output_path): - create_directory(self._output_path) + def __init__(self, config_path, framework="PyTorch") -> None: + self.framework = framework + if self.framework not in GradConst.FRAMEWORKS: + raise RuntimeError(f"{self.framework} is not supported! Choose from {GradConst.FRAMEWORKS}.") + if self.framework == GradConst.PYTORCH: + from grad_tool.grad_pt.grad_monitor import PtGradientMonitor as grad_monitor else: - print_warn_log(f"the file in {self._output_path} will be recoverd") - self._step = -1 - self._param2name = defaultdict(str) - - def _rank_in_targets(self): - if not hasattr(self, "_rank"): - raise AttributeError("grad monitor need attribute {_rank}") - return not torch.distributed.is_initialized() or data_in_list_target(getattr(self, "_rank"), self._target_ranks) + from grad_tool.grad_ms.grad_monitor import MsGradientMonitor as grad_monitor + self.grad_monitor = grad_monitor(config_path) - def _hook_optimizer(self): - def optimizer_pre_step_hook(optimizer, args, kargs): - self._step += 1 - if not data_in_list_target(self._step, self._target_step): - return - output_lines = [] - for param, param_name in self._param2name.items(): - if not data_in_list_target(param_name, self._param_list): - continue - grad = param.main_grad if hasattr(param, "main_grad") else param.grad - grad_info = GradStatCsv.generate_csv_line( - level=self._level_adp, - param_name=param_name, - grad=grad, - bounds=self._bounds) - output_lines.append(grad_info) - self._level_adp.save_grad_direction(param_name, grad, f'{self._output_path}/rank_{self._rank}/step_{self._step}') - output_path = os.path.join(self._output_path, f"rank_{getattr(self, '_rank')}", f"grad_summary_{self._step}.csv") - write_csv(output_path, output_lines, GradStatCsv.generate_csv_header(level=self._level_adp, bounds=self._bounds)) - register_optimizer_step_pre_hook(optimizer_pre_step_hook) - - def monitor(self, model): - print_rank_0("> parameter names:") - for name, param in model.named_parameters(): - self._param2name[param] = name - print_rank_0(f"\t{name}") - setattr(self, "_rank", get_rank_id()) - if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks): - return - self._hook_optimizer() + def monitor(self, module): + self.grad_monitor.monitor(module) diff --git a/debug/accuracy_tools/grad_tool/grad_ms/__init__.py b/debug/accuracy_tools/grad_tool/grad_ms/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/grad_tool/grad_ms/global_context.py b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py new file mode 100644 index 0000000000000000000000000000000000000000..91806ee6a62a5f7d6245c634f4ca07d9c164dcba --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_ms/global_context.py @@ -0,0 +1,78 @@ +import os +import threading +from typing import Dict, List, Union + +from grad_tool.common.utils import print_warn_log +from grad_tool.common.constant import GradConst +from grad_tool.common.utils import path_valid_check, create_directory + + +class GlobalContext: + + _instance = None + _instance_lock = threading.Lock() + _setting = { + GradConst.LEVEL: GradConst.LEVEL0, + GradConst.PARAM_LIST: None, + GradConst.RANK: None, + GradConst.STEP: [0, 0], + GradConst.CURRENT_STEP: 0, + GradConst.BOUNDS: [-10., -1., -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1., 10.], + GradConst.OUTPUT_PATH: "./grad_stat" + } + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance_lock.acquire() + cls._instance = object.__new__(cls) + cls._instance_lock.release() + return cls._instance + + def init_context(self, config_dict: Dict): + if config_dict.get(GradConst.LEVEL, None) in GradConst.SUPPORTED_LEVEL: + self._setting[GradConst.LEVEL] = config_dict.get(GradConst.LEVEL) + self._set_input_list(config_dict, GradConst.PARAM_LIST, str) + self._set_input_list(config_dict, GradConst.RANK, int) + self._set_input_list(config_dict, GradConst.STEP, int) + step_list = self._setting.get(GradConst.STEP) + if len(step_list) != 2: + raise ValueError("Two interger are required for step in mindspore mode.") + self._set_input_list(config_dict, GradConst.BOUNDS, float) + output_path = config_dict.get(GradConst.OUTPUT_PATH) + if output_path: + try: + path_valid_check(output_path) + except RuntimeError as err: + print_warn_log(f"Invalid output_path, use default output_path. The error message is {err}.") + output_path = None + if output_path: + self._setting[GradConst.OUTPUT_PATH] = output_path + if not os.path.isdir(self._setting.get(GradConst.OUTPUT_PATH)): + create_directory(self._setting.get(GradConst.OUTPUT_PATH)) + else: + print_warn_log("The output_path exists, the data will be covered.") + + def get_context(self, key: str): + if key not in self._setting: + print_warn_log(f"Unrecognized {key}.") + return self._setting.get(key) + + def update_step(self): + self._setting[GradConst.CURRENT_STEP] += 1 + + def _set_input_list(self, config_dict: Dict, name: str, dtype: Union[int, str, float]): + value = config_dict.get(name) + if dtype == int: + type_str = "integer" + elif dtype == float: + type_str = "float" + else: + type_str = "string" + if value and isinstance(value, list): + if not isinstance(value[0], dtype): + print_warn_log(f"Invalid {name} which must be None or list of {type_str}") + value = None + self._setting[name] = value + + +grad_context = GlobalContext() diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py new file mode 100644 index 0000000000000000000000000000000000000000..e495297a4564c69d3af6fcc3378ae13225d29ca5 --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_analyzer.py @@ -0,0 +1,215 @@ +import os +import shutil +import time +from typing import List, Tuple +from multiprocessing import Process + +import numpy as np +import mindspore as ms +from mindspore.communication import get_rank +from mindspore.ops import operations as P +from mindspore.common.parameter import Parameter + +from grad_tool.common.constant import GradConst +from grad_tool.common.utils import ListCache, print_warn_log +from grad_tool.common.utils import create_directory, check_file_or_directory_path, write_csv +from grad_tool.grad_ms.global_context import grad_context +from grad_tool.grad_ms.global_context import GlobalContext + + +def get_rank_id(): + try: + rank_id = get_rank() + except Exception as err: + rank_id = 0 + return rank_id + + +class GradAnalyzer: + + @staticmethod + def dump(dump_dir: str, g_name: str, dump_step: Parameter, grad: ms.Tensor): + ''' + Dump gradient statistic data. + level0: [step, max, min, norm, shape_dim, shape] + level1: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + level2: [step, max, min, norm, shape_dim, shape] + grad_bool_data + level3: [step, max, min, norm, shape_dim, shape, dist_dim, dist] + grad_bool_data + ''' + dump_path = os.path.join(dump_dir, g_name) + dump_dir_path = dump_path + "_dir" + save_op = ms.ops.TensorDump() + level = grad_context.get_context(GradConst.LEVEL) + + if level == GradConst.LEVEL0 or level == GradConst.LEVEL2: + level_stat = GradAnalyzer.calculate_level0(dump_step, grad) + else: + level_stat = GradAnalyzer.calculate_level1(dump_step, grad) + + save_op(dump_path, level_stat) + if level == GradConst.LEVEL2 or level == GradConst.LEVEL3: + grad_direction = GradAnalyzer.calculate_direction(grad) + save_op(dump_dir_path, grad_direction) + + @staticmethod + def calculate_level0(dump_step: Parameter, grad: ms.Tensor): + is_bf16 = grad.dtype + max_val = grad.max().float() if is_bf16 else grad.max() + min_val = grad.min().float() if is_bf16 else grad.min() + norm_val = grad.norm().float() if is_bf16 else grad.norm() + shape = grad.shape + extrem_stat = ms.ops.stack([dump_step[0].astype(max_val.dtype), max_val, min_val, norm_val]) + shape_stat = ms.Tensor([len(shape)] + list(shape)).astype(max_val.dtype) + level0_stat = ms.ops.concat((extrem_stat, shape_stat), axis=0) + return level0_stat + + @staticmethod + def calculate_level1(dump_step: Parameter, grad: ms.Tensor): + level0_stat = GradAnalyzer.calculate_level0(dump_step, grad) + bounds = grad_context.get_context(GradConst.BOUNDS) + zero_grad = (grad == 0).sum() + dist_dim = ms.Tensor([len(bounds) + 2]).astype(level0_stat.dtype) + bucket_result = ms.ops.bucketize(grad, bounds).astype(ms.int8) + dist_stat = [(bucket_result == i).sum() for i in range(len(bounds) + 1)] + dist_stat.append(zero_grad) + dist_stat = ms.ops.stack(dist_stat, axis=0).astype(level0_stat.dtype) + element_num = dist_stat.sum() - dist_stat[-1] + if element_num != 0: + dist_stat = dist_stat / element_num + level1_stat = ms.ops.concat((level0_stat, dist_dim, dist_stat), axis=0) + return level1_stat + + @staticmethod + def calculate_direction(grad: ms.Tensor): + return grad > 0 + + +class CSVGenerator(Process): + + def __init__(self) -> None: + super().__init__() + self.dump_dir = None + self.save_dir = None + self.level = GradConst.LEVEL0 + self.cache_list = ListCache() + self.current_step = None + self.bounds = [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10], + + def init(self, context: GlobalContext): + rank_id = get_rank_id() + output_path = context.get_context(GradConst.OUTPUT_PATH) + self.level = context.get_context(GradConst.LEVEL) + self.bounds = context.get_context(GradConst.BOUNDS) + step_range = context.get_context(GradConst.STEP) + self.step_end = 0 if step_range is None else step_range[1] + self.dump_dir = f"{output_path}/rank_{rank_id}/Dump/" + self.save_dir = f"{output_path}/rank_{rank_id}/" + self.current_step = None + self.finish_flag = False + + def run(self): + while not self.finish_flag: + if not os.path.exists(self.dump_dir): + time.sleep(0.1) + continue + npy_files = os.listdir(self.dump_dir) + npy_files.sort(key=lambda x: int(x.split("_")[0])) + if not npy_files: + continue + self.traverse_files(npy_files) + shutil.rmtree(self.dump_dir) + + def traverse_files(self, npy_files: List): + for npy_file in npy_files: + file_path = os.path.join(self.dump_dir, npy_file) + while not os.path.exists(file_path): + time.sleep(0.01) + check_file_or_directory_path(file_path) + if GradConst.STEP_FINISH in npy_file: + self.cache_list.flush() + os.remove(file_path) + if self.current_step == self.step_end: + self.finish_flag = True + elif file_path.split("_")[-1] == GradConst.DIR_SUFFIX: + prefix_idx = len(npy_file.split("_")[0]) + new_name = npy_file[prefix_idx + 1:].replace("_" + GradConst.DIR_SUFFIX, "." + GradConst.NPY_SUFFIX) + if not new_name: + raise RuntimeError("Invalid dump data name.") + if self.current_step is None: + raise RuntimeError("Current record step is None.") + step_dir = os.path.join(self.save_dir, f"step_{self.current_step}") + if not os.path.exists(step_dir): + create_directory(step_dir) + dst_file = os.path.join(step_dir, new_name) + shutil.move(file_path, dst_file) + elif file_path.split(".")[-1] == GradConst.NPY_SUFFIX: + stat_data = self.load_npy_data(file_path) + if stat_data is None: + continue + step = int(stat_data[GradConst.STEP_IDX]) + if self.current_step is None or step != self.current_step: + self.current_step = step + self.create_csv_file() + self.gen_csv_line(file_path, stat_data) + os.remove(file_path) + + def load_npy_data(self, file_path: str): + stat_data = None + max_try = 10 + while max_try: + try: + stat_data = np.load(file_path) + return stat_data + except Exception as err: + print_warn_log(f"load numpy file failed, retry...") + max_try -= 1 + time.sleep(0.1) + return stat_data + + def gen_csv_line(self, file_path: str, stat_data) -> None: + shape_dim = int(stat_data[GradConst.SHAPE_DIM_IDX]) + file_name = os.path.basename(file_path) + prefix_idx = len(file_name.split("_")[0]) + param_name = file_name[(prefix_idx + 1) : -(len(GradConst.NPY_SUFFIX) + 1)] + if not param_name: + raise RuntimeError("Invalid gradient statistic file name.") + csv_line = [param_name] + if self.level == GradConst.LEVEL1 or self.level == GradConst.LEVEL3: + csv_line.extend(self.get_dist_data(shape_dim, stat_data)) + csv_line.extend(self.get_extrem_data(shape_dim, stat_data)) + self.cache_list.append(csv_line) + + def get_dist_data(self, shape_dim: int, stat_data: np.ndarray): + return list(stat_data[(shape_dim + GradConst.SHAPE_DIM_IDX + 2):]) + + def get_extrem_data(self, shape_dim: int, stat_data: np.ndarray): + extrem_data = list(stat_data[(GradConst.STEP_IDX + 1):(GradConst.STEP_IDX + 4)]) + shape_data = stat_data[(GradConst.SHAPE_DIM_IDX + 1):(GradConst.SHAPE_DIM_IDX + shape_dim + 1)] + shape_data = list(shape_data.astype(int)) + extrem_data.append(shape_data) + return extrem_data + + def create_csv_file(self): + headers = ["Param_name"] + if self.level == GradConst.LEVEL1 or self.level == GradConst.LEVEL3: + headers.extend(self.get_dist_header()) + headers.extend(self.get_extrem_headers()) + output_path = f"{self.save_dir}/grad_summary_{self.current_step}.csv" + write_csv(output_path, [], headers) + self.cache_list.set_output_file(output_path) + self.cache_list.clear() + + def get_extrem_headers(self) -> List[str]: + return ["Max", "Min", "Norm", "Shape"] + + def get_dist_header(self) -> List[str]: + intervals = [] + for i, _ in enumerate(self.bounds): + if i == 0: + intervals.append(f"(-inf, {self.bounds[i]}]") + else: + intervals.append(f"({self.bounds[i-1]}, {self.bounds[i]}]") + intervals.extend([f"({self.bounds[-1]}, inf)", "=0"]) + return intervals + +csv_generator = CSVGenerator() diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..2bfeda4387e061e346312d011a832fd8ff3d6e3a --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_comparator.py @@ -0,0 +1,23 @@ +import os +import torch +import numpy as np + +from grad_tool.common.base_comparator import BaseComparator + + +class MsGradComparator(BaseComparator): + + @classmethod + def _load_grad_files(cls, grad_file1: str, grad_file2: str): + grad1_suffix = grad_file1.split(".")[-1] + grad2_suffix = grad_file2.split(".")[-1] + grad1 = torch.load(grad_file1).numpy() if grad1_suffix == "pt" else np.load(grad_file1) + grad2 = torch.load(grad_file2).numpy() if grad2_suffix == "pt" else np.load(grad_file2) + + if grad1.shape != grad2.shape: + raise RuntimeError(f"numpy shape is not equal: {grad_file1}, {grad_file2}") + if grad1.dtype != bool: + raise TypeError(f"numpy type is not bool: {grad_file1}") + if grad2.dtype != bool: + raise TypeError(f"numpy type is not bool: {grad_file2}") + return grad1, grad2 diff --git a/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..0e762fe36c9c9eab297d869e4d0bc54c770fef9a --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_ms/grad_monitor.py @@ -0,0 +1,17 @@ +from grad_tool.common.base_monitor import BaseMonitor +from grad_tool.common.utils import print_info_log +from grad_tool.grad_ms.global_context import grad_context +from grad_tool.grad_ms.grad_analyzer import csv_generator +from grad_tool.grad_ms.hook import hook_optimizer + + +class MsGradientMonitor(BaseMonitor): + + def __init__(self, config_file: str): + super(MsGradientMonitor, self).__init__(config_file) + grad_context.init_context(self.config) + csv_generator.init(grad_context) + + def monitor(self, module): + hook_optimizer(module) + csv_generator.start() diff --git a/debug/accuracy_tools/grad_tool/grad_ms/hook.py b/debug/accuracy_tools/grad_tool/grad_ms/hook.py new file mode 100644 index 0000000000000000000000000000000000000000..00afb2b79cb74246b211a14a88f72d8fe4727bfb --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_ms/hook.py @@ -0,0 +1,49 @@ +from functools import wraps +import os +import shutil + +import mindspore +import mindspore as ms +from mindspore.common.api import jit +from mindspore.nn.optim.optimizer import Optimizer +from mindspore.common.parameter import Parameter +from mindspore.common.initializer import initializer + +from grad_tool.common.constant import GradConst +from grad_tool.common.utils import print_warn_log +from grad_tool.grad_ms.global_context import grad_context +from grad_tool.grad_ms.grad_analyzer import GradAnalyzer, get_rank_id + + +def hook_optimizer(opt: Optimizer): + func = opt.construct + g_names = [param.name for param in opt._parameters] + step_range = grad_context.get_context(GradConst.STEP) + step_start = step_range[0] + step_end = step_range[1] + param_list = grad_context.get_context(GradConst.PARAM_LIST) + rank_list = grad_context.get_context(GradConst.RANK) + rank_id = get_rank_id() + output_path = grad_context.get_context(GradConst.OUTPUT_PATH) + dump_dir = f"{output_path}/rank_{rank_id}/Dump/" + save_dir = f"{output_path}/rank_{rank_id}/" + step_finish_flag = f"{output_path}/rank_{rank_id}/Dump/{GradConst.STEP_FINISH}" + if os.path.exists(save_dir): + print_warn_log(f"Delete existing path {save_dir}.") + shutil.rmtree(save_dir) + + @jit + def new_construct(self, gradients): + if step_start <= self.dump_step[0] <= step_end: + for index, grad_value in enumerate(gradients): + if param_list and g_names[index] not in param_list: + continue + GradAnalyzer.dump(dump_dir, g_names[index], self.dump_step, grad_value) + ms.ops.TensorDump()(step_finish_flag, self.dump_step) + self.assignadd(self.dump_step, self.global_step_increase_tensor) + out = func(gradients) + return out + + if rank_list is None or rank_id in rank_list: + opt.dump_step = Parameter(initializer(0, [1], ms.int32), name="dump_step") + opt.construct = new_construct.__get__(opt, type(opt)) diff --git a/debug/accuracy_tools/grad_tool/grad_pt/__init__.py b/debug/accuracy_tools/grad_tool/grad_pt/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py new file mode 100644 index 0000000000000000000000000000000000000000..d1229b93de77ef20ba6a9f021cbe6c5e9cd1a5ec --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_comparator.py @@ -0,0 +1,18 @@ +import torch + +from grad_tool.common.base_comparator import BaseComparator + + +class PtGradComparator(BaseComparator): + + @classmethod + def _load_grad_files(cls, grad_file1: str, grad_file2: str): + tensor1 = torch.load(grad_file1, map_location=torch.device("cpu")) + tensor2 = torch.load(grad_file2, map_location=torch.device("cpu")) + if tensor1.shape != tensor2.shape: + raise RuntimeError(f"tensor shape is not equal: {grad_file1}, {grad_file2}") + if tensor1.dtype != torch.bool: + raise TypeError(f"tensor type is not bool: {grad_file1}") + if tensor2.dtype != torch.bool: + raise TypeError(f"tensor type is not bool: {grad_file2}") + return tensor1.numpy(), tensor2.numpy() diff --git a/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..246ea337b0caa996521971aecfdefeb92712f335 --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_monitor.py @@ -0,0 +1,71 @@ +import os +from collections import defaultdict + +import torch +from torch.optim.optimizer import register_optimizer_step_pre_hook +from grad_tool.common.base_monitor import BaseMonitor +from grad_tool.grad_pt.level_adapter import Level, LevelAdapter +from grad_tool.grad_pt.grad_stat_csv import GradStatCsv +from grad_tool.common.utils import check_numeral_list_ascend, data_in_list_target,\ + write_csv, print_info_log, create_directory, print_warn_log +from grad_tool.grad_pt.utils import get_rank_id, print_rank_0 + + +class PtGradientMonitor(BaseMonitor): + default_bounds = [-10, -1, -0.1, -0.01, -0.001, 0, 0.001, 0.01, 0.1, 1, 10] + + def __init__(self, config_filepath): + super(PtGradientMonitor, self).__init__(config_filepath) + self._level_adp: Level = LevelAdapter.level_adapter(self.config.get("level")) + self._param_list = self.config.get('param_list') + self._target_ranks = self.config.get("rank") + print_info_log(f"target rank {self._target_ranks}") + self._target_step = self.config.get("step") + print_info_log(f"target step {self._target_step}") + self._bounds = self.config.get("bounds") + if not self._bounds or len(self._bounds) == 0: + self._bounds = PtGradientMonitor.default_bounds + check_numeral_list_ascend(self._bounds) + self._output_path = self.config.get("output_path") + if not os.path.isdir(self._output_path): + create_directory(self._output_path) + else: + print_warn_log(f"the file in {self._output_path} will be recoverd") + self._step = -1 + self._param2name = defaultdict(str) + + def _rank_in_targets(self): + if not hasattr(self, "_rank"): + raise AttributeError("grad monitor need attribute {_rank}") + return not torch.distributed.is_initialized() or data_in_list_target(getattr(self, "_rank"), self._target_ranks) + + def _hook_optimizer(self): + def optimizer_pre_step_hook(optimizer, args, kargs): + self._step += 1 + if not data_in_list_target(self._step, self._target_step): + return + output_lines = [] + for param, param_name in self._param2name.items(): + if not data_in_list_target(param_name, self._param_list): + continue + grad = param.main_grad if hasattr(param, "main_grad") else param.grad + grad_info = GradStatCsv.generate_csv_line( + level=self._level_adp, + param_name=param_name, + grad=grad, + bounds=self._bounds) + output_lines.append(grad_info) + self._level_adp.save_grad_direction(param_name, grad, f'{self._output_path}/rank_{self._rank}/step_{self._step}') + output_path = os.path.join(self._output_path, f"rank_{getattr(self, '_rank')}", f"grad_summary_{self._step}.csv") + write_csv(output_path, output_lines, GradStatCsv.generate_csv_header(level=self._level_adp, bounds=self._bounds)) + register_optimizer_step_pre_hook(optimizer_pre_step_hook) + + def monitor(self, model): + print_rank_0("> parameter names:") + for name, param in model.named_parameters(): + self._param2name[param] = name + print_rank_0(f"\t{name}") + setattr(self, "_rank", get_rank_id()) + if torch.distributed.is_initialized() and not data_in_list_target(getattr(self, "_rank"), self._target_ranks): + return + self._hook_optimizer() diff --git a/debug/accuracy_tools/grad_tool/grad_stat_csv.py b/debug/accuracy_tools/grad_tool/grad_pt/grad_stat_csv.py similarity index 97% rename from debug/accuracy_tools/grad_tool/grad_stat_csv.py rename to debug/accuracy_tools/grad_tool/grad_pt/grad_stat_csv.py index 76a341c1fc1fea296e900ba1220c973d784bf743..c15f6612dc2f23c0cf787e55f8ebfb65577bd71e 100644 --- a/debug/accuracy_tools/grad_tool/grad_stat_csv.py +++ b/debug/accuracy_tools/grad_tool/grad_pt/grad_stat_csv.py @@ -1,17 +1,17 @@ import hashlib import torch -from grad_tool.level_adapter import Level +from grad_tool.grad_pt.level_adapter import Level class GradExtremeOps: @staticmethod def tensor_max(tensor): return torch._C._VariableFunctionsClass.max(tensor).cpu().detach().float().numpy().tolist() - + @staticmethod def tensor_min(tensor): return torch._C._VariableFunctionsClass.min(tensor).cpu().detach().float().numpy().tolist() - + @staticmethod def tensor_norm(tensor): return torch._C._VariableFunctionsClass.norm(tensor).cpu().detach().float().numpy().tolist() @@ -29,7 +29,7 @@ class GradStatOps: @staticmethod def md5_header(**kwargs): return ["MD5"] - + @staticmethod def intervals_header(**kwargs): level: Level = kwargs.get("level") @@ -43,7 +43,7 @@ class GradStatOps: @staticmethod def shape_header(**kwargs): return ["shape"] - + @staticmethod def md5_content(**kwargs): grad = kwargs.get("grad") @@ -62,7 +62,7 @@ class GradStatOps: def extremes_content(**kwargs): grad = kwargs.get("grad") return [f(grad) for f in GradExtremes.extremes.values()] - + @staticmethod def shape_content(**kwargs): grad = kwargs.get("grad") @@ -88,7 +88,7 @@ class GradStatCsv: "content": GradStatOps.shape_content }, } - + @staticmethod def generate_csv_header(**kwargs): header = ["param_name"] @@ -102,4 +102,3 @@ class GradStatCsv: for func in GradStatCsv.CSV.values(): line.extend(func["content"](**kwargs)) return line - \ No newline at end of file diff --git a/debug/accuracy_tools/grad_tool/level_adapter.py b/debug/accuracy_tools/grad_tool/grad_pt/level_adapter.py similarity index 98% rename from debug/accuracy_tools/grad_tool/level_adapter.py rename to debug/accuracy_tools/grad_tool/grad_pt/level_adapter.py index e03bf2b110dc8041b356d83843e365e91236d32a..1d906a18cdb9a842af081f90e1be144450aa5bc5 100644 --- a/debug/accuracy_tools/grad_tool/level_adapter.py +++ b/debug/accuracy_tools/grad_tool/grad_pt/level_adapter.py @@ -1,7 +1,7 @@ import os from abc import ABC, abstractmethod import torch -from grad_tool.utils import print_info_log +from grad_tool.common.utils import print_info_log class LevelOps: diff --git a/debug/accuracy_tools/grad_tool/grad_pt/utils.py b/debug/accuracy_tools/grad_tool/grad_pt/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cbccab7b480c6d38fea7c22b6779913f47e91a22 --- /dev/null +++ b/debug/accuracy_tools/grad_tool/grad_pt/utils.py @@ -0,0 +1,17 @@ +import os +import torch +import torch.distributed as dist + + +def get_rank_id(): + if torch.distributed.is_initialized(): + return torch.distributed.get_rank() + return os.getpid() + + +def print_rank_0(message): + if dist.is_initialized(): + if dist.get_rank() == 0: + print(message) + else: + print(message) diff --git a/debug/accuracy_tools/grad_tool/test/run_ut.py b/debug/accuracy_tools/grad_tool/test/run_ut.py index c73949697941d84782c4983aa484c06b1a7cbcc2..c4dcb55e1dea6b8421f26e53e93dcfc251c9e161 100644 --- a/debug/accuracy_tools/grad_tool/test/run_ut.py +++ b/debug/accuracy_tools/grad_tool/test/run_ut.py @@ -3,21 +3,38 @@ import shutil import subprocess import sys + +FRAMEWORKS = set() +PYTORCH = "PyTorch" +MINDSPORE = "MindSpore" + +try: + import torch + import torch_npu +except ImportError as err: + FRAMEWORKS.add(PYTORCH) + +try: + import mindspore +except ImportError as err: + FRAMEWORKS.add(MINDSPORE) + def run_ut(): cur_dir = os.path.realpath(os.path.dirname(__file__)) top_dir = os.path.realpath(os.path.dirname(cur_dir)) ut_path = os.path.join(cur_dir, "ut/") - src_dir = top_dir report_dir = os.path.join(cur_dir, "report") + xml_path = os.path.join(report_dir, "final.xml") + cov_path = os.path.join(report_dir, "coverage.xml") if os.path.exists(report_dir): shutil.rmtree(report_dir) os.makedirs(report_dir) - cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + report_dir + "/final.xml", - "--cov=" + src_dir, "--cov-branch", "--cov-report=xml:" + report_dir + "/coverage.xml"] - + cmd = ["python3", "-m", "pytest", ut_path, "--junitxml=" + xml_path, + "--cov=" + ut_path, "--cov-branch", "--cov-report=xml:" + cov_path] + result_ut = subprocess.Popen(cmd, shell=False, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) while result_ut.poll() is None: diff --git a/debug/accuracy_tools/grad_tool/test/ut/test_grad_csv.py b/debug/accuracy_tools/grad_tool/test/ut/test_grad_csv.py index a4b6d9dd6aee13b0797a7696c92bc76c1746186b..1f3544e06ee643e1374786bb25ba09cb5c32f5b0 100644 --- a/debug/accuracy_tools/grad_tool/test/ut/test_grad_csv.py +++ b/debug/accuracy_tools/grad_tool/test/ut/test_grad_csv.py @@ -2,8 +2,8 @@ import unittest import os import torch -from grad_tool.grad_stat_csv import GradStatCsv -from grad_tool.level_adapter import LevelAdapter +from grad_tool.grad_pt.grad_stat_csv import GradStatCsv +from grad_tool.grad_pt.level_adapter import LevelAdapter grad_tensor = torch.tensor([[-2, 2], [0.2, 0.3]]) @@ -13,19 +13,19 @@ class TestGradCSV(unittest.TestCase): def test_level_L0_header(self): self.assertEqual(['param_name', 'MD5', 'max', 'min', 'norm', 'shape'], GradStatCsv.generate_csv_header(level=LevelAdapter.level_adapter("L0"), bounds=[-1, 0, 1])) - + def test_level_L1_header(self): self.assertEqual(['param_name', 'MD5', '(-inf, -1]', '(-1, 0]', '(0, 1]', '(1, inf)', '=0', 'max', 'min', 'norm', 'shape'], GradStatCsv.generate_csv_header(level=LevelAdapter.level_adapter("L1"), bounds=[-1, 0, 1])) - + def test_level_L2_header(self): self.assertEqual(['param_name', 'MD5', 'max', 'min', 'norm', 'shape'], GradStatCsv.generate_csv_header(level=LevelAdapter.level_adapter("L2"), bounds=[-1, 0, 1])) - + def test_level_L3_header(self): self.assertEqual(['param_name', 'MD5', '(-inf, -1]', '(-1, 0]', '(0, 1]', '(1, inf)', '=0', 'max', 'min', 'norm', 'shape'], GradStatCsv.generate_csv_header(level=LevelAdapter.level_adapter("L3"), bounds=[-1, 0, 1])) - + def test_level_L0_content(self): generated_csv_line = GradStatCsv.generate_csv_line( level=LevelAdapter.level_adapter("L0"), @@ -34,7 +34,7 @@ class TestGradCSV(unittest.TestCase): bounds=[-1, 0, 1]) self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) - + def test_level_L1_content(self): generated_csv_line = GradStatCsv.generate_csv_line( level=LevelAdapter.level_adapter("L1"), @@ -43,7 +43,7 @@ class TestGradCSV(unittest.TestCase): bounds=[-1, 0, 1]) self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 0.25, 0.0, 0.5, 0.25, 0.0, 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) - + def test_level_L2_content(self): generated_csv_line = GradStatCsv.generate_csv_line( level=LevelAdapter.level_adapter("L2"), @@ -52,7 +52,7 @@ class TestGradCSV(unittest.TestCase): bounds=[-1, 0, 1]) self.assertEqual(['model.conv2d', '678a6c7d9d9716682b56fda097d0936c', 2.0, -2.0, 2.851315498352051, [2, 2]], generated_csv_line) - + def test_level_L3_content(self): generated_csv_line = GradStatCsv.generate_csv_line( level=LevelAdapter.level_adapter("L3"), diff --git a/debug/accuracy_tools/grad_tool/test/ut/test_grad_monitor.py b/debug/accuracy_tools/grad_tool/test/ut/test_grad_monitor.py index f233997e73e66f46104db6eeff7fa722e83ccbb7..4c33717f7cf5aaa162a566d9b9cd1f7e3165a9c5 100644 --- a/debug/accuracy_tools/grad_tool/test/ut/test_grad_monitor.py +++ b/debug/accuracy_tools/grad_tool/test/ut/test_grad_monitor.py @@ -29,12 +29,12 @@ class TestModule(nn.Module): super().__init__() self.linear = nn.Linear(10, 5) self.relu = nn.ReLU() - + def forward(self, x): x1 = self.linear(x) x2 = self.relu(x1) return x2 - + def test_grad_monitor(): gm = GradientMonitor(os.path.join(base_dir, "resources/test_grad_monitor.yaml")) @@ -53,19 +53,19 @@ def test_grad_monitor(): return gm -def test_save_grad(): +def test_grad_monitor_1(): gm = GradientMonitor(os.path.join(base_dir, "resources/test_save_grad.yaml")) loss_fun = nn.CrossEntropyLoss() test_module = TestModule() nn.init.constant_(test_module.linear.weight, 1.0) nn.init.constant_(test_module.linear.bias, 1.0) + gm.monitor(test_module) optimizer = torch.optim.SGD(test_module.parameters(), lr=1e-2) - for input_data, label in zip([x + 0.1 for x in inputs], labels): + for input_data, label in zip(inputs, labels): output = test_module(input_data) loss = loss_fun(output, label) optimizer.zero_grad() loss.backward() - gm.save_grad(test_module) optimizer.step() return gm @@ -73,7 +73,7 @@ def test_save_grad(): class TestGradMonitor(unittest.TestCase): def test_compare(self): gm1 = test_grad_monitor() - gm2 = test_save_grad() + gm2 = test_grad_monitor_1() compare_output_path = os.path.join(os.path.dirname(gm1._output_path), "grad_compare") GradComparator.compare_distributed(gm1._output_path, gm2._output_path, compare_output_path) items = os.listdir(compare_output_path) diff --git a/debug/accuracy_tools/grad_tool/test/ut_ms/test_grad_monitor.py b/debug/accuracy_tools/grad_tool/test/ut_ms/test_grad_monitor.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/debug/accuracy_tools/grad_tool/utils.py b/debug/accuracy_tools/grad_tool/utils.py deleted file mode 100644 index a7db58be29a9fbf55c04a8283cb68d09f9289796..0000000000000000000000000000000000000000 --- a/debug/accuracy_tools/grad_tool/utils.py +++ /dev/null @@ -1,75 +0,0 @@ -import os -import yaml -import torch -import torch.distributed as dist -import pandas as pd -from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileOpen, create_directory, \ - FileChecker, FileCheckConst -from ptdbg_ascend.src.python.ptdbg_ascend.common.utils import check_file_or_directory_path, print_info_log, \ - print_warn_log - - -def get_config(filepath): - with FileOpen(filepath, 'r') as file: - config = yaml.safe_load(file) - return config - - -def write_csv(filepath, content_list, header): - if not os.path.exists(filepath): - make_file_safety(filepath) - data_frame = pd.DataFrame(columns=header) - data_frame.to_csv(filepath, index=False) - - filepath_checker = FileChecker(filepath, FileCheckConst.FILE) - filepath_checker.common_check() - new_data = pd.DataFrame(list(content for content in content_list)) - new_data.to_csv(filepath, mode='a+', header=False, index=False) - print_info_log(f"write {len(content_list)} items to {filepath}") - - -def make_file_safety(file_path: str, permission=0o640): - if os.path.islink(file_path): - raise RuntimeError(f"Invalid soft link path: {file_path}") - file_real_path = os.path.realpath(file_path) - if os.path.exists(file_real_path): - return - parent_path = os.path.dirname(file_real_path) - if not os.path.exists(parent_path): - create_directory(parent_path) - if not os.access(parent_path, os.W_OK): - raise PermissionError(f"The path {parent_path} is not writable!") - try: - os.close(os.open(file_real_path, os.O_WRONLY | os.O_CREAT, permission)) - except OSError as e: - raise RuntimeError("Can't create file: " + file_real_path) from e - os.chmod(file_real_path, permission) - - -def data_in_list_target(data, lst): - return not lst or len(lst) == 0 or data in lst - - -def check_numeral_list_ascend(lst): - if any(not isinstance(item, (int, float)) for item in lst): - raise Exception("The input list should only contain numbers") - if lst != sorted(lst): - raise Exception("The input list should be ascending") - - -def get_rank_id(): - if torch.distributed.is_initialized(): - return torch.distributed.get_rank() - return os.getpid() - - -def path_check(path, isdir=False): - check_file_or_directory_path(path, isdir) - - -def print_rank_0(message): - if dist.is_initialized(): - if dist.get_rank() == 0: - print(message) - else: - print(message)