diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/.keep b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/README.md b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/README.md new file mode 100644 index 0000000000000000000000000000000000000000..25ea1bb7511d72cb317ebb08f7a17ff3e5cfe15d --- /dev/null +++ b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/README.md @@ -0,0 +1,89 @@ +# SVD_ID2019_for ACL + +#### 概述 +给定两个三维点云图,利用SVD正交化过程SVDO+(M)将其投射到SO(3)上,要求网络预测最佳对齐它们的3D旋转。 + +- 开源代码:训练获取 + + https://github.com/google-research/google-research/tree/master/special_orthogonalization。 + +- 参考论文: + + [An Analysis of SVD for Deep Rotation Estimation](https://arxiv.org/abs/2006.14616) + +#### 数据集 + +训练数据集 points + +测试数据集 points_test + +旋转后数据集 points_test_modified + +#### 模型固化 +-直接获取 + +直接下载获取,百度网盘 +链接:https://pan.baidu.com/s/17zKWq2aY06cF9IQW6htn_A +提取码:2019 +-训练获取 + +训练获取 +训练完成saved_model模型网盘链接:https://pan.baidu.com/s/1Y4ato6Ob-6-rcXr31AvgoA +提取码:2019 + +1.按照SVD_ID2019_for_Tensorflow中的流程训练,模型保存为saved_model格式 + +2.将saved_model格式文件冻结为pb文件(需要在freeze.py文件中修改路径) + +python freeze.py + +得到svd.pd + +#### 使用ATC工具将pb文件转换为om模型 +命令行代码示例 + +atc --model=/home/test_user04/svd.pb --framework=3 --output=/home/test_user04/svd --soc_version=Ascend310 --input_shape="data_1:1,1410,3;rot_1:1,3,3" + +注意所使用机器的Ascend的型号 + +模型直接下载百度网盘链接:https://pan.baidu.com/s/14-m0ZhPQyIr8enpUgVytpg +提取码:2019 + +得到svd.om +#### 制作数据集 +-直接下载,数据集在svd_inference/data_1 + + + +-自己制作 +原数据链接:链接:https://pan.baidu.com/s/1aGAO3os8ifDnYm1yXrxndQ +提取码:2019 + +使用pts2txt制作数据集(注意修改数据路径,数据路径为/xxx/points_test_modified/*.pts,注意修改产生数据集后的路径) + +python pts2txt.py + +#### 获取离线推理输出bin文件 + +推理文件在压缩包svd_inference中,下载百度网盘链接:https://pan.baidu.com/s/1OfCxHMUJcnyqp2IcvV3eWg +提取码:2019 + +脚本在src文件夹中,直接运行 + +python svdom_inference.py + +推理结果直接下载网盘链接:https://pan.baidu.com/s/1NFNfJkTUW4u7YJcaHK9mLw +提取码:2019 + +#### 使用输出的bin文件验证推理精度 + +运行脚本 + +python calc_acc.py + +得到推理精度:3.150504164928697 + +与在线推理精度近似 + +关于以上所有文件的百度网盘链接:https://pan.baidu.com/s/1sR8gYK8jM6xCZwbq7eK50A +提取码:2019 \ No newline at end of file diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/calc_acc.py b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/calc_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..b71c645e3be27303ba1d56628328983a05611541 --- /dev/null +++ b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/calc_acc.py @@ -0,0 +1,33 @@ +import numpy as np +import glob +import pathlib +import os +import utils_gpu +import tensorflow as tf + +INFERENCE_DIR = "C:/Users/1young/Desktop/svd_output/svd_output/*.bin" +TEST_DIR = "D:/svd_code2/data_1" +tf.enable_eager_execution() + + +def main(): + input_test_files = glob.glob(INFERENCE_DIR) + mean_err = [] + for in_file in input_test_files: + out_file_prefix = pathlib.Path(in_file).stem + rot_path = os.path.join(TEST_DIR,'%s.txt' % out_file_prefix) + rot = np.loadtxt(rot_path)[:3,:].reshape((-1,3,3)) + r = np.loadtxt(in_file).reshape((-1,3,3)) + theta = utils_gpu.relative_angle(rot, r) + mean_theta = tf.reduce_mean(theta) + mean_theta_deg = mean_theta * 180.0 / np.pi + mean_theta_deg = mean_theta_deg.numpy() + mean_err.append(mean_theta_deg) + print("the mean of error") + print(np.mean(np.array(mean_err))) + +if __name__=="__main__": + main() + + + diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/frezze.py b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/frezze.py new file mode 100644 index 0000000000000000000000000000000000000000..70014c274a48fe831a544872df55ec61accd3998 --- /dev/null +++ b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/frezze.py @@ -0,0 +1,50 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tensorflow as tf +from tensorflow.python.tools import freeze_graph +#from npu_bridge.npu_init import * + +saved_model_path = 'C:/Users/1young/Desktop/1666004668' +def main(): + freeze_graph.freeze_graph( + input_saved_model_dir=saved_model_path, + output_node_names='rotation_matrix', + output_graph='svd.pb', + initializer_nodes='', + input_graph= None, + input_saver= False, + input_binary=False, + input_checkpoint=None, + restore_op_name=None, + filename_tensor_name=None, + clear_devices=False, + input_meta_graph=False) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/pts2txt.py b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/pts2txt.py new file mode 100644 index 0000000000000000000000000000000000000000..01f84fb3af7f203150f1a86204c81f76f61228cb --- /dev/null +++ b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/pts2txt.py @@ -0,0 +1,44 @@ +import os +import numpy as np +import tensorflow as tf +import glob +import pathlib + +inputpath = "D:\special_orthogonalization\points_test_modified\*.pts" +outputpath = "D:\svd_code2\data_1" + + +tf.enable_eager_execution() + +def data_processing(pts_path): + file_buffer = tf.read_file(pts_path) + lines = tf.string_split([file_buffer], delimiter='\n') + lines1 = tf.string_split(lines.values, delimiter='\r') + values = tf.stack(tf.decode_csv(lines1.values, + record_defaults=[[0.0], [0.0], [0.0]], field_delim=' ')) + values = tf.transpose(values) # 3xN --> Nx3. + diff_num = 1414-tf.shape(values)[0] + repeat_pts = tf.tile(tf.reshape(values[4,:],(1,-1)),[diff_num,1]) + + values = tf.concat([values,repeat_pts],axis=0) + # First three rows are the rotation matrix, remaining rows the point cloud. + values = tf.concat([values, repeat_pts], axis=0) + # First three rows are the rotation matrix, remaining rows the point cloud. + return values.numpy() + +def file_save(path,datapath): + input_test_files = glob.glob(path) + for in_file in input_test_files: + out_file_prefix = pathlib.Path(in_file).stem + values = data_processing(in_file) + out_file1 = os.path.join( + datapath, '%s.txt' % out_file_prefix) + np.savetxt(out_file1,values) + +def main(): + os.makedirs(outputpath, exist_ok=True) + file_save(inputpath,outputpath) + +if __name__ == '__main__': + main() + diff --git a/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/utils_gpu.py b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/utils_gpu.py new file mode 100644 index 0000000000000000000000000000000000000000..ff47d69c208dfa44d00c739d81dec4532675d64e --- /dev/null +++ b/ACL_TensorFlow/contrib/cv/SVD-ID2019_for_ACL/utils_gpu.py @@ -0,0 +1,115 @@ +# coding=utf-8 +# Copyright 2021 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Utility functions.""" +import numpy as np +from scipy.stats import special_ortho_group +import tensorflow as tf + + +def relative_angle(r1, r2): + """Relative angle (radians) between 3D rotation matrices.""" + rel_rot = tf.matmul(tf.transpose(r1, perm=[0, 2, 1]), r2) + trace = rel_rot[:, 0, 0] + rel_rot[:, 1, 1] + rel_rot[:, 2, 2] + cos_theta = (trace - 1.0) / 2.0 + cos_theta = tf.minimum(cos_theta, tf.ones_like(cos_theta)) + cos_theta = tf.maximum(cos_theta, (-1.0) * tf.ones_like(cos_theta)) + theta = tf.acos(cos_theta) + return theta + + +def random_rotation_benchmark_np(n): + """Sample a random 3D rotation by method used in Zhou et al, CVPR19. + + This numpy function is a copy of the PyTorch function + get_sampled_rotation_matrices_by_axisAngle() in the code made available + for Zhou et al, CVPR19, at https://github.com/papagina/RotationContinuity/. + + Args: + n: the number of rotation matrices to return. + + Returns: + [n, 3, 3] np array. + """ + theta = np.random.uniform(-1, 1, n) * np.pi + sin = np.sin(theta) + axis = np.random.randn(n, 3) + axis = axis / np.maximum(np.linalg.norm(axis, axis=-1, keepdims=True), 1e-7) + qw = np.cos(theta) + qx = axis[:, 0] * sin + qy = axis[:, 1] * sin + qz = axis[:, 2] * sin + + xx = qx*qx + yy = qy*qy + zz = qz*qz + xy = qx*qy + xz = qx*qz + yz = qy*qz + xw = qx*qw + yw = qy*qw + zw = qz*qw + + row0 = np.stack((1-2*yy-2*zz, 2*xy-2*zw, 2*xz+2*yw), axis=-1) + row1 = np.stack((2*xy+2*zw, 1-2*xx-2*zz, 2*yz-2*xw), axis=-1) + row2 = np.stack((2*xz-2*yw, 2*yz+2*xw, 1-2*xx-2*yy), axis=-1) + matrix = np.stack((row0, row1, row2), axis=1) + + return matrix + + +def random_rotation_benchmark(n): + """A TF wrapper for random_rotation_benchmark_np().""" + mat = tf.py_func( + func=lambda t: np.float32(random_rotation_benchmark_np(t)), + inp=[n], + Tout=tf.float32, + stateful=True) + return tf.reshape(mat, (n, 3, 3)) + + +def random_rotation(n): + """Sample rotations from a uniform distribution on SO(3).""" + mat = tf.py_func( + func=lambda t: np.float32(special_ortho_group.rvs(3, size=t)), + inp=[n], + Tout=tf.float32, + stateful=True) + return tf.reshape(mat, (n, 3, 3)) + + +def symmetric_orthogonalization(x): + """Maps 9D input vectors onto SO(3) via symmetric orthogonalization.""" + # Innner dimensions of the input should be 3x3 matrices. + m = tf.reshape(x, (-1, 3, 3)) + _, u, v = tf.svd(m) + det = tf.linalg.det(tf.matmul(u, v, transpose_b=True)) + r = tf.matmul( + tf.concat([u[:, :, :-1], u[:, :, -1:] * tf.reshape(det, [-1, 1, 1])], 2), + v, transpose_b=True) + return r + + +def gs_orthogonalization(p6): + """Gram-Schmidt orthogonalization from 6D input.""" + # Input should be [batch_size, 6] + x = p6[:, 0:3] + y = p6[:, 3:6] + xn = tf.math.l2_normalize(x, axis=-1) + z = tf.linalg.cross(xn, y) + zn = tf.math.l2_normalize(z, axis=-1) + y = tf.linalg.cross(zn, xn) + r = tf.stack([xn, y, zn], -1) + return r diff --git a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/README.md b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/README.md index 4a337f83165faf964367f33d01450728d8c5eb7c..cbb95c3f1388153ad36c9e415694ca4391ce4a8a 100644 --- a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/README.md +++ b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/README.md @@ -27,12 +27,9 @@ [An Analysis of SVD for Deep Rotation Estimation](https://arxiv.org/abs/2006.14616) - 参考实现: - gpu最后训练出来的模型在 - obs://cann-id2019/gpu - - 数据在 - obs://cann-id2019/dataset/ + 数据下载百度网盘链接:https://pan.baidu.com/s/1up1HW6McgSor3JF0yqQZSA +提取码:2019 共有3个数据集 @@ -42,12 +39,11 @@ 第一步旋转后的数据集 test_points_modified - npu训练出来的模型在 + npu训练出来的模型下载百度网盘链接:https://pan.baidu.com/s/1JU1koZR7uGlkKfRYIk8tsw +提取码:2019 + - obs://cann-id2019/dataset/output - - 相关代码均以上传,在NPU和GPU文件夹中可以找到 - 相关迁移的工作: 在进行代码迁移到NPU上时,输入的训练数据为点云数据,点云数据的shape为(N,3),其中N并不是固定的,因此在NPU上存在动态shape的问题,导致模型训练无法正常进行。我们为此想了三个解决方法:1、找出所有点云数据中最小的N,对于大于N的点云数据,仅取前N行的数据输入训练。2、找到所有点云数据中最大的N,对于小于N的点云进行补0操作,将所有数据固定为最大的N后,输入网络进行训练。3、找到所有点云数据中最大的N,对小于N的点云数据,从原数据中选择一个点云进行填补至行数为N,再将数据输入网络进行训练。该三种方法均成功解决了NPU上的动态shape问题,但是第一种方法删除了样本点,因此导致最后训练出的模型精度很差;第二种方法虽然并没有丢失样本信息,但是向数据中填入大量的0,改变了本来的代码逻辑,导致最后训练出的模型精度也并不高。对于第三种方法,即没有丢失样本信息,对每个点云数据中的某一个点云样本点进行重复操作,没有改变原始的代码逻辑,最后也获得了不错的精度表现。 @@ -64,7 +60,7 @@ ## 默认配置 -- 数据集获取:obs://cann-id2019/dataset/ +- 数据集获取百度网盘链接: - 训练超参 @@ -250,4 +246,7 @@ checkpoint文件中最新模型的路径修改 | ------------------- | ------- | ------ | ------ | | global_step/sec| 无 | 87.64 | 116.77 | +## 离线推理 +参考 SVD_ID2019_for_ACL + diff --git a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/genStatistical.sh b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/genStatistical.sh index 990ca29a57588f3964dc67e662ad0bd451c6b615..f9a760ef8cc851ef42e7d8fa1aacd6588a1a36ee 100644 --- a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/genStatistical.sh +++ b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/genStatistical.sh @@ -123,13 +123,13 @@ echo "$testData_Path" if [ x"${modelarts_flag}" != x ]; then - python /home/ma-user/modelarts/user-job-dir/code/main_point_cloud.py \ + python /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ --method=svd \ --checkpoint_dir=${outputPath} \ --pt_cloud_test_files=${testData_Path} \ --predict_all_test=True else - python /home/ma-user/modelarts/user-job-dir/code/main_point_cloud.py \ + python /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ --method=svd \ --checkpoint_dir=${outputPath} \ --pt_cloud_test_files=${testData_Path} \ diff --git a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_boostPerf.py b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_perf.py similarity index 79% rename from TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_boostPerf.py rename to TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_perf.py index 54f2fffa0010bedd9e07db0b228778fc11d5aa3b..316d7da2065cbbbc3959ba40397e6342d0c218cc 100644 --- a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_boostPerf.py +++ b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/main_point_cloud_perf.py @@ -29,11 +29,6 @@ from npu_bridge.npu_init import * from npu_bridge.estimator.npu.npu_estimator import NPUEstimatorSpec from npu_bridge.estimator.npu.npu_estimator import NPUEstimator from npu_bridge.estimator.npu.npu_config import NPURunConfig -from npu_bridge.estimator.npu.npu_config import ProfilingConfig -#import precision_tool.tf_config as npu_tf_config - -#import precision_tool.config.config as CONFIG - import os FLAGS = tf.app.flags.FLAGS @@ -76,11 +71,6 @@ flags.DEFINE_boolean('random_rotation_axang', True, 'If true, samples random rotations using the method ' 'from the original benchmark code. Otherwise samples ' 'by Haar measure.') -flags.DEFINE_boolean('Profiling',True, - 'parse NPU operator performance') - -flags.DEFINE_boolean('Dump',False, - 'overflow test') def pt_features(batch_pts): @@ -148,13 +138,14 @@ def net_point_cloud(points1, points2, mode): def model_fn(features, labels, mode, params): """The model_fn used to construct the tf.Estimator.""" - del labels, params # Unused. + del params # Unused. if mode == tf.estimator.ModeKeys.TRAIN: # Training data has point cloud of size [1, N, 3] and random rotations # of size [1, FLAGS.num_train_augmentations, 3, 3] - rot = features['rot'][0] + rot = labels[0] + data = features num_rot = FLAGS.num_train_augmentations - batch_pts1 = tf.tile(features['data'], [num_rot, 1, 1]) + batch_pts1 = tf.tile(data, [num_rot, 1, 1]) # In this experiment it does not matter if we pre or post-multiply the # rotation as long as we are consistent between training and eval. batch_pts2 = tf.matmul(batch_pts1, rot) # post-multiplying! @@ -167,16 +158,22 @@ def model_fn(features, labels, mode, params): rot = tf.reshape(rot, (-1, 3, 3)) # Predict the rotation. + + r = net_point_cloud(batch_pts1, batch_pts2, mode) + unit_one = tf.constant(1.0,dtype=tf.float32) + + rotation_matrix = tf.multiply(r,unit_one,name='rotation_matrix') # Compute the loss. loss = tf.nn.l2_loss(rot - r) # Compute the relative angle in radians. - theta = utils.relative_angle(r, rot) + theta = utils.relative_angle(rot, r) # Mean angle error over the batch. mean_theta = tf.reduce_mean(theta) + mean_theta_deg = mean_theta * 180.0 / np.pi # Train, eval, or predict depending on mode. @@ -204,7 +201,8 @@ def model_fn(features, labels, mode, params): return NPUEstimatorSpec( mode=mode, loss=loss, - train_op=train_op) + train_op=train_op, + predictions=rotation_matrix) if mode == tf.estimator.ModeKeys.EVAL: if FLAGS.predict_all_test: @@ -268,13 +266,12 @@ def train_input_fn(): dataset = dataset.repeat() dataset = dataset.map(_random_rotation) dataset = dataset.batch(1) - iterator = tf.data.make_one_shot_iterator(dataset) batch_data, batch_rot = iterator.get_next() - features_dict = {'data': batch_data, 'rot': batch_rot} - batch_size = tf.shape(batch_data)[0] - batch_labels_dummy = tf.zeros(shape=(batch_size, 1)) - return (features_dict, batch_labels_dummy) + #features_dict = {'data': batch_data, 'rot': batch_rot} + #batch_size = tf.shape(batch_data)[0] + #batch_labels_dummy = tf.zeros(shape=(batch_size, 1)) + return batch_data,batch_rot def eval_input_fn(): @@ -321,11 +318,10 @@ def eval_input_fn(): dataset = dataset.batch(1) iterator = tf.data.make_one_shot_iterator(dataset) batch_data, batch_rot = iterator.get_next() - features_dict = {'data': batch_data, 'rot': batch_rot} - batch_size = tf.shape(batch_data)[0] - batch_labels_dummy = tf.zeros(shape=(batch_size, 1)) - return (features_dict, batch_labels_dummy) + #batch_size = tf.shape(batch_data)[0] + #batch_labels_dummy = tf.zeros(shape=(batch_size, 1)) + return features_dict def print_variable_names(): @@ -373,6 +369,16 @@ def predict_all_test(): index = np.int32(np.float32(n * perc) / 100.0) - 1 print('%3d%%: %f'%(perc, sorted_errors[index])) +def serving_input_fn(): + input_data = tf.placeholder(tf.float32, [None, None, 3], name='data') + input_rot = tf.placeholder(tf.float32, [None, None, 3], name='rot') + input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn({ + 'data': input_data, + 'rot': input_rot + })() + return input_fn + + def train_and_eval(): """Train and evaluate a model.""" @@ -380,69 +386,42 @@ def train_and_eval(): save_checkpoints_steps = FLAGS.save_checkpoints_steps log_step_count = FLAGS.log_step_count + config = NPURunConfig( + save_summary_steps=save_summary_steps, + save_checkpoints_steps=save_checkpoints_steps, + log_step_count_steps=log_step_count, + keep_checkpoint_max=None, + precision_mode = "allow_mix_precision") + - # dump_config = npu_tf_config.estimator_dump_config(action='overflow') - # - profilingPath=os.path.join(FLAGS.checkpoint_dir,'npu_profiling') - if not os.path.exists(profilingPath): - os.makedirs(profilingPath) - - profiling_options= '{"output":"%s",\ - "task_trace":"on",\ - "aicpu":"on"}'%(profilingPath) - - profiling_config=ProfilingConfig(enable_profiling=True,profiling_options=profiling_options) - session_config=tf.ConfigProto() - - if FLAGS.Profiling: - config = NPURunConfig( - save_summary_steps=save_summary_steps, - save_checkpoints_steps=save_checkpoints_steps, - log_step_count_steps=log_step_count, - keep_checkpoint_max=None, - precision_mode = "allow_mix_precision", - profiling_config=profiling_config, - session_config=session_config, - customize_dtypes="/home/ma-user/modelarts/user-job-dir/code/switch_config.txt") - # if FLAGS.Dump == True: - # config = NPURunConfig( - # save_summary_steps=save_summary_steps, - # save_checkpoints_steps=save_checkpoints_steps, - # log_step_count_steps=log_step_count, - # keep_checkpoint_max=None, - # precision_mode="allow_mix_precision", - # dump_config=dump_config) - else: - config = NPURunConfig( - save_summary_steps=save_summary_steps, - save_checkpoints_steps=save_checkpoints_steps, - log_step_count_steps=log_step_count, - keep_checkpoint_max=None, - precision_mode="allow_mix_precision", - customize_dtypes="/home/ma-user/modelarts/user-job-dir/code/switch_config.txt") params = {'dummy': 0} estimator = NPUEstimator( - model_fn=model_fn, - model_dir=FLAGS.checkpoint_dir, - config=config, - params=params) - + model_fn=model_fn, + model_dir=FLAGS.checkpoint_dir, + config=config, + params=params) - train_spec = tf.estimator.TrainSpec( - input_fn=train_input_fn, - max_steps=FLAGS.train_steps) + # train_spec = tf.estimator.TrainSpec( + # input_fn=train_input_fn, + # max_steps=FLAGS.train_steps) + # + # eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, + # start_delay_secs=60, + # steps=FLAGS.eval_examples, + # throttle_secs=60) - eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn, - start_delay_secs=60, - steps=FLAGS.eval_examples, - throttle_secs=60) + #tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + estimator.train(input_fn=train_input_fn, + max_steps=FLAGS.train_steps) - tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec) + estimator.export_savedmodel(FLAGS.checkpoint_dir, + serving_input_fn) + estimator.evaluate(input_fn=eval_input_fn, steps=FLAGS.eval_examples) def main(argv=None): # pylint: disable=unused-argument diff --git a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_full_1p.sh b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_full_1p.sh index edd5eed9ab52c3f94a377aea26e45f76599f3154..13753883246ddb6859ebaa098e03f88e528f1ea6 100644 --- a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_full_1p.sh +++ b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_full_1p.sh @@ -119,7 +119,7 @@ testData_Path="$data_path$test_data" if [ x"${modelarts_flag}" != x ]; then - python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_boostPerf.py \ + python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ --method=svd \ --checkpoint_dir=${output_path} \ --log_step_count=200 \ @@ -130,8 +130,8 @@ then --save_checkpoints_steps=100000 \ --eval_examples=39900 else - python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_boostPerf.py \ - --method=svd \ + python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ + --method=svd \ --checkpoint_dir=${output_path} \ --log_step_count=200 \ --save_summaries_steps=25000 \ diff --git a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_performance_1p.sh b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_performance_1p.sh index ab8be4ed0d7421d610930c67efd8fde2c2d520b5..330245b56975c30b60663151a083f713b336b6a4 100644 --- a/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_performance_1p.sh +++ b/TensorFlow/contrib/cv/SVD_ID2019_for_Tensorflow/train_performance_1p.sh @@ -119,7 +119,7 @@ testData_Path="$data_path$test_data" if [ x"${modelarts_flag}" != x ]; then - python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_boostPerf.py \ + python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ --method=svd \ --checkpoint_dir=${output_path} \ --log_step_count=200 \ @@ -130,7 +130,7 @@ then --save_checkpoints_steps=10000 \ --eval_examples=399 else - python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_boostPerf.py \ + python3.7 /home/ma-user/modelarts/user-job-dir/code/main_point_cloud_perf.py \ --method=svd \ --checkpoint_dir=${output_path} \ --log_step_count=200 \