diff --git a/accuracy_tools/msprobe/base/__init__.py b/accuracy_tools/msprobe/base/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f57ae818ae0b4236e394ff1e5ff3c88b8401712f --- /dev/null +++ b/accuracy_tools/msprobe/base/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from msprobe.base.cmd import BaseCommand, Command +from msprobe.base.component.manager import BaseComponent, Component, ConsumerComp, ProducerComp, Scheduler +from msprobe.base.config import SIZE_1M, BaseConfig, Dict2Class +from msprobe.base.service.manager import BaseService, Service diff --git a/accuracy_tools/msprobe/base/cmd.py b/accuracy_tools/msprobe/base/cmd.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0022bce7da5a545f93bb2fce7eb29edfddc410 --- /dev/null +++ b/accuracy_tools/msprobe/base/cmd.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from argparse import RawTextHelpFormatter +from sys import argv + +from msprobe.utils.constants import CmdConst, MsgConst +from msprobe.utils.exceptions import MsprobeException + + +class Command: + """ + A hierarchical command registration system that supports multi-level command structures. + """ + + _cmd_map = {} # Internal storage: {parent_cmd: {name: command_class}} + _instance = None + + def __new__(cls, *args, **kwargs): + if not cls._instance: + cls._instance = super(Command, cls).__new__(cls) + return cls._instance + + @classmethod + def register(cls, parent_cmd, name): + def decorator(command_cls): + if parent_cmd not in cls._cmd_map: + cls._cmd_map[parent_cmd] = {} + cls._cmd_map[parent_cmd][name] = command_cls + return command_cls + + return decorator + + @classmethod + def get(cls, parent_cmd): + return cls._cmd_map.get(parent_cmd, {}) + + +class BaseCommand(ABC): + def __init__(self): + self.formatter_class = RawTextHelpFormatter + + @property + def service_key(self): + if isinstance(self.subcommand_level, int) and self.subcommand_level > 0: + return argv[self.subcommand_level] if len(argv) > self.subcommand_level else None + else: + raise MsprobeException(MsgConst.INVALID_ARGU, "Subcommand level must be a positive integer.") + + @abstractmethod + def add_arguments(self, parse): + pass + + def build_parser(self, parent_parser, parent_cmd_class): + if self.subcommand_level > MsgConst.MAX_RECURSION_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH} exceeded." + ) + subcommands = Command.get(parent_cmd_class) + if subcommands: + self.subcommand_level += 1 + subparsers = parent_parser.add_subparsers(dest=f"L{self.subcommand_level}command") + for name, cmd_class in subcommands.items(): + cmd_parser = subparsers.add_parser( + name=name, help=CmdConst.HELP_TASK_MAP.get(name), formatter_class=self.formatter_class + ) + cmd_class.add_arguments(cmd_parser) + self.build_parser(cmd_parser, cmd_class) diff --git a/accuracy_tools/msprobe/base/component/__init__.py b/accuracy_tools/msprobe/base/component/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53529bc8d3158c537ae7970cf531b33ba6acd57a --- /dev/null +++ b/accuracy_tools/msprobe/base/component/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/base/component/manager.py b/accuracy_tools/msprobe/base/component/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..fdc3078e185b05f0e2e045561dbaf3c83e40ce1b --- /dev/null +++ b/accuracy_tools/msprobe/base/component/manager.py @@ -0,0 +1,278 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod +from collections import deque +from threading import RLock + +from msprobe.utils.constants import MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.toolkits import register + + +class BaseComponent(object): + """ + Methods that need to be implemented: + activate: Called when service.start() is invoked. + deactivate: Called when service.stop() is invoked. + """ + + def __init__(self, priority=100): + self.activated = False + self.priority = priority + + @property + def is_activated(self): + return self.activated + + def activate(self, *args, **kwargs): + pass + + def deactivate(self, *args, **kwargs): + pass + + def do_activate(self): + if self.activated: + return + self.activate() + self.activated = True + + def do_deactivate(self): + if not self.activated: + return + self.deactivate() + self.activated = False + + +class ProducerComp(BaseComponent, ABC): + """ + A ProducerComp can generate data. + If the data is passively generated (e.g., when a consumer applies the data), implement "load_data". + If the data is actively generated (e.g., when an interest event occurs), + call "publish" to send it to subscribers. + """ + + def __init__(self, priority): + super(ProducerComp, self).__init__(priority) + self.output_buffer = deque() + self.subscribers = set() + + @property + def is_ready(self): + return len(self.output_buffer) > 0 + + @abstractmethod + def load_data(self): + pass + + def publish(self, data, msg_id=0): + """ + Wrap the data and pack it into the output buffer. + """ + self.output_buffer.append([self, data, msg_id]) + Scheduler().enqueue([self]) + + def on_subscribe(self, comp): + if not isinstance(comp, ConsumerComp): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Only ConsumerComp can subscribe to ProducerComp.") + self.subscribers.add(comp) + + def retrieve(self): + if self.output_buffer: + return self.output_buffer.popleft() + else: + return None + + def do_load_data(self): + if self.output_buffer: + return + data = self.load_data() + if data: + self.publish(data) + + def get_subscribers(self): + return self.subscribers + + +class ConsumerComp(BaseComponent, ABC): + """ + A ConsumerComp can consume data. + Call "subscribe" to subscribe data from a ProducerComp. + Implement "consume" to process data. + """ + + def __init__(self, priority): + super(ConsumerComp, self).__init__(priority) + self.dependencies = {} + + def subscribe(self, comp): + if not isinstance(comp, ProducerComp): + raise MsprobeException(MsgConst.INVALID_DATA_TYPE, "Only ProducerComp can subscribe to ConsumerComp.") + if self.is_activated: + raise MsprobeException( + MsgConst.INVALID_DATA_TYPE, f"Component {comp} must be subscribed before activation." + ) + if self.is_cycle(comp): + raise MsprobeException(MsgConst.RISK_ALERT, "Cycle dependency detected! Subscription denied.") + comp.on_subscribe(self) + if comp not in self.dependencies: + self.dependencies[comp] = None + + @abstractmethod + def consume(self, packages): + pass + + def is_cycle(self, comp, visited=None, stack=None): + if visited is None: + visited = set() + if stack is None: + stack = set() + if comp in stack: + return True + if comp in visited: + return False + visited.add(comp) + stack.add(comp) + if isinstance(comp, ConsumerComp): + for producer in comp.dependencies: + if self.is_cycle(producer, visited, stack): + return True + stack.remove(comp) + return False + + def on_receive(self, package): + try: + self.dependencies[package[0]] = package + except Exception as e: + raise MsprobeException( + MsgConst.PARSING_FAILED, + "The first element in the data (self.output_buffer) published by the producer must be itself.", + ) from e + + def get_empty_dependencies(self): + dependencies_list = [] + for k, v in self.dependencies.items(): + if v is None: + dependencies_list.append(k) + return dependencies_list + + def do_consume(self): + """ + Encapsulate the data in "dependencies" and invoke it using "consume". + """ + if self.get_empty_dependencies(): + return + packages = [] + for key in self.dependencies: + packages.append(self.dependencies[key]) + self.dependencies[key] = None + self.consume(packages) + + +class Component: + _component_type_map = {} + + @classmethod + def register(cls, name): + return register(name, cls._component_type_map) + + @classmethod + def get(cls, name): + return cls._component_type_map.get(name) + + +class Scheduler: + _instance = None + _lock = RLock() + + def __new__(cls): + with cls._lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + if self._initialized: + return + self.comp_ref = {} + self.queue = deque() + self.enqueued = set() + self.in_loop = False + self._initialized = True + + def add(self, components): + for comp in components: + if comp in self.comp_ref: + self.comp_ref[comp] += 1 + else: + self.comp_ref[comp] = 1 + comp.do_activate() + self.enqueue([comp]) + self.run_loop() + + def remove(self, components): + for comp in components: + if comp not in self.comp_ref: + continue + if self.comp_ref[comp] > 1: + self.comp_ref[comp] -= 1 + else: + comp.do_deactivate() + del self.comp_ref[comp] + + def enqueue(self, comps): + for comp in comps: + if comp not in self.enqueued: + self.queue.append(comp) + self.enqueued.add(comp) + + def run_loop(self): + if self.in_loop: + return + self.in_loop = True + try: + while self.queue: + comp = self.queue.popleft() + self.enqueued.remove(comp) + if isinstance(comp, ConsumerComp): + self._schedule_consumer(comp) + if isinstance(comp, ProducerComp): + self._schedule_producer(comp) + finally: + self.in_loop = False + + def _schedule_producer(self, comp: ProducerComp): + if not comp.is_ready: + return + package = comp.retrieve() + if not package: + return + subscribers = comp.get_subscribers() + if not subscribers: + return + for subscriber in subscribers: + subscriber.on_receive(package) + self.enqueue([subscriber]) + + def _schedule_consumer(self, comp: ConsumerComp): + dependencies = comp.get_empty_dependencies() + if not dependencies: + comp.do_consume() + self.enqueue([comp]) + return + for dependency in dependencies: + dependency.do_load_data() + if dependency.is_ready: + self.enqueue([dependency]) diff --git a/accuracy_tools/msprobe/base/config.py b/accuracy_tools/msprobe/base/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9368efdc6b84727adec06fac7e717a9642348dfb --- /dev/null +++ b/accuracy_tools/msprobe/base/config.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.common.validation import ( + valid_buffer_size, + valid_framework, + valid_level, + valid_log_level, + valid_seed, + valid_step_or_rank, + valid_task, +) +from msprobe.utils.constants import CfgConst, MsgConst +from msprobe.utils.exceptions import MsprobeException +from msprobe.utils.io import load_json +from msprobe.utils.log import logger + +SIZE_1M = 1_048_576 # 1024 * 1024 + + +class BaseConfig(ABC): + def __init__(self, config_path, task="", framework="", step: list = None, level: list = None): + self.config_path = config_path + self.config = load_json(self.config_path) + self.task_config = {} + self.task = task + self.framework = framework + self.step = step + self.level = level + + def __getattribute__(self, name): + attr = object.__getattribute__(self, name) + if name == "check_config" and callable(attr): + + def wrapper(*args, **kwargs): + self._common_check() + self._get_task_dict() + result = attr(*args, **kwargs) + return result + + return wrapper + return attr + + @abstractmethod + def check_config(self): + pass + + def _get_task_dict(self): + self.task_config = self.config.get(self.config.get(CfgConst.TASK)) + if not self.task_config: + raise MsprobeException( + MsgConst.REQUIRED_ARGU_MISSING, f'Missing dictionary for key "{self.config.get(CfgConst.TASK)}".' + ) + + def _common_check(self): + logger.info("Validating configuration file parameters.") + self._update_config(self.config, CfgConst.TASK, valid_task, self.task or self.config.get(CfgConst.TASK, None)) + self._update_config( + self.config, + CfgConst.FRAMEWORK, + valid_framework, + self.framework or self.config.get(CfgConst.FRAMEWORK, None), + ) + self._update_config( + self.config, CfgConst.STEP, valid_step_or_rank, self.step or self.config.get(CfgConst.STEP, []) + ) + self._update_config(self.config, CfgConst.RANK, valid_step_or_rank, self.config.get(CfgConst.RANK, [])) + self._update_config( + self.config, + CfgConst.LEVEL, + valid_level, + self.level or self.config.get(CfgConst.LEVEL, [CfgConst.LEVEL_API]), + ) + self._update_config( + self.config, CfgConst.LOG_LEVEL, valid_log_level, self.config.get(CfgConst.LOG_LEVEL, "info") + ) + self._update_config(self.config, CfgConst.SEED, valid_seed, self.config.get(CfgConst.SEED, None)) + self._update_config( + self.config, CfgConst.BUFFER_SIZE, valid_buffer_size, self.config.get(CfgConst.BUFFER_SIZE, SIZE_1M) + ) + + def _update_config(self, dic: dict, key: str, check_fun, value: str): + dic[key] = check_fun(value) + + +class Dict2Class: + def __init__(self, data: dict, depth: int = 0): + if depth > MsgConst.MAX_RECURSION_DEPTH: + raise MsprobeException( + MsgConst.RISK_ALERT, f"Maximum recursion depth of {MsgConst.MAX_RECURSION_DEPTH} exceeded." + ) + if data.get(CfgConst.TASK) in data: + data_pop = data.pop(data.get(CfgConst.TASK)) + for key, value in data_pop.items(): + if key == "input" and len(value) == 2: + setattr(self, "input_shape", value[0]) + setattr(self, "input_path", value[1]) + setattr(self, key, value) + for key, value in data.items(): + if isinstance(value, dict): + setattr(self, key, Dict2Class(value, depth + 1)) + else: + setattr(self, key, value) + + @classmethod + def __getattr__(cls, item): + raise MsprobeException(MsgConst.ATTRIBUTE_ERROR, f"{cls.__name__} object has no attribute {item}.") diff --git a/accuracy_tools/msprobe/base/service/__init__.py b/accuracy_tools/msprobe/base/service/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..53529bc8d3158c537ae7970cf531b33ba6acd57a --- /dev/null +++ b/accuracy_tools/msprobe/base/service/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/accuracy_tools/msprobe/base/service/manager.py b/accuracy_tools/msprobe/base/service/manager.py new file mode 100644 index 0000000000000000000000000000000000000000..756f36098111284b0dfa8e26ab934193f8d83b3d --- /dev/null +++ b/accuracy_tools/msprobe/base/service/manager.py @@ -0,0 +1,114 @@ +# Copyright (c) 2025-2025 Huawei Technologies Co., Ltd. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import ABC, abstractmethod + +from msprobe.base import BaseComponent, Scheduler +from msprobe.common.validation import valid_task +from msprobe.utils.constants import CfgConst, CmdConst +from msprobe.utils.io import load_json +from msprobe.utils.toolkits import get_current_rank, register + +_TASK_SERVICE_MAP = {CfgConst.TASK_STAT: CmdConst.DUMP, CfgConst.TASK_TENSOR: CmdConst.DUMP} + + +class Service: + _services_map = {} + + def __init__(self, *args, **kwargs): + cmd_namespace = kwargs.get("cmd_namespace") + serv_name = kwargs.get("serv_name") + if hasattr(cmd_namespace, CfgConst.CONFIG_PATH): + if not kwargs.get(CfgConst.TASK): + config = load_json(cmd_namespace.config_path) + task = valid_task(config.get(CfgConst.TASK)) + else: + task = valid_task(kwargs.get(CfgConst.TASK)) + serv_name = _TASK_SERVICE_MAP.get(task) + self.service_class = self.get(serv_name) + self.service_instance = self.service_class(*args, **kwargs) + + def __getattr__(self, name): + return getattr(self.service_instance, name) + + @classmethod + def register(cls, name): + return register(name, cls._services_map) + + @classmethod + def get(cls, name): + return cls._services_map.get(name) + + +class BaseService(ABC): + def __init__(self): + self.comps = [] + self.current_step = 0 + self.scheduler = Scheduler() + + @property + def is_skip(self): + return False + + @property + def current_rank(self): + try: + return int(get_current_rank()) + except Exception: + return None + + @abstractmethod + def construct(self): + pass + + def start(self, *args, **kwargs): + """ + Service startup workflow: + 1. Configure services (init_start). + 2. Build components (construct). + 3. Filter/prioritize components (ignore_actuator), then schedule execution. + 4. Schedule execution and cleanup. + 5. Post-processing (finalize_start). + """ + if self.is_skip: + return + self.init_start() + self.construct() + for attr in self.__dict__.values(): + if isinstance(attr, BaseComponent) and (attr not in self.comps): + self.comps.append(attr) + self.ignore_actuator(attr) + self.comps.sort(key=lambda x: x.priority) + self.scheduler.add(self.comps) + self.finalize_start() + + def init_start(self): + pass + + def ignore_actuator(self, attr): + pass + + def finalize_start(self): + pass + + def step(self, *args, **kwargs): + if self.is_skip: + return + self.current_step += 1 + + def stop(self, *args, **kwargs): + if self.is_skip: + return + self.scheduler.remove(self.comps) + self.comps.clear()