diff --git a/profiler/cluster_analyse/analysis/analysis_facade.py b/profiler/cluster_analyse/analysis/analysis_facade.py index 34228f97a25e70b4d4ff97541a0a9e7f77d4f365..67f557d3797e782f8b03ff885c1b4101d962fb3f 100644 --- a/profiler/cluster_analyse/analysis/analysis_facade.py +++ b/profiler/cluster_analyse/analysis/analysis_facade.py @@ -13,10 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from multiprocessing import Process +from multiprocessing import Pool from analysis.communication_analysis import CommunicationAnalysis from analysis.step_trace_time_analysis import StepTraceTimeAnalysis from analysis.communication_analysis import CommMatrixAnalysis +from common_func.constant import Constant class AnalysisFacade: @@ -27,11 +28,11 @@ class AnalysisFacade: def cluster_analyze(self): # 多个profiler用多进程处理 - process_list = [] - for analysis in self.analysis_module: - process = Process(target=analysis(self.param).run) - process.start() - process_list.append(process) - - for process in process_list: - process.join() + with Pool(processes=len(self.analysis_module)) as pool: + results = [pool.apply_async(analysis(self.param).run) for analysis in self.analysis_module] + outputs = [result.get() for result in results] + for i in outputs: + if i != 1: + return + print(f"[INFO] Cluster analysis success! Result is saved to {self.param.get(Constant.COLLECTION_PATH)}") + \ No newline at end of file diff --git a/profiler/cluster_analyse/analysis/communication_analysis.py b/profiler/cluster_analyse/analysis/communication_analysis.py index a3c51d46a9bf6a1be8bf117caf79e47732284cb4..fcb2149e5a93b92b30fd9a7f5a50a7a6b17a0b8c 100644 --- a/profiler/cluster_analyse/analysis/communication_analysis.py +++ b/profiler/cluster_analyse/analysis/communication_analysis.py @@ -85,6 +85,7 @@ class CommunicationAnalysis(BaseCommAnalysis): self.split_op_by_group() self.combine_ops_total_info() self.dump_data() + return 1 def compute_total_info(self, comm_ops: dict): if not comm_ops: @@ -161,6 +162,7 @@ class CommMatrixAnalysis(BaseCommAnalysis): self.split_op_by_group() self.combine_ops_total_info() self.dump_data() + return 1 def compute_total_info(self, step_dict: dict): self.merge_same_links(step_dict) diff --git a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py index d24a7f1fe635e62c0857e276578463539a61ee76..ffac66c597157790fa245baad516b968ba0964ca 100644 --- a/profiler/cluster_analyse/analysis/step_trace_time_analysis.py +++ b/profiler/cluster_analyse/analysis/step_trace_time_analysis.py @@ -47,6 +47,7 @@ class StepTraceTimeAnalysis: self.load_step_trace_time_data() self.analyze_step_time() self.dump_data() + return 1 def dump_data(self): if not self.step_data_list: