diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/.keep b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/README.md b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/README.md new file mode 100644 index 0000000000000000000000000000000000000000..4f88621704a3eeec0aee8d4f8b37aca7eb7c6111 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/README.md @@ -0,0 +1,65 @@ +# MvDSCN +:game_die: Tensorflow Repo for "Multi-view Deep Subspace Clustering Networks" + + +[[Paper]](https://arxiv.org/abs/1908.01978) (submitted to **TIP 2019**) + +# Overview + +In this work, we propose a novel multi-view deep subspace clustering network (MvDSCN) by learning a multi-view self-representation matrix in an end to end manner. +MvDSCN consists of two sub-networks, i.e., diversity network (Dnet) and universality network (Unet). +A latent space is built upon deep convolutional auto-encoders and a self-representation matrix is learned in the latent space using a fully connected layer. +Dnet learns view-specific self-representation matrices while Unet learns a common self-representation matrix for all views. +To exploit the complementarity of multi-view representations, Hilbert Schmidt Independence Criterion (HSIC) is introduced as a diversity regularization, which can capture +the non-linear and high-order inter-view relations. +As different views share the same label space, the self-representation matrices of each view are aligned to the common one by a universality regularization. + + +![MvDSCN](/assets/Architecture.jpg) + + +# Requirements + +* Tensorflow +* scipy +* numpy +* sklearn +* munkres + +# Usage + +* Test by Released Result: + +```bash +python main.py --test +``` + +* Train Network with Finetune. + +We have released the pretrain model in `/pretrain` folder, you can train it with finetune: + +```bash +python main.py --ft +``` + +* Pretrain Auoencoder From Scratch: + +You re-train autoencoder from scarath: +``` +python main.py +``` + +# Citation +Please star :star2: this repo and cite :page_facing_up: this paper if you want to use it in your work. + +``` +@article{zhu2019multiview, + title={Multi-view Deep Subspace Clustering Networks}, + author={Pengfei Zhu and Binyuan Hui and Changqing Zhang and Dawei Du and Longyin Wen and Qinghua Hu}, + journal={ArXiv: 1908.01978} + year={2019} +} +``` + +# License +MIT License diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/assets/.keep b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/assets/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/assets/Architecture.jpg b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/assets/Architecture.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ca57bb16ea033cbdaca2d77e2322aac05f5a5af1 Binary files /dev/null and b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/assets/Architecture.jpg differ diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/main.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/main.py new file mode 100644 index 0000000000000000000000000000000000000000..8b3a805b1d70e19bf64e76dc5d19b3d0a651f4f0 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/main.py @@ -0,0 +1,142 @@ +from npu_bridge.npu_init import * +import argparse +import numpy as np +from model.rgbd import MTV +from utils import process_data +from metric import thrC, post_proC, err_rate +from metric import normalized_mutual_info_score, f1_score, rand_index_score, adjusted_rand_score + +import tensorflow as tf +import os + +import scipy.io as sio +import time + +parser = argparse.ArgumentParser(description='Multi-view Deep Subspace CLustering Networks') +parser.add_argument('--path', metavar='DIR', default='./Data/rgbd_mtv.mat', + help='path to dataset') + +parser.add_argument('--data_url', help='path to dataset') + +parser.add_argument('--train_url', help='path to output') + +parser.add_argument('--epochs', default=10000, type=int, metavar='N', + help='number of total epochs to run') + +parser.add_argument('--pretrain', default=100000, type=int, metavar='N', + help='number of total epochs to run') + +parser.add_argument('--lr', default=1e-3, type=float, + help='number of total epochs to run') + +parser.add_argument('--gpu', default='0', type=str, + help='GPU id to use.') + +parser.add_argument('--ft', action='store_true', help='finetune') + +parser.add_argument('--test', action='store_true', help='run kmeans on learned coef') + +np.random.seed(1) +tf.compat.v1.set_random_seed(1) + + +def main(): + args = parser.parse_args() + np.random.seed(1) + tf.compat.v1.set_random_seed(1) + # ignore tensorflow warning + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '0' + os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu + + view_shape, views, label = process_data(args) + num_class = np.unique(label).shape[0] + batch_size = label.shape[0] + # class_single = batch_size / num_class # 10 + + reg1 = 1.0 + reg2 = 1.0 + alpha = max(0.4 - (num_class-1)/10 * 0.1, 0.1) + lr = args.lr + acc_= [] + + tf.compat.v1.reset_default_graph() + + if args.test: + label_10_subjs = label - label.min() + 1 + label_10_subjs = np.squeeze(label_10_subjs) + print("args.gpu = ", args.gpu) + Coef = sio.loadmat(args.data_url + '/result/rgbd_coef.mat')['coef'] + y_x, L = post_proC(Coef, label_10_subjs.max(), 3, 1) + missrate_x = err_rate(label_10_subjs, y_x) + acc_x = 1 - missrate_x + nmi = normalized_mutual_info_score(label_10_subjs, y_x) + f_measure = f1_score(label_10_subjs, y_x) + ri = rand_index_score(label_10_subjs, y_x) + ar = adjusted_rand_score(label_10_subjs, y_x) + print("Final Accuracy accuracy %.4f " % acc_x) + print("nmi: %.4f" % nmi, + "accuracy: %.4f" % acc_x, + "F-measure: %.4f" % f_measure, + "RI: %.4f" % ri, + "AR: %.4f" % ar + ) + exit() + + if not args.ft: + # pretrian stage + mtv = MTV(view_shape=view_shape, batch_size=batch_size, ft=False, reg_constant1=reg1, reg_constant2=reg2) + mtv.restore() + epoch = 0 + min_loss = 9970 + while epoch < args.pretrain: + loss = mtv.reconstruct(views[0], views[1], lr) + print("epoch: %.1d" % epoch, "loss: %.8f" % (loss/float(batch_size))) + if loss/float(batch_size) < min_loss: + print('save model.') + mtv.save_model() + min_loss = loss/float(batch_size) + epoch += 1 + else: + t3=time.time() + # self-expressive stage + mtv = MTV(view_shape=view_shape, batch_size=batch_size, ft=True, reg_constant1=reg1, reg_constant2=reg2) + mtv.restore() + Coef = None + label_10_subjs = label - label.min() + 1 + label_10_subjs = np.squeeze(label_10_subjs) + + best_acc, best_epoch = 0, 0 + + epoch = 0 + + while epoch < args.epochs: + start = time.time() + loss, Coef, Coef_1, Coef_2 = mtv.finetune(views[0], views[1], lr) + end = time.time() + print("epoch: %.1d" % epoch) + print("loss: %.8f" % (loss)) + epoch += 1 + print("sec/step :", end - start) + + Coef = thrC(Coef, alpha) + sio.savemat(args.data_url +'/result/rgbd_coef.mat', dict([('coef', Coef)])) + y_x, L = post_proC(Coef, label_10_subjs.max(), 3, 1) + missrate_x = err_rate(label_10_subjs, y_x) + acc_x = 1 - missrate_x + nmi = normalized_mutual_info_score(label_10_subjs, y_x) + f_measure = f1_score(label_10_subjs, y_x) + ri = rand_index_score(label_10_subjs, y_x) + ar = adjusted_rand_score(label_10_subjs, y_x) + + print("Final Accuracy accuracy %.4f" % acc_x) + print("nmi: %.4f" % nmi, + "accuracy: %.4f" % acc_x, + "F-measure: %.4f" % f_measure, + "RI: %.4f" % ri, + "AR: %.4f" % ar + ) + t4=time.time() + print("The overal time is:", t4-t3) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/metric.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..8876acdc22b09c92e0c0cfd6fe42a20260ae390e --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/metric.py @@ -0,0 +1,120 @@ +from npu_bridge.npu_init import * +import numpy as np +from sklearn import cluster +from sklearn.preprocessing import normalize +from munkres import Munkres + +from sklearn.metrics.cluster import normalized_mutual_info_score +from sklearn.metrics.cluster import adjusted_rand_score + +from scipy.sparse.linalg import svds +from scipy import sparse as sp +from scipy.special import comb + +np.random.seed(1) + +def best_map(L1, L2): + #L1 should be the groundtruth labels and L2 should be the clustering labels we got + Label1 = np.unique(L1) + nClass1 = len(Label1) + Label2 = np.unique(L2) + nClass2 = len(Label2) + nClass = np.maximum(nClass1,nClass2) + G = np.zeros((nClass,nClass)) + for i in range(nClass1): + ind_cla1 = L1 == Label1[i] + ind_cla1 = ind_cla1.astype(float) + for j in range(nClass2): + ind_cla2 = L2 == Label2[j] + ind_cla2 = ind_cla2.astype(float) + G[i,j] = np.sum(ind_cla2 * ind_cla1) + m = Munkres() + index = m.compute(-G.T) + index = np.array(index) + c = index[:,1] + newL2 = np.zeros(L2.shape) + for i in range(nClass2): + newL2[L2 == Label2[i]] = Label1[c[i]] + return newL2 + +def thrC(C, ro): + if ro < 1: + N = C.shape[1] + Cp = np.zeros((N,N)) + S = np.abs(np.sort(-np.abs(C),axis=0)) + Ind = np.argsort(-np.abs(C),axis=0) + for i in range(N): + cL1 = np.sum(S[:,i]).astype(float) + stop = False + csum = 0 + t = 0 + while(stop == False): + # print(S.shape, t, i) + csum = csum + S[t,i] + if csum > ro*cL1: + stop = True + Cp[Ind[0:t+1,i],i] = C[Ind[0:t+1,i],i] + t = t + 1 + else: + Cp = C + return Cp + +def post_proC(C, K, d, alpha): + # C: coefficient matrix, K: number of clusters, d: dimension of each subspace + C = 0.5*(C + C.T) + r = min(d*K + 1, C.shape[0]-1) + U, S, _ = svds(C, r, v0=np.ones(C.shape[0])) + U = U[:,::-1] + S = np.sqrt(S[::-1]) + S = np.diag(S) + U = U.dot(S) + U = normalize(U, norm='l2', axis = 1) + Z = U.dot(U.T) + Z = Z * (Z>0) + L = np.abs(Z ** alpha) + L = L/L.max() + L = 0.5 * (L + L.T) + spectral = cluster.SpectralClustering(n_clusters=K, eigen_solver='arpack', affinity='precomputed', assign_labels='discretize', random_state=66) + spectral.fit(L) + # print("fit :", 1111) + grp = spectral.fit_predict(L) + 1 + return grp, L + +def err_rate(gt_s, s): + c_x = best_map(gt_s, s) + err_x = np.sum(gt_s[:] != c_x[:]) + missrate = err_x.astype(float) / (gt_s.shape[0]) + return missrate + +def f1_score(gt_s, s): + N = len(gt_s) + num_t = 0 + num_h = 0 + num_i = 0 + for n in range(N-1): + tn = (gt_s[n] == gt_s[n+1:]).astype('int') + hn = (s[n] == s[n+1:]).astype('int') + num_t += np.sum(tn) + num_h += np.sum(hn) + num_i += np.sum(tn * hn) + p = r = f = 1 + if num_h > 0: + p = num_i / num_h + if num_t > 0: + r = num_i / num_t + if p + r == 0: + f = 0 + else: + f = 2 * p * r / (p + r) + return f + +def rand_index_score(clusters, classes): + tp_plus_fp = comb(np.bincount(clusters), 2).sum() + tp_plus_fn = comb(np.bincount(classes), 2).sum() + A = np.c_[(clusters, classes)] + tp = sum(comb(np.bincount(A[A[:, 0] == i, 1]), 2).sum() + for i in set(clusters)) + fp = tp_plus_fp - tp + fn = tp_plus_fn - tp + tn = comb(len(A), 2) - tp - fp - fn + return (tp + tn) / (tp + fp + fn + tn) diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/.keep b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/.keep new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/__init__.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1e2e97566dac83c754049fa5bc6572ac7b180390 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/__init__.py @@ -0,0 +1,7 @@ +from npu_bridge.npu_init import * +import os +os.system('pip install munkres') +os.system('pip install sklearn') +os.system('pip install scipy') + + diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/rgbd.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/rgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..47dc4e10633af53b4141c56d2970e5afd9c3bb52 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/model/rgbd.py @@ -0,0 +1,267 @@ +from npu_bridge.npu_init import * +import tensorflow as tf + + +class MTV: + def __init__(self, view_shape, batch_size, ft=False, reg_constant1 = 1.0, reg_constant2 = 1.0, reg = None, \ + denoise = False, model_path = '/home/ma-user/modelarts/inputs/data_url_0/pretrain/rgbd/ae_fusion',\ + restore_path = '/home/ma-user/modelarts/inputs/data_url_0/pretrain/rgbd/ae_fusion'): + self.ft = ft + self.view1_input = view_shape[0] + self.view2_input = view_shape[1] + self.view3_input = view_shape[2] + self.batch_size = batch_size + + self.model_path = model_path + self.restore_path = restore_path + self.iter = 0 + + # different view feature input + # self.view1 = tf.placeholder(tf.float32, [None, 64, 64, 3], name='view_1') + # tf.compat.v1.placeholder + self.view1 = tf.compat.v1.placeholder(tf.float32, [None, 64, 64, 3], name='view_1') + self.view2 = tf.compat.v1.placeholder(tf.float32, [None, 64, 64, 1], name='view_2') + + # learning rate + self.lr = tf.compat.v1.placeholder(tf.float32, [], name='lr') + + # encoder + # latent is the output of Unet encoder + latent1 = self.encoder1(self.view1) + latent2 = self.encoder2(self.view2) + + # lantent_single means output of Dnet encoder + latent1_single = self.encoder1_single(self.view1) + latent2_single = self.encoder2_single(self.view2) + + # reshape + self.z1 = tf.reshape(latent1, [batch_size, -1]) + self.z2 = tf.reshape(latent2, [batch_size, -1]) + z1 = self.z1 + z2 = self.z2 + + z1_single = tf.reshape(latent1_single, [batch_size, -1]) + z2_single = tf.reshape(latent2_single, [batch_size, -1]) + + # self-expressive layer + # common expressive + self.Coef = tf.Variable(1.0e-8 * tf.ones([self.batch_size, self.batch_size], tf.float32), name = 'Coef') + # single expressive + self.Coef_1 = tf.Variable(1.0e-8 * tf.ones([self.batch_size, self.batch_size], tf.float32), name = 'Coef_1') + self.Coef_2 = tf.Variable(1.0e-8 * tf.ones([self.batch_size, self.batch_size], tf.float32), name = 'Coef_2') + + + # normalize + self.Coef = (self.Coef - tf.linalg.tensor_diag(tf.linalg.tensor_diag_part(self.Coef))) + + self.Coef_1 = (self.Coef_1 - tf.linalg.tensor_diag(tf.linalg.tensor_diag_part(self.Coef_1))) + self.Coef_2 = (self.Coef_2 - tf.linalg.tensor_diag(tf.linalg.tensor_diag_part(self.Coef_2))) + + # zc + z1_c = tf.matmul(self.Coef, z1) + z2_c = tf.matmul(self.Coef, z2) + + z1_c_single = tf.matmul(self.Coef_1, z1_single) + z2_c_single = tf.matmul(self.Coef_2, z2_single) + + # reshape + latent1_c = tf.reshape(z1_c, tf.shape(latent1)) + latent2_c = tf.reshape(z2_c, tf.shape(latent2)) + + latent1_c_single = tf.reshape(z1_c_single, tf.shape(latent1_single)) + latent2_c_single = tf.reshape(z2_c_single, tf.shape(latent2_single)) + + if self.ft: + # reconst with self-expressive + self.view1_r = self.decoder1(latent1_c) + self.view2_r = self.decoder2(latent2_c) + + self.view1_r_single = self.decoder1_single(latent1_c_single) + self.view2_r_single = self.decoder2_single(latent2_c_single) + + else: + # only reconst by autoencoder + self.view1_r = self.decoder1(latent1) + self.view2_r = self.decoder2(latent2) + + self.view1_r_single = self.decoder1_single(latent1_single) + self.view2_r_single = self.decoder2_single(latent2_single) + + print(latent1.shape, self.view1_r.shape) + print(latent2.shape, self.view2_r.shape) + + # reconstruction loss by Unet + self.reconst_loss_1 = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.view1_r, self.view1), 2.0)) + self.reconst_loss_2 = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.view2_r, self.view2), 2.0)) + + # reconstruction loss by Dnet + self.reconst_loss_1_single = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.view1_r_single, self.view1), 2.0)) + self.reconst_loss_2_single = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(self.view2_r_single, self.view2), 2.0)) + + + self.reconst_loss_single = self.reconst_loss_1_single + self.reconst_loss_2_single + + # reconstruction loss all (Unet + Dnet) + self.reconst_loss = self.reconst_loss_1 + self.reconst_loss_2 + self.reconst_loss += self.reconst_loss_single + + # self-expressive loss by Unet + self.selfexpress_loss_1 = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(z1_c, z1), 2.0)) + self.selfexpress_loss_2 = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(z2_c, z2), 2.0)) + # self-expressive loss by Dnet + self.selfexpress_loss_1_single = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(z1_c_single, z1_single), 2.0)) + self.selfexpress_loss_2_single = 0.5 * tf.reduce_sum(tf.pow(tf.subtract(z2_c_single, z2_single), 2.0)) + + # selfexpress all (Unet + Dnet) + self.selfexpress_loss = self.selfexpress_loss_1 + self.selfexpress_loss_2 + self.selfexpress_loss_single = self.selfexpress_loss_1_single + self.selfexpress_loss_2_single + + self.selfexpress_loss += self.selfexpress_loss_single + + # Coef regularization + self.reg_loss = tf.reduce_sum(tf.pow(self.Coef, 2.0)) + + self.reg_loss += tf.reduce_sum(tf.pow(self.Coef_1, 2.0)) + self.reg_loss += tf.reduce_sum(tf.pow(self.Coef_2, 2.0)) + + + # unify loss + self.unify_loss = tf.reduce_sum(tf.abs(tf.subtract(self.Coef, self.Coef_1))) + \ + tf.reduce_sum(tf.abs(tf.subtract(self.Coef, self.Coef_2))) + + + self.hsic_loss = self.HSIC(self.Coef_1, self.Coef_2) + + # summary loss + self.loss = self.reconst_loss + reg_constant1 * self.reg_loss + reg_constant2 * self.selfexpress_loss + self.unify_loss * 0.1 + self.hsic_loss * 0.1 + + # selfexpression optimizer + self.optimizer = npu_distributed_optimizer_wrapper(tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr)).minimize(self.loss) + # autoencoder optimizer + self.optimizer_ae = npu_distributed_optimizer_wrapper(tf.compat.v1.train.AdamOptimizer(learning_rate=self.lr)).minimize(self.reconst_loss) + # session + self.init = tf.compat.v1.global_variables_initializer() + config = tf.compat.v1.ConfigProto() + config.gpu_options.allow_growth=True + self.sess = tf.compat.v1.Session(config=npu_config_proto(config_proto=config)) + self.sess.run(self.init) + + + self.saver = tf.compat.v1.train.Saver( + [v for v in tf.compat.v1.trainable_variables() if not (v.name.startswith("Coef"))] + ) + + def HSIC(self, c_v, c_w): + N = tf.shape(c_v)[0] + H = tf.ones((N, N)) * tf.cast((1/N), tf.float32) * (-1) + tf.eye(N) + K_1 = tf.matmul(c_v, tf.transpose(c_v)) + K_2 = tf.matmul(c_w, tf.transpose(c_w)) + rst = tf.matmul(K_1, H) + rst = tf.matmul(rst, K_2) + rst = tf.matmul(rst, H) + rst = tf.linalg.trace(rst) + return rst + + + def conv_block(self, inputs, out_channels, name='conv'): + # conv = tf.layers.conv2d(inputs, out_channels, kernel_size=3, strides=2, padding="same") + # tf.keras.layers.Conv2D + conv = tf.keras.layers.Conv2D(out_channels, kernel_size=3, strides=2, padding="same")(inputs) + conv = tf.nn.relu(conv) + return conv + + def deconv_block(self, inputs, out_channels, name='conv'): + deconv = tf.keras.layers.Conv2DTranspose(out_channels, kernel_size=3, strides=2, padding='same')(inputs) + deconv = tf.nn.relu(deconv) + return deconv + + def encoder1(self, x): + net = self.conv_block(x, 64) + net = self.conv_block(net, 64) + net = self.conv_block(net, 64) + return net + + def encoder1_single(self, x): + net = self.conv_block(x, 64) + net = self.conv_block(net, 64) + net = self.conv_block(net, 64) + return net + + def decoder1(self, z): + net = self.deconv_block(z, 64) + net = self.deconv_block(net, 64) + net = self.deconv_block(net, 3) + return net + + def decoder1_single(self, z): + net = self.deconv_block(z, 64) + net = self.deconv_block(net, 64) + net = self.deconv_block(net, 3) + return net + + def encoder2(self, x): + net = self.conv_block(x, 64) + net = self.conv_block(net, 64) + net = self.conv_block(net, 64) + return net + + def encoder2_single(self, x): + net = self.conv_block(x, 64) + net = self.conv_block(net, 64) + net = self.conv_block(net, 64) + return net + + def decoder2(self, z): + net = self.deconv_block(z, 64) + net = self.deconv_block(net, 64) + net = self.deconv_block(net, 1) + return net + + def decoder2_single(self, z): + net = self.deconv_block(z, 64) + net = self.deconv_block(net, 64) + net = self.deconv_block(net, 1) + return net + + def finetune(self, view1, view2, lr): + loss, _, Coef, Coef_1, Coef_2 = self.sess.run( + (self.loss, self.optimizer, self.Coef, self.Coef_1, self.Coef_2), + feed_dict={ + self.view1: view1, + self.view2: view2, + self.lr: lr + }) + return loss, Coef, Coef_1, Coef_2 + + def reconstruct(self, view1, view2, lr): + loss, _ = self.sess.run( + [self.reconst_loss, self.optimizer_ae], + feed_dict={ + self.view1: view1, + self.view2: view2, + self.lr: lr + } + ) + return loss + + def initlization(self): + self.sess.run(self.init) + + def save_model(self): + save_path = self.saver.save(self.sess, self.model_path) + print("model saved in ", save_path) + return save_path + + def restore(self): + self.saver.restore(self.sess, self.restore_path) + print("mode restored successed.") + + def get_latent(self, view1, view2): + latent_1, latent_2 = self.sess.run( + [self.z1, self.z2], + feed_dict={ + self.view1: view1, + self.view2: view2 + } + ) + return latent_1, latent_2 diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_acc.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_acc.py new file mode 100644 index 0000000000000000000000000000000000000000..ef1554168f39ee6fc4146573cc031901b940ad22 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_acc.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This is the boot file for ModelArts platform. +Firstly, the train datasets are copyed from obs to ModelArts. +Then, the string of train shell command is concated and using 'os.system()' to execute +""" +import os +import argparse +import sys + +# 解析输入参数data_url +parser = argparse.ArgumentParser() +parser.add_argument("--data_url", type=str, default="/home/ma-user/modelarts/inputs/data_url_0") +parser.add_argument("--train_url", type=str, default="/home/ma-user/modelarts/outputs/train_url_0/") +config = parser.parse_args() + +print("[CANN-Modelzoo] code_dir path is [%s]" % (sys.path[0])) +code_dir = sys.path[0] +os.chdir(code_dir) +print("[CANN-Modelzoo] work_dir path is [%s]" % (os.getcwd())) + +print("[CANN-Modelzoo] before train - list my run files:") +os.system("ls -al /usr/local/Ascend/ascend-toolkit/") + +print("[CANN-Modelzoo] before train - list my dataset files:") +os.system("ls -al %s" % config.data_url) + +print("[CANN-Modelzoo] start run train shell") +# 设置sh文件格式为linux可执行 +os.system("dos2unix ./test/*") + +# 执行train_full_1p.sh或者train_performance_1p.sh,需要用户自己指定 +# full和performance的差异,performance只需要执行很少的step,控制在15分钟以内,主要关注性能FPS +os.system("bash ./test/train_full_1p.sh --data_path=%s --output_path=%s " % (config.data_url, config.train_url)) + +print("[CANN-Modelzoo] finish run train shell") + +# 将当前执行目录所有文件拷贝到obs的output进行备份 +print("[CANN-Modelzoo] after train - list my output files:") +os.system("cp -r %s %s " % (code_dir, config.train_url)) +os.system("ls -al %s" % config.train_url) diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_perf.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_perf.py new file mode 100644 index 0000000000000000000000000000000000000000..d76fe4046bf66e19e745ba3642e4c066e7ff1614 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/modelarts_entry_perf.py @@ -0,0 +1,53 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This is the boot file for ModelArts platform. +Firstly, the train datasets are copyed from obs to ModelArts. +Then, the string of train shell command is concated and using 'os.system()' to execute +""" +import os +import argparse +import sys + +# 解析输入参数data_url +parser = argparse.ArgumentParser() +parser.add_argument("--data_url", type=str, default="/home/ma-user/modelarts/inputs/data_url_0") +parser.add_argument("--train_url", type=str, default="/home/ma-user/modelarts/outputs/train_url_0/") +config = parser.parse_args() + +print("[CANN-Modelzoo] code_dir path is [%s]" % (sys.path[0])) +code_dir = sys.path[0] +os.chdir(code_dir) +print("[CANN-Modelzoo] work_dir path is [%s]" % (os.getcwd())) + +print("[CANN-Modelzoo] before train - list my run files:") +os.system("ls -al /usr/local/Ascend/ascend-toolkit/") + +print("[CANN-Modelzoo] before train - list my dataset files:") +os.system("ls -al %s" % config.data_url) + +print("[CANN-Modelzoo] start run train shell") +# 设置sh文件格式为linux可执行 +os.system("dos2unix ./test/*") + +# 执行train_full_1p.sh或者train_performance_1p.sh,需要用户自己指定 +# full和performance的差异,performance只需要执行很少的step,控制在15分钟以内,主要关注性能FPS +os.system("bash ./test/train_performance_1p.sh --data_path=%s --output_path=%s " % (config.data_url, config.train_url)) + +print("[CANN-Modelzoo] finish run train shell") + +# 将当前执行目录所有文件拷贝到obs的output进行备份 +print("[CANN-Modelzoo] after train - list my output files:") +os.system("cp -r %s %s " % (code_dir, config.train_url)) +os.system("ls -al %s" % config.train_url) diff --git a/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/utils.py b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6d83c70291cf24f59930da89b8bb96c010f20430 --- /dev/null +++ b/TensorFlow/contrib/cv/MVDSCN_ID1272_for_TensorFlow/utils.py @@ -0,0 +1,19 @@ +from npu_bridge.npu_init import * +import numpy as np + +import scipy.io as sio + +def process_data(args): + # to do release other dataset. + if 'rgbd' in args.path: + data = sio.loadmat(args.data_url + '/rgbd_mtv.mat') + features = data['X'] + label = data['gt'] + + views = [] + view_shape = [] + for v in features[0]: + view_shape.append(v.shape[1]) + views.append(v) + + return view_shape, views, label