diff --git a/profiler/cluster_analyse/cluster_analysis.py b/profiler/cluster_analyse/cluster_analysis.py index 68eae526fb05479bc8b93f3bfc51037df221dc25..24454622119acbb223c70dfea65d3b792b00444c 100644 --- a/profiler/cluster_analyse/cluster_analysis.py +++ b/profiler/cluster_analyse/cluster_analysis.py @@ -14,7 +14,6 @@ # limitations under the License. import argparse -import glob import os from cluster_data_preprocess.pytorch_data_preprocessor import PytorchDataPreprocessor @@ -29,8 +28,6 @@ from analysis.analysis_facade import AnalysisFacade class Interface: ASCEND_PT = "ascend_pt" ASCEND_MS = "ascend_ms" - DB_RESULT_INFO = "*.db" - ALL_RESULT_INFO = "*.*" def __init__(self, params: dict): self.collection_path = PathManager.get_realpath(params.get(Constant.COLLECTION_PATH)) @@ -41,25 +38,6 @@ class Interface: self.communication_ops = [] self.matrix_ops = [] - def check_db_or_other_files(self, data_map: dict) -> tuple: - type_db_count = 0 - type_text_count = 0 - for _, folder_path in data_map.items(): - folder_path = os.path.join(folder_path, Constant.SINGLE_OUTPUT) - db_files = glob.glob(os.path.join(folder_path, self.DB_RESULT_INFO)) - all_files = glob.glob(os.path.join(folder_path, self.ALL_RESULT_INFO)) - if all_files and db_files and len(all_files) != len(db_files): - return False, None - if db_files: - type_db_count += 1 - else: - type_text_count += 1 - if type_db_count == len(data_map): - return True, Constant.DB - if type_text_count == len(data_map): - return True, Constant.TEXT - return False, None - def allocate_prof_data(self): ascend_pt_dirs = [] ascend_ms_dirs = [] @@ -69,24 +47,25 @@ class Interface: ascend_pt_dirs.append(os.path.join(root, dir_name)) if dir_name.endswith(self.ASCEND_MS): ascend_ms_dirs.append(os.path.join(root, dir_name)) - pt_data_map = PytorchDataPreprocessor(ascend_pt_dirs).get_data_map() + pytorch_processor = PytorchDataPreprocessor(ascend_pt_dirs) + pt_data_map = pytorch_processor.get_data_map() + data_type = pytorch_processor.get_data_type() ms_data_map = MindsporeDataPreprocessor(ascend_ms_dirs).get_data_map() if pt_data_map and ms_data_map: print("[ERROR] Can not analyze pytorch and mindspore meantime.") return [] - return pt_data_map if pt_data_map else ms_data_map + return (pt_data_map, data_type) if pt_data_map else (ms_data_map, Constant.TEXT) def run(self): PathManager.check_input_directory_path(self.collection_path) PathManager.check_path_owner_consistent(self.collection_path) FileManager.create_output_dir(self.collection_path) - data_map = self.allocate_prof_data() + data_map, data_type = self.allocate_prof_data() if not data_map: print("[WARNING] Can not get rank info or profiling data.") return - is_valid, data_type = self.check_db_or_other_files(data_map) - if not is_valid: - print("[WARNING] The current folder contains both DB and other files. Please check.") + if data_type == Constant.INVALID: + print("[ERROR] The current folder contains both DB and other files. Please check.") return params = { Constant.COLLECTION_PATH: self.collection_path, diff --git a/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py b/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py index f1e4c062a7c05656980f0767a3180154e91942ae..7b5561284550f9f2776ccdcbb363cc8f1c7f2fbb 100644 --- a/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py +++ b/profiler/cluster_analyse/cluster_data_preprocess/pytorch_data_preprocessor.py @@ -12,9 +12,11 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import glob from collections import defaultdict import os + +from common_func.constant import Constant from common_func.file_manager import FileManager from common_func.path_manager import PathManager @@ -22,9 +24,13 @@ from common_func.path_manager import PathManager class PytorchDataPreprocessor: PROFILER_INFO_HEAD = 'profiler_info_' PROFILER_INFO_EXTENSION = '.json' + JSON_RESULT_INFO = "*.json" + CSV_RESULT_INFO = "*.csv" def __init__(self, path_list: str): self.path_list = path_list + self.db_count = 0 + self.text_count = 0 def get_data_map(self) -> dict: rank_id_map = defaultdict(list) @@ -33,6 +39,21 @@ class PytorchDataPreprocessor: if rank_id < 0: print('[Error]fail to get rankid or rankid invalid.') continue + folder_path = os.path.join(dir_name, Constant.SINGLE_OUTPUT) + db_files = glob.glob(os.path.join(folder_path, Constant.DB_COMMUNICATION_ANALYZER)) + text_files = (glob.glob(os.path.join(folder_path, self.JSON_RESULT_INFO)) + + glob.glob(os.path.join(folder_path, self.CSV_RESULT_INFO))) + if text_files and db_files: + print(f"[ERROR] Rank {rank_id} has both db and text files") + self.db_count, self.text_count = 1, 1 + break + if db_files: + self.db_count += 1 + elif text_files: + self.text_count += 1 + else: + print(f"[WARNING] Rank {rank_id} has no valid files") + continue rank_id_map[rank_id].append(dir_name) ret_dict = dict() @@ -55,3 +76,12 @@ class PytorchDataPreprocessor: rank_id = -1 return rank_id return -1 + + def get_data_type(self): + if self.db_count != 0 and self.text_count != 0: + return Constant.INVALID + if self.db_count != 0: + return Constant.DB + if self.text_count != 0: + return Constant.TEXT + return Constant.INVALID diff --git a/profiler/cluster_analyse/common_func/constant.py b/profiler/cluster_analyse/common_func/constant.py index 71caee40db8b58ff263ad5d7311e797684883f3d..200244aff4a039ec25dead6b2a9f92248b496f1e 100644 --- a/profiler/cluster_analyse/common_func/constant.py +++ b/profiler/cluster_analyse/common_func/constant.py @@ -84,6 +84,7 @@ class Constant(object): # result files type TEXT = "text" DB = "db" + INVALID = "invalid" # db name DB_COMMUNICATION_ANALYZER = "analysis.db"