diff --git a/TensorFlow/contrib/nlp/tabnet/README.md b/TensorFlow/contrib/nlp/tabnet/README.md new file mode 100644 index 0000000000000000000000000000000000000000..873de110e7d635562a6cfbfe193ec407f21d0ac0 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/README.md @@ -0,0 +1,151 @@ +## 基本信息 + +**发布者(Publisher):Huawei** + +**应用领域(Application Domain):NLP** + +**版本(Version):1.2** + +**修改时间(Modified) :2022.11.14** + +**大小(Size):66KB** + +**框架(Framework):TensorFlow 1.15.0** + +**模型格式(Model Format):ckpt** + +**精度(Precision):Mixed** + +**处理器(Processor):昇腾910** + +**应用级别(Categories):Contrib** + +**描述(Description):基于TensorFlow框架的tabnet表格处理神经网络训练代码 ** + +## 概述 + +Papåer: https://arxiv.org/abs/1908.07442 + +This directory contains an example implementation of TabNet on the Forest Covertype dataset (https://archive.ics.uci.edu/ml/datasets/covertype). + +First, run `python -m download_prepare_covertype.py` to download and prepare the Forest Covertype dataset. This command creates train.csv, val.csv and test.csv files under the "data/" directory. + +To run the pipeline for training and evaluation, simply use `python -m experiment_covertype.py`. For debugging in a low-resource environment, you can use `python -m test_experiment_covertype.py`. + +To modify the experiment to other tabular datasets: + +- Substitute the train.csv, val.csv, and test.csv files under "data/" directory, +- Modify the data_helper function with the numerical and categorical features of the new dataset, +- Reoptimize the TabNet hyperparameters for the new dataset. + +## Requirements + +- Tensorflow 1.15.0 +- absl-py >= 0.5.0 +- numpy \=\= 1.15.1 +- wget >\= 3.2 +- Ascend910 + +## 模型训练 + +### 脚本和示例代码 + +``` +tabnet +├── check_result.tf.json +├── data_helper_covertype.py +├── download_prepare_covertype.py +├── experiment_covertype.py +├── fusion_result.json +├── fusion_switch.cfg +├── myTest.py +├── requirements.txt +├── run.sh +├── tabnet README.md +├── tabnet_model.py +└── test_experiment_covertype.py +``` + +### 脚本参数 + +``` +- TRAIN_FILE = "data/train_covertype.csv" +- VAL_FILE = "data/val_covertype.csv" +- TEST_FILE = "data/test_covertype.csv" +- MAX_STEPS = 10 +- DISPLAY_STEP = 5 +- VAL_STEP = 5 +- SAVE_STEP = 40000 +- INIT_LEARNING_RATE = 0.02 +- DECAY_EVERY = 500 +- DECAY_RATE = 0.95 +- BATCH_SIZE = 32 +- SPARSITY_LOSS_WEIGHT = 0.0001 +- GRADIENT_THRESH = 2000.0 +- SEED = 1 +``` + +### 训练环境 + +- 华为NPU裸机 + +### 训练过程 + +- 首先运行 `python -m download_prepare_covertype.py` 下载数据集 +- 基于[原始代码](https://github.com/google-research/google-research/tree/master/tabnet),参考[TensorFlow 1.15网络模型迁移和训练](https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/600alpha001/moddevg/tfmigr1/atlasmprtg_13_0009.html)官方文档,将其中代码适配入NPU +- 全量运行参数过大,目前仅在GPU和NPU执行了 `python -m test_experiment_covertype.py` ,同时对比了精度和性能 + +## 结果对比 + +### GPU训练结果 + +``` +Step 1 , Step Training Time = 15.6996 +Step 2 , Step Training Time = 0.0656 +Step 3 , Step Training Time = 0.0643 +Step 4 , Step Training Time = 0.0658 +Step 5 , Training Loss = 2.1249 +Step 5 , Step Training Time = 42.3971 +Step 5 , Val Accuracy = 0.3649 +Step 6 , Step Training Time = 0.0686 +Step 7 , Step Training Time = 0.0667 +Step 8 , Step Training Time = 0.0663 +Step 9 , Step Training Time = 0.0665 +Step 10 , Training Loss = 1.9217 +Step 10 , Step Training Time = 16.2367 +Step 10 , Val Accuracy = 0.3444 +``` + +### NPU训练结果 + +``` +Step 1 , Step Training Time = 931.2861 +Step 2 , Step Training Time = 0.3418 +Step 3 , Step Training Time = 0.3229 +Step 4 , Step Training Time = 0.3138 +Step 5 , Training Loss = 1.7622 +Step 5 , Step Training Time = 253.5153 +Step 5 , Val Accuracy = 0.2752 +Step 6 , Step Training Time = 116.0940 +Step 7 , Step Training Time = 0.4088 +Step 8 , Step Training Time = 0.3603 +Step 9 , Step Training Time = 0.3069 +Step 10 , Training Loss = 1.7022 +Step 10 , Step Training Time = 196.8820 +Step 10 , Val Accuracy = 0.3621 +``` + +### 精度与性能分析 + +- 不考虑NPU train和eval的首次编图耗时,目前NPU的训练耗时大约是GPU的5倍,最终精度有所提升 + + + + + + + + + + + diff --git a/TensorFlow/contrib/nlp/tabnet/check_result.tf.json b/TensorFlow/contrib/nlp/tabnet/check_result.tf.json new file mode 100644 index 0000000000000000000000000000000000000000..f45084148c3728c68ce23e424406b14be9a4fe8b --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/check_result.tf.json @@ -0,0 +1,256 @@ +{ + "op": [ + { + "is_support": false, + "name": "Encoder/Aggregated_mask", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder/Mask_for_step0", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder/Mask_for_step1", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder/Mask_for_step2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder/Mask_for_step3", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder/Mask_for_step4", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Aggregated_mask", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Mask_for_step0", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Mask_for_step1", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Mask_for_step2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Mask_for_step3", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_1/Mask_for_step4", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Aggregated_mask", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Mask_for_step0", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Mask_for_step1", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Mask_for_step2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Mask_for_step3", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "Encoder_2/Mask_for_step4", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ImageSummary" + }, + { + "is_support": false, + "name": "IteratorGetNext", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorGetNext" + }, + { + "is_support": false, + "name": "IteratorGetNext_1", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorGetNext" + }, + { + "is_support": false, + "name": "IteratorGetNext_2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorGetNext" + }, + { + "is_support": false, + "name": "IteratorV2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorV2" + }, + { + "is_support": false, + "name": "IteratorV2_1", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorV2" + }, + { + "is_support": false, + "name": "IteratorV2_2", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "IteratorV2" + }, + { + "is_support": false, + "name": "Merge/MergeSummary", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "MergeSummary" + }, + { + "is_support": false, + "name": "Test_accuracy", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ScalarSummary" + }, + { + "is_support": false, + "name": "Total_loss", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ScalarSummary" + }, + { + "is_support": false, + "name": "Val_accuracy", + "not_support_reason": { + "code": 1, + "message": "This op is not exsit on npu." + }, + "type": "ScalarSummary" + } + ] +} \ No newline at end of file diff --git a/TensorFlow/contrib/nlp/tabnet/data_helper_covertype.py b/TensorFlow/contrib/nlp/tabnet/data_helper_covertype.py new file mode 100644 index 0000000000000000000000000000000000000000..98ecbb83f73a93984a35ff2d75d6f1c85056fa69 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/data_helper_covertype.py @@ -0,0 +1,131 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Data helper function for the Forest Covertype dataset.""" + +import tensorflow as tf + +# Dataset size +# N_TRAIN_SAMPLES = 309871 +N_VAL_SAMPLES = 154937 +N_TEST_SAMPLES = 116203 +NUM_FEATURES = 54 +NUM_CLASSES = 7 + +# All feature columns in the data +LABEL_COLUMN = "Covertype" + +BOOL_COLUMNS = [ + "Wilderness_Area1", "Wilderness_Area2", "Wilderness_Area3", + "Wilderness_Area4", "Soil_Type1", "Soil_Type2", "Soil_Type3", "Soil_Type4", + "Soil_Type5", "Soil_Type6", "Soil_Type7", "Soil_Type8", "Soil_Type9", + "Soil_Type10", "Soil_Type11", "Soil_Type12", "Soil_Type13", "Soil_Type14", + "Soil_Type15", "Soil_Type16", "Soil_Type17", "Soil_Type18", "Soil_Type19", + "Soil_Type20", "Soil_Type21", "Soil_Type22", "Soil_Type23", "Soil_Type24", + "Soil_Type25", "Soil_Type26", "Soil_Type27", "Soil_Type28", "Soil_Type29", + "Soil_Type30", "Soil_Type31", "Soil_Type32", "Soil_Type33", "Soil_Type34", + "Soil_Type35", "Soil_Type36", "Soil_Type37", "Soil_Type38", "Soil_Type39", + "Soil_Type40" +] + +INT_COLUMNS = [ + "Elevation", "Aspect", "Slope", "Horizontal_Distance_To_Hydrology", + "Vertical_Distance_To_Hydrology", "Horizontal_Distance_To_Roadways", + "Hillshade_9am", "Hillshade_Noon", "Hillshade_3pm", + "Horizontal_Distance_To_Fire_Points" +] + +STR_COLUMNS = [] +STR_NUNIQUESS = [] + +FLOAT_COLUMNS = [] + +DEFAULTS = ([[0] for col in INT_COLUMNS] + [[""] for col in BOOL_COLUMNS] + + [[0.0] for col in FLOAT_COLUMNS] + [[""] for col in STR_COLUMNS] + + [[-1]]) + +FEATURE_COLUMNS = ( + INT_COLUMNS + BOOL_COLUMNS + STR_COLUMNS + FLOAT_COLUMNS) +ALL_COLUMNS = FEATURE_COLUMNS + [LABEL_COLUMN] + + +def get_columns(): + """Get the representations for all input columns.""" + + columns = [] + if FLOAT_COLUMNS: + columns += [tf.feature_column.numeric_column(ci) for ci in FLOAT_COLUMNS] + if INT_COLUMNS: + columns += [tf.feature_column.numeric_column(ci) for ci in INT_COLUMNS] + if STR_COLUMNS: + # pylint: disable=g-complex-comprehension + columns += [ + tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_hash_bucket( + ci, hash_bucket_size=int(3 * num)), + dimension=1) for ci, num in zip(STR_COLUMNS, STR_NUNIQUESS) + ] + if BOOL_COLUMNS: + # pylint: disable=g-complex-comprehension + columns += [ + tf.feature_column.embedding_column( + tf.feature_column.categorical_column_with_hash_bucket( + ci, hash_bucket_size=3), + dimension=1) for ci in BOOL_COLUMNS + ] + return columns + + +def parse_csv(value_column): + """Parses a CSV file based on the provided column types.""" + columns = tf.decode_csv(value_column, record_defaults=DEFAULTS) + features = dict(zip(ALL_COLUMNS, columns)) + label = features.pop(LABEL_COLUMN) + classes = tf.cast(label, tf.int32) - 1 + return features, classes + + +def input_fn(data_file, + num_epochs, + shuffle, + batch_size, + n_buffer=50, + n_parallel=16): + """Function to read the input file and return the dataset. + + Args: + data_file: Name of the file. + num_epochs: Number of epochs. + shuffle: Whether to shuffle the data. + batch_size: Batch size. + n_buffer: Buffer size. + n_parallel: Number of cores for multi-core processing option. + + Returns: + The Tensorflow dataset. + """ + + # Extract lines from input files using the Dataset API. + dataset = tf.data.TextLineDataset(data_file) + + if shuffle: + dataset = dataset.shuffle(buffer_size=n_buffer) + + dataset = dataset.map(parse_csv, num_parallel_calls=n_parallel) + + # Repeat after shuffling, to prevent separate epochs from blending together. + dataset = dataset.repeat(num_epochs) + dataset = dataset.batch(batch_size, drop_remainder=True) + return dataset diff --git a/TensorFlow/contrib/nlp/tabnet/download_prepare_covertype.py b/TensorFlow/contrib/nlp/tabnet/download_prepare_covertype.py new file mode 100644 index 0000000000000000000000000000000000000000..b430368569254b224b6c5dde37e1c1bbbaf0a091 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/download_prepare_covertype.py @@ -0,0 +1,61 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Downloads and prepares the Forest Covertype dataset.""" + +import gzip +import os +import shutil +import pandas as pd +from sklearn.model_selection import train_test_split +import wget + +URL = 'https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz' + + +def main(): + + if not os.path.exists('./data'): + os.makedirs('./data') + + filename = wget.download(URL) + with gzip.open(filename, 'rb') as f_in: + with open('data/covtype.csv', 'wb') as f_out: + shutil.copyfileobj(f_in, f_out) + + df = pd.read_csv('data/covtype.csv') + n_total = len(df) + + # Train, val and test split follows + # Rory Mitchell, Andrey Adinets, Thejaswi Rao, and Eibe Frank. + # Xgboost: Scalable GPU accelerated learning. arXiv:1806.11248, 2018. + + train_val_indices, test_indices = train_test_split( + range(n_total), test_size=0.2, random_state=0) + train_indices, val_indices = train_test_split( + train_val_indices, test_size=0.2 / 0.6, random_state=0) + + traindf = df.iloc[train_indices] + valdf = df.iloc[val_indices] + testdf = df.iloc[test_indices] + traindf = traindf.sample(frac=1) + + traindf.to_csv('data/train_covertype.csv', index=False, header=False) + valdf.to_csv('data/val_covertype.csv', index=False, header=False) + testdf.to_csv('data/test_covertype.csv', index=False, header=False) + +if __name__ == '__main__': + main() + diff --git a/TensorFlow/contrib/nlp/tabnet/experiment_covertype.py b/TensorFlow/contrib/nlp/tabnet/experiment_covertype.py new file mode 100644 index 0000000000000000000000000000000000000000..8b983cf0ffbce847862c0e6180d54951d8865906 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/experiment_covertype.py @@ -0,0 +1,200 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Experiment to train and evaluate the TabNet model on Forest Covertype.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +import os +from absl import app +import data_helper_covertype +import numpy as np +import tabnet_model +import tensorflow as tf + +# Run Tensorflow on GPU 0 +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Training parameters +TRAIN_FILE = "data/train_covertype.csv" +VAL_FILE = "data/val_covertype.csv" +TEST_FILE = "data/test_covertype.csv" +MAX_STEPS = 1000000 +DISPLAY_STEP = 5000 +VAL_STEP = 10000 +SAVE_STEP = 40000 +INIT_LEARNING_RATE = 0.02 +DECAY_EVERY = 500 +DECAY_RATE = 0.95 +BATCH_SIZE = 16384 +SPARSITY_LOSS_WEIGHT = 0.0001 +GRADIENT_THRESH = 2000.0 +SEED = 1 + + +def main(unused_argv): + + # Fix random seeds + tf.set_random_seed(SEED) + np.random.seed(SEED) + + # Define the TabNet model + tabnet_forest_covertype = tabnet_model.TabNet( + columns=data_helper_covertype.get_columns(), + num_features=data_helper_covertype.NUM_FEATURES, + feature_dim=128, + output_dim=64, + num_decision_steps=6, + relaxation_factor=1.5, + batch_momentum=0.7, + virtual_batch_size=512, + num_classes=data_helper_covertype.NUM_CLASSES) + + column_names = sorted(data_helper_covertype.FEATURE_COLUMNS) + print( + "Ordered column names, corresponding to the indexing in Tensorboard visualization" + ) + for fi in range(len(column_names)): + print(str(fi) + " : " + column_names[fi]) + + # Input sampling + train_batch = data_helper_covertype.input_fn( + TRAIN_FILE, num_epochs=100000, shuffle=True, batch_size=BATCH_SIZE) + val_batch = data_helper_covertype.input_fn( + VAL_FILE, + num_epochs=10000, + shuffle=False, + batch_size=data_helper_covertype.N_VAL_SAMPLES) + test_batch = data_helper_covertype.input_fn( + TEST_FILE, + num_epochs=10000, + shuffle=False, + batch_size=data_helper_covertype.N_TEST_SAMPLES) + + train_iter = train_batch.make_initializable_iterator() + val_iter = val_batch.make_initializable_iterator() + test_iter = test_batch.make_initializable_iterator() + + feature_train_batch, label_train_batch = train_iter.get_next() + feature_val_batch, label_val_batch = val_iter.get_next() + feature_test_batch, label_test_batch = test_iter.get_next() + + # Define the model and losses + + encoded_train_batch, total_entropy = tabnet_forest_covertype.encoder( + feature_train_batch, reuse=False, is_training=True) + + logits_orig_batch, _ = tabnet_forest_covertype.classify( + encoded_train_batch, reuse=False) + + softmax_orig_key_op = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits_orig_batch, labels=label_train_batch)) + + train_loss_op = softmax_orig_key_op + SPARSITY_LOSS_WEIGHT * total_entropy + tf.summary.scalar("Total loss", train_loss_op) + + # Optimization step + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.exponential_decay( + INIT_LEARNING_RATE, + global_step=global_step, + decay_steps=DECAY_EVERY, + decay_rate=DECAY_RATE) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + gvs = optimizer.compute_gradients(train_loss_op) + capped_gvs = [(tf.clip_by_value(grad, -GRADIENT_THRESH, + GRADIENT_THRESH), var) for grad, var in gvs] + train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) + + # Model evaluation + + # Validation performance + encoded_val_batch, _ = tabnet_forest_covertype.encoder( + feature_val_batch, reuse=True, is_training=False) + + _, prediction_val = tabnet_forest_covertype.classify( + encoded_val_batch, reuse=True) + + predicted_labels = tf.cast(tf.argmax(prediction_val, 1), dtype=tf.int32) + val_eq_op = tf.equal(predicted_labels, label_val_batch) + val_acc_op = tf.reduce_mean(tf.cast(val_eq_op, dtype=tf.float32)) + tf.summary.scalar("Val accuracy", val_acc_op) + + # Test performance + encoded_test_batch, _ = tabnet_forest_covertype.encoder( + feature_test_batch, reuse=True, is_training=False) + + _, prediction_test = tabnet_forest_covertype.classify( + encoded_test_batch, reuse=True) + + predicted_labels = tf.cast(tf.argmax(prediction_test, 1), dtype=tf.int32) + test_eq_op = tf.equal(predicted_labels, label_test_batch) + test_acc_op = tf.reduce_mean(tf.cast(test_eq_op, dtype=tf.float32)) + tf.summary.scalar("Test accuracy", test_acc_op) + + # Training setup + model_name = "tabnet_forest_covertype_model" + init = tf.initialize_all_variables() + init_local = tf.local_variables_initializer() + init_table = tf.tables_initializer(name="Initialize_all_tables") + saver = tf.train.Saver() + summaries = tf.summary.merge_all() + + with tf.Session() as sess: + summary_writer = tf.summary.FileWriter("./tflog/" + model_name, sess.graph) + + sess.run(init) + sess.run(init_local) + sess.run(init_table) + sess.run(train_iter.initializer) + sess.run(val_iter.initializer) + sess.run(test_iter.initializer) + + for step in range(1, MAX_STEPS + 1): + if step % DISPLAY_STEP == 0: + _, train_loss, merged_summary = sess.run( + [train_op, train_loss_op, summaries]) + summary_writer.add_summary(merged_summary, step) + print("Step " + str(step) + " , Training Loss = " + + "{:.4f}".format(train_loss)) + else: + _ = sess.run(train_op) + + if step % VAL_STEP == 0: + feed_arr = [ + vars()["summaries"], + vars()["val_acc_op"], + vars()["test_acc_op"] + ] + + val_arr = sess.run(feed_arr) + merged_summary = val_arr[0] + val_acc = val_arr[1] + + print("Step " + str(step) + " , Val Accuracy = " + + "{:.4f}".format(val_acc)) + summary_writer.add_summary(merged_summary, step) + + if step % SAVE_STEP == 0: + saver.save(sess, "./checkpoints/" + model_name + ".ckpt") + + +if __name__ == "__main__": + app.run(main) diff --git a/TensorFlow/contrib/nlp/tabnet/fusion_result.json b/TensorFlow/contrib/nlp/tabnet/fusion_result.json new file mode 100644 index 0000000000000000000000000000000000000000..8c8b8cf65d22b589be547088957c002df64d7598 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/fusion_result.json @@ -0,0 +1,593 @@ +{ + "graph_fusion": { + "MulAddFusionPass": { + "effect_times": "0", + "match_times": "40" + }, + "MulSquareFusionPass": { + "effect_times": "0", + "match_times": "20" + } + }, + "session_and_graph_id": "0_1", + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "20", + "match_times": "20" + } + } +}{ + "graph_fusion": { + "AReduceMeanFusionPass": { + "effect_times": "0", + "match_times": "118" + }, + "AReduceSumFusionPass": { + "effect_times": "0", + "match_times": "128" + }, + "AddNFusionPass": { + "effect_times": "0", + "match_times": "123" + }, + "ApplyAddOutputPass": { + "effect_times": "80", + "match_times": "80" + }, + "BatchMatMulFusionPass": { + "effect_times": "0", + "match_times": "90" + }, + "CastRemoveFusionPass": { + "effect_times": "0", + "match_times": "55" + }, + "ClipFusionRule0": { + "effect_times": "46", + "match_times": "46" + }, + "ConstToAttrGatherV2Fusion": { + "effect_times": "0", + "match_times": "132" + }, + "ConstToAttrPass": { + "effect_times": "208", + "match_times": "262" + }, + "ConstToAttrReduceSumFusion": { + "effect_times": "123", + "match_times": "128" + }, + "ConstToAttrStridedSliceFusion": { + "effect_times": "156", + "match_times": "200" + }, + "ConvConcatFusionPass": { + "effect_times": "0", + "match_times": "45" + }, + "ConvWeightCompressFusionPass": { + "effect_times": "0", + "match_times": "90" + }, + "DreluFusionPass": { + "effect_times": "0", + "match_times": "5" + }, + "FIXPIPEAPREQUANTFUSIONPASS": { + "effect_times": "0", + "match_times": "90" + }, + "FIXPIPEFUSIONPASS": { + "effect_times": "0", + "match_times": "90" + }, + "ForceFp16CastFusionPass": { + "effect_times": "5", + "match_times": "55" + }, + "LayerNormInferenceFusionPass": { + "effect_times": "0", + "match_times": "29" + }, + "LayerNormTrainingFusionPass": { + "effect_times": "0", + "match_times": "29" + }, + "MulAddFusionPass": { + "effect_times": "0", + "match_times": "76" + }, + "MulAddNL2LossFusionPass": { + "effect_times": "0", + "match_times": "87" + }, + "MulAddNPass": { + "effect_times": "0", + "match_times": "87" + }, + "MulGradFusionPass": { + "effect_times": "0", + "match_times": "59" + }, + "MulSquareFusionPass": { + "effect_times": "0", + "match_times": "927" + }, + "PackFusionPass": { + "effect_times": "0", + "match_times": "5" + }, + "Pow2SquareFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "RealDiv2MulsFusionPass": { + "effect_times": "0", + "match_times": "98" + }, + "SparseSoftMaxFusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "SplitConvConcatFusionPass": { + "effect_times": "0", + "match_times": "45" + }, + "StridedSliceGradFusionPass": { + "effect_times": "0", + "match_times": "63" + }, + "StridedSliceRemovePass": { + "effect_times": "0", + "match_times": "200" + }, + "SubFusionPass": { + "effect_times": "0", + "match_times": "227" + }, + "TileConstToAttrFusion": { + "effect_times": "60", + "match_times": "60" + }, + "TopKFusionPass": { + "effect_times": "5", + "match_times": "5" + }, + "UnsortedSegmentSumFusionPass": { + "effect_times": "0", + "match_times": "44" + }, + "ZAttentionQKVGradXFusionPass": { + "effect_times": "0", + "match_times": "21" + }, + "ZConcatExt2FusionPass": { + "effect_times": "45", + "match_times": "45" + }, + "ZConcatv2dFusionPass": { + "effect_times": "0", + "match_times": "45" + }, + "ZReduceMeanVarianceFusionPass": { + "effect_times": "0", + "match_times": "30" + }, + "ZSplitVFusionPass": { + "effect_times": "1", + "match_times": "1" + } + }, + "session_and_graph_id": "0_11", + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "339", + "match_times": "339" + }, + "TbeDynamicElemwiseBroadcastFusionPass": { + "effect_times": "44", + "match_times": "44" + }, + "TbeEltwiseCastFusionPass": { + "effect_times": "4", + "match_times": "4" + }, + "TbeEltwiseFusionPass": { + "effect_times": "82", + "match_times": "82" + }, + "TbeMultiOutputFusionPass": { + "effect_times": "49", + "match_times": "59" + }, + "TbeReduceElemwiseFusionPass": { + "effect_times": "176", + "match_times": "176" + } + } +}{ + "graph_fusion": { + "AReduceMeanFusionPass": { + "effect_times": "0", + "match_times": "126" + }, + "AReduceSumFusionPass": { + "effect_times": "0", + "match_times": "158" + }, + "ASoftmaxFusionPass": { + "effect_times": "0", + "match_times": "2" + }, + "AddNFusionPass": { + "effect_times": "0", + "match_times": "129" + }, + "ApplyAddOutputPass": { + "effect_times": "80", + "match_times": "80" + }, + "ArgMaxV2FusionPass": { + "effect_times": "2", + "match_times": "2" + }, + "BatchMatMulFusionPass": { + "effect_times": "0", + "match_times": "150" + }, + "CastRemoveFusionPass": { + "effect_times": "0", + "match_times": "169" + }, + "ClipFusionRule0": { + "effect_times": "46", + "match_times": "46" + }, + "ConstToAttrGatherV2Fusion": { + "effect_times": "0", + "match_times": "396" + }, + "ConstToAttrPass": { + "effect_times": "228", + "match_times": "282" + }, + "ConstToAttrReduceSumFusion": { + "effect_times": "143", + "match_times": "158" + }, + "ConstToAttrStridedSliceFusion": { + "effect_times": "292", + "match_times": "424" + }, + "ConvConcatFusionPass": { + "effect_times": "0", + "match_times": "47" + }, + "ConvWeightCompressFusionPass": { + "effect_times": "0", + "match_times": "150" + }, + "DreluFusionPass": { + "effect_times": "0", + "match_times": "5" + }, + "FIXPIPEAPREQUANTFUSIONPASS": { + "effect_times": "0", + "match_times": "150" + }, + "FIXPIPEFUSIONPASS": { + "effect_times": "0", + "match_times": "150" + }, + "ForceFp16CastFusionPass": { + "effect_times": "15", + "match_times": "169" + }, + "LayerNormInferenceFusionPass": { + "effect_times": "0", + "match_times": "29" + }, + "LayerNormTrainingFusionPass": { + "effect_times": "0", + "match_times": "29" + }, + "MulAddFusionPass": { + "effect_times": "0", + "match_times": "232" + }, + "MulAddNL2LossFusionPass": { + "effect_times": "0", + "match_times": "91" + }, + "MulAddNPass": { + "effect_times": "0", + "match_times": "91" + }, + "MulGradFusionPass": { + "effect_times": "0", + "match_times": "64" + }, + "MulSquareFusionPass": { + "effect_times": "0", + "match_times": "1206" + }, + "PackFusionPass": { + "effect_times": "0", + "match_times": "15" + }, + "Pow2SquareFusionPass": { + "effect_times": "0", + "match_times": "1" + }, + "RealDiv2MulsFusionPass": { + "effect_times": "0", + "match_times": "108" + }, + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "SoftmaxFusionPass": { + "effect_times": "0", + "match_times": "2" + }, + "SparseSoftMaxFusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "SplitConvConcatFusionPass": { + "effect_times": "0", + "match_times": "47" + }, + "StridedSliceGradFusionPass": { + "effect_times": "0", + "match_times": "63" + }, + "StridedSliceRemovePass": { + "effect_times": "0", + "match_times": "424" + }, + "SubFusionPass": { + "effect_times": "0", + "match_times": "295" + }, + "TileConstToAttrFusion": { + "effect_times": "60", + "match_times": "60" + }, + "TopKFusionPass": { + "effect_times": "15", + "match_times": "15" + }, + "TransposedUpdateFusionPass": { + "effect_times": "4", + "match_times": "4" + }, + "UnsortedSegmentSumFusionPass": { + "effect_times": "0", + "match_times": "44" + }, + "ZAttentionQKVGradXFusionPass": { + "effect_times": "0", + "match_times": "21" + }, + "ZConcatExt2FusionPass": { + "effect_times": "47", + "match_times": "47" + }, + "ZConcatv2dFusionPass": { + "effect_times": "0", + "match_times": "47" + }, + "ZReduceMeanVarianceFusionPass": { + "effect_times": "0", + "match_times": "30" + }, + "ZSplitVFusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "softmaxTransFusionPass": { + "effect_times": "2", + "match_times": "2" + } + }, + "session_and_graph_id": "0_21", + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "436", + "match_times": "446" + }, + "TbeDynamicElemwiseBroadcastFusionPass": { + "effect_times": "44", + "match_times": "44" + }, + "TbeEltwiseCastFusionPass": { + "effect_times": "12", + "match_times": "12" + }, + "TbeEltwiseFusionPass": { + "effect_times": "112", + "match_times": "112" + }, + "TbeMultiOutputFusionPass": { + "effect_times": "55", + "match_times": "95" + }, + "TbeReduceElemwiseFusionPass": { + "effect_times": "176", + "match_times": "176" + } + } +}{ + "graph_fusion": { + "AReduceMeanFusionPass": { + "effect_times": "0", + "match_times": "68" + }, + "AReduceSumFusionPass": { + "effect_times": "0", + "match_times": "35" + }, + "ASoftmaxFusionPass": { + "effect_times": "0", + "match_times": "2" + }, + "AddNFusionPass": { + "effect_times": "0", + "match_times": "7" + }, + "ArgMaxV2FusionPass": { + "effect_times": "2", + "match_times": "2" + }, + "BatchMatMulFusionPass": { + "effect_times": "0", + "match_times": "90" + }, + "CastRemoveFusionPass": { + "effect_times": "0", + "match_times": "168" + }, + "ConstToAttrGatherV2Fusion": { + "effect_times": "0", + "match_times": "396" + }, + "ConstToAttrPass": { + "effect_times": "85", + "match_times": "85" + }, + "ConstToAttrReduceSumFusion": { + "effect_times": "20", + "match_times": "35" + }, + "ConstToAttrStridedSliceFusion": { + "effect_times": "204", + "match_times": "336" + }, + "ConvConcatFusionPass": { + "effect_times": "0", + "match_times": "3" + }, + "ConvWeightCompressFusionPass": { + "effect_times": "0", + "match_times": "90" + }, + "FIXPIPEAPREQUANTFUSIONPASS": { + "effect_times": "0", + "match_times": "90" + }, + "FIXPIPEFUSIONPASS": { + "effect_times": "0", + "match_times": "90" + }, + "ForceFp16CastFusionPass": { + "effect_times": "15", + "match_times": "168" + }, + "LayerNormInferenceFusionPass": { + "effect_times": "0", + "match_times": "29" + }, + "MulAddFusionPass": { + "effect_times": "0", + "match_times": "229" + }, + "MulAddNL2LossFusionPass": { + "effect_times": "0", + "match_times": "3" + }, + "MulAddNPass": { + "effect_times": "0", + "match_times": "3" + }, + "MulGradFusionPass": { + "effect_times": "0", + "match_times": "5" + }, + "MulSquareFusionPass": { + "effect_times": "0", + "match_times": "424" + }, + "PackFusionPass": { + "effect_times": "0", + "match_times": "15" + }, + "RealDiv2MulsFusionPass": { + "effect_times": "0", + "match_times": "15" + }, + "RefreshInt64ToInt32FusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "SoftmaxFusionPass": { + "effect_times": "0", + "match_times": "2" + }, + "SparseSoftMaxFusionPass": { + "effect_times": "1", + "match_times": "1" + }, + "SplitConvConcatFusionPass": { + "effect_times": "0", + "match_times": "3" + }, + "StridedSliceRemovePass": { + "effect_times": "0", + "match_times": "336" + }, + "SubFusionPass": { + "effect_times": "0", + "match_times": "117" + }, + "TopKFusionPass": { + "effect_times": "15", + "match_times": "15" + }, + "TransposedUpdateFusionPass": { + "effect_times": "4", + "match_times": "4" + }, + "ZConcatExt2FusionPass": { + "effect_times": "3", + "match_times": "3" + }, + "ZConcatv2dFusionPass": { + "effect_times": "0", + "match_times": "3" + }, + "ZReduceMeanVarianceFusionPass": { + "effect_times": "0", + "match_times": "30" + }, + "softmaxTransFusionPass": { + "effect_times": "2", + "match_times": "2" + } + }, + "session_and_graph_id": "0_31", + "ub_fusion": { + "AutomaticUbFusion": { + "effect_times": "147", + "match_times": "162" + }, + "TbeEltwiseCastFusionPass": { + "effect_times": "12", + "match_times": "12" + }, + "TbeEltwiseFusionPass": { + "effect_times": "35", + "match_times": "35" + }, + "TbeMultiOutputFusionPass": { + "effect_times": "33", + "match_times": "63" + }, + "TbeReduceElemwiseFusionPass": { + "effect_times": "30", + "match_times": "30" + } + } +} \ No newline at end of file diff --git a/TensorFlow/contrib/nlp/tabnet/fusion_switch.cfg b/TensorFlow/contrib/nlp/tabnet/fusion_switch.cfg new file mode 100644 index 0000000000000000000000000000000000000000..01472ddc01fd6fd56c3eae8c0525534738e7b327 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/fusion_switch.cfg @@ -0,0 +1,10 @@ +{ + "Switch":{ + "GraphFusion":{ + "ALL":"off" + }, + "UBFusion":{ + "ALL":"off" + } + } +} diff --git a/TensorFlow/contrib/nlp/tabnet/myTest.py b/TensorFlow/contrib/nlp/tabnet/myTest.py new file mode 100644 index 0000000000000000000000000000000000000000..0dcc356e4594eb6c2831cfe37d0a09d81616e8f2 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/myTest.py @@ -0,0 +1,2 @@ +from npu_bridge.npu_init import * +print(3) \ No newline at end of file diff --git a/TensorFlow/contrib/nlp/tabnet/requirements.txt b/TensorFlow/contrib/nlp/tabnet/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f0e62b9e64efc87a144186b129e0b87f2a40ec98 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/requirements.txt @@ -0,0 +1,4 @@ +tensorflow-gpu==1.11.0 +absl-py>=0.5.0 +numpy==1.15.1 +wget>=3.2 diff --git a/TensorFlow/contrib/nlp/tabnet/run.sh b/TensorFlow/contrib/nlp/tabnet/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..c5adb3ac3a0b169ad0542d83f4001c2b4a87df3c --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/run.sh @@ -0,0 +1,24 @@ +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +#!/bin/bash +set -e +set -x + +virtualenv -p python3 . +source ./bin/activate + +pip install tensorflow +pip install -r requirements.txt +python -m test_experiment_covertype diff --git a/TensorFlow/contrib/nlp/tabnet/tabnet_model.py b/TensorFlow/contrib/nlp/tabnet/tabnet_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b0ea6f719bb7e62c51fc758c9ee117157e80e990 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/tabnet_model.py @@ -0,0 +1,315 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""TabNet model.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import tensorflow as tf + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn + +def sparsemax(logits, name=None): + """Computes sparsemax activations [1]. + For each batch `i` and class `j` we have + $$sparsemax[i, j] = max(logits[i, j] - tau(logits[i, :]), 0)$$ + [1]: https://arxiv.org/abs/1602.02068 + Args: + logits: A `Tensor`. Must be one of the following types: `half`, `float32`, + `float64`. + name: A name for the operation (optional). + Returns: + A `Tensor`. Has the same type as `logits`. + """ + + with ops.name_scope(name, "sparsemax", [logits]) as name: + logits = ops.convert_to_tensor(logits, name="logits") + obs = array_ops.shape(logits)[0] + dims = array_ops.shape(logits)[1] + + # In the paper, they call the logits z. + # The mean(logits) can be substracted from logits to make the algorithm + # more numerically stable. the instability in this algorithm comes mostly + # from the z_cumsum. Substacting the mean will cause z_cumsum to be close + # to zero. However, in practise the numerical instability issues are very + # minor and substacting the mean causes extra issues with inf and nan + # input. + z = logits + + # sort z + z_sorted, _ = nn.top_k(z, k=dims) + + # calculate k(z) + z_cumsum = math_ops.cumsum(z_sorted, axis=1) + k = math_ops.range( + 1, math_ops.cast(dims, logits.dtype) + 1, dtype=logits.dtype) + z_check = 1 + k * z_sorted > z_cumsum + # because the z_check vector is always [1,1,...1,0,0,...0] finding the + # (index + 1) of the last `1` is the same as just summing the number of 1. + k_z = math_ops.reduce_sum(math_ops.cast(z_check, dtypes.int32), axis=1) + + # calculate tau(z) + # If there are inf values or all values are -inf, the k_z will be zero, + # this is mathematically invalid and will also cause the gather_nd to fail. + # Prevent this issue for now by setting k_z = 1 if k_z = 0, this is then + # fixed later (see p_safe) by returning p = nan. This results in the same + # behavior as softmax. + k_z_safe = math_ops.maximum(k_z, 1) + indices = array_ops.stack([math_ops.range(0, obs), k_z_safe - 1], axis=1) + tau_sum = array_ops.gather_nd(z_cumsum, indices) + tau_z = (tau_sum - 1) / math_ops.cast(k_z, logits.dtype) + + # calculate p + p = math_ops.maximum( + math_ops.cast(0, logits.dtype), z - tau_z[:, array_ops.newaxis]) + # If k_z = 0 or if z = nan, then the input is invalid + p_safe = array_ops.where( + math_ops.logical_or( + math_ops.equal(k_z, 0), math_ops.is_nan(z_cumsum[:, -1])), + array_ops.fill([obs, dims], math_ops.cast(0, logits.dtype)), + p) + + return p_safe + +def glu(act, n_units): + """Generalized linear unit nonlinear activation.""" + return act[:, :n_units] * tf.nn.sigmoid(act[:, n_units:]) + + +class TabNet(object): + """TabNet model class.""" + + def __init__(self, + columns, + num_features, + feature_dim, + output_dim, + num_decision_steps, + relaxation_factor, + batch_momentum, + virtual_batch_size, + num_classes, + epsilon=0.00001): + """Initializes a TabNet instance. + + Args: + columns: The Tensorflow column names for the dataset. + num_features: The number of input features (i.e the number of columns for + tabular data assuming each feature is represented with 1 dimension). + feature_dim: Dimensionality of the hidden representation in feature + transformation block. Each layer first maps the representation to a + 2*feature_dim-dimensional output and half of it is used to determine the + nonlinearity of the GLU activation where the other half is used as an + input to GLU, and eventually feature_dim-dimensional output is + transferred to the next layer. + output_dim: Dimensionality of the outputs of each decision step, which is + later mapped to the final classification or regression output. + num_decision_steps: Number of sequential decision steps. + relaxation_factor: Relaxation factor that promotes the reuse of each + feature at different decision steps. When it is 1, a feature is enforced + to be used only at one decision step and as it increases, more + flexibility is provided to use a feature at multiple decision steps. + batch_momentum: Momentum in ghost batch normalization. + virtual_batch_size: Virtual batch size in ghost batch normalization. The + overall batch size should be an integer multiple of virtual_batch_size. + num_classes: Number of output classes. + epsilon: A small number for numerical stability of the entropy calcations. + + Returns: + A TabNet instance. + """ + + self.columns = columns + self.num_features = num_features + self.feature_dim = feature_dim + self.output_dim = output_dim + self.num_decision_steps = num_decision_steps + self.relaxation_factor = relaxation_factor + self.batch_momentum = batch_momentum + self.virtual_batch_size = virtual_batch_size + self.num_classes = num_classes + self.epsilon = epsilon + + def encoder(self, data, reuse, is_training): + """TabNet encoder model.""" + + with tf.variable_scope("Encoder", reuse=reuse): + + # Reads and normalizes input features. + features = tf.feature_column.input_layer(data, self.columns) + features = tf.layers.batch_normalization( + features, training=is_training, momentum=self.batch_momentum) + batch_size = tf.shape(features)[0] + + # Initializes decision-step dependent variables. + output_aggregated = tf.zeros([batch_size, self.output_dim]) + masked_features = features + mask_values = tf.zeros([batch_size, self.num_features]) + aggregated_mask_values = tf.zeros([batch_size, self.num_features]) + complemantary_aggregated_mask_values = tf.ones( + [batch_size, self.num_features]) + total_entropy = 0 + + if is_training: + v_b = self.virtual_batch_size + else: + v_b = 1 + + for ni in range(self.num_decision_steps): + + # Feature transformer with two shared and two decision step dependent + # blocks is used below. + + reuse_flag = (ni > 0) + + transform_f1 = tf.layers.dense( + masked_features, + self.feature_dim * 2, + name="Transform_f1", + reuse=reuse_flag, + use_bias=False) + transform_f1 = tf.layers.batch_normalization( + transform_f1, + training=is_training, + momentum=self.batch_momentum, + virtual_batch_size=v_b) + transform_f1 = glu(transform_f1, self.feature_dim) + + transform_f2 = tf.layers.dense( + transform_f1, + self.feature_dim * 2, + name="Transform_f2", + reuse=reuse_flag, + use_bias=False) + transform_f2 = tf.layers.batch_normalization( + transform_f2, + training=is_training, + momentum=self.batch_momentum, + virtual_batch_size=v_b) + transform_f2 = (glu(transform_f2, self.feature_dim) + + transform_f1) * np.sqrt(0.5) + + transform_f3 = tf.layers.dense( + transform_f2, + self.feature_dim * 2, + name="Transform_f3" + str(ni), + use_bias=False) + transform_f3 = tf.layers.batch_normalization( + transform_f3, + training=is_training, + momentum=self.batch_momentum, + virtual_batch_size=v_b) + transform_f3 = (glu(transform_f3, self.feature_dim) + + transform_f2) * np.sqrt(0.5) + + transform_f4 = tf.layers.dense( + transform_f3, + self.feature_dim * 2, + name="Transform_f4" + str(ni), + use_bias=False) + transform_f4 = tf.layers.batch_normalization( + transform_f4, + training=is_training, + momentum=self.batch_momentum, + virtual_batch_size=v_b) + transform_f4 = (glu(transform_f4, self.feature_dim) + + transform_f3) * np.sqrt(0.5) + + if ni > 0: + + decision_out = tf.nn.relu(transform_f4[:, :self.output_dim]) + + # Decision aggregation. + output_aggregated += decision_out + + # Aggregated masks are used for visualization of the + # feature importance attributes. + scale_agg = tf.reduce_sum( + decision_out, axis=1, keep_dims=True) / ( + self.num_decision_steps - 1) + aggregated_mask_values += mask_values * scale_agg + + features_for_coef = (transform_f4[:, self.output_dim:]) + + if ni < self.num_decision_steps - 1: + + # Determines the feature masks via linear and nonlinear + # transformations, taking into account of aggregated feature use. + mask_values = tf.layers.dense( + features_for_coef, + self.num_features, + name="Transform_coef" + str(ni), + use_bias=False) + mask_values = tf.layers.batch_normalization( + mask_values, + training=is_training, + momentum=self.batch_momentum, + virtual_batch_size=v_b) + mask_values *= complemantary_aggregated_mask_values + # mask_values = tf.contrib.sparsemax.sparsemax(mask_values) + mask_values = sparsemax(mask_values) + + # Relaxation factor controls the amount of reuse of features between + # different decision blocks and updated with the values of + # coefficients. + complemantary_aggregated_mask_values *= ( + self.relaxation_factor - mask_values) + + # Entropy is used to penalize the amount of sparsity in feature + # selection. + total_entropy += tf.reduce_mean( + tf.reduce_sum( + -mask_values * tf.log(mask_values + self.epsilon), + axis=1)) / ( + self.num_decision_steps - 1) + + # Feature selection. + masked_features = tf.multiply(mask_values, features) + + # Visualization of the feature selection mask at decision step ni + tf.summary.image( + "Mask for step" + str(ni), + tf.expand_dims(tf.expand_dims(mask_values, 0), 3), + max_outputs=1) + + # Visualization of the aggregated feature importances + tf.summary.image( + "Aggregated mask", + tf.expand_dims(tf.expand_dims(aggregated_mask_values, 0), 3), + max_outputs=1) + + return output_aggregated, total_entropy + + def classify(self, activations, reuse): + """TabNet classify block.""" + + with tf.variable_scope("Classify", reuse=reuse): + logits = tf.layers.dense(activations, self.num_classes, use_bias=False) + predictions = tf.nn.softmax(logits) + return logits, predictions + + def regress(self, activations, reuse): + """TabNet regress block.""" + + with tf.variable_scope("Regress", reuse=reuse): + predictions = tf.layers.dense(activations, 1) + return predictions diff --git a/TensorFlow/contrib/nlp/tabnet/test_experiment_covertype.py b/TensorFlow/contrib/nlp/tabnet/test_experiment_covertype.py new file mode 100644 index 0000000000000000000000000000000000000000..846e517c7201fef6fe79a0db900cbc7185ed14b9 --- /dev/null +++ b/TensorFlow/contrib/nlp/tabnet/test_experiment_covertype.py @@ -0,0 +1,218 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Low-resource test for a small-scale TabNet model on Forest Covertype.""" +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from npu_bridge.npu_init import * +import time + +import os +from absl import app +import data_helper_covertype +import numpy as np +import tabnet_model +import tensorflow as tf + +# Run Tensorflow on GPU 0 +os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" +os.environ["CUDA_VISIBLE_DEVICES"] = "0" + +# Training parameters +TRAIN_FILE = "data/train_covertype.csv" +VAL_FILE = "data/val_covertype.csv" +TEST_FILE = "data/test_covertype.csv" +MAX_STEPS = 10 +DISPLAY_STEP = 5 +VAL_STEP = 5 +SAVE_STEP = 40000 +INIT_LEARNING_RATE = 0.02 +DECAY_EVERY = 500 +DECAY_RATE = 0.95 +BATCH_SIZE = 32 +SPARSITY_LOSS_WEIGHT = 0.0001 +GRADIENT_THRESH = 2000.0 +SEED = 1 + + +def main(unused_argv): + + # Fix random seeds + tf.set_random_seed(SEED) + np.random.seed(SEED) + + # Define the TabNet model + tabnet_forest_covertype = tabnet_model.TabNet( + columns=data_helper_covertype.get_columns(), + num_features=data_helper_covertype.NUM_FEATURES, + feature_dim=4, + output_dim=2, + num_decision_steps=6, + relaxation_factor=1.5, + batch_momentum=0.7, + virtual_batch_size=4, + num_classes=data_helper_covertype.NUM_CLASSES) + + column_names = sorted(data_helper_covertype.FEATURE_COLUMNS) + print( + "Ordered column names, corresponding to the indexing in Tensorboard visualization" + ) + for fi in range(len(column_names)): + print(str(fi) + " : " + column_names[fi]) + + # Input sampling + train_batch = data_helper_covertype.input_fn( + TRAIN_FILE, num_epochs=100000, shuffle=True, batch_size=BATCH_SIZE) + val_batch = data_helper_covertype.input_fn( + VAL_FILE, + num_epochs=10000, + shuffle=False, + batch_size=data_helper_covertype.N_VAL_SAMPLES) + test_batch = data_helper_covertype.input_fn( + TEST_FILE, + num_epochs=10000, + shuffle=False, + batch_size=data_helper_covertype.N_TEST_SAMPLES) + + train_iter = train_batch.make_initializable_iterator() + val_iter = val_batch.make_initializable_iterator() + test_iter = test_batch.make_initializable_iterator() + + feature_train_batch, label_train_batch = train_iter.get_next() + feature_val_batch, label_val_batch = val_iter.get_next() + feature_test_batch, label_test_batch = test_iter.get_next() + + # Define the model and losses + + encoded_train_batch, total_entropy = tabnet_forest_covertype.encoder( + feature_train_batch, reuse=False, is_training=True) + + logits_orig_batch, _ = tabnet_forest_covertype.classify( + encoded_train_batch, reuse=False) + + softmax_orig_key_op = tf.reduce_mean( + tf.nn.sparse_softmax_cross_entropy_with_logits( + logits=logits_orig_batch, labels=label_train_batch)) + + train_loss_op = softmax_orig_key_op + SPARSITY_LOSS_WEIGHT * total_entropy + tf.summary.scalar("Total loss", train_loss_op) + + # Optimization step + global_step = tf.train.get_or_create_global_step() + learning_rate = tf.train.exponential_decay( + INIT_LEARNING_RATE, + global_step=global_step, + decay_steps=DECAY_EVERY, + decay_rate=DECAY_RATE) + optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) + update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) + with tf.control_dependencies(update_ops): + gvs = optimizer.compute_gradients(train_loss_op) + capped_gvs = [(tf.clip_by_value(grad, -GRADIENT_THRESH, + GRADIENT_THRESH), var) for grad, var in gvs] + train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step) + + # Model evaluation + + # Validation performance + encoded_val_batch, _ = tabnet_forest_covertype.encoder( + feature_val_batch, reuse=True, is_training=False) + + _, prediction_val = tabnet_forest_covertype.classify( + encoded_val_batch, reuse=True) + + predicted_labels = tf.cast(tf.argmax(prediction_val, 1), dtype=tf.int32) + val_eq_op = tf.equal(predicted_labels, label_val_batch) + val_acc_op = tf.reduce_mean(tf.cast(val_eq_op, dtype=tf.float32)) + tf.summary.scalar("Val accuracy", val_acc_op) + + # Test performance + encoded_test_batch, _ = tabnet_forest_covertype.encoder( + feature_test_batch, reuse=True, is_training=False) + + _, prediction_test = tabnet_forest_covertype.classify( + encoded_test_batch, reuse=True) + + predicted_labels = tf.cast(tf.argmax(prediction_test, 1), dtype=tf.int32) + test_eq_op = tf.equal(predicted_labels, label_test_batch) + test_acc_op = tf.reduce_mean(tf.cast(test_eq_op, dtype=tf.float32)) + tf.summary.scalar("Test accuracy", test_acc_op) + + # Training setup + model_name = "tabnet_forest_covertype_model" + init = tf.initialize_all_variables() + init_local = tf.local_variables_initializer() + init_table = tf.tables_initializer(name="Initialize_all_tables") + saver = tf.train.Saver() + summaries = tf.summary.merge_all() + + config = tf.ConfigProto() + custom_op = config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + # custom_op.parameter_map["enable_data_pre_proc"].b = True + # custom_op.parameter_map["fusion_switch_file"].s = tf.compat.as_bytes("./fusion_switch.cfg") + config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 必须显式关闭 + config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF # 必须显式关闭 + + with tf.Session(config=config) as sess: + summary_writer = tf.summary.FileWriter("./tflog/" + model_name, sess.graph) + + sess.run(init) + sess.run(init_local) + sess.run(init_table) + sess.run(train_iter.initializer) + sess.run(val_iter.initializer) + sess.run(test_iter.initializer) + + for step in range(1, MAX_STEPS + 1): + start_time = time.time() + if step % DISPLAY_STEP == 0: + _, train_loss, merged_summary = sess.run( + [train_op, train_loss_op, summaries]) + summary_writer.add_summary(merged_summary, step) + print("Step " + str(step) + " , Training Loss = " + + "{:.4f}".format(train_loss)) + else: + _ = sess.run(train_op) + + end_time = time.time() + print("Step " + str(step) + " , Step Training Time = " + + "{:.4f}".format(end_time - start_time)) + # print("Step " + str(step) + " , Step Loss = " + train_loss + " , Step Training Time = " + + # "{:.4f}".format(end_time - start_time)) + + if step % VAL_STEP == 0: + feed_arr = [ + vars()["summaries"], + vars()["val_acc_op"], + vars()["test_acc_op"] + ] + + val_arr = sess.run(feed_arr) + merged_summary = val_arr[0] + val_acc = val_arr[1] + + print("Step " + str(step) + " , Val Accuracy = " + + "{:.4f}".format(val_acc)) + summary_writer.add_summary(merged_summary, step) + + if step % SAVE_STEP == 0: + saver.save(sess, "./checkpoints/" + model_name + ".ckpt") + + +if __name__ == "__main__": + app.run(main)