diff --git a/.gitignore b/.gitignore index c70c40e0f527c8c20a6bf994bcb8070b95e13e27..01a2222429c34b3a10f5311f2316fa5a8b685b18 100644 --- a/.gitignore +++ b/.gitignore @@ -142,4 +142,11 @@ cython_debug/ att_advisor*.html *.xlsx operator_tuning_file*.cfg -.ipynb_checkpoints/ \ No newline at end of file +.ipynb_checkpoints/ +.idea/vcs.xml +.idea/inspectionProfiles/profiles_settings.xml +.idea/misc.xml +.idea/modules.xml +.idea/mstt_primitive.iml +.idea/.gitignore +.gitignore diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index c1a453a21a6c2f8f30f22812214e2a6e4fc53932..eff7b8be8adfee36275027a56f23925aae63b14d 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -16,6 +16,7 @@ class Const: OFF = 'OFF' BACKWARD = 'backward' FORWARD = 'forward' + PRIMITIVE_PREFIX = 'Primitive' DEFAULT_LIST = [] DEFAULT_PATH = './' WHITE_LIST = 'white_list' diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py index db437539afeb98050ce59aad87a1e79d98b84085..de2b93c206d4db32bf4c3daeba00fa8f708bd6bf 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_collector.py @@ -106,6 +106,22 @@ class DataCollector: raise Exception("[msprobe] exit") self.handle_data(name, data_info) + def backward_input_data_collect(self, name, module, pid, module_input_output): + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = self.data_processor.analyze_backward_input(name, module, module_input_output) + self.handle_data(name, data_info) + + def backward_output_data_collect(self, name, module, pid, module_input_output): + self.update_construct(name) + if not self.check_scope_and_pid(self.scope, name, pid): + return + + data_info = self.data_processor.analyze_backward_output(name, module, module_input_output) + self.handle_data(name, data_info) + def update_construct(self, name): if self.config.level not in DataCollector.level_without_construct: self.data_writer.update_construct({name: self.module_processor.api_parent_node}) diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py index 2fbc86b5656c3bcfe14b2fe9fe6bb295451e9466..13134d6198034c79fac19e993229dd6c4ca9caa5 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/base.py @@ -40,6 +40,21 @@ class ModuleBackwardInputsOutputs: def grad_output_tuple(self): return convert_tuple(self.grad_output) +@dataclass +class ModuleBackwardInputs: + grad_input: Optional[Tuple] + + @property + def grad_input_tuple(self): + return convert_tuple(self.grad_input) + +@dataclass +class ModuleBackwardOutputs: + grad_output: Optional[Tuple] + + @property + def grad_output_tuple(self): + return convert_tuple(self.grad_output) class TensorStatInfo: def __init__(self, max_val=None, min_val=None, mean_val=None, norm_val=None): @@ -228,6 +243,32 @@ class BaseDataProcessor: return api_info_struct + def analyze_backward_input(self, name, module, module_input_output: ModuleBackwardInputsOutputs): + """ + Analyze and save backward input gradients. + """ + api_info_struct = {} + if self.is_dump_for_data_mode(Const.BACKWARD, Const.INPUT): + api_info_struct[name] = {} + self.api_data_category = Const.OUTPUT + # self.api_data_category = Const.INPUT + output_info_list = self.analyze_element(module_input_output.grad_input_tuple) + api_info_struct[name][Const.GRAD_OUTPUT] = output_info_list + return api_info_struct + + def analyze_backward_output(self, name, module, module_input_output: ModuleBackwardInputsOutputs): + """ + Analyze and save backward output gradients. + """ + api_info_struct = {} + if self.is_dump_for_data_mode(Const.BACKWARD, Const.OUTPUT): + api_info_struct[name] = {} + self.api_data_category = Const.INPUT + # self.api_data_category = Const.OUTPUT + input_info_list = self.analyze_element(module_input_output.grad_output_tuple) + api_info_struct[name][Const.GRAD_INPUT] = input_info_list + return api_info_struct + def get_save_file_path(self, suffix): file_format = Const.PT_SUFFIX if self.config.framework == Const.PT_FRAMEWORK else Const.NUMPY_SUFFIX dump_data_name = (self.current_api_or_module_name + Const.SEP + self.api_data_category + Const.SEP + diff --git a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py index c208df7d900683197fc24081b42835716ce7605f..b28817e4aa7869d0a62fae42fa9df96c2543c574 100644 --- a/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py +++ b/debug/accuracy_tools/msprobe/core/data_dump/data_processor/mindspore_processor.py @@ -74,8 +74,9 @@ class MindsporeDataProcessor(BaseDataProcessor): if data.numel() == 0: return tensor_stat elif data.dtype == ms.bool_: - tensor_stat.max = self.mint_ops_func["max"](data).item() - tensor_stat.min = self.mint_ops_func["min"](data).item() + data_np = data.asnumpy() + tensor_stat.max = np.max(data_np) + tensor_stat.min = np.min(data_np) elif not data.shape: tensor_stat.max = tensor_stat.min = tensor_stat.mean = tensor_stat.norm = data.item() elif data.dtype == ms.complex64 or data.dtype == ms.complex128: diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 5475dc3586c35687fec63b51f265ac83c0d33a87..40b44c57ec9f0833502ca8150112f88519ea96de 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -41,7 +41,7 @@ class PrecisionDebugger: return MsConst.PYNATIVE_MODE @classmethod - def start(cls): + def start(cls, target=None): instance = cls._instance if not instance: raise Exception("No instance of PrecisionDebugger found.") @@ -50,7 +50,7 @@ class PrecisionDebugger: if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API: if not instance.service: instance.service = Service(instance.config) - instance.service.start() + instance.service.start(target) else: if not instance.first_start: handler = TaskHandlerFactory.create(instance.config) diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 50776aaf1097339e7c6d98944db7ddf2d2238c5f..b795ec1034290f1c2bd7ca7b77b5d6bf689dce16 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -19,6 +19,9 @@ from pathlib import Path import functools from collections import defaultdict +from mindspore.common.tensor import Tensor +from mindspore import ops +from mindspore import nn from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.scope import BaseScope from msprobe.mindspore.common.utils import get_rank_if_initialized @@ -27,7 +30,9 @@ from msprobe.mindspore.common.log import logger from msprobe.core.common.utils import Const from msprobe.core.common.exceptions import DistributedNotInitializedError from msprobe.mindspore.dump.hook_cell.api_registry import api_register -from msprobe.core.data_dump.data_processor.base import ModuleForwardInputsOutputs, ModuleBackwardInputsOutputs +from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutputs, ModuleForwardInputsOutputs,\ + ModuleBackwardInputs, ModuleBackwardOutputs +from msprobe.core.common.exceptions import MsprobeException from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell @@ -41,6 +46,7 @@ class Service: self.current_iter = 0 self.first_start = True self.current_rank = None + self.primitive_counters = {} self.dump_iter_dir = None self.start_call = False @@ -79,13 +85,154 @@ class Service: return wrap_forward_hook, wrap_backward_hook + def wrap_primitive(self, origin_func, primitive_name): + service_instance = self + + def create_backward_hook(captured_grads, num_tensors, updated_primitive_name, hook_type): + def backward_hook(grad): + captured_grads.append(grad) + try: + if len(captured_grads) == num_tensors and hook_type == Const.INPUT: + backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}" + service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name) + new_module_input_output = ModuleBackwardInputs(grad_input=tuple(captured_grads)) + service_instance.data_collector.backward_input_data_collect( + backward_primitive_name, service_instance, os.getpid(), new_module_input_output + ) + captured_grads.clear() + elif len(captured_grads) == num_tensors and hook_type == Const.OUTPUT: + backward_primitive_name = f"{updated_primitive_name}.{Const.BACKWARD}" + service_instance.data_collector.visit_and_clear_overflow_status(backward_primitive_name) + new_module_input_output = ModuleBackwardOutputs(grad_output=tuple(captured_grads)) + service_instance.data_collector.backward_output_data_collect( + backward_primitive_name, service_instance, os.getpid(), new_module_input_output + ) + captured_grads.clear() + + except Exception as exception: + raise Exception( + "This is a primitive op {hook_type}_backward dump error: {exception}," + " updated_primitive_name: {updated_primitive_name}".format( + hook_type=hook_type, exception=exception, updated_primitive_name=updated_primitive_name + ) + ) from exception + + return backward_hook + + def hook_primitive_inputs(args, captured_grads_input, updated_primitive_name): + hooked_inputs = [] + num_tensors = sum(isinstance(arg, Tensor) for arg in args) + input_backward_hook = create_backward_hook(captured_grads_input, num_tensors, updated_primitive_name, + Const.INPUT) + for _, arg in enumerate(args): + if isinstance(arg, Tensor): + arg_hooked = ops.HookBackward(input_backward_hook)(arg) + hooked_inputs.append(arg_hooked) + else: + hooked_inputs.append(arg) + return hooked_inputs + + def hook_primitive_outputs(out, captured_grads_output, updated_primitive_name): + if isinstance(out, tuple): + num_output_tensors = sum(isinstance(tensor, Tensor) for tensor in out) + else: + num_output_tensors = 1 + output_backward_hook = create_backward_hook(captured_grads_output, num_output_tensors, + updated_primitive_name, Const.OUTPUT) + + if isinstance(out, Tensor): + return ops.HookBackward(output_backward_hook)(out) + elif isinstance(out, tuple): + hooked_outputs = [] + for tensor in out: + if isinstance(tensor, Tensor): + hooked_outputs.append(ops.HookBackward(output_backward_hook)(tensor)) + else: + hooked_outputs.append(tensor) + return tuple(hooked_outputs) + return out + + def wrapped_primitive_call(instance_self, *args, **kwargs): + + service_instance.update_primitive_counters(primitive_name) + current_count = service_instance.primitive_counters[primitive_name] + updated_primitive_name = f"{Const.PRIMITIVE_PREFIX}.{primitive_name}.{current_count}" + + if not service_instance.switch: + return origin_func(*args, **kwargs) + + captured_grads_input, captured_grads_output = [], [] + + try: + hooked_inputs = hook_primitive_inputs(args, captured_grads_input, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during input hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + try: + out = origin_func(*hooked_inputs, **kwargs) + except Exception as exception: + raise Exception("This is a primitive op dump error during function call: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + forward_primitive_name = f"{updated_primitive_name}.{Const.FORWARD}" + service_instance.data_collector.visit_and_clear_overflow_status(forward_primitive_name) + if service_instance.data_collector: + module_input_output = ModuleForwardInputsOutputs(args=hooked_inputs, kwargs=kwargs, output=out) + try: + service_instance.data_collector.forward_data_collect(forward_primitive_name, instance_self, + os.getpid(), module_input_output) + except Exception as exception: + raise Exception("This is a primitive op dump error during forward data collection: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + if service_instance.data_collector.if_return_forward_new_output(): + out = service_instance.data_collector.get_forward_new_output() + + try: + out = hook_primitive_outputs(out, captured_grads_output, updated_primitive_name) + except Exception as exception: + raise Exception("This is a primitive op dump error during output hooking: {}," + " primitive_name: {}".format(exception, primitive_name)) from exception + + return out + + + return wrapped_primitive_call + + def update_primitive_counters(self, primitive_name): + if primitive_name not in self.primitive_counters: + self.primitive_counters[primitive_name] = 0 + else: + self.primitive_counters[primitive_name] += 1 + + def register_hooks(self): + primitive_set = set() + for _, cell in self.model.cells_and_names(): + for pname, primitive in cell._primitives.items(): + primitive_set.add((pname, primitive)) + + for pname, primitive in primitive_set: + NewPrimitive = type('NewPrimitive', (primitive.__class__,), + {'__call__': self.wrap_primitive(primitive.__call__, pname)}) + primitive.__class__ = NewPrimitive + def step(self): self.current_iter += 1 self.data_collector.update_iter(self.current_iter) HOOKCell.cell_count = defaultdict(int) + self.primitive_counters.clear() + + @staticmethod + def check_model_valid(model): + if not model or isinstance(model, nn.Cell): + return model + raise MsprobeException( + MsprobeException.INVALID_PARAM_ERROR, "model 参数必须是 mindspore.nn.Cell 类型。" + ) def start(self, model=None): - self.model = model + self.model = Service.check_model_valid(model) self.start_call = True logger.info("msprobe: debugger.start() is set successfully") if self.config.step and self.current_iter > max(self.config.step): @@ -150,3 +297,5 @@ class Service: if self.config.level == "L1": api_register.initialize_hook(functools.partial(self.build_hook, BaseScope.Module_Type_API)) api_register.api_set_hook_func() + if self.model: + self.register_hooks()