diff --git a/tf_adapter/kernels/geop_npu.cc b/tf_adapter/kernels/geop_npu.cc index 4767e4e43e08d97de6cdb6c146cd5182fc6edb6e..b69dcd6f0bb7e2137f9319b3b64091e1dc34693e 100644 --- a/tf_adapter/kernels/geop_npu.cc +++ b/tf_adapter/kernels/geop_npu.cc @@ -907,6 +907,33 @@ Status GeOp::BuildGraphDef(FunctionLibraryDefinition &flib_def, } graph.ToGraphDef(&graph_def); WriteTextProto(Env::Default(), function_.name() + "_v2.pbtxt", graph_def); + + for (int i = 0; i < graph_def.node_size(); ++i) { + NodeDef *node_def = graph_def.mutable_node(i); + node_def->clear_experimental_debug_info(); + } + + for (auto &name : flib_def.ListFunctionNames()) { + const FunctionDef *fdef = flib_def.Find(name); + if (fdef == nullptr) { continue; } + + FunctionDef func_def(*fdef); + for (int i = 0; i < func_def.node_def_size(); ++i) { + NodeDef *node_def = func_def.mutable_node_def(i); + node_def->clear_experimental_debug_info(); + + AttrValue &value_i = (*node_def->mutable_attr())["input_tensor_desc"]; + for (NameAttrList &func : *(value_i.mutable_list()->mutable_func())) { + (*func.mutable_attr())["serialize_shape"].set_type(DT_INT32); + } + + AttrValue &value_o = (*node_def->mutable_attr())["output_tensor_desc"]; + for (NameAttrList &func : *(value_o.mutable_list()->mutable_func())) { + (*func.mutable_attr())["serialize_shape"].set_type(DT_INT32); + } + } + flib_def.ReplaceFunction(name, func_def); + } } return Status::OK(); } @@ -963,7 +990,7 @@ Status GeOp::ChangeInputsShapeDesc() { AttrValue &output_tensor_descs = (*node_def.mutable_attr())[OUTPUT_DESC]; for (int32 i = 0; i < dynamic_shape_nodes_[0]->num_outputs(); ++i) { AttrValue attr_shape_value; - attr_shape_value.set_type(DT_INT32); + // attr_shape_value.set_type(DT_INT32); SetShapesToOutputDesc(result, i, attr_shape_value); (*output_tensor_descs.mutable_list()->mutable_func(i)->mutable_attr())[SERIALIZE_SHAPE] = attr_shape_value; } @@ -978,7 +1005,7 @@ Status GeOp::ChangeInputsShapeDesc() { NodeDef &node_def = const_cast(dynamic_shape_nodes_[i]->def()); AttrValue &output_tensor_descs = (*node_def.mutable_attr())[OUTPUT_DESC]; AttrValue attr_shape_value; - attr_shape_value.set_type(DT_INT32); + // attr_shape_value.set_type(DT_INT32); SetShapesToOutputDesc(result, i, attr_shape_value); (*output_tensor_descs.mutable_list()->mutable_func(0)->mutable_attr())[SERIALIZE_SHAPE] = attr_shape_value; } @@ -1113,7 +1140,7 @@ Status GeOp::GenerateDesc(Node *&node) { attr_datatype_value.set_i((int64_t)inputs[num]); name_attr_list.mutable_attr()->insert({SERIALIZE_DATATYPE, attr_datatype_value}); AttrValue attr_shape_value; - attr_shape_value.set_type(DT_INT32); + // attr_shape_value.set_type(DT_INT32); name_attr_list.mutable_attr()->insert({SERIALIZE_SHAPE, attr_shape_value}); *(input_tensor_descs.mutable_list()->add_func()) = name_attr_list; } @@ -1167,7 +1194,7 @@ Status GeOp::GenerateDesc(Node *&node) { // shape AttrValue attr_shape_value; - attr_shape_value.set_type(DT_INT32); + // attr_shape_value.set_type(DT_INT32); if (shape_value.has_list()) { TensorShapeProto shape_proto = shape_value.list().shape(num); for (int j = 0; j < shape_proto.dim_size(); j++) {