diff --git a/profiler/advisor/analyzer/computation/ai_core_performance/__init__.py b/profiler/advisor/analyzer/computation/ai_core_performance/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py b/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b6be77957d4fb00f1b47dfb0b683c37454bd2d
--- /dev/null
+++ b/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_analyzer.py
@@ -0,0 +1,54 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+
+from profiler.advisor.analyzer.base_analyzer import BaseAnalyzer
+from profiler.advisor.analyzer.computation.ai_core_performance.ai_core_performance_checker import \
+ AICorePerformanceChecker
+from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset
+from profiler.advisor.result.result import OptimizeResult
+from profiler.advisor.display.html.priority_background_color import PriorityBackgroundColor
+from profiler.advisor.display.html.render import HTMLRender
+
+logger = logging.getLogger()
+
+
+class AICorePerformanceAnalyzer(BaseAnalyzer):
+ dataset_cls_list = [ProfilingDataset]
+
+ def __init__(self, collection_path, n_processes: int = 1, **kwargs) -> None:
+ super().__init__(collection_path, n_processes, **kwargs)
+ profiling_key = ProfilingDataset.get_key()
+ self.profiling_dataset = self.get_first_data_by_key(self.dataset_list, profiling_key)
+ self.result = OptimizeResult()
+ self.html_render = HTMLRender()
+ self.html = None
+
+ def optimize(self, **kwargs):
+ add_render_list = kwargs.get("add_render_list", True)
+ ai_core_perf_checker = AICorePerformanceChecker()
+ ai_core_perf_checker.data_filter(self.profiling_dataset)
+ if not ai_core_perf_checker.ai_core_performance_issues:
+ return self.result
+ ai_core_perf_checker.check_ai_core_performance(self.profiling_dataset)
+ ai_core_perf_checker.make_record(self.result)
+ self.html = ai_core_perf_checker.make_render(self.html_render,
+ add_render_list,
+ priority=self.get_priority(),
+ rank=kwargs.get("rank"))
+ return self.result
+
+ def get_priority(self, max_mem_op_dur=None):
+ return PriorityBackgroundColor.low
\ No newline at end of file
diff --git a/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py b/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3622ebdf0f64aaa19f3891c692a92e9b1343681
--- /dev/null
+++ b/profiler/advisor/analyzer/computation/ai_core_performance/ai_core_performance_checker.py
@@ -0,0 +1,542 @@
+# Copyright (c) 2024, Huawei Technologies Co., Ltd.
+# All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os
+from functools import reduce
+
+from profiler.advisor.dataset.profiling.profiling_dataset import ProfilingDataset
+from profiler.advisor.result.item import OptimizeItem, OptimizeRecord
+from profiler.advisor.result.result import OptimizeResult
+from profiler.prof_common.additional_args_manager import AdditionalArgsManager
+from profiler.prof_common.file_manager import FileManager
+
+logger = logging.getLogger()
+
+
+class AICorePerformanceChecker:
+ """
+ operator performance checker
+ """
+ _CHECKER = "AICorePerformanceChecker"
+ CUBE_OPERATOR_MEMORY_SIZE_MB = 100
+
+ def __init__(self):
+
+ self.result = dict()
+ self.ai_core_performance_issues = False
+ self.desc = ""
+ self.cube_dict = {}
+ self.fa_dict = {}
+ self.fa_list = []
+ self.vector_dict = {}
+ self.load_aicore_perf_rules()
+
+ def load_aicore_perf_rules(self):
+ language = AdditionalArgsManager().language
+ rule_path = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
+ "rules",
+ language,
+ "aicore_performance.yaml"
+ )
+
+ if not os.path.exists(rule_path):
+ logger.warning("Skip analyze aicpu issues, because %s does not exist.", rule_path)
+
+ self.language = language
+ self.aicore_rules = FileManager.read_yaml_file(rule_path)
+ self._CUBE_PROBLEM = self.aicore_rules.get("cube_problem")
+ self._FA_PROBLEM = self.aicore_rules.get("fa_problem")
+ self._VECTOR_PROBLEM = self.aicore_rules.get("vector_problem")
+ self.desc = self.aicore_rules.get("description")
+ self._BOUND_DESC = self.aicore_rules.get("bound_description")
+ self._OPTI_DESC = self.aicore_rules.get("optimization_description")
+ self._AFFINITY_DESC = self.aicore_rules.get("affinity_description")
+ self._CUBE_AFFINITY_DESC = self.aicore_rules.get("cube_affinity_desc")
+ self._FA_AFFINITY_DESC_TYPE1 = self.aicore_rules.get("fa_affinity_desc_type1")
+ self._FA_AFFINITY_DESC_TYPE2 = self.aicore_rules.get("fa_affinity_desc_type2")
+ self._FA_AFFINITY_DESC_TYPE3 = self.aicore_rules.get("fa_affinity_desc_type3")
+ self.suggestion = self.aicore_rules.get("suggestion")
+ self._AFFINITY_SUGGESTION = self.aicore_rules.get("affinity_suggestion")
+ self._BOUND_SUGGESTION = self.aicore_rules.get("bound_suggestion")
+ self._OPTI_SUGGESTION = self.aicore_rules.get("optimization_suggestion")
+
+ def data_filter(self, profiling_dataset: ProfilingDataset):
+ if not self.check_task_list(profiling_dataset):
+ return
+
+ operator_list = profiling_dataset.op_summary.op_list
+ total_duration = sum(float(operator.task_duration) for operator in operator_list)
+ cube_memory_dict = {}
+ vector_type_dict = {}
+
+ for op in operator_list:
+ shapes = op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1]
+ # preliminary filter cube operator
+ if op.task_type == "AI_CORE" and "matmul" in op.op_type.lower():
+ cube_memory_dict.setdefault(op.op_name, {}).setdefault(shapes, 0)
+ cube_memory_dict[op.op_name][shapes] += self.memory_size(op)
+ continue
+
+ # preliminary filter vector operator
+ if op.task_type in ["AI_VECTOR_CORE", "MIX_AIV"]:
+ vector_type_dict.setdefault(op.op_type, set()).add(op)
+ continue
+
+ # filter fa operator
+ if op.op_type == "FlashAttentionScore":
+ self.fa_dict.setdefault(op.op_name, set()).add(shapes)
+ self.fa_list.append(op)
+ elif op.op_type == "FlashAttentionScoreGrad":
+ self.fa_dict.setdefault(op.op_name, set()).add(shapes + "-grad")
+ self.fa_list.append(op)
+
+ # filter cube operator
+ for op_name in cube_memory_dict:
+ for shapes in cube_memory_dict[op_name]:
+ if cube_memory_dict[op_name][shapes] >= self.CUBE_OPERATOR_MEMORY_SIZE_MB:
+ self.cube_dict.setdefault(op_name, set()).add(shapes)
+
+ # filter vector operator
+ for op_type in vector_type_dict:
+ duration_group_by_time = sum(float(op.task_duration) for op in vector_type_dict[op_type])
+ if (duration_group_by_time / total_duration) >= 0.01 or duration_group_by_time >= 1000000:
+ for op in vector_type_dict[op_type]:
+ shapes = op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1]
+ self.vector_dict.setdefault(op.op_name, set()).add(shapes)
+
+ if any([self.cube_dict, self.fa_dict, self.vector_dict]):
+ self.ai_core_performance_issues = True
+
+ @staticmethod
+ def memory_size(operator):
+ memory = 0
+ input_shapes = operator.input_shapes[1:-1].split(";")
+ output_shapes = operator.output_shapes[1:-1]
+ for shapes in input_shapes:
+ if not "," in shapes and shapes != "":
+ # 多的一维是 bias ,预先乘2
+ memory += int(shapes) * 2
+ continue
+ memory += reduce(lambda x, y: x*y, map(int, shapes.split(",")))
+ memory += reduce(lambda x, y: x * y, map(int, output_shapes.split(",")))
+
+ return memory * 2 / 1024 / 1024
+
+ def check_ai_core_performance(self, promoting_dataset: ProfilingDataset):
+ """
+ :Param profiling_dataset: dataset of operator performance from kernel_details.csv
+ """
+ try:
+ self.result["cube"] = self.check_cube_operator(promoting_dataset)
+ except (IndexError, ValueError, AttributeError) as e:
+ logger.error(f"Failed to check ai core performance cube operator, {e}.")
+ self.result["cube"] = []
+
+ try:
+ self.result["fa"] = self.check_fa_operator(promoting_dataset)
+ except (IndexError, ValueError, AttributeError) as e:
+ logger.error(f"Failed to check ai core performance fa operator, {e}.")
+ self.result["fa"] = []
+
+ try:
+ self.result["vector"] = self.check_vector_operator(promoting_dataset)
+ except (IndexError, ValueError, AttributeError) as e:
+ logger.error(f"Failed to check ai core performance vector operator, {e}.")
+ self.result["vector"] = []
+
+ if not any([self.result["cube"], self.result["fa"], self.result["vector"]]):
+ self.ai_core_performance_issues = False
+
+ def check_cube_operator(self, profiling_dataset: ProfilingDataset):
+ cube_dict = self.cube_dict
+ optimization_queue = []
+ bound_queue = []
+ affinity_queue = []
+ operator_list = [op for op in profiling_dataset.op_summary.op_list
+ if op.op_name in cube_dict
+ and op.input_shapes[1:-1] + "-" + op.output_shapes[1:-1] in cube_dict[op.op_name]]
+ suggestion = self._CUBE_AFFINITY_DESC
+ for op in cube_dict:
+ for shape in cube_dict[op]:
+ dtype = None
+ shape_duration = 0.
+ # 判断输入shape内轴是否为256的倍数
+ if (len(shape.split("-")[0].split(";")[0].split(","))) == 4:
+ # NZ格式
+ shapes = shape.split("-")[0].split(";")
+ b = int(shapes[0].split(",")[1])
+ c = int(shapes[0].split(",")[2])
+
+ f = int(shapes[1].split(",")[1])
+ g = int(shapes[1].split(",")[2])
+ affinity_flag = (b * c % 256 == 0) and (f * g % 256 == 0)
+ else:
+ # ND格式
+ shapes = shape.split("-")[0].split(";")
+ l = int(shapes[0].split(",")[1])
+ k = int(shapes[1].split(",")[1])
+ affinity_flag = (l % 256 == 0) and (k % 256 == 0)
+ if not affinity_flag:
+ for operator in operator_list:
+ if (operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape):
+ dtype = operator.input_data_types
+ shape_duration += float(operator.task_duration)
+ affinity_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "duration": shape_duration,
+ "suggestion": suggestion})
+ else:
+ shap_list = [operator for operator in operator_list if
+ operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape]
+ shape_duration = sum(float(operator.task_duration) for operator in shap_list)
+ dtype = shap_list[0].input_data_types if shap_list else None
+ aic_mac_ratio, aic_mte2_ratio = 0., 0.
+ length = 0
+ for operator in shap_list:
+ try:
+ aic_mac_ratio += float(operator.aic_mac_ratio)
+ aic_mte2_ratio += float(operator.aic_mte2_ratio)
+ length += 1
+ except ValueError:
+ continue
+ if length == 0:
+ continue
+ aic_mac_ratio = aic_mac_ratio / length
+ aic_mte2_ratio = aic_mte2_ratio / length
+ bound = ""
+ optimization = 0.
+ if aic_mac_ratio >= 0.8 and aic_mte2_ratio >= 0.95:
+ bound = "mac_and_mte2_bound"
+ elif aic_mac_ratio >= 0.8:
+ bound = "mac_bound"
+ elif aic_mte2_ratio >= 0.95:
+ bound = "mte2_bound"
+ else:
+ optimization = round(max(0.8 - aic_mac_ratio, 0.95 - aic_mte2_ratio) * 100, 2)
+ if bound:
+ bound_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "bound": bound,
+ "duration": shape_duration})
+ else:
+ optimization_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "optimization": optimization})
+ return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5],
+ sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5],
+ sorted(affinity_queue, key=lambda x: x["duration"], reverse=True)[:5]]
+
+ def check_fa_operator(self, profiling_dataset: ProfilingDataset):
+ fa_list = self.fa_list
+ fa_dict = self.fa_dict
+ optimization_queue = []
+ bound_queue = []
+ affinity_queue = []
+ # 不亲和算子筛选
+ for op in fa_dict:
+ for shape in fa_dict[op]:
+ affinity_flag = False
+ shape_duration = 0.
+ dtype = None
+ suggestion = ""
+ if "varlen" in op.lower():
+ # 处理变长算子 如果不亲和则affinity_flag为False
+ if int(shape.split("-")[0].split(";")[0].split(",")[2]) % 128 != 0:
+ affinity_flag = True
+ suggestion = self._FA_AFFINITY_DESC_TYPE1
+ for operator in fa_list:
+ if (operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape):
+ shape_duration += float(operator.task_duration)
+ dtype = operator.input_data_types
+ else:
+ # 处理定长算子 如果不亲和则affinity_flag为False
+ head_dim = 0
+ seq_len = int(shape.split("-")[1].split(";")[0].split(",")[2])
+ input_first_tensor = shape.split("-")[0].split(";")[0].split(",")
+ if len(input_first_tensor) == 3:
+ head_dim = int(input_first_tensor[2]) / int(shape.split("-")[1].split(";")[0].split(",")[1])
+ else:
+ head_dim = int(input_first_tensor[3])
+ if head_dim % 128 != 0 and seq_len % 128 != 0:
+ affinity_flag = True
+ suggestion = self._FA_AFFINITY_DESC_TYPE3
+ elif head_dim % 128 != 0:
+ affinity_flag = True
+ suggestion = self._FA_AFFINITY_DESC_TYPE1
+ elif seq_len % 128 != 0:
+ affinity_flag = True
+ suggestion = self._FA_AFFINITY_DESC_TYPE2
+ if affinity_flag:
+ for operator in fa_list:
+ if (operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" +
+ operator.output_shapes[1:-1] == shape):
+ shape_duration += float(operator.task_duration)
+ dtype = operator.input_data_types
+
+ if affinity_flag:
+ # 不亲和算子 计算耗时,加入affinity_queue
+ affinity_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "suggestion": suggestion,
+ "duration": shape_duration})
+ else:
+ # 处理bound算子和优化算子
+ aiv_vec_ratio, aic_fixpipe_ratio, aic_mte2_ratio, optimization = 0., 0., 0., 0.
+ bound = ""
+ length = 0
+ if len(shape.split("-")) > 2:
+ for operator in fa_list:
+ if (operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" +
+ operator.output_shapes[1:-1] + "-grad" == shape):
+ try:
+ aic_fixpipe_ratio += float(operator.aic_fixpipe_ratio)
+ aic_mte2_ratio += float(operator.aic_mte2_ratio)
+ shape_duration += float(operator.task_duration)
+ dtype = operator.input_data_types
+ length += 1
+ except ValueError:
+ continue
+ if length == 0:
+ continue
+ aic_fixpipe_ratio = aic_fixpipe_ratio / length
+ aic_mte2_ratio = aic_mte2_ratio / length
+ if aic_mte2_ratio >= 0.8 and aic_fixpipe_ratio >= 0.75:
+ bound = "mte2_and_fixpipe_bound"
+ elif aic_mte2_ratio >= 0.8:
+ bound = "mte2_bound"
+ elif aiv_vec_ratio >= 0.75:
+ bound = "vec_bound"
+ else:
+ optimization = max(0.8 - aic_mte2_ratio, 0.75 - aiv_vec_ratio)
+ else:
+ for operator in fa_list:
+ if (operator.op_name == op and
+ operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape):
+ try:
+ aiv_vec_ratio += float(operator.aiv_vec_ratio)
+ aic_mte2_ratio += float(operator.aic_mte2_ratio)
+ shape_duration += float(operator.task_duration)
+ length += 1
+ except ValueError:
+ continue
+ if length == 0:
+ continue
+ aiv_vec_ratio = aiv_vec_ratio / length
+ aic_mte2_ratio = aic_mte2_ratio / length
+ if aic_mte2_ratio >= 0.8 and aiv_vec_ratio >= 0.75:
+ bound = "mte2_and_vec_bound"
+ elif aic_mte2_ratio >= 0.8:
+ bound = "mte2_bound"
+ elif aiv_vec_ratio >= 0.75:
+ bound = "vec_bound"
+ else:
+ optimization = max(0.8 - aic_mte2_ratio, 0.75 - aiv_vec_ratio)
+ if bound:
+ bound_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "bound": bound,
+ "duration": shape_duration})
+ else:
+ optimization_queue.append({
+ "op_name": op,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "optimization": round(optimization * 100, 2)})
+
+ return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5],
+ sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5],
+ sorted(affinity_queue, key=lambda x: x["duration"], reverse=True)[:5]]
+
+ def check_vector_operator(self, profiling_dataset: ProfilingDataset):
+ vector_dict = self.vector_dict
+ vector_list = []
+ optimization_queue = []
+ bound_queue = []
+ vector_list.extend(
+ operator for op_name in vector_dict
+ for shape in vector_dict[op_name]
+ for operator in profiling_dataset.op_summary.op_list
+ if operator.op_name == op_name
+ and operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape
+ )
+ for op_name in vector_dict:
+ for shape in vector_dict[op_name]:
+ aiv_vec_ratio, aiv_mte2_ratio, aiv_mte3_ratio, shape_duration, optimization = 0., 0., 0., 0., 0.
+ length = 0
+ bound, dtype = "", ""
+ for operator in vector_list:
+ if (operator.op_name == op_name and
+ operator.input_shapes[1:-1] + "-" + operator.output_shapes[1:-1] == shape):
+ try:
+ aiv_vec_ratio += float(operator.aiv_vec_ratio)
+ aiv_mte2_ratio += float(operator.aiv_mte2_ratio)
+ aiv_mte3_ratio += float(operator.aiv_mte3_ratio)
+ shape_duration += float(operator.task_duration)
+ dtype = operator.input_data_types
+ length += 1
+ except ValueError:
+ continue
+ if length == 0:
+ continue
+ aiv_vec_ratio = aiv_vec_ratio / length
+ aiv_mte2_ratio = aiv_mte2_ratio / length
+ aiv_mte2_ratio = aiv_mte2_ratio / length
+ if aiv_vec_ratio + aiv_mte2_ratio + aiv_mte3_ratio >= 0.9:
+ bound = "vec_mte2_mte3_bound"
+ elif aiv_mte2_ratio >= 0.7:
+ bound = "mte2_bound"
+ elif aiv_mte3_ratio >= 0.7:
+ bound = "mte3_bound"
+ elif aiv_vec_ratio >= 0.7:
+ bound = "vec_bound"
+ else:
+ optimization = max(0.7 - aiv_vec_ratio, 0.7 - aiv_mte2_ratio, 0.7 - aiv_mte3_ratio)
+ if bound:
+ bound_queue.append({
+ "op_name": op_name,
+ "shape": shape.split("-")[0],
+ "bound": bound,
+ "dtype": dtype,
+ "duration": shape_duration})
+ else:
+ optimization_queue.append({
+ "op_name": op_name,
+ "shape": shape.split("-")[0],
+ "dtype": dtype,
+ "optimization": round(optimization * 100, 2)})
+ return [sorted(optimization_queue, key=lambda x: x["optimization"], reverse=True)[:5],
+ sorted(bound_queue, key=lambda x: x["duration"], reverse=True)[:5]]
+
+ def make_record(self, result: OptimizeResult):
+ """
+ make record for what and how to optimize
+ """
+ if not self.ai_core_performance_issues:
+ return self.ai_core_performance_issues
+
+ suggestion_keys = ['opti', 'bound', 'affinity']
+ cube_desc = dict.fromkeys(suggestion_keys, "")
+ fa_desc = dict.fromkeys(suggestion_keys, "")
+ vector_desc = dict.fromkeys(suggestion_keys, "")
+ if any(self.result["cube"]):
+ optimization_item = OptimizeItem(self._CUBE_PROBLEM, self.desc, [self.suggestion])
+ result.add(OptimizeRecord(optimization_item))
+ headers = [
+ "Type",
+ "Description and Suggestion",
+ ]
+ result.add_detail(self._CUBE_PROBLEM, headers=headers)
+ for cube_opti_issue in self.result["cube"][0]:
+ opti_sugg = self._OPTI_SUGGESTION.format(**cube_opti_issue)
+ cube_desc["opti"] += opti_sugg
+ if cube_desc["opti"]:
+ result.add_detail(self._CUBE_PROBLEM, detail=[self._OPTI_DESC, cube_desc["opti"]])
+ for cube_bound_issue in self.result["cube"][1]:
+ bound_sugg = self._BOUND_SUGGESTION.format(**cube_bound_issue)
+ cube_desc["bound"] += bound_sugg
+ if cube_desc["bound"]:
+ result.add_detail(self._CUBE_PROBLEM, detail=[self._BOUND_DESC, cube_desc["bound"]])
+ for cube_affinity_issue in self.result["cube"][2]:
+ affinity_sugg = self._AFFINITY_SUGGESTION.format(**cube_affinity_issue)
+ cube_desc["affinity"] += affinity_sugg
+ if cube_desc["affinity"]:
+ result.add_detail(self._CUBE_PROBLEM, detail=[self._AFFINITY_DESC, cube_desc["affinity"]])
+
+ if any(self.result["fa"]):
+ optimization_item = OptimizeItem(self._FA_PROBLEM, self.desc, [self.suggestion])
+ result.add(OptimizeRecord(optimization_item))
+ headers = [
+ "Type",
+ "Description and Suggestion",
+ ]
+ result.add_detail(self._FA_PROBLEM, headers=headers)
+ for fa_opti_issue in self.result["fa"][0]:
+ opti_sugg = self._OPTI_SUGGESTION.format(**fa_opti_issue)
+ fa_desc["opti"] += opti_sugg
+ if fa_desc["opti"]:
+ result.add_detail(self._FA_PROBLEM, detail=[self._OPTI_DESC, fa_desc["opti"]])
+ for fa_bound_issue in self.result["fa"][1]:
+ bound_sugg = self._BOUND_SUGGESTION.format(**fa_bound_issue)
+ fa_desc["bound"] += bound_sugg
+ if fa_desc["bound"]:
+ result.add_detail(self._FA_PROBLEM, detail=[self._BOUND_DESC, fa_desc["bound"]])
+ for fa_affinity_issue in self.result["fa"][2]:
+ affinity_sugg = self._AFFINITY_SUGGESTION.format(**fa_affinity_issue)
+ fa_desc["affinity"] += affinity_sugg
+ if fa_desc["affinity"]:
+ result.add_detail(self._FA_PROBLEM, detail=[self._AFFINITY_DESC, fa_desc["affinity"]])
+
+ if any(self.result["vector"]):
+ optimization_item = OptimizeItem(self._VECTOR_PROBLEM, self.desc, [self.suggestion])
+ result.add(OptimizeRecord(optimization_item))
+ headers = [
+ "Type",
+ "Description and Suggestion",
+ ]
+ result.add_detail(self._VECTOR_PROBLEM, headers=headers)
+ for vector_opti_issue in self.result["vector"][0]:
+ opti_sugg = self._OPTI_SUGGESTION.format(**vector_opti_issue)
+ vector_desc["opti"] += opti_sugg
+ if vector_desc["opti"]:
+ result.add_detail(self._VECTOR_PROBLEM, detail=[self._OPTI_DESC, vector_desc["opti"]])
+ for vector_bound_issue in self.result["vector"][1]:
+ bound_sugg = self._BOUND_SUGGESTION.format(**vector_bound_issue)
+ vector_desc["bound"] += bound_sugg
+ if vector_desc["bound"]:
+ result.add_detail(self._VECTOR_PROBLEM, detail=[self._BOUND_DESC, vector_desc["bound"]])
+ return True
+
+ def make_render(self, html_render, add_render_list=True, **kwargs):
+ if not self.ai_core_performance_issues:
+ return self.ai_core_performance_issues
+
+ priority = kwargs.get("priority")
+ return html_render.render_template(key="computation",
+ template_dir="templates",
+ template_name="ai_core_performance.html",
+ format_result=self.result,
+ language=self.language,
+ add_render_list=add_render_list,
+ priority_background_color=priority,
+ rank=kwargs.get("rank"))
+
+ def check_task_list(self, profiling_dataset: ProfilingDataset) -> bool:
+ if not hasattr(profiling_dataset, "op_summary"):
+ logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op summary")
+ return False
+ if not hasattr(profiling_dataset.op_summary, "op_list"):
+ logger.warning("Skip %s checker because of not containing %s", self._CHECKER, "op_list")
+ return False
+ if (not hasattr(profiling_dataset.op_summary.op_list[0], "input_shapes") or
+ not hasattr(profiling_dataset.op_summary.op_list[0], "input_data_types")):
+ logger.warning("Skip %s checker because of not containing input datas", self._CHECKER)
+ return False
+ return True
diff --git a/profiler/advisor/common/analyzer_scopes.py b/profiler/advisor/common/analyzer_scopes.py
index 2cad1a3ce7b873ef66730a6c098152853c0b1155..40a8d99bcca8af0d415b01c08e173baec642c126 100644
--- a/profiler/advisor/common/analyzer_scopes.py
+++ b/profiler/advisor/common/analyzer_scopes.py
@@ -40,3 +40,4 @@ class SupportedScopes:
GC_ANALYSIS = "gc_analysis"
CONJECTURED_GC_ANALYSIS = "conjectured_analysis"
COMPARISON = "comparison"
+ AICORE_PERFORMANCE_ANALYSIS = "ai_core_performance_analysis"
diff --git a/profiler/advisor/display/html/templates/ai_core_performance.html b/profiler/advisor/display/html/templates/ai_core_performance.html
new file mode 100644
index 0000000000000000000000000000000000000000..77e5e0cb55200efdf5b854e03ac2844ddc631a8f
--- /dev/null
+++ b/profiler/advisor/display/html/templates/ai_core_performance.html
@@ -0,0 +1,159 @@
+{% if format_result|length > 0 %}
+
+
+
+ {% if language == "cn" %}
+ {% set title_ns = namespace(type='类别', desc='描述及建议', opti_set='性能优化算子集合', bound_set='bound算子集合', affinity_set='不亲和算子集合',
+ opti_refer=' 参考性能优化空间: ', bound_refer=' bound类型为: ', affinity_refer=' 不亲和类型为: ', title_desc='算子相关分析,参考如下: ') %}
+ {% else %}
+ {% set title_ns = namespace(type='Type', desc='Description and Suggestion', opti_set='set of performance optimization operators',
+ bound_set='set of bound operators', affinity_set='set of unaffine operators', opti_refer=' refer to Performance Optimization Space: ',
+ bound_refer=' bound type: ', affinity_refer=' type of disaffinity: ', title_desc=' Operator related analysis, referenced below: ') %}
+ {% endif %}
+ {% if format_result.cube[0]|length + format_result.cube[1]|length + format_result.cube[2]|length > 0 %}
+
MatMul{{ title_ns.title_desc }}
+
+
+
+ {{ title_ns.type }} |
+ {{ title_ns.desc }} |
+
+ {% set opti_ns = namespace(total_opti='') %}
+ {% for opti in format_result.cube[0] %}
+ {% if not loop.first %}
+ {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% else %}
+ {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% endif %}
+ {% endfor %}
+ {% if opti_ns.total_opti|length > 0 %}
+
+ {{ title_ns.opti_set }} |
+ {{ opti_ns.total_opti | safe }} |
+
+ {% endif %}
+ {% set bound_ns = namespace(total_bound='') %}
+ {% for bound in format_result.cube[1] %}
+ {% if not loop.first %}
+ {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% else %}
+ {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% endif %}
+ {% endfor %}
+ {% if bound_ns.total_bound|length > 0 %}
+
+ {{ title_ns.bound_set }} |
+ {{ bound_ns.total_bound | safe }} |
+
+ {% endif %}
+ {% set affinity_ns = namespace(total_affinity='') %}
+ {% for affinity in format_result.cube[2] %}
+ {% if not loop.first %}
+ {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
" ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %}
+ {% else %}
+ {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %}
+ {% endif %}
+ {% endfor %}
+ {% if affinity_ns.total_affinity|length > 0 %}
+
+ {{ title_ns.affinity_set }} |
+ {{ affinity_ns.total_affinity | safe }} |
+
+ {% endif %}
+
+ {% endif %}
+
+ {% if format_result.fa[0]|length + format_result.fa[1]|length + format_result.fa[2]|length > 0 %}
+
FA{{ title_ns.title_desc }}
+
+
+
+ {{ title_ns.type }} |
+ {{ title_ns.desc }} |
+
+ {% set opti_ns = namespace(total_opti='') %}
+ {% for opti in format_result.fa[0] %}
+ {% if not loop.first %}
+ {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% else %}
+ {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% endif %}
+ {% endfor %}
+ {% if opti_ns.total_opti|length > 0 %}
+
+ {{ title_ns.opti_set }} |
+ {{ opti_ns.total_opti | safe }} |
+
+ {% endif %}
+ {% set bound_ns = namespace(total_bound='') %}
+ {% for bound in format_result.fa[1] %}
+ {% if not loop.first %}
+ {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% else %}
+ {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% endif %}
+ {% endfor %}
+ {% if bound_ns.total_bound|length > 0 %}
+
+ {{ title_ns.bound_set }} |
+ {{ bound_ns.total_bound | safe }} |
+
+ {% endif %}
+ {% set affinity_ns = namespace(total_affinity='') %}
+ {% for affinity in format_result.fa[2] %}
+ {% if not loop.first %}
+ {% set affinity_ns.total_affinity = affinity_ns.total_affinity ~ "
" ~ affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %}
+ {% else %}
+ {% set affinity_ns.total_affinity = affinity.op_name ~ " operator shape: " ~ affinity.shape ~ " dtype: " ~ affinity.dtype ~ title_ns.affinity_refer ~ affinity.suggestion %}
+ {% endif %}
+ {% endfor %}
+ {% if affinity_ns.total_affinity|length > 0 %}
+
+ {{ title_ns.affinity_set }} |
+ {{ affinity_ns.total_affinity | safe }} |
+
+ {% endif %}
+
+ {% endif %}
+
+ {% if format_result.vector[0]|length + format_result.vector[1]|length > 0 %}
+
Vector{{ title_ns.title_desc }}
+
+
+
+ {{ title_ns.type }} |
+ {{ title_ns.desc }} |
+
+ {% set opti_ns = namespace(total_opti='') %}
+ {% for opti in format_result.vector[0] %}
+ {% if not loop.first %}
+ {% set opti_ns.total_opti = opti_ns.total_opti ~ "
" ~ opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% else %}
+ {% set opti_ns.total_opti = opti.op_name ~ " operator shape: " ~ opti.shape ~ " dtype: " ~ opti.dtype ~ title_ns.opti_refer ~ opti.optimization ~ "%" %}
+ {% endif %}
+ {% endfor %}
+ {% if opti_ns.total_opti|length > 0 %}
+
+ {{ title_ns.opti_set }} |
+ {{ opti_ns.total_opti | safe }} |
+
+ {% endif %}
+ {% set bound_ns = namespace(total_bound='') %}
+ {% for bound in format_result.vector[1] %}
+ {% if not loop.first %}
+ {% set bound_ns.total_bound = bound_ns.total_bound ~ "
" ~ bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% else %}
+ {% set bound_ns.total_bound = bound.op_name ~ " operator shape: " ~ bound.shape ~ " dtype: " ~ bound.dtype ~ title_ns.bound_refer ~ bound.bound %}
+ {% endif %}
+ {% endfor %}
+ {% if bound_ns.total_bound|length > 0 %}
+
+ {{ title_ns.bound_set }} |
+ {{ bound_ns.total_bound | safe }} |
+
+ {% endif %}
+
+ {% endif %}
+
+
+{% endif %}
\ No newline at end of file
diff --git a/profiler/advisor/interface/interface.py b/profiler/advisor/interface/interface.py
index 7b9cb00fdfef103b90b9828c3fd6ff6aa6f1900a..ebcf5680673430a7473fd819557e81e05312bb2c 100644
--- a/profiler/advisor/interface/interface.py
+++ b/profiler/advisor/interface/interface.py
@@ -47,6 +47,8 @@ from profiler.advisor.analyzer.communication.alignment.byte_alignment_analyzer i
from profiler.advisor.analyzer.schedule.gc.gc_analyzer import GcAnalyzer
from profiler.advisor.analyzer.schedule.conjectured_gc.conjectured_gc_analyzer import ConjecturedGcAnalyzer
from profiler.advisor.analyzer.comparison.comparison_analyzer import ComparisonAnalyzer
+from profiler.advisor.analyzer.computation.ai_core_performance.ai_core_performance_analyzer import \
+ AICorePerformanceAnalyzer
logger = logging.getLogger()
@@ -76,7 +78,8 @@ class Interface:
SupportedScopes.OPERATOR_NO_BOUND_ANALYSIS: OperatorBoundAnalyzer,
SupportedScopes.BLOCK_DIM_ANALYSIS: BlockDimAnalyzer,
SupportedScopes.GRAPH: FusionOPAnalyzer,
- SupportedScopes.FREQ_ANALYSIS: AICoreFreqAnalyzer
+ SupportedScopes.FREQ_ANALYSIS: AICoreFreqAnalyzer,
+ SupportedScopes.AICORE_PERFORMANCE_ANALYSIS: AICorePerformanceAnalyzer
}),
COMMUNICATION: OrderedDict({SupportedScopes.PACKET: PacketAnalyzer,
SupportedScopes.COMMUNICATION_RETRANSMISSION_DETECTION: RDMARetransmissionAnalyzer,
diff --git a/profiler/advisor/rules/cn/aicore_performance.yaml b/profiler/advisor/rules/cn/aicore_performance.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8d44aaab2735efe163caee5f2baa884f7eed73b5
--- /dev/null
+++ b/profiler/advisor/rules/cn/aicore_performance.yaml
@@ -0,0 +1,15 @@
+cube_problem: "Cube算子性能分析"
+fa_problem: "FA算子性能分析"
+vector_problem: "Vector算子性能分析"
+description: "提供一些AICORE算子的参考瓶颈"
+bound_description: "bound算子集合"
+optimization_description: "性能优化算子集合"
+affinity_description: "不亲和算子集合"
+cube_affinity_desc: "内轴无法被256整除"
+fa_affinity_desc_type1: "D不能被128整除"
+fa_affinity_desc_type2: "S不能被128整除"
+fa_affinity_desc_type3: "D和S均不能被128整除"
+suggestion: "请根据亲和性、bound类型或优化空间尝试分析筛选出来的算子"
+affinity_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} 有不亲和特征: {suggestion}\n"
+bound_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} bound类型为: {bound} bound\n"
+optimization_suggestion: "{op_name}算子 shape: {shape} dtype: {dtype} 疑似有性能优化空间,参考性能优化空间: {optimization}%\n"
\ No newline at end of file
diff --git a/profiler/advisor/rules/en/aicore_performance.yaml b/profiler/advisor/rules/en/aicore_performance.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e85a919ab95bba21ec533ba5db8243d64d939184
--- /dev/null
+++ b/profiler/advisor/rules/en/aicore_performance.yaml
@@ -0,0 +1,15 @@
+cube_problem: "Cube operator performance analysis"
+fa_problem: "FA operator performance analysis"
+vector_problem: "Vector operator performance analysis"
+description: "Provide some reference bottlenecks for the AICORE operator"
+bound_description: "set of bound operators"
+optimization_description: "set of performance optimization operators"
+affinity_description: "set of unaffine operators"
+cube_affinity_desc: "Then inner axis is not divisible by 256"
+fa_affinity_desc_type1: "D is not divisible by 128"
+fa_affinity_desc_type2: "S is not divisible by 128"
+fa_affinity_desc_type3: "Neither D nor S is not divisible by 128"
+suggestion: "Please try to analyze the filtered operators based on affinity, bound type or optimization space"
+affinity_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} with disaffection characteristics: {suggestion}\n"
+bound_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} bound type: {bound} bound\n"
+optimization_suggestion: "{op_name} Op shape: {shape} dtype: {dtype} suspect there is room for performance optimization, refer to Performance Optimization Space: {optimization}%\n"
\ No newline at end of file
diff --git a/profiler/cli/entrance.py b/profiler/cli/entrance.py
deleted file mode 100644
index c6d72837b7ec0d5c9943f7652378afb3de11808c..0000000000000000000000000000000000000000
--- a/profiler/cli/entrance.py
+++ /dev/null
@@ -1,68 +0,0 @@
-#!/usr/bin/python
-# -*- coding: utf-8 -*-
-# Copyright (c) 2024, Huawei Technologies Co., Ltd.
-# All rights reserved.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-import logging
-import click
-
-from profiler.cli.analyze_cli import analyze_cli
-from profiler.cli.complete_cli import auto_complete_cli
-from profiler.cli.compare_cli import compare_cli
-from profiler.cli.cluster_cli import cluster_cli
-from profiler.advisor.version import print_version_callback, cli_version
-
-logger = logging.getLogger()
-CONTEXT_SETTINGS = dict(help_option_names=['-H', '-h', '--help'],
- max_content_width=160)
-
-COMMAND_PRIORITY = {
- "advisor": 1,
- "compare": 2,
- "cluster": 3,
- "auto-completion": 4
-}
-
-
-class SpecialHelpOrder(click.Group):
-
- def __init__(self, *args, **kwargs):
- super(SpecialHelpOrder, self).__init__(*args, **kwargs)
-
- def list_commands_for_help(self, ctx):
- """
- reorder the list of commands when listing the help
- """
- commands = super(SpecialHelpOrder, self).list_commands(ctx)
- return [item[1] for item in sorted((COMMAND_PRIORITY.get(command, float('INF')),
- command) for command in commands)]
-
- def get_help(self, ctx):
- self.list_commands = self.list_commands_for_help
- return super(SpecialHelpOrder, self).get_help(ctx)
-
-
-@click.group(context_settings=CONTEXT_SETTINGS, cls=SpecialHelpOrder)
-@click.option('--version', '-V', '-v', is_flag=True,
- callback=print_version_callback, expose_value=False,
- is_eager=True, help=cli_version())
-def msprof_analyze_cli(**kwargs):
- pass
-
-
-msprof_analyze_cli.add_command(analyze_cli, name="advisor")
-msprof_analyze_cli.add_command(compare_cli, name="compare")
-msprof_analyze_cli.add_command(cluster_cli, name="cluster")
-msprof_analyze_cli.add_command(auto_complete_cli, name="auto-completion")
-
diff --git a/profiler/test/ut/advisor/compute_advice/data/kernel_details.csv b/profiler/test/ut/advisor/compute_advice/data/kernel_details.csv
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/profiler/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py b/profiler/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py
new file mode 100644
index 0000000000000000000000000000000000000000..61ae35d138ed9c66eb9459eec8c610e9e84cb9dd
--- /dev/null
+++ b/profiler/test/ut/advisor/compute_advice/test_ai_core_performance_advice.py
@@ -0,0 +1,74 @@
+import csv
+import os
+import shutil
+import stat
+
+import unittest
+from profiler.advisor.interface.interface import Interface
+from profiler.advisor.common.analyzer_scopes import SupportedScopes
+
+
+class TestAICorePerformanceAdvice(unittest.TestCase):
+ TMP_DIR = "./ascend_pt"
+ OUTPUT_DIR = "./ascend_pt/ASCEND_PROFILER_OUTPUT"
+ interface = None
+ err_interface = None
+
+ def tearDown(self):
+ if os.path.exists(TestAICorePerformanceAdvice.TMP_DIR):
+ shutil.rmtree(TestAICorePerformanceAdvice.TMP_DIR)
+ self.clear_htmls()
+
+ def setUp(self):
+ if os.path.exists(TestAICorePerformanceAdvice.TMP_DIR):
+ shutil.rmtree(TestAICorePerformanceAdvice.TMP_DIR)
+ if not os.path.exists(TestAICorePerformanceAdvice.TMP_DIR):
+ os.makedirs(TestAICorePerformanceAdvice.TMP_DIR)
+ if not os.path.exists(TestAICorePerformanceAdvice.OUTPUT_DIR):
+ os.makedirs(TestAICorePerformanceAdvice.OUTPUT_DIR)
+ self.clear_htmls()
+
+ @classmethod
+ def clear_htmls(cls):
+ current_path = os.path.dirname(os.path.abspath(__file__))
+ for filename in os.listdir(current_path):
+ # 检查文件是否以“mstt”开头
+ if filename.startswith("mstt"):
+ # 构建文件的完整路径
+ file_path = os.path.join(current_path, filename)
+ # 删除文件
+ os.remove(file_path)
+
+ @classmethod
+ def copy_kernel_details(cls, path):
+ # Define source and destination paths
+ source_csv_path = f"./data/{path}"
+ destination_csv_path = f"{TestAICorePerformanceAdvice.OUTPUT_DIR}/kernel_details.csv"
+
+ # Check if source CSV file exists
+ if not os.path.exists(source_csv_path):
+ raise FileNotFoundError(f"test data file not found:{source_csv_path}")
+
+ # Ensure the output directory exists
+ if not os.path.exists(TestAICorePerformanceAdvice.OUTPUT_DIR):
+ os.makedirs(TestAICorePerformanceAdvice.OUTPUT_DIR)
+
+ # Copy the CSV file from source to destination
+ shutil.copyfile(source_csv_path, destination_csv_path)
+
+ def test_ai_core_performance_total(self):
+ file_path = "kernel_details.csv"
+ self.copy_kernel_details(file_path)
+ interface = Interface(profiling_path=self.TMP_DIR)
+ dimension = Interface.COMPUTATION
+ scope = SupportedScopes.AICORE_PERFORMANCE_ANALYSIS
+ result = interface.get_result(dimension, scope, render_html=1, output_dict=False, profiling_path=self.TMP_DIR)
+ self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[0]))
+ self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[1]))
+ self.assertLess(1, len(result.data.get("Cube算子性能分析").get("data")[2]))
+ self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[0]))
+ self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[1]))
+ self.assertLess(1, len(result.data.get("FA算子性能分析").get("data")[2]))
+ self.assertLess(1, len(result.data.get("Vector算子性能分析").get("data")[0]))
+ self.assertLess(1, len(result.data.get("Vector算子性能分析").get("data")[1]))
+ result.clear()