From 89f942b68abec534f62a2cce16c223ebfafafc75 Mon Sep 17 00:00:00 2001 From: yanzhi2024 Date: Sat, 9 Nov 2024 18:37:22 +0800 Subject: [PATCH] EmbeddingHashTableApplyAdamW infer shape bug fix --- tf_adapter/ops/aicore/npu_aicore_ops.cc | 24 ++++++------------------ 1 file changed, 6 insertions(+), 18 deletions(-) diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 166956fcb..c4dfd316c 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -699,35 +699,23 @@ REGISTER_OP("ProdEnvMatA") REGISTER_OP("EmbeddingHashTableApplyAdamW") .Input("table_handle: int64") .Input("keys: int64") - .Input("m: T") - .Input("v: T") - .Input("beta1_power: T") - .Input("beta2_power: T") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: Ref(T)") + .Input("beta2_power: Ref(T)") .Input("lr: T") .Input("weight_decay: T") .Input("beta1: T") .Input("beta2: T") .Input("epsilon: T") .Input("grad: T") - .Input("max_grad_norm: T") - .Output("m_output: T") - .Output("v_output: T") - .Output("beta1_power_output: T") - .Output("beta2_power_output: T") - .Output("max_grad_norm_output: T") + .Input("max_grad_norm: Ref(T)") .Attr("embedding_dim: int") .Attr("bucket_size: int") .Attr("amsgrad: bool = false") .Attr("maximize: bool = false") .Attr("T: {float16, float32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->input(2)); - c->set_output(1, c->input(3)); - c->set_output(2, c->input(4)); - c->set_output(3, c->input(5)); - c->set_output(4, c->input(12)); - return Status::OK(); - }); + .SetShapeFn(tensorflow::shape_inference::NoOutputs); REGISTER_OP("ProdVirialSeA") .Input("net_deriv:T") -- Gitee