diff --git a/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/README.md b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/README.md index 068560664a5deab6fd3d45bfb26026e81a2e5baa..4b581197385b3ff6f90f2d11107275d2ccf33b53 100644 --- a/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/README.md +++ b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/README.md @@ -35,6 +35,74 @@ cd Modelzoo-TensorFlow/ACL_TensorFlow/built-in/cv/Facenet_for_ACL ``` + +原始pb模型修改输入适配: + +github链接中原始的pb模型,输入int型batchsize和bool型phase_train来控制模型走训练还是推理分支,以及batch,通过FIFO队列来读取文件,组装后送到模型中推理pb包含训练bn和dropout分支对推理场景没用,所以裁剪模型如下: + +phase_train改为const类型,固定值为false,删除前面的FIFO节点,改造input节点为placeholder,接受输入数据,其他不变。 +具体操作如下: + +1.使用pb转pbtxt脚本 +python3 pb_to_pbtxt.py 20180408-102900.pb + +2.将protobuf.pbtxt编辑: + +1.删除第一个节点batch_size + +2.第二个节点phase_train的op修改为const + +3.删除第三个节点batch_join/fifo_queue + +4.删除第四个节点batch_join + +5.删除第五个节点image_batch + +6.修改input节点为: + +node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + dim { + size: -1 + } + dim { + size: 160 + } + dim { + size: 160 + } + dim { + size: 3 + } + } + } + } +} + +7.删除第七个节点 label_batch + + +3.保存后转为pb模型 + +python3 pb_to_pbtxt.py protobuf.pbtxt + +修改模型名字 mv protobuf.pb facenet_tf.pb + + + + + ### 3. 离线推理 diff --git a/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/pb_to_pbtxt.py b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/pb_to_pbtxt.py new file mode 100644 index 0000000000000000000000000000000000000000..81e3ff5d61f4546a547f41025baa57b0b71b02bb --- /dev/null +++ b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/pb_to_pbtxt.py @@ -0,0 +1,27 @@ +import tensorflow as tf +from tensorflow.python.platform import gfile +from google.protobuf import text_format +import sys + +def convert_pb_to_pbtxt(filename): + with gfile.FastGFile(filename, 'rb') as f: + graph_def = tf.GraphDef() + graph_def.ParseFromString(f.read()) + tf.import_graph_def(graph_def, name='') + tf.train.write_graph(graph_def, './', 'protobuf.pbtxt', as_text=True) + +def convert_pbtxt_to_pb(filename): + with tf.gfile.FastGFile(filename, 'r') as f: + graph_def = tf.GraphDef() + file_content = f.read() + # Merges the human-readable string in `file_content` into `graph_def`. + text_format.Merge(file_content, graph_def) + tf.train.write_graph(graph_def, './', 'protobuf.pb', as_text=False) + +filepath = sys.argv[1] +if filepath.endswith(".pbtxt"): + convert_pbtxt_to_pb(filepath) +elif filepath.endswith(".pb"): + convert_pb_to_pbtxt(filepath) +else: + print("Error! Please set the file path of pb or pbtxt!") \ No newline at end of file