diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py index 1553106b0ae3b111c03cedf7f80b5d2fa9ed7b5e..7ce03210bb34ce769829f3a643b6d2ad3444a407 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/cmd_parser.py @@ -13,12 +13,23 @@ # See the License for the specific language governing permissions and # limitations under the License. +import argparse import os from msprobe.core.common.file_utils import check_file_or_directory_path from msprobe.core.common.utils import Const, MsprobeBaseException +class UniqueDeviceAction(argparse.Action): + def __call__(self, parser, namespace, values, option_string=None): + unique_values = set(values) + if len(values) != len(unique_values): + parser.error("device id must be unique") + for device_id in values: + if not 0 <= device_id <= 4095: + parser.error(f"the argument 'device_id' must be in range [0, 4095], but got {device_id}") + setattr(namespace, self.dest, values) + def add_api_accuracy_checker_argument(parser): parser.add_argument("-api_info", "--api_info_file", dest="api_info_file", type=str, required=True, @@ -29,6 +40,11 @@ def add_api_accuracy_checker_argument(parser): parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, required=False, help=" the exit csv for continue") + #以下属于多线程 + parser.add_argument("-d", "--device", dest="device_id", nargs='+', type=int, + help=" set device id to run ut, must be unique and in range 0-7", + default=[0], required=False, action=UniqueDeviceAction) + def check_args(args): args.api_info_file = os.path.abspath(args.api_info_file) @@ -42,3 +58,13 @@ def check_args(args): if args.result_csv_path: args.result_csv_path = os.path.abspath(args.result_csv_path) check_file_or_directory_path(args.result_csv_path) + +#补充安全性 + + # 获取 device_id 参数 + print("Device IDs:", args.device_id) # 输出设备 ID 列表 + # 获取 device_id 参数并逐个打印 + + print("Device IDs:") + for device_id in args.device_id: + print(device_id) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py index a4d5e40e388a45692c7d5ff8cdbe745f649294fd..07e3e9989af3afe78f13e52c5b4cf9370a946863 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/data_manager.py @@ -90,8 +90,17 @@ class DataManager: self.initialize_api_names_set(result_csv_path) else: # 默认情况下,设置输出路径为空,等待首次写入时初始化 - self.detail_out_path = None - self.result_out_path = None + self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME)) + self.detail_out_path = os.path.join( + self.csv_dir, + os.path.basename(self.result_out_path).replace("result", "details") + ) + + if self.detail_out_path and os.path.exists(self.detail_out_path): + check_file_or_directory_path(self.detail_out_path) + + if self.result_out_path and os.path.exists(self.result_out_path): + check_file_or_directory_path(self.result_out_path) def initialize_api_names_set(self, result_csv_path): """读取现有的 CSV 文件并存储已经出现的 API 名称到集合中""" @@ -146,17 +155,6 @@ class DataManager: def save_results(self, api_name_str): if self.is_first_write: - self.result_out_path = os.path.join(self.csv_dir, add_time_as_suffix(MsCompareConst.RESULT_CSV_FILE_NAME)) - self.detail_out_path = os.path.join( - self.csv_dir, - os.path.basename(self.result_out_path).replace("result", "details") - ) - if self.detail_out_path and os.path.exists(self.detail_out_path): - check_file_or_directory_path(self.detail_out_path) - - if self.result_out_path and os.path.exists(self.result_out_path): - check_file_or_directory_path(self.result_out_path) - # 直接写入表头 logger.info("Writing CSV headers for the first time.") write_csv_header(self.detail_out_path, get_detail_csv_header) diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/main.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/main.py index 6a8799cd487ac14b946ba4651b9cd52468cce932..ec20b34e1fec708aa5c1686a6aca1fd724c7429d 100644 --- a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/main.py +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/main.py @@ -14,6 +14,9 @@ # limitations under the License. from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker + +from msprobe.mindspore.api_accuracy_checker.multi_api_accuracy_checker import MultiApiAccuracyChecker + from msprobe.mindspore.api_accuracy_checker.cmd_parser import check_args @@ -22,3 +25,9 @@ def api_checker_main(args): api_accuracy_checker = ApiAccuracyChecker(args) api_accuracy_checker.parse(args.api_info_file) api_accuracy_checker.run_and_compare() + +def mul_api_checker_main(args): + check_args(args) + api_accuracy_checker = MultiApiAccuracyChecker(args) + api_accuracy_checker.parse(args.api_info_file) + api_accuracy_checker.run_and_compare() diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py new file mode 100644 index 0000000000000000000000000000000000000000..684796ce539695d9aff5ebaa50dfb22fd20f10fc --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_api_accuracy_checker.py @@ -0,0 +1,185 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import numpy as np +from tqdm import tqdm +import multiprocessing +from multiprocessing import Manager +from mindspore import context +from tqdm import tqdm + +from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, load_json +from msprobe.core.common.utils import add_time_as_suffix +from msprobe.mindspore.api_accuracy_checker.api_accuracy_checker import ApiAccuracyChecker, BasicInfoAndStatus +from msprobe.mindspore.api_accuracy_checker.api_info import ApiInfo +from msprobe.mindspore.api_accuracy_checker.api_runner import api_runner, ApiInputAggregation +from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms +from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager +from msprobe.mindspore.api_accuracy_checker.multi_data_manager import MultiDataManager +from msprobe.mindspore.api_accuracy_checker.utils import (check_and_get_from_json_dict, global_context, + trim_output_compute_element_list) +from msprobe.mindspore.common.log import logger +import numpy as np + + +class MultiApiAccuracyChecker(ApiAccuracyChecker): + def __init__(self, args): + # super().__init__(args) # 调用父类的初始化方法 + # 可以添加 MultiApiAccuracyChecker 特有的属性或方法 + self.api_infos = dict() + # 使用 Manager 创建共享变量,确保进程间的同步 + + self.manager = Manager() + self.is_first_write = self.manager.Value('b', True) # 创建共享变量 + + # 初始化 DataManager 时传入共享的 is_first_write + self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path, self.is_first_write) + + # self.multi_data_manager = MultiDataManager(args.out_path, args.result_csv_path) # 在初始化时实例化 DataManager + self.args = args # 将 args 保存为类的属性 + + print("Device IDs:", args.device_id) + + # def run_and_compare(self): + # print("Running on Device IDs:", self.args.device_id) # 打印设备 ID + # # 调用父类的 `run_and_compare` 方法 + # super().run_and_compare() + # def run_and_compare(self): + # # 重载 `run_and_compare` 方法以支持多 API 的并行处理 + # for api_name_str, api_info in tqdm(self.api_infos.items()): + # if not self.data_manager.is_unique_api(api_name_str): + # continue + + # # 以下内容可以根据需求修改以实现多 API 的并行处理 + # # 这里我们假设只是调用父类的 `run_and_compare`,在实际使用中可考虑多线程、多进程等方式 + # super().run_and_compare() # 调用父类的 `run_and_compare` 方法 + + def process_on_device(self, device_id, partitioned_api_infos, index): + # 设置 MindSpore context 的 device_id + context.set_context(device_id=device_id) + print(f"123Running on device {device_id} with process index {index}") + + # 使用 numpy.array_split 来均匀分配任务 + partitioned_api_infos_split = np.array_split(partitioned_api_infos, len(self.args.device_id)) + + # 打印分割后的数据,检查是否正确分配 + print(f"Total APIs: {len(partitioned_api_infos)}") + print(f"Number of partitions: {len(partitioned_api_infos_split)}") + print( + f"Current partition for device {device_id}, index {index}: {len(partitioned_api_infos_split[index])} APIs") + # 统计跳过的任务 + skipped_due_to_non_unique = 0 + skipped_due_to_no_forward = 0 + # 获取当前进程要处理的任务 + current_partition = partitioned_api_infos_split[index] + + successful_tasks = 0 + # 使用 tqdm 进度条,每个进程单独显示 + # with tqdm(total=len(partitioned_api_infos) // len(self.args.device_id), desc=f"Device {device_id}", position=index) as pbar: + with tqdm(total=len(current_partition), desc=f"Device {device_id}", position=index) as pbar: + # for idx, (api_name_str, api_info) in enumerate(partitioned_api_infos): + for idx, (api_name_str, api_info) in enumerate(current_partition): # 只遍历当前进程分配的任务 + # 打印每个任务的基本信息 + print( + f"Processing API: {api_name_str}, Index: {idx} (Total tasks: {len(current_partition)}), Device: {device_id}") + + # if idx % len(self.args.device_id) != index: + # continue # 跳过不属于该进程的工作 + + if not self.multi_data_manager.is_unique_api(api_name_str): + print(f"API {api_name_str} is not unique, skipping.") + skipped_due_to_non_unique += 1 + pbar.update(1) + continue + + if not api_info.check_forward_info(): + skipped_due_to_no_forward += 1 + logger.debug( + f"api: {api_name_str} is lack of forward information, skipping forward and backward check.") + pbar.update(1) + continue + + # 执行前向和后向检查 + try: + print(f"Executing forward check for {api_name_str}") + forward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.FORWARD) + forward_output_list = self.run_and_compare_helper(api_info, api_name_str, + forward_inputs_aggregation, Const.FORWARD) + self.multi_data_manager.record(forward_output_list) + except Exception as e: + logger.warning(f"Error in forward check for {api_name_str}: {e}") + + if api_info.check_backward_info(): + try: + print(f"Executing backward check for {api_name_str}") + backward_inputs_aggregation = self.prepare_api_input_aggregation(api_info, Const.BACKWARD) + backward_output_list = self.run_and_compare_helper(api_info, api_name_str, + backward_inputs_aggregation, Const.BACKWARD) + self.multi_data_manager.record(backward_output_list) + except Exception as e: + logger.warning(f"Error in backward check for {api_name_str}: {e}") + + # 保存结果 + self.multi_data_manager.save_results(api_name_str) + successful_tasks += 1 + # 更新进度条 + pbar.update(1) + # 打印统计信息 + print(f"Total skipped due to non-unique API: {skipped_due_to_non_unique}") + print(f"Total skipped due to lack of forward info: {skipped_due_to_no_forward}") + print(f"Total successful tasks processed: {successful_tasks}") + + def run_and_compare(self): + # 获取要使用的设备ID列表 + device_ids = self.args.device_id + # 初始化计数器 + total_forward_info_count = 0 + total_without_forward_info_count = 0 + # 按设备数划分要处理的 API 项 + partitioned_api_infos = list(self.api_infos.items()) + # 打印 partitioned_api_infos 的内容 + print(f"Original partitioned_api_infos (Total tasks: {len(partitioned_api_infos)}):") + # for api_name_str, api_info in partitioned_api_infos: + # print(f"API Name: {api_name_str}, API Info: {api_info}") + # # 初始化计数器 + # total_with_forward_info = 0 + # total_without_forward_info = 0 + + # # 遍历 partitioned_api_infos,统计有前向信息的任务 + # for api_name_str, api_info in partitioned_api_infos: + # if api_info.check_forward_info(): + # total_with_forward_info += 1 # 统计有前向信息的任务 + # else: + # total_without_forward_info += 1 # 统计没有前向信息的任务 + + # # 打印统计结果 + # print(f"Total APIs with forward information: {total_with_forward_info}") + # print(f"Total APIs without forward information: {total_without_forward_info}") + + # 创建多进程 + processes = [] + for index, device_id in enumerate(device_ids): + process = multiprocessing.Process(target=self.process_on_device, + args=(device_id, partitioned_api_infos, index)) + processes.append(process) + process.start() + + # 等待所有进程完成 + for process in processes: + process.join() + + diff --git a/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_data_manager.py b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_data_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..796ec8bc42ca2154831c61728a5333f7112cbc3a --- /dev/null +++ b/debug/accuracy_tools/msprobe/mindspore/api_accuracy_checker/multi_data_manager.py @@ -0,0 +1,82 @@ +# Copyright (c) 2024-2024, Huawei Technologies Co., Ltd. +# All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import csv +import threading +import os +from collections import defaultdict + +from msprobe.core.common.const import Const, CompareConst, MsCompareConst +from msprobe.core.common.file_utils import FileOpen, create_directory, write_csv, read_csv +from msprobe.core.common.utils import add_time_as_suffix, MsprobeBaseException +from msprobe.mindspore.api_accuracy_checker.base_compare_algorithm import compare_algorithms +from msprobe.mindspore.api_accuracy_checker.data_manager import DataManager, ResultCsvEntry, write_csv_header, get_result_csv_header, get_detail_csv_header, check_csv_header +from msprobe.core.common.file_utils import check_file_or_directory_path +from msprobe.mindspore.common.log import logger + +class MultiDataManager(DataManager): + def __init__(self, csv_dir, result_csv_path, shared_is_first_write): + super().__init__(csv_dir, result_csv_path) + # 创建锁对象,确保线程安全 + # 使用共享的 is_first_write 变量来控制表头写入 + self.shared_is_first_write = shared_is_first_write + self.lock = threading.Lock() + # self.first_write_lock = threading.Lock() # 用于确保只有一个线程写入表头 + + def record(self, output_list): + """记录输出数据,线程安全操作""" + if output_list is None: + return + for output in output_list: + api_real_name, forward_or_backward, basic_info, compare_result_dict = output + key = (api_real_name, forward_or_backward) + if key not in self.results: + self.results[key] = [] + self.results[key].append((basic_info, compare_result_dict)) + logger.debug(f"Updated self.results for key {key}: {self.results[key]}") + logger.debug(f"Complete self.results after recording: {self.results}") + + def save_results(self, api_name_str): + """保存结果,线程安全操作""" + with self.lock: # 确保保存操作不会被多个线程同时进行 + + if self.is_first_write and self.shared_is_first_write.value: + self.shared_is_first_write.value = False + self.is_first_write = False # 写入后标记为 False,避免重复写入表头 + + # 直接写入表头 + logger.info("Writing CSV headers for the first time.") + write_csv_header(self.detail_out_path, get_detail_csv_header) + write_csv_header(self.result_out_path, get_result_csv_header) + + + + """写入详细输出和结果摘要并清理结果""" + logger.debug("Starting to write detailed output to CSV.") + self.to_detail_csv(self.detail_out_path) + logger.debug(f"Detailed output for {api_name_str} written to {self.detail_out_path}.") + + logger.debug("Starting to write result summary to CSV.") + self.to_result_csv(self.result_out_path) + logger.debug(f"Result summary for {api_name_str} written to {self.result_out_path}.") + + # 清理记录,准备下一次调用 + self.clear_results() + + def clear_results(self): + """清空 self.results 数据,线程安全操作""" + logger.debug("Clearing self.results data.") + self.results.clear() diff --git a/debug/accuracy_tools/msprobe/msprobe.py b/debug/accuracy_tools/msprobe/msprobe.py index 03c52e4f7a16fa9581884a26284efa32656187cd..95f54cfb30b5bfd1a3bc84c95393ad45aac22d95 100644 --- a/debug/accuracy_tools/msprobe/msprobe.py +++ b/debug/accuracy_tools/msprobe/msprobe.py @@ -74,6 +74,9 @@ def main(): from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument from msprobe.visualization.graph_service import _ms_graph_service_parser, _ms_graph_service_command add_api_accuracy_checker_argument(run_ut_cmd_parser) + from msprobe.mindspore.api_accuracy_checker.cmd_parser import add_api_accuracy_checker_argument + add_api_accuracy_checker_argument(multi_run_ut_cmd_parser) + _ms_graph_service_parser(graph_service_cmd_parser) args = parser.parse_args(sys.argv[1:]) @@ -108,6 +111,9 @@ def main(): elif sys.argv[3] == "run_ut": from msprobe.mindspore.api_accuracy_checker.main import api_checker_main api_checker_main(args) + elif sys.argv[3] == "multi_run_ut": + from msprobe.mindspore.api_accuracy_checker.main import mul_api_checker_main + mul_api_checker_main(args) elif sys.argv[3] == "graph": _ms_graph_service_command(args)