From 845aea52bf742cc9428f12cc65e85c37d04dbf14 Mon Sep 17 00:00:00 2001 From: alex Date: Thu, 17 Aug 2023 13:22:58 +0000 Subject: [PATCH] update ACL_TensorFlow/built-in/cv/Facenet_for_ACL/script/preprocess_data.py. Signed-off-by: alex --- .../Facenet_for_ACL/script/preprocess_data.py | 23 ++++++++++++++----- 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/script/preprocess_data.py b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/script/preprocess_data.py index 0eedf0c7e..79ae01342 100644 --- a/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/script/preprocess_data.py +++ b/ACL_TensorFlow/built-in/cv/Facenet_for_ACL/script/preprocess_data.py @@ -57,7 +57,7 @@ def random_rotate_image(image): def get_control_flag(control, field): return tf.equal(tf.math.mod(tf.math.floordiv(control, field), 2), 1) -def create_input_pipeline(input_queue, image_size, nrof_preprocess_threads, batch_size_placeholder): +def create_input_pipeline(input_queue, image_size, nrof_preprocess_threads, batch_size_placeholder,fake_data=False): images_and_labels_list = [] for _ in range(nrof_preprocess_threads): filenames, label, control = input_queue.dequeue() @@ -74,10 +74,18 @@ def create_input_pipeline(input_queue, image_size, nrof_preprocess_threads, batc image = tf.cond(get_control_flag(control[0], RANDOM_FLIP), lambda:tf.image.random_flip_left_right(image), lambda:tf.identity(image)) - image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION), - #lambda:(tf.cast(image, tf.float32) - 127.5)/128.0, - lambda:(tf.cast(image, tf.float32) - 0)/1, - lambda:tf.cast(tf.image.per_image_standardization(image),tf.float32)) + if fake_data == False: + print("It will generate real data!") + image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION), + #lambda:(tf.cast(image, tf.float32) - 127.5)/128.0, + lambda:(tf.cast(image, tf.float32) - 0)/1, + lambda:tf.cast(tf.image.per_image_standardization(image),tf.float32)) + else: + print("It will generate fake data!") + image = tf.cond(get_control_flag(control[0], FIXED_STANDARDIZATION), + #lambda:(tf.cast(image, tf.float32) - 127.5)/128.0, + lambda:(tf.cast(image, tf.float32) + 20)/1, + lambda:tf.cast(tf.image.per_image_standardization(image),tf.float32)) image = tf.cond(get_control_flag(control[0], FLIP), lambda:tf.image.flip_left_right(image), lambda:tf.identity(image)) @@ -134,6 +142,7 @@ def main(args): with tf.Graph().as_default(): with tf.Session() as sess: output_path = args.output_dir + fake_data = args.fake_data # Read the file containing the pairs used for testing pairs = read_pairs(os.path.expanduser(args.lfw_pairs)) @@ -158,7 +167,7 @@ def main(args): eval_enqueue_op = eval_input_queue.enqueue_many( [image_paths_placeholder, labels_placeholder, control_placeholder], name='eval_enqueue_op') image_batch, label_batch = create_input_pipeline(eval_input_queue, image_size, nrof_preprocess_threads, - batch_size_placeholder) + batch_size_placeholder,fake_data) coord = tf.train.Coordinator() threads =tf.train.start_queue_runners(coord=coord, sess=sess) @@ -250,6 +259,8 @@ def parse_arguments(argv): help='Subtract feature mean before calculating distance.', action='store_true') parser.add_argument('--use_fixed_image_standardization', help='Performs fixed standardization of images.', action='store_true') + parser.add_argument('--fake_data', type=bool, + help='Whether generate fake datas', default=False) return parser.parse_args(argv) if __name__ == '__main__': -- Gitee