diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 096f2447e916c7e97877262c40cb1eb2a72cc5be..1cc318f98e848478dc60718e09f241cc475619ad 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -941,6 +941,24 @@ void GeOp::SetDynamicInput() { } } +void GeOp::MakeAllInputsShapeUnknown(OpKernelContext *const ctx, bool &updated) { + ADP_LOG(INFO) << "Begin make all shape unknown"; + constexpr static int64_t kUnknownDim = -1; + const static auto kUnknownRankShape = PartialTensorShape(); + for (size_t i = 0UL; i < static_cast(ctx->num_inputs()); i++) { + auto &shape = input_shapes_vec_[i]; + auto &value_shape = ctx->input(static_cast(i)).shape(); + std::vector dims; + for (int i = 0; i < value_shape.dims(); i++) { + dims.push_back(kUnknownDim); + } + PartialTensorShape out_shape; + auto status = PartialTensorShape::MakePartialShape(dims.data(), static_cast(dims.size()), &out_shape); + shape = status.ok() ? out_shape : kUnknownRankShape; + ADP_LOG(INFO) << "Refresh input " << i << " shape to " << shape.value().DebugString(); + } +} + PartialTensorShape GeOp::MakeCompatShape(const PartialTensorShape &a, const PartialTensorShape &b) const { const static auto kUnknownRankShape = PartialTensorShape(); if (a.dims() != b.dims()) { @@ -976,7 +994,8 @@ bool GeOp::MaybeUpdateShape(OpKernelContext *const ctx) { shape = value_shape; ADP_LOG(WARNING) << "Dynamic shape, recommended to configure jit_compile value to false or auto"; } else { - shape = MakeCompatShape(shape.value(), value_shape); + MakeAllInputsShapeUnknown(ctx, updated); + return updated; } ADP_LOG(INFO) << "Refresh input " << i << " shape to " << shape.value().DebugString(); } diff --git a/tf_adapter/kernels/geop_npu.h b/tf_adapter/kernels/geop_npu.h index 86bd5f317216bd7f1ba58cdb599aa74bcd3a47e2..6bdbdc2525ef24dc11d2dd2162773348ff6cf0e9 100644 --- a/tf_adapter/kernels/geop_npu.h +++ b/tf_adapter/kernels/geop_npu.h @@ -158,6 +158,8 @@ public: bool IsDynamicConfig(); + void MakeAllInputsShapeUnknown(OpKernelContext *const ctx, bool &updated); + PartialTensorShape MakeCompatShape(const PartialTensorShape &a, const PartialTensorShape &b) const; bool MaybeUpdateShape(OpKernelContext *const ctx);