Ai
4 Star 0 Fork 1

Cherrytier/object_detection

加入 Gitee
与超过 1200万 开发者一起发现、参与优秀开源项目,私有仓库也完全免费 :)
免费加入
文件
该仓库未声明开源许可证文件(LICENSE),使用请关注具体项目描述及其代码上游依赖。
克隆/下载
create_classification_tf_record.py 2.45 KB
一键复制 编辑 原始数据 按行查看 历史
Cherrytier 提交于 2018-04-24 10:09 +08:00 . first commit
import os
import cv2
import numpy as np
import tensorflow as tf
from .utils.io import read_text_file
flags = tf.app.flags
flags.DEFINE_string('data_dir', '',
'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('postfix', '', 'postfix of dataset')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('image_size', 256, 'size of input image')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
flags.DEFINE_boolean('channel_mean', False, 'Whether to compute channel mean value')
FLAGS = flags.FLAGS
SETS = ['train', 'val']
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
data_dir = FLAGS.data_dir
writer = tf.python_io.TFRecordWriter(
os.path.join(FLAGS.output_path, '{}_{}.record'.format(FLAGS.set, FLAGS.postfix)))
examples_path = os.path.join(data_dir, '{}_{}.txt'.format(FLAGS.set, FLAGS.postfix))
examples_list = read_text_file(examples_path)
total = len(examples_list)
mean = np.zeros(3, np.float64)
buffer_mean = np.zeros(3, np.float64)
for idx, example in enumerate(examples_list):
img_path, label = example.split('&!&')
img = cv2.imread(img_path)
if idx % 500 == 0:
print('On image {} of {}'.format(idx, len(examples_list)))
if FLAGS.channel_mean:
mean += (buffer_mean * 500 / total)
buffer_mean = np.zeros(3, np.float64)
if FLAGS.channel_mean:
buffer_mean += np.mean(img, axis=(0, 1))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[0]])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[1]])),
'channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[2]]))
}))
writer.write(example.SerializeToString())
writer.close()
print(mean)
if __name__ == '__main__':
tf.app.run()
Loading...
马建仓 AI 助手
尝试更多
代码解读
代码找茬
代码优化
Python
1
https://gitee.com/Cherrytier/object_detection.git
git@gitee.com:Cherrytier/object_detection.git
Cherrytier
object_detection
object_detection
master

搜索帮助