diff --git a/.gitignore b/.gitignore index f10e7b78c8bd3fb11c5700195a7cc8e93110cc47..1cb956a010ca7bd81c79e29c1aaded8911462211 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,7 @@ .idea/ # Visual Studio Code -.vscode/ \ No newline at end of file +.vscode/ + +# Python +__pycache__/ diff --git a/data/base_data_reader.py b/data/base_data_reader.py index 2987e43c651b314145f7fff2c465af80e97138c2..6280c19c3c806e6f540d9166e8455d8210958d38 100644 --- a/data/base_data_reader.py +++ b/data/base_data_reader.py @@ -8,6 +8,8 @@ """ import os +from utils import log + class BaseFileReader(object): """ @@ -19,29 +21,46 @@ class BaseFileReader(object): max_number=1e12, display=100, logger=None): + """ + + Parameters + ---------- + root: 根目录 + recurrence: 是否递归搜索子目录 + max_number: 读入的最大样本数 + display: 显示间隔 + logger: 日志对象 + """ self.root = root if not os.path.exists(self.root): raise NotADirectoryError + # 当需要递归遍历目录时,该变量保存进入的子目录 self.buf_root = root + + # 保存中间数据 + self.buf_data = None self.recurrence = recurrence + self.max_number = max_number self.count = 0 self.display = display - self.logger = logger + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('FileReader') def __iter__(self): list_file = os.listdir(self.buf_root) for file in list_file: file_path = os.path.join(self.buf_root, file) if os.path.isfile(file_path): - flag, result = self._filter(file_path) - if flag: - if self.logger is not None and self.count % self.display == 0: + if self._filter(file_path): + if self.count % self.display == 0: self.logger.info( 'Reading the number of {}.'.format(self.count)) - yield result + yield self.buf_data elif self.recurrence: self.buf_root = file_path for result in self: @@ -50,7 +69,14 @@ class BaseFileReader(object): else: pass - def _count(self): + def has_next(self): + """ + 判断是否已经超过最大读取数量 + + Returns + ------- + + """ if self.count < self.max_number: self.count += 1 return True @@ -59,13 +85,18 @@ class BaseFileReader(object): def _filter(self, file): """ - 自定义用于筛选文件的方法 + 自定义用于筛选文件的方法, 可以重载 Returns ------- """ - if self._count(): - return True, file + if self.has_next(): + # 统一使用这样的方法表示数据 + self.buf_data = {'feature': file, # feature 表示数据的特征 + 'label': None, # label 表示数据的标记信息 + 'data': None, # data 表示其他辅助信息 + 'name': file} # name 表示数据的表示 + return True else: - return False, file + return False diff --git a/data/base_record_generator.py b/data/base_record_generator.py index 429af7db21518b1c47e664edcf37205e62708042..b519dc79cb8d13154c289a1554cc437ee2c9cd5d 100644 --- a/data/base_record_generator.py +++ b/data/base_record_generator.py @@ -12,6 +12,8 @@ from collections import Iterable import numpy as np import tensorflow as tf +from utils import log + class BaseRecordGenerator(object): """ @@ -19,13 +21,28 @@ class BaseRecordGenerator(object): """ def __init__(self, data, output, - compute_mean_value=True, + mean_value_length=0, display=100, logger=None): + """ + + Parameters + ---------- + data: 可迭代的数据对象 + output: 保存 record 的文件名 + mean_value_length: 是否计算数据均值, 为 0 不计算,大于 0 时数字就是均值长度 + display: 显示间隔 + logger: 日志对象 + """ assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' self.data = data - self.compute_mean_value = compute_mean_value + self.mean_value_length = mean_value_length + if self.mean_value_length == 0: + self.mean = np.zeros((self.mean_value_length, ), dtype=np.float128) + else: + self.mean = None + if not os.path.exists(os.path.dirname(output)): os.makedirs(os.path.dirname(output)) @@ -35,7 +52,10 @@ class BaseRecordGenerator(object): self.display = display self.mean = None - self.logger = logger + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('RecordGenerator') def _encode_data(self): """ @@ -55,18 +75,18 @@ class BaseRecordGenerator(object): ------- """ - mean = np.zeros(3, np.float128) count = 0 for meta in self.data: - if self.logger is not None and count % self.display == 0: + if count % self.display == 0: self.logger.info( 'Processing the number of {} data.'.format(count)) self.buf_data['raw'] = meta self._encode_data() - if self.compute_mean_value: - mean += self.buf_data['mean'] + if self.mean_value_length: + self.mean += self.buf_data['mean'] self.writer.write(self.buf_data['example'].SerializeToString()) self.buf_data.clear() count += 1 - self.mean = mean / count + if self.mean_value_length: + self.mean = self.mean / count self.writer.close() diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edf8778e9be4a633e8a19e08b10fd18e5b584907 --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" \ No newline at end of file diff --git a/eval/base_eval.py b/eval/base_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ff396f7fa1ab2b31ff94c7edda7966f199ee064d --- /dev/null +++ b/eval/base_eval.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" +from collections import Iterable + +import tensorflow as tf + +from utils import log + + +class BaseExporter(object): + """ + 适用于利用 tensorflow SaveModel 导出的模型 + """ + + def __init__(self, model_dir, + model_name, + input_tensor_map, + output_tensor_map, + logger=None): + """ + + Parameters + ---------- + model_dir: checkpoint 路径 + model_name: 导出的模型标记名 + input_tensor_map: Dict 对象,模型的输入数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输入节点的名称 + output_tensor_map: Dict 对象,模型的输出数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输出节点的名称 + logger: 日志对象 + """ + self.model_dir = model_dir + self.model_name = model_name + self.input_tensor_map = {'map': input_tensor_map, 'name': 'Input'} + self.output_tensor_map = {'map': output_tensor_map, 'name': 'Output'} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Exporter') + + def _get_tensor_map(self, graph, src_tensor_map): + """ + 基于导入和导出张量的名字构造张量字典 + + Parameters + ---------- + graph + src_tensor_map + + Returns + ------- + + """ + dst_tensor_map = {} + self.logger.info( + 'Generating tensor map from {}'.format(src_tensor_map['name'])) + for import_name, export_name in src_tensor_map['map'].items(): + tensor = graph.get_tensor_by_name(import_name) + tensor = tf.saved_model.utils.build_tensor_info(tensor) + dst_tensor_map[export_name] = tensor + self.logger.info( + 'Importing {} from graph and Exporting as {}'.format( + import_name, export_name)) + return dst_tensor_map + + def export(self): + """ + 执行导出操作 + + Returns + ------- + + """ + self.logger.info('Exporting trained model from {}', self.model_dir) + + with tf.Session() as sess: + builder = tf.saved_model.builder.SavedModelBuilder(self.model_dir) + saver = tf.train.import_meta_graph('{}.meta'.format(self.model_dir)) + saver.restore(sess, self.model_dir) + self.logger.info('Loading trained model !') + + graph = tf.get_default_graph() + inputs = self._get_tensor_map(graph, self.input_tensor_map) + outputs = self._get_tensor_map(graph, self.output_tensor_map) + + self.logger.info('Exporting...') + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs, outputs, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + builder.add_meta_graph_and_variables( + sess, ['frozen_model'], + signature_def_map={self.model_name: signature}) + builder.save() + + self.logger.info('Done exporting!') + + +class BasePredictor(object): + """ + 使用 Saved Model 进行预测 + """ + def __init__(self, model_dir, + model_name, + data, + output_tensors, + logger=None, + config=None): + """ + + Parameters + ---------- + model_dir: 导出的 pb 文件的路径 + model_name: 导出的模型标记名 + data: 可迭代的数据对象 + output_tensors: 所需要的输出数据的张量名称 + logger: 日志对象 + config: tensorflow session 的配置对象 + """ + self.model_dir = model_dir + self.model_name = model_name + assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' + self.data = data + if config is not None: + self.config = config + else: + self.config = tf.ConfigProto() + self.config.gpu_options.allow_growth = True + + self.input_tensors = None + self.output_tensors = output_tensors + self.buf_data = {} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Predictor') + + def _prepare_output_tensor_map(self, sess, signature): + """ + 获取输出变量的关键字和张量字典 + + Returns + ------- + + """ + output_tensor_map = {} + for tensor_name in self.output_tensors: + tensor = signature.inputs[tensor_name].name + tensor = sess.graph.get_tensor_by_name(tensor) + output_tensor_map[tensor_name] = tensor + self.output_tensors = output_tensor_map + + def _prepare_input_data_map(self, sess, signature): + """ + 实现出来计算所需的输入 + + Returns + ------- + + """ + raise NotImplementedError + + def _evaluate_result(self): + """ + 对预测结果的评估 + + Returns + ------- + + """ + raise NotImplementedError + + def predict(self): + """ + 执行预测操作 + + Returns + ------- + + """ + with tf.Session(config=self.config) as sess: + meta_graph_def = tf.saved_model.loader.load(sess, ['frozen_model'], + self.model_dir) + signature = meta_graph_def.signature_def[self.model_name] + self.logger.info('Loading model completed !') + self._prepare_output_tensor_map(sess, signature) + + for data in self.data: + self.buf_data['data'] = data + self._prepare_input_data_map(sess, signature) + self.buf_data['result'] = sess.run(self.output_tensors, + feed_dict=self.input_tensors) + self.logger.info('Predicting {} ...'.format(data['name'])) + self._evaluate_result()