From b70540ae62ffe64de4490f40b491dc9309b0ca58 Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Mon, 5 Mar 2018 10:09:19 +0800 Subject: [PATCH] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=B8=80=E4=BA=9B=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E5=86=85=E5=AE=B9?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 改变了计算均值的方法; 2. 添加了 requirements 文件; 3. 添加一些用用的工具函数. --- README.md | 41 +++++++++ data/base_record_generator.py | 24 ++++-- requirements.txt | 5 ++ utils/basic.py | 94 +++++++++++++++++++++ utils/image.py | 85 +++++++++++++++++++ utils/io.py | 155 ++++++++++++++++++++++++++++++++++ utils/tf_model.py | 135 +++++++++++++++++++++++++++++ 7 files changed, 531 insertions(+), 8 deletions(-) create mode 100644 requirements.txt create mode 100644 utils/basic.py create mode 100644 utils/image.py create mode 100644 utils/io.py create mode 100644 utils/tf_model.py diff --git a/README.md b/README.md index 0e58aac..872898b 100644 --- a/README.md +++ b/README.md @@ -1 +1,42 @@ # tensorflow_basic_template + +我们开发一个可以作为基本工具包的项目,为一些深度学习项目的之后的开发提供统一的规范。 + +## 安装教程 + +### 配置 + +需要安装 `Python`, 建议使用 `Anaconda` ,可以方便的进行包管理。 + +requirements: + +- tensorflow +- numpy + +### 下载 + +使用命令下载: + +```bash +git clone https://gitee.com/study-cooperation/tensorflow_basic_template.git +``` + +或者在 `gitee` 的页面[下载](https://gitee.com/study-cooperation/tensorflow_basic_template)。 + +### 安装 + +执行如下命令安装: + +```bash +cd tensorflow_basic_template +python setup.py install +``` + +## 使用说明 + +安装后的包名为 `dltools` ,可以这样使用该包: + +```python +import dltools +``` + diff --git a/data/base_record_generator.py b/data/base_record_generator.py index b519dc7..052b75a 100644 --- a/data/base_record_generator.py +++ b/data/base_record_generator.py @@ -51,6 +51,7 @@ class BaseRecordGenerator(object): self.buf_data = {} self.display = display self.mean = None + self.count = 0 if logger is not None: self.logger = logger @@ -67,6 +68,16 @@ class BaseRecordGenerator(object): """ raise NotImplementedError + def _write_data(self): + """ + 将 tf example 写入 record 文件 + + Returns + ------- + + """ + self.writer.write(self.buf_data['example'].SerializeToString()) + def update(self): """ 处理全部数据 @@ -75,18 +86,15 @@ class BaseRecordGenerator(object): ------- """ - count = 0 for meta in self.data: - if count % self.display == 0: + if self.count % self.display == 0: self.logger.info( - 'Processing the number of {} data.'.format(count)) + 'Processing the number of {} data.'.format(self.count)) self.buf_data['raw'] = meta self._encode_data() - if self.mean_value_length: - self.mean += self.buf_data['mean'] - self.writer.write(self.buf_data['example'].SerializeToString()) + self._write_data() self.buf_data.clear() - count += 1 + self.count += 1 if self.mean_value_length: - self.mean = self.mean / count + self.mean = self.mean / self.count self.writer.close() diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..1f2abf2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +opencv-python +opencv-contrib-python +numpy +scikit-image +tensorflow diff --git a/utils/basic.py b/utils/basic.py new file mode 100644 index 0000000..987f443 --- /dev/null +++ b/utils/basic.py @@ -0,0 +1,94 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/9 21:15 +@desc: +""" +import os + +import numpy as np + + +def num_samples(x): + """ + 返回类似数组的数据数量 + + Parameters + ---------- + x: array-like + + Returns + ------- + + """ + if not hasattr(x, '__len__') and not hasattr(x, 'shape'): + if hasattr(x, '__array__'): + x = np.asarray(x) + else: + raise TypeError("Expected sequence or array-like, got %s" % + type(x)) + if hasattr(x, 'shape'): + if len(x.shape) == 0: + raise TypeError("Singleton array %r cannot be considered" + " a valid collection." % x) + return x.shape[0] + else: + return len(x) + + +def check_consistent_length(*arrays): + """ + 检查所有数组的第一维的长度是否相等 + + Parameters + ---------- + *arrays : list or tuple of input objects. + Objects that will be checked for consistent length. + """ + lengths = [num_samples(X) for X in arrays if X is not None] + uniques = np.unique(lengths) + if len(uniques) > 1: + raise ValueError("Found input variables with inconsistent numbers of" + " samples: %r" % [int(l) for l in lengths]) + + +def merge_dict(*dicts): + """ + 合并一系列字典,合并后的字典的key是所有字典的key的并集, + 其value取和 + + Parameters + ---------- + dicts: 字典列表 + + Returns + ------- + 合并后的字典 + """ + res = {} + for _dict in dicts: + for key, value in _dict.items(): + if key not in res: + res[key] = value + else: + res[key] += value + return res + + +def get_name_and_ext(file_name): + """ + 获取文件名,包含扩展名 + + Parameters + ---------- + file_name: 文件完整路径 + + Returns + ------- + + """ + (_, temp_filename) = os.path.split(file_name) + (shot_name, extension) = os.path.splitext(temp_filename) + return shot_name, extension diff --git a/utils/image.py b/utils/image.py new file mode 100644 index 0000000..af08d16 --- /dev/null +++ b/utils/image.py @@ -0,0 +1,85 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/9 21:53 +@desc: +""" +import base64 + +import cv2 +import numpy as np +from skimage import transform + + +def image_encode_base64(image): + """ + 图像转换为base64编码,然后转换为字符串 + + Parameters + ---------- + image + + Returns + ------- + + """ + code = cv2.imencode('.jpg', image)[1] + code = base64.b64encode(code) + code = base64.encodebytes(code) + code = bytes.decode(code) + return code + + +def image_decode_base64(string): + """ + 将字符串转化为图像 + + Parameters + ---------- + string + + Returns + ------- + + """ + byte = str.encode(string) + byte = base64.decodebytes(byte) + byte = base64.b64decode(byte) + byte = np.frombuffer(byte, np.uint8) + img = cv2.imdecode(byte, cv2.CAP_MODE_RGB) + return img + + +def rotate_image(angle, image): + """ + 逆时针旋转一个图像 angle 度 + + Parameters + ---------- + angle: 要旋转的角度 + image: 原始图像 + + Returns + ------- + img_buf: 从旋转后的图像的中心切割出来与原始图像相同shape的图像 + offset: img_buf 相对于旋转后的图像的顶点(左上角)的起始坐标 + + """ + shape = image.shape + arc = angle * np.pi / 180 + post_size = (np.ceil(abs(shape[0] * np.cos(arc)) + + abs(shape[1] * np.sin(arc))).astype(np.int32), + np.ceil(abs(shape[0] * np.sin(arc)) + + abs(shape[1] * np.cos(arc))).astype(np.int32)) + offset = list(map(lambda x, y: (x - y) // 2, post_size, shape)) + if 0 == angle: + img_buf = image.copy() + img_buf = img_buf.astype(np.float32) + else: + img_buf = transform.rotate(image, angle, + resize=True, preserve_range=True) + img_buf = img_buf[offset[0]:(offset[0] + shape[0]), + offset[1]:(offset[1] + shape[1]), :] + return offset, img_buf diff --git a/utils/io.py b/utils/io.py new file mode 100644 index 0000000..8da2c55 --- /dev/null +++ b/utils/io.py @@ -0,0 +1,155 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/9 21:15 +@desc: +""" +import xml.etree.ElementTree as ET + +from object_detection.utils import label_map_util + +from .basic import check_consistent_length + + +def read_label_from_pd_file(pd_file, class_num): + """ + 读取 .pdtxt 文件并返回类别索引 + + Parameters + ---------- + pd_file + class_num + + Returns + ------- + category index, like: + {1: {'id': 1, 'name': 'fisheye'}, + 2: {'id': 2, 'name': 'person'}, + ...} + """ + label_map = label_map_util.load_labelmap(pd_file) + categories = label_map_util.convert_label_map_to_categories( + label_map, max_num_classes=class_num, use_display_name=True) + category_index = label_map_util.create_category_index(categories) + return category_index + + +def get_names_from_category_index(category_index=None, + pd_file='', class_num=0): + """ + 从类别索引或 .pdtxt 文件获取类名 + + Parameters + ---------- + category_index + pd_file + class_num + + Returns + ------- + + """ + label_list = [] + if pd_file != '': + category_index = read_label_from_pd_file(pd_file, class_num) + for idx in range(len(category_index)): + item = category_index[idx + 1] + label_list.append(item['name']) + return label_list + + +def read_voc_xml(file_path, image_size, label_list=()): + """ + 读取 VOC 格式的 xml 文件并提取对象信息 + + Parameters + ---------- + file_path + image_size + label_list + + Returns + ------- + + """ + image_info = {} + tree = ET.parse(file_path.strip()) + root = tree.getroot() + size = root.find('size') + width = int(size.find('width').text) + height = int(size.find('height').text) + image_info['objects'] = [] + + def _read_xml(changed=False): + for obj in root.iter('object'): + cls_name = obj.find('name').text + if len(label_list) > 0: + if cls_name not in label_list: + continue + xml_box = obj.find('bndbox') + xmin = int(xml_box.find('xmin').text) + ymin = int(xml_box.find('ymin').text) + xmax = int(xml_box.find('xmax').text) + ymax = int(xml_box.find('ymax').text) + if changed: + xmin, ymin, xmax, ymax = ymin, xmin, ymax, xmax + if (0 <= xmin < xmax < image_size[1]) or ( + 0 <= ymin < ymax < image_size[0]): + image_info['objects'].append({'name': cls_name, + 'label': label_list.index( + cls_name), + 'box': [xmin, ymin, xmax, ymax]}) + + image_info['shape'] = {'width': image_size[1], 'height': image_size[0]} + if (image_size[0] == height) and (image_size[1] == width): + _read_xml() + return image_info, False + elif (image_size[0] == width) and (image_size[1] == height): + _read_xml(True) + return image_info, False + else: + return None, True + + +def write_text_file(file_name, *file_lists, split=' ', encoding='w'): + """ + 将多个字符串列表写入文本文件中 + + Parameters + ---------- + file_name + file_lists + split + encoding + + Returns + ------- + + """ + check_consistent_length(*file_lists) + with open(file_name, encoding) as file: + for strings in zip(*file_lists): + string = split.join(strings) + file.write(string) + + +def create_map_pdtxt(file_path, label_list): + """ + 根据已有的类别生成 .pdtxt 文件 + + Parameters + ---------- + file_path + label_list + + Returns + ------- + + """ + base_str = """item {}\n id: {}\n name: '{}'\n{}\n""" + string = [] + for idx, label in enumerate(label_list): + string.append(base_str.format('{', idx + 1, label, '}')) + write_text_file(file_path, string) diff --git a/utils/tf_model.py b/utils/tf_model.py new file mode 100644 index 0000000..3844972 --- /dev/null +++ b/utils/tf_model.py @@ -0,0 +1,135 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/9 21:50 +@desc: +""" +import hashlib +import os + +import numpy as np +import tensorflow as tf + +from object_detection.utils import dataset_util + + +def load_model(checkpoint): + """ + load the frozen graph of tensorflow as a detection model + + Parameters + ---------- + checkpoint + + Returns + ------- + + """ + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + detection_graph = tf.Graph() + with detection_graph.as_default(): + od_graph_def = tf.GraphDef() + with tf.gfile.GFile(checkpoint, 'rb') as fid: + serialized_graph = fid.read() + od_graph_def.ParseFromString(serialized_graph) + tf.import_graph_def(od_graph_def, name='') + return detection_graph + + +def run_detection(sess, detection_graph, image_np): + image_tensor = detection_graph.get_tensor_by_name('image_tensor:0') + boxes = detection_graph.get_tensor_by_name('detection_boxes:0') + scores = detection_graph.get_tensor_by_name('detection_scores:0') + classes = detection_graph.get_tensor_by_name('detection_classes:0') + num_detections = detection_graph.get_tensor_by_name('num_detections:0') + (boxes, scores, classes, num_detections) = sess.run( + [boxes, scores, classes, num_detections], + feed_dict={ + image_tensor: image_np + }) + boxes = np.squeeze(boxes) + classes = np.squeeze(classes).astype(np.int32) + scores = np.squeeze(scores) + return boxes, classes, scores + + +def create_tf_example(image_path, + height, + width, + xmin, + ymin, + xmax, + ymax, + classes, + classes_text=(), + truncated=(), + poses=(), + difficult_obj=(), + source_id=''): + """ + 创建 tf example 实例 + + Parameters + ---------- + image_path + height + width + xmin + ymin + xmax + ymax + classes + classes_text + truncated + poses + difficult_obj + source_id + + Returns + ------- + + """ + with tf.gfile.GFile(image_path, 'rb') as fid: + encoded_jpg = fid.read() + key = hashlib.sha256(encoded_jpg).hexdigest() + example = tf.train.Example( + features=tf.train.Features(feature={ + 'image/height': + dataset_util.int64_feature(height), + 'image/width': + dataset_util.int64_feature(width), + 'image/filename': + dataset_util.bytes_feature( + os.path.basename(image_path).encode('utf8')), + 'image/source_id': + dataset_util.bytes_feature( + os.path.basename(source_id).encode('utf8')), + 'image/key/sha256': + dataset_util.bytes_feature(key.encode('utf8')), + 'image/encoded': + dataset_util.bytes_feature(encoded_jpg), + 'image/format': + dataset_util.bytes_feature('jpeg'.encode('utf8')), + 'image/object/bbox/xmin': + dataset_util.float_list_feature(xmin), + 'image/object/bbox/xmax': + dataset_util.float_list_feature(xmax), + 'image/object/bbox/ymin': + dataset_util.float_list_feature(ymin), + 'image/object/bbox/ymax': + dataset_util.float_list_feature(ymax), + 'image/object/class/text': + dataset_util.bytes_list_feature(classes_text), + 'image/object/class/label': + dataset_util.int64_list_feature(classes), + 'image/object/difficult': + dataset_util.int64_list_feature(difficult_obj), + 'image/object/truncated': + dataset_util.int64_list_feature(truncated), + 'image/object/view': + dataset_util.bytes_list_feature(poses), + })) + return example + -- Gitee