From 0b8203a9bab770d96294ad9c36513ca04c906f5e Mon Sep 17 00:00:00 2001 From: KangGrandesty Date: Thu, 1 Mar 2018 16:39:05 +0800 Subject: [PATCH] =?UTF-8?q?=E5=88=9D=E5=A7=8B=E5=8C=96=E6=8F=90=E4=BA=A4?= =?UTF-8?q?=EF=BC=8C=E5=AE=9A=E4=B9=89=E4=BA=86=E9=83=A8=E5=88=86=E6=A8=A1?= =?UTF-8?q?=E6=9D=BF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 删除目录.idea/inspectionProfiles --- .gitignore | 5 ++ README.md | 2 +- data/__init__.py | 8 ++++ data/base_data_generator.py | 71 ++++++++++++++++++++++++++++ data/base_data_reader.py | 71 ++++++++++++++++++++++++++++ data/base_record_generator.py | 72 +++++++++++++++++++++++++++++ model/__init__.py | 8 ++++ model/base_model.py | 56 ++++++++++++++++++++++ setup.py | 26 +++++++++++ train/__init__.py | 8 ++++ train/base_train.py | 87 +++++++++++++++++++++++++++++++++++ utils/__init__.py | 8 ++++ utils/log.py | 32 +++++++++++++ 13 files changed, 453 insertions(+), 1 deletion(-) create mode 100644 .gitignore create mode 100644 data/__init__.py create mode 100644 data/base_data_generator.py create mode 100644 data/base_data_reader.py create mode 100644 data/base_record_generator.py create mode 100644 model/__init__.py create mode 100644 model/base_model.py create mode 100644 setup.py create mode 100644 train/__init__.py create mode 100644 train/base_train.py create mode 100644 utils/__init__.py create mode 100644 utils/log.py diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f10e7b7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +# PyCharm +.idea/ + +# Visual Studio Code +.vscode/ \ No newline at end of file diff --git a/README.md b/README.md index 10b5c1f..0e58aac 100644 --- a/README.md +++ b/README.md @@ -1 +1 @@ -# tensorflow_basic_template +# tensorflow_basic_template diff --git a/data/__init__.py b/data/__init__.py new file mode 100644 index 0000000..4eefe15 --- /dev/null +++ b/data/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/data/base_data_generator.py b/data/base_data_generator.py new file mode 100644 index 0000000..69fb516 --- /dev/null +++ b/data/base_data_generator.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 10:07 +@desc: +""" + +import tensorflow as tf + + +class BaseDataGenerator(object): + """ + 用于生成 tensorflow 支持的训练数据的基本类 + """ + + def __init__(self, data_files, + keys_to_features=None, + shuffle=False, + batch_size=32, + num_epochs=1): + # 初始化 tf dataset + if len(data_files) > 1: + self.dataset = tf.data.Dataset.from_tensor_slices(data_files) + if shuffle: + self.dataset = self.dataset.shuffle(buffer_size=len(data_files)) + self.dataset = self.dataset.flat_map(tf.data.TFRecordDataset) + else: + self.dataset = tf.data.TFRecordDataset(data_files) + self.batch_size = batch_size + self.shuffle = shuffle + self.num_epochs = num_epochs + + self.keys_to_features = keys_to_features + + self._fetch_data() + + def _fetch_data(self): + self.dataset = self.dataset.map(self._record_parser, + num_parallel_calls=5) + self.dataset = self.dataset.prefetch(self.batch_size) + + if self.shuffle: + # When choosing shuffle buffer sizes, larger sizes result in better + # randomness, while smaller sizes have better performance. + self.dataset = self.dataset.shuffle(buffer_size=4096) + + # We call repeat after shuffling, rather than before, to prevent + # separate epochs from blending together. + self.dataset = self.dataset.repeat(self.num_epochs) + self.dataset = self.dataset.batch(self.batch_size) + + def _record_parser(self, value): + """ + 在这里解析 tf record + + Parameters + ---------- + value + + Returns + ------- + + """ + raise NotImplementedError + + def __call__(self): + iterator = self.dataset.make_one_shot_iterator() + images, labels = iterator.get_next() + return images, labels diff --git a/data/base_data_reader.py b/data/base_data_reader.py new file mode 100644 index 0000000..2987e43 --- /dev/null +++ b/data/base_data_reader.py @@ -0,0 +1,71 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 17:17 +@desc: 基本数据 +""" +import os + + +class BaseFileReader(object): + """ + 基本原始图像数据读取 + """ + + def __init__(self, root, + recurrence=True, + max_number=1e12, + display=100, + logger=None): + self.root = root + if not os.path.exists(self.root): + raise NotADirectoryError + + self.buf_root = root + self.recurrence = recurrence + self.max_number = max_number + self.count = 0 + self.display = display + + self.logger = logger + + def __iter__(self): + list_file = os.listdir(self.buf_root) + for file in list_file: + file_path = os.path.join(self.buf_root, file) + if os.path.isfile(file_path): + flag, result = self._filter(file_path) + if flag: + if self.logger is not None and self.count % self.display == 0: + self.logger.info( + 'Reading the number of {}.'.format(self.count)) + yield result + elif self.recurrence: + self.buf_root = file_path + for result in self: + yield result + self.buf_root = os.path.dirname(file_path) + else: + pass + + def _count(self): + if self.count < self.max_number: + self.count += 1 + return True + else: + raise StopIteration + + def _filter(self, file): + """ + 自定义用于筛选文件的方法 + + Returns + ------- + + """ + if self._count(): + return True, file + else: + return False, file diff --git a/data/base_record_generator.py b/data/base_record_generator.py new file mode 100644 index 0000000..429af7d --- /dev/null +++ b/data/base_record_generator.py @@ -0,0 +1,72 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 19:58 +@desc: 用于生成 tensorflow 支持的二进制文件 +""" +import os +from collections import Iterable + +import numpy as np +import tensorflow as tf + + +class BaseRecordGenerator(object): + """ + 用于生成 tensorflow 支持的二进制文件的基本类 + """ + + def __init__(self, data, output, + compute_mean_value=True, + display=100, + logger=None): + assert isinstance(data, Iterable), '请输入正确的数据!data必须是可迭代的对象。' + self.data = data + + self.compute_mean_value = compute_mean_value + if not os.path.exists(os.path.dirname(output)): + os.makedirs(os.path.dirname(output)) + + self.writer = tf.python_io.TFRecordWriter(output) + + self.buf_data = {} + self.display = display + self.mean = None + + self.logger = logger + + def _encode_data(self): + """ + 将数据转化为 tf example + + Returns + ------- + + """ + raise NotImplementedError + + def update(self): + """ + 处理全部数据 + + Returns + ------- + + """ + mean = np.zeros(3, np.float128) + count = 0 + for meta in self.data: + if self.logger is not None and count % self.display == 0: + self.logger.info( + 'Processing the number of {} data.'.format(count)) + self.buf_data['raw'] = meta + self._encode_data() + if self.compute_mean_value: + mean += self.buf_data['mean'] + self.writer.write(self.buf_data['example'].SerializeToString()) + self.buf_data.clear() + count += 1 + self.mean = mean / count + self.writer.close() diff --git a/model/__init__.py b/model/__init__.py new file mode 100644 index 0000000..4eefe15 --- /dev/null +++ b/model/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/model/base_model.py b/model/base_model.py new file mode 100644 index 0000000..0c683cd --- /dev/null +++ b/model/base_model.py @@ -0,0 +1,56 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 16:28 +@desc: 这是一个基本模型模板类 +""" +import tensorflow as tf + + +class BaseModel(object): + """ + 基本模型模板 + """ + def __init__(self, data_format='channels_last'): + self.data_format = data_format + + def transpose(self, data): + """ + 转换数据格式 + + Returns + ------- + + """ + if self.data_format == 'channels_last': + return data + else: + return tf.transpose(data, [0, 3, 1, 2]) + + def _init_model(self, *args, **kwargs): + """ + 重载这个方法,定义模型 + + Returns + ------- + + """ + raise NotImplementedError + + def __call__(self, *args, **kwargs): + """ + 调用模型 + + Parameters + ---------- + args + kwargs + + Returns + ------- + + """ + return self._init_model(*args, **kwargs) + diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..2e6dc09 --- /dev/null +++ b/setup.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 14:48 +@desc: +""" +from setuptools import find_packages +from setuptools import setup + +setup( + name='dltools', + version='0.1', + include_package_data=True, + packages=find_packages(), + description='deep learning toolbox', + license="MIT Licence", + url="https://gitee.com/study-cooperation", + author="hnu deep learning project group", + platforms="any", + install_requires=[ + "tensorflow >= 1.4.0", + "numpy >= 1.13" + ], +) diff --git a/train/__init__.py b/train/__init__.py new file mode 100644 index 0000000..1632092 --- /dev/null +++ b/train/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 11:10 +@desc: +""" \ No newline at end of file diff --git a/train/base_train.py b/train/base_train.py new file mode 100644 index 0000000..44b576c --- /dev/null +++ b/train/base_train.py @@ -0,0 +1,87 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/3/1 11:15 +@desc: +""" +import os + +import tensorflow as tf + + +class BaseTrain(object): + """ + 基本训练类 + """ + + def __init__(self, model_save_dir, + train_input, + params=None, + keep_checkpoint_max=10, + evaluate_input=None): + # Set up a RunConfig to only save checkpoints once per training cycle. + run_config = tf.estimator.RunConfig( + keep_checkpoint_max=keep_checkpoint_max) + self.estimator = tf.estimator.Estimator( + model_fn=self._model, model_dir=model_save_dir, + config=run_config, params=params) + self.model_save_dir = model_save_dir + self.train_input = train_input + self.evaluate_input = evaluate_input + + def _model(self, features, labels, mode, params): + """ + 定义方法模型 + + Parameters + ---------- + features + labels + mode + params + + Returns + ------- + + """ + raise NotImplementedError + + def train(self, log_fmt, log_iter=10, save_steps=100, hooks=None): + """ + 训练 + + Returns + ------- + + """ + logging_hook = tf.train.LoggingTensorHook( + tensors=log_fmt, every_n_iter=log_iter) + saving_hook = tf.train.CheckpointSaverHook( + checkpoint_dir=self.model_save_dir, save_steps=save_steps) + + all_hooks = [logging_hook, saving_hook] + if hooks is not None: + all_hooks += hooks + + self.estimator.train(input_fn=self.train_input, hooks=all_hooks) + + def evaluate(self, hooks=None, checkpoint_path=None): + """ + 测试 + + Returns + ------- + + """ + if checkpoint_path is not None: + assert os.path.exists(checkpoint_path), 'Not such Directory found !' + assert tf.train.get_checkpoint_state( + checkpoint_path), 'Could not find trained model in: {}.'.format( + checkpoint_path) + else: + checkpoint_path = self.model_save_dir + self.estimator.evaluate(input_fn=self.evaluate_input, + checkpoint_path=checkpoint_path, + hooks=hooks) diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..4eefe15 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,8 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:26 +@desc: +""" \ No newline at end of file diff --git a/utils/log.py b/utils/log.py new file mode 100644 index 0000000..31e41a6 --- /dev/null +++ b/utils/log.py @@ -0,0 +1,32 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@author: liang kang +@contact: gangkanli1219@gmail.com +@time: 2018/2/28 20:28 +@desc: +""" +import logging + +formatter = logging.Formatter('%(name)s %(levelname)s %(asctime)s: %(message)s') + + +def get_console_logger(name): + logger = logging.getLogger(name) + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(formatter)) + logger.setLevel(logging.INFO) + logger.addHandler(console_handler) + return logger + + +def get_file_logger(name, file_name): + logger = logging.getLogger(name) + file_handler = logging.FileHandler(file_name) + file_handler.setFormatter(logging.Formatter(formatter)) + console_handler = logging.StreamHandler() + console_handler.setFormatter(logging.Formatter(formatter)) + logger.setLevel(logging.INFO) + logger.addHandler(file_handler) + logger.addHandler(console_handler) + return logger -- Gitee