From e999fc76130a561acba053495a0f7416d442d2c7 Mon Sep 17 00:00:00 2001 From: GTS-AI-wangchao Date: Mon, 28 Oct 2024 10:50:52 +0800 Subject: [PATCH] optimize tpvformer model performance --- .../TPVFormer_for_PyTorch/README.md | 51 +- .../config/_base_/optimizer.py | 2 +- .../config/tpv_lidarseg.py | 6 + .../config/tpv_lidarseg_dim64.py | 6 + .../config/tpv_lidarseg_dim96.py | 6 + .../mmcv_need/distributed.py | 165 +++++ .../mmcv_need/modulated_deform_conv.py | 154 ++++ .../mmcv_need/optimizer.py | 561 +++++++++++++++ .../mmcv_replace/ops/modulated_deform_conv.py | 439 ------------ .../mmdet_need/resnet.py | 672 ++++++++++++++++++ .../TPVFormer_for_PyTorch/requirements.txt | 2 +- .../TPVFormer_for_PyTorch/test/env_npu.sh | 13 +- .../tpvformer10/modules/encoder.py | 5 +- .../tpvformer10/modules/tpvformer_layer.py | 27 +- .../TPVFormer_for_PyTorch/train.py | 3 + 15 files changed, 1630 insertions(+), 482 deletions(-) create mode 100644 PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/distributed.py create mode 100644 PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/modulated_deform_conv.py create mode 100644 PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/optimizer.py delete mode 100644 PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_replace/ops/modulated_deform_conv.py create mode 100644 PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmdet_need/resnet.py diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/README.md b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/README.md index 3d10fa7817..99279e208a 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/README.md +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/README.md @@ -47,10 +47,10 @@ | 软件类型 | 支持版本 | |:------------------:|:--------:| - | FrameworkPTAdapter | 6.0.RC2 | - | CANN | 8.0.RC2 | - | 昇腾NPU固件 | 24.0.RC2 | - | 昇腾NPU驱动 | 24.0.RC2 | + | FrameworkPTAdapter | 6.0.RC3 | + | CANN | 8.0.RC3 | + | 昇腾NPU固件 | 24.1.RC3 | + | 昇腾NPU驱动 | 24.1.RC3 | ## 安装模型环境 @@ -58,17 +58,17 @@ **表 2** 版本支持表 - | 三方库 | 支持版本 | - |:--------------:|:------:| - | PyTorch | 2.1 | - | Mx_Driving-Accelerator | latest | - | mmcv | 1.x | - | mmdet | 2.28.2 | - | mmsegmentation | 0.30.0 | + | 三方库 | 支持版本 | + |:----------------------:|:------:| + | PyTorch | 2.1 | + | mx-driving | latest | + | mmcv | 1.x | + | mmdet | 2.28.2 | + | mmsegmentation | 0.30.0 | - 安装Mx_Driving-Accelerator - 请参考昇腾[mxDriving](https://gitee.com/ascend/mxDriving)代码仓说明编译安装Mx_Driving-Accelerator + 请参考昇腾[mxDriving](https://gitee.com/ascend/mxDriving)代码仓说明编译安装mx-driving 【注意】当前版本配套mxDriving RC3及以上版本,历史mxDriving版本需要model仓代码回退到git reset --hard 91ac141ecfe5872f4835eef6aa4662f46ede80c3 - 安装基础依赖 @@ -76,21 +76,29 @@ 在模型源码包根目录下执行命令,安装模型需要的依赖。 ``` - pip install opencv-python==4.9.0.80 + pip install opencv-python==4.10.0.84 pip install -r requirements.txt ``` -- 安装mmcv - - 在mmcv官网获取[mmcv 1.x](https://github.com/open-mmlab/mmcv/tree/1.x)分支源码,解压至`$YOURMMCVPATH`。将`mmcv_replace`中的文件拷贝到`$YOURMMCVPATH/mmcv`覆盖原文件。运行以下命令 +- 源码安装 mmcv 1.x ``` - cd $YOURMMCVPATH + git clone -b 1.x https://github.com/open-mmlab/mmcv.git + cp -f mmcv_need/distributed.py mmcv/mmcv/parallel/distributed.py + cp -f mmcv_need/modulated_deform_conv.py mmcv/mmcv/ops/modulated_deform_conv.py + cp -f mmcv_need/optimizer.py mmcv/mmcv/runner/hooks/optimizer.py + cd mmcv MMCV_WITH_OPS=1 FORCE_NPU=1 python setup.py install ``` -- 安装mmdet和mmsegmentation +- 源码安装 mmdet 2.28.2 + ``` + git clone -b v2.28.2 https://github.com/open-mmlab/mmdetection.git + cp -f mmdet_need/resnet.py mmdetection/mmdet/models/backbones/resnet.py + cd mmdetection + pip install -e . + ``` +- 安装mmsegmentation ``` - pip install mmdet==2.28.2 pip install mmsegmentation==0.30.0 ``` @@ -165,8 +173,8 @@ TPVFormer_for_PyTorch/data |:--------:|----|:------:|:----:|:----------:| | 竞品A | 1p | - | 0.71 | 1 | | 竞品A | 8p | 54.498 | 8.08 | 24 | -| Atlas 800T A2 | 1p | - | 0.4 | 1 | -| Atlas 800T A2 | 8p | 54.344 | 3.05 | 24 | +| Atlas 800T A2 | 1p | - | 0.7 | 1 | +| Atlas 800T A2 | 8p | 68.661 | 5.63 | 24 | # 公网地址说明 @@ -174,6 +182,7 @@ TPVFormer_for_PyTorch/data # 变更说明 2024.05.13:首次发布。 +2024.10.27:性能优化。 ## FAQ 暂无。 diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/_base_/optimizer.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/_base_/optimizer.py index 90beb1b2bd..ff4b69ddc9 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/_base_/optimizer.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/_base_/optimizer.py @@ -1,5 +1,5 @@ optimizer = dict( - type='AdamW', + type='NpuFusedAdamW', lr=2e-4, paramwise_cfg=dict( custom_keys={ diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg.py index ef6adca9c1..871c6d27a2 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg.py @@ -35,6 +35,9 @@ nbr_class = 17 self_cross_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', @@ -76,6 +79,9 @@ self_cross_layer = dict( self_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim64.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim64.py index 797e33479f..534c8b8b3d 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim64.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim64.py @@ -35,6 +35,9 @@ nbr_class = 17 self_cross_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', @@ -76,6 +79,9 @@ self_cross_layer = dict( self_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim96.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim96.py index c34389bccf..8aa84c74a3 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim96.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/config/tpv_lidarseg_dim96.py @@ -35,6 +35,9 @@ nbr_class = 17 self_cross_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', @@ -76,6 +79,9 @@ self_cross_layer = dict( self_layer = dict( type='TPVFormerLayer', + tpv_h=tpv_h_, + tpv_w=tpv_w_, + tpv_z=tpv_z_, attn_cfgs=[ dict( type='TPVCrossViewHybridAttention', diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/distributed.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/distributed.py new file mode 100644 index 0000000000..20a44be3d3 --- /dev/null +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/distributed.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +from typing import Any, List, Tuple + +import torch +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from mmcv import print_log +from mmcv.utils import TORCH_VERSION, digit_version +from .scatter_gather import ScatterInputs, scatter_kwargs + + +class MMDistributedDataParallel(DistributedDataParallel): + """The DDP module that supports DataContainer. + + MMDDP has two main differences with PyTorch DDP: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data. + - It implement two APIs ``train_step()`` and ``val_step()``. + """ + + def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs, + device_id: int) -> Tuple[tuple, tuple]: + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs, + device_ids: List[int]) -> Tuple[tuple, tuple]: + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def _run_ddp_forward(self, *inputs, **kwargs) -> Any: + """Processes inputs and runs ``self.module.forward``. + + Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward`` + and deprecates using ``DistributedDataParallel.to_kwargs`` to + process inputs, which leads to inputs cannot be processed by + :meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore, + ``MMDistributedDataParallel`` overrides this method to call + :meth:`to_kwargs` explicitly. + + Returns: + Any: Forward result of :attr:`module`. + """ + module_to_run = self.module + + if self.device_ids: + inputs, kwargs = self.to_kwargs( # type: ignore + inputs, kwargs, self.device_ids[0]) + return module_to_run(*inputs[0], **kwargs[0]) # type: ignore + else: + return module_to_run(*inputs, **kwargs) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/modulated_deform_conv.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/modulated_deform_conv.py new file mode 100644 index 0000000000..6667f55ad6 --- /dev/null +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/modulated_deform_conv.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +import math +from typing import Optional, Tuple, Union + +import torch +import torch_npu +import torch.nn as nn +from torch.nn.modules.utils import _pair, _single +from mmcv.utils import deprecated_api_warning +from mx_driving.fused import modulated_deform_conv2d, ModulatedDeformConv2dFunction + +from ..cnn import CONV_LAYERS +from ..utils import print_log + + +class ModulatedDeformConv2d(nn.Module): + + @deprecated_api_warning({"deformable_groups": "deform_groups"}, cls_name="ModulatedDeformConv2d") + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + deform_groups: int = 1, + bias: Union[bool, str] = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x: torch.Tensor, offset: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return modulated_deform_conv2d( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deform_groups, + ) + + +@CONV_LAYERS.register_module("DCNv2") +class ModulatedDeformConv2dPack(ModulatedDeformConv2d): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv + layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True, + ) + self.init_weights() + + def init_weights(self) -> None: + super().init_weights() + if hasattr(self, "conv_offset"): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore + out = self.conv_offset(x) + len1 = ((out.shape[1] + 2) // 3) * 2 + len2 = out.shape[1] - len1 + offset, mask = torch.split(out, [len1, len2], dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv2d( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deform_groups, + ) + + # pylint: disable=huawei-too-many-arguments + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if prefix + "conv_offset.weight" not in state_dict and prefix[:-1] + "_offset.weight" in state_dict: + state_dict[prefix + "conv_offset.weight"] = state_dict.pop(prefix[:-1] + "_offset.weight") + if prefix + "conv_offset.bias" not in state_dict and prefix[:-1] + "_offset.bias" in state_dict: + state_dict[prefix + "conv_offset.bias"] = state_dict.pop(prefix[:-1] + "_offset.bias") + + if version is not None and version > 1: + print_log(f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' "version 2.", logger="root") + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/optimizer.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/optimizer.py new file mode 100644 index 0000000000..f664af6d7d --- /dev/null +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_need/optimizer.py @@ -0,0 +1,561 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +# Modified: Replaced the fused optimizer by clip_grad_norm_fused_ +import copy +import logging +from collections import defaultdict +from itertools import chain +from typing import Optional, Union + +import torch.nn as nn +from torch import Tensor +from torch.nn.utils import clip_grad + +from mmcv.utils import (IS_NPU_AVAILABLE, TORCH_VERSION, _BatchNorm, + digit_version) +from ..dist_utils import allreduce_grads +from ..fp16_utils import LossScaler, wrap_fp16_model +from .hook import HOOKS, Hook + +try: + # If PyTorch version >= 1.6.0, torch.cuda.amp.GradScaler would be imported + # and used; otherwise, auto fp16 will adopt mmcv's implementation. + if IS_NPU_AVAILABLE: + from torch.npu.amp import GradScaler + else: + from torch.cuda.amp import GradScaler +except ImportError: + pass + + +@HOOKS.register_module() +class OptimizerHook(Hook): + """A hook contains custom operations for the optimizer. + + Args: + grad_clip (dict, optional): A config dict to control the clip_grad. + Default: None. + detect_anomalous_params (bool): This option is only used for + debugging which will slow down the training speed. + Detect anomalous parameters that are not included in + the computational graph with `loss` as the root. + There are two cases + + - Parameters were not used during + forward pass. + - Parameters were not used to produce + loss. + Default: False. + """ + + def __init__(self, + grad_clip: Optional[dict] = None, + detect_anomalous_params: bool = False): + self.grad_clip = grad_clip + self.detect_anomalous_params = detect_anomalous_params + + def clip_grads(self, params, runner): + params = list( + filter(lambda p: p.requires_grad and p.grad is not None, params)) + if len(params) > 0: + return runner.optimizer.clip_grad_norm_fused_(**self.grad_clip) + + def after_train_iter(self, runner): + runner.optimizer.zero_grad() + if self.detect_anomalous_params: + self.detect_anomalous_parameters(runner.outputs['loss'], runner) + runner.outputs['loss'].backward() + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters(), runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + + def detect_anomalous_parameters(self, loss: Tensor, runner) -> None: + logger = runner.logger + parameters_in_graph = set() + visited = set() + + def traverse(grad_fn): + if grad_fn is None: + return + if grad_fn not in visited: + visited.add(grad_fn) + if hasattr(grad_fn, 'variable'): + parameters_in_graph.add(grad_fn.variable) + parents = grad_fn.next_functions + if parents is not None: + for parent in parents: + grad_fn = parent[0] + traverse(grad_fn) + + traverse(loss.grad_fn) + for n, p in runner.model.named_parameters(): + if p not in parameters_in_graph and p.requires_grad: + logger.log( + level=logging.ERROR, + msg=f'{n} with shape {p.size()} is not ' + f'in the computational graph \n') + + +@HOOKS.register_module() +class GradientCumulativeOptimizerHook(OptimizerHook): + """Optimizer Hook implements multi-iters gradient cumulating. + + Args: + cumulative_iters (int, optional): Num of gradient cumulative iters. + The optimizer will step every `cumulative_iters` iters. + Defaults to 1. + + Examples: + >>> # Use cumulative_iters to simulate a large batch size + >>> # It is helpful when the hardware cannot handle a large batch size. + >>> loader = DataLoader(data, batch_size=64) + >>> optim_hook = GradientCumulativeOptimizerHook(cumulative_iters=4) + >>> # almost equals to + >>> loader = DataLoader(data, batch_size=256) + >>> optim_hook = OptimizerHook() + """ + + def __init__(self, cumulative_iters: int = 1, **kwargs): + super().__init__(**kwargs) + + assert isinstance(cumulative_iters, int) and cumulative_iters > 0, \ + f'cumulative_iters only accepts positive int, but got ' \ + f'{type(cumulative_iters)} instead.' + + self.cumulative_iters = cumulative_iters + self.divisible_iters = 0 + self.remainder_iters = 0 + self.initialized = False + + def has_batch_norm(self, module: nn.Module) -> bool: + if isinstance(module, _BatchNorm): + return True + for m in module.children(): + if self.has_batch_norm(m): + return True + return False + + def _init(self, runner): + if runner.iter % self.cumulative_iters != 0: + runner.logger.warning( + 'Resume iter number is not divisible by cumulative_iters in ' + 'GradientCumulativeOptimizerHook, which means the gradient of ' + 'some iters is lost and the result may be influenced slightly.' + ) + + if self.has_batch_norm(runner.model) and self.cumulative_iters > 1: + runner.logger.warning( + 'GradientCumulativeOptimizerHook may slightly decrease ' + 'performance if the model has BatchNorm layers.') + + self.divisible_iters = ( + runner.max_iters // self.cumulative_iters * self.cumulative_iters) + self.remainder_iters = runner.max_iters - self.divisible_iters + + self.initialized = True + + def _get_loss_factor(self, runner): + """Get loss division factor for the current iteration.""" + if runner.iter < runner.max_iters - self.remainder_iters: + loss_factor = self.cumulative_iters + else: + loss_factor = self.remainder_iters + runner.logger.warning( + f'Loss will be divided by {loss_factor} in the last ' + f'{self.remainder_iters} iterations because they are not ' + f'enough for {self.cumulative_iters} cumulative_iters.') + assert loss_factor > 0 + + return loss_factor + + def after_train_iter(self, runner): + if not self.initialized: + self._init(runner) + + loss = runner.outputs['loss'] / self._get_loss_factor(runner) + loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters(), runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + runner.optimizer.step() + runner.optimizer.zero_grad() + + +if (TORCH_VERSION != 'parrots' + and digit_version(TORCH_VERSION) >= digit_version('1.6.0')): + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): + """FP16 optimizer hook (using PyTorch's implementation). + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of GradScalar. + Defaults to 512. For Pytorch >= 1.6, mmcv uses official + implementation of GradScaler. If you use a dict version of + loss_scale to create GradScaler, please refer to: + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler + for the parameters. + + Examples: + >>> loss_scale = dict( + ... init_scale=65536.0, + ... growth_factor=2.0, + ... backoff_factor=0.5, + ... growth_interval=2000 + ... ) + >>> optimizer_hook = Fp16OptimizerHook(loss_scale=loss_scale) + """ + + def __init__(self, + grad_clip: Optional[dict] = None, + coalesce: bool = True, + bucket_size_mb: int = -1, + loss_scale: Union[float, str, dict] = 512., + distributed: bool = True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + self._scale_update_param = None + if loss_scale == 'dynamic': + self.loss_scaler = GradScaler() + elif isinstance(loss_scale, float): + self._scale_update_param = loss_scale + self.loss_scaler = GradScaler(init_scale=loss_scale) + elif isinstance(loss_scale, dict): + self.loss_scaler = GradScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner) -> None: + """Preparing steps before Mixed Precision Training.""" + # wrap model mode to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net: nn.Module, + fp32_weights: Tensor) -> None: + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net: nn.Module, + fp32_weights: Tensor) -> None: + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner) -> None: + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer to + https://pytorch.org/docs/stable/amp.html#torch.cuda.amp.GradScaler. + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients. + 3. Unscale the optimizer’s gradient tensors. + 4. Call optimizer.step() and update scale factor. + 5. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + + self.loss_scaler.scale(runner.outputs['loss']).backward() + self.loss_scaler.unscale_(runner.optimizer) + # grad clip + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters(), runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update({'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook(GradientCumulativeOptimizerHook, + Fp16OptimizerHook): + """Fp16 optimizer Hook (using PyTorch's implementation) implements + multi-iters gradient cumulating. + + If you are using PyTorch >= 1.6, torch.cuda.amp is used as the backend, + to take care of the optimization procedure. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def after_train_iter(self, runner) -> None: + if not self.initialized: + self._init(runner) + + loss = runner.outputs['loss'] / self._get_loss_factor(runner) + self.loss_scaler.scale(loss).backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + self.loss_scaler.unscale_(runner.optimizer) + + if self.grad_clip is not None: + grad_norm = self.clip_grads(runner.model.parameters(), runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + + # backward and update scaler + self.loss_scaler.step(runner.optimizer) + self.loss_scaler.update(self._scale_update_param) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() + +else: + + @HOOKS.register_module() + class Fp16OptimizerHook(OptimizerHook): # type: ignore + """FP16 optimizer hook (mmcv's implementation). + + The steps of fp16 optimizer is as follows. + 1. Scale the loss value. + 2. BP in the fp16 model. + 2. Copy gradients from fp16 model to fp32 weights. + 3. Update fp32 weights. + 4. Copy updated parameters from fp32 weights to fp16 model. + + Refer to https://arxiv.org/abs/1710.03740 for more details. + + Args: + loss_scale (float | str | dict): Scale factor configuration. + If loss_scale is a float, static loss scaling will be used with + the specified scale. If loss_scale is a string, it must be + 'dynamic', then dynamic loss scaling will be used. + It can also be a dict containing arguments of LossScaler. + Defaults to 512. + """ + + def __init__(self, + grad_clip: Optional[dict] = None, + coalesce: bool = True, + bucket_size_mb: int = -1, + loss_scale: Union[float, str, dict] = 512., + distributed: bool = True): + self.grad_clip = grad_clip + self.coalesce = coalesce + self.bucket_size_mb = bucket_size_mb + self.distributed = distributed + if loss_scale == 'dynamic': + self.loss_scaler = LossScaler(mode='dynamic') + elif isinstance(loss_scale, float): + self.loss_scaler = LossScaler( + init_scale=loss_scale, mode='static') + elif isinstance(loss_scale, dict): + self.loss_scaler = LossScaler(**loss_scale) + else: + raise ValueError('loss_scale must be of type float, dict, or ' + f'"dynamic", got {loss_scale}') + + def before_run(self, runner) -> None: + """Preparing steps before Mixed Precision Training. + + 1. Make a master copy of fp32 weights for optimization. + 2. Convert the main model from fp32 to fp16. + """ + # keep a copy of fp32 weights + old_groups = runner.optimizer.param_groups + runner.optimizer.param_groups = copy.deepcopy( + runner.optimizer.param_groups) + state: defaultdict = defaultdict(dict) + p_map = { + old_p: p + for old_p, p in zip( + chain(*(g['params'] for g in old_groups)), + chain(*(g['params'] + for g in runner.optimizer.param_groups))) + } + for k, v in runner.optimizer.state.items(): + state[p_map[k]] = v + runner.optimizer.state = state + # convert model to fp16 + wrap_fp16_model(runner.model) + # resume from state dict + if 'fp16' in runner.meta and 'loss_scaler' in runner.meta['fp16']: + scaler_state_dict = runner.meta['fp16']['loss_scaler'] + self.loss_scaler.load_state_dict(scaler_state_dict) + + def copy_grads_to_fp32(self, fp16_net: nn.Module, + fp32_weights: Tensor) -> None: + """Copy gradients from fp16 model to fp32 weight copy.""" + for fp32_param, fp16_param in zip(fp32_weights, + fp16_net.parameters()): + if fp16_param.grad is not None: + if fp32_param.grad is None: + fp32_param.grad = fp32_param.data.new( + fp32_param.size()) + fp32_param.grad.copy_(fp16_param.grad) + + def copy_params_to_fp16(self, fp16_net: nn.Module, + fp32_weights: Tensor) -> None: + """Copy updated params from fp32 weight copy to fp16 model.""" + for fp16_param, fp32_param in zip(fp16_net.parameters(), + fp32_weights): + fp16_param.data.copy_(fp32_param.data) + + def after_train_iter(self, runner) -> None: + """Backward optimization steps for Mixed Precision Training. For + dynamic loss scaling, please refer `loss_scalar.py` + + 1. Scale the loss by a scale factor. + 2. Backward the loss to obtain the gradients (fp16). + 3. Copy gradients from the model to the fp32 weight copy. + 4. Scale the gradients back and update the fp32 weight copy. + 5. Copy back the params from fp32 weight copy to the fp16 model. + 6. Save loss_scaler state_dict for resume purpose. + """ + # clear grads of last iteration + runner.model.zero_grad() + runner.optimizer.zero_grad() + # scale the loss value + scaled_loss = runner.outputs['loss'] * self.loss_scaler.loss_scale + scaled_loss.backward() + # copy fp16 grads in the model to fp32 params in the optimizer + + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights, runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + self.loss_scaler.update_scale(has_overflow) + if has_overflow: + runner.logger.warning('Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + @HOOKS.register_module() + class GradientCumulativeFp16OptimizerHook( # type: ignore + GradientCumulativeOptimizerHook, Fp16OptimizerHook): + """Fp16 optimizer Hook (using mmcv implementation) implements multi-iters gradient cumulating.""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def after_train_iter(self, runner) -> None: + if not self.initialized: + self._init(runner) + + loss = runner.outputs['loss'] / self._get_loss_factor(runner) + scaled_loss = loss * self.loss_scaler.loss_scale + scaled_loss.backward() + + if (self.every_n_iters(runner, self.cumulative_iters) + or self.is_last_iter(runner)): + + # copy fp16 grads in the model to fp32 params in the optimizer + fp32_weights = [] + for param_group in runner.optimizer.param_groups: + fp32_weights += param_group['params'] + self.copy_grads_to_fp32(runner.model, fp32_weights) + # allreduce grads + if self.distributed: + allreduce_grads(fp32_weights, self.coalesce, + self.bucket_size_mb) + + has_overflow = self.loss_scaler.has_overflow(fp32_weights) + # if has overflow, skip this iteration + if not has_overflow: + # scale the gradients back + for param in fp32_weights: + if param.grad is not None: + param.grad.div_(self.loss_scaler.loss_scale) + if self.grad_clip is not None: + grad_norm = self.clip_grads(fp32_weights, runner) + if grad_norm is not None: + # Add grad norm to the logger + runner.log_buffer.update( + {'grad_norm': float(grad_norm)}, + runner.outputs['num_samples']) + # update fp32 params + runner.optimizer.step() + # copy fp32 params to the fp16 model + self.copy_params_to_fp16(runner.model, fp32_weights) + else: + runner.logger.warning( + 'Check overflow, downscale loss scale ' + f'to {self.loss_scaler.cur_scale}') + + self.loss_scaler.update_scale(has_overflow) + + # save state_dict of loss_scaler + runner.meta.setdefault( + 'fp16', {})['loss_scaler'] = self.loss_scaler.state_dict() + + # clear grads + runner.model.zero_grad() + runner.optimizer.zero_grad() diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_replace/ops/modulated_deform_conv.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_replace/ops/modulated_deform_conv.py deleted file mode 100644 index 5d4551585b..0000000000 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmcv_replace/ops/modulated_deform_conv.py +++ /dev/null @@ -1,439 +0,0 @@ -# Copyright 2024 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# Copyright (c) OpenMMLab. All rights reserved. -import math -from typing import Optional, Tuple, Union - -import torch -import torch.nn as nn -from torch.autograd import Function -from torch.autograd.function import once_differentiable -from torch.nn.modules.utils import _pair, _single - -from mmcv.utils import IS_MLU_AVAILABLE, deprecated_api_warning -from ..cnn import CONV_LAYERS -from ..utils import ext_loader, print_log - -ext_module = ext_loader.load_ext( - '_ext', - ['modulated_deform_conv_forward', 'modulated_deform_conv_backward']) - - -class ModulatedDeformConv2dFunction(Function): - - @staticmethod - def symbolic(g, input, offset, mask, weight, bias, stride, padding, - dilation, groups, deform_groups): - input_tensors = [input, offset, mask, weight] - if bias is not None: - input_tensors.append(bias) - return g.op( - 'mmcv::MMCVModulatedDeformConv2d', - *input_tensors, - stride_i=stride, - padding_i=padding, - dilation_i=dilation, - groups_i=groups, - deform_groups_i=deform_groups) - - @staticmethod - def _calculate_sort_index(kernel_h, kernel_w, deformable_group): - split_num = deformable_group * 2 * kernel_h * kernel_w - sort_index = list(range(split_num)) - sort_index_fp = (sort_index[1::2] + sort_index[::2]) - sort_index_bp_dict = {i: idx for idx, i in enumerate(sort_index_fp)} - sort_index_bp = [sort_index_bp_dict[i] for i in sort_index] - sort_index_fp = torch.IntTensor(sort_index_fp) - sort_index_bp = torch.IntTensor(sort_index_bp) - sort_index_fp = sort_index_fp.npu() - sort_index_bp = sort_index_bp.npu() - return sort_index_fp, sort_index_bp - - @staticmethod - def _npu_forward(ctx, input_tensor, offset, mask, weight, bias): - _, _, kernel_h, kernel_w = weight.shape - conv2d_bias = bias if len(bias) > 0 else None - sort_index_fp, sort_index_bp = \ - ModulatedDeformConv2dFunction._calculate_sort_index( - kernel_w, kernel_h, ctx.deform_groups) - select_offset = offset.index_select(1, sort_index_fp) - offset_all = torch.cat([select_offset, mask], dim=1) - import torch_npu - output, offset_out = torch_npu.npu_deformable_conv2d( - input_tensor, - weight, - offset_all, - conv2d_bias, - kernel_size=[kernel_w, kernel_h], - stride=[1, 1, ctx.stride[0], ctx.stride[1]], - padding=[ - ctx.padding[0], ctx.padding[0], ctx.padding[1], ctx.padding[1] - ], - dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], - groups=ctx.groups, - deformable_groups=ctx.deform_groups, - modulated=True) - if weight.requires_grad or mask.requires_grad or offset.requires_grad \ - or input_tensor.requires_grad: - ctx.save_for_backward(input_tensor, weight, offset_out, offset_all, - sort_index_bp) - return output - - @staticmethod - def _npu_backward(ctx, grad_output): - input_tensor, weight, offset_out, offset_all, sort_index_bp = \ - ctx.saved_tensors - import torch_npu - grad_input, grad_weight, grad_offset_all, grad_bias = \ - torch_npu.npu_deformable_conv2dbk( - input_tensor, grad_output, offset_out, weight, offset_all, - kernel_size=[weight.shape[3], weight.shape[2]], - stride=[1, 1, ctx.stride[0], ctx.stride[1]], - padding=[ctx.padding[0], ctx.padding[0], ctx.padding[1], - ctx.padding[1]], - dilation=[1, 1, ctx.dilation[0], ctx.dilation[1]], - groups=ctx.groups, deformable_groups=ctx.deform_groups, - modulated=True) - grad_offset = grad_offset_all.index_select(1, sort_index_bp) - grad_mask = grad_offset_all[:, grad_offset.shape[1]:, :, :] - if not ctx.with_bias: - grad_bias = None - return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, - None, None, None, None, None, None, None, None) - - @staticmethod - def forward(ctx, - input: torch.Tensor, - offset: torch.Tensor, - mask: torch.Tensor, - weight: nn.Parameter, - bias: Optional[nn.Parameter] = None, - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - deform_groups: int = 1) -> torch.Tensor: - if input is not None and input.dim() != 4: - raise ValueError( - f'Expected 4D tensor as input, got {input.dim()}D tensor \ - instead.') - ctx.stride = _pair(stride) - ctx.padding = _pair(padding) - ctx.dilation = _pair(dilation) - ctx.groups = groups - ctx.deform_groups = deform_groups - ctx.with_bias = bias is not None - ctx.device = input.device.type - if not ctx.with_bias: - bias = input.new_empty(0) # fake tensor - # When pytorch version >= 1.6.0, amp is adopted for fp16 mode; - # amp won't cast the type of model (float32), but "offset" is cast - # to float16 by nn.Conv2d automatically, leading to the type - # mismatch with input (when it is float32) or weight. - # The flag for whether to use fp16 or amp is the type of "offset", - # we cast weight and input to temporarily support fp16 and amp - # whatever the pytorch version is. - input = input.type_as(offset) - weight = weight.type_as(input) - bias = bias.type_as(input) # type: ignore - mask = mask.type_as(input) - if ctx.device == 'npu': - output = ModulatedDeformConv2dFunction._npu_forward( - ctx, input, offset, mask, weight, bias) - return output - ctx.save_for_backward(input, offset, mask, weight, bias) - output = input.new_empty( - ModulatedDeformConv2dFunction._output_size(ctx, input, weight)) - ctx._bufs = [input.new_empty(0), input.new_empty(0)] - ext_module.modulated_deform_conv_forward( - input, - weight, - bias, - ctx._bufs[0], - offset, - mask, - output, - ctx._bufs[1], - kernel_h=weight.size(2), - kernel_w=weight.size(3), - stride_h=ctx.stride[0], - stride_w=ctx.stride[1], - pad_h=ctx.padding[0], - pad_w=ctx.padding[1], - dilation_h=ctx.dilation[0], - dilation_w=ctx.dilation[1], - group=ctx.groups, - deformable_group=ctx.deform_groups, - with_bias=ctx.with_bias) - return output - - @staticmethod - @once_differentiable - def backward(ctx, grad_output: torch.Tensor) -> tuple: - if ctx.device == 'npu': - return ModulatedDeformConv2dFunction._npu_backward( - ctx, grad_output) - input, offset, mask, weight, bias = ctx.saved_tensors - grad_input = torch.zeros_like(input) - grad_offset = torch.zeros_like(offset) - grad_mask = torch.zeros_like(mask) - grad_weight = torch.zeros_like(weight) - grad_bias = torch.zeros_like(bias) - grad_output = grad_output.contiguous() - ext_module.modulated_deform_conv_backward( - input, - weight, - bias, - ctx._bufs[0], - offset, - mask, - ctx._bufs[1], - grad_input, - grad_weight, - grad_bias, - grad_offset, - grad_mask, - grad_output, - kernel_h=weight.size(2), - kernel_w=weight.size(3), - stride_h=ctx.stride[0], - stride_w=ctx.stride[1], - pad_h=ctx.padding[0], - pad_w=ctx.padding[1], - dilation_h=ctx.dilation[0], - dilation_w=ctx.dilation[1], - group=ctx.groups, - deformable_group=ctx.deform_groups, - with_bias=ctx.with_bias) - if not ctx.with_bias: - grad_bias = None - - return (grad_input, grad_offset, grad_mask, grad_weight, grad_bias, - None, None, None, None, None) - - @staticmethod - def _output_size(ctx, input, weight): - channels = weight.size(0) - output_size = (input.size(0), channels) - for d in range(input.dim() - 2): - in_size = input.size(d + 2) - pad = ctx.padding[d] - kernel = ctx.dilation[d] * (weight.size(d + 2) - 1) + 1 - stride_ = ctx.stride[d] - output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, ) - if not all(map(lambda s: s > 0, output_size)): - raise ValueError( - 'convolution input is too small (output would be ' + - 'x'.join(map(str, output_size)) + ')') - return output_size - - -modulated_deform_conv2d = ModulatedDeformConv2dFunction.apply - - -class ModulatedDeformConv2d(nn.Module): - - @deprecated_api_warning({'deformable_groups': 'deform_groups'}, - cls_name='ModulatedDeformConv2d') - def __init__(self, - in_channels: int, - out_channels: int, - kernel_size: Union[int, Tuple[int]], - stride: int = 1, - padding: int = 0, - dilation: int = 1, - groups: int = 1, - deform_groups: int = 1, - bias: Union[bool, str] = True): - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = _pair(kernel_size) - self.stride = _pair(stride) - self.padding = _pair(padding) - self.dilation = _pair(dilation) - self.groups = groups - self.deform_groups = deform_groups - # enable compatibility with nn.Conv2d - self.transposed = False - self.output_padding = _single(0) - - self.weight = nn.Parameter( - torch.Tensor(out_channels, in_channels // groups, - *self.kernel_size)) - if bias: - self.bias = nn.Parameter(torch.Tensor(out_channels)) - else: - self.register_parameter('bias', None) - self.init_weights() - - def init_weights(self): - n = self.in_channels - for k in self.kernel_size: - n *= k - stdv = 1. / math.sqrt(n) - self.weight.data.uniform_(-stdv, stdv) - if self.bias is not None: - self.bias.data.zero_() - - def forward(self, x: torch.Tensor, offset: torch.Tensor, - mask: torch.Tensor) -> torch.Tensor: - return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, - self.stride, self.padding, - self.dilation, self.groups, - self.deform_groups) - - -@CONV_LAYERS.register_module('DCNv2') -class ModulatedDeformConv2dPack(ModulatedDeformConv2d): - """A ModulatedDeformable Conv Encapsulation that acts as normal Conv - layers. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int): Same as nn.Conv2d, while tuple is not supported. - padding (int): Same as nn.Conv2d, while tuple is not supported. - dilation (int): Same as nn.Conv2d, while tuple is not supported. - groups (int): Same as nn.Conv2d. - bias (bool or str): If specified as `auto`, it will be decided by the - norm_cfg. Bias will be set as True if norm_cfg is None, otherwise - False. - """ - - _version = 2 - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.conv_offset = nn.Conv2d( - self.in_channels, - self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - bias=True) - self.init_weights() - - def init_weights(self) -> None: - super().init_weights() - if hasattr(self, 'conv_offset'): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore - out = self.conv_offset(x) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - return modulated_deform_conv2d(x, offset, mask, self.weight, self.bias, - self.stride, self.padding, - self.dilation, self.groups, - self.deform_groups) - - def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, - missing_keys, unexpected_keys, error_msgs): - version = local_metadata.get('version', None) - - if version is None or version < 2: - # the key is different in early versions - # In version < 2, ModulatedDeformConvPack - # loads previous benchmark models. - if (prefix + 'conv_offset.weight' not in state_dict - and prefix[:-1] + '_offset.weight' in state_dict): - state_dict[prefix + 'conv_offset.weight'] = state_dict.pop( - prefix[:-1] + '_offset.weight') - if (prefix + 'conv_offset.bias' not in state_dict - and prefix[:-1] + '_offset.bias' in state_dict): - state_dict[prefix + - 'conv_offset.bias'] = state_dict.pop(prefix[:-1] + - '_offset.bias') - - if version is not None and version > 1: - print_log( - f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' - 'version 2.', - logger='root') - - super()._load_from_state_dict(state_dict, prefix, local_metadata, - strict, missing_keys, unexpected_keys, - error_msgs) - - -if IS_MLU_AVAILABLE: - import torchvision - from torchvision.ops import deform_conv2d as tv_deform_conv2d - - from mmcv.utils import digit_version - - @CONV_LAYERS.register_module('DCNv2', force=True) - class ModulatedDeformConv2dPack_MLU(ModulatedDeformConv2d): - """This class is the DCNv2 implementation of the MLU device. The MLU - backend support of the operator has been implemented in torchvision. - The mmcv registration mechanism is used for multiplexing here. The - torchvision implementation of DCNv2 is called. - - Args: - in_channels (int): Same as nn.Conv2d. - out_channels (int): Same as nn.Conv2d. - kernel_size (int or tuple[int]): Same as nn.Conv2d. - stride (int): Same as nn.Conv2d, while tuple is not supported. - padding (int): Same as nn.Conv2d, while tuple is not supported. - dilation (int): Same as nn.Conv2d, while tuple is not supported. - groups (int): Same as nn.Conv2d. - bias (bool or str): If specified as `auto`, it will be decided by - the norm_cfg. Bias will be set as True if norm_cfg is None, - otherwise False. - """ - - def __init__(self, *args, **kwargs): - assert digit_version(torchvision.__version__) >= digit_version( - '0.10.0a0'), 'the version of torchvision should be >= 0.10.0' - super().__init__(*args, **kwargs) - self.conv_offset = nn.Conv2d( - self.in_channels, - self.deform_groups * 3 * self.kernel_size[0] * - self.kernel_size[1], - kernel_size=self.kernel_size, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - bias=True) - self.init_weights() - - def init_weights(self): - super().init_weights() - if hasattr(self, 'conv_offset'): - self.conv_offset.weight.data.zero_() - self.conv_offset.bias.data.zero_() - - def forward(self, x): - out = self.conv_offset(x) - o1, o2, mask = torch.chunk(out, 3, dim=1) - offset = torch.cat((o1, o2), dim=1) - mask = torch.sigmoid(mask) - x = x.type_as(offset) - weight = self.weight.type_as(x) - mask = mask.type_as(x) - return tv_deform_conv2d( - x, - offset, - weight, - bias=self.bias, - stride=self.stride, - padding=self.padding, - dilation=self.dilation, - mask=mask) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmdet_need/resnet.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmdet_need/resnet.py new file mode 100644 index 0000000000..2f9888b8b8 --- /dev/null +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/mmdet_need/resnet.py @@ -0,0 +1,672 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +import warnings +import torch +import torch_npu + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmcv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm +import mx_driving.fused + + +from ..builder import BACKBONES +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + out = x + for name in plugin_names: + out = getattr(self, name)(out) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out = mx_driving.fused.npu_add_relu(out, identity) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = mx_driving.fused.npu_max_pool2d(x, 3, 2, 1) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + r""" + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/requirements.txt b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/requirements.txt index d551a425c4..b172853f5e 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/requirements.txt +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/requirements.txt @@ -1,6 +1,6 @@ setuptools==65.7.0 torchvision==0.16.0 -opencv-python-headless==4.5.3.56 +opencv-python-headless==4.10.0.84 nuscenes-devkit==1.1.11 numba==0.58.1 numpy==1.23.1 diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/test/env_npu.sh b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/test/env_npu.sh index f539726029..f42bff3fe0 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/test/env_npu.sh +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/test/env_npu.sh @@ -14,23 +14,24 @@ else fi +#设置Shape数据缓存 +export HOST_CACHE_CAPACITY=20 #将Host日志输出到串口,0-关闭/1-开启 export ASCEND_SLOG_PRINT_TO_STDOUT=0 #设置默认日志级别,0-debug/1-info/2-warning/3-error export ASCEND_GLOBAL_LOG_LEVEL=3 #设置Event日志开启标志,0-关闭/1-开启 export ASCEND_GLOBAL_EVENT_ENABLE=0 -#设置是否开启taskque,0-关闭/1-开启 -export TASK_QUEUE_ENABLE=1 -#设置是否开启PTCopy,0-关闭/1-开启 -export PTCOPY_ENABLE=1 +#设置是否开启taskque,0-关闭/1-开启/2-流水优化 +export TASK_QUEUE_ENABLE=2 #设置是否开启combined标志,0-关闭/1-开启 export COMBINED_ENABLE=1 -#设置特殊场景是否需要重新编译,不需要修改 -export DYNAMIC_OP="ADD#MUL" +#设置是否开启均匀绑核,0-关闭/1-开启 +export CPU_AFFINITY_CONF=1 #HCCL白名单开关,1-关闭/0-开启 export HCCL_WHITELIST_DISABLE=1 export HCCL_IF_IP=$(hostname -I |awk '{print $1}') +export HCCL_CONNECT_TIMEOUT=1200 #设置device侧日志登记为error msnpureport -g error -d 0 diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/encoder.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/encoder.py index 0f65e44210..18fd2ee8ab 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/encoder.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/encoder.py @@ -1,3 +1,4 @@ +# Copyright 2024 Huawei Technologies Co., Ltd from mmcv.cnn.bricks.registry import TRANSFORMER_LAYER_SEQUENCE from mmcv.cnn.bricks.transformer import TransformerLayerSequence from mmcv.runner import force_fp32, auto_fp16 @@ -193,8 +194,8 @@ class TPVFormerEncoder(TransformerLayerSequence): lidar2img = lidar2img.view( 1, B, num_cam, 1, 4, 4).repeat(D, 1, 1, num_query, 1, 1) - reference_points_cam = torch.matmul(lidar2img.to(torch.float32), - reference_points.to(torch.float32)).squeeze(-1) + reference_points_cam = torch.mul(lidar2img.to(torch.float32), + reference_points.to(torch.float32).transpose(-1, -2)).sum(-1, keepdim=True).squeeze(-1) eps = 1e-5 tpv_mask = (reference_points_cam[..., 2:3] > eps) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/tpvformer_layer.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/tpvformer_layer.py index e4ca80c2a0..114752f740 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/tpvformer_layer.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/tpvformer10/modules/tpvformer_layer.py @@ -1,4 +1,4 @@ - +# Copyright 2024 Huawei Technologies Co., Ltd import copy import warnings @@ -46,6 +46,9 @@ class TPVFormerLayer(BaseModule): """ def __init__(self, + tpv_h, + tpv_w, + tpv_z, attn_cfgs=None, ffn_cfgs=dict( type='FFN', @@ -129,6 +132,15 @@ class TPVFormerLayer(BaseModule): for _ in range(num_norms): self.norms.append(build_norm_layer(norm_cfg, self.embed_dims)[1]) + self.ss = torch.tensor([ + [tpv_h, tpv_w], + [tpv_z, tpv_h], + [tpv_w, tpv_z] + ]).npu() + self.lsi = torch.tensor([ + 0, tpv_h * tpv_w, tpv_h * tpv_w + tpv_z * tpv_h + ]).npu() + def forward(self, query, key=None, @@ -169,15 +181,6 @@ class TPVFormerLayer(BaseModule): for layer in self.operation_order: # cross view hybrid attention if layer == 'self_attn': - ss = torch.tensor([ - [tpv_h, tpv_w], - [tpv_z, tpv_h], - [tpv_w, tpv_z] - ], device=query[0].device) - lsi = torch.tensor([ - 0, tpv_h*tpv_w, tpv_h*tpv_w+tpv_z*tpv_h - ], device=query[0].device) - if not isinstance(query, (list, tuple)): query = torch.split( query, [tpv_h*tpv_w, tpv_z*tpv_h, tpv_w*tpv_z], dim=1) @@ -187,8 +190,8 @@ class TPVFormerLayer(BaseModule): identity if self.pre_norm else None, query_pos=tpv_pos, reference_points=ref_2d, - spatial_shapes=ss, - level_start_index=lsi, + spatial_shapes=self.ss, + level_start_index=self.lsi, **kwargs) attn_index += 1 query = torch.cat(query, dim=1) diff --git a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/train.py b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/train.py index 855a7073c9..6fc2c92260 100644 --- a/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/train.py +++ b/PyTorch/contrib/autonoumous_driving/TPVFormer_for_PyTorch/train.py @@ -30,6 +30,9 @@ from mmcv.runner import build_optimizer from mmseg.utils import get_root_logger from timm.scheduler import CosineLRScheduler +torch_npu.npu.set_compile_mode(jit_compile=False) +torch.npu.config.allow_internal_format = False + import warnings warnings.filterwarnings("ignore") -- Gitee