diff --git a/torch_npu/profiler/analysis/prof_common_func/_constant.py b/torch_npu/profiler/analysis/prof_common_func/_constant.py index 37abe8cf0452b35fd70db1a5a5f409b827e32ec7..c84fb672e7dfdb0a0e92573e62b376c10d0ca2ce 100644 --- a/torch_npu/profiler/analysis/prof_common_func/_constant.py +++ b/torch_npu/profiler/analysis/prof_common_func/_constant.py @@ -317,7 +317,7 @@ class DbConstant(): TABLE_OPERATOR_MEMORY = "OP_MEMORY" TABLE_NPU_OP_MEM = "NPU_OP_MEM" TABLE_META_DATA = "META_DATA" - + # rank device map table name TABLE_RANK_DEVICE_MAP = "RANK_DEVICE_MAP" # host info diff --git a/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py b/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py index 2ae2ac6474e707e27f1717298b9de3cac8d725b9..b9b923d275289983c943fbcabb4832133fed676e 100644 --- a/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py +++ b/torch_npu/profiler/analysis/prof_view/prof_db_parse/_fwk_api_db_parser.py @@ -1,5 +1,7 @@ +from collections import defaultdict from enum import Enum from ...prof_common_func._db_manager import TorchDb +from ...prof_common_func._file_manager import FileManager from ...prof_common_func._id_manager import Str2IdManager, ConnectionIdManager, CallChainIdManager from ...prof_common_func._constant import Constant, DbConstant, TableColumnsManager from .._base_parser import BaseParser @@ -142,12 +144,13 @@ class FwkApiDbParser(BaseParser): if not cann_tx_apis: raise RuntimeWarning("Failed to get msprof_tx apis") mstx_mark_apis.sort(key=lambda x: x[TorchOpDataOri.START_NS.value]) - mstx_op_len = len(mstx_mark_apis) if task_enqueues and task_dequeues: - self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, mstx_mark_apis, mstx_op_len, - cann_tx_apis) + self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, mstx_mark_apis, cann_tx_apis) def get_torch_op_connection_ids_with_cann_api(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list): + torch_op_apis = [api for api in torch_op_apis if not api[TorchOpDataOri.NAME.value].startswith('ProfilerStep#')] + if not torch_op_apis: + return sql = "select id from {} where value = 'launch'".format(DbConstant.TABLE_STRING_IDS) node_launch_str_ids = TorchDb().fetch_one_data(sql) node_launch_str_id = 0 @@ -162,71 +165,158 @@ class FwkApiDbParser(BaseParser): if not node_launch_apis: raise RuntimeWarning("Failed to get node launch apis") torch_op_apis.sort(key=lambda x: x[TorchOpDataOri.START_NS.value]) - torch_op_len = len(torch_op_apis) + task_dequeues.sort(key=lambda x: x[TaskQueueDataOri.START_NS.value]) if task_enqueues and task_dequeues: - self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, torch_op_apis, torch_op_len, + self.get_torch_op_connection_ids_with_task_queue(task_enqueues, task_dequeues, torch_op_apis, node_launch_apis) else: - self.get_torch_op_connection_ids_without_task_queue(torch_op_apis, torch_op_len, node_launch_apis) - - def get_torch_op_connection_ids_with_task_queue(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list, torch_op_len: int, node_lauch_apis: list): - enqueue_corr_ids = {task_enqueue[TaskQueueDataOri.CORRELATION_ID.value] for task_enqueue in task_enqueues} - dequeue_corr_ids = {task_dequeue[TaskQueueDataOri.CORRELATION_ID.value] for task_dequeue in task_dequeues} - enqueue_list = [] - for task_enqueue in task_enqueues: - if task_enqueue[TaskQueueDataOri.CORRELATION_ID.value] in dequeue_corr_ids: - enqueue_list.append(task_enqueue) - dequeue_list = [] - for task_dequeue in task_dequeues: - if task_dequeue[TaskQueueDataOri.CORRELATION_ID.value] in enqueue_corr_ids: - dequeue_list.append(task_dequeue) - last_dequeue_index = 0 - last_torch_op_index = 0 - dequeue_len = len(dequeue_list) - for node_launch_api in node_lauch_apis: - for idx in range(last_dequeue_index, dequeue_len): - if node_launch_api[CannNodeLaunchApiOri.START_NS.value] > dequeue_list[idx][TaskQueueDataOri.START_NS.value] and \ - node_launch_api[CannNodeLaunchApiOri.END_NS.value] < dequeue_list[idx][TaskQueueDataOri.END_NS.value]: - last_dequeue_index = idx - enqeue = enqueue_list[idx] - last_torch_op_index = self.get_torch_op_connection_ids_with_enqueue(torch_op_apis, - torch_op_len, - enqeue, - last_torch_op_index, - node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID.value]) - break + self._get_torch_op_connection_ids_without_task_queue(torch_op_apis, node_launch_apis) + + def get_torch_op_connection_ids_with_task_queue(self, task_enqueues: list, task_dequeues: list, torch_op_apis: list, + node_launch_apis: list): + # 1. Match node launch and dequeue + dequeue_corrections_ids = self._match_node_launch_and_dequeue(node_launch_apis, task_dequeues) + + if not dequeue_corrections_ids: + return + + # 2. Match dequeue and enqueue + enqueue_dict = self._match_dequeue_and_enqueue(dequeue_corrections_ids, task_enqueues) + + # 3. Match enqueue and torch op + self._match_enqueue_and_torch_op(enqueue_dict, torch_op_apis) + + @staticmethod + def _match_node_launch_and_dequeue(node_launch_apis, task_dequeues): + dequeue_dict = defaultdict(list) + node_launch_dict = defaultdict(list) + for dequeue in task_dequeues: + dequeue_dict[dequeue[TaskQueueDataOri.GLOBAL_TID.value]].append(dequeue) + for node_launch in node_launch_apis: + node_launch_dict[node_launch[CannNodeLaunchApiOri.GLOBAL_TID.value]].append(node_launch) + common_keys = dequeue_dict.keys() & node_launch_dict.keys() + dequeue_dict = {k: dequeue_dict[k] for k in common_keys} + node_launch_dict = {k: node_launch_dict[k] for k in common_keys} + dequeue_corrections_ids = [] + for tid in common_keys: + dequeue_index = 0 + for node_launch in node_launch_dict[tid]: + while dequeue_index < len(dequeue_dict[tid]): + if dequeue_dict[tid][dequeue_index][TaskQueueDataOri.START_NS.value] > node_launch[ + CannNodeLaunchApiOri.START_NS.value]: + break + if dequeue_dict[tid][dequeue_index][TaskQueueDataOri.END_NS.value] < node_launch[ + CannNodeLaunchApiOri.START_NS.value]: + break + if ( + dequeue_dict[tid][dequeue_index][TaskQueueDataOri.START_NS.value] < node_launch[ + CannNodeLaunchApiOri.START_NS.value] + and dequeue_dict[tid][dequeue_index][TaskQueueDataOri.END_NS.value] > node_launch[ + CannNodeLaunchApiOri.END_NS.value] + ): + dequeue_correction_id = dequeue_dict[tid][dequeue_index][ + TaskQueueDataOri.CORRELATION_ID.value] + node_launch_correction_id = node_launch[CannNodeLaunchApiOri.CORRELATION_ID.value] + dequeue_corrections_ids.append([dequeue_correction_id, node_launch_correction_id]) + dequeue_index += 1 + return dequeue_corrections_ids - def get_torch_op_connection_ids_with_enqueue(self, torch_op_apis: list, torch_op_len: int, enqeue: list, last_torch_op_index: int, connection_id: int) -> int: - last_op_api = None - for idx in range(last_torch_op_index, torch_op_len): - if enqeue[TaskQueueDataOri.START_NS.value] > torch_op_apis[idx][TorchOpDataOri.END_NS.value]: - continue - if enqeue[TaskQueueDataOri.START_NS.value] > torch_op_apis[idx][TorchOpDataOri.START_NS.value] and enqeue[TaskQueueDataOri.END_NS.value] < torch_op_apis[idx][TorchOpDataOri.END_NS.value]: - last_op_api = torch_op_apis[idx] - last_torch_op_index = idx - elif last_op_api: - break - if last_op_api: - torch_op_apis[last_torch_op_index][TorchOpDataOri.CONNECTION_ID.value].append(connection_id) - return last_torch_op_index - - def get_torch_op_connection_ids_without_task_queue(self, torch_op_apis: list, torch_op_len: int, node_lauch_apis: list): - last_op_api = None - last_op_index = 0 - for node_launch_api in node_lauch_apis: - for idx in range(last_op_index, torch_op_len): - if torch_op_apis[idx][TorchOpDataOri.GLOBAL_TID.value] != node_launch_api[CannNodeLaunchApiOri.GLOBAL_TID.value]: - continue - if node_launch_api[CannNodeLaunchApiOri.START_NS.value] > torch_op_apis[idx][TorchOpDataOri.END_NS.value]: - continue - if node_launch_api[CannNodeLaunchApiOri.START_NS.value] > torch_op_apis[idx][TorchOpDataOri.START_NS.value] and \ - node_launch_api[CannNodeLaunchApiOri.END_NS.value] < torch_op_apis[idx][TorchOpDataOri.END_NS.value]: - last_op_api = torch_op_apis[idx] - last_op_index = idx - elif last_op_api: - torch_op_apis[last_op_index][TorchOpDataOri.CONNECTION_ID.value].append(node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID.value]) - last_op_api = None + @staticmethod + def _match_dequeue_and_enqueue(dequeue_corrections_ids, task_enqueues): + dequeue_corrections_ids.sort(key=lambda x: x[0]) + task_enqueues.sort(key=lambda x: x[TaskQueueDataOri.CORRELATION_ID.value]) + enqueue_dict = defaultdict(list) + idx = 0 + for enqueue in task_enqueues: + while idx < len(dequeue_corrections_ids): + if enqueue[TaskQueueDataOri.CORRELATION_ID.value] < dequeue_corrections_ids[idx][0]: + break + if enqueue[TaskQueueDataOri.CORRELATION_ID.value] == dequeue_corrections_ids[idx][0]: + enqueue_dict[enqueue[TaskQueueDataOri.GLOBAL_TID.value]].append(( + dequeue_corrections_ids[idx][1], + enqueue[TaskQueueDataOri.START_NS.value], + enqueue[TaskQueueDataOri.END_NS.value] + )) + idx += 1 break + idx += 1 + return enqueue_dict + + @staticmethod + def _match_enqueue_and_torch_op(enqueue_dict, torch_op_apis): + torch_op_dict = defaultdict(list) + for torch_op_api in torch_op_apis: + torch_op_dict[torch_op_api[TorchOpDataOri.GLOBAL_TID.value]].append(torch_op_api) + common_keys = enqueue_dict.keys() & torch_op_dict.keys() + enqueue_dict = {k: enqueue_dict[k] for k in common_keys} + torch_op_dict = {k: torch_op_dict[k] for k in common_keys} + for tid in common_keys: + enqueues = enqueue_dict[tid] + torch_ops = torch_op_dict[tid] + torch_ops_len = len(torch_ops) + last_torch_op_index = 0 + for correction_id, enqueue_start_time, enqueue_end_time in enqueues: + last_torch_op_api = None + while last_torch_op_index < torch_ops_len: + current_op = torch_ops[last_torch_op_index] + op_start = current_op[TorchOpDataOri.START_NS.value] + op_end = current_op[TorchOpDataOri.END_NS.value] + if op_start > enqueue_start_time: + break + if op_end < enqueue_start_time: + last_torch_op_index += 1 + continue + if op_start < enqueue_start_time and op_end > enqueue_end_time: + last_torch_op_api = current_op + last_torch_op_index += 1 + else: + if last_torch_op_api: + break + last_torch_op_index += 1 + + if last_torch_op_api: + torch_ops[last_torch_op_index - 1][TorchOpDataOri.CONNECTION_ID.value].append(correction_id) + + @staticmethod + def _get_torch_op_connection_ids_without_task_queue(torch_op_apis: list, node_launch_apis: list): + torch_op_dict = defaultdict(list) + node_launch_dict = defaultdict(list) + for torch_op_api in torch_op_apis: + torch_op_dict[torch_op_api[TorchOpDataOri.GLOBAL_TID.value]].append(torch_op_api) + for node_launch in node_launch_apis: + node_launch_dict[node_launch[CannNodeLaunchApiOri.GLOBAL_TID.value]].append(node_launch) + common_keys = torch_op_dict.keys() & node_launch_dict.keys() + node_launch_dict = {k: node_launch_dict[k] for k in common_keys} + torch_op_dict = {k: torch_op_dict[k] for k in common_keys} + for tid in common_keys: + node_launch_apis = node_launch_dict[tid] + torch_ops = torch_op_dict[tid] + torch_ops_len = len(torch_ops) + for node_launch_api in node_launch_apis: + last_torch_op_api = None + last_torch_op_index = 0 + node_start = node_launch_api[CannNodeLaunchApiOri.START_NS.value] + node_end = node_launch_api[CannNodeLaunchApiOri.END_NS.value] + node_corr_id = node_launch_api[CannNodeLaunchApiOri.CORRELATION_ID.value] + while last_torch_op_index < torch_ops_len: + current_op = torch_ops[last_torch_op_index] + op_start = current_op[TorchOpDataOri.START_NS.value] + op_end = current_op[TorchOpDataOri.END_NS.value] + if op_start > node_start: + break + if node_start > op_end: + last_torch_op_index += 1 + continue + if node_start > op_start and node_end < op_end: + last_torch_op_api = current_op + last_torch_op_index += 1 + else: + if last_torch_op_api: + break + last_torch_op_index += 1 + + if last_torch_op_api: + torch_ops[last_torch_op_index - 1][TorchOpDataOri.CONNECTION_ID.value].append(node_corr_id) def set_start_string_id(self): Str2IdManager().set_start_id(DbConstant.START_STRING_ID_FWK_API)