From 06b8301097a434a755dae29c6e93c8b1d15d389f Mon Sep 17 00:00:00 2001 From: leo Date: Thu, 29 Feb 2024 11:46:32 +0800 Subject: [PATCH] fixed a31cdfc from https://gitee.com/guopeian/tensorflow/pulls/2580 tmp --- tf_adapter/kernels/geop_npu.cc | 21 ++++++++++++++++++++- tf_adapter/kernels/geop_npu.h | 2 ++ 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 096f2447e..1cc318f98 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -941,6 +941,24 @@ void GeOp::SetDynamicInput() { } } +void GeOp::MakeAllInputsShapeUnknown(OpKernelContext *const ctx, bool &updated) { + ADP_LOG(INFO) << "Begin make all shape unknown"; + constexpr static int64_t kUnknownDim = -1; + const static auto kUnknownRankShape = PartialTensorShape(); + for (size_t i = 0UL; i < static_cast(ctx->num_inputs()); i++) { + auto &shape = input_shapes_vec_[i]; + auto &value_shape = ctx->input(static_cast(i)).shape(); + std::vector dims; + for (int i = 0; i < value_shape.dims(); i++) { + dims.push_back(kUnknownDim); + } + PartialTensorShape out_shape; + auto status = PartialTensorShape::MakePartialShape(dims.data(), static_cast(dims.size()), &out_shape); + shape = status.ok() ? out_shape : kUnknownRankShape; + ADP_LOG(INFO) << "Refresh input " << i << " shape to " << shape.value().DebugString(); + } +} + PartialTensorShape GeOp::MakeCompatShape(const PartialTensorShape &a, const PartialTensorShape &b) const { const static auto kUnknownRankShape = PartialTensorShape(); if (a.dims() != b.dims()) { @@ -976,7 +994,8 @@ bool GeOp::MaybeUpdateShape(OpKernelContext *const ctx) { shape = value_shape; ADP_LOG(WARNING) << "Dynamic shape, recommended to configure jit_compile value to false or auto"; } else { - shape = MakeCompatShape(shape.value(), value_shape); + MakeAllInputsShapeUnknown(ctx, updated); + return updated; } ADP_LOG(INFO) << "Refresh input " << i << " shape to " << shape.value().DebugString(); } diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 86bd5f317..6bdbdc252 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -158,6 +158,8 @@ public: bool IsDynamicConfig(); + void MakeAllInputsShapeUnknown(OpKernelContext *const ctx, bool &updated); + PartialTensorShape MakeCompatShape(const PartialTensorShape &a, const PartialTensorShape &b) const; bool MaybeUpdateShape(OpKernelContext *const ctx); -- Gitee