diff --git a/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py index 77a7095abbba9af1bcd3750714dc73c34a15925d..f2b9c01b18026ae7fcf95772c67bb2e7fad07144 100644 --- a/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py +++ b/profiler/msprof_analyze/cluster_analyse/recipes/mstx2commop/mstx2commop.py @@ -15,7 +15,10 @@ import json import os +import shutil + import pandas as pd +from msprof_analyze.prof_common.path_manager import PathManager from msprof_analyze.cluster_analyse.recipes.base_recipe_analysis import BaseRecipeAnalysis from msprof_analyze.prof_common.db_manager import DBManager @@ -50,6 +53,7 @@ class Mstx2Commop(BaseRecipeAnalysis): logger.info("Mstx2Commop init.") self.communication_op = None self.string_ids_insert = None + self.set_output = Constant.CLUSTER_ANALYSIS_OUTPUT_PATH in params # 是否设置了output_path参数 @property def base_dir(self): @@ -170,9 +174,26 @@ class Mstx2Commop(BaseRecipeAnalysis): communication_op.set_index('opId', inplace=True) string_ids_insert = list(map(list, zip(special_id_list, special_primal_list))) - DBManager.insert_data_into_db(data_map.get(Constant.PROFILER_DB_PATH), TABLE_STRING_IDS, string_ids_insert) + new_profiler_db = self._prepare_output_profiler_db(data_map.get(Constant.PROFILER_DB_PATH)) + + DBManager.insert_data_into_db(new_profiler_db, TABLE_STRING_IDS, string_ids_insert) - self.dump_data(data=communication_op, file_name=data_map.get(Constant.PROFILER_DB_PATH), - table_name=TABLE_COMMUNICATION_OP, custom_db_path=data_map.get(Constant.PROFILER_DB_PATH)) + self.dump_data(data=communication_op, file_name="", table_name=TABLE_COMMUNICATION_OP, + custom_db_path=new_profiler_db) return data_map.get(Constant.RANK_ID) + + def _prepare_output_profiler_db(self, profiler_db_path): + """ + copy profiler_db to output if not exist + """ + output_dir = os.path.join(self._cluster_analysis_output_path, self._recipe_name) + relative_db_path = os.path.relpath(profiler_db_path, start=self._collection_dir) + relative_dir = os.path.dirname(relative_db_path) + + new_path = os.path.join(output_dir, relative_dir) + new_db_path = os.path.join(output_dir, relative_db_path) + PathManager.make_dir_safety(new_path) + shutil.copyfile(profiler_db_path, new_db_path) + return new_db_path +