diff --git a/tf_adapter/ops/aicpu/npu_cpu_ops.cc b/tf_adapter/ops/aicpu/npu_cpu_ops.cc index 93c822cca5ba097b9700f4faddfd857b8d69c9a4..44471265dbe5f29966d7323ccf9aa1fa4144b4fd 100644 --- a/tf_adapter/ops/aicpu/npu_cpu_ops.cc +++ b/tf_adapter/ops/aicpu/npu_cpu_ops.cc @@ -232,7 +232,12 @@ REGISTER_OP("InitEmbeddingHashmap") .Attr("value_total_len: int = 0") .Attr("dtype: {uint8, uint16, float32} = DT_FLOAT") .Attr("embedding_dim: int = 0") - .Attr("random_alg: string = '' ") + .Attr("initializer_mode: string = '' ") + .Attr("constant_value: float = 0") + .Attr("min: float = -2") + .Attr("max: float = 2") + .Attr("mu: float = 0") + .Attr("sigma: float = 1") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn(shape_inference::NoOutputs); @@ -286,7 +291,12 @@ REGISTER_OP("EmbeddingTableFindAndInit") .Output("values: float32") .Attr("embedding_dim: int = 0") .Attr("value_total_len: int = 0") - .Attr("random_alg: string = 'random_uniform'") + .Attr("initializer_mode: string = 'random_uniform'") + .Attr("constant_value: float = 0") + .Attr("min: float = -2") + .Attr("max: float = 2") + .Attr("mu: float = 0") + .Attr("sigma: float = 1") .Attr("seed: int = 0") .Attr("seed2: int = 0") .SetShapeFn([](shape_inference::InferenceContext *c) { @@ -344,6 +354,7 @@ REGISTER_OP("EmbeddingTableExport") .Input("table_id: int32") .Attr("embedding_dim: int = 0") .Attr("value_total_len: int = 0") + .Attr("export_mode: {'all', 'old', 'new', 'specifiednew'} = 'all'") .Attr("only_var_flag: bool = false") .Attr("file_type: string = 'bin' ") .SetShapeFn(shape_inference::NoOutputs);