代码拉取完成,页面将自动刷新
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
r"""Convert raw PASCAL dataset to TFRecord for object_detection.
Example usage in shell file:
#!/usr/bin/env bash
DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )"
export PYTHONPATH=$PYTHONPATH:$DIR/../../
export PYTHONPATH=$PYTHONPATH:$DIR/../../slim
export PYTHONPATH=$PYTHONPATH:$DIR/../../object_detection
POSTFIX=""
DATA=/home/admins/data/beer_data
echo "generating train dataset ..."
python $DIR/../../object_detection/dataset_tools/create_object_detection_tf_record.py \
--data_dir $DATA \
--set train \
--postfix $POSTFIX \
--output_path $DATA/train.record \
--label_map_path $DIR/../../object_detection/data/beer.pbtxt
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import hashlib
import os
import xml.etree.ElementTree as ET
import tensorflow as tf
from object_detection.utils import dataset_util
from .utils.io import read_label_as_list
flags = tf.app.flags
flags.DEFINE_string('data_dir', '',
'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('postfix', '', 'postfix of dataset')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('label_map_path', 'data/beer.pbtxt',
'Path to label map proto')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
flags.DEFINE_integer('class_num', 73, 'number of class')
flags.DEFINE_integer('instance', 0, 'number of instance for each class')
FLAGS = flags.FLAGS
SETS = ['train', 'val']
def dict_to_tf_example(xml_path, img_path, label_list):
with tf.gfile.GFile(img_path, 'rb') as fid:
encoded_jpg = fid.read()
key = hashlib.sha256(encoded_jpg).hexdigest()
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
xmin = []
ymin = []
xmax = []
ymax = []
classes = []
classes_text = []
truncated = []
poses = []
difficult_obj = []
for obj in root.iter('object'):
difficult = bool(int(obj.find('difficult').text))
difficult_obj.append(int(difficult))
xml_box = obj.find('bndbox')
xmin.append(float(xml_box.find('xmin').text) / width)
ymin.append(float(xml_box.find('ymin').text) / height)
xmax.append(float(xml_box.find('xmax').text) / width)
ymax.append(float(xml_box.find('ymax').text) / height)
classes_text.append(obj.find('name').text.encode('utf8'))
classes.append(label_list.index(obj.find('name').text) + 1)
example = tf.train.Example(
features=tf.train.Features(feature={
'image/height':
dataset_util.int64_feature(height),
'image/width':
dataset_util.int64_feature(width),
'image/filename':
dataset_util.bytes_feature(
os.path.basename(xml_path).encode('utf8')),
'image/source_id':
dataset_util.bytes_feature(
os.path.basename(img_path).encode('utf8')),
'image/key/sha256':
dataset_util.bytes_feature(key.encode('utf8')),
'image/encoded':
dataset_util.bytes_feature(encoded_jpg),
'image/format':
dataset_util.bytes_feature('jpeg'.encode('utf8')),
'image/object/bbox/xmin':
dataset_util.float_list_feature(xmin),
'image/object/bbox/xmax':
dataset_util.float_list_feature(xmax),
'image/object/bbox/ymin':
dataset_util.float_list_feature(ymin),
'image/object/bbox/ymax':
dataset_util.float_list_feature(ymax),
'image/object/class/text':
dataset_util.bytes_list_feature(classes_text),
'image/object/class/label':
dataset_util.int64_list_feature(classes),
'image/object/difficult':
dataset_util.int64_list_feature(difficult_obj),
'image/object/truncated':
dataset_util.int64_list_feature(truncated),
'image/object/view':
dataset_util.bytes_list_feature(poses),
}))
return example
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
data_dir = FLAGS.data_dir
writer = tf.python_io.TFRecordWriter(
os.path.join(FLAGS.output_path, '{}_{}.record'.format(FLAGS.set, FLAGS.postfix)))
label_map_dict = read_label_as_list(FLAGS.label_map_path, FLAGS.class_num, FLAGS.instance)
examples_path = os.path.join(data_dir, FLAGS.set + '{}.txt'.format(FLAGS.postfix))
examples_list = dataset_util.read_examples_list(examples_path)
for idx, example in enumerate(examples_list):
if idx % 100 == 0:
print('On image {} of {}'.format(idx, len(examples_list)))
img_path, xml_path = example.split('&!&')
tf_example = dict_to_tf_example(xml_path, img_path, label_map_dict)
writer.write(tf_example.SerializeToString())
writer.close()
if __name__ == '__main__':
tf.app.run()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。