From 9b6f0b7c244cdf7db7d79a4e406187b1659991e2 Mon Sep 17 00:00:00 2001 From: feng123www Date: Wed, 21 Feb 2024 09:57:35 +0800 Subject: [PATCH] =?UTF-8?q?=E8=9E=8D=E5=90=88=E7=AE=97=E5=AD=90=E8=AF=86?= =?UTF-8?q?=E5=88=AB=E5=8F=8A=E7=94=A8=E6=88=B7=E8=BD=AF=E4=BB=B6=E6=A0=88?= =?UTF-8?q?=E8=8E=B7=E5=8F=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../common_func_advisor/constant.py | 8 +- .../common_func_advisor/trace_view_json.py | 182 ++++++++++++++++ .../compute_advice/compute_advice_base.py | 28 ++- .../{analyser.py => csv_analyzer.py} | 27 ++- .../compute_advice/npu_fused/json_analyzer.py | 55 +++++ .../compute_advice/npu_fused_advice.py | 42 ++-- profiler/advisor/compute_perf_analysis.ipynb | 99 ++++----- .../compute_advice/test_npufused_advice.py | 204 ++++++++++++++---- 8 files changed, 523 insertions(+), 122 deletions(-) create mode 100644 profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py rename profiler/advisor/advisor_backend/compute_advice/npu_fused/{analyser.py => csv_analyzer.py} (81%) create mode 100644 profiler/advisor/advisor_backend/compute_advice/npu_fused/json_analyzer.py diff --git a/profiler/advisor/advisor_backend/common_func_advisor/constant.py b/profiler/advisor/advisor_backend/common_func_advisor/constant.py index 86b98c2873..34879db9f2 100644 --- a/profiler/advisor/advisor_backend/common_func_advisor/constant.py +++ b/profiler/advisor/advisor_backend/common_func_advisor/constant.py @@ -32,7 +32,7 @@ class Constant: SLOW_RANK = "slow rank" SLOW_LINK = "slow link" KERNEL = "kernel" - + # compute NPU_FUSED = "npu_fused" @@ -103,4 +103,8 @@ class Constant: ("Mul", "AsStrided", "Neg", "AsStrided", "ConcatD", "Mul", "Add"): "RotaryMul", ("Mul", "Slice", "Neg", "Slice", "ConcatD", "Mul", "Add"): "RotaryMul", ("MatMulV2", "Swish", "MatMulV2", "Mul", "MatMulV2"): "FFN", - ("Transpose", "Transpose", "GatherElement", "Transpose"): "GatherElement"} + ("Transpose", "Transpose", "GatherElement", "Transpose"): "GatherElement", + ("Slice", "Slice", "Swish", "Mul"): "torch_npu.npu_swiglu", + ("Cast", "Mul", "MaskedFill", "SoftmaxV2", "Cast"): "torch_npu.npu_scaled_masked_softmax", + ("Mul", "Slice", "Neg", "Slice", "ConcatD", "Mul"): "torch_npu.npu_rotary_mul", + ("Cast", "Square", "ReduceMeanD", "Add", "Rsqrt", "Mul", "Cast", "Mul"): "torch_npu.npu_rms_norm"} diff --git a/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py new file mode 100644 index 0000000000..08ef028765 --- /dev/null +++ b/profiler/advisor/advisor_backend/common_func_advisor/trace_view_json.py @@ -0,0 +1,182 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from abc import abstractmethod +from dataclasses import dataclass +from dataclasses import field +from typing import Dict +from typing import List + +from common_func.file_manager import FileManager + + +@dataclass +class TraceObj: + ph: str = "" + bp: str = "" + cat: str = "" + name: str = "" + pid: int = 0 + tid: int = 0 + id: int = 0 + ts: str = "" + dur: float = 0.0 + args: dict = field(default='unknown') + + @abstractmethod + def hash(self): + raise Exception("To be implemented") + + def valid(self): + return self.name != "" + + def check_hashable(self): + if not self.valid(): + raise Exception("Illegal {} to hash".format(self.__class__.name)) + + +@dataclass +class Process(TraceObj): + def hash(self): + self.check_hashable() + # msprof 保证name唯一性 + return self.args.get("name") + + +@dataclass +class Thread(TraceObj): + def hash(self): + self.check_hashable() + # msprof 保证name唯一性 + return self.args.get("name") + + +@dataclass +class DurationEvent(TraceObj): + def hash(self): + self.check_hashable() + return self.ts + + +@dataclass +class FlowEvent(TraceObj): + s_point_ts: str = "" + e_point_ts: str = "" + + def hash(self): + self.check_hashable() + return self.e_point_ts + + +class TraceViewJson: + + def __init__(self, path): + self.processes: Dict[str, Process] = dict() + self.threads: Dict[str, Thread] = dict() + self.python_dur_events: Dict[str, DurationEvent] = dict() + self.cann_dur_events: Dict[str, DurationEvent] = dict() + self.ascend_hardware_dur_events: Dict[str, DurationEvent] = dict() + self.torch_2_npu_flow_events: Dict[str, FlowEvent] = dict() + + traces = FileManager.read_json_file(path) + self._load_obj(traces) + + def get_torch_2_npu_flow_event(self, end_time) -> FlowEvent: + if not self.torch_2_npu_flow_events or not self.torch_2_npu_flow_events.get(end_time): + print("[ERROR] Find flow event failed for ts: {}".format(end_time)) + return FlowEvent() + return self.torch_2_npu_flow_events.get(end_time) + + def get_python_dur_events_contain_ts(self, ts) -> List[DurationEvent]: + res = [] + for event in self.python_dur_events.values(): + if float(event.ts) <= float(ts) <= float(event.ts) + event.dur: + res.append(event) + return res + + def _load_obj(self, traces): + self._load_format(traces) + if not self._check_format(): + print("[ERROR] parse json failed for error format") + return + self._load_duration_events(traces) + self._load_torch_to_npu_flow_events(traces) + + def _check_format(self): + # 当前功能只需要这两个process,可扩展 + check_processes = ['Python', 'Ascend Hardware'] + for check_process in check_processes: + if check_process in self.processes: + continue + print("[ERROR] {} process not found in json.".format(check_process)) + return False + return True + + # 加载pid, tid头 + def _load_format(self, traces: List[Dict]): + for i, trace in enumerate(traces): + if trace.get('name') == 'process_name': + if not trace.get('args') or not trace.get('args').get('name') or not trace.get('pid'): + continue + process = Process(**trace) + self.processes[process.hash()] = process + if trace.get('name') == 'thread_name': + if not trace.get('args') or not trace.get('args').get('name') or not trace.get('tid'): + continue + thread = Thread(**trace) + self.threads[thread.hash()] = thread + + def _load_duration_events(self, traces: List[Dict]): + def check_events(_trace): + return _trace.get('name') and _trace.get("ts") and _trace.get("dur") + + python_pid = self.processes.get("Python").pid + cann_pid = self.processes.get("CANN").pid + ascend_hardware_pid = self.processes.get("Ascend Hardware").pid + for i, trace in enumerate(traces): + if trace.get('ph') != 'X': + continue + if not check_events(trace): + continue + event = DurationEvent(**trace) + if trace.get('pid') == python_pid: + self.python_dur_events[event.hash()] = event + elif trace.get('pid') == cann_pid: + self.cann_dur_events[event.hash()] = event + elif trace.get("pid") == ascend_hardware_pid: + self.ascend_hardware_dur_events[event.hash()] = event + + def _load_torch_to_npu_flow_events(self, traces: List[Dict]): + def check_events(_trace): + return _trace.get('name') and _trace.get("id") and _trace.get("ts") + + flow_events_table_by_id = dict() + + python_pid = self.processes.get("Python") + for i, trace in enumerate(traces): + if trace.get('ph') != 's' and trace.get('ph') != 'f' and trace.get('pid') != python_pid: + continue + if not check_events(trace): + continue + event = flow_events_table_by_id.get(trace.get("id")) + if not event: + event = FlowEvent(**trace) + if trace.get('ph') == 's': + event.s_point_ts = trace.get('ts') + else: + event.e_point_ts = trace.get('ts') + flow_events_table_by_id[event.id] = event + + self.torch_2_npu_flow_events = {eve.hash(): eve for eve in flow_events_table_by_id.values()} diff --git a/profiler/advisor/advisor_backend/compute_advice/compute_advice_base.py b/profiler/advisor/advisor_backend/compute_advice/compute_advice_base.py index 84153727ec..8e09381876 100644 --- a/profiler/advisor/advisor_backend/compute_advice/compute_advice_base.py +++ b/profiler/advisor/advisor_backend/compute_advice/compute_advice_base.py @@ -18,6 +18,7 @@ from collections import defaultdict import os from advice_base import AdviceBase +from common_func.file_manager import FileManager class ComputeAdviceBase(AdviceBase): @@ -26,6 +27,7 @@ class ComputeAdviceBase(AdviceBase): self.kernel_details_path = "" self.has_preparse = False self.preparse_data = defaultdict(list) + self.call_stack = None def path_check(self): """ @@ -35,9 +37,11 @@ class ComputeAdviceBase(AdviceBase): print("[ERROR] Path: {} is not exist.".format(self.collection_path)) return False if os.path.isdir(self.collection_path) and self.collection_path.endswith("ascend_pt"): - self.kernel_details_path = os.path.join(self.collection_path, "ASCEND_PROFILER_OUTPUT", "kernel_details.csv") + self.kernel_details_path = os.path.join(self.collection_path, "ASCEND_PROFILER_OUTPUT", + "kernel_details.csv") if not os.path.exists(self.kernel_details_path): - print("[ERROR] kernel_details.csv is not exist in the Path: {}.".format(os.path.join(self.collection_path, "ASCEND_PROFILER_OUTPUT"))) + print("[ERROR] kernel_details.csv is not exist in the Path: {}.".format( + os.path.join(self.collection_path, "ASCEND_PROFILER_OUTPUT"))) return False elif os.path.isfile(self.collection_path) and os.path.basename(self.collection_path) == "kernel_details.csv": self.kernel_details_path = self.collection_path @@ -48,6 +52,26 @@ class ComputeAdviceBase(AdviceBase): self.preparse() return True + def has_callstack(self): + if self.call_stack is not None: + return self.call_stack + profiler_info_json_path = os.path.join(self.collection_path, "profiler_info.json") + self.trace_view_path = os.path.join(self.collection_path, self.ASCEND_PROFILER_OUTPUT, "trace_view.json") + if not os.path.exists(profiler_info_json_path) or not os.path.exists(self.trace_view_path): + self.call_stack = False + return self.call_stack + info = FileManager.read_json_file(profiler_info_json_path) + if not info.get("config") or not info.get("config").get("common_config") \ + or not info.get("config").get("common_config").get("with_stack"): + self.call_stack = False + return self.call_stack + activities = info.get("config").get("common_config").get("activities") + if not activities or "ProfilerActivity.CPU" not in activities: + self.call_stack = False + return self.call_stack + self.call_stack = info.get("config").get("common_config").get("with_stack") + return self.call_stack + @abstractmethod def run(self): """ diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused/analyser.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py similarity index 81% rename from profiler/advisor/advisor_backend/compute_advice/npu_fused/analyser.py rename to profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py index 74ea03b8d7..ccf69dfe97 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_fused/analyser.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused/csv_analyzer.py @@ -16,17 +16,20 @@ import multiprocessing import pandas as pd +import numpy as np from common_func_advisor.constant import Constant from .op_perf import OpPerfFactory -class Analyser: +class CSVAnalyzer: def __init__(self, path) -> None: self._path = path def process(self): - df = pd.read_csv(self._path) + df = pd.read_csv(self._path, dtype={"Start Time(us)": str}) + + pool = multiprocessing.Pool(multiprocessing.cpu_count()) # 数据预解析 result = pool.map(self.update_op_row, df.iterrows()) @@ -36,12 +39,15 @@ class Analyser: # 分析是否存在可融合的算子 op_type_list = preparse_df["Type"].tolist() duration_list = preparse_df["Duration(us)"].tolist() + start_times = preparse_df["Start Time(us)"].tolist() + # 去除末尾的\t分隔符 + start_times = [start_time[:-1] for start_time in start_times] result_list = [] for pattern in Constant.PATTERN_DICT.keys(): - result_list.extend(self.find_all_sub_lists(op_type_list, duration_list, pattern)) + result_list.extend(self.find_all_sub_lists(op_type_list, duration_list, start_times, pattern)) data_frame = pd.DataFrame(result_list) data_frame.columns = ["pattern_name", "pattern", "len", "count", "duration sum(us)", "op durations(us)", - "index"] + "index", "first_timestamp"] return data_frame @staticmethod @@ -49,7 +55,7 @@ class Analyser: return OpPerfFactory.build(row[1]).update() @staticmethod - def find_all_sub_lists(op_type_list, duration_list, expect_sub_list): + def find_all_sub_lists(op_type_list, duration_list, start_times, expect_sub_list): # 创建一个空字典,用来存储子列表和它们的出现次数和起始位置 len_sub_list = len(expect_sub_list) expect_sub_list = tuple(expect_sub_list) @@ -61,20 +67,25 @@ class Analyser: continue # 如果子列表已经在字典中,就增加它的出现次数,否则就初始化为1 if sublist in sublist_dict: + # count sublist_dict[sublist][0] += 1 + # index sublist_dict[sublist][1].append(i) + # total duration sublist_dict[sublist][2] += sum(duration_list[i:i + len_sub_list]) + # duration zip_data = zip(sublist_dict[sublist][3], duration_list[i:i + len_sub_list]) sublist_dict[sublist][3] = [a + b for a, b in zip_data] else: sublist_dict[sublist] = [1, [i], sum(duration_list[i:i + len_sub_list]), - duration_list[i:i + len_sub_list], len_sub_list] + duration_list[i:i + len_sub_list], len_sub_list, start_times[i]] # 创建一个空列表,用来存储所有重复的子列表 repeated_sublists = [] - for sublist, (count, index, duration_sum, op_durations, sublist_len) in sublist_dict.items(): + for sublist, (count, index, duration_sum, op_durations, sublist_len, first_time) in sublist_dict.items(): pattern_name = Constant.PATTERN_DICT.get(sublist, "unknown") op_durations = [round(num, 2) for num in op_durations] - repeated_sublists.append([pattern_name, sublist, sublist_len, count, duration_sum, op_durations, index]) + repeated_sublists.append([pattern_name, sublist, sublist_len, count, + duration_sum, op_durations, index, first_time]) if len(sublist_dict) == 0: pattern_name = Constant.PATTERN_DICT.get(expect_sub_list, "unknown") repeated_sublists.append([pattern_name, expect_sub_list, 0, 0, 0, 0, 0]) diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused/json_analyzer.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused/json_analyzer.py new file mode 100644 index 0000000000..fd2a72ffa3 --- /dev/null +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused/json_analyzer.py @@ -0,0 +1,55 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pandas as pd + +from common_func_advisor.trace_view_json import TraceViewJson + + +class JSONAnalyzer(object): + def __init__(self, path): + self._path = path + + def get_custom_code(self, data: pd.DataFrame, ts_col: str, output_col: str): + trace_json = TraceViewJson(self._path) + callstacks = pd.DataFrame(columns=[output_col]) + + for i, row in data.iterrows(): + if ts_col not in data.columns.tolist(): + print("[ERROR] No {} col found in data columns.".format(ts_col)) + return callstacks + timestamp = row[ts_col] + flow_event = trace_json.get_torch_2_npu_flow_event(timestamp) + if not flow_event.valid(): + print("[ERROR] Get flow event failed for pattern {}.".format(row['pattern'])) + callstacks.loc[i] = "" + continue + flow_event_s_key = flow_event.s_point_ts + python_dur_events = trace_json.get_python_dur_events_contain_ts(flow_event_s_key) + if not python_dur_events: + print("[ERROR] No python dur event found for pattern {}.".format(row['pattern'])) + callstacks.loc[i] = "" + continue + # 保持新老版本callstack兼容性 + if python_dur_events[0].args.get("Call stack"): + # 旧版本 + callstack = python_dur_events[0].args.get("Call stack").split(";") + else: + python_dur_events.sort(key=lambda e: e.ts) + # 新版本 + callstack = [event.name for event in python_dur_events if event.cat == "python_function"] + callstack_str = "\n".join(callstack) + callstacks.loc[i] = callstack_str + return callstacks diff --git a/profiler/advisor/advisor_backend/compute_advice/npu_fused_advice.py b/profiler/advisor/advisor_backend/compute_advice/npu_fused_advice.py index 2aca8e47d7..2e40e77ce4 100644 --- a/profiler/advisor/advisor_backend/compute_advice/npu_fused_advice.py +++ b/profiler/advisor/advisor_backend/compute_advice/npu_fused_advice.py @@ -14,22 +14,27 @@ # limitations under the License. import os +import pandas as pd from compute_advice.compute_advice_base import ComputeAdviceBase -from compute_advice.npu_fused.analyser import Analyser -from common_func_advisor.constant import Constant +from compute_advice.npu_fused.csv_analyzer import CSVAnalyzer +from compute_advice.npu_fused.json_analyzer import JSONAnalyzer class NpuFusedAdvice(ComputeAdviceBase): + ASCEND_PT = 'ascend_pt' + ASCEND_PROFILER_OUTPUT = 'ASCEND_PROFILER_OUTPUT' + KERNEL_DETAIL_FILE = "kernel_details.csv" + TRACE_VIEW_FILE = "trace_view.json" + def __init__(self, collection_path: str): super().__init__(collection_path) self.cur_data = dict() self.cur_bottleneck = str() self.cur_advice = str() - - if collection_path.endswith(Constant.PT_PROF_SUFFIX): - self.collection_path = os.path.join(collection_path, - Constant.ASCEND_PROFILER_OUTPUT, Constant.KERNEL_DETAILS_CSV) + self.kernel_details_path = "" + self.trace_view_path = "" + self.call_stack = None def run(self): if not self.path_check(): @@ -39,20 +44,29 @@ class NpuFusedAdvice(ComputeAdviceBase): return self.output_format_data def process(self): - analyser = Analyser(self.collection_path) - self.cur_data = analyser.process() - self.cur_data = self.cur_data.sort_values(by='duration sum(us)', ascending=False) - filter_data = self.cur_data.get(self.cur_data.get("duration sum(us)", 0) > 0) - op_num = len(filter_data.index) + csv_analyzer = CSVAnalyzer(self.kernel_details_path) + all_pattern_data = csv_analyzer.process() + all_pattern_data = all_pattern_data.sort_values(by='duration sum(us)', ascending=False) + filter_data = all_pattern_data.get(all_pattern_data.get("duration sum(us)", 0) > 0) + if not self.has_callstack(): + print("[Warning] No call stack info found, advice will be incomplete") + self.cur_data = filter_data + else: + json_analyzer = JSONAnalyzer(self.trace_view_path) + custom_code = json_analyzer.get_custom_code(filter_data, "first_timestamp", "custom code") + self.cur_data = pd.concat([filter_data, custom_code], axis=1) + op_num = len(self.cur_data.index) op_dur = filter_data["duration sum(us)"].sum() if op_num > 0: index = 0 - self.cur_advice = "Advice:\n" + self.cur_advice = f"Advice {index}:\n" self.cur_bottleneck = f"The computing time of fusable op is {round(op_dur, 2)} ms." - for _, row in filter_data.iterrows(): + for _, row in self.cur_data.iterrows(): cur_op = "[" + ", ".join(row.loc["pattern"]) + "]" npu_fused_op = row.loc["pattern_name"] - self.cur_advice += f"Replace {cur_op} with {npu_fused_op}." + self.cur_advice += f"Replace {cur_op} with {npu_fused_op}. " + if self.call_stack: + self.cur_advice += f"This pattern first happened in: \n{row['custom code']}" if index != op_num - 1: self.cur_advice += "\n" index += 1 diff --git a/profiler/advisor/compute_perf_analysis.ipynb b/profiler/advisor/compute_perf_analysis.ipynb index 795632f84a..7037f078f5 100644 --- a/profiler/advisor/compute_perf_analysis.ipynb +++ b/profiler/advisor/compute_perf_analysis.ipynb @@ -2,10 +2,17 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, - "metadata": {}, + "execution_count": 2, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-21T09:19:13.937531900Z", + "start_time": "2024-02-21T09:19:13.267899500Z" + } + }, "outputs": [], "source": [ + "import pandas as pd\n", + "\n", "from advisor_backend.interface import Interface\n", "import numpy as np" ] @@ -16,64 +23,36 @@ "source": [ "# 算子调优分析\n", "## 1. 算子分析的数据准备\n", - "当前算子分析工具支持分析Ascend Pyorch Profiler方式生成的ascend_pt目录以及ascend_pt/ASCEND_PROFILER_OUTPUT/kernel_details.csv文件\n", + "当前算子分析工具支持分析Ascend Pyorch Profiler方式生成的ascend_pt目录\n", "## 2. 算子分析解决的问题\n", - "当前支持分析模型中存在可融合的小算子,并给出优化建议。" + "当前支持分析模型中存在可融合的小算子,并给出优化建议。\n", + "\n", + "\"更多融合算子信息,请查阅 https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/700alpha003/processormodel/hardwaredesc_0001.html" ] }, { "cell_type": "code", - "execution_count": 2, - "metadata": {}, + "execution_count": 4, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-21T09:21:08.118763500Z", + "start_time": "2024-02-21T09:21:07.392583800Z" + } + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "[INFO] Start to analyse the target file: [YOUR PATH]\n", - " pattern_name pattern len \\\n", - "0 bias_dropout_add (Add, DropOutDoMask, Add) 3 \n", - "7 AddLayerNorm (Add, LayerNormV3) 2 \n", - "10 FA (BatchMatMul, RealDiv, Add, Maximum, SoftmaxV2... 0 \n", - "16 FFN (MatMulV2, Swish, MatMulV2, Mul, MatMulV2) 0 \n", - "15 RotaryMul (Mul, Slice, Neg, Slice, ConcatD, Mul, Add) 0 \n", - "14 RotaryMul (Mul, AsStrided, Neg, AsStrided, ConcatD, Mul,... 0 \n", - "13 RotaryMul (Mul, Slice, Neg, Slice, ConcatD, Cast, Mul, Add) 0 \n", - "12 FA (BatchMatMulV2, RealDiv, Add, Cast, SoftmaxV2,... 0 \n", - "11 FA (BatchMatMulV2, RealDiv, Add, Cast, Maximum, C... 0 \n", - "9 RMSNorm (Cast, Square, MemSet, ReduceMean, Add, Rsqrt,... 0 \n", - "1 FA (BatchMatMul, Mul, Cast, Mul, MaskedFill, Soft... 0 \n", - "8 GeluAdd (Gelu, Add) 0 \n", - "6 AddLayerNorm (Add, LayerNorm) 0 \n", - "5 LayerNorm (Cast, LayerNorm, Cast) 0 \n", - "4 RMSNORM (Cast, Square, ReduceMeanD, Add, Rsqrt, Cast, ... 0 \n", - "3 FA (Transpose, BatchMatMulV2, Transpose, Transpos... 0 \n", - "2 FA (Transpose, Transpose, Transpose, Mul, Transpo... 0 \n", - "17 GatherElement (Transpose, Transpose, GatherElement, Transpose) 0 \n", + "[INFO] Start to analyse the target file: C:\\data\\ascend_pt\\ASCEND_PROFILER_OUTPUT\\kernel_details.csv\n", + " pattern_name pattern len count duration sum(us) op durations(us) index\n", + "18 torch_npu.npu_swiglu (Slice, Slice, Swish, Mul) 4 1 12.56 [3.14, 3.14, 3.14, 3.14] [0]\n", "\n", - " count duration sum(us) op durations(us) index \n", - "0 4 2178.16 [839.64, 464.04, 874.48] [52, 64, 87, 99] \n", - "7 4 2154.98 [874.48, 1280.5] [54, 66, 89, 101] \n", - "10 0 0.00 0 0 \n", - "16 0 0.00 0 0 \n", - "15 0 0.00 0 0 \n", - "14 0 0.00 0 0 \n", - "13 0 0.00 0 0 \n", - "12 0 0.00 0 0 \n", - "11 0 0.00 0 0 \n", - "9 0 0.00 0 0 \n", - "1 0 0.00 0 0 \n", - "8 0 0.00 0 0 \n", - "6 0 0.00 0 0 \n", - "5 0 0.00 0 0 \n", - "4 0 0.00 0 0 \n", - "3 0 0.00 0 0 \n", - "2 0 0.00 0 0 \n", - "17 0 0.00 0 0 \n", - "The computing time of fusable op is 4333.14 ms.\n", - "Advice:\n", - "Replace [Add, DropOutDoMask, Add] with bias_dropout_add.\n", - "Replace [Add, LayerNormV3] with AddLayerNorm.\n" + "\n", + "Advice 0:\n", + "Replace [Slice, Slice, Swish, Mul] with torch_npu.npu_swiglu. This pattern first happened in: \n", + "torch/nn/modules/module.py(1513): _call_impl\n", + "profiler_main.py(116):forward\n" ] } ], @@ -82,18 +61,30 @@ "compute_path = \"[YOUR PATH]\"\n", "interface = Interface(compute_path)\n", "data = interface.get_data('compute', 'npu_fused')\n", - "\n", - "print(data['data'])\n", - "print(data['bottleneck'])\n", + "pd.set_option('display.max_columns', None)\n", + "pd.set_option('display.width', 900)\n", + "print(data['data'].iloc[:, :-2])\n", + "print('\\n')\n", "print(data['advice'])" ] + }, + { + "cell_type": "code", + "outputs": [], + "source": [ + "\n", + "\n" + ], + "metadata": { + "collapsed": false + } } ], "metadata": { "kernelspec": { - "display_name": "qkd", + "name": "python3", "language": "python", - "name": "qkd" + "display_name": "Python 3 (ipykernel)" }, "language_info": { "codemirror_mode": { diff --git a/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npufused_advice.py b/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npufused_advice.py index bfefbdc7b1..55fdb9836e 100644 --- a/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npufused_advice.py +++ b/profiler/test/ut/advisor/advisor_backend/compute_advice/test_npufused_advice.py @@ -1,3 +1,4 @@ +import json import os import shutil import stat @@ -10,62 +11,181 @@ from advisor_backend.interface import Interface class TestComputeAdvice(unittest.TestCase): TMP_DIR = "./ascend_pt" + OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT" interface = None err_interface = None - @classmethod - def tearDownClass(cls) -> None: - super().tearDownClass() + def tearDown(self): if os.path.exists(TestComputeAdvice.TMP_DIR): shutil.rmtree(TestComputeAdvice.TMP_DIR) - @classmethod - def setUpClass(cls) -> None: - super().setUpClass() + def setUp(self): + if os.path.exists(TestComputeAdvice.TMP_DIR): + shutil.rmtree(TestComputeAdvice.TMP_DIR) if not os.path.exists(TestComputeAdvice.TMP_DIR): os.makedirs(TestComputeAdvice.TMP_DIR) + if not os.path.exists(TestComputeAdvice.OUTPUT_DIR): + os.makedirs(TestComputeAdvice.OUTPUT_DIR) + + @classmethod + def get_basic_trace_view(cls): + # Python pid + py_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 1, "args": {"name": "Python"}} + # ascend pid + ascend_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 4, "args": {"name": "Ascend Hardware"}} + # ascend pid + cann_pid_data = {"ph": "M", "name": "process_name", "tid": 0, "pid": 5, "args": {"name": "CANN"}} + # ascend hardware ops + ah_event1 = {"ph": "X", "name": "Slice1", "ts": "1699529623106750", "dur": 100, "tid": 3, "pid": 4, "args": {}} + ah_event2 = {"ph": "X", "name": "Slice2", "ts": "1699529623106751", "dur": 80, "tid": 3, "pid": 4, "args": {}} + # flow event + flow_event_s = {"ph": "s", "name": "link1", "id": 1, "tid": 3, "pid": 1, "ts": "200", "args": {}} + flow_event_e = {"ph": "f", "name": "link1", "id": 1, "tid": 3, "pid": 1, "ts": "1699529623106750", "args": {}} + return [py_pid_data, ascend_pid_data, cann_pid_data, ah_event1, ah_event2, flow_event_s, flow_event_e] + + @classmethod + def create_profiler_info_json(cls): + info = { + "config": { + "common_config": { + "with_stack": True, + "activities": ["ProfilerActivity.CPU", "ProfilerActivity.NPU"] + } + } + } + with os.fdopen(os.open(f"{TestComputeAdvice.TMP_DIR}/profiler_info.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(info)) + + @classmethod + def create_old_version_trace_view(cls): + basic_info = cls.get_basic_trace_view() + + # python ops + py_event1 = {"ph": "X", "cat": "python_function", "name": "aten::slice", "ts": "200", "dur": 100, "tid": 2, + "pid": 1, + "args": {"Call stack": "/root/test/slice.py(116);\r\n/root/torch/module.py"}} + py_event2 = {"ph": "X", "cat": "python_function", "name": "slice", "ts": "199", "dur": 200, "tid": 2, "pid": 1, + "args": {"Call stack": "/root/test/slice.py(116);\r\n/root/torch/module.py"}} + raw_data = [ + *basic_info, py_event1, py_event2 + ] + + with os.fdopen(os.open(f"{TestComputeAdvice.OUTPUT_DIR}/trace_view.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + @classmethod + def create_new_version_trace_view(cls): + basic_info = cls.get_basic_trace_view() + # python ops + py_event1 = {"ph": "X", "name": "aten::slice", "ts": "200", "dur": 100, "tid": 2, "pid": 1, "args": {}} + py_event2 = {"ph": "X", "name": "slice", "ts": "199", "dur": 105, "tid": 2, "pid": 1, "args": {}} + py_event3 = {"ph": "X", "cat": "python_function", "name": "/root/test/slice.py(116)", "ts": "198", "dur": 120, + "tid": 2, "pid": 1, + "args": {}} + py_event4 = {"ph": "X", "cat": "python_function", "name": "/root/torch/module.py", "ts": "197", "dur": 150, + "tid": 2, "pid": 1, "args": {}} + + raw_data = [ + *basic_info, py_event1, py_event2, py_event3, py_event4 + ] + + with os.fdopen(os.open(f"{TestComputeAdvice.OUTPUT_DIR}/trace_view.json", + os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: + fp.write(json.dumps(raw_data)) + + @classmethod + def create_kernel_details(cls): # create csv files - csv_header = ['Step Id', 'Model ID', 'Task ID', 'Stream ID', 'Name', 'Type', 'Accelerator Core', 'Start Time(us)', + csv_header = ['Step Id', 'Model ID', 'Task ID', 'Stream ID', 'Name', 'Type', 'Accelerator Core', + 'Start Time(us)', 'Duration(us)', 'Wait Time(us)', 'Block Dim', 'Mix Block Dim', 'Input Shapes', 'Input Data Types', - 'Input Formats', 'Output Shapes', 'Output Data Types', 'Output Formats', 'Context ID', 'aicore_time(us)', - 'aic_total_cycles', 'aic_mac_fp16_ratio', 'aic_mac_int8_ratio', 'aic_cube_fops', 'aic_vector_fops', - 'aiv_time(us)', 'aiv_total_cycles', 'aiv_vec_fp32_ratio', 'aiv_vec_fp16_ratio', 'aiv_vec_int32_ratio', + 'Input Formats', 'Output Shapes', 'Output Data Types', 'Output Formats', 'Context ID', + 'aicore_time(us)', + 'aic_total_cycles', 'aic_mac_fp16_ratio', 'aic_mac_int8_ratio', 'aic_cube_fops', + 'aic_vector_fops', + 'aiv_time(us)', 'aiv_total_cycles', 'aiv_vec_fp32_ratio', 'aiv_vec_fp16_ratio', + 'aiv_vec_int32_ratio', 'aiv_vec_misc_ratio', 'aiv_cube_fops', 'aiv_vector_fops'] - csv_row1 = [1, 4294967295, 1265, 16, 'Cast66', 'Cast', 'AI_VECTOR_CORE', 1699529623106750, 3.14, 261.56, 9, 0, '4,1025', - 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, 0.0062, + csv_row1 = [1, 4294967295, 1265, 16, 'Slice1', 'Slice', 'AI_VECTOR_CORE', "1699529623106750\t", 3.14, 261.56, 9, + 0, + '4,1025', + 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, + 0.0062, + 0, 0, 5856] + csv_row2 = [1, 4294967295, 1265, 16, 'Slice2', 'Slice', 'AI_VECTOR_CORE', "1699529623106751\t", 3.14, 261.56, 9, + 0, + '4,1025', + 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, + 0.0062, 0, 0, 5856] - with os.fdopen(os.open(f"{TestComputeAdvice.TMP_DIR}/err_file.csv", + csv_row3 = [1, 4294967295, 1265, 16, 'Swish1', 'Swish', 'AI_VECTOR_CORE', "1699529623106752\t", 3.14, 261.56, 9, + 0, + '4,1025', + 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, + 0.0062, + 0, 0, 5856] + csv_row4 = [1, 4294967295, 1265, 16, 'Mul1', 'Mul', 'AI_VECTOR_CORE', "1699529623106753\t", 3.14, 261.56, 9, 0, + '4,1025', + 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, + 0.0062, + 0, 0, 5856] + csv_row5 = [1, 4294967295, 1265, 16, 'Add1', 'Add', 'AI_VECTOR_CORE', "1699529623106754\t", 3.14, 261.56, 9, 0, + '4,1025', + 'INT64', 'FORMAT_ND', '4,1025', 'INT32', 'FORMAT_ND', 'N/A', 0, 0, 0, 0, 0, 0, 1.77, 29508, 0, 0, + 0.0062, + 0, 0, 5856] + with os.fdopen(os.open(f"{TestComputeAdvice.OUTPUT_DIR}/kernel_details.csv", os.O_WRONLY | os.O_CREAT, stat.S_IWUSR | stat.S_IRUSR), 'w') as fp: csv_writer = csv.writer(fp) csv_writer.writerow(csv_header) csv_writer.writerow(csv_row1) + csv_writer.writerow(csv_row2) + csv_writer.writerow(csv_row3) + csv_writer.writerow(csv_row4) + csv_writer.writerow(csv_row5) + + def test_run_should_return_empty_when_ascend_pt_path_not_exist(self): + interface = Interface("") + data = interface.get_data('compute', 'npu_fused') + self.assertEqual(0, len(data['advice'])) + self.assertEqual(0, len(data['data'])) + + def test_run_should_return_empty_when_csv_and_json_not_exist(self): + interface = Interface(self.TMP_DIR) + data = interface.get_data('compute', 'npu_fused') + self.assertEqual(0, len(data['advice'])) + self.assertEqual(0, len(data['data'])) + + def test_run_should_return_advice_without_callstack_when_json_not_exist(self): + self.create_kernel_details() + interface = Interface(self.TMP_DIR) + data = interface.get_data('compute', 'npu_fused') + res = ".py" in data['advice'] + self.assertFalse(res) + self.assertEqual(1, len(data['data'])) + + def test_run_should_return_1_fused_op_and_its_callstack_when_json_is_in_old_version(self): + self.create_profiler_info_json() + self.create_kernel_details() + self.create_old_version_trace_view() + interface = Interface(self.TMP_DIR) + data = interface.get_data('compute', 'npu_fused') + res = ".py" in data['advice'] + self.assertTrue(res) + self.assertEqual(1, len(data['data'])) + self.assertEqual(1, len(data['data']['custom code'])) + self.assertEqual(2, data['data']['custom code'].iloc[0].count("py")) - TestComputeAdvice.err_interface = Interface(os.path.join(TestComputeAdvice.TMP_DIR, "err_file.csv")) - TestComputeAdvice.interface = Interface(os.path.join(os.path.dirname(os.path.abspath(__file__)), "kernel_details.csv")) - - - def test_run(self): - dataset = TestComputeAdvice.err_interface.get_data('compute', 'npu_fused') - case_advice = dataset.get('advice') - case_bottleneck = dataset.get('bottleneck') - case_data = dataset.get('data') - self.assertEqual(0, len(case_advice)) - self.assertEqual(0, len(case_bottleneck)) - self.assertEqual(0, len(case_data)) - - dataset = TestComputeAdvice.interface.get_data('compute', 'npu_fused') - case_advice = dataset.get('advice') - case_bottleneck = dataset.get('bottleneck') - self.assertEqual(110, len(case_advice)) - self.assertEqual(47, len(case_bottleneck)) - case_data = dataset.get('data') - - entry_data = case_data.iloc[0] - self.assertEqual('bias_dropout_add', entry_data.loc['pattern_name']) - self.assertEqual(3, entry_data.loc['len']) - self.assertEqual(4, entry_data.loc['count']) - - entry_data = case_data.iloc[1] - self.assertEqual('AddLayerNorm', entry_data.loc['pattern_name']) - self.assertEqual(2, entry_data.loc['len']) - self.assertEqual(4, entry_data.loc['count']) + def test_run_should_return_1_fused_op_and_its_callstack_when_json_is_in_new_version(self): + self.create_profiler_info_json() + self.create_kernel_details() + self.create_new_version_trace_view() + interface = Interface(self.TMP_DIR) + data = interface.get_data('compute', 'npu_fused') + res = ".py" in data['advice'] + self.assertTrue(res) + self.assertEqual(1, len(data['data'])) + self.assertEqual(1, len(data['data']['custom code'])) + self.assertEqual(2, data['data']['custom code'].iloc[0].count("py")) -- Gitee