diff --git a/tf_adapter/kernels/npu_cpu_ops.cc b/tf_adapter/kernels/npu_cpu_ops.cc index 72a0ae3b3dc316723a8680206bea993422398a9a..1f86d57645725d4710bb5063e2138d8023e1a91d 100644 --- a/tf_adapter/kernels/npu_cpu_ops.cc +++ b/tf_adapter/kernels/npu_cpu_ops.cc @@ -1,199 +1,239 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "tensorflow/core/framework/bounds_check.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/register_types.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/resource_op_kernel.h" -#include "tf_adapter/common/adp_logger.h" -#include "tf_adapter/util/cache_interface.h" - -namespace tensorflow { -class EmbeddingRankIdOpKernel : public OpKernel { - public: - explicit EmbeddingRankIdOpKernel(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingRankIdOpKernel() {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingRankIdOp Compute."; } -}; - -class EmbeddingLocalIndexOpKernel : public OpKernel { - public: - explicit EmbeddingLocalIndexOpKernel(OpKernelConstruction *context) : OpKernel(context) {} - ~EmbeddingLocalIndexOpKernel() {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingLocalIndexOp Compute."; } -}; - -class LruCacheOp : public ResourceOpKernel { - public: - explicit LruCacheOp(OpKernelConstruction* context) : ResourceOpKernel(context) {} - ~LruCacheOp() override {} - void Compute(OpKernelContext* context) override { ADP_LOG(INFO) << "LruCacheOp Compute"; } - private: - Status CreateResource(CacheInterface** resource) override - EXCLUSIVE_LOCKS_REQUIRED(mu_) { - return Status::OK(); - } -}; - -class CacheAddOp : public OpKernel { - public: - explicit CacheAddOp(OpKernelConstruction *context) : OpKernel(context) {} - ~CacheAddOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheAddOp Compute"; } -}; - -class CacheRemoteIndexToLocalOp : public OpKernel { - public: - explicit CacheRemoteIndexToLocalOp(OpKernelConstruction *context) : OpKernel(context) {} - ~CacheRemoteIndexToLocalOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheRemoteIndexToLocalOp Compute"; } -}; - -class CacheAllIndexToLocalOp : public OpKernel { - public: - explicit CacheAllIndexToLocalOp(OpKernelConstruction *context) : OpKernel(context) {} - ~CacheAllIndexToLocalOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheAllIndexToLocalOp Compute"; } -}; - -template -class DeformableOffsetsOp : public OpKernel { - public: - explicit DeformableOffsetsOp(OpKernelConstruction *context) : OpKernel(context) {} - ~DeformableOffsetsOp() override {} - void Compute(OpKernelContext *context) override { - ADP_LOG(INFO) << "DeformableOffsetsOp Compute, num_inputs: " - << context->num_inputs(); - } - bool IsExpensive() override { return false; } -}; - -template -class DeformableOffsetsGradOp : public OpKernel { - public: - explicit DeformableOffsetsGradOp(OpKernelConstruction *context) : OpKernel(context) {} - ~DeformableOffsetsGradOp() override {} - void Compute(OpKernelContext *context) override { - ADP_LOG(INFO) << "DeformableOffsetsGradOp Compute, num_inputs: " - << context->num_inputs(); - } - bool IsExpensive() override { return false; } -}; - -class RandomChoiceWithMaskOp : public OpKernel { - public: - explicit RandomChoiceWithMaskOp(OpKernelConstruction *context) : OpKernel(context) {} - ~RandomChoiceWithMaskOp() override {} - void Compute(OpKernelContext *context) override { - ADP_LOG(INFO) << "RandomChoiceWithMaskOp Compute "; - } -}; - -template -class DenseImageWarpOp : public OpKernel { - public: - explicit DenseImageWarpOp(OpKernelConstruction *context) : OpKernel(context) {} - ~DenseImageWarpOp() override {} - void Compute(OpKernelContext *context) override {} - bool IsExpensive() override { return false; } -}; - -template -class DenseImageWarpGradOp : public OpKernel { - public: - explicit DenseImageWarpGradOp(OpKernelConstruction *context) : OpKernel(context) {} - ~DenseImageWarpGradOp() override {} - void Compute(OpKernelContext *context) override {} - bool IsExpensive() override { return false; } -}; - -class BatchEnqueueOp : public OpKernel { - public: - explicit BatchEnqueueOp(OpKernelConstruction *context) : OpKernel(context) {} - ~BatchEnqueueOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "BatchEnqueueOp Compute"; } -}; - -class OCRRecognitionPreHandleOp : public OpKernel { - public: - explicit OCRRecognitionPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} - ~OCRRecognitionPreHandleOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRRecognitionPreHandleOp Compute"; } -}; - -class OCRDetectionPreHandleOp : public OpKernel { - public: - explicit OCRDetectionPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} - ~OCRDetectionPreHandleOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRDetectionPreHandleOp Compute"; } -}; - -class OCRIdentifyPreHandleOp : public OpKernel { - public: - explicit OCRIdentifyPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} - ~OCRIdentifyPreHandleOp() override {} - void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRIdentifyPreHandleOp Compute"; } -}; - -REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); -REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); -REGISTER_KERNEL_BUILDER(Name("LruCache").Device(DEVICE_CPU), LruCacheOp); -REGISTER_KERNEL_BUILDER(Name("CacheAdd").Device(DEVICE_CPU), CacheAddOp); -REGISTER_KERNEL_BUILDER(Name("CacheRemoteIndexToLocal").Device(DEVICE_CPU), CacheRemoteIndexToLocalOp); -REGISTER_KERNEL_BUILDER(Name("CacheAllIndexToLocal").Device(DEVICE_CPU), CacheAllIndexToLocalOp); -REGISTER_KERNEL_BUILDER(Name("RandomChoiceWithMask").Device(DEVICE_CPU), RandomChoiceWithMaskOp); -REGISTER_KERNEL_BUILDER(Name("BatchEnqueue").Device(DEVICE_CPU), BatchEnqueueOp); -REGISTER_KERNEL_BUILDER(Name("OCRRecognitionPreHandle").Device(DEVICE_CPU), OCRRecognitionPreHandleOp); -REGISTER_KERNEL_BUILDER(Name("OCRDetectionPreHandle").Device(DEVICE_CPU), OCRDetectionPreHandleOp); -REGISTER_KERNEL_BUILDER(Name("OCRIdentifyPreHandle").Device(DEVICE_CPU), OCRIdentifyPreHandleOp); - -#define REGISTER_KERNEL(type) \ -REGISTER_KERNEL_BUILDER(Name("DeformableOffsets") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - DeformableOffsetsOp) -REGISTER_KERNEL(float); -REGISTER_KERNEL(Eigen::half); -#undef REGISTER_KERNEL - -#define REGISTER_KERNEL(type) \ -REGISTER_KERNEL_BUILDER(Name("DeformableOffsetsGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - DeformableOffsetsGradOp) -REGISTER_KERNEL(float); -REGISTER_KERNEL(Eigen::half); -#undef REGISTER_KERNEL - -#define REGISTER_KERNEL(type) \ -REGISTER_KERNEL_BUILDER(Name("DenseImageWarp") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - DenseImageWarpOp) -REGISTER_KERNEL(float); -REGISTER_KERNEL(Eigen::half); -#undef REGISTER_KERNEL - -#define REGISTER_KERNEL(type) \ -REGISTER_KERNEL_BUILDER(Name("DenseImageWarpGrad") \ - .Device(DEVICE_CPU) \ - .TypeConstraint("T"), \ - DenseImageWarpGradOp) -REGISTER_KERNEL(float); -REGISTER_KERNEL(Eigen::half); -#undef REGISTER_KERNEL +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "tensorflow/core/framework/bounds_check.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/register_types.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/resource_op_kernel.h" +#include "tf_adapter/common/adp_logger.h" +#include "tf_adapter/util/cache_interface.h" + +namespace tensorflow { +class EmbeddingRankIdOpKernel : public OpKernel { + public: + explicit EmbeddingRankIdOpKernel(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingRankIdOpKernel() {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingRankIdOp Compute."; } +}; + +class EmbeddingLocalIndexOpKernel : public OpKernel { + public: + explicit EmbeddingLocalIndexOpKernel(OpKernelConstruction *context) : OpKernel(context) {} + ~EmbeddingLocalIndexOpKernel() {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "EmbeddingLocalIndexOp Compute."; } +}; + +class LruCacheOp : public ResourceOpKernel { + public: + explicit LruCacheOp(OpKernelConstruction* context) : ResourceOpKernel(context) {} + ~LruCacheOp() override {} + void Compute(OpKernelContext* context) override { ADP_LOG(INFO) << "LruCacheOp Compute"; } + private: + Status CreateResource(CacheInterface** resource) override + EXCLUSIVE_LOCKS_REQUIRED(mu_) { + return Status::OK(); + } +}; + +class CacheAddOp : public OpKernel { + public: + explicit CacheAddOp(OpKernelConstruction *context) : OpKernel(context) {} + ~CacheAddOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheAddOp Compute"; } +}; + +class CacheRemoteIndexToLocalOp : public OpKernel { + public: + explicit CacheRemoteIndexToLocalOp(OpKernelConstruction *context) : OpKernel(context) {} + ~CacheRemoteIndexToLocalOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheRemoteIndexToLocalOp Compute"; } +}; + +class CacheAllIndexToLocalOp : public OpKernel { + public: + explicit CacheAllIndexToLocalOp(OpKernelConstruction *context) : OpKernel(context) {} + ~CacheAllIndexToLocalOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "CacheAllIndexToLocalOp Compute"; } +}; + +template +class DeformableOffsetsOp : public OpKernel { + public: + explicit DeformableOffsetsOp(OpKernelConstruction *context) : OpKernel(context) {} + ~DeformableOffsetsOp() override {} + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "DeformableOffsetsOp Compute, num_inputs: " + << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +template +class DeformableOffsetsGradOp : public OpKernel { + public: + explicit DeformableOffsetsGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~DeformableOffsetsGradOp() override {} + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "DeformableOffsetsGradOp Compute, num_inputs: " + << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +class RandomChoiceWithMaskOp : public OpKernel { + public: + explicit RandomChoiceWithMaskOp(OpKernelConstruction *context) : OpKernel(context) {} + ~RandomChoiceWithMaskOp() override {} + void Compute(OpKernelContext *context) override { + ADP_LOG(INFO) << "RandomChoiceWithMaskOp Compute "; + } +}; + +template +class DenseImageWarpOp : public OpKernel { + public: + explicit DenseImageWarpOp(OpKernelConstruction *context) : OpKernel(context) {} + ~DenseImageWarpOp() override {} + void Compute(OpKernelContext *context) override {} + bool IsExpensive() override { return false; } +}; + +template +class DenseImageWarpGradOp : public OpKernel { + public: + explicit DenseImageWarpGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~DenseImageWarpGradOp() override {} + void Compute(OpKernelContext *context) override {} + bool IsExpensive() override { return false; } +}; + +class BatchEnqueueOp : public OpKernel { + public: + explicit BatchEnqueueOp(OpKernelConstruction *context) : OpKernel(context) {} + ~BatchEnqueueOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "BatchEnqueueOp Compute"; } +}; + +class OCRRecognitionPreHandleOp : public OpKernel { + public: + explicit OCRRecognitionPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} + ~OCRRecognitionPreHandleOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRRecognitionPreHandleOp Compute"; } +}; + +class OCRDetectionPreHandleOp : public OpKernel { + public: + explicit OCRDetectionPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} + ~OCRDetectionPreHandleOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRDetectionPreHandleOp Compute"; } +}; + +class OCRIdentifyPreHandleOp : public OpKernel { + public: + explicit OCRIdentifyPreHandleOp(OpKernelConstruction *context) : OpKernel(context) {} + ~OCRIdentifyPreHandleOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRIdentifyPreHandleOp Compute"; } +}; + +class BatchDilatePolysOp : public OpKernel { + public : + explicit BatchDilatePolysOp(OpKernelConstruction *context):OpKernel(context){} + ~BatchDilatePolysOp() override{} + void Compute(OpKernelContext *context) override{ADP_LOG(INFO)<<"BatchDilatePolysOp Compute";} +}; + +class OCRFindContoursOp : public OpKernel { + public : + explicit OCRFindContoursOp(OpKernelConstruction *context):OpKernel(context){} + ~OCRFindContoursOp() override{} + void Compute(OpKernelContext *context) override{ADP_LOG(INFO)<<"OCRFindContoursOp Compute";} +}; + +class OCRDetectionPostHandleOp : public OpKernel { + public: + explicit OCRDetectionPostHandleOp(OpKernelConstruction *context) : OpKernel(context) {} + ~OCRDetectionPostHandleOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "OCRDetectionPostHandleOp Compute"; } +}; + +class ResizeAndClipPolysOp : public OpKernel { + public: + explicit ResizeAndClipPolysOp(OpKernelConstruction *context) : OpKernel(context) {} + ~ResizeAndClipPolysOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "ResizeAndClipPolysOp Compute"; } +}; + +class DequeueOp : public OpKernel { + public: + explicit DequeueOp(OpKernelConstruction *context) : OpKernel(context) {} + ~DequeueOp() override {} + void Compute(OpKernelContext *context) override { ADP_LOG(INFO) << "DequeueOp Compute"; } +}; + +REGISTER_KERNEL_BUILDER(Name("EmbeddingRankId").Device(DEVICE_CPU), EmbeddingRankIdOpKernel); +REGISTER_KERNEL_BUILDER(Name("EmbeddingLocalIndex").Device(DEVICE_CPU), EmbeddingLocalIndexOpKernel); +REGISTER_KERNEL_BUILDER(Name("LruCache").Device(DEVICE_CPU), LruCacheOp); +REGISTER_KERNEL_BUILDER(Name("CacheAdd").Device(DEVICE_CPU), CacheAddOp); +REGISTER_KERNEL_BUILDER(Name("CacheRemoteIndexToLocal").Device(DEVICE_CPU), CacheRemoteIndexToLocalOp); +REGISTER_KERNEL_BUILDER(Name("CacheAllIndexToLocal").Device(DEVICE_CPU), CacheAllIndexToLocalOp); +REGISTER_KERNEL_BUILDER(Name("RandomChoiceWithMask").Device(DEVICE_CPU), RandomChoiceWithMaskOp); +REGISTER_KERNEL_BUILDER(Name("BatchEnqueue").Device(DEVICE_CPU), BatchEnqueueOp); +REGISTER_KERNEL_BUILDER(Name("OCRRecognitionPreHandle").Device(DEVICE_CPU), OCRRecognitionPreHandleOp); +REGISTER_KERNEL_BUILDER(Name("OCRDetectionPreHandle").Device(DEVICE_CPU), OCRDetectionPreHandleOp); +REGISTER_KERNEL_BUILDER(Name("OCRIdentifyPreHandle").Device(DEVICE_CPU), OCRIdentifyPreHandleOp); +REGISTER_KERNEL_BUILDER(Name("BatchDilatePolys").Device(DEVICE_CPU), BatchDilatePolysOp); +REGISTER_KERNEL_BUILDER(Name("OCRFindContours").Device(DEVICE_CPU), OCRFindContoursOp); +REGISTER_KERNEL_BUILDER(Name("OCRDetectionPostHandle").Device(DEVICE_CPU), OCRDetectionPostHandleOp); +REGISTER_KERNEL_BUILDER(Name("ResizeAndClipPolys").Device(DEVICE_CPU), ResizeAndClipPolysOp); +REGISTER_KERNEL_BUILDER(Name("Dequeue").Device(DEVICE_CPU), DequeueOp); + +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DeformableOffsets") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableOffsetsOp) +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half); +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DeformableOffsetsGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DeformableOffsetsGradOp) +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half); +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DenseImageWarp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DenseImageWarpOp) +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half); +#undef REGISTER_KERNEL + +#define REGISTER_KERNEL(type) \ +REGISTER_KERNEL_BUILDER(Name("DenseImageWarpGrad") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("T"), \ + DenseImageWarpGradOp) +REGISTER_KERNEL(float); +REGISTER_KERNEL(Eigen::half); +#undef REGISTER_KERNEL } // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/ops/npu_cpu_ops.cc b/tf_adapter/ops/npu_cpu_ops.cc index 4401b3d49398fc5f00bed27fecf7f49ec5ce6f0a..501253eb88756dd48800ec5d0343cd71875eabba 100644 --- a/tf_adapter/ops/npu_cpu_ops.cc +++ b/tf_adapter/ops/npu_cpu_ops.cc @@ -1,370 +1,479 @@ -/* - * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#include "tensorflow/core/framework/common_shape_fns.h" -#include "tensorflow/core/framework/op.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/framework/tensor.pb.h" - -namespace tensorflow { -using shape_inference::DimensionHandle; -using shape_inference::InferenceContext; -using shape_inference::ShapeHandle; - -REGISTER_OP("EmbeddingRankId") - .Input("addr_table: uint64") - .Input("index: T") - .Output("rank_id: uint64") - .Attr("T: {int64,int32,uint64}") - .Attr("row_memory: int = 320") - .Attr("mode: string = 'mod' ") - .SetAllowsUninitializedInput() - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto out_shape = c->MakeShape({c->Dim(c->input(1), 0), c->Dim(c->input(0), 1)}); - c->set_output(0, out_shape); - return Status::OK(); - }) - .Doc(R"doc( - Traverse the index calculation server and its position in the server. - Arguments - addr_table: Tensors of addr_table. - index: Tensors of index. - Output - rank_id: Tensors with the same shape as index.dim(0)*3. - )doc"); -//regist embedding local index op -REGISTER_OP("EmbeddingLocalIndex") - .Input("addr_table: uint64") - .Input("index: T") - .Output("local_idx: T") - .Output("nums: T") - .Output("recover_idx: T") - .Attr("T: {int64,int32,uint64,uint32}") - .Attr("row_memory: int = 320") - .Attr("mode: string = 'mod' ") - .SetAllowsUninitializedInput() - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto index_shape = c->input(1); - c->set_output(0, index_shape); - auto nums_shape = c->MakeShape({c->Dim(c->input(0), 0)}); - c->set_output(1, nums_shape); - c->set_output(2, index_shape); - return Status::OK(); - }) - .Doc(R"doc( - Traverse the index calculation server and its position in the server. - Arguments - addr_table: Tensors of addr_table. - index: Tensors of index. - Output - local_idx: Local_idx sorted by rank_id. - nums: The number of local_idx found on each rank_id. - recover_idx: The sorted local_idx element corresponds to the position of - the original input index. - )doc"); -//regist lru cahe op -REGISTER_OP("LruCache") - .Output("cache: resource") - .Attr("cache_size: int") - .Attr("load_factor: float = 1.0") - .Attr("container: string = ''") - .Attr("shared_name: string = 'LruCache'") - .Attr("dtype: {uint32, uint64, int32, int64}") - .SetIsStateful() - .SetShapeFn(shape_inference::ScalarShape); -//regist cache add op -REGISTER_OP("CacheAdd") - .Input("cache: resource") - .Input("ids: T") - .Output("swap_in_id: T") - .Output("swap_in_idx: T") - .Output("swap_out_id: T") - .Output("swap_out_idx: T") - .Attr("T: {int64, int32, uint64, uint32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->Vector(c->UnknownDim())); - c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); - c->set_output(3, c->Vector(c->UnknownDim())); - return Status::OK(); - }); -//regist cache remote index to local op -REGISTER_OP("CacheRemoteIndexToLocal") - .Input("cache: resource") - .Input("ids: T") - .Output("local_idx: T") - .Attr("T: {int64, int32, uint32, uint64}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->Vector(c->Rank(c->input(1)))); - return Status::OK(); - }); -//regist cache all index to local op -REGISTER_OP("CacheAllIndexToLocal") - .Input("cache: resource") - .Output("local_idx: dtype") - .Attr("dtype: {int64, int32, uint32, uint64}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->Vector(c->UnknownDim())); - return Status::OK(); - }); - -//regist deformable offsets op -REGISTER_OP("DeformableOffsets") - .Input("x: T") - .Input("offsets: T") - .Output("y: T") - .Attr("T: {float16, float32}") - .Attr("strides: list(int)") - .Attr("pads: list(int)") - .Attr("ksize: list(int)") - .Attr("dilations: list(int) = [1,1,1,1]") - .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .Attr("deformable_groups: int = 1") - .Attr("modulated: bool = true") - .SetShapeFn([](shape_inference::InferenceContext *c) { - std::string dt_format; - const std::set kValidFormat = {"NHWC", "NCHW"}; - if (!c->GetAttr("data_format", &dt_format).ok()) { - dt_format = "NHWC"; - } - if (kValidFormat.find(dt_format) == kValidFormat.end()) { - return errors::InvalidArgument("Invalid data format string: ", - dt_format); - } - - size_t pos_n = dt_format.find("N"); - size_t pos_c = dt_format.find("C"); - size_t pos_h = dt_format.find("H"); - size_t pos_w = dt_format.find("W"); - - auto input_x_shape = c->input(0); - auto input_offsets_shape = c->input(1); - int64_t input_offsets_h = c->Value(c->Dim(input_offsets_shape, pos_h)); - int64_t input_offsets_w = c->Value(c->Dim(input_offsets_shape, pos_w)); - - std::vector ksizes; - TF_RETURN_IF_ERROR(c->GetAttr("ksize", &ksizes)); - if (ksizes.size() != 2) { - return errors::InvalidArgument( - "ksize attribute should contain 2 values, but got: ", - ksizes.size()); - } - const int64_t kh = ksizes[0]; - const int64_t kw = ksizes[1]; - - const int32_t rank = 4; - std::vector out_dims(rank); - out_dims[pos_n] = c->Dim(input_x_shape, pos_n); - out_dims[pos_c] = c->Dim(input_x_shape, pos_c); - out_dims[pos_h] = c->MakeDim(input_offsets_h * kh); - out_dims[pos_w] = c->MakeDim(input_offsets_w * kw); - c->set_output(0, c->MakeShape(out_dims)); - return Status::OK(); - }); -//regist deformable offsets grad op -REGISTER_OP("DeformableOffsetsGrad") - .Input("grad: T") - .Input("x: T") - .Input("offsets: T") - .Output("grad_x: T") - .Output("grad_offsets: T") - .Attr("T: {float16, float32}") - .Attr("strides: list(int)") - .Attr("pads: list(int)") - .Attr("ksize: list(int)") - .Attr("dilations: list(int) = [1,1,1,1]") - .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .Attr("deformable_groups: int = 1") - .Attr("modulated: bool = true") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto input_x_shape = c->input(1); - auto input_offsets_shape = c->input(2); - c->set_output(0, input_x_shape); - c->set_output(1, input_offsets_shape); - return Status::OK(); - }); -//regist Random Choice With Mask op -REGISTER_OP("RandomChoiceWithMask") - .Input("x: bool") - .Output("y: int32") - .Output("mask: bool") - .Attr("count: int = 0") - .Attr("seed: int = 0") - .Attr("seed2: int = 0") - .SetShapeFn([](shape_inference::InferenceContext *c) { - int64 count(0); - c->GetAttr("count", &count); - if (count >0) { - c->set_output(0, c->Matrix(count, c->Rank(c->input(0)))); - c->set_output(1, c->Vector(count)); - } else if (count == 0) { - c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); - c->set_output(1, c->Vector(c->UnknownDim())); - } else { - return errors::InvalidArgument( - "input count must greater or equal to 0 but instead is ", - count); - } - return Status::OK(); - }); -//regist dense image warp op -REGISTER_OP("DenseImageWarp") - .Input("image: T") - .Input("flow: S") - .Output("y: T") - .Attr("T: {float16, float32}") - .Attr("S: {float16, float32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto input_image_shape = c->input(0); - c->set_output(0, input_image_shape); - return Status::OK(); - }); -//regist dense image warp grad op -REGISTER_OP("DenseImageWarpGrad") - .Input("grad: T") - .Input("image: T") - .Input("flow: S") - .Output("grad_image: T") - .Output("grad_flow: S") - .Attr("T: {float16, float32}") - .Attr("S: {float16, float32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto input_image_shape = c->input(1); - auto input_flow_shape = c->input(2); - c->set_output(0, input_image_shape); - c->set_output(1, input_flow_shape); - return Status::OK(); - }); - - REGISTER_OP("ScatterElements") - .Input("data: T") - .Input("indices: indexT") - .Input("updates: T") - .Output("y: T") - .Attr("axis: int = 0") - .Attr("T: numbertype") - .Attr("indexT: {int32, int64}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - auto data_shape = c->input(0); - c->set_output(0, data_shape); - return Status::OK(); - }); - - REGISTER_OP("BatchEnqueue") - .Input("x: T") - .Input("queue_id: uint32") - .Output("enqueue_count: int32") - .Attr("batch_size: int = 8") - .Attr("queue_name: string = ''") - .Attr("pad_mode: {'REPLICATE', 'ZERO'} = 'REPLICATE'") - .Attr("T: {float16, float32, float64, int8, uint8, int16, uint16, int32, uint32, int64, uint64}") - .SetShapeFn(tensorflow::shape_inference::ScalarShape); - - REGISTER_OP("OCRRecognitionPreHandle") - .Input("imgs_data: uint8") - .Input("imgs_offset: int32") - .Input("imgs_size: int32") - .Input("langs: int32") - .Input("langs_score: T") - .Output("imgs: uint8") - .Output("imgs_relation: int32") - .Output("imgs_lang: int32") - .Attr("batch_size: int = 8") - .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .Attr("pad_mode: {'REPLICATE', 'ZERO'} = 'REPLICATE'") - .Attr("T: {float16, float32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->Vector(c->UnknownDim())); - c->set_output(1, c->Vector(c->UnknownDim())); - c->set_output(2, c->Vector(c->UnknownDim())); - return Status::OK(); - }); - - REGISTER_OP("OCRDetectionPreHandle") - .Input("img: uint8") - .Output("resized_img: uint8") - .Output("h_scale: float32") - .Output("w_scale: float32") - .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .SetShapeFn([](shape_inference::InferenceContext *c) { - std::string dt_format; - const std::set kVaildFormat = {"NHWC", "NCHW"}; - if (!c->GetAttr("data_format", &dt_format).ok()) { - dt_format = "NHWC"; - } - if (kVaildFormat.find(dt_format) == kVaildFormat.end()) { - return errors::InvalidArgument("Invalid data format string: ", - dt_format); - } - const int32_t kRank = 3; - std::vector out_dims(kRank); - if (dt_format == "NHWC") { - out_dims[0] = c->UnknownDim(); - out_dims[1] = c->UnknownDim(); - out_dims[2] = c->MakeDim(3); - } else { - out_dims[0] = c->MakeDim(3); - out_dims[1] = c->UnknownDim(); - out_dims[2] = c->UnknownDim(); - } - c->set_output(0, c->MakeShape(out_dims)); - c->set_output(1, c->Scalar()); - c->set_output(2, c->Scalar()); - return Status::OK(); - }); - - REGISTER_OP("OCRIdentifyPreHandle") - .Input("imgs_data: uint8") - .Input("imgs_offset: int32") - .Input("imgs_size: int32") - .Output("resized_imgs: uint8") - .Attr("size: list(int)") - .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") - .SetShapeFn([](shape_inference::InferenceContext *c) { - std::vector size; - TF_RETURN_IF_ERROR(c->GetAttr("size", &size)); - if (size.size() != 2) { - return errors::InvalidArgument( - "size attribute should contain 2 values, but got: ", - size.size()); - } - const int64_t k1 = size[0]; - const int64_t k2 = size[1]; - - std::string dt_format; - const std::set kVaildFormat = {"NHWC", "NCHW"}; - if (!c->GetAttr("data_format", &dt_format).ok()) { - dt_format = "NHWC"; - } - if (kVaildFormat.find(dt_format) == kVaildFormat.end()) { - return errors::InvalidArgument("Invalid data format string: ", - dt_format); - } - const int32_t kRank = 4; - std::vector out_dims(kRank); - out_dims[0] = c->UnknownDim(); - if (dt_format == "NHWC") { - out_dims[0] = c->MakeDim(k1); - out_dims[1] = c->MakeDim(k2); - out_dims[2] = c->MakeDim(3); - } else { - out_dims[0] = c->MakeDim(3); - out_dims[1] = c->MakeDim(k1); - out_dims[2] = c->MakeDim(k2); - } - c->set_output(0, c->MakeShape(out_dims)); - return Status::OK(); - }); -} // namespace tensorflow +/* + * Copyright (c) Huawei Technologies Co., Ltd. 2019-2020. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.pb.h" + +namespace tensorflow { +using shape_inference::DimensionHandle; +using shape_inference::InferenceContext; +using shape_inference::ShapeHandle; + +REGISTER_OP("EmbeddingRankId") + .Input("addr_table: uint64") + .Input("index: T") + .Output("rank_id: uint64") + .Attr("T: {int64,int32,uint64}") + .Attr("row_memory: int = 320") + .Attr("mode: string = 'mod' ") + .SetAllowsUninitializedInput() + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto out_shape = c->MakeShape({c->Dim(c->input(1), 0), c->Dim(c->input(0), 1)}); + c->set_output(0, out_shape); + return Status::OK(); + }) + .Doc(R"doc( + Traverse the index calculation server and its position in the server. + Arguments + addr_table: Tensors of addr_table. + index: Tensors of index. + Output + rank_id: Tensors with the same shape as index.dim(0)*3. + )doc"); +//regist embedding local index op +REGISTER_OP("EmbeddingLocalIndex") + .Input("addr_table: uint64") + .Input("index: T") + .Output("local_idx: T") + .Output("nums: T") + .Output("recover_idx: T") + .Attr("T: {int64,int32,uint64,uint32}") + .Attr("row_memory: int = 320") + .Attr("mode: string = 'mod' ") + .SetAllowsUninitializedInput() + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto index_shape = c->input(1); + c->set_output(0, index_shape); + auto nums_shape = c->MakeShape({c->Dim(c->input(0), 0)}); + c->set_output(1, nums_shape); + c->set_output(2, index_shape); + return Status::OK(); + }) + .Doc(R"doc( + Traverse the index calculation server and its position in the server. + Arguments + addr_table: Tensors of addr_table. + index: Tensors of index. + Output + local_idx: Local_idx sorted by rank_id. + nums: The number of local_idx found on each rank_id. + recover_idx: The sorted local_idx element corresponds to the position of + the original input index. + )doc"); +//regist lru cahe op +REGISTER_OP("LruCache") + .Output("cache: resource") + .Attr("cache_size: int") + .Attr("load_factor: float = 1.0") + .Attr("container: string = ''") + .Attr("shared_name: string = 'LruCache'") + .Attr("dtype: {uint32, uint64, int32, int64}") + .SetIsStateful() + .SetShapeFn(shape_inference::ScalarShape); +//regist cache add op +REGISTER_OP("CacheAdd") + .Input("cache: resource") + .Input("ids: T") + .Output("swap_in_id: T") + .Output("swap_in_idx: T") + .Output("swap_out_id: T") + .Output("swap_out_idx: T") + .Attr("T: {int64, int32, uint64, uint32}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); + c->set_output(2, c->Vector(c->UnknownDim())); + c->set_output(3, c->Vector(c->UnknownDim())); + return Status::OK(); + }); +//regist cache remote index to local op +REGISTER_OP("CacheRemoteIndexToLocal") + .Input("cache: resource") + .Input("ids: T") + .Output("local_idx: T") + .Attr("T: {int64, int32, uint32, uint64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->Rank(c->input(1)))); + return Status::OK(); + }); +//regist cache all index to local op +REGISTER_OP("CacheAllIndexToLocal") + .Input("cache: resource") + .Output("local_idx: dtype") + .Attr("dtype: {int64, int32, uint32, uint64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->UnknownDim())); + return Status::OK(); + }); + +//regist deformable offsets op +REGISTER_OP("DeformableOffsets") + .Input("x: T") + .Input("offsets: T") + .Output("y: T") + .Attr("T: {float16, float32}") + .Attr("strides: list(int)") + .Attr("pads: list(int)") + .Attr("ksize: list(int)") + .Attr("dilations: list(int) = [1,1,1,1]") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("deformable_groups: int = 1") + .Attr("modulated: bool = true") + .SetShapeFn([](shape_inference::InferenceContext *c) { + std::string dt_format; + const std::set kValidFormat = {"NHWC", "NCHW"}; + if (!c->GetAttr("data_format", &dt_format).ok()) { + dt_format = "NHWC"; + } + if (kValidFormat.find(dt_format) == kValidFormat.end()) { + return errors::InvalidArgument("Invalid data format string: ", + dt_format); + } + + size_t pos_n = dt_format.find("N"); + size_t pos_c = dt_format.find("C"); + size_t pos_h = dt_format.find("H"); + size_t pos_w = dt_format.find("W"); + + auto input_x_shape = c->input(0); + auto input_offsets_shape = c->input(1); + int64_t input_offsets_h = c->Value(c->Dim(input_offsets_shape, pos_h)); + int64_t input_offsets_w = c->Value(c->Dim(input_offsets_shape, pos_w)); + + std::vector ksizes; + TF_RETURN_IF_ERROR(c->GetAttr("ksize", &ksizes)); + if (ksizes.size() != 2) { + return errors::InvalidArgument( + "ksize attribute should contain 2 values, but got: ", + ksizes.size()); + } + const int64_t kh = ksizes[0]; + const int64_t kw = ksizes[1]; + + const int32_t rank = 4; + std::vector out_dims(rank); + out_dims[pos_n] = c->Dim(input_x_shape, pos_n); + out_dims[pos_c] = c->Dim(input_x_shape, pos_c); + out_dims[pos_h] = c->MakeDim(input_offsets_h * kh); + out_dims[pos_w] = c->MakeDim(input_offsets_w * kw); + c->set_output(0, c->MakeShape(out_dims)); + return Status::OK(); + }); +//regist deformable offsets grad op +REGISTER_OP("DeformableOffsetsGrad") + .Input("grad: T") + .Input("x: T") + .Input("offsets: T") + .Output("grad_x: T") + .Output("grad_offsets: T") + .Attr("T: {float16, float32}") + .Attr("strides: list(int)") + .Attr("pads: list(int)") + .Attr("ksize: list(int)") + .Attr("dilations: list(int) = [1,1,1,1]") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("deformable_groups: int = 1") + .Attr("modulated: bool = true") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto input_x_shape = c->input(1); + auto input_offsets_shape = c->input(2); + c->set_output(0, input_x_shape); + c->set_output(1, input_offsets_shape); + return Status::OK(); + }); +//regist Random Choice With Mask op +REGISTER_OP("RandomChoiceWithMask") + .Input("x: bool") + .Output("y: int32") + .Output("mask: bool") + .Attr("count: int = 0") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetShapeFn([](shape_inference::InferenceContext *c) { + int64 count(0); + c->GetAttr("count", &count); + if (count >0) { + c->set_output(0, c->Matrix(count, c->Rank(c->input(0)))); + c->set_output(1, c->Vector(count)); + } else if (count == 0) { + c->set_output(0, c->Matrix(c->UnknownDim(), c->Rank(c->input(0)))); + c->set_output(1, c->Vector(c->UnknownDim())); + } else { + return errors::InvalidArgument( + "input count must greater or equal to 0 but instead is ", + count); + } + return Status::OK(); + }); +//regist dense image warp op +REGISTER_OP("DenseImageWarp") + .Input("image: T") + .Input("flow: S") + .Output("y: T") + .Attr("T: {float16, float32}") + .Attr("S: {float16, float32}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto input_image_shape = c->input(0); + c->set_output(0, input_image_shape); + return Status::OK(); + }); +//regist dense image warp grad op +REGISTER_OP("DenseImageWarpGrad") + .Input("grad: T") + .Input("image: T") + .Input("flow: S") + .Output("grad_image: T") + .Output("grad_flow: S") + .Attr("T: {float16, float32}") + .Attr("S: {float16, float32}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto input_image_shape = c->input(1); + auto input_flow_shape = c->input(2); + c->set_output(0, input_image_shape); + c->set_output(1, input_flow_shape); + return Status::OK(); + }); + + REGISTER_OP("ScatterElements") + .Input("data: T") + .Input("indices: indexT") + .Input("updates: T") + .Output("y: T") + .Attr("axis: int = 0") + .Attr("T: numbertype") + .Attr("indexT: {int32, int64}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + auto data_shape = c->input(0); + c->set_output(0, data_shape); + return Status::OK(); + }); + + REGISTER_OP("BatchEnqueue") + .Input("x: T") + .Input("queue_id: uint32") + .Output("enqueue_count: int32") + .Attr("batch_size: int = 8") + .Attr("queue_name: string = ''") + .Attr("pad_mode: {'REPLICATE', 'ZERO'} = 'REPLICATE'") + .Attr("T: {float16, float32, float64, int8, uint8, int16, uint16, int32, uint32, int64, uint64}") + .SetShapeFn(tensorflow::shape_inference::ScalarShape); + + REGISTER_OP("OCRRecognitionPreHandle") + .Input("imgs_data: uint8") + .Input("imgs_offset: int32") + .Input("imgs_size: int32") + .Input("langs: int32") + .Input("langs_score: T") + .Output("imgs: uint8") + .Output("imgs_relation: int32") + .Output("imgs_lang: int32") + .Attr("batch_size: int = 8") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .Attr("pad_mode: {'REPLICATE', 'ZERO'} = 'REPLICATE'") + .Attr("T: {float16, float32}") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); + c->set_output(2, c->Vector(c->UnknownDim())); + return Status::OK(); + }); + + REGISTER_OP("OCRDetectionPreHandle") + .Input("img: uint8") + .Output("resized_img: uint8") + .Output("h_scale: float32") + .Output("w_scale: float32") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .SetShapeFn([](shape_inference::InferenceContext *c) { + std::string dt_format; + const std::set kVaildFormat = {"NHWC", "NCHW"}; + if (!c->GetAttr("data_format", &dt_format).ok()) { + dt_format = "NHWC"; + } + if (kVaildFormat.find(dt_format) == kVaildFormat.end()) { + return errors::InvalidArgument("Invalid data format string: ", + dt_format); + } + const int32_t kRank = 3; + std::vector out_dims(kRank); + if (dt_format == "NHWC") { + out_dims[0] = c->UnknownDim(); + out_dims[1] = c->UnknownDim(); + out_dims[2] = c->MakeDim(3); + } else { + out_dims[0] = c->MakeDim(3); + out_dims[1] = c->UnknownDim(); + out_dims[2] = c->UnknownDim(); + } + c->set_output(0, c->MakeShape(out_dims)); + c->set_output(1, c->Scalar()); + c->set_output(2, c->Scalar()); + return Status::OK(); + }); + + REGISTER_OP("OCRIdentifyPreHandle") + .Input("imgs_data: uint8") + .Input("imgs_offset: int32") + .Input("imgs_size: int32") + .Output("resized_imgs: uint8") + .Attr("size: list(int)") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .SetShapeFn([](shape_inference::InferenceContext *c) { + std::vector size; + TF_RETURN_IF_ERROR(c->GetAttr("size", &size)); + if (size.size() != 2) { + return errors::InvalidArgument( + "size attribute should contain 2 values, but got: ", + size.size()); + } + const int64_t k1 = size[0]; + const int64_t k2 = size[1]; + + std::string dt_format; + const std::set kVaildFormat = {"NHWC", "NCHW"}; + if (!c->GetAttr("data_format", &dt_format).ok()) { + dt_format = "NHWC"; + } + if (kVaildFormat.find(dt_format) == kVaildFormat.end()) { + return errors::InvalidArgument("Invalid data format string: ", + dt_format); + } + const int32_t kRank = 4; + std::vector out_dims(kRank); + out_dims[0] = c->UnknownDim(); + if (dt_format == "NHWC") { + out_dims[0] = c->MakeDim(k1); + out_dims[1] = c->MakeDim(k2); + out_dims[2] = c->MakeDim(3); + } else { + out_dims[0] = c->MakeDim(3); + out_dims[1] = c->MakeDim(k1); + out_dims[2] = c->MakeDim(k2); + } + c->set_output(0, c->MakeShape(out_dims)); + return Status::OK(); + }); + +REGISTER_OP("BatchDilatePolys") + .Input("polys_data:int32") + .Input("polys_offset:int32") + .Input("polys_size:int32") + .Input("score:float") + .Input("min_border:int32") + .Input("min_area_thr:int32") + .Input("score_thr:float") + .Input("expands_cale:float") + .Output("dilated_polys_data:int32") + .Output("dilated_polys_offset:int32") + .Output("dilated_polys_size:int32") + .SetShapeFn([](shape_inference::InferenceContext *c){ + auto input_shape0=c->input(0); + auto input_shape1=c->input(1); + auto input_shape2=c->input(2); + auto input_shape3=c->input(3); + auto input_shape4=c->input(4); + auto input_shape5=c->input(5); + auto input_shape6=c->input(6); + auto input_shape7=c->input(7); + c->set_output(0,c->Vector(c->UnknownDim())); + c->set_output(1,c->Vector(c->UnknownDim())); + c->set_output(2,c->Vector(c->UnknownDim())); + return Status::OK(); + }); + +REGISTER_OP("OCRFindContours") + .Input("img:uint8") + .Output("polys_data:int32") + .Output("polys_offset:int32") + .Output("polys_size:int32") + .Attr("value_mode:int = 0") + .SetShapeFn([](shape_inference::InferenceContext *c){ + auto input_shape0=c->input(0); + auto input_shape1=c->input(1); + auto input_shape2=c->input(2); + auto input_shape3=c->input(3); + auto input_shape4=c->input(4); + auto input_shape5=c->input(5); + auto input_shape6=c->input(6); + auto input_shape7=c->input(7); + c->set_output(0,c->Vector(c->UnknownDim())); + c->set_output(1,c->Vector(c->UnknownDim())); + c->set_output(2,c->Vector(c->UnknownDim())); + return Status::OK(); + }); + +REGISTER_OP("Dequeue") + .Input("queue_id: uint32") + .Output("data: output_type") + .Attr("output_type: {float16, float32, float64, uint8, uint16} = DT_UINT8") + .Attr("output_shape: list(int)") + .Attr("queue_name: string = ''") + .SetShapeFn([](shape_inference::InferenceContext *c) { + std::vector output_shape; + TF_RETURN_IF_ERROR(c->GetAttr("output_shape", &output_shape)); + int32_t rank = output_shape.size(); + std::vector out_dims(rank); + for (auto i = 0; i < rank; ++i){ + out_dims[i] = c->MakeDim(output_shape[i]); + } + c->set_output(0, c->MakeShape(out_dims)); + return Status::OK(); + }); + + REGISTER_OP("OCRDetectionPostHandle") + .Input("img: uint8") + .Input("polys_data: int32") + .Input("polys_offset: int32") + .Input("polys_size: int32") + .Output("imgs_data: uint8") + .Output("imgs_offset: int32") + .Output("imgs_size: int32") + .Output("rect_points: int32") + .Attr("data_format: {'NHWC', 'NCHW'} = 'NHWC'") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->UnknownDim())); + auto data_shape = c->input(2); + c->set_output(1, data_shape); + c->set_output(2, c->Matrix(c->Rank(c->input(2)), 3)); + const int32_t rank = 3; + std::vector out_dims(rank); + out_dims[0] = c->Dim(data_shape, 0); + out_dims[1] = c->MakeDim(4); + out_dims[2] = c->MakeDim(2); + c->set_output(3, c->MakeShape(out_dims)); + return Status::OK(); + }); + + REGISTER_OP("ResizeAndClipPolys") + .Input("polys_data: int32") + .Input("polys_offset: int32") + .Input("polys_size: int32") + .Input("h_scale: float32") + .Input("w_scale: float32") + .Input("img_h: int32") + .Input("img_w: int32") + .Output("clipped_polys_data: int32") + .Output("clipped_polys_offset: int32") + .Output("clipped_polys_size: int32") + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->Vector(c->UnknownDim())); + c->set_output(1, c->Vector(c->UnknownDim())); + c->set_output(2, c->Vector(c->UnknownDim())); + return Status::OK(); + }); + +} // namespace tensorflow diff --git a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py index fb8b7e5d2d9bff0efe87b5aa8965ffae1480a4b1..2c902d9478aeb08bc9fc21823b9d67b28c27265a 100644 --- a/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py +++ b/tf_adapter/python/npu_bridge/npu_cpu/npu_cpu_ops.py @@ -1,152 +1,196 @@ -#!/usr/bin/env python3 -# -*- coding: utf-8 -*- -# Copyright 2019 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -from tensorflow.contrib.util import loader -from tensorflow.python.framework import load_library -from tensorflow.python.framework import ops -from tensorflow.python.platform import resource_loader -from npu_bridge.helper import helper - -gen_npu_cpu_ops = helper.get_gen_ops(); - - -## 提供embeddingrankid功能 -# @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; -# @param index tensorflow的tensor类型,embeddingrankid操作的输入; -# @param row_memory int类型,一行数据存储的大小 默认为320。 -# @param mode string类型,embeddingrankid的操作类型,可以为”mod”,”order”;数据存储的方式。 -# @return 对输入addr_tensor,index_tensor执行完embeddingrankid操作之后的结果tensor -def embeddingrankid(addr_tensor, index, row_memory=320, mode='mod'): - result = gen_npu_cpu_ops.embedding_rank_id( - addr_table=addr_tensor, - index=index, - row_memory=row_memory, - mode=mode) - return result - -## 提供embeddinglocalindex功能 -# @param addr_tensor tensorflow的tensor类型,embeddinglocalindex操作的输入; -# @param index tensorflow的tensor类型,embeddinglocalindex操作的输入; -# @param row_memory int类型,一行数据存储的大小 默认为320。 -# @param mode string类型,embeddinglocalindex的操作类型,可以为”mod”,”order”;数据存储的方式。 -# @return 对输入addr_tensor,index_tensor执行完embeddinglocalindex操作之后的结果tensor -def embedding_local_index(addr_tensor, index, row_memory=320, mode='mod'): - result = gen_npu_cpu_ops.embedding_local_index( - addr_table=addr_tensor, - index=index, - row_memory=row_memory, - mode=mode) - return result - -## 提供RandomChoiceWithMask功能 -# @param x bool 类型 -# @param count int 类型 -# @param seed int类型 -# @param seed2 int类型 -# @return y int32类型 mask bool 类型 -def randomchoicewithmask(x, count, seed=0, seed2=0): - result = gen_npu_cpu_ops.random_choice_with_mask( - x=x, - count=count, - seed=seed, - seed2=seed2) - return result - -## 提供DenseImageWarp功能 -# @param image tensor类型 -# @param flow tensor类型 -# @return y tensor类型 -def dense_image_warp(image, flow, name=None): - result = gen_npu_cpu_ops.dense_image_warp( - image=image, - flow=flow, - name=name - ) - return result - -## DenseImageWarp的梯度函数 -@ops.RegisterGradient("DenseImageWarp") -def dense_image_warp_grad(op, grad): - image = op.inputs[0] - flow = op.inputs[1] - grad_image, grad_flow = gen_npu_cpu_ops.dense_image_warp_grad( - grad, image, flow) - return [grad_image, grad_flow] - -## 提供BatchEnqueue功能 -# @param x uint8 类型 -# @param queue_id uint32 类型 -# @param batch_size int 类型 -# @param queue_name string 类型 -# @param pad_mode string 类型 -# @return enqueue_count int64类型 -def batch_enqueue(x, queue_id, batch_size=8, queue_name="", pad_mode="REPLICATE"): - result = gen_npu_cpu_ops.batch_enqueue( - x=x, - queue_id=queue_id, - batch_size=batch_size, - queue_name=queue_name, - pad_mode=pad_mode) - return result - -## 提供OCRRecognitionPreHandle功能 -# @param imgs_data uint8 类型 -# @param imgs_offset int32 类型 -# @param imgs_size int32 类型 -# @param langs int32 类型 -# @param langs_score int32 类型 -# @param batch_size int 类型 -# @param data_format string 类型 -# @param pad_mode string 类型 -# @return imgs,imgs_relation,imgs_lang uint8,int32,int32 类型 -def ocr_recognition_pre_handle(imgs_data, imgs_offset, imgs_size, langs, langs_score, batch_size=8, data_format="NHWC", pad_mode="REPLICATE"): - result = gen_npu_cpu_ops.ocr_recognition_pre_handle( - imgs_data=imgs_data, - imgs_offset=imgs_offset, - imgs_size=imgs_size, - langs=langs, - langs_score=langs_score, - batch_size=batch_size, - data_format=data_format, - pad_mode=pad_mode) - return result - -## 提供OCRDetectionPreHandle功能 -# @param img uint8 类型 -# @param data_format string 类型 -# @return resized_img,h_scale,w_scale uint8,float32,float32 类型 -def ocr_detection_pre_handle(img, data_format="NHWC"): - result = gen_npu_cpu_ops.ocr_recognition_pre_handle( - img=img, - data_format=data_format) - return result - -## 提供OCRIdentifyPreHandle功能 -# @param imgs_data uint8 类型 -# @param imgs_offset int32 类型 -# @param imgs_size int32 类型 -# @param size list(int) 类型 -# @param data_format string 类型 -# @return resized_imgs, uint8 类型 -def ocr_identify_pre_handle(imgs_data, imgs_offset, imgs_size, size, data_format="NHWC"): - result = gen_npu_cpu_ops.ocr_recognition_pre_handle( - imgs_data=imgs_data, - imgs_offset=imgs_offset, - imgs_size=imgs_size, - size=size, - data_format=data_format) +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from tensorflow.contrib.util import loader +from tensorflow.python.framework import load_library +from tensorflow.python.framework import ops +from tensorflow.python.platform import resource_loader +from npu_bridge.helper import helper + +gen_npu_cpu_ops = helper.get_gen_ops(); + + +## 提供embeddingrankid功能 +# @param addr_tensor tensorflow的tensor类型,embeddingrankid操作的输入; +# @param index tensorflow的tensor类型,embeddingrankid操作的输入; +# @param row_memory int类型,一行数据存储的大小 默认为320。 +# @param mode string类型,embeddingrankid的操作类型,可以为”mod”,”order”;数据存储的方式。 +# @return 对输入addr_tensor,index_tensor执行完embeddingrankid操作之后的结果tensor +def embeddingrankid(addr_tensor, index, row_memory=320, mode='mod'): + result = gen_npu_cpu_ops.embedding_rank_id( + addr_table=addr_tensor, + index=index, + row_memory=row_memory, + mode=mode) + return result + +## 提供embeddinglocalindex功能 +# @param addr_tensor tensorflow的tensor类型,embeddinglocalindex操作的输入; +# @param index tensorflow的tensor类型,embeddinglocalindex操作的输入; +# @param row_memory int类型,一行数据存储的大小 默认为320。 +# @param mode string类型,embeddinglocalindex的操作类型,可以为”mod”,”order”;数据存储的方式。 +# @return 对输入addr_tensor,index_tensor执行完embeddinglocalindex操作之后的结果tensor +def embedding_local_index(addr_tensor, index, row_memory=320, mode='mod'): + result = gen_npu_cpu_ops.embedding_local_index( + addr_table=addr_tensor, + index=index, + row_memory=row_memory, + mode=mode) + return result + +## 提供RandomChoiceWithMask功能 +# @param x bool 类型 +# @param count int 类型 +# @param seed int类型 +# @param seed2 int类型 +# @return y int32类型 mask bool 类型 +def randomchoicewithmask(x, count, seed=0, seed2=0): + result = gen_npu_cpu_ops.random_choice_with_mask( + x=x, + count=count, + seed=seed, + seed2=seed2) + return result + +## 提供DenseImageWarp功能 +# @param image tensor类型 +# @param flow tensor类型 +# @return y tensor类型 +def dense_image_warp(image, flow, name=None): + result = gen_npu_cpu_ops.dense_image_warp( + image=image, + flow=flow, + name=name + ) + return result + +## DenseImageWarp的梯度函数 +@ops.RegisterGradient("DenseImageWarp") +def dense_image_warp_grad(op, grad): + image = op.inputs[0] + flow = op.inputs[1] + grad_image, grad_flow = gen_npu_cpu_ops.dense_image_warp_grad( + grad, image, flow) + return [grad_image, grad_flow] + +## 提供BatchEnqueue功能 +# @param x uint8 类型 +# @param queue_id uint32 类型 +# @param batch_size int 类型 +# @param queue_name string 类型 +# @param pad_mode string 类型 +# @return enqueue_count int64类型 +def batch_enqueue(x, queue_id, batch_size=8, queue_name="", pad_mode="REPLICATE"): + result = gen_npu_cpu_ops.batch_enqueue( + x=x, + queue_id=queue_id, + batch_size=batch_size, + queue_name=queue_name, + pad_mode=pad_mode) + return result + +## 提供OCRRecognitionPreHandle功能 +# @param imgs_data uint8 类型 +# @param imgs_offset int32 类型 +# @param imgs_size int32 类型 +# @param langs int32 类型 +# @param langs_score int32 类型 +# @param batch_size int 类型 +# @param data_format string 类型 +# @param pad_mode string 类型 +# @return imgs,imgs_relation,imgs_lang uint8,int32,int32 类型 +def ocr_recognition_pre_handle(imgs_data, imgs_offset, imgs_size, langs, langs_score, batch_size=8, data_format="NHWC", pad_mode="REPLICATE"): + result = gen_npu_cpu_ops.ocr_recognition_pre_handle( + imgs_data=imgs_data, + imgs_offset=imgs_offset, + imgs_size=imgs_size, + langs=langs, + langs_score=langs_score, + batch_size=batch_size, + data_format=data_format, + pad_mode=pad_mode) + return result + +## 提供OCRDetectionPreHandle功能 +# @param img uint8 类型 +# @param data_format string 类型 +# @return resized_img,h_scale,w_scale uint8,float32,float32 类型 +def ocr_detection_pre_handle(img, data_format="NHWC"): + result = gen_npu_cpu_ops.ocr_recognition_pre_handle( + img=img, + data_format=data_format) + return result + +## 提供OCRIdentifyPreHandle功能 +# @param imgs_data uint8 类型 +# @param imgs_offset int32 类型 +# @param imgs_size int32 类型 +# @param size list(int) 类型 +# @param data_format string 类型 +# @return resized_imgs, uint8 类型 +def ocr_identify_pre_handle(imgs_data, imgs_offset, imgs_size, size, data_format="NHWC"): + result = gen_npu_cpu_ops.ocr_recognition_pre_handle( + imgs_data=imgs_data, + imgs_offset=imgs_offset, + imgs_size=imgs_size, + size=size, + data_format=data_format) + return result + +def batch_dilate_polys(polys_data, polys_offset,polys_size,score,min_border,min_area_thr,score_thr,expands_cale): + result = gen_npu_cpu_ops.batch_dilate_polys( + polys_data=polys_data, + polys_offset=polys_offset, + polys_size=polys_size, + score=score, + min_border=min_border, + min_area_thr=min_area_thr, + score_thr=score_thr, + expands_cale=expands_cale) + return result + +def ocr_find_contours(img, value_mode=0): + result = gen_npu_cpu_ops.ocr_find_contours(img=img,value_mode=value_mode) + return result + +def dequeue(queue_id, output_type, output_shape, queue_name=""): + result = gen_npu_cpu_ops.dequeue( + queue_id=queue_id, + output_type=output_type, + output_shape=output_shape, + queue_name=queue_name) + return result + +def ocr_detection_post_handle(img, polys_data, polys_offset, polys_size, data_format="NHWC"): + result = gen_npu_cpu_ops.ocr_detection_post_handle( + img=img, + polys_data=polys_data, + polys_offset=polys_offset, + polys_size=polys_size, + data_format=data_format) + return result + +def resize_and_clip_polys(polys_data, polys_offset, polys_size, h_scale, w_scale, img_h, img_w): + result = gen_npu_cpu_ops.resize_and_clip_polys( + polys_data=polys_data, + polys_offset=polys_offset, + polys_size=polys_size, + h_scale=h_scale, + w_scale=w_scale, + img_h=img_h, + img_w=img_w) return result \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc index bef8258b205bd322dd5c9881b3ece0552b946408..1ca0b3d2b771e6161903f4763e98d7349b3d4904 100644 --- a/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/npu_cpu_ops_test.cc @@ -1,29 +1,29 @@ -#include -#include "tf_adapter/kernels/npu_cpu_ops.cc" -#include "gtest/gtest.h" - -namespace tensorflow { -class NpuCpuOpTest : public testing::Test { - protected: - virtual void SetUp() {} - virtual void TearDown() {} -}; - -TEST_F(NpuCpuOpTest, TestCacheAdd) { - DataTypeSlice input_types({DT_RESOURCE, DT_INT64}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_INT64, DT_INT64, DT_INT64, DT_INT64}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - CacheAddOp cache(context); - delete device; - delete node_def; - delete op_def; - delete context; -} +#include +#include "tf_adapter/kernels/npu_cpu_ops.cc" +#include "gtest/gtest.h" + +namespace tensorflow { +class NpuCpuOpTest : public testing::Test { + protected: + virtual void SetUp() {} + virtual void TearDown() {} +}; + +TEST_F(NpuCpuOpTest, TestCacheAdd) { + DataTypeSlice input_types({DT_RESOURCE, DT_INT64}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT64, DT_INT64, DT_INT64, DT_INT64}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + CacheAddOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} } \ No newline at end of file diff --git a/tf_adapter/tests/ut/kernels/testcase/ocr_ops_test.cc b/tf_adapter/tests/ut/kernels/testcase/ocr_ops_test.cc index d688e4db5e7af3cb23f2033baaf6e7f36a36b7c5..85456c1a0a66117aba3bae0bfc7b04fef5f20cf1 100644 --- a/tf_adapter/tests/ut/kernels/testcase/ocr_ops_test.cc +++ b/tf_adapter/tests/ut/kernels/testcase/ocr_ops_test.cc @@ -1,171 +1,345 @@ -#include -#include "tf_adapter/kernels/npu_cpu_ops.cc" -#include "gtest/gtest.h" -#include "tensorflow/core/framework/attr_value.pb.h" -#include "tensorflow/core/framework/attr_value_util.h" -#include "tensorflow/core/framework/fake_input.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/node_def_builder.h" -#include "tensorflow/core/framework/op_kernel.h" -#include "tensorflow/core/framework/shape_inference.h" -#include "tensorflow/core/platform/test.h" - -namespace tensorflow { -namespace { - -PartialTensorShape TShape(std::initializer_list dims) { - return PartialTensorShape(dims); -} - -FakeInputFunctor FakeInputStub(DataType dt) { - return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, - NodeDefBuilder* builder) { - char c = 'a' + (in_index % 26); - string in_node = string(&c, 1); - builder->Input(in_node, 0, dt); - return Status::OK(); - }; -} - -TEST(OCROpsTest, TestBatchEnqueue) { - DataTypeSlice input_types({DT_INT32, DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_INT32}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - BatchEnqueueOp cache(context); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(OCROpsTest, TestOCRRecognitionPreHandle) { - DataTypeSlice input_types({DT_UINT8, DT_INT32, DT_INT32, DT_INT32, DT_FLOAT}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_UINT8, DT_INT32, DT_INT32}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - OCRRecognitionPreHandleOp cache(context); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(OCROpsTest, TestOCRDetectionPreHandle) { - DataTypeSlice input_types({DT_UINT8}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_UINT8, DT_FLOAT, DT_FLOAT}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - OCRDetectionPreHandleOp cache(context); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(OCROpsTest, TestOCRIdentifyPreHandle) { - DataTypeSlice input_types({DT_UINT8, DT_INT32, DT_INT32}); - MemoryTypeSlice input_memory_types; - DataTypeSlice output_types({DT_UINT8}); - MemoryTypeSlice output_memory_types; - DeviceBase *device = new DeviceBase(Env::Default()); - NodeDef *node_def = new NodeDef(); - OpDef *op_def = new OpDef(); - OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, - input_types, input_memory_types, output_types, output_memory_types, - 1, nullptr); - OCRIdentifyPreHandleOp cache(context); - delete device; - delete node_def; - delete op_def; - delete context; -} - -TEST(OCROpsTest, TestBatchEnqueueShapeInference) { - const OpRegistrationData* reg; - TF_CHECK_OK(OpRegistry::Global()->LookUp("BatchEnqueue", ®)); - OpDef op_def = reg->op_def; - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) - .Attr("T", DT_INT32) - .Attr("batch_size", 8) - .Attr("queue_name", "TEST") - .Attr("pad_mode", "REPLICATE") - .Input(FakeInputStub(DT_INT32)) - .Input(FakeInputStub(DT_UINT32)) - .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def,{TShape({5}), TShape({})}, {}, {}, {}); - TF_CHECK_OK(reg->shape_inference_fn(&c)); -} - -TEST(OCROpsTest, TestOCRRecognitionPreHandleShapeInference) { - const OpRegistrationData* reg; - TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRRecognitionPreHandle", ®)); - OpDef op_def = reg->op_def; - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) - .Attr("T", DT_FLOAT) - .Attr("batch_size", 8) - .Attr("data_format", "NHWC") - .Attr("pad_mode", "REPLICATE") - .Input(FakeInputStub(DT_UINT8)) - .Input(FakeInputStub(DT_INT32)) - .Input(FakeInputStub(DT_INT32)) - .Input(FakeInputStub(DT_INT32)) - .Input(FakeInputStub(DT_FLOAT)) - .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, - {TShape({3}), TShape({3}), TShape({3}), TShape({3}), TShape({3})}, {}, {}, {}); - TF_CHECK_OK(reg->shape_inference_fn(&c)); -} - -TEST(OCROpsTest, TestOCRDetectionPreHandleShapeInference) { - const OpRegistrationData* reg; - TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRDetectionPreHandle", ®)); - OpDef op_def = reg->op_def; - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) - .Attr("data_format", "NHWC") - .Input(FakeInputStub(DT_UINT8)) - .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, - {TShape({3})}, {}, {}, {}); - TF_CHECK_OK(reg->shape_inference_fn(&c)); -} - -TEST(OCROpsTest, TestOCRIdentifyPreHandleShapeInference) { - const OpRegistrationData* reg; - TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRIdentifyPreHandle", ®)); - OpDef op_def = reg->op_def; - NodeDef def; - TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) - .Attr("size", {1,2}) - .Attr("data_format", "NHWC") - .Input(FakeInputStub(DT_UINT8)) - .Input(FakeInputStub(DT_INT32)) - .Input(FakeInputStub(DT_INT32)) - .Finalize(&def)); - shape_inference::InferenceContext c(0, &def, op_def, - {TShape({3}), TShape({3}), TShape({3})}, {}, {}, {}); - TF_CHECK_OK(reg->shape_inference_fn(&c)); -} -} // namespace +#include +#include "tf_adapter/kernels/npu_cpu_ops.cc" +#include "gtest/gtest.h" +#include "tensorflow/core/framework/attr_value.pb.h" +#include "tensorflow/core/framework/attr_value_util.h" +#include "tensorflow/core/framework/fake_input.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/framework/node_def_builder.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/platform/test.h" + +namespace tensorflow { +namespace { + +PartialTensorShape TShape(std::initializer_list dims) { + return PartialTensorShape(dims); +} + +FakeInputFunctor FakeInputStub(DataType dt) { + return [dt](const OpDef& op_def, int in_index, const NodeDef& node_def, + NodeDefBuilder* builder) { + char c = 'a' + (in_index % 26); + string in_node = string(&c, 1); + builder->Input(in_node, 0, dt); + return Status::OK(); + }; +} + +TEST(OCROpsTest, TestBatchEnqueue) { + DataTypeSlice input_types({DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + BatchEnqueueOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestOCRRecognitionPreHandle) { + DataTypeSlice input_types({DT_UINT8, DT_INT32, DT_INT32, DT_INT32, DT_FLOAT}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_UINT8, DT_INT32, DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + OCRRecognitionPreHandleOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestOCRDetectionPreHandle) { + DataTypeSlice input_types({DT_UINT8}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_UINT8, DT_FLOAT, DT_FLOAT}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + OCRDetectionPreHandleOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestOCRIdentifyPreHandle) { + DataTypeSlice input_types({DT_UINT8, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_UINT8}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + OCRIdentifyPreHandleOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestBatchDilatePolys) { + DataTypeSlice input_types({DT_INT32, DT_INT32, DT_INT32,DT_FLOAT,DT_INT32, DT_INT32,DT_FLOAT,DT_FLOAT}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + BatchDilatePolysOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestOCRFindContours) { + DataTypeSlice input_types({DT_UINT8}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + OCRFindContoursOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestDequeue) { + DataTypeSlice input_types({DT_UINT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_UINT8}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + DequeueOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestOCRDetectionPostHandle) { + DataTypeSlice input_types({DT_UINT8, DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_UINT8, DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + OCRDetectionPostHandleOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestResizeAndClipPolys) { + DataTypeSlice input_types({DT_INT32, DT_INT32, DT_INT32, DT_FLOAT, DT_FLOAT, DT_INT32, DT_INT32}); + MemoryTypeSlice input_memory_types; + DataTypeSlice output_types({DT_INT32, DT_INT32, DT_INT32}); + MemoryTypeSlice output_memory_types; + DeviceBase *device = new DeviceBase(Env::Default()); + NodeDef *node_def = new NodeDef(); + OpDef *op_def = new OpDef(); + OpKernelConstruction *context = new OpKernelConstruction(DEVICE_CPU, device, nullptr, node_def, op_def, nullptr, + input_types, input_memory_types, output_types, output_memory_types, + 1, nullptr); + ResizeAndClipPolysOp cache(context); + delete device; + delete node_def; + delete op_def; + delete context; +} + +TEST(OCROpsTest, TestBatchEnqueueShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("BatchEnqueue", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_INT32) + .Attr("batch_size", 8) + .Attr("queue_name", "TEST") + .Attr("pad_mode", "REPLICATE") + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_UINT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({5}), TShape({})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestOCRRecognitionPreHandleShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRRecognitionPreHandle", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("T", DT_FLOAT) + .Attr("batch_size", 8) + .Attr("data_format", "NHWC") + .Attr("pad_mode", "REPLICATE") + .Input(FakeInputStub(DT_UINT8)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_FLOAT)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({3}), TShape({3}), TShape({3}), TShape({3}), TShape({3})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestOCRDetectionPreHandleShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRDetectionPreHandle", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("data_format", "NHWC") + .Input(FakeInputStub(DT_UINT8)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({3})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestOCRIdentifyPreHandleShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRIdentifyPreHandle", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("size", {1,2}) + .Attr("data_format", "NHWC") + .Input(FakeInputStub(DT_UINT8)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({3}), TShape({3}), TShape({3})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestBatchDilatePolysShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("BatchDilatePolys", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({1}), TShape({1}), TShape({1}), TShape({1}),TShape({1}),TShape({1}),TShape({1})},{}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestOCRFindContoursShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRFindContours", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("value_mode", 0) + .Input(FakeInputStub(DT_UINT8)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def, + {TShape({2})},{}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestDequeueShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("Dequeue", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("queue_name", "TEST") + .Attr("output_type", DT_UINT8) + .Attr("output_shape", {2}) + .Input(FakeInputStub(DT_UINT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestOCRDetectionPostHandleShapeInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("OCRDetectionPostHandle", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Attr("data_format", "NHWC") + .Input(FakeInputStub(DT_UINT8)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({3}), TShape({3}), TShape({3}), TShape({3})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +TEST(OCROpsTest, TestResizeAndClipPolysInference) { + const OpRegistrationData* reg; + TF_CHECK_OK(OpRegistry::Global()->LookUp("ResizeAndClipPolys", ®)); + OpDef op_def = reg->op_def; + NodeDef def; + TF_CHECK_OK(NodeDefBuilder("dummy", &op_def) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_FLOAT)) + .Input(FakeInputStub(DT_INT32)) + .Input(FakeInputStub(DT_INT32)) + .Finalize(&def)); + shape_inference::InferenceContext c(0, &def, op_def,{TShape({})}, {}, {}, {}); + TF_CHECK_OK(reg->shape_inference_fn(&c)); +} + +} // namespace } // namespace tensorflow \ No newline at end of file