From 3c5ea1030940e7ee5eb282f4c89dfae09b949407 Mon Sep 17 00:00:00 2001 From: Henry Date: Wed, 7 Aug 2024 11:48:07 +0800 Subject: [PATCH] pynative + jit --- .../mindspore/debugger/precision_debugger.py | 2 +- .../mindspore/dump/dump_tool_factory.py | 2 +- .../mindspore/dump/hook_cell/api_registry.py | 2 +- .../msprobe/mindspore/dump/jit_dump.py | 332 ++++++++++++++++++ .../msprobe/mindspore/service.py | 20 +- .../mindspore_ut/debugger/test_jit_dump.py | 125 +++++++ 6 files changed, 479 insertions(+), 4 deletions(-) create mode 100644 debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py create mode 100644 debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_jit_dump.py diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index d2a5e8d2de..0b3fec632d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -61,7 +61,7 @@ class PrecisionDebugger: return instance.config.execution_mode = instance._get_execution_mode() - if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.task != Const.FREE_BENCHMARK: + if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.task != Const.FREE_BENCHMARK and instance.config.level != "kernel": if not instance.service: instance.service = Service(instance.config) instance.service.start(target) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index 1e4b06a387..138dcb60d4 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -19,7 +19,7 @@ class DumpToolFactory: Const.KERNEL: { Const.GRAPH_KBYK_MODE: KernelKbykDump, Const.GRAPH_GE_MODE: KernelGraphDump, - Const.PYNATIVE_MODE: None + Const.PYNATIVE_MODE: KernelKbykDump } } diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py index 5508416fde..584bf91c1f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/hook_cell/api_registry.py @@ -98,7 +98,7 @@ class ApiRegistry: self.mint_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintOP, attr_name) for attr_name in dir(HOOKMintNNFunctionalOP): if attr_name.startswith(Const.ATTR_NAME_PREFIX): - self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name) + self.mint_func_ops_hook_attr[attr_name[Const.ATTR_NAME_PREFIX_LEN:]] = getattr(HOOKMintNNFunctionalOP, attr_name) api_register = ApiRegistry() diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py new file mode 100644 index 0000000000..d710106146 --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/dump/jit_dump.py @@ -0,0 +1,332 @@ +import os +import json +import stat +import math +import inspect +import hashlib +import numpy as np +import mindspore as ms +from mindspore import ops +from mindspore.common import mutable +from mindspore.common import dtype as mstype +from mindspore.common.api import _MindsporeFunctionExecutor +from mindspore.ops import constexpr +from mindspore.ops._primitive_cache import _get_cache_prim +from mindspore.ops.operations.math_ops import NPUGetFloatStatusV2, NPUClearFloatStatusV2 +from mindspore._c_expression import MSContext, PyNativeExecutor_ +from msprobe.mindspore.common.log import logger +from msprobe.mindspore.dump.hook_cell.api_registry import api_register + + + +global unsupport_count +class DataInfo(object): + def __init__(self, save_data, summary_data, dtype, shape, md5_nume, l2norm): + self.save_data = save_data + self.summary_data = summary_data + self.dtype = dtype + self.shape = shape + self.md5_nume = md5_nume + self.l2norm = l2norm + + +@constexpr +def _ascend_target(self): + return ms.context.get_context("device_target") == "Ascend" + + +@constexpr +def _ascend_910a_target(self): + return MSContext.get_instance().get_ascend_soc_version() == "ascend910" + + +@constexpr +def _ascend_910b_target(self): + return MSContext.get_instance().get_ascend_soc_version() == "ascend910b" + + +@constexpr +def _gpu_target(self): + return ms.context.get_context("device_target") == "GPU" + + +def _overflow(inputs): + if _gpu_target(): + return ops.FloatStatus()(inputs) + status = ops.isfinite(inputs) + return 1 - status.all() + + +def _all_finite(inputs, check_overflow_mode): + """all finite check""" + if _ascend_target(): + if (_ascend_910a_target()) or \ + (_ascend_910b_target() and check_overflow_mode == "SATURATION_MODE"): + status = ms.Tensor([0] * 8, mstype.int32) + status = ms.ops.depend(status, inputs) + get_status = _get_cache_prim(NPUGetFloatStatusV2)()(status) + status = ms.ops.depend(status, get_status) + clear_status = _get_cache_prim(NPUClearFloatStatusV2)()(status) + get_status = ms.ops.depend(get_status, clear_status) + status_finite = get_status.equal(ms.Tensor(0, mstype.int32)).all() + return status_finite + + status_finite = False + outputs = ms.ops.HyperMap()(ops.Partial()(_overflow), inputs) + flag_sum = ms.ops.addn(outputs).reshape(()) + status_finite = ms.ops.less(flag_sum, 1) + return status_finite + + +def all_finite(inputs): + inputs = mutable(inputs) + _check_overflow_mode = os.environ.get('MS_ASCEND_CHECK_OVERFLOW_MODE') + return _all_finite(inputs, _check_overflow_mode) + + +def check_overflow(out_feat): + return all_finite((out_feat,)) + + +def dump_jit(name, in_feat, out_feat, is_forward): + ori_args = str(type(name)) + index = ori_args.find("__main__.") + if index!= -1: + result = ori_args[(index + len("__main__.")):-2] + if is_forward: + name_template = "jit_" + result + "_forward" + else: + name_template = "jit_" + result + "_backward" + dump_path = JitDump.dump_config.dump_path + step_num = "step" + str(JitDump.dump_step) + step_path = os.path.join(dump_path, step_num) + file_path = os.path_join(step_path, "rank") + dump_stack_file = os.path.join(file_path, "jit_stack.json") + if JitDump.dump_config.task == "overflow": + if not check_overflow(out_feat): + dump_stack_info(name_template, dump_stack_file) + dump_api_tensor(in_feat, name_template, out_feat, file_path) + dump_stack_info(name_template, dump_stack_file) + dump_api_tensor(in_feat, name_template, out_feat, file_path) + + +def dump_api_tensor(in_feat, name_template, out_feat, dump_file): + if in_feat is not None: + dump_tensor(in_feat, name_template + "_input", JitDump.dump_step, dump_file, JitDump.dump_config.task) + dump_tensor(out_feat, name_template + "_output", JitDump.dump_step, dump_file, JitDump.dump_config.task) + + +def complex_squre(A): + if ops.is_complex(A): + return ops.conj(A) * A + return ops.square(A) + + +def myl2norm(A): + ndim = A.ndim + dim = tuple(range(ndim)) + ret = ops.sqrt(ops.reduce_sum(complex_squre(A), dim)) + return ret + + +def is_unsupport_type(data): + unsupport_dtype = [ms.uint16, ms.uint32, ms.uint64, ms.complex64, ms.complex128] + is_unsupported_type = False + if data.dtype in unsupport_dtype: + is_unsupported_type = True + if unsupport_count == 0: + logger.user_attention(f"On the Acend platform, {unsupport_dtype}data types are currently not supported for calculating statistical values.") + unsupport_count += 1 + return is_unsupported_type + return is_unsupported_type + + +def cal_l2norm(data): + Key_ops = "wrap_ops." + if is_unsupport_type(data): + l2norm = None + return l2norm + + l2norm = myl2norm(data) + if l2norm.dtype == ms.bfloat16: + l2norm = ops.Cast()(l2norm, dtype = ms.float32) + l2norm = l2norm.tolist() + return l2norm + + +def cal_max(data): + if is_unsupport_type(data): + tensor_max = None + return tensor_max + + if(hasattr(ms,'mint')): + max = ms.mint.max + else: + max = ops.max + tensor_max = max(data).astype(ms.float32).tolist() + return tensor_max + + +def cal_min(data): + if is_unsupport_type(data): + tensor_min = None + return tensor_min + + if(hasattr(ms,'mint')): + min = ms.mint.min + else: + min = ops.min + tensor_min = min(data).astype(ms.float32).tolist() + return tensor_min + + +def cal_mean(data): + if is_unsupport_type(data): + tensor_mean = None + return tensor_mean + + if(hasattr(ms,'mint')): + mean = ms.mint.mean + else: + mean = ops.mean + tensor_mean = mean(data).astype(ms.float32).tolist() + return tensor_mean + + +def get_not_float_tensor_info(data, compute_summary): + saved_tensor = data.asnumpy() + tensor_max, tensor_min, tensor_mean = math.nan, math.nan, math.nan + if compute_summary: + if saved_tensor.size == 0 or saved_tensor.dtype == np.bool_: + pass + elif len(saved_tensor.shape) == 0: + tensor_max = saved_tensor.astype(np.float32).tolist() + tensor_min = saved_tensor.astype(np.float32).tolist() + tensor_mean = saved_tensor.astype(np.float32).tolist() + else: + tensor_max = cal_max(data) + tensor_min = cal_min(data) + tensor_mean = cal_mean(data) + summary_data = [tensor_max, tensor_min, tensor_mean] + md5_nume = hashlib.md5(saved_tensor).hexdigest() + l2norm = cal_l2norm(data) + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), md5_nume, l2norm) + summary_data = [tensor_max, tensor_min, tensor_mean] + return DataInfo(saved_tensor, summary_data, str(data.dtype), tuple(data.shape), [], []) + + +def get_scalar_data_info(data, compute_summary): + if compute_summary: + summary_data = [data, data, data] + md5_nume = hashlib.md5(str(data).encode()).hexdigest() + l2norm = np.linalg.norm(data).item() + return DataInfo(data, summary_data, str(type(data)), [], md5_nume, l2norm) + else: + summary_data = [math.nan] * 3 + return DataInfo(data, summary_data, str(type(data)), [], [], []) + + +def get_float_tensor_info(data, compute_summary): + dtype = str(data.dtype) + tensor_max, tensor_min, tensor_mean = math.nan, math.nan, math.nan + if compute_summary: + tensor_max = cal_max(data) + tensor_min = cal_min(data) + tensor_mean = cal_mean(data) + summary_data = [tensor_max, tensor_min, tensor_mean] + l2norm = cal_l2norm(data) + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() + md5_nume = hashlib.md5(saved_tensor).hexdigest() + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), md5_nume, l2norm) + summary_data = [tensor_max, tensor_min, tensor_mean] + if data.dtype == mstype.bfloat16: + data = ops.Cast()(data, dtype=mstype.float32) + saved_tensor = data.asnumpy() + return DataInfo(saved_tensor, summary_data, dtype, tuple(data.shape), [], []) + + +def dump_tensor(x, prefix, jit_num, file_path, dump_type): + compute_summary = True if dump_type in ['tensor', 'statistics'] else False + + if isinstance(x, (tuple, list)) and x: + for i, item in enumerate(x): + dump_tensor(item, "{}.{}".format(prefix, i), jit_num, file_path, dump_type) + elif isinstance(x, ms.Tensor): + dump_flag = True + if x.numel() == 0 or len(x.shape) == 0 or not x.is_floating_point(): + data_info_func = get_not_float_tensor_info + else: + data_info_func = get_float_tensor_info + + if dump_flag: + data_info = data_info_func(x, compute_summary) + dump_data(file_path, prefix, data_info) + else: + if isinstance(x, bool) or isinstance(x, int) or isinstance(x, float): + data_info = get_scalar_data_info(x) + dump_data(file_path, prefix, data_info) + + +def dump_data(file_path, prefix, data_info): + statistics_file_name = os.path.join(file_path, "jit_dump.csv") + dump_data_path = os.path.join(file_path, "dump_tensor_data") + with os.fdopen(os.open(statistics_file_name, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), + "a") as f: + if JitDump.dump_config.task == "tensor": + output_path = os.path.join(dump_data_path, f'{prefix}.npy') + np.save(output_path, data_info.save_data) + os.chmod(dump_data_path, 0o400) + json.dump([prefix, [], data_info.dtype, data_info.shape, data_info.summary_data, data_info.md5_nume, data_info.l2norm], f) + f.write('\n') + +def dump_stack_info(name_template, dump_file): + stack_str = [] + prefix = name_template.format("stack_info") + for (_, path, line, func, code, _) in inspect.stack()[3:]: + if code: + stack_line = " ".join([ + "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), + " ".join(["\n", code[0].strip() if code else code])])]) + else: + stack_line = " ".join([ + "File", ", ".join([path, " ".join(["line", str(line)]), " ".join(["in", func]), + " ".join(["\n", code])])]) + stack_str.append(stack_line) + + dump_stack_dic = {} + dump_stack_dic[prefix] = stack_str + json_str = json.dumps(dump_stack_dic, indent=4) + + with os.fdopen(os.open(dump_file, os.O_RDWR | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), "w") as f: + f.write(json_str) + +class JitDump(_MindsporeFunctionExecutor): + dump_config = None + dump_step = 0 + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._executor = PyNativeExecutor_.get_instance() + + def __call__(self, *args, **kwargs): + api_register.api_set_ori_func() + out = super().__call__(*args, **kwargs) + dump_jit(args[0], args[1], out, True) + api_register.api_set_hook_func() + return out + + @classmethod + def set_config(cls, value): + cls.dump_config = value + + @classmethod + def set_step(cls, value): + cls.dump_step = value + + def grad(self, obj, grad, weights, grad_position, *args, **kwargs): + api_register.api_set_ori_func() + output = self._executor.grad(grad, obj, weights, grad_position, *args, *(kwargs.values())) + dump_jit(obj, args, output, False) + api_register.api_set_hook_func() + return output \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index bd87effd97..4f4cf4558c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -19,9 +19,16 @@ from pathlib import Path import functools from collections import defaultdict +import mindspore as ms from mindspore.common.tensor import Tensor from mindspore import ops from mindspore import nn +try: + from mindspore.common._pijit_context import PIJitCaptureContext + pijit_label = True +except ImportError: + pijit_label = False + from msprobe.core.data_dump.data_collector import build_data_collector from msprobe.core.data_dump.scope import BaseScope @@ -36,6 +43,7 @@ from msprobe.core.data_dump.data_processor.base import ModuleBackwardInputsOutpu from msprobe.core.common.exceptions import MsprobeException from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell from msprobe.mindspore.cell_processor import CellProcessor +from msprobe.mindspore.dump.jit_dump import JitDump class Service: @@ -239,6 +247,7 @@ class Service: def step(self): self.current_iter += 1 + JitDump.set_step(self.current_iter) self.data_collector.update_iter(self.current_iter) HOOKCell.cell_count = defaultdict(int) CellProcessor.cell_count = {} @@ -304,6 +313,9 @@ class Service: construct_file_path = os.path.join(dump_dir, "construct.json") self.data_collector.update_dump_paths( dump_file_path, stack_file_path, construct_file_path, dump_data_dir, None) + + def empty(self, *args, **kwargs): + pass def register_hook_new(self): logger.info("The {} hook function is successfully mounted to the model.".format(self.config.task)) @@ -312,6 +324,12 @@ class Service: api_register.api_set_hook_func() if self.model: self.register_hooks() + JitDump.set_config(self.config) + ms.common.api._MindsporeFunctionExecutor = JitDump + ms.common.api._PyNativeExecutor.grad = JitDump.grad + if pijit_label: + PIJitCaptureContext.__enter__ = self.empty + PIJitCaptureContext.__exit__ = self.empty if self.config.level == "L0": if not self.model: @@ -332,4 +350,4 @@ class Service: cell.register_backward_pre_hook( self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.START)) cell.register_backward_hook( - self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) + self.cell_processor.node_hook(prefix + Const.BACKWARD, Const.STOP)) \ No newline at end of file diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_jit_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_jit_dump.py new file mode 100644 index 0000000000..b425fe910f --- /dev/null +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/debugger/test_jit_dump.py @@ -0,0 +1,125 @@ +import numpy as np +import unittest +from unittest.mock import patch, MagicMock +import mindspore as ms +import mindspore.common.dtype as mstype +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore.common.tensor import Tensor +from mindspore import jit +from msprobe.mindspore import PrecisionDebugger +from msprobe.core.common_config import CommonConfig, BaseConfig + +def conv(in_channels, out_channels, kernel_size, stride=1, padding=0, pad_mode="valid", has_bias=True): + return nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, +has_bias=has_bias, pad_mode=pad_mode) + +def fc_with_initialize(input_channels, out_channels, has_bias=True): + return nn.Dense(input_channels, out_channels, has_bias=has_bias) + +class DataNormTranspose(nn.Cell): + """Normalize an tensor image with mean and standard deviation. + + Given mean: (R, G, B) and std: (R, G, B), + will normalize each channel of the torch.*Tensor, i.e. + channel = (channel - mean) / std + + Args: + mean (sequence): Sequence of means for R, G, B channels respectively. + std (sequence): Sequence of standard deviations for R, G, B channels + respectively. + """ + + def __init__(self, dataset_name='imagenet'): + super(DataNormTranspose, self).__init__() + # Computed from random subset of ImageNet training images + if dataset_name == 'imagenet': + self.mean = Tensor(np.array([0.485 * 255, 0.456 * 255, 0.406 * 255]).reshape((1, 1, 1, 3)), mstype.float32) + self.std = Tensor(np.array([0.229 * 255, 0.224 * 255, 0.225 * 255]).reshape((1, 1, 1, 3)), mstype.float32) + else: + self.mean = Tensor(np.array([0.4914, 0.4822, 0.4465]).reshape((1, 1, 1, 3)), mstype.float32) + self.std = Tensor(np.array([0.2023, 0.1994, 0.2010]).reshape((1, 1, 1, 3)), mstype.float32) + + def construct(self, x): + x = (x - self.mean) / self.std + x = ops.transpose(x, (0, 3, 1, 2)) + return x + +class AlexNet(nn.Cell): + """ + Alexnet + """ + + def __init__(self, num_classes=10, channel=3, phase='train', include_top=True, dataset_name='imagenet'): + super(AlexNet, self).__init__() + self.data_trans = DataNormTranspose(dataset_name=dataset_name) + self.conv1 = conv(channel, 64, 11, stride=4, pad_mode="same", has_bias=True) + self.conv2 = conv(64, 128, 5, pad_mode="same", has_bias=True) + self.conv3 = conv(128, 192, 3, pad_mode="same", has_bias=True) + self.conv4 = conv(192, 256, 3, pad_mode="same", has_bias=True) + self.conv5 = conv(256, 256, 3, pad_mode="same", has_bias=True) + self.relu = nn.ReLU() + nn.BatchNorm2d + self.max_pool2d = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='valid') + self.include_top = include_top + if self.include_top: + dropout_ratio = 0.65 + if phase == 'test': + dropout_ratio = 1.0 + self.flatten = nn.Flatten() + self.fc1 = fc_with_initialize(6 * 6 * 256, 4096) + self.fc2 = fc_with_initialize(4096, 4096) + self.fc3 = fc_with_initialize(4096, num_classes) + self.dropout = nn.Dropout(p=1 - dropout_ratio) + + @jit + def construct(self, x): + """define network""" + x = self.data_trans(x) + x = self.conv1(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv2(x) + x = self.relu(x) + x = self.max_pool2d(x) + x = self.conv3(x) + x = self.relu(x) + x = self.conv4(x) + x = self.relu(x) + x = self.conv5(x) + x = self.relu(x) + x = self.max_pool2d(x) + if not self.include_top: + return x + x = self.flatten(x) + x = self.fc1(x) + x = self.relu(x) + x = self.dropout(x) + x = self.fc2(x) + x = self.relu(x) + x = self.dropout(x) + x = self.fc3(x) + x = ops.celu(x, 2.0) + return x + +if __name__ == "__main__": + json_config = { + "task": "statistics", + "dump_path": "/absolute_path", + "rank": [], + "step": [], + "level": "L1" + } + + common_config = CommonConfig(json_config) + task_config = BaseConfig(json_config) + mock_parse_json_config = MagicMock() + mock_parse_json_config.return_value = [common_config, task_config] + debugger = PrecisionDebugger() + ms.set_context(mode=ms.PYNATIVE_MODE) + net = AlexNet() + debugger.start() + ops.relu(ms.Tensor(np.random.random([1, 227, 227, 3]).astype(np.float32))) + grad_net = ms.grad(net, None, net.trainable_params()) + output = grad_net(ms.Tensor(np.random.random([1, 227, 227, 3]).astype(np.float32))) + debugger.stop() -- Gitee