diff --git a/debug/accuracy_tools/msprobe/core/common/const.py b/debug/accuracy_tools/msprobe/core/common/const.py index b59536aa5d4c6356737cd148e4ebe9fc9e6f2c79..c1a453a21a6c2f8f30f22812214e2a6e4fc53932 100644 --- a/debug/accuracy_tools/msprobe/core/common/const.py +++ b/debug/accuracy_tools/msprobe/core/common/const.py @@ -81,12 +81,12 @@ class Const: INT_TYPE = [np.int32, np.int64] NPU = 'NPU' DISTRIBUTED = 'Distributed' - + INPLACE_LIST = [ "broadcast", "all_reduce", "reduce", "all_gather", "gather", "scatter", "reduce_scatter", "_reduce_scatter_base", "_all_gather_base", "send", "recv", "irecv", "isend", "all_to_all_single", "all_to_all" ] - + CONVERT = { "int32_to_int64": ["torch.int32", "torch.int64"], } @@ -253,3 +253,17 @@ class OverflowConst: OVERFLOW_DEBUG_MODE_ENABLE = "OVERFLOW_DEBUG_MODE_ENABLE" OVERFLOW_ORIGINAL_MODE = 0 OVERFLOW_DEBUG_MODE = 1 + + +class MsConst: + CELL = "cell" + API = "api" + KERNEL = "kernel" + TOOL_LEVEL_DICT = { + "L0": CELL, + "L1": API, + "L2": KERNEL + } + PYNATIVE_MODE = "pynative" + GRAPH_GE_MODE = "graph_ge" + GRAPH_KBYK_MODE = "graph_kbyk" diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py index 0f0cdd905ebd3296e49abde768c5378121b5bd66..c702dedac0d8ebe448e672c793fe489c70725191 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/debugger_config.py @@ -1,22 +1,19 @@ import os + from msprobe.core.common.utils import Const +from msprobe.core.common.const import MsConst class DebuggerConfig: - convert_map = { - "L0": "cell", - "L1": "api", - "L2": 'kernel' - } - def __init__(self, common_config, task_config): + self.execution_mode = None self.dump_path = common_config.dump_path self.task = common_config.task self.rank = [] if not common_config.rank else common_config.rank self.step = [] if not common_config.step else common_config.step if not common_config.level: common_config.level = "L1" - self.level = DebuggerConfig.convert_map[common_config.level] + self.level = MsConst.TOOL_LEVEL_DICT.get(common_config.level, MsConst.API) self.level_ori = common_config.level self.list = [] if not task_config.list else task_config.list self.scope = [] if not task_config.scope else task_config.scope diff --git a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py index 7082fc13e77beb7a3964c073c9ee5bb05ea2dd56..b4087012429ae2621b12e0d8a24fb3dd6bb0df0c 100644 --- a/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py +++ b/debug/accuracy_tools/msprobe/mindspore/debugger/precision_debugger.py @@ -1,9 +1,12 @@ import os + import mindspore as ms + from msprobe.mindspore.service import Service from msprobe.mindspore.ms_config import parse_json_config from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +from msprobe.core.common.const import MsConst class PrecisionDebugger: @@ -14,6 +17,8 @@ class PrecisionDebugger: cls._instance = super().__new__(cls) cls._instance.initialized = False cls._instance.config = None + cls.service = None + cls.first_start = False return cls._instance def __init__(self, config_path=None): @@ -24,18 +29,34 @@ class PrecisionDebugger: common_config, task_config = parse_json_config(config_path) self.config = DebuggerConfig(common_config, task_config) self.initialized = True - self.service = Service(self.config) + + @staticmethod + def _get_execution_mode(): + if ms.get_context("mode") == ms.GRAPH_MODE: + if ms.context.get_jit_config().get("jit_level") == "O2" or ms.get_context("jit_level") == "O2": + return MsConst.GRAPH_GE_MODE + else: + return MsConst.GRAPH_KBYK_MODE + else: + return MsConst.PYNATIVE_MODE @classmethod def start(cls): instance = cls._instance if not instance: raise Exception("No instance of PrecisionDebugger found.") - if ms.get_context("mode") == ms.PYNATIVE_MODE and instance.config.level_ori == "L1": + + instance.config.execution_mode = instance._get_execution_mode() + if instance.config.execution_mode == MsConst.PYNATIVE_MODE and instance.config.level == MsConst.API: + if not instance.service: + instance.service = Service(instance.config) instance.service.start() else: - handler = TaskHandlerFactory.create(instance.config) - handler.handle() + if not instance.first_start: + handler = TaskHandlerFactory.create(instance.config) + handler.handle() + + instance.first_start = True @classmethod def stop(cls): diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py index 2c4579b0e75fe1573f387f696c3d9e4efd4945e3..a09905c5b02810a6e2ff79d5bd6a38f30c17e316 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/dump_tool_factory.py @@ -1,38 +1,38 @@ +from msprobe.core.common.const import MsConst from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump +from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump class DumpToolFactory: tools = { - "cell": { - "kbk": None, - "graph": None, - "pynative": None + MsConst.CELL: { + MsConst.GRAPH_KBYK_MODE: None, + MsConst.GRAPH_GE_MODE: None, + MsConst.PYNATIVE_MODE: None }, - "api": { - "kbk": ApiKbkDump, - "graph": None, - "pynative": None + MsConst.API: { + MsConst.GRAPH_KBYK_MODE: None, + MsConst.GRAPH_GE_MODE: None, + MsConst.PYNATIVE_MODE: None }, - "kernel": { - "kbk": None, - "graph": KernelGraphDump, - "pynative": None + MsConst.KERNEL: { + MsConst.GRAPH_KBYK_MODE: KernelKbykDump, + MsConst.GRAPH_GE_MODE: KernelGraphDump, + MsConst.PYNATIVE_MODE: None } } @staticmethod def create(config: DebuggerConfig): + if config.level == MsConst.CELL: + raise Exception("Cell dump is not supported now.") + if config.level == MsConst.API: + raise Exception("API dump is not supported in graph mode.") tool = DumpToolFactory.tools.get(config.level) if not tool: - raise Exception("valid level is needed.") - if config.level == "api": - tool = tool.get("kbk") - elif config.level == "kernel": - tool = tool.get("graph") - elif config.level == "cell": - raise Exception("Cell dump in not supported now.") + raise Exception("Valid level is needed.") + tool = tool.get(config.execution_mode) if not tool: - raise Exception("Data dump in not supported in this mode.") - return tool(config) \ No newline at end of file + raise Exception(f"Data dump is not supported in {config.execution_mode} when dump level is {config.level}.") + return tool(config) diff --git a/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py similarity index 40% rename from debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py rename to debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py index 5c7af45d79060c00ce198f19a589d46bacf1f756..b815b9b1345dfe965139d54898eabe3439553571 100644 --- a/debug/accuracy_tools/msprobe/mindspore/dump/api_kbk_dump.py +++ b/debug/accuracy_tools/msprobe/mindspore/dump/kernel_kbyk_dump.py @@ -6,50 +6,59 @@ from msprobe.core.common.log import logger from msprobe.core.common.file_check import FileOpen -class ApiKbkDump: +COMMON_SETTINGS = "common_dump_settings" +E2E_SETTINGS = "e2e_dump_settings" + + +class KernelKbykDump: def __init__(self, config: DebuggerConfig): self.dump_json = dict() - self.dump_json["common_dump_settings"] = dict() - self.dump_json["common_dump_settings"]["dump_mode"] = 0 - self.dump_json["common_dump_settings"]["path"] = "" - self.dump_json["common_dump_settings"]["net_name"] = "Net" - self.dump_json["common_dump_settings"]["iteration"] = "all" - self.dump_json["common_dump_settings"]["saved_data"] = "statistic" - self.dump_json["common_dump_settings"]["input_output"] = 0 - self.dump_json["common_dump_settings"]["kernels"] = [] - self.dump_json["common_dump_settings"]["support_device"] = [0,1,2,3,4,5,6,7] - self.dump_json["e2e_dump_settings"] = dict() - self.dump_json["e2e_dump_settings"]["enable"] = True - self.dump_json["e2e_dump_settings"]["trans_flag"] = True + common_set = dict() + e2e_set = dict() + common_set = dict() + common_set["dump_mode"] = 0 + common_set["path"] = "" + common_set["net_name"] = "Net" + common_set["iteration"] = "all" + common_set["saved_data"] = "statistic" + common_set["input_output"] = 0 + common_set["kernels"] = [] + common_set["support_device"] = [0, 1, 2, 3, 4, 5, 6, 7] + e2e_set = dict() + e2e_set["enable"] = True + e2e_set["trans_flag"] = True if len(config.list) > 0: - self.dump_json["common_dump_settings"]["dump_mode"] = 1 - self.dump_json["common_dump_settings"]["kernels"] = config.list - self.dump_json["common_dump_settings"]["path"] = config.dump_path + common_set["dump_mode"] = 1 + common_set["kernels"] = config.list + common_set["path"] = config.dump_path if len(config.step) > 0: step_str = "" for s in config.step: step_str += (str(s) + '|') - self.dump_json["common_dump_settings"]["iteration"] = step_str[:-1] + common_set["iteration"] = step_str[:-1] if len(config.rank) > 0: - self.dump_json["common_dump_settings"]["support_device"] = config.rank + common_set["support_device"] = config.rank if config.task == "tensor": - self.dump_json["common_dump_settings"]["saved_data"] = "tensor" + common_set["saved_data"] = "tensor" if len(config.data_mode) == 1: if config.data_mode[0] == "input": - self.dump_json["common_dump_settings"]["input_output"] = 1 + common_set["input_output"] = 1 if config.data_mode[0] == "output": - self.dump_json["common_dump_settings"]["input_output"] = 2 + common_set["input_output"] = 2 + + self.dump_json[COMMON_SETTINGS] = common_set + self.dump_json[E2E_SETTINGS] = e2e_set def handle(self): - json_path = self.dump_json["common_dump_settings"]["path"] + json_path = self.dump_json[COMMON_SETTINGS]["path"] make_dump_path_if_not_exists(json_path) - json_path = os.path.join(json_path, "api_kbk_dump.json") + json_path = os.path.join(json_path, "kernel_kbyk_dump.json") with FileOpen(json_path, 'w') as f: json.dump(self.dump_json, f) logger.info(json_path + " has been created.") - os.environ["GRAPH_OP_RUN"] = "1" + os.environ["MINDSPORE_DUMP_CONFIG"] = json_path if "MS_ACL_DUMP_CFG_PATH" in os.environ: del os.environ["MS_ACL_DUMP_CFG_PATH"] diff --git a/debug/accuracy_tools/msprobe/mindspore/service.py b/debug/accuracy_tools/msprobe/mindspore/service.py index 87287aabdfcb2f3066d054e35c0c828dc0aaf365..50776aaf1097339e7c6d98944db7ddf2d2238c5f 100644 --- a/debug/accuracy_tools/msprobe/mindspore/service.py +++ b/debug/accuracy_tools/msprobe/mindspore/service.py @@ -14,6 +14,7 @@ # ============================================================================ import os +import copy from pathlib import Path import functools from collections import defaultdict @@ -33,9 +34,9 @@ from msprobe.mindspore.dump.hook_cell.hook_cell import HOOKCell class Service: def __init__(self, config): self.model = None - self.config = config + self.config = copy.deepcopy(config) self.config.level = self.config.level_ori - self.data_collector = build_data_collector(config) + self.data_collector = build_data_collector(self.config) self.switch = False self.current_iter = 0 self.first_start = True diff --git a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py index 7b7e6fd889c775a4491e824c1f73e6021cb99350..c39552b7fbf3e2f481719b4f08a7fe9c84fdd0a6 100644 --- a/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/mindspore/task_handler_factory.py @@ -1,3 +1,4 @@ +from msprobe.core.common.const import Const from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.dump_tool_factory import DumpToolFactory from msprobe.mindspore.overflow_check.overflow_check_tool_factory import OverflowCheckToolFactory @@ -5,9 +6,9 @@ from msprobe.mindspore.overflow_check.overflow_check_tool_factory import Overflo class TaskHandlerFactory: tasks = { - "tensor": DumpToolFactory, - "statistics": DumpToolFactory, - "overflow_check": OverflowCheckToolFactory + Const.TENSOR: DumpToolFactory, + Const.STATISTICS: DumpToolFactory, + Const.OVERFLOW_CHECK: OverflowCheckToolFactory } @staticmethod diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py index fb88d7bbbf328b0b8a61b11d41808b756881510e..94e93fdf235b85550b0fd716302183e7c7b0a948 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_dump_tool_factory.py @@ -39,12 +39,12 @@ class TestDumpToolFactory(TestCase): config.level = "module" with self.assertRaises(Exception) as context: DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "valid level is needed.") + self.assertEqual(str(context.exception), "Valid level is needed.") config.level = "cell" with self.assertRaises(Exception) as context: DumpToolFactory.create(config) - self.assertEqual(str(context.exception), "Cell dump in not supported now.") + self.assertEqual(str(context.exception), "Cell dump is not supported now.") config.level = "kernel" dumper = DumpToolFactory.create(config) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py similarity index 76% rename from debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py rename to debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py index 7411018ff08507f0ab867b6394aa1c08b5f26469..fc62af1ac46be9e67f3b27165897efa9ca18089b 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_api_kbk_dump.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_kernel_kbyk_dump.py @@ -21,7 +21,7 @@ from unittest.mock import patch from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig -from msprobe.mindspore.dump.api_kbk_dump import ApiKbkDump +from msprobe.mindspore.dump.kernel_kbyk_dump import KernelKbykDump class TestApiKbkDump(TestCase): @@ -38,14 +38,13 @@ class TestApiKbkDump(TestCase): common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) config = DebuggerConfig(common_config, task_config) - dumper = ApiKbkDump(config) + dumper = KernelKbykDump(config) self.assertEqual(dumper.dump_json["common_dump_settings"]["iteration"], "0|2") os.environ["MS_ACL_DUMP_CFG_PATH"] = "path" - with patch("msprobe.mindspore.dump.api_kbk_dump.make_dump_path_if_not_exists"), \ - patch("msprobe.mindspore.dump.api_kbk_dump.FileOpen"), \ - patch("msprobe.mindspore.dump.api_kbk_dump.json.dump"), \ - patch("msprobe.mindspore.dump.api_kbk_dump.logger.info"): + with patch("msprobe.mindspore.dump.kernel_kbyk_dump.make_dump_path_if_not_exists"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.FileOpen"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.json.dump"), \ + patch("msprobe.mindspore.dump.kernel_kbyk_dump.logger.info"): dumper.handle() - self.assertEqual(os.environ.get("GRAPH_OP_RUN"), "1") self.assertEqual(os.environ.get("MS_ACL_DUMP_CFG_PATH"), None) diff --git a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py index 41be7b1db6c7d723aaeec1607f564ac3d772b404..67cacb86c999208e3d52969f57e7b54fdd31f688 100644 --- a/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py +++ b/debug/accuracy_tools/msprobe/test/mindspore_ut/test_task_handler_factory.py @@ -21,6 +21,7 @@ from msprobe.core.common_config import CommonConfig, BaseConfig from msprobe.mindspore.debugger.debugger_config import DebuggerConfig from msprobe.mindspore.dump.kernel_graph_dump import KernelGraphDump from msprobe.mindspore.task_handler_factory import TaskHandlerFactory +from msprobe.core.common.const import MsConst class TestTaskHandlerFactory(TestCase): @@ -43,6 +44,7 @@ class TestTaskHandlerFactory(TestCase): common_config = CommonConfig(json_config) task_config = BaseConfig(json_config) config = DebuggerConfig(common_config, task_config) + config.execution_mode = MsConst.GRAPH_GE_MODE handler = TaskHandlerFactory.create(config) self.assertTrue(isinstance(handler, KernelGraphDump))