diff --git a/application_example/maskrcnn/src/dataset/__init__.py b/application_example/maskrcnn/src/dataset/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/application_example/maskrcnn/src/dataset/dataset.py b/application_example/maskrcnn/src/dataset/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..dd983ce7ee52c129eee1962654aeef3cae2a6b80 --- /dev/null +++ b/application_example/maskrcnn/src/dataset/dataset.py @@ -0,0 +1,825 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""MaskRcnn dataset""" +from __future__ import division + +import os + +import cv2 +import numpy as np +from numpy import random +from pycocotools.coco import COCO +from pycocotools import mask as maskHelper +import mindspore.dataset as de +import mindspore.dataset.vision.c_transforms as C +from mindspore.mindrecord import FileWriter + +from utils.config import config + + +def bbox_overlaps(bboxes1, bboxes2, mode='iou'): + """Calculate the ious between each bbox of bboxes1 and bboxes2. + + Args: + bboxes1(array): shape (n, 4) + bboxes2(array): shape (k, 4) + mode(str): iou (intersection over union) or iof (intersection + over foreground) + + Returns: + ious(array), shape (n, k) + """ + bboxes1 = bboxes1.astype(np.float32) + bboxes2 = bboxes2.astype(np.float32) + rows = bboxes1.shape[0] + cols = bboxes2.shape[0] + ious = np.zeros((rows, cols), dtype=np.float32) + if rows * cols == 0: + return ious + exchange = False + if bboxes1.shape[0] > bboxes2.shape[0]: + bboxes1, bboxes2 = bboxes2, bboxes1 + ious = np.zeros((cols, rows), dtype=np.float32) + exchange = True + area1 = (bboxes1[:, 2] - bboxes1[:, 0] + 1) * (bboxes1[:, 3] - bboxes1[:, 1] + 1) + area2 = (bboxes2[:, 2] - bboxes2[:, 0] + 1) * (bboxes2[:, 3] - bboxes2[:, 1] + 1) + for i in range(bboxes1.shape[0]): + x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0]) + y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1]) + x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2]) + y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3]) + overlap = np.maximum(x_end - x_start + 1, 0) * np.maximum( + y_end - y_start + 1, 0) + if mode == 'iou': + union = area1[i] + area2 - overlap + else: + union = area1[i] if not exchange else area2 + ious[i, :] = overlap / union + if exchange: + ious = ious.T + return ious + + +class PhotoMetricDistortion: + """ + Random Photo Metric Distortion + + this function is used to distort image randomly. + + Args: + brightness_delta (int): brightness range bound. Default: 32. + contrast_range (Tuple): set the contrast range. Default: (0.5, 1.5) + saturation_range (Tuple): set the saturation range. Default: (0.5, 1.5) + hue_delta (int): set the hue value. Default: 18 + """ + def __init__(self, brightness_delta=32, contrast_range=(0.5, 1.5), saturation_range=(0.5, 1.5), hue_delta=18): + self.brightness_delta = brightness_delta + self.contrast_lower, self.contrast_upper = contrast_range + self.saturation_lower, self.saturation_upper = saturation_range + self.hue_delta = hue_delta + + def __call__(self, img, boxes, labels): + # random brightness + img = img.astype('float32') + + if random.randint(2): + delta = random.uniform(-self.brightness_delta, self.brightness_delta) + img += delta + + # mode == 0 --> do random contrast first + # mode == 1 --> do random contrast last + mode = random.randint(2) + if mode == 1: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, self.contrast_upper) + img *= alpha + + # convert color from BGR to HSV + img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) + + # random saturation + if random.randint(2): + img[..., 1] *= random.uniform(self.saturation_lower, self.saturation_upper) + + # random hue + if random.randint(2): + img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta) + img[..., 0][img[..., 0] > 360] -= 360 + img[..., 0][img[..., 0] < 0] += 360 + + # convert color from HSV to BGR + img = cv2.cvtColor(img, cv2.COLOR_HSV2BGR) + + # random contrast + if mode == 0: + if random.randint(2): + alpha = random.uniform(self.contrast_lower, self.contrast_upper) + img *= alpha + + # randomly swap channels + if random.randint(2): + img = img[..., random.permutation(3)] + + return img, boxes, labels + + +class Expand: + """ + Expand image + + Args: + img (Tensor): input image + boxes (Tuple): bounding box array + labels (Tuple): input labels + mask (Tuple): expanded masks. + """ + def __init__(self, mean=(0, 0, 0), to_rgb=True, ratio_range=(1, 4)): + if to_rgb: + self.mean = mean[::-1] + else: + self.mean = mean + self.min_ratio, self.max_ratio = ratio_range + + def __call__(self, img, boxes, labels, mask): + if random.randint(2): + return img, boxes, labels, mask + + h, w, c = img.shape + ratio = random.uniform(self.min_ratio, self.max_ratio) + expand_img = np.full((int(h * ratio), int(w * ratio), c), self.mean).astype(img.dtype) + left = int(random.uniform(0, w * ratio - w)) + top = int(random.uniform(0, h * ratio - h)) + expand_img[top:top + h, left:left + w] = img + img = expand_img + boxes += np.tile((left, top), 2) + + mask_count, mask_h, mask_w = mask.shape + expand_mask = np.zeros((mask_count, int(mask_h * ratio), int(mask_w * ratio))).astype(mask.dtype) + expand_mask[:, top:top + h, left:left + w] = mask + mask = expand_mask + + return img, boxes, labels, mask + + +def rescale_with_tuple(img, scale): + """ + rescale tuples. + + Args: + img(Array): image. + scale(int): scale coefficient. + + Returns: + rescaled_img, Array, rescaled image. + scale_factor, int, scaling factor. + """ + height, width = img.shape[:2] + scale_factor = min(max(scale) / max(height, width), min(scale) / min(height, width)) + new_size = int(width * float(scale_factor) + 0.5), int(height * float(scale_factor) + 0.5) + rescaled_img = cv2.resize(img, new_size, interpolation=cv2.INTER_LINEAR) + + return rescaled_img, scale_factor + + +def rescale_with_factor(img, scale_factor): + """ + Rescale factors. + + Args: + img(Array): image. + scale_factor(int): scale coefficient. + + Returns: + Array, resized image. + """ + height, width = img.shape[:2] + new_size = int(width * float(scale_factor) + 0.5), int(height * float(scale_factor) + 0.5) + return cv2.resize(img, new_size, interpolation=cv2.INTER_NEAREST) + + +def rescale_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + Rescale operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data, scale_factor = rescale_with_tuple(img, (config.img_width, config.img_height)) + if img_data.shape[0] > config.img_height: + img_data, scale_factor2 = rescale_with_tuple(img_data, (config.img_height, config.img_height)) + scale_factor = scale_factor*scale_factor2 + + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_data.shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_data.shape[0] - 1) + + gt_mask_data = np.array([rescale_with_factor(mask, scale_factor) for mask in gt_mask]) + + pad_h = config.img_height - img_data.shape[0] + pad_w = config.img_width - img_data.shape[1] + assert ((pad_h >= 0) and (pad_w >= 0)) + + pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype) + pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data + + mask_count, mask_h, mask_w = gt_mask_data.shape + pad_mask = np.zeros((mask_count, config.img_height, config.img_width)).astype(gt_mask_data.dtype) + pad_mask[:, 0:mask_h, 0:mask_w] = gt_mask_data + + img_shape = (config.img_height, config.img_width, 1.0) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, pad_mask) + + +def rescale_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + rescale operation for image of eval + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data, scale_factor = rescale_with_tuple(img, (config.img_width, config.img_height)) + if img_data.shape[0] > config.img_height: + img_data, scale_factor2 = rescale_with_tuple(img_data, (config.img_height, config.img_height)) + scale_factor = scale_factor*scale_factor2 + + pad_h = config.img_height - img_data.shape[0] + pad_w = config.img_width - img_data.shape[1] + assert ((pad_h >= 0) and (pad_w >= 0)) + + pad_img_data = np.zeros((config.img_height, config.img_width, 3)).astype(img_data.dtype) + pad_img_data[0:img_data.shape[0], 0:img_data.shape[1], :] = img_data + + img_shape = np.append(img_shape, (scale_factor, scale_factor)) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (pad_img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def resize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + resize operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data = img + h, w = img_data.shape[:2] + img_data = cv2.resize(img_data, (config.img_width, config.img_height), interpolation=cv2.INTER_LINEAR) + h_scale = config.img_height / h + w_scale = config.img_width / w + + scale_factor = np.array([w_scale, h_scale, w_scale, h_scale], dtype=np.float32) + img_shape = (config.img_height, config.img_width, 1.0) + img_shape = np.asarray(img_shape, dtype=np.float32) + + gt_bboxes = gt_bboxes * scale_factor + gt_bboxes[:, 0::2] = np.clip(gt_bboxes[:, 0::2], 0, img_shape[1] - 1) + gt_bboxes[:, 1::2] = np.clip(gt_bboxes[:, 1::2], 0, img_shape[0] - 1) + + gt_mask_data = np.array([cv2.resize(mask, (config.img_width, config.img_height), + interpolation=cv2.INTER_NEAREST) for mask in gt_mask]) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data) + + +def resize_column_test(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + resize operation for image of eval + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data = img + h, w = img_data.shape[:2] + img_data = cv2.resize(img_data, (config.img_width, config.img_height), interpolation=cv2.INTER_LINEAR) + h_scale = config.img_height / h + w_scale = config.img_width / w + + img_shape = np.append(img_shape, (h_scale, w_scale)) + img_shape = np.asarray(img_shape, dtype=np.float32) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def impad_to_multiple_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + impad operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data = cv2.copyMakeBorder(img, 0, config.img_height - img.shape[0], 0, + config.img_width - img.shape[1], + cv2.BORDER_CONSTANT, value=0) + img_data = img_data.astype(np.float32) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def imnormalize_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + imnormalize operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + mean = np.asarray([123.675, 116.28, 103.53]) + std = np.asarray([58.395, 57.12, 57.375]) + img_data = img.copy().astype(np.float32) + cv2.cvtColor(img_data, cv2.COLOR_BGR2RGB, img_data) + cv2.subtract(img_data, np.float64(mean.reshape(1, -1)), img_data) + cv2.multiply(img_data, 1 / np.float64(std.reshape(1, -1)), img_data) + + img_data = img_data.astype(np.float32) + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def flip_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + flip operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data = img + img_data = np.flip(img_data, axis=1) + flipped = gt_bboxes.copy() + _, w, _ = img_data.shape + + flipped[..., 0::4] = w - gt_bboxes[..., 2::4] - 1 + flipped[..., 2::4] = w - gt_bboxes[..., 0::4] - 1 + + gt_mask_data = np.array([mask[:, ::-1] for mask in gt_mask]) + + return (img_data, img_shape, flipped, gt_label, gt_num, gt_mask_data) + + +def transpose_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + transpose operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + img_data = img.transpose(2, 0, 1).copy() + img_data = img_data.astype(np.float32) + img_shape = img_shape.astype(np.float32) + gt_bboxes = gt_bboxes.astype(np.float32) + gt_label = gt_label.astype(np.int32) + gt_num = gt_num.astype(np.bool) + gt_mask_data = gt_mask.astype(np.bool) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask_data) + + +def photo_crop_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + photo crop operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + random_photo = PhotoMetricDistortion() + img_data, gt_bboxes, gt_label = random_photo(img, gt_bboxes, gt_label) + + return (img_data, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def expand_column(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask): + """ + expand operation for image + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + expand = Expand() + img, gt_bboxes, gt_label, gt_mask = expand(img, gt_bboxes, gt_label, gt_mask) + + return (img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + + +def pad_to_max(img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask, instance_count): + """ + pad the image by the max instance count. + + Args: + img(Array): image + img_shape(Tuple): image shape + gt_bboxes(Array): GT bounding box + gt_label(Array): GT label + gt_num(int): GT number + gt_mask(Array): GT mask array. + + Returns: + Tuple, A tuple of (img, img_shape, gt_bboxes, gt_label, gt_num, gt_mask) + """ + pad_max_number = config.max_instance_count + gt_box_new = np.pad(gt_bboxes, ((0, pad_max_number - instance_count), (0, 0)), mode="constant", constant_values=0) + gt_label_new = np.pad(gt_label, ((0, pad_max_number - instance_count)), mode="constant", constant_values=-1) + gt_iscrowd_new = np.pad(gt_num, ((0, pad_max_number - instance_count)), mode="constant", constant_values=1) + gt_iscrowd_new_revert = ~(gt_iscrowd_new.astype(np.bool)) + + return img, img_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask + + +def preprocess_fn(image, box, mask, mask_shape, is_training): + """ + Preprocess function for dataset. + + Args: + img(Array): image + box(Array): evaluated bounding box + mask(Array): GT mask array. + mask_shape(Tuple): mask shape + is_training(bool): param for training. + + Returns: + Array, augmented data. + """ + def _infer_data(image_bgr, image_shape, gt_box_new, gt_label_new, + gt_iscrowd_new_revert, gt_mask_new, instance_count): + """ + inference on data. + + Args: + image_bgr(Array): BGR image + image_shape(tuple): image shape. + gt_box_new(Array): processed GT bounding box. + gt_label_new(Array): processed GT label. + gt_iscrowd_new_revert(Array): a column of gt bounding box. + gt_mask_new(Array): processed GT mask. + instance_count(int): instance number + + Returns: + Tuple, processed data + """ + image_shape = image_shape[:2] + input_data = image_bgr, image_shape, gt_box_new, gt_label_new, gt_iscrowd_new_revert, gt_mask_new + + if config.keep_ratio: + input_data = rescale_column_test(*input_data) + else: + input_data = resize_column_test(*input_data) + input_data = imnormalize_column(*input_data) + + input_data = pad_to_max(*input_data, instance_count) + output_data = transpose_column(*input_data) + return output_data + + def _data_aug(image, box, mask, mask_shape, is_training): + """ + Data augmentation function. + + Args: + img(Array): image + box(Array): evaluated bounding box + mask(Array): GT mask array. + mask_shape(Tuple): mask shape + is_training(bool): param for training. + Returns: + Tuple, augmented data. + """ + image_bgr = image.copy() + image_bgr[:, :, 0] = image[:, :, 2] + image_bgr[:, :, 1] = image[:, :, 1] + image_bgr[:, :, 2] = image[:, :, 0] + image_shape = image_bgr.shape[:2] + instance_count = box.shape[0] + gt_box = box[:, :4] + gt_label = box[:, 4] + gt_iscrowd = box[:, 5] + gt_mask = mask.copy() + n, h, w = mask_shape + gt_mask = gt_mask.reshape(n, h, w) + assert n == box.shape[0] + + if not is_training: + return _infer_data(image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask, instance_count) + + flip = (np.random.rand() < config.flip_ratio) + expand = (np.random.rand() < config.expand_ratio) + + input_data = image_bgr, image_shape, gt_box, gt_label, gt_iscrowd, gt_mask + + if expand: + input_data = expand_column(*input_data) + if config.keep_ratio: + input_data = rescale_column(*input_data) + else: + input_data = resize_column(*input_data) + + input_data = imnormalize_column(*input_data) + if flip: + input_data = flip_column(*input_data) + + input_data = pad_to_max(*input_data, instance_count) + output_data = transpose_column(*input_data) + return output_data + + return _data_aug(image, box, mask, mask_shape, is_training) + + +def ann_to_mask(ann, height, width): + """ + Convert annotation to RLE and then to binary mask. + + Args: + ann(bool): annotations + height(int): mask height + width(int): mask width + + Returns: + Array, a mask. + """ + segm = ann['segmentation'] + if isinstance(segm, list): + rles = maskHelper.frPyObjects(segm, height, width) + rle = maskHelper.merge(rles) + elif isinstance(segm['counts'], list): + rle = maskHelper.frPyObjects(segm, height, width) + else: + rle = ann['segmentation'] + m = maskHelper.decode(rle) + return m + + +def create_coco_label(is_training): + """ + Get image path and annotation from COCO. + + Args: + is_training: param for training + + Returns: + Tuple, a tuple of (image_files, image_anno_dict, masks, masks_shape) + """ + coco_root = config.data_root + data_type = config.val_data_type + if is_training: + data_type = config.train_data_type + + # Classes need to train or test. + train_cls = config.data_classes + + train_cls_dict = {} + for i, cls in enumerate(train_cls): + train_cls_dict[cls] = i + + anno_json = os.path.join(coco_root, config.instance_set.format(data_type)) + print(anno_json) + + coco = COCO(anno_json) + classs_dict = {} + cat_ids = coco.loadCats(coco.getCatIds()) + for cat in cat_ids: + classs_dict[cat["id"]] = cat["name"] + + image_ids = coco.getImgIds() + image_files = [] + image_anno_dict = {} + masks = {} + masks_shape = {} + images_num = len(image_ids) + print('images num:', images_num) + for ind, img_id in enumerate(image_ids): + image_info = coco.loadImgs(img_id) + file_name = image_info[0]["file_name"] + image_path = os.path.join(coco_root, data_type, file_name) + if not os.path.isfile(image_path): + print("{}/{}: {} is in annotations but not exist".format(ind + 1, images_num, image_path)) + continue + anno_ids = coco.getAnnIds(imgIds=img_id, iscrowd=None) + anno = coco.loadAnns(anno_ids) + + annos = [] + instance_masks = [] + image_height = coco.imgs[img_id]["height"] + image_width = coco.imgs[img_id]["width"] + print(image_height) + print(image_width) + if (ind + 1) % 10 == 0: + print("{}/{}: parsing annotation for image={}".format(ind + 1, images_num, file_name)) + if not is_training: + image_files.append(image_path) + image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) + masks[image_path] = np.zeros([1, 1, 1], dtype=np.bool).tobytes() + masks_shape[image_path] = np.array([1, 1, 1], dtype=np.int32) + else: + for label in anno: + bbox = label["bbox"] + class_name = classs_dict[label["category_id"]] + if class_name in train_cls: + # get coco mask + m = ann_to_mask(label, image_height, image_width) + if m.max() < 1: + print("all black mask!!!!") + continue + # Resize mask for the crowd + if label['iscrowd'] and (m.shape[0] != image_height or + m.shape[1] != image_width): + m = np.ones([image_height, image_width], dtype=np.bool) + instance_masks.append(m) + + # get coco bbox + x1, x2 = bbox[0], bbox[0] + bbox[2] + y1, y2 = bbox[1], bbox[1] + bbox[3] + annos.append([x1, y1, x2, y2] + [train_cls_dict[class_name]] + [int(label["iscrowd"])]) + else: + print("not in classes: ", class_name) + + image_files.append(image_path) + if annos: + image_anno_dict[image_path] = np.array(annos) + instance_masks = np.stack(instance_masks, axis=0).astype(np.bool) + masks[image_path] = np.array(instance_masks).tobytes() + masks_shape[image_path] = np.array(instance_masks.shape, dtype=np.int32) + else: + print("no annotations for image ", file_name) + image_anno_dict[image_path] = np.array([0, 0, 0, 0, 0, 1]) + masks[image_path] = np.zeros([1, image_height, image_width], dtype=np.bool).tobytes() + masks_shape[image_path] = np.array([1, image_height, image_width], dtype=np.int32) + + return image_files, image_anno_dict, masks, masks_shape + + +def data_to_mindrecord_byte_image(dataset="coco", is_training=True, prefix="maskrcnn.mindrecord", file_num=8): + """ + Create MindRecord file. + + Args: + dataset(str): dataset name. default: "coco" + is_training(bool): param for training. default: True + prefix(str): a prefix of mindrecord files. default: "maskrcnn.mindrecord" + file_num(int): file number. default: 8 + """ + mindrecord_dir = config.mindrecord_dir + mindrecord_path = os.path.join(mindrecord_dir, prefix) + print(mindrecord_path) + + writer = FileWriter(mindrecord_path, file_num) + if dataset == "coco": + image_files, image_anno_dict, masks, masks_shape = create_coco_label(is_training) + else: + print("Error unsupported other dataset") + return + + maskrcnn_json = { + "image": {"type": "bytes"}, + "annotation": {"type": "int32", "shape": [-1, 6]}, + "mask": {"type": "bytes"}, + "mask_shape": {"type": "int32", "shape": [-1]}, + } + writer.add_schema(maskrcnn_json, "maskrcnn_json") + + image_files_num = len(image_files) + for ind, image_name in enumerate(image_files): + with open(image_name, 'rb') as f: + img = f.read() + annos = np.array(image_anno_dict[image_name], dtype=np.int32) + mask = masks[image_name] + mask_shape = masks_shape[image_name] + row = {"image": img, "annotation": annos, + "mask": mask, "mask_shape": mask_shape} + if (ind + 1) % 10 == 0: + print("writing {}/{} into mindrecord".format(ind + 1, image_files_num)) + writer.write_raw_data([row]) + writer.commit() + + +def create_coco_dataset(mindrecord_file, batch_size=2, device_num=1, rank_id=0, + is_training=True, num_parallel_workers=8): + """ + Create MaskRcnn COCO dataset with MindDataset. + + Args: + mindrecord_file(str): mindrecord file path. + batch_size(int): batch size. default: 2. + device_num(int): processor number. default: 1. + rank_id(int): processor id. default: 0. + is_training(bool): param for training. default: True + num_parallel_workers(int): number of parallel workers. default: 8 + + Returns: + Tuple, the dataset. + """ + cv2.setNumThreads(0) + de.config.set_prefetch_size(8) + ds = de.MindDataset(mindrecord_file, + columns_list=["image", "annotation", "mask", "mask_shape"], + num_shards=device_num, shard_id=rank_id, + num_parallel_workers=4, shuffle=is_training) + decode = C.Decode() + ds = ds.map(operations=decode, input_columns=["image"]) + compose_map_func = (lambda image, annotation, mask, mask_shape: + preprocess_fn(image, annotation, mask, mask_shape, is_training)) + + if is_training: + ds = ds.map(operations=compose_map_func, + input_columns=["image", "annotation", "mask", "mask_shape"], + output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"], + column_order=["image", "image_shape", "box", "label", "valid_num", "mask"], + python_multiprocessing=False, num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True, pad_info={"mask": ([config.max_instance_count, None, None], 0)}) + + else: + ds = ds.map(operations=compose_map_func, + input_columns=["image", "annotation", "mask", "mask_shape"], + output_columns=["image", "image_shape", "box", "label", "valid_num", "mask"], + column_order=["image", "image_shape", "box", "label", "valid_num", "mask"], + num_parallel_workers=num_parallel_workers) + ds = ds.batch(batch_size, drop_remainder=True) + + return ds diff --git a/application_example/maskrcnn/src/model/__init__.py b/application_example/maskrcnn/src/model/__init__.py index a19715213645bf636fd06d1563e5caac7ce9a85f..9846a1f2e0c9fbbf6b4c1ed22188e9614000d042 100644 --- a/application_example/maskrcnn/src/model/__init__.py +++ b/application_example/maskrcnn/src/model/__init__.py @@ -14,16 +14,16 @@ # ============================================================================ """MaskRcnn Init.""" -from src.model.resnet50 import ResNetFea, ResidualBlockUsing -from src.model.bbox_assign_sample import BboxAssignSample -from src.model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn -from src.model.fpn_neck import FeatPyramidNeck -from src.model.proposal_generator import Proposal -from src.model.rcnn_cls import RcnnCls -from src.model.rcnn_mask import RcnnMask -from src.model.rpn import RPN -from src.model.roi_align import SingleRoIExtractor -from src.model.anchor_generator import AnchorGenerator +from model.resnet50 import ResNetFea, ResidualBlockUsing +from model.bbox_assign_sample import BboxAssignSample +from model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn +from model.fpn_neck import FeatPyramidNeck +from model.proposal_generator import Proposal +from model.rcnn_cls import RcnnCls +from model.rcnn_mask import RcnnMask +from model.rpn import RPN +from model.roi_align import SingleRoIExtractor +from model.anchor_generator import AnchorGenerator __all__ = [ "ResNetFea", "BboxAssignSample", "BboxAssignSampleForRcnn", diff --git a/application_example/maskrcnn/src/model/bbox_assign_sample.py b/application_example/maskrcnn/src/model/bbox_assign_sample.py index 7ed9e0c4c0e992ff3f4405afe3f8357311046366..58e14ba8838eb3353ea300299598ab5eca8a2ea0 100644 --- a/application_example/maskrcnn/src/model/bbox_assign_sample.py +++ b/application_example/maskrcnn/src/model/bbox_assign_sample.py @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """MaskRcnn positive and negative sample screening for RPN.""" + import numpy as np import mindspore.nn as nn from mindspore.ops import operations as P @@ -88,7 +89,6 @@ class BboxAssignSample(nn.Cell): self.reshape = P.Reshape() self.equal = P.Equal() self.bounding_box_encode = P.BoundingBoxEncode(means=(0.0, 0.0, 0.0, 0.0), stds=(1.0, 1.0, 1.0, 1.0)) - self.scatter_nd_update = P.scatter_nd_update() self.scatter_nd = P.ScatterNd() self.logicalnot = P.LogicalNot() self.tile = P.Tile() diff --git a/application_example/maskrcnn/src/model/mask_rcnn_mobilenetv1.py b/application_example/maskrcnn/src/model/mask_rcnn_mobilenetv1.py index ae8adcd16ae0c6335f2344aacab144a1f802eca7..b4edddd4d802d0cd548d4679f9ba1b9c788c96ed 100644 --- a/application_example/maskrcnn/src/model/mask_rcnn_mobilenetv1.py +++ b/application_example/maskrcnn/src/model/mask_rcnn_mobilenetv1.py @@ -21,15 +21,15 @@ from mindspore.common.tensor import Tensor from mindspore.ops import functional as F from mindspore import context -from src.model.mobilenetv1 import MobileNetV1FeatureSelector -from src.model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn -from src.model.fpn_neck import FeatPyramidNeck -from src.model.proposal_generator import Proposal -from src.model.rcnn_cls import RcnnCls -from src.model.rcnn_mask import RcnnMask -from src.model.rpn import RPN -from src.model.roi_align import SingleRoIExtractor -from src.model.anchor_generator import AnchorGenerator +from model.mobilenetv1 import MobileNetV1FeatureSelector +from model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn +from model.fpn_neck import FeatPyramidNeck +from model.proposal_generator import Proposal +from model.rcnn_cls import RcnnCls +from model.rcnn_mask import RcnnMask +from model.rpn import RPN +from model.roi_align import SingleRoIExtractor +from model.anchor_generator import AnchorGenerator class MaskRcnnMobilenetv1(nn.Cell): diff --git a/application_example/maskrcnn/src/model/mask_rcnn_r50.py b/application_example/maskrcnn/src/model/mask_rcnn_r50.py index 09f445c33bf3484c83c920b1c581109a6d91b49b..7f55e0277557fcd7f57a48ee4f55e497f712b051 100644 --- a/application_example/maskrcnn/src/model/mask_rcnn_r50.py +++ b/application_example/maskrcnn/src/model/mask_rcnn_r50.py @@ -20,15 +20,15 @@ from mindspore.ops import operations as P from mindspore.common.tensor import Tensor from mindspore.ops import functional as F -from src.model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn -from src.model.fpn_neck import FeatPyramidNeck -from src.model.proposal_generator import Proposal -from src.model.rcnn_cls import RcnnCls -from src.model.rcnn_mask import RcnnMask -from src.model.rpn import RPN -from src.model.roi_align import SingleRoIExtractor -from src.model.anchor_generator import AnchorGenerator -from src.model.resnet50 import ResNetFea, ResidualBlockUsing +from model.bbox_assign_sample_stage2 import BboxAssignSampleForRcnn +from model.fpn_neck import FeatPyramidNeck +from model.proposal_generator import Proposal +from model.rcnn_cls import RcnnCls +from model.rcnn_mask import RcnnMask +from model.rpn import RPN +from model.roi_align import SingleRoIExtractor +from model.anchor_generator import AnchorGenerator +from model.resnet50 import ResNetFea, ResidualBlockUsing class MaskRcnnResnet50(nn.Cell): diff --git a/application_example/maskrcnn/src/model/roi_align.py b/application_example/maskrcnn/src/model/roi_align.py index b13ecd16d25cd3d9647dbb49616131e04854ce7b..90711ca4e5354151c724e87e41aa042890c1729c 100644 --- a/application_example/maskrcnn/src/model/roi_align.py +++ b/application_example/maskrcnn/src/model/roi_align.py @@ -1,4 +1,4 @@ -# Copyright 2022 Huawei Technologies Co., Ltd +# Copyright 2020-2021 Huawei Technologies Co., Ltd # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,6 +13,7 @@ # limitations under the License. # ============================================================================ """MaskRcnn ROIAlign module.""" + import numpy as np import mindspore.nn as nn import mindspore.common.dtype as mstype @@ -27,32 +28,18 @@ class ROIAlign(nn.Cell): Extract RoI features from mulitiple feature map. Args: - out_size_h (int): RoI height. - out_size_w (int): RoI width. - spatial_scale (int): RoI spatial scale. - sample_num (int): RoI sample number. Default: 0. - roi_align_mode (int): RoI align mode. Default: 1. - - Inputs: - - **features** (Tensor) - The input features, whose shape must be :math:'(N, C, H, W)'. - - **rois** (Tensor) - The shape is :math:'(rois_n, 5)'. With data type of float16 or float32. - - Outputs: - Tensor, the shape is :math: '(rois_n, C, pooled_height, pooled_width)'. - - Support Platform: - ``Ascend`` ``CPU`` ``GPU`` - - Examples: - >>> features = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) - >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32) - >>> roi_align = ops.ROIAlign(2, 2, 0.5, 2) - >>> output = roi_align(features, rois) - >>> print(output) - [[[[1.775 2.025] - [2.275 2.525]]]] + out_size_h (int) - RoI height. + out_size_w (int) - RoI width. + spatial_scale (int) - RoI spatial scale. + sample_num (int) - RoI sample number. + roi_align_mode (int)- RoI align mode """ - def __init__(self, out_size_h, out_size_w, spatial_scale, sample_num=0, roi_align_mode=1): + def __init__(self, + out_size_h, + out_size_w, + spatial_scale, + sample_num=0, + roi_align_mode=1): super(ROIAlign, self).__init__() self.out_size = (out_size_h, out_size_w) @@ -68,8 +55,8 @@ class ROIAlign(nn.Cell): def __repr__(self): format_str = self.__class__.__name__ - format_str += \ - '(out_size={}, spatial_scale={}, sample_num={}'.format(self.out_size, self.spatial_scale, self.sample_num) + format_str += '(out_size={}, spatial_scale={}, sample_num={}'.format( + self.out_size, self.spatial_scale, self.sample_num) return format_str @@ -77,49 +64,37 @@ class SingleRoIExtractor(nn.Cell): """ Extract RoI features from a single level feature map. - If there are multiple input feature levels, each RoI is mapped to a level according to its scale. + If there are multiple input feature levels, each RoI is mapped to a level + according to its scale. Args: config (dict): Config + roi_layer (dict): Specify RoI layer type and arguments. out_channels (int): Output channels of RoI layers. featmap_strides (int): Strides of input feature maps. - batch_size (int): Batchsize. Default: 1. - finest_scale (int): Scale threshold of mapping to level 0. Default: 56. - mask (bool): Specify ROIAlign for cls or mask branch. Default: False. - - Inputs: - - **rois** (Tensor) - The shape is :math:'(rois_n, 5)'. With data type of float16 or float32. - - **feat1** (Tensor) - The input features, whose shape must be :math:'(N, C, H, W)'. - - **feat2** (Tensor) - The input features, whose shape must be :math:'(N, C, H, W)'. - - **feat3** (Tensor) - The input features, whose shape must be :math:'(N, C, H, W)'. - - **feat4** (Tensor) - The input features, whose shape must be :math:'(N, C, H, W)'. - - Outputs: - Tensor, the shape is :math:'(rois_n, C, pooled_height, pooled_width)'. - - Support Platform: - ``Ascend`` ``CPU`` ``GPU`` - - Examples: - >>> fea1 = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) - >>> fea2 = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) - >>> fea3 = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) - >>> fea4 = Tensor(np.array([[[[1., 2.], [3., 4.]]]]), mindspore.float32) - >>> rois = Tensor(np.array([[0, 0.2, 0.3, 0.2, 0.3]]), mindspore.float32) - >>> single_roi = ops.SingleRoIExtractor(conifg, 2, 1, 2, 2, mask) - >>> output = single_roi(rois, fea1, fea2, fea3, fea4) + batch_size (int): Batchsize. + finest_scale (int): Scale threshold of mapping to level 0. + mask (bool): Specify ROIAlign for cls or mask branch """ - def __init__(self, config, roi_layer, out_channels, featmap_strides, batch_size=1, finest_scale=56, mask=False): + def __init__(self, + config, + roi_layer, + out_channels, + featmap_strides, + batch_size=1, + finest_scale=56, + mask=False): super(SingleRoIExtractor, self).__init__() cfg = config self.train_batch_size = batch_size self.out_channels = out_channels self.featmap_strides = featmap_strides self.num_levels = len(self.featmap_strides) - self.out_size = roi_layer.mask_out_size if mask else roi_layer.out_size + + self.out_size = config.roi_layer.mask_out_size if mask else config.roi_layer.out_size self.mask = mask - self.sample_num = roi_layer.sample_num + self.sample_num = config.roi_layer.sample_num self.roi_layers = self.build_roi_layers(self.featmap_strides) self.roi_layers = L.CellList(self.roi_layers) @@ -132,9 +107,9 @@ class SingleRoIExtractor(nn.Cell): self.equal = P.Equal() self.select = P.Select() - in_mode_16 = False - self.dtype = np.float16 if in_mode_16 else np.float32 - self.ms_dtype = mstype.float16 if in_mode_16 else mstype.float32 + _mode_16 = False + self.dtype = np.float16 if _mode_16 else np.float32 + self.ms_dtype = mstype.float16 if _mode_16 else mstype.float32 self.set_train_local(cfg, training=True) def set_train_local(self, config, training=True): @@ -143,24 +118,43 @@ class SingleRoIExtractor(nn.Cell): cfg = config # Init tensor - roi_sample_num = cfg.num_expected_pos_stage2 if self.mask else cfg.roi_sample_num - self.batch_size = roi_sample_num if self.training_local else cfg.rpn_max_num + roi_sample_num = \ + cfg.num_expected_pos_stage2 if self.mask else cfg.roi_sample_num + self.batch_size = \ + roi_sample_num if self.training_local else cfg.rpn_max_num self.batch_size = self.train_batch_size*self.batch_size \ if self.training_local else cfg.test_batch_size*self.batch_size - self.ones = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)) - finest_scale = np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * self.finest_scale_ + self.ones = \ + Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)) + finest_scale = \ + np.array(np.ones((self.batch_size, 1)), + dtype=self.dtype) * self.finest_scale_ self.finest_scale = Tensor(finest_scale) - self.epslion = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype)*self.dtype(1e-6)) - self.zeros = Tensor(np.array(np.zeros((self.batch_size, 1)), dtype=np.int32)) - self.max_levels = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=np.int32)*(self.num_levels-1)) - self.twos = Tensor(np.array(np.ones((self.batch_size, 1)), dtype=self.dtype) * 2) - self.res_ = Tensor(np.array(np.zeros((self.batch_size, self.out_channels, self.out_size, self.out_size)), - dtype=self.dtype)) + self.epslion = \ + Tensor(np.array(np.ones((self.batch_size, 1)), + dtype=self.dtype)*self.dtype(1e-6)) + self.zeros = \ + Tensor(np.array(np.zeros((self.batch_size, 1)), + dtype=np.int32)) + self.max_levels = \ + Tensor(np.array(np.ones((self.batch_size, 1)), + dtype=np.int32)*(self.num_levels-1)) + self.twos = \ + Tensor(np.array(np.ones((self.batch_size, 1)), + dtype=self.dtype) * 2) + self.res_ = \ + Tensor(np.array(np.zeros((self.batch_size, self.out_channels, + self.out_size, self.out_size)), + dtype=self.dtype)) def num_inputs(self): """input number.""" return len(self.featmap_strides) + def init_weights(self): + """initialize weights.""" + pass + def log2(self, value): """calculate log2.""" return self.log(value) / self.log(self.twos) @@ -169,8 +163,10 @@ class SingleRoIExtractor(nn.Cell): """build ROI layers.""" roi_layers = [] for s in featmap_strides: - layer_cls = ROIAlign(self.out_size, self.out_size, spatial_scale=1 / s, - sample_num=self.sample_num, roi_align_mode=0) + layer_cls = ROIAlign(self.out_size, self.out_size, + spatial_scale=1 / s, + sample_num=self.sample_num, + roi_align_mode=0) roi_layers.append(layer_cls) return roi_layers @@ -187,7 +183,7 @@ class SingleRoIExtractor(nn.Cell): num_levels (int): Total level number. Returns: - Tensor, Level index (0-based) of each RoI, shape (k, ) + Tensor: Level index (0-based) of each RoI, shape (k, ) """ scale = self.sqrt(rois[::, 3:4:1] - rois[::, 1:2:1] + self.ones) * \ self.sqrt(rois[::, 4:5:1] - rois[::, 2:3:1] + self.ones) @@ -208,8 +204,9 @@ class SingleRoIExtractor(nn.Cell): mask = self.equal(target_lvls, P.ScalarToArray()(i)) mask = P.Reshape()(mask, (-1, 1, 1, 1)) roi_feats_t = self.roi_layers[i](feats[i], rois) - mask = \ - self.cast(P.Tile()(self.cast(mask, mstype.int32), (1, 256, self.out_size, self.out_size)), mstype.bool_) + mask = self.cast(P.Tile()(self.cast(mask, mstype.int32), + (1, 256, self.out_size, self.out_size)), + mstype.bool_) res = self.select(mask, roi_feats_t, res) return res diff --git a/application_example/maskrcnn/src/model/rpn.py b/application_example/maskrcnn/src/model/rpn.py index 22ba2edaf0746a71950ca48a4df30abe194b0112..952cf173cbf985347112c88eb29998ac27dab557 100644 --- a/application_example/maskrcnn/src/model/rpn.py +++ b/application_example/maskrcnn/src/model/rpn.py @@ -21,7 +21,7 @@ from mindspore import Tensor from mindspore.ops import functional as F from mindspore.common.initializer import initializer -from src.model.bbox_assign_sample import BboxAssignSample +from model.bbox_assign_sample import BboxAssignSample class RpnRegClsBlock(nn.Cell): diff --git a/application_example/maskrcnn/src/train.py b/application_example/maskrcnn/src/train.py index 4444e8d20f608c277f2f2a3478987e4f4561d537..bded39d7a774b2caf0939a05f61721e47f7591fd 100644 --- a/application_example/maskrcnn/src/train.py +++ b/application_example/maskrcnn/src/train.py @@ -25,12 +25,12 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.nn import Momentum from mindspore.common import set_seed -from src.utils.config import config +from utils.config import config # when use maskrcnn mobilenetv1, just change the following backbone and defined network # from mask_rcnn_mobilenetv1 and network_define_maskrcnnmobilenetv1 -from src.model.mask_rcnn_r50 import MaskRcnnResnet50 -from src.utils.network_define_maskrcnnresnet50 import LossCallBack, WithLossCell, TrainOneStepCell, LossNet -from src.utils.lr_schedule import dynamic_lr +from model.mask_rcnn_r50 import MaskRcnnResnet50 +from utils.network_define import LossCallBack, WithLossCell, TrainOneStepCell, LossNet +from utils.lr_schedule import dynamic_lr from dataset.dataset import create_coco_dataset, data_to_mindrecord_byte_image diff --git a/application_example/maskrcnn/src/utils/__init__.py b/application_example/maskrcnn/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/application_example/maskrcnn/src/utils/config.py b/application_example/maskrcnn/src/utils/config.py index 4b7b00ed9030cea26021c3786f7fe0110a421517..3cda69190bcae564fb151ef52aaf5eac3283922e 100644 --- a/application_example/maskrcnn/src/utils/config.py +++ b/application_example/maskrcnn/src/utils/config.py @@ -18,6 +18,7 @@ Network config setting, will be used in train.py, eval.py and infer.py import argparse import ast +from easydict import EasyDict as ed def parse_args(): @@ -36,7 +37,7 @@ def parse_args(): parser.add_argument('--enable_profiling', default=False, type=ast.literal_eval, help="Whether to enable profiling.") # Dataset path - parser.add_argument('--data_root', default='/home/maskrcnn/cocodataset', type=str, + parser.add_argument('--data_root', default='../../../coco2017bk', type=str, help="File path of dataset in training.") # MaskRcnn training @@ -46,14 +47,14 @@ def parse_args(): parser.add_argument('--do_train', default=True, type=ast.literal_eval, help="Whether to do train.") parser.add_argument('--do_eval', default=False, type=ast.literal_eval, help="Whether to do eval.") parser.add_argument('--dataset', default='coco', type=str, help="Dataset name") - parser.add_argument('--pre_trained', default='./checkpoint/maskrcnn_coco2017_acc32.9.ckpt', + parser.add_argument('--pre_trained', default='../../../maskrcnnr5/checkpoint/resnet50_ascend_v180_imagenet2012_official_cv_top1acc76.97_top5acc93.44.ckpt', type=str, help="File path of pretrained checkpoint in training.") parser.add_argument('--device_id', default=0, type=int, help="Target device id.") parser.add_argument('--device_num', default=1, type=int, help="Target device number.") parser.add_argument('--rank_id', default=0, type=int, help="Target device rank id.") # MaskRcnn evaluation - parser.add_argument('--ann_file', default='/home/maskrcnn/cocodataset/annotations/instances_val2017.json', + parser.add_argument('--ann_file', default='../../../coco2017bk/annotations/instances_val2017.json', type=str, help="File path of cocodataset annotations.") parser.add_argument('--checkpoint_path', default='./checkpoint/maskrcnn_coco2017_acc32.9.ckpt', type=str, help="File path of pretrained checkpoint in evaluation.") @@ -134,7 +135,7 @@ def parse_args(): help="Whether to use sigmoid for classification.") # roi_align - parser.add_argument('--roi_layer', default=dict(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2), + parser.add_argument('--roi_layer', default=ed(type='RoIAlign', out_size=7, mask_out_size=14, sample_num=2), type=dict, help="The roi layer.") parser.add_argument('--roi_align_out_channels', default=256, type=int, help="The roi align output channels.") parser.add_argument('--roi_align_featmap_strides', default=[4, 8, 16, 32], type=int, @@ -183,7 +184,7 @@ def parse_args(): help="Whether to do rpn nms across levels.") parser.add_argument('--rpn_nms_pre', default=1000, type=int, help="The rpn nms preparation.") parser.add_argument('--rpn_nms_post', default=1000, type=int, help="The rpn nms post.") - parser.add_argument('--rpn_nms_num', default=1000, type=int, help="The rpn max number.") + parser.add_argument('--rpn_max_num', default=1000, type=int, help="The rpn max number.") parser.add_argument('--rpn_nms_thr', default=0.7, type=float, help="The rpn nms threshold.") parser.add_argument('--rpn_min_bbox_min_size', default=0, type=int, help="The min size of rpn min bounding box.") parser.add_argument('--test_score_thr', default=0.05, type=float, help="The test score threshold.") @@ -219,7 +220,7 @@ def parse_args(): help="File path of pretrained checkpoint to save.") # cocodataset - parser.add_argument('--mindrecord_dir', default='/home/maskrcnn/cocodataset/MindRecord_COCO', type=str, + parser.add_argument('--mindrecord_dir', default='../../../coco2017bk/MindRecord_COCO/MindRecord_COCO', type=str, help="File path of MindRecord to save/read.") parser.add_argument('--train_data_type', default='train2017', type=str, help="The data type for training (it is not necessary for other dataset.).") @@ -245,6 +246,6 @@ def parse_args(): help="The data classes for cocodataset (it is not necessary for other dataset.).") parser.add_argument('--num_classes', default=81, type=int, help="The number of classes for cocodataset (it is not necessary for other dataset.).") - return parser.parse_args(args=[]) + return parser.parse_args() config = parse_args() diff --git a/application_example/maskrcnn/src/utils/network_define.py b/application_example/maskrcnn/src/utils/network_define.py index f9b1fcdaf19d92de5dc0b226ccc0a8adee3d02fc..8e7705bc4c2e59229d0d9373c1e2a1ce9f510f01 100644 --- a/application_example/maskrcnn/src/utils/network_define.py +++ b/application_example/maskrcnn/src/utils/network_define.py @@ -116,7 +116,7 @@ class LossCallBack(Callback): class LossNet(nn.Cell): """MaskRcnn loss sum""" - def construct(self, x1, x2): + def construct(self, x1, x2, x3, x4, x5, x6, x7): return x1 + x2