From 1b5154c2a8d05ed65ef075604dcc6ab2b016e4ed Mon Sep 17 00:00:00 2001 From: zhangchenrui Date: Thu, 31 Oct 2024 16:05:29 +0800 Subject: [PATCH] update model SurroudOcc --- .../autonoumous_driving/SurroundOcc/README.md | 19 +- .../SurroundOcc/patch/mmcv/distributed.py | 165 +++++ .../patch/mmcv/modulated_deform_conv.py | 154 ++++ .../SurroundOcc/patch/mmdet/resnet.py | 675 ++++++++++++++++++ .../SurroundOcc/replace_patch.sh | 3 + .../SurroundOcc/test/env_npu.sh | 6 + .../SurroundOcc/tools/train.py | 3 +- 7 files changed, 1015 insertions(+), 10 deletions(-) create mode 100644 PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/distributed.py create mode 100644 PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/modulated_deform_conv.py create mode 100644 PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmdet/resnet.py diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/README.md b/PyTorch/built-in/autonoumous_driving/SurroundOcc/README.md index 2fa65c8c00..0fc6aa603b 100644 --- a/PyTorch/built-in/autonoumous_driving/SurroundOcc/README.md +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/README.md @@ -52,7 +52,7 @@ | 三方库 | 支持版本 | | :-----: | :------: | -| PyTorch | 1.11 | +| PyTorch | 2.1 | ### 安装昇腾环境 @@ -62,10 +62,10 @@ | 软件类型 | 支持版本 | | :---------------: | :------: | -| FrameworkPTAdaper | 6.0.RC2 | -| CANN | 8.0.RC2 | -| 昇腾NPU固件 | 24.1.RC2 | -| 昇腾NPU驱动 | 24.1.RC2 | +| FrameworkPTAdaper | 6.0.RC3 | +| CANN | 8.0.RC3 | +| 昇腾NPU固件 | 24.1.RC3 | +| 昇腾NPU驱动 | 24.1.RC3 | - 安装mmdet3d @@ -98,7 +98,7 @@ MMCV_WITH_OPS=1 pip install -e . -v ``` - 安装mx_driving - - 参考mxDriving官方gitee仓README安装编译构建并安装mxDriving包:[参考链接](https://gitee.com/ascend/mxDriving) + - 参考mxDriving官方gitee仓README安装编译构建并安装mxDriving包,安装master分支:[参考链接](https://gitee.com/ascend/mxDriving) 【注意】当前版本配套mxDriving RC3及以上版本,历史mxDriving版本需要model仓代码回退到git reset --hard 91ac141ecfe5872f4835eef6aa4662f46ede80c3 - 在模型根目录下执行以下命令,安装模型对应PyTorch版本需要的依赖。 @@ -113,7 +113,6 @@ bash replace_patch.sh --packages_path=location_path ``` -- 安装mxDriving加速库,安装方法参考[原仓](https://gitee.com/ascend/mxDriving),安装后手动source环境变量或将其配置在test/env_npu.sh中。 ### 准备数据集 @@ -166,7 +165,7 @@ SurroundOcc | 芯片 | 卡数 | global batch size | Precision | epoch | IoU | mIoU | 性能-单步迭代耗时(ms) | | ------------- | :--: | :---------------: | :-------: | :---: | :----: | :----: | :-------------------: | | 竞品A | 8p | 8 | fp32 | 12 | 0.3163 | 0.1999 | 1028 | -| Atlas 800T A2 | 8p | 8 | fp32 | 12 | 0.3169 | 0.1985 | 1654 | +| Atlas 800T A2 | 8p | 8 | fp32 | 12 | 0.3114 | 0.1896 | 1054 | # 公网地址说明 @@ -174,7 +173,9 @@ SurroundOcc # 变更说明 -2024.05.30:首次发布。 +2024.05.30:首次发布 + +2024.10.30:性能优化 # FAQ diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/distributed.py b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/distributed.py new file mode 100644 index 0000000000..20a44be3d3 --- /dev/null +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/distributed.py @@ -0,0 +1,165 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +from typing import Any, List, Tuple + +import torch +from torch.nn.parallel.distributed import (DistributedDataParallel, + _find_tensors) + +from mmcv import print_log +from mmcv.utils import TORCH_VERSION, digit_version +from .scatter_gather import ScatterInputs, scatter_kwargs + + +class MMDistributedDataParallel(DistributedDataParallel): + """The DDP module that supports DataContainer. + + MMDDP has two main differences with PyTorch DDP: + + - It supports a custom type :class:`DataContainer` which allows more + flexible control of input data. + - It implement two APIs ``train_step()`` and ``val_step()``. + """ + + def to_kwargs(self, inputs: ScatterInputs, kwargs: ScatterInputs, + device_id: int) -> Tuple[tuple, tuple]: + # Use `self.to_kwargs` instead of `self.scatter` in pytorch1.8 + # to move all tensors to device_id + return scatter_kwargs(inputs, kwargs, [device_id], dim=self.dim) + + def scatter(self, inputs: ScatterInputs, kwargs: ScatterInputs, + device_ids: List[int]) -> Tuple[tuple, tuple]: + return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) + + def train_step(self, *inputs, **kwargs): + """train_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.train_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.train_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.train_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def val_step(self, *inputs, **kwargs): + """val_step() API for module wrapped by DistributedDataParallel. + + This method is basically the same as + ``DistributedDataParallel.forward()``, while replacing + ``self.module.forward()`` with ``self.module.val_step()``. + It is compatible with PyTorch 1.1 - 1.5. + """ + # In PyTorch >= 1.7, ``reducer._rebuild_buckets()`` is moved from the + # end of backward to the beginning of forward. + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.7') + and self.reducer._rebuild_buckets()): + print_log( + 'Reducer buckets have been rebuilt in this iteration.', + logger='mmcv') + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_pre_fwd(): + self._sync_buffers() + else: + if (getattr(self, 'require_forward_param_sync', False) + and self.require_forward_param_sync): + self._sync_params() + + if self.device_ids: + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + if len(self.device_ids) == 1: + output = self.module.val_step(*inputs[0], **kwargs[0]) + else: + outputs = self.parallel_apply( + self._module_copies[:len(inputs)], inputs, kwargs) + output = self.gather(outputs, self.output_device) + else: + output = self.module.val_step(*inputs, **kwargs) + + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) >= digit_version('1.11.0a0')): + if self._check_sync_bufs_post_fwd(): + self._sync_buffers() + + if (torch.is_grad_enabled() + and getattr(self, 'require_backward_grad_sync', False) + and self.require_backward_grad_sync): + if self.find_unused_parameters: + self.reducer.prepare_for_backward(list(_find_tensors(output))) + else: + self.reducer.prepare_for_backward([]) + else: + if ('parrots' not in TORCH_VERSION + and digit_version(TORCH_VERSION) > digit_version('1.2')): + self.require_forward_param_sync = False + return output + + def _run_ddp_forward(self, *inputs, **kwargs) -> Any: + """Processes inputs and runs ``self.module.forward``. + + Pytorch 1.12.0 performs ``self.module.forward`` in ``_run_ddp_forward`` + and deprecates using ``DistributedDataParallel.to_kwargs`` to + process inputs, which leads to inputs cannot be processed by + :meth:`MMDistributedDataParallel.to_kwargs` anymore. Therefore, + ``MMDistributedDataParallel`` overrides this method to call + :meth:`to_kwargs` explicitly. + + Returns: + Any: Forward result of :attr:`module`. + """ + module_to_run = self.module + + if self.device_ids: + inputs, kwargs = self.to_kwargs( # type: ignore + inputs, kwargs, self.device_ids[0]) + return module_to_run(*inputs[0], **kwargs[0]) # type: ignore + else: + return module_to_run(*inputs, **kwargs) diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/modulated_deform_conv.py b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/modulated_deform_conv.py new file mode 100644 index 0000000000..6667f55ad6 --- /dev/null +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmcv/modulated_deform_conv.py @@ -0,0 +1,154 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +import math +from typing import Optional, Tuple, Union + +import torch +import torch_npu +import torch.nn as nn +from torch.nn.modules.utils import _pair, _single +from mmcv.utils import deprecated_api_warning +from mx_driving.fused import modulated_deform_conv2d, ModulatedDeformConv2dFunction + +from ..cnn import CONV_LAYERS +from ..utils import print_log + + +class ModulatedDeformConv2d(nn.Module): + + @deprecated_api_warning({"deformable_groups": "deform_groups"}, cls_name="ModulatedDeformConv2d") + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: Union[int, Tuple[int]], + stride: int = 1, + padding: int = 0, + dilation: int = 1, + groups: int = 1, + deform_groups: int = 1, + bias: Union[bool, str] = True, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = _pair(stride) + self.padding = _pair(padding) + self.dilation = _pair(dilation) + self.groups = groups + self.deform_groups = deform_groups + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias", None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1.0 / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + def forward(self, x: torch.Tensor, offset: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + return modulated_deform_conv2d( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deform_groups, + ) + + +@CONV_LAYERS.register_module("DCNv2") +class ModulatedDeformConv2dPack(ModulatedDeformConv2d): + """A ModulatedDeformable Conv Encapsulation that acts as normal Conv + layers. + + Args: + in_channels (int): Same as nn.Conv2d. + out_channels (int): Same as nn.Conv2d. + kernel_size (int or tuple[int]): Same as nn.Conv2d. + stride (int): Same as nn.Conv2d, while tuple is not supported. + padding (int): Same as nn.Conv2d, while tuple is not supported. + dilation (int): Same as nn.Conv2d, while tuple is not supported. + groups (int): Same as nn.Conv2d. + bias (bool or str): If specified as `auto`, it will be decided by the + norm_cfg. Bias will be set as True if norm_cfg is None, otherwise + False. + """ + + _version = 2 + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.conv_offset = nn.Conv2d( + self.in_channels, + self.deform_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding, + dilation=self.dilation, + bias=True, + ) + self.init_weights() + + def init_weights(self) -> None: + super().init_weights() + if hasattr(self, "conv_offset"): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore + out = self.conv_offset(x) + len1 = ((out.shape[1] + 2) // 3) * 2 + len2 = out.shape[1] - len1 + offset, mask = torch.split(out, [len1, len2], dim=1) + mask = torch.sigmoid(mask) + return modulated_deform_conv2d( + x, + offset, + mask, + self.weight, + self.bias, + self.stride, + self.padding, + self.dilation, + self.groups, + self.deform_groups, + ) + + # pylint: disable=huawei-too-many-arguments + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # the key is different in early versions + # In version < 2, ModulatedDeformConvPack + # loads previous benchmark models. + if prefix + "conv_offset.weight" not in state_dict and prefix[:-1] + "_offset.weight" in state_dict: + state_dict[prefix + "conv_offset.weight"] = state_dict.pop(prefix[:-1] + "_offset.weight") + if prefix + "conv_offset.bias" not in state_dict and prefix[:-1] + "_offset.bias" in state_dict: + state_dict[prefix + "conv_offset.bias"] = state_dict.pop(prefix[:-1] + "_offset.bias") + + if version is not None and version > 1: + print_log(f'ModulatedDeformConvPack {prefix.rstrip(".")} is upgraded to ' "version 2.", logger="root") + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmdet/resnet.py b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmdet/resnet.py new file mode 100644 index 0000000000..adc8e3281a --- /dev/null +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/patch/mmdet/resnet.py @@ -0,0 +1,675 @@ +# Copyright (c) OpenMMLab. All rights reserved. +# Copyright 2024 Huawei Technologies Co., Ltd +import warnings + +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer +from mmcv.runner import BaseModule +from torch.nn.modules.batchnorm import _BatchNorm +import mx_driving.fused +import torch +import torch_npu + +from ..builder import BACKBONES +from ..utils import ResLayer + + +class BasicBlock(BaseModule): + expansion = 1 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + super(BasicBlock, self).__init__(init_cfg) + assert dcn is None, 'Not implemented yet.' + assert plugins is None, 'Not implemented yet.' + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + 3, + stride=stride, + padding=dilation, + dilation=dilation, + bias=False) + self.add_module(self.norm1_name, norm1) + self.conv2 = build_conv_layer( + conv_cfg, planes, planes, 3, padding=1, bias=False) + self.add_module(self.norm2_name, norm2) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + self.stride = stride + self.dilation = dilation + self.with_cp = with_cp + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.norm2(out) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class Bottleneck(BaseModule): + expansion = 4 + + def __init__(self, + inplanes, + planes, + stride=1, + dilation=1, + downsample=None, + style='pytorch', + with_cp=False, + conv_cfg=None, + norm_cfg=dict(type='BN'), + dcn=None, + plugins=None, + init_cfg=None): + """Bottleneck block for ResNet. + + If style is "pytorch", the stride-two layer is the 3x3 conv layer, if + it is "caffe", the stride-two layer is the first 1x1 conv layer. + """ + super(Bottleneck, self).__init__(init_cfg) + assert style in ['pytorch', 'caffe'] + assert dcn is None or isinstance(dcn, dict) + assert plugins is None or isinstance(plugins, list) + if plugins is not None: + allowed_position = ['after_conv1', 'after_conv2', 'after_conv3'] + assert all(p['position'] in allowed_position for p in plugins) + + self.inplanes = inplanes + self.planes = planes + self.stride = stride + self.dilation = dilation + self.style = style + self.with_cp = with_cp + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.dcn = dcn + self.with_dcn = dcn is not None + self.plugins = plugins + self.with_plugins = plugins is not None + + if self.with_plugins: + # collect plugins for conv1/conv2/conv3 + self.after_conv1_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv1' + ] + self.after_conv2_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv2' + ] + self.after_conv3_plugins = [ + plugin['cfg'] for plugin in plugins + if plugin['position'] == 'after_conv3' + ] + + if self.style == 'pytorch': + self.conv1_stride = 1 + self.conv2_stride = stride + else: + self.conv1_stride = stride + self.conv2_stride = 1 + + self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1) + self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2) + self.norm3_name, norm3 = build_norm_layer( + norm_cfg, planes * self.expansion, postfix=3) + + self.conv1 = build_conv_layer( + conv_cfg, + inplanes, + planes, + kernel_size=1, + stride=self.conv1_stride, + bias=False) + self.add_module(self.norm1_name, norm1) + fallback_on_stride = False + if self.with_dcn: + fallback_on_stride = dcn.pop('fallback_on_stride', False) + if not self.with_dcn or fallback_on_stride: + self.conv2 = build_conv_layer( + conv_cfg, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + else: + assert self.conv_cfg is None, 'conv_cfg must be None for DCN' + self.conv2 = build_conv_layer( + dcn, + planes, + planes, + kernel_size=3, + stride=self.conv2_stride, + padding=dilation, + dilation=dilation, + bias=False) + + self.add_module(self.norm2_name, norm2) + self.conv3 = build_conv_layer( + conv_cfg, + planes, + planes * self.expansion, + kernel_size=1, + bias=False) + self.add_module(self.norm3_name, norm3) + + self.relu = nn.ReLU(inplace=True) + self.downsample = downsample + + if self.with_plugins: + self.after_conv1_plugin_names = self.make_block_plugins( + planes, self.after_conv1_plugins) + self.after_conv2_plugin_names = self.make_block_plugins( + planes, self.after_conv2_plugins) + self.after_conv3_plugin_names = self.make_block_plugins( + planes * self.expansion, self.after_conv3_plugins) + + def make_block_plugins(self, in_channels, plugins): + """make plugins for block. + + Args: + in_channels (int): Input channels of plugin. + plugins (list[dict]): List of plugins cfg to build. + + Returns: + list[str]: List of the names of plugin. + """ + assert isinstance(plugins, list) + plugin_names = [] + for plugin in plugins: + plugin = plugin.copy() + name, layer = build_plugin_layer( + plugin, + in_channels=in_channels, + postfix=plugin.pop('postfix', '')) + assert not hasattr(self, name), f'duplicate plugin {name}' + self.add_module(name, layer) + plugin_names.append(name) + return plugin_names + + def forward_plugin(self, x, plugin_names): + out = x + for name in plugin_names: + out = getattr(self, name)(out) + return out + + @property + def norm1(self): + """nn.Module: normalization layer after the first convolution layer""" + return getattr(self, self.norm1_name) + + @property + def norm2(self): + """nn.Module: normalization layer after the second convolution layer""" + return getattr(self, self.norm2_name) + + @property + def norm3(self): + """nn.Module: normalization layer after the third convolution layer""" + return getattr(self, self.norm3_name) + + def forward(self, x): + """Forward function.""" + + def _inner_forward(x): + identity = x + out = self.conv1(x) + out = self.norm1(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv1_plugin_names) + + out = self.conv2(out) + out = self.norm2(out) + out = self.relu(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv2_plugin_names) + + out = self.conv3(out) + out = self.norm3(out) + + if self.with_plugins: + out = self.forward_plugin(out, self.after_conv3_plugin_names) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +@BACKBONES.register_module() +class ResNet(BaseModule): + """ResNet backbone. + + Args: + depth (int): Depth of resnet, from {18, 34, 50, 101, 152}. + stem_channels (int | None): Number of stem channels. If not specified, + it will be the same as `base_channels`. Default: None. + base_channels (int): Number of base channels of res layer. Default: 64. + in_channels (int): Number of input image channels. Default: 3. + num_stages (int): Resnet stages. Default: 4. + strides (Sequence[int]): Strides of the first block of each stage. + dilations (Sequence[int]): Dilation of each stage. + out_indices (Sequence[int]): Output from which stages. + style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two + layer is the 3x3 conv layer, otherwise the stride-two layer is + the first 1x1 conv layer. + deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv + avg_down (bool): Use AvgPool instead of stride conv when + downsampling in the bottleneck. + frozen_stages (int): Stages to be frozen (stop grad and set eval mode). + -1 means not freezing any parameters. + norm_cfg (dict): Dictionary to construct and config norm layer. + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + plugins (list[dict]): List of plugins for stages, each dict contains: + + - cfg (dict, required): Cfg dict to build plugin. + - position (str, required): Position inside block to insert + plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'. + - stages (tuple[bool], optional): Stages to apply plugin, length + should be same as 'num_stages'. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + zero_init_residual (bool): Whether to use zero init for last norm layer + in resblocks to let them behave as identity. + pretrained (str, optional): model pretrained path. Default: None + init_cfg (dict or list[dict], optional): Initialization config dict. + Default: None + + Example: + >>> from mmdet.models import ResNet + >>> import torch + >>> self = ResNet(depth=18) + >>> self.eval() + >>> inputs = torch.rand(1, 3, 32, 32) + >>> level_outputs = self.forward(inputs) + >>> for level_out in level_outputs: + ... print(tuple(level_out.shape)) + (1, 64, 8, 8) + (1, 128, 4, 4) + (1, 256, 2, 2) + (1, 512, 1, 1) + """ + + arch_settings = { + 18: (BasicBlock, (2, 2, 2, 2)), + 34: (BasicBlock, (3, 4, 6, 3)), + 50: (Bottleneck, (3, 4, 6, 3)), + 101: (Bottleneck, (3, 4, 23, 3)), + 152: (Bottleneck, (3, 8, 36, 3)) + } + + def __init__(self, + depth, + in_channels=3, + stem_channels=None, + base_channels=64, + num_stages=4, + strides=(1, 2, 2, 2), + dilations=(1, 1, 1, 1), + out_indices=(0, 1, 2, 3), + style='pytorch', + deep_stem=False, + avg_down=False, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN', requires_grad=True), + norm_eval=True, + dcn=None, + stage_with_dcn=(False, False, False, False), + plugins=None, + with_cp=False, + zero_init_residual=True, + pretrained=None, + init_cfg=None): + super(ResNet, self).__init__(init_cfg) + self.zero_init_residual = zero_init_residual + if depth not in self.arch_settings: + raise KeyError(f'invalid depth {depth} for resnet') + + block_init_cfg = None + assert not (init_cfg and pretrained), \ + 'init_cfg and pretrained cannot be specified at the same time' + if isinstance(pretrained, str): + warnings.warn('DeprecationWarning: pretrained is deprecated, ' + 'please use "init_cfg" instead') + self.init_cfg = dict(type='Pretrained', checkpoint=pretrained) + elif pretrained is None: + if init_cfg is None: + self.init_cfg = [ + dict(type='Kaiming', layer='Conv2d'), + dict( + type='Constant', + val=1, + layer=['_BatchNorm', 'GroupNorm']) + ] + block = self.arch_settings[depth][0] + if self.zero_init_residual: + if block is BasicBlock: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm2')) + elif block is Bottleneck: + block_init_cfg = dict( + type='Constant', + val=0, + override=dict(name='norm3')) + else: + raise TypeError('pretrained must be a str or None') + + self.depth = depth + if stem_channels is None: + stem_channels = base_channels + self.stem_channels = stem_channels + self.base_channels = base_channels + self.num_stages = num_stages + assert num_stages >= 1 and num_stages <= 4 + self.strides = strides + self.dilations = dilations + assert len(strides) == len(dilations) == num_stages + self.out_indices = out_indices + assert max(out_indices) < num_stages + self.style = style + self.deep_stem = deep_stem + self.avg_down = avg_down + self.frozen_stages = frozen_stages + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.with_cp = with_cp + self.norm_eval = norm_eval + self.dcn = dcn + self.stage_with_dcn = stage_with_dcn + if dcn is not None: + assert len(stage_with_dcn) == num_stages + self.plugins = plugins + self.block, stage_blocks = self.arch_settings[depth] + self.stage_blocks = stage_blocks[:num_stages] + self.inplanes = stem_channels + + self._make_stem_layer(in_channels, stem_channels) + + self.res_layers = [] + for i, num_blocks in enumerate(self.stage_blocks): + stride = strides[i] + dilation = dilations[i] + dcn = self.dcn if self.stage_with_dcn[i] else None + if plugins is not None: + stage_plugins = self.make_stage_plugins(plugins, i) + else: + stage_plugins = None + planes = base_channels * 2**i + res_layer = self.make_res_layer( + block=self.block, + inplanes=self.inplanes, + planes=planes, + num_blocks=num_blocks, + stride=stride, + dilation=dilation, + style=self.style, + avg_down=self.avg_down, + with_cp=with_cp, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + dcn=dcn, + plugins=stage_plugins, + init_cfg=block_init_cfg) + self.inplanes = planes * self.block.expansion + layer_name = f'layer{i + 1}' + self.add_module(layer_name, res_layer) + self.res_layers.append(layer_name) + + self._freeze_stages() + + self.feat_dim = self.block.expansion * base_channels * 2**( + len(self.stage_blocks) - 1) + + def make_stage_plugins(self, plugins, stage_idx): + """Make plugins for ResNet ``stage_idx`` th stage. + + Currently we support to insert ``context_block``, + ``empirical_attention_block``, ``nonlocal_block`` into the backbone + like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of + Bottleneck. + + An example of plugins format could be: + + Examples: + >>> plugins=[ + ... dict(cfg=dict(type='xxx', arg1='xxx'), + ... stages=(False, True, True, True), + ... position='after_conv2'), + ... dict(cfg=dict(type='yyy'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='1'), + ... stages=(True, True, True, True), + ... position='after_conv3'), + ... dict(cfg=dict(type='zzz', postfix='2'), + ... stages=(True, True, True, True), + ... position='after_conv3') + ... ] + >>> self = ResNet(depth=18) + >>> stage_plugins = self.make_stage_plugins(plugins, 0) + >>> assert len(stage_plugins) == 3 + + Suppose ``stage_idx=0``, the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->conv3->yyy->zzz1->zzz2 + + Suppose 'stage_idx=1', the structure of blocks in the stage would be: + + .. code-block:: none + + conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2 + + If stages is missing, the plugin would be applied to all stages. + + Args: + plugins (list[dict]): List of plugins cfg to build. The postfix is + required if multiple same type plugins are inserted. + stage_idx (int): Index of stage to build + + Returns: + list[dict]: Plugins for current stage + """ + stage_plugins = [] + for plugin in plugins: + plugin = plugin.copy() + stages = plugin.pop('stages', None) + assert stages is None or len(stages) == self.num_stages + # whether to insert plugin into current stage + if stages is None or stages[stage_idx]: + stage_plugins.append(plugin) + + return stage_plugins + + def make_res_layer(self, **kwargs): + """Pack all blocks in a stage into a ``ResLayer``.""" + return ResLayer(**kwargs) + + @property + def norm1(self): + """nn.Module: the normalization layer named "norm1" """ + return getattr(self, self.norm1_name) + + def _make_stem_layer(self, in_channels, stem_channels): + if self.deep_stem: + self.stem = nn.Sequential( + build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels // 2, + kernel_size=3, + stride=2, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels // 2, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels // 2)[1], + nn.ReLU(inplace=True), + build_conv_layer( + self.conv_cfg, + stem_channels // 2, + stem_channels, + kernel_size=3, + stride=1, + padding=1, + bias=False), + build_norm_layer(self.norm_cfg, stem_channels)[1], + nn.ReLU(inplace=True)) + else: + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels, + stem_channels, + kernel_size=7, + stride=2, + padding=3, + bias=False) + self.norm1_name, norm1 = build_norm_layer( + self.norm_cfg, stem_channels, postfix=1) + self.add_module(self.norm1_name, norm1) + self.relu = nn.ReLU(inplace=True) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + if self.deep_stem: + self.stem.eval() + for param in self.stem.parameters(): + param.requires_grad = False + else: + self.norm1.eval() + for m in [self.conv1, self.norm1]: + for param in m.parameters(): + param.requires_grad = False + + for i in range(1, self.frozen_stages + 1): + m = getattr(self, f'layer{i}') + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def forward(self, x): + """Forward function.""" + if self.deep_stem: + x = self.stem(x) + else: + x = self.conv1(x) + x = self.norm1(x) + x = self.relu(x) + x = mx_driving.fused.npu_max_pool2d(x, 3, 2, 1) + outs = [] + for i, layer_name in enumerate(self.res_layers): + res_layer = getattr(self, layer_name) + x = res_layer(x) + if i in self.out_indices: + outs.append(x) + return tuple(outs) + + def train(self, mode=True): + """Convert the model into training mode while keep normalization layer + freezed.""" + super(ResNet, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + # trick: eval have effect on BatchNorm only + if isinstance(m, _BatchNorm): + m.eval() + + +@BACKBONES.register_module() +class ResNetV1d(ResNet): + r"""ResNetV1d variant described in `Bag of Tricks + `_. + + Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in + the input stem with three 3x3 convs. And in the downsampling block, a 2x2 + avg_pool with stride 2 is added before conv, whose stride is changed to 1. + """ + + def __init__(self, **kwargs): + super(ResNetV1d, self).__init__( + deep_stem=True, avg_down=True, **kwargs) diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/replace_patch.sh b/PyTorch/built-in/autonoumous_driving/SurroundOcc/replace_patch.sh index 3266ac0295..30f65a843d 100644 --- a/PyTorch/built-in/autonoumous_driving/SurroundOcc/replace_patch.sh +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/replace_patch.sh @@ -7,5 +7,8 @@ do done cp -f patch/torch/conv.py ${packages_path}/torch/nn/modules/conv.py +cp -f patch/mmdet/resnet.py ${packages_path}/mmdet/models/backbones/resnet.py cp -f patch/mmcv/optimizer.py mmcv/mmcv/runner/hooks/optimizer.py cp -f patch/mmcv/epoch_based_runner.py mmcv/mmcv/runner/epoch_based_runner.py +cp -f patch/mmcv/distributed.py mmcv/mmcv/parallel/distributed.py +cp -f patch/mmcv/modulated_deform_conv.py mmcv/mmcv/ops/modulated_deform_conv.py diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/test/env_npu.sh b/PyTorch/built-in/autonoumous_driving/SurroundOcc/test/env_npu.sh index 19880a441e..e7fdd1dc22 100644 --- a/PyTorch/built-in/autonoumous_driving/SurroundOcc/test/env_npu.sh +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/test/env_npu.sh @@ -19,6 +19,12 @@ export ASCEND_SLOG_PRINT_TO_STDOUT=0 export ASCEND_GLOBAL_LOG_LEVEL=3 #设置Event日志开启标志,0-关闭/1-开启 export ASCEND_GLOBAL_EVENT_ENABLE=0 +# 设置是否开启 taskque, 0-关闭/1-开启 +export TASK_QUEUE_ENABLE=2 +# 绑核 +export CPU_AFFINITY_CONF=1 +# 设置是否开启 combined 标志, 0-关闭/1-开启 +export COMBINED_ENABLE=1 #设置device侧日志登记为error diff --git a/PyTorch/built-in/autonoumous_driving/SurroundOcc/tools/train.py b/PyTorch/built-in/autonoumous_driving/SurroundOcc/tools/train.py index 89644e04dc..b503cdcc86 100644 --- a/PyTorch/built-in/autonoumous_driving/SurroundOcc/tools/train.py +++ b/PyTorch/built-in/autonoumous_driving/SurroundOcc/tools/train.py @@ -97,13 +97,14 @@ def parse_args(): default='none', help='job launcher') parser.add_argument('--local_rank', type=int, default=0) + parser.add_argument('--local-rank', type=int, default=0) parser.add_argument( '--autoscale-lr', action='store_true', help='automatically scale lr with the number of gpus') args = parser.parse_args() if 'LOCAL_RANK' not in os.environ: - os.environ['LOCAL_RANK'] = str(args.local_rank) + os.environ['LOCAL_RANK'] = str(args.local-rank) if args.options and args.cfg_options: raise ValueError( -- Gitee