diff --git a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py index b95ec8fa3e47327e24d1ecd532ad0a13b46586ee..6211032d0c5dc7fb1768fa8f8372001bcda0a181 100644 --- a/debug/accuracy_tools/api_accuracy_checker/compare/compare.py +++ b/debug/accuracy_tools/api_accuracy_checker/compare/compare.py @@ -73,7 +73,7 @@ class Comparator: "Npu Name", "Bench Dtype", "NPU Dtype", "Shape", "Cosine Similarity", "Max Abs Error", - "Relative Error (hundredth)", + "Relative Error (dual hundredth)", "Relative Error (dual thousandth)", "Relative Error (dual ten thousandth)", "Error Rate", diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py index fb885168d662006b3fa639c2322113d6f1f6b386..c6c273fdd4f58ac0b4253246497fe1f03aaa5f1d 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/run_ut.py @@ -37,6 +37,21 @@ DETAILS_FILE_NAME = "accuracy_checking_details_" + current_time + ".csv" RunUTConfig = namedtuple('RunUTConfig', ['forward_content', 'backward_content', 'result_csv_path', 'details_csv_path', 'save_error_data', 'is_continue_run_ut', 'test_result_cnt']) +tqdm_params = { + 'smoothing': 0, # 平滑进度条的预计剩余时间,取值范围0到1 + 'desc': 'Processing', # 进度条前的描述文字 + 'leave': True, # 迭代完成后保留进度条的显示 + 'ncols': 75, # 进度条的固定宽度 + 'mininterval': 0.1, # 更新进度条的最小间隔秒数 + 'maxinterval': 1.0, # 更新进度条的最大间隔秒数 + 'miniters': 1, # 更新进度条之间的最小迭代次数 + 'ascii': None, # 根据环境自动使用ASCII或Unicode字符 + 'unit': 'it', # 迭代单位 + 'unit_scale': True, # 自动根据单位缩放 + 'dynamic_ncols': True, # 动态调整进度条宽度以适应控制台 + 'bar_format': '{l_bar}{bar}| {n}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]' # 自定义进度条输出格式 +} + def exec_api(api_type, api_name, args, kwargs): if api_type == "Functional": @@ -119,7 +134,7 @@ def run_ut(config): csv_reader = csv.reader(file) next(csv_reader) api_name_set = {row[0] for row in csv_reader} - for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items())): + for i, (api_full_name, api_info_dict) in enumerate(tqdm(config.forward_content.items(), **tqdm_params)): if api_full_name in api_name_set: continue try: