diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index 6211032d0c5dc7fb1768fa8f8372001bcda0a181..ddf2c17ed2cbf74e9bcce5970de5c5830661a011 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -18,12 +18,7 @@ class Comparator: def __init__(self, result_csv_path, details_csv_path, is_continue_run_ut, test_result_cnt=None, stack_info_json_path=None): self.save_path = result_csv_path self.detail_save_path = details_csv_path - if not is_continue_run_ut: - if os.path.exists(self.save_path): - raise ValueError(f"file {self.save_path} already exists, please remove it first or use a new dump path") - if os.path.exists(self.detail_save_path): - raise ValueError( - f"file {self.detail_save_path} already exists, please remove it first or use a new dump path") + if not is_continue_run_ut and not os.path.exists(self.save_path) and not os.path.exists(self.detail_save_path): self.write_csv_title() if stack_info_json_path: self.stack_info = get_json_contents(stack_info_json_path) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa8fd47eb0573df69b6e667d1b2c92b6d59a33b --- /dev/null +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/multi_run_ut.py @@ -0,0 +1,108 @@ +import subprocess +import json +import os +import sys +import argparse +import glob +from collections import namedtuple +from ptdbg_ascend.src.python.ptdbg_ascend.common.file_check_util import FileCheckConst, FileChecker, \ + check_file_suffix, check_link, FileOpen +from api_accuracy_checker.compare.compare import Comparator +from api_accuracy_checker.run_ut.run_ut import get_statistics_from_result_csv, _run_ut_parser +from api_accuracy_checker.dump.info_dump import write_json + + +def split_json_file(input_file, num_splits): + with FileOpen(input_file, 'r') as file: + data = json.load(file) + + items = list(data.items()) + total_items = len(items) + + chunk_size = total_items // num_splits + split_files = [] + + for i in range(num_splits): + start = i * chunk_size + end = (i + 1) * chunk_size if i < num_splits - 1 else total_items + split_filename = os.path.join("./", f"temp_part{i}.json") + for item in items[start:end]: + write_json(split_filename, {item[0]: item[1]}) + split_files.append(split_filename) + + return split_files + + +ParallelUTConfig = namedtuple('ParallelUTConfig', ['forward_files', 'backward_files', 'out_path', 'num_splits', 'save_error_data_flag', 'jit_compile_flag', 'device_id', 'csv_file']) + + +def run_parallel_ut(config): + processes = [] + + def create_cmd(fwd, bwd): + cmd = [ + sys.executable, 'run_ut.py', + '-forward', fwd, + '-backward' if bwd else '', + bwd if bwd else '', + '-o', config.out_path, + '-d' if config.device_id else '', + str(config.device_id) if config.device_id else '', + '-j' if config.jit_compile_flag else '', + '-save_error_data' if config.save_error_data_flag else '', + '-c' if config.csv_file else '', + config.csv_file if config.csv_file else '' + ] + return [arg for arg in cmd if arg] + + commands = [create_cmd(fwd, bwd) for fwd, bwd in zip(config.forward_files, config.backward_files)] + for cmd in commands: + processes.append(subprocess.Popen(cmd)) + + for process in processes: + process.wait() + + for file in config.forward_files: + os.remove(file) + try: + process_csv_and_print_results(config.out_path, config.csv_file) + except FileNotFoundError as e: + print(f"Error: {e}") + except Exception as e: + print(f"An unexpected error occurred: {e}") + + +def process_csv_and_print_results(out_path, csv_file=None): + csv_files = glob.glob(os.path.join(out_path, 'accuracy_checking_result_*.csv')) + if not csv_files: + raise FileNotFoundError("No CSV files found in the specified output path.") + latest_csv = max(csv_files, key=os.path.getmtime) + + comparator = Comparator(out_path, out_path, False) + comparator.test_result_cnt = get_statistics_from_result_csv(latest_csv) + comparator.print_pretest_result() + + +def main(): + parser = argparse.ArgumentParser(description='Run UT in parallel') + _run_ut_parser(parser) + parser.add_argument('-n', '--num_splits', type=int, choices=range(1, 21), default=10, help='Number of splits for parallel processing. Range: 1-20') + args = parser.parse_args() + + check_link(args.forward_input_file) + check_link(args.backward_input_file) + forward_file = os.path.realpath(args.forward_input_file) + backward_file = os.path.realpath(args.backward_input_file) if args.backward_input_file else None + check_file_suffix(forward_file, FileCheckConst.JSON_SUFFIX) + check_file_suffix(backward_file, FileCheckConst.JSON_SUFFIX) + out_path = os.path.realpath(args.out_path) if args.out_path else "./" + out_path_checker = FileChecker(out_path, FileCheckConst.DIR, ability=FileCheckConst.WRITE_ABLE) + out_path = out_path_checker.common_check() + forward_splits = split_json_file(args.forward_input_file, args.num_splits) + backward_splits = [backward_file] * args.num_splits if backward_file else [None] * args.num_splits + + config = ParallelUTConfig(forward_splits, backward_splits, args.out_path, args.num_splits, args.save_error_data, args.jit_compile, args.device_id, args.csv_file) + run_parallel_ut(config) + +if __name__ == '__main__': + main() diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index 6b1c178a963aa5f6df4f37bf2e3e7e62d8104d1b..b985814acf201a285b23a5aef75aae1e946e8f09 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -322,13 +322,13 @@ def _run_ut_parser(parser): parser.add_argument("-o", "--out_path", dest="out_path", default="", type=str, help=" The ut task result out path.", required=False) - parser.add_argument('-save_error_data', dest="save_error_data", action="store_true", + parser.add_argument("-s", "--save_error_data", dest="save_error_data", action="store_true", help=" Save compare failed api output.", required=False) parser.add_argument("-j", "--jit_compile", dest="jit_compile", action="store_true", help=" whether to turn on jit compile", required=False) parser.add_argument("-d", "--device", dest="device_id", type=int, help=" set device id to run ut", default=0, required=False) - parser.add_argument("-csv_path", "--result_csv_path", dest="result_csv_path", default="", type=str, + parser.add_argument("-c", "--result_csv_path", dest="result_csv_path", default="", type=str, help=" The path of accuracy_checking_result_{timestamp}.csv, " "when run ut is interrupted, enter the file path to continue run ut.", required=False)