From 681b0608d230ad34c50c1f1dbff8f67f2bc7b8da Mon Sep 17 00:00:00 2001 From: pxp1 <958876660@qq.com> Date: Thu, 3 Apr 2025 15:46:13 +0800 Subject: [PATCH] =?UTF-8?q?=E9=80=82=E9=85=8Dmindspeed=5Frl=20dump?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../msprobe/pytorch/__init__.py | 1 + .../pytorch/hook_module/support_wrap_ops.yaml | 4 +- .../msprobe/pytorch/hook_module/wrap_aten.py | 7 +- .../accuracy_tools/msprobe/pytorch/service.py | 1 + .../msprobe/pytorch/single_save/__init__.py | 0 .../pytorch/single_save/single_save.py | 346 ++++++++++++++++++ 6 files changed, 354 insertions(+), 5 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/pytorch/single_save/__init__.py create mode 100644 debug/accuracy_tools/msprobe/pytorch/single_save/single_save.py diff --git a/debug/accuracy_tools/msprobe/pytorch/__init__.py b/debug/accuracy_tools/msprobe/pytorch/__init__.py index ce84e6b35b..1d16893774 100644 --- a/debug/accuracy_tools/msprobe/pytorch/__init__.py +++ b/debug/accuracy_tools/msprobe/pytorch/__init__.py @@ -18,6 +18,7 @@ from .compare.distributed_compare import compare_distributed from .compare.pt_compare import compare from .common.utils import seed_all from .debugger.precision_debugger import PrecisionDebugger, module_dump, module_dump_end +from .single_save.single_save import SingleSave, SingleComparator torch_version_above_or_equal_2 = torch.__version__.split('+')[0] >= '2.0' if torch_version_above_or_equal_2: diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml index c57f8c182a..71fe38d3fc 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/support_wrap_ops.yaml @@ -75,7 +75,7 @@ functional: - rrelu - rrelu_ - logsigmoid - - gelu + # - gelu - hardshrink - tanhshrink - softsign @@ -90,7 +90,7 @@ functional: - hardsigmoid - linear - bilinear - - silu + # - silu - hardswish - embedding - embedding_bag diff --git a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py index 637bae33ad..54537b7700 100644 --- a/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py +++ b/debug/accuracy_tools/msprobe/pytorch/hook_module/wrap_aten.py @@ -36,9 +36,10 @@ for f in dir(torch.ops.aten): def get_aten_ops(): - global wrap_aten_ops - _all_aten_ops = dir(torch.ops.aten) - return set(wrap_aten_ops) & set(_all_aten_ops) + # global wrap_aten_ops + # _all_aten_ops = dir(torch.ops.aten) + # return set(wrap_aten_ops) & set(_all_aten_ops) + return set() class HOOKAtenOP(object): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 69bd385fd7..455f161307 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -299,6 +299,7 @@ class Service: if self.config.task == Const.TENSOR: self.data_collector.data_processor.dump_async_data() self.data_collector.write_json() + self.reset_status() def step(self): if self.config.level == Const.LEVEL_DEBUG: diff --git a/debug/accuracy_tools/msprobe/pytorch/single_save/__init__.py b/debug/accuracy_tools/msprobe/pytorch/single_save/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/debug/accuracy_tools/msprobe/pytorch/single_save/single_save.py b/debug/accuracy_tools/msprobe/pytorch/single_save/single_save.py new file mode 100644 index 0000000000..d58d4aea38 --- /dev/null +++ b/debug/accuracy_tools/msprobe/pytorch/single_save/single_save.py @@ -0,0 +1,346 @@ +import os +import json +import torch +import pandas as pd +from tqdm import tqdm +import multiprocessing + +from msprobe.pytorch.common.utils import get_rank_if_initialized, save_pt, logger +from msprobe.core.common.file_utils import create_directory, check_file_or_directory_path +from msprobe.pytorch.compare.distributed_compare import compare_distributed + + +support_nested_data_type = (list, tuple, dict) +support_inner_data_type = (torch.Tensor, ) + + +class SingleSave: + _instance = None + + def __new__(cls, dump_path): + if cls._instance is None: + cls._instance = super(SingleSave, cls).__new__(cls) + cls._instance.dump_path = dump_path + try: + cls._instance.rank = get_rank_if_initialized() + except: + cls._instance.rank = "" + cls._instance.dump_dir = create_directory(cls._instance.dump_path) + cls._instance.step_count = 0 + cls._instance.cache_dict = {} + + return cls._instance + + def __init__(self, dump_path): + self.dump_path = dump_path + + @staticmethod + def save_dict_to_json(data_dict, file_path): + """ + 将字典保存为 JSON 文件。 + + :param data_dict: 要保存的字典数据 + :param file_path: 保存的文件路径 + """ + try: + # 打开文件,以写入模式 + with open(file_path, 'w', encoding='utf-8') as file: + # 将字典转为 JSON 格式并写入文件 + # ensure_ascii=False 保证中文等非 ASCII 字符能正常显示 + # indent=4 使 JSON 文件内容格式化,更易读 + json.dump(data_dict, file, ensure_ascii=False, indent=4) + logger.info(f"SingleSave: result save to {file_path}") + except Exception as e: + logger.error(f"SingleSave: save json error: {e}") + + @staticmethod + def _analyze_tensor_data(data, data_name=None, save_dir=None): + ''' + data: torch.Tensor + return: + result_data: with keys {"max", "min", "mean", "norm", "shape"} + ''' + if not data.is_floating_point() or data.dtype == torch.float64: + data = data.to(torch.float32) + result_data = {} + result_data["max"] = torch.max(data).item() + result_data["min"] = torch.min(data).item() + result_data["mean"] = torch.mean(data).item() + result_data["norm"] = torch.norm(data).item() + result_data["shape"] = list(data.shape) + if save_dir is not None and data_name is not None: + real_save_path = os.path.join(save_dir, data_name + ".pt") + save_pt(data, real_save_path) + return result_data + + @classmethod + def _analyze_list_tuple_data(cls, data, data_name=None, save_dir=None): + lst = [] + for index, element in enumerate(data): + analyze_func = cls._get_analyze_func_for_inner_variable(type(element)) + if analyze_func is None: + raise TypeError(f"SingleSave: Unsupported type: {type(element)}") + element_name = data_name + "." + str(index) + lst.append(analyze_func(element, element_name, save_dir)) + return lst + + @classmethod + def _analyze_dict_data(cls, data, data_name=None, save_dir=None): + result_data = {} + for key, value in data.items(): + analyze_func = cls._get_analyze_func_for_inner_variable(type(value)) + if analyze_func is None: + raise TypeError(f"SingleSave: Unsupported type: {type(value)}") + key_name = data_name + "." + str(key) + result_data[key] = analyze_func(value, key_name, save_dir) + return result_data + + @classmethod + def _get_analyze_func_for_inner_variable(cls, real_type): + func_mapping = { + torch.Tensor: cls._analyze_tensor_data, + } + return func_mapping.get(real_type) + + @classmethod + def save_config(cls, data, dump_path): + create_directory(dump_path) + dump_file = os.path.join(dump_path, 'hyperparameter.json') + torch.save(data, dump_file) + data = torch.load(dump_file) + cls.save_dict_to_json(eval(str(data)), dump_file) + return + + @classmethod + def save_ex(cls, data, micro_batch=None): + ''' + data: dict{str: Union[torch.Tensor, tuple, list]} + + return: void + ''' + + instance = cls._instance + + if not isinstance(data, dict): + logger.warning("SingleSave data type not valid, " + "should be dict. " + "Skip current save process.") + return + for key, value in data.items(): + if not isinstance(key, str) or not isinstance(value, support_nested_data_type + support_inner_data_type): + logger.warning("SingleSave inner variable type not valid, " + "key should be string, " + f"value should be one of {support_nested_data_type + support_inner_data_type}. " + f"but get {type(key)}-{type(value)} pair.") + continue + real_dump_dir = os.path.join(instance.dump_path, "data", key, f"step{instance.step_count}", f"rank{instance.rank}") + if micro_batch is not None: + real_dump_dir = os.path.join(real_dump_dir, f"micro_step{micro_batch}") + create_directory(real_dump_dir) + + if isinstance(value, torch.Tensor): + result = cls._analyze_tensor_data(value, key, real_dump_dir) + elif isinstance(value, (tuple, list)): + result = cls._analyze_list_tuple_data(value, key, real_dump_dir) + elif isinstance(value, dict): + result = cls._analyze_dict_data(value, key, real_dump_dir) + + result_json = {"data": result} + json_path = os.path.join(real_dump_dir, key + ".json") + cls.save_dict_to_json(result_json, json_path) + + + @classmethod + def step(cls): + instance = cls._instance + for key, value in instance.cache_dict.items(): + if not value["have_micro_batch"]: + cls.save_ex({key: value["data"][0]}) + else: + for i, data in enumerate(value["data"]): + cls.save_ex({key: data}, micro_batch=i) + instance.cache_dict = {} + instance.step_count += 1 + + @classmethod + def save(cls, data): + instance = cls._instance + if not isinstance(data, dict): + logger.warning("SingleSave data type not valid, " + "should be dict. " + "Skip current save process.") + return + for key, value in data.items(): + if key not in instance.cache_dict: + instance.cache_dict[key] = { + "have_micro_batch": False, + "data": [value] + } + else: + instance.cache_dict[key]["have_micro_batch"] = True + instance.cache_dict[key]["data"].append(value) + + + +class SingleComparator: + + def compare(self, dir1, dir2, tag, output_path="./msprobe_compare_output", num_processes=8): + check_file_or_directory_path(dir1, isdir=True) + check_file_or_directory_path(dir2, isdir=True) + if tag == "data": + self.compare_data(dir1, dir2, os.path.join(output_path, "data"), num_processes) + else: + self.compare_model_data(dir1, dir2, tag, output_path) + + def compare_model_data(self, dir1, dir2, tag, output_path): + step_dir1 = os.path.join(dir1, tag) + step_dir2 = os.path.join(dir2, tag) + output_dir = os.path.join(output_path, tag) + compare_distributed(step_dir1, step_dir2, output_dir) + + def compare_tensors(self, tensor1, tensor2): + """ + 比较两个张量,计算最大绝对误差、最大相对误差和相同元素的百分比 + """ + # 计算每个维度上的最小尺寸 + min_shape = [min(s1, s2) for s1, s2 in zip(tensor1.shape, tensor2.shape)] + # 截取张量到相同的形状 + sliced_tensor1 = tensor1[tuple(slice(0, s) for s in min_shape)] + sliced_tensor2 = tensor2[tuple(slice(0, s) for s in min_shape)] + + abs_error = torch.abs(sliced_tensor1 - sliced_tensor2) + max_abs_error = abs_error.max().item() + relative_error = abs_error / torch.max(torch.abs(sliced_tensor1), torch.abs(sliced_tensor2)) + relative_error[torch.isnan(relative_error)] = 0 + max_relative_error = relative_error.max().item() + + same_elements = (sliced_tensor1 == sliced_tensor2).sum().item() + total_elements = sliced_tensor1.numel() + same_percentage = (same_elements / total_elements) * 100 + + # 展平张量 + flat_tensor1 = sliced_tensor1.flatten() + flat_tensor2 = sliced_tensor2.flatten() + + # 计算从第几个元素开始对不上 + first_mismatch = torch.nonzero(flat_tensor1 != flat_tensor2, as_tuple=False) + first_mismatch_index = first_mismatch[0].item() if first_mismatch.numel() > 0 else None + + # 计算误差在千分之一内的元素占比 + error_within_thousandth = (abs_error <= 0.001 * torch.max(torch.abs(sliced_tensor1), torch.abs(sliced_tensor2))).sum().item() + percentage_within_thousandth = (error_within_thousandth / total_elements) * 100 + + # 计算误差在百分之一内的元素占比 + error_within_hundredth = (abs_error <= 0.01 * torch.max(torch.abs(sliced_tensor1), torch.abs(sliced_tensor2))).sum().item() + percentage_within_hundredth = (error_within_hundredth / total_elements) * 100 + + return max_abs_error, max_relative_error, same_percentage, first_mismatch_index, percentage_within_thousandth, percentage_within_hundredth + + def get_steps(self, tag_path): + for step_folder in os.listdir(tag_path): + if step_folder.startswith('step'): + step = int(step_folder[4:]) + yield step, os.path.join(tag_path, step_folder) + + def get_ranks(self, step_path): + for rank_folder in os.listdir(step_path): + if rank_folder.startswith('rank'): + rank = int(rank_folder[4:]) + yield rank, os.path.join(step_path, rank_folder) + + def get_micro_steps(self, rank_path): + for micro_step_folder in os.listdir(rank_path): + if micro_step_folder.startswith('micro_step'): + micro_step = int(micro_step_folder[10:]) + yield micro_step, os.path.join(rank_path, micro_step_folder) + else: + yield 0, rank_path + + def get_tensors(self, micro_step_path): + for file in os.listdir(micro_step_path): + if file.endswith('.pt'): + try: + parts = file.rsplit('.', 2) + if len(parts) > 1 and parts[-2].isdigit(): + tensor_id = int(parts[-2]) + else: + tensor_id = 0 + except ValueError: + tensor_id = 0 + yield tensor_id, os.path.join(micro_step_path, file) + + def get_tensor_paths(self, dir_path): + """ + 获取目录中所有符合结构的张量文件路径 + """ + tensor_paths = {} + data_path = os.path.join(dir_path, 'data') + if not os.path.exists(data_path): + return tensor_paths + for tag in os.listdir(data_path): + tag_path = os.path.join(data_path, tag) + if not os.path.isdir(tag_path): + continue + for step, step_path in self.get_steps(tag_path): + for rank, rank_path in self.get_ranks(step_path): + for micro_step, micro_step_path in self.get_micro_steps(rank_path): + for tensor_id, tensor_path in self.get_tensors(micro_step_path): + tensor_paths.setdefault(tag, []).append((step, rank, micro_step, tensor_id, tensor_path)) + return tensor_paths + + def compare_single_tag(self, tag, tensor_paths1, tensor_paths2, output_dir): + try: + data = [] + paths1 = tensor_paths1.get(tag, []) + paths2 = tensor_paths2.get(tag, []) + path_dict1 = {(step, rank, micro_step, tensor_id): path for step, rank, micro_step, tensor_id, path in paths1} + path_dict2 = {(step, rank, micro_step, tensor_id): path for step, rank, micro_step, tensor_id, path in paths2} + common_keys = set(path_dict1.keys()) & set(path_dict2.keys()) + for key in common_keys: + try: + tensor1 = torch.load(path_dict1[key]) + tensor2 = torch.load(path_dict2[key]) + result = self.compare_tensors(tensor1, tensor2) + max_abs_error, max_relative_error, same_percentage, first_mismatch_index, percentage_within_thousandth, percentage_within_hundredth = result + step, rank, micro_step, tensor_id = key + data.append([step, rank, micro_step, tensor_id, list(tensor1.shape), list(tensor2.shape), same_percentage, first_mismatch_index, max_abs_error, max_relative_error, percentage_within_thousandth, percentage_within_hundredth]) + except Exception as e: + logger.error(f"Error comparing {path_dict1[key]} and {path_dict2[key]}: {e}") + + df = pd.DataFrame(data, columns=['step', 'rank', 'micro_step', 'id', 'shape1', 'shape2', '相同元素百分比(%)', '首个不匹配元素索引', '最大绝对误差', '最大相对误差', '误差在千分之一内元素占比(%)', '误差在百分之一内元素占比(%)']) + df = df.sort_values(by=['step', 'rank', 'micro_step', 'id']) + # 构建输出文件的完整路径 + output_file_path = os.path.join(output_dir, f'{tag}.xlsx') + df.to_excel(output_file_path, index=False) + except Exception as e: + logger.error(f"Error processing tag {tag}: {e}") + + def compare_data(self, dir1, dir2, output_dir, num_processes=8): + """ + 比较两个目录中的张量文件,并将结果保存到指定目录的 Excel 文件中 + """ + # 确保输出目录存在,如果不存在则创建 + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + tensor_paths1 = self.get_tensor_paths(dir1) + tensor_paths2 = self.get_tensor_paths(dir2) + + all_tags = set(tensor_paths1.keys()) | set(tensor_paths2.keys()) + + with multiprocessing.Pool(processes=num_processes) as pool: + args = [(tag, tensor_paths1, tensor_paths2, output_dir) for tag in all_tags] + try: + results = pool.starmap_async(self.compare_single_tag, args) + with tqdm(total=len(all_tags), desc="Processing data") as pbar: + while not results.ready(): + pbar.n = len(all_tags) - results._number_left + pbar.refresh() + results.wait() + results.get() + except Exception as e: + logger.error(f"Multiprocessing error: {e}") + finally: + pool.close() + pool.join() + \ No newline at end of file -- Gitee