diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index 28ca86793f1db03647df525ea5037b64623dde35..aea5675cef30b85e5e2603f1ae5dbbc7227499dc 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -67,11 +67,10 @@ def gen_real_tensor(data_path, convert_type): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) - if not data_path.endswith('.npy'): - print_error_log(f"The file: {data_path} is not a numpy file.") + if not data_path.endswith('.pt'): + print_error_log(f"The file: {data_path} is not a pt file.") raise CompareException.INVALID_FILE_ERROR - data_np = np.load(data_path) - data = torch.from_numpy(data_np) + data = torch.load(data_path) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] dist_dtype = Const.CONVERT.get(convert_type)[1]