diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/README.md b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/README.md index 2f94a656c4d7846bed369392dd3d24f3f4e5cc42..abe6c8fafb79a5ab44abb6abf51a1e4713763ee0 100644 --- a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/README.md +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/README.md @@ -1,84 +1,89 @@ -# AdvancedEAST +# ID0130_AdvancedEAST AdvancedEAST 是一种用于场景图像文本检测的算法,主要基于 EAST:An Efficient and Accurate Scene Text Detector,并进行了重大改进,使长文本预测更加准确。参考项目:https://github.com/huoyijie/AdvancedEAST ## 训练环境 -* python 3.7.5+ -* tensorflow-gpu 1.15.0+ -* numpy 1.14.1+ -* tqdm 4.19.7+ +* TensorFlow 1.15.0+ +* Python 3.7.0+ ## 代码及路径解释 ``` -AdvancedEAST_ID0130_for_TensorFlow -├── advanced_east.py GPU训练入口 -├── cfg.py 参数配置 -├── icpr 数据集 -│ ├── image_10000 图像文件 -│ └── txt_10000 标签文件 -├── demo 样例图片 -│ ├── 001.png -│ └── 004.png -│ └── ... -├── data 数据集 -│ └── image_test 测试图像文件 -├── model checkpoints -├── saved_model 保存的模型 -├── data_generator.py 数据生成 -├── image_util.py keras中的一些工具 -├── keras_npu.py npu训练入口 -├── label.py 给图像打标签 -├── losses.py 损失函数 -├── network_tensorflow_changeVGG_npu.py 模型结构 +AdvancedEAST_ID0130_for_ACL ├── nms.py 预测用到的一个函数 -├── predict.py 预测函数 +├── cfg.py 参数配置 +├── cfg_bank.py 参数配置 +├── advanced_east.py 参数配置 +├── image_util.py 参数配置 +├── label.py 参数配置 +├── network_add_bn.py 参数配置 +├── icpr 数据集位置 +│ └── image_10000 图像文件 +│ └── txt_10000 标签文件 +├── demo 样例图片 +│ └── 001.png +├── image_test_bin 图片转为bin存放位置 +├── image_test_output msame推理结果bin文件存放位置 ├── preprocess.py 图片预处理 -├── LICENSE +├── image2bin.py 推理数据预处理:将image_test中的image文件转换为bin并进行其他图片预处理 +├── h5_to_pb.py h5模型固化为pb +├── atc.sh act工具 pb==》om 转换命令 +├── msame.sh msame工具:om离线推理命令 +├── postprocess.py 后处理 +├── predict.py 精度预测 ``` -## 数据集 -``` -选择使用 tianchi ICPR dataset -``` -## training +## 数据集 tianchi ICPR MTWI 2018 -* tianchi ICPR dataset download -链接: https://pan.baidu.com/s/1NSyc-cHKV3IwDo6qojIrKA 密码: ye9y +测试集下载地址:链接: 提取码:1234 -* 数据预处理: -```bash - $ python3 preprocess.py - $ python3 label.py - ``` -* 执行GPU训练: -```bash - $ python3 advanced_east.py -``` -* 执行NPU训练: -```bash - $ python3 keras_npu.py +精度验证链接: https://tianchi.aliyun.com/competition/entrance/231685/rankingList + +## 图片预处理 +```shell +python3.7.5 preprocess.py ``` -* 执行predict: -```bash - $ python3 predict.py + +## 将测试集图片转为bin文件 + +```shell +python3.7.5 image2bin.py ``` -## 验证精度 -* tianchi ICPR MTWI 2018 +## 模型文件 +包括初始h5文件,固化pb文件,以及推理om文件 +h5模型下载地址:链接: 提取码: +pb模型下载地址:链接: 提取码: -测试集下载地址:链接:https://pan.baidu.com/s/1pU4TXFWfOoZxAmIeAx98dQ 提取码:1234 +## pb模型 -精度验证链接: https://tianchi.aliyun.com/competition/entrance/231685/rankingList +模型固化 +```shell +python3.7.5 h5_to_pb.py +``` +## 生成om模型 -## 模型文件 +使用ATC模型转换工具进行模型转换时可参考如下指令 atc.sh: +```shell +atc --model=model.pb --input_shape="input_img:1,736,736,3" --framework=3 --output=model --soc_version=Ascend310 --input_format=NHWC +``` +具体参数使用方法请查看官方文档。 -链接:https://pan.baidu.com/s/1csf-VEwEIF-P0pArvf9lnw -提取码:7kru +## 使用msame工具推理 -## 精度 +使用msame工具进行推理时可参考如下指令 msame.sh +```shell +./msame --model $MODEL --input $INPUT --output $OUTPUT --outfmt BIN +``` +参考 https://gitee.com/ascend/tools/tree/master/msame, 获取msame推理工具及使用方法。 +## 执行predict +```shell +python3.7.5 predict.py +``` + +## 精度 * 论文精度: | Score | Precision | Recall | @@ -91,17 +96,10 @@ AdvancedEAST_ID0130_for_TensorFlow | :--------: | ---------- | ------ | | 0.554 | 0.760 | 0.436 | -* Ascend精度: +* Ascend推理精度: | Score | Precision | Recall | | :--------: | ---------- | ------ | -| 0.582 | 0.762 | 0.471 | - - -## 性能对比: - -| GPU V100 | Ascend 910 | -| :--------: | --------| -| 1057s/epoch | 956s/epoch | +| 0.632 | 0.849 | 0.513 | diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/atc.sh b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/atc.sh new file mode 100644 index 0000000000000000000000000000000000000000..dcb54cffe468a0da4c31e69e03ab0b02d68d7857 --- /dev/null +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/atc.sh @@ -0,0 +1,3 @@ +#!/bin/bash +atc --model=model.pb --input_shape="input_img:1,736,736,3" --framework=3 --output=./om/modelom --soc_version=Ascend310 --input_format=NHWC \ + diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/losses.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/cfg_bank.py similarity index 30% rename from TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/losses.py rename to TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/cfg_bank.py index 90f994a2b7d22b107a24fe0bb52ef7cb0cbc8d86..c7e6f919ddad526c323472506301d564cdfbb617 100644 --- a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/losses.py +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/cfg_bank.py @@ -26,69 +26,76 @@ # See the License for the specific language governing permissions and # limitations under the License. -import cfg -import tensorflow as tf +import os -def quad_loss(y_true, y_pred): - print('y_true',y_true,'y_pred',y_pred) - # loss for inside_score - logits = y_pred[:, :, :, :1] - labels = y_true[:, :, :, :1] - # balance positive and negative samples in an image - beta = 1 - tf.reduce_mean(labels) - # first apply sigmoid activation - predicts = tf.nn.sigmoid(logits) - # log +epsilon for stable cal - inside_score_loss = tf.reduce_mean( - -1 * (beta * labels * tf.log(predicts + cfg.epsilon) + - (1 - beta) * (1 - labels) * tf.log(1 - predicts + cfg.epsilon))) - inside_score_loss *= cfg.lambda_inside_score_loss - print(inside_score_loss) - # loss for side_vertex_code - vertex_logits = y_pred[:, :, :, 1:3] - vertex_labels = y_true[:, :, :, 1:3] - vertex_beta = 1 - (tf.reduce_mean(y_true[:, :, :, 1:2]) - / (tf.reduce_mean(labels) + cfg.epsilon)) - vertex_predicts = tf.nn.sigmoid(vertex_logits) - pos = -1 * vertex_beta * vertex_labels * tf.log(vertex_predicts + - cfg.epsilon) - neg = -1 * (1 - vertex_beta) * (1 - vertex_labels) * tf.log( - 1 - vertex_predicts + cfg.epsilon) - positive_weights = tf.cast(tf.equal(y_true[:, :, :, 0], 1), tf.float32) - side_vertex_code_loss = \ - tf.reduce_sum(tf.reduce_sum(pos + neg, axis=-1) * positive_weights) / ( - tf.reduce_sum(positive_weights) + cfg.epsilon) - side_vertex_code_loss *= cfg.lambda_side_vertex_code_loss - print(side_vertex_code_loss) - # loss for side_vertex_coord delta - g_hat = y_pred[:, :, :, 3:] - g_true = y_true[:, :, :, 3:] - vertex_weights = tf.cast(tf.equal(y_true[:, :, :, 1], 1), tf.float32) - pixel_wise_smooth_l1norm = smooth_l1_loss(g_hat, g_true, vertex_weights) - side_vertex_coord_loss = tf.reduce_sum(pixel_wise_smooth_l1norm) / ( - tf.reduce_sum(vertex_weights) + cfg.epsilon) - side_vertex_coord_loss *= cfg.lambda_side_vertex_coord_loss - print(side_vertex_coord_loss) - return inside_score_loss + side_vertex_code_loss + side_vertex_coord_loss +train_task_id = '3T832' +initial_epoch = 0 +epoch_num = 24 +lr = 1e-3 +decay = 5e-4 +# clipvalue = 0.5 # default 0.5, 0 means no clip +patience = 5 +load_weights = False +lambda_inside_score_loss = 4.0 +lambda_side_vertex_code_loss = 1.0 +lambda_side_vertex_coord_loss = 1.0 +total_img = 10000 +validation_split_ratio = 0.1 +max_train_img_size = int(train_task_id[-3:]) +max_predict_img_size = int(train_task_id[-3:]) # 2400 +assert max_train_img_size in [256, 384, 512, 640, 736, 832], \ + 'max_train_img_size must in [256, 384, 512, 640, 736, 832]' +if max_train_img_size == 256: + batch_size = 8 +elif max_train_img_size == 384: + batch_size = 4 +elif max_train_img_size == 512: + batch_size = 2 +else: + batch_size = 1 +steps_per_epoch = total_img * (1 - validation_split_ratio) // batch_size +validation_steps = total_img * validation_split_ratio // batch_size -def smooth_l1_loss(prediction_tensor, target_tensor, weights): - n_q = tf.reshape(quad_norm(target_tensor), tf.shape(weights)) - diff = prediction_tensor - target_tensor - abs_diff = tf.abs(diff) - abs_diff_lt_1 = tf.less(abs_diff, 1) - pixel_wise_smooth_l1norm = (tf.reduce_sum( - tf.where(abs_diff_lt_1, 0.5 * tf.square(abs_diff), abs_diff - 0.5), - axis=-1) / n_q) * weights - return pixel_wise_smooth_l1norm +data_dir = 'dataset/train' +origin_image_dir_name = 'pitures/' +origin_txt_dir_name = 'txts/' +train_image_dir_name = 'images_%s/' % train_task_id +train_label_dir_name = 'labels_%s/' % train_task_id +show_gt_image_dir_name = 'show_gt_images_%s/' % train_task_id +show_act_image_dir_name = 'show_act_images_%s/' % train_task_id +gen_origin_img = True +draw_gt_quad = True +draw_act_quad = True +val_fname = 'val_%s.txt' % train_task_id +train_fname = 'train_%s.txt' % train_task_id +# in paper it's 0.3, maybe to large to this problem +shrink_ratio = 0.2 +# pixels between 0.2 and 0.6 are side pixels +shrink_side_ratio = 0.6 +epsilon = 1e-4 +num_channels = 3 +feature_layers_range = range(5, 1, -1) +# feature_layers_range = range(3, 0, -1) +feature_layers_num = len(feature_layers_range) +# pixel_size = 4 +pixel_size = 2 ** feature_layers_range[-1] +locked_layers = False -def quad_norm(g_true): - shape = tf.shape(g_true) - delta_xy_matrix = tf.reshape(g_true, [-1, 2, 2]) - diff = delta_xy_matrix[:, 0:1, :] - delta_xy_matrix[:, 1:2, :] - square = tf.square(diff) - distance = tf.sqrt(tf.reduce_sum(square, axis=-1)) - distance *= 4.0 - distance += cfg.epsilon - return tf.reshape(distance, shape[:-1]) +if not os.path.exists('model'): + os.mkdir('model') +if not os.path.exists('saved_model'): + os.mkdir('saved_model') + +model_weights_path = 'model/weights_%s.{epoch:03d}-{val_loss:.3f}.h5' \ + % train_task_id +saved_model_file_path = 'saved_model/east_model_%s.h5' % train_task_id +saved_model_weights_file_path = 'saved_model/east_model_weights_%s.h5' \ + % train_task_id + +pixel_threshold = 0.7 +side_vertex_pixel_threshold = 0.7 +trunc_threshold = 0.1 +predict_cut_text_line = False +predict_write2txt = True diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/clock_wise.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/clock_wise.py new file mode 100644 index 0000000000000000000000000000000000000000..aa03f94b4381ccd0f25624ded9889fe6b6c3ff9e --- /dev/null +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/clock_wise.py @@ -0,0 +1,92 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from scipy.spatial import distance as dist +import numpy as np +import math + +def cos_dist(a, b): + if len(a) != len(b): + return None + part_up = 0.0 + a_sq = 0.0 + b_sq = 0.0 + for a1, b1 in zip(a,b): + part_up += a1*b1 + a_sq += a1**2 + b_sq += b1**2 + part_down = math.sqrt(a_sq*b_sq) + if part_down == 0.0: + return None + else: + return part_up / part_down + +def order_points(pts): + xSorted = pts[np.argsort(pts[:, 0]), :] + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + D = dist.cdist(tl[np.newaxis], rightMost, "euclidean")[0] + (br, tr) = rightMost[np.argsort(D)[::-1], :] + return np.array([tl, tr, br, bl], dtype="int32") + +def order_points_quadrangle(pts): + xSorted = pts[np.argsort(pts[:, 0]), :] + leftMost = xSorted[:2, :] + rightMost = xSorted[2:, :] + leftMost = leftMost[np.argsort(leftMost[:, 1]), :] + (tl, bl) = leftMost + + vector_0 = np.array(bl-tl) + vector_1 = np.array(rightMost[0]-tl) + vector_2 = np.array(rightMost[1]-tl) + angle = [np.arccos(cos_dist(vector_0, vector_1)), np.arccos(cos_dist(vector_0, vector_2))] + (br, tr) = rightMost[np.argsort(angle), :] + return np.array([tl, tr, br, bl], dtype="int32") + +from functools import reduce +import operator +import math +def order_points_tuple(pts): + pts = pts.tolist() + coords = [] + for elem in pts: + coords.append(tuple(elem)) + center = tuple(map(operator,truediv, reduce(lambda x, y:map(operator.add, x, y), coords), [len(coords)] * 2)) + output = sorted(coords, key=lambda coords: (-135 - math.degrees(math.atan2(*tuple(map(operator.sub, coords, center))[::-1]))) % 360, reverse=True) + res = [] + for elem in output: + res.append(list(elem)) + return np.array(res, dtype="int32") +points = np.array([[54,20],[39,48],[117,52],[121,21]]) +print(order_points(points)) +pt = np.array([703,211,754,283,756,223,747,212]).reshape(4,2) +print(order_points(pt)) +print(order_points_tuple(pt)) \ No newline at end of file diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/data_generator.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/h5_to_pb.py similarity index 45% rename from TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/data_generator.py rename to TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/h5_to_pb.py index 0ec219c8cdd2c19f7311f69b0f51498aef16b70a..fa6c05dad0671bfcd49dd59e101494fca10ec72e 100644 --- a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/data_generator.py +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/h5_to_pb.py @@ -1,66 +1,68 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import os -import numpy as np -from tensorflow.python.keras.preprocessing import image -from tensorflow.python.keras.applications.vgg16 import preprocess_input - -import cfg - - -def gen(batch_size=cfg.batch_size, is_val=False): - img_h, img_w = cfg.max_train_img_size, cfg.max_train_img_size - x = np.zeros((batch_size, img_h, img_w, cfg.num_channels), dtype=np.float32) - pixel_num_h = img_h // cfg.pixel_size - pixel_num_w = img_w // cfg.pixel_size - y = np.zeros((batch_size, pixel_num_h, pixel_num_w, 7), dtype=np.float32) - if is_val: - with open(os.path.join(cfg.data_dir, cfg.val_fname), 'r') as f_val: - f_list = f_val.readlines() - else: - with open(os.path.join(cfg.data_dir, cfg.train_fname), 'r') as f_train: - f_list = f_train.readlines() - while True: - for i in range(batch_size): - # random gen an image name - random_img = np.random.choice(f_list) - img_filename = str(random_img).strip().split(',')[0] - # load img and img anno - img_path = os.path.join(cfg.data_dir, - cfg.train_image_dir_name, - img_filename) - img = image.load_img(img_path) - img = image.img_to_array(img,dtype='float32') - x[i] = preprocess_input(img, mode='tf') - gt_file = os.path.join(cfg.data_dir, - cfg.train_label_dir_name, - img_filename[:-4] + '_gt.npy') - y[i] = np.load(gt_file) - - yield x, y +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from tensorflow.keras.models import load_model +import tensorflow as tf +import os +import os.path as osp +from tensorflow.keras import backend as K +from network_add_bn import East +from predict import predict_txt +import cfg + +#路径参数 +input_path = './model/' +weight_file = 'east_model_3T832.h5' +weight_file_path = osp.join(input_path,weight_file) +output_graph_name = weight_file[:-3] + '.pb' +#转换函数 +def h5_to_pb(h5_model,output_dir,model_name,out_prefix = "output_",log_tensorboard = True): + if osp.exists(output_dir) == False: + os.mkdir(output_dir) + out_nodes = [] + for i in range(len(h5_model.outputs)): + out_nodes.append(out_prefix + str(i + 1)) + tf.identity(h5_model.output[i],out_prefix + str(i + 1)) + sess = K.get_session() + from tensorflow.python.framework import graph_util,graph_io + init_graph = sess.graph.as_graph_def() + main_graph = graph_util.convert_variables_to_constants(sess,init_graph,out_nodes) + graph_io.write_graph(main_graph,output_dir,name = model_name,as_text = False) + if log_tensorboard: + from tensorflow.python.tools import import_pb_to_tensorboard + import_pb_to_tensorboard.import_to_tensorboard(osp.join(output_dir,model_name),output_dir) +#输出路径 +output_dir = "trans_model" +#加载模型 +east = East() +east_detect = east.east_network() +east_detect.load_weights(cfg.saved_model_weights_file_path) +h5_model =east_detect +h5_to_pb(h5_model,output_dir = output_dir,model_name = output_graph_name) +print('model saved') \ No newline at end of file diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/keras_npu.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/image2bin.py similarity index 31% rename from TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/keras_npu.py rename to TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/image2bin.py index 945c9916004959d5caa0299d316aa511945bc5b2..291e0e8449b8ece5c5d89c53de9178e0ae2317bd 100644 --- a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/keras_npu.py +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/image2bin.py @@ -1,94 +1,89 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import tensorflow as tf -from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint -from tensorflow.python.keras import backend as K - -import cfg -from network_tensorflow_changeVGG_npu import East -from losses import quad_loss -from data_generator import gen -import argparse -from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer -from npu_bridge.estimator.npu.npu_loss_scale_manager import ExponentialUpdateLossScaleManager - -parser = argparse.ArgumentParser() -parser.add_argument('--data_path', type=str, default='', help='data path') -parser.add_argument('--epochs', type=int, default=24, help='epochs') -parser.add_argument('--steps_per_epoch', type=int, default=9000, help='steps_per_epoch') -parser.add_argument('--validation_steps', type=int, default=1000, help='validation_steps') - -args = parser.parse_args() -cfg.data_dir = args.data_path - -from npu_bridge.npu_init import * - -# session config -sess_config = tf.ConfigProto() -custom_op = sess_config.graph_options.rewrite_options.custom_optimizers.add() -custom_op.name = "NpuOptimizer" -sess_config.graph_options.rewrite_options.remapping = RewriterConfig.OFF -sess_config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF -# custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision") - -sess = tf.Session(config=sess_config) -K.set_session(sess) - -east = East() -east_network = east.east_network() -east_network.summary() - -opt_tmp = tf.compat.v1.train.AdamOptimizer(learning_rate=cfg.lr) -loss_scale_manager = ExponentialUpdateLossScaleManager(init_loss_scale=65536, incr_every_n_steps=1000, - decr_every_n_nan_or_inf=2, decr_ratio=0.5) -opt = NPULossScaleOptimizer(opt_tmp, loss_scale_manager) -east_network.compile(loss=quad_loss, optimizer=opt) - -# if cfg.load_weights and os.path.exists(cfg.saved_model_weights_file_path): -# east_network.load_weights(cfg.saved_model_weights_file_path) -# print('load model') -east_network.fit_generator(generator=gen(), - steps_per_epoch=int(args.steps_per_epoch), - epochs=args.epochs, - validation_data=gen(is_val=True), - validation_steps=int(args.validation_steps), - verbose=1, - initial_epoch=cfg.initial_epoch, - callbacks=[ - EarlyStopping(patience=cfg.patience, verbose=1), - ModelCheckpoint(filepath=cfg.model_weights_path, - save_best_only=True, - save_weights_only=True, - verbose=1)]) -print('Train Success') -east_network.save(cfg.saved_model_file_path) -east_network.save_weights(cfg.saved_model_weights_file_path) - - -sess.close() +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import cfg +import argparse +import os +import numpy as np +from PIL import Image, ImageDraw +from tensorflow.keras.applications.vgg16 import preprocess_input +from tensorflow.keras.preprocessing import image + +def resize_image(im, max_img_size=cfg.max_train_img_size): + im_width = np.minimum(im.width, max_img_size) + if im_width == max_img_size < im.width: + im_height = int((im_width / im.width) * im.height) + else: + im_height = im.height + o_height = np.minimum(im_height, max_img_size) + if o_height == max_img_size < im_height: + o_width = int((o_height / im_height) * im_width) + else: + o_width = im_width + d_wight = o_width - (o_width % 32) + d_height = o_height - (o_height % 32) + return d_wight, d_height + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--src_path", default="../image_test1", help="path of original pictures") + parser.add_argument("--dst_path", default="../image_test_bin", help="path of output bin files") + parser.add_argument("--pic_num", default=10000, help="picture number") + args = parser.parse_args() + src_path = args.src_path + dst_path = args.dst_path + pic_num = args.pic_num + files = os.listdir(src_path) + files.sort() + n = 0 + for file in files: + src = src_path + "/" + file + print("start to process %s"%src) + img = image.load_img(src) + d_wight, d_height = resize_image(img, cfg.max_predict_img_size) + img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB') + img = image.img_to_array(img,dtype=np.float32) + print(img.shape) + print(d_wight, d_height) + if d_height != 736: + zero = np.zeros((736 - d_height, d_wight, 3),dtype=np.float32) + img = np.concatenate((img, zero), axis=0) + print(img.shape) + if d_wight != 736: + zero = np.zeros((736, 736 - d_wight, 3),dtype=np.float32) + img = np.concatenate((img, zero), axis=1) + print('img.shape', img.shape) + img = preprocess_input(img, mode='tf') + x = np.expand_dims(img, axis=0) + print('x.shape', x.shape) + print(x.dtype) + x.tofile(dst_path + "/" + file + ".bin") + n += 1 + if int(pic_num) == n: + break \ No newline at end of file diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/msame.sh b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/msame.sh new file mode 100644 index 0000000000000000000000000000000000000000..fd27404fbb15b18261c20ab8290bae5e87801068 --- /dev/null +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/msame.sh @@ -0,0 +1,6 @@ +MODEL="/home/HwHiAiUser/11my310/ACL/modelom.om" +INPUT="/home/HwHiAiUser/11my310/image_test_bin/" +OUTPUT="/home/HwHiAiUser/11my310/image_test_output" +./msame --model $MODEL --input $INPUT --output $OUTPUT --outfmt BIN + + diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/network_tensorflow_changeVGG_npu.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/network_add_bn.py similarity index 100% rename from TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/network_tensorflow_changeVGG_npu.py rename to TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/network_add_bn.py diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/postprocess.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/postprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..5281ef160d08d09c1f326ab9b52a3deea7cef0ca --- /dev/null +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/postprocess.py @@ -0,0 +1,178 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import os +import numpy as np +from PIL import Image, ImageDraw +from keras.preprocessing import image +from keras.applications.vgg16 import preprocess_input + +import cfg_bank as cfg +from label import point_inside_of_quad +from network_add_bn import East +from preprocess import resize_image +from nms import nms +from clock_wise import* + + +def sigmoid(x): + """`y = 1 / (1 + exp(-x))`""" + return 1 / (1 + np.exp(-x)) + +def cut_text_line(geo,scale_ratio_w, scale_ratio_h, imarry, img_path, s): + geo /= [scale_ratio_w, scale_ratio_h] + p_min = np.amin(geo, axis=0) + p_max = np.amax(geo, axis=0) + min_xy = p_min.astype(int) + max_xy = p_max.astype(int) + 2 + sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy() + for m in range(min_xy[1], max_xy[1]): + for n in range(min_xy[1], max_xy[1]): + if not point_inside_of_quad(n, m, geo, p_min, p_max): + sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255 + sub_im = image.array_to_img(sub_im_arr, scale=False) + sub_im.save(img_path + '_subim%d.jpg' % s) + +def predict(east_detect, img_path, pixel_threshold, quiet=False): + img = image.load_img(img_path) + d_weight, d_height = resize_image(img, cfg.max_predict_img_size) + img = img.resize((d_weight, d_height), Image.NEAREST).convert('RGB') + img = image.img_to_array(img) + #img = img/255 + img = preprocess_input(img, mode='tf') + x = np.expand_dims(img, axis=0) + y =east_detect.predict(x) + + y = np.squeeze(y, axis=0) + y[:, :, :3] = sigmoid(y[:, :, :3]) + cond = np.greater_equal(y[:, :, 0], pixel_threshold) + activation_pixels = np.where(cond) + quad_scores, quad_after_nms = nms(y, activation_pixels) + with Image.open(img_path) as im: + im_array = image.img_to_array(im.convert('RGB')) + d_wight, d_height = resize_image(im, cfg.max_predict_img_size) + scale_ratio_w = d_wight / im.width + scale_ratio_h = d_height / im.height + im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB') + quad_im = im.copy() + draw = ImageDraw.Draw(im) + for i, j in zip(activation_pixels[0], activation_pixels[1]): + px = (j + 0.5) * cfg.pixel_size + py = (i + 0.5) * cfg.pixel_size + line_width, line_color = 1, 'red' + if y[i, j, 1] >= cfg.side_vertex_pixel_threshold: + if y[i, j, 2] < cfg.trunc_threshold: + line_width, line_color = 2, 'yellow' + elif y[i, j, 2] >= 1 - cfg.trunc_threshold: + line_width, line_color = 2, 'green' + draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), + (px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), + (px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), + (px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), + (px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)], + width=line_width, fill=line_color) + im.save(img_path + '_act.jpg') + quad_draw = ImageDraw.Draw(quad_im) + txt_items = [] + for score, geo, s in zip(quad_scores, quad_after_nms, + range(len(quad_scores))): + if np.amin(score) > 0: + quad_draw.line([tuple(geo[0]), + tuple(geo[1]), + tuple(geo[2]), + tuple(geo[3]), + tuple(geo[0])], width=2, fill='red') + if cfg.predict_cut_text_line: + cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, + img_path, s) + rescaled_geo = geo / [scale_ratio_w, scale_ratio_h] + rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist() + txt_item = ','.join(map(str, rescaled_geo_list)) + txt_items.append(txt_item + '\n') + elif not quiet: + print('quad invalid with vertex num less then 4.') + quad_im.save(img_path + '_predict.jpg') + if cfg.predict_write2txt and len(txt_items) > 0: + with open(img_path[:-4] + '.txt', 'w') as f_txt: + f_txt.writelines(txt_items) + +def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False): + img = image.load_img(img_path) + d_weight, d_height = cfg.max_predict_img_size, cfg.max_predict_img_size + scale_ratio_w = d_weight / img.width + scale_ratio_h = d_height / img.height + img = img.resize((d_weight, d_height), Image.NEAREST).convert('RGB') + img = image.img_to_array(img) + img = preprocess_input(img, mode='tf') + x = np.expand_dims(img, axis=0) + y =east_detect.predict(x) + y = np.squeeze(y, axis=0) + y[:, :, :3] = sigmoid(y[:, :, :3]) + cond = np.greater_equal(y[:, :, 0], pixel_threshold) + activation_pixels = np.where(cond) + quad_scores, quad_after_nms = nms(y, activation_pixels) + + txt_items = [] + for score, geo in zip(quad_scores, quad_after_nms): + if np.amin(score) > 0: + rescaled_geo = geo / [scale_ratio_w, scale_ratio_h] + rescaled_geo = np.round(rescaled_geo,decimals=0).astype(np.int32) + rescaled_geo = order_points_tuple(rescaled_geo) + rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist() + txt_item = ','.join(map(str, rescaled_geo_list)) + txt_items.append(txt_item + '\n') + elif not quiet: + print('quad invalid with vertex num less than 4.') + if cfg.predict_write2txt and len(txt_items) > 0: + with open(txt_path, 'w') as f_txt: + f_txt.writelines(txt_items) + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--path', '-p', + default='demo/012.png', + help='image path') + parser.add_argument('--threshold', '-t', + default=cfg.pixel_threshold, + help='pixel activation threshold') + return parser.parse_args() + + +if __name__=='__main__': + args = parse_args() + img_path = args.path + threshold = float(args.threshold) + print(img_path, threshold) + + east = East() + east_detect = east.east_network() + east_detect.load_weights(cfg.saved_model_weights_file_path) + predict(east_detect, img_path, threshold) + east_detect.save("saved_model/AdvancedEast_model.h5") + \ No newline at end of file diff --git a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/predict.py b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/predict.py index 52f5fc822986196b45523e1abf576de2f0d0b2f0..4fcc92e45a2983589a07eb101ee33c835d065349 100644 --- a/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/predict.py +++ b/TensorFlow/contrib/cv/AdvancedEast_ID0130_for_TensorFlow/predict.py @@ -1,201 +1,50 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import argparse - -import numpy as np -from PIL import Image, ImageDraw -from tensorflow.keras.applications.vgg16 import preprocess_input -from tensorflow.keras.preprocessing import image - -import cfg -from label import point_inside_of_quad -from network_tensorflow_changeVGG_npu import East -# from network import East -from nms import nms -from preprocess import resize_image - - -def sigmoid(x): - """`y = 1 / (1 + exp(-x))`""" - return 1 / (1 + np.exp(-x)) - - -def cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, img_path, s): - geo /= [scale_ratio_w, scale_ratio_h] - p_min = np.amin(geo, axis=0) - p_max = np.amax(geo, axis=0) - min_xy = p_min.astype(int) - max_xy = p_max.astype(int) + 2 - sub_im_arr = im_array[min_xy[1]:max_xy[1], min_xy[0]:max_xy[0], :].copy() - for m in range(min_xy[1], max_xy[1]): - for n in range(min_xy[0], max_xy[0]): - if not point_inside_of_quad(n, m, geo, p_min, p_max): - sub_im_arr[m - min_xy[1], n - min_xy[0], :] = 255 - sub_im = image.array_to_img(sub_im_arr, scale=False) - sub_im.save(img_path + '_subim%d.jpg' % s) - - -def predict(east_detect, img_path, pixel_threshold, quiet=False): - img = image.load_img(img_path) - d_wight, d_height = resize_image(img, cfg.max_predict_img_size) - img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB') - img = image.img_to_array(img) - img = preprocess_input(img, mode='tf') - x = np.expand_dims(img, axis=0) - print('x.shape', x.shape) - y = east_detect.predict(x) - - y = np.squeeze(y, axis=0) - print('y.shape:', y.shape) - y[:, :, :3] = sigmoid(y[:, :, :3]) - cond = np.greater_equal(y[:, :, 0], pixel_threshold) - activation_pixels = np.where(cond) - quad_scores, quad_after_nms = nms(y, activation_pixels) - # print('quad_scores:',quad_scores,'quad_after_nms', quad_after_nms) - with Image.open(img_path) as im: - im_array = image.img_to_array(im.convert('RGB')) - d_wight, d_height = resize_image(im, cfg.max_predict_img_size) - scale_ratio_w = d_wight / im.width - scale_ratio_h = d_height / im.height - im = im.resize((d_wight, d_height), Image.NEAREST).convert('RGB') - quad_im = im.copy() - draw = ImageDraw.Draw(im) - for i, j in zip(activation_pixels[0], activation_pixels[1]): - px = (j + 0.5) * cfg.pixel_size - py = (i + 0.5) * cfg.pixel_size - line_width, line_color = 1, 'red' - if y[i, j, 1] >= cfg.side_vertex_pixel_threshold: - if y[i, j, 2] < cfg.trunc_threshold: - line_width, line_color = 2, 'yellow' - elif y[i, j, 2] >= 1 - cfg.trunc_threshold: - line_width, line_color = 2, 'green' - draw.line([(px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), - (px + 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size), - (px + 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), - (px - 0.5 * cfg.pixel_size, py + 0.5 * cfg.pixel_size), - (px - 0.5 * cfg.pixel_size, py - 0.5 * cfg.pixel_size)], - width=line_width, fill=line_color) - im.save(img_path + '_act.jpg') - quad_draw = ImageDraw.Draw(quad_im) - txt_items = [] - for score, geo, s in zip(quad_scores, quad_after_nms, - range(len(quad_scores))): - if np.amin(score) > 0: - quad_draw.line([tuple(geo[0]), - tuple(geo[1]), - tuple(geo[2]), - tuple(geo[3]), - tuple(geo[0])], width=2, fill='red') - if cfg.predict_cut_text_line: - cut_text_line(geo, scale_ratio_w, scale_ratio_h, im_array, - img_path, s) - rescaled_geo = geo / [scale_ratio_w, scale_ratio_h] - rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist() - txt_item = ','.join(map(str, rescaled_geo_list)) - txt_items.append(txt_item + '\n') - elif not quiet: - print('quad invalid with vertex num less then 4.') - quad_im.save(img_path + '_predict.jpg') - if cfg.predict_write2txt and len(txt_items) > 0: - with open(img_path[:-4] + '.txt', 'w') as f_txt: - f_txt.writelines(txt_items) - - -def predict_txt(east_detect, img_path, txt_path, pixel_threshold, quiet=False): - img = image.load_img(img_path) - d_wight, d_height = resize_image(img, cfg.max_predict_img_size) - scale_ratio_w = d_wight / img.width - scale_ratio_h = d_height / img.height - img = img.resize((d_wight, d_height), Image.NEAREST).convert('RGB') - img = image.img_to_array(img) - img = preprocess_input(img, mode='tf') - x = np.expand_dims(img, axis=0) - y = east_detect.predict(x) - - y = np.squeeze(y, axis=0) - y[:, :, :3] = sigmoid(y[:, :, :3]) - cond = np.greater_equal(y[:, :, 0], pixel_threshold) - activation_pixels = np.where(cond) - quad_scores, quad_after_nms = nms(y, activation_pixels) - print(quad_scores, quad_after_nms) - - txt_items = [] - for score, geo in zip(quad_scores, quad_after_nms): - if np.amin(score) > 0: - rescaled_geo = geo / [scale_ratio_w, scale_ratio_h] - rescaled_geo_list = np.reshape(rescaled_geo, (8,)).tolist() - txt_item = ','.join(map(str, rescaled_geo_list)) - txt_items.append(txt_item + '\n') - elif not quiet: - print('quad invalid with vertex num less then 4.') - if cfg.predict_write2txt and len(txt_items) > 0: - with open(txt_path, 'w') as f_txt: - f_txt.writelines(txt_items) - - -# def parse_args(): -# parser = argparse.ArgumentParser() -# parser.add_argument('--path', '-p', -# default='demo/012.png', -# help='image path') -# parser.add_argument('--threshold', '-t', -# default=cfg.pixel_threshold, -# help='pixel activation threshold') -# return parser.parse_args() - - -import os - -if __name__ == '__main__': - parser = argparse.ArgumentParser() - parser.add_argument('--model', - default='model/weights.h5', - help='model path') - parser.add_argument('--testdata', - default='data/image_test', - help='data path') - parser.add_argument('--threshold', '-t', - default=cfg.pixel_threshold, - help='pixel activation threshold') - args = parser.parse_args() - east = East() - east_detect = east.east_network() - east_detect.load_weights(args.model) - # east_detect.load_weights('model/east_model_weights_%s.h5' \ - # % '3T736') - path = args.testdata - filelist = os.listdir(path) - print(len(filelist)) - for i in range(len(filelist)): - s = path + '/' + filelist[i] - print(i, ':', s) - img_path = s - threshold = float(args.threshold) - print(img_path, threshold) - predict(east_detect, img_path, threshold, quiet=True) +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import os +from tqdm import tqdm +from network_add_bn import East +from postprocess import predict_txt +import cfg + +if __name__=='__main__': + east = East() + east_detect = east.east_network() + east_detect.load_weights("./model.h5") + image_test_dir = os.path.join(cfg.data_dir, 'rename_images/') + txt_test_dir = os.path.join(cfg.data_dir, 'txt_test') + test_imgname_list = os.listdir(image_test_dir) + test_imgname_list = sorted(test_imgname_list) + print('found %d test images.' % len(test_imgname_list)) + for test_img_name, _ in zip(test_imgname_list, + tqdm(range(len(test_imgname_list)))): + img_path = os.path.join(image_test_dir, test_img_name) + txt_path = os.path.join(txt_test_dir, 'res_'+test_img_name[:-4] + '.txt') + predict_txt(east_detect, img_path, txt_path, cfg.pixel_threshold, True) +