From 73761ed7dd6cd93e8ab6e88f39d717019ef01ac5 Mon Sep 17 00:00:00 2001 From: Cathy Wong Date: Thu, 9 Jun 2022 10:49:42 -0400 Subject: [PATCH] [MD] Transforms Unification - Minor Updates Update due to PR2763 new model HRNetW48_cls - research/cv/HRNetW48_cls/src/dataset.py --- research/cv/FaceAttribute/preprocess.py | 2 +- research/cv/FaceAttribute/src/dataset_train.py | 2 +- research/cv/HRNetW48_cls/src/dataset.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/research/cv/FaceAttribute/preprocess.py b/research/cv/FaceAttribute/preprocess.py index 1a2ebb578..70e226427 100644 --- a/research/cv/FaceAttribute/preprocess.py +++ b/research/cv/FaceAttribute/preprocess.py @@ -28,7 +28,7 @@ def eval_data_generator(args): dst_h = args.dst_h batch_size = 1 #attri_num = args.attri_num - transform_img = F2.Compose([F.Decode(True)), + transform_img = F2.Compose([F.Decode(True), F.Resize((dst_w, dst_h)), F.ToTensor(), F.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), is_hwc=False)]) diff --git a/research/cv/FaceAttribute/src/dataset_train.py b/research/cv/FaceAttribute/src/dataset_train.py index 79617bdb7..80d2cc297 100644 --- a/research/cv/FaceAttribute/src/dataset_train.py +++ b/research/cv/FaceAttribute/src/dataset_train.py @@ -28,7 +28,7 @@ def data_generator(args): batch_size = args.per_batch_size attri_num = args.attri_num max_epoch = args.max_epoch - transform_img = F2.Compose([F.Decode(True)), + transform_img = F2.Compose([F.Decode(True), F.Resize((dst_w, dst_h)), F.RandomHorizontalFlip(prob=0.5), F.ToTensor(), diff --git a/research/cv/HRNetW48_cls/src/dataset.py b/research/cv/HRNetW48_cls/src/dataset.py index 3a73f1c52..278153489 100644 --- a/research/cv/HRNetW48_cls/src/dataset.py +++ b/research/cv/HRNetW48_cls/src/dataset.py @@ -19,8 +19,8 @@ from mindspore import dataset as ds from mindspore.common import dtype as mstype from mindspore.communication.management import get_group_size from mindspore.communication.management import get_rank -from mindspore.dataset.transforms import c_transforms as C2 -from mindspore.dataset.vision import c_transforms as C +import mindspore.dataset.transforms as C2 +import mindspore.dataset.vision as C def create_imagenet( -- Gitee