From a2b595da8573ff4e0020dd5a5d403946b95b9b1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E5=AE=8F=E7=AC=8B?= Date: Tue, 10 Sep 2024 16:47:23 +0800 Subject: [PATCH 1/2] =?UTF-8?q?=E3=80=90PyTorch=E3=80=91=E3=80=90contrib?= =?UTF-8?q?=E3=80=91=E3=80=90FCOS=E3=80=91=E6=80=A7=E8=83=BD=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cv/detection/FCOS/1.11_requirements.txt | 2 + .../cv/detection/FCOS/2.1_requirements.txt | 2 + PyTorch/contrib/cv/detection/FCOS/README.md | 76 +++++++++++++------ .../fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py | 13 ++-- .../FCOS/mmcv_need/epoch_based_runner.py | 2 +- .../cv/detection/FCOS/mmdet/apis/train.py | 7 +- .../detection/FCOS/mmdet/datasets/builder.py | 34 ++++++++- .../mmdet/models/dense_heads/fcos_head.py | 14 +++- .../FCOS/mmdet/models/losses/focal_loss.py | 4 +- .../cv/detection/FCOS/requirements/build.txt | 2 +- .../detection/FCOS/requirements/optional.txt | 2 +- .../contrib/cv/detection/FCOS/test/env_npu.sh | 2 + .../cv/detection/FCOS/test/train_full_1p.sh | 6 +- .../cv/detection/FCOS/test/train_full_8p.sh | 6 +- .../FCOS/test/train_performance_1p.sh | 7 +- .../FCOS/test/train_performance_4p.sh | 6 +- .../FCOS/test/train_performance_8p.sh | 6 +- .../contrib/cv/detection/FCOS/tools/train.py | 4 + 18 files changed, 144 insertions(+), 51 deletions(-) diff --git a/PyTorch/contrib/cv/detection/FCOS/1.11_requirements.txt b/PyTorch/contrib/cv/detection/FCOS/1.11_requirements.txt index dd21091f33..16926c4365 100644 --- a/PyTorch/contrib/cv/detection/FCOS/1.11_requirements.txt +++ b/PyTorch/contrib/cv/detection/FCOS/1.11_requirements.txt @@ -1,5 +1,7 @@ torchvision==0.12.0 Pillow==9.1.0 +seaborn==0.11.0 +pycocotools==2.0.7 -r requirements/build.txt -r requirements/optional.txt -r requirements/runtime.txt diff --git a/PyTorch/contrib/cv/detection/FCOS/2.1_requirements.txt b/PyTorch/contrib/cv/detection/FCOS/2.1_requirements.txt index 4093b09702..647932c916 100644 --- a/PyTorch/contrib/cv/detection/FCOS/2.1_requirements.txt +++ b/PyTorch/contrib/cv/detection/FCOS/2.1_requirements.txt @@ -1,5 +1,7 @@ torchvision==0.16.0 Pillow==9.1.0 +seaborn==0.11.0 +pycocotools==2.0.7 -r requirements/build.txt -r requirements/optional.txt -r requirements/runtime.txt diff --git a/PyTorch/contrib/cv/detection/FCOS/README.md b/PyTorch/contrib/cv/detection/FCOS/README.md index 1547a894ef..2e021bca4b 100644 --- a/PyTorch/contrib/cv/detection/FCOS/README.md +++ b/PyTorch/contrib/cv/detection/FCOS/README.md @@ -55,6 +55,8 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 在模型源码包根目录下执行命令,安装模型对应PyTorch版本需要的依赖。 ``` + pip install cython==0.29.33 # 前置 + pip install -r 1.11_requirements.txt # PyTorch1.11版本 pip install -r 2.1_requirements.txt # PyTorch2.1版本 @@ -152,13 +154,7 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 ## 训练模型 -1. 进入解压后的源码包根目录。 - - ``` - cd /${模型文件夹名称} - ``` - -2. 运行训练脚本。 +1. 运行训练脚本。 该模型支持单机单卡训练和单机8卡训练。 @@ -167,8 +163,8 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 启动单卡训练。 ``` - bash ./test/train_full_1p.sh --data_path=/data/xxx/ # 精度训练 - bash ./test/train_performance_1p.sh --data_path=/data/xxx/ --data_shuffle=False # 性能训练 + bash ./test/train_full_1p.sh --data_path=/data/xxx/ --batch_size=4 --total_epochs=1 # 精度训练 + bash ./test/train_performance_1p.sh --data_path=/data/xxx/ --batch_size=4 --total_epochs=1 # 性能训练 ``` - 单机8卡训练 @@ -176,17 +172,16 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 启动8卡训练。 ``` - bash ./test/train_full_8p.sh --data_path=/data/xxx/ # 精度训练 - bash ./test/train_performance_8p.sh --data_path=/data/xxx/ --data_shuffle=False # 性能训练 + bash ./test/train_full_8p.sh --data_path=/data/xxx/ --batch_size=4 --total_epochs=1 # 精度训练 + bash ./test/train_performance_8p.sh --data_path=/data/xxx/ --batch_size=4 --total_epochs=1 # 性能训练 ``` - - 多机多卡训练 多机多卡性能数据获取流程,在每个节点上执行: ``` - bash ./test/train_performance_multinodes.sh --data_path=数据集路径 --batch_size=单卡batch_size --nnodes=机器总数量 --node_rank=当前机器rank(0,1,2..) --local_addr=当前机器IP(需要和master_addr处于同一网段) --master_addr=主节点IP --data_shuffle=False + bash ./test/train_performance_multinodes.sh --data_path=数据集路径 --batch_size=单卡batch_size --nnodes=机器总数量 --node_rank=当前机器rank(0,1,2..) --local_addr=当前机器IP(需要和master_addr处于同一网段) --master_addr=主节点IP --data_shuffle=True ``` --data_path参数填写数据集路径,需写到数据集的一级目录。 @@ -195,9 +190,10 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 ``` --data_path //数据集路径 - --device_id //npu卡号 - --batch-size //训练批次大小 - --data_shuffle //控制shuffle开关 + --device_id //npu卡号 + --batch-size //默认2,训练批次大小 + --data_shuffle //默认True,控制shuffle开关 + --total_epochs //默认1,训练次数 ``` 训练完成后,权重文件保存在当前路径下,并输出模型训练精度和性能信息。 @@ -206,20 +202,52 @@ FCOS是一个全卷积的one-stage目标检测模型,相比其他目标检测 **表 2** 训练结果展示表 - -| NAME | CPU_Type | Acc@1 | FPS | Epochs | AMP_Type | Loss_Scale | Torch_Version | +| NAME | FPS | Epochs | AMP_Type | Loss_Scale | Torch_Version | |:------:|:--------:|:-----:|:-----:|:------:| :------: |:----------:|:----------:| -| 1p-竞品V | X86 | 12.6 | 19.2 | 1 | O1 | dynamic | 1.5 | -| 8p-竞品V | X86 | 36.2 | 102.0 | 12 | O1 | dynamic | 1.5 | -| 1p-Npu | 非ARM | 16.4 | 3.19 | 1 | O1 | 32.0 | 1.8 | -| 8p-Npu | 非ARM | 36.2 | 44.81 | 12 | O1 | 32.0 | 1.8 | -| 8p-Npu | ARM | 36.2 | 35.69 | 12 | O1 | 32.0 | 1.8 | - +| 1p-竞品V | 19.2 | 1 | O1 | dynamic | 1.5 | +| 8p-竞品V | 102.0 | 12 | O1 | dynamic | 1.5 | +| 8p-竞品A | 197 | 12 | O1 | dynamic | 1.11 | +| 8p-Atlas 800T A2 | 197 | 12 | O1 | 32.0 | 1.11 | + +**表 3** 8p-竞品A 12 epochs 训练精度数据 + +``` +Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.354 +Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.551 +Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.376 +Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.206 +Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.389 +Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.452 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.527 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.527 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.527 +Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.341 +Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.575 +Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.672 +``` + +**表 4** 8p-Atlas 800T A2 12 epochs 训练精度数据 + +``` +Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.348 +Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=1000 ] = 0.534 +Average Precision (AP) @[ IoU=0.75 | area= all | maxDets=1000 ] = 0.367 +Average Precision (AP) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.192 +Average Precision (AP) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.383 +Average Precision (AP) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.450 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.512 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=300 ] = 0.512 +Average Recall (AR) @[ IoU=0.50:0.95 | area= all | maxDets=1000 ] = 0.512 +Average Recall (AR) @[ IoU=0.50:0.95 | area= small | maxDets=1000 ] = 0.313 +Average Recall (AR) @[ IoU=0.50:0.95 | area=medium | maxDets=1000 ] = 0.558 +Average Recall (AR) @[ IoU=0.50:0.95 | area= large | maxDets=1000 ] = 0.658 +``` # 版本说明 ## 变更 +2024.09.10: 优化性能。 2022.12.21:更新Readme。 diff --git a/PyTorch/contrib/cv/detection/FCOS/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py b/PyTorch/contrib/cv/detection/FCOS/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py index 62cedc7fe2..28fedd2755 100644 --- a/PyTorch/contrib/cv/detection/FCOS/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py +++ b/PyTorch/contrib/cv/detection/FCOS/configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py @@ -98,7 +98,7 @@ train_pipeline = [ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True), dict(type='RandomFlip', flip_ratio=0.5), dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=1344), # change 32 to 1344 + dict(type='Pad', size_divisor=32), # old is 1344 dict(type='DefaultFormatBundle'), dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']), ] @@ -112,20 +112,20 @@ test_pipeline = [ dict(type='Resize', keep_ratio=True), dict(type='RandomFlip'), dict(type='Normalize', **img_norm_cfg), - dict(type='Pad', size_divisor=1344), # change 32 toto 1344 + dict(type='Pad', size_divisor=32), # old is 1344 dict(type='ImageToTensor', keys=['img']), dict(type='Collect', keys=['img']), ]) ] data = dict( - samples_per_gpu=2, # change 4 to 2 + samples_per_gpu=4, workers_per_gpu=4, train=dict(pipeline=train_pipeline), val=dict(pipeline=test_pipeline), test=dict(pipeline=test_pipeline)) # optimizer optimizer = dict( - lr=0.01, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.)) + type='SGD', lr=0.01, paramwise_cfg=dict(bias_lr_mult=2., bias_decay_mult=0.)) optimizer_config = dict( _delete_=True, grad_clip=dict(max_norm=35, norm_type=2)) # learning policy @@ -135,14 +135,13 @@ lr_config = dict( warmup_iters=500, warmup_ratio=1.0 / 3, step=[8, 11]) -total_epochs = 12 +total_epochs = 1 max_step = None # add for print log log_config = dict( - interval=1, + interval=500, hooks=[ dict(type='TextLoggerHook'), - # dict(type='TensorboardLoggerHook') ]) amp = True # add for apex diff --git a/PyTorch/contrib/cv/detection/FCOS/mmcv_need/epoch_based_runner.py b/PyTorch/contrib/cv/detection/FCOS/mmcv_need/epoch_based_runner.py index 4f1b66c24d..85832d6fb4 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmcv_need/epoch_based_runner.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmcv_need/epoch_based_runner.py @@ -102,7 +102,7 @@ class EpochBasedRunner(BaseRunner): profile.end() self._iter += 1 # added by jyl - self.logger.info('FPS: ' + str(self.samples_per_gpu * self.num_of_gpus / self.iter_timer_hook.time_all * (self._max_iters - 5))) + self.logger.info('FPS: ' + str(self.samples_per_gpu * self.num_of_gpus * self._inner_iter / self.iter_timer_hook.time_all)) self.call_hook('after_train_epoch') self._epoch += 1 diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py index a9992acbe1..52649f30af 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py @@ -97,12 +97,14 @@ def train_detector(model, len(cfg.npu_ids), dist=distributed, shuffle=cfg.data.shuffle, - seed=cfg.seed) for ds in dataset + seed=cfg.seed, + persistent_workers=True, + ) for ds in dataset ] # add apex optimizer = build_optimizer(model.npu(), cfg.optimizer) - model, optimizer = amp.initialize(model.npu(), optimizer, opt_level=cfg.opt_level, loss_scale=cfg.loss_scale) + model, optimizer = amp.initialize(model.npu(), optimizer, opt_level=cfg.opt_level, loss_scale=cfg.loss_scale, combine_grad=True) # put model on npus if distributed: @@ -128,6 +130,7 @@ def train_detector(model, logger=logger, meta=meta, max_iters=cfg.max_step, + samples_per_gpu=cfg.data.samples_per_gpu, num_of_gpus=world_size) # an ugly workaround to make .log and .log.json filenames the same runner.timestamp = timestamp diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py index 16d9ae34b8..fc19b9e71a 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py @@ -1,3 +1,35 @@ +# BSD 3-Clause License +# +# Copyright (c) 2024 xxxx +# All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + import copy import platform import random @@ -128,7 +160,7 @@ def build_dataloader(dataset, sampler=sampler, num_workers=num_workers, collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=False, + pin_memory=True, worker_init_fn=init_fn, **kwargs) diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/models/dense_heads/fcos_head.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/models/dense_heads/fcos_head.py index 343b36059a..9e569c32e3 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/models/dense_heads/fcos_head.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/models/dense_heads/fcos_head.py @@ -604,13 +604,19 @@ class FCOSHead(AnchorFreeHead): # if there are still more than one objects for a location, # we choose the one with minimal area - areas[inside_gt_bbox_mask == 0] = INF - areas[inside_regress_range == 0] = INF + areas = areas.masked_fill(inside_gt_bbox_mask == 0, INF) + areas = areas.masked_fill(inside_regress_range == 0, INF) min_area, min_area_inds = areas.min(dim=1) labels = gt_labels[min_area_inds] - labels[min_area == INF] = self.num_classes # set as BG - bbox_targets = bbox_targets[range(num_points), min_area_inds] + labels = labels.masked_fill(min_area == INF, self.num_classes) # set as BG + + base_step = bbox_targets.shape[1] + base_end = bbox_targets.shape[0] * bbox_targets.shape[1] + min_area_base = torch.arange(0, base_end, base_step, device=min_area_inds.device) + min_area_inds = min_area_inds + min_area_base + bbox_targets = bbox_targets.view(-1, bbox_targets.shape[2]) + bbox_targets = bbox_targets.index_select(0, min_area_inds) return labels, bbox_targets diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/models/losses/focal_loss.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/models/losses/focal_loss.py index 5bcbf80f15..5851de40a7 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/models/losses/focal_loss.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/models/losses/focal_loss.py @@ -83,8 +83,8 @@ def _sigmoid_focal_loss(pred, pred = torch_npu.npu_format_cast(pred, 0) p = torch.sigmoid(pred) - targets_zero = torch.zeros(p.shape[0], p.shape[1] + 1).int().npu() - target = targets_zero.scatter_(1, target.unsqueeze(1), 1).float()[:,:80].npu() + targets_zero = torch.zeros([p.shape[0], p.shape[1] + 1], dtype=torch.float, device='npu') + target = targets_zero.scatter_(1, target.unsqueeze(1), 1)[:,:80] ce_loss = F.binary_cross_entropy_with_logits(pred, target, reduction="none") p_t = p * target + (1 - p) * (1 - target) diff --git a/PyTorch/contrib/cv/detection/FCOS/requirements/build.txt b/PyTorch/contrib/cv/detection/FCOS/requirements/build.txt index 8155829859..2ee9fc17c5 100644 --- a/PyTorch/contrib/cv/detection/FCOS/requirements/build.txt +++ b/PyTorch/contrib/cv/detection/FCOS/requirements/build.txt @@ -1,3 +1,3 @@ # These must be installed before building mmdetection -cython +cython==0.29.33 numpy diff --git a/PyTorch/contrib/cv/detection/FCOS/requirements/optional.txt b/PyTorch/contrib/cv/detection/FCOS/requirements/optional.txt index 6654b5b96f..2f349474e7 100644 --- a/PyTorch/contrib/cv/detection/FCOS/requirements/optional.txt +++ b/PyTorch/contrib/cv/detection/FCOS/requirements/optional.txt @@ -1,4 +1,4 @@ -albumentations>=0.3.2 +albumentations<1.4.0 cityscapesscripts imagecorruptions mmlvis diff --git a/PyTorch/contrib/cv/detection/FCOS/test/env_npu.sh b/PyTorch/contrib/cv/detection/FCOS/test/env_npu.sh index c6d6103967..7f8dc6aaee 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/env_npu.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/env_npu.sh @@ -21,6 +21,8 @@ export ASCEND_GLOBAL_LOG_LEVEL=3 export ASCEND_GLOBAL_EVENT_ENABLE=0 #设置是否开启taskque,0-关闭/1-开启 export TASK_QUEUE_ENABLE=1 +#读取队列激活 +export READ_QUEUE_ACTIVATE #设置是否开启PTCopy,0-关闭/1-开启 export PTCOPY_ENABLE=1 #设置是否开启combined标志,0-关闭/1-开启 diff --git a/PyTorch/contrib/cv/detection/FCOS/test/train_full_1p.sh b/PyTorch/contrib/cv/detection/FCOS/test/train_full_1p.sh index 05eefed716..bb5741c8da 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/train_full_1p.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/train_full_1p.sh @@ -24,7 +24,7 @@ Network="FCOS" #训练batch_size,,需要模型审视修改 batch_size=2 device_id=0 - +total_epochs=1 #参数校验,不需要修改 for para in $* do @@ -32,6 +32,8 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --total_epochs* ]];then + total_epochs=`echo ${para#*=}` fi done @@ -76,7 +78,7 @@ fi #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 PORT=29880 ./tools/dist_train.sh ./configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py 1 \ --npu-ids 0 \ - --cfg-options optimizer.lr=0.00125 data.samples_per_gpu=16 data_root=$data_path \ + --cfg-options optimizer.lr=0.00125 data.samples_per_gpu=16 total_epochs=$total_epochs data_root=$data_path \ --seed 0 \ --opt-level O1 \ --loss-scale 32.0 > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & diff --git a/PyTorch/contrib/cv/detection/FCOS/test/train_full_8p.sh b/PyTorch/contrib/cv/detection/FCOS/test/train_full_8p.sh index c7f1332c65..e65bc7def9 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/train_full_8p.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/train_full_8p.sh @@ -25,7 +25,7 @@ Network="FCOS" #训练batch_size,,需要模型审视修改 batch_size=16 device_id=0 - +total_epochs=1 #参数校验,不需要修改 for para in $* do @@ -33,6 +33,8 @@ do device_id=`echo ${para#*=}` elif [[ $para == --data_path* ]];then data_path=`echo ${para#*=}` + elif [[ $para == --total_epochs* ]];then + total_epochs=`echo ${para#*=}` fi done @@ -76,7 +78,7 @@ fi #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 PORT=29888 ./tools/dist_train.sh ./configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py 8 \ --npu-ids 0 \ - --cfg-options optimizer.lr=0.01 data_root=$data_path \ + --cfg-options optimizer.lr=0.01 total_epochs=$total_epochs data_root=$data_path \ --seed 0 \ --opt-level O1 \ --loss-scale 32.0 > ${test_path_dir}/output/${ASCEND_DEVICE_ID}/train_${ASCEND_DEVICE_ID}.log 2>&1 & diff --git a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_1p.sh b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_1p.sh index ae3a009398..aee04d152f 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_1p.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_1p.sh @@ -25,7 +25,8 @@ Network="FCOS" #训练batch_size,,需要模型审视修改 batch_size=2 device_id=0 - +data_shuffle=True +total_epochs=1 #参数校验,不需要修改 for para in $* do @@ -37,6 +38,8 @@ do batch_size=`echo ${para#*=}` elif [[ $para == --data_shuffle* ]];then data_shuffle=`echo ${para#*=}` + elif [[ $para == --total_epochs* ]];then + total_epochs=`echo ${para#*=}` fi done @@ -82,7 +85,7 @@ fi #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 PORT=29880 ./tools/dist_train.sh ./configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py 1 \ --npu-ids ${device_id} \ - --cfg-options optimizer.lr=0.00125 data.samples_per_gpu=${batch_size} total_epochs=1 max_step=2000 data_root=$data_path \ + --cfg-options optimizer.lr=0.00125 data.samples_per_gpu=${batch_size} total_epochs=$total_epochs data_root=$data_path \ --seed 0 \ --data-shuffle ${data_shuffle} \ --opt-level O1 \ diff --git a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_4p.sh b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_4p.sh index 253dfa7609..224664589c 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_4p.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_4p.sh @@ -25,6 +25,8 @@ Network="FCOS" #训练batch_size,,需要模型审视修改 batch_size=2 device_id=0 +data_shuffle=True +total_epochs=1 #参数校验,不需要修改 for para in $* do @@ -36,6 +38,8 @@ do batch_size=`echo ${para#*=}` elif [[ $para == --data_shuffle* ]];then data_shuffle=`echo ${para#*=}` + elif [[ $para == --total_epochs* ]];then + total_epochs=`echo ${para#*=}` fi done @@ -81,7 +85,7 @@ fi #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 PORT=29888 ./tools/dist_train.sh ./configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py 4 \ --npu-ids 0 \ - --cfg-options optimizer.lr=0.01 data.samples_per_gpu=${batch_size} total_epochs=1 data_root=$data_path \ + --cfg-options optimizer.lr=0.01 data.samples_per_gpu=${batch_size} total_epochs=$total_epochs data_root=$data_path \ --seed 0 \ --data-shuffle ${data_shuffle} \ --opt-level O1 \ diff --git a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_8p.sh b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_8p.sh index e06f562e1d..8c6e67d5ef 100644 --- a/PyTorch/contrib/cv/detection/FCOS/test/train_performance_8p.sh +++ b/PyTorch/contrib/cv/detection/FCOS/test/train_performance_8p.sh @@ -25,6 +25,8 @@ Network="FCOS" #训练batch_size,,需要模型审视修改 batch_size=2 device_id=0 +data_shuffle=True +total_epochs=1 #参数校验,不需要修改 for para in $* do @@ -36,6 +38,8 @@ do batch_size=`echo ${para#*=}` elif [[ $para == --data_shuffle* ]];then data_shuffle=`echo ${para#*=}` + elif [[ $para == --total_epochs* ]];then + total_epochs=`echo ${para#*=}` fi done @@ -81,7 +85,7 @@ fi #执行训练脚本,以下传参不需要修改,其他需要模型审视修改 PORT=29888 ./tools/dist_train.sh ./configs/fcos/fcos_r50_caffe_fpn_4x4_1x_coco.py 8 \ --npu-ids 0 \ - --cfg-options data.samples_per_gpu=${batch_size} optimizer.lr=0.01 total_epochs=1 data_root=$data_path \ + --cfg-options optimizer.lr=0.01 data.samples_per_gpu=${batch_size} total_epochs=$total_epochs data_root=$data_path \ --seed 0 \ --data-shuffle ${data_shuffle} \ --opt-level O1 \ diff --git a/PyTorch/contrib/cv/detection/FCOS/tools/train.py b/PyTorch/contrib/cv/detection/FCOS/tools/train.py index b4e942a818..fe82c0dfde 100644 --- a/PyTorch/contrib/cv/detection/FCOS/tools/train.py +++ b/PyTorch/contrib/cv/detection/FCOS/tools/train.py @@ -39,6 +39,10 @@ import warnings import ast import mmcv import torch +import torch_npu +from torch_npu.contrib import transfer_to_npu +torch_npu.npu.set_compile_mode(jit_compile=False) + from mmcv import Config, DictAction from mmcv.runner import get_dist_info, init_dist from mmcv.utils import get_git_hash -- Gitee From be7902567631002dfff8740c79a7c1c735fc2cea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=83=A1=E5=AE=8F=E7=AC=8B?= Date: Mon, 4 Nov 2024 11:36:06 +0800 Subject: [PATCH 2/2] =?UTF-8?q?=E3=80=90PyTorch=E3=80=91=E3=80=90contrib?= =?UTF-8?q?=E3=80=91=E3=80=90FCOS=E3=80=91=E6=80=A7=E8=83=BD=E4=BC=98?= =?UTF-8?q?=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../cv/detection/YOLOV9_for_PyTorch/README.md | 45 ++++++++--------- .../cv/detection/YOLOV9_for_PyTorch/train.py | 3 +- .../YOLOV9_for_PyTorch/train_dual.py | 2 +- .../YOLOV9_for_PyTorch/train_triple.py | 3 +- .../YOLOV9_for_PyTorch/utils/dataloaders.py | 3 +- .../YOLOV9_for_PyTorch/utils/tal/assigner.py | 49 ++++++++++++++++--- .../cv/detection/FCOS/mmdet/apis/train.py | 8 ++- .../detection/FCOS/mmdet/datasets/builder.py | 34 +------------ 8 files changed, 76 insertions(+), 71 deletions(-) diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/README.md b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/README.md index e5756d9013..02b0c10a37 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/README.md +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/README.md @@ -64,19 +64,19 @@ YOLOv9融合了深度学习技术和架构设计的进步,以在对象检测 **表 1** 版本支持表 - | Torch_Version | 三方库依赖版本 | - | :--------: | :----------------------------------------------------------: | - | PyTorch 1.11 | torchvision==0.12.0, torchvision_npu==0.12.0 | - | PyTorch 2.1 | torchvision==0.16.0, torchvision_npu==0.16.0 | + | Torch_Version | 三方库依赖版本 | 说明 | + | :--------: | :--------------------------------------------: | :--------: | + | PyTorch 1.11 | torchvision==0.12.0, torchvision_npu==0.12.0 | 优先 | + | PyTorch 2.1 | torchvision==0.16.0, torchvision_npu==0.16.0, attr, attrs | 对比1.11性能劣化3倍左右 | **表 2** 昇腾软件版本支持表 - | 软件类型 | 支持版本 | - | :----------------: | :----------------------------------------------------------: | - | FrameworkPTAdapter | 6.0.RC1/在研版本 | - | CANN | 8.0.RC1/在研版本 | - | 昇腾NPU固件 | 24.1.RC1/在研版本 | - | 昇腾NPU驱动 | 24.1.RC1/在研版本 | + | 软件类型 | 支持版本 | 说明 | + | :----------------: | :-------------: | :--------: | + | FrameworkPTAdapter | 6.0.RC1/在研版本 | | + | CANN | 8.0.RC2/在研版本 | 8.0.RC1性能略微劣化 | + | 昇腾NPU固件 | 24.1.RC1/在研版本 | | + | 昇腾NPU驱动 | 24.1.RC1/在研版本 | | - 安装依赖。 @@ -86,7 +86,7 @@ YOLOv9融合了深度学习技术和架构设计的进步,以在对象检测 pip install -r 1.11_requirements.txt # pip install -r 2.1_requirements.txt ``` - > **说明:** + > **说明:** > 只需执行一条对应的PyTorch版本依赖安装命令。 - 安装torchvision(可选)。 @@ -101,8 +101,8 @@ YOLOv9融合了深度学习技术和架构设计的进步,以在对象检测 pip uninstall torchvision pip install dist/torchvision*.whl ``` - > **说明:** - > 训练报错且报错信息为torchvision.ops.nms时,需要源码安装torchvision。 + > **说明:** + > 如若训练报错且报错信息为torchvision.ops.nms时,源码安装torchvision解决。 - 安装torchvision_npu。 @@ -177,27 +177,27 @@ YOLOv9融合了深度学习技术和架构设计的进步,以在对象检测 --cfg //训练配置 --name //训练名 --optimizer //优化器类型,可选'SGD'、'Adam'、'AdamW'、'RMSProp' - --patience //默认值100,此epochs之后,若AP值无法提升则提前终止训练 + --patience //默认值100,若AP值无法提升则提前终止训练,受限环境未验证有效性。 ``` - > **说明:** + > **说明:** > 训练完成后,权重文件保存在runs/train,并输出模型训练精度和性能信息。 - > 优化器选'AdamW',AP值会比'SGD'更好。 + > 优化器选AdamW精度更优,需同步修改hyp.scratch-high.yaml的lr0值,SGD=0.01、Adam/AdamW=0.001。 > 第一个epoch会有大量Gradient overflow,已知现象,不必在意。 ## 训练结果展示 **表 3** 训练性能 -| NAME | 卡数 | epochs | Train Times(1 epoch) | Val Times(1 epoch) | Torch_Version | -|:--------:| :---: | :---: | :---: | :---: | :---: | -| 竞品A | 8p | 100 | 05:12 | 00:33 | 2.4 | -| Atlas 800T A2 | 8p | 100 | 07:23 | 01:44 | 1.11 | +| NAME | 卡数 | epochs | batch | FPS(image/s) | Train Times(1 epoch) | Val Times(1 epoch) | Torch_Version | +|:--------:| :---: | :---: | :---: | :---: | :---: | :---: | :---: | +| 竞品A | 8p | 100 | 256 | 380 | 05:12 | 00:33 | 2.4 | +| Atlas 800T A2 | 8p | 100 | 256 | 293 | 06:44 | 01:44 | 1.11 | **表 4** 训练精度 Atlas 800T A2 -yolov9-c 100 epochs训练精度数据 +yolov9-c 100 epochs+SGD训练精度数据 ``` Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.468 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.628 @@ -214,7 +214,7 @@ yolov9-c 100 epochs训练精度数据 ``` 竞品A -yolov9-c 100 epochs训练精度数据 +yolov9-c 100 epochs+SGD训练精度数据 ``` Average Precision (AP) @[ IoU=0.50:0.95 | area= all | maxDets=100 ] = 0.471 Average Precision (AP) @[ IoU=0.50 | area= all | maxDets=100 ] = 0.629 @@ -237,6 +237,7 @@ yolov9-c 100 epochs训练精度数据 # 变更说明 +2024.10.13:性能略微提升。 2024.8.13:首次发布。 # FAQ diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train.py b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train.py index 9bc13911cd..5a99aef82e 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train.py +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd import argparse import math import os @@ -524,7 +525,7 @@ def main(opt, callbacks=Callbacks()): assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' torch.cuda.set_device(LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK) - dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") + dist.init_process_group(backend="hccl" if dist.is_nccl_available() else "gloo") # Train if not opt.evolve: diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_dual.py b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_dual.py index c4f57dcd55..771e0e0770 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_dual.py +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_dual.py @@ -539,7 +539,7 @@ def main(opt, callbacks=Callbacks()): assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' torch.cuda.set_device(LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK) - dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") + dist.init_process_group(backend="hccl" if dist.is_nccl_available() else "gloo") # Train if not opt.evolve: diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_triple.py b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_triple.py index 2da45a4632..2b5cadb3f2 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_triple.py +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/train_triple.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd import argparse import math import os @@ -526,7 +527,7 @@ def main(opt, callbacks=Callbacks()): assert torch.cuda.device_count() > LOCAL_RANK, 'insufficient CUDA devices for DDP command' torch.cuda.set_device(LOCAL_RANK) device = torch.device('cuda', LOCAL_RANK) - dist.init_process_group(backend="nccl" if dist.is_nccl_available() else "gloo") + dist.init_process_group(backend="hccl" if dist.is_nccl_available() else "gloo") # Train if not opt.evolve: diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/dataloaders.py b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/dataloaders.py index 001f9fd00b..1ff7001842 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/dataloaders.py +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/dataloaders.py @@ -38,6 +38,7 @@ VID_FORMATS = 'asf', 'avi', 'gif', 'm4v', 'mkv', 'mov', 'mp4', 'mpeg', 'mpg', 't LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1)) RANK = int(os.getenv('RANK', -1)) PIN_MEMORY = str(os.getenv('PIN_MEMORY', True)).lower() == 'true' # global pin_memory for dataloaders +PERSISTENT_WORKERS = str(os.getenv('PERSISTENT_WORKERS', True)).lower() == 'true' # Get orientation exif tag for orientation in ExifTags.TAGS.keys(): @@ -149,7 +150,7 @@ def create_dataloader(path, collate_fn=LoadImagesAndLabels.collate_fn4 if quad else LoadImagesAndLabels.collate_fn, worker_init_fn=seed_worker, generator=generator, - persistent_workers=True, + persistent_workers=PERSISTENT_WORKERS, ), dataset diff --git a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/tal/assigner.py b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/tal/assigner.py index ddbf1ad977..4061d924e9 100644 --- a/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/tal/assigner.py +++ b/PyTorch/built-in/cv/detection/YOLOV9_for_PyTorch/utils/tal/assigner.py @@ -1,6 +1,39 @@ +# BSD 3-Clause License +# +# Copyright (c) 2024 xxxx +# All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# ============================================================================ + import torch import torch.nn as nn import torch.nn.functional as F +import torch_npu from utils.metrics import bbox_iou @@ -116,13 +149,15 @@ class TaskAlignedAssigner(nn.Module): return mask_pos, align_metric, overlaps def get_box_metrics(self, pd_scores, pd_bboxes, gt_labels, gt_bboxes): - - gt_labels = gt_labels.to(torch.long) # b, max_num_obj, 1 - ind = torch.zeros([2, self.bs, self.n_max_boxes], dtype=torch.long) # 2, b, max_num_obj - ind[0] = torch.arange(end=self.bs).view(-1, 1).repeat(1, self.n_max_boxes) # b, max_num_obj - ind[1] = gt_labels.squeeze(-1) # b, max_num_obj - # get the scores of each grid for each gt cls - bbox_scores = pd_scores[ind[0], :, ind[1]] # b, max_num_obj, h*w + bs = pd_scores.shape[0] + hw = pd_scores.shape[1] + ind0 = torch.arange(start=0, end=self.bs*self.num_classes, step=self.num_classes, device=gt_labels.device, dtype=torch.int) + ind0 = ind0.view(-1, 1).repeat(1, self.n_max_boxes).view(-1) # b*max_num_obj + ind1 = gt_labels.view(-1).to(torch.int) # b*max_num_obj + inds = ind0 + ind1 # b*max_num_obj + # b, h*w, 80 --> b*80, h*w --> b, max_num_obj, h*w + bbox_scores = torch_npu.npu_confusion_transpose(pd_scores, [0, 2, 1], (bs * self.num_classes, hw), True) + bbox_scores = bbox_scores.index_select(0, inds).view(bs, self.n_max_boxes, hw) overlaps = bbox_iou(gt_bboxes.unsqueeze(2), pd_bboxes.unsqueeze(1), xywh=False, CIoU=True).squeeze(3).clamp(0) align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py index 52649f30af..751e8726d8 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/apis/train.py @@ -1,8 +1,8 @@ # BSD 3-Clause License # -# Copyright (c) 2017 xxxx +# Copyright (c) 2024 xxxx # All rights reserved. -# Copyright 2021 Huawei Technologies Co., Ltd +# Copyright 2024 Huawei Technologies Co., Ltd # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are met: @@ -97,9 +97,7 @@ def train_detector(model, len(cfg.npu_ids), dist=distributed, shuffle=cfg.data.shuffle, - seed=cfg.seed, - persistent_workers=True, - ) for ds in dataset + seed=cfg.seed) for ds in dataset ] # add apex diff --git a/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py b/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py index fc19b9e71a..16d9ae34b8 100644 --- a/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py +++ b/PyTorch/contrib/cv/detection/FCOS/mmdet/datasets/builder.py @@ -1,35 +1,3 @@ -# BSD 3-Clause License -# -# Copyright (c) 2024 xxxx -# All rights reserved. -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions are met: -# -# * Redistributions of source code must retain the above copyright notice, this -# list of conditions and the following disclaimer. -# -# * Redistributions in binary form must reproduce the above copyright notice, -# this list of conditions and the following disclaimer in the documentation -# and/or other materials provided with the distribution. -# -# * Neither the name of the copyright holder nor the names of its -# contributors may be used to endorse or promote products derived from -# this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" -# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE -# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL -# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR -# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER -# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, -# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -# ============================================================================ - import copy import platform import random @@ -160,7 +128,7 @@ def build_dataloader(dataset, sampler=sampler, num_workers=num_workers, collate_fn=partial(collate, samples_per_gpu=samples_per_gpu), - pin_memory=True, + pin_memory=False, worker_init_fn=init_fn, **kwargs) -- Gitee