diff --git a/torch_npu/profiler/analysis/prof_view/_memory_prepare_parser.py b/torch_npu/profiler/analysis/prof_view/_memory_prepare_parser.py index 1111b6566c8b7d0692c4e2614b84454445a0a845..c8f947f004b11e4eb1249928b0f8a13c3c049337 100644 --- a/torch_npu/profiler/analysis/prof_view/_memory_prepare_parser.py +++ b/torch_npu/profiler/analysis/prof_view/_memory_prepare_parser.py @@ -50,21 +50,13 @@ class MemoryPrepareParser(BaseParser): self._enqueue_record_dict = {} # {corrid: enqueue} self._dequeue_pids = set() self._dequeue_tids = set() + self._torch_ops_index = {} + self._dequeue_record_index = {} + self._record_to_name = {} + self.torch_ops_tid_dict = defaultdict(list) ProfilerLogger.init(self._profiler_path, "MemoryPrepareParser") self.logger = ProfilerLogger.get_instance() - @staticmethod - def _find_torch_ops_by_binary_search(ts: int, torch_ops: list): - right = len(torch_ops) - 1 - left = 0 - while right > left: - mid = left + ceil((right - left) / 2) - if ts >= torch_ops[mid].start_time: - left = mid - else: - right = mid - 1 - return left - def run(self, deps_data: dict): self.logger.info("MemoryPrepareParser start.") try: @@ -88,8 +80,9 @@ class MemoryPrepareParser(BaseParser): self._init_queue_info() self._add_pta_memory_data() - def _find_matched_torch_op_name(self, mem_start_ts: int, torch_ops: list) -> str: - matched_torch_op_idx = self._find_torch_ops_by_binary_search(mem_start_ts, torch_ops) + def _find_matched_torch_op_name(self, mem_start_ts: int, tid: int) -> str: + matched_torch_op_idx = self._find_torch_ops_by_binary_search(mem_start_ts, tid) + torch_ops = self.torch_ops_tid_dict.get(tid, []) matched_torch_op = torch_ops[matched_torch_op_idx] while matched_torch_op.end_time < mem_start_ts: matched_torch_op = matched_torch_op.parent_node @@ -163,22 +156,35 @@ class MemoryPrepareParser(BaseParser): ret_list.append(data_buf[:]) return ret_list - def _find_dequeue_record_by_binary_search(self, ts: int, dequeue_records: list) -> int: - right = len(dequeue_records) - 1 - left = 0 - while right > left: - mid = left + ceil((right - left) / 2) - if ts >= dequeue_records[mid].ts: - left = mid + def _find_dequeue_record_by_binary_search(self, ts: int, record: MemoryUseBean) -> int: + dequeue_records = self._dequeue_record_dict[(record.pid, record.tid)] + if not self._dequeue_record_index.get((record.pid, record.tid)): + self._dequeue_record_index[(record.pid, record.tid)] = 0 + dequeue_num = len(dequeue_records) - 1 + while self._dequeue_record_index[(record.pid, record.tid)] <= dequeue_num: + if dequeue_records[self._dequeue_record_index[(record.pid, record.tid)]].ts <= ts: + self._dequeue_record_index[(record.pid, record.tid)] += 1 else: - right = mid - 1 - return left + break + return self._dequeue_record_index[(record.pid, record.tid)] - 1 + + def _find_torch_ops_by_binary_search(self, ts: int, tid: int): + torch_ops = self.torch_ops_tid_dict.get(tid, []) + if not self._torch_ops_index.get(tid): + self._torch_ops_index[tid] = 0 + torch_ops_num = len(torch_ops) - 1 + while self._torch_ops_index[tid] <= torch_ops_num: + if torch_ops[self._torch_ops_index[tid]].start_time <= ts: + self._torch_ops_index[tid] += 1 + else: + break + return self._torch_ops_index[tid] - 1 def _find_related_dequeue_record(self, record: MemoryUseBean) -> OpMarkBean: if not (record.pid in self._dequeue_pids and record.tid in self._dequeue_tids): return None + index = self._find_dequeue_record_by_binary_search(record.time_ns, record) dequeue_records = self._dequeue_record_dict[(record.pid, record.tid)] - index = self._find_dequeue_record_by_binary_search(record.time_ns, dequeue_records) if not (dequeue_records[index].ts <= record.time_ns < dequeue_records[index].ts + dequeue_records[index].dur): warn("Cannot find dequeue record matched memory record") @@ -204,22 +210,33 @@ class MemoryPrepareParser(BaseParser): return "" return self._get_aten_op_name_by_enqueue_record(enqueue_record, torch_ops) + def _get_op_name_of_record(self, record: MemoryUseBean) -> str: + op_name = self._record_to_name.get(record) + if op_name is not None: + return op_name + dequeue_record = self._find_related_dequeue_record(record) + if dequeue_record is None: + op_name = self._find_matched_torch_op_name(record.time_ns, record.tid) + else: + op_name = self._find_real_op_name_of_record(dequeue_record, self.torch_ops_tid_dict.get(dequeue_record.tid, [])) + self._record_to_name[record] = op_name + return op_name + def _complete_record_entry(self, ptr_records: list, torch_ops: list) -> list: ret_list = list() cann_path = ProfilerPathManager.get_cann_path(self._profiler_path) device_ids = ProfilerPathManager.get_device_id(cann_path) device_tag = "NPU:" + str(device_ids[0]) if len(device_ids) == 1 else "" - torch_ops = [torch_op for torch_op in torch_ops if torch_op.name != "empty_tensor" and torch_op.name != "malloc_workspace"] + if not self.torch_ops_tid_dict: + torch_ops = [torch_op for torch_op in torch_ops if torch_op.name != "empty_tensor" and torch_op.name != "malloc_workspace"] + for torch_op in torch_ops: + self.torch_ops_tid_dict[torch_op.event.tid].append(torch_op) for records in ptr_records: combine_data = list() records_len = len(records) if not records or records_len > 3: continue - dequeue_record = self._find_related_dequeue_record(records[0]) - if dequeue_record is None: - op_name = self._find_matched_torch_op_name(records[0].time_ns, torch_ops) - else: - op_name = self._find_real_op_name_of_record(dequeue_record, torch_ops) + op_name = self._get_op_name_of_record(records[0]) if records_len == 1: if hasattr(records[0], 'component_type') and records[0].component_type == Constant.CACHING_TYPE: self._incomplete_num += 2 @@ -257,17 +274,16 @@ class MemoryPrepareParser(BaseParser): cann_path = ProfilerPathManager.get_cann_path(self._profiler_path) device_ids = ProfilerPathManager.get_device_id(cann_path) device_index = device_ids[0] if len(device_ids) == 1 else -1 - torch_ops = [torch_op for torch_op in torch_ops if torch_op.name != "empty_tensor" and torch_op.name != "malloc_workspace"] + if not self.torch_ops_tid_dict: + torch_ops = [torch_op for torch_op in torch_ops if torch_op.name != "empty_tensor" and torch_op.name != "malloc_workspace"] + for torch_op in torch_ops: + self.torch_ops_tid_dict[torch_op.event.tid].append(torch_op) for records in ptr_records: combine_data = list() records_len = len(records) if not records or records_len > 3: continue - dequeue_record = self._find_related_dequeue_record(records[0]) - if dequeue_record is None: - op_name = self._find_matched_torch_op_name(records[0].time_ns, torch_ops) - else: - op_name = self._find_real_op_name_of_record(dequeue_record, torch_ops) + op_name = self._get_op_name_of_record(records[0]) if records_len == 1: if hasattr(records[0], 'component_type') and records[0].component_type == Constant.CACHING_TYPE: self._incomplete_num += 2