diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 11c52ae5b32c4727167e8bb8e75a16e4cceb24fa..596c04b07b5698207feb45739a4f24c2ed9387f9 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -14,7 +14,11 @@ # limitations under the License. import argparse +import multiprocessing import os +import platform +import sys +from multiprocessing import freeze_support from cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor from cluster_data_preprocess.mindspore_data_preprocessor import MindsporeDataPreprocessor @@ -26,11 +30,11 @@ from analysis.analysis_facade import AnalysisFacade COMM_FEATURE_LIST = ['all', 'communication_time', 'communication_matrix'] + class Interface: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" - def __init__(self, params: dict): self.collection_path = PathManager.get_realpath(params.get(Constant.COLLECTION_PATH)) self.analysis_mode = params.get(Constant.ANALYSIS_MODE) @@ -47,6 +51,7 @@ class Interface: if cluster_analysis_output_path: return PathManager.get_realpath(cluster_analysis_output_path) return self.collection_path + def allocate_prof_data(self): ascend_pt_dirs = [] ascend_ms_dirs = [] @@ -64,6 +69,7 @@ class Interface: print("[ERROR] Can not analyze pytorch and mindspore meantime.") return [] return (pt_data_map, data_type) if pt_data_map else (ms_data_map, Constant.TEXT) + def run(self): PathManager.check_input_directory_path(self.collection_path) PathManager.check_path_owner_consistent(self.collection_path) @@ -87,6 +93,7 @@ class Interface: params[Constant.COMM_DATA_DICT] = comm_data_dict AnalysisFacade(params).cluster_analyze() + def cluster_analysis_main(args=None): parser = argparse.ArgumentParser(description="cluster analysis module") parser.add_argument('-d', '--collection_path', type=str, required=True, help="profiling data path") @@ -102,5 +109,11 @@ def cluster_analysis_main(args=None): Interface(parameter).run() + if __name__ == "__main__": - cluster_analysis_main() + # 用于支持Windows和MacOS + if platform.system() == 'Darwin': + multiprocessing.set_start_method('fork') + if platform.system() == 'Windows': + freeze_support() + cluster_analysis_main(sys.argv)