diff --git "a/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" "b/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" index 20f5e6090084b0e98b58a3fbed6317a8bf48d173..c1b4f812913f3c5b5925f8a708d291be742e0716 100644 --- "a/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" +++ "b/ACL_PyTorch/docs/ONNX/04.\347\246\273\347\272\277\346\216\250\347\220\206.md" @@ -152,7 +152,7 @@ print(inferencer.e2e_inference("./ILSVRC2012_val_00006083.jpeg")) 2. 使用数据集测试模型精度 ```python - def evaluate(inferencer, image_dir, label_file): + def evaluate(inferencer, img_dir, label_file): groundtruth = {} with open(label_file, 'r') as f: for line in f: @@ -161,7 +161,7 @@ print(inferencer.e2e_inference("./ILSVRC2012_val_00006083.jpeg")) num_total, num_right = 0, 0 for i, image_name in tqdm.tqdm(enumerate(os.listdir(img_dir))): - label = gt_dict[image_name] + label = groundtruth[image_name] image_path = osp.join(img_dir, image_name) pred = inferencer.e2e_inference(image_path)['class_id'] num_total += 1