From 122d7490d2dc35878603beb61c900dbfb695aa12 Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 21 Sep 2023 11:28:47 +0800 Subject: [PATCH 1/2] fix --- .../api_accuracy_checker/run_ut/data_generate.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index aea5675cef..45f1096be6 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -67,10 +67,14 @@ def gen_real_tensor(data_path, convert_type): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) - if not data_path.endswith('.pt'): - print_error_log(f"The file: {data_path} is not a pt file.") + if not data_path.endswith('.pt') or data_path.endswith('.npy'): + print_error_log(f"The file: {data_path} is not a pt or numpy file.") raise CompareException.INVALID_FILE_ERROR - data = torch.load(data_path) + if data_path.endswith('.pt'): + data = torch.load(data_path) + else: + data_np = np.load(data_path) + data = torch.from_numpy(data_np) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] dist_dtype = Const.CONVERT.get(convert_type)[1] -- Gitee From 74758a49e056f52bf63895ace189b56d631ddffe Mon Sep 17 00:00:00 2001 From: gitee Date: Thu, 21 Sep 2023 12:40:16 +0800 Subject: [PATCH 2/2] fix --- .../accuracy_tools/api_accuracy_checker/run_ut/data_generate.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 45f1096be6..9fdb8bbfaa 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -67,7 +67,7 @@ def gen_real_tensor(data_path, convert_type): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) - if not data_path.endswith('.pt') or data_path.endswith('.npy'): + if not data_path.endswith('.pt') and not data_path.endswith('.npy'): print_error_log(f"The file: {data_path} is not a pt or numpy file.") raise CompareException.INVALID_FILE_ERROR if data_path.endswith('.pt'): -- Gitee