From fac5a5f3af9f17ad3bb9f3e9cb2d73c6a954f5a3 Mon Sep 17 00:00:00 2001 From: qian-dan <756328797@qq.com> Date: Wed, 16 Jul 2025 16:32:36 +0800 Subject: [PATCH] modify onnx_demo Signed-off-by: qian-dan <756328797@qq.com> --- docs/sample_code/golden/onnx_demo.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/docs/sample_code/golden/onnx_demo.py b/docs/sample_code/golden/onnx_demo.py index 959ecb8126..0c2980c9a0 100644 --- a/docs/sample_code/golden/onnx_demo.py +++ b/docs/sample_code/golden/onnx_demo.py @@ -34,7 +34,8 @@ def generate_random_input_data_and_run_model(args): if args.inDataFile: # User specifies input input_data_files = args.inDataFile.split(':') - assert len(input_data_files) == len(input_tensors), "Shape of input data is not compatible with the model!" + if not len(input_data_files) == len(input_tensors): + raise ValueError("Shape of input data is not compatible with the model!") for i, input_tensor in enumerate(input_tensors): tensor_type = input_tensor.type[7:-1] if tensor_type == "float": @@ -48,12 +49,12 @@ def generate_random_input_data_and_run_model(args): if args.inputShape: stable_shape = ast.literal_eval(args.inputShape) shape_info = stable_shape[i] - assert len(input_data) == np.prod( - shape_info), "Shape of input data is not compatible with the input shape!" + if not len(input_data) == np.prod(shape_info): + raise ValueError("Shape of input data is not compatible with the input shape!") input_data = input_data.reshape(shape_info) else: - assert len(input_data) == np.prod( - input_tensor.shape), "Shape of input data is not compatible with the model!" + if not len(input_data) == np.prod(input_tensor.shape): + raise ValueError("Shape of input data is not compatible with the model!") input_data = input_data.reshape(input_tensor.shape) input_dict[input_tensor.name] = input_data else: # Generate random input and save -- Gitee