diff --git a/docs/sample_code/golden/onnx_demo.py b/docs/sample_code/golden/onnx_demo.py index 959ecb81265f85e72e98ef8c241fd3ab9faf390f..0c2980c9a000c96f0b90653a9d306d2d8c9abba2 100644 --- a/docs/sample_code/golden/onnx_demo.py +++ b/docs/sample_code/golden/onnx_demo.py @@ -34,7 +34,8 @@ def generate_random_input_data_and_run_model(args): if args.inDataFile: # User specifies input input_data_files = args.inDataFile.split(':') - assert len(input_data_files) == len(input_tensors), "Shape of input data is not compatible with the model!" + if not len(input_data_files) == len(input_tensors): + raise ValueError("Shape of input data is not compatible with the model!") for i, input_tensor in enumerate(input_tensors): tensor_type = input_tensor.type[7:-1] if tensor_type == "float": @@ -48,12 +49,12 @@ def generate_random_input_data_and_run_model(args): if args.inputShape: stable_shape = ast.literal_eval(args.inputShape) shape_info = stable_shape[i] - assert len(input_data) == np.prod( - shape_info), "Shape of input data is not compatible with the input shape!" + if not len(input_data) == np.prod(shape_info): + raise ValueError("Shape of input data is not compatible with the input shape!") input_data = input_data.reshape(shape_info) else: - assert len(input_data) == np.prod( - input_tensor.shape), "Shape of input data is not compatible with the model!" + if not len(input_data) == np.prod(input_tensor.shape): + raise ValueError("Shape of input data is not compatible with the model!") input_data = input_data.reshape(input_tensor.shape) input_dict[input_tensor.name] = input_data else: # Generate random input and save