From d3b7d09b3c73c018122a5612cce011cf1c74f0ea Mon Sep 17 00:00:00 2001 From: Gaffey_1437 Date: Thu, 24 Dec 2020 16:18:58 +0800 Subject: [PATCH] Prelu --- tf_adapter/kernels/prelu.cc | 42 +++++++++++++++++++ tf_adapter/kernels/prelu_grad.cc | 42 +++++++++++++++++++ tf_adapter/ops/npu_ops.cc | 24 +++++++++++ .../python/npu_bridge/estimator/npu_ops.py | 9 ++++ 4 files changed, 117 insertions(+) create mode 100644 tf_adapter/kernels/prelu.cc create mode 100644 tf_adapter/kernels/prelu_grad.cc diff --git a/tf_adapter/kernels/prelu.cc b/tf_adapter/kernels/prelu.cc new file mode 100644 index 000000000..26b9a186f --- /dev/null +++ b/tf_adapter/kernels/prelu.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +class PReluOp : public OpKernel { + public: + explicit PReluOp(OpKernelConstruction *context) : OpKernel(context) {} + ~PReluOp() override = default; + void Compute(OpKernelContext *context) override { + LOG(INFO) << "PReluOp Compute, num_inputs: " << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("PRelu").Device(DEVICE_CPU), PReluOp); +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/kernels/prelu_grad.cc b/tf_adapter/kernels/prelu_grad.cc new file mode 100644 index 000000000..8ea81991c --- /dev/null +++ b/tf_adapter/kernels/prelu_grad.cc @@ -0,0 +1,42 @@ +/* Copyright 2017 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +Copyright (C) 2019-2020. Huawei Technologies Co., Ltd. All rights reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/core/framework/op_kernel.h" + +namespace tensorflow { +class PReluGradOp : public OpKernel { + public: + explicit PReluGradOp(OpKernelConstruction *context) : OpKernel(context) {} + ~PReluGradOp() override = default; + void Compute(OpKernelContext *context) override { + LOG(INFO) << "PReluGradOp Compute, num_inputs: " << context->num_inputs(); + } + bool IsExpensive() override { return false; } +}; + +REGISTER_KERNEL_BUILDER(Name("PReluGrad").Device(DEVICE_CPU), PReluGradOp); +} // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 234330c33..22a8676a4 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -214,6 +214,30 @@ REGISTER_OP("DropOutGenMask") return Status::OK(); }); +REGISTER_OP("PRelu") + .Input("x: T") + .Input("weight: T") + .Output("y: T") + .Attr("T: {float16, float32}") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(0)); + return Status::OK(); + }); + +REGISTER_OP("PReluGrad") + .Input("grads: float32") + .Input("features: float32") + .Input("weights: float32") + .Output("dx: float32") + .Output("da: float32") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + c->set_output(0, c->input(1)); + c->set_output(1, c->input(1)); + return Status::OK(); + }); + REGISTER_OP("BasicLSTMCell") .Input("x: T") .Input("h: T") diff --git a/tf_adapter/python/npu_bridge/estimator/npu_ops.py b/tf_adapter/python/npu_bridge/estimator/npu_ops.py index ab88bf56c..68da4afa0 100644 --- a/tf_adapter/python/npu_bridge/estimator/npu_ops.py +++ b/tf_adapter/python/npu_bridge/estimator/npu_ops.py @@ -248,3 +248,12 @@ def adam_apply_one_with_decay_assign(input0, input1, input2, input3, input4, result = gen_npu_ops.adam_apply_one_with_decay_assign(input0, input1, input2, input3, input4, mul0_x, mul1_x, mul2_x, mul3_x, mul4_x, add2_y, name) return result + +@ops.RegisterGradient("PRelu") +def prelu_grad(op, grad): + dx, da = gen_npu_ops.prelu_grad(grad, op.inputs[1], op.inputs[2]) + return [dx, da] + +def prelu(x, weight): + return gen_npu_ops.prelu(x,weight) + -- Gitee