diff --git a/research/cv/unisiam/README.md b/research/cv/unisiam/README.md new file mode 100644 index 0000000000000000000000000000000000000000..595bd8a1e1842dc4890ea47f433ba61b7ac1f81d --- /dev/null +++ b/research/cv/unisiam/README.md @@ -0,0 +1,72 @@ +# Self-Supervision Can Be a Good Few-Shot Learner + +This is a [MindSpore](https://www.mindspore.cn/) implementation of the ECCV2022 paper [Self-Supervision Can Be a Good Few-Shot Learner (UniSiam)](https://arxiv.org/abs/2207.09176). + +## Contents + +- [Contents](#contents) + - [UniSiam Description](#UniSiam-description) + - [Dataset](#dataset) + - [Environment Requirements](#environment-requirements) + - [Script description](#script-description) + - [Acknowledgements](#acknowledgements) + +## [UniSiam Description](#contents) + +Existing few-shot learning (FSL) methods rely on training with a large labeled dataset, which prevents them from leveraging abundant unlabeled data. From an information-theoretic perspective, we propose an effective unsupervised FSL method, learning representations with self-supervision. Following the InfoMax principle, our method learns comprehensive representations by capturing the intrinsic structure of the data. Specifically, we maximize the mutual information (MI) of instances and their representations with a low-bias MI estimator to perform self-supervised pre-training. Rather than supervised pre-training focusing on the discriminable features of the seen classes, our self-supervised model has less bias toward the seen classes, resulting in better generalization for unseen classes. We explain that supervised pre-training and selfsupervised pre-training are actually maximizing different MI objectives. Extensive experiments are further conducted to analyze their FSL performance with various training settings. Surprisingly, the results show that self-supervised pre-training can outperform supervised pre-training under the appropriate conditions. Compared with state-of-the-art FSL methods, our approach achieves comparable performance on widely used FSL benchmarks without any labels of the base classes. + + +``` +@inproceedings{Lu2022Self, + title={Self-Supervision Can Be a Good Few-Shot Learner}, + author={Lu, Yuning and Wen, Liangjian and Liu, Jianzhuang and Liu, Yajing and Tian, Xinmei}, + booktitle={European Conference on Computer Vision (ECCV)}, + year={2022} +} +``` + +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervision-can-be-a-good-few-shot/unsupervised-few-shot-image-classification-on)](https://paperswithcode.com/sota/unsupervised-few-shot-image-classification-on?p=self-supervision-can-be-a-good-few-shot) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervision-can-be-a-good-few-shot/unsupervised-few-shot-image-classification-on-1)](https://paperswithcode.com/sota/unsupervised-few-shot-image-classification-on-1?p=self-supervision-can-be-a-good-few-shot) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervision-can-be-a-good-few-shot/unsupervised-few-shot-image-classification-on-2)](https://paperswithcode.com/sota/unsupervised-few-shot-image-classification-on-2?p=self-supervision-can-be-a-good-few-shot) +[![PWC](https://img.shields.io/endpoint.svg?url=https://paperswithcode.com/badge/self-supervision-can-be-a-good-few-shot/unsupervised-few-shot-image-classification-on-3)](https://paperswithcode.com/sota/unsupervised-few-shot-image-classification-on-3?p=self-supervision-can-be-a-good-few-shot) + + +## [Dataset](#contents) +- mini-ImageNet + - download the mini-ImageNet dataset from [google drive](https://drive.google.com/file/d/1BfEBMlrf5UT4aNOoJPaa83CgbGWZAAAk/view?usp=sharing) and unzip it. + - download the [split files](https://github.com/twitter/meta-learning-lstm/tree/master/data/miniImagenet) of mini-ImageNet which created by [Ravi and Larochelle](https://openreview.net/pdf?id=rJY0-Kcll). + - move the split files to the folder `./split/miniImageNet` + + +## [Environment Requirements](#contents) +- Hardware(GPU) + - Prepare hardware environment with GPU. +- Framework + - [MindSpore 1.7](https://www.mindspore.cn/install/en) +- For more information, please check the resources below£º + - [MindSpore Tutorials](https://www.mindspore.cn/tutorials/en/master/index.html) + - [MindSpore Python API](https://www.mindspore.cn/docs/en/master/index.html) + + + + + +## [Script description](#contents) + +Run +```python ./train.py --data_path [your DATA FOLDER] --dataset [DATASET NAME] --backbone [BACKBONE] [--OPTIONARG]``` + +For example, to train UniSiam model with ResNet-18 backbone and strong data augmentations on mini-ImageNet (V100): +``` +python train.py \ + --dataset miniImageNet \ + --backbone resnet18 \ + --lrd_step \ + --data_path [your mini-imagenet-folder] \ + --save_path [your save-folder] +``` + + +## [Acknowledgements](#contents) + +Some codes borrow from [SimSiam](https://github.com/facebookresearch/simsiam), [SupContrast](https://github.com/HobbitLong/SupContrast), [(unofficial) SimCLR](https://github.com/AndrewAtanov/simclr-pytorch), [RFS](https://github.com/WangYueFt/rfs). diff --git a/research/cv/unisiam/config/unisiam/mini/r10.sh b/research/cv/unisiam/config/unisiam/mini/r10.sh new file mode 100644 index 0000000000000000000000000000000000000000..dffba0f70766e13ea4a8e02b35dbf00c9de6fb8e --- /dev/null +++ b/research/cv/unisiam/config/unisiam/mini/r10.sh @@ -0,0 +1,6 @@ +python train.py \ + --dataset miniImageNet \ + --backbone resnet10 \ + --lrd_step \ + --data_path [your mini-imagenet-folder] \ + --save_path [your save-folder] \ No newline at end of file diff --git a/research/cv/unisiam/config/unisiam/mini/r18.sh b/research/cv/unisiam/config/unisiam/mini/r18.sh new file mode 100644 index 0000000000000000000000000000000000000000..0f08fd38af28c8d8e398558ec9d62a54134ca6a2 --- /dev/null +++ b/research/cv/unisiam/config/unisiam/mini/r18.sh @@ -0,0 +1,6 @@ +python train.py \ + --dataset miniImageNet \ + --backbone resnet18 \ + --lrd_step \ + --data_path [your mini-imagenet-folder] \ + --save_path [your save-folder] \ No newline at end of file diff --git a/research/cv/unisiam/dataset/miniImageNet.py b/research/cv/unisiam/dataset/miniImageNet.py new file mode 100644 index 0000000000000000000000000000000000000000..4735e7c24cd609d44452382ffcca115eff6a601d --- /dev/null +++ b/research/cv/unisiam/dataset/miniImageNet.py @@ -0,0 +1,41 @@ +import os +import csv +import pickle +import numpy as np +from PIL import Image + + + +class miniImageNet: + def __init__(self, data_path, split_path, partition='train'): + self.data_root = data_path + self.partition = partition + + file_path = os.path.join(split_path, 'miniImageNet', '{}.csv'.format(self.partition)) + self.imgs, self.labels = self._read_csv(file_path) + + def _read_csv(self, file_path): + imgs = [] + labels = [] + labels_name = [] + with open(file_path, 'r') as f: + reader = csv.reader(f) + for i, row in enumerate(reader): + if i==0: + continue + img, label = row[0], row[1] + img = os.path.join(self.data_root, 'images/{}'.format(img)) + imgs.append(img) + if label not in labels_name: + labels_name.append(label) + labels.append(labels_name.index(label)) + return imgs, labels + + def __getitem__(self, item): + img = self.imgs[item] + img = np.fromfile(img, np.uint8) #Image.open(img) + target = self.labels[item] + return img, target + + def __len__(self): + return len(self.labels) \ No newline at end of file diff --git a/research/cv/unisiam/evaluate.py b/research/cv/unisiam/evaluate.py new file mode 100644 index 0000000000000000000000000000000000000000..850df63f4d7164ea6b7fbe62545e9d33ff98c6c7 --- /dev/null +++ b/research/cv/unisiam/evaluate.py @@ -0,0 +1,71 @@ +import mindspore as ms + +import numpy as np +import math + +from sklearn import metrics +from sklearn.linear_model import LogisticRegression +from sklearn.svm import LinearSVC + + +def evaluate_fewshot( + encoder, loader, n_way=5, n_shots=[1,5], n_query=15, n_tasks=3000, classifier='LR', power_norm=True): + + # exect features + encoder.set_train(False) + features_temp = np.zeros( (loader.get_dataset_size()*loader.get_batch_size(), encoder.num_features), dtype=np.float32) + labels_temp = np.zeros(loader.get_dataset_size()*loader.get_batch_size(), dtype=np.int64) + start_idx = 0 + for i, batchs in enumerate(loader): + images, labels = batchs[0], batchs[1] + features = encoder(images) + features /= ms.numpy.norm(features, axis=-1, keepdims=True) + if power_norm: features = features ** 0.5 + bsz = features.shape[0] + features_temp[start_idx:start_idx+bsz, :] = features.asnumpy() + labels_temp[start_idx:start_idx+bsz] = labels.asnumpy() + start_idx += bsz + features_temp = features_temp[:(loader.get_dataset_size()-1)*loader.get_batch_size()+bsz] + labels_temp = labels_temp[:(loader.get_dataset_size()-1)*loader.get_batch_size()+bsz] + + # few-shot evaluation + catlocs = [np.argwhere(labels_temp == c).reshape(-1) for c in range(labels_temp.max() + 1)] + + def get_select_index(n_cls, n_samples): + episode = [] + classes = np.random.choice(len(catlocs), n_cls, replace=False) + episode = [ np.random.choice(catlocs[c], n_samples, replace=False) for c in classes ] + return np.concatenate(episode).reshape((n_cls, n_samples)) + + accs = {} + for n_shot in n_shots: + accs[f'{n_shot}-shot'] = [] + for _ in range(n_tasks): + select_idx = get_select_index(n_way, n_shot+n_query) + sup_idx = select_idx[:, :n_shot].reshape(-1) + qry_idx = select_idx[:, n_shot:].reshape(-1) + sup_f, qry_f = features_temp[sup_idx].reshape(n_way*n_shot, -1), features_temp[qry_idx] + sup_y = np.arange(n_way)[:,None].repeat(n_shot,1).reshape(-1) + qry_y = np.arange(n_way)[:,None].repeat(n_query,1).reshape(-1) + if classifier == 'LR': + clf = LogisticRegression(penalty='l2', + random_state=0, + C=1.0, + solver='lbfgs', + max_iter=1000, + multi_class='multinomial') + elif classifier == 'SVM': + clf = LinearSVC(C=1.0) + clf.fit(sup_f, sup_y) + qry_pred = clf.predict(qry_f) + acc = metrics.accuracy_score(qry_y, qry_pred) + accs[f'{n_shot}-shot'].append(acc) + + for n_shot in n_shots: + acc = np.array(accs[f'{n_shot}-shot']) + mean = acc.mean() + std = acc.std() + c95 = 1.96*std/math.sqrt(acc.shape[0]) + print('classifier: {}, power_norm: {}, {}-way {}-shot acc: {:.2f}+{:.2f}'.format( + classifier, power_norm, n_way, n_shot, mean*100, c95*100)) + return \ No newline at end of file diff --git a/research/cv/unisiam/model/resnet.py b/research/cv/unisiam/model/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3045bfa00327e8e6f2055d99041ed58656f3fb --- /dev/null +++ b/research/cv/unisiam/model/resnet.py @@ -0,0 +1,263 @@ +""" +copy from mindcv (https://github.com/mindspore-ecosystem/mindcv/blob/72610c30c78c4d375d7035184d7c056274a2381e/mindcv/models/resnet.py) +MindSpore implementation of `ResNet`. +Refer to Deep Residual Learning for Image Recognition. +""" + +from typing import Optional, Type, List, Union + +import math + +from mindspore import nn, Tensor +import mindspore as ms +from mindspore.common.initializer import HeNormal + + + +class BasicBlock(nn.Cell): + """define the basic block of resnet""" + expansion: int = 1 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None, + down_sample: Optional[nn.Cell] = None + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + assert groups == 1, 'BasicBlock only supports groups=1' + assert base_width == 64, 'BasicBlock only supports base_width=64' + + self.conv1 = nn.Conv2d(in_channels, channels, kernel_size=3, + stride=stride, padding=1, pad_mode='pad') + self.bn1 = norm(channels) + self.relu = nn.ReLU() + self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, + stride=1, padding=1, pad_mode='pad') + self.bn2 = norm(channels) + self.down_sample = down_sample + + def construct(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.down_sample is not None: + identity = self.down_sample(x) + + out += identity + out = self.relu(out) + + return out + + +class Bottleneck(nn.Cell): + """ + Bottleneck here places the stride for downsampling at 3x3 convolution(self.conv2) as torchvision does, + while original implementation places the stride at the first 1x1 convolution(self.conv1) + """ + expansion: int = 4 + + def __init__(self, + in_channels: int, + channels: int, + stride: int = 1, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None, + down_sample: Optional[nn.Cell] = None + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + + width = int(channels * (base_width / 64.0)) * groups + + self.conv1 = nn.Conv2d(in_channels, width, kernel_size=1, stride=1) + self.bn1 = norm(width) + self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride, + padding=1, pad_mode='pad', group=groups) + self.bn2 = norm(width) + self.conv3 = nn.Conv2d(width, channels * self.expansion, + kernel_size=1, stride=1) + self.bn3 = norm(channels * self.expansion) + self.relu = nn.ReLU() + self.down_sample = down_sample + + def construct(self, x: Tensor) -> Tensor: + identity = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + out = self.relu(out) + + out = self.conv3(out) + out = self.bn3(out) + + if self.down_sample is not None: + identity = self.down_sample(x) + + out += identity + out = self.relu(out) + + return out + + +class ResNet(nn.Cell): + r"""ResNet model class, based on + `"Deep Residual Learning for Image Recognition" `_ + Args: + block: block of resnet. + layers: number of layers of each stage. + num_classes: number of classification classes. Default: 1000. + in_channels: number the channels of the input. Default: 3. + groups: number of groups for group conv in blocks. Default: 1. + base_width: base width of pre group hidden channel in blocks. Default: 64. + norm: normalization layer in blocks. Default: None. + """ + + def __init__(self, + block: Type[Union[BasicBlock, Bottleneck]], + layers: List[int], + num_classes: int = 1000, + in_channels: int = 3, + groups: int = 1, + base_width: int = 64, + norm: Optional[nn.Cell] = None + ) -> None: + super().__init__() + if norm is None: + norm = nn.BatchNorm2d + + self.norm: nn.Cell = norm # add type hints to make pylint happy + self.input_channels = 64 + self.groups = groups + self.base_with = base_width + + self.conv1 = nn.Conv2d(in_channels, self.input_channels, kernel_size=7, + stride=2, pad_mode='pad', padding=3) + self.bn1 = norm(self.input_channels) + self.relu = nn.ReLU() + self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, pad_mode='same') + self.layer1 = self._make_layer(block, 64, layers[0]) + self.layer2 = self._make_layer(block, 128, layers[1], stride=2) + self.layer3 = self._make_layer(block, 256, layers[2], stride=2) + self.layer4 = self._make_layer(block, 512, layers[3], stride=2) + + self.pool = ms.ops.AdaptiveAvgPool2D((1,1)) + self.flatten = ms.ops.Flatten() + self.num_features = 512 * block.expansion + self.fc = nn.Dense(self.num_features, num_classes) + + for _, cell in self.cells_and_names(): + if isinstance(cell, nn.Conv2d): + cell.weight.set_data(ms.common.initializer.initializer( + ms.common.initializer.HeNormal(negative_slope=0, mode='fan_out', nonlinearity='relu'), + cell.weight.shape, cell.weight.dtype)) + elif isinstance(cell, (nn.BatchNorm2d, nn.GroupNorm)): + cell.gamma.set_data(ms.common.initializer.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) + cell.beta.set_data(ms.common.initializer.initializer("zeros", cell.beta.shape, cell.beta.dtype)) + elif isinstance(cell, (nn.Dense)): + cell.weight.set_data(ms.common.initializer.initializer( + ms.common.initializer.HeUniform(negative_slope=math.sqrt(5)), + cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(ms.common.initializer.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + + + def _make_layer(self, + block: Type[Union[BasicBlock, Bottleneck]], + channels: int, + block_nums: int, + stride: int = 1 + ) -> nn.SequentialCell: + """build model depending on cfgs""" + down_sample = None + + if stride != 1 or self.input_channels != channels * block.expansion: + down_sample = nn.SequentialCell([ + nn.Conv2d(self.input_channels, channels * block.expansion, kernel_size=1, stride=stride), + self.norm(channels * block.expansion) + ]) + + layers = [] + layers.append( + block( + self.input_channels, + channels, + stride=stride, + down_sample=down_sample, + groups=self.groups, + base_width=self.base_with, + norm=self.norm + ) + ) + self.input_channels = channels * block.expansion + + for _ in range(1, block_nums): + layers.append( + block( + self.input_channels, + channels, + groups=self.groups, + base_width=self.base_with, + norm=self.norm + ) + ) + + return nn.SequentialCell(layers) + + def _forward_impl(self, x: Tensor) -> Tensor: + """Network forward feature extraction.""" + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + x = self.max_pool(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.pool(x) + x = self.flatten(x) + if self.fc is not None: + x = self.fc(x) + return x + + def construct(self, x: Tensor) -> Tensor: + x = self._forward_impl(x) + return x + + +def resnet10(): + model = ResNet(BasicBlock, [1, 1, 1, 1]) + return model + + +def resnet18(): + model = ResNet(BasicBlock, [2, 2, 2, 2]) + return model + + +def resnet34(): + model = ResNet(BasicBlock, [3, 4, 6, 3]) + return model + + +def resnet50(): + model = ResNet(Bottleneck, [3, 4, 6, 3]) + return model + diff --git a/research/cv/unisiam/model/unisiam.py b/research/cv/unisiam/model/unisiam.py new file mode 100644 index 0000000000000000000000000000000000000000..d42c943e27e0f7c6456a56146370aa5953a0964b --- /dev/null +++ b/research/cv/unisiam/model/unisiam.py @@ -0,0 +1,101 @@ +import math + +import mindspore as ms +import mindspore.ops as ops +import mindspore.numpy as np + +from mindspore import nn + + +class UniSiam(nn.Cell): + def __init__(self, encoder, lamb=0.1, temp=2.0, dim_hidden=None, dist=False, dim_out=2048): + super().__init__() + self.encoder = encoder + self.encoder.fc = None + + dim_in = encoder.num_features + dim_hidden = dim_in if dim_hidden is None else dim_hidden + + self.proj = nn.SequentialCell([ + nn.Dense(dim_in, dim_hidden), + nn.BatchNorm1d(dim_hidden), + nn.ReLU(), + nn.Dense(dim_hidden, dim_hidden), + nn.BatchNorm1d(dim_hidden), + nn.ReLU(), + nn.Dense(dim_hidden, dim_hidden), + nn.BatchNorm1d(dim_hidden),]) + self.pred = nn.SequentialCell([ + nn.Dense(dim_hidden, dim_hidden//4), + nn.BatchNorm1d(dim_hidden//4), + nn.ReLU(), + nn.Dense(dim_hidden//4, dim_hidden)]) + + if dist: + self.pred_dist = nn.SequentialCell([ + nn.Dense(dim_in, dim_out), + nn.BatchNorm1d(dim_out), + nn.ReLU(), + nn.Dense(dim_out, dim_out), + nn.BatchNorm1d(dim_out), + nn.ReLU(), + nn.Dense(dim_out, dim_out), + nn.BatchNorm1d(dim_out), + nn.ReLU(), + nn.Dense(dim_out, dim_out//4), + nn.BatchNorm1d(dim_out//4), + nn.ReLU(), + nn.Dense(dim_out//4, dim_out)]) + + self.lamb = lamb + self.temp = temp + + for _, cell in self.cells_and_names(): + if isinstance(cell, (nn.BatchNorm2d)): + cell.gamma.set_data(ms.common.initializer.initializer("ones", cell.gamma.shape, cell.gamma.dtype)) + cell.beta.set_data(ms.common.initializer.initializer("zeros", cell.beta.shape, cell.beta.dtype)) + elif isinstance(cell, (nn.Dense)): + cell.weight.set_data(ms.common.initializer.initializer( + ms.common.initializer.HeUniform(negative_slope=math.sqrt(5)), + cell.weight.shape, cell.weight.dtype)) + cell.bias.set_data(ms.common.initializer.initializer("zeros", cell.bias.shape, cell.bias.dtype)) + + def construct(self, x, z_dist=None): + + f = self.encoder(x) + z = self.proj(f) + p = self.pred(z) + z1, z2 = np.split(z, axis=0, indices_or_sections=2) + p1, p2 = np.split(p, axis=0, indices_or_sections=2) + + loss_pos = (self.pos(p1, z2)+self.pos(p2,z1))/2 + loss_neg = self.neg(z) + loss = loss_pos + self.lamb * loss_neg + + if z_dist is not None: + p_dist = self.pred_dist(f) + loss_dist = self.pos(p_dist, z_dist) + loss = 0.5 * loss + 0.5 * loss_dist + + std = self.std(z) + + return loss, loss_pos, loss_neg, std + + def std(self, z): + return (z/np.norm(z, axis=1, keepdims=True)).std(axis=0).mean() + + def pos(self, p, z): + z = ops.stop_gradient(z) + z /= np.norm(z, axis=1, keepdims=True) + p /= np.norm(p, axis=1, keepdims=True) + return -(p*z).sum(axis=1).mean() + + + def neg(self, z): + batch_size = z.shape[0] //2 + n_neg = z.shape[0] - 2 + z /= np.norm(z, axis=-1, keepdims=True) + mask = 1-ops.eye(batch_size, batch_size, ms.float32) + mask = np.tile(mask, (2,2)) + out = np.matmul(z, z.T) * mask + return np.log(np.mean( (np.exp(out/self.temp).sum(axis=1)-2) / n_neg )) \ No newline at end of file diff --git a/research/cv/unisiam/split/.keep b/research/cv/unisiam/split/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/research/cv/unisiam/split/miniImageNet/.keep b/research/cv/unisiam/split/miniImageNet/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/research/cv/unisiam/train.py b/research/cv/unisiam/train.py new file mode 100644 index 0000000000000000000000000000000000000000..92e56e203f54501702d3671ad02be1e5078aeb42 --- /dev/null +++ b/research/cv/unisiam/train.py @@ -0,0 +1,235 @@ +from __future__ import print_function + +import os +import sys +import time +import math +import random +import argparse +import moxing as mox +import numpy as np + +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +import mindspore.dataset as ds +import mindspore.communication as comm +import mindspore.dataset.vision.c_transforms as c_transforms + +from model.unisiam import UniSiam +from model.resnet import resnet10, resnet18, resnet34, resnet50 +from dataset.miniImageNet import miniImageNet +from evaluate import evaluate_fewshot +from transform.build_transform import build_transform +from util import AverageMeter, adjust_learning_rate + + +def parse_option(): + parser = argparse.ArgumentParser('argument for training') + + parser.add_argument('--save_path', type=str, default='/home/ma-user/work/ufsl/code_mindspore/src/output', help='path for saving') + parser.add_argument('--data_path', type=str, default='/home/ma-user/work/dataset', help='path to dataset') + parser.add_argument('--dataset', type=str, default='miniImageNet', choices=['miniImageNet'], help='dataset') + parser.add_argument('--print_freq', type=int, default=10, help='print frequency') + parser.add_argument('--num_workers', type=int, default=6, help='num of workers to use') + parser.add_argument('--seed', type=int, default=42, help='random seed') + + # optimization setting + parser.add_argument('--lr', type=float, default=0.3, help='learning rate') + parser.add_argument('--wd', type=float, default=1e-4, help='weight decay') + parser.add_argument('--momentum', type=float, default=0.9, help='momentum') + parser.add_argument('--batch_size', type=int, default=256, help='batch_size') + parser.add_argument('--epochs', type=int, default=400, help='number of training epochs') + parser.add_argument('--lrd_step', action='store_true', help='decay learning rate per step') + + # self-supervision setting + parser.add_argument('--backbone', type=str, default='resnet18', choices=['resnet10', 'resnet18', 'resnet34', 'resnet50']) + parser.add_argument('--size', type=int, default=224, help='input size') + parser.add_argument('--temp', type=float, default=2.0, help='temperature for loss function') + parser.add_argument('--lamb', type=float, default=0.1, help='lambda for uniform loss') + parser.add_argument('--dim_hidden', type=int, default=None, help='hidden dim. of projection') + + # few-shot evaluation setting + parser.add_argument('--n_way', type=int, default=5, help='n_way') + parser.add_argument('--n_query', type=int, default=15, help='n_query') + parser.add_argument('--n_test_task', type=int, default=3000, help='total test few-shot episodes') + parser.add_argument('--test_batch_size', type=int, default=20, help='episode_batch_size') + + args = parser.parse_args() + + args.lr = args.lr * args.batch_size / 256 + if (args.save_path is not None) and (not os.path.isdir(args.save_path)): os.makedirs(args.save_path) + args.split_path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'split') + + return args + + +def build_train_loader(args, device_num=None, rank_id=None): + train_transform = build_transform(args.size) + + if args.dataset == 'miniImageNet': + train_dataset = miniImageNet( + data_path=args.data_path, + split_path=args.split_path, + partition='train') + else: + raise ValueError(args.dataset) + + def copy_column(x, y): + return x, x, y + + train_dataset = ds.GeneratorDataset( + train_dataset, ["image", "label"], shuffle=True, num_parallel_workers=args.num_workers, num_shards=device_num, shard_id=rank_id) + train_dataset = train_dataset.map( + operations=copy_column, input_columns=["image", "label"], output_columns=["image1", "image2", "label"], + column_order=["image1", "image2", "label"], num_parallel_workers=args.num_workers) + train_dataset = train_dataset.map(operations=train_transform, input_columns=["image1"], num_parallel_workers=args.num_workers, python_multiprocessing=True) + train_dataset = train_dataset.map(operations=train_transform, input_columns=["image2"], num_parallel_workers=args.num_workers, python_multiprocessing=True) + train_dataset = train_dataset.batch(args.batch_size) + + return train_dataset + + +def build_fewshot_loader(args, mode='test', device_num=None, rank_id=None): + + assert mode in ['train', 'val', 'test'] + + resize_dict = {160: 182, 224: 256, 288: 330, 320:366, 384:438} + resize_size = resize_dict[args.size] + print('Image Size: {}({})'.format(args.size, resize_size)) + + test_transform = [ + c_transforms.Decode(), + c_transforms.Resize(resize_size), + c_transforms.CenterCrop(args.size), + c_transforms.Normalize(mean=(0.485*255, 0.456*255, 0.406*255), std=(0.229*255, 0.224*255, 0.225*255)), + c_transforms.HWC2CHW(), + ] + print('test_transform: ', test_transform) + + if args.dataset == 'miniImageNet': + test_dataset = miniImageNet( + data_path=args.data_path, + split_path=args.split_path, + partition=mode) + else: + raise ValueError(args.dataset) + + test_dataset = ds.GeneratorDataset( + test_dataset, ["image", "label"], shuffle=False, num_parallel_workers=args.num_workers, num_shards=device_num, shard_id=rank_id) + test_dataset = test_dataset.map(operations=test_transform, input_columns=["image"], num_parallel_workers=args.num_workers, python_multiprocessing=True) + test_dataset = test_dataset.batch(args.batch_size) + + return test_dataset + + +def build_model(args): + model_dict = {'resnet10': resnet10, 'resnet18': resnet18, 'resnet34': resnet34, 'resnet50': resnet50} + encoder = model_dict[args.backbone]() + model = UniSiam(encoder=encoder, lamb=args.lamb, temp=args.temp, dim_hidden=args.dim_hidden) + print(model) + return model + + +class TrainOneStep(nn.Cell): + def __init__(self, model, optimizer): + super(TrainOneStep, self).__init__() + self.model = model + self.optimizer = optimizer + self.weights = optimizer.parameters + self.grad = ops.GradOperation(get_by_list=True) + + def construct(self, images1, images2): + + def forward_fn(images1, images2): + images = ops.Concat(axis=0)((images1, images2)) + loss, loss_pos, loss_neg, std = self.model(images) + return loss, loss_pos, loss_neg, std + + loss, loss_pos, loss_neg, std = forward_fn(images1, images2) + grads = self.grad(forward_fn, self.weights)(images1, images2) + loss = ops.depend(loss, self.optimizer(grads)) + return loss, loss_pos, loss_neg, std + + +def train_one_epoch(train_loader, model, epoch, args): + """one epoch training""" + + model.set_train() + + batch_time = AverageMeter() + data_time = AverageMeter() + loss_hist = AverageMeter() + loss_pos_hist = AverageMeter() + loss_neg_hist = AverageMeter() + std_hist = AverageMeter() + + end = time.time() + + n_iter = train_loader.get_dataset_size() + + for idx, batchs in enumerate(train_loader): + data_time.update(time.time() - end) + + bsz = batchs[0].shape[0] + if args.lrd_step: adjust_learning_rate(args, model.optimizer, idx*1.0/n_iter+epoch, args.epochs) + + loss, loss_pos, loss_neg, std = model(batchs[0], batchs[1]) + + loss_hist.update(loss.asnumpy(), bsz) + loss_pos_hist.update(loss_pos.asnumpy(), bsz) + loss_neg_hist.update(loss_neg.asnumpy(), bsz) + std_hist.update(std.asnumpy(), bsz) + batch_time.update(time.time() - end) + end = time.time() + + if (idx + 1) % args.print_freq == 0: + print('Train: [{0}][{1}/{2}]\t' + 'BT {batch_time.val:.3f} ({batch_time.avg:.3f})\t' + 'DT {data_time.val:.3f} ({data_time.avg:.3f})\t' + 'loss {loss.val:.3f} ({loss.avg:.3f})\t' + 'loss_pos {lossp.val:.3f} ({lossp.avg:.3f})\t' + 'loss_neg {lossn.val:.3f} ({lossn.avg:.3f})\t' + 'std {std.val:.3f} ({std.avg:.3f})'.format( + epoch, idx + 1, train_loader.get_dataset_size(), batch_time=batch_time, + data_time=data_time, loss=loss_hist, lossp=loss_pos_hist, lossn=loss_neg_hist, std=std_hist)) + sys.stdout.flush() + + return loss_hist.avg + + +def main(): + args = parse_option() + print("{}".format(args).replace(', ', ',\n')) + + ms.set_seed(args.seed) + np.random.seed(args.seed) + random.seed(args.seed) + + ms.set_context(mode=ms.GRAPH_MODE, device_target='GPU') + + train_loader = build_train_loader(args) + test_loader = build_fewshot_loader(args, 'test') + + model = build_model(args) + optimizer = nn.SGD(model.trainable_params(), learning_rate=args.lr, weight_decay=args.wd, momentum=0.9) + train_model = TrainOneStep(model, optimizer) + + for epoch in range(args.epochs): + + if not args.lrd_step: adjust_learning_rate(args, optimizer, epoch+1, args.epochs) + + time1 = time.time() + _ = train_one_epoch(train_loader, train_model, epoch, args) + time2 = time.time() + print('epoch {}, total time {:.2f}'.format(epoch, time2 - time1)) + + evaluate_fewshot( + model.encoder, test_loader, n_way=args.n_way, n_shots=[1,5], n_query=args.n_query, n_tasks=args.n_test_task, classifier='LR', power_norm=True) + + if args.save_path is not None: + save_file = os.path.join(args.save_path, 'last.pth') + ms.save_checkpoint(model, save_file) + +if __name__ == '__main__': + main() diff --git a/research/cv/unisiam/transform/build_transform.py b/research/cv/unisiam/transform/build_transform.py new file mode 100644 index 0000000000000000000000000000000000000000..7693eea4d1e7b2b8c7cde01e8a058e0592bc354a --- /dev/null +++ b/research/cv/unisiam/transform/build_transform.py @@ -0,0 +1,47 @@ +import random +import numpy as np +from PIL import ImageFilter + +from mindspore.dataset import vision +import mindspore.dataset.vision.c_transforms as c_transforms +import mindspore.dataset.vision.py_transforms as py_transforms +import mindspore.dataset.transforms as transforms + +from .rand_augmentation import rand_augment_transform + + +class GaussianBlur(object): + """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709""" + def __init__(self, sigma=[.1, 2.]): + self.sigma = sigma + + def __call__(self, x): + sigma = random.uniform(self.sigma[0], self.sigma[1]) + x = x.filter(ImageFilter.GaussianBlur(radius=sigma)) + return x + + +def build_transform(input_size): + + rgb_mean = (0.485, 0.456, 0.406) + ra_params = dict( + translate_const=int(input_size * 0.45), + img_mean=tuple([min(255, round(255 * x)) for x in rgb_mean]), + ) + train_transform = [ + c_transforms.RandomCropDecodeResize(size=input_size, scale=(0.2, 1.)), + transforms.c_transforms.RandomApply([c_transforms.RandomColorAdjust(0.4, 0.4, 0.4, 0.1)], 0.8), + py_transforms.ToPIL(), + py_transforms.RandomGrayscale(0.2), + transforms.py_transforms.RandomApply([GaussianBlur([.1, 2.])], 0.5), + rand_augment_transform('rand-n2-m10-mstd0.5', ra_params, use_cmc=False), + np.array, + c_transforms.RandomHorizontalFlip(), + c_transforms.RandomVerticalFlip(), + c_transforms.Normalize(mean=(0.485*255, 0.456*255, 0.406*255), std=(0.229*255, 0.224*255, 0.225*255)), + c_transforms.HWC2CHW(), + ] + + print('train transform: ', train_transform) + #train_transform = transforms.py_transforms.Compose(train_transform) + return train_transform \ No newline at end of file diff --git a/research/cv/unisiam/transform/rand_augmentation.py b/research/cv/unisiam/transform/rand_augmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..f8269b072afa68fccea66e32ff62e4e1b575de6c --- /dev/null +++ b/research/cv/unisiam/transform/rand_augmentation.py @@ -0,0 +1,444 @@ +""" AutoAugment and RandAugment +""" +import random +import math +import re +from PIL import Image, ImageOps, ImageEnhance +import PIL +import numpy as np + + +_PIL_VER = tuple([int(x) for x in PIL.__version__.split('.')[:2]]) + +_FILL = (128, 128, 128) + +# This signifies the max integer that the controller RNN could predict for the +# augmentation scheme. +_MAX_LEVEL = 10. + +_HPARAMS_DEFAULT = dict( + translate_const=250, + img_mean=_FILL, +) + +_RANDOM_INTERPOLATION = (Image.BILINEAR, Image.BICUBIC) + + +def _interpolation(kwargs): + interpolation = kwargs.pop('resample', Image.BILINEAR) + if isinstance(interpolation, (list, tuple)): + return random.choice(interpolation) + else: + return interpolation + + +def _check_args_tf(kwargs): + if 'fillcolor' in kwargs and _PIL_VER < (5, 0): + kwargs.pop('fillcolor') + kwargs['resample'] = _interpolation(kwargs) + + +def shear_x(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, factor, 0, 0, 1, 0), **kwargs) + + +def shear_y(img, factor, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, factor, 1, 0), **kwargs) + + +def translate_x_rel(img, pct, **kwargs): + pixels = pct * img.size[0] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_rel(img, pct, **kwargs): + pixels = pct * img.size[1] + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def translate_x_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, pixels, 0, 1, 0), **kwargs) + + +def translate_y_abs(img, pixels, **kwargs): + _check_args_tf(kwargs) + return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, pixels), **kwargs) + + +def rotate(img, degrees, **kwargs): + _check_args_tf(kwargs) + if _PIL_VER >= (5, 2): + return img.rotate(degrees, **kwargs) + elif _PIL_VER >= (5, 0): + w, h = img.size + post_trans = (0, 0) + rotn_center = (w / 2.0, h / 2.0) + angle = -math.radians(degrees) + matrix = [ + round(math.cos(angle), 15), + round(math.sin(angle), 15), + 0.0, + round(-math.sin(angle), 15), + round(math.cos(angle), 15), + 0.0, + ] + + def transform(x, y, matrix): + (a, b, c, d, e, f) = matrix + return a * x + b * y + c, d * x + e * y + f + + matrix[2], matrix[5] = transform( + -rotn_center[0] - post_trans[0], -rotn_center[1] - post_trans[1], matrix + ) + matrix[2] += rotn_center[0] + matrix[5] += rotn_center[1] + return img.transform(img.size, Image.AFFINE, matrix, **kwargs) + else: + return img.rotate(degrees, resample=kwargs['resample']) + + +def auto_contrast(img, **__): + return ImageOps.autocontrast(img) + + +def invert(img, **__): + return ImageOps.invert(img) + + +def identity(img, **__): + return img + + +def equalize(img, **__): + return ImageOps.equalize(img) + + +def solarize(img, thresh, **__): + return ImageOps.solarize(img, thresh) + + +def solarize_add(img, add, thresh=128, **__): + lut = [] + for i in range(256): + if i < thresh: + lut.append(min(255, i + add)) + else: + lut.append(i) + if img.mode in ("L", "RGB"): + if img.mode == "RGB" and len(lut) == 256: + lut = lut + lut + lut + return img.point(lut) + else: + return img + + +def posterize(img, bits_to_keep, **__): + if bits_to_keep >= 8: + return img + return ImageOps.posterize(img, bits_to_keep) + + +def contrast(img, factor, **__): + return ImageEnhance.Contrast(img).enhance(factor) + + +def color(img, factor, **__): + return ImageEnhance.Color(img).enhance(factor) + + +def brightness(img, factor, **__): + return ImageEnhance.Brightness(img).enhance(factor) + + +def sharpness(img, factor, **__): + return ImageEnhance.Sharpness(img).enhance(factor) + + +def _randomly_negate(v): + """With 50% prob, negate the value""" + return -v if random.random() > 0.5 else v + + +def _rotate_level_to_arg(level, _hparams): + # range [-30, 30] + level = (level / _MAX_LEVEL) * 30. + level = _randomly_negate(level) + return level, + + +def _enhance_level_to_arg(level, _hparams): + # range [0.1, 1.9] + return (level / _MAX_LEVEL) * 1.8 + 0.1, + + +def _shear_level_to_arg(level, _hparams): + # range [-0.3, 0.3] + level = (level / _MAX_LEVEL) * 0.3 + level = _randomly_negate(level) + return level, + + +def _translate_abs_level_to_arg(level, hparams): + translate_const = hparams['translate_const'] + level = (level / _MAX_LEVEL) * float(translate_const) + level = _randomly_negate(level) + return level, + + +def _translate_rel_level_to_arg(level, _hparams): + # range [-0.45, 0.45] + level = (level / _MAX_LEVEL) * 0.45 + level = _randomly_negate(level) + return level, + + +def _posterize_original_level_to_arg(level, _hparams): + # As per original AutoAugment paper description + # range [4, 8], 'keep 4 up to 8 MSB of image' + return int((level / _MAX_LEVEL) * 4) + 4, + + +def _posterize_research_level_to_arg(level, _hparams): + # As per Tensorflow models research and UDA impl + # range [4, 0], 'keep 4 down to 0 MSB of original image' + return 4 - int((level / _MAX_LEVEL) * 4), + + +def _posterize_tpu_level_to_arg(level, _hparams): + # As per Tensorflow TPU EfficientNet impl + # range [0, 4], 'keep 0 up to 4 MSB of original image' + return int((level / _MAX_LEVEL) * 4), + + +def _solarize_level_to_arg(level, _hparams): + # range [0, 256] + return int((level / _MAX_LEVEL) * 256), + + +def _solarize_add_level_to_arg(level, _hparams): + # range [0, 110] + return int((level / _MAX_LEVEL) * 110), + + +LEVEL_TO_ARG = { + 'AutoContrast': None, + 'Equalize': None, + 'Invert': None, + 'Identity': None, + 'Rotate': _rotate_level_to_arg, + # There are several variations of the posterize level scaling in various Tensorflow/Google repositories/papers + 'PosterizeOriginal': _posterize_original_level_to_arg, + 'PosterizeResearch': _posterize_research_level_to_arg, + 'PosterizeTpu': _posterize_tpu_level_to_arg, + 'Solarize': _solarize_level_to_arg, + 'SolarizeAdd': _solarize_add_level_to_arg, + 'Color': _enhance_level_to_arg, + 'Contrast': _enhance_level_to_arg, + 'Brightness': _enhance_level_to_arg, + 'Sharpness': _enhance_level_to_arg, + 'ShearX': _shear_level_to_arg, + 'ShearY': _shear_level_to_arg, + 'TranslateX': _translate_abs_level_to_arg, + 'TranslateY': _translate_abs_level_to_arg, + 'TranslateXRel': _translate_rel_level_to_arg, + 'TranslateYRel': _translate_rel_level_to_arg, +} + + +NAME_TO_OP = { + 'AutoContrast': auto_contrast, + 'Equalize': equalize, + 'Invert': invert, + 'Identity': identity, + 'Rotate': rotate, + 'PosterizeOriginal': posterize, + 'PosterizeResearch': posterize, + 'PosterizeTpu': posterize, + 'Solarize': solarize, + 'SolarizeAdd': solarize_add, + 'Color': color, + 'Contrast': contrast, + 'Brightness': brightness, + 'Sharpness': sharpness, + 'ShearX': shear_x, + 'ShearY': shear_y, + 'TranslateX': translate_x_abs, + 'TranslateY': translate_y_abs, + 'TranslateXRel': translate_x_rel, + 'TranslateYRel': translate_y_rel, +} + + +class AutoAugmentOp: + + def __init__(self, name, prob=0.5, magnitude=10, hparams=None): + hparams = hparams or _HPARAMS_DEFAULT + self.aug_fn = NAME_TO_OP[name] + self.level_fn = LEVEL_TO_ARG[name] + self.prob = prob + self.magnitude = magnitude + self.hparams = hparams.copy() + self.kwargs = dict( + fillcolor=hparams['img_mean'] if 'img_mean' in hparams else _FILL, + resample=hparams['interpolation'] if 'interpolation' in hparams else _RANDOM_INTERPOLATION, + ) + + # If magnitude_std is > 0, we introduce some randomness + # in the usually fixed policy and sample magnitude from a normal distribution + # with mean `magnitude` and std-dev of `magnitude_std`. + # NOTE This is my own hack, being tested, not in papers or reference impls. + self.magnitude_std = self.hparams.get('magnitude_std', 0) + + def __call__(self, img): + if random.random() > self.prob: + return img + magnitude = self.magnitude + if self.magnitude_std and self.magnitude_std > 0: + magnitude = random.gauss(magnitude, self.magnitude_std) + magnitude = min(_MAX_LEVEL, max(0, magnitude)) # clip to valid range + level_args = self.level_fn(magnitude, self.hparams) if self.level_fn is not None else tuple() + return self.aug_fn(img, *level_args, **self.kwargs) + + +_RAND_TRANSFORMS = [ + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'PosterizeTpu', + 'Solarize', + 'SolarizeAdd', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # FIXME I implement this as random erasing separately +] + +_RAND_TRANSFORMS_CMC = [ + 'AutoContrast', + 'Identity', + 'Rotate', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateXRel', + 'TranslateYRel', + #'Cutout' # FIXME I implement this as random erasing separately +] + + +# These experimental weights are based loosely on the relative improvements mentioned in paper. +# They may not result in increased performance, but could likely be tuned to so. +_RAND_CHOICE_WEIGHTS_0 = { + 'Rotate': 0.3, + 'ShearX': 0.2, + 'ShearY': 0.2, + 'TranslateXRel': 0.1, + 'TranslateYRel': 0.1, + 'Color': .025, + 'Sharpness': 0.025, + 'AutoContrast': 0.025, + 'Solarize': .005, + 'SolarizeAdd': .005, + 'Contrast': .005, + 'Brightness': .005, + 'Equalize': .005, + 'PosterizeTpu': 0, + 'Invert': 0, +} + + +def _select_rand_weights(weight_idx=0, transforms=None): + transforms = transforms or _RAND_TRANSFORMS + assert weight_idx == 0 # only one set of weights currently + rand_weights = _RAND_CHOICE_WEIGHTS_0 + probs = [rand_weights[k] for k in transforms] + probs /= np.sum(probs) + return probs + + +def rand_augment_ops(magnitude=10, hparams=None, transforms=None): + """rand augment ops for RGB images""" + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS + return [AutoAugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +def rand_augment_ops_cmc(magnitude=10, hparams=None, transforms=None): + """rand augment ops for CMC images (removing color ops)""" + hparams = hparams or _HPARAMS_DEFAULT + transforms = transforms or _RAND_TRANSFORMS_CMC + return [AutoAugmentOp( + name, prob=0.5, magnitude=magnitude, hparams=hparams) for name in transforms] + + +class RandAugment: + def __init__(self, ops, num_layers=2, choice_weights=None): + self.ops = ops + self.num_layers = num_layers + self.choice_weights = choice_weights + + def __call__(self, img): + # no replacement when using weighted choice + ops = np.random.choice( + self.ops, self.num_layers, replace=self.choice_weights is None, p=self.choice_weights) + for op in ops: + img = op(img) + return img + + +def rand_augment_transform(config_str, hparams, use_cmc=False): + """ + Create a RandAugment transform + :param config_str: String defining configuration of random augmentation. Consists of multiple sections separated by + dashes ('-'). The first section defines the specific variant of rand augment (currently only 'rand'). The remaining + sections, not order specific determine + 'm' - integer magnitude of rand augment + 'n' - integer num layers (number of transform ops selected per image) + 'w' - integer probabiliy weight index (index of a set of weights to influence choice of op) + 'mstd' - float std deviation of magnitude noise applied + Ex 'rand-m9-n3-mstd0.5' results in RandAugment with magnitude 9, num_layers 3, magnitude_std 0.5 + 'rand-mstd1-w0' results in magnitude_std 1.0, weights 0, default magnitude of 10 and num_layers 2 + :param hparams: Other hparams (kwargs) for the RandAugmentation scheme + :param use_cmc: Flag indicates removing augmentation for coloring ops. + :return: A PyTorch compatible Transform + """ + magnitude = _MAX_LEVEL # default to _MAX_LEVEL for magnitude (currently 10) + num_layers = 2 # default to 2 ops per image + weight_idx = None # default to no probability weights for op choice + config = config_str.split('-') + assert config[0] == 'rand' + config = config[1:] + for c in config: + cs = re.split(r'(\d.*)', c) + if len(cs) < 2: + continue + key, val = cs[:2] + if key == 'mstd': + # noise param injected via hparams for now + hparams.setdefault('magnitude_std', float(val)) + elif key == 'm': + magnitude = int(val) + elif key == 'n': + num_layers = int(val) + elif key == 'w': + weight_idx = int(val) + else: + assert False, 'Unknown RandAugment config section' + if use_cmc: + ra_ops = rand_augment_ops_cmc(magnitude=magnitude, hparams=hparams) + else: + ra_ops = rand_augment_ops(magnitude=magnitude, hparams=hparams) + choice_weights = None if weight_idx is None else _select_rand_weights(weight_idx) + return RandAugment(ra_ops, num_layers, choice_weights=choice_weights) \ No newline at end of file diff --git a/research/cv/unisiam/util.py b/research/cv/unisiam/util.py new file mode 100644 index 0000000000000000000000000000000000000000..6e1ebd630415f4885166949bf38557a2090f79ff --- /dev/null +++ b/research/cv/unisiam/util.py @@ -0,0 +1,30 @@ +import math +import numpy as np + +import mindspore as ms +from mindspore import ops, nn + + +class AverageMeter(object): + """Computes and stores the average and current value""" + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + +def adjust_learning_rate(args, optimizer, cur_iter, total_iter): + lr = args.lr + eta_min = lr * 1e-3 + lr = eta_min + (lr - eta_min) * (1 + math.cos(math.pi * cur_iter / total_iter)) / 2 + + ops.assign(optimizer.learning_rate, ms.Tensor(lr, ms.float32)) \ No newline at end of file