From 369a4dbca3438c8ea70bf9a365bf24c691d5e91d Mon Sep 17 00:00:00 2001 From: "hongliang.yuan" Date: Fri, 20 Dec 2024 14:20:37 +0800 Subject: [PATCH] reset cv/semantic_segmentation/torchvision/pytorch --- .../pytorch/dataloader/segmentation.py | 39 ++-- .../torchvision/pytorch/train.py | 211 +++++++++++------- .../torchvision/pytorch/utils.py | 37 ++- 3 files changed, 180 insertions(+), 107 deletions(-) diff --git a/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py b/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py index 04ffe316..3791b8b6 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py +++ b/cv/semantic_segmentation/torchvision/pytorch/dataloader/segmentation.py @@ -1,25 +1,18 @@ -# Copyright (c) 2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. # All Rights Reserved. # -# Licensed under the Apache License, Version 2.0 (the "License"); you may -# not use this file except in compliance with the License. You may obtain -# a copy of the License at +# Licensed under the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. You may obtain +# a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# Copyright (c) 2022 Iluvatar CoreX. All rights reserved. -# Copyright Declaration: This software, including all of its code and documentation, -# except for the third-party software it contains, is a copyrighted work of Shanghai Iluvatar CoreX -# Semiconductor Co., Ltd. and its affiliates ("Iluvatar CoreX") in accordance with the PRC Copyright -# Law and relevant international treaties, and all rights contained therein are enjoyed by Iluvatar -# CoreX. No user of this software shall have any right, ownership or interest in this software and -# any use of this software shall be in compliance with the terms and conditions of the End User -# License Agreement. +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + import torchvision @@ -37,14 +30,12 @@ Examples: """ -def get_transform(train): - base_size = 520 - crop_size = 480 - return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(base_size) +def get_transform(train, base_size, crop_size): + return presets.SegmentationPresetTrain(base_size, crop_size) if train else presets.SegmentationPresetEval(crop_size) -def get_dataset(dir_path, name, image_set): - transform = get_transform(image_set == 'train') +def get_dataset(dir_path, name, image_set, base_size=540, crop_size=512): + transform = get_transform(image_set == 'train', base_size, crop_size) # name = 'camvid' def sbd(*args, **kwargs): return torchvision.datasets.SBDataset(*args, mode='segmentation', **kwargs) diff --git a/cv/semantic_segmentation/torchvision/pytorch/train.py b/cv/semantic_segmentation/torchvision/pytorch/train.py index 371fdad0..9f997a3a 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/train.py +++ b/cv/semantic_segmentation/torchvision/pytorch/train.py @@ -1,16 +1,21 @@ -# Copyright (c) 2022-2024, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. -# All Rights Reserved. # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. +# Copyright (c) 2022, Shanghai Iluvatar CoreX Semiconductor Co., Ltd. +# All Rights Reserved. import datetime import os -import sys import time -import math + import torch -from torch import nn import torch.utils.data import torchvision +from torch import nn +import torch.nn.functional as TF + +try: + from apex import amp as apex_amp +except: + apex_amp = None import utils from dataloader.segmentation import get_dataset @@ -21,25 +26,16 @@ try: except: autocast = None scaler = None -import ssl -ssl._create_default_https_context = ssl._create_unverified_context - -import torchvision.models.resnet -print("WARN: Using pretrained weights from torchvision-0.9.") -torchvision.models.resnet.model_urls = { - 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', - 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', - 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', - 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', - 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', - 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', - 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', - 'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth', - 'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth', -} def criterion(inputs, target): + if isinstance(inputs, (tuple, list)): + inputs = {str(i): x for i, x in enumerate(inputs)} + inputs["out"] = inputs.pop("0") + + if not isinstance(inputs, dict): + inputs = dict(out=inputs) + losses = {} for name, x in inputs.items(): losses[name] = nn.functional.cross_entropy(x, target, ignore_index=255) @@ -47,7 +43,8 @@ def criterion(inputs, target): if len(losses) == 1: return losses['out'] - return losses['out'] + 0.5 * losses['aux'] + loss = losses.pop("out") + return loss + 0.5 * sum(losses.values()) def evaluate(model, data_loader, device, num_classes): @@ -59,7 +56,13 @@ def evaluate(model, data_loader, device, num_classes): for image, target in metric_logger.log_every(data_loader, 100, header): image, target = image.to(device), target.to(device) output = model(image) - output = output['out'] + if isinstance(output, dict): + output = output['out'] + if isinstance(output, (tuple, list)): + output = output[0] + + if output.shape[2:] != image.shape[2:]: + output = TF.upsample(output, image.shape[2:], mode="bilinear") confmat.update(target.flatten(), output.argmax(1).flatten()) @@ -68,7 +71,10 @@ def evaluate(model, data_loader, device, num_classes): return confmat -def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, print_freq, amp=False): +def train_one_epoch(model, criterion, optimizer, + data_loader, lr_scheduler, + device, epoch, print_freq, + use_amp=False, use_nhwc=False): model.train() metric_logger = utils.MetricLogger(delimiter=" ") metric_logger.add_meter('lr', utils.SmoothedValue(window_size=1, fmt='{value}')) @@ -78,33 +84,22 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi all_fps = [] for image, target in metric_logger.log_every(data_loader, print_freq, header): start_time = time.time() - image, target = image.to(device), target.to(device) + image, target = image.to(device, non_blocking=True), target.to(device, non_blocking=True) - if autocast is None or not amp: - output = model(image) - loss = criterion(output, target) - else: - with autocast(): - output = model(image) - loss = criterion(output, target) + output = model(image) + loss = criterion(output, target) - optimizer.zero_grad() - if scaler is not None and amp: - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() + if use_amp: + with apex_amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() else: loss.backward() - optimizer.step() - torch.cuda.synchronize() - end_time = time.time() + optimizer.step() + optimizer.zero_grad() + end_time = time.time() lr_scheduler.step() - loss_value = loss.item() - if not math.isfinite(loss_value): - print("Loss is {}, stopping training".format(loss_value)) - sys.exit(1) metric_logger.update(loss=loss.item(), lr=optimizer.param_groups[0]["lr"]) fps = image.shape[0] / (end_time - start_time) * utils.get_world_size() @@ -113,6 +108,7 @@ def train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, devi print(header, 'Avg img/s:', sum(all_fps) / len(all_fps)) + def main(args): if args.output_dir: utils.mkdir(args.output_dir) @@ -124,8 +120,16 @@ def main(args): torch.backends.cudnn.benchmark = True - dataset, num_classes = get_dataset(args.data_path, args.dataset, "train") - dataset_test, _ = get_dataset(args.data_path, args.dataset, "val") + dataset, num_classes = get_dataset(args.data_path, args.dataset, "train", + crop_size=args.crop_size, base_size=args.base_size) + dataset_test, _ = get_dataset(args.data_path, args.dataset, "val", + crop_size=args.crop_size, base_size=args.base_size) + args.num_classes = num_classes + + if args.nhwc: + collate_fn = utils.nhwc_collate_fn(fp16=args.amp, padding_channel=args.padding_channel) + else: + collate_fn = utils.collate_fn if args.distributed: train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) @@ -137,36 +141,52 @@ def main(args): data_loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, sampler=train_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn, drop_last=True) + collate_fn=collate_fn, drop_last=True) data_loader_test = torch.utils.data.DataLoader( dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, - collate_fn=utils.collate_fn) + collate_fn=collate_fn) - model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, - aux_loss=args.aux_loss, - pretrained=args.pretrained) + if hasattr(args, "model_cls"): + model = args.model_cls(args) + else: + model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes, + aux_loss=args.aux_loss, + pretrained=args.pretrained) + if args.padding_channel: + if hasattr(model, "backbone") and hasattr(model.backbone, "conv1"): + model.backbone.conv1 = utils.padding_conv_channel_to_4(model.backbone.conv1) + else: + print("WARN: Cannot convert first conv to N4HW.") model.to(device) if args.distributed: model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) - model_without_ddp = model - if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) - model_without_ddp = model.module + if args.nhwc: + model = model.cuda().to(memory_format=torch.channels_last) params_to_optimize = [ - {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]}, - {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]}, + {"params": [p for p in model.parameters() if p.requires_grad]}, ] - if args.aux_loss: - params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad] - params_to_optimize.append({"params": params, "lr": args.lr * 10}) + optimizer = torch.optim.SGD( params_to_optimize, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + if args.amp: + model, optimizer = apex_amp.initialize(model, optimizer, opt_level="O2", + loss_scale=args.loss_scale, + master_weights=True) + + model_without_ddp = model + if args.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[args.gpu], + find_unused_parameters=args.find_unused_parameters + ) + model_without_ddp = model.module + lr_scheduler = torch.optim.lr_scheduler.LambdaLR( optimizer, lambda x: (1 - x / (len(data_loader) * args.epochs)) ** 0.9) @@ -189,22 +209,24 @@ def main(args): epoch_start_time = time.time() if args.distributed: train_sampler.set_epoch(epoch) - train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, args.amp) + train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch, args.print_freq, + args.amp, args.nhwc) confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes) print(confmat) - checkpoint = { - 'model': model_without_ddp.state_dict(), - 'optimizer': optimizer.state_dict(), - 'lr_scheduler': lr_scheduler.state_dict(), - 'epoch': epoch, - 'args': args - } - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) - utils.save_on_master( - checkpoint, - os.path.join(args.output_dir, 'checkpoint.pth')) + if args.output_dir is not None: + checkpoint = { + 'model': model_without_ddp.state_dict(), + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + 'epoch': epoch, + 'args': args + } + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'model_{}.pth'.format(epoch))) + utils.save_on_master( + checkpoint, + os.path.join(args.output_dir, 'checkpoint.pth')) epoch_total_time = time.time() - epoch_start_time epoch_total_time_str = str(datetime.timedelta(seconds=int(epoch_total_time))) print('epoch time {}'.format(epoch_total_time_str)) @@ -220,7 +242,7 @@ def get_args_parser(add_help=True): parser.add_argument('--data-path', default='/datasets01/COCO/022719/', help='dataset path') parser.add_argument('--dataset', default='camvid', help='dataset name') - parser.add_argument('--model', default='fcn_resnet101', help='model') + parser.add_argument('--model', default='deeplabv3_resnet50', help='model') parser.add_argument('--aux-loss', action='store_true', help='auxiliar loss') parser.add_argument('--device', default='cuda', help='device') parser.add_argument('-b', '--batch-size', default=8, type=int) @@ -236,7 +258,7 @@ def get_args_parser(add_help=True): metavar='W', help='weight decay (default: 1e-4)', dest='weight_decay') parser.add_argument('--print-freq', default=10, type=int, help='print frequency') - parser.add_argument('--output-dir', default='.', help='path where to save') + parser.add_argument('--output-dir', default=None, help='path where to save') parser.add_argument('--resume', default='', help='resume from checkpoint') parser.add_argument('--start-epoch', default=0, type=int, metavar='N', help='start epoch') @@ -253,20 +275,45 @@ def get_args_parser(add_help=True): action="store_true", ) # distributed training parameters - parser.add_argument('--local_rank', '--local-rank', default=-1, type=int, + parser.add_argument('--local_rank', default=-1, type=int, help='Local rank') parser.add_argument('--world-size', default=1, type=int, help='number of distributed processes') parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') parser.add_argument('--amp', action='store_true', help='Automatic Mixed Precision training') + parser.add_argument('--padding-channel', action='store_true', help='Padding the channels of image to 4') + parser.add_argument('--loss_scale', default="dynamic", type=str) + parser.add_argument('--nhwc', action='store_true', help='Use NHWC') + parser.add_argument('--find_unused_parameters', action='store_true') + parser.add_argument('--crop-size', default=512, type=int) + parser.add_argument('--base-size', default=540, type=int) return parser +def check_agrs(args): + try: + args.loss_scale = float(args.loss_scale) + except: pass + + if args.padding_channel: + if not args.nhwc: + print("Turning nhwc when padding the channel of image.") + args.nhwc = True + + if args.amp: + if apex_amp is None: + raise RuntimeError("Not found apex in installed packages, cannot enable amp.") + + +def train_model(model_cls=None): + args = get_args_parser().parse_args() + check_agrs(args) + if model_cls is not None: + args.model_cls = model_cls + main(args) + + if __name__ == "__main__": args = get_args_parser().parse_args() - try: - from dltest import show_training_arguments - show_training_arguments(args) - except: - pass + check_agrs(args) main(args) diff --git a/cv/semantic_segmentation/torchvision/pytorch/utils.py b/cv/semantic_segmentation/torchvision/pytorch/utils.py index 88d59a6a..48b348ba 100644 --- a/cv/semantic_segmentation/torchvision/pytorch/utils.py +++ b/cv/semantic_segmentation/torchvision/pytorch/utils.py @@ -69,4 +69,39 @@ def collate_fn(batch): images, targets = list(zip(*batch)) batched_imgs = cat_list(images, fill_value=0) batched_targets = cat_list(targets, fill_value=255) - return batched_imgs, batched_targets \ No newline at end of file + return batched_imgs, batched_targets + + +def nhwc_collate_fn(fp16=False, padding_channel=False): + dtype = torch.float32 + if fp16: + dtype = torch.float16 + def _collect_fn(batch): + batch = collate_fn(batch) + if not padding_channel: + return batch + batch = list(batch) + image = batch[0] + zeros = image.new_zeros(image.shape[0], image.shape[2], image.shape[3], 1) + image = torch.cat([image.permute(0, 2, 3, 1), zeros], dim=-1).permute(0, 3, 1, 2) + image = image.to(memory_format=torch.channels_last, dtype=dtype) + batch[0] = image + return batch + + return _collect_fn + + +def padding_conv_channel_to_4(conv: torch.nn.Conv2d): + new_conv = torch.nn.Conv2d( + 4, conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + dilation=conv.dilation, + bias=conv.bias is not None + ) + weight_shape = conv.weight.shape + padding_weight = conv.weight.new_zeros(weight_shape[0], 1, *weight_shape[2:]) + new_conv.weight = torch.nn.Parameter(torch.cat([conv.weight, padding_weight], dim=1)) + new_conv.bias = conv.bias + return new_conv -- Gitee