From a920c3e302922b3e32e4da63ed8d72deed1cd142 Mon Sep 17 00:00:00 2001 From: hqy2022 <2609131663@qq.com> Date: Mon, 13 Jun 2022 03:04:33 +0000 Subject: [PATCH] =?UTF-8?q?=E4=B8=BB=E8=A6=81=E5=81=9A=E4=BA=86=E9=80=82?= =?UTF-8?q?=E5=BA=94npu=E7=9A=84=E8=B0=83=E6=95=B4=201.=20pc=5Fdistance?= =?UTF-8?q?=E4=B8=AD=E7=94=A8=E5=88=B0=E4=BA=86=E4=B8=80=E4=B8=AAcuda?= =?UTF-8?q?=E8=87=AA=E5=AE=9A=E4=B9=89=E7=AE=97=E5=AD=90=E5=BA=93=EF=BC=8C?= =?UTF-8?q?=E5=9C=A8npu=E4=B8=8D=E5=85=BC=E5=AE=B9=EF=BC=8C=E6=89=80?= =?UTF-8?q?=E4=BB=A5=E5=90=AF=E7=94=A8=E4=BA=86cpu+npu=E6=B7=B7=E5=90=88?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=EF=BC=9B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 2. npu不支持动态shape,所以原代码里所有动态shape都去掉了,或者替换成了不涉及动态shape的实现; --- data_util.py | 119 ++++++++++++++++++++++++++ io_util.py | 39 +++++++++ kitti_registration.py | 164 ++++++++++++++++++++++++++++++++++++ lmdb_writer.py | 65 +++++++++++++++ test_kitti.py | 109 ++++++++++++++++++++++++ test_shapenet.py | 126 ++++++++++++++++++++++++++++ tf_nndistance.py | 105 +++++++++++++++++++++++ tf_util.py | 164 ++++++++++++++++++++++++++++++++++++ train.py | 188 ++++++++++++++++++++++++++++++++++++++++++ visu_util.py | 60 ++++++++++++++ 10 files changed, 1139 insertions(+) create mode 100644 data_util.py create mode 100644 io_util.py create mode 100644 kitti_registration.py create mode 100644 lmdb_writer.py create mode 100644 test_kitti.py create mode 100644 test_shapenet.py create mode 100644 tf_nndistance.py create mode 100644 tf_util.py create mode 100644 train.py create mode 100644 visu_util.py diff --git a/data_util.py b/data_util.py new file mode 100644 index 000000000..4f36811e5 --- /dev/null +++ b/data_util.py @@ -0,0 +1,119 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import numpy as np +import tensorflow as tf +import tensorflow.compat.v1 as tf +from tensorpack import dataflow + + +def resample_pcd(pcd, n): + """Drop or duplicate points so that pcd has exactly n points""" + idx = np.random.permutation(pcd.shape[0]) + if idx.shape[0] < n: + idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])]) + return pcd[idx[:n]] + + +class PreprocessData(dataflow.ProxyDataFlow): + def __init__(self, ds, input_size, output_size): + super(PreprocessData, self).__init__(ds) + self.input_size = input_size + self.output_size = output_size + + def get_data(self): + for id, input, gt in self.ds.get_data(): + input = resample_pcd(input, self.input_size) + gt = resample_pcd(gt, self.output_size) + print(gt.shape) + yield id, input, gt + + +class BatchData(dataflow.ProxyDataFlow): + def __init__(self, ds, batch_size, input_size, gt_size, remainder=True, use_list=False): + super(BatchData, self).__init__(ds) + self.batch_size = batch_size + self.input_size = input_size + self.gt_size = gt_size + self.remainder = remainder + self.use_list = use_list + + def __len__(self): + ds_size = len(self.ds) + div = ds_size // self.batch_size + rem = ds_size % self.batch_size + if rem == 0: + return div + return div + int(self.remainder) + + def __iter__(self): + holder = [] + for data in self.ds: + holder.append(data) + if len(holder) == self.batch_size: + yield self._aggregate_batch(holder, self.use_list) + del holder[:] + if self.remainder and len(holder) > 0: + yield self._aggregate_batch(holder, self.use_list) + + def _aggregate_batch(self, data_holder, use_list=False): + ''' Concatenate input points along the 0-th dimension + Stack all other data along the 0-th dimension + ''' + ids = np.stack([x[0] for x in data_holder]) + inputs = [resample_pcd(x[1], self.input_size) for x in data_holder] + inputs = np.expand_dims(np.concatenate([x for x in inputs]), 0).astype(np.float32) + npts = np.stack([self.input_size for x in data_holder]).astype(np.int32) + gts = np.stack([resample_pcd(x[2], self.gt_size) for x in data_holder]).astype(np.float32) + return ids, inputs, npts, gts + + +def lmdb_dataflow(lmdb_path, batch_size, input_size, output_size, is_training, test_speed=False): + df = dataflow.LMDBSerializer.load(lmdb_path, shuffle=False) + size = df.size() + if is_training: + df = dataflow.LocallyShuffleData(df, buffer_size=2000) + df = dataflow.PrefetchData(df, nr_prefetch=500, nr_proc=1) + df = BatchData(df, batch_size, input_size, output_size) + if is_training: + df = dataflow.PrefetchDataZMQ(df, nr_proc=8) + df = dataflow.RepeatedData(df, -1) + if test_speed: + dataflow.TestDataSpeed(df, size=3000).start() + df.reset_state() + return df, size + + +def get_queued_data(generator, dtypes, shapes, queue_capacity=10): + assert len(dtypes) == len(shapes), 'dtypes and shapes must have the same length' + queue = tf.FIFOQueue(queue_capacity, dtypes, shapes) + placeholders = [tf.placeholder(dtype, shape) for dtype, shape in zip(dtypes, shapes)] + enqueue_op = queue.enqueue(placeholders) + close_op = queue.close(cancel_pending_enqueues=True) + feed_fn = lambda: {placeholder: value for placeholder, value in zip(placeholders, next(generator))} + queue_runner = tf.contrib.training.FeedingQueueRunner(queue, [enqueue_op], close_op, feed_fns=[feed_fn]) + tf.train.add_queue_runner(queue_runner) + return queue.dequeue() + diff --git a/io_util.py b/io_util.py new file mode 100644 index 000000000..d8082aab0 --- /dev/null +++ b/io_util.py @@ -0,0 +1,39 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import numpy as np +import open3d as o3d + + +def read_pcd(filename): + pcd = o3d.io.read_point_cloud(filename) + return np.array(pcd.points) + + +def save_pcd(filename, points): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + o3d.io.write_point_cloud(filename, pcd) + diff --git a/kitti_registration.py b/kitti_registration.py new file mode 100644 index 000000000..3e9cbb269 --- /dev/null +++ b/kitti_registration.py @@ -0,0 +1,164 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import argparse +import copy +import csv +import matplotlib.pyplot as plt +import numpy as np +import os +from mpl_toolkits.mplot3d import Axes3D +from open3d import * + + +def bbox2rt(bbox): + center = (bbox.min(0) + bbox.max(0)) / 2 + bbox -= center + yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0]) + rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0], + [np.sin(yaw), np.cos(yaw), 0], + [0, 0, 1]]) + return rotation, center + + +def register(source, target, args): + residual = TransformationEstimationPointToPoint() + criteria = ICPConvergenceCriteria(max_iteration=args.max_iter) + # Align the centroids of the point clouds + source_points = np.array(source.points) + target_points = np.array(target.points) + source_center = np.mean(source_points, axis=0) + target_center = np.mean(target_points, axis=0) + source = PointCloud() + source.points = Vector3dVector(source_points - source_center) + target = PointCloud() + target.points = Vector3dVector(target_points - target_center) + result = registration_icp(source, target, args.max_dist, np.eye(4), residual, criteria) + source_trans = copy.deepcopy(source) + source_trans.transform(result.transformation) + R = result.transformation[:3, :3] + t = result.transformation[:3, 3] + target_center - np.dot(source_center, R.T) + return R, t, np.array(source_trans.points), np.array(target.points) + + +def rotation_error(R1, R2): + cos = (np.trace(np.dot(R1, R2.T)) - 1) / 2 + cos = np.maximum(np.minimum(cos, 1), -1) + return 180 * np.arccos(cos) / np.pi + + +def translation_error(t1, t2): + return np.sqrt(np.sum((t1 - t2) ** 2)) + + +def plot_pcd_pair(ax, pcd1, pcd2, title, cmaps, size, xlim=(-1.5, 1.5), ylim=(-1.5, 1.5), zlim=(-1, 2)): + ax.scatter(pcd1[:, 0], pcd1[:, 1], pcd1[:, 2], c=pcd1[:, 0], s=size, cmap=cmaps[0], vmin=-5, vmax=1.5) + ax.scatter(pcd2[:, 0], pcd2[:, 1], pcd2[:, 2], c=pcd2[:, 0], s=size, cmap=cmaps[1], vmin=-5, vmax=1.5) + ax.set_title(title) + ax.set_axis_off() + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_zlim(zlim) + + +def track(args): + os.makedirs(os.path.join(args.results_dir, 'plots'), exist_ok=True) + csv_file = open(os.path.join(args.results_dir, 'error.csv'), 'w') + writer = csv.writer(csv_file) + writer.writerow(['id', 'r_err_part', 't_err_part', 'r_err_comp', 't_err_comp']) + + n = 0 + total_r_err_part = 0 + total_t_err_part = 0 + total_r_err_comp = 0 + total_t_err_comp = 0 + for filename in os.listdir(args.tracklet_dir): + tracklet_id = filename.split('.')[0] + with open(os.path.join(args.tracklet_dir, filename)) as file: + car_ids = file.read().splitlines() + + prev_frame = int(car_ids[0].split('_')[1]) + prev_R, prev_t = bbox2rt(np.loadtxt(os.path.join(args.bbox_dir, '%s.txt' % car_ids[0]))) + prev_partial = read_point_cloud(os.path.join(args.partial_dir, '%s.pcd' % car_ids[0])) + prev_complete = read_point_cloud(os.path.join(args.complete_dir, '%s.pcd' % car_ids[0])) + for i in range(args.interval, len(car_ids), args.interval): + n += 1 + frame = int(car_ids[i].split('_')[1]) + instance_id = '%s_frame_%d_to_%d' % (tracklet_id, prev_frame, frame) + + R, t = bbox2rt(np.loadtxt(os.path.join(args.bbox_dir, '%s.txt' % car_ids[i]))) + R_gt = np.dot(R, prev_R.T) + t_gt = t - np.dot(prev_t, R_gt.T) + + partial = read_point_cloud(os.path.join(args.partial_dir, '%s.pcd' % car_ids[i])) + R_part, t_part, partial_trans, partial_target = register(prev_partial, partial, args) + r_err_part = rotation_error(R_part, R_gt) + t_err_part = translation_error(t_part, t_gt) + total_r_err_part += r_err_part + total_t_err_part += t_err_part + + complete = read_point_cloud(os.path.join(args.complete_dir, '%s.pcd' % car_ids[i])) + R_comp, t_comp, complete_trans, complete_target = register(prev_complete, complete, args) + r_err_comp = rotation_error(R_comp, R_gt) + t_err_comp = translation_error(t_comp, t_gt) + total_r_err_comp += r_err_comp + total_t_err_comp += t_err_comp + + writer.writerow([instance_id, r_err_part, t_err_part, r_err_comp, t_err_comp]) + + if n % args.plot_freq == 0: + fig = plt.figure(figsize=(8, 4)) + ax = fig.add_subplot(121, projection='3d') + plot_pcd_pair(ax, partial_trans, partial_target, + 'Rotation error %.4f\nTranslation error %.4f' % (r_err_part, t_err_part), + ['Reds', 'Blues'], size=5) + ax = fig.add_subplot(122, projection='3d') + plot_pcd_pair(ax, complete_trans, complete_target, + 'Rotation error %.4f\nTranslation error %.4f' % (r_err_comp, t_err_comp), + ['Reds', 'Blues'], size=0.5) + plt.subplots_adjust(left=0, right=1, bottom=0, top=0.95, wspace=0) + fig.savefig(os.path.join(args.results_dir, 'plots', '%s.png' % instance_id)) + plt.close(fig) + print('Using original pcd: average roration error %.4f average translation error %.4f' % + (total_r_err_part / n, total_t_err_part / n)) + print('Using completed pcd: average roration error %.4f average translation error %.4f' % + (total_r_err_comp / n, total_t_err_comp / n)) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--partial_dir', default='data/kitti/cars') + parser.add_argument('--complete_dir', default='data/results/kitti/pcn_emd/completions') + parser.add_argument('--bbox_dir', default='data/kitti/bboxes') + parser.add_argument('--tracklet_dir', default='data/kitti/tracklets') + parser.add_argument('--results_dir', default='data/results/kitti_registration') + parser.add_argument('--interval', type=int, default=1, help='number of frames to skip') + parser.add_argument('--max_iter', type=int, default=100, help='max iteration for ICP') + parser.add_argument('--max_dist', type=float, default=0.05, help='matching threshold for ICP') + parser.add_argument('--plot_freq', type=int, default=100) + args = parser.parse_args() + + track(args) + diff --git a/lmdb_writer.py b/lmdb_writer.py new file mode 100644 index 000000000..7eaebc0c2 --- /dev/null +++ b/lmdb_writer.py @@ -0,0 +1,65 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import argparse +import os +from io_util import read_pcd +from tensorpack import DataFlow, dataflow + + +class pcd_df(DataFlow): + def __init__(self, model_list, num_scans, partial_dir, complete_dir): + self.model_list = model_list + self.num_scans = num_scans + self.partial_dir = partial_dir + self.complete_dir = complete_dir + + def size(self): + return len(self.model_list) * self.num_scans + + def get_data(self): + for model_id in model_list: + complete = read_pcd(os.path.join(self.complete_dir, '%s.pcd' % model_id)) + for i in range(self.num_scans): + partial = read_pcd(os.path.join(self.partial_dir, model_id, '%d.pcd' % i)) + yield model_id.replace('/', '_'), partial, complete + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--list_path') + parser.add_argument('--num_scans', type=int) + parser.add_argument('--partial_dir') + parser.add_argument('--complete_dir') + parser.add_argument('--output_path') + args = parser.parse_args() + + with open(args.list_path) as file: + model_list = file.read().splitlines() + df = pcd_df(model_list, args.num_scans, args.partial_dir, args.complete_dir) + if os.path.exists(args.output_path): + os.system('rm %s' % args.output_path) + dataflow.LMDBSerializer.save(df, args.output_path) + diff --git a/test_kitti.py b/test_kitti.py new file mode 100644 index 000000000..5bdfc4b1c --- /dev/null +++ b/test_kitti.py @@ -0,0 +1,109 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import argparse +import importlib +import models +import numpy as np +import os +import tensorflow as tf +import time +from io_util import read_pcd, save_pcd +from visu_util import plot_pcd_three_views + + +def test(args): + inputs = tf.placeholder(tf.float32, (1, None, 3)) + npts = tf.placeholder(tf.int32, (1,)) + gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) + model_module = importlib.import_module('.%s' % args.model_type, 'models') + model = model_module.Model(inputs, npts, gt, tf.constant(1.0)) + + os.makedirs(os.path.join(args.results_dir, 'plots'), exist_ok=True) + os.makedirs(os.path.join(args.results_dir, 'completions'), exist_ok=True) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + sess = tf.Session(config=npu_config_proto(config_proto=config)) + + saver = tf.train.Saver() + saver.restore(sess, args.checkpoint) + + car_ids = [filename.split('.')[0] for filename in os.listdir(args.pcd_dir)] + total_time = 0 + total_points = 0 + for i, car_id in enumerate(car_ids): + partial = read_pcd(os.path.join(args.pcd_dir, '%s.pcd' % car_id)) + bbox = np.loadtxt(os.path.join(args.bbox_dir, '%s.txt' % car_id)) + total_points += partial.shape[0] + + # Calculate center, rotation and scale + center = (bbox.min(0) + bbox.max(0)) / 2 + bbox -= center + yaw = np.arctan2(bbox[3, 1] - bbox[0, 1], bbox[3, 0] - bbox[0, 0]) + rotation = np.array([[np.cos(yaw), -np.sin(yaw), 0], + [np.sin(yaw), np.cos(yaw), 0], + [0, 0, 1]]) + bbox = np.dot(bbox, rotation) + scale = bbox[3, 0] - bbox[0, 0] + bbox /= scale + + partial = np.dot(partial - center, rotation) / scale + partial = np.dot(partial, [[1, 0, 0], [0, 0, 1], [0, 1, 0]]) + + start = time.time() + completion = sess.run(model.outputs, feed_dict={inputs: [partial], npts: [partial.shape[0]]}) + total_time += time.time() - start + completion = completion[0] + + completion_w = np.dot(completion, [[1, 0, 0], [0, 0, 1], [0, 1, 0]]) + completion_w = np.dot(completion_w * scale, rotation.T) + center + pcd_path = os.path.join(args.results_dir, 'completions', '%s.pcd' % car_id) + save_pcd(pcd_path, completion_w) + + if i % args.plot_freq == 0: + plot_path = os.path.join(args.results_dir, 'plots', '%s.png' % car_id) + plot_pcd_three_views(plot_path, [partial, completion], ['input', 'output'], + '%d input points' % partial.shape[0], [5, 0.5]) + print('Average # input points:', total_points / len(car_ids)) + print('Average time:', total_time / len(car_ids)) + sess.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--model_type', default='pcn_emd') + parser.add_argument('--checkpoint', default='D://shapenet/trained_models/pcn_emd_car') + parser.add_argument('--pcd_dir', default='D://shapenet/kitti/cars') + parser.add_argument('--bbox_dir', default='D://shapenet/kitti/bboxes') + parser.add_argument('--results_dir', default='D://pcn/results/kitti_pcn_emd') + parser.add_argument('--num_gt_points', type=int, default=16384) + parser.add_argument('--plot_freq', type=int, default=100) + parser.add_argument('--save_pcd', action='store_false') + args = parser.parse_args() + print('deded') + test(args) + diff --git a/test_shapenet.py b/test_shapenet.py new file mode 100644 index 000000000..cc4d8ed16 --- /dev/null +++ b/test_shapenet.py @@ -0,0 +1,126 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import argparse +import csv +import importlib +import models +import numpy as np +import os +import tensorflow.compat.v1 as tf +import time +from io_util import read_pcd, save_pcd +from tf_util import chamfer, earth_mover +from visu_util import plot_pcd_three_views + + +def test(args): + inputs = tf.placeholder(tf.float32, (1, None, 3)) + npts = tf.placeholder(tf.int32, (1,)) + gt = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) + model_module = importlib.import_module('.%s' % args.model_type, 'models') + model = model_module.Model(inputs, npts, gt, tf.constant(1.0)) + + output = tf.placeholder(tf.float32, (1, args.num_gt_points, 3)) + cd_op = chamfer(output, gt) + emd_op = earth_mover(output, gt) + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + sess = tf.Session(config=npu_config_proto(config_proto=config)) + + saver = tf.train.Saver() + saver.restore(sess, args.checkpoint) + + os.makedirs(args.results_dir, exist_ok=True) + csv_file = open(os.path.join(args.results_dir, 'results.csv'), 'w') + writer = csv.writer(csv_file) + writer.writerow(['id', 'cd', 'emd']) + + with open(args.list_path) as file: + model_list = file.read().splitlines() + total_time = 0 + total_cd = 0 + total_emd = 0 + cd_per_cat = {} + emd_per_cat = {} + for i, model_id in enumerate(model_list): + partial = read_pcd(os.path.join(args.data_dir, 'partial', '%s.pcd' % model_id)) + complete = read_pcd(os.path.join(args.data_dir, 'complete', '%s.pcd' % model_id)) + start = time.time() + completion = sess.run(model.outputs, feed_dict={inputs: [partial], npts: [partial.shape[0]]}) + total_time += time.time() - start + cd, emd = sess.run([cd_op, emd_op], feed_dict={output: completion, gt: [complete]}) + total_cd += cd + total_emd += emd + writer.writerow([model_id, cd, emd]) + + synset_id, model_id = model_id.split('/') + if not cd_per_cat.get(synset_id): + cd_per_cat[synset_id] = [] + if not emd_per_cat.get(synset_id): + emd_per_cat[synset_id] = [] + cd_per_cat[synset_id].append(cd) + emd_per_cat[synset_id].append(emd) + + if i % args.plot_freq == 0: + os.makedirs(os.path.join(args.results_dir, 'plots', synset_id), exist_ok=True) + plot_path = os.path.join(args.results_dir, 'plots', synset_id, '%s.png' % model_id) + plot_pcd_three_views(plot_path, [partial, completion[0], complete], + ['input', 'output', 'ground truth'], + 'CD %.4f EMD %.4f' % (cd, emd), + [5, 0.5, 0.5]) + if args.save_pcd: + os.makedirs(os.path.join(args.results_dir, 'pcds', synset_id), exist_ok=True) + save_pcd(os.path.join(args.results_dir, 'pcds', '%s.pcd' % model_id), completion[0]) + csv_file.close() + sess.close() + + print('Average time: %f' % (total_time / len(model_list))) + print('Average Chamfer distance: %f' % (total_cd / len(model_list))) + print('Average Earth mover distance: %f' % (total_emd / len(model_list))) + print('Chamfer distance per category') + for synset_id in cd_per_cat.keys(): + print(synset_id, '%f' % np.mean(cd_per_cat[synset_id])) + print('Earth mover distance per category') + for synset_id in emd_per_cat.keys(): + print(synset_id, '%f' % np.mean(emd_per_cat[synset_id])) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--list_path', default='/data/pcn/test.list') + parser.add_argument('--data_dir', default='/data/pcn/test') + parser.add_argument('--model_type', default='pcn_cd') + parser.add_argument('--checkpoint', default='log/pcn_cd/model-400000') + parser.add_argument('--results_dir', default='results/shapenet_pcn_cd') + parser.add_argument('--num_gt_points', type=int, default=16384) + parser.add_argument('--plot_freq', type=int, default=100) + parser.add_argument('--save_pcd', action='store_true') + args = parser.parse_args() + + test(args) + diff --git a/tf_nndistance.py b/tf_nndistance.py new file mode 100644 index 000000000..704f38a47 --- /dev/null +++ b/tf_nndistance.py @@ -0,0 +1,105 @@ +import tensorflow as tf +from tensorflow.python.framework import ops +import os.path as osp + +from npu_bridge.npu_init import * +import os.path as osp +import os +import sys + +base_dir = osp.dirname(osp.abspath(__file__)) + +nn_distance_module = tf.load_op_library(osp.join(base_dir, 'tf_nndistance_so.so')) + + +def nn_distance(xyz1, xyz2): + ''' + Computes the distance of nearest neighbors for a pair of point clouds + input: xyz1: (batch_size,#points_1,3) the first point cloud + input: xyz2: (batch_size,#points_2,3) the second point cloud + output: dist1: (batch_size,#point_1) distance from first to second + output: idx1: (batch_size,#point_1) nearest neighbor from first to second + output: dist2: (batch_size,#point_2) distance from second to first + output: idx2: (batch_size,#point_2) nearest neighbor from second to first + ''' + print('xyz1',xyz1,'xyz2',xyz2) + #xyz1 = tf.expand_dims(xyz1, 0) + #xyz2 = tf.expand_dims(xyz2, 0) + return nn_distance_module.nn_distance(xyz1,xyz2) + +#@tf.RegisterShape('NnDistance') +@ops.RegisterShape('NnDistance') +def _nn_distance_shape(op): + shape1=op.inputs[0].get_shape().with_rank(3) + shape2=op.inputs[1].get_shape().with_rank(3) + return [tf.TensorShape([shape1.dims[0],shape1.dims[1]]),tf.TensorShape([shape1.dims[0],shape1.dims[1]]), + tf.TensorShape([shape2.dims[0],shape2.dims[1]]),tf.TensorShape([shape2.dims[0],shape2.dims[1]])] +@ops.RegisterGradient('NnDistance') +def _nn_distance_grad(op,grad_dist1,grad_idx1,grad_dist2,grad_idx2): + xyz1=op.inputs[0] + xyz2=op.inputs[1] + idx1=op.outputs[1] + idx2=op.outputs[3] + return nn_distance_module.nn_distance_grad(xyz1,xyz2,grad_dist1,idx1,grad_dist2,idx2) + + +if __name__=='__main__': + import numpy as np + import random + import time + #from tensorflow.python.kernel_tests.gradient_checker import compute_gradient + from tensorflow.python.ops.gradient_checker import compute_gradient + random.seed(100) + np.random.seed(100) + # Create a session + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + config.log_device_placement = False + + # 增加混合计算开关 start + custom_op = config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = True + # 增加混合计算开关 end + config.graph_options.rewrite_options.remapping = RewriterConfig.OFF # 必须显式关闭 + #config.graph_options.rewrite_options.memory_optimization = RewriterConfig.OFF # 必须显式关闭 + with tf.Session(config=npu_config_proto(config_proto=config)) as sess: + xyz1=np.random.randn(32,16384,3).astype('float32') + xyz2=np.random.randn(32,1024,3).astype('float32') + with tf.device('/gpu:0'): + inp1=tf.Variable(xyz1) + inp2=tf.constant(xyz2) + reta,retb,retc,retd=nn_distance(inp1,inp2) + loss=tf.reduce_sum(reta)+tf.reduce_sum(retc) + train=tf.train.GradientDescentOptimizer(learning_rate=0.05).minimize(loss) + sess.run(tf.initialize_all_variables()) + t0=time.time() + t1=t0 + best=1e100 + for i in range(100): + trainloss,_=sess.run([loss,train]) + newt=time.time() + best=min(best,newt-t1) + print(i,trainloss,(newt-t0)/(i+1),best) + t1=newt + #print sess.run([inp1,retb,inp2,retd]) + #grads=compute_gradient([inp1,inp2],[(16,32,3),(16,32,3)],loss,(1,),[xyz1,xyz2]) + #for i,j in grads: + #print i.shape,j.shape,np.mean(np.abs(i-j)),np.mean(np.abs(i)),np.mean(np.abs(j)) + #for i in xrange(10): + #t0=time.time() + #a,b,c,d=sess.run([reta,retb,retc,retd],feed_dict={inp1:xyz1,inp2:xyz2}) + #print 'time',time.time()-t0 + #print a.shape,b.shape,c.shape,d.shape + #print a.dtype,b.dtype,c.dtype,d.dtype + #samples=np.array(random.sample(range(xyz2.shape[1]),100),dtype='int32') + #dist1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).min(axis=-1) + #idx1=((xyz1[:,samples,None,:]-xyz2[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) + #print np.abs(dist1-a[:,samples]).max() + #print np.abs(idx1-b[:,samples]).max() + #dist2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).min(axis=-1) + #idx2=((xyz2[:,samples,None,:]-xyz1[:,None,:,:])**2).sum(axis=-1).argmin(axis=-1) + #print np.abs(dist2-c[:,samples]).max() + #print np.abs(idx2-d[:,samples]).max() + diff --git a/tf_util.py b/tf_util.py new file mode 100644 index 000000000..44bab555f --- /dev/null +++ b/tf_util.py @@ -0,0 +1,164 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +import tensorflow as tf +from pc_distance import tf_nndistance, tf_approxmatch + + +def mlp(features, layer_dims, bn=None, bn_params=None): + for i, num_outputs in enumerate(layer_dims[:-1]): + features = tf.contrib.layers.fully_connected( + features, num_outputs, + normalizer_fn=bn, + normalizer_params=bn_params, + scope='fc_%d' % i) + outputs = tf.contrib.layers.fully_connected( + features, layer_dims[-1], + activation_fn=None, + scope='fc_%d' % (len(layer_dims) - 1)) + return outputs + + +def mlp_conv(inputs, layer_dims, bn=None, bn_params=None): + for i, num_out_channel in enumerate(layer_dims[:-1]): + inputs = tf.contrib.layers.conv1d( + inputs, num_out_channel, + kernel_size=1, + normalizer_fn=bn, + normalizer_params=bn_params, + scope='conv_%d' % i) + outputs = tf.contrib.layers.conv1d( + inputs, layer_dims[-1], + kernel_size=1, + activation_fn=None, + scope='conv_%d' % (len(layer_dims) - 1)) + return outputs + + +def point_maxpool(inputs, npts, keepdims=False): + inputs = [inputs[:,3000*i:3000*(i+1),:] for i in range(npts.shape[0])] + #print(tf.split(inputs, npts, axis=1)) + outputs = [tf.reduce_max(f, axis=1, keepdims=keepdims) + for f in inputs]#tf.split(inputs, npts, axis=1)] + return tf.concat(outputs, axis=0) + + +def point_unpool(inputs, npts): + inputs = [tf.expand_dims(inputs[i,...],0) for i in range(inputs.shape[0])] + #inputs2 = tf.split(inputs, inputs.shape[0], axis=0) + #print(inputs2) + #print(inputs) + outputs = [tf.tile(f, [1, 3000, 1]) for i,f in enumerate(inputs)] + return tf.concat(outputs, axis=1) + +def chamfer(pcd1, pcd2): + print('chamfer:',pcd1,pcd2) + dist1, _, dist2, _ = tf_nndistance.nn_distance(pcd1, pcd2) + dist1 = tf.reduce_mean(tf.sqrt(dist1)) + dist2 = tf.reduce_mean(tf.sqrt(dist2)) + + """ + this is an official implementation by tensorflow-graphics/google + project url:https://github.com/tensorflow/graphics/blob/master/tensorflow_graphics/nn/loss/chamfer_distance.py + Computes the Chamfer distance for the given two point sets. + Note: + This is a symmetric version of the Chamfer distance, calculated as the sum + of the average minimum distance from point_set_a to point_set_b and vice + versa. + The average minimum distance from one point set to another is calculated as + the average of the distances between the points in the first set and their + closest point in the second set, and is thus not symmetrical. + Note: + This function returns the exact Chamfer distance and not an approximation. + Note: + In the following, A1 to An are optional batch dimensions, which must be + broadcast compatible. + Args: + point_set_a: A tensor of shape `[A1, ..., An, N, D]`, where the last axis + represents points in a D dimensional space. + point_set_b: A tensor of shape `[A1, ..., An, M, D]`, where the last axis + represents points in a D dimensional space. + name: A name for this op. Defaults to "chamfer_distance_evaluate". + Returns: + A tensor of shape `[A1, ..., An]` storing the chamfer distance between the + two point sets. + Raises: + ValueError: if the shape of `point_set_a`, `point_set_b` is not supported. + point_set_a = tf.convert_to_tensor(value=pcd1) + point_set_b = tf.convert_to_tensor(value=pcd2) + + shape.compare_batch_dimensions( + tensors=(point_set_a, point_set_b), + tensor_names=("point_set_a", "point_set_b"), + last_axes=-3, + broadcast_compatible=True) + # Verify that the last axis of the tensors has the same dimension. + dimension = point_set_a.shape.as_list()[-1] + shape.check_static( + tensor=point_set_b, + tensor_name="point_set_b", + has_dim_equals=(-1, dimension)) + + # Create N x M matrix where the entry i,j corresponds to ai - bj (vector of + # dimension D). + difference = ( + tf.expand_dims(point_set_a, axis=-2) - + tf.expand_dims(point_set_b, axis=-3)) + # Calculate the square distances between each two points: |ai - bj|^2. + square_distances = tf.einsum("...i,...i->...", difference, difference) + + minimum_square_distance_a_to_b = tf.reduce_min( + input_tensor=square_distances, axis=-1) + minimum_square_distance_b_to_a = tf.reduce_min( + input_tensor=square_distances, axis=-2) + + return (tf.reduce_mean(input_tensor=minimum_square_distance_a_to_b, axis=-1) + + tf.reduce_mean(input_tensor=minimum_square_distance_b_to_a, axis=-1)) + """ + return (dist1 + dist2) / 2 + +#def earth_mover(pcd1,pcd2): +# +# return + +def earth_mover(pcd1, pcd2): + assert pcd1.shape[1] == pcd2.shape[1] + num_points = tf.cast(pcd1.shape[1], tf.float32) + match = tf_approxmatch.approx_match(pcd1, pcd2) + print(match) + cost = tf_approxmatch.match_cost(pcd1, pcd2, match) + print('costshape:',cost) + return tf.reduce_mean(cost / num_points) + + +def add_train_summary(name, value): + tf.compat.v1.summary.scalar(name, value, collections=['train_summary']) + + +def add_valid_summary(name, value): + avg, update = tf.compat.v1.metrics.mean(value) + tf.compat.v1.summary.scalar(name, avg, collections=['valid_summary']) + return update + diff --git a/train.py b/train.py new file mode 100644 index 000000000..c27e0659e --- /dev/null +++ b/train.py @@ -0,0 +1,188 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' + +import argparse +import datetime +import importlib +import os +import tensorflow as tf +import tensorflow.compat.v1 as tf +from tensorflow.core.protobuf.rewriter_config_pb2 import RewriterConfig +import time +from data_util import lmdb_dataflow +from termcolor import colored +from tf_util import add_train_summary +#from visu_util import plot_pcd_three_views +from npu_bridge.npu_init import * +from npu_bridge.estimator import npu_ops +#import moxing as mox +import models + +def train(args): + is_training_pl = tf.placeholder(tf.bool, shape=(), name='is_training') + global_step = tf.Variable(0, trainable=False, name='global_step') + alpha = tf.train.piecewise_constant(global_step, [10000, 20000, 50000], + [0.01, 0.1, 0.5, 1.0], 'alpha_op') + inputs_pl = tf.placeholder(tf.float32, (1, args.batch_size*args.num_input_points, 3), 'inputs') + npts_pl = tf.placeholder(tf.int32, (args.batch_size,), 'num_points') + gt_pl = tf.placeholder(tf.float32, (args.batch_size, args.num_gt_points, 3), 'ground_truths') + + model_module = importlib.import_module('.%s' % args.model_type, 'models') + model = model_module.Model(inputs_pl, npts_pl, gt_pl, alpha, args.batch_size) + add_train_summary('alpha', alpha) + + if args.lr_decay: + learning_rate = tf.train.exponential_decay(args.base_lr, global_step, + args.lr_decay_steps, args.lr_decay_rate, + staircase=True, name='lr') + learning_rate = tf.maximum(learning_rate, args.lr_clip) + add_train_summary('learning_rate', learning_rate) + else: + learning_rate = tf.constant(args.base_lr, name='lr') + train_summary = tf.summary.merge_all('train_summary') + valid_summary = tf.summary.merge_all('valid_summary') + + trainer = tf.train.AdamOptimizer(learning_rate) + train_op = trainer.minimize(model.loss, global_step) + #mox.file.copy_parallel(src_url=os.path.join(args.data_url,args.lmdb_train), + # dst_url=args.lmdb_train) + #mox.file.copy_parallel(src_url=os.path.join(args.data_url,args.lmdb_valid), + # dst_url=args.lmdb_valid) + #print(os.path.join(args.data_url,args.lmdb_train)) + df_train, num_train = lmdb_dataflow( + args.lmdb_train, args.batch_size, args.num_input_points, args.num_gt_points, is_training=True) + train_gen = df_train.get_data() + df_valid, num_valid = lmdb_dataflow( + args.lmdb_valid, args.batch_size, args.num_input_points, args.num_gt_points, is_training=False) + valid_gen = df_valid.get_data() + + config = tf.ConfigProto() + config.gpu_options.allow_growth = True + config.allow_soft_placement = True + custom_op = config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op = config.graph_options.rewrite_options.custom_optimizers.add() + custom_op.name = "NpuOptimizer" + custom_op.parameter_map["mix_compile_mode"].b = True + config.graph_options.rewrite_options.remapping = RewriterConfig.OFF + sess = tf.Session(config=npu_config_proto(config_proto=config)) + saver = tf.train.Saver() + + if args.restore: + saver.restore(sess, tf.train.latest_checkpoint(args.log_dir)) + writer = tf.summary.FileWriter(args.log_dir) + else: + sess.run(tf.global_variables_initializer()) + if os.path.exists(args.log_dir): + # delete_key = input(colored('%s exists. Delete? [y (or enter)/N]' + # % args.log_dir, 'white', 'on_red')) + # if delete_key == 'y' or delete_key == "": + os.system('rm -rf %s/*' % args.log_dir) + os.makedirs(os.path.join(args.log_dir, 'plots')) + else: + os.makedirs(os.path.join(args.log_dir, 'plots')) + with open(os.path.join(args.log_dir, 'args.txt'), 'w') as log: + for arg in sorted(vars(args)): + log.write(arg + ': ' + str(getattr(args, arg)) + '\n') # log of arguments + os.system('cp models/%s.py %s' % (args.model_type, args.log_dir)) # bkp of model def + os.system('cp train.py %s' % args.log_dir) # bkp of train procedure + writer = tf.summary.FileWriter(args.log_dir, sess.graph) + + total_time = 0 + train_start = time.time() + init_step = sess.run(global_step) + for step in range(init_step+1, args.max_step+1): + epoch = step * args.batch_size // num_train + 1 + ids, inputs, npts, gt = next(train_gen) + start = time.time() + feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: True} + _, loss, summary = sess.run([train_op, model.loss, train_summary], feed_dict=feed_dict) + total_time += time.time() - start + writer.add_summary(summary, step) + if step % args.steps_per_print == 0: + print('epoch %d step %d loss %.8f - time per batch %.4f' % + (epoch, step, loss, total_time / args.steps_per_print)) + total_time = 0 + + if step % args.steps_per_eval == 0: + print(colored('Testing...', 'grey', 'on_green')) + num_eval_steps = num_valid // args.batch_size + total_loss = 0 + total_time = 0 + sess.run(tf.local_variables_initializer()) + for i in range(num_eval_steps): + start = time.time() + ids, inputs, npts, gt = next(valid_gen) + feed_dict = {inputs_pl: inputs, npts_pl: npts, gt_pl: gt, is_training_pl: False} + loss, _ = sess.run([model.loss, model.update], feed_dict=feed_dict) + total_loss += loss + total_time += time.time() - start + summary = sess.run(valid_summary, feed_dict={is_training_pl: False}) + writer.add_summary(summary, step) + print(colored('epoch %d step %d loss %.8f - time per batch %.4f' % + (epoch, step, total_loss / num_eval_steps, total_time / num_eval_steps), + 'grey', 'on_green')) + total_time = 0 + if step % args.steps_per_visu == 0: + all_pcds = sess.run(model.visualize_ops, feed_dict=feed_dict) + for i in range(0, args.batch_size, args.visu_freq): + plot_path = os.path.join(args.log_dir, 'plots', + 'epoch_%d_step_%d_%s.png' % (epoch, step, ids[i])) + pcds = [x[i] for x in all_pcds] + #plot_pcd_three_views(plot_path, pcds, model.visualize_titles) + if step % args.steps_per_save == 0: + saver.save(sess, os.path.join(args.log_dir, 'model'), step) + print(colored('Model saved at %s' % args.log_dir, 'white', 'on_blue')) + + print('Total time', datetime.timedelta(seconds=time.time() - train_start)) + sess.close() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--data_url', default='/home/test_user01/dataset/shapenet/train.lmdb') + parser.add_argument('--train_url', default='/home/test_user01/dataset/shapenet/train.lmdb') + parser.add_argument('--lmdb_train', default='/home/test_user01/dataset/shapenet/train.lmdb') + parser.add_argument('--lmdb_valid', default='/home/test_user01/dataset/shapenet/valid.lmdb') + parser.add_argument('--log_dir', default='log/pcn_cd3') + parser.add_argument('--model_type', default='pcn_cd') + parser.add_argument('--restore', action='store_true') + parser.add_argument('--batch_size', type=int, default=16) + parser.add_argument('--num_input_points', type=int, default=3000) + parser.add_argument('--num_gt_points', type=int, default=16384) + parser.add_argument('--base_lr', type=float, default=0.001) + parser.add_argument('--lr_decay', action='store_true') + parser.add_argument('--lr_decay_steps', type=int, default=50000) + parser.add_argument('--lr_decay_rate', type=float, default=0.7) + parser.add_argument('--lr_clip', type=float, default=1e-6) + parser.add_argument('--max_step', type=int, default=1000000) + parser.add_argument('--steps_per_print', type=int, default=1) + parser.add_argument('--steps_per_eval', type=int, default=1000) + parser.add_argument('--steps_per_visu', type=int, default=3000) + parser.add_argument('--steps_per_save', type=int, default=100000) + parser.add_argument('--visu_freq', type=int, default=5) + args = parser.parse_args() + print(args.lmdb_train) + train(args) + diff --git a/visu_util.py b/visu_util.py new file mode 100644 index 000000000..18757f9c9 --- /dev/null +++ b/visu_util.py @@ -0,0 +1,60 @@ +''' +MIT License + +Copyright (c) 2018 Wentao Yuan + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +''' +from npu_bridge.npu_init import * + +#import open3d as o3d + +from matplotlib import pyplot as plt +from mpl_toolkits.mplot3d import Axes3D + + +def plot_pcd_three_views(filename, pcds, titles, suptitle='', sizes=None, cmap='Reds', zdir='y', + xlim=(-0.3, 0.3), ylim=(-0.3, 0.3), zlim=(-0.3, 0.3)): + if sizes is None: + sizes = [0.5 for i in range(len(pcds))] + fig = plt.figure(figsize=(len(pcds) * 3, 9)) + for i in range(3): + elev = 30 + azim = -45 + 90 * i + for j, (pcd, size) in enumerate(zip(pcds, sizes)): + color = pcd[:, 0] + ax = fig.add_subplot(3, len(pcds), i * len(pcds) + j + 1, projection='3d') + ax.view_init(elev, azim) + ax.scatter(pcd[:, 0], pcd[:, 1], pcd[:, 2], zdir=zdir, c=color, s=size, cmap=cmap, vmin=-1, vmax=0.5) + ax.set_title(titles[j]) + ax.set_axis_off() + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_zlim(zlim) + plt.subplots_adjust(left=0.05, right=0.95, bottom=0.05, top=0.9, wspace=0.1, hspace=0.1) + plt.suptitle(suptitle) + fig.savefig(filename) + plt.close(fig) + +''' +def show_pcd(points): + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + o3d.visualization.draw_geometries([pcd]) +''' -- Gitee