From 531de383f0efe55b32c40c4f9a2509f8091aa6ee Mon Sep 17 00:00:00 2001 From: koala_zhang <571700104@qq.com> Date: Mon, 8 Mar 2021 20:48:17 +0800 Subject: [PATCH 1/4] add DropOutGenMaskV3 --- tf_adapter/kernels/dropout_ops.cc | 1 + tf_adapter/ops/npu_ops.cc | 32 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/tf_adapter/kernels/dropout_ops.cc b/tf_adapter/kernels/dropout_ops.cc index 7f58b314b..fd37943b7 100644 --- a/tf_adapter/kernels/dropout_ops.cc +++ b/tf_adapter/kernels/dropout_ops.cc @@ -47,5 +47,6 @@ class DropOutGenMaskOp : public OpKernel { }; REGISTER_KERNEL_BUILDER(Name("DropOutGenMask").Device(DEVICE_CPU), DropOutGenMaskOp); +REGISTER_KERNEL_BUILDER(Name("DropOutGenMaskV3").Device(DEVICE_CPU), DropOutGenMaskOp); REGISTER_KERNEL_BUILDER(Name("DropOutDoMask").Device(DEVICE_CPU), DropOutDoMaskOp); } // namespace tensorflow \ No newline at end of file diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 234330c33..04b6386ea 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -214,6 +214,38 @@ REGISTER_OP("DropOutGenMask") return Status::OK(); }); +REGISTER_OP("DropOutGenMaskV3") + .Input("shape: T") + .Attr("T: {int64, int32}") + .Input("prob: S") + .Attr("S: {float, half}") + .Output("output: uint8") + .Attr("seed: int = 0") + .Attr("seed2: int = 0") + .SetIsStateful() + .SetShapeFn([](shape_inference::InferenceContext *c) { + ShapeHandle unused; + TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); // prob must be 0-d + ShapeHandle input_shape_handle; + TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &input_shape_handle)); + if(!c->FulllyDefined(input_shape_handle)) { + ShapeHandle out = c->UnknowShapeOfRank(); + c->set_output(0, out); + return Status::OK(); + } + DimensionHandle input_dim_handle = c->NumElement(&input_shape_handle); + uint64 random_count = static_cast(c->Value(input_dim_handle)); + if(random_count > (INT64 - 15)) { + return errors::InvalidArgument("Required random count[", random_count, + "] exceed INT64 - 15"); + } + // align to 16 + random_count = (random_count + 15) & (~15); + ShapeHandle out = c->Vector(static_cast(random_count)); + c->set_output(0, out); + return Status::OK(); + }); + REGISTER_OP("BasicLSTMCell") .Input("x: T") .Input("h: T") -- Gitee From 9622723344d339962e2f56a880187dd0138bc46f Mon Sep 17 00:00:00 2001 From: koala_zhang <571700104@qq.com> Date: Tue, 9 Mar 2021 14:57:32 +0800 Subject: [PATCH 2/4] DropOutGenMaskV3 --- tf_adapter/ops/npu_ops.cc | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index 04b6386ea..f799db8d8 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -233,11 +233,11 @@ REGISTER_OP("DropOutGenMaskV3") c->set_output(0, out); return Status::OK(); } - DimensionHandle input_dim_handle = c->NumElement(&input_shape_handle); + DimensionHandle input_dim_handle = c->NumElements(input_shape_handle); uint64 random_count = static_cast(c->Value(input_dim_handle)); - if(random_count > (INT64 - 15)) { + if(random_count > (INT64_MAX - 15)) { return errors::InvalidArgument("Required random count[", random_count, - "] exceed INT64 - 15"); + "] exceed INT64_MAX - 15"); } // align to 16 random_count = (random_count + 15) & (~15); -- Gitee From acde92e90c33ee8518ef72ff0a2b3ea31e2b0b77 Mon Sep 17 00:00:00 2001 From: koala_zhang <571700104@qq.com> Date: Tue, 9 Mar 2021 15:14:50 +0800 Subject: [PATCH 3/4] DropOutGenMaskV3 --- tf_adapter/ops/npu_ops.cc | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index f799db8d8..e22f091b4 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -228,8 +228,8 @@ REGISTER_OP("DropOutGenMaskV3") TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(1), 0, &unused)); // prob must be 0-d ShapeHandle input_shape_handle; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &input_shape_handle)); - if(!c->FulllyDefined(input_shape_handle)) { - ShapeHandle out = c->UnknowShapeOfRank(); + if(!c->FullyDefined(input_shape_handle)) { + ShapeHandle out = c->UnknownShapeOfRank(); c->set_output(0, out); return Status::OK(); } -- Gitee From 105315c053a4bbc606e9d3f80c1dab2888c5e231 Mon Sep 17 00:00:00 2001 From: koala_zhang <571700104@qq.com> Date: Tue, 9 Mar 2021 15:20:19 +0800 Subject: [PATCH 4/4] DropOutGenMaskV3 --- tf_adapter/ops/npu_ops.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tf_adapter/ops/npu_ops.cc b/tf_adapter/ops/npu_ops.cc index e22f091b4..de5544cd2 100644 --- a/tf_adapter/ops/npu_ops.cc +++ b/tf_adapter/ops/npu_ops.cc @@ -229,7 +229,7 @@ REGISTER_OP("DropOutGenMaskV3") ShapeHandle input_shape_handle; TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &input_shape_handle)); if(!c->FullyDefined(input_shape_handle)) { - ShapeHandle out = c->UnknownShapeOfRank(); + ShapeHandle out = c->UnknownShapeOfRank(1); c->set_output(0, out); return Status::OK(); } -- Gitee