From 1ab49768b571bba678d914091a440538081918e8 Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Fri, 2 Mar 2018 17:34:18 +0800 Subject: [PATCH 1/4] =?UTF-8?q?=E6=9B=B4=E6=96=B0=E4=BC=98=E5=8C=96?= =?UTF-8?q?=E9=83=A8=E5=88=86=E4=BB=A3=E7=A0=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 设定默认的 logger; 2. 添加了注释; 3. 修改的默认 `filter` 并定义了标注的数据格式 添加行间注释 --- data/base_data_reader.py | 51 ++++++++++++++++++++++++++++++++-------- 1 file changed, 41 insertions(+), 10 deletions(-) diff --git a/data/base_data_reader.py b/data/base_data_reader.py index 2987e43..6280c19 100644 --- a/data/base_data_reader.py +++ b/data/base_data_reader.py @@ -8,6 +8,8 @@ """ import os +from utils import log + class BaseFileReader(object): """ @@ -19,29 +21,46 @@ class BaseFileReader(object): max_number=1e12, display=100, logger=None): + """ + + Parameters + ---------- + root: 根目录 + recurrence: 是否递归搜索子目录 + max_number: 读入的最大样本数 + display: 显示间隔 + logger: 日志对象 + """ self.root = root if not os.path.exists(self.root): raise NotADirectoryError + # 当需要递归遍历目录时,该变量保存进入的子目录 self.buf_root = root + + # 保存中间数据 + self.buf_data = None self.recurrence = recurrence + self.max_number = max_number self.count = 0 self.display = display - self.logger = logger + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('FileReader') def __iter__(self): list_file = os.listdir(self.buf_root) for file in list_file: file_path = os.path.join(self.buf_root, file) if os.path.isfile(file_path): - flag, result = self._filter(file_path) - if flag: - if self.logger is not None and self.count % self.display == 0: + if self._filter(file_path): + if self.count % self.display == 0: self.logger.info( 'Reading the number of {}.'.format(self.count)) - yield result + yield self.buf_data elif self.recurrence: self.buf_root = file_path for result in self: @@ -50,7 +69,14 @@ class BaseFileReader(object): else: pass - def _count(self): + def has_next(self): + """ + 判断是否已经超过最大读取数量 + + Returns + ------- + + """ if self.count < self.max_number: self.count += 1 return True @@ -59,13 +85,18 @@ class BaseFileReader(object): def _filter(self, file): """ - 自定义用于筛选文件的方法 + 自定义用于筛选文件的方法, 可以重载 Returns ------- """ - if self._count(): - return True, file + if self.has_next(): + # 统一使用这样的方法表示数据 + self.buf_data = {'feature': file, # feature 表示数据的特征 + 'label': None, # label 表示数据的标记信息 + 'data': None, # data 表示其他辅助信息 + 'name': file} # name 表示数据的表示 + return True else: - return False, file + return False -- Gitee From 797aa9fea8f01ffdda4ceb2c7e693956b08101c5 Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Fri, 2 Mar 2018 17:53:09 +0800 Subject: [PATCH 2/4] =?UTF-8?q?=E4=B8=BB=E8=A6=81=E6=94=B9=E8=BF=9B?= =?UTF-8?q?=E4=BA=86=E8=AE=A1=E7=AE=97=E5=9D=87=E5=80=BC=E7=9A=84=E6=96=B9?= =?UTF-8?q?=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1. 设定默认的 logger; 2. 添加了大量注释; 3. 改进了计算均值的方法 --- data/base_record_generator.py | 36 +++++++++++++++++++++++++++-------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/data/base_record_generator.py b/data/base_record_generator.py index 429af7d..b519dc7 100644 --- a/data/base_record_generator.py +++ b/data/base_record_generator.py @@ -12,6 +12,8 @@ from collections import Iterable import numpy as np import tensorflow as tf +from utils import log + class BaseRecordGenerator(object): """ @@ -19,13 +21,28 @@ class BaseRecordGenerator(object): """ def __init__(self, data, output, - compute_mean_value=True, + mean_value_length=0, display=100, logger=None): + """ + + Parameters + ---------- + data: 可迭代的数据对象 + output: 保存 record 的文件名 + mean_value_length: 是否计算数据均值, 为 0 不计算,大于 0 时数字就是均值长度 + display: 显示间隔 + logger: 日志对象 + """ assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' self.data = data - self.compute_mean_value = compute_mean_value + self.mean_value_length = mean_value_length + if self.mean_value_length == 0: + self.mean = np.zeros((self.mean_value_length, ), dtype=np.float128) + else: + self.mean = None + if not os.path.exists(os.path.dirname(output)): os.makedirs(os.path.dirname(output)) @@ -35,7 +52,10 @@ class BaseRecordGenerator(object): self.display = display self.mean = None - self.logger = logger + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('RecordGenerator') def _encode_data(self): """ @@ -55,18 +75,18 @@ class BaseRecordGenerator(object): ------- """ - mean = np.zeros(3, np.float128) count = 0 for meta in self.data: - if self.logger is not None and count % self.display == 0: + if count % self.display == 0: self.logger.info( 'Processing the number of {} data.'.format(count)) self.buf_data['raw'] = meta self._encode_data() - if self.compute_mean_value: - mean += self.buf_data['mean'] + if self.mean_value_length: + self.mean += self.buf_data['mean'] self.writer.write(self.buf_data['example'].SerializeToString()) self.buf_data.clear() count += 1 - self.mean = mean / count + if self.mean_value_length: + self.mean = self.mean / count self.writer.close() -- Gitee From b521ef50e8b97ec82d40b08ca8e34a9703427880 Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Fri, 2 Mar 2018 17:55:51 +0800 Subject: [PATCH 3/4] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E4=BA=86=E5=9F=BA?= =?UTF-8?q?=E7=A1=80=E6=A3=80=E9=AA=8C=E6=A8=A1=E5=9D=97=EF=BC=8C=E5=8C=85?= =?UTF-8?q?=E6=8B=AC=20Exporter=20=E5=92=8C=20Predictor?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- eval/__init__.py | 8 ++ eval/base_eval.py | 203 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 211 insertions(+) create mode 100644 eval/__init__.py create mode 100644 eval/base_eval.py diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000..edf8778 --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" \ No newline at end of file diff --git a/eval/base_eval.py b/eval/base_eval.py new file mode 100644 index 0000000..ff396f7 --- /dev/null +++ b/eval/base_eval.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" +from collections import Iterable + +import tensorflow as tf + +from utils import log + + +class BaseExporter(object): + """ + 适用于利用 tensorflow SaveModel 导出的模型 + """ + + def __init__(self, model_dir, + model_name, + input_tensor_map, + output_tensor_map, + logger=None): + """ + + Parameters + ---------- + model_dir: checkpoint 路径 + model_name: 导出的模型标记名 + input_tensor_map: Dict 对象,模型的输入数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输入节点的名称 + output_tensor_map: Dict 对象,模型的输出数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输出节点的名称 + logger: 日志对象 + """ + self.model_dir = model_dir + self.model_name = model_name + self.input_tensor_map = {'map': input_tensor_map, 'name': 'Input'} + self.output_tensor_map = {'map': output_tensor_map, 'name': 'Output'} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Exporter') + + def _get_tensor_map(self, graph, src_tensor_map): + """ + 基于导入和导出张量的名字构造张量字典 + + Parameters + ---------- + graph + src_tensor_map + + Returns + ------- + + """ + dst_tensor_map = {} + self.logger.info( + 'Generating tensor map from {}'.format(src_tensor_map['name'])) + for import_name, export_name in src_tensor_map['map'].items(): + tensor = graph.get_tensor_by_name(import_name) + tensor = tf.saved_model.utils.build_tensor_info(tensor) + dst_tensor_map[export_name] = tensor + self.logger.info( + 'Importing {} from graph and Exporting as {}'.format( + import_name, export_name)) + return dst_tensor_map + + def export(self): + """ + 执行导出操作 + + Returns + ------- + + """ + self.logger.info('Exporting trained model from {}', self.model_dir) + + with tf.Session() as sess: + builder = tf.saved_model.builder.SavedModelBuilder(self.model_dir) + saver = tf.train.import_meta_graph('{}.meta'.format(self.model_dir)) + saver.restore(sess, self.model_dir) + self.logger.info('Loading trained model !') + + graph = tf.get_default_graph() + inputs = self._get_tensor_map(graph, self.input_tensor_map) + outputs = self._get_tensor_map(graph, self.output_tensor_map) + + self.logger.info('Exporting...') + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs, outputs, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + builder.add_meta_graph_and_variables( + sess, ['frozen_model'], + signature_def_map={self.model_name: signature}) + builder.save() + + self.logger.info('Done exporting!') + + +class BasePredictor(object): + """ + 使用 Saved Model 进行预测 + """ + def __init__(self, model_dir, + model_name, + data, + output_tensors, + logger=None, + config=None): + """ + + Parameters + ---------- + model_dir: 导出的 pb 文件的路径 + model_name: 导出的模型标记名 + data: 可迭代的数据对象 + output_tensors: 所需要的输出数据的张量名称 + logger: 日志对象 + config: tensorflow session 的配置对象 + """ + self.model_dir = model_dir + self.model_name = model_name + assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' + self.data = data + if config is not None: + self.config = config + else: + self.config = tf.ConfigProto() + self.config.gpu_options.allow_growth = True + + self.input_tensors = None + self.output_tensors = output_tensors + self.buf_data = {} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Predictor') + + def _prepare_output_tensor_map(self, sess, signature): + """ + 获取输出变量的关键字和张量字典 + + Returns + ------- + + """ + output_tensor_map = {} + for tensor_name in self.output_tensors: + tensor = signature.inputs[tensor_name].name + tensor = sess.graph.get_tensor_by_name(tensor) + output_tensor_map[tensor_name] = tensor + self.output_tensors = output_tensor_map + + def _prepare_input_data_map(self, sess, signature): + """ + 实现出来计算所需的输入 + + Returns + ------- + + """ + raise NotImplementedError + + def _evaluate_result(self): + """ + 对预测结果的评估 + + Returns + ------- + + """ + raise NotImplementedError + + def predict(self): + """ + 执行预测操作 + + Returns + ------- + + """ + with tf.Session(config=self.config) as sess: + meta_graph_def = tf.saved_model.loader.load(sess, ['frozen_model'], + self.model_dir) + signature = meta_graph_def.signature_def[self.model_name] + self.logger.info('Loading model completed !') + self._prepare_output_tensor_map(sess, signature) + + for data in self.data: + self.buf_data['data'] = data + self._prepare_input_data_map(sess, signature) + self.buf_data['result'] = sess.run(self.output_tensors, + feed_dict=self.input_tensors) + self.logger.info('Predicting {} ...'.format(data['name'])) + self._evaluate_result() -- Gitee From 151036cfe4a88e82abd34128a3722e0dc80434d5 Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Fri, 2 Mar 2018 17:56:20 +0800 Subject: [PATCH 4/4] =?UTF-8?q?=E5=BF=BD=E7=95=A5=20Python=20=E8=BF=90?= =?UTF-8?q?=E8=A1=8C=E4=BA=A7=E7=94=9F=E7=9A=84=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f10e7b7..1cb956a 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,7 @@ .idea/ # Visual Studio Code -.vscode/ \ No newline at end of file +.vscode/ + +# Python +__pycache__/ -- Gitee