diff --git a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py index aea5675cef30b85e5e2603f1ae5dbbc7227499dc..9fdb8bbfaa7f0bf6b839b9197026a247acfa120b 100644 --- a/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py +++ b/debug/accuracy_tools/api_accuracy_checker/run_ut/data_generate.py @@ -67,10 +67,14 @@ def gen_real_tensor(data_path, convert_type): """ data_path = os.path.realpath(data_path) check_file_or_directory_path(data_path) - if not data_path.endswith('.pt'): - print_error_log(f"The file: {data_path} is not a pt file.") + if not data_path.endswith('.pt') and not data_path.endswith('.npy'): + print_error_log(f"The file: {data_path} is not a pt or numpy file.") raise CompareException.INVALID_FILE_ERROR - data = torch.load(data_path) + if data_path.endswith('.pt'): + data = torch.load(data_path) + else: + data_np = np.load(data_path) + data = torch.from_numpy(data_np) if convert_type: ori_dtype = Const.CONVERT.get(convert_type)[0] dist_dtype = Const.CONVERT.get(convert_type)[1]