From 1165f324892555fc8fc78d88a6ddfdc7072a613e Mon Sep 17 00:00:00 2001 From: "shengquan.nian" Date: Fri, 17 Nov 2023 12:24:02 +0800 Subject: [PATCH] add repmlp --- cv/classification/repmlp/pytorch/LICENSE | 21 + cv/classification/repmlp/pytorch/README.md | 35 ++ cv/classification/repmlp/pytorch/config.py | 213 +++++++++ cv/classification/repmlp/pytorch/convert.py | 33 ++ cv/classification/repmlp/pytorch/cutout.py | 55 +++ .../repmlp/pytorch/data/__init__.py | 1 + .../repmlp/pytorch/data/build.py | 193 ++++++++ .../pytorch/data/cached_image_folder.py | 252 +++++++++++ .../repmlp/pytorch/data/samplers.py | 30 ++ .../repmlp/pytorch/data/zipreader.py | 104 +++++ cv/classification/repmlp/pytorch/logger.py | 42 ++ .../repmlp/pytorch/lr_scheduler.py | 102 +++++ .../repmlp/pytorch/main_repmlp.py | 417 ++++++++++++++++++ cv/classification/repmlp/pytorch/optimizer.py | 68 +++ cv/classification/repmlp/pytorch/randaug.py | 407 +++++++++++++++++ cv/classification/repmlp/pytorch/repmlpnet.py | 325 ++++++++++++++ cv/classification/repmlp/pytorch/test.py | 135 ++++++ cv/classification/repmlp/pytorch/utils.py | 193 ++++++++ 18 files changed, 2626 insertions(+) create mode 100644 cv/classification/repmlp/pytorch/LICENSE create mode 100644 cv/classification/repmlp/pytorch/README.md create mode 100644 cv/classification/repmlp/pytorch/config.py create mode 100644 cv/classification/repmlp/pytorch/convert.py create mode 100644 cv/classification/repmlp/pytorch/cutout.py create mode 100644 cv/classification/repmlp/pytorch/data/__init__.py create mode 100644 cv/classification/repmlp/pytorch/data/build.py create mode 100644 cv/classification/repmlp/pytorch/data/cached_image_folder.py create mode 100644 cv/classification/repmlp/pytorch/data/samplers.py create mode 100644 cv/classification/repmlp/pytorch/data/zipreader.py create mode 100644 cv/classification/repmlp/pytorch/logger.py create mode 100644 cv/classification/repmlp/pytorch/lr_scheduler.py create mode 100644 cv/classification/repmlp/pytorch/main_repmlp.py create mode 100644 cv/classification/repmlp/pytorch/optimizer.py create mode 100644 cv/classification/repmlp/pytorch/randaug.py create mode 100644 cv/classification/repmlp/pytorch/repmlpnet.py create mode 100644 cv/classification/repmlp/pytorch/test.py create mode 100644 cv/classification/repmlp/pytorch/utils.py diff --git a/cv/classification/repmlp/pytorch/LICENSE b/cv/classification/repmlp/pytorch/LICENSE new file mode 100644 index 00000000..d8754619 --- /dev/null +++ b/cv/classification/repmlp/pytorch/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2021 DingXiaoH + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/cv/classification/repmlp/pytorch/README.md b/cv/classification/repmlp/pytorch/README.md new file mode 100644 index 00000000..cd6e73c4 --- /dev/null +++ b/cv/classification/repmlp/pytorch/README.md @@ -0,0 +1,35 @@ +# RepMLP +## Model description +RepMLP, a multi-layer-perceptron-style neural network building block for image recognition, which is composed of a series of fully-connected (FC) layers. Compared to convolutional layers, FC layers are more efficient, better at modeling the long-range dependencies and positional patterns, but worse at capturing the local structures, hence usually less favored for image recognition. Construct convolutional layers inside a RepMLP during training and merge them into the FC for inference. + +## Step 1: Installing + +```bash +pip3 install timm yacs +``` + +## Step 2: Download data + +Sign up and login in [ImageNet official website](https://www.image-net.org/index.php), then choose 'Download' to download the whole ImageNet dataset. Specify `/path/to/imagenet` to your ImageNet path in later training process. + +The ImageNet dataset path structure should look like: + +```bash +imagenet +├── train +│ └── n01440764 +│ ├── n01440764_10026.JPEG +│ └── ... +├── train_list.txt +├── val +│ └── n01440764 +│ ├── ILSVRC2012_val_00000293.JPEG +│ └── ... +└── val_list.txt +``` + +## Step 3: Run RepMLP + +```bash +python3 -m torch.distributed.launch --nproc_per_node 8 --master_port 12349 main_repmlp.py --arch RepMLPNet-B256 --batch-size 32 --tag my_experiment --opts TRAIN.EPOCHS 100 TRAIN.BASE_LR 0.001 TRAIN.WEIGHT_DECAY 0.1 TRAIN.OPTIMIZER.NAME adamw TRAIN.OPTIMIZER.MOMENTUM 0.9 TRAIN.WARMUP_LR 5e-7 TRAIN.MIN_LR 0.0 TRAIN.WARMUP_EPOCHS 10 AUG.PRESET raug15 AUG.MIXUP 0.4 AUG.CUTMIX 1.0 DATA.IMG_SIZE 256 --data-path [/path/to/imagenet] +``` diff --git a/cv/classification/repmlp/pytorch/config.py b/cv/classification/repmlp/pytorch/config.py new file mode 100644 index 00000000..30ef7f90 --- /dev/null +++ b/cv/classification/repmlp/pytorch/config.py @@ -0,0 +1,213 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import os +import yaml +from yacs.config import CfgNode as CN + +_C = CN() + +# Base config files +_C.BASE = [''] + +# ----------------------------------------------------------------------------- +# Data settings +# ----------------------------------------------------------------------------- +_C.DATA = CN() +# Batch size for a single GPU, could be overwritten by command line argument +_C.DATA.BATCH_SIZE = 128 +# Path to dataset, could be overwritten by command line argument +_C.DATA.DATA_PATH = '/path/to/imgnet/' + +# Dataset name +_C.DATA.DATASET = 'imagenet' +# Input image size +_C.DATA.IMG_SIZE = 224 +_C.DATA.TEST_SIZE = None +_C.DATA.TEST_BATCH_SIZE = None +# Interpolation to resize image (random, bilinear, bicubic) +_C.DATA.INTERPOLATION = 'bilinear' +# Use zipped dataset instead of folder dataset +# could be overwritten by command line argument +_C.DATA.ZIP_MODE = False +# Cache Data in Memory, could be overwritten by command line argument +_C.DATA.CACHE_MODE = 'part' +# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU. +_C.DATA.PIN_MEMORY = True +# Number of data loading threads +_C.DATA.NUM_WORKERS = 8 + +# ----------------------------------------------------------------------------- +# Model settings +# ----------------------------------------------------------------------------- +_C.MODEL = CN() +# Model type +_C.MODEL.ARCH = 'RepMLPNet-B224' +# Checkpoint to resume, could be overwritten by command line argument +_C.MODEL.RESUME = '' +# Number of classes, overwritten in data preparation +_C.MODEL.NUM_CLASSES = 1000 +# Label Smoothing +_C.MODEL.LABEL_SMOOTHING = 0.1 + +# ----------------------------------------------------------------------------- +# Training settings +# ----------------------------------------------------------------------------- +_C.TRAIN = CN() +_C.TRAIN.START_EPOCH = 0 +_C.TRAIN.EPOCHS = 3 #100 +_C.TRAIN.WARMUP_EPOCHS = 10 +_C.TRAIN.WEIGHT_DECAY = 0.1 +_C.TRAIN.BASE_LR = 0.002 +_C.TRAIN.WARMUP_LR = 0.0 +_C.TRAIN.MIN_LR = 0.0 +# Clip gradient norm +_C.TRAIN.CLIP_GRAD = 5.0 +# Auto resume from latest checkpoint +_C.TRAIN.AUTO_RESUME = True +# Gradient accumulation steps +# could be overwritten by command line argument +_C.TRAIN.ACCUMULATION_STEPS = 0 +# Whether to use gradient checkpointing to save memory +# could be overwritten by command line argument +_C.TRAIN.USE_CHECKPOINT = False + +# LR scheduler +_C.TRAIN.LR_SCHEDULER = CN() +_C.TRAIN.LR_SCHEDULER.NAME = 'cosine' +# Epoch interval to decay LR, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30 +# LR decay rate, used in StepLRScheduler +_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1 + +# Optimizer +_C.TRAIN.OPTIMIZER = CN() +_C.TRAIN.OPTIMIZER.NAME = 'adamw' +# Optimizer Epsilon +_C.TRAIN.OPTIMIZER.EPS = 1e-8 +# Optimizer Betas +_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999) +# SGD momentum +_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9 + +# For EMA model +_C.TRAIN.EMA_ALPHA = 0.0 +_C.TRAIN.EMA_UPDATE_PERIOD = 8 + +# ----------------------------------------------------------------------------- +# Augmentation settings +# ----------------------------------------------------------------------------- +_C.AUG = CN() +# Mixup alpha, mixup enabled if > 0 +_C.AUG.MIXUP = 0.0 +# Cutmix alpha, cutmix enabled if > 0 +_C.AUG.CUTMIX = 0.0 +# Cutmix min/max ratio, overrides alpha and enables cutmix if set +_C.AUG.CUTMIX_MINMAX = None +# Probability of performing mixup or cutmix when either/both is enabled +_C.AUG.MIXUP_PROB = 1.0 +# Probability of switching to cutmix when both mixup and cutmix enabled +_C.AUG.MIXUP_SWITCH_PROB = 0.5 +# How to apply mixup/cutmix params. Per "batch", "pair", or "elem" +_C.AUG.MIXUP_MODE = 'batch' + +_C.AUG.PRESET = None # If use AUG.PRESET (e.g., 'raug15'), use the pre-defined preprocessing, ignoring the following settings. +# Color jitter factor +_C.AUG.COLOR_JITTER = 0.4 +# Use AutoAugment policy. "v0" or "original" +_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1' +# Random erase prob +_C.AUG.REPROB = 0.25 +# Random erase mode +_C.AUG.REMODE = 'pixel' +# Random erase count +_C.AUG.RECOUNT = 1 + + +# ----------------------------------------------------------------------------- +# Testing settings +# ----------------------------------------------------------------------------- +_C.TEST = CN() +# Whether to use center crop when testing +_C.TEST.CROP = False + +# ----------------------------------------------------------------------------- +# Misc +# ----------------------------------------------------------------------------- +# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2') +# overwritten by command line argument +_C.AMP_OPT_LEVEL = '' +# Path to output folder, overwritten by command line argument +_C.OUTPUT = '' +# Tag of experiment, overwritten by command line argument +_C.TAG = 'default' +# Frequency to save checkpoint +_C.SAVE_FREQ = 20 +# Frequency to logging info +_C.PRINT_FREQ = 10 +# Fixed random seed +_C.SEED = 0 +# Perform evaluation only, overwritten by command line argument +_C.EVAL_MODE = False +# Test throughput only, overwritten by command line argument +_C.THROUGHPUT_MODE = False +# local rank for DistributedDataParallel, given by command line argument +_C.LOCAL_RANK = 0 + + +def update_config(config, args): + config.defrost() + if args.opts: + config.merge_from_list(args.opts) + # merge from specific arguments + if args.arch: + config.MODEL.ARCH = args.arch + if args.batch_size: + config.DATA.BATCH_SIZE = args.batch_size + if args.data_path: + config.DATA.DATA_PATH = args.data_path + if args.zip: + config.DATA.ZIP_MODE = True + if args.cache_mode: + config.DATA.CACHE_MODE = args.cache_mode + if args.resume: + config.MODEL.RESUME = args.resume + if args.accumulation_steps: + config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps + if args.use_checkpoint: + config.TRAIN.USE_CHECKPOINT = True + if args.amp_opt_level: + config.AMP_OPT_LEVEL = args.amp_opt_level + if args.output: + config.OUTPUT = args.output + if args.tag: + config.TAG = args.tag + if args.eval: + config.EVAL_MODE = True + if args.throughput: + config.THROUGHPUT_MODE = True + + if config.DATA.TEST_SIZE is None: + config.DATA.TEST_SIZE = config.DATA.IMG_SIZE + if config.DATA.TEST_BATCH_SIZE is None: + config.DATA.TEST_BATCH_SIZE = config.DATA.BATCH_SIZE + # set local rank for distributed training + config.LOCAL_RANK = args.local_rank + # output folder + config.OUTPUT = os.path.join(config.OUTPUT, config.MODEL.ARCH, config.TAG) + config.freeze() + + +def get_config(args): + """Get a yacs CfgNode object with default values.""" + # Return a clone so that the defaults will not be altered + # This is for the "local variable" use pattern + config = _C.clone() + update_config(config, args) + + return config diff --git a/cv/classification/repmlp/pytorch/convert.py b/cv/classification/repmlp/pytorch/convert.py new file mode 100644 index 00000000..5fcb5580 --- /dev/null +++ b/cv/classification/repmlp/pytorch/convert.py @@ -0,0 +1,33 @@ +import argparse +import os +import torch +from repmlpnet import get_RepMLPNet_model + +parser = argparse.ArgumentParser(description='RepMLPNet Conversion') +parser.add_argument('load', metavar='LOAD', help='path to the source weights file') +parser.add_argument('save', metavar='SAVE', help='path to the target weights file') +parser.add_argument('-a', '--arch', metavar='ARCH', default='RepMLPNet-B224') + +def convert(): + args = parser.parse_args() + model = get_RepMLPNet_model(args.arch, deploy=False) + + if os.path.isfile(args.load): + print("=> loading checkpoint '{}'".format(args.load)) + checkpoint = torch.load(args.load, map_location='cpu') + if 'state_dict' in checkpoint: + checkpoint = checkpoint['state_dict'] + elif 'model' in checkpoint: + checkpoint = checkpoint['model'] + ckpt = {k.replace('module.', ''): v for k, v in checkpoint.items()} # strip the names + print(ckpt.keys()) + model.load_state_dict(ckpt) + else: + raise ValueError("=> no checkpoint found at '{}'".format(args.load)) + + model.locality_injection() + + torch.save(model.state_dict(), args.save) + +if __name__ == '__main__': + convert() \ No newline at end of file diff --git a/cv/classification/repmlp/pytorch/cutout.py b/cv/classification/repmlp/pytorch/cutout.py new file mode 100644 index 00000000..8592ffc0 --- /dev/null +++ b/cv/classification/repmlp/pytorch/cutout.py @@ -0,0 +1,55 @@ +import numpy as np + +class Cutout: + + def __init__(self, size=16) -> None: + self.size = size + + def _create_cutout_mask(self, img_height, img_width, num_channels, size): + """Creates a zero mask used for cutout of shape `img_height` x `img_width`. + Args: + img_height: Height of image cutout mask will be applied to. + img_width: Width of image cutout mask will be applied to. + num_channels: Number of channels in the image. + size: Size of the zeros mask. + Returns: + A mask of shape `img_height` x `img_width` with all ones except for a + square of zeros of shape `size` x `size`. This mask is meant to be + elementwise multiplied with the original image. Additionally returns + the `upper_coord` and `lower_coord` which specify where the cutout mask + will be applied. + """ + # assert img_height == img_width + + # Sample center where cutout mask will be applied + height_loc = np.random.randint(low=0, high=img_height) + width_loc = np.random.randint(low=0, high=img_width) + + size = int(size) + # Determine upper right and lower left corners of patch + upper_coord = (max(0, height_loc - size // 2), max(0, width_loc - size // 2)) + lower_coord = ( + min(img_height, height_loc + size // 2), + min(img_width, width_loc + size // 2), + ) + mask_height = lower_coord[0] - upper_coord[0] + mask_width = lower_coord[1] - upper_coord[1] + assert mask_height > 0 + assert mask_width > 0 + + mask = np.ones((img_height, img_width, num_channels)) + zeros = np.zeros((mask_height, mask_width, num_channels)) + mask[upper_coord[0]: lower_coord[0], upper_coord[1]: lower_coord[1], :] = zeros + return mask, upper_coord, lower_coord + + def __call__(self, pil_img): + pil_img = pil_img.copy() + img_height, img_width, num_channels = (*pil_img.size, 3) + _, upper_coord, lower_coord = self._create_cutout_mask( + img_height, img_width, num_channels, self.size + ) + pixels = pil_img.load() # create the pixel map + for i in range(upper_coord[0], lower_coord[0]): # for every col: + for j in range(upper_coord[1], lower_coord[1]): # For every row + pixels[i, j] = (125, 122, 113, 0) # set the colour accordingly + return pil_img \ No newline at end of file diff --git a/cv/classification/repmlp/pytorch/data/__init__.py b/cv/classification/repmlp/pytorch/data/__init__.py new file mode 100644 index 00000000..70c633ce --- /dev/null +++ b/cv/classification/repmlp/pytorch/data/__init__.py @@ -0,0 +1 @@ +from .build import build_loader \ No newline at end of file diff --git a/cv/classification/repmlp/pytorch/data/build.py b/cv/classification/repmlp/pytorch/data/build.py new file mode 100644 index 00000000..1b0108d5 --- /dev/null +++ b/cv/classification/repmlp/pytorch/data/build.py @@ -0,0 +1,193 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- +import os +import torch +import numpy as np +import torch.distributed as dist +from torchvision import datasets, transforms +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.data import Mixup +from timm.data import create_transform +try: + from timm.data.transforms import str_to_pil_interp as _pil_interp +except: + from timm.data.transforms import _pil_interp +from .cached_image_folder import CachedImageFolder +from .samplers import SubsetRandomSampler + + +def build_loader(config): + config.defrost() + dataset_train, config.MODEL.NUM_CLASSES = build_dataset(is_train=True, config=config) + config.freeze() + print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build train dataset") + dataset_val, _ = build_dataset(is_train=False, config=config) + print(f"local rank {config.LOCAL_RANK} / global rank {dist.get_rank()} successfully build val dataset") + + num_tasks = dist.get_world_size() + global_rank = dist.get_rank() + if config.DATA.ZIP_MODE and config.DATA.CACHE_MODE == 'part': + indices = np.arange(dist.get_rank(), len(dataset_train), dist.get_world_size()) + sampler_train = SubsetRandomSampler(indices) + else: + sampler_train = torch.utils.data.DistributedSampler( + dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True + ) + + if dataset_val is None: + sampler_val = None + else: + indices = np.arange(dist.get_rank(), len(dataset_val), dist.get_world_size()) #TODO + sampler_val = SubsetRandomSampler(indices) + + data_loader_train = torch.utils.data.DataLoader( + dataset_train, sampler=sampler_train, + batch_size=config.DATA.BATCH_SIZE, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=config.DATA.PIN_MEMORY, + drop_last=True, + ) + + if dataset_val is None: + data_loader_val = None + else: + data_loader_val = torch.utils.data.DataLoader( + dataset_val, sampler=sampler_val, + batch_size=config.DATA.TEST_BATCH_SIZE, + shuffle=False, + num_workers=config.DATA.NUM_WORKERS, + pin_memory=config.DATA.PIN_MEMORY, + drop_last=False + ) + + # setup mixup / cutmix + mixup_fn = None + mixup_active = config.AUG.MIXUP > 0 or config.AUG.CUTMIX > 0. or config.AUG.CUTMIX_MINMAX is not None + if mixup_active: + mixup_fn = Mixup( + mixup_alpha=config.AUG.MIXUP, cutmix_alpha=config.AUG.CUTMIX, cutmix_minmax=config.AUG.CUTMIX_MINMAX, + prob=config.AUG.MIXUP_PROB, switch_prob=config.AUG.MIXUP_SWITCH_PROB, mode=config.AUG.MIXUP_MODE, + label_smoothing=config.MODEL.LABEL_SMOOTHING, num_classes=config.MODEL.NUM_CLASSES) + + return dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn + + +def build_dataset(is_train, config): + if config.DATA.DATASET == 'imagenet': + transform = build_transform(is_train, config) + prefix = 'train' if is_train else 'val' + if config.DATA.ZIP_MODE: + ann_file = prefix + "_map.txt" + prefix = prefix + ".zip@/" + dataset = CachedImageFolder(config.DATA.DATA_PATH, ann_file, prefix, transform, + cache_mode=config.DATA.CACHE_MODE if is_train else 'part') + else: + # Data source on our machines. You will never need it. + nori_root = os.path.join('/home/dingxiaohan/ndp/', 'imagenet.train.nori.list' if is_train else 'imagenet.val.nori.list') + if os.path.exists(nori_root): + # Data source on our machines. You will never need it. + from nori_dataset import ImageNetNoriDataset + dataset = ImageNetNoriDataset(nori_root, transform=transform) + else: + import torchvision + print('use raw ImageNet data') + root = os.path.join(config.DATA.DATA_PATH, prefix) + dataset = datasets.ImageFolder(root, transform=transform) + nb_classes = 1000 + + elif config.DATA.DATASET == 'cf100': + mean = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343] + std = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404] + if is_train: + transform = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(mean, std) + ]) + dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=True, download=True, transform=transform) + else: + transform = transforms.Compose( + [transforms.ToTensor(), + transforms.Normalize(mean, std)]) + dataset = datasets.CIFAR100(root=config.DATA.DATA_PATH, train=False, download=True, transform=transform) + nb_classes = 100 + + else: + raise NotImplementedError("We only support ImageNet and CIFAR-100 now.") + + return dataset, nb_classes + + +def build_transform(is_train, config): + resize_im = config.DATA.IMG_SIZE > 32 + if is_train: + # this should always dispatch to transforms_imagenet_train + + if config.AUG.PRESET is None: + transform = create_transform( + input_size=config.DATA.IMG_SIZE, + is_training=True, + color_jitter=config.AUG.COLOR_JITTER if config.AUG.COLOR_JITTER > 0 else None, + auto_augment=config.AUG.AUTO_AUGMENT if config.AUG.AUTO_AUGMENT != 'none' else None, + re_prob=config.AUG.REPROB, + re_mode=config.AUG.REMODE, + re_count=config.AUG.RECOUNT, + interpolation=config.DATA.INTERPOLATION, + ) + print('=============================== original AUG! ', config.AUG.AUTO_AUGMENT) + if not resize_im: + # replace RandomResizedCropAndInterpolation with + # RandomCrop + transform.transforms[0] = transforms.RandomCrop(config.DATA.IMG_SIZE, padding=4) + + elif config.AUG.PRESET.strip() == 'raug15': + from randaug import RandAugPolicy + transform = transforms.Compose([ + transforms.RandomResizedCrop(config.DATA.IMG_SIZE), + transforms.RandomHorizontalFlip(), + RandAugPolicy(magnitude=15), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + print('---------------------- RAND AUG 15 distortion!') + elif config.AUG.PRESET.strip() == 'weak': + transform = transforms.Compose([ + transforms.RandomResizedCrop(config.DATA.IMG_SIZE), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + elif config.AUG.PRESET.strip() == 'none': + transform = transforms.Compose([ + transforms.Resize(config.DATA.IMG_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION)), + transforms.CenterCrop(config.DATA.IMG_SIZE), + transforms.ToTensor(), + transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD), + ]) + else: + raise ValueError('???' + config.AUG.PRESET) + print(transform) + return transform + + t = [] + if resize_im: + if config.TEST.CROP: + size = int((256 / 224) * config.DATA.TEST_SIZE) + t.append(transforms.Resize(size, interpolation=_pil_interp(config.DATA.INTERPOLATION)), + # to maintain same ratio w.r.t. 224 images + ) + t.append(transforms.CenterCrop(config.DATA.TEST_SIZE)) + else: + # default for testing + t.append(transforms.Resize(config.DATA.TEST_SIZE, interpolation=_pil_interp(config.DATA.INTERPOLATION))) + t.append(transforms.CenterCrop(config.DATA.TEST_SIZE)) + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + trans = transforms.Compose(t) + return trans diff --git a/cv/classification/repmlp/pytorch/data/cached_image_folder.py b/cv/classification/repmlp/pytorch/data/cached_image_folder.py new file mode 100644 index 00000000..94fcde30 --- /dev/null +++ b/cv/classification/repmlp/pytorch/data/cached_image_folder.py @@ -0,0 +1,252 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import io +import os +import time +import torch.distributed as dist +import torch.utils.data as data +from PIL import Image + +from .zipreader import is_zip_path, ZipReader + + +def has_file_allowed_extension(filename, extensions): + """Checks if a file is an allowed extension. + Args: + filename (string): path to a file + Returns: + bool: True if the filename ends with a known image extension + """ + filename_lower = filename.lower() + return any(filename_lower.endswith(ext) for ext in extensions) + + +def find_classes(dir): + classes = [d for d in os.listdir(dir) if os.path.isdir(os.path.join(dir, d))] + classes.sort() + class_to_idx = {classes[i]: i for i in range(len(classes))} + return classes, class_to_idx + + +def make_dataset(dir, class_to_idx, extensions): + images = [] + dir = os.path.expanduser(dir) + for target in sorted(os.listdir(dir)): + d = os.path.join(dir, target) + if not os.path.isdir(d): + continue + + for root, _, fnames in sorted(os.walk(d)): + for fname in sorted(fnames): + if has_file_allowed_extension(fname, extensions): + path = os.path.join(root, fname) + item = (path, class_to_idx[target]) + images.append(item) + + return images + + +def make_dataset_with_ann(ann_file, img_prefix, extensions): + images = [] + with open(ann_file, "r") as f: + contents = f.readlines() + for line_str in contents: + path_contents = [c for c in line_str.split('\t')] + im_file_name = path_contents[0] + class_index = int(path_contents[1]) + + assert str.lower(os.path.splitext(im_file_name)[-1]) in extensions + item = (os.path.join(img_prefix, im_file_name), class_index) + + images.append(item) + + return images + + +class DatasetFolder(data.Dataset): + """A generic data loader where the samples are arranged in this way: :: + root/class_x/xxx.ext + root/class_x/xxy.ext + root/class_x/xxz.ext + root/class_y/123.ext + root/class_y/nsdf3.ext + root/class_y/asd932_.ext + Args: + root (string): Root directory path. + loader (callable): A function to load a sample given its path. + extensions (list[string]): A list of allowed extensions. + transform (callable, optional): A function/transform that takes in + a sample and returns a transformed version. + E.g, ``transforms.RandomCrop`` for images. + target_transform (callable, optional): A function/transform that takes + in the target and transforms it. + Attributes: + samples (list): List of (sample path, class_index) tuples + """ + + def __init__(self, root, loader, extensions, ann_file='', img_prefix='', transform=None, target_transform=None, + cache_mode="no"): + # image folder mode + if ann_file == '': + _, class_to_idx = find_classes(root) + samples = make_dataset(root, class_to_idx, extensions) + # zip mode + else: + samples = make_dataset_with_ann(os.path.join(root, ann_file), + os.path.join(root, img_prefix), + extensions) + + if len(samples) == 0: + raise (RuntimeError("Found 0 files in subfolders of: " + root + "\n" + + "Supported extensions are: " + ",".join(extensions))) + + self.root = root + self.loader = loader + self.extensions = extensions + + self.samples = samples + self.labels = [y_1k for _, y_1k in samples] + self.classes = list(set(self.labels)) + + self.transform = transform + self.target_transform = target_transform + + self.cache_mode = cache_mode + if self.cache_mode != "no": + self.init_cache() + + def init_cache(self): + assert self.cache_mode in ["part", "full"] + n_sample = len(self.samples) + global_rank = dist.get_rank() + world_size = dist.get_world_size() + + samples_bytes = [None for _ in range(n_sample)] + start_time = time.time() + for index in range(n_sample): + if index % (n_sample // 10) == 0: + t = time.time() - start_time + print(f'global_rank {dist.get_rank()} cached {index}/{n_sample} takes {t:.2f}s per block') + start_time = time.time() + path, target = self.samples[index] + if self.cache_mode == "full": + samples_bytes[index] = (ZipReader.read(path), target) + elif self.cache_mode == "part" and index % world_size == global_rank: + samples_bytes[index] = (ZipReader.read(path), target) + else: + samples_bytes[index] = (path, target) + self.samples = samples_bytes + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (sample, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + sample = self.loader(path) + if self.transform is not None: + sample = self.transform(sample) + if self.target_transform is not None: + target = self.target_transform(target) + + return sample, target + + def __len__(self): + return len(self.samples) + + def __repr__(self): + fmt_str = 'Dataset ' + self.__class__.__name__ + '\n' + fmt_str += ' Number of datapoints: {}\n'.format(self.__len__()) + fmt_str += ' Root Location: {}\n'.format(self.root) + tmp = ' Transforms (if any): ' + fmt_str += '{0}{1}\n'.format(tmp, self.transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + tmp = ' Target Transforms (if any): ' + fmt_str += '{0}{1}'.format(tmp, self.target_transform.__repr__().replace('\n', '\n' + ' ' * len(tmp))) + return fmt_str + + +IMG_EXTENSIONS = ['.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif'] + + +def pil_loader(path): + # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835) + if isinstance(path, bytes): + img = Image.open(io.BytesIO(path)) + elif is_zip_path(path): + data = ZipReader.read(path) + img = Image.open(io.BytesIO(data)) + else: + with open(path, 'rb') as f: + img = Image.open(f) + return img.convert('RGB') + + +def accimage_loader(path): + import accimage + try: + return accimage.Image(path) + except IOError: + # Potentially a decoding problem, fall back to PIL.Image + return pil_loader(path) + + +def default_img_loader(path): + from torchvision import get_image_backend + if get_image_backend() == 'accimage': + return accimage_loader(path) + else: + return pil_loader(path) + + +class CachedImageFolder(DatasetFolder): + """A generic data loader where the images are arranged in this way: :: + root/dog/xxx.png + root/dog/xxy.png + root/dog/xxz.png + root/cat/123.png + root/cat/nsdf3.png + root/cat/asd932_.png + Args: + root (string): Root directory path. + transform (callable, optional): A function/transform that takes in an PIL image + and returns a transformed version. E.g, ``transforms.RandomCrop`` + target_transform (callable, optional): A function/transform that takes in the + target and transforms it. + loader (callable, optional): A function to load an image given its path. + Attributes: + imgs (list): List of (image path, class_index) tuples + """ + + def __init__(self, root, ann_file='', img_prefix='', transform=None, target_transform=None, + loader=default_img_loader, cache_mode="no"): + super(CachedImageFolder, self).__init__(root, loader, IMG_EXTENSIONS, + ann_file=ann_file, img_prefix=img_prefix, + transform=transform, target_transform=target_transform, + cache_mode=cache_mode) + self.imgs = self.samples + + def __getitem__(self, index): + """ + Args: + index (int): Index + Returns: + tuple: (image, target) where target is class_index of the target class. + """ + path, target = self.samples[index] + image = self.loader(path) + if self.transform is not None: + img = self.transform(image) + else: + img = image + if self.target_transform is not None: + target = self.target_transform(target) + + return img, target diff --git a/cv/classification/repmlp/pytorch/data/samplers.py b/cv/classification/repmlp/pytorch/data/samplers.py new file mode 100644 index 00000000..738da77f --- /dev/null +++ b/cv/classification/repmlp/pytorch/data/samplers.py @@ -0,0 +1,30 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import torch + + +class SubsetRandomSampler(torch.utils.data.Sampler): + r"""Samples elements randomly from a given list of indices, without replacement. + + Arguments: + indices (sequence): a sequence of indices + """ + + def __init__(self, indices): + self.epoch = 0 + self.indices = indices + + def __iter__(self): + return (self.indices[i] for i in torch.randperm(len(self.indices))) + + def __len__(self): + return len(self.indices) + + def set_epoch(self, epoch): + self.epoch = epoch diff --git a/cv/classification/repmlp/pytorch/data/zipreader.py b/cv/classification/repmlp/pytorch/data/zipreader.py new file mode 100644 index 00000000..2476fdc2 --- /dev/null +++ b/cv/classification/repmlp/pytorch/data/zipreader.py @@ -0,0 +1,104 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import os +import zipfile +import io +import numpy as np +from PIL import Image +from PIL import ImageFile + +ImageFile.LOAD_TRUNCATED_IMAGES = True + + +def is_zip_path(img_or_path): + """judge if this is a zip path""" + return '.zip@' in img_or_path + + +class ZipReader(object): + """A class to read zipped files""" + zip_bank = dict() + + def __init__(self): + super(ZipReader, self).__init__() + + @staticmethod + def get_zipfile(path): + zip_bank = ZipReader.zip_bank + if path not in zip_bank: + zfile = zipfile.ZipFile(path, 'r') + zip_bank[path] = zfile + return zip_bank[path] + + @staticmethod + def split_zip_style_path(path): + pos_at = path.index('@') + assert pos_at != -1, "character '@' is not found from the given path '%s'" % path + + zip_path = path[0: pos_at] + folder_path = path[pos_at + 1:] + folder_path = str.strip(folder_path, '/') + return zip_path, folder_path + + @staticmethod + def list_folder(path): + zip_path, folder_path = ZipReader.split_zip_style_path(path) + + zfile = ZipReader.get_zipfile(zip_path) + folder_list = [] + for file_foler_name in zfile.namelist(): + file_foler_name = str.strip(file_foler_name, '/') + if file_foler_name.startswith(folder_path) and \ + len(os.path.splitext(file_foler_name)[-1]) == 0 and \ + file_foler_name != folder_path: + if len(folder_path) == 0: + folder_list.append(file_foler_name) + else: + folder_list.append(file_foler_name[len(folder_path) + 1:]) + + return folder_list + + @staticmethod + def list_files(path, extension=None): + if extension is None: + extension = ['.*'] + zip_path, folder_path = ZipReader.split_zip_style_path(path) + + zfile = ZipReader.get_zipfile(zip_path) + file_lists = [] + for file_foler_name in zfile.namelist(): + file_foler_name = str.strip(file_foler_name, '/') + if file_foler_name.startswith(folder_path) and \ + str.lower(os.path.splitext(file_foler_name)[-1]) in extension: + if len(folder_path) == 0: + file_lists.append(file_foler_name) + else: + file_lists.append(file_foler_name[len(folder_path) + 1:]) + + return file_lists + + @staticmethod + def read(path): + zip_path, path_img = ZipReader.split_zip_style_path(path) + zfile = ZipReader.get_zipfile(zip_path) + data = zfile.read(path_img) + return data + + @staticmethod + def imread(path): + zip_path, path_img = ZipReader.split_zip_style_path(path) + zfile = ZipReader.get_zipfile(zip_path) + data = zfile.read(path_img) + try: + im = Image.open(io.BytesIO(data)) + except: + print("ERROR IMG LOADED: ", path_img) + random_img = np.random.rand(224, 224, 3) * 255 + im = Image.fromarray(np.uint8(random_img)) + return im diff --git a/cv/classification/repmlp/pytorch/logger.py b/cv/classification/repmlp/pytorch/logger.py new file mode 100644 index 00000000..b08e312c --- /dev/null +++ b/cv/classification/repmlp/pytorch/logger.py @@ -0,0 +1,42 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import os +import sys +import logging +import functools +from termcolor import colored + + +@functools.lru_cache() +def create_logger(output_dir, dist_rank=0, name=''): + # create logger + logger = logging.getLogger(name) + logger.setLevel(logging.DEBUG) + logger.propagate = False + + # create formatter + fmt = '[%(asctime)s %(name)s] (%(filename)s %(lineno)d): %(levelname)s %(message)s' + color_fmt = colored('[%(asctime)s %(name)s]', 'green') + \ + colored('(%(filename)s %(lineno)d)', 'yellow') + ': %(levelname)s %(message)s' + + # create console handlers for master process + if dist_rank == 0: + console_handler = logging.StreamHandler(sys.stdout) + console_handler.setLevel(logging.DEBUG) + console_handler.setFormatter( + logging.Formatter(fmt=color_fmt, datefmt='%Y-%m-%d %H:%M:%S')) + logger.addHandler(console_handler) + + # create file handlers + file_handler = logging.FileHandler(os.path.join(output_dir, f'log_rank{dist_rank}.txt'), mode='a') + file_handler.setLevel(logging.DEBUG) + file_handler.setFormatter(logging.Formatter(fmt=fmt, datefmt='%Y-%m-%d %H:%M:%S')) + logger.addHandler(file_handler) + + return logger diff --git a/cv/classification/repmlp/pytorch/lr_scheduler.py b/cv/classification/repmlp/pytorch/lr_scheduler.py new file mode 100644 index 00000000..490e32f3 --- /dev/null +++ b/cv/classification/repmlp/pytorch/lr_scheduler.py @@ -0,0 +1,102 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import torch +from timm.scheduler.cosine_lr import CosineLRScheduler +from timm.scheduler.step_lr import StepLRScheduler +from timm.scheduler.scheduler import Scheduler + + +def build_scheduler(config, optimizer, n_iter_per_epoch): + num_steps = int(config.TRAIN.EPOCHS * n_iter_per_epoch) + warmup_steps = int(config.TRAIN.WARMUP_EPOCHS * n_iter_per_epoch) + decay_steps = int(config.TRAIN.LR_SCHEDULER.DECAY_EPOCHS * n_iter_per_epoch) + + lr_scheduler = None + if config.TRAIN.LR_SCHEDULER.NAME == 'cosine': + lr_scheduler = CosineLRScheduler( + optimizer, + t_initial=num_steps, + lr_min=config.TRAIN.MIN_LR, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + cycle_limit=1, + t_in_epochs=False, + ) + elif config.TRAIN.LR_SCHEDULER.NAME == 'linear': + lr_scheduler = LinearLRScheduler( + optimizer, + t_initial=num_steps, + lr_min_rate=0.01, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) + elif config.TRAIN.LR_SCHEDULER.NAME == 'step': + lr_scheduler = StepLRScheduler( + optimizer, + decay_t=decay_steps, + decay_rate=config.TRAIN.LR_SCHEDULER.DECAY_RATE, + warmup_lr_init=config.TRAIN.WARMUP_LR, + warmup_t=warmup_steps, + t_in_epochs=False, + ) + + return lr_scheduler + + +class LinearLRScheduler(Scheduler): + def __init__(self, + optimizer: torch.optim.Optimizer, + t_initial: int, + lr_min_rate: float, + warmup_t=0, + warmup_lr_init=0., + t_in_epochs=True, + noise_range_t=None, + noise_pct=0.67, + noise_std=1.0, + noise_seed=42, + initialize=True, + ) -> None: + super().__init__( + optimizer, param_group_field="lr", + noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed, + initialize=initialize) + + self.t_initial = t_initial + self.lr_min_rate = lr_min_rate + self.warmup_t = warmup_t + self.warmup_lr_init = warmup_lr_init + self.t_in_epochs = t_in_epochs + if self.warmup_t: + self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values] + super().update_groups(self.warmup_lr_init) + else: + self.warmup_steps = [1 for _ in self.base_values] + + def _get_lr(self, t): + if t < self.warmup_t: + lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps] + else: + t = t - self.warmup_t + total_t = self.t_initial - self.warmup_t + lrs = [v - ((v - v * self.lr_min_rate) * (t / total_t)) for v in self.base_values] + return lrs + + def get_epoch_values(self, epoch: int): + if self.t_in_epochs: + return self._get_lr(epoch) + else: + return None + + def get_update_values(self, num_updates: int): + if not self.t_in_epochs: + return self._get_lr(num_updates) + else: + return None diff --git a/cv/classification/repmlp/pytorch/main_repmlp.py b/cv/classification/repmlp/pytorch/main_repmlp.py new file mode 100644 index 00000000..b553fb60 --- /dev/null +++ b/cv/classification/repmlp/pytorch/main_repmlp.py @@ -0,0 +1,417 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- +import time +import argparse +import datetime +import numpy as np +import torch +import torch.backends.cudnn as cudnn +import torch.distributed as dist +from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy +from timm.utils import accuracy, AverageMeter +from config import get_config +from data import build_loader +from lr_scheduler import build_scheduler +from logger import create_logger +from utils import load_checkpoint, save_checkpoint, get_grad_norm, auto_resume_helper, reduce_tensor, save_latest, update_model_ema, unwrap_model, load_weights +import copy +from optimizer import build_optimizer +from repmlpnet import get_RepMLPNet_model + +try: + # noinspection PyUnresolvedReferences + from apex import amp +except ImportError: + amp = None + + +def parse_option(): + parser = argparse.ArgumentParser('RepOpt-VGG training script built on the codebase of Swin Transformer', add_help=False) + parser.add_argument( + "--opts", + help="Modify config options by adding 'KEY VALUE' pairs. ", + default=None, + nargs='+', + ) + + # easy config modification + parser.add_argument('--arch', default=None, type=str, help='arch name') + parser.add_argument('--batch-size', default=128, type=int, help="batch size for single GPU") + parser.add_argument('--data-path', default='/path/to/imgnet/', type=str, help='path to dataset') + parser.add_argument('--zip', action='store_true', help='use zipped dataset instead of folder dataset') + parser.add_argument('--cache-mode', type=str, default='part', choices=['no', 'full', 'part'], + help='no: no cache, ' + 'full: cache all data, ' + 'part: sharding the dataset into nonoverlapping pieces and only cache one piece') + parser.add_argument('--resume', help='resume from checkpoint') + parser.add_argument('--accumulation-steps', type=int, help="gradient accumulation steps") + parser.add_argument('--use-checkpoint', action='store_true', + help="whether to use gradient checkpointing to save memory") + parser.add_argument('--amp-opt-level', type=str, default='O0', choices=['O0', 'O1', 'O2'], #TODO Note: use amp if you have it + help='mixed precision opt level, if O0, no amp is used') + parser.add_argument('--output', default='output', type=str, metavar='PATH', + help='root of output folder, the full path is // (default: output)') + parser.add_argument('--tag', help='tag of experiment') + parser.add_argument('--eval', action='store_true', help='Perform evaluation only') + parser.add_argument('--throughput', action='store_true', help='Test throughput only') + + # distributed training + parser.add_argument("--local_rank", type=int, default=0, help='local rank for DistributedDataParallel') + + args, unparsed = parser.parse_known_args() + + config = get_config(args) + + return args, config + + + + +def main(config): + dataset_train, dataset_val, data_loader_train, data_loader_val, mixup_fn = build_loader(config) + + logger.info(f"Creating model:{config.MODEL.ARCH}") + + model = get_RepMLPNet_model(config.MODEL.ARCH, deploy=False) + + optimizer = build_optimizer(config, model) + + model.cuda() + + if torch.cuda.device_count() > 1: + if config.AMP_OPT_LEVEL != "O0": + model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[config.LOCAL_RANK], + broadcast_buffers=False) + model_without_ddp = model.module + else: + if config.AMP_OPT_LEVEL != "O0": + model, optimizer = amp.initialize(model, optimizer, opt_level=config.AMP_OPT_LEVEL) + model_without_ddp = model + + n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) + logger.info(f"number of params: {n_parameters}") + if hasattr(model_without_ddp, 'flops'): + flops = model_without_ddp.flops() + logger.info(f"number of GFLOPs: {flops / 1e9}") + + if config.THROUGHPUT_MODE: + throughput(data_loader_val, model, logger) + return + + if config.EVAL_MODE: + load_weights(model, config.MODEL.RESUME) + acc1, acc5, loss = validate(config, data_loader_val, model) + logger.info(f"Only eval. top-1 acc, top-5 acc, loss: {acc1:.3f}, {acc5:.3f}, {loss:.5f}") + return + + + lr_scheduler = build_scheduler(config, optimizer, len(data_loader_train)) + + if config.AUG.MIXUP > 0.: + # smoothing is handled with mixup label transform + criterion = SoftTargetCrossEntropy() + elif config.MODEL.LABEL_SMOOTHING > 0.: + criterion = LabelSmoothingCrossEntropy(smoothing=config.MODEL.LABEL_SMOOTHING) + else: + criterion = torch.nn.CrossEntropyLoss() + + + + max_accuracy = 0.0 + max_ema_accuracy = 0.0 + + if config.TRAIN.EMA_ALPHA > 0 and (not config.EVAL_MODE) and (not config.THROUGHPUT_MODE): + model_ema = copy.deepcopy(model) + else: + model_ema = None + + if config.TRAIN.AUTO_RESUME: + resume_file = auto_resume_helper(config.OUTPUT) + if resume_file: + if config.MODEL.RESUME: + logger.warning(f"auto-resume changing resume file from {config.MODEL.RESUME} to {resume_file}") + config.defrost() + config.MODEL.RESUME = resume_file + config.freeze() + logger.info(f'auto resuming from {resume_file}') + else: + logger.info(f'no checkpoint found in {config.OUTPUT}, ignoring auto resume') + + if (not config.THROUGHPUT_MODE) and config.MODEL.RESUME: + max_accuracy = load_checkpoint(config, model_without_ddp, optimizer, lr_scheduler, logger, model_ema=model_ema) + + + + logger.info("Start training") + start_time = time.time() + for epoch in range(config.TRAIN.START_EPOCH, config.TRAIN.EPOCHS): + data_loader_train.sampler.set_epoch(epoch) + + train_one_epoch(config, model, criterion, data_loader_train, optimizer, epoch, mixup_fn, lr_scheduler, model_ema=model_ema) + if dist.get_rank() == 0: + save_latest(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, model_ema=model_ema) + if epoch % config.SAVE_FREQ == 0: + save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, model_ema=model_ema) + + if epoch % config.SAVE_FREQ == 0 or epoch >= (config.TRAIN.EPOCHS - 10): + + if data_loader_val is not None: + acc1, acc5, loss = validate(config, data_loader_val, model) + logger.info(f"Accuracy of the network at epoch {epoch}: {acc1:.3f}%") + max_accuracy = max(max_accuracy, acc1) + logger.info(f'Max accuracy: {max_accuracy:.2f}%') + if max_accuracy == acc1 and dist.get_rank() == 0: + save_checkpoint(config, epoch, model_without_ddp, max_accuracy, optimizer, lr_scheduler, logger, + is_best=True, model_ema=model_ema) + + if model_ema is not None: + if data_loader_val is not None: + acc1, acc5, loss = validate(config, data_loader_val, model_ema) + logger.info(f"EMAAccuracy of the network at epoch {epoch} test images: {acc1:.3f}%") + max_ema_accuracy = max(max_ema_accuracy, acc1) + logger.info(f'EMAMax accuracy: {max_ema_accuracy:.2f}%') + if max_ema_accuracy == acc1 and dist.get_rank() == 0: + best_ema_path = os.path.join(config.OUTPUT, 'best_ema.pth') + logger.info(f"{best_ema_path} best EMA saving......") + torch.save(unwrap_model(model_ema).state_dict(), best_ema_path) + else: + latest_ema_path = os.path.join(config.OUTPUT, 'latest_ema.pth') + logger.info(f"{latest_ema_path} latest EMA saving......") + torch.save(unwrap_model(model_ema).state_dict(), latest_ema_path) + + total_time = time.time() - start_time + total_time_str = str(datetime.timedelta(seconds=int(total_time))) + logger.info('Training time {}'.format(total_time_str)) + + +def train_one_epoch(config, model, criterion, data_loader, optimizer, epoch, mixup_fn, lr_scheduler, model_ema=None): + model.train() + optimizer.zero_grad() + + num_steps = len(data_loader) + batch_time = AverageMeter() + loss_meter = AverageMeter() + norm_meter = AverageMeter() + + start = time.time() + end = time.time() + for idx, (samples, targets) in enumerate(data_loader): + samples = samples.cuda(non_blocking=True) + targets = targets.cuda(non_blocking=True) + + if mixup_fn is not None: + samples, targets = mixup_fn(samples, targets) + + outputs = model(samples) + + if config.TRAIN.ACCUMULATION_STEPS > 1: + + loss = criterion(outputs, targets) + loss = loss / config.TRAIN.ACCUMULATION_STEPS + if config.AMP_OPT_LEVEL != "O0": + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + if config.TRAIN.CLIP_GRAD: + grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(amp.master_params(optimizer)) + else: + loss.backward() + if config.TRAIN.CLIP_GRAD: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + if (idx + 1) % config.TRAIN.ACCUMULATION_STEPS == 0: + optimizer.step() + optimizer.zero_grad() + lr_scheduler.step_update(epoch * num_steps + idx) + + else: + if type(outputs) is dict: + loss = 0.0 + for name, pred in outputs.items(): + if 'aux' in name: + loss += 0.1 * criterion(pred, targets) + else: + loss += criterion(pred, targets) + else: + loss = criterion(outputs, targets) + + optimizer.zero_grad() + if config.AMP_OPT_LEVEL != "O0": + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + if config.TRAIN.CLIP_GRAD: + grad_norm = torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(amp.master_params(optimizer)) + else: + loss.backward() + if config.TRAIN.CLIP_GRAD: + grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), config.TRAIN.CLIP_GRAD) + else: + grad_norm = get_grad_norm(model.parameters()) + optimizer.step() + lr_scheduler.step_update(epoch * num_steps + idx) + + torch.cuda.synchronize() + + loss_meter.update(loss.item(), targets.size(0)) + norm_meter.update(grad_norm) + batch_time.update(time.time() - end) + + if model_ema is not None: + update_model_ema(config, dist.get_world_size(), model=model, model_ema=model_ema, cur_epoch=epoch, cur_iter=idx) + + end = time.time() + + if idx % config.PRINT_FREQ == 0: + lr = optimizer.param_groups[0]['lr'] + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + etas = batch_time.avg * (num_steps - idx) + logger.info( + f'Train: [{epoch}/{config.TRAIN.EPOCHS}][{idx}/{num_steps}]\t' + f'eta {datetime.timedelta(seconds=int(etas))} lr {lr:.6f}\t' + f'time {batch_time.val:.4f} ({batch_time.avg:.4f})\t' + f'loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'grad_norm {norm_meter.val:.4f} ({norm_meter.avg:.4f})\t' + f'mem {memory_used:.0f}MB') + epoch_time = time.time() - start + logger.info(f"EPOCH {epoch} training takes {datetime.timedelta(seconds=int(epoch_time))}") + + +@torch.no_grad() +def validate(config, data_loader, model): + criterion = torch.nn.CrossEntropyLoss() + model.eval() + + batch_time = AverageMeter() + loss_meter = AverageMeter() + acc1_meter = AverageMeter() + acc5_meter = AverageMeter() + + end = time.time() + for idx, (images, target) in enumerate(data_loader): + images = images.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(images) + + if type(output) is dict: + output = output['main'] + + # measure accuracy and record loss + loss = criterion(output, target) + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + + acc1 = reduce_tensor(acc1) + acc5 = reduce_tensor(acc5) + loss = reduce_tensor(loss) + + loss_meter.update(loss.item(), target.size(0)) + acc1_meter.update(acc1.item(), target.size(0)) + acc5_meter.update(acc5.item(), target.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if idx % config.PRINT_FREQ == 0: + memory_used = torch.cuda.max_memory_allocated() / (1024.0 * 1024.0) + logger.info( + f'Test: [{idx}/{len(data_loader)}]\t' + f'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + f'Loss {loss_meter.val:.4f} ({loss_meter.avg:.4f})\t' + f'Acc@1 {acc1_meter.val:.3f} ({acc1_meter.avg:.3f})\t' + f'Acc@5 {acc5_meter.val:.3f} ({acc5_meter.avg:.3f})\t' + f'Mem {memory_used:.0f}MB') + logger.info(f' * Acc@1 {acc1_meter.avg:.3f} Acc@5 {acc5_meter.avg:.3f}') + return acc1_meter.avg, acc5_meter.avg, loss_meter.avg + + +@torch.no_grad() +def throughput(data_loader, model, logger): + model.eval() + + for idx, (images, _) in enumerate(data_loader): + images = images.cuda(non_blocking=True) + + batch_size = images.shape[0] + for i in range(50): + model(images) + torch.cuda.synchronize() + logger.info(f"throughput averaged with 30 times") + tic1 = time.time() + for i in range(30): + model(images) + torch.cuda.synchronize() + tic2 = time.time() + throughput = 30 * batch_size / (tic2 - tic1) + logger.info(f"batch_size {batch_size} throughput {throughput}") + return + + +import os + +if __name__ == '__main__': + args, config = parse_option() + + if config.AMP_OPT_LEVEL != "O0": + assert amp is not None, "amp not installed!" + + if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: + rank = int(os.environ["RANK"]) + world_size = int(os.environ['WORLD_SIZE']) + else: + rank = -1 + world_size = -1 + torch.cuda.set_device(config.LOCAL_RANK) + torch.distributed.init_process_group(backend='nccl', init_method='env://', world_size=world_size, rank=rank) + torch.distributed.barrier() + seed = config.SEED + dist.get_rank() + + torch.manual_seed(seed) + np.random.seed(seed) + cudnn.benchmark = True + + if not config.EVAL_MODE: + # linear scale the learning rate according to total batch size, may not be optimal + linear_scaled_lr = config.TRAIN.BASE_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0 + linear_scaled_warmup_lr = config.TRAIN.WARMUP_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0 + linear_scaled_min_lr = config.TRAIN.MIN_LR * config.DATA.BATCH_SIZE * dist.get_world_size() / 256.0 + # gradient accumulation also need to scale the learning rate + if config.TRAIN.ACCUMULATION_STEPS > 1: + linear_scaled_lr = linear_scaled_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_warmup_lr = linear_scaled_warmup_lr * config.TRAIN.ACCUMULATION_STEPS + linear_scaled_min_lr = linear_scaled_min_lr * config.TRAIN.ACCUMULATION_STEPS + config.defrost() + config.TRAIN.BASE_LR = linear_scaled_lr + config.TRAIN.WARMUP_LR = linear_scaled_warmup_lr + config.TRAIN.MIN_LR = linear_scaled_min_lr + config.freeze() + + print('==========================================') + print('real base lr: ', config.TRAIN.BASE_LR) + print('==========================================') + + os.makedirs(config.OUTPUT, exist_ok=True) + + logger = create_logger(output_dir=config.OUTPUT, dist_rank=0 if torch.cuda.device_count() == 1 else dist.get_rank(), name=f"{config.MODEL.ARCH}") + + if torch.cuda.device_count() == 1 or dist.get_rank() == 0: + path = os.path.join(config.OUTPUT, "config.json") + with open(path, "w") as f: + f.write(config.dump()) + logger.info(f"Full config saved to {path}") + + # print config + logger.info(config.dump()) + + main(config) diff --git a/cv/classification/repmlp/pytorch/optimizer.py b/cv/classification/repmlp/pytorch/optimizer.py new file mode 100644 index 00000000..a43e321c --- /dev/null +++ b/cv/classification/repmlp/pytorch/optimizer.py @@ -0,0 +1,68 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +from torch import optim as optim + + +def build_optimizer(config, model): + """ + Build optimizer, set weight decay of normalization to 0 by default. + """ + skip = {} + skip_keywords = {} + if hasattr(model, 'no_weight_decay'): + skip = model.no_weight_decay() + if hasattr(model, 'no_weight_decay_keywords'): + skip_keywords = model.no_weight_decay_keywords() + echo = (config.LOCAL_RANK==0) + parameters = set_weight_decay(model, skip, skip_keywords, echo=echo) + opt_lower = config.TRAIN.OPTIMIZER.NAME.lower() + optimizer = None + if opt_lower == 'sgd': + optimizer = optim.SGD(parameters, momentum=config.TRAIN.OPTIMIZER.MOMENTUM, nesterov=True, + lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) + if echo: + print('================================== SGD nest, momentum = {}, wd = {}'.format(config.TRAIN.OPTIMIZER.MOMENTUM, config.TRAIN.WEIGHT_DECAY)) + elif opt_lower == 'adamw': + optimizer = optim.AdamW(parameters, eps=config.TRAIN.OPTIMIZER.EPS, betas=config.TRAIN.OPTIMIZER.BETAS, + lr=config.TRAIN.BASE_LR, weight_decay=config.TRAIN.WEIGHT_DECAY) + + return optimizer + + +def set_weight_decay(model, skip_list=(), skip_keywords=(), echo=False): + has_decay = [] + no_decay = [] + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if 'identity.weight' in name: + has_decay.append(param) + if echo: + print(f"{name} USE weight decay") + elif len(param.shape) == 1 or name.endswith(".bias") or (name in skip_list) or \ + check_keywords_in_name(name, skip_keywords): + no_decay.append(param) + if echo: + print(f"{name} has no weight decay") + else: + has_decay.append(param) + if echo: + print(f"{name} USE weight decay") + + return [{'params': has_decay}, + {'params': no_decay, 'weight_decay': 0.}] + + +def check_keywords_in_name(name, keywords=()): + isin = False + for keyword in keywords: + if keyword in name: + isin = True + return isin diff --git a/cv/classification/repmlp/pytorch/randaug.py b/cv/classification/repmlp/pytorch/randaug.py new file mode 100644 index 00000000..31600c33 --- /dev/null +++ b/cv/classification/repmlp/pytorch/randaug.py @@ -0,0 +1,407 @@ +import math +import random + +import numpy as np +import PIL +from PIL import Image, ImageEnhance, ImageOps + +from cutout import Cutout + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def cutout(img, factor, **kwargs): + _check_args_tf(kwargs) + return Cutout(size=factor)(img) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], + - rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def identity(img, **__): + return img + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _cutout_level_to_arg(level, _hparams): + # range [0, 40] + level = max(2, (level / _MAX_LEVEL) * 40.) + return level, + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return level, + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return (level / _MAX_LEVEL) * 1.8 + 0.1, + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return level, + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return level, + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return level, + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return 4 - int((level / _MAX_LEVEL) * 4), + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return int((level / _MAX_LEVEL) * 4), + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Identity': None, + 'Rotate': _rotate_level_to_arg, + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, + 'Cutout': _cutout_level_to_arg, +} + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Identity': identity, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, + 'Cutout': cutout, +} + + +class AutoAugmentTransform(object): + """ + AutoAugment from Google. + Implementation adapted from: + https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py + """ + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + """ + Args: + name (str): any type of transforms list in _RAND_TRANSFORMS. + prob (float): probability of perform current augmentation. + magnitude (int): intensity / magnitude of each augmentation. + hparams (dict): hyper-parameters required by each augmentation. + """ + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams + else _RANDOM_INTERPOLATION, + ) + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img: PIL.Image) -> PIL.Image: + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + # NOTE: magnitude fixed and no boundary + # magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn( + magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + # return np.array(self.aug_fn(Image.fromarray(img), *level_args, **self.kwargs)) + + # def apply_coords(self, coords: np.ndarray) -> np.ndarray: + # return coords + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + 'Cutout' # FIXME I implement this as random erasing separately +] + +_RAND_TRANSFORMS_CMC = [ + 'AutoContrast', + 'Identity', + 'Rotate', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + # 'Cutout' # FIXME I implement this as random erasing separately +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, +} + + +class RandAugPolicy(object): + def __init__(self, layers=2, magnitude=10): + self.layers = layers + self.magnitude = magnitude + + def __call__(self, img): + for _ in range(self.layers): + trans = np.random.choice(_RAND_TRANSFORMS) + # NOTE: prob apply, fixed magnitude + # trans_op = AutoAugmentTransform(trans, prob=np.random.uniform(0.2, 0.8), magnitude=self.magnitude) + # NOTE: always apply, random magnitude + trans_op = AutoAugmentTransform(trans, prob=1.0, magnitude=np.random.choice(self.magnitude)) + img = trans_op(img) + assert img is not None, trans + return img diff --git a/cv/classification/repmlp/pytorch/repmlpnet.py b/cv/classification/repmlp/pytorch/repmlpnet.py new file mode 100644 index 00000000..8bc9af92 --- /dev/null +++ b/cv/classification/repmlp/pytorch/repmlpnet.py @@ -0,0 +1,325 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# -------------------------------------------------------- + +import torch.nn.functional as F +import torch.nn as nn +import torch.utils.checkpoint as checkpoint +import torch + +def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = nn.Sequential() + result.add_module('conv', nn.Conv2d(in_channels=in_channels, out_channels=out_channels, + kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)) + result.add_module('bn', nn.BatchNorm2d(num_features=out_channels)) + return result + +def conv_bn_relu(in_channels, out_channels, kernel_size, stride, padding, groups=1): + result = conv_bn(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) + result.add_module('relu', nn.ReLU()) + return result + +def fuse_bn(conv_or_fc, bn): + std = (bn.running_var + bn.eps).sqrt() + t = bn.weight / std + t = t.reshape(-1, 1, 1, 1) + + if len(t) == conv_or_fc.weight.size(0): + return conv_or_fc.weight * t, bn.bias - bn.running_mean * bn.weight / std + else: + repeat_times = conv_or_fc.weight.size(0) // len(t) + repeated = t.repeat_interleave(repeat_times, 0) + return conv_or_fc.weight * repeated, (bn.bias - bn.running_mean * bn.weight / std).repeat_interleave( + repeat_times, 0) + + +class GlobalPerceptron(nn.Module): + + def __init__(self, input_channels, internal_neurons): + super(GlobalPerceptron, self).__init__() + self.fc1 = nn.Conv2d(in_channels=input_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True) + self.fc2 = nn.Conv2d(in_channels=internal_neurons, out_channels=input_channels, kernel_size=1, stride=1, bias=True) + self.input_channels = input_channels + + def forward(self, inputs): + x = F.adaptive_avg_pool2d(inputs, output_size=(1, 1)) + x = self.fc1(x) + x = F.relu(x, inplace=True) + x = self.fc2(x) + x = F.sigmoid(x) + x = x.view(-1, self.input_channels, 1, 1) + return x + +class RepMLPBlock(nn.Module): + + def __init__(self, in_channels, out_channels, + h, w, + reparam_conv_k=None, + globalperceptron_reduce=4, + num_sharesets=1, + deploy=False): + super().__init__() + + self.C = in_channels + self.O = out_channels + self.S = num_sharesets + + self.h, self.w = h, w + + self.deploy = deploy + + assert in_channels == out_channels + self.gp = GlobalPerceptron(input_channels=in_channels, internal_neurons=in_channels // globalperceptron_reduce) + + self.fc3 = nn.Conv2d(self.h * self.w * num_sharesets, self.h * self.w * num_sharesets, 1, 1, 0, bias=deploy, groups=num_sharesets) + if deploy: + self.fc3_bn = nn.Identity() + else: + self.fc3_bn = nn.BatchNorm2d(num_sharesets) + + self.reparam_conv_k = reparam_conv_k + if not deploy and reparam_conv_k is not None: + for k in reparam_conv_k: + conv_branch = conv_bn(num_sharesets, num_sharesets, kernel_size=k, stride=1, padding=k//2, groups=num_sharesets) + self.__setattr__('repconv{}'.format(k), conv_branch) + + + def partition(self, x, h_parts, w_parts): + x = x.reshape(-1, self.C, h_parts, self.h, w_parts, self.w) + x = x.permute(0, 2, 4, 1, 3, 5) + return x + + def partition_affine(self, x, h_parts, w_parts): + fc_inputs = x.reshape(-1, self.S * self.h * self.w, 1, 1) + out = self.fc3(fc_inputs) + out = out.reshape(-1, self.S, self.h, self.w) + out = self.fc3_bn(out) + out = out.reshape(-1, h_parts, w_parts, self.S, self.h, self.w) + return out + + def forward(self, inputs): + # Global Perceptron + global_vec = self.gp(inputs) + + origin_shape = inputs.size() + h_parts = origin_shape[2] // self.h + w_parts = origin_shape[3] // self.w + + partitions = self.partition(inputs, h_parts, w_parts) + + # Channel Perceptron + fc3_out = self.partition_affine(partitions, h_parts, w_parts) + + # Local Perceptron + if self.reparam_conv_k is not None and not self.deploy: + conv_inputs = partitions.reshape(-1, self.S, self.h, self.w) + conv_out = 0 + for k in self.reparam_conv_k: + conv_branch = self.__getattr__('repconv{}'.format(k)) + conv_out += conv_branch(conv_inputs) + conv_out = conv_out.reshape(-1, h_parts, w_parts, self.S, self.h, self.w) + fc3_out += conv_out + + fc3_out = fc3_out.permute(0, 3, 1, 4, 2, 5) # N, O, h_parts, out_h, w_parts, out_w + out = fc3_out.reshape(*origin_shape) + out = out * global_vec + return out + + + def get_equivalent_fc3(self): + fc_weight, fc_bias = fuse_bn(self.fc3, self.fc3_bn) + if self.reparam_conv_k is not None: + largest_k = max(self.reparam_conv_k) + largest_branch = self.__getattr__('repconv{}'.format(largest_k)) + total_kernel, total_bias = fuse_bn(largest_branch.conv, largest_branch.bn) + for k in self.reparam_conv_k: + if k != largest_k: + k_branch = self.__getattr__('repconv{}'.format(k)) + kernel, bias = fuse_bn(k_branch.conv, k_branch.bn) + total_kernel += F.pad(kernel, [(largest_k - k) // 2] * 4) + total_bias += bias + rep_weight, rep_bias = self._convert_conv_to_fc(total_kernel, total_bias) + final_fc3_weight = rep_weight.reshape_as(fc_weight) + fc_weight + final_fc3_bias = rep_bias + fc_bias + else: + final_fc3_weight = fc_weight + final_fc3_bias = fc_bias + return final_fc3_weight, final_fc3_bias + + def local_inject(self): + self.deploy = True + # Locality Injection + fc3_weight, fc3_bias = self.get_equivalent_fc3() + # Remove Local Perceptron + if self.reparam_conv_k is not None: + for k in self.reparam_conv_k: + self.__delattr__('repconv{}'.format(k)) + self.__delattr__('fc3') + self.__delattr__('fc3_bn') + self.fc3 = nn.Conv2d(self.S * self.h * self.w, self.S * self.h * self.w, 1, 1, 0, bias=True, groups=self.S) + self.fc3_bn = nn.Identity() + self.fc3.weight.data = fc3_weight + self.fc3.bias.data = fc3_bias + + def _convert_conv_to_fc(self, conv_kernel, conv_bias): + I = torch.eye(self.h * self.w).repeat(1, self.S).reshape(self.h * self.w, self.S, self.h, self.w).to(conv_kernel.device) + fc_k = F.conv2d(I, conv_kernel, padding=(conv_kernel.size(2)//2,conv_kernel.size(3)//2), groups=self.S) + fc_k = fc_k.reshape(self.h * self.w, self.S * self.h * self.w).t() + fc_bias = conv_bias.repeat_interleave(self.h * self.w) + return fc_k, fc_bias + + +# The common FFN Block used in many Transformer and MLP models. +class FFNBlock(nn.Module): + def __init__(self, in_channels, hidden_channels=None, out_channels=None, act_layer=nn.GELU): + super().__init__() + out_features = out_channels or in_channels + hidden_features = hidden_channels or in_channels + self.ffn_fc1 = conv_bn(in_channels, hidden_features, 1, 1, 0) + self.ffn_fc2 = conv_bn(hidden_features, out_features, 1, 1, 0) + self.act = act_layer() + + def forward(self, x): + x = self.ffn_fc1(x) + x = self.act(x) + x = self.ffn_fc2(x) + return x + + +class RepMLPNetUnit(nn.Module): + + def __init__(self, channels, h, w, reparam_conv_k, globalperceptron_reduce, ffn_expand=4, + num_sharesets=1, deploy=False): + super().__init__() + self.repmlp_block = RepMLPBlock(in_channels=channels, out_channels=channels, h=h, w=w, + reparam_conv_k=reparam_conv_k, globalperceptron_reduce=globalperceptron_reduce, + num_sharesets=num_sharesets, deploy=deploy) + self.ffn_block = FFNBlock(channels, channels * ffn_expand) + self.prebn1 = nn.BatchNorm2d(channels) + self.prebn2 = nn.BatchNorm2d(channels) + + def forward(self, x): + y = x + self.repmlp_block(self.prebn1(x)) # TODO use droppath? + z = y + self.ffn_block(self.prebn2(y)) + return z + + +class RepMLPNet(nn.Module): + + def __init__(self, + in_channels=3, num_class=1000, + patch_size=(4, 4), + num_blocks=(2,2,6,2), channels=(192,384,768,1536), + hs=(64,32,16,8), ws=(64,32,16,8), + sharesets_nums=(4,8,16,32), + reparam_conv_k=(3,), + globalperceptron_reduce=4, use_checkpoint=False, + deploy=False): + super().__init__() + num_stages = len(num_blocks) + assert num_stages == len(channels) + assert num_stages == len(hs) + assert num_stages == len(ws) + assert num_stages == len(sharesets_nums) + + self.conv_embedding = conv_bn_relu(in_channels, channels[0], kernel_size=patch_size, stride=patch_size, padding=0) + + stages = [] + embeds = [] + for stage_idx in range(num_stages): + stage_blocks = [RepMLPNetUnit(channels=channels[stage_idx], h=hs[stage_idx], w=ws[stage_idx], reparam_conv_k=reparam_conv_k, + globalperceptron_reduce=globalperceptron_reduce, ffn_expand=4, num_sharesets=sharesets_nums[stage_idx], + deploy=deploy) for _ in range(num_blocks[stage_idx])] + stages.append(nn.ModuleList(stage_blocks)) + if stage_idx < num_stages - 1: + embeds.append(conv_bn_relu(in_channels=channels[stage_idx], out_channels=channels[stage_idx + 1], kernel_size=2, stride=2, padding=0)) + + self.stages = nn.ModuleList(stages) + self.embeds = nn.ModuleList(embeds) + self.head_norm = nn.BatchNorm2d(channels[-1]) + self.head = nn.Linear(channels[-1], num_class) + + self.use_checkpoint = use_checkpoint + + def forward(self, x): + x = self.conv_embedding(x) + for i, stage in enumerate(self.stages): + for block in stage: + if self.use_checkpoint: + x = checkpoint.checkpoint(block, x) + else: + x = block(x) + if i < len(self.stages) - 1: + embed = self.embeds[i] + if self.use_checkpoint: + x = checkpoint.checkpoint(embed, x) + else: + x = embed(x) + x = self.head_norm(x) + x = F.adaptive_avg_pool2d(x, 1) + x = x.view(x.size(0), -1) + x = self.head(x) + return x + + def locality_injection(self): + for m in self.modules(): + if hasattr(m, 'local_inject'): + m.local_inject() + +def create_RepMLPNet_T224(deploy=False): + return RepMLPNet(channels=(64, 128, 256, 512), hs=(56,28,14,7), ws=(56,28,14,7), + num_blocks=(2,2,6,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128), + deploy=deploy) +def create_RepMLPNet_T256(deploy=False): + return RepMLPNet(channels=(64, 128, 256, 512), hs=(64,32,16,8), ws=(64,32,16,8), + num_blocks=(2,2,6,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128), + deploy=deploy) +def create_RepMLPNet_B224(deploy=False): + return RepMLPNet(channels=(96, 192, 384, 768), hs=(56,28,14,7), ws=(56,28,14,7), + num_blocks=(2,2,12,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,128), + deploy=deploy) +def create_RepMLPNet_B256(deploy=False): + return RepMLPNet(channels=(96, 192, 384, 768), hs=(64,32,16,8), ws=(64,32,16,8), + num_blocks=(2,2,12,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,128), + deploy=deploy) +def create_RepMLPNet_D256(deploy=False): + return RepMLPNet(channels=(80, 160, 320, 640), hs=(64,32,16,8), ws=(64,32,16,8), + num_blocks=(2,2,18,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,16,128), + deploy=deploy) +def create_RepMLPNet_L256(deploy=False): + return RepMLPNet(channels=(96, 192, 384, 768), hs=(64,32,16,8), ws=(64,32,16,8), + num_blocks=(2,2,18,2), reparam_conv_k=(1, 3), sharesets_nums=(1,4,32,256), + deploy=deploy) + +model_map = { + 'RepMLPNet-T256': create_RepMLPNet_T256, + 'RepMLPNet-T224': create_RepMLPNet_T224, + 'RepMLPNet-B224': create_RepMLPNet_B224, + 'RepMLPNet-B256': create_RepMLPNet_B256, + 'RepMLPNet-D256': create_RepMLPNet_D256, + 'RepMLPNet-L256': create_RepMLPNet_L256, +} + +def get_RepMLPNet_model(name, deploy=False): + if name not in model_map: + raise ValueError('Not yet supported. You may add some code to create the model here.') + model = model_map[name](deploy=deploy) + return model + + +# Verify the equivalency +if __name__ == '__main__': + model = create_RepMLPNet_B224() + model.eval() + + x = torch.randn(1, 3, 224, 224) + origin_y = model(x) + + model.locality_injection() + + print(model) + new_y = model(x) + print((new_y - origin_y).abs().sum()) \ No newline at end of file diff --git a/cv/classification/repmlp/pytorch/test.py b/cv/classification/repmlp/pytorch/test.py new file mode 100644 index 00000000..d501e510 --- /dev/null +++ b/cv/classification/repmlp/pytorch/test.py @@ -0,0 +1,135 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- +import argparse +import os +import time +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.backends.cudnn as cudnn +import torch.optim +import torch.utils.data +import torch.utils.data.distributed +import torchvision.transforms as transforms +import torchvision.datasets as datasets +from utils import load_weights, ProgressMeter, AverageMeter +from repmlpnet import get_RepMLPNet_model +from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD +from timm.utils import accuracy + +parser = argparse.ArgumentParser(description='PyTorch ImageNet Test') +parser.add_argument('data', metavar='DATA', help='path to dataset') +parser.add_argument('mode', metavar='MODE', default='train', choices=['train', 'deploy', 'check'], help='train, deploy, or check the equivalency?') +parser.add_argument('weights', metavar='WEIGHTS', help='path to the weights file') +parser.add_argument('-a', '--arch', metavar='ARCH', default='RepMLPNet-B224') +parser.add_argument('-b', '--batch-size', default=100, type=int, + metavar='N', + help='mini-batch size (default: 100) for test') +parser.add_argument('-r', '--resolution', default=224, type=int, + metavar='R', + help='resolution (default: 224) for test') +parser.add_argument('-j', '--workers', default=4, type=int, metavar='N', + help='number of data loading workers (default: 4)') + +def test(): + args = parser.parse_args() + model = get_RepMLPNet_model(name=args.arch, deploy=args.mode == 'deploy') + + num_params = 0 + for k, v in model.state_dict().items(): + print(k, v.shape) + num_params += v.nelement() + print('total params: ', num_params) + + if os.path.isfile(args.weights): + print("=> loading checkpoint '{}'".format(args.weights)) + load_weights(model, args.weights) + else: + raise ValueError("=> no checkpoint found at '{}'".format(args.weights)) + + if args.mode == 'check': # Note this. In "check" mode, we load the trained weights and convert afterwards. + model.locality_injection() + + if not torch.cuda.is_available(): + print('using CPU, this will be slow.') + use_gpu = False + criterion = nn.CrossEntropyLoss() + else: + model = model.cuda() + use_gpu = True + criterion = nn.CrossEntropyLoss().cuda() + cudnn.benchmark = True + + t = [] + t.append(transforms.Resize(args.resolution)) + t.append(transforms.CenterCrop(args.resolution)) + t.append(transforms.ToTensor()) + t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) + trans = transforms.Compose(t) + + if os.path.exists('/home/dingxiaohan/ndp/imagenet.val.nori.list'): + # This is the data source on our machine. For debugging only. You will never need it. + from noris_dataset import ImageNetNoriDataset + val_dataset = ImageNetNoriDataset('/home/dingxiaohan/ndp/imagenet.val.nori.list', trans) + else: + # Your ImageNet directory + valdir = os.path.join(args.data, 'val') + val_dataset = datasets.ImageFolder(valdir, trans) + + val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=args.batch_size, shuffle=False, + num_workers=args.workers, pin_memory=True) + + validate(val_loader, model, criterion, use_gpu) + + +def validate(val_loader, model, criterion, use_gpu): + batch_time = AverageMeter('Time', ':6.3f') + losses = AverageMeter('Loss', ':.4e') + top1 = AverageMeter('Acc@1', ':6.2f') + top5 = AverageMeter('Acc@5', ':6.2f') + progress = ProgressMeter( + len(val_loader), + [batch_time, losses, top1, top5], + prefix='Test: ') + + # switch to evaluate mode + model.eval() + + with torch.no_grad(): + end = time.time() + for i, (images, target) in enumerate(val_loader): + if use_gpu: + images = images.cuda(non_blocking=True) + target = target.cuda(non_blocking=True) + + # compute output + output = model(images) + loss = criterion(output, target) + + # measure accuracy and record loss + acc1, acc5 = accuracy(output, target, topk=(1, 5)) + losses.update(loss.item(), images.size(0)) + top1.update(acc1.item(), images.size(0)) + top5.update(acc5.item(), images.size(0)) + + # measure elapsed time + batch_time.update(time.time() - end) + end = time.time() + + if i % 10 == 0: + progress.display(i) + + print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(top1=top1, top5=top5)) + + return top1.avg + + +if __name__ == '__main__': + test() \ No newline at end of file diff --git a/cv/classification/repmlp/pytorch/utils.py b/cv/classification/repmlp/pytorch/utils.py new file mode 100644 index 00000000..f7caac85 --- /dev/null +++ b/cv/classification/repmlp/pytorch/utils.py @@ -0,0 +1,193 @@ +# -------------------------------------------------------- +# RepMLPNet: Hierarchical Vision MLP with Re-parameterized Locality (https://arxiv.org/abs/2112.11081) +# CVPR 2022 +# Github source: https://github.com/DingXiaoH/RepMLP +# Licensed under The MIT License [see LICENSE for details] +# The training script is based on the code of Swin Transformer (https://github.com/microsoft/Swin-Transformer) +# -------------------------------------------------------- + +import os +import torch +import torch.distributed as dist +import numpy as np + +try: + # noinspection PyUnresolvedReferences + from apex import amp +except ImportError: + amp = None + + +def unwrap_model(model): + """Remove the DistributedDataParallel wrapper if present.""" + wrapped = isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel) + return model.module if wrapped else model + + +def load_checkpoint(config, model, optimizer, lr_scheduler, logger, model_ema=None): + logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") + if config.MODEL.RESUME.startswith('https'): + checkpoint = torch.hub.load_state_dict_from_url( + config.MODEL.RESUME, map_location='cpu', check_hash=True) + else: + checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') + msg = model.load_state_dict(checkpoint['model'], strict=False) + logger.info(msg) + max_accuracy = 0.0 + if not config.EVAL_MODE and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + config.defrost() + config.TRAIN.START_EPOCH = checkpoint['epoch'] + 1 + config.freeze() + if 'amp' in checkpoint and config.AMP_OPT_LEVEL != "O0" and checkpoint['config'].AMP_OPT_LEVEL != "O0": + amp.load_state_dict(checkpoint['amp']) + logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") + if 'max_accuracy' in checkpoint: + max_accuracy = checkpoint['max_accuracy'] + if model_ema is not None: + unwrap_model(model_ema).load_state_dict(checkpoint['ema']) + print('=================================================== EMAloaded') + + del checkpoint + torch.cuda.empty_cache() + return max_accuracy + +def load_weights(model, path): + checkpoint = torch.load(path, map_location='cpu') + if 'model' in checkpoint: + checkpoint = checkpoint['model'] + unwrap_model(model).load_state_dict(checkpoint, strict=False) + print('=================== loaded from', path) + + +def save_latest(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, model_ema=None): + save_state = {'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'max_accuracy': max_accuracy, + 'epoch': epoch, + 'config': config} + if config.AMP_OPT_LEVEL != "O0": + save_state['amp'] = amp.state_dict() + if model_ema is not None: + save_state['ema'] = unwrap_model(model_ema).state_dict() + + save_path = os.path.join(config.OUTPUT, 'latest.pth') + logger.info(f"{save_path} saving......") + torch.save(save_state, save_path) + logger.info(f"{save_path} saved !!!") + +def save_checkpoint(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, is_best=False, model_ema=None): + save_state = {'model': model.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'max_accuracy': max_accuracy, + 'epoch': epoch, + 'config': config} + if config.AMP_OPT_LEVEL != "O0": + save_state['amp'] = amp.state_dict() + if model_ema is not None: + save_state['ema'] = unwrap_model(model_ema).state_dict() + + if is_best: + best_path = os.path.join(config.OUTPUT, 'best_ckpt.pth') + torch.save(save_state, best_path) + + save_path = os.path.join(config.OUTPUT, f'ckpt_epoch_{epoch}.pth') + logger.info(f"{save_path} saving......") + torch.save(save_state, save_path) + logger.info(f"{save_path} saved !!!") + + +def get_grad_norm(parameters, norm_type=2): + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + norm_type = float(norm_type) + total_norm = 0 + for p in parameters: + param_norm = p.grad.data.norm(norm_type) + total_norm += param_norm.item() ** norm_type + total_norm = total_norm ** (1. / norm_type) + return total_norm + + +def auto_resume_helper(output_dir): + checkpoints = os.listdir(output_dir) + checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth') and 'ema' not in ckpt] + print(f"All checkpoints founded in {output_dir}: {checkpoints}") + if len(checkpoints) > 0: + latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) + print(f"The latest checkpoint founded: {latest_checkpoint}") + resume_file = latest_checkpoint + else: + resume_file = None + return resume_file + + +def reduce_tensor(tensor): + rt = tensor.clone() + dist.all_reduce(rt, op=dist.ReduceOp.SUM) + rt /= dist.get_world_size() + return rt + + + + +def update_model_ema(cfg, num_gpus, model, model_ema, cur_epoch, cur_iter): + """Update exponential moving average (ema) of model weights.""" + update_period = cfg.TRAIN.EMA_UPDATE_PERIOD + if update_period is None or update_period == 0 or cur_iter % update_period != 0: + return + # Adjust alpha to be fairly independent of other parameters + total_batch_size = num_gpus * cfg.DATA.BATCH_SIZE + adjust = total_batch_size / cfg.TRAIN.EPOCHS * update_period + # print('ema adjust', adjust) + alpha = min(1.0, cfg.TRAIN.EMA_ALPHA * adjust) + # During warmup simply copy over weights instead of using ema + alpha = 1.0 if cur_epoch < cfg.TRAIN.WARMUP_EPOCHS else alpha + # Take ema of all parameters (not just named parameters) + params = unwrap_model(model).state_dict() + for name, param in unwrap_model(model_ema).state_dict().items(): + param.copy_(param * (1.0 - alpha) + params[name] * alpha) + + +class ProgressMeter(object): + def __init__(self, num_batches, meters, prefix=""): + self.batch_fmtstr = self._get_batch_fmtstr(num_batches) + self.meters = meters + self.prefix = prefix + + def display(self, batch): + entries = [self.prefix + self.batch_fmtstr.format(batch)] + entries += [str(meter) for meter in self.meters] + print('\t'.join(entries)) + + def _get_batch_fmtstr(self, num_batches): + num_digits = len(str(num_batches // 1)) + fmt = '{:' + str(num_digits) + 'd}' + return '[' + fmt + '/' + fmt.format(num_batches) + ']' + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self, name, fmt=':f'): + self.name = name + self.fmt = fmt + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + def __str__(self): + fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})' + return fmtstr.format(**self.__dict__) \ No newline at end of file -- Gitee