代码拉取完成,页面将自动刷新
import os
import cv2
import numpy as np
import tensorflow as tf
from .utils.io import read_text_file
flags = tf.app.flags
flags.DEFINE_string('data_dir', '',
'Root directory to raw PASCAL VOC dataset.')
flags.DEFINE_string('postfix', '', 'postfix of dataset')
flags.DEFINE_string('set', 'train', 'Convert training set, validation set or '
'merged set.')
flags.DEFINE_string('output_path', '', 'Path to output TFRecord')
flags.DEFINE_string('image_size', 256, 'size of input image')
flags.DEFINE_boolean('ignore_difficult_instances', False, 'Whether to ignore '
'difficult instances')
flags.DEFINE_boolean('channel_mean', False, 'Whether to compute channel mean value')
FLAGS = flags.FLAGS
SETS = ['train', 'val']
def main(_):
if FLAGS.set not in SETS:
raise ValueError('set must be in : {}'.format(SETS))
data_dir = FLAGS.data_dir
writer = tf.python_io.TFRecordWriter(
os.path.join(FLAGS.output_path, '{}_{}.record'.format(FLAGS.set, FLAGS.postfix)))
examples_path = os.path.join(data_dir, '{}_{}.txt'.format(FLAGS.set, FLAGS.postfix))
examples_list = read_text_file(examples_path)
total = len(examples_list)
mean = np.zeros(3, np.float64)
buffer_mean = np.zeros(3, np.float64)
for idx, example in enumerate(examples_list):
img_path, label = example.split('&!&')
img = cv2.imread(img_path)
if idx % 500 == 0:
print('On image {} of {}'.format(idx, len(examples_list)))
if FLAGS.channel_mean:
mean += (buffer_mean * 500 / total)
buffer_mean = np.zeros(3, np.float64)
if FLAGS.channel_mean:
buffer_mean += np.mean(img, axis=(0, 1))
img_raw = img.tobytes()
example = tf.train.Example(features=tf.train.Features(feature={
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(label)])),
'img_raw': tf.train.Feature(bytes_list=tf.train.BytesList(value=[img_raw])),
'height': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[0]])),
'width': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[1]])),
'channel': tf.train.Feature(int64_list=tf.train.Int64List(value=[img.shape[2]]))
}))
writer.write(example.SerializeToString())
writer.close()
print(mean)
if __name__ == '__main__':
tf.app.run()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。