From 3c9ea1cd7c946a1fb2b872642a06dab87033370c Mon Sep 17 00:00:00 2001 From: gitee Date: Tue, 19 Sep 2023 12:58:38 +0800 Subject: [PATCH] fix --- .../api_accuracy_checker/run_ut/data_generate.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 28ca86793f..aea5675cef 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -67,11 +67,10 @@ def gen_real_tensor(data_path, convert_type): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) - if not data_path.endswith('.npy'): - print_error_log(f"The file: {data_path} is not a numpy file.") + if not data_path.endswith('.pt'): + print_error_log(f"The file: {data_path} is not a pt file.") raise CompareException.INVALID_FILE_ERROR - data_np = np.load(data_path) - data = torch.from_numpy(data_np) + data = torch.load(data_path) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] dist_dtype = Const.CONVERT.get(convert_type)[1] -- Gitee