diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/LICENSE b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..cd31b288c87e9274a266266b27e453e493ca2244 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/LICENSE @@ -0,0 +1,29 @@ +BSD 3-Clause License + +Copyright (c) 2019, Charles Shang +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/README.md b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/README.md new file mode 100644 index 0000000000000000000000000000000000000000..0ddcf182a2ad9b7305c0e66970fb34161331f866 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/README.md @@ -0,0 +1,60 @@ +## Deformable Convolutional Networks V2 with Pytorch + +### Build +```bash + ./make.sh # build + python test.py # run examples and gradient check +``` + +### An Example +- deformable conv +```python + from dcn_v2 import DCN + input = torch.randn(2, 64, 128, 128).cuda() + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() + output = dcn(input) + print(output.shape) +``` +- deformable roi pooling +```python + from dcn_v2 import DCNPooling + input = torch.randn(2, 32, 64, 64).cuda() + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + # wrap all things (offset and mask) in DCNPooling + dpooling = DCNPooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + + dout = dpooling(input, rois) +``` + +### Known Issues: + +- [x] Gradient check w.r.t offset (solved) +- [ ] Backward is not reentrant (minor) + +This is an adaption of the official [Deformable-ConvNets](https://github.com/msracver/Deformable-ConvNets/tree/master/DCNv2_op). + +I have ran the gradient check for many times with DOUBLE type. Every tensor **except offset** passes. +However, when I set the offset to 0.5, it passes. I'm still wondering what cause this problem. Is it because some +non-differential points? + +Update: all gradient check passes with double precision. + +Another issue is that it raises `RuntimeError: Backward is not reentrant`. However, the error is very small (`<1e-7` for +float `<1e-15` for double), +so it may not be a serious problem (?) + +Please post an issue or PR if you have any comments. + \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/__init__.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build.py new file mode 100644 index 0000000000000000000000000000000000000000..b93f2a90811c5f41e396836a8ddcd10d911042a6 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build.py @@ -0,0 +1,43 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2.c'] +headers = ['src/dcn_v2.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda.c'] + headers += ['src/dcn_v2_cuda.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda.cu.o'] + extra_objects += ['src/cuda/dcn_v2_psroi_pooling_cuda.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build_double.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build_double.py new file mode 100644 index 0000000000000000000000000000000000000000..02f3912820e5d37d75d9c40fa765457d8303bdf9 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/build_double.py @@ -0,0 +1,43 @@ +import os +import torch +from torch.utils.ffi import create_extension + + +sources = ['src/dcn_v2_double.c'] +headers = ['src/dcn_v2_double.h'] +defines = [] +with_cuda = False + +extra_objects = [] +if torch.cuda.is_available(): + print('Including CUDA code.') + sources += ['src/dcn_v2_cuda_double.c'] + headers += ['src/dcn_v2_cuda_double.h'] + defines += [('WITH_CUDA', None)] + extra_objects += ['src/cuda/dcn_v2_im2col_cuda_double.cu.o'] + extra_objects += ['src/cuda/dcn_v2_psroi_pooling_cuda_double.cu.o'] + with_cuda = True +else: + raise ValueError('CUDA is not available') + +extra_compile_args = ['-fopenmp', '-std=c99'] + +this_file = os.path.dirname(os.path.realpath(__file__)) +print(this_file) +sources = [os.path.join(this_file, fname) for fname in sources] +headers = [os.path.join(this_file, fname) for fname in headers] +extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] + +ffi = create_extension( + '_ext.dcn_v2_double', + headers=headers, + sources=sources, + define_macros=defines, + relative_to=__file__, + with_cuda=with_cuda, + extra_objects=extra_objects, + extra_compile_args=extra_compile_args +) + +if __name__ == '__main__': + ffi.build() diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e1bb7001cafd91315a5a95951738e02013b5a2b4 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2.py @@ -0,0 +1,171 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +import math +from torch import nn +from torch.nn.modules.utils import _pair + +from .dcn_v2_func import DCNv2Function +from .dcn_v2_func import DCNv2PoolingFunction + +class DCNv2(nn.Module): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2, self).__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size)) + self.bias = nn.Parameter(torch.Tensor(out_channels)) + self.reset_parameters() + + def reset_parameters(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + self.bias.data.zero_() + + def forward(self, input, offset, mask): + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) + + +class DCN(DCNv2): + + def __init__(self, in_channels, out_channels, + kernel_size, stride, padding, + dilation=1, deformable_groups=1): + super(DCN, self).__init__(in_channels, out_channels, + kernel_size, stride, padding, dilation, deformable_groups) + + self.conv_offset_mask = nn.Conv2d(self.in_channels, + self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1], + kernel_size=self.kernel_size, + stride=(self.stride, self.stride), + padding=(self.padding, self.padding), + bias=True) + self.init_offset() + + def init_offset(self): + self.conv_offset_mask.weight.data.zero_() + self.conv_offset_mask.bias.data.zero_() + + def forward(self, input): + out = self.conv_offset_mask(input) + o1, o2, mask = torch.chunk(out, 3, dim=1) + offset = torch.cat((o1, o2), dim=1) + mask = torch.sigmoid(mask) + func = DCNv2Function(self.stride, self.padding, self.dilation, self.deformable_groups) + return func(input, offset, mask, self.weight, self.bias) + + +class DCNv2Pooling(nn.Module): + + def __init__(self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + super(DCNv2Pooling, self).__init__() + self.spatial_scale = spatial_scale + self.pooled_size = pooled_size + self.output_dim = output_dim + self.no_trans = no_trans + self.group_size = group_size + self.part_size = pooled_size if part_size is None else part_size + self.sample_per_part = sample_per_part + self.trans_std = trans_std + self.func = DCNv2PoolingFunction(self.spatial_scale, + self.pooled_size, + self.output_dim, + self.no_trans, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + + def forward(self, data, rois, offset): + + if self.no_trans: + offset = data.new() + return self.func(data, rois, offset) + +class DCNPooling(DCNv2Pooling): + + def __init__(self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0, + deform_fc_dim=1024): + super(DCNPooling, self).__init__(spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size, + part_size, + sample_per_part, + trans_std) + + self.deform_fc_dim = deform_fc_dim + + if not no_trans: + self.func_offset = DCNv2PoolingFunction(self.spatial_scale, + self.pooled_size, + self.output_dim, + True, + self.group_size, + self.part_size, + self.sample_per_part, + self.trans_std) + self.offset_fc = nn.Sequential( + nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_dim, self.deform_fc_dim), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 2) + ) + self.offset_fc[4].weight.data.zero_() + self.offset_fc[4].bias.data.zero_() + self.mask_fc = nn.Sequential( + nn.Linear(self.pooled_size * self.pooled_size * self.output_dim, self.deform_fc_dim), + nn.ReLU(inplace=True), + nn.Linear(self.deform_fc_dim, self.pooled_size * self.pooled_size * 1), + nn.Sigmoid() + ) + self.mask_fc[2].weight.data.zero_() + self.mask_fc[2].bias.data.zero_() + + def forward(self, data, rois): + if self.no_trans: + offset = data.new() + else: + n = rois.shape[0] + offset = data.new() + x = self.func_offset(data, rois, offset) + offset = self.offset_fc(x.view(n, -1)) + offset = offset.view(n, 2, self.pooled_size, self.pooled_size) + mask = self.mask_fc(x.view(n, -1)) + mask = mask.view(n, 1, self.pooled_size, self.pooled_size) + feat = self.func(data, rois, offset) * mask + return feat + return self.func(data, rois, offset) diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2_func.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2_func.py new file mode 100644 index 0000000000000000000000000000000000000000..7e98f49f47bd4c87d12d313fa855712d671b96a3 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/dcn_v2_func.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import torch +from torch.autograd import Function + +from ._ext import dcn_v2 as _backend +# from _ext import dcn_v2_double as _backend + + +class DCNv2Function(Function): + + def __init__(self, stride, padding, dilation=1, deformable_groups=1): + super(DCNv2Function, self).__init__() + self.stride = stride + self.padding = padding + self.dilation = dilation + self.deformable_groups = deformable_groups + + def forward(self, input, offset, mask, weight, bias): + if not input.is_cuda: + raise NotImplementedError + if weight.requires_grad or mask.requires_grad or offset.requires_grad or input.requires_grad: + self.save_for_backward(input, offset, mask, weight, bias) + output = input.new(*self._infer_shape(input, weight)) + self._bufs = [input.new(), input.new()] + _backend.dcn_v2_cuda_forward(input, weight, + bias, self._bufs[0], + offset, mask, + output, self._bufs[1], + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + return output + + def backward(self, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + input, offset, mask, weight, bias = self.saved_tensors + grad_input = input.new(*input.size()).zero_() + grad_offset = offset.new(*offset.size()).zero_() + grad_mask = mask.new(*mask.size()).zero_() + grad_weight = weight.new(*weight.size()).zero_() + grad_bias = bias.new(*bias.size()).zero_() + _backend.dcn_v2_cuda_backward(input, weight, + bias, self._bufs[0], + offset, mask, + self._bufs[1], + grad_input, grad_weight, + grad_bias, grad_offset, + grad_mask, grad_output, + weight.shape[2], weight.shape[3], + self.stride, self.stride, + self.padding, self.padding, + self.dilation, self.dilation, + self.deformable_groups) + + return grad_input, grad_offset, grad_mask, grad_weight, grad_bias + + def _infer_shape(self, input, weight): + n = input.size(0) + channels_out = weight.size(0) + height, width = input.shape[2:4] + kernel_h, kernel_w = weight.shape[2:4] + height_out = (height + 2 * self.padding - + (self.dilation * (kernel_h - 1) + 1)) // self.stride + 1 + width_out = (width + 2 * self.padding - (self.dilation * + (kernel_w - 1) + 1)) // self.stride + 1 + return (n, channels_out, height_out, width_out) + + +class DCNv2PoolingFunction(Function): + + def __init__(self, + spatial_scale, + pooled_size, + output_dim, + no_trans, + group_size=1, + part_size=None, + sample_per_part=4, + trans_std=.0): + super(DCNv2PoolingFunction, self).__init__() + self.spatial_scale = spatial_scale + self.pooled_size = pooled_size + self.output_dim = output_dim + self.no_trans = no_trans + self.group_size = group_size + self.part_size = pooled_size if part_size is None else part_size + self.sample_per_part = sample_per_part + self.trans_std = trans_std + + assert self.trans_std >= 0.0 and self.trans_std <= 1.0 + + def forward(self, data, rois, offset): + if not data.is_cuda: + raise NotImplementedError + + output = data.new(*self._infer_shape(data, rois)) + output_count = data.new(*self._infer_shape(data, rois)) + _backend.dcn_v2_psroi_pooling_cuda_forward(data, rois, offset, + output, output_count, + self.no_trans, self.spatial_scale, + self.output_dim, self.group_size, + self.pooled_size, self.part_size, + self.sample_per_part, self.trans_std) + + if data.requires_grad or rois.requires_grad or offset.requires_grad: + self.save_for_backward(data, rois, offset, output_count) + + return output + + def backward(self, grad_output): + if not grad_output.is_cuda: + raise NotImplementedError + + data, rois, offset, output_count = self.saved_tensors + grad_input = data.new(*data.size()).zero_() + grad_offset = offset.new(*offset.size()).zero_() + + _backend.dcn_v2_psroi_pooling_cuda_backward(grad_output, + data, + rois, + offset, + output_count, + grad_input, + grad_offset, + self.no_trans, + self.spatial_scale, + self.output_dim, + self.group_size, + self.pooled_size, + self.part_size, + self.sample_per_part, + self.trans_std) + return grad_input, None, grad_offset + + def _infer_shape(self, data, rois): + # _, c, h, w = data.shape[:4] + c = data.shape[1] + n = rois.shape[0] + return (n, self.output_dim, self.pooled_size, self.pooled_size) diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/make.sh b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/make.sh new file mode 100755 index 0000000000000000000000000000000000000000..d489f7ca221ba1d21227e7156bbafddd3680236a --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/make.sh @@ -0,0 +1,14 @@ +#!/usr/bin/env bash +cd src/cuda + +# compile dcn +nvcc -c -o dcn_v2_im2col_cuda.cu.o dcn_v2_im2col_cuda.cu -x cu -Xcompiler -fPIC +nvcc -c -o dcn_v2_im2col_cuda_double.cu.o dcn_v2_im2col_cuda_double.cu -x cu -Xcompiler -fPIC + +# compile dcn-roi-pooling +nvcc -c -o dcn_v2_psroi_pooling_cuda.cu.o dcn_v2_psroi_pooling_cuda.cu -x cu -Xcompiler -fPIC +nvcc -c -o dcn_v2_psroi_pooling_cuda_double.cu.o dcn_v2_psroi_pooling_cuda_double.cu -x cu -Xcompiler -fPIC + +cd - +python build.py +python build_double.py diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..ab22b1bd4496fb9444d8599238860f0ac31f3a1b --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.cu @@ -0,0 +1,387 @@ +#include "dcn_v2_im2col_cuda.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + + +__device__ float dmcn_im2col_bilinear(const float *bottom_data, const int data_width, + const int height, const int width, float h, float w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + float lh = h - h_low; + float lw = w - w_low; + float hh = 1 - lh, hw = 1 - lw; + + float v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + float v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + float v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + float v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + float w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + float val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ float dmcn_get_gradient_weight(float argmax_h, float argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ float dmcn_get_coordinate_weight(float argmax_h, float argmax_w, + const int height, const int width, const float *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + float weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const float *data_im, const float *data_offset, const float *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + float *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + float *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const float* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const float *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const float *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const float *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float val = static_cast(0); + const float h_im = h_in + i * dilation_h + offset_h; + const float w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const float map_h = i * dilation_h + offset_h; + //const float map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const float *data_col, const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + float *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + const float cur_inv_h_data = h_in + i * dilation_h + offset_h; + const float cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const float cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + float weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const float *data_col, const float *data_im, + const float *data_offset, const float *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + float *grad_offset, float *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + float val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const float *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const float *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const float *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const float *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const float offset_h = data_offset_ptr[data_offset_h_ptr]; + const float offset_w = data_offset_ptr[data_offset_w_ptr]; + const float mask = data_mask_ptr[data_mask_hw_ptr]; + float inv_h = h_in + i * dilation_h + offset_h; + float inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const float weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* data_col) { + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel + <<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float* data_col, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float* grad_im){ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel + <<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } + +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float* data_col, const float* data_im, const float* data_offset, const float* data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float* grad_offset, float* grad_mask) { + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel + <<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..3457e961f983b4d9312507260ec4bb89e0a40ea6 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda.h @@ -0,0 +1,100 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_IM2COL_CUDA +#define DCN_V2_IM2COL_CUDA + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cuda(cudaStream_t stream, + const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *data_col); + + void modulated_deformable_col2im_cuda(cudaStream_t stream, + const float *data_col, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, float *grad_im); + + void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const float *data_col, const float *data_im, const float *data_offset, const float *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + float *grad_offset, float *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.cu b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.cu new file mode 100644 index 0000000000000000000000000000000000000000..29cb048da68136263b61a3f2b3946a5523cb6c65 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.cu @@ -0,0 +1,399 @@ +#include "dcn_v2_im2col_cuda_double.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 512; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +#else +__device__ double atomicAdd(double* address, double val) +{ + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} +#endif + +__device__ double dmcn_im2col_bilinear(const double *bottom_data, const int data_width, + const int height, const int width, double h, double w) +{ + int h_low = floor(h); + int w_low = floor(w); + int h_high = h_low + 1; + int w_high = w_low + 1; + + double lh = h - h_low; + double lw = w - w_low; + double hh = 1 - lh, hw = 1 - lw; + + double v1 = 0; + if (h_low >= 0 && w_low >= 0) + v1 = bottom_data[h_low * data_width + w_low]; + double v2 = 0; + if (h_low >= 0 && w_high <= width - 1) + v2 = bottom_data[h_low * data_width + w_high]; + double v3 = 0; + if (h_high <= height - 1 && w_low >= 0) + v3 = bottom_data[h_high * data_width + w_low]; + double v4 = 0; + if (h_high <= height - 1 && w_high <= width - 1) + v4 = bottom_data[h_high * data_width + w_high]; + + double w1 = hh * hw, w2 = hh * lw, w3 = lh * hw, w4 = lh * lw; + + double val = (w1 * v1 + w2 * v2 + w3 * v3 + w4 * v4); + return val; +} + +__device__ double dmcn_get_gradient_weight(double argmax_h, double argmax_w, + const int h, const int w, const int height, const int width) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + if (h == argmax_h_low && w == argmax_w_low) + weight = (h + 1 - argmax_h) * (w + 1 - argmax_w); + if (h == argmax_h_low && w == argmax_w_high) + weight = (h + 1 - argmax_h) * (argmax_w + 1 - w); + if (h == argmax_h_high && w == argmax_w_low) + weight = (argmax_h + 1 - h) * (w + 1 - argmax_w); + if (h == argmax_h_high && w == argmax_w_high) + weight = (argmax_h + 1 - h) * (argmax_w + 1 - w); + return weight; +} + +__device__ double dmcn_get_coordinate_weight(double argmax_h, double argmax_w, + const int height, const int width, const double *im_data, + const int data_width, const int bp_dir) +{ + if (argmax_h <= -1 || argmax_h >= height || argmax_w <= -1 || argmax_w >= width) + { + //empty + return 0; + } + + int argmax_h_low = floor(argmax_h); + int argmax_w_low = floor(argmax_w); + int argmax_h_high = argmax_h_low + 1; + int argmax_w_high = argmax_w_low + 1; + + double weight = 0; + + if (bp_dir == 0) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += -1 * (argmax_w - argmax_w_low) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += (argmax_w_low + 1 - argmax_w) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_w - argmax_w_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + else if (bp_dir == 1) + { + if (argmax_h_low >= 0 && argmax_w_low >= 0) + weight += -1 * (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_low]; + if (argmax_h_low >= 0 && argmax_w_high <= width - 1) + weight += (argmax_h_low + 1 - argmax_h) * im_data[argmax_h_low * data_width + argmax_w_high]; + if (argmax_h_high <= height - 1 && argmax_w_low >= 0) + weight += -1 * (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_low]; + if (argmax_h_high <= height - 1 && argmax_w_high <= width - 1) + weight += (argmax_h - argmax_h_low) * im_data[argmax_h_high * data_width + argmax_w_high]; + } + + return weight; +} + +__global__ void modulated_deformable_im2col_gpu_kernel(const int n, + const double *data_im, const double *data_offset, const double *data_mask, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int num_channels, const int deformable_group, + const int height_col, const int width_col, + double *data_col) +{ + CUDA_KERNEL_LOOP(index, n) + { + // index index of output matrix + const int w_col = index % width_col; + const int h_col = (index / width_col) % height_col; + const int b_col = (index / width_col / height_col) % batch_size; + const int c_im = (index / width_col / height_col) / batch_size; + const int c_col = c_im * kernel_h * kernel_w; + + // compute deformable group index + const int deformable_group_index = c_im / channel_per_deformable_group; + + const int h_in = h_col * stride_h - pad_h; + const int w_in = w_col * stride_w - pad_w; + + double *data_col_ptr = data_col + ((c_col * batch_size + b_col) * height_col + h_col) * width_col + w_col; + //const double* data_im_ptr = data_im + ((b_col * num_channels + c_im) * height + h_in) * width + w_in; + const double *data_im_ptr = data_im + (b_col * num_channels + c_im) * height * width; + const double *data_offset_ptr = data_offset + (b_col * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + + const double *data_mask_ptr = data_mask + (b_col * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + for (int i = 0; i < kernel_h; ++i) + { + for (int j = 0; j < kernel_w; ++j) + { + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_col) * width_col + w_col; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_col) * width_col + w_col; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_col) * width_col + w_col; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double val = static_cast(0); + const double h_im = h_in + i * dilation_h + offset_h; + const double w_im = w_in + j * dilation_w + offset_w; + //if (h_im >= 0 && w_im >= 0 && h_im < height && w_im < width) { + if (h_im > -1 && w_im > -1 && h_im < height && w_im < width) + { + //const double map_h = i * dilation_h + offset_h; + //const double map_w = j * dilation_w + offset_w; + //const int cur_height = height - h_in; + //const int cur_width = width - w_in; + //val = dmcn_im2col_bilinear(data_im_ptr, width, cur_height, cur_width, map_h, map_w); + val = dmcn_im2col_bilinear(data_im_ptr, width, height, width, h_im, w_im); + } + *data_col_ptr = val * mask; + data_col_ptr += batch_size * height_col * width_col; + //data_col_ptr += height_col * width_col; + } + } + } +} + +__global__ void modulated_deformable_col2im_gpu_kernel(const int n, + const double *data_col, const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int deformable_group, + const int height_col, const int width_col, + double *grad_im) +{ + CUDA_KERNEL_LOOP(index, n) + { + const int j = (index / width_col / height_col / batch_size) % kernel_w; + const int i = (index / width_col / height_col / batch_size / kernel_w) % kernel_h; + const int c = index / width_col / height_col / batch_size / kernel_w / kernel_h; + // compute the start and end of the output + + const int deformable_group_index = c / channel_per_deformable_group; + + int w_out = index % width_col; + int h_out = (index / width_col) % height_col; + int b = (index / width_col / height_col) % batch_size; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + const int data_offset_h_ptr = ((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out; + const int data_offset_w_ptr = ((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out; + const int data_mask_hw_ptr = ((i * kernel_w + j) * height_col + h_out) * width_col + w_out; + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + const double cur_inv_h_data = h_in + i * dilation_h + offset_h; + const double cur_inv_w_data = w_in + j * dilation_w + offset_w; + + const double cur_top_grad = data_col[index] * mask; + const int cur_h = (int)cur_inv_h_data; + const int cur_w = (int)cur_inv_w_data; + for (int dy = -2; dy <= 2; dy++) + { + for (int dx = -2; dx <= 2; dx++) + { + if (cur_h + dy >= 0 && cur_h + dy < height && + cur_w + dx >= 0 && cur_w + dx < width && + abs(cur_inv_h_data - (cur_h + dy)) < 1 && + abs(cur_inv_w_data - (cur_w + dx)) < 1) + { + int cur_bottom_grad_pos = ((b * channels + c) * height + cur_h + dy) * width + cur_w + dx; + double weight = dmcn_get_gradient_weight(cur_inv_h_data, cur_inv_w_data, cur_h + dy, cur_w + dx, height, width); + atomicAdd(grad_im + cur_bottom_grad_pos, weight * cur_top_grad); + } + } + } + } +} + +__global__ void modulated_deformable_col2im_coord_gpu_kernel(const int n, + const double *data_col, const double *data_im, + const double *data_offset, const double *data_mask, + const int channels, const int height, const int width, + const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int channel_per_deformable_group, + const int batch_size, const int offset_channels, const int deformable_group, + const int height_col, const int width_col, + double *grad_offset, double *grad_mask) +{ + CUDA_KERNEL_LOOP(index, n) + { + double val = 0, mval = 0; + int w = index % width_col; + int h = (index / width_col) % height_col; + int c = (index / width_col / height_col) % offset_channels; + int b = (index / width_col / height_col) / offset_channels; + // compute the start and end of the output + + const int deformable_group_index = c / (2 * kernel_h * kernel_w); + const int col_step = kernel_h * kernel_w; + int cnt = 0; + const double *data_col_ptr = data_col + deformable_group_index * channel_per_deformable_group * batch_size * width_col * height_col; + const double *data_im_ptr = data_im + (b * deformable_group + deformable_group_index) * channel_per_deformable_group / kernel_h / kernel_w * height * width; + const double *data_offset_ptr = data_offset + (b * deformable_group + deformable_group_index) * 2 * kernel_h * kernel_w * height_col * width_col; + const double *data_mask_ptr = data_mask + (b * deformable_group + deformable_group_index) * kernel_h * kernel_w * height_col * width_col; + + const int offset_c = c - deformable_group_index * 2 * kernel_h * kernel_w; + + for (int col_c = (offset_c / 2); col_c < channel_per_deformable_group; col_c += col_step) + { + const int col_pos = (((col_c * batch_size + b) * height_col) + h) * width_col + w; + const int bp_dir = offset_c % 2; + + int j = (col_pos / width_col / height_col / batch_size) % kernel_w; + int i = (col_pos / width_col / height_col / batch_size / kernel_w) % kernel_h; + int w_out = col_pos % width_col; + int h_out = (col_pos / width_col) % height_col; + int w_in = w_out * stride_w - pad_w; + int h_in = h_out * stride_h - pad_h; + const int data_offset_h_ptr = (((2 * (i * kernel_w + j)) * height_col + h_out) * width_col + w_out); + const int data_offset_w_ptr = (((2 * (i * kernel_w + j) + 1) * height_col + h_out) * width_col + w_out); + const int data_mask_hw_ptr = (((i * kernel_w + j) * height_col + h_out) * width_col + w_out); + const double offset_h = data_offset_ptr[data_offset_h_ptr]; + const double offset_w = data_offset_ptr[data_offset_w_ptr]; + const double mask = data_mask_ptr[data_mask_hw_ptr]; + double inv_h = h_in + i * dilation_h + offset_h; + double inv_w = w_in + j * dilation_w + offset_w; + if (inv_h <= -1 || inv_w <= -1 || inv_h >= height || inv_w >= width) + { + inv_h = inv_w = -2; + } + else + { + mval += data_col_ptr[col_pos] * dmcn_im2col_bilinear(data_im_ptr + cnt * height * width, width, height, width, inv_h, inv_w); + } + const double weight = dmcn_get_coordinate_weight( + inv_h, inv_w, + height, width, data_im_ptr + cnt * height * width, width, bp_dir); + val += weight * data_col_ptr[col_pos] * mask; + cnt += 1; + } + // KERNEL_ASSIGN(grad_offset[index], offset_req, val); + grad_offset[index] = val; + if (offset_c % 2 == 0) + // KERNEL_ASSIGN(grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w], mask_req, mval); + grad_mask[(((b * deformable_group + deformable_group_index) * kernel_h * kernel_w + offset_c / 2) * height_col + h) * width_col + w] = mval; + } +} + +void modulated_deformable_im2col_cuda(cudaStream_t stream, + const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *data_col) +{ + // num_axes should be smaller than block size + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * batch_size * height_col * width_col; + modulated_deformable_im2col_gpu_kernel<<>>( + num_kernels, data_im, data_offset, data_mask, height_im, width_im, kernel_h, kenerl_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, channel_per_deformable_group, + batch_size, channels, deformable_group, height_col, width_col, data_col); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_im2col_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_cuda(cudaStream_t stream, + const double *data_col, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *grad_im) +{ + + const int channel_per_deformable_group = channels / deformable_group; + const int num_kernels = channels * kernel_h * kernel_w * batch_size * height_col * width_col; + modulated_deformable_col2im_gpu_kernel<<>>( + num_kernels, data_col, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_h, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, deformable_group, height_col, width_col, grad_im); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_cuda: %s\n", cudaGetErrorString(err)); + } +} + +void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kernel_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + double *grad_offset, double *grad_mask) +{ + const int num_kernels = batch_size * height_col * width_col * 2 * kernel_h * kernel_w * deformable_group; + const int channel_per_deformable_group = channels * kernel_h * kernel_w / deformable_group; + modulated_deformable_col2im_coord_gpu_kernel<<>>( + num_kernels, data_col, data_im, data_offset, data_mask, channels, height_im, width_im, + kernel_h, kernel_w, pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, channel_per_deformable_group, + batch_size, 2 * kernel_h * kernel_w * deformable_group, deformable_group, height_col, width_col, + grad_offset, grad_mask); + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in modulated_deformable_col2im_coord_cuda: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.h new file mode 100644 index 0000000000000000000000000000000000000000..a46169235d01594663c94279a4ced02f33d1aa5f --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_im2col_cuda_double.h @@ -0,0 +1,100 @@ +/*! + ******************* BEGIN Caffe Copyright Notice and Disclaimer **************** + * + * COPYRIGHT + * + * All contributions by the University of California: + * Copyright (c) 2014-2017 The Regents of the University of California (Regents) + * All rights reserved. + * + * All other contributions: + * Copyright (c) 2014-2017, the respective contributors + * All rights reserved. + * + * Caffe uses a shared copyright model: each contributor holds copyright over + * their contributions to Caffe. The project versioning records all such + * contribution and copyright details. If a contributor wants to further mark + * their specific copyright on a particular contribution, they should indicate + * their copyright solely in the commit message of the change when it is + * committed. + * + * LICENSE + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND + * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED + * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES + * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; + * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND + * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT + * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS + * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + * CONTRIBUTION AGREEMENT + * + * By contributing to the BVLC/caffe repository through pull-request, comment, + * or otherwise, the contributor releases their content to the + * license and copyright terms herein. + * + ***************** END Caffe Copyright Notice and Disclaimer ******************** + * + * Copyright (c) 2018 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file modulated_deformable_im2col.h + * \brief Function definitions of converting an image to + * column matrix based on kernel, padding, dilation, and offset. + * These functions are mainly used in deformable convolution operators. + * \ref: https://arxiv.org/abs/1811.11168 + * \author Yuwen Xiong, Haozhi Qi, Jifeng Dai, Xizhou Zhu, Han Hu + */ + +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_IM2COL_CUDA_DOUBLE +#define DCN_V2_IM2COL_CUDA_DOUBLE + +#ifdef __cplusplus +extern "C" +{ +#endif + + void modulated_deformable_im2col_cuda(cudaStream_t stream, + const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *data_col); + + void modulated_deformable_col2im_cuda(cudaStream_t stream, + const double *data_col, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, double *grad_im); + + void modulated_deformable_col2im_coord_cuda(cudaStream_t stream, + const double *data_col, const double *data_im, const double *data_offset, const double *data_mask, + const int batch_size, const int channels, const int height_im, const int width_im, + const int height_col, const int width_col, const int kernel_h, const int kenerl_w, + const int pad_h, const int pad_w, const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group, + double *grad_offset, double *grad_mask); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..295657c05fd1713acfa56d337b311892fdf6fd3c --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.cu @@ -0,0 +1,353 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ +#include "dcn_v2_psroi_pooling_cuda.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +__device__ float bilinear_interp( + const float *data, + const float x, + const float y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + float dist_x = (float)(x - x1); + float dist_y = (float)(y - y1); + float value11 = data[y1 * width + x1]; + float value12 = data[y2 * width + x1]; + float value21 = data[y1 * width + x2]; + float value22 = data[y2 * width + x2]; + float value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +__global__ void DeformablePSROIPoolForwardKernel( + const int count, + const float *bottom_data, + const float spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const float *bottom_rois, const float *bottom_trans, + const int no_trans, + const float trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + float *top_data, + float *top_count) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const float *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + float roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + float bin_size_h = roi_height / (float)(pooled_height); + float bin_size_w = roi_width / (float)(pooled_width); + + float sub_bin_size_h = bin_size_h / (float)(sample_per_part); + float sub_bin_size_w = bin_size_w / (float)(sample_per_part); + + int part_h = floor((float)(ph) / pooled_height * part_size); + int part_w = floor((float)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + float wstart = (float)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + float hstart = (float)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + float sum = 0; + int count = 0; + int gw = floor((float)(pw)*group_size / pooled_width); + int gh = floor((float)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + const float *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + float w = wstart + iw * sub_bin_size_w; + float h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + float val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? (float)(0) : sum / count; + top_count[index] = count; + } +} + +__global__ void DeformablePSROIPoolBackwardAccKernel( + const int count, + const float *top_diff, + const float *top_count, + const int num_rois, + const float spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + float *bottom_data_diff, float *bottom_trans_diff, + const float *bottom_data, + const float *bottom_rois, + const float *bottom_trans, + const int no_trans, + const float trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const float *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + float roi_start_w = (float)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + float roi_start_h = (float)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + float roi_end_w = (float)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + float roi_end_h = (float)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + float roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + float roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + float bin_size_h = roi_height / (float)(pooled_height); + float bin_size_w = roi_width / (float)(pooled_width); + + float sub_bin_size_h = bin_size_h / (float)(sample_per_part); + float sub_bin_size_w = bin_size_w / (float)(sample_per_part); + + int part_h = floor((float)(ph) / pooled_height * part_size); + int part_w = floor((float)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + float trans_x = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + float trans_y = no_trans ? (float)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + float wstart = (float)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + float hstart = (float)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + float diff_val = top_diff[index] / top_count[index]; + const float *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + float *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor((float)(pw)*group_size / pooled_width); + int gh = floor((float)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + float w = wstart + iw * sub_bin_size_w; + float h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + float dist_x = w - x0, dist_y = h - y0; + float q00 = (1 - dist_x) * (1 - dist_y); + float q01 = (1 - dist_x) * dist_y; + float q10 = dist_x * (1 - dist_y); + float q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) + { + continue; + } + float U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + float U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + float U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + float U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + float diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + float diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); + } + } + } +} + +void DeformablePSROIPoolForward(cudaStream_t stream, + const float *data, + const float *bbox, + const float *trans, + float *out, + float *top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + + const float *bottom_data = data; + const float *bottom_rois = bbox; + const float *bottom_trans = no_trans ? NULL : trans; + float *top_data = out; + float *top_count_data = top_count; + + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + DeformablePSROIPoolForwardKernel<<>>( + count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, + bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, + group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} + +void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, + const float *out_grad, + const float *data, + const float *bbox, + const float *trans, + const float *top_count, + float *in_grad, + float *trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + // LOG(INFO) << "DeformablePSROIPoolBackward"; + const float *top_diff = out_grad; + const float *bottom_data = data; + const float *bottom_rois = bbox; + const float *bottom_trans = no_trans ? NULL : trans; + float *bottom_data_diff = in_grad; + float *bottom_trans_diff = no_trans ? NULL : trans_grad; + const float *top_count_data = top_count; + + const int num_rois = num_bbox; + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + DeformablePSROIPoolBackwardAccKernel<<>>( + count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, + pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, + group_size, part_size, num_classes, channels_each_class); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..5fa2c6c2355b6c7aab63efd93fac0bfe81bc90a5 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda.h @@ -0,0 +1,66 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_PSROI_POOLING_CUDA +#define DCN_V2_PSROI_POOLING_CUDA + +#ifdef __cplusplus +extern "C" +{ +#endif + + void DeformablePSROIPoolForward(cudaStream_t stream, + const float *data, + const float *bbox, + const float *trans, + float *out, + float *top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + + void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, + const float *out_grad, + const float *data, + const float *bbox, + const float *trans, + const float *top_count, + float *in_grad, + float *trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu new file mode 100644 index 0000000000000000000000000000000000000000..ce05cc96d65a6bb88c3f8c9eddaa49504c5d332f --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.cu @@ -0,0 +1,368 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ +#include "dcn_v2_psroi_pooling_cuda_double.h" +#include +#include +#include + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +const int CUDA_NUM_THREADS = 1024; +inline int GET_BLOCKS(const int N) +{ + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#if !defined(__CUDA_ARCH__) || __CUDA_ARCH__ >= 600 +#else +__device__ double atomicAdd(double* address, double val) +{ + unsigned long long int* address_as_ull = (unsigned long long int*)address; + unsigned long long int old = *address_as_ull, assumed; + do { + assumed = old; + old = atomicCAS(address_as_ull, assumed, + __double_as_longlong(val + __longlong_as_double(assumed))); + } while (assumed != old); + return __longlong_as_double(old); +} +#endif + +__device__ double bilinear_interp( + const double *data, + const double x, + const double y, + const int width, + const int height) +{ + int x1 = floor(x); + int x2 = ceil(x); + int y1 = floor(y); + int y2 = ceil(y); + double dist_x = (double)(x - x1); + double dist_y = (double)(y - y1); + double value11 = data[y1 * width + x1]; + double value12 = data[y2 * width + x1]; + double value21 = data[y1 * width + x2]; + double value22 = data[y2 * width + x2]; + double value = (1 - dist_x) * (1 - dist_y) * value11 + (1 - dist_x) * dist_y * value12 + dist_x * (1 - dist_y) * value21 + dist_x * dist_y * value22; + return value; +} + +__global__ void DeformablePSROIPoolForwardKernel( + const int count, + const double *bottom_data, + const double spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const double *bottom_rois, const double *bottom_trans, + const int no_trans, + const double trans_std, + const int sample_per_part, + const int output_dim, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class, + double *top_data, + double *top_count) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const double *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + double roi_start_w = (double)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + double roi_start_h = (double)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + double roi_end_w = (double)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + double roi_end_h = (double)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + double roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + double roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + double bin_size_h = roi_height / (double)(pooled_height); + double bin_size_w = roi_width / (double)(pooled_width); + + double sub_bin_size_h = bin_size_h / (double)(sample_per_part); + double sub_bin_size_w = bin_size_w / (double)(sample_per_part); + + int part_h = floor((double)(ph) / pooled_height * part_size); + int part_w = floor((double)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + double trans_x = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + double trans_y = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + double wstart = (double)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + double hstart = (double)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + double sum = 0; + int count = 0; + int gw = floor((double)(pw)*group_size / pooled_width); + int gh = floor((double)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + const double *offset_bottom_data = bottom_data + (roi_batch_ind * channels) * height * width; + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + double w = wstart + iw * sub_bin_size_w; + double h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + double val = bilinear_interp(offset_bottom_data + c * height * width, w, h, width, height); + sum += val; + count++; + } + } + top_data[index] = count == 0 ? (double)(0) : sum / count; + top_count[index] = count; + } +} + +__global__ void DeformablePSROIPoolBackwardAccKernel( + const int count, + const double *top_diff, + const double *top_count, + const int num_rois, + const double spatial_scale, + const int channels, + const int height, const int width, + const int pooled_height, const int pooled_width, + const int output_dim, + double *bottom_data_diff, double *bottom_trans_diff, + const double *bottom_data, + const double *bottom_rois, + const double *bottom_trans, + const int no_trans, + const double trans_std, + const int sample_per_part, + const int group_size, + const int part_size, + const int num_classes, + const int channels_each_class) +{ + CUDA_KERNEL_LOOP(index, count) + { + // The output is in order (n, ctop, ph, pw) + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int ctop = (index / pooled_width / pooled_height) % output_dim; + int n = index / pooled_width / pooled_height / output_dim; + + // [start, end) interval for spatial sampling + const double *offset_bottom_rois = bottom_rois + n * 5; + int roi_batch_ind = offset_bottom_rois[0]; + double roi_start_w = (double)(round(offset_bottom_rois[1])) * spatial_scale - 0.5; + double roi_start_h = (double)(round(offset_bottom_rois[2])) * spatial_scale - 0.5; + double roi_end_w = (double)(round(offset_bottom_rois[3]) + 1.) * spatial_scale - 0.5; + double roi_end_h = (double)(round(offset_bottom_rois[4]) + 1.) * spatial_scale - 0.5; + + // Force too small ROIs to be 1x1 + double roi_width = max(roi_end_w - roi_start_w, 0.1); //avoid 0 + double roi_height = max(roi_end_h - roi_start_h, 0.1); + + // Compute w and h at bottom + double bin_size_h = roi_height / (double)(pooled_height); + double bin_size_w = roi_width / (double)(pooled_width); + + double sub_bin_size_h = bin_size_h / (double)(sample_per_part); + double sub_bin_size_w = bin_size_w / (double)(sample_per_part); + + int part_h = floor((double)(ph) / pooled_height * part_size); + int part_w = floor((double)(pw) / pooled_width * part_size); + int class_id = ctop / channels_each_class; + double trans_x = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w] * trans_std; + double trans_y = no_trans ? (double)(0) : bottom_trans[(((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w] * trans_std; + + double wstart = (double)(pw)*bin_size_w + roi_start_w; + wstart += trans_x * roi_width; + double hstart = (double)(ph)*bin_size_h + roi_start_h; + hstart += trans_y * roi_height; + + if (top_count[index] <= 0) + { + continue; + } + double diff_val = top_diff[index] / top_count[index]; + const double *offset_bottom_data = bottom_data + roi_batch_ind * channels * height * width; + double *offset_bottom_data_diff = bottom_data_diff + roi_batch_ind * channels * height * width; + int gw = floor((double)(pw)*group_size / pooled_width); + int gh = floor((double)(ph)*group_size / pooled_height); + gw = min(max(gw, 0), group_size - 1); + gh = min(max(gh, 0), group_size - 1); + + for (int ih = 0; ih < sample_per_part; ih++) + { + for (int iw = 0; iw < sample_per_part; iw++) + { + double w = wstart + iw * sub_bin_size_w; + double h = hstart + ih * sub_bin_size_h; + // bilinear interpolation + if (w < -0.5 || w > width - 0.5 || h < -0.5 || h > height - 0.5) + { + continue; + } + w = min(max(w, 0.), width - 1.); + h = min(max(h, 0.), height - 1.); + int c = (ctop * group_size + gh) * group_size + gw; + // backward on feature + int x0 = floor(w); + int x1 = ceil(w); + int y0 = floor(h); + int y1 = ceil(h); + double dist_x = w - x0, dist_y = h - y0; + double q00 = (1 - dist_x) * (1 - dist_y); + double q01 = (1 - dist_x) * dist_y; + double q10 = dist_x * (1 - dist_y); + double q11 = dist_x * dist_y; + int bottom_index_base = c * height * width; + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x0, q00 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x0, q01 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y0 * width + x1, q10 * diff_val); + atomicAdd(offset_bottom_data_diff + bottom_index_base + y1 * width + x1, q11 * diff_val); + + if (no_trans) + { + continue; + } + double U00 = offset_bottom_data[bottom_index_base + y0 * width + x0]; + double U01 = offset_bottom_data[bottom_index_base + y1 * width + x0]; + double U10 = offset_bottom_data[bottom_index_base + y0 * width + x1]; + double U11 = offset_bottom_data[bottom_index_base + y1 * width + x1]; + double diff_x = (U11 * dist_y + U10 * (1 - dist_y) - U01 * dist_y - U00 * (1 - dist_y)) * trans_std * diff_val; + diff_x *= roi_width; + double diff_y = (U11 * dist_x + U01 * (1 - dist_x) - U10 * dist_x - U00 * (1 - dist_x)) * trans_std * diff_val; + diff_y *= roi_height; + + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2) * part_size + part_h) * part_size + part_w, diff_x); + atomicAdd(bottom_trans_diff + (((n * num_classes + class_id) * 2 + 1) * part_size + part_h) * part_size + part_w, diff_y); + } + } + } +} + +void DeformablePSROIPoolForward(cudaStream_t stream, + const double *data, + const double *bbox, + const double *trans, + double *out, + double *top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std) +{ + + const double *bottom_data = data; + const double *bottom_rois = bbox; + const double *bottom_trans = no_trans ? NULL : trans; + double *top_data = out; + double *top_count_data = top_count; + + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + DeformablePSROIPoolForwardKernel<<>>( + count, bottom_data, spatial_scale, channels, height, width, pooled_height, pooled_width, + bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, output_dim, + group_size, part_size, num_classes, channels_each_class, top_data, top_count_data); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} + +void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, + const double *out_grad, + const double *data, + const double *bbox, + const double *trans, + const double *top_count, + double *in_grad, + double *trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std) +{ + // LOG(INFO) << "DeformablePSROIPoolBackward"; + const double *top_diff = out_grad; + const double *bottom_data = data; + const double *bottom_rois = bbox; + const double *bottom_trans = no_trans ? NULL : trans; + double *bottom_data_diff = in_grad; + double *bottom_trans_diff = no_trans ? NULL : trans_grad; + const double *top_count_data = top_count; + + const int num_rois = num_bbox; + const int pooled_height = pooled_size; + const int pooled_width = pooled_size; + const int count = num_bbox * output_dim * pooled_height * pooled_width; + const int num_classes = no_trans ? 1 : channels_trans / 2; + const int channels_each_class = no_trans ? output_dim : output_dim / num_classes; + + DeformablePSROIPoolBackwardAccKernel<<>>( + count, top_diff, top_count_data, num_rois, spatial_scale, channels, height, width, + pooled_height, pooled_width, output_dim, bottom_data_diff, bottom_trans_diff, + bottom_data, bottom_rois, bottom_trans, no_trans, trans_std, sample_per_part, + group_size, part_size, num_classes, channels_each_class); + + cudaError_t err = cudaGetLastError(); + if (err != cudaSuccess) + { + printf("error in DeformablePSROIPoolForward: %s\n", cudaGetErrorString(err)); + } +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.h new file mode 100644 index 0000000000000000000000000000000000000000..8a16f72c7c8990066d15a98c345b85e9451e5d2f --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/cuda/dcn_v2_psroi_pooling_cuda_double.h @@ -0,0 +1,66 @@ +/*! + * Copyright (c) 2017 Microsoft + * Licensed under The MIT License [see LICENSE for details] + * \file deformable_psroi_pooling.cu + * \brief + * \author Yi Li, Guodong Zhang, Jifeng Dai +*/ +/***************** Adapted by Charles Shang *********************/ + +#ifndef DCN_V2_PSROI_POOLING_CUDA_DOUBLE +#define DCN_V2_PSROI_POOLING_CUDA_DOUBLE + +#ifdef __cplusplus +extern "C" +{ +#endif + + void DeformablePSROIPoolForward(cudaStream_t stream, + const double *data, + const double *bbox, + const double *trans, + double *out, + double *top_count, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std); + + void DeformablePSROIPoolBackwardAcc(cudaStream_t stream, + const double *out_grad, + const double *data, + const double *bbox, + const double *trans, + const double *top_count, + double *in_grad, + double *trans_grad, + const int batch, + const int channels, + const int height, + const int width, + const int num_bbox, + const int channels_trans, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std); + +#ifdef __cplusplus +} +#endif + +#endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.c b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.c new file mode 100644 index 0000000000000000000000000000000000000000..b440d3f93af121ac1b74ace75c803d11b5c483ab --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.c @@ -0,0 +1,30 @@ +#include +#include +#include + +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} + void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.h new file mode 100644 index 0000000000000000000000000000000000000000..1a97ff0f94ec2186db5656cc2e8da79a8735c57f --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THFloatTensor *input, THFloatTensor *weight, + THFloatTensor *bias, THFloatTensor *ones, + THFloatTensor *offset, THFloatTensor *mask, + THFloatTensor *output, THFloatTensor *columns, + THFloatTensor *grad_input, THFloatTensor *grad_weight, + THFloatTensor *grad_bias, THFloatTensor *grad_offset, + THFloatTensor *grad_mask, THFloatTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.c b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.c new file mode 100644 index 0000000000000000000000000000000000000000..1503b5df52749277d31e7550fe16e51195096e46 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.c @@ -0,0 +1,335 @@ +#include +#include "cuda/dcn_v2_im2col_cuda.h" +#include "cuda/dcn_v2_psroi_pooling_cuda.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor_nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1); + } + + // resize output + THCudaTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + THCudaTensor *output_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, ones), k_, + THCudaTensor_data(state, bias), k_, 0.0f, + THCudaTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Sgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaTensor_data(state, columns), n, + THCudaTensor_data(state, weight), k, 1.0f, + THCudaTensor_data(state, output_n), n); + } + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + THCudaTensor_free(state, output_n); +} + +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + + const int channels_out = THCudaTensor_size(state, weight, 0); + const int channels_kernel = THCudaTensor_size(state, weight, 1); + const int kernel_h_ = THCudaTensor_size(state, weight, 2); + const int kernel_w_ = THCudaTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaTensor_nDimension(state, ones) != 2 || + THCudaTensor_size(state, ones, 0) * THCudaTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaTensor_resize2d(state, ones, height_out, width_out); + THCudaTensor_fill(state, ones, 1.0f); + } + + THCudaTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaTensor *input_n = THCudaTensor_new(state); + THCudaTensor *offset_n = THCudaTensor_new(state); + THCudaTensor *mask_n = THCudaTensor_new(state); + + THCudaTensor *grad_output_n = THCudaTensor_new(state); + THCudaTensor *grad_input_n = THCudaTensor_new(state); + THCudaTensor *grad_offset_n = THCudaTensor_new(state); + THCudaTensor *grad_mask_n = THCudaTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaTensor_select(state, input_n, input, 0, b); + THCudaTensor_select(state, offset_n, offset, 0, b); + THCudaTensor_select(state, mask_n, mask, 0, b); + THCudaTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Sgemm(state, 'n', 't', n, m, k, 1.0f, + THCudaTensor_data(state, grad_output_n), n, + THCudaTensor_data(state, weight), m, 0.0f, + THCudaTensor_data(state, columns), n); + + // gradient w.r.t. input coordinate data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_offset_n), + THCudaTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, columns), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaTensor_data(state, input_n), + THCudaTensor_data(state, offset_n), + THCudaTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Sgemm(state, 't', 'n', n_, m_, k_, 1.0f, + THCudaTensor_data(state, columns), k_, + THCudaTensor_data(state, grad_output_n), k_, 1.0f, + THCudaTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Sgemv(state, + 't', + k_, m_, 1.0f, + THCudaTensor_data(state, grad_output_n), k_, + THCudaTensor_data(state, ones), 1, 1.0f, + THCudaTensor_data(state, grad_bias), 1); + } + + THCudaTensor_free(state, input_n); + THCudaTensor_free(state, offset_n); + THCudaTensor_free(state, mask_n); + + THCudaTensor_free(state, grad_output_n); + THCudaTensor_free(state, grad_input_n); + THCudaTensor_free(state, grad_offset_n); + THCudaTensor_free(state, grad_mask_n); +} + +void dcn_v2_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, + THCudaTensor * trans, + THCudaTensor * out, THCudaTensor * top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THCAssertSameGPU(THCudaTensor_checkGPU(state, 5, input, bbox, trans, out, top_count)); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); + + const int num_bbox = THCudaTensor_size(state, bbox, 0); + if (num_bbox != THCudaTensor_size(state, out, 0)) + THError("Output shape and bbox number wont match: (%d vs %d).", + THCudaTensor_size(state, out, 0), num_bbox); + + DeformablePSROIPoolForward(THCState_getCurrentStream(state), + THCudaTensor_data(state, input), + THCudaTensor_data(state, bbox), + THCudaTensor_data(state, trans), + THCudaTensor_data(state, out), + THCudaTensor_data(state, top_count), + batch, channels, height, width, + num_bbox, + channels_trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} + +void dcn_v2_psroi_pooling_cuda_backward(THCudaTensor * out_grad, + THCudaTensor * input, THCudaTensor * bbox, + THCudaTensor * trans, THCudaTensor * top_count, + THCudaTensor * input_grad, THCudaTensor * trans_grad, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std) +{ + THArgCheck(THCudaTensor_isContiguous(state, out_grad), 0, "out_grad tensor has to be contiguous"); + THArgCheck(THCudaTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THCAssertSameGPU(THCudaTensor_checkGPU(state, 7, input, bbox, trans, out_grad, top_count, + input_grad, trans_grad)); + + const int batch = THCudaTensor_size(state, input, 0); + const int channels = THCudaTensor_size(state, input, 1); + const int height = THCudaTensor_size(state, input, 2); + const int width = THCudaTensor_size(state, input, 3); + const int channels_trans = no_trans? 2 : THCudaTensor_size(state, trans, 1); + + const int num_bbox = THCudaTensor_size(state, bbox, 0); + if (num_bbox != THCudaTensor_size(state, out_grad, 0)) + THError("Output shape and bbox number wont match: (%d vs %d).", + THCudaTensor_size(state, out_grad, 0), num_bbox); + + DeformablePSROIPoolBackwardAcc(THCState_getCurrentStream(state), + THCudaTensor_data(state, out_grad), + THCudaTensor_data(state, input), + THCudaTensor_data(state, bbox), + THCudaTensor_data(state, trans), + THCudaTensor_data(state, top_count), + THCudaTensor_data(state, input_grad), + THCudaTensor_data(state, trans_grad), + batch, channels, height, width, num_bbox, + channels_trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.h new file mode 100644 index 0000000000000000000000000000000000000000..70a27a8efa3d9e96cfb92ca0dc7962d493a72de9 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda.h @@ -0,0 +1,60 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *output, THCudaTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaTensor *input, THCudaTensor *weight, + THCudaTensor *bias, THCudaTensor *ones, + THCudaTensor *offset, THCudaTensor *mask, + THCudaTensor *columns, + THCudaTensor *grad_input, THCudaTensor *grad_weight, + THCudaTensor *grad_bias, THCudaTensor *grad_offset, + THCudaTensor *grad_mask, THCudaTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +void dcn_v2_psroi_pooling_cuda_forward(THCudaTensor * input, THCudaTensor * bbox, + THCudaTensor * trans, + THCudaTensor * out, THCudaTensor * top_count, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +void dcn_v2_psroi_pooling_cuda_backward(THCudaTensor * out_grad, + THCudaTensor * input, THCudaTensor * bbox, + THCudaTensor * trans, THCudaTensor * top_count, + THCudaTensor * input_grad, THCudaTensor * trans_grad, + const int no_trans, + const float spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const float trans_std); + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.c b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.c new file mode 100644 index 0000000000000000000000000000000000000000..021ef12c1755f77a9f63c9a76cac1f53f810faeb --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.c @@ -0,0 +1,358 @@ +#include +#include "cuda/dcn_v2_im2col_cuda_double.h" +#include "cuda/dcn_v2_psroi_pooling_cuda_double.h" + +extern THCState *state; + +// author: Charles Shang +// https://github.com/torch/cunn/blob/master/lib/THCUNN/generic/SpatialConvolutionMM.cu + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 8, input, weight, bias, ones, offset, mask, output, columns)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor_nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // resize output + THCudaDoubleTensor_resize4d(state, output, batch, channels_out, height_out, width_out); + // resize temporary columns + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, 1 * height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *output_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, output_n, output, 0, b); + + // Do Bias first: + // M,N,K are dims of matrix A and B + // (see http://docs.nvidia.com/cuda/cublas/#cublas-lt-t-gt-gemm) + // (N x 1) (1 x M) + long m_ = channels_out; + long n_ = height_out * width_out; + long k_ = 1; + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, ones), k_, + THCudaDoubleTensor_data(state, bias), k_, 0.0, + THCudaDoubleTensor_data(state, output_n), n_); + + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, dilation_h, dilation_w, + deformable_group, THCudaDoubleTensor_data(state, columns)); + + //(k * m) x (m * n) + // Y = WC + long m = channels_out; + long n = height_out * width_out; + long k = channels * kernel_h * kernel_w; + THCudaBlas_Dgemm(state, 'n', 'n', n, m, k, 1.0f, + THCudaDoubleTensor_data(state, columns), n, + THCudaDoubleTensor_data(state, weight), k, 1.0f, + THCudaDoubleTensor_data(state, output_n), n); + } + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + THCudaDoubleTensor_free(state, output_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); +} + +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 13, input, weight, bias, ones, offset, mask, columns, + grad_input, grad_weight, grad_bias, grad_offset, grad_mask, grad_output)); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, weight), 2, "weight tensor has to be contiguous"); + + input = THCudaDoubleTensor_newContiguous(state, input); + offset = THCudaDoubleTensor_newContiguous(state, offset); + mask = THCudaDoubleTensor_newContiguous(state, mask); + weight = THCudaDoubleTensor_newContiguous(state, weight); + grad_output = THCudaDoubleTensor_newContiguous(state, grad_output); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + + const int channels_out = THCudaDoubleTensor_size(state, weight, 0); + const int channels_kernel = THCudaDoubleTensor_size(state, weight, 1); + const int kernel_h_ = THCudaDoubleTensor_size(state, weight, 2); + const int kernel_w_ = THCudaDoubleTensor_size(state, weight, 3); + if (kernel_h_ != kernel_h || kernel_w_ != kernel_w) + THError("Input shape and kernel shape wont match: (%d x %d vs %d x %d).", + kernel_h_, kernel_w, kernel_h_, kernel_w_); + if (channels != channels_kernel) + THError("Input shape and kernel channels wont match: (%d vs %d).", + channels, channels_kernel); + + const int height_out = (height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1; + const int width_out = (width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1; + + if (THCudaDoubleTensor_nDimension(state, ones) != 2 || + THCudaDoubleTensor_size(state, ones, 0) * THCudaDoubleTensor_size(state, ones, 1) < height_out * width_out) + { + // Resize plane and fill with ones... + THCudaDoubleTensor_resize2d(state, ones, height_out, width_out); + THCudaDoubleTensor_fill(state, ones, 1); + } + + // THCudaDoubleTensor_resize4d(state, grad_input, batch, channels, height, width); + THCudaDoubleTensor_resize2d(state, columns, channels * kernel_h * kernel_w, height_out * width_out); + + THCudaDoubleTensor *input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *mask_n = THCudaDoubleTensor_new(state); + + THCudaDoubleTensor *grad_output_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_input_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_offset_n = THCudaDoubleTensor_new(state); + THCudaDoubleTensor *grad_mask_n = THCudaDoubleTensor_new(state); + + for (int b = 0; b < batch; b++) + { + THCudaDoubleTensor_select(state, input_n, input, 0, b); + THCudaDoubleTensor_select(state, offset_n, offset, 0, b); + THCudaDoubleTensor_select(state, mask_n, mask, 0, b); + THCudaDoubleTensor_select(state, grad_output_n, grad_output, 0, b); + THCudaDoubleTensor_select(state, grad_input_n, grad_input, 0, b); + THCudaDoubleTensor_select(state, grad_offset_n, grad_offset, 0, b); + THCudaDoubleTensor_select(state, grad_mask_n, grad_mask, 0, b); + + long m = channels * kernel_h * kernel_w; + long n = height_out * width_out; + long k = channels_out; + + THCudaBlas_Dgemm(state, 'n', 't', n, m, k, 1.0, + THCudaDoubleTensor_data(state, grad_output_n), n, + THCudaDoubleTensor_data(state, weight), m, 0.0, + THCudaDoubleTensor_data(state, columns), n); + + // gradient w.r.t. input offset and mask data + modulated_deformable_col2im_coord_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_offset_n), + THCudaDoubleTensor_data(state, grad_mask_n)); + // gradient w.r.t. input data + modulated_deformable_col2im_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, columns), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, grad_input_n)); + + // gradient w.r.t. weight, dWeight should accumulate across the batch and group + modulated_deformable_im2col_cuda(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input_n), + THCudaDoubleTensor_data(state, offset_n), + THCudaDoubleTensor_data(state, mask_n), + 1, channels, height, width, + height_out, width_out, kernel_h, kernel_w, + pad_h, pad_w, stride_h, stride_w, + dilation_h, dilation_w, deformable_group, + THCudaDoubleTensor_data(state, columns)); + long m_ = channels_out; + long n_ = channels * kernel_h * kernel_w; + long k_ = height_out * width_out; + + THCudaBlas_Dgemm(state, 't', 'n', n_, m_, k_, 1.0, + THCudaDoubleTensor_data(state, columns), k_, + THCudaDoubleTensor_data(state, grad_output_n), k_, 1.0, + THCudaDoubleTensor_data(state, grad_weight), n_); + + // gradient w.r.t. bias + // long m_ = channels_out; + // long k__ = height_out * width_out; + THCudaBlas_Dgemv(state, + 't', + k_, m_, 1.0, + THCudaDoubleTensor_data(state, grad_output_n), k_, + THCudaDoubleTensor_data(state, ones), 1, 1.0, + THCudaDoubleTensor_data(state, grad_bias), 1); + } + + THCudaDoubleTensor_free(state, input_n); + THCudaDoubleTensor_free(state, offset_n); + THCudaDoubleTensor_free(state, mask_n); + + THCudaDoubleTensor_free(state, grad_output_n); + THCudaDoubleTensor_free(state, grad_input_n); + THCudaDoubleTensor_free(state, grad_offset_n); + THCudaDoubleTensor_free(state, grad_mask_n); + + THCudaDoubleTensor_free(state, input); + THCudaDoubleTensor_free(state, offset); + THCudaDoubleTensor_free(state, mask); + THCudaDoubleTensor_free(state, weight); + THCudaDoubleTensor_free(state, grad_output); +} + + +void dcn_v2_psroi_pooling_cuda_forward(THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, + THCudaDoubleTensor * trans, + THCudaDoubleTensor * out, THCudaDoubleTensor * top_count, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std) +{ + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 5, input, bbox, trans, out, top_count)); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + const int channels_trans = no_trans? 2 : THCudaDoubleTensor_size(state, trans, 1); + + const int num_bbox = THCudaDoubleTensor_size(state, bbox, 0); + if (num_bbox != THCudaDoubleTensor_size(state, out, 0)) + THError("Output shape and bbox number wont match: (%d vs %d).", + THCudaDoubleTensor_size(state, out, 0), num_bbox); + + DeformablePSROIPoolForward(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, input), + THCudaDoubleTensor_data(state, bbox), + THCudaDoubleTensor_data(state, trans), + THCudaDoubleTensor_data(state, out), + THCudaDoubleTensor_data(state, top_count), + batch, channels, height, width, + num_bbox, + channels_trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} + +void dcn_v2_psroi_pooling_cuda_backward(THCudaDoubleTensor * out_grad, + THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, + THCudaDoubleTensor * trans, THCudaDoubleTensor * top_count, + THCudaDoubleTensor * input_grad, THCudaDoubleTensor * trans_grad, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std) +{ + THArgCheck(THCudaDoubleTensor_isContiguous(state, out_grad), 0, "out_grad tensor has to be contiguous"); + THArgCheck(THCudaDoubleTensor_isContiguous(state, input), 1, "input tensor has to be contiguous"); + THCAssertSameGPU(THCudaDoubleTensor_checkGPU(state, 7, input, bbox, trans, out_grad, top_count, + input_grad, trans_grad)); + + const int batch = THCudaDoubleTensor_size(state, input, 0); + const int channels = THCudaDoubleTensor_size(state, input, 1); + const int height = THCudaDoubleTensor_size(state, input, 2); + const int width = THCudaDoubleTensor_size(state, input, 3); + const int channels_trans = no_trans? 2 : THCudaDoubleTensor_size(state, trans, 1); + + const int num_bbox = THCudaDoubleTensor_size(state, bbox, 0); + if (num_bbox != THCudaDoubleTensor_size(state, out_grad, 0)) + THError("Output shape and bbox number wont match: (%d vs %d).", + THCudaDoubleTensor_size(state, out_grad, 0), num_bbox); + + DeformablePSROIPoolBackwardAcc(THCState_getCurrentStream(state), + THCudaDoubleTensor_data(state, out_grad), + THCudaDoubleTensor_data(state, input), + THCudaDoubleTensor_data(state, bbox), + THCudaDoubleTensor_data(state, trans), + THCudaDoubleTensor_data(state, top_count), + THCudaDoubleTensor_data(state, input_grad), + THCudaDoubleTensor_data(state, trans_grad), + batch, channels, height, width, num_bbox, + channels_trans, + no_trans, + spatial_scale, + output_dim, + group_size, + pooled_size, + part_size, + sample_per_part, + trans_std); +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.h new file mode 100644 index 0000000000000000000000000000000000000000..826cb2bbf2c9f52b07284f08320200277fc95c91 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_cuda_double.h @@ -0,0 +1,61 @@ +// #ifndef DCN_V2_CUDA +// #define DCN_V2_CUDA + +// #ifdef __cplusplus +// extern "C" +// { +// #endif + +void dcn_v2_cuda_forward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *output, THCudaDoubleTensor *columns, + int kernel_h, int kernel_w, + const int stride_h, const int stride_w, + const int pad_h, const int pad_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_cuda_backward(THCudaDoubleTensor *input, THCudaDoubleTensor *weight, + THCudaDoubleTensor *bias, THCudaDoubleTensor *ones, + THCudaDoubleTensor *offset, THCudaDoubleTensor *mask, + THCudaDoubleTensor *columns, + THCudaDoubleTensor *grad_input, THCudaDoubleTensor *grad_weight, + THCudaDoubleTensor *grad_bias, THCudaDoubleTensor *grad_offset, + THCudaDoubleTensor *grad_mask, THCudaDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); + +void dcn_v2_psroi_pooling_cuda_forward(THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, + THCudaDoubleTensor * trans, + THCudaDoubleTensor * out, THCudaDoubleTensor * top_count, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std); + +void dcn_v2_psroi_pooling_cuda_backward(THCudaDoubleTensor * out_grad, + THCudaDoubleTensor * input, THCudaDoubleTensor * bbox, + THCudaDoubleTensor * trans, THCudaDoubleTensor * top_count, + THCudaDoubleTensor * input_grad, THCudaDoubleTensor * trans_grad, + const int no_trans, + const double spatial_scale, + const int output_dim, + const int group_size, + const int pooled_size, + const int part_size, + const int sample_per_part, + const double trans_std); + + +// #ifdef __cplusplus +// } +// #endif + +// #endif \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.c b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.c new file mode 100644 index 0000000000000000000000000000000000000000..2b865452097bae25b2670a0db2d997c405a4a8b6 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.c @@ -0,0 +1,30 @@ +#include +#include +#include + +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group) +{ + printf("only implemented in GPU"); +} +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group) +{ + printf("only implemented in GPU"); +} \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.h b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.h new file mode 100644 index 0000000000000000000000000000000000000000..eda1f4c48acdcd58d3cdf7b169a6067921ab1428 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/src/dcn_v2_double.h @@ -0,0 +1,20 @@ +void dcn_v2_forward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + const int pad_h, const int pad_w, + const int stride_h, const int stride_w, + const int dilation_h, const int dilation_w, + const int deformable_group); +void dcn_v2_backward(THDoubleTensor *input, THDoubleTensor *weight, + THDoubleTensor *bias, THDoubleTensor *ones, + THDoubleTensor *offset, THDoubleTensor *mask, + THDoubleTensor *output, THDoubleTensor *columns, + THDoubleTensor *grad_input, THDoubleTensor *grad_weight, + THDoubleTensor *grad_bias, THDoubleTensor *grad_offset, + THDoubleTensor *grad_mask, THDoubleTensor *grad_output, + int kernel_h, int kernel_w, + int stride_h, int stride_w, + int pad_h, int pad_w, + int dilation_h, int dilation_w, + int deformable_group); \ No newline at end of file diff --git a/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/test.py b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/test.py new file mode 100644 index 0000000000000000000000000000000000000000..3a8b2e4b7fa6c3239e78ef427f10b06746060525 --- /dev/null +++ b/cv/detection/centernet/pytorch/src/lib/models/networks/DCNv2/test.py @@ -0,0 +1,277 @@ +#!/usr/bin/env python +from __future__ import absolute_import +from __future__ import print_function +from __future__ import division + +import time +import torch +import torch.nn as nn +from torch.autograd import gradcheck + +from dcn_v2 import DCNv2 +from dcn_v2_func import DCNv2Function +from dcn_v2 import DCNv2Pooling +from dcn_v2_func import DCNv2PoolingFunction + +deformable_groups = 1 +N, inC, inH, inW = 2, 2, 4, 4 +outC = 2 +kH, kW = 3, 3 + +def conv_identify(weight, bias): + weight.data.zero_() + bias.data.zero_() + o, i, h, w = weight.shape + y = h//2 + x = w//2 + for p in range(i): + for q in range(o): + if p == q: + weight.data[q, p, y, x] = 1.0 + +def check_zero_offset(): + conv_offset = nn.Conv2d(inC, deformable_groups * 2 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + conv_mask = nn.Conv2d(inC, deformable_groups * 1 * kH * kW, + kernel_size=(kH, kW), + stride=(1, 1), + padding=(1, 1), + bias=True).cuda() + + dcn_v2 = DCNv2(inC, outC, (kH, kW), + stride=1, padding=1, dilation=1, + deformable_groups=deformable_groups).cuda() + + conv_offset.weight.data.zero_() + conv_offset.bias.data.zero_() + conv_mask.weight.data.zero_() + conv_mask.bias.data.zero_() + conv_identify(dcn_v2.weight, dcn_v2.bias) + + input = torch.randn(N, inC, inH, inW).cuda() + offset = conv_offset(input) + mask = conv_mask(input) + mask = torch.sigmoid(mask) + output = dcn_v2(input, offset, mask) + output *= 2 + d = (input - output).abs().max() + if d < 1e-10: + print('Zero offset passed') + else: + print('Zero offset failed') + +def check_gradient_dconv_double(): + + input = torch.randn(N, inC, inH, inW, dtype=torch.float64).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW, dtype=torch.float64).cuda() + # offset.data.zero_() + # offset.data -= 0.00001 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW, dtype=torch.float64).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW, dtype=torch.float64).cuda() + weight.requires_grad = True + + bias = torch.rand(outC, dtype=torch.float64).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-6, atol=1e-5, rtol=1e-3)) + +def check_gradient_dconv(): + + input = torch.randn(N, inC, inH, inW).cuda() + input.requires_grad = True + + offset = torch.randn(N, deformable_groups * 2 * kW * kH, inH, inW).cuda() + # offset.data.zero_() + # offset.data -= 0.5 + offset.requires_grad = True + + mask = torch.rand(N, deformable_groups * 1 * kW * kH, inH, inW).cuda() + # mask.data.zero_() + mask.requires_grad = True + mask = torch.sigmoid(mask) + + weight = torch.randn(outC, inC, kH, kW).cuda() + weight.requires_grad = True + + bias = torch.rand(outC).cuda() + bias.requires_grad = True + + func = DCNv2Function(stride=1, padding=1, dilation=1, deformable_groups=deformable_groups) + + print(gradcheck(func, (input, offset, mask, weight, bias), eps=1e-3, atol=1e-3, rtol=1e-2)) + +def check_pooling_zero_offset(): + from dcn_v2 import DCNv2Pooling + input = torch.randn(2, 16, 64, 64).cuda().zero_() + input[0, :, 16:26, 16:26] = 1. + input[1, :, 10:20, 20:30] = 2. + rois = torch.tensor([ + [0, 65, 65, 103, 103], + [1, 81, 41, 119, 79], + ]).cuda().float() + pooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=True, + group_size=1, + trans_std=0.1).cuda() + + out = pooling(input, rois, input.new()) + s = ', '.join(['%f' % out[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + + dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=16, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + offset = torch.randn(20, 2, 7, 7).cuda().zero_() + dout = dpooling(input, rois, offset) + s = ', '.join(['%f' % dout[i, :, :, :].mean().item() for i in range(rois.shape[0])]) + print(s) + +def check_gradient_dpooling(): + input = torch.randn(2, 3, 5, 5).cuda() * 0.01 + N = 4 + batch_inds = torch.randint(2, (N, 1)).cuda().float() + x = torch.rand((N, 1)).cuda().float() * 15 + y = torch.rand((N, 1)).cuda().float() * 15 + w = torch.rand((N, 1)).cuda().float() * 10 + h = torch.rand((N, 1)).cuda().float() * 10 + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(N, 2, 3, 3).cuda() + dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=3, + output_dim=3, + no_trans=False, + group_size=1, + trans_std=0.0).cuda() + input.requires_grad = True + offset.requires_grad = True + print('check_gradient_dpooling', gradcheck(dpooling, (input, rois, offset), eps=1e-4)) + + +def example_dconv(): + from dcn_v2 import DCN + input = torch.randn(2, 64, 128, 128).cuda() + # wrap all things (offset and mask) in DCN + dcn = DCN(64, 64, kernel_size=(3,3), stride=1, padding=1, deformable_groups=2).cuda() + output = dcn(input) + targert = output.new(*output.size()) + targert.data.uniform_(-0.01, 0.01) + error = (targert - output).mean() + error.backward() + print(output.shape) + +def example_dpooling(): + from dcn_v2 import DCNv2Pooling + input = torch.randn(2, 32, 64, 64).cuda() + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + offset = torch.randn(20, 2, 7, 7).cuda() + input.requires_grad = True + offset.requires_grad = True + + # normal roi_align + pooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=True, + group_size=1, + trans_std=0.1).cuda() + + # deformable pooling + dpooling = DCNv2Pooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + + out = pooling(input, rois, offset) + dout = dpooling(input, rois, offset) + print(out.shape) + print(dout.shape) + + target_out = out.new(*out.size()) + target_out.data.uniform_(-0.01, 0.01) + target_dout = dout.new(*dout.size()) + target_dout.data.uniform_(-0.01, 0.01) + e = (target_out - out).mean() + e.backward() + e = (target_dout - dout).mean() + e.backward() + +def example_mdpooling(): + from dcn_v2 import DCNPooling + input = torch.randn(2, 32, 64, 64).cuda() + input.requires_grad = True + batch_inds = torch.randint(2, (20, 1)).cuda().float() + x = torch.randint(256, (20, 1)).cuda().float() + y = torch.randint(256, (20, 1)).cuda().float() + w = torch.randint(64, (20, 1)).cuda().float() + h = torch.randint(64, (20, 1)).cuda().float() + rois = torch.cat((batch_inds, x, y, x + w, y + h), dim=1) + + # mdformable pooling (V2) + dpooling = DCNPooling(spatial_scale=1.0 / 4, + pooled_size=7, + output_dim=32, + no_trans=False, + group_size=1, + trans_std=0.1).cuda() + + dout = dpooling(input, rois) + target = dout.new(*dout.size()) + target.data.uniform_(-0.1, 0.1) + error = (target - dout).mean() + error.backward() + print(dout.shape) + +if __name__ == '__main__': + + example_dconv() + example_dpooling() + example_mdpooling() + + check_pooling_zero_offset() + # zero offset check + if inC == outC: + check_zero_offset() + + check_gradient_dpooling() + + # # gradient check + # try: + # check_gradient_double() + # except TypeError: + # print('''****** You can swith to double precision in dcn_v2_func.py by (un)commenting these two lines: + # ****** from _ext import dcn_v2 as _backend + # ****** from _ext import dcn_v2_double as _backend''') + # print('****** Your tensor may not be **double** type') + # print('****** Switching to **float** type') + # + # check_gradient() + # finally: + # print('****** Note: backward is not reentrant error may not be a serious problem, ' + # '****** since the max error is less than 1e-7\n' + # '****** Still looking for what trigger this problem') \ No newline at end of file