diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index e832348fb7957f8d7dc624faa04697ee78615470..7473836bc302a27fc2f06063fba85539a87a0033 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -105,7 +105,8 @@ class Const: LEVEL_L1 = "L1" LEVEL_L2 = "L2" LEVEL_MIX = "mix" - LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX] + LEVEL_DEBUG = "debug" + LEVEL_LIST = [LEVEL_L0, LEVEL_L1, LEVEL_L2, LEVEL_MIX, LEVEL_DEBUG] ATTR_NAME_PREFIX = "wrap_" ATTR_NAME_PREFIX_LEN = len(ATTR_NAME_PREFIX) KERNEL_DUMP = "kernel_dump" @@ -123,7 +124,7 @@ class Const: CPU_LOWERCASE = 'cpu' CUDA_LOWERCASE = 'cuda' DISTRIBUTED = 'Distributed' - DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive", + DUMP_PREFIX = ["Distributed", "Functional", "Torch", "Tensor", "Mint", "MintFunctional", "Primitive", "Aten", "VF", "NPU", "Jit"] # struct json param diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index 0506d0591e943b633b380193cf722a331191b178..d0c3751ac605fabc099312d6caa08dcaea320a7d 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -147,3 +147,15 @@ class DataCollector: def update_iter(self, current_iter): self.data_processor.update_iter(current_iter) + + def debug_data_collect_forward(self, variable, name_with_count): + self.update_api_or_module_name(name_with_count) + data_info = self.data_processor.analyze_debug_forward(variable) + self.data_writer.update_debug({name_with_count: data_info}) + + def debug_data_collect_backward(self, variable, grad_name_with_count): + # prepare all None nested data structure + all_none_data_info = self.data_processor.analyze_element_to_all_none(variable) + self.data_writer.update_debug({grad_name_with_count: all_none_data_info}) + + self.data_processor.analyze_debug_backward(variable, grad_name_with_count, self.data_writer.cache_debug) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 07bc4d6dbc085f3e1d9c94d6a480696e84cd40b3..1a98a9ccffc516ff8f17914948ccbd330433a830 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -17,6 +17,8 @@ import inspect import os from dataclasses import dataclass, is_dataclass from typing import Tuple, Dict, Optional, Any +from functools import partial +import copy import numpy as np @@ -77,11 +79,12 @@ class ModuleBackwardOutputs: class TensorStatInfo: - def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): + def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None, name=None): self.max = max_val self.min = min_val self.mean = mean_val self.norm = norm_val + self.name = name class BaseDataProcessor: @@ -102,6 +105,7 @@ class BaseDataProcessor: self.current_iter = 0 self._return_forward_new_output = False self._forward_new_output = None + self.custom_dump_data_name = None if hasattr(config, "data_mode"): self.allowed_data_mode = self._get_allowed_data_mode(config.data_mode) @@ -219,7 +223,7 @@ class BaseDataProcessor: return cls.apply_transform_dict(args_dict, transform, depth) elif isinstance(args, (list, tuple)): result_list = cls.apply_transform_list(args, transform, depth) - return type(args)(result_list) + return (result_list) elif isinstance(args, dict): return cls.apply_transform_dict(args, transform, depth) elif args is not None: @@ -227,7 +231,7 @@ class BaseDataProcessor: return None else: return None - + @classmethod def apply_transform_dict(cls, args, transform, depth): result_dict = {} @@ -279,6 +283,67 @@ class BaseDataProcessor: def analyze_element(self, element): return self.recursive_apply_transform(element, self.analyze_single_element) + def analyze_element_to_all_none(self, element): + return self.recursive_apply_transform(element, self.analyze_element_to_none) + + def analyze_element_to_none(self, element, suffix_stack): + return None + + def analyze_hook_single_element(self, element, suffix_stack, hook_fn): + if hasattr(element, "register_hook"): + if hasattr(element, "requires_grad") and not element.requires_grad: + return + indexes = copy.deepcopy(suffix_stack) + wrap_hook_fn = partial(hook_fn, index=indexes) + def real_hook_fn(grad): + return wrap_hook_fn(grad) + element.register_hook(real_hook_fn) + + @staticmethod + def get_and_set_nested_value(data_structure, index, value): + ''' + Args: + data_structure: nested data structure + index: List[str] + value: value to be set + ''' + if not index: + raise ValueError(f"index need to be non empty when set value to nested data structure") + current_level = data_structure + for i in index[:-1]: + if isinstance(current_level, list): + current_level = current_level[int(i)] + elif isinstance(current_level, dict): + current_level = current_level[i] + else: + raise ValueError(f"Unsupported type in nested structure: {type(current_level)}") + + if isinstance(current_level, list): + current_level[int(index[-1])] = value + elif isinstance(current_level, dict): + current_level[index[-1]] = value + else: + raise ValueError(f"Unsupported type for final assignment: {type(current_level)}") + + def analyze_debug_forward(self, variable): + self.api_data_category = Const.OUTPUT + data_info = self.analyze_element(variable) + return data_info + + def analyze_debug_backward(self, variable, grad_name_with_count, nested_data_structure): + def hook_fn(grad, index): + self.api_data_category = Const.OUTPUT + file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX + suffix = Const.SEP.join(index) + self.custom_dump_data_name = (grad_name_with_count + Const.SEP + Const.OUTPUT + Const.SEP + suffix + file_format) + grad_data_info = self.analyze_element(grad) + self.custom_dump_data_name = None + full_index = [grad_name_with_count] + index + self.get_and_set_nested_value(nested_data_structure, full_index, grad_data_info) + return grad + wrap_analyze_hook_single_element = partial(self.analyze_hook_single_element, hook_fn=hook_fn) + self.recursive_apply_transform(variable, wrap_analyze_hook_single_element) + def analyze_forward(self, name, module, module_input_output: ModuleForwardInputsOutputs): api_info_struct = {} # check whether data_mode contains forward or input @@ -360,6 +425,10 @@ class BaseDataProcessor: return api_info_struct def get_save_file_path(self, suffix): + if self.custom_dump_data_name: + dump_data_name = self.custom_dump_data_name + file_path = os.path.join(self.data_writer.dump_tensor_data_dir, dump_data_name) + return dump_data_name, file_path file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + suffix + file_format) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py index e99235977798102e120478c53216e982b2c82e5e..8539a2544808b7a1b78ccfafd802de5887e28470 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/json_writer.py @@ -15,6 +15,7 @@ import csv import os +import copy from msprobe.core.common.const import Const, FileCheckConst from msprobe.core.common.file_utils import change_mode, FileOpen, save_json @@ -29,10 +30,12 @@ class DataWriter: self.construct_file_path = None self.free_benchmark_file_path = None self.dump_tensor_data_dir = None + self.debug_file_path = None self.flush_size = 1000 self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} @staticmethod def write_data_to_csv(result: list, result_header: tuple, file_path: str): @@ -53,6 +56,7 @@ class DataWriter: self.cache_data = {} self.cache_stack = {} self.cache_construct = {} + self.cache_debug = {} def initialize_json_file(self, **kwargs): if not self.cache_data: @@ -63,14 +67,20 @@ class DataWriter: save_json(self.stack_file_path, self.cache_stack, indent=1) if not self.cache_construct: save_json(self.construct_file_path, self.cache_construct, indent=1) + if self.debug_file_path and not self.cache_debug: + debug_dict = copy.deepcopy(kwargs) + debug_dict.update({"dump_data_dir": self.dump_tensor_data_dir, Const.DATA: {}}) + self.cache_debug = debug_dict + save_json(self.debug_file_path, self.cache_debug, indent=1) def update_dump_paths(self, dump_file_path, stack_file_path, construct_file_path, dump_data_dir, - free_benchmark_file_path): + free_benchmark_file_path, debug_file_path=None): self.dump_file_path = dump_file_path self.stack_file_path = stack_file_path self.construct_file_path = construct_file_path self.dump_tensor_data_dir = dump_data_dir self.free_benchmark_file_path = free_benchmark_file_path + self.debug_file_path = debug_file_path def flush_data_periodically(self): dump_data = self.cache_data.get(Const.DATA) @@ -98,6 +108,9 @@ class DataWriter: def update_construct(self, new_data): self.cache_construct.update(new_data) + def update_debug(self, new_data): + self.cache_debug.update(new_data) + def write_data_json(self, file_path): logger.info(f"dump.json is at {os.path.dirname(os.path.dirname(file_path))}. ") save_json(file_path, self.cache_data, indent=1) @@ -108,6 +121,9 @@ class DataWriter: def write_construct_info_json(self, file_path): save_json(file_path, self.cache_construct, indent=1) + def write_debug_info_json(self, file_path): + save_json(file_path, self.cache_debug, indent=1) + def write_json(self): if self.cache_data: self.write_data_json(self.dump_file_path) @@ -115,3 +131,5 @@ class DataWriter: self.write_stack_info_json(self.stack_file_path) if self.cache_construct: self.write_construct_info_json(self.construct_file_path) + if self.cache_debug: + self.write_debug_info_json(self.debug_file_path) diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index d928e0a3a62504f07ffb8807a43d405773921861..5d40d9a7e75ce85ee5a68c918d9ac0392431818b 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -145,6 +145,20 @@ class PrecisionDebugger: return instance.gm.monitor(opt) + @classmethod + def save(cls, variable, name, save_backward=True): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task not in [Const.TENSOR, Const.STATISTICS]: + return + + instance.config.execution_mode = cls._get_execution_mode() + if cls._need_service(): + if not instance.service: + instance.service = Service(instance.config) + instance.service.save(variable, name, save_backward) + @classmethod def _need_service(cls): instance = cls._instance diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 9217bd3efeb1436f33ad57661e2f7ac439660906..b4e0603e6b1e47eedb302f3362db91c86a41007e 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -59,6 +59,8 @@ class Service: self.start_call = False self.check_level_valid() self.should_stop_service = False + if self.config.level == Const.LEVEL_DEBUG: + self.init_for_debug_level() @staticmethod def check_model_valid(model): @@ -150,6 +152,8 @@ class Service: primitive.__class__ = new_primitive def step(self): + if self.config.level == Const.LEVEL_DEBUG: + return self.current_iter += 1 self.data_collector.update_iter(self.current_iter) self.primitive_hook_service.primitive_counters.clear() @@ -157,6 +161,8 @@ class Service: JitDump.jit_count = defaultdict(int) def start(self, model=None): + if self.config.level == Const.LEVEL_DEBUG: + return self.start_call = True if self.should_stop_service: return @@ -220,6 +226,8 @@ class Service: JitDump.jit_dump_switch = False def stop(self): + if self.config.level == Const.LEVEL_DEBUG: + return if self.should_stop_service: return logger.info(f"{Const.TOOL_NAME}: debugger.stop() is set successfully. " @@ -237,6 +245,49 @@ class Service: self.data_collector.write_json() JitDump.jit_dump_switch = False + def save(self, variable, name, save_backward): + if self.config.level != Const.LEVEL_DEBUG: + return + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + def init_for_debug_level(self): + try: + self.current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + self.current_rank = None + + # dir: dump_path -- rank{} -- debug.json + create_directory(self.config.dump_path) + self.dump_iter_dir = self.config.dump_path + cur_rank = self.current_rank if self.current_rank is not None else '' + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + if self.config.task in self.data_collector.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths( + dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None, debug_file_path) + + self.debug_variable_counter = defaultdict(int) + def need_end_service(self): if self.config.step and self.current_iter > max(self.config.step): return True diff --git a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py index 72890761e9550adc1709fced46c38ddbf3f6e8d8..bd03ef5aea0aa6389ee171ad70dfd9db877af36d 100644 --- a/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/pytorch/debugger/precision_debugger.py @@ -123,6 +123,15 @@ class PrecisionDebugger: instance.service.start(instance.model, instance.api_origin) instance.api_origin = False + @classmethod + def save(cls, variable, name, save_backward=True): + instance = cls._instance + if not instance: + raise Exception(MsgConst.NOT_CREATED_INSTANCE) + if instance.task not in [Const.TENSOR, Const.STATISTICS]: + return + instance.service.save(variable, name, save_backward) + # 指定代码段dump前反向结束符,之后的计算过程数据将被忽略,无法被dump @classmethod def forward_backward_dump_end(cls): diff --git a/debug/accuracy_tools/msprobe/pytorch/service.py b/debug/accuracy_tools/msprobe/pytorch/service.py index 04f9136e43005d18f9eede6a84a0f2e89d546268..54071502fe9b0eadc71530ea1ca523c36b37bf93 100644 --- a/debug/accuracy_tools/msprobe/pytorch/service.py +++ b/debug/accuracy_tools/msprobe/pytorch/service.py @@ -15,6 +15,7 @@ import functools import os +from collections import defaultdict from collections import namedtuple import torch @@ -53,6 +54,8 @@ class Service: self.dump_iter_dir = None self.should_stop_service = False self.attl = None + if self.config.level == Const.LEVEL_DEBUG: + self.init_for_debug_level() @staticmethod def forward_backward_dump_end(): @@ -141,6 +144,8 @@ class Service: return HookFn(pre_forward_hook_fn, forward_hook_fn, backward_hook_fn, forward_hook_torch_version_below_2_fn) def start(self, model, api_origin=False): + if self.config.level == Const.LEVEL_DEBUG: + return if self.need_stop_service(): return @@ -169,7 +174,7 @@ class Service: def stop(self): if self.should_stop_service: return - if self.config.level == "L2": + if self.config.level in [Const.LEVEL_DEBUG, Const.LEVEL_L2]: return if self.config.step and self.current_iter not in self.config.step: return @@ -182,6 +187,8 @@ class Service: self.data_collector.write_json() def step(self): + if self.config.level == Const.LEVEL_DEBUG: + return if self.should_stop_service: return self.current_iter += 1 @@ -236,6 +243,52 @@ class Service: self.data_collector.update_dump_paths( dump_file_path, stack_file_path, construct_file_path, dump_data_dir, free_benchmark_file_path) + def save(self, variable, name, save_backward): + if self.config.level != Const.LEVEL_DEBUG: + return + count = self.debug_variable_counter[name] + self.debug_variable_counter[name] += 1 + + name_with_count = f"{name}.{count}" + grad_name_with_count = f"{name}_grad.{count}" + + # forward save + self.data_collector.debug_data_collect_forward(variable, name_with_count) + + # backward save + if save_backward: + self.data_collector.debug_data_collect_backward(variable, grad_name_with_count) + + + def init_for_debug_level(self): + try: + self.current_rank = get_rank_if_initialized() + except DistributedNotInitializedError: + self.current_rank = None + + # dir: dump_path -- rank{} -- debug.json + create_directory(self.config.dump_path) + self.dump_iter_dir = self.config.dump_path + cur_rank = self.current_rank if self.current_rank is not None else '' + dump_dir = os.path.join(self.dump_iter_dir, f"rank{cur_rank}") + create_directory(dump_dir) + if self.config.task in self.data_collector.tasks_need_tensor_data: + dump_data_dir = os.path.join(dump_dir, "dump_tensor_data") + create_directory(dump_data_dir) + else: + dump_data_dir = None + + dump_file_path = os.path.join(dump_dir, "dump.json") + stack_file_path = os.path.join(dump_dir, "stack.json") + construct_file_path = os.path.join(dump_dir, "construct.json") + free_benchmark_file_path = os.path.join(self.config.dump_path, "free_benchmark.csv") + debug_file_path = os.path.join(dump_dir, "debug.json") + self.data_collector.update_dump_paths( + dump_file_path, stack_file_path, construct_file_path, dump_data_dir, + free_benchmark_file_path, debug_file_path) + + self.debug_variable_counter = defaultdict(int) + def register_hook_new(self): logger.info_on_rank_0("The {} hook function is successfully mounted to the model.".format(self.config.task)) if self.config.level in ["L0", "mix"]: