From caeff2c874b7b30aeb65f62222f3595564b53562 Mon Sep 17 00:00:00 2001 From: xiaxia3 Date: Thu, 10 Mar 2022 11:18:13 +0800 Subject: [PATCH] =?UTF-8?q?1.5=E7=89=88=E6=9C=AC=E4=BA=B2=E5=92=8C?= =?UTF-8?q?=E5=BA=93module=E8=BF=81=E7=A7=BB=E5=88=B01.8=E7=89=88=E6=9C=AC?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/contrib/__init__.py | 0 torch_npu/contrib/optimized_lib/__init__.py | 27 +++ .../contrib/optimized_lib/module/__init__.py | 31 +++ .../optimized_lib/module/activations.py | 121 +++++++++++ .../module/bidirectional_lstm.py | 102 +++++++++ .../optimized_lib/module/channel_shuffle.py | 194 ++++++++++++++++++ .../optimized_lib/module/crossentropy.py | 63 ++++++ .../optimized_lib/module/ps_roi_pooling.py | 100 +++++++++ .../contrib/optimized_lib/module/roi_align.py | 127 ++++++++++++ 9 files changed, 765 insertions(+) create mode 100644 torch_npu/contrib/__init__.py create mode 100644 torch_npu/contrib/optimized_lib/__init__.py create mode 100644 torch_npu/contrib/optimized_lib/module/__init__.py create mode 100644 torch_npu/contrib/optimized_lib/module/activations.py create mode 100644 torch_npu/contrib/optimized_lib/module/bidirectional_lstm.py create mode 100644 torch_npu/contrib/optimized_lib/module/channel_shuffle.py create mode 100644 torch_npu/contrib/optimized_lib/module/crossentropy.py create mode 100644 torch_npu/contrib/optimized_lib/module/ps_roi_pooling.py create mode 100644 torch_npu/contrib/optimized_lib/module/roi_align.py diff --git a/torch_npu/contrib/__init__.py b/torch_npu/contrib/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/torch_npu/contrib/optimized_lib/__init__.py b/torch_npu/contrib/optimized_lib/__init__.py new file mode 100644 index 0000000000..6b6ad7373d --- /dev/null +++ b/torch_npu/contrib/optimized_lib/__init__.py @@ -0,0 +1,27 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .module import ChannelShuffle, LabelSmoothingCrossEntropy, ROIAlign, , Mish, BiLSTM, PSROIPool, SiLU, Swish + +__all__ = [ + # from module + "ChannelShuffle", + "LabelSmoothingCrossEntropy", + "ROIAlign", + "Mish", + "BiLSTM", + "PSROIPool", + "SiLU", + "Swish", +] diff --git a/torch_npu/contrib/optimized_lib/module/__init__.py b/torch_npu/contrib/optimized_lib/module/__init__.py new file mode 100644 index 0000000000..9dacc435d0 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .channel_shuffle import ChannelShuffle +from .crossentropy import LabelSmoothingCrossEntropy +from .roi_align import ROIAlign +from .activations import Mish, SiLU, Swish +from .bidirectional_lstm import BiLSTM +from .ps_roi_pooling import PSROIPool + +__all__ = [ + "ChannelShuffle", + "LabelSmoothingCrossEntropy", + "ROIAlign", + "Mish", + "BiLSTM", + "PSROIPool", + "SiLU", + "Swish", +] diff --git a/torch_npu/contrib/optimized_lib/module/activations.py b/torch_npu/contrib/optimized_lib/module/activations.py new file mode 100644 index 0000000000..0511622615 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/activations.py @@ -0,0 +1,121 @@ +# Copyright (c) 2021, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn as nn + +class Mish(nn.Module): + def __init__(self): + r"""Applies an NPU based Mish operation. + + Origin CUDA implement link: + https://github.com/thomasbrandon/mish-cuda + + Paper link: + [Mish: A Self Regularized Non-Monotonic Activation Function] + (https://www.bmvc2020-conference.com/assets/papers/0928.pdf) + + Official implementation based on PyTorch link: + https://github.com/digantamisra98/Mish/blob/master/Mish/Torch/mish.py + + The calculation formula is as follows: + mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) + + .. note:: + Mish exists in the official version in PyTorch 1.9.0. + Currently, the PyTorch version adapted for NPU is 1.5.0, + so Mish needs to be defined as an additional module. + + Examples:: + >>> m = nnn.Mish() + >>> input_tensor = torch.randn(2, 32, 5, 5) + >>> output = m(input_tensor) + """ + super(Mish, self).__init__() + + def forward(self, x): + x = torch_npu.npu_mish(x) + return x + +class SiLU(nn.Module): + def __init__(self): + r"""Applies an NPU based Sigmoid Linear Unit (SiLU) function, element-wise. + The SiLU function is also known as the swish function. + + .. math:: + \text{silu}(x) = x * \sigma(x), \text{where } \sigma(x) \text{ is the logistic sigmoid.} + + .. note:: + See `Gaussian Error Linear Units (GELUs) `_ + where the SiLU (Sigmoid Linear Unit) was originally coined, and see + `Sigmoid-Weighted Linear Units for Neural Network Function Approximation + in Reinforcement Learning `_ and `Swish: + a Self-Gated Activation Function `_ + where the SiLU was experimented with later. + + SiLU exists in the official version since PyTorch 1.7.0. + Currently, the PyTorch version adapted for NPU is 1.5.0, + so SiLU needs to be defined as an additional module. + + Examples:: + >>> m = nnn.SiLU() + >>> input_tensor = torch.randn(2, 32, 5, 5) + >>> output = m(input_tensor) + """ + super(SiLU, self).__init__() + + def forward(self, x): + x = torch_npu.npu_silu(x) + return x + +Swish = SiLU + +if __name__ == '__main__': + torch.npu.set_device('npu:0') + input_tensor = torch.randn(2, 32, 4, 4) + input_tensor.requires_grad = True + model = Mish() + + input_tensor = input_tensor.npu() + model = model.npu() + + o = model(input_tensor) + l = o.sum() + l.backward() + + o = model(input_tensor.half()) + l = o.sum() + l.backward() + + torch.npu.synchronize() + print('Mish test success.') + + input_tensor = torch.randn(2, 32, 4, 4) + input_tensor.requires_grad = True + model = SiLU() + + input_tensor = input_tensor.npu() + model = model.npu() + + o = model(input_tensor) + l = o.sum() + l.backward() + + o = model(input_tensor.half()) + l = o.sum() + l.backward() + + torch.npu.synchronize() + print('SiLU test success.') diff --git a/torch_npu/contrib/optimized_lib/module/bidirectional_lstm.py b/torch_npu/contrib/optimized_lib/module/bidirectional_lstm.py new file mode 100644 index 0000000000..e3ac0f52a9 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/bidirectional_lstm.py @@ -0,0 +1,102 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu + +class BiLSTM(torch.nn.Module): + r"""Applies an NPU compatible bidirectional LSTM operation to an input + sequence. + + The implementation of this BidirectionalLSTM is mainly based on the principle of bidirectional LSTM. + Since NPU do not support the parameter bidirectional in torch.nn.lstm to be True, + we reimplement it by joining two unidirection LSTM together to form a bidirectional LSTM + + Paper: [Bidirectional recurrent neural networks] + https://ieeexplore.ieee.org/document/650093 + + Args: + input_size: The number of expected features in the input `x` + hidden_size: The number of features in the hidden state `h` + + + Inputs: input, (h_0, c_0) + - **input** of shape `(seq_len, batch, input_size)`: tensor containing the features + of the input sequence. + The input can also be a packed variable length sequence. + See :func:`torch.nn.utils.rnn.pack_padded_sequence` or + :func:`torch.nn.utils.rnn.pack_sequence` for details. + - **h_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial hidden state for each element in the batch. + If the LSTM is bidirectional, num_directions should be 2, else it should be 1. + - **c_0** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the initial cell state for each element in the batch. + + If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. + + + Outputs: output, (h_n, c_n) + - **output** of shape `(seq_len, batch, num_directions * hidden_size)`: tensor + containing the output features `(h_t)` from the last layer of the LSTM, + for each `t`. If a :class:`torch.nn.utils.rnn.PackedSequence` has been + given as the input, the output will also be a packed sequence. + + For the unpacked case, the directions can be separated + using ``output.view(seq_len, batch, num_directions, hidden_size)``, + with forward and backward being direction `0` and `1` respectively. + Similarly, the directions can be separated in the packed case. + - **h_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the hidden state for `t = seq_len`. + + Like *output*, the layers can be separated using + ``h_n.view(num_layers, num_directions, batch, hidden_size)`` and similarly for *c_n*. + - **c_n** of shape `(num_layers * num_directions, batch, hidden_size)`: tensor + containing the cell state for `t = seq_len`. + + + Examples:: + >>> r = BiLSTM(512, 256) + >>> input_tensor = torch.randn(26, 2560, 512) + >>> output = r(input_tensor) + """ + def __init__(self, input_size, hidden_size): + super(BiLSTM, self).__init__() + + self.fw_rnn = torch.nn.LSTM(input_size, hidden_size, bidirectional=False) + self.bw_rnn = torch.nn.LSTM(input_size, hidden_size, bidirectional=False) + + def forward(self, inputs): + input_fw = inputs + recurrent_fw, _ = self.fw_rnn(input_fw) + input_bw = torch.flip(inputs, [0]) + recurrent_bw, _ = self.bw_rnn(input_bw) + recurrent_bw = torch.flip(recurrent_bw, [0]) + recurrent = torch.cat((recurrent_fw, recurrent_bw), 2) + + return recurrent + + +if __name__ == '__main__': + x = torch.randn(26, 2560, 512) + x.requires_grad = True + + torch.npu.set_device(0) + x = x.npu() + rnn = BiLSTM(512, 256).npu() + x.retain_grad() + output = rnn(x) + print('test forward: ', output) + output.backward(torch.ones(x.size(), dtype=torch.float).npu()) + x_grad = x.grad + print('test grad ', x_grad) diff --git a/torch_npu/contrib/optimized_lib/module/channel_shuffle.py b/torch_npu/contrib/optimized_lib/module/channel_shuffle.py new file mode 100644 index 0000000000..c182124e73 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/channel_shuffle.py @@ -0,0 +1,194 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import torch +import torch_npu +import torch.nn as nn + + +class ChannelShuffle(nn.Module): + r"""Applies an NPU compatible channel shuffle operation. + + The origin implement is https://github.com/pytorch/vision/blob/master/torchvision/models/shufflenetv2.py#L21 + + In order to avoid contiguous operation which is not efficient on npu, we replaced the original operation + with a rewrite of the same semantics. Two discontinuous operations are replaced, transpose and chunk. + + .. note:: + Only group=2 is implemented, modify other group scenarios yourself. + + Args: + in_channels (int): The total number of channels in the input tensors + groups (int): The number of shuffle groups. Default: 2 + split_shuffle (bool): Whether to execute the chunk after shuffle. Default: True + + Shape: + - Input: :math:`(N, C_{in}, L_{in})`, `(N, C_{in}, L_{in})` + - Output: :math:`(N, C_{out}, L_{out})` + + Examples:: + >>> x1 = torch.randn(2,32,7,7) + >>> x2 = torch.randn(2,32,7,7) + >>> m = ChannelShuffle(64, split_shuffle=True) + >>> output = m(x1, x2) + + """ + + def __init__(self, in_channels, groups=2, split_shuffle=True): + super(ChannelShuffle, self).__init__() + self.split_shuffle = split_shuffle + self.group_len = in_channels // groups + + # init out_channels + self.out_channels = np.array(list(range(in_channels))).reshape(groups, self.group_len).transpose(1, 0).flatten() + self.out_channels = torch.from_numpy(self.out_channels).long() + + # init index used in fp & bp + # Only group=2 is implemented, modify other group scenarios yourself. + if self.split_shuffle: + self.fp_index1 = self.out_channels[:self.group_len] + self.fp_index2 = self.out_channels[self.group_len:] + else: + self.fp_index = self.out_channels + self.bp_index1 = torch.tensor(list(range(0, in_channels, 2))) + self.bp_index2 = torch.tensor(list(range(1, in_channels, 2))) + + self.checked = False + + def check_self(self, x): + r"""Check device equipment between tensors. + """ + if self.bp_index1.device == x.device: + self.checked = True + return + + device = x.device + + if str(device).startswith('npu'): + if self.split_shuffle: + self.fp_index1 = self.fp_index1.int() + self.fp_index2 = self.fp_index2.int() + else: + self.fp_index = self.fp_index.int() + self.bp_index1 = self.bp_index1.int() + self.bp_index2 = self.bp_index2.int() + + if self.split_shuffle: + self.fp_index1 = self.fp_index1.to(device) + self.fp_index2 = self.fp_index2.to(device) + else: + self.fp_index = self.fp_index.to(device) + self.bp_index1 = self.bp_index1.to(device) + self.bp_index2 = self.bp_index2.to(device) + + def forward(self, x1, x2): + if not self.checked: + self.check_self(x1) + if self.split_shuffle: + if self.training: + output = IndexSelectHalfImplementation.apply(x1, x2, self.fp_index1, self.fp_index2, self.bp_index1, + self.bp_index2) + else: + output = indexselect_half_implementation_forward(x1, x2, self.fp_index1, self.fp_index2) + else: + if self.training: + output = IndexSelectFullImplementation.apply(x1, x2, self.fp_index, self.bp_index1, self.bp_index2) + else: + output = indexselect_full_implementation_forward(x1, x2, self.fp_index) + return output + +def indexselect_full_implementation_forward(x1, x2, fp_index): + x = torch.cat([x1, x2], dim=1) + result = x.index_select(1, fp_index) + return result + + +def indexselect_half_implementation_forward(x1, x2, fp_index1, fp_index2): + x = torch.cat([x1, x2], dim=1) + return x.index_select(1, fp_index1), x.index_select(1, fp_index2) + + +class IndexSelectFullImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, x1, x2, fp_index, bp_index1, bp_index2): + if str(x1.device).startswith('npu'): + # for training stream stable + stream = torch.npu.current_stream() + stream.synchronize() + + ctx.bp_index1 = bp_index1 + ctx.bp_index2 = bp_index2 + x = torch.cat([x1, x2], dim=1) + result = x.index_select(1, fp_index) + return result + + @staticmethod + def backward(ctx, grad_output): + if str(grad_output.device).startswith('npu'): + # for training stream stable + stream = torch.npu.current_stream() + stream.synchronize() + # convert to NCHW to avoid extra 5HD --> 4D + grad_output.data = grad_output.data.npu_format_cast(0) + + out1 = grad_output.index_select(1, ctx.bp_index1) + out2 = grad_output.index_select(1, ctx.bp_index2) + return out1, out2, None, None, None, None + + +class IndexSelectHalfImplementation(torch.autograd.Function): + @staticmethod + def forward(ctx, x1, x2, fp_index1, fp_index2, bp_index1, bp_index2): + ctx.bp_index1 = bp_index1 + ctx.bp_index2 = bp_index2 + x = torch.cat([x1, x2], dim=1) + return x.index_select(1, fp_index1), x.index_select(1, fp_index2) + + @staticmethod + def backward(ctx, grad_output1, grad_output2): + grad_output = torch.cat([grad_output1, grad_output2], 1) + out1 = grad_output.index_select(1, ctx.bp_index1) + out2 = grad_output.index_select(1, ctx.bp_index2) + return out1, out2, None, None, None, None + + +def main(): + device = 'cpu' + + if device.startswith('npu'): + torch.npu.set_device(device) + + + def tescase(split_shuffle=True): + x = torch.randn(2, 32, 7, 7) + conv = torch.nn.Conv2d(32, 32, 1) + model = ChannelShuffle(64, split_shuffle=split_shuffle) + + x = x.to(device) + conv = conv.to(device) + model = model.to(device) + + x1 = conv(x) + x2 = conv(x) + output = model(x1, x2) + loss = sum([i.sum() for i in output]) if split_shuffle else output.sum() + loss.backward() + + + tescase(split_shuffle=True) + tescase(split_shuffle=False) + +if __name__ == '__main__': + main() diff --git a/torch_npu/contrib/optimized_lib/module/crossentropy.py b/torch_npu/contrib/optimized_lib/module/crossentropy.py new file mode 100644 index 0000000000..6eba982fb8 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/crossentropy.py @@ -0,0 +1,63 @@ +# Copyright (c) 2020, Huawei Technologies.All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch_npu +import torch.nn as nn + + +class LabelSmoothingCrossEntropy(nn.Module): + """CrossEntropy with LabelSmoothing using npu api. + + Paper: [Rethinking the Inception Architecture for Computer Vision] + https://arxiv.org/pdf/1512.00567.pdf + + Args: + smooth_factor (float): default 0. If label_smoothing using, using 0.1([0, 1]) instead. + num_classes (float): classes numbers using for onehot. + + Returns: + float: tensors of shape (k, 5) and (k, 1). Labels are 0-based. + """ + + def __init__(self, num_classes=1000, smooth_factor=0.): + super(LabelSmoothingCrossEntropy, self).__init__() + self.on_value = 1.0 - smooth_factor + self.off_value = 1.0 * smooth_factor / (num_classes - 1) + + def forward(self, pred, target): + one_hot_label = torch_npu.npu_one_hot(target.int(), -1, pred.size(1), self.on_value, self.off_value) + loss = torch_npu.npu_softmax_cross_entropy_with_logits(pred, one_hot_label) + + loss = torch.mean(loss, [0], keepdim=False, dtype=torch.float32) + return loss + + +if __name__ == '__main__': + x = torch.randn(2, 10) + x.requires_grad = True + y = torch.randint(0, 10, size=(2,)) + + torch.npu.set_device(0) + x = x.npu() + y = y.npu() + m = LabelSmoothingCrossEntropy(10) + l = m(x, y) + l.backward() + print('test ce ok, loss is ', l) + + m = LabelSmoothingCrossEntropy(10, 0.1) + l = m(x, y) + l.backward() + print('test lsce ok, loss is ', l) diff --git a/torch_npu/contrib/optimized_lib/module/ps_roi_pooling.py b/torch_npu/contrib/optimized_lib/module/ps_roi_pooling.py new file mode 100644 index 0000000000..14f93e627a --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/ps_roi_pooling.py @@ -0,0 +1,100 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +from torch import nn + +class PSROIPool(nn.Module): + def __init__(self, pooled_height=7, pooled_width=7, spatial_scale=1 / 16.0, group_size=7, output_dim=22): + """ROIAlign using npu api. + + Origin implement is + https://github.com/RebornL/RFCN-pytorch.1.0/blob/master/lib/model/roi_layers/ps_roi_pool.py + + Args: + pooled_height (int): pooled_height + pooled_width (int): pooled_width + spatial_scale (float): scale the input boxes by this number + group_size (int): number of groups encoding position sensitive score maps + output_dim (int):number of output channels + + Note: + only pooled_height == pooled_width == group_size implemented. + + Examples:: + >>> model = PSROIPool(pooled_height=7, pooled_width=7, spatial_scale=1 / 16.0, group_size=7, output_dim=22) + + .. _R-FCN\: Object Detection via Region-based Fully Convolutional Networks + https://arxiv.org/abs/1605.06409 + """ + + super(PSROIPool, self).__init__() + + assert (pooled_height == pooled_width == group_size), \ + "only pooled_height == pooled_width == group_size supported." + + self.group_size = group_size + self.spatial_scale = spatial_scale + self.output_dim = output_dim + + def forward(self, features, rois): + ''' + rois needs to follow the specified format, please refer to get_random_rois function in this scripts. + ''' + + return torch_npu.npu_ps_roi_pooling(features, + rois, + self.spatial_scale, + self.group_size, + self.output_dim) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "pooled_width=" + str(self.pooled_width) + tmpstr += ", pooled_height=" + str(self.pooled_height) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", group_size=" + str(self.group_size) + tmpstr += ", output_dim=" + str(self.output_dim) + tmpstr += ")" + return tmpstr + + +def get_random_rois(shape): + rois_init = torch.zeros(shape) + for i in range(shape[0]): + for j in range(shape[1]): + pi1 = torch.rand(1, 2).uniform_(0, 10) + pi2 = torch.rand(1, 2).uniform_(10, 100) + boxi = torch.cat((pi1, pi2), 1) + n = torch.tensor([[float(i)]]) + boxi = torch.cat((n, boxi), 1) + rois_init[i, j, :] = boxi + return rois_init + + +if __name__ == "__main__": + cls_feat = torch.randn(4, 1078, 84, 84).float() + cls_feat.requires_grad = True + rois_tensor = get_random_rois((4, 128, 5)).permute(0, 2, 1).float() + + model = PSROIPool(pooled_height=7, pooled_width=7, spatial_scale=1 / 16.0, group_size=7, output_dim=22) + + torch.npu.set_device(0) + cls_feat = cls_feat.npu() + rois_tensor = rois_tensor.npu() + + x = model(cls_feat, rois_tensor) # 512,22,7,7 + l = x.sum() + l.backward() diff --git a/torch_npu/contrib/optimized_lib/module/roi_align.py b/torch_npu/contrib/optimized_lib/module/roi_align.py new file mode 100644 index 0000000000..78674bdbf8 --- /dev/null +++ b/torch_npu/contrib/optimized_lib/module/roi_align.py @@ -0,0 +1,127 @@ +# Copyright (c) 2020 Huawei Technologies Co., Ltd +# Copyright (c) 2019, Facebook CORPORATION. +# All rights reserved. +# +# Licensed under the BSD 3-Clause License (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://opensource.org/licenses/BSD-3-Clause +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import torch +import torch_npu +from torch import nn + +from torch.nn.modules.utils import _pair +from torch.autograd import Function +from torch.autograd.function import once_differentiable + + +class _ROIAlign(Function): + @staticmethod + def forward(ctx, input_tensor, roi, output_size, spatial_scale, sampling_ratio, aligned): + ctx.save_for_backward(roi) + ctx.output_size = _pair(output_size) + ctx.spatial_scale = spatial_scale + ctx.sampling_ratio = sampling_ratio + ctx.input_shape = input_tensor.size() + ctx.aligned = aligned + if aligned: + roi_end_mode = 3 + else: + roi_end_mode = 0 + output = torch_npu.npu_roi_align( + input_tensor, roi, spatial_scale, + output_size[0], output_size[1], sampling_ratio, roi_end_mode) + + return output + + @staticmethod + @once_differentiable + def backward(ctx, grad_output): + (rois,) = ctx.saved_tensors + output_size = ctx.output_size + spatial_scale = ctx.spatial_scale + sampling_ratio = ctx.sampling_ratio + bs, ch, h, w = ctx.input_shape + + grad_input = torch_npu.npu_roi_alignbk( + grad_output, rois, ctx.input_shape, + output_size[0], output_size[1], + spatial_scale, sampling_ratio) + + return grad_input, None, None, None, None, None + + +roi_align = _ROIAlign.apply + + +# NOTE: torchvision's RoIAlign has a different default aligned=False +class ROIAlign(nn.Module): + def __init__(self, output_size, spatial_scale, sampling_ratio, aligned=True): + """ROIAlign using npu api. + + Origin implement from detectron2 is + https://github.com/facebookresearch/detectron2/blob/master/detectron2/layers/roi_align.py#L7 + + The input parameters of the interface are the same, but due to the different implementation of the operator, + the accuracy is different from that of CPU and GPU. + + Args: + output_size (tuple): h, w + spatial_scale (float): scale the input boxes by this number + sampling_ratio (int): number of inputs samples to take for each output + sample. 0 to take samples densely. + aligned (bool): if False, use the legacy implementation in + Detectron. If True, align the results more perfectly. + + Note: + The meaning of aligned=True: + + Given a continuous coordinate c, its two neighboring pixel indices (in our + pixel model) are computed by floor(c - 0.5) and ceil(c - 0.5). For example, + c=1.3 has pixel neighbors with discrete indices [0] and [1] (which are sampled + from the underlying signal at continuous coordinates 0.5 and 1.5). But the original + roi_align (aligned=False) does not subtract the 0.5 when computing neighboring + pixel indices and therefore it uses pixels with a slightly incorrect alignment + (relative to our pixel model) when performing bilinear interpolation. + + With `aligned=True`, + we first appropriately scale the ROI and then shift it by -0.5 + prior to calling roi_align. This produces the correct neighbors; see + detectron2/tests/test_roi_align.py for verification. + + The difference does not make a difference to the model's performance if + ROIAlign is used together with conv layers. + """ + super(ROIAlign, self).__init__() + self.output_size = output_size + self.spatial_scale = spatial_scale + self.sampling_ratio = sampling_ratio + self.aligned = aligned + + def forward(self, input_tensor, rois): + """ + Args: + input_tensor: NCHW images + rois: Bx5 boxes. First column is the index into N. The other 4 columns are xyxy. + """ + assert rois.dim() == 2 and rois.size(1) == 5 + return roi_align( + input_tensor.float(), rois, self.output_size, + self.spatial_scale, self.sampling_ratio, self.aligned + ) + + def __repr__(self): + tmpstr = self.__class__.__name__ + "(" + tmpstr += "output_size=" + str(self.output_size) + tmpstr += ", spatial_scale=" + str(self.spatial_scale) + tmpstr += ", sampling_ratio=" + str(self.sampling_ratio) + tmpstr += ", aligned=" + str(self.aligned) + tmpstr += ")" + return tmpstr -- Gitee