From 3e03fa8f2236a150fa050066ae9196edeb34fb2a Mon Sep 17 00:00:00 2001 From: stby Date: Mon, 8 Apr 2024 19:19:12 +0800 Subject: [PATCH 1/3] init --- profiler/cluster_custom_analyse/README.md | 0 profiler/cluster_custom_analyse/__init__.py | 14 +++++ .../analysis/__init__.py | 14 +++++ .../analysis/acl_api_sum.py | 28 ++++++++++ .../analysis/base_analysis.py | 27 +++++++++ .../cluster_custom_analysis.py | 55 +++++++++++++++++++ .../common_func/__init__.py | 14 +++++ .../common_func/analysis_loader.py | 37 +++++++++++++ .../common_func/constant.py | 20 +++++++ 9 files changed, 209 insertions(+) create mode 100644 profiler/cluster_custom_analyse/README.md create mode 100644 profiler/cluster_custom_analyse/__init__.py create mode 100644 profiler/cluster_custom_analyse/analysis/__init__.py create mode 100644 profiler/cluster_custom_analyse/analysis/acl_api_sum.py create mode 100644 profiler/cluster_custom_analyse/analysis/base_analysis.py create mode 100644 profiler/cluster_custom_analyse/cluster_custom_analysis.py create mode 100644 profiler/cluster_custom_analyse/common_func/__init__.py create mode 100644 profiler/cluster_custom_analyse/common_func/analysis_loader.py create mode 100644 profiler/cluster_custom_analyse/common_func/constant.py diff --git a/profiler/cluster_custom_analyse/README.md b/profiler/cluster_custom_analyse/README.md new file mode 100644 index 000000000..e69de29bb diff --git a/profiler/cluster_custom_analyse/__init__.py b/profiler/cluster_custom_analyse/__init__.py new file mode 100644 index 000000000..a0e9f748f --- /dev/null +++ b/profiler/cluster_custom_analyse/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/__init__.py b/profiler/cluster_custom_analyse/analysis/__init__.py new file mode 100644 index 000000000..a0e9f748f --- /dev/null +++ b/profiler/cluster_custom_analyse/analysis/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/acl_api_sum.py b/profiler/cluster_custom_analyse/analysis/acl_api_sum.py new file mode 100644 index 000000000..071b99d86 --- /dev/null +++ b/profiler/cluster_custom_analyse/analysis/acl_api_sum.py @@ -0,0 +1,28 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from analysis.base_analysis import BaseAnalysis + +class AclApiSum(BaseAnalysis): + + @staticmethod + def _mapper_func(): + pass + + def mapper_func(self, context): + pass + + def reducer_func(self, mapper_res): + pass \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/base_analysis.py b/profiler/cluster_custom_analyse/analysis/base_analysis.py new file mode 100644 index 000000000..60ea7e712 --- /dev/null +++ b/profiler/cluster_custom_analyse/analysis/base_analysis.py @@ -0,0 +1,27 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +class BaseAnalysis: + def __init__(self, params): + self._params = params + self._output_dir = None + self._output_files = {} + self._analysis_dict = {} + + def __enter__(self): + return self + + def run(self, context): + self._analysis_dict = {} \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/cluster_custom_analysis.py b/profiler/cluster_custom_analyse/cluster_custom_analysis.py new file mode 100644 index 000000000..65467bc95 --- /dev/null +++ b/profiler/cluster_custom_analyse/cluster_custom_analysis.py @@ -0,0 +1,55 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import sys + +from common_func import analysis_loader + +def print_analyses_list(): + pass + +def run_custom_analysis(analysis_name, analysis_args): + analysis_class = analysis_loader.get_class_from_name(analysis_name) + if not analysis_class: + print("[ERROR] unknown analysis.") + return None + + args_parsed = get_analysis_args(analysis_class, analysis_args) + #TODO try + with Context.create_context(args_parsed.mode) as context: + with analysis_class(args_parsed) as analysis: + analysis.run(context) + return analysis + +def main(): + parser = argparse.ArgumentParser(description="cluster custome analysis module") + parser.add_argument('--analysis-help', action='store_true', help='Print available analyses') + + args_parsed, args_remained = parser.parse_known_args() + + if args_parsed.analysis_help: + print_analyses_list() + return + + if not args_remained: + print("[ERROR] No analysis specified.") + return + + analysis = run_custom_analysis(args_remained[0], args_remained[1:]) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/common_func/__init__.py b/profiler/cluster_custom_analyse/common_func/__init__.py new file mode 100644 index 000000000..a0e9f748f --- /dev/null +++ b/profiler/cluster_custom_analyse/common_func/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/common_func/analysis_loader.py b/profiler/cluster_custom_analyse/common_func/analysis_loader.py new file mode 100644 index 000000000..7676d7587 --- /dev/null +++ b/profiler/cluster_custom_analyse/common_func/analysis_loader.py @@ -0,0 +1,37 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib +import inspect +import sys + +from common_func.constant import Constant +from analysis.base_analysis import BaseAnalysis + +def is_analysis_class(obj): + return inspect.isclass(obj) and issubclass(obj, BaseAnalysis) + +def get_class_from_name(analysis_name : str): + sys.path.append(Constant.ANALYSIS_PATH) + analysis_path = f"analysis.{analysis_name}" + module = None + try: + module = importlib.import_module(analysis_path) + except Exception as e: + print(f"[ERROR] {analysis_path} not find:{e}") + specific_analysis = inspect.getmembers(module, is_analysis_class) + if not specific_analysis: + print(f"[ERROR] {analysis_name} not found.") + return specific_analysis[1] \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/common_func/constant.py b/profiler/cluster_custom_analyse/common_func/constant.py new file mode 100644 index 000000000..ad6bfc915 --- /dev/null +++ b/profiler/cluster_custom_analyse/common_func/constant.py @@ -0,0 +1,20 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +class Constant(object): + CLUSTER_CUSTOM_ANALYSE_PATH = os.path.abspath(os.path.dirname(__file__)) + ANALYSIS_PATH = os.path.join(CLUSTER_CUSTOM_ANALYSE_PATH, 'analysis') \ No newline at end of file -- Gitee From 6bf667719ad664b0a236c60ca72111e983436b56 Mon Sep 17 00:00:00 2001 From: stby Date: Tue, 23 Apr 2024 11:36:38 +0800 Subject: [PATCH 2/3] put in cluster_analyse --- .../analysis/analysis_facade.py | 17 +++- .../cluster_analyse/analysis/base_analysis.py | 33 +++++++ .../analysis/cann_api_sum.py} | 31 +++++- profiler/cluster_analyse/cluster_analysis.py | 74 +++++++++++---- .../common_func/analysis_loader.py | 7 +- .../cluster_analyse/common_func/constant.py | 8 ++ .../cluster_analyse/common_func/context.py | 94 +++++++++++++++++++ profiler/cluster_custom_analyse/README.md | 0 profiler/cluster_custom_analyse/__init__.py | 14 --- .../analysis/__init__.py | 14 --- .../analysis/base_analysis.py | 27 ------ .../cluster_custom_analysis.py | 55 ----------- .../common_func/__init__.py | 14 --- .../common_func/constant.py | 20 ---- 14 files changed, 237 insertions(+), 171 deletions(-) rename profiler/{cluster_custom_analyse/analysis/acl_api_sum.py => cluster_analyse/analysis/cann_api_sum.py} (55%) rename profiler/{cluster_custom_analyse => cluster_analyse}/common_func/analysis_loader.py (86%) create mode 100644 profiler/cluster_analyse/common_func/context.py delete mode 100644 profiler/cluster_custom_analyse/README.md delete mode 100644 profiler/cluster_custom_analyse/__init__.py delete mode 100644 profiler/cluster_custom_analyse/analysis/__init__.py delete mode 100644 profiler/cluster_custom_analyse/analysis/base_analysis.py delete mode 100644 profiler/cluster_custom_analyse/cluster_custom_analysis.py delete mode 100644 profiler/cluster_custom_analyse/common_func/__init__.py delete mode 100644 profiler/cluster_custom_analyse/common_func/constant.py diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 06be6002e..2e3e81cc5 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -18,10 +18,11 @@ from multiprocessing import Process from analysis.communication_analysis import CommunicationAnalysis from analysis.comm_matrix_analysis import CommMatrixAnalysis from analysis.step_trace_time_analysis import StepTraceTimeAnalysis - +from common_func.context import Context +from common_func.constant import Constant class AnalysisFacade: - analysis_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis} + default_module = {CommunicationAnalysis, StepTraceTimeAnalysis, CommMatrixAnalysis} def __init__(self, params: dict): self.params = params @@ -29,10 +30,20 @@ class AnalysisFacade: def cluster_analyze(self): # 多个profiler用多进程处理 process_list = [] - for analysis in self.analysis_module: + for analysis in self.default_module: process = Process(target=analysis(self.params).run) process.start() process_list.append(process) for process in process_list: process.join() + + def recipe_analyze(self): + print("recipe analysis launched.") + try: + with Context.create_context(self.params.get(Constant.PARALLEL_MODE)) as context: + with self.params.get(Constant.RECIPE_CLASS)(self.params) as recipe: + recipe.run(context) + return recipe + except Exception as e: + print("[ERROR] recipe analysis launched failed.") \ No newline at end of file diff --git a/profiler/cluster_analyse/analysis/base_analysis.py b/profiler/cluster_analyse/analysis/base_analysis.py index cc803813d..8af86473a 100644 --- a/profiler/cluster_analyse/analysis/base_analysis.py +++ b/profiler/cluster_analyse/analysis/base_analysis.py @@ -2,6 +2,7 @@ from abc import abstractmethod from common_func.constant import Constant from utils.data_transfer_adapter import DataTransferAdapter from common_func.file_manager import FileManager +import os class BaseAnalysis: @@ -75,3 +76,35 @@ class BaseAnalysis: for rank_tup, group_dict in self.comm_ops_struct.items(): for step_id, communication_ops in group_dict.items(): self.compute_total_info(communication_ops) + + +class BaseRecipeAnalysis: + def __init__(self, params): + self._params = params + self._collection_dir = params.get(Constant.COLLECTION_PATH) + self._data_map = params.get(Constant.DATA_MAP) + self._recipe_name = params.get(Constant.RECIPE_NAME) + self._mode = params.get(Constant.PARALLEL_MODE) + self._analysis_dict = {} + + def __enter__(self): + return self + + def run(self, context): + self._analysis_dict = { + "Mode": self.get_mode(), + "RecipeName": self.get_recipe_name() + } + + def _get_rank_db(self): + db_paths = [os.path.join(rank_path, + Constant.CLUSTER_ANALYSIS_OUTPUT, + f"ascend_pytorch_profiler_{rank_id}.db") + for rank_id, rank_path in self._data_map.items()] + return db_paths + + def get_mode(self): + return self._mode + + def get_recipe_name(self): + return self._recipe_name \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/acl_api_sum.py b/profiler/cluster_analyse/analysis/cann_api_sum.py similarity index 55% rename from profiler/cluster_custom_analyse/analysis/acl_api_sum.py rename to profiler/cluster_analyse/analysis/cann_api_sum.py index 071b99d86..6c30fcb28 100644 --- a/profiler/cluster_custom_analyse/analysis/acl_api_sum.py +++ b/profiler/cluster_analyse/analysis/cann_api_sum.py @@ -13,16 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -from analysis.base_analysis import BaseAnalysis +from analysis.base_analysis import BaseRecipeAnalysis + +class CannApiSum(BaseRecipeAnalysis): + def __init__(self, params): + super().__init__(params) + print("CannApiSum init.") -class AclApiSum(BaseAnalysis): - @staticmethod def _mapper_func(): pass def mapper_func(self, context): - pass + return context.map( + self._mapper_func, + self._get_rank_db(), + xx + ) def reducer_func(self, mapper_res): + pass + + def run(self, context): + super().run(context) + + mapper_res = self.mapper_func(context) + self.reducer_func(mapper_res) + + self.save_notebook() + self.save_analysis_file() + + + def save_notebook(self): + pass + + def save_analysis_file(self): pass \ No newline at end of file diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 244546221..9caa3ff32 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -22,9 +22,29 @@ from communication_group.communication_group_generator import CommunicationGroup from common_func.constant import Constant from common_func.file_manager import FileManager from common_func.path_manager import PathManager +from common_func import analysis_loader from analysis.analysis_facade import AnalysisFacade +def get_analysis_args(analysis_class, analysis_args): + parser = argparse.ArgumentParser(description="custom analysis args") + parser.add_argument("--parallel_mode", type=str, help="context mode", default="concurrent") + return parser.parse_args(analysis_args) + +def parse_recipe_params(analysis_name, analysis_args): + analysis_class = analysis_loader.get_class_from_name(analysis_name) + if not analysis_class: + print("[ERROR] unknown analysis.") + return None + + args_parsed = get_analysis_args(analysis_class, analysis_args) + recipe_params = { + Constant.RECIPE_NAME: analysis_class[0], + Constant.RECIPE_CLASS: analysis_class[1], + Constant.PARALLEL_MODE: args_parsed.parallel_mode + } + return recipe_params + class Interface: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" @@ -37,6 +57,9 @@ class Interface: self.collective_group_dict = {} self.communication_ops = [] self.matrix_ops = [] + self.recipe_name = params.get(Constant.RECIPE_NAME) + self.recipe_class = params.get(Constant.RECIPE_CLASS) + self.recipe_parallel_mode = params.get(Constant.PARALLEL_MODE) def allocate_prof_data(self): ascend_pt_dirs = [] @@ -61,31 +84,48 @@ class Interface: PathManager.check_path_owner_consistent(self.collection_path) FileManager.create_output_dir(self.collection_path) data_map, data_type = self.allocate_prof_data() - if not data_map: - print("[WARNING] Can not get rank info or profiling data.") - return - if data_type == Constant.INVALID: - print("[ERROR] The current folder contains both DB and other files. Please check.") - return - params = { - Constant.COLLECTION_PATH: self.collection_path, - Constant.DATA_MAP: data_map, - Constant.ANALYSIS_MODE: self.analysis_mode, - Constant.DATA_TYPE: data_type - } - comm_data_dict = CommunicationGroupGenerator(params).generate() - params[Constant.COMM_DATA_DICT] = comm_data_dict - AnalysisFacade(params).cluster_analyze() + # if not data_map: + # print("[WARNING] Can not get rank info or profiling data.") + # return + # if data_type == Constant.INVALID: + # print("[ERROR] The current folder contains both DB and other files. Please check.") + # return + if self.analysis_mode == "recipe": + params = { + Constant.COLLECTION_PATH: self.collection_path, + Constant.DATA_MAP: data_map, + Constant.RECIPE_NAME: self.recipe_name, + Constant.RECIPE_CLASS: self.recipe_class, + Constant.PARALLEL_MODE: self.recipe_parallel_mode + } + AnalysisFacade(params).recipe_analyze() + else: + params = { + Constant.COLLECTION_PATH: self.collection_path, + Constant.DATA_MAP: data_map, + Constant.ANALYSIS_MODE: self.analysis_mode, + Constant.DATA_TYPE: data_type + } + comm_data_dict = CommunicationGroupGenerator(params).generate() + params[Constant.COMM_DATA_DICT] = comm_data_dict + AnalysisFacade(params).cluster_analyze() if __name__ == "__main__": parser = argparse.ArgumentParser(description="cluster analysis module") parser.add_argument('-d', '--collection_path', type=str, required=True, help="profiling data path") - parser.add_argument('-m', '--mode', choices=['all', 'communication_time', 'communication_matrix'], + parser.add_argument('-m', '--mode', choices=['all', 'communication_time', 'communication_matrix', 'recipe'], default='all', help="different analysis mode") - args_parsed = parser.parse_args() + #TODO 扩充mode的内容,把当前确定的mode类型改成不限定,all替换成default,以函数的形式实现 + args_parsed, args_remained = parser.parse_known_args() parameter = { Constant.COLLECTION_PATH: args_parsed.collection_path, Constant.ANALYSIS_MODE: args_parsed.mode } + if args_parsed.mode == 'recipe': + if not args_remained: + print("[ERROR] No recipe analysis specified.") + else: + parameter.update(parse_recipe_params(args_remained[0], args_remained[1:])) + # print(parameter) Interface(parameter).run() diff --git a/profiler/cluster_custom_analyse/common_func/analysis_loader.py b/profiler/cluster_analyse/common_func/analysis_loader.py similarity index 86% rename from profiler/cluster_custom_analyse/common_func/analysis_loader.py rename to profiler/cluster_analyse/common_func/analysis_loader.py index 7676d7587..374b87f59 100644 --- a/profiler/cluster_custom_analyse/common_func/analysis_loader.py +++ b/profiler/cluster_analyse/common_func/analysis_loader.py @@ -18,10 +18,10 @@ import inspect import sys from common_func.constant import Constant -from analysis.base_analysis import BaseAnalysis +from analysis.base_analysis import BaseRecipeAnalysis def is_analysis_class(obj): - return inspect.isclass(obj) and issubclass(obj, BaseAnalysis) + return inspect.isclass(obj) and issubclass(obj, BaseRecipeAnalysis) and obj != BaseRecipeAnalysis def get_class_from_name(analysis_name : str): sys.path.append(Constant.ANALYSIS_PATH) @@ -31,7 +31,8 @@ def get_class_from_name(analysis_name : str): module = importlib.import_module(analysis_path) except Exception as e: print(f"[ERROR] {analysis_path} not find:{e}") + specific_analysis = inspect.getmembers(module, is_analysis_class) if not specific_analysis: print(f"[ERROR] {analysis_name} not found.") - return specific_analysis[1] \ No newline at end of file + return specific_analysis[0] \ No newline at end of file diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 3b4126de7..c486bec7a 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -103,3 +103,11 @@ class Constant(object): CONFIG = "config" EXPER_CONFIG = "experimental_config" EXPORT_TYPE = "_export_type" + + # recipe config + RECIPE_NAME = "recipe_name" + RECIPE_CLASS = "recipe_class" + PARALLEL_MODE = "parallel_mode" + + SINGLE_MODE = "single" + CONCURRENT_MODE = "concurrent" \ No newline at end of file diff --git a/profiler/cluster_analyse/common_func/context.py b/profiler/cluster_analyse/common_func/context.py new file mode 100644 index 000000000..18e906846 --- /dev/null +++ b/profiler/cluster_analyse/common_func/context.py @@ -0,0 +1,94 @@ +# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from functools import partial +from concurrent import futures +from common_func.constant import Constant + + +class Context(object): + """abstract base class""" + + ctx_map = None + + @classmethod + def create_context(cls, mode=Constant.CONCURRENT_MODE): + if cls.ctx_map is None: + keys = [Constant.SINGLE_MODE, Constant.CONCURRENT_MODE] + values = [SingleContext, ConcurrentContext] + cls.ctx_map = dict(zip(keys, values)) + + if mode not in cls.ctx_map: + raise NotImplementedError("mode must be in {}".format(keys)) + + return cls.ctx_map[mode]() + + def __init__(self): + print("[INFO] context {} initialized.".format(self._mode)) + + def __enter__(self): + return self + + def __exit__(self): + pass + + def launch(self, func, *args, **kwargs): + raise NotImplementedError + + def map(self, func, *iterables, **kwargs): + raise NotImplementedError + + +class SingleContext(Context): + + def __init__(self): + self._mode = Constant.SINGLE_MODE + super().__init__() + def launch(self, func, *args, **kwargs): + return func(*args, **kwargs) + + def map(self, func, *iterables, **kwargs): + partial_func = partial(func, **kwargs) + return list(map(partial(self.launch, partial_func), *iterables)) + + +class ConcurrentContext(Context): + + def __init__(self, executor=None): + self._mode = Constant.CONCURRENT_MODE + super().__init__() + self._custom = executor is None + self._executor = executor or futures.ProcessPoolExecutor(max_workers=os.cpu_count()) + + def __enter__(self): + if self._executor is None: + raise RuntimeError("executor is None") + return self + + def __exit__(self): + self.close() + + def close(self): + if self._custom: + self._executor.shutdown(wait=True) + self._executor = None + + def launch(self, func, *args, **kwargs): + return self._executor.submit(func, *args, **kwargs).result() + + def map(self, func, *iterables, **kwargs): + partial_func = partial(func, **kwargs) + return list(self._executor.map(partial_func, *iterables)) \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/README.md b/profiler/cluster_custom_analyse/README.md deleted file mode 100644 index e69de29bb..000000000 diff --git a/profiler/cluster_custom_analyse/__init__.py b/profiler/cluster_custom_analyse/__init__.py deleted file mode 100644 index a0e9f748f..000000000 --- a/profiler/cluster_custom_analyse/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/__init__.py b/profiler/cluster_custom_analyse/analysis/__init__.py deleted file mode 100644 index a0e9f748f..000000000 --- a/profiler/cluster_custom_analyse/analysis/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/analysis/base_analysis.py b/profiler/cluster_custom_analyse/analysis/base_analysis.py deleted file mode 100644 index 60ea7e712..000000000 --- a/profiler/cluster_custom_analyse/analysis/base_analysis.py +++ /dev/null @@ -1,27 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -class BaseAnalysis: - def __init__(self, params): - self._params = params - self._output_dir = None - self._output_files = {} - self._analysis_dict = {} - - def __enter__(self): - return self - - def run(self, context): - self._analysis_dict = {} \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/cluster_custom_analysis.py b/profiler/cluster_custom_analyse/cluster_custom_analysis.py deleted file mode 100644 index 65467bc95..000000000 --- a/profiler/cluster_custom_analyse/cluster_custom_analysis.py +++ /dev/null @@ -1,55 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse -import os -import sys - -from common_func import analysis_loader - -def print_analyses_list(): - pass - -def run_custom_analysis(analysis_name, analysis_args): - analysis_class = analysis_loader.get_class_from_name(analysis_name) - if not analysis_class: - print("[ERROR] unknown analysis.") - return None - - args_parsed = get_analysis_args(analysis_class, analysis_args) - #TODO try - with Context.create_context(args_parsed.mode) as context: - with analysis_class(args_parsed) as analysis: - analysis.run(context) - return analysis - -def main(): - parser = argparse.ArgumentParser(description="cluster custome analysis module") - parser.add_argument('--analysis-help', action='store_true', help='Print available analyses') - - args_parsed, args_remained = parser.parse_known_args() - - if args_parsed.analysis_help: - print_analyses_list() - return - - if not args_remained: - print("[ERROR] No analysis specified.") - return - - analysis = run_custom_analysis(args_remained[0], args_remained[1:]) - -if __name__ == "__main__": - main() \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/common_func/__init__.py b/profiler/cluster_custom_analyse/common_func/__init__.py deleted file mode 100644 index a0e9f748f..000000000 --- a/profiler/cluster_custom_analyse/common_func/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file diff --git a/profiler/cluster_custom_analyse/common_func/constant.py b/profiler/cluster_custom_analyse/common_func/constant.py deleted file mode 100644 index ad6bfc915..000000000 --- a/profiler/cluster_custom_analyse/common_func/constant.py +++ /dev/null @@ -1,20 +0,0 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. -# All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os - -class Constant(object): - CLUSTER_CUSTOM_ANALYSE_PATH = os.path.abspath(os.path.dirname(__file__)) - ANALYSIS_PATH = os.path.join(CLUSTER_CUSTOM_ANALYSE_PATH, 'analysis') \ No newline at end of file -- Gitee From c6d70f5a7a690dc05519aa255e07d664e4401597 Mon Sep 17 00:00:00 2001 From: stby Date: Tue, 7 May 2024 16:53:16 +0800 Subject: [PATCH 3/3] add stats_export --- .../analysis/analysis_facade.py | 7 +- .../cluster_analyse/analysis/base_analysis.py | 26 ++++++-- .../cluster_analyse/analysis/cann_api_sum.py | 19 ++++-- profiler/cluster_analyse/cluster_analysis.py | 25 +++----- .../cluster_statistics_export/__init__.py | 14 ++++ .../cann_api_sum_export.py | 64 +++++++++++++++++++ .../cluster_statistics_export/stats_export.py | 38 +++++++++++ .../common_func/analysis_loader.py | 2 +- .../cluster_analyse/common_func/constant.py | 9 ++- .../cluster_analyse/common_func/context.py | 29 ++------- 10 files changed, 179 insertions(+), 54 deletions(-) create mode 100644 profiler/cluster_analyse/cluster_statistics_export/__init__.py create mode 100644 profiler/cluster_analyse/cluster_statistics_export/cann_api_sum_export.py create mode 100644 profiler/cluster_analyse/cluster_statistics_export/stats_export.py diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 2e3e81cc5..74ced48ce 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -39,11 +39,10 @@ class AnalysisFacade: process.join() def recipe_analyze(self): - print("recipe analysis launched.") + print("[INFO] Recipe analysis launched.") try: with Context.create_context(self.params.get(Constant.PARALLEL_MODE)) as context: with self.params.get(Constant.RECIPE_CLASS)(self.params) as recipe: recipe.run(context) - return recipe except Exception as e: - print("[ERROR] recipe analysis launched failed.") \ No newline at end of file + print("[ERROR] Recipe analysis launched failed.") \ No newline at end of file diff --git a/profiler/cluster_analyse/analysis/base_analysis.py b/profiler/cluster_analyse/analysis/base_analysis.py index 8af86473a..ce4216639 100644 --- a/profiler/cluster_analyse/analysis/base_analysis.py +++ b/profiler/cluster_analyse/analysis/base_analysis.py @@ -1,3 +1,18 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from abc import abstractmethod from common_func.constant import Constant from utils.data_transfer_adapter import DataTransferAdapter @@ -81,15 +96,18 @@ class BaseAnalysis: class BaseRecipeAnalysis: def __init__(self, params): self._params = params - self._collection_dir = params.get(Constant.COLLECTION_PATH) - self._data_map = params.get(Constant.DATA_MAP) - self._recipe_name = params.get(Constant.RECIPE_NAME) - self._mode = params.get(Constant.PARALLEL_MODE) + self._collection_dir = params.get(Constant.COLLECTION_PATH, "") + self._data_map = params.get(Constant.DATA_MAP, {}) + self._recipe_name = params.get(Constant.RECIPE_NAME, "") + self._mode = params.get(Constant.PARALLEL_MODE, "") self._analysis_dict = {} def __enter__(self): return self + def __exit__(self, exc_type, exc_val, exc_tb): + if self._params is not None and exc_type is not None: + print(f"[ERROR] Failed to exit analysis: {exc_val}") def run(self, context): self._analysis_dict = { "Mode": self.get_mode(), diff --git a/profiler/cluster_analyse/analysis/cann_api_sum.py b/profiler/cluster_analyse/analysis/cann_api_sum.py index 6c30fcb28..0265fb41d 100644 --- a/profiler/cluster_analyse/analysis/cann_api_sum.py +++ b/profiler/cluster_analyse/analysis/cann_api_sum.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,21 +14,28 @@ # limitations under the License. from analysis.base_analysis import BaseRecipeAnalysis - +from common_func.constant import Constant +from cluster_statistics_export.cann_api_sum_export import CannApiSumExport class CannApiSum(BaseRecipeAnalysis): def __init__(self, params): super().__init__(params) - print("CannApiSum init.") + print("[INFO] CannApiSum init.") @staticmethod - def _mapper_func(): - pass + def _mapper_func(db_path, params): + df = CannApiSumExport(db_path, params.get(Constant.RECIPE_NAME)).read_export_db() + + if df is None or df.empty: + print("[WARNING] There is no stats data.") + return None + + return df def mapper_func(self, context): return context.map( self._mapper_func, self._get_rank_db(), - xx + params=self._params ) def reducer_func(self, mapper_res): diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 9caa3ff32..beb582b95 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -34,7 +34,7 @@ def get_analysis_args(analysis_class, analysis_args): def parse_recipe_params(analysis_name, analysis_args): analysis_class = analysis_loader.get_class_from_name(analysis_name) if not analysis_class: - print("[ERROR] unknown analysis.") + print("[ERROR] undefined analysis.") return None args_parsed = get_analysis_args(analysis_class, analysis_args) @@ -84,12 +84,12 @@ class Interface: PathManager.check_path_owner_consistent(self.collection_path) FileManager.create_output_dir(self.collection_path) data_map, data_type = self.allocate_prof_data() - # if not data_map: - # print("[WARNING] Can not get rank info or profiling data.") - # return - # if data_type == Constant.INVALID: - # print("[ERROR] The current folder contains both DB and other files. Please check.") - # return + if not data_map: + print("[WARNING] Can not get rank info or profiling data.") + return + if data_type == Constant.INVALID: + print("[ERROR] The current folder contains both DB and other files. Please check.") + return if self.analysis_mode == "recipe": params = { Constant.COLLECTION_PATH: self.collection_path, @@ -114,18 +114,13 @@ class Interface: if __name__ == "__main__": parser = argparse.ArgumentParser(description="cluster analysis module") parser.add_argument('-d', '--collection_path', type=str, required=True, help="profiling data path") - parser.add_argument('-m', '--mode', choices=['all', 'communication_time', 'communication_matrix', 'recipe'], + parser.add_argument('-m', '--mode', choices=Constant.ALL_FEATURE_LIST, default='all', help="different analysis mode") - #TODO 扩充mode的内容,把当前确定的mode类型改成不限定,all替换成default,以函数的形式实现 args_parsed, args_remained = parser.parse_known_args() parameter = { Constant.COLLECTION_PATH: args_parsed.collection_path, Constant.ANALYSIS_MODE: args_parsed.mode } - if args_parsed.mode == 'recipe': - if not args_remained: - print("[ERROR] No recipe analysis specified.") - else: - parameter.update(parse_recipe_params(args_remained[0], args_remained[1:])) - # print(parameter) + if args_parsed.mode not in Constant.COMM_FEATURE_LIST: + parameter.update(parse_recipe_params(args_parsed.mode, args_remained)) Interface(parameter).run() diff --git a/profiler/cluster_analyse/cluster_statistics_export/__init__.py b/profiler/cluster_analyse/cluster_statistics_export/__init__.py new file mode 100644 index 000000000..7101187a2 --- /dev/null +++ b/profiler/cluster_analyse/cluster_statistics_export/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/profiler/cluster_analyse/cluster_statistics_export/cann_api_sum_export.py b/profiler/cluster_analyse/cluster_statistics_export/cann_api_sum_export.py new file mode 100644 index 000000000..9b1047978 --- /dev/null +++ b/profiler/cluster_analyse/cluster_statistics_export/cann_api_sum_export.py @@ -0,0 +1,64 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from cluster_statistics_export.stats_export import StatsExport + +QUERY = """ +WITH + summary as ( + SELECT + name, + sum(endNs - startNs) AS duration, + count (*) AS num, + avg(endNs - startNs) AS avg_duration, + min(endNs - startNs) AS min_duration, + median(endNs - startNs) AS median_duration, + max(endNs - startNs) AS max_duration, + stdev(endNs - startNs) AS stddev, + lower_quartile(endNs - startNs) AS q1, + upper_quartile(endNs - startNs) AS q3 + FROM + CANN_API + GROUP BY name + ), + totals AS ( + SELECT sum(duration) AS total + FROM summary + ) +SELECT + round(summary.duration * 100.0 / totals.total, 2) AS "duration_ratio: %", + summary.duration AS "Total Time: ns", + summary.num AS "Total Count", + round(summary.avg_duration, 1) AS "Average: ns", + summary.min_duration, 1 AS "Min: ns", + round(summary.median_duration, 1) AS "Med: ns", + summary.max_duration, 1 AS "Max: ns", + round(summary.stddev, 1) AS "StdDev: ns" + summary.q1 AS "Q1" + summary.q3 AS "Q3" +FROM + summary +LEFT JOIN + STRING_IDS AS ids + ON ids.id == summary.name +ORDER BY 2 DESC + """ +class CannApiSumExport(StatsExport): + + + def __init__(self, db_path, recipe_name): + super().__init__(db_path, recipe_name) + self._query = QUERY + print("[INFO] CannApiSumExport init.") \ No newline at end of file diff --git a/profiler/cluster_analyse/cluster_statistics_export/stats_export.py b/profiler/cluster_analyse/cluster_statistics_export/stats_export.py new file mode 100644 index 000000000..4e0a98beb --- /dev/null +++ b/profiler/cluster_analyse/cluster_statistics_export/stats_export.py @@ -0,0 +1,38 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from common_func.db_manager import DBManager + +class StatsExport: + + def __init__(self, db_path, recipe_name): + self._db_path = db_path + self._recipe_name = recipe_name + self._query = None + + def get_query(self): + return self._query + + def read_export_db(self): + query = self.get_query() + if query is None: + print(f"[ERROR] query is None.") + return + conn, cursor = DBManager.create_connect_db(self._db_path) + data = pd.read_sql(query, conn) + DBManager.destroy_db_connect(conn, cursor) + return data \ No newline at end of file diff --git a/profiler/cluster_analyse/common_func/analysis_loader.py b/profiler/cluster_analyse/common_func/analysis_loader.py index 374b87f59..f7761e589 100644 --- a/profiler/cluster_analyse/common_func/analysis_loader.py +++ b/profiler/cluster_analyse/common_func/analysis_loader.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index c486bec7a..61922f2a3 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os class Constant(object): # dir name @@ -108,6 +109,10 @@ class Constant(object): RECIPE_NAME = "recipe_name" RECIPE_CLASS = "recipe_class" PARALLEL_MODE = "parallel_mode" + CLUSTER_CUSTOM_ANALYSE_PATH = os.path.abspath(os.path.dirname(__file__)) + ANALYSIS_PATH = os.path.join(CLUSTER_CUSTOM_ANALYSE_PATH, 'analysis') - SINGLE_MODE = "single" - CONCURRENT_MODE = "concurrent" \ No newline at end of file + CONCURRENT_MODE = "concurrent" + + COMM_FEATURE_LIST = ['all', 'communication_time', 'communication_matrix'] + ALL_FEATURE_LIST = ['all', 'communication_time', 'communication_matrix', 'cann_api_sum'] \ No newline at end of file diff --git a/profiler/cluster_analyse/common_func/context.py b/profiler/cluster_analyse/common_func/context.py index 18e906846..3a9f1c2ad 100644 --- a/profiler/cluster_analyse/common_func/context.py +++ b/profiler/cluster_analyse/common_func/context.py @@ -1,4 +1,4 @@ -# Copyright (c) 2023, Huawei Technologies Co., Ltd. +# Copyright (c) 2024, Huawei Technologies Co., Ltd. # All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -27,8 +27,8 @@ class Context(object): @classmethod def create_context(cls, mode=Constant.CONCURRENT_MODE): if cls.ctx_map is None: - keys = [Constant.SINGLE_MODE, Constant.CONCURRENT_MODE] - values = [SingleContext, ConcurrentContext] + keys = [Constant.CONCURRENT_MODE] + values = [ConcurrentContext] cls.ctx_map = dict(zip(keys, values)) if mode not in cls.ctx_map: @@ -42,8 +42,10 @@ class Context(object): def __enter__(self): return self - def __exit__(self): - pass + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + if exc_type is not None: + print(f"[ERROR] Failed to exit context: {exc_val}") def launch(self, func, *args, **kwargs): raise NotImplementedError @@ -51,20 +53,6 @@ class Context(object): def map(self, func, *iterables, **kwargs): raise NotImplementedError - -class SingleContext(Context): - - def __init__(self): - self._mode = Constant.SINGLE_MODE - super().__init__() - def launch(self, func, *args, **kwargs): - return func(*args, **kwargs) - - def map(self, func, *iterables, **kwargs): - partial_func = partial(func, **kwargs) - return list(map(partial(self.launch, partial_func), *iterables)) - - class ConcurrentContext(Context): def __init__(self, executor=None): @@ -78,9 +66,6 @@ class ConcurrentContext(Context): raise RuntimeError("executor is None") return self - def __exit__(self): - self.close() - def close(self): if self._custom: self._executor.shutdown(wait=True) -- Gitee