diff --git a/tf_adapter/ops/aicore/npu_aicore_ops.cc b/tf_adapter/ops/aicore/npu_aicore_ops.cc index 166956fcbfcad85a0e319981250a071c43517d5d..c4dfd316c1fe1e214f9eae53733d6b0d34569680 100644 --- a/tf_adapter/ops/aicore/npu_aicore_ops.cc +++ b/tf_adapter/ops/aicore/npu_aicore_ops.cc @@ -699,35 +699,23 @@ REGISTER_OP("ProdEnvMatA") REGISTER_OP("EmbeddingHashTableApplyAdamW") .Input("table_handle: int64") .Input("keys: int64") - .Input("m: T") - .Input("v: T") - .Input("beta1_power: T") - .Input("beta2_power: T") + .Input("m: Ref(T)") + .Input("v: Ref(T)") + .Input("beta1_power: Ref(T)") + .Input("beta2_power: Ref(T)") .Input("lr: T") .Input("weight_decay: T") .Input("beta1: T") .Input("beta2: T") .Input("epsilon: T") .Input("grad: T") - .Input("max_grad_norm: T") - .Output("m_output: T") - .Output("v_output: T") - .Output("beta1_power_output: T") - .Output("beta2_power_output: T") - .Output("max_grad_norm_output: T") + .Input("max_grad_norm: Ref(T)") .Attr("embedding_dim: int") .Attr("bucket_size: int") .Attr("amsgrad: bool = false") .Attr("maximize: bool = false") .Attr("T: {float16, float32}") - .SetShapeFn([](shape_inference::InferenceContext *c) { - c->set_output(0, c->input(2)); - c->set_output(1, c->input(3)); - c->set_output(2, c->input(4)); - c->set_output(3, c->input(5)); - c->set_output(4, c->input(12)); - return Status::OK(); - }); + .SetShapeFn(tensorflow::shape_inference::NoOutputs); REGISTER_OP("ProdVirialSeA") .Input("net_deriv:T")