diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..85ce2a3e6dad2e88e757067206309f0e650734a9 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +# PyCharm +.idea/ + +# Visual Studio Code +.vscode/ + +# Python +__pycache__/ +build/ +dist/ +*.egg-info/ diff --git a/README.md b/README.md index 10b5c1f9407b2c32632b75e7d87c6373c88d7329..872898b0234c8bbd0227403905efe63bd9c3af94 100644 --- a/README.md +++ b/README.md @@ -1 +1,42 @@ -# tensorflow_basic_template +# tensorflow_basic_template + +我们开发一个可以作为基本工具包的项目,为一些深度学习项目的之后的开发提供统一的规范。 + +## 安装教程 + +### 配置 + +需要安装 `Python`, 建议使用 `Anaconda` ,可以方便的进行包管理。 + +requirements: + +- tensorflow +- numpy + +### 下载 + +使用命令下载: + +```bash +git clone https://gitee.com/study-cooperation/tensorflow_basic_template.git +``` + +或者在 `gitee` 的页面[下载](https://gitee.com/study-cooperation/tensorflow_basic_template)。 + +### 安装 + +执行如下命令安装: + +```bash +cd tensorflow_basic_template +python setup.py install +``` + +## 使用说明 + +安装后的包名为 `dltools` ,可以这样使用该包: + +```python +import dltools +``` + diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eefe15bf0a24b01fcf1b62a7127d52de24b8a90 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/data/base_data_generator.py b/data/base_data_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..69fb516c1895bf0657c015518ccafba65c39d45c --- /dev/null +++ b/data/base_data_generator.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 10:07 +@desc: +""" + +import tensorflow as tf + + +class BaseDataGenerator(object): + """ + 用于生成 tensorflow 支持的训练数据的基本类 + """ + + def __init__(self, data_files, + keys_to_features=None, + shuffle=False, + batch_size=32, + num_epochs=1): + # 初始化 tf dataset + if len(data_files) > 1: + self.dataset = tf.data.Dataset.from_tensor_slices(data_files) + if shuffle: + self.dataset = self.dataset.shuffle(buffer_size=len(data_files)) + self.dataset = self.dataset.flat_map(tf.data.TFRecordDataset) + else: + self.dataset = tf.data.TFRecordDataset(data_files) + self.batch_size = batch_size + self.shuffle = shuffle + self.num_epochs = num_epochs + + self.keys_to_features = keys_to_features + + self._fetch_data() + + def _fetch_data(self): + self.dataset = self.dataset.map(self._record_parser, + num_parallel_calls=5) + self.dataset = self.dataset.prefetch(self.batch_size) + + if self.shuffle: + # When choosing shuffle buffer sizes, larger sizes result in better + # randomness, while smaller sizes have better performance. + self.dataset = self.dataset.shuffle(buffer_size=4096) + + # We call repeat after shuffling, rather than before, to prevent + # separate epochs from blending together. + self.dataset = self.dataset.repeat(self.num_epochs) + self.dataset = self.dataset.batch(self.batch_size) + + def _record_parser(self, value): + """ + 在这里解析 tf record + + Parameters + ---------- + value + + Returns + ------- + + """ + raise NotImplementedError + + def __call__(self): + iterator = self.dataset.make_one_shot_iterator() + images, labels = iterator.get_next() + return images, labels diff --git a/data/base_data_reader.py b/data/base_data_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..6280c19c3c806e6f540d9166e8455d8210958d38 --- /dev/null +++ b/data/base_data_reader.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 17:17 +@desc: 基本数据 +""" +import os + +from utils import log + + +class BaseFileReader(object): + """ + 基本原始图像数据读取 + """ + + def __init__(self, root, + recurrence=True, + max_number=1e12, + display=100, + logger=None): + """ + + Parameters + ---------- + root: 根目录 + recurrence: 是否递归搜索子目录 + max_number: 读入的最大样本数 + display: 显示间隔 + logger: 日志对象 + """ + self.root = root + if not os.path.exists(self.root): + raise NotADirectoryError + + # 当需要递归遍历目录时,该变量保存进入的子目录 + self.buf_root = root + + # 保存中间数据 + self.buf_data = None + self.recurrence = recurrence + + self.max_number = max_number + self.count = 0 + self.display = display + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('FileReader') + + def __iter__(self): + list_file = os.listdir(self.buf_root) + for file in list_file: + file_path = os.path.join(self.buf_root, file) + if os.path.isfile(file_path): + if self._filter(file_path): + if self.count % self.display == 0: + self.logger.info( + 'Reading the number of {}.'.format(self.count)) + yield self.buf_data + elif self.recurrence: + self.buf_root = file_path + for result in self: + yield result + self.buf_root = os.path.dirname(file_path) + else: + pass + + def has_next(self): + """ + 判断是否已经超过最大读取数量 + + Returns + ------- + + """ + if self.count < self.max_number: + self.count += 1 + return True + else: + raise StopIteration + + def _filter(self, file): + """ + 自定义用于筛选文件的方法, 可以重载 + + Returns + ------- + + """ + if self.has_next(): + # 统一使用这样的方法表示数据 + self.buf_data = {'feature': file, # feature 表示数据的特征 + 'label': None, # label 表示数据的标记信息 + 'data': None, # data 表示其他辅助信息 + 'name': file} # name 表示数据的表示 + return True + else: + return False diff --git a/data/base_record_generator.py b/data/base_record_generator.py new file mode 100644 index 0000000000000000000000000000000000000000..b519dc79cb8d13154c289a1554cc437ee2c9cd5d --- /dev/null +++ b/data/base_record_generator.py @@ -0,0 +1,92 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 19:58 +@desc: 用于生成 tensorflow 支持的二进制文件 +""" +import os +from collections import Iterable + +import numpy as np +import tensorflow as tf + +from utils import log + + +class BaseRecordGenerator(object): + """ + 用于生成 tensorflow 支持的二进制文件的基本类 + """ + + def __init__(self, data, output, + mean_value_length=0, + display=100, + logger=None): + """ + + Parameters + ---------- + data: 可迭代的数据对象 + output: 保存 record 的文件名 + mean_value_length: 是否计算数据均值, 为 0 不计算,大于 0 时数字就是均值长度 + display: 显示间隔 + logger: 日志对象 + """ + assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' + self.data = data + + self.mean_value_length = mean_value_length + if self.mean_value_length == 0: + self.mean = np.zeros((self.mean_value_length, ), dtype=np.float128) + else: + self.mean = None + + if not os.path.exists(os.path.dirname(output)): + os.makedirs(os.path.dirname(output)) + + self.writer = tf.python_io.TFRecordWriter(output) + + self.buf_data = {} + self.display = display + self.mean = None + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('RecordGenerator') + + def _encode_data(self): + """ + 将数据转化为 tf example + + Returns + ------- + + """ + raise NotImplementedError + + def update(self): + """ + 处理全部数据 + + Returns + ------- + + """ + count = 0 + for meta in self.data: + if count % self.display == 0: + self.logger.info( + 'Processing the number of {} data.'.format(count)) + self.buf_data['raw'] = meta + self._encode_data() + if self.mean_value_length: + self.mean += self.buf_data['mean'] + self.writer.write(self.buf_data['example'].SerializeToString()) + self.buf_data.clear() + count += 1 + if self.mean_value_length: + self.mean = self.mean / count + self.writer.close() diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..edf8778e9be4a633e8a19e08b10fd18e5b584907 --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" \ No newline at end of file diff --git a/eval/base_eval.py b/eval/base_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..ff396f7fa1ab2b31ff94c7edda7966f199ee064d --- /dev/null +++ b/eval/base_eval.py @@ -0,0 +1,203 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/2 10:08 +@desc: +""" +from collections import Iterable + +import tensorflow as tf + +from utils import log + + +class BaseExporter(object): + """ + 适用于利用 tensorflow SaveModel 导出的模型 + """ + + def __init__(self, model_dir, + model_name, + input_tensor_map, + output_tensor_map, + logger=None): + """ + + Parameters + ---------- + model_dir: checkpoint 路径 + model_name: 导出的模型标记名 + input_tensor_map: Dict 对象,模型的输入数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输入节点的名称 + output_tensor_map: Dict 对象,模型的输出数据张量名称; + key 是计算图中的张量名称, + value 是导出后的模型对应输出节点的名称 + logger: 日志对象 + """ + self.model_dir = model_dir + self.model_name = model_name + self.input_tensor_map = {'map': input_tensor_map, 'name': 'Input'} + self.output_tensor_map = {'map': output_tensor_map, 'name': 'Output'} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Exporter') + + def _get_tensor_map(self, graph, src_tensor_map): + """ + 基于导入和导出张量的名字构造张量字典 + + Parameters + ---------- + graph + src_tensor_map + + Returns + ------- + + """ + dst_tensor_map = {} + self.logger.info( + 'Generating tensor map from {}'.format(src_tensor_map['name'])) + for import_name, export_name in src_tensor_map['map'].items(): + tensor = graph.get_tensor_by_name(import_name) + tensor = tf.saved_model.utils.build_tensor_info(tensor) + dst_tensor_map[export_name] = tensor + self.logger.info( + 'Importing {} from graph and Exporting as {}'.format( + import_name, export_name)) + return dst_tensor_map + + def export(self): + """ + 执行导出操作 + + Returns + ------- + + """ + self.logger.info('Exporting trained model from {}', self.model_dir) + + with tf.Session() as sess: + builder = tf.saved_model.builder.SavedModelBuilder(self.model_dir) + saver = tf.train.import_meta_graph('{}.meta'.format(self.model_dir)) + saver.restore(sess, self.model_dir) + self.logger.info('Loading trained model !') + + graph = tf.get_default_graph() + inputs = self._get_tensor_map(graph, self.input_tensor_map) + outputs = self._get_tensor_map(graph, self.output_tensor_map) + + self.logger.info('Exporting...') + signature = tf.saved_model.signature_def_utils.build_signature_def( + inputs, outputs, + tf.saved_model.signature_constants.PREDICT_METHOD_NAME) + builder.add_meta_graph_and_variables( + sess, ['frozen_model'], + signature_def_map={self.model_name: signature}) + builder.save() + + self.logger.info('Done exporting!') + + +class BasePredictor(object): + """ + 使用 Saved Model 进行预测 + """ + def __init__(self, model_dir, + model_name, + data, + output_tensors, + logger=None, + config=None): + """ + + Parameters + ---------- + model_dir: 导出的 pb 文件的路径 + model_name: 导出的模型标记名 + data: 可迭代的数据对象 + output_tensors: 所需要的输出数据的张量名称 + logger: 日志对象 + config: tensorflow session 的配置对象 + """ + self.model_dir = model_dir + self.model_name = model_name + assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' + self.data = data + if config is not None: + self.config = config + else: + self.config = tf.ConfigProto() + self.config.gpu_options.allow_growth = True + + self.input_tensors = None + self.output_tensors = output_tensors + self.buf_data = {} + + if logger is not None: + self.logger = logger + else: + self.logger = log.get_console_logger('Predictor') + + def _prepare_output_tensor_map(self, sess, signature): + """ + 获取输出变量的关键字和张量字典 + + Returns + ------- + + """ + output_tensor_map = {} + for tensor_name in self.output_tensors: + tensor = signature.inputs[tensor_name].name + tensor = sess.graph.get_tensor_by_name(tensor) + output_tensor_map[tensor_name] = tensor + self.output_tensors = output_tensor_map + + def _prepare_input_data_map(self, sess, signature): + """ + 实现出来计算所需的输入 + + Returns + ------- + + """ + raise NotImplementedError + + def _evaluate_result(self): + """ + 对预测结果的评估 + + Returns + ------- + + """ + raise NotImplementedError + + def predict(self): + """ + 执行预测操作 + + Returns + ------- + + """ + with tf.Session(config=self.config) as sess: + meta_graph_def = tf.saved_model.loader.load(sess, ['frozen_model'], + self.model_dir) + signature = meta_graph_def.signature_def[self.model_name] + self.logger.info('Loading model completed !') + self._prepare_output_tensor_map(sess, signature) + + for data in self.data: + self.buf_data['data'] = data + self._prepare_input_data_map(sess, signature) + self.buf_data['result'] = sess.run(self.output_tensors, + feed_dict=self.input_tensors) + self.logger.info('Predicting {} ...'.format(data['name'])) + self._evaluate_result() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eefe15bf0a24b01fcf1b62a7127d52de24b8a90 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/model/base_model.py b/model/base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0c683cd040161bda5313aa0f90e87af4d1b2dd7e --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 16:28 +@desc: 这是一个基本模型模板类 +""" +import tensorflow as tf + + +class BaseModel(object): + """ + 基本模型模板 + """ + def __init__(self, data_format='channels_last'): + self.data_format = data_format + + def transpose(self, data): + """ + 转换数据格式 + + Returns + ------- + + """ + if self.data_format == 'channels_last': + return data + else: + return tf.transpose(data, [0, 3, 1, 2]) + + def _init_model(self, *args, **kwargs): + """ + 重载这个方法,定义模型 + + Returns + ------- + + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """ + 调用模型 + + Parameters + ---------- + args + kwargs + + Returns + ------- + + """ + return self._init_model(*args, **kwargs) + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6d44545d6a6184b9ef23e054992533cd73d81391 --- /dev/null +++ b/setup.py @@ -0,0 +1,22 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 14:48 +@desc: +""" +from setuptools import find_packages +from setuptools import setup + +setup( + name='dltools', + version='0.1', + include_package_data=True, + packages=find_packages(), + description='deep learning toolbox', + license="MIT Licence", + url="https://gitee.com/study-cooperation", + author="hnu deep learning project group", + platforms="any", +) diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..163209272259173cda130175a70167b78a8e2ce5 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 11:10 +@desc: +""" \ No newline at end of file diff --git a/train/base_train.py b/train/base_train.py new file mode 100644 index 0000000000000000000000000000000000000000..44b576c7e95a6bea8987f3a0e08ea97430c60cde --- /dev/null +++ b/train/base_train.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 11:15 +@desc: +""" +import os + +import tensorflow as tf + + +class BaseTrain(object): + """ + 基本训练类 + """ + + def __init__(self, model_save_dir, + train_input, + params=None, + keep_checkpoint_max=10, + evaluate_input=None): + # Set up a RunConfig to only save checkpoints once per training cycle. + run_config = tf.estimator.RunConfig( + keep_checkpoint_max=keep_checkpoint_max) + self.estimator = tf.estimator.Estimator( + model_fn=self._model, model_dir=model_save_dir, + config=run_config, params=params) + self.model_save_dir = model_save_dir + self.train_input = train_input + self.evaluate_input = evaluate_input + + def _model(self, features, labels, mode, params): + """ + 定义方法模型 + + Parameters + ---------- + features + labels + mode + params + + Returns + ------- + + """ + raise NotImplementedError + + def train(self, log_fmt, log_iter=10, save_steps=100, hooks=None): + """ + 训练 + + Returns + ------- + + """ + logging_hook = tf.train.LoggingTensorHook( + tensors=log_fmt, every_n_iter=log_iter) + saving_hook = tf.train.CheckpointSaverHook( + checkpoint_dir=self.model_save_dir, save_steps=save_steps) + + all_hooks = [logging_hook, saving_hook] + if hooks is not None: + all_hooks += hooks + + self.estimator.train(input_fn=self.train_input, hooks=all_hooks) + + def evaluate(self, hooks=None, checkpoint_path=None): + """ + 测试 + + Returns + ------- + + """ + if checkpoint_path is not None: + assert os.path.exists(checkpoint_path), 'Not such Directory found !' + assert tf.train.get_checkpoint_state( + checkpoint_path), 'Could not find trained model in: {}.'.format( + checkpoint_path) + else: + checkpoint_path = self.model_save_dir + self.estimator.evaluate(input_fn=self.evaluate_input, + checkpoint_path=checkpoint_path, + hooks=hooks) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4eefe15bf0a24b01fcf1b62a7127d52de24b8a90 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000000000000000000000000000000000000..184a9b405f23c90c0f37bd7ef26f60b04eb2de95 --- /dev/null +++ b/utils/log.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:28 +@desc: +""" +import logging + +FORMATTER = '%(name)s %(levelname)s %(asctime)s: %(message)s' + + +def get_console_logger(name, formatter=FORMATTER): + logger = logging.getLogger(name) + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(formatter)) + logger.setLevel(logging.INFO) + logger.addHandler(console_handler) + return logger + + +def get_file_logger(name, file_name, formatter=FORMATTER): + logger = logging.getLogger(name) + file_handler = logging.FileHandler(file_name) + file_handler.setFormatter(logging.Formatter(formatter)) + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(formatter)) + logger.setLevel(logging.INFO) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + return logger