From 8a4c57706d8c3cb9ce5824fe440a25ca25732737 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Mon, 9 Nov 2020 17:17:00 +0800 Subject: [PATCH 01/13] Add deformable_conv2d api --- tf_adapter/kernels/deformable_conv2d_ops.cc | 54 ++++++ tf_adapter/ops/npu_cube_ops.cc | 159 ++++++++++++++++++ .../python/npu_bridge/tbe/npu_cube_ops.py | 58 +++++++ 3 files changed, 271 insertions(+) create mode 100644 tf_adapter/kernels/deformable_conv2d_ops.cc create mode 100644 tf_adapter/ops/npu_cube_ops.cc create mode 100644 tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py diff --git a/tf_adapter/kernels/deformable_conv2d_ops.cc b/tf_adapter/kernels/deformable_conv2d_ops.cc new file mode 100644 index 000000000..5b278d5ec --- /dev/null +++ b/tf_adapter/kernels/deformable_conv2d_ops.cc @@ -0,0 +1,54 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +template +class DeforableConv2DOP : public tensorflow::OpKernel { + public: + explicit DeforableConv2DOP(tensorflow::OpKernelConstruction *context) + : OpKernel(context) { + LOG(INFO) << "new DeforableConv2DOP"; + } + ~DeforableConv2DOP() override = default; + void Compute(OpKernelContext *context) override { + LOG(INFO) << "DeforableConv2DOP Compute, num_inputs: " + << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +#define REGISTER_KERNEL(type) + REGISTER_KERNEL_BUILDER(Name("DeforableConv2D") \ + .Device(tensorflow::DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeforableConv2DOP); \ +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half>); +#undef REGISTER_KERNEL +} // namespace tensorflow diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc new file mode 100644 index 000000000..22cac8622 --- /dev/null +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -0,0 +1,159 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" + +namespace tensorflow { +Status Conv2dInferShape(shape_inference::InferenceContext *c) { + std::string dt_format; + const std::set kVaildFormat = {"NHWC", "NCHW"}; + if (!c->GetAttr("data_format", &dt_format).ok()) { + dt_format = "NHWC"; + } + if (kVaildFormat.find(dt_format) == kVaildFormat.end()) { + return errors::InvalidArgument("Invalid data format string: ", + dt_format); + } + size_t pos_n = dt_format.find("N"); + size_t pos_c = dt_format.find("C"); + size_t pos_h = dt_format.find("H"); + size_t pos_w = dt_format.find("W"); + + const int rank = 4; + ShapeHandle x_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x_shape)); + DimensionHandle x_n_dim = c->Dim(x_shape, pos_n); + DimensionHandle x_c_dim = c->Dim(x_shape, pos_c); + DimensionHandle x_h_dim = c->Dim(x_shape, pos_c); + DimensionHandle x_w_dim = c->Dim(x_shape, pos_c); + + // Filter format is always HWCN + ShapeHandle filter_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); + DimensionHandle filter_h_dim = c->Dim(filter_shape, 0); + DimensionHandle filter_w_dim = c->Dim(filter_shape, 1); + DimensionHandle filter_c_dim = c->Dim(filter_shape, 2); + DimensionHandle filter_n_dim = c->Dim(filter_shape, 3); + + int32 groups; + TF_RETURN_IF_ERROR(c->GetAttr("groups", &groups)); + if (groups < 1) { + return errors::InvalidArgument("Groups must be >= 1"); + } + if (c->ValueKnown(x_c_dim) && c->ValueKnown(filter_c_dim)) { + int64 x_c = c->Value(x_c_dim); + int64 filter_c = c->Value(filter_c_dim); + if (x_c != filter_c * groups) + return errors::InvalidArgument( + "In_channels (", x_c, + ") should be equal to filter channels (", filter_c, + ")* groups (", groups, ")"); + + if (c->ValueKnown(filter_n_dim)) { + int64 filter_n = c->Value(filter_n_dim); + if (filter_n % groups != 0) + return errors::InvalidArgument( + "Out_channels (", filter_n, + ") should be divisiable by groups (", groups, ")"); + } + } + + std::vector dilations; + TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); + if (dilations.size() != 4) { + return errors::InvalidArgument("Dilations list should be 4D, actual is: ", + dilations.size()); + } + const int32 dil_h = dilations[pos_h]; + const int32 dil_w = dilations[pos_w]; + if (dil_h < 1 || dil_w < 1) { + return errors::InvalidArgument("Dilation rate must be >= 1"); + } + std::vector strides; + TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); + if (strides.size() != 4) { + return errors::InvalidArgument("Strides list should be 4D, actual is: ", + strides.size()); + } + const int32 str_h = strides[pos_h]; + const int32 str_w = strides[pos_w]; + if (str_h < 1 || str_w < 1) { + return errors::InvalidArgument("Stride must be > 0"); + } + + int32 pads; + TF_RETURN_IF_ERROR(c->GetAttr("pads", &pads)); + if (pads.size() != 4) { + return errors::InvalidArgument("Pads list should be 4D, actual is: ", + pads.size()); + } + int64 pad_t = pads[0]; + int64 pad_b = pads[1]; + int64 pad_l = pads[2]; + int64 pad_r = pads[3]; + if (pad_t < 0 || pad_b < 0 || pad_l < 0 || pad_r < 0) { + return errors::InvalidArgument("Pad must be > 0"); + } + + int64 dil_filter_h = dil_h * (c->Value(filter_h_dim) - 1) + 1; + int64 dil_filter_w = dil_w * (c->Value(filter_w_dim) - 1) + 1; + int64 x_h = c->Value(x_h_dim); + int64 x_w = c->Value(x_w_dim); + int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; + int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; + if (out_h < 0 || out_w < 0) { + return errors::InvalidArgument("Image size after padding should not be + smaller than filter size after dilation"); + } + + DimensionHandle out_h_dim = c->MakeDim(out_h); + DimensionHandle out_w_dim = c->MakeDim(out_w); + std::vector out_dims(rank); + out_dims[pos_n] = x_n_dim; + out_dims[pos_c] = filter_n_dim; + out_dims[pos_h] = out_h_dim; + out_dims[pos_w] = out_w_dim + c->set_output(0, c->MakeShape(out_dims)); + return Status::OK(); +} + +REGISTER_OP("DeforableConv2D") + .Input("x: T") + .Input("filter: T") + .Input("offsets: T") + .Input("bias: T") + .Output("y: T") + .Attr("T: {float16, float32}") + .Attr("strides: list(int) = [1,1,1,1]") + .Attr("pads: list(int) = [0,0,0,0]") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("dilations: int = 1") + .Attr("groups: int = 1") + .Attr("deformable_groups: int = 1") + .SetShapeFn(Conv2dInferShape); +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py new file mode 100644 index 000000000..9cdf3e08a --- /dev/null +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -0,0 +1,58 @@ +# Copyright 2016 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Ops for aicore cube.""" +from tensorflow.python.eager import context +from npu_bridge.helper import helper +gen_npu_ops = helper.get_gen_ops() + +def deformable_conv2d(x, + filter, + offsets, + bias, + strides=[1,1,1,1], + pads=[0,0,0,0], + data_format='NHWC', + dilations, + groups=1, + deformable_groups=1, + name=None): + if context.executing_eagerly(): + raise RuntimeError("tf.deformable_conv2d() is not compatible with " + "eager execution.") + + return gen_npu_ops.deformable_conv2d(x, + filter, + offsets, + bias, + strides, + pads, + data_format, + dilations, + groups, + deformable_groups, + name): -- Gitee From 5891fdfc483eccc11152eeac8dbf6528b65e7b12 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Tue, 10 Nov 2020 10:52:40 +0800 Subject: [PATCH 02/13] Fix compile problems --- tf_adapter/kernels/deformable_conv2d_ops.cc | 28 ++++++++++--------- tf_adapter/ops/npu_cube_ops.cc | 22 +++++++++------ .../python/npu_bridge/tbe/npu_cube_ops.py | 6 ++-- 3 files changed, 32 insertions(+), 24 deletions(-) diff --git a/tf_adapter/kernels/deformable_conv2d_ops.cc b/tf_adapter/kernels/deformable_conv2d_ops.cc index 5b278d5ec..15b202704 100644 --- a/tf_adapter/kernels/deformable_conv2d_ops.cc +++ b/tf_adapter/kernels/deformable_conv2d_ops.cc @@ -28,27 +28,29 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { -template -class DeforableConv2DOP : public tensorflow::OpKernel { +namespace{ +template +class DeformableConv2DOp : public tensorflow::OpKernel { public: - explicit DeforableConv2DOP(tensorflow::OpKernelConstruction *context) - : OpKernel(context) { - LOG(INFO) << "new DeforableConv2DOP"; + explicit DeformableConv2DOp(tensorflow::OpKernelConstruction *context) + : OpKernel(context) { + LOG(INFO) << "new DeformableConv2DOp"; } - ~DeforableConv2DOP() override = default; + ~DeformableConv2DOp() override { LOG(INFO) << "del DeformableConv2DOp"; } void Compute(OpKernelContext *context) override { - LOG(INFO) << "DeforableConv2DOP Compute, num_inputs: " + LOG(INFO) << "DeformableConv2DOp Compute, num_inputs: " << context->num_inputs(); } bool IsExpensive() override { return false; } }; -#define REGISTER_KERNEL(type) - REGISTER_KERNEL_BUILDER(Name("DeforableConv2D") \ - .Device(tensorflow::DEVICE_CPU) \ - .TypeConstraint("T"), \ - DeforableConv2DOP); \ +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DeformableConv2D") \ + .Device(tensorflow::DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DOp) REGISTER_KERNEL(float); -REGISTER_KERNEL(Eigen::half>); +REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL +} // namespace } // namespace tensorflow diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 22cac8622..719b55ef1 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -29,6 +29,11 @@ limitations under the License. #include "tensorflow/core/framework/op.h" namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +namespace{ Status Conv2dInferShape(shape_inference::InferenceContext *c) { std::string dt_format; const std::set kVaildFormat = {"NHWC", "NCHW"}; @@ -105,8 +110,8 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { if (str_h < 1 || str_w < 1) { return errors::InvalidArgument("Stride must be > 0"); } - - int32 pads; + + std::vector pads; TF_RETURN_IF_ERROR(c->GetAttr("pads", &pads)); if (pads.size() != 4) { return errors::InvalidArgument("Pads list should be 4D, actual is: ", @@ -114,7 +119,7 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { } int64 pad_t = pads[0]; int64 pad_b = pads[1]; - int64 pad_l = pads[2]; + int64 pad_l = pads[2]; int64 pad_r = pads[3]; if (pad_t < 0 || pad_b < 0 || pad_l < 0 || pad_r < 0) { return errors::InvalidArgument("Pad must be > 0"); @@ -127,8 +132,8 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; if (out_h < 0 || out_w < 0) { - return errors::InvalidArgument("Image size after padding should not be - smaller than filter size after dilation"); + return errors::InvalidArgument("Image size after padding should not be " + "smaller than filter size after dilation"); } DimensionHandle out_h_dim = c->MakeDim(out_h); @@ -137,12 +142,12 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { out_dims[pos_n] = x_n_dim; out_dims[pos_c] = filter_n_dim; out_dims[pos_h] = out_h_dim; - out_dims[pos_w] = out_w_dim + out_dims[pos_w] = out_w_dim; c->set_output(0, c->MakeShape(out_dims)); return Status::OK(); } -REGISTER_OP("DeforableConv2D") +REGISTER_OP("DeformableConv2D") .Input("x: T") .Input("filter: T") .Input("offsets: T") @@ -156,4 +161,5 @@ REGISTER_OP("DeforableConv2D") .Attr("groups: int = 1") .Attr("deformable_groups: int = 1") .SetShapeFn(Conv2dInferShape); -} // namespace tensorflow \ No newline at end of file +} // namespace +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 9cdf3e08a..5ed98b420 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -11,14 +11,14 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# +# # Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at -# +# # http://www.apache.org/licenses/LICENSE-2.0 -# +# # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -- Gitee From 4b10777813cc958eabe77b34d079e26ef469345b Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Tue, 10 Nov 2020 11:56:03 +0800 Subject: [PATCH 03/13] Fix debug problems --- tf_adapter/ops/npu_cube_ops.cc | 2 +- tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py | 8 ++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 719b55ef1..17e67833a 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -157,7 +157,7 @@ REGISTER_OP("DeformableConv2D") .Attr("strides: list(int) = [1,1,1,1]") .Attr("pads: list(int) = [0,0,0,0]") .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .Attr("dilations: int = 1") + .Attr("dilations: list(int) = [1,1,1,1]") .Attr("groups: int = 1") .Attr("deformable_groups: int = 1") .SetShapeFn(Conv2dInferShape); diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 5ed98b420..cb649b85e 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -37,13 +37,13 @@ def deformable_conv2d(x, strides=[1,1,1,1], pads=[0,0,0,0], data_format='NHWC', - dilations, + dilations=[1,1,1,1], groups=1, deformable_groups=1, name=None): if context.executing_eagerly(): - raise RuntimeError("tf.deformable_conv2d() is not compatible with " - "eager execution.") + raise RuntimeError("tf.deformable_conv2d() is not compatible with " + "eager execution.") return gen_npu_ops.deformable_conv2d(x, filter, @@ -55,4 +55,4 @@ def deformable_conv2d(x, dilations, groups, deformable_groups, - name): + name) -- Gitee From 137d68e7abef2fed410dc44ffa0b1289b5db8377 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Tue, 10 Nov 2020 15:45:56 +0800 Subject: [PATCH 04/13] Add op comment --- .../python/npu_bridge/tbe/npu_cube_ops.py | 92 ++++++++++++++----- 1 file changed, 69 insertions(+), 23 deletions(-) diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index cb649b85e..70173b7b6 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -30,29 +30,75 @@ from tensorflow.python.eager import context from npu_bridge.helper import helper gen_npu_ops = helper.get_gen_ops() -def deformable_conv2d(x, - filter, - offsets, - bias, - strides=[1,1,1,1], - pads=[0,0,0,0], - data_format='NHWC', - dilations=[1,1,1,1], - groups=1, - deformable_groups=1, - name=None): - if context.executing_eagerly(): + +def deformable_conv2d( + x, + filters, + offsets, + bias, + strides=(1, 1, 1, 1), + pads=(0, 0, 0, 0), + data_format='NHWC', + dilations=(1, 1, 1, 1), + groups=1, + deformable_groups=1, + name=None): + + r"""Computes a 2-D deformable convolution given 4-D `inputs`、`filters` and + `offsets` tensors. + + Inputs: + x: A 4D `Tensor` of input `image`. With the `data_format` `NHWC`, the + data is stored in the order of: [batch, in_height, in_width, + in_channels]. Must be one of the following types: `float16`, `float32`. + filters: A 4D `Tensor` of learnable filters. Must have the same type as + `x`. The data is stored in the order of: `[filter_height, filter_width, + in_channels / groups, out_channels]`. + offsets: A 4D `Tensor` of x y coordinates offset and mask. With the + `data_format` `NHWC`, the data is stored in the order of: `[in_height, + in_width, deformable_groups * filter_height * filter_width * 3]`. Must + be one of the following types: `float16`, `float32`. + + Attributes: + strides: Required. An list of `4` `ints`. The stride of the sliding + window for each dimension of `image`. The dimension order is + interpreted according to the value of `data_format`. The `N` and `C` + dimensions must be set to 1. + pads: Required. An list of `4` `ints`. The number of pixels to add to + each `(pad_top, pad_bottom, pad_left, pad_right)` side of the `image`. + data_format: Optional. A `string` from: `"NHWC", "NCHW"`. Specify the + data format of the input and output data. Defaults to `"NHWC"`. + dilations: Optional. An list of `4` `ints`. The dilation factor for each + dimension of `image`. The dimension order is interpreted according to + the value of `data_format`. The `N` and `C` dimensions must be set to + 1. Defaults to `(1, 1, 1, 1)`. + groups: Optional. An `int`. The number of blocked connections from + `in_channels` to `out_channels`. `In_channels` and `out_channels` must + both be divisible by `groups`. Defaults to 1. + deformable_groups: Optional. An `int`. The number of deformable group + partitions. `In_channels` must be divisible by `deformable_groups`. + Defaults to 1. + name: Optional. A name for the operation. + + Returns: + A 4D `Tensor` of output feature map. Has the same type as `x`. With the + `data_format` `NHWC`, the data is stored in the order of: `[batch, + out_height, out_width, out_channels]`. + """ + + if context.executing_eagerly(): raise RuntimeError("tf.deformable_conv2d() is not compatible with " "eager execution.") - return gen_npu_ops.deformable_conv2d(x, - filter, - offsets, - bias, - strides, - pads, - data_format, - dilations, - groups, - deformable_groups, - name) + return gen_npu_ops.deformable_conv2d( + x=x, + filter=filters, + offsets=offsets, + bias=bias, + strides=strides, + pads=pads, + data_format=data_format, + dilations=dilations, + groups=groups, + deformable_groups=deformable_groups, + name=name) -- Gitee From c66d046d6c23856c5080d8dfa27e505737186d68 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Tue, 10 Nov 2020 16:26:11 +0800 Subject: [PATCH 05/13] Fix indentation error --- tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 70173b7b6..3b4697cb1 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -44,7 +44,7 @@ def deformable_conv2d( deformable_groups=1, name=None): - r"""Computes a 2-D deformable convolution given 4-D `inputs`、`filters` and + """Computes a 2-D deformable convolution given 4-D `inputs`、`filters` and `offsets` tensors. Inputs: @@ -86,7 +86,7 @@ def deformable_conv2d( out_height, out_width, out_channels]`. """ - if context.executing_eagerly(): + if context.executing_eagerly(): raise RuntimeError("tf.deformable_conv2d() is not compatible with " "eager execution.") -- Gitee From 029af1326ad39e485c734f09cd9c5fc928d6c18f Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Tue, 10 Nov 2020 17:21:03 +0800 Subject: [PATCH 06/13] Remove un-ascii symbol --- tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 3b4697cb1..82319d102 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -44,7 +44,7 @@ def deformable_conv2d( deformable_groups=1, name=None): - """Computes a 2-D deformable convolution given 4-D `inputs`、`filters` and + """Computes a 2-D deformable convolution given 4-D `inputs`, `filters` and `offsets` tensors. Inputs: -- Gitee From 35c46a06a5be194f338a960b2d66f7846bd31b84 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 09:45:35 +0800 Subject: [PATCH 07/13] Change deformable_conv2d manual --- tf_adapter/ops/npu_cube_ops.cc | 11 ++++---- .../python/npu_bridge/tbe/npu_cube_ops.py | 28 +++++++++++++------ 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 17e67833a..7c6e9e990 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -73,18 +73,19 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { if (c->ValueKnown(x_c_dim) && c->ValueKnown(filter_c_dim)) { int64 x_c = c->Value(x_c_dim); int64 filter_c = c->Value(filter_c_dim); - if (x_c != filter_c * groups) + if (x_c != filter_c * groups) { return errors::InvalidArgument( "In_channels (", x_c, ") should be equal to filter channels (", filter_c, ")* groups (", groups, ")"); - + } if (c->ValueKnown(filter_n_dim)) { int64 filter_n = c->Value(filter_n_dim); - if (filter_n % groups != 0) + if (filter_n % groups != 0) { return errors::InvalidArgument( "Out_channels (", filter_n, ") should be divisiable by groups (", groups, ")"); + } } } @@ -154,8 +155,8 @@ REGISTER_OP("DeformableConv2D") .Input("bias: T") .Output("y: T") .Attr("T: {float16, float32}") - .Attr("strides: list(int) = [1,1,1,1]") - .Attr("pads: list(int) = [0,0,0,0]") + .Attr("strides: list(int)") + .Attr("pads: list(int)") .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") .Attr("dilations: list(int) = [1,1,1,1]") .Attr("groups: int = 1") diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 82319d102..824dc03fa 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -31,20 +31,20 @@ from npu_bridge.helper import helper gen_npu_ops = helper.get_gen_ops() -def deformable_conv2d( +def deformable_conv2d( # pylint: disable=redefined-builtin x, - filters, + filter, offsets, - bias, - strides=(1, 1, 1, 1), - pads=(0, 0, 0, 0), + bias=None, + strides=None, + pads=None, data_format='NHWC', dilations=(1, 1, 1, 1), groups=1, deformable_groups=1, name=None): - """Computes a 2-D deformable convolution given 4-D `inputs`, `filters` and + """Computes a 2-D deformable convolution given 4-D `x`, `filters` and `offsets` tensors. Inputs: @@ -58,6 +58,9 @@ def deformable_conv2d( `data_format` `NHWC`, the data is stored in the order of: `[in_height, in_width, deformable_groups * filter_height * filter_width * 3]`. Must be one of the following types: `float16`, `float32`. + bias: A optional 1D `Tensor` of additive biased to the filter outputs. + The data is stored in the order of: `[out_channels]`. Must be one of + the following types: `float16`, `float32`. Attributes: strides: Required. An list of `4` `ints`. The stride of the sliding @@ -66,8 +69,8 @@ def deformable_conv2d( dimensions must be set to 1. pads: Required. An list of `4` `ints`. The number of pixels to add to each `(pad_top, pad_bottom, pad_left, pad_right)` side of the `image`. - data_format: Optional. A `string` from: `"NHWC", "NCHW"`. Specify the - data format of the input and output data. Defaults to `"NHWC"`. + data_format: Optional. A `string` from: `NHWC`, `NCHW`. Specify the + data format of the input and output data. Defaults to `NHWC`. dilations: Optional. An list of `4` `ints`. The dilation factor for each dimension of `image`. The dimension order is interpreted according to the value of `data_format`. The `N` and `C` dimensions must be set to @@ -84,6 +87,13 @@ def deformable_conv2d( A 4D `Tensor` of output feature map. Has the same type as `x`. With the `data_format` `NHWC`, the data is stored in the order of: `[batch, out_height, out_width, out_channels]`. + + out_height = (in_height + pad_top + pad_bottom) - + (dilation_h * filter_height - 1) + 1)) + / stride_h + 1 + out_width = (in_width + pad_left + pad_right) - + (dilation_w * filter_width - 1) + 1)) + / stride_w + 1 """ if context.executing_eagerly(): @@ -92,7 +102,7 @@ def deformable_conv2d( return gen_npu_ops.deformable_conv2d( x=x, - filter=filters, + filter=filter, offsets=offsets, bias=bias, strides=strides, -- Gitee From 3888a4e33224751293e7b28923c8a0816b70a725 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 16:27:00 +0800 Subject: [PATCH 08/13] Add input check --- tf_adapter/ops/npu_cube_ops.cc | 145 +++++++++++++----- .../python/npu_bridge/tbe/npu_cube_ops.py | 6 +- 2 files changed, 112 insertions(+), 39 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 7c6e9e990..96d184e75 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -54,9 +54,18 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { TF_RETURN_IF_ERROR(c->WithRank(c->input(0), rank, &x_shape)); DimensionHandle x_n_dim = c->Dim(x_shape, pos_n); DimensionHandle x_c_dim = c->Dim(x_shape, pos_c); - DimensionHandle x_h_dim = c->Dim(x_shape, pos_c); - DimensionHandle x_w_dim = c->Dim(x_shape, pos_c); - + DimensionHandle x_h_dim = c->Dim(x_shape, pos_h); + DimensionHandle x_w_dim = c->Dim(x_shape, pos_w); + if (!c->ValueKnown(x_n_dim) || + !c->ValueKnown(x_c_dim) || + !c->ValueKnown(x_h_dim) || + !c->ValueKnown(x_w_dim)) { + return errors::InvalidArgument("Invalid x shape value"); + } + int64 x_n = c->Value(x_n_dim); + int64 x_c = c->Value(x_c_dim); + int64 x_h = c->Value(x_h_dim); + int64 x_w = c->Value(x_w_dim); // Filter format is always HWCN ShapeHandle filter_shape; TF_RETURN_IF_ERROR(c->WithRank(c->input(1), rank, &filter_shape)); @@ -64,77 +73,141 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { DimensionHandle filter_w_dim = c->Dim(filter_shape, 1); DimensionHandle filter_c_dim = c->Dim(filter_shape, 2); DimensionHandle filter_n_dim = c->Dim(filter_shape, 3); + if (!c->ValueKnown(filter_h_dim) || + !c->ValueKnown(filter_w_dim) || + !c->ValueKnown(filter_c_dim) || + !c->ValueKnown(filter_n_dim)) { + return errors::InvalidArgument("Invalid filter shape value"); + } + int64 filter_n = c->Value(filter_n_dim); + int64 filter_c = c->Value(filter_c_dim); + int64 filter_h = c->Value(filter_h_dim); + int64 filter_w = c->Value(filter_w_dim); + ShapeHandle offsets_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(2), rank, &offsets_shape)); + DimensionHandle offsets_n_dim = c->Dim(offsets_shape, pos_n); + DimensionHandle offsets_c_dim = c->Dim(offsets_shape, pos_c); + DimensionHandle offsets_h_dim = c->Dim(offsets_shape, pos_h); + DimensionHandle offsets_w_dim = c->Dim(offsets_shape, pos_w); + if (!c->ValueKnown(offsets_n_dim) || + !c->ValueKnown(offsets_c_dim) || + !c->ValueKnown(offsets_h_dim) || + !c->ValueKnown(offsets_w_dim)) { + return errors::InvalidArgument("Invalid offsets shape value"); + } + int64 offsets_n = c->Value(offsets_n_dim); + int64 offsets_c = c->Value(offsets_c_dim); + int64 offsets_h = c->Value(offsets_h_dim); + int64 offsets_w = c->Value(offsets_w_dim); + if (offsets_n != x_n) { + return errors::InvalidArgument( + "Offsets batch size (", offsets_n, + ") should be equal to x (", x_n, ")"); + } + if (offsets_h != x_h) { + return errors::InvalidArgument( + "Offsets height (", offsets_h, + ") should be equal to x (", x_h, ")"); + } + if (offsets_w != x_w) { + return errors::InvalidArgument( + "Offsets width (", offsets_w, + ") should be equal to x (", x_w, ")"); + } + int32 dfm_groups; + TF_RETURN_IF_ERROR(c->GetAttr("deformable_groups", &groups)); + if (dfm_groups < 1) { + return errors::InvalidArgument("Deformable_groups should be >= 1"); + } + int64 exp_offset_c = x_w * x_w * dfm_groups * 3; + if (offsets_c != x_w) { + return errors::InvalidArgument( + "Offsets channels (", offsets_w, + ") should be equal to ", exp_offset_c); + } + if (c->num_inputs() == 4) { + ShapeHandle bias_shape; + TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &bias_shape)); + DimensionHandle bias_n_dim = c->Dim(bias_shape, 0); + if (c->ValueKnown(bias_n_dim) && + c->Value(bias_n_dim) != filter_n) { + return errors::InvalidArgument( + "Bias size (", c->Value(bias_n_dim), + ") should be equal to out_channels (", filter_n, ")"); + } + } int32 groups; TF_RETURN_IF_ERROR(c->GetAttr("groups", &groups)); if (groups < 1) { - return errors::InvalidArgument("Groups must be >= 1"); + return errors::InvalidArgument("Groups should be >= 1"); } - if (c->ValueKnown(x_c_dim) && c->ValueKnown(filter_c_dim)) { - int64 x_c = c->Value(x_c_dim); - int64 filter_c = c->Value(filter_c_dim); - if (x_c != filter_c * groups) { - return errors::InvalidArgument( - "In_channels (", x_c, - ") should be equal to filter channels (", filter_c, - ")* groups (", groups, ")"); - } - if (c->ValueKnown(filter_n_dim)) { - int64 filter_n = c->Value(filter_n_dim); - if (filter_n % groups != 0) { - return errors::InvalidArgument( - "Out_channels (", filter_n, - ") should be divisiable by groups (", groups, ")"); - } - } + if (x_c != filter_c * groups) { + return errors::InvalidArgument( + "In_channels (", x_c, + ") should be equal to filter channels (", filter_c, + ") * groups (", groups, ")"); + } + if (x_c % dfm_groups != 0) { + return errors::InvalidArgument( + "In_channels (", filter_n, + ") should be divisiable by deformable_groups (", dfm_groups, ")"); + } + int64 filter_n = c->Value(filter_n_dim); + if (filter_n % groups != 0) { + return errors::InvalidArgument( + "Out_channels (", filter_n, + ") should be divisiable by groups (", groups, ")"); } std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); if (dilations.size() != 4) { - return errors::InvalidArgument("Dilations list should be 4D, actual is: ", - dilations.size()); + return errors::InvalidArgument( + "Dilations attribute should contain 4 values, but got: ", + dilations.size()); } const int32 dil_h = dilations[pos_h]; const int32 dil_w = dilations[pos_w]; if (dil_h < 1 || dil_w < 1) { - return errors::InvalidArgument("Dilation rate must be >= 1"); + return errors::InvalidArgument("Dilation rate should be >= 1"); } std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { - return errors::InvalidArgument("Strides list should be 4D, actual is: ", - strides.size()); + return errors::InvalidArgument( + "Strides attribute should contain 4 values, but got: ", + strides.size()); } const int32 str_h = strides[pos_h]; const int32 str_w = strides[pos_w]; if (str_h < 1 || str_w < 1) { - return errors::InvalidArgument("Stride must be > 0"); + return errors::InvalidArgument("Stride should be > 0"); } std::vector pads; TF_RETURN_IF_ERROR(c->GetAttr("pads", &pads)); if (pads.size() != 4) { - return errors::InvalidArgument("Pads list should be 4D, actual is: ", - pads.size()); + return errors::InvalidArgument( + "Pads attribute should contain 4 values, but got: ", + pads.size()); } int64 pad_t = pads[0]; int64 pad_b = pads[1]; int64 pad_l = pads[2]; int64 pad_r = pads[3]; if (pad_t < 0 || pad_b < 0 || pad_l < 0 || pad_r < 0) { - return errors::InvalidArgument("Pad must be > 0"); + return errors::InvalidArgument("Pad should be > 0"); } - int64 dil_filter_h = dil_h * (c->Value(filter_h_dim) - 1) + 1; - int64 dil_filter_w = dil_w * (c->Value(filter_w_dim) - 1) + 1; - int64 x_h = c->Value(x_h_dim); - int64 x_w = c->Value(x_w_dim); + int64 dil_filter_h = dil_h * (filter_h - 1) + 1; + int64 dil_filter_w = dil_w * (filter_w - 1) + 1; int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; if (out_h < 0 || out_w < 0) { - return errors::InvalidArgument("Image size after padding should not be " - "smaller than filter size after dilation"); + return errors::InvalidArgument( + "Image size after padding should not be smaller than filter size " + "after dilation"); } DimensionHandle out_h_dim = c->MakeDim(out_h); diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 824dc03fa..09529210b 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -55,9 +55,9 @@ def deformable_conv2d( # pylint: disable=redefined-builtin `x`. The data is stored in the order of: `[filter_height, filter_width, in_channels / groups, out_channels]`. offsets: A 4D `Tensor` of x y coordinates offset and mask. With the - `data_format` `NHWC`, the data is stored in the order of: `[in_height, - in_width, deformable_groups * filter_height * filter_width * 3]`. Must - be one of the following types: `float16`, `float32`. + `data_format` `NHWC`, the data is stored in the order of: `[batch, + in_height, in_width, deformable_groups * filter_height * filter_width + * 3]`. Must be one of the following types: `float16`, `float32`. bias: A optional 1D `Tensor` of additive biased to the filter outputs. The data is stored in the order of: `[out_channels]`. Must be one of the following types: `float16`, `float32`. -- Gitee From 1a13830b22219e4e8a7b808c7dfaafb41e7e823f Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 17:25:38 +0800 Subject: [PATCH 09/13] Fix error variable --- tf_adapter/ops/npu_cube_ops.cc | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 96d184e75..208b02632 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -115,14 +115,14 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { ") should be equal to x (", x_w, ")"); } int32 dfm_groups; - TF_RETURN_IF_ERROR(c->GetAttr("deformable_groups", &groups)); + TF_RETURN_IF_ERROR(c->GetAttr("deformable_groups", &dfm_groups)); if (dfm_groups < 1) { return errors::InvalidArgument("Deformable_groups should be >= 1"); } - int64 exp_offset_c = x_w * x_w * dfm_groups * 3; - if (offsets_c != x_w) { + int64 exp_offset_c = dfm_groups * filter_h * filter_w * 3; + if (offsets_c != exp_offset_c) { return errors::InvalidArgument( - "Offsets channels (", offsets_w, + "Offsets channels (", offsets_c, ") should be equal to ", exp_offset_c); } if (c->num_inputs() == 4) { @@ -150,10 +150,9 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { } if (x_c % dfm_groups != 0) { return errors::InvalidArgument( - "In_channels (", filter_n, + "In_channels (", x_c, ") should be divisiable by deformable_groups (", dfm_groups, ")"); } - int64 filter_n = c->Value(filter_n_dim); if (filter_n % groups != 0) { return errors::InvalidArgument( "Out_channels (", filter_n, @@ -210,13 +209,11 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { "after dilation"); } - DimensionHandle out_h_dim = c->MakeDim(out_h); - DimensionHandle out_w_dim = c->MakeDim(out_w); std::vector out_dims(rank); out_dims[pos_n] = x_n_dim; out_dims[pos_c] = filter_n_dim; - out_dims[pos_h] = out_h_dim; - out_dims[pos_w] = out_w_dim; + out_dims[pos_h] = c->MakeDim(out_h); + out_dims[pos_w] = c->MakeDim(out_w); c->set_output(0, c->MakeShape(out_dims)); return Status::OK(); } -- Gitee From f864fb32b61a1509f62fabc8ef2359e15e10308e Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 17:30:00 +0800 Subject: [PATCH 10/13] Remove unnecessary whitespace --- tf_adapter/ops/npu_cube_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index 208b02632..bc76adbf9 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -201,8 +201,8 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { int64 dil_filter_h = dil_h * (filter_h - 1) + 1; int64 dil_filter_w = dil_w * (filter_w - 1) + 1; - int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; - int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; + int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; + int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; if (out_h < 0 || out_w < 0) { return errors::InvalidArgument( "Image size after padding should not be smaller than filter size " -- Gitee From ac9ec10c01b9b517331e0b2616185160c8932956 Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 18:24:58 +0800 Subject: [PATCH 11/13] Add more print info --- tf_adapter/ops/npu_cube_ops.cc | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index bc76adbf9..d357d3f44 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -171,6 +171,10 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { if (dil_h < 1 || dil_w < 1) { return errors::InvalidArgument("Dilation rate should be >= 1"); } + if (dilations[pos_n] != 1 || dilations[pos_c] != 1) { + return errors::InvalidArgument( + "Dilations N and C dimensions must be set to 1"); + } std::vector strides; TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides)); if (strides.size() != 4) { @@ -181,9 +185,12 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { const int32 str_h = strides[pos_h]; const int32 str_w = strides[pos_w]; if (str_h < 1 || str_w < 1) { - return errors::InvalidArgument("Stride should be > 0"); + return errors::InvalidArgument("Stride should be >= 1"); + } + if (strides[pos_n] != 1 || strides[pos_c] != 1) { + return errors::InvalidArgument( + "Stride N and C dimensions must be set to 1"); } - std::vector pads; TF_RETURN_IF_ERROR(c->GetAttr("pads", &pads)); if (pads.size() != 4) { @@ -196,17 +203,20 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { int64 pad_l = pads[2]; int64 pad_r = pads[3]; if (pad_t < 0 || pad_b < 0 || pad_l < 0 || pad_r < 0) { - return errors::InvalidArgument("Pad should be > 0"); + return errors::InvalidArgument("Pad should be >= 0"); } int64 dil_filter_h = dil_h * (filter_h - 1) + 1; int64 dil_filter_w = dil_w * (filter_w - 1) + 1; - int64 out_h = (x_h + pad_t + pad_b - dil_filter_h) / str_h + 1; - int64 out_w = (x_w + pad_l + pad_r - dil_filter_w) / str_w + 1; + int64 pad_image_h = x_h + pad_t + pad_b; + int64 pad_image_w = x_w + pad_l + pad_r; + int64 out_h = (pad_image_h - dil_filter_h) / str_h + 1; + int64 out_w = (pad_image_w - dil_filter_w) / str_w + 1; if (out_h < 0 || out_w < 0) { return errors::InvalidArgument( - "Image size after padding should not be smaller than filter size " - "after dilation"); + "Image size after padding (", pad_image_h, " ", pad_image_w, + ") should not be smaller than filter size after dilation (", + dil_filter_h, " ", dil_filter_w,")"); } std::vector out_dims(rank); -- Gitee From a57a91cca642fe76497c270ee5029a2d494de88b Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Wed, 11 Nov 2020 18:44:16 +0800 Subject: [PATCH 12/13] Change attr check order --- tf_adapter/ops/npu_cube_ops.cc | 47 +++++++++++++++++----------------- 1 file changed, 23 insertions(+), 24 deletions(-) diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index d357d3f44..c04130fe9 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -114,11 +114,32 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { "Offsets width (", offsets_w, ") should be equal to x (", x_w, ")"); } + int32 groups; + TF_RETURN_IF_ERROR(c->GetAttr("groups", &groups)); + if (groups < 1) { + return errors::InvalidArgument("Groups should be >= 1"); + } + if (x_c != filter_c * groups) { + return errors::InvalidArgument( + "In_channels (", x_c, + ") should be equal to filter channels (", filter_c, + ") * groups (", groups, ")"); + } + if (filter_n % groups != 0) { + return errors::InvalidArgument( + "Out_channels (", filter_n, + ") should be divisiable by groups (", groups, ")"); + } int32 dfm_groups; TF_RETURN_IF_ERROR(c->GetAttr("deformable_groups", &dfm_groups)); if (dfm_groups < 1) { return errors::InvalidArgument("Deformable_groups should be >= 1"); } + if (x_c % dfm_groups != 0) { + return errors::InvalidArgument( + "In_channels (", x_c, + ") should be divisiable by deformable_groups (", dfm_groups, ")"); + } int64 exp_offset_c = dfm_groups * filter_h * filter_w * 3; if (offsets_c != exp_offset_c) { return errors::InvalidArgument( @@ -137,28 +158,6 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { } } - int32 groups; - TF_RETURN_IF_ERROR(c->GetAttr("groups", &groups)); - if (groups < 1) { - return errors::InvalidArgument("Groups should be >= 1"); - } - if (x_c != filter_c * groups) { - return errors::InvalidArgument( - "In_channels (", x_c, - ") should be equal to filter channels (", filter_c, - ") * groups (", groups, ")"); - } - if (x_c % dfm_groups != 0) { - return errors::InvalidArgument( - "In_channels (", x_c, - ") should be divisiable by deformable_groups (", dfm_groups, ")"); - } - if (filter_n % groups != 0) { - return errors::InvalidArgument( - "Out_channels (", filter_n, - ") should be divisiable by groups (", groups, ")"); - } - std::vector dilations; TF_RETURN_IF_ERROR(c->GetAttr("dilations", &dilations)); if (dilations.size() != 4) { @@ -214,9 +213,9 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { int64 out_w = (pad_image_w - dil_filter_w) / str_w + 1; if (out_h < 0 || out_w < 0) { return errors::InvalidArgument( - "Image size after padding (", pad_image_h, " ", pad_image_w, + "Image size after padding (", pad_image_h, ", ", pad_image_w, ") should not be smaller than filter size after dilation (", - dil_filter_h, " ", dil_filter_w,")"); + dil_filter_h, ", ", dil_filter_w,")"); } std::vector out_dims(rank); -- Gitee From ab946081fb7cf8d8f2134857265f4a3d85c8506c Mon Sep 17 00:00:00 2001 From: wangchonghui Date: Thu, 12 Nov 2020 20:16:38 +0800 Subject: [PATCH 13/13] Add support for optional bias --- tf_adapter/kernels/deformable_conv2d_ops.cc | 40 +++++++++++++++---- tf_adapter/ops/npu_cube_ops.cc | 16 +++++++- .../python/npu_bridge/tbe/npu_cube_ops.py | 40 +++++++++++++------ 3 files changed, 73 insertions(+), 23 deletions(-) diff --git a/tf_adapter/kernels/deformable_conv2d_ops.cc b/tf_adapter/kernels/deformable_conv2d_ops.cc index 15b202704..00ac592b3 100644 --- a/tf_adapter/kernels/deformable_conv2d_ops.cc +++ b/tf_adapter/kernels/deformable_conv2d_ops.cc @@ -28,11 +28,10 @@ limitations under the License. #include "tensorflow/core/framework/op_kernel.h" namespace tensorflow { -namespace{ template -class DeformableConv2DOp : public tensorflow::OpKernel { +class DeformableConv2DOp : public OpKernel { public: - explicit DeformableConv2DOp(tensorflow::OpKernelConstruction *context) + explicit DeformableConv2DOp(OpKernelConstruction *context) : OpKernel(context) { LOG(INFO) << "new DeformableConv2DOp"; } @@ -44,13 +43,38 @@ class DeformableConv2DOp : public tensorflow::OpKernel { bool IsExpensive() override { return false; } }; -#define REGISTER_KERNEL(type) \ -REGISTER_KERNEL_BUILDER(Name("DeformableConv2D") \ - .Device(tensorflow::DEVICE_CPU) \ - .TypeConstraint("T"), \ +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DeformableConv2D") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ DeformableConv2DOp) REGISTER_KERNEL(float); REGISTER_KERNEL(Eigen::half); #undef REGISTER_KERNEL -} // namespace + +template +class DeformableConv2DWithBiasOp : public OpKernel { + public: + explicit DeformableConv2DWithBiasOp(OpKernelConstruction *context) + : OpKernel(context) { + LOG(INFO) << "new DeformableConv2DWithBiasOp"; + } + ~DeformableConv2DWithBiasOp() override { + LOG(INFO) << "del DeformableConv2DWithBiasOp"; + } + void Compute(OpKernelContext *context) override { + LOG(INFO) << "DeformableConv2DWithBiasOp Compute, num_inputs: " + << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DeformableConv2DWithBias") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableConv2DWithBiasOp) +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half); +#undef REGISTER_KERNEL } // namespace tensorflow diff --git a/tf_adapter/ops/npu_cube_ops.cc b/tf_adapter/ops/npu_cube_ops.cc index c04130fe9..990b45024 100644 --- a/tf_adapter/ops/npu_cube_ops.cc +++ b/tf_adapter/ops/npu_cube_ops.cc @@ -33,7 +33,6 @@ using shape_inference::DimensionHandle; using shape_inference::InferenceContext; using shape_inference::ShapeHandle; -namespace{ Status Conv2dInferShape(shape_inference::InferenceContext *c) { std::string dt_format; const std::set kVaildFormat = {"NHWC", "NCHW"}; @@ -228,6 +227,20 @@ Status Conv2dInferShape(shape_inference::InferenceContext *c) { } REGISTER_OP("DeformableConv2D") + .Input("x: T") + .Input("filter: T") + .Input("offsets: T") + .Output("y: T") + .Attr("T: {float16, float32}") + .Attr("strides: list(int)") + .Attr("pads: list(int)") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("dilations: list(int) = [1,1,1,1]") + .Attr("groups: int = 1") + .Attr("deformable_groups: int = 1") + .SetShapeFn(Conv2dInferShape); + +REGISTER_OP("DeformableConv2DWithBias") .Input("x: T") .Input("filter: T") .Input("offsets: T") @@ -241,5 +254,4 @@ REGISTER_OP("DeformableConv2D") .Attr("groups: int = 1") .Attr("deformable_groups: int = 1") .SetShapeFn(Conv2dInferShape); -} // namespace } // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py index 09529210b..b9966a3c9 100644 --- a/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py +++ b/tf_adapter/python/npu_bridge/tbe/npu_cube_ops.py @@ -26,6 +26,7 @@ # limitations under the License. # ============================================================================== """Ops for aicore cube.""" +from tensorflow import Tensor from tensorflow.python.eager import context from npu_bridge.helper import helper gen_npu_ops = helper.get_gen_ops() @@ -43,7 +44,6 @@ def deformable_conv2d( # pylint: disable=redefined-builtin groups=1, deformable_groups=1, name=None): - """Computes a 2-D deformable convolution given 4-D `x`, `filters` and `offsets` tensors. @@ -100,15 +100,29 @@ def deformable_conv2d( # pylint: disable=redefined-builtin raise RuntimeError("tf.deformable_conv2d() is not compatible with " "eager execution.") - return gen_npu_ops.deformable_conv2d( - x=x, - filter=filter, - offsets=offsets, - bias=bias, - strides=strides, - pads=pads, - data_format=data_format, - dilations=dilations, - groups=groups, - deformable_groups=deformable_groups, - name=name) + if isinstance(bias, Tensor): + op_res = gen_npu_ops.deformable_conv2d_with_bias( + x=x, + filter=filter, + offsets=offsets, + bias=bias, + strides=strides, + pads=pads, + data_format=data_format, + dilations=dilations, + groups=groups, + deformable_groups=deformable_groups, + name=name) + else: + op_res = gen_npu_ops.deformable_conv2d( + x=x, + filter=filter, + offsets=offsets, + strides=strides, + pads=pads, + data_format=data_format, + dilations=dilations, + groups=groups, + deformable_groups=deformable_groups, + name=name) + return op_res -- Gitee