From 65a8e70b09ad93fe3a79970be5da033cc725c2e5 Mon Sep 17 00:00:00 2001 From: wuzhifeng Date: Thu, 25 Sep 2025 21:36:32 +0800 Subject: [PATCH] add sharker --- mindscience/{gnn => sharker}/__init__.py | 24 +- mindscience/sharker/data/__init__.py | 38 + mindscience/sharker/data/batch.py | 222 ++++ mindscience/sharker/data/collate.py | 290 ++++ mindscience/sharker/data/database.py | 597 +++++++++ mindscience/sharker/data/datapipe.py | 67 + mindscience/sharker/data/dataset.py | 429 ++++++ mindscience/sharker/data/download.py | 67 + mindscience/sharker/data/extract.py | 77 ++ mindscience/sharker/data/graph.py | 971 ++++++++++++++ mindscience/sharker/data/heterograph.py | 1023 ++++++++++++++ mindscience/sharker/data/hypergraph.py | 234 ++++ mindscience/sharker/data/in_memory.py | 344 +++++ mindscience/sharker/data/on_disk.py | 175 +++ mindscience/sharker/data/remote_store.py | 582 ++++++++ mindscience/sharker/data/separate.py | 149 +++ mindscience/sharker/data/storage.py | 780 +++++++++++ mindscience/sharker/data/summary.py | 157 +++ mindscience/sharker/data/temporal.py | 315 +++++ mindscience/sharker/data/view.py | 39 + mindscience/sharker/dataset/__init__.py | 1 + mindscience/sharker/dataset/qm9.py | 332 +++++ mindscience/sharker/experimental.py | 136 ++ mindscience/sharker/home.py | 30 + mindscience/sharker/inspector.py | 554 ++++++++ mindscience/sharker/io/__init__.py | 18 + mindscience/sharker/io/fs.py | 207 +++ mindscience/sharker/io/npz.py | 36 + mindscience/sharker/io/obj.py | 41 + mindscience/sharker/io/off.py | 42 + mindscience/sharker/io/planetoid.py | 59 + mindscience/sharker/io/ply.py | 19 + mindscience/sharker/io/sdf.py | 32 + mindscience/sharker/io/tu.py | 142 ++ mindscience/sharker/io/txt_array.py | 32 + mindscience/sharker/loader/__init__.py | 7 + mindscience/sharker/loader/dataloader.py | 393 ++++++ mindscience/sharker/nn/__init__.py | 15 + mindscience/sharker/nn/aggr/__init__.py | 56 + mindscience/sharker/nn/aggr/attention.py | 92 ++ mindscience/sharker/nn/aggr/base.py | 204 +++ mindscience/sharker/nn/aggr/basic.py | 324 +++++ mindscience/sharker/nn/aggr/deep_sets.py | 81 ++ mindscience/sharker/nn/aggr/equilibrium.py | 206 +++ mindscience/sharker/nn/aggr/fused.py | 345 +++++ mindscience/sharker/nn/aggr/gmt.py | 102 ++ mindscience/sharker/nn/aggr/gru.py | 58 + mindscience/sharker/nn/aggr/lcm.py | 124 ++ mindscience/sharker/nn/aggr/lstm.py | 58 + mindscience/sharker/nn/aggr/mlp.py | 76 ++ mindscience/sharker/nn/aggr/multi.py | 204 +++ mindscience/sharker/nn/aggr/quantile.py | 171 +++ mindscience/sharker/nn/aggr/scaler.py | 115 ++ mindscience/sharker/nn/aggr/set2set.py | 67 + .../sharker/nn/aggr/set_transformer.py | 116 ++ mindscience/sharker/nn/aggr/sort.py | 76 ++ mindscience/sharker/nn/aggr/utils.py | 236 ++++ .../sharker/nn/aggr/variance_preserving.py | 34 + mindscience/sharker/nn/conv/__init__.py | 113 ++ mindscience/sharker/nn/conv/agnn_conv.py | 76 ++ .../sharker/nn/conv/antisymmetric_conv.py | 116 ++ mindscience/sharker/nn/conv/appnp.py | 109 ++ mindscience/sharker/nn/conv/arma_conv.py | 131 ++ mindscience/sharker/nn/conv/cg_conv copy.py | 86 ++ mindscience/sharker/nn/conv/cg_conv.py | 86 ++ mindscience/sharker/nn/conv/cheb_conv.py | 181 +++ .../sharker/nn/conv/cluster_gcn_conv.py | 89 ++ mindscience/sharker/nn/conv/dir_gnn_conv.py | 70 + mindscience/sharker/nn/conv/dna_conv.py | 295 ++++ mindscience/sharker/nn/conv/edge_conv.py | 138 ++ mindscience/sharker/nn/conv/eg_conv.py | 200 +++ mindscience/sharker/nn/conv/fa_conv.py | 162 +++ mindscience/sharker/nn/conv/feast_conv.py | 103 ++ mindscience/sharker/nn/conv/film_conv.py | 138 ++ mindscience/sharker/nn/conv/gat_conv.py | 310 +++++ .../sharker/nn/conv/gated_graph_conv.py | 86 ++ mindscience/sharker/nn/conv/gatv2_conv.py | 280 ++++ mindscience/sharker/nn/conv/gcn_conv.py | 183 +++ mindscience/sharker/nn/conv/gen_conv.py | 217 +++ mindscience/sharker/nn/conv/general_conv.py | 172 +++ mindscience/sharker/nn/conv/gin_conv.py | 191 +++ mindscience/sharker/nn/conv/gmm_conv.py | 168 +++ mindscience/sharker/nn/conv/graph_conv.py | 94 ++ mindscience/sharker/nn/conv/gravnet_conv.py | 113 ++ mindscience/sharker/nn/conv/heat_conv.py | 134 ++ mindscience/sharker/nn/conv/hetero_conv.py | 167 +++ mindscience/sharker/nn/conv/hgt_conv.py | 227 ++++ .../sharker/nn/conv/hypergraph_conv.py | 210 +++ mindscience/sharker/nn/conv/le_conv.py | 94 ++ mindscience/sharker/nn/conv/lg_conv.py | 48 + .../sharker/nn/conv/message_passing.py | 611 +++++++++ mindscience/sharker/nn/conv/mf_conv.py | 108 ++ mindscience/sharker/nn/conv/mixhop_conv.py | 110 ++ mindscience/sharker/nn/conv/nn_conv.py | 121 ++ mindscience/sharker/nn/conv/pdn_conv.py | 113 ++ mindscience/sharker/nn/conv/point_conv.py | 107 ++ mindscience/sharker/nn/conv/point_gnn_conv.py | 80 ++ .../sharker/nn/conv/point_transformer_conv.py | 139 ++ mindscience/sharker/nn/conv/ppf_conv.py | 129 ++ .../sharker/nn/conv/res_gated_graph_conv.py | 138 ++ mindscience/sharker/nn/conv/rgat_conv.py | 515 +++++++ mindscience/sharker/nn/conv/sage_conv.py | 142 ++ mindscience/sharker/nn/conv/sg_conv.py | 98 ++ mindscience/sharker/nn/conv/signed_conv.py | 136 ++ mindscience/sharker/nn/conv/simple_conv.py | 84 ++ mindscience/sharker/nn/conv/spline_conv.py | 142 ++ mindscience/sharker/nn/conv/ssg_conv.py | 109 ++ mindscience/sharker/nn/conv/tag_conv.py | 95 ++ .../sharker/nn/conv/transformer_conv.py | 224 ++++ mindscience/sharker/nn/conv/wl_conv.py | 81 ++ .../sharker/nn/conv/wl_conv_continuous.py | 72 + mindscience/sharker/nn/conv/x_conv.py | 148 +++ mindscience/sharker/nn/dense/__init__.py | 14 + mindscience/sharker/nn/dense/linear.py | 272 ++++ mindscience/sharker/nn/encoding.py | 97 ++ mindscience/sharker/nn/inits.py | 103 ++ mindscience/sharker/nn/lr_scheduler.py | 251 ++++ mindscience/sharker/nn/models/__init__.py | 3 + mindscience/sharker/nn/models/mlp.py | 251 ++++ mindscience/sharker/nn/norm/__init__.py | 12 + mindscience/sharker/nn/norm/batch_norm.py | 213 +++ mindscience/sharker/nn/norm/msg_norm.py | 48 + mindscience/sharker/nn/reshape.py | 16 + mindscience/sharker/nn/resolver.py | 174 +++ mindscience/sharker/profile/__init__.py | 19 + mindscience/sharker/profile/benchmark.py | 132 ++ mindscience/sharker/profile/utils.py | 93 ++ mindscience/sharker/resolver.py | 43 + mindscience/sharker/seed.py | 16 + mindscience/sharker/template.py | 37 + mindscience/sharker/testing/__init__.py | 62 + mindscience/sharker/testing/asserts.py | 91 ++ mindscience/sharker/testing/data.py | 57 + mindscience/sharker/testing/decorators.py | 150 +++ mindscience/sharker/testing/distributed.py | 92 ++ mindscience/sharker/typing.py | 194 +++ mindscience/sharker/utils/__init__.py | 144 ++ mindscience/sharker/utils/_scatter.py | 339 +++++ mindscience/sharker/utils/_segment.py | 115 ++ mindscience/sharker/utils/assortativity.py | 61 + mindscience/sharker/utils/augmentation.py | 242 ++++ mindscience/sharker/utils/cluster.py | 560 ++++++++ mindscience/sharker/utils/coalesce.py | 194 +++ mindscience/sharker/utils/convert.py | 389 ++++++ mindscience/sharker/utils/degree.py | 31 + mindscience/sharker/utils/dropout.py | 151 +++ mindscience/sharker/utils/embedding.py | 54 + mindscience/sharker/utils/functions.py | 76 ++ mindscience/sharker/utils/grid.py | 74 ++ mindscience/sharker/utils/hetero.py | 132 ++ mindscience/sharker/utils/homophily.py | 128 ++ mindscience/sharker/utils/isolated.py | 96 ++ mindscience/sharker/utils/laplacian.py | 207 +++ mindscience/sharker/utils/loop.py | 322 +++++ mindscience/sharker/utils/map.py | 169 +++ mindscience/sharker/utils/mask.py | 99 ++ mindscience/sharker/utils/mixin.py | 22 + mindscience/sharker/utils/ncon.py | 514 +++++++ .../sharker/utils/negative_sampling.py | 388 ++++++ mindscience/sharker/utils/noise_scheduler.py | 89 ++ mindscience/sharker/utils/normalize.py | 36 + mindscience/sharker/utils/num_nodes.py | 57 + mindscience/sharker/utils/random.py | 88 ++ mindscience/sharker/utils/repeat.py | 35 + mindscience/sharker/utils/select.py | 68 + mindscience/sharker/utils/softmax.py | 70 + mindscience/sharker/utils/sort_edge_index.py | 104 ++ mindscience/sharker/utils/sparse.py | 37 + mindscience/sharker/utils/subgraph.py | 449 +++++++ mindscience/sharker/utils/to_dense_adj.py | 105 ++ mindscience/sharker/utils/to_dense_batch.py | 136 ++ .../sharker/utils/tree_decomposition.py | 128 ++ mindscience/sharker/utils/trim_to_layer.py | 157 +++ mindscience/sharker/utils/unbatch.py | 70 + mindscience/sharker/utils/undirected.py | 142 ++ tests/graph/__init__.py | 1 + tests/graph/cluster/test_fps.py | 71 + tests/graph/cluster/test_graclus.py | 56 + tests/graph/cluster/test_grid.py | 43 + tests/graph/cluster/test_knn.py | 88 ++ tests/graph/cluster/test_nearest.py | 63 + tests/graph/cluster/test_radius.py | 133 ++ tests/graph/cluster/test_rw.py | 85 ++ tests/graph/conftest.py | 95 ++ tests/graph/data/test_batch.py | 379 ++++++ tests/graph/data/test_data.py | 429 ++++++ tests/graph/data/test_database.py | 219 +++ tests/graph/data/test_dataloader.py | 172 +++ tests/graph/data/test_dataset.py | 340 +++++ tests/graph/data/test_dataset_summary.py | 99 ++ tests/graph/data/test_feature_store.py | 109 ++ tests/graph/data/test_graph_store.py | 96 ++ tests/graph/data/test_hetero.py | 625 +++++++++ tests/graph/data/test_hypergraph.py | 156 +++ tests/graph/data/test_inherit.py | 61 + tests/graph/data/test_on_disk_dataset.py | 111 ++ tests/graph/data/test_remote_backend_utils.py | 33 + tests/graph/data/test_storage.py | 81 ++ tests/graph/data/test_temporal.py | 137 ++ tests/graph/data/test_view.py | 32 + .../datasets/graph_generator/test_ba_graph.py | 11 + .../datasets/graph_generator/test_er_graph.py | 12 + .../graph_generator/test_grid_graph.py | 11 + .../graph_generator/test_tree_graph.py | 25 + .../motif_generator/test_custom_motif.py | 37 + .../motif_generator/test_cycle_motif.py | 15 + .../motif_generator/test_grid_motif.py | 17 + .../motif_generator/test_house_motif.py | 12 + tests/graph/datasets/test_ba_shapes.py | 18 + tests/graph/datasets/test_bzr.py | 23 + tests/graph/datasets/test_elliptic.py | 28 + tests/graph/datasets/test_enzymes.py | 71 + .../graph/datasets/test_explainer_dataset.py | 47 + tests/graph/datasets/test_fake.py | 84 ++ tests/graph/datasets/test_imdb_binary.py | 17 + .../graph/datasets/test_infection_dataset.py | 54 + tests/graph/datasets/test_karate.py | 13 + tests/graph/datasets/test_mutag.py | 19 + tests/graph/datasets/test_planetoid.py | 62 + tests/graph/datasets/test_snap_dataset.py | 25 + tests/graph/datasets/test_suite_sparse.py | 19 + .../algorithm/test_attention_explainer.py | 84 ++ tests/graph/explain/algorithm/test_captum.py | 201 +++ .../algorithm/test_captum_explainer.py | 255 ++++ .../explain/algorithm/test_captum_hetero.py | 106 ++ .../algorithm/test_explain_algorithm_utils.py | 78 ++ .../explain/algorithm/test_gnn_explainer.py | 275 ++++ .../algorithm/test_graphmask_explainer.py | 232 ++++ .../explain/algorithm/test_pg_explainer.py | 178 +++ tests/graph/explain/conftest.py | 94 ++ .../graph/explain/metric/test_basic_metric.py | 46 + .../graph/explain/metric/test_faithfulness.py | 59 + tests/graph/explain/metric/test_fidelity.py | 74 ++ tests/graph/explain/test_explain_config.py | 48 + tests/graph/explain/test_explainer.py | 123 ++ tests/graph/explain/test_explanation.py | 150 +++ tests/graph/explain/test_hetero_explainer.py | 128 ++ .../graph/explain/test_hetero_explanation.py | 144 ++ tests/graph/io/example1.off | 8 + tests/graph/io/example2.off | 7 + tests/graph/io/test_fs.py | 131 ++ tests/graph/io/test_off.py | 37 + tests/graph/loader/test_cache.py | 68 + tests/graph/loader/test_cluster.py | 204 +++ tests/graph/loader/test_common.py | 161 +++ tests/graph/loader/test_dataloader.py | 282 ++++ .../loader/test_dynamic_batch_sampler.py | 37 + tests/graph/loader/test_graph_saint.py | 82 ++ tests/graph/loader/test_hgt_loader.py | 214 +++ tests/graph/loader/test_ibmb_loader.py | 67 + tests/graph/loader/test_imbalanced_sampler.py | 116 ++ .../graph/loader/test_link_neighbor_loader.py | 588 ++++++++ tests/graph/loader/test_mixin.py | 46 + tests/graph/loader/test_neighbor_loader.py | 955 +++++++++++++ tests/graph/loader/test_neighbor_sampler.py | 107 ++ tests/graph/loader/test_random_node_loader.py | 46 + tests/graph/loader/test_shadow.py | 54 + .../graph/loader/test_temporal_dataloader.py | 28 + tests/graph/loader/test_utils.py | 16 + tests/graph/loader/test_zip_loader.py | 40 + tests/graph/metrics/test_link_pred_metric.py | 94 ++ tests/graph/my_config.yaml | 15 + tests/graph/nn/aggr/test_aggr_utils.py | 66 + tests/graph/nn/aggr/test_attention.py | 24 + tests/graph/nn/aggr/test_basic.py | 95 ++ tests/graph/nn/aggr/test_deep_sets.py | 18 + tests/graph/nn/aggr/test_equilibrium.py | 45 + tests/graph/nn/aggr/test_fused.py | 74 ++ tests/graph/nn/aggr/test_gmt.py | 17 + tests/graph/nn/aggr/test_gru.py | 14 + tests/graph/nn/aggr/test_lcm.py | 72 + tests/graph/nn/aggr/test_lstm.py | 18 + tests/graph/nn/aggr/test_mlp_aggr.py | 19 + tests/graph/nn/aggr/test_multi.py | 42 + tests/graph/nn/aggr/test_quantile.py | 111 ++ tests/graph/nn/aggr/test_scaler.py | 25 + tests/graph/nn/aggr/test_set2set.py | 28 + tests/graph/nn/aggr/test_set_transformer.py | 19 + tests/graph/nn/aggr/test_sort.py | 72 + .../graph/nn/aggr/test_variance_preserving.py | 29 + tests/graph/nn/attn/test_performer.py | 12 + tests/graph/nn/conv/test_agnn_conv.py | 32 + .../graph/nn/conv/test_antisymmetric_conv.py | 32 + tests/graph/nn/conv/test_appnp.py | 56 + tests/graph/nn/conv/test_arma_conv.py | 53 + tests/graph/nn/conv/test_cg_conv.py | 92 ++ tests/graph/nn/conv/test_cheb_conv.py | 71 + tests/graph/nn/conv/test_cluster_gcn_conv.py | 30 + tests/graph/nn/conv/test_create_gnn.py | 36 + tests/graph/nn/conv/test_dir_gnn_conv.py | 24 + tests/graph/nn/conv/test_dna_conv.py | 88 ++ tests/graph/nn/conv/test_edge_conv.py | 96 ++ tests/graph/nn/conv/test_eg_conv.py | 68 + tests/graph/nn/conv/test_fa_conv.py | 125 ++ tests/graph/nn/conv/test_feast_conv.py | 49 + tests/graph/nn/conv/test_film_conv.py | 79 ++ tests/graph/nn/conv/test_fused_gat_conv.py | 35 + tests/graph/nn/conv/test_gat_conv.py | 175 +++ tests/graph/nn/conv/test_gated_graph_conv.py | 39 + tests/graph/nn/conv/test_gatv2_conv.py | 156 +++ tests/graph/nn/conv/test_gcn2_conv.py | 51 + tests/graph/nn/conv/test_gcn_conv.py | 72 + tests/graph/nn/conv/test_gen_conv.py | 141 ++ tests/graph/nn/conv/test_general_conv.py | 31 + tests/graph/nn/conv/test_gin_conv.py | 159 +++ tests/graph/nn/conv/test_gmm_conv.py | 93 ++ tests/graph/nn/conv/test_gps_conv.py | 37 + tests/graph/nn/conv/test_graph_conv.py | 108 ++ tests/graph/nn/conv/test_gravnet_conv.py | 34 + tests/graph/nn/conv/test_han_conv.py | 136 ++ tests/graph/nn/conv/test_heat_conv.py | 39 + tests/graph/nn/conv/test_hetero_conv.py | 204 +++ tests/graph/nn/conv/test_hgt_conv.py | 228 ++++ tests/graph/nn/conv/test_hypergraph_conv.py | 48 + tests/graph/nn/conv/test_le_conv.py | 30 + tests/graph/nn/conv/test_lg_conv.py | 40 + tests/graph/nn/conv/test_message_passing.py | 383 ++++++ tests/graph/nn/conv/test_mf_conv.py | 56 + tests/graph/nn/conv/test_mixhop_conv.py | 41 + tests/graph/nn/conv/test_nn_conv.py | 82 ++ tests/graph/nn/conv/test_pan_conv.py | 28 + tests/graph/nn/conv/test_pdn_conv.py | 55 + tests/graph/nn/conv/test_pna_conv.py | 75 ++ tests/graph/nn/conv/test_point_conv.py | 65 + tests/graph/nn/conv/test_point_gnn_conv.py | 41 + .../nn/conv/test_point_transformer_conv.py | 69 + tests/graph/nn/conv/test_ppf_conv.py | 74 ++ .../nn/conv/test_res_gated_graph_conv.py | 57 + tests/graph/nn/conv/test_rgat_conv.py | 126 ++ tests/graph/nn/conv/test_rgcn_conv.py | 136 ++ tests/graph/nn/conv/test_sage_conv.py | 143 ++ tests/graph/nn/conv/test_sg_conv.py | 49 + tests/graph/nn/conv/test_signed_conv.py | 72 + tests/graph/nn/conv/test_simple_conv.py | 58 + tests/graph/nn/conv/test_spline_conv.py | 84 ++ tests/graph/nn/conv/test_ssg_conv.py | 49 + tests/graph/nn/conv/test_static_graph.py | 27 + tests/graph/nn/conv/test_supergat_conv.py | 44 + tests/graph/nn/conv/test_tag_conv.py | 50 + tests/graph/nn/conv/test_transformer_conv.py | 154 +++ tests/graph/nn/conv/test_wl_conv.py | 31 + .../graph/nn/conv/test_wl_conv_continuous.py | 53 + tests/graph/nn/conv/test_x_conv.py | 30 + tests/graph/nn/dense/test_dense_gat_conv.py | 66 + tests/graph/nn/dense/test_dense_gcn_conv.py | 64 + tests/graph/nn/dense/test_dense_gin_conv.py | 71 + tests/graph/nn/dense/test_dense_graph_conv.py | 94 ++ tests/graph/nn/dense/test_dense_sage_conv.py | 64 + tests/graph/nn/dense/test_diff_pool.py | 65 + tests/graph/nn/dense/test_dmon_pool.py | 21 + tests/graph/nn/dense/test_linear.py | 248 ++++ tests/graph/nn/dense/test_mincut_pool.py | 30 + tests/graph/nn/kge/test_complex.py | 58 + tests/graph/nn/kge/test_distmult.py | 24 + tests/graph/nn/kge/test_rotate.py | 24 + tests/graph/nn/kge/test_transe.py | 24 + tests/graph/nn/models/test_attentive_fp.py | 23 + tests/graph/nn/models/test_autoencoder.py | 107 ++ tests/graph/nn/models/test_basic_gnn.py | 409 ++++++ .../nn/models/test_correct_and_smooth.py | 53 + .../nn/models/test_deep_graph_infomax.py | 68 + tests/graph/nn/models/test_deepgcn.py | 25 + tests/graph/nn/models/test_dimenet.py | 66 + tests/graph/nn/models/test_gnnff.py | 23 + tests/graph/nn/models/test_graph_mixer.py | 74 ++ tests/graph/nn/models/test_graph_unet.py | 24 + .../graph/nn/models/test_jumping_knowledge.py | 40 + tests/graph/nn/models/test_label_prop.py | 31 + tests/graph/nn/models/test_lightgcn.py | 72 + tests/graph/nn/models/test_linkx.py | 49 + tests/graph/nn/models/test_mask_label.py | 27 + tests/graph/nn/models/test_meta.py | 127 ++ tests/graph/nn/models/test_metapath2vec.py | 61 + tests/graph/nn/models/test_mlp.py | 90 ++ .../nn/models/test_neural_fingerprint.py | 28 + tests/graph/nn/models/test_node2vec.py | 32 + tests/graph/nn/models/test_pmlp.py | 23 + tests/graph/nn/models/test_re_net.py | 65 + tests/graph/nn/models/test_rect.py | 43 + tests/graph/nn/models/test_rev_gnn.py | 106 ++ tests/graph/nn/models/test_schnet.py | 66 + tests/graph/nn/models/test_signed_gcn.py | 37 + tests/graph/nn/models/test_tgn.py | 82 ++ tests/graph/nn/models/test_visnet.py | 28 + tests/graph/nn/norm/test_batch_norm.py | 70 + tests/graph/nn/norm/test_diff_group_norm.py | 38 + tests/graph/nn/norm/test_graph_norm.py | 26 + tests/graph/nn/norm/test_graph_size_norm.py | 19 + tests/graph/nn/norm/test_instance_norm.py | 52 + tests/graph/nn/norm/test_layer_norm.py | 67 + .../nn/norm/test_mean_subtraction_norm.py | 24 + tests/graph/nn/norm/test_msg_norm.py | 26 + tests/graph/nn/norm/test_pair_norm.py | 24 + .../nn/pool/connect/test_filter_edges.py | 33 + .../graph/nn/pool/select/test_select_topk.py | 89 ++ tests/graph/nn/pool/test_approx_knn.py | 60 + tests/graph/nn/pool/test_asap.py | 40 + tests/graph/nn/pool/test_avg_pool.py | 105 ++ tests/graph/nn/pool/test_consecutive.py | 13 + tests/graph/nn/pool/test_decimation.py | 40 + tests/graph/nn/pool/test_edge_pool.py | 103 ++ tests/graph/nn/pool/test_glob.py | 72 + tests/graph/nn/pool/test_graclus.py | 7 + tests/graph/nn/pool/test_knn.py | 130 ++ tests/graph/nn/pool/test_max_pool.py | 105 ++ tests/graph/nn/pool/test_mem_pool.py | 28 + tests/graph/nn/pool/test_pan_pool.py | 35 + tests/graph/nn/pool/test_pool.py | 18 + tests/graph/nn/pool/test_sag_pool.py | 51 + tests/graph/nn/pool/test_topk_pool.py | 60 + tests/graph/nn/pool/test_voxel_grid.py | 50 + tests/graph/nn/test_encoding.py | 18 + tests/graph/nn/test_inits.py | 62 + tests/graph/nn/test_reshape.py | 13 + tests/graph/nn/test_resolver.py | 111 ++ tests/graph/nn/unpool/test_knn_interpolate.py | 25 + tests/graph/profile/test_benchmark.py | 21 + tests/graph/profile/test_profile.py | 174 +++ tests/graph/profile/test_profile_utils.py | 81 ++ tests/graph/profile/test_profiler.py | 26 + tests/graph/sampler/test_sampler_base.py | 114 ++ tests/graph/sparse/test_add.py | 33 + tests/graph/sparse/test_cat.py | 48 + tests/graph/sparse/test_coalesce.py | 34 + tests/graph/sparse/test_convert.py | 24 + tests/graph/sparse/test_diag.py | 65 + tests/graph/sparse/test_ego_sample.py | 23 + tests/graph/sparse/test_eye.py | 48 + tests/graph/sparse/test_matmul.py | 83 ++ tests/graph/sparse/test_metis.py | 40 + tests/graph/sparse/test_mul.py | 52 + tests/graph/sparse/test_neighbor_sample.py | 43 + tests/graph/sparse/test_overload.py | 26 + tests/graph/sparse/test_permute.py | 17 + tests/graph/sparse/test_saint.py | 16 + tests/graph/sparse/test_sample.py | 36 + tests/graph/sparse/test_scatter.py | 124 ++ tests/graph/sparse/test_segment.py | 30 + tests/graph/sparse/test_spmm.py | 20 + tests/graph/sparse/test_spspmm.py | 48 + tests/graph/sparse/test_storage.py | 143 ++ tests/graph/sparse/test_tensor.py | 104 ++ tests/graph/sparse/test_transpose.py | 30 + tests/graph/test_config_store.py | 147 ++ tests/graph/test_debug.py | 28 + tests/graph/test_edge_index.py | 1183 +++++++++++++++++ tests/graph/test_experimental.py | 28 + tests/graph/test_home.py | 19 + tests/graph/test_inspector.py | 138 ++ tests/graph/test_schnet.py | 37 + tests/graph/test_seed.py | 16 + tests/graph/test_typing.py | 36 + tests/graph/transforms/test_add_metapaths.py | 232 ++++ .../test_add_positional_encoding.py | 114 ++ .../test_add_remaining_self_loops.py | 72 + tests/graph/transforms/test_add_self_loops.py | 58 + tests/graph/transforms/test_cartesian.py | 37 + tests/graph/transforms/test_center.py | 16 + tests/graph/transforms/test_compose.py | 62 + tests/graph/transforms/test_constant.py | 38 + tests/graph/transforms/test_delaunay.py | 32 + tests/graph/transforms/test_distance.py | 31 + tests/graph/transforms/test_face_to_edge.py | 18 + .../transforms/test_feature_propagation.py | 29 + tests/graph/transforms/test_fixed_points.py | 64 + tests/graph/transforms/test_gcn_norm.py | 47 + tests/graph/transforms/test_gdc.py | 103 ++ .../transforms/test_generate_mesh_normals.py | 29 + tests/graph/transforms/test_grid_sampling.py | 25 + tests/graph/transforms/test_half_hop.py | 45 + tests/graph/transforms/test_knn_graph.py | 27 + .../transforms/test_laplacian_lambda_max.py | 33 + .../test_largest_connected_components.py | 46 + tests/graph/transforms/test_line_graph.py | 33 + .../transforms/test_linear_transformation.py | 26 + .../graph/transforms/test_local_cartesian.py | 29 + .../transforms/test_local_degree_profile.py | 27 + tests/graph/transforms/test_mask_transform.py | 92 ++ .../transforms/test_node_property_split.py | 39 + .../transforms/test_normalize_features.py | 27 + .../transforms/test_normalize_rotation.py | 57 + .../graph/transforms/test_normalize_scale.py | 17 + tests/graph/transforms/test_one_hot_degree.py | 34 + tests/graph/transforms/test_pad.py | 585 ++++++++ .../transforms/test_point_pair_features.py | 40 + tests/graph/transforms/test_polar.py | 36 + tests/graph/transforms/test_radius_graph.py | 25 + tests/graph/transforms/test_random_flip.py | 20 + tests/graph/transforms/test_random_jitter.py | 29 + .../transforms/test_random_link_split.py | 323 +++++ .../transforms/test_random_node_split.py | 159 +++ tests/graph/transforms/test_random_rotate.py | 46 + tests/graph/transforms/test_random_scale.py | 20 + tests/graph/transforms/test_random_shear.py | 20 + .../test_remove_duplicated_edges.py | 20 + .../transforms/test_remove_isolated_nodes.py | 51 + .../test_remove_training_classes.py | 19 + .../graph/transforms/test_rooted_subgraph.py | 87 ++ tests/graph/transforms/test_sample_points.py | 31 + tests/graph/transforms/test_sign.py | 32 + tests/graph/transforms/test_spherical.py | 61 + .../transforms/test_svd_feature_reduction.py | 19 + .../graph/transforms/test_target_indegree.py | 25 + tests/graph/transforms/test_to_dense.py | 54 + .../graph/transforms/test_to_sparse_tensor.py | 124 ++ tests/graph/transforms/test_to_superpixels.py | 90 ++ tests/graph/transforms/test_to_undirected.py | 67 + tests/graph/transforms/test_two_hop.py | 27 + tests/graph/transforms/test_virtual_node.py | 38 + tests/graph/utils/test_assortativity.py | 29 + tests/graph/utils/test_augmentation.py | 96 ++ tests/graph/utils/test_cluster.py | 27 + tests/graph/utils/test_coalesce.py | 53 + tests/graph/utils/test_convert.py | 399 ++++++ tests/graph/utils/test_degree.py | 9 + tests/graph/utils/test_dropout.py | 44 + tests/graph/utils/test_embedding.py | 34 + tests/graph/utils/test_functions.py | 15 + tests/graph/utils/test_geodesic.py | 48 + tests/graph/utils/test_grid.py | 25 + tests/graph/utils/test_hetero.py | 38 + tests/graph/utils/test_homophily.py | 32 + tests/graph/utils/test_isolated.py | 45 + tests/graph/utils/test_laplacian.py | 28 + tests/graph/utils/test_lexsort.py | 11 + tests/graph/utils/test_loop.py | 187 +++ tests/graph/utils/test_map.py | 74 ++ tests/graph/utils/test_mask.py | 28 + tests/graph/utils/test_mesh_laplacian.py | 101 ++ tests/graph/utils/test_negative_sampling.py | 169 +++ tests/graph/utils/test_noise_scheduler.py | 34 + tests/graph/utils/test_normalized_cut.py | 20 + tests/graph/utils/test_num_nodes.py | 25 + tests/graph/utils/test_ppr.py | 27 + tests/graph/utils/test_random.py | 22 + tests/graph/utils/test_repeat.py | 9 + tests/graph/utils/test_scatter.py | 126 ++ tests/graph/utils/test_segment.py | 30 + tests/graph/utils/test_select.py | 23 + tests/graph/utils/test_softmax.py | 82 ++ tests/graph/utils/test_sort_edge_index.py | 72 + tests/graph/utils/test_sparse.py | 202 +++ tests/graph/utils/test_spmm.py | 144 ++ tests/graph/utils/test_subgraph.py | 122 ++ tests/graph/utils/test_to_dense_adj.py | 91 ++ tests/graph/utils/test_to_dense_batch.py | 90 ++ tests/graph/utils/test_tree_decomposition.py | 16 + tests/graph/utils/test_trim_to_layer.py | 154 +++ tests/graph/utils/test_unbatch.py | 25 + tests/graph/utils/test_undirected.py | 40 + 550 files changed, 58072 insertions(+), 1 deletion(-) rename mindscience/{gnn => sharker}/__init__.py (56%) create mode 100644 mindscience/sharker/data/__init__.py create mode 100644 mindscience/sharker/data/batch.py create mode 100644 mindscience/sharker/data/collate.py create mode 100644 mindscience/sharker/data/database.py create mode 100644 mindscience/sharker/data/datapipe.py create mode 100644 mindscience/sharker/data/dataset.py create mode 100644 mindscience/sharker/data/download.py create mode 100644 mindscience/sharker/data/extract.py create mode 100644 mindscience/sharker/data/graph.py create mode 100644 mindscience/sharker/data/heterograph.py create mode 100644 mindscience/sharker/data/hypergraph.py create mode 100644 mindscience/sharker/data/in_memory.py create mode 100644 mindscience/sharker/data/on_disk.py create mode 100644 mindscience/sharker/data/remote_store.py create mode 100644 mindscience/sharker/data/separate.py create mode 100644 mindscience/sharker/data/storage.py create mode 100644 mindscience/sharker/data/summary.py create mode 100644 mindscience/sharker/data/temporal.py create mode 100644 mindscience/sharker/data/view.py create mode 100644 mindscience/sharker/dataset/__init__.py create mode 100644 mindscience/sharker/dataset/qm9.py create mode 100644 mindscience/sharker/experimental.py create mode 100644 mindscience/sharker/home.py create mode 100644 mindscience/sharker/inspector.py create mode 100644 mindscience/sharker/io/__init__.py create mode 100644 mindscience/sharker/io/fs.py create mode 100644 mindscience/sharker/io/npz.py create mode 100644 mindscience/sharker/io/obj.py create mode 100644 mindscience/sharker/io/off.py create mode 100644 mindscience/sharker/io/planetoid.py create mode 100644 mindscience/sharker/io/ply.py create mode 100644 mindscience/sharker/io/sdf.py create mode 100644 mindscience/sharker/io/tu.py create mode 100644 mindscience/sharker/io/txt_array.py create mode 100644 mindscience/sharker/loader/__init__.py create mode 100644 mindscience/sharker/loader/dataloader.py create mode 100644 mindscience/sharker/nn/__init__.py create mode 100644 mindscience/sharker/nn/aggr/__init__.py create mode 100644 mindscience/sharker/nn/aggr/attention.py create mode 100644 mindscience/sharker/nn/aggr/base.py create mode 100644 mindscience/sharker/nn/aggr/basic.py create mode 100644 mindscience/sharker/nn/aggr/deep_sets.py create mode 100644 mindscience/sharker/nn/aggr/equilibrium.py create mode 100644 mindscience/sharker/nn/aggr/fused.py create mode 100644 mindscience/sharker/nn/aggr/gmt.py create mode 100644 mindscience/sharker/nn/aggr/gru.py create mode 100644 mindscience/sharker/nn/aggr/lcm.py create mode 100644 mindscience/sharker/nn/aggr/lstm.py create mode 100644 mindscience/sharker/nn/aggr/mlp.py create mode 100644 mindscience/sharker/nn/aggr/multi.py create mode 100644 mindscience/sharker/nn/aggr/quantile.py create mode 100644 mindscience/sharker/nn/aggr/scaler.py create mode 100644 mindscience/sharker/nn/aggr/set2set.py create mode 100644 mindscience/sharker/nn/aggr/set_transformer.py create mode 100644 mindscience/sharker/nn/aggr/sort.py create mode 100644 mindscience/sharker/nn/aggr/utils.py create mode 100644 mindscience/sharker/nn/aggr/variance_preserving.py create mode 100644 mindscience/sharker/nn/conv/__init__.py create mode 100644 mindscience/sharker/nn/conv/agnn_conv.py create mode 100644 mindscience/sharker/nn/conv/antisymmetric_conv.py create mode 100644 mindscience/sharker/nn/conv/appnp.py create mode 100644 mindscience/sharker/nn/conv/arma_conv.py create mode 100644 mindscience/sharker/nn/conv/cg_conv copy.py create mode 100644 mindscience/sharker/nn/conv/cg_conv.py create mode 100644 mindscience/sharker/nn/conv/cheb_conv.py create mode 100644 mindscience/sharker/nn/conv/cluster_gcn_conv.py create mode 100644 mindscience/sharker/nn/conv/dir_gnn_conv.py create mode 100644 mindscience/sharker/nn/conv/dna_conv.py create mode 100644 mindscience/sharker/nn/conv/edge_conv.py create mode 100644 mindscience/sharker/nn/conv/eg_conv.py create mode 100644 mindscience/sharker/nn/conv/fa_conv.py create mode 100644 mindscience/sharker/nn/conv/feast_conv.py create mode 100644 mindscience/sharker/nn/conv/film_conv.py create mode 100644 mindscience/sharker/nn/conv/gat_conv.py create mode 100644 mindscience/sharker/nn/conv/gated_graph_conv.py create mode 100644 mindscience/sharker/nn/conv/gatv2_conv.py create mode 100644 mindscience/sharker/nn/conv/gcn_conv.py create mode 100644 mindscience/sharker/nn/conv/gen_conv.py create mode 100644 mindscience/sharker/nn/conv/general_conv.py create mode 100644 mindscience/sharker/nn/conv/gin_conv.py create mode 100644 mindscience/sharker/nn/conv/gmm_conv.py create mode 100644 mindscience/sharker/nn/conv/graph_conv.py create mode 100644 mindscience/sharker/nn/conv/gravnet_conv.py create mode 100644 mindscience/sharker/nn/conv/heat_conv.py create mode 100644 mindscience/sharker/nn/conv/hetero_conv.py create mode 100644 mindscience/sharker/nn/conv/hgt_conv.py create mode 100644 mindscience/sharker/nn/conv/hypergraph_conv.py create mode 100644 mindscience/sharker/nn/conv/le_conv.py create mode 100644 mindscience/sharker/nn/conv/lg_conv.py create mode 100644 mindscience/sharker/nn/conv/message_passing.py create mode 100644 mindscience/sharker/nn/conv/mf_conv.py create mode 100644 mindscience/sharker/nn/conv/mixhop_conv.py create mode 100644 mindscience/sharker/nn/conv/nn_conv.py create mode 100644 mindscience/sharker/nn/conv/pdn_conv.py create mode 100644 mindscience/sharker/nn/conv/point_conv.py create mode 100644 mindscience/sharker/nn/conv/point_gnn_conv.py create mode 100644 mindscience/sharker/nn/conv/point_transformer_conv.py create mode 100644 mindscience/sharker/nn/conv/ppf_conv.py create mode 100644 mindscience/sharker/nn/conv/res_gated_graph_conv.py create mode 100644 mindscience/sharker/nn/conv/rgat_conv.py create mode 100644 mindscience/sharker/nn/conv/sage_conv.py create mode 100644 mindscience/sharker/nn/conv/sg_conv.py create mode 100644 mindscience/sharker/nn/conv/signed_conv.py create mode 100644 mindscience/sharker/nn/conv/simple_conv.py create mode 100644 mindscience/sharker/nn/conv/spline_conv.py create mode 100644 mindscience/sharker/nn/conv/ssg_conv.py create mode 100644 mindscience/sharker/nn/conv/tag_conv.py create mode 100644 mindscience/sharker/nn/conv/transformer_conv.py create mode 100644 mindscience/sharker/nn/conv/wl_conv.py create mode 100644 mindscience/sharker/nn/conv/wl_conv_continuous.py create mode 100644 mindscience/sharker/nn/conv/x_conv.py create mode 100644 mindscience/sharker/nn/dense/__init__.py create mode 100644 mindscience/sharker/nn/dense/linear.py create mode 100644 mindscience/sharker/nn/encoding.py create mode 100644 mindscience/sharker/nn/inits.py create mode 100644 mindscience/sharker/nn/lr_scheduler.py create mode 100644 mindscience/sharker/nn/models/__init__.py create mode 100644 mindscience/sharker/nn/models/mlp.py create mode 100644 mindscience/sharker/nn/norm/__init__.py create mode 100644 mindscience/sharker/nn/norm/batch_norm.py create mode 100644 mindscience/sharker/nn/norm/msg_norm.py create mode 100644 mindscience/sharker/nn/reshape.py create mode 100644 mindscience/sharker/nn/resolver.py create mode 100644 mindscience/sharker/profile/__init__.py create mode 100644 mindscience/sharker/profile/benchmark.py create mode 100644 mindscience/sharker/profile/utils.py create mode 100644 mindscience/sharker/resolver.py create mode 100644 mindscience/sharker/seed.py create mode 100644 mindscience/sharker/template.py create mode 100644 mindscience/sharker/testing/__init__.py create mode 100644 mindscience/sharker/testing/asserts.py create mode 100644 mindscience/sharker/testing/data.py create mode 100644 mindscience/sharker/testing/decorators.py create mode 100644 mindscience/sharker/testing/distributed.py create mode 100644 mindscience/sharker/typing.py create mode 100644 mindscience/sharker/utils/__init__.py create mode 100644 mindscience/sharker/utils/_scatter.py create mode 100644 mindscience/sharker/utils/_segment.py create mode 100644 mindscience/sharker/utils/assortativity.py create mode 100644 mindscience/sharker/utils/augmentation.py create mode 100644 mindscience/sharker/utils/cluster.py create mode 100644 mindscience/sharker/utils/coalesce.py create mode 100644 mindscience/sharker/utils/convert.py create mode 100644 mindscience/sharker/utils/degree.py create mode 100644 mindscience/sharker/utils/dropout.py create mode 100644 mindscience/sharker/utils/embedding.py create mode 100644 mindscience/sharker/utils/functions.py create mode 100644 mindscience/sharker/utils/grid.py create mode 100644 mindscience/sharker/utils/hetero.py create mode 100644 mindscience/sharker/utils/homophily.py create mode 100644 mindscience/sharker/utils/isolated.py create mode 100644 mindscience/sharker/utils/laplacian.py create mode 100644 mindscience/sharker/utils/loop.py create mode 100644 mindscience/sharker/utils/map.py create mode 100644 mindscience/sharker/utils/mask.py create mode 100644 mindscience/sharker/utils/mixin.py create mode 100644 mindscience/sharker/utils/ncon.py create mode 100644 mindscience/sharker/utils/negative_sampling.py create mode 100644 mindscience/sharker/utils/noise_scheduler.py create mode 100644 mindscience/sharker/utils/normalize.py create mode 100644 mindscience/sharker/utils/num_nodes.py create mode 100644 mindscience/sharker/utils/random.py create mode 100644 mindscience/sharker/utils/repeat.py create mode 100644 mindscience/sharker/utils/select.py create mode 100644 mindscience/sharker/utils/softmax.py create mode 100644 mindscience/sharker/utils/sort_edge_index.py create mode 100644 mindscience/sharker/utils/sparse.py create mode 100644 mindscience/sharker/utils/subgraph.py create mode 100644 mindscience/sharker/utils/to_dense_adj.py create mode 100644 mindscience/sharker/utils/to_dense_batch.py create mode 100644 mindscience/sharker/utils/tree_decomposition.py create mode 100644 mindscience/sharker/utils/trim_to_layer.py create mode 100644 mindscience/sharker/utils/unbatch.py create mode 100644 mindscience/sharker/utils/undirected.py create mode 100644 tests/graph/cluster/test_fps.py create mode 100644 tests/graph/cluster/test_graclus.py create mode 100644 tests/graph/cluster/test_grid.py create mode 100644 tests/graph/cluster/test_knn.py create mode 100644 tests/graph/cluster/test_nearest.py create mode 100644 tests/graph/cluster/test_radius.py create mode 100644 tests/graph/cluster/test_rw.py create mode 100644 tests/graph/conftest.py create mode 100644 tests/graph/data/test_batch.py create mode 100644 tests/graph/data/test_data.py create mode 100644 tests/graph/data/test_database.py create mode 100644 tests/graph/data/test_dataloader.py create mode 100644 tests/graph/data/test_dataset.py create mode 100644 tests/graph/data/test_dataset_summary.py create mode 100644 tests/graph/data/test_feature_store.py create mode 100644 tests/graph/data/test_graph_store.py create mode 100644 tests/graph/data/test_hetero.py create mode 100644 tests/graph/data/test_hypergraph.py create mode 100644 tests/graph/data/test_inherit.py create mode 100644 tests/graph/data/test_on_disk_dataset.py create mode 100644 tests/graph/data/test_remote_backend_utils.py create mode 100644 tests/graph/data/test_storage.py create mode 100644 tests/graph/data/test_temporal.py create mode 100644 tests/graph/data/test_view.py create mode 100644 tests/graph/datasets/graph_generator/test_ba_graph.py create mode 100644 tests/graph/datasets/graph_generator/test_er_graph.py create mode 100644 tests/graph/datasets/graph_generator/test_grid_graph.py create mode 100644 tests/graph/datasets/graph_generator/test_tree_graph.py create mode 100644 tests/graph/datasets/motif_generator/test_custom_motif.py create mode 100644 tests/graph/datasets/motif_generator/test_cycle_motif.py create mode 100644 tests/graph/datasets/motif_generator/test_grid_motif.py create mode 100644 tests/graph/datasets/motif_generator/test_house_motif.py create mode 100644 tests/graph/datasets/test_ba_shapes.py create mode 100644 tests/graph/datasets/test_bzr.py create mode 100644 tests/graph/datasets/test_elliptic.py create mode 100644 tests/graph/datasets/test_enzymes.py create mode 100644 tests/graph/datasets/test_explainer_dataset.py create mode 100644 tests/graph/datasets/test_fake.py create mode 100644 tests/graph/datasets/test_imdb_binary.py create mode 100644 tests/graph/datasets/test_infection_dataset.py create mode 100644 tests/graph/datasets/test_karate.py create mode 100644 tests/graph/datasets/test_mutag.py create mode 100644 tests/graph/datasets/test_planetoid.py create mode 100644 tests/graph/datasets/test_snap_dataset.py create mode 100644 tests/graph/datasets/test_suite_sparse.py create mode 100644 tests/graph/explain/algorithm/test_attention_explainer.py create mode 100644 tests/graph/explain/algorithm/test_captum.py create mode 100644 tests/graph/explain/algorithm/test_captum_explainer.py create mode 100644 tests/graph/explain/algorithm/test_captum_hetero.py create mode 100644 tests/graph/explain/algorithm/test_explain_algorithm_utils.py create mode 100644 tests/graph/explain/algorithm/test_gnn_explainer.py create mode 100644 tests/graph/explain/algorithm/test_graphmask_explainer.py create mode 100644 tests/graph/explain/algorithm/test_pg_explainer.py create mode 100644 tests/graph/explain/conftest.py create mode 100644 tests/graph/explain/metric/test_basic_metric.py create mode 100644 tests/graph/explain/metric/test_faithfulness.py create mode 100644 tests/graph/explain/metric/test_fidelity.py create mode 100644 tests/graph/explain/test_explain_config.py create mode 100644 tests/graph/explain/test_explainer.py create mode 100644 tests/graph/explain/test_explanation.py create mode 100644 tests/graph/explain/test_hetero_explainer.py create mode 100644 tests/graph/explain/test_hetero_explanation.py create mode 100644 tests/graph/io/example1.off create mode 100644 tests/graph/io/example2.off create mode 100644 tests/graph/io/test_fs.py create mode 100644 tests/graph/io/test_off.py create mode 100644 tests/graph/loader/test_cache.py create mode 100644 tests/graph/loader/test_cluster.py create mode 100644 tests/graph/loader/test_common.py create mode 100644 tests/graph/loader/test_dataloader.py create mode 100644 tests/graph/loader/test_dynamic_batch_sampler.py create mode 100644 tests/graph/loader/test_graph_saint.py create mode 100644 tests/graph/loader/test_hgt_loader.py create mode 100644 tests/graph/loader/test_ibmb_loader.py create mode 100644 tests/graph/loader/test_imbalanced_sampler.py create mode 100644 tests/graph/loader/test_link_neighbor_loader.py create mode 100644 tests/graph/loader/test_mixin.py create mode 100644 tests/graph/loader/test_neighbor_loader.py create mode 100644 tests/graph/loader/test_neighbor_sampler.py create mode 100644 tests/graph/loader/test_random_node_loader.py create mode 100644 tests/graph/loader/test_shadow.py create mode 100644 tests/graph/loader/test_temporal_dataloader.py create mode 100644 tests/graph/loader/test_utils.py create mode 100644 tests/graph/loader/test_zip_loader.py create mode 100644 tests/graph/metrics/test_link_pred_metric.py create mode 100644 tests/graph/my_config.yaml create mode 100644 tests/graph/nn/aggr/test_aggr_utils.py create mode 100644 tests/graph/nn/aggr/test_attention.py create mode 100644 tests/graph/nn/aggr/test_basic.py create mode 100644 tests/graph/nn/aggr/test_deep_sets.py create mode 100644 tests/graph/nn/aggr/test_equilibrium.py create mode 100644 tests/graph/nn/aggr/test_fused.py create mode 100644 tests/graph/nn/aggr/test_gmt.py create mode 100644 tests/graph/nn/aggr/test_gru.py create mode 100644 tests/graph/nn/aggr/test_lcm.py create mode 100644 tests/graph/nn/aggr/test_lstm.py create mode 100644 tests/graph/nn/aggr/test_mlp_aggr.py create mode 100644 tests/graph/nn/aggr/test_multi.py create mode 100644 tests/graph/nn/aggr/test_quantile.py create mode 100644 tests/graph/nn/aggr/test_scaler.py create mode 100644 tests/graph/nn/aggr/test_set2set.py create mode 100644 tests/graph/nn/aggr/test_set_transformer.py create mode 100644 tests/graph/nn/aggr/test_sort.py create mode 100644 tests/graph/nn/aggr/test_variance_preserving.py create mode 100644 tests/graph/nn/attn/test_performer.py create mode 100644 tests/graph/nn/conv/test_agnn_conv.py create mode 100644 tests/graph/nn/conv/test_antisymmetric_conv.py create mode 100644 tests/graph/nn/conv/test_appnp.py create mode 100644 tests/graph/nn/conv/test_arma_conv.py create mode 100644 tests/graph/nn/conv/test_cg_conv.py create mode 100644 tests/graph/nn/conv/test_cheb_conv.py create mode 100644 tests/graph/nn/conv/test_cluster_gcn_conv.py create mode 100644 tests/graph/nn/conv/test_create_gnn.py create mode 100644 tests/graph/nn/conv/test_dir_gnn_conv.py create mode 100644 tests/graph/nn/conv/test_dna_conv.py create mode 100644 tests/graph/nn/conv/test_edge_conv.py create mode 100644 tests/graph/nn/conv/test_eg_conv.py create mode 100644 tests/graph/nn/conv/test_fa_conv.py create mode 100644 tests/graph/nn/conv/test_feast_conv.py create mode 100644 tests/graph/nn/conv/test_film_conv.py create mode 100644 tests/graph/nn/conv/test_fused_gat_conv.py create mode 100644 tests/graph/nn/conv/test_gat_conv.py create mode 100644 tests/graph/nn/conv/test_gated_graph_conv.py create mode 100644 tests/graph/nn/conv/test_gatv2_conv.py create mode 100644 tests/graph/nn/conv/test_gcn2_conv.py create mode 100644 tests/graph/nn/conv/test_gcn_conv.py create mode 100644 tests/graph/nn/conv/test_gen_conv.py create mode 100644 tests/graph/nn/conv/test_general_conv.py create mode 100644 tests/graph/nn/conv/test_gin_conv.py create mode 100644 tests/graph/nn/conv/test_gmm_conv.py create mode 100644 tests/graph/nn/conv/test_gps_conv.py create mode 100644 tests/graph/nn/conv/test_graph_conv.py create mode 100644 tests/graph/nn/conv/test_gravnet_conv.py create mode 100644 tests/graph/nn/conv/test_han_conv.py create mode 100644 tests/graph/nn/conv/test_heat_conv.py create mode 100644 tests/graph/nn/conv/test_hetero_conv.py create mode 100644 tests/graph/nn/conv/test_hgt_conv.py create mode 100644 tests/graph/nn/conv/test_hypergraph_conv.py create mode 100644 tests/graph/nn/conv/test_le_conv.py create mode 100644 tests/graph/nn/conv/test_lg_conv.py create mode 100644 tests/graph/nn/conv/test_message_passing.py create mode 100644 tests/graph/nn/conv/test_mf_conv.py create mode 100644 tests/graph/nn/conv/test_mixhop_conv.py create mode 100644 tests/graph/nn/conv/test_nn_conv.py create mode 100644 tests/graph/nn/conv/test_pan_conv.py create mode 100644 tests/graph/nn/conv/test_pdn_conv.py create mode 100644 tests/graph/nn/conv/test_pna_conv.py create mode 100644 tests/graph/nn/conv/test_point_conv.py create mode 100644 tests/graph/nn/conv/test_point_gnn_conv.py create mode 100644 tests/graph/nn/conv/test_point_transformer_conv.py create mode 100644 tests/graph/nn/conv/test_ppf_conv.py create mode 100644 tests/graph/nn/conv/test_res_gated_graph_conv.py create mode 100644 tests/graph/nn/conv/test_rgat_conv.py create mode 100644 tests/graph/nn/conv/test_rgcn_conv.py create mode 100644 tests/graph/nn/conv/test_sage_conv.py create mode 100644 tests/graph/nn/conv/test_sg_conv.py create mode 100644 tests/graph/nn/conv/test_signed_conv.py create mode 100644 tests/graph/nn/conv/test_simple_conv.py create mode 100644 tests/graph/nn/conv/test_spline_conv.py create mode 100644 tests/graph/nn/conv/test_ssg_conv.py create mode 100644 tests/graph/nn/conv/test_static_graph.py create mode 100644 tests/graph/nn/conv/test_supergat_conv.py create mode 100644 tests/graph/nn/conv/test_tag_conv.py create mode 100644 tests/graph/nn/conv/test_transformer_conv.py create mode 100644 tests/graph/nn/conv/test_wl_conv.py create mode 100644 tests/graph/nn/conv/test_wl_conv_continuous.py create mode 100644 tests/graph/nn/conv/test_x_conv.py create mode 100644 tests/graph/nn/dense/test_dense_gat_conv.py create mode 100644 tests/graph/nn/dense/test_dense_gcn_conv.py create mode 100644 tests/graph/nn/dense/test_dense_gin_conv.py create mode 100644 tests/graph/nn/dense/test_dense_graph_conv.py create mode 100644 tests/graph/nn/dense/test_dense_sage_conv.py create mode 100644 tests/graph/nn/dense/test_diff_pool.py create mode 100644 tests/graph/nn/dense/test_dmon_pool.py create mode 100644 tests/graph/nn/dense/test_linear.py create mode 100644 tests/graph/nn/dense/test_mincut_pool.py create mode 100644 tests/graph/nn/kge/test_complex.py create mode 100644 tests/graph/nn/kge/test_distmult.py create mode 100644 tests/graph/nn/kge/test_rotate.py create mode 100644 tests/graph/nn/kge/test_transe.py create mode 100644 tests/graph/nn/models/test_attentive_fp.py create mode 100644 tests/graph/nn/models/test_autoencoder.py create mode 100644 tests/graph/nn/models/test_basic_gnn.py create mode 100644 tests/graph/nn/models/test_correct_and_smooth.py create mode 100644 tests/graph/nn/models/test_deep_graph_infomax.py create mode 100644 tests/graph/nn/models/test_deepgcn.py create mode 100644 tests/graph/nn/models/test_dimenet.py create mode 100644 tests/graph/nn/models/test_gnnff.py create mode 100644 tests/graph/nn/models/test_graph_mixer.py create mode 100644 tests/graph/nn/models/test_graph_unet.py create mode 100644 tests/graph/nn/models/test_jumping_knowledge.py create mode 100644 tests/graph/nn/models/test_label_prop.py create mode 100644 tests/graph/nn/models/test_lightgcn.py create mode 100644 tests/graph/nn/models/test_linkx.py create mode 100644 tests/graph/nn/models/test_mask_label.py create mode 100644 tests/graph/nn/models/test_meta.py create mode 100644 tests/graph/nn/models/test_metapath2vec.py create mode 100644 tests/graph/nn/models/test_mlp.py create mode 100644 tests/graph/nn/models/test_neural_fingerprint.py create mode 100644 tests/graph/nn/models/test_node2vec.py create mode 100644 tests/graph/nn/models/test_pmlp.py create mode 100644 tests/graph/nn/models/test_re_net.py create mode 100644 tests/graph/nn/models/test_rect.py create mode 100644 tests/graph/nn/models/test_rev_gnn.py create mode 100644 tests/graph/nn/models/test_schnet.py create mode 100644 tests/graph/nn/models/test_signed_gcn.py create mode 100644 tests/graph/nn/models/test_tgn.py create mode 100644 tests/graph/nn/models/test_visnet.py create mode 100644 tests/graph/nn/norm/test_batch_norm.py create mode 100644 tests/graph/nn/norm/test_diff_group_norm.py create mode 100644 tests/graph/nn/norm/test_graph_norm.py create mode 100644 tests/graph/nn/norm/test_graph_size_norm.py create mode 100644 tests/graph/nn/norm/test_instance_norm.py create mode 100644 tests/graph/nn/norm/test_layer_norm.py create mode 100644 tests/graph/nn/norm/test_mean_subtraction_norm.py create mode 100644 tests/graph/nn/norm/test_msg_norm.py create mode 100644 tests/graph/nn/norm/test_pair_norm.py create mode 100644 tests/graph/nn/pool/connect/test_filter_edges.py create mode 100644 tests/graph/nn/pool/select/test_select_topk.py create mode 100644 tests/graph/nn/pool/test_approx_knn.py create mode 100644 tests/graph/nn/pool/test_asap.py create mode 100644 tests/graph/nn/pool/test_avg_pool.py create mode 100644 tests/graph/nn/pool/test_consecutive.py create mode 100644 tests/graph/nn/pool/test_decimation.py create mode 100644 tests/graph/nn/pool/test_edge_pool.py create mode 100644 tests/graph/nn/pool/test_glob.py create mode 100644 tests/graph/nn/pool/test_graclus.py create mode 100644 tests/graph/nn/pool/test_knn.py create mode 100644 tests/graph/nn/pool/test_max_pool.py create mode 100644 tests/graph/nn/pool/test_mem_pool.py create mode 100644 tests/graph/nn/pool/test_pan_pool.py create mode 100644 tests/graph/nn/pool/test_pool.py create mode 100644 tests/graph/nn/pool/test_sag_pool.py create mode 100644 tests/graph/nn/pool/test_topk_pool.py create mode 100644 tests/graph/nn/pool/test_voxel_grid.py create mode 100644 tests/graph/nn/test_encoding.py create mode 100644 tests/graph/nn/test_inits.py create mode 100644 tests/graph/nn/test_reshape.py create mode 100644 tests/graph/nn/test_resolver.py create mode 100644 tests/graph/nn/unpool/test_knn_interpolate.py create mode 100644 tests/graph/profile/test_benchmark.py create mode 100644 tests/graph/profile/test_profile.py create mode 100644 tests/graph/profile/test_profile_utils.py create mode 100644 tests/graph/profile/test_profiler.py create mode 100644 tests/graph/sampler/test_sampler_base.py create mode 100644 tests/graph/sparse/test_add.py create mode 100644 tests/graph/sparse/test_cat.py create mode 100644 tests/graph/sparse/test_coalesce.py create mode 100644 tests/graph/sparse/test_convert.py create mode 100644 tests/graph/sparse/test_diag.py create mode 100644 tests/graph/sparse/test_ego_sample.py create mode 100644 tests/graph/sparse/test_eye.py create mode 100644 tests/graph/sparse/test_matmul.py create mode 100644 tests/graph/sparse/test_metis.py create mode 100644 tests/graph/sparse/test_mul.py create mode 100644 tests/graph/sparse/test_neighbor_sample.py create mode 100644 tests/graph/sparse/test_overload.py create mode 100644 tests/graph/sparse/test_permute.py create mode 100644 tests/graph/sparse/test_saint.py create mode 100644 tests/graph/sparse/test_sample.py create mode 100644 tests/graph/sparse/test_scatter.py create mode 100644 tests/graph/sparse/test_segment.py create mode 100644 tests/graph/sparse/test_spmm.py create mode 100644 tests/graph/sparse/test_spspmm.py create mode 100644 tests/graph/sparse/test_storage.py create mode 100644 tests/graph/sparse/test_tensor.py create mode 100644 tests/graph/sparse/test_transpose.py create mode 100644 tests/graph/test_config_store.py create mode 100644 tests/graph/test_debug.py create mode 100644 tests/graph/test_edge_index.py create mode 100644 tests/graph/test_experimental.py create mode 100644 tests/graph/test_home.py create mode 100644 tests/graph/test_inspector.py create mode 100644 tests/graph/test_schnet.py create mode 100644 tests/graph/test_seed.py create mode 100644 tests/graph/test_typing.py create mode 100644 tests/graph/transforms/test_add_metapaths.py create mode 100644 tests/graph/transforms/test_add_positional_encoding.py create mode 100644 tests/graph/transforms/test_add_remaining_self_loops.py create mode 100644 tests/graph/transforms/test_add_self_loops.py create mode 100644 tests/graph/transforms/test_cartesian.py create mode 100644 tests/graph/transforms/test_center.py create mode 100644 tests/graph/transforms/test_compose.py create mode 100644 tests/graph/transforms/test_constant.py create mode 100644 tests/graph/transforms/test_delaunay.py create mode 100644 tests/graph/transforms/test_distance.py create mode 100644 tests/graph/transforms/test_face_to_edge.py create mode 100644 tests/graph/transforms/test_feature_propagation.py create mode 100644 tests/graph/transforms/test_fixed_points.py create mode 100644 tests/graph/transforms/test_gcn_norm.py create mode 100644 tests/graph/transforms/test_gdc.py create mode 100644 tests/graph/transforms/test_generate_mesh_normals.py create mode 100644 tests/graph/transforms/test_grid_sampling.py create mode 100644 tests/graph/transforms/test_half_hop.py create mode 100644 tests/graph/transforms/test_knn_graph.py create mode 100644 tests/graph/transforms/test_laplacian_lambda_max.py create mode 100644 tests/graph/transforms/test_largest_connected_components.py create mode 100644 tests/graph/transforms/test_line_graph.py create mode 100644 tests/graph/transforms/test_linear_transformation.py create mode 100644 tests/graph/transforms/test_local_cartesian.py create mode 100644 tests/graph/transforms/test_local_degree_profile.py create mode 100644 tests/graph/transforms/test_mask_transform.py create mode 100644 tests/graph/transforms/test_node_property_split.py create mode 100644 tests/graph/transforms/test_normalize_features.py create mode 100644 tests/graph/transforms/test_normalize_rotation.py create mode 100644 tests/graph/transforms/test_normalize_scale.py create mode 100644 tests/graph/transforms/test_one_hot_degree.py create mode 100644 tests/graph/transforms/test_pad.py create mode 100644 tests/graph/transforms/test_point_pair_features.py create mode 100644 tests/graph/transforms/test_polar.py create mode 100644 tests/graph/transforms/test_radius_graph.py create mode 100644 tests/graph/transforms/test_random_flip.py create mode 100644 tests/graph/transforms/test_random_jitter.py create mode 100644 tests/graph/transforms/test_random_link_split.py create mode 100644 tests/graph/transforms/test_random_node_split.py create mode 100644 tests/graph/transforms/test_random_rotate.py create mode 100644 tests/graph/transforms/test_random_scale.py create mode 100644 tests/graph/transforms/test_random_shear.py create mode 100644 tests/graph/transforms/test_remove_duplicated_edges.py create mode 100644 tests/graph/transforms/test_remove_isolated_nodes.py create mode 100644 tests/graph/transforms/test_remove_training_classes.py create mode 100644 tests/graph/transforms/test_rooted_subgraph.py create mode 100644 tests/graph/transforms/test_sample_points.py create mode 100644 tests/graph/transforms/test_sign.py create mode 100644 tests/graph/transforms/test_spherical.py create mode 100644 tests/graph/transforms/test_svd_feature_reduction.py create mode 100644 tests/graph/transforms/test_target_indegree.py create mode 100644 tests/graph/transforms/test_to_dense.py create mode 100644 tests/graph/transforms/test_to_sparse_tensor.py create mode 100644 tests/graph/transforms/test_to_superpixels.py create mode 100644 tests/graph/transforms/test_to_undirected.py create mode 100644 tests/graph/transforms/test_two_hop.py create mode 100644 tests/graph/transforms/test_virtual_node.py create mode 100644 tests/graph/utils/test_assortativity.py create mode 100644 tests/graph/utils/test_augmentation.py create mode 100644 tests/graph/utils/test_cluster.py create mode 100644 tests/graph/utils/test_coalesce.py create mode 100644 tests/graph/utils/test_convert.py create mode 100644 tests/graph/utils/test_degree.py create mode 100644 tests/graph/utils/test_dropout.py create mode 100644 tests/graph/utils/test_embedding.py create mode 100644 tests/graph/utils/test_functions.py create mode 100644 tests/graph/utils/test_geodesic.py create mode 100644 tests/graph/utils/test_grid.py create mode 100644 tests/graph/utils/test_hetero.py create mode 100644 tests/graph/utils/test_homophily.py create mode 100644 tests/graph/utils/test_isolated.py create mode 100644 tests/graph/utils/test_laplacian.py create mode 100644 tests/graph/utils/test_lexsort.py create mode 100644 tests/graph/utils/test_loop.py create mode 100644 tests/graph/utils/test_map.py create mode 100644 tests/graph/utils/test_mask.py create mode 100644 tests/graph/utils/test_mesh_laplacian.py create mode 100644 tests/graph/utils/test_negative_sampling.py create mode 100644 tests/graph/utils/test_noise_scheduler.py create mode 100644 tests/graph/utils/test_normalized_cut.py create mode 100644 tests/graph/utils/test_num_nodes.py create mode 100644 tests/graph/utils/test_ppr.py create mode 100644 tests/graph/utils/test_random.py create mode 100644 tests/graph/utils/test_repeat.py create mode 100644 tests/graph/utils/test_scatter.py create mode 100644 tests/graph/utils/test_segment.py create mode 100644 tests/graph/utils/test_select.py create mode 100644 tests/graph/utils/test_softmax.py create mode 100644 tests/graph/utils/test_sort_edge_index.py create mode 100644 tests/graph/utils/test_sparse.py create mode 100644 tests/graph/utils/test_spmm.py create mode 100644 tests/graph/utils/test_subgraph.py create mode 100644 tests/graph/utils/test_to_dense_adj.py create mode 100644 tests/graph/utils/test_to_dense_batch.py create mode 100644 tests/graph/utils/test_tree_decomposition.py create mode 100644 tests/graph/utils/test_trim_to_layer.py create mode 100644 tests/graph/utils/test_unbatch.py create mode 100644 tests/graph/utils/test_undirected.py diff --git a/mindscience/gnn/__init__.py b/mindscience/sharker/__init__.py similarity index 56% rename from mindscience/gnn/__init__.py rename to mindscience/sharker/__init__.py index 69a14b29e..aa42b106b 100644 --- a/mindscience/gnn/__init__.py +++ b/mindscience/sharker/__init__.py @@ -16,4 +16,26 @@ init """ -__all__ = [] \ No newline at end of file +from .seed import seed_everything +from .home import get_home_dir, set_home_dir + +import sharker.profile +import sharker.utils +import sharker.data +import sharker.loader +import sharker.nn +from .experimental import (is_experimental_mode_enabled, experimental_mode, + set_experimental_mode) + +__version__ = '0.1' + +__all__ = [ + 'seed_everything', + 'get_home_dir', + 'set_home_dir', + 'is_experimental_mode_enabled', + 'experimental_mode', + 'set_experimental_mode', + 'sharker', + '__version__', +] diff --git a/mindscience/sharker/data/__init__.py b/mindscience/sharker/data/__init__.py new file mode 100644 index 000000000..1ff2a3984 --- /dev/null +++ b/mindscience/sharker/data/__init__.py @@ -0,0 +1,38 @@ +from .graph import Data, Graph +from .heterograph import HeteroGraph +from .batch import Batch +from .temporal import TemporalGraph +from .dataset import Dataset +from .in_memory import InMemoryDataset +from .on_disk import OnDiskDataset +from .database import Database, SQLiteDatabase, RocksDatabase +from .download import download_url, download_google_url +from .extract import extract_tar, extract_zip, extract_bz2, extract_gz + +data_classes = [ + "Data", + "Graph", + "HeteroGraph", + "Batch", + "TemporalGraph", + "Dataset", + "InMemoryDataset", + "OnDiskDataset", +] + +database_classes = [ + "Database", + "SQLiteDatabase", + "RocksDatabase", +] + +helper_functions = [ + "download_url", + "download_google_url", + "extract_tar", + "extract_zip", + "extract_bz2", + "extract_gz", +] + +__all__ = data_classes + helper_functions + database_classes diff --git a/mindscience/sharker/data/batch.py b/mindscience/sharker/data/batch.py new file mode 100644 index 000000000..c895ea428 --- /dev/null +++ b/mindscience/sharker/data/batch.py @@ -0,0 +1,222 @@ +import inspect +from collections.abc import Sequence, Mapping +from typing import Any, List, Optional, Type, Union +from typing_extensions import Self + +import mindspore as ms +from mindspore import mint +import numpy as np +from .graph import Graph +from .collate import collate +from .separate import separate +from .dataset import IndexType + + +class DynamicInheritance(type): + # A meta class that sets the base class of a `Batch` object, e.g.: + # * `Batch(Graph)` in case `Graph` objects are batched together + # * `Batch(HeteroGraph)` in case `HeteroGraph` objects are batched together + def __call__(cls, *args: Any, **kwargs: Any) -> Any: + base_cls = kwargs.pop("_base_cls", Graph) + + if issubclass(base_cls, Batch): + new_cls = base_cls + else: + name = f"{base_cls.__name__}{cls.__name__}" + + # NOTE `MetaResolver` is necessary to resolve metaclass conflict + # problems between `DynamicInheritance` and the metaclass of + # `base_cls`. In particular, it creates a new common metaclass + # from the defined metaclasses. + class MetaResolver(type(cls), type(base_cls)): # type: ignore + pass + + if name not in globals(): + globals()[name] = MetaResolver(name, (cls, base_cls), {}) + new_cls = globals()[name] + + params = list(inspect.signature(base_cls.__init__).parameters.items()) + for i, (k, v) in enumerate(params[1:]): + if k == "args" or k == "kwargs": + continue + if i < len(args) or k in kwargs: + continue + if v.default is not inspect.Parameter.empty: + continue + kwargs[k] = None + + return super(DynamicInheritance, new_cls).__call__(*args, **kwargs) + + +class DynamicInheritanceGetter: + def __call__(self, cls: Type, base_cls: Type) -> Self: + return cls(_base_cls=base_cls) + + +@ms.jit_class +class Batch(metaclass=DynamicInheritance): + r"""A data object describing a batch of graphs as one big (disconnected) + graph. + Inherits from :class:`sharker.data.Graph` or + :class:`sharker.data.HeteroGraph`. + In addition, single graphs can be identified via the assignment vector + :obj:`batch`, which maps each node to its respective graph identifier. + + :mindgeometric:`MindGeometric` allows modification to the underlying batching procedure by + overwriting the :meth:`~Data.__inc__` and :meth:`~Data.__cat_dim__` + functionalities. + The :meth:`~Data.__inc__` method defines the incremental count between two + consecutive graph attributes. + By default, :mindgeometric:`MindGeometric` increments attributes by the number of nodes + whenever their attribute names contain the substring :obj:`index` + (for historical reasons), which comes in handy for attributes such as + :obj:`edge_index` or :obj:`node_index`. + However, note that this may lead to unexpected behavior for attributes + whose names contain the substring :obj:`index` but should not be + incremented. + To make sure, it is best practice to always double-check the output of + batching. + Furthermore, :meth:`~Data.__cat_dim__` defines in which dimension graph + tensors of the same attribute should be concatenated together. + """ + + @classmethod + def from_data_list( + cls, + data_list: List[Graph], + follow_batch: Optional[List[str]] = None, + exclude_keys: Optional[List[str]] = None, + return_tensor: bool = True + ) -> Self: + batch, slice_dict, inc_dict = collate( + cls, + data_list=data_list, + increment=True, + add_batch=not isinstance(data_list[0], Batch), + return_tensor = False, + follow_batch=follow_batch, + exclude_keys=exclude_keys, + ) + + batch._num_graphs = len(data_list) + batch._slice_dict = slice_dict + batch._inc_dict = inc_dict + + if return_tensor == True: + batch.tensor() + return batch + + def get_example(self, idx: int) -> Graph: + r"""Gets the :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object at index :obj:`idx`. + The :class:`~sharker.data.Batch` object must have been created + via :meth:`from_data_list` in order to be able to reconstruct the + initial object. + """ + if not hasattr(self, "_slice_dict"): + raise RuntimeError( + ( + "Cannot reconstruct 'Data' object from 'Batch' because " + "'Batch' was not created via 'Batch.from_data_list()'" + ) + ) + + data = separate( + cls=self.__class__.__bases__[-1], + batch=self, + idx=idx, + slice_dict=getattr(self, "_slice_dict"), + inc_dict=getattr(self, "_inc_dict"), + decrement=True, + ) + + return data + + + def index_select(self, idx: IndexType) -> List[Graph]: + r"""Creates a subset of :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` objects from specified + indices :obj:`idx`. + Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a + list, a tuple, or a :obj:`Tensor` or :obj:`np.ndarray` of type + long or bool. + The :class:`~sharker.data.Batch` object must have been created + via :meth:`from_data_list` in order to be able to reconstruct the + initial objects. + """ + index: Sequence[int] + if isinstance(idx, slice): + index = list(range(self.num_graphs)[idx]) + + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.int64: + index = idx.flatten().tolist() + + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.bool_: + index = mint.nonzero(idx).flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + index = idx.flatten().tolist() + + elif isinstance(idx, np.ndarray) and idx.dtype == bool: + index = idx.flatten().nonzero()[0].flatten().tolist() + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + index = idx + + else: + raise IndexError( + f"Only slices (':'), list, tuples, Tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + return [self.get_example(i) for i in index] + + def __getitem__(self, idx: Union[int, np.integer, str, IndexType]) -> Any: + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, ms.Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + return self.get_example(idx) + elif isinstance(idx, str) or ( + isinstance(idx, tuple) and isinstance(idx[0], str) + ): + # Accessing attributes or node/edge types: + return super().__getitem__(idx) + else: + return self.index_select(idx) + + def to_data_list(self) -> List[Graph]: + r"""Reconstructs the list of :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` objects from the + :class:`~sharker.data.Batch` object. + The :class:`~sharker.data.Batch` object must have been created + via :meth:`from_data_list` in order to be able to reconstruct the + initial objects. + """ + return [self.get_example(i) for i in range(self.num_graphs)] + + @property + def num_graphs(self) -> int: + """Returns the number of graphs in the batch.""" + if hasattr(self, "_num_graphs"): + return self._num_graphs + elif hasattr(self, "ptr"): + return self.ptr.numel() - 1 + elif hasattr(self, "batch"): + return int(self.batch.max()) + 1 + else: + raise ValueError("Can not infer the number of graphs") + + @property + def batch_size(self) -> int: + r"""Alias for :obj:`num_graphs`.""" + return self.num_graphs + + def __len__(self) -> int: + return self.num_graphs + + def __reduce__(self) -> Any: + state = self.__dict__.copy() + return DynamicInheritanceGetter(), self.__class__.__bases__, state diff --git a/mindscience/sharker/data/collate.py b/mindscience/sharker/data/collate.py new file mode 100644 index 000000000..970f035be --- /dev/null +++ b/mindscience/sharker/data/collate.py @@ -0,0 +1,290 @@ +from collections import defaultdict +from collections.abc import Mapping, Sequence +from typing import ( + Any, + Dict, + Iterable, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) +import numpy as np +import mindspore as ms +from mindspore import ops, mint +from .graph import Graph +from .storage import BaseStorage, NodeStorage +from ..utils.functions import cumsum, cumsum_np + +T = TypeVar("T") +SliceDictType = Dict[str, Union[ms.Tensor, Dict[str, ms.Tensor], np.array, Dict[str, np.array]]] +IncDictType = Dict[str, Union[ms.Tensor, Dict[str, ms.Tensor], np.array, Dict[str, np.array]]] + +def collate( + cls: Type[T], + data_list: List[Graph], + increment: bool = True, + add_batch: bool = True, + return_tensor: bool = True, + follow_batch: Optional[Iterable[str]] = None, + exclude_keys: Optional[Iterable[str]] = None, +) -> Tuple[T, SliceDictType, IncDictType]: + # Collates a list of `data` objects into a single object of type `cls`. + # `collate` can handle both homogeneous and heterogeneous data objects by + # individually collating all their stores. + # In addition, `collate` can handle nested data structures such as + # dictionaries and lists. + + if not isinstance(data_list, (list, tuple)): + # Materialize `data_list` to keep the `_parent` weakref alive. + data_list = list(data_list) + + if cls != data_list[0].__class__: # Dynamic inheritance. + out = cls(_base_cls=data_list[0].__class__) + else: + out = cls() + + # Create empty stores: + out.stores_as(data_list[0]) + + follow_batch = set(follow_batch or []) + exclude_keys = set(exclude_keys or []) + + # Group all storage objects of every data object in the `data_list` by key, + # i.e. `key_to_stores = { key: [store_1, store_2, ...], ... }`: + key_to_stores = defaultdict(list) + for data in data_list: + for store in data.stores: + key_to_stores[store._key].append(store) + + # With this, we iterate over each list of storage objects and recursively + # collate all its attributes into a unified representation: + + # We maintain two additional dictionaries: + # * `slice_dict` stores a compressed index representation of each attribute + # and is needed to re-construct individual elements from mini-batches. + # * `inc_dict` stores how individual elements need to be incremented, e.g., + # `edge_index` is incremented by the cumulated sum of previous elements. + # We also need to make use of `inc_dict` when re-constructuing individual + # elements as attributes that got incremented need to be decremented + # while separating to obtain original values. + slice_dict: SliceDictType = {} + inc_dict: IncDictType = {} + for out_store in out.stores: + key = out_store._key + stores = key_to_stores[key] + for attr in stores[0].keys(): + + if attr in exclude_keys: # Do not include top-level attribute. + continue + + values = [store[attr] for store in stores] + + # The `num_nodes` attribute needs special treatment, as we need to + # sum their values up instead of merging them to a list: + if attr == "num_nodes": + out_store._num_nodes = values + out_store.num_nodes = sum(values) + continue + + # Skip batching of `ptr` vectors for now: + if attr == "ptr": + continue + + # Collate attributes into a unified representation: + value, slices, incs = _collate(attr, values, data_list, stores, increment) + + out_store[attr] = value + + if key is not None: # Heterogeneous: + store_slice_dict = slice_dict.get(key, {}) + assert isinstance(store_slice_dict, dict) + store_slice_dict[attr] = slices + slice_dict[key] = store_slice_dict + + store_inc_dict = inc_dict.get(key, {}) + assert isinstance(store_inc_dict, dict) + store_inc_dict[attr] = incs + inc_dict[key] = store_inc_dict + else: # Homogeneous: + slice_dict[attr] = slices + inc_dict[attr] = incs + + # Add an additional batch vector for the given attribute: + if attr in follow_batch: + batch, ptr = _batch_and_ptr(slices) + out_store[f"{attr}_batch"] = batch + out_store[f"{attr}_ptr"] = ptr + + # In case of node-level storages, we add a top-level batch vector it: + if ( + add_batch + and isinstance(stores[0], NodeStorage) + and stores[0].can_infer_num_nodes + ): + repeats = [int(store.num_nodes) or 0 for store in stores] + out_store.batch = repeat(repeats) + out_store.ptr = cumsum_np(np.array(repeats)) + + if return_tensor == True: + out = out.tensor() + return out, slice_dict, inc_dict + + +def _collate( + key: str, + values: List[Any], + data_list: List[Graph], + stores: List[BaseStorage], + increment: bool, +) -> Tuple[Any, Any, Any]: + + elem = values[0] + + if isinstance(elem, ms.Tensor): + # Concatenate a list of `Tensor` along the `cat_dim`. + # NOTE: We need to take care of incrementing elements appropriately. + values = [value.asnumpy() for value in values] + value, slices, incs = _collate(key, values, data_list, stores, increment) + return value, slices, incs + + elif isinstance(elem, np.ndarray): + key = str(key) + cat_dim = data_list[0].__cat_dim__(key, elem, stores[0]) + if elem.ndim == 0: + values = [value.reshape(1) for value in values] + if cat_dim is None: + values = [value[None,:] for value in values] + sizes = np.array([value.shape[cat_dim or 0] for value in values]) + slices = cumsum_np(sizes) + if increment: + incs = get_incs_np(key, values, data_list, stores) + if incs.ndim > 1 or incs[-1] != 0: + values = [value + inc for value, inc in zip(values, incs)] + else: + incs = None + + value = np.concatenate(values, axis=cat_dim or 0) + return value, slices, incs + + elif isinstance(elem, (int, float)): + # Convert a list of numerical values to a `Tensor`. + value = np.array(values) + if increment: + incs = get_incs_np(key, values, data_list, stores) + if (incs[-1]) != 0: + value += incs + else: + incs = None + slices = np.arange(len(values) + 1) + return value, slices, incs + + elif isinstance(elem, Mapping): + # Recursively collate elements of dictionaries. + value_dict, slice_dict, inc_dict = {}, {}, {} + for key in elem.keys(): + value_dict[key], slice_dict[key], inc_dict[key] = _collate( + key, [v[key] for v in values], data_list, stores, increment + ) + return value_dict, slice_dict, inc_dict + + elif ( + isinstance(elem, Sequence) + and not isinstance(elem, str) + and len(elem) > 0 + and isinstance(elem[0], ms.Tensor) + ): + # Recursively collate elements of lists. + value_list, slice_list, inc_list = [], [], [] + for i in range(len(elem)): + value, slices, incs = _collate( + key, [v[i] for v in values], data_list, stores, increment + ) + value_list.append(value) + slice_list.append(slices) + inc_list.append(incs) + return value_list, slice_list, inc_list + + else: + # Other-wise, just return the list of values as it is. + slices = np.arange(len(values) + 1) + return values, slices, None + + +def _batch_and_ptr( + slices: Any, +) -> Tuple[Any, Any]: + if isinstance(slices, ms.Tensor) and slices.dim() == 1: + # Default case, turn slices tensor into batch. + slices_np = slices.asnumpy() + batch, ptr = _batch_and_ptr(slices_np) + return batch, ptr + + if isinstance(slices, np.ndarray) and slices.ndim == 1: + # Default case, turn slices tensor into batch. + repeats = slices[1:] - slices[:-1] + batch = repeat(repeats) + ptr = cumsum_np(repeats) + return batch, ptr + + elif isinstance(slices, Mapping): + # Recursively batch elements of dictionaries. + batch, ptr = {}, {} + for k, v in slices.items(): + batch[k], ptr[k] = _batch_and_ptr(v) + return batch, ptr + + elif ( + isinstance(slices, Sequence) + and not isinstance(slices, str) + and isinstance(slices[0], (ms.Tensor,np.ndarray)) + ): + # Recursively batch elements of lists. + batch, ptr = [], [] + for s in slices: + sub_batch, sub_ptr = _batch_and_ptr(s) + batch.append(sub_batch) + ptr.append(sub_ptr) + return batch, ptr + + else: + # Failure of batching, usually due to slices.dim() != 1 + return None, None + + +############################################################################### + + +def repeat( + repeats) -> np.ndarray: + if isinstance(repeats, List): + repeats = np.array(repeats) + if isinstance(repeats, ms.Tensor): + repeaats = repeats.asnumpy() + outs = np.repeat(np.arange(repeats.shape[0]), repeats, 0) + return outs + + +def get_incs( + key, values: List[Any], data_list: List[Graph], stores: List[BaseStorage] +) -> ms.Tensor: + repeats = [ + data.__inc__(key, value, store) + for value, data, store in zip(values, data_list, stores) + ] + if isinstance(repeats[0], ms.Tensor): + repeats = mint.stack(repeats, dim=0) + else: + repeats = ms.Tensor(repeats) + return cumsum(mint.narrow(repeats, 0, 0, repeats.shape[0] - 1)) + +def get_incs_np( + key, values: List[Any], data_list: List[Graph], stores: List[BaseStorage] +) -> np.ndarray: + repeats = [data.__inc__(key, value, store) + for value, data, store in zip(values, data_list, stores)] + repeats = np.stack(repeats, axis=0) + incs = cumsum_np(repeats[:-1], axis=0) + return incs diff --git a/mindscience/sharker/data/database.py b/mindscience/sharker/data/database.py new file mode 100644 index 000000000..fff7e6a78 --- /dev/null +++ b/mindscience/sharker/data/database.py @@ -0,0 +1,597 @@ +import pickle +import warnings +from abc import ABC, abstractmethod +from dataclasses import dataclass +from functools import cached_property +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +import numpy as np +import mindspore as ms +from mindspore import Tensor +from tqdm import tqdm +from ..utils import CastMixin + + +@dataclass +class TensorInfo(CastMixin): + dtype: ms.Type + size: Tuple[int, ...] = (-1,) + is_edge_index: bool = False + + def __post_init__(self) -> None: + if self.is_edge_index: + self.size = (2, -1) + + +def maybe_cast_to_tensor_info(value: Any) -> Union[Any, TensorInfo]: + if not isinstance(value, dict): + return value + if len(value) < 1 or len(value) > 3: + return value + if "dtype" not in value: + return value + if len(set(value.keys()) | {"dtype", "size", "is_edge_index"}) != 3: + return value + return TensorInfo.cast(value) + + +Schema = Union[Any, Dict[str, Any], Tuple[Any], List[Any]] + +SORT_ORDER_TO_INDEX: Dict[str, int] = { + None: -1, + "row": 0, + "col": 1, +} +INDEX_TO_SORT_ORDER = {v: k for k, v in SORT_ORDER_TO_INDEX.items()} + + +class Database(ABC): + r"""Base class for inserting and retrieving data from a database. + + A database acts as a persisted, out-of-memory and index-based key/value + store for tensor and custom data: + + .. code-block:: python + + db = Database() + db[0] = Data(x=ops.randn(5, 16), y=0, z='id_0') + print(db[0]) + >>> Data(x=[5, 16], y=0, z='id_0') + + To improve efficiency, it is recommended to specify the underlying + :obj:`schema` of the data: + + .. code-block:: python + + db = Database(schema={ # Custom schema: + # Tensor information can be specified through a dictionary: + 'x': dict(dtype=ms.float32, size=(-1, 16)), + 'y': int, + 'z': str, + }) + db[0] = dict(x=ops.randn(5, 16), y=0, z='id_0') + print(db[0]) + >>> {'x': Tensor(...), 'y': 0, 'z': 'id_0'} + + In addition, databases support batch-wise insert and get, and support + syntactic sugar known from indexing :python:`Python` lists, *e.g.*: + + .. code-block:: python + + db = Database() + db[2:5] = ops.randn(3, 16) + print(db[Tensor([2, 3])]) + >>> [Tensor(...), Tensor(...)] + + Args: + schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of + the input data. + Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a + dictionary with :obj:`dtype` and :obj:`size` keys (for specifying + tensor data) as input, and can be nested as a tuple or dictionary. + Specifying the schema will improve efficiency, since by default the + database will use python pickling for serializing and + deserializing. (default: :obj:`object`) + """ + + def __init__(self, schema: Schema = object) -> None: + schema_dict = self._to_dict(maybe_cast_to_tensor_info(schema)) + self.schema: Dict[Union[str, int], Any] = { + key: maybe_cast_to_tensor_info(value) for key, value in schema_dict.items() + } + + def connect(self) -> None: + r"""Connects to the database. + Databases will automatically connect on instantiation. + """ + pass + + def close(self) -> None: + r"""Closes the connection to the database.""" + pass + + @abstractmethod + def insert(self, index: int, data: Any) -> None: + r"""Inserts data at the specified index. + + Args: + index (int): The index at which to insert. + data (Any): The object to insert. + """ + raise NotImplementedError + + def multi_insert( + self, + indices: Union[Sequence[int], Tensor, slice, range], + data_list: Sequence[Any], + batch_size: Optional[int] = None, + log: bool = False, + ) -> None: + r"""Inserts a chunk of data at the specified indices. + + Args: + indices (List[int] or Tensor or range): The indices at which + to insert. + data_list (List[Any]): The objects to insert. + batch_size (int, optional): If specified, will insert the data to + the database in batches of size :obj:`batch_size`. + (default: :obj:`None`) + log (bool, optional): If set to :obj:`True`, will log progress to + the console. (default: :obj:`False`) + """ + if isinstance(indices, slice): + indices = self.slice_to_range(indices) + + length = min(len(indices), len(data_list)) + batch_size = length if batch_size is None else batch_size + + if log and length > batch_size: + desc = f"Insert {length} entries" + offsets = tqdm(range(0, length, batch_size), desc=desc) + else: + offsets = range(0, length, batch_size) + + for start in offsets: + self._multi_insert( + indices[start:start + batch_size], + data_list[start:start + batch_size], + ) + + def _multi_insert( + self, + indices: Union[Sequence[int], Tensor, range], + data_list: Sequence[Any], + ) -> None: + if isinstance(indices, Tensor): + indices = indices.tolist() + for index, data in zip(indices, data_list): + self.insert(index, data) + + @abstractmethod + def get(self, index: int) -> Any: + r"""Gets data from the specified index. + Args: + index (int): The index to query. + """ + raise NotImplementedError + + def multi_get( + self, + indices: Union[Sequence[int], Tensor, slice, range], + batch_size: Optional[int] = None, + ) -> List[Any]: + r"""Gets a chunk of data from the specified indices. + + Args: + indices (List[int] or Tensor or range): The indices to query. + batch_size (int, optional): If specified, will request the data + from the database in batches of size :obj:`batch_size`. + (default: :obj:`None`) + """ + if isinstance(indices, slice): + indices = self.slice_to_range(indices) + + length = len(indices) + batch_size = length if batch_size is None else batch_size + + data_list: List[Any] = [] + for start in range(0, length, batch_size): + chunk_indices = indices[start: start + batch_size] + data_list.extend(self._multi_get(chunk_indices)) + return data_list + + def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: + if isinstance(indices, Tensor): + indices = indices.tolist() + return [self.get(index) for index in indices] + + # Helper functions ######################################################## + + @staticmethod + def _to_dict( + value: Union[Dict[Union[int, str], Any], Sequence[Any], Any], + ) -> Dict[Union[str, int], Any]: + if isinstance(value, dict): + return value + if isinstance(value, (tuple, list)): + return {i: v for i, v in enumerate(value)} + else: + return {0: value} + + def slice_to_range(self, indices: slice) -> range: + start = 0 if indices.start is None else indices.start + stop = len(self) if indices.stop is None else indices.stop + step = 1 if indices.step is None else indices.step + + return range(start, stop, step) + + # Python built-ins ######################################################## + + def __len__(self) -> int: + raise NotImplementedError + + def __getitem__( + self, + key: Union[int, Sequence[int], Tensor, slice, range], + ) -> Union[Any, List[Any]]: + + if isinstance(key, int): + return self.get(key) + else: + return self.multi_get(key) + + def __setitem__( + self, + key: Union[int, Sequence[int], Tensor, slice, range], + value: Union[Any, Sequence[Any]], + ) -> None: + if isinstance(key, int): + self.insert(key, value) + else: + self.multi_insert(key, value) + + def __repr__(self) -> str: + try: + return f"{self.__class__.__name__}({len(self)})" + except NotImplementedError: + return f"{self.__class__.__name__}()" + + +class SQLiteDatabase(Database): + r"""An index-based key/value database based on :obj:`sqlite3`. + + .. note:: + This database implementation requires the :obj:`sqlite3` package. + + Args: + path (str): The path to where the database should be saved. + name (str): The name of the table to save the data to. + schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of + the input data. + Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a + dictionary with :obj:`dtype` and :obj:`size` keys (for specifying + tensor data) as input, and can be nested as a tuple or dictionary. + Specifying the schema will improve efficiency, since by default the + database will use python pickling for serializing and + deserializing. (default: :obj:`object`) + """ + + def __init__(self, path: str, name: str, schema: Schema = object) -> None: + super().__init__(schema) + + warnings.filterwarnings("ignore", ".*given buffer is not writable.*") + + import sqlite3 + + self.path = path + self.name = name + + self._connection: Optional[sqlite3.Connection] = None + self._cursor: Optional[sqlite3.Cursor] = None + + self.connect() + + # Create the table (if it does not exist) by mapping the Python schema + # to the corresponding SQL schema: + sql_schema = ",\n".join( + [ + f" {col_name} {self._to_sql_type(type_info)}" + for col_name, type_info in zip(self._col_names, self.schema.values()) + ] + ) + query = ( + f"CREATE TABLE IF NOT EXISTS {self.name} (\n" + f" id INTEGER PRIMARY KEY,\n" + f"{sql_schema}\n" + f")" + ) + self.cursor.execute(query) + + def connect(self) -> None: + import sqlite3 + + self._connection = sqlite3.connect(self.path) + self._cursor = self._connection.cursor() + + def close(self) -> None: + if self._connection is not None: + self._connection.commit() + self._connection.close() + self._connection = None + self._cursor = None + + @property + def connection(self) -> Any: + if self._connection is None: + raise RuntimeError("No open database connection") + return self._connection + + @property + def cursor(self) -> Any: + if self._cursor is None: + raise RuntimeError("No open database connection") + return self._cursor + + def insert(self, index: int, data: Any) -> None: + query = ( + f"INSERT INTO {self.name} " + f"(id, {self._joined_col_names}) " + f"VALUES (?, {self._dummies})" + ) + self.cursor.execute(query, (index, *self._serialize(data))) + self.connection.commit() + + def _multi_insert( + self, + indices: Union[Sequence[int], Tensor, range], + data_list: Sequence[Any], + ) -> None: + if isinstance(indices, Tensor): + indices = indices.tolist() + + data_list = [ + (index, *self._serialize(data)) for index, data in zip(indices, data_list) + ] + + query = ( + f"INSERT INTO {self.name} " + f"(id, {self._joined_col_names}) " + f"VALUES (?, {self._dummies})" + ) + self.cursor.executemany(query, data_list) + self.connection.commit() + + def get(self, index: int) -> Any: + query = f"SELECT {self._joined_col_names} FROM {self.name} " f"WHERE id = ?" + self.cursor.execute(query, (index,)) + return self._deserialize(self.cursor.fetchone()) + + def multi_get( + self, + indices: Union[Sequence[int], Tensor, slice, range], + batch_size: Optional[int] = None, + ) -> List[Any]: + + if isinstance(indices, slice): + indices = self.slice_to_range(indices) + elif isinstance(indices, Tensor): + indices = indices.tolist() + + # We create a temporary ID table to then perform an INNER JOIN. + # This avoids having a long IN clause and guarantees sorted outputs: + join_table_name = f"{self.name}__join" + # Temporary tables do not lock the database. + query = ( + f"CREATE TEMP TABLE {join_table_name} (\n" + f" id INTEGER,\n" + f" row_id INTEGER\n" + f")" + ) + self.cursor.execute(query) + + query = f"INSERT INTO {join_table_name} (id, row_id) VALUES (?, ?)" + self.cursor.executemany(query, zip(indices, range(len(indices)))) + self.connection.commit() + + query = f"SELECT * FROM {join_table_name}" + self.cursor.execute(query) + + query = ( + f"SELECT {self._joined_col_names} " + f"FROM {self.name} INNER JOIN {join_table_name} " + f"ON {self.name}.id = {join_table_name}.id " + f"ORDER BY {join_table_name}.row_id" + ) + self.cursor.execute(query) + + if batch_size is None: + data_list = self.cursor.fetchall() + else: + data_list = [] + while True: + chunk_list = self.cursor.fetchmany(size=batch_size) + if len(chunk_list) == 0: + break + data_list.extend(chunk_list) + + query = f"DROP TABLE {join_table_name}" + self.cursor.execute(query) + + return [self._deserialize(data) for data in data_list] + + def __len__(self) -> int: + query = f"SELECT COUNT(*) FROM {self.name}" + self.cursor.execute(query) + return self.cursor.fetchone()[0] + + # Helper functions ######################################################## + + @cached_property + def _col_names(self) -> List[str]: + return [f"COL_{key}" for key in self.schema.keys()] + + @cached_property + def _joined_col_names(self) -> str: + return ", ".join(self._col_names) + + @cached_property + def _dummies(self) -> str: + return ", ".join(["?"] * len(self.schema.keys())) + + def _to_sql_type(self, type_info: Any) -> str: + if type_info == int: + return "INTEGER NOT NULL" + if type_info == float: + return "FLOAT" + if type_info == str: + return "TEXT NOT NULL" + else: + return "BLOB NOT NULL" + + def _serialize(self, row: Any) -> List[Any]: + # Serializes the given input data according to `schema`: + # * {int, float, str}: Use as they are. + # * Tensor: Convert into the raw byte string + # * object: Dump via pickle + # If we find a `Tensor` that is not registered as such in + # `schema`, we modify the schema in-place for improved efficiency. + out: List[Any] = [] + row_dict = self._to_dict(row) + for key, schema in self.schema.items(): + col = row_dict[key] + + if isinstance(col, Tensor) and not isinstance(schema, TensorInfo): + self.schema[key] = schema = TensorInfo( + col.dtype, + is_edge_index=False, + ) + + if isinstance(schema, TensorInfo): + assert isinstance(col, Tensor) + col = col.asnumpy() + schema.dtype = col.dtype + out.append(col.tobytes()) + + elif schema in {int, float, str}: + out.append(col) + + else: + out.append(pickle.dumps(col)) + + return out + + def _deserialize(self, row: Tuple[Any]) -> Any: + # Deserializes the DB data according to `schema`: + # * {int, float, str}: Use as they are. + # * Tensor: Load raw byte string with `dtype` and `size` + # information from `schema` + # * object: Load via pickle + out_dict = {} + for i, (key, schema) in enumerate(self.schema.items()): + value = row[i] + + if isinstance(schema, TensorInfo): + if len(value) > 0: + tensor = np.frombuffer(value, dtype=schema.dtype) + else: + tensor = np.empty(0, dtype=schema.dtype) + tensor = Tensor.from_numpy(tensor) + out_dict[key] = tensor.view(*schema.size) + elif schema == float: + out_dict[key] = value if value is not None else float("NaN") + + elif schema in {int, str}: + out_dict[key] = value + + else: + out_dict[key] = pickle.loads(value) + + # In case `0` exists as integer in the schema, this means that the + # schema was passed as either a single entry or a tuple: + if 0 in self.schema: + if len(self.schema) == 1: + return out_dict[0] + else: + return tuple(out_dict.values()) + else: # Otherwise, return the dictionary as it is: + return out_dict + + +class RocksDatabase(Database): + r"""An index-based key/value database based on :obj:`RocksDB`. + + .. note:: + This database implementation requires the :obj:`rocksdict` package. + + .. warning:: + :class:`RocksDatabase` is currently less optimized than + :class:`SQLiteDatabase`. + + Args: + path (str): The path to where the database should be saved. + schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of + the input data. + Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a + dictionary with :obj:`dtype` and :obj:`size` keys (for specifying + tensor data) as input, and can be nested as a tuple or dictionary. + Specifying the schema will improve efficiency, since by default the + database will use python pickling for serializing and + deserializing. (default: :obj:`object`) + """ + + def __init__(self, path: str, schema: Schema = object) -> None: + super().__init__(schema) + + import rocksdict + + self.path = path + + self._db: Optional[rocksdict.Rdict] = None + + self.connect() + + def connect(self) -> None: + import rocksdict + + self._db = rocksdict.Rdict( + self.path, + options=rocksdict.Options(raw_mode=True), + ) + + def close(self) -> None: + if self._db is not None: + self._db.close() + self._db = None + + @property + def db(self) -> Any: + if self._db is None: + raise RuntimeError("No open database connection") + return self._db + + @staticmethod + def to_key(index: int) -> bytes: + return index.to_bytes(8, byteorder="big", signed=True) + + def insert(self, index: int, data: Any) -> None: + self.db[self.to_key(index)] = self._serialize(data) + + def get(self, index: int) -> Any: + return self._deserialize(self.db[self.to_key(index)]) + + def _multi_get(self, indices: Union[Sequence[int], Tensor]) -> List[Any]: + if isinstance(indices, Tensor): + indices = indices.tolist() + data_list = self.db[[self.to_key(index) for index in indices]] + return [self._deserialize(data) for data in data_list] + + # Helper functions ######################################################## + + def _serialize(self, row: Any) -> bytes: + # Ensure that data is not a view of a larger tensor: + if isinstance(row, Tensor): + row = row.copy() + return pickle.dumps(row) + + def _deserialize(self, row: bytes) -> Any: + return pickle.loads(row) diff --git a/mindscience/sharker/data/datapipe.py b/mindscience/sharker/data/datapipe.py new file mode 100644 index 000000000..9342a9e33 --- /dev/null +++ b/mindscience/sharker/data/datapipe.py @@ -0,0 +1,67 @@ +import copy +from typing import Any, Callable, Iterator, Sequence +from .batch import Batch + +IterDataPipe = IterBatcher = object + + +def functional_datapipe(name: str) -> Callable: + return lambda cls: cls + + +class Batcher: + def __init__( + self, + dp: IterDataPipe, + batch_size: int, + drop_last: bool = False, + ) -> None: + super().__init__( + dp, + batch_size=batch_size, + drop_last=drop_last, + wrapper_class=Batch.from_data_list, + ) + + +class DatasetAdapter(IterDataPipe): + def __init__(self, dataset: Sequence[Any]) -> None: + super().__init__() + self.dataset = dataset + self.range = range(len(self)) + + def is_shardable(self) -> bool: + return True + + def apply_sharding(self, num_shards: int, shard_idx: int) -> None: + self.range = range(shard_idx, len(self), num_shards) + + def __iter__(self) -> Iterator: + for i in self.range: + yield self.dataset[i] + + def __len__(self) -> int: + return len(self.dataset) + + +def functional_transform(name: str) -> Callable: + def wrapper(cls: Any) -> Any: + @functional_datapipe(name) + class DynamicMapper(IterDataPipe): + def __init__( + self, + dp: IterDataPipe, + *args: Any, + **kwargs: Any, + ) -> None: + super().__init__() + self.dp = dp + self.fn = cls(*args, **kwargs) + + def __iter__(self) -> Iterator: + for data in self.dp: + yield self.fn(copy.copy(data)) + + return cls + + return wrapper diff --git a/mindscience/sharker/data/dataset.py b/mindscience/sharker/data/dataset.py new file mode 100644 index 000000000..5d467180a --- /dev/null +++ b/mindscience/sharker/data/dataset.py @@ -0,0 +1,429 @@ +import copy +import os +import re +import sys +import warnings +from collections.abc import Sequence +from typing import ( + Any, + Callable, + Iterable, + Iterator, + List, + Optional, + Tuple, + Union, +) + +import numpy as np +import mindspore as ms +from mindspore import ops, mint + +from .graph import Graph +from ..io import fs + +IndexType = Union[slice, ms.Tensor, np.ndarray, Sequence] +MISSING = "???" + + +class Dataset: + r"""Dataset base class for creating graph datasets. + See `here `__ for the accompanying tutorial. + + Args: + root (str, optional): Root directory where the dataset should be saved. + (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + transformed version. + The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + a :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + transformed version. + The data object will be transformed before being saved to disk. + (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + boolean value, indicating whether the data object should be + included in the final dataset. (default: :obj:`None`) + log (bool, optional): Whether to print any console output while + downloading and processing the dataset. (default: :obj:`True`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: + r"""The name of the files in the :obj:`self.raw_dir` folder that must + be present in order to skip downloading. + """ + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: + r"""The name of the files in the :obj:`self.processed_dir` folder that + must be present in order to skip processing. + """ + raise NotImplementedError + + def download(self) -> None: + r"""Downloads the dataset to the :obj:`self.raw_dir` folder.""" + raise NotImplementedError + + def process(self) -> None: + r"""Processes the dataset to the :obj:`self.processed_dir` folder.""" + raise NotImplementedError + + def len(self) -> int: + r"""Returns the number of data objects stored in the dataset.""" + raise NotImplementedError + + def get(self, idx: int) -> Graph: + r"""Gets the data object at index :obj:`idx`.""" + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + log: bool = True, + force_reload: bool = False, + ) -> None: + super().__init__() + + if isinstance(root, str): + root = os.path.expanduser(fs.normpath(root)) + + self.root = root or MISSING + self.transform = transform + self.pre_transform = pre_transform + self.pre_filter = pre_filter + self.log = log + self._indices: Optional[Sequence] = None + self.force_reload = force_reload + + if self.has_download: + self._download() + + if self.has_process: + self._process() + + def indices(self) -> Sequence: + return range(self.len()) if self._indices is None else self._indices + + @property + def raw_dir(self) -> str: + return os.path.join(self.root, "raw") + + @property + def processed_dir(self) -> str: + return os.path.join(self.root, "processed") + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the dataset.""" + data = self[0] + # Do not fill cache for `InMemoryDataset`: + if hasattr(self, "_data_list") and self._data_list is not None: + self._data_list[0] = None + data = data[0] if isinstance(data, tuple) else data + if hasattr(data, "num_node_features"): + return data.num_node_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_node_features'" + ) + + @property + def num_features(self) -> int: + r"""Returns the number of features per node in the dataset. + Alias for :py:attr:`~num_node_features`. + """ + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the dataset.""" + data = self[0] + # Do not fill cache for `InMemoryDataset`: + if hasattr(self, "_data_list") and self._data_list is not None: + self._data_list[0] = None + data = data[0] if isinstance(data, tuple) else data + if hasattr(data, "num_edge_features"): + return data.num_edge_features + raise AttributeError( + f"'{data.__class__.__name__}' object has no " + f"attribute 'num_edge_features'" + ) + + def _infer_num_classes(self, y: Optional[ms.Tensor]) -> int: + if y is None: + return 0 + elif y.numel() == y.shape[0] and not ops.is_floating_point(y): + return int(y.max()) + 1 + elif y.numel() == y.shape[0] and ops.is_floating_point(y): + num_classes = ops.unique(y)[0].numel() + if num_classes > 2: + warnings.warn( + "Found floating-point labels while calling " + "`dataset.num_classes`. Returning the number of " + "unique elements. Please make sure that this " + "is expected before proceeding." + ) + return num_classes + else: + return y.shape[-1] + + @property + def num_classes(self) -> int: + r"""Returns the number of classes in the dataset.""" + # We iterate over the dataset and collect all labels to determine the + # maximum number of classes. Importantly, in rare cases, `__getitem__` + # may produce a tuple of data objects (e.g., when used in combination + # with `RandomLinkSplit`, so we take care of this case here as well: + data_list = _flatten([data for data in self]) + if "y" in data_list[0] and isinstance(data_list[0].y, ms.Tensor): + y = mint.cat([data.y for data in data_list if "y" in data], dim=0) + else: + y = ms.Tensor([data.y for data in data_list if "y" in data]) + + # Do not fill cache for `InMemoryDataset`: + if hasattr(self, "_data_list") and self._data_list is not None: + self._data_list = self.len() * [None] + return self._infer_num_classes(y) + + @property + def raw_paths(self) -> List[str]: + r"""The absolute filepaths that must be present in order to skip + downloading. + """ + files = self.raw_file_names + # Prevent a common source of error in which `file_names` are not + # defined as a property. + if isinstance(files, Callable): + files = files() + return [os.path.join(self.raw_dir, f) for f in to_list(files)] + + @property + def processed_paths(self) -> List[str]: + r"""The absolute filepaths that must be present in order to skip + processing. + """ + files = self.processed_file_names + # Prevent a common source of error in which `file_names` are not + # defined as a property. + if isinstance(files, Callable): + files = files() + return [os.path.join(self.processed_dir, f) for f in to_list(files)] + + @property + def has_download(self) -> bool: + r"""Checks whether the dataset defines a :meth:`download` method.""" + return overrides_method(self.__class__, "download") + + def _download(self): + if len(self.raw_paths) != 0 and all([fs.exists(f) for f in self.raw_paths]): + return + + fs.makedirs(self.raw_dir, exist_ok=True) + self.download() + + @property + def has_process(self) -> bool: + r"""Checks whether the dataset defines a :meth:`process` method.""" + return overrides_method(self.__class__, "process") + + def _process(self): + + if not self.force_reload: + if len(self.processed_paths) != 0 and all( + [fs.exists(f) for f in self.processed_paths] + ): + return + + if self.log and "pytest" not in sys.modules: + print("Processing...", file=sys.stderr) + + fs.makedirs(self.processed_dir, exist_ok=True) + self.process() + + path = os.path.join(self.processed_dir, "pre_transform.pt") + fs.pickle_save(_repr(self.pre_transform), path) + path = os.path.join(self.processed_dir, "pre_filter.pt") + fs.pickle_save(_repr(self.pre_filter), path) + + if self.log and "pytest" not in sys.modules: + print("Done!", file=sys.stderr) + + def __len__(self) -> int: + r"""The number of examples in the dataset.""" + return len(self.indices()) + + def __getitem__( + self, + idx: Union[int, np.integer, IndexType], + ) -> Union["Dataset", Graph]: + r"""In case :obj:`idx` is of type integer, will return the data object + at index :obj:`idx` (and transforms it in case :obj:`transform` is + present). + In case :obj:`idx` is a slicing object, *e.g.*, :obj:`[2:5]`, a list, a + tuple, or a :obj:`Tensor` or :obj:`np.ndarray` of type long or + bool, will return a subset of the dataset at the specified indices. + """ + if ( + isinstance(idx, (int, np.integer)) + or (isinstance(idx, ms.Tensor) and idx.dim() == 0) + or (isinstance(idx, np.ndarray) and np.isscalar(idx)) + ): + + data = self.get(self.indices()[idx]) + data = data if self.transform is None else self.transform(data) + return data + + else: + return self.index_select(idx) + + def __iter__(self) -> Iterator[Graph]: + for i in range(len(self)): + yield self[i] + + def index_select(self, idx: IndexType) -> "Dataset": + r"""Creates a subset of the dataset from specified indices :obj:`idx`. + Indices :obj:`idx` can be a slicing object, *e.g.*, :obj:`[2:5]`, a + list, a tuple, or a :obj:`Tensor` or :obj:`np.ndarray` of type + long or bool. + """ + indices = self.indices() + + if isinstance(idx, slice): + start, stop, step = idx.start, idx.stop, idx.step + # Allow floating-point slicing, e.g., dataset[:0.9] + if isinstance(start, float): + start = round(start * len(self)) + if isinstance(stop, float): + stop = round(stop * len(self)) + idx = slice(start, stop, step) + + indices = indices[idx] + + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.bool_: + idx = idx.flatten().nonzero() + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == np.int64: + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, np.ndarray) and idx.dtype == bool: + idx = idx.flatten().nonzero()[0] + return self.index_select(idx.flatten().tolist()) + + elif isinstance(idx, Sequence) and not isinstance(idx, str): + indices = [indices[i] for i in idx] + + else: + raise IndexError( + f"Only slices (':'), list, tuples, Tensor and " + f"np.ndarray of dtype long or bool are valid indices (got " + f"'{type(idx).__name__}')" + ) + + dataset = copy.copy(self) + dataset._indices = indices + return dataset + + def shuffle( + self, + return_perm: bool = False, + ) -> Union["Dataset", Tuple["Dataset", ms.Tensor]]: + r"""Randomly shuffles the examples in the dataset. + + Args: + return_perm (bool, optional): If set to :obj:`True`, will also + return the random permutation used to shuffle the dataset. + (default: :obj:`False`) + """ + perm = ops.shuffle(mint.arange(len(self))) + dataset = self.index_select(perm) + return (dataset, perm) if return_perm is True else dataset + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"{self.__class__.__name__}({arg_repr})" + + def get_summary(self) -> Any: + r"""Collects summary statistics for the dataset.""" + from .summary import Summary + + return Summary.from_dataset(self) + + def print_summary(self) -> None: + r"""Prints summary statistics of the dataset to the console.""" + print(str(self.get_summary())) + + def to_datapipe(self) -> Any: + r"""Converts the dataset into a :class:`DatasetAdapter`. + + The returned instance can then be used with :mindgeometric:`MindGeometric's` built-in + :class:`DataPipes` for baching graphs as follows: + + .. code-block:: python + + from mindscience.sharker.datasets import QM9 + + dp = QM9(root='./data/QM9/').to_datapipe() + dp = dp.batch_graphs(batch_size=2, drop_last=True) + + for batch in dp: + pass + """ + from .datapipe import DatasetAdapter + + return DatasetAdapter(self) + + +def overrides_method(cls, method_name: str) -> bool: + from .in_memory import InMemoryDataset + + if method_name in cls.__dict__: + return True + + out = False + for base in cls.__bases__: + if base != Dataset and base != InMemoryDataset: + out |= overrides_method(base, method_name) + return out + + +def to_list(value: Any) -> Sequence: + if isinstance(value, Sequence) and not isinstance(value, str): + return value + else: + return [value] + + +def _repr(obj: Any) -> str: + if obj is None: + return "None" + return re.sub("(<.*?)\\s.*(>)", r"\1\2", str(obj)) + + +def _flatten(data_list: Iterable[Any]) -> List[Graph]: + outs: List[Graph] = [] + for data in data_list: + if isinstance(data, Graph): + outs.append(data) + elif isinstance(data, (tuple, list)): + outs.extend(_flatten(data)) + elif isinstance(data, dict): + outs.extend(_flatten(data.values())) + return outs diff --git a/mindscience/sharker/data/download.py b/mindscience/sharker/data/download.py new file mode 100644 index 000000000..009fd1990 --- /dev/null +++ b/mindscience/sharker/data/download.py @@ -0,0 +1,67 @@ +import os +import os.path as osp +import ssl +import sys +from urllib import request +from typing import Optional + +import fsspec + +from ..io import fs + + +def download_url( + url: str, + folder: str, + log: bool = True, + filename: Optional[str] = None, +): + r"""Downloads the content of an URL to a specific folder. + + Args: + url (str): The URL. + folder (str): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + filename (str, optional): The filename of the downloaded file. If set + to :obj:`None`, will correspond to the filename given by the URL. + (default: :obj:`None`) + """ + if filename is None: + filename = url.rpartition("/")[2] + filename = filename if filename[0] == "?" else filename.split("?")[0] + + path = osp.join(folder, filename) + + if fs.exists(path): + if log and "pytest" not in sys.modules: + print(f"Using existing file {filename}", file=sys.stderr) + return path + + if log and "pytest" not in sys.modules: + print(f"Downloading {url}", file=sys.stderr) + + os.makedirs(folder, exist_ok=True) + + context = ssl._create_unverified_context() + data = request.urlopen(url, context=context) + + with fsspec.open(path, "wb") as f: + while True: + chunk = data.read(10 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + + return path + + +def download_google_url( + id: str, + folder: str, + filename: str, + log: bool = True, +): + r"""Downloads the content of a Google Drive ID to a specific folder.""" + url = f"https://drive.usercontent.google.com/download?id={id}&confirm=t" + return download_url(url, folder, log, filename) diff --git a/mindscience/sharker/data/extract.py b/mindscience/sharker/data/extract.py new file mode 100644 index 000000000..7b58726f3 --- /dev/null +++ b/mindscience/sharker/data/extract.py @@ -0,0 +1,77 @@ +import bz2 +import gzip +import os +import sys +import tarfile +import zipfile + + +def maybe_log(path: str, log: bool = True) -> None: + if log and 'pytest' not in sys.modules: + print(f'Extracting {path}', file=sys.stderr) + + +def extract_tar( + path: str, + folder: str, + mode: str = 'r:gz', + log: bool = True, +) -> None: + r"""Extracts a tar archive to a specific folder. + + Args: + path (str): The path to the tar archive. + folder (str): The folder. + mode (str, optional): The compression mode. (default: :obj:`"r:gz"`) + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + maybe_log(path, log) + with tarfile.open(path, mode) as f: + f.extractall(folder) + + +def extract_zip(path: str, folder: str, log: bool = True) -> None: + r"""Extracts a zip archive to a specific folder. + + Args: + path (str): The path to the tar archive. + folder (str): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + maybe_log(path, log) + with zipfile.ZipFile(path, 'r') as f: + f.extractall(folder) + + +def extract_bz2(path: str, folder: str, log: bool = True) -> None: + r"""Extracts a bz2 archive to a specific folder. + + Args: + path (str): The path to the tar archive. + folder (str): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + maybe_log(path, log) + path = os.path.abspath(path) + with bz2.open(path, 'r') as r: + with open(os.path.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w: + w.write(r.read()) + + +def extract_gz(path: str, folder: str, log: bool = True) -> None: + r"""Extracts a gz archive to a specific folder. + + Args: + path (str): The path to the tar archive. + folder (str): The folder. + log (bool, optional): If :obj:`False`, will not print anything to the + console. (default: :obj:`True`) + """ + maybe_log(path, log) + path = os.path.abspath(path) + with gzip.open(path, 'r') as r: + with open(os.path.join(folder, '.'.join(path.split('.')[:-1])), 'wb') as w: + w.write(r.read()) diff --git a/mindscience/sharker/data/graph.py b/mindscience/sharker/data/graph.py new file mode 100644 index 000000000..ceb14a9f1 --- /dev/null +++ b/mindscience/sharker/data/graph.py @@ -0,0 +1,971 @@ +import copy +import warnings +from collections.abc import Mapping, Sequence +from itertools import chain +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import mindspore as ms +from typing_extensions import Self +from mindspore import Tensor, ops, mint + +from .storage import BaseStorage, EdgeStorage, GlobalStorage, NodeStorage +from ..utils import select, subgraph + + +class Data: + def __getattr__(self, key: str) -> Any: + raise NotImplementedError + + def __setattr__(self, key: str, value: Any): + raise NotImplementedError + + def __delattr__(self, key: str): + raise NotImplementedError + + def __getitem__(self, key: str) -> Any: + raise NotImplementedError + + def __setitem__(self, key: str, value: Any): + raise NotImplementedError + + def __delitem__(self, key: str): + raise NotImplementedError + + def __copy__(self): + raise NotImplementedError + + def __deepcopy__(self, memo): + raise NotImplementedError + + def __repr__(self) -> str: + raise NotImplementedError + + def stores_as(self, data: Self): + raise NotImplementedError + + @property + def stores(self) -> List[BaseStorage]: + raise NotImplementedError + + @property + def node_stores(self) -> List[NodeStorage]: + raise NotImplementedError + + @property + def edge_stores(self) -> List[EdgeStorage]: + raise NotImplementedError + + def to_dict(self) -> Dict[str, Any]: + r"""Returns a dictionary of stored key/value pairs.""" + raise NotImplementedError + + def to_namedtuple(self) -> NamedTuple: + r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" + raise NotImplementedError + + def update(self, data: Self) -> Self: + r"""Updates the data object with the elements from another data object. + Added elements will override existing ones (in case of duplicates). + """ + raise NotImplementedError + + def concat(self, data: Self, return_tensor: bool = True) -> Self: + r"""Concatenates :obj:`self` with another :obj:`data` object. + All values needs to have matching shapes at non-concat dimensions. + """ + out = copy.copy(self) + for store, other_store in zip(out.stores, data.stores): + store.concat(other_store) + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + r"""Returns the dimension for which the value :obj:`value` of the + attribute :obj:`key` will get concatenated when creating mini-batches + using :class:`sharker.loader.DataLoader`. + + .. note:: + + This method is for internal use only, and should only be overridden + in case the mini-batch creation process is corrupted for a specific + attribute. + """ + raise NotImplementedError + + def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: + r"""Returns the incremental count to cumulatively increase the value + :obj:`value` of the attribute :obj:`key` when creating mini-batches + using :class:`sharker.loader.DataLoader`. + + .. note:: + + This method is for internal use only, and should only be overridden + in case the mini-batch creation process is corrupted for a specific + attribute. + """ + raise NotImplementedError + + ########################################################################### + + def keys(self) -> List[str]: + r"""Returns a list of all graph attribute names.""" + out = [] + for store in self.stores: + out += list(store.keys()) + return list(set(out)) + + def __len__(self) -> int: + r"""Returns the number of graph attributes.""" + return len(self.keys()) + + def __contains__(self, key: str) -> bool: + r"""Returns :obj:`True` if the attribute :obj:`key` is present in the + data. + """ + return key in self.keys() + + def __getstate__(self) -> Dict[str, Any]: + return self.__dict__ + + def __setstate__(self, mapping: Dict[str, Any]): + for key, value in mapping.items(): + self.__dict__[key] = value + + @property + def num_nodes(self) -> Optional[int]: + r"""Returns the number of nodes in the graph. + + .. note:: + The number of nodes in the data object is automatically inferred + in case node-level attributes are present, *e.g.*, :obj:`data.x`. + In some cases, however, a graph may only be given without any + node-level attributes. + :mindgeometric:`MindGeometric` then *guesses* the number of nodes according to + :obj:`edge_index.max().item() + 1`. + However, in case there exists isolated nodes, this number does not + have to be correct which can result in unexpected behavior. + Thus, we recommend to set the number of nodes in your data object + explicitly via :obj:`data.num_nodes = ...`. + You will be given a warning that requests you to do so. + """ + try: + size = sum([v.num_nodes for v in self.node_stores]) + if isinstance(size, Tensor): + size = size.item() + return size + except TypeError: + return None + + @property + def shape(self) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: + r"""Returns the size of the adjacency matrix induced by the graph.""" + shape = (self.num_nodes, self.num_nodes) + return shape + + @property + def num_edges(self) -> int: + r"""Returns the number of edges in the graph. + For undirected graphs, this will return the number of bi-directional + edges, which is double the amount of unique edges. + """ + size = sum([v.num_edges for v in self.edge_stores]) + if isinstance(size, Tensor): + size = size.item() + return size + + def node_attrs(self) -> List[str]: + r"""Returns all node-level tensor attribute names.""" + return list(set(chain(*[s.node_attrs() for s in self.node_stores]))) + + def edge_attrs(self) -> List[str]: + r"""Returns all edge-level tensor attribute names.""" + return list(set(chain(*[s.edge_attrs() for s in self.edge_stores]))) + + @property + def node_offsets(self) -> Dict[str, int]: + out: Dict[str, int] = {} + offset: int = 0 + for store in self.node_stores: + out[store._key] = offset + offset = offset + store.num_nodes + return out + + def generate_ids(self): + r"""Generates and sets :obj:`n_id` and :obj:`e_id` attributes to assign + each node and edge to a continuously ascending and unique ID. + """ + for store in self.node_stores: + store.n_id = np.arange(store.num_nodes) + for store in self.edge_stores: + store.e_id = np.arange(store.num_edges) + + def is_sorted(self, sort_by_row: bool = True) -> bool: + r"""Returns :obj:`True` if edge indices :obj:`edge_index` are sorted. + + Args: + sort_by_row (bool, optional): If set to :obj:`False`, will require + column-wise order/by destination node order of + :obj:`edge_index`. (default: :obj:`True`) + """ + input_graph = copy.copy(self).numpy() + return all([store.is_sorted(sort_by_row) for store in input_graph.edge_stores]) + + def sort(self, sort_by_row: bool = True, return_tensor: bool = True) -> Self: + r"""Sorts edge indices :obj:`edge_index` and their corresponding edge + features. + + Args: + sort_by_row (bool, optional): If set to :obj:`False`, will sort + :obj:`edge_index` in column-wise order/by destination node. + (default: :obj:`True`) + """ + out = copy.copy(self).numpy() + for store in out.edge_stores: + store.sort(sort_by_row) + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def is_coalesced(self) -> bool: + r"""Returns :obj:`True` if edge indices :obj:`edge_index` are sorted + and do not contain duplicate entries. + """ + input_graph = copy.copy(self).numpy() + return all([store.is_coalesced() for store in input_graph.edge_stores]) + + def coalesce(self, return_tensor: bool = True) -> Self: + r"""Sorts and removes duplicated entries from edge indices + :obj:`edge_index`. + """ + out = copy.copy(self).numpy() + for store in out.edge_stores: + store.coalesce() + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def is_sorted_by_time(self) -> bool: + r"""Returns :obj:`True` if :obj:`time` is sorted.""" + input_graph = copy.copy(self).numpy() + return all([store.is_sorted_by_time() for store in input_graph.stores]) + + def sort_by_time(self, return_tensor: bool = True) -> Self: + r"""Sorts data associated with :obj:`time` according to :obj:`time`.""" + out = copy.copy(self).numpy() + for store in out.stores: + store.sort_by_time() + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def snapshot( + self, + start_time: Union[float, int], + end_time: Union[float, int], + return_tensor: bool = True + ) -> Self: + r"""Returns a snapshot of :obj:`data` to only hold events that occurred + in period :obj:`[start_time, end_time]`. + """ + out = copy.copy(self).numpy() + for store in out.stores: + store.snapshot(start_time, end_time) + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def up_to(self, end_time: Union[float, int], return_tensor: bool = True) -> Self: + r"""Returns a snapshot of :obj:`data` to only hold events that occurred + up to :obj:`end_time` (inclusive of :obj:`edge_time`). + """ + out = copy.copy(self).numpy() + for store in out.stores: + store.up_to(end_time) + if return_tensor == True: + out.tensor() + else: + out.numpy() + return out + + def has_isolated_nodes(self) -> bool: + r"""Returns :obj:`True` if the graph contains isolated nodes.""" + return any([store.has_isolated_nodes() for store in self.edge_stores]) + + def has_self_loops(self) -> bool: + """Returns :obj:`True` if the graph contains self-loops.""" + return any([store.has_self_loops() for store in self.edge_stores]) + + def is_undirected(self) -> bool: + r"""Returns :obj:`True` if graph edges are undirected.""" + return all([store.is_undirected() for store in self.edge_stores]) + + def is_directed(self) -> bool: + r"""Returns :obj:`True` if graph edges are directed.""" + return not self.is_undirected() + + def apply_(self, func: Callable, *args: str): + r"""Applies the in-place function :obj:`func`, either to all attributes + or only the ones given in :obj:`*args`. + """ + for store in self.stores: + store.apply_(func, *args) + return self + + def apply(self, func: Callable, *args: str): + r"""Applies the function :obj:`func`, either to all attributes or only + the ones given in :obj:`*args`. + """ + for store in self.stores: + store.apply(func, *args) + return self + + def numpy(self, *args: str): + r"""Copies attributes to CPU memory, either for all attributes or only + the ones given in :obj:`*args`. + """ + return self.apply(lambda x: x.asnumpy() if isinstance(x, Tensor) else x, *args) + + def tensor(self, *args: str): + r"""Copies attributes to CPU memory, either for all attributes or only + the ones given in :obj:`*args`. + """ + return self.apply( + lambda x: Tensor.from_numpy(x) if isinstance(x, np.ndarray) else x, *args + ) + + def copy(self, *args: str): + r"""Performs cloning of tensors, either for all attributes or only the + ones given in :obj:`*args`. + """ + return copy.copy(self).apply(lambda x: x.copy(), *args) +############################################################################### + +@ms.jit_class +class Graph(Data): + r"""A graph object describing a homogeneous graph. + The data object can hold node-level, link-level and graph-level attributes. + In general, :class:`~sharker.data.Graph` tries to mimic the + behavior of a regular :python:`Python` dictionary. + In addition, it provides useful functionality for analyzing graph + structures, and provides basic MindSpore tensor functionalities. + See `here `__ for the accompanying + tutorial. + + .. code-block:: python + + from mindscience.sharker.data import Graph + + data = Graph(x=x, edge_index=edge_index, ...) + + # Add additional arguments to `data`: + data.train_idx = Tensor([...], dtype=ms.int64) + data.test_mask = Tensor([...], dtype=ms.bool_) + + # Analyzing the graph structure: + data.num_nodes + >>> 23 + + data.is_directed() + >>> False + + # MindSpore tensor functionality: + data = data.pin_memory() + + Args: + x (Tensor, optional): Node feature matrix with shape + :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Graph connectivity in COO format + with shape :obj:`[2, num_edges]`. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. (default: :obj:`None`) + y (Tensor, optional): Graph-level or node-level ground-truth + labels with arbitrary shape. (default: :obj:`None`) + crd (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + time (Tensor, optional): The timestamps for each event with shape + :obj:`[num_edges]` or :obj:`[num_nodes]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + """ + + def __init__( + self, + x: Optional[Tensor] = None, + edge_index: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + y: Optional[Union[Tensor, int, float]] = None, + crd: Optional[Tensor] = None, + time: Optional[Tensor] = None, + file_name: Optional[str] = None, + node_num: Optional[int] = None, + **kwargs, + ): + self.__dict__["_store"] = GlobalStorage(_parent=self) + + if x is not None: + self.x = x + if edge_index is not None: + self.edge_index = edge_index + if edge_attr is not None: + self.edge_attr = edge_attr + if y is not None: + self.y = y + if crd is not None: + self.crd = crd + if time is not None: + self.time = time + if file_name is not None: + self.file_name = file_name + if node_num is not None: + self.node_num = node_num + + for key, value in kwargs.items(): + setattr(self, key, value) + + def __getattr__(self, key: str) -> Any: + if "_store" not in self.__dict__: + raise RuntimeError( + "The 'data' object was created by an older version of MindeGometric. " + "If this error occurred while loading an already existing " + "dataset, remove the 'processed/' directory in the dataset's " + "root folder and try again." + ) + return getattr(self._store, key) + + def __setattr__(self, key: str, value: Any): + propobj = getattr(self.__class__, key, None) + if propobj is not None and getattr(propobj, "fset", None) is not None: + propobj.fset(self, value) + else: + setattr(self._store, key, value) + + def __delattr__(self, key: str): + delattr(self._store, key) + + def __getitem__(self, key: str) -> Any: + return self._store[key] + + def __setitem__(self, key: str, value: Any): + self._store[key] = value + + def __delitem__(self, key: str): + if key in self._store: + del self._store[key] + + def __copy__(self): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + out.__dict__["_store"] = copy.copy(self._store) + out._store._parent = out + return out + + def __deepcopy__(self, memo): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = copy.deepcopy(value, memo) + out._store._parent = out + return out + + def __repr__(self) -> str: + cls = self.__class__.__name__ + has_dict = any([isinstance(v, Mapping) for v in self._store.values()]) + + if not has_dict: + info = [size_repr(k, v) for k, v in self._store.items()] + info = ", ".join(info) + return f"{cls}({info})" + else: + info = [size_repr(k, v, indent=2) for k, v in self._store.items()] + info = ",\n".join(info) + return f"{cls}(\n{info}\n)" + + @property + def num_nodes(self) -> Optional[int]: + return super().num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: Optional[int]): + self._store.num_nodes = num_nodes + + def stores_as(self, data: Self): + return self + + @property + def stores(self) -> List[BaseStorage]: + return [self._store] + + @property + def node_stores(self) -> List[NodeStorage]: + return [self._store] + + @property + def edge_stores(self) -> List[EdgeStorage]: + return [self._store] + + def to_dict(self) -> Dict[str, Any]: + return self._store.to_dict() + + def to_namedtuple(self) -> NamedTuple: + return self._store.to_namedtuple() + + def update(self, data: Union[Self, Dict[str, Any]]) -> Self: + for key, value in data.items(): + self[key] = value + return self + + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + if "adj" in key: + return (0, 1) + elif "index" in key or key == "face": + return -1 + else: + return 0 + + def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: + if "batch" in key and isinstance(value, (Tensor, np.ndarray)): + return int(value.max()) + 1 + elif "index" in key or key == "face": + return self.num_nodes + else: + return 0 + + def validate(self, raise_on_error: bool = True) -> bool: + r"""Validates the correctness of the data.""" + cls_name = self.__class__.__name__ + status = True + + num_nodes = self.num_nodes + if num_nodes is None: + status = False + warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'", raise_on_error) + + if "edge_index" in self: + edge_index_np = self.edge_index.asnumpy() if isinstance(self.edge_index, ms.Tensor) else self.edge_index + if edge_index_np.ndim != 2 or edge_index_np.shape[0] != 2: + status = False + warn_or_raise( + f"'edge_index' needs to be of shape [2, num_edges] in " + f"'{cls_name}' (found {edge_index_np.shape})", + raise_on_error, + ) + + if "edge_index" in self and self.edge_index.size > 0: + if np.max(edge_index_np) < 0: + status = False + warn_or_raise( + f"'edge_index' contains negative indices in " + f"'{cls_name}' (found {int(edge_index_np.min())})", + raise_on_error, + ) + + if np.max(edge_index_np) >= num_nodes: + status = False + warn_or_raise( + f"'edge_index' contains larger indices than the number " + f"of nodes ({num_nodes}) in '{cls_name}' " + f"(found {int(edge_index_np.max())})", + raise_on_error, + ) + + return status + + def is_node_attr(self, key: str) -> bool: + r"""Returns :obj:`True` if the object at key :obj:`key` denotes a + node-level tensor attribute. + """ + return self._store.is_node_attr(key) + + def is_edge_attr(self, key: str) -> bool: + r"""Returns :obj:`True` if the object at key :obj:`key` denotes an + edge-level tensor attribute. + """ + return self._store.is_edge_attr(key) + + def subgraph(self, subset: Union[Tensor, np.ndarray], return_tensor: bool = True) -> Self: + r"""Returns the induced subgraph given by the node indices + :obj:`subset`. + + Args: + subset (LongTensor or BoolTensor): The nodes to keep. + """ + if isinstance(subset, Tensor): + subset = subset.asnumpy() + input_graph = copy.copy(self).numpy() + + if "edge_index" in input_graph: + edge_index, _, edge_mask = subgraph( + subset, + input_graph.edge_index, + relabel_nodes=True, + num_nodes=input_graph.num_nodes, + return_edge_mask=True, + ) + else: + edge_index = None + edge_mask = np.ones(input_graph.num_edges, dtype=np.bool_) + + data = copy.copy(input_graph) + + for key, value in input_graph: + if key == "edge_index": + data.edge_index = edge_index + elif key == "num_nodes": + if subset.dtype == np.bool_: + data.num_nodes = int(subset.sum()) + else: + data.num_nodes = subset.shape[0] + elif input_graph.is_node_attr(key): + cat_dim = input_graph.__cat_dim__(key, value) + data[key] = select(value, subset, axis=cat_dim) + elif input_graph.is_edge_attr(key): + cat_dim = input_graph.__cat_dim__(key, value) + data[key] = select(value, edge_mask, axis=cat_dim) + + if return_tensor == True: + data.tensor() + else: + data.numpy() + return data + + def edge_subgraph(self, subset: Union[Tensor, np.ndarray], return_tensor: bool = True) -> Self: + r"""Returns the induced subgraph given by the edge indices + :obj:`subset`. + Will currently preserve all the nodes in the graph, even if they are + isolated after subgraph computation. + + Args: + subset (LongTensor or BoolTensor): The edges to keep. + """ + data = copy.copy(self).numpy() + input_graph = copy.copy(self).numpy() + if isinstance(subset, Tensor): + subset = subset.asnumpy() + + for key, value in input_graph: + if input_graph.is_edge_attr(key): + cat_dim = input_graph.__cat_dim__(key, value) + data[key] = select(value, subset, axis=cat_dim) + if return_tensor == True: + data.tensor() + else: + data.numpy() + return data + + def to_hetero( + self, + node_type: Optional[Tensor] = None, + edge_type: Optional[Tensor] = None, + node_type_names: Optional[List[str]] = None, + edge_type_names: Optional[List[Tuple[str, str, str]]] = None, + return_tensor: bool = True + ): + r"""Converts a :class:`~sharker.data.Graph` object to a + heterogeneous :class:`~sharker.data.HeteroGraph` object. + For this, node and edge attributes are splitted according to the + node-level and edge-level vectors :obj:`node_type` and + :obj:`edge_type`, respectively. + :obj:`node_type_names` and :obj:`edge_type_names` can be used to give + meaningful node and edge type names, respectively. + That is, the node_type :obj:`0` is given by :obj:`node_type_names[0]`. + If the :class:`~sharker.data.Graph` object was constructed via + :meth:`~sharker.data.HeteroGraph.to_homogeneous`, the object can + be reconstructed without any need to pass in additional arguments. + + Args: + node_type (Tensor, optional): A node-level vector denoting + the type of each node. (default: :obj:`None`) + edge_type (Tensor, optional): An edge-level vector denoting + the type of each edge. (default: :obj:`None`) + node_type_names (List[str], optional): The names of node types. + (default: :obj:`None`) + edge_type_names (List[Tuple[str, str, str]], optional): The names + of edge types. (default: :obj:`None`) + """ + from .heterograph import HeteroGraph + + input_graph = copy.copy(self).numpy() + if node_type is not None and isinstance(node_type, Tensor): + node_type = node_type.asnumpy() + if edge_type is not None and isinstance(edge_type, Tensor): + edge_type = edge_type.asnumpy() + + + if node_type is None: + node_type = input_graph._store.get("node_type", None) + if node_type is None: + node_type = np.zeros(input_graph.num_nodes, dtype=np.int64) + + if node_type_names is None: + store = input_graph._store + node_type_names = store.__dict__.get("_node_type_names", None) + if node_type_names is None: + node_type_names = [str(i) for i in np.unique(node_type)] + + if edge_type is None: + edge_type = input_graph._store.get("edge_type", None) + if edge_type is None: + edge_type = np.zeros(input_graph.num_edges, dtype=np.int64) + + if edge_type_names is None: + store = input_graph._store + edge_type_names = store.__dict__.get("_edge_type_names", None) + if edge_type_names is None: + edge_type_names = [] + edge_index = input_graph.edge_index + for i in np.unique(edge_type): + src, dst = edge_index[:, edge_type == i] + src_types = np.unique(node_type[src]) + dst_types = np.unique(node_type[dst]) + if len(src_types) != 1 and len(dst_types) != 1: + raise ValueError( + "Could not construct a 'HeteroGraph' object from the " + "'Graph' object because single edge types span over " + "multiple node types" + ) + edge_type_names.append( + ( + node_type_names[src_types.item(0)], + str(i), + node_type_names[dst_types.item(0)], + ) + ) + + # We iterate over node types to find the local node indices belonging + # to each node type. Furthermore, we create a global `index_map` vector + # that maps global node indices to local ones in the final + # heterogeneous graph: + node_ids, index_map = {}, np.zeros_like(node_type) + for i, key in enumerate(node_type_names): + node_ids[i] = np.nonzero((node_type == i).reshape(-1))[0] + index_map[node_ids[i]] = np.arange(len(node_ids[i])) + + # We iterate over edge types to find the local edge indices: + edge_ids = {} + for i, key in enumerate(edge_type_names): + edge_ids[i] = np.nonzero((edge_type == i).reshape(-1))[0] + + graph = HeteroGraph() + + for i, key in enumerate(node_type_names): + for attr, value in input_graph.items(): + if attr in {"node_type", "edge_type", "ptr"}: + continue + elif isinstance(value, np.ndarray) and input_graph.is_node_attr(attr): + cat_dim = input_graph.__cat_dim__(attr, value) + graph[key][attr] = np.take(value, node_ids[i], cat_dim) + if len(graph[key]) == 0: + graph[key].num_nodes = node_ids[i].shape[0] + + for i, key in enumerate(edge_type_names): + src, _, dst = key + for attr, value in input_graph.items(): + if attr in {"node_type", "edge_type", "ptr"}: + continue + elif attr == "edge_index": + edge_index = value[:, edge_ids[i]] + edge_index[0] = index_map[edge_index[0]] + edge_index[1] = index_map[edge_index[1]] + graph[key].edge_index = edge_index + elif isinstance(value, np.ndarray) and input_graph.is_edge_attr(attr): + cat_dim = input_graph.__cat_dim__(attr, value) + graph[key][attr] = np.take(value, edge_ids[i], cat_dim) + + # Add global attributes. + exclude_keys = set(graph.keys()) | { + "node_type", + "edge_type", + "edge_index", + "num_nodes", + "ptr", + } + for attr, value in input_graph.items(): + if attr in exclude_keys: + continue + graph[attr] = value + + if return_tensor == True: + graph.tensor() + else: + graph.numpy() + return graph + + ########################################################################### + + @classmethod + def from_dict(cls, mapping: Dict[str, Any]) -> Self: + r"""Creates a :class:`~sharker.data.Graph` object from a + dictionary. + """ + return cls(**mapping) + + @property + def num_node_features(self) -> int: + r"""Returns the number of features per node in the graph.""" + return self._store.num_node_features + + @property + def num_features(self) -> int: + r"""Returns the number of features per node in the graph. + Alias for :py:attr:`~num_node_features`. + """ + return self.num_node_features + + @property + def num_edge_features(self) -> int: + r"""Returns the number of features per edge in the graph.""" + return self._store.num_edge_features + + @property + def num_node_types(self) -> int: + r"""Returns the number of node types in the graph.""" + return int(self.node_type.max()) + 1 if "node_type" in self else 1 + + @property + def num_edge_types(self) -> int: + r"""Returns the number of edge types in the graph.""" + return int(self.edge_type.max()) + 1 if "edge_type" in self else 1 + + def __iter__(self) -> Iterable: + r"""Iterates over all attributes in the data, yielding their attribute + names and values. + """ + for key, value in self._store.items(): + yield key, value + + @property + def x(self) -> Optional[Tensor]: + return self["x"] if "x" in self._store else None + + @x.setter + def x(self, x: Optional[Tensor]): + self._store.x = x + + @property + def edge_index(self) -> Optional[Tensor]: + return self["edge_index"] if "edge_index" in self._store else None + + @edge_index.setter + def edge_index(self, edge_index: Optional[Tensor]): + self._store.edge_index = edge_index + + @property + def edge_weight(self) -> Optional[Tensor]: + return self["edge_weight"] if "edge_weight" in self._store else None + + @edge_weight.setter + def edge_weight(self, edge_weight: Optional[Tensor]): + self._store.edge_weight = edge_weight + + @property + def edge_attr(self) -> Optional[Tensor]: + return self["edge_attr"] if "edge_attr" in self._store else None + + @edge_attr.setter + def edge_attr(self, edge_attr: Optional[Tensor]): + self._store.edge_attr = edge_attr + + @property + def y(self) -> Optional[Union[Tensor, int, float]]: + return self["y"] if "y" in self._store else None + + @y.setter + def y(self, y: Optional[Tensor]): + self._store.y = y + + @property + def crd(self) -> Optional[Tensor]: + return self["crd"] if "crd" in self._store else None + + @crd.setter + def crd(self, crd: Optional[Tensor]): + self._store.crd = crd + + @property + def batch(self) -> Optional[Tensor]: + return self["batch"] if "batch" in self._store else None + + @batch.setter + def batch(self, batch: Optional[Tensor]): + self._store.batch = batch + + @property + def time(self) -> Optional[Tensor]: + return self["time"] if "time" in self._store else None + + @time.setter + def time(self, time: Optional[Tensor]): + self._store.time = time + + @property + def face(self) -> Optional[Tensor]: + return self["face"] if "face" in self._store else None + + @face.setter + def face(self, face: Optional[Tensor]): + self._store.face = face + + +############################################################################### + + +def size_repr(key: Any, value: Any, indent: int = 0) -> str: + pad = " " * indent + if isinstance(value, Tensor) and value.dim() == 0: + out = value.item() + elif isinstance(value, Tensor) and getattr(value, "is_nested", False): + out = str(list(value.to_padded_tensor(padding=0.0).shape)) + elif isinstance(value, Tensor): + out = str(list(value.shape)) + elif isinstance(value, np.ndarray): + out = str(list(value.shape)) + elif isinstance(value, str): + out = f"'{value}'" + elif isinstance(value, Sequence): + out = str([len(value)]) + elif isinstance(value, Mapping) and len(value) == 0: + out = "{}" + elif ( + isinstance(value, Mapping) + and len(value) == 1 + and not isinstance(list(value.values())[0], Mapping) + ): + lines = [size_repr(k, v, 0) for k, v in value.items()] + out = "{ " + ", ".join(lines) + " }" + elif isinstance(value, Mapping): + lines = [size_repr(k, v, indent + 2) for k, v in value.items()] + out = "{\n" + ",\n".join(lines) + ",\n" + pad + "}" + else: + out = str(value) + + key = str(key).replace("'", "") + return f"{pad}{key}={out}" + + +def warn_or_raise(msg: str, raise_on_error: bool = True): + if raise_on_error: + raise ValueError(msg) + else: + warnings.warn(msg) diff --git a/mindscience/sharker/data/heterograph.py b/mindscience/sharker/data/heterograph.py new file mode 100644 index 000000000..90a644cd1 --- /dev/null +++ b/mindscience/sharker/data/heterograph.py @@ -0,0 +1,1023 @@ +import copy +import re +import warnings +from collections import defaultdict, namedtuple +from collections.abc import Mapping +from itertools import chain +from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Union +from typing_extensions import Self +import mindspore as ms +from mindspore import ops, mint +import numpy as np + +from .graph import Graph, size_repr, warn_or_raise +from .storage import BaseStorage, EdgeStorage, NodeStorage + +from ..utils import ( + bipartite_subgraph, + contains_isolated_nodes, + is_undirected, + mask_select, +) + +NodeOrEdgeStorage = Union[NodeStorage, EdgeStorage] +DEFAULT_REL = "to" +EDGE_TYPE_STR_SPLIT = "__" + + +class HeteroGraph(Graph): + r"""A data object describing a heterogeneous graph, holding multiple node + and/or edge types in disjunct storage objects. + Storage objects can hold either node-level, link-level or graph-level + attributes. + In general, :class:`~sharker.data.HeteroGraph` tries to mimic the + behavior of a regular **nested** :python:`Python` dictionary. + In addition, it provides useful functionality for analyzing graph + structures, and provides basic MindSpore tensor functionalities. + + .. code-block:: + + from mindscience.sharker.data import HeteroGraph + + data = HeteroGraph() + + # Create two node types "paper" and "author" holding a feature matrix: + data['paper'].x = ops.randn(num_papers, num_paper_features) + data['author'].x = ops.randn(num_authors, num_authors_features) + + # Create an edge type "(author, writes, paper)" and building the + # graph connectivity: + data['author', 'writes', 'paper'].edge_index = ... # [2, num_edges] + + data['paper'].num_nodes + >>> 23 + + data['author', 'writes', 'paper'].num_edges + >>> 52 + + # MindSpore tensor functionality: + data = data.pin_memory() + + Note that there exists multiple ways to create a heterogeneous graph data, + *e.g.*: + + * To initialize a node of type :obj:`"paper"` holding a node feature + matrix :obj:`x_paper` named :obj:`x`: + + .. code-block:: python + + from mindscience.sharker.data import HeteroGraph + + # (1) Assign attributes after initialization, + data = HeteroGraph() + data['paper'].x = x_paper + + # or (2) pass them as keyword arguments during initialization, + data = HeteroGraph(paper={ 'x': x_paper }) + + # or (3) pass them as dictionaries during initialization, + data = HeteroGraph({'paper': { 'x': x_paper }}) + + * To initialize an edge from source node type :obj:`"author"` to + destination node type :obj:`"paper"` with relation type :obj:`"writes"` + holding a graph connectivity matrix :obj:`edge_index_author_paper` named + :obj:`edge_index`: + + .. code-block:: python + + # (1) Assign attributes after initialization, + data = HeteroGraph() + data['author', 'writes', 'paper'].edge_index = edge_index_author_paper + + # or (2) pass them as keyword arguments during initialization, + data = HeteroGraph(author__writes__paper={ + 'edge_index': edge_index_author_paper + }) + + # or (3) pass them as dictionaries during initialization, + data = HeteroGraph({ + ('author', 'writes', 'paper'): + { 'edge_index': edge_index_author_paper } + }) + """ + + def __init__(self, _mapping: Optional[Dict[str, Any]] = None, **kwargs): + super().__init__() + + self.__dict__["_store"] = BaseStorage(_parent=self) + self.__dict__["_node_store_dict"] = {} + self.__dict__["_edge_store_dict"] = {} + + for key, value in chain((_mapping or {}).items(), kwargs.items()): + if "__" in key and isinstance(value, Mapping): + key = tuple(key.split("__")) + + if isinstance(value, Mapping): + self[key].update(value) + else: + setattr(self, key, value) + + @classmethod + def from_dict(cls, mapping: Dict[str, Any]) -> Self: + r"""Creates a :class:`~sharker.data.HeteroGraph` object from a + dictionary. + """ + out = cls() + for key, value in mapping.items(): + if key == "_store": + out.__dict__["_store"] = BaseStorage(_parent=out, **value) + elif isinstance(key, str): + out._node_store_dict[key] = NodeStorage(_parent=out, _key=key, **value) + else: + out._edge_store_dict[key] = EdgeStorage(_parent=out, _key=key, **value) + return out + + def __getattr__(self, key: str) -> Any: + # `data.*_dict` => Link to node and edge stores. + # `data.*` => Link to the `_store`. + # Using `data.*_dict` is the same as using `collect()` for collecting + # nodes and edges features. + if hasattr(self._store, key): + return getattr(self._store, key) + elif bool(re.search("_dict$", key)): + return self.collect(key[:-5]) + raise AttributeError( + f"'{self.__class__.__name__}' has no " f"attribute '{key}'" + ) + + def __setattr__(self, key: str, value: Any): + # NOTE: We aim to prevent duplicates in node or edge types. + if key in self.node_types: + raise AttributeError(f"'{key}' is already present as a node type") + elif key in self.edge_types: + raise AttributeError(f"'{key}' is already present as an edge type") + setattr(self._store, key, value) + + def __delattr__(self, key: str): + delattr(self._store, key) + + def __getitem__( + self, *args: Union[str, Tuple[str, str], Tuple[str, str, str]] + ) -> Any: + # `data[*]` => Link to either `_store`, _node_store_dict` or + # `_edge_store_dict`. + # If neither is present, we create a new `Storage` object for the given + # node/edge-type. + key = self._to_canonical(*args) + + out = self._store.get(key, None) + if out is not None: + return out + + if isinstance(key, tuple): + return self.get_edge_store(*key) + else: + return self.get_node_store(key) + + def __setitem__(self, key: str, value: Any): + if key in self.node_types: + raise AttributeError(f"'{key}' is already present as a node type") + elif key in self.edge_types: + raise AttributeError(f"'{key}' is already present as an edge type") + self._store[key] = value + + def __delitem__(self, *args: Union[str, Tuple[str, str], Tuple[str, str, str]]): + key = self._to_canonical(*args) + if key in self.edge_types: + del self._edge_store_dict[key] + elif key in self.node_types: + del self._node_store_dict[key] + else: + del self._store[key] + + def __copy__(self): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + out.__dict__["_store"] = copy.copy(self._store) + out._store._parent = out + out.__dict__["_node_store_dict"] = {} + for key, store in self._node_store_dict.items(): + out._node_store_dict[key] = copy.copy(store) + out._node_store_dict[key]._parent = out + out.__dict__["_edge_store_dict"] = {} + for key, store in self._edge_store_dict.items(): + out._edge_store_dict[key] = copy.copy(store) + out._edge_store_dict[key]._parent = out + return out + + def __deepcopy__(self, memo): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = copy.deepcopy(value, memo) + out._store._parent = out + for key in self._node_store_dict.keys(): + out._node_store_dict[key]._parent = out + for key in out._edge_store_dict.keys(): + out._edge_store_dict[key]._parent = out + return out + + def __repr__(self) -> str: + info1 = [size_repr(k, v, 2) for k, v in self._store.items()] + info2 = [size_repr(k, v, 2) for k, v in self._node_store_dict.items()] + info3 = [size_repr(k, v, 2) for k, v in self._edge_store_dict.items()] + info = ",\n".join(info1 + info2 + info3) + info = f"\n{info}\n" if len(info) > 0 else info + return f"{self.__class__.__name__}({info})" + + def stores_as(self, data: Self): + for node_type in data.node_types: + self.get_node_store(node_type) + for edge_type in data.edge_types: + self.get_edge_store(*edge_type) + return self + + @property + def stores(self) -> List[BaseStorage]: + r"""Returns a list of all storages of the graph.""" + return [self._store] + list(self.node_stores) + list(self.edge_stores) + + @property + def node_types(self) -> List[str]: + r"""Returns a list of all node types of the graph.""" + return list(self._node_store_dict.keys()) + + @property + def node_stores(self) -> List[NodeStorage]: + r"""Returns a list of all node storages of the graph.""" + return list(self._node_store_dict.values()) + + @property + def edge_types(self) -> List[Tuple[str, str, str]]: + r"""Returns a list of all edge types of the graph.""" + return list(self._edge_store_dict.keys()) + + @property + def edge_stores(self) -> List[EdgeStorage]: + r"""Returns a list of all edge storages of the graph.""" + return list(self._edge_store_dict.values()) + + def node_items(self) -> List[Tuple[str, NodeStorage]]: + r"""Returns a list of node type and node storage pairs.""" + return list(self._node_store_dict.items()) + + def edge_items(self) -> List[Tuple[Tuple[str, str, str], EdgeStorage]]: + r"""Returns a list of edge type and edge storage pairs.""" + return list(self._edge_store_dict.items()) + + def to_dict(self) -> Dict[str, Any]: + out_dict: Dict[str, Any] = {} + out_dict["_store"] = self._store.to_dict() + for key, store in chain( + self._node_store_dict.items(), self._edge_store_dict.items() + ): + out_dict[key] = store.to_dict() + return out_dict + + def to_namedtuple(self) -> NamedTuple: + field_names = list(self._store.keys()) + field_values = list(self._store.values()) + field_names += [ + "__".join(key) if isinstance(key, tuple) else key + for key in self.node_types + self.edge_types + ] + field_values += [ + store.to_namedtuple() for store in self.node_stores + self.edge_stores + ] + DataTuple = namedtuple("DataTuple", field_names) + return DataTuple(*field_values) + + def set_value_dict( + self, + key: str, + value_dict: Dict[str, Any], + ) -> Self: + r"""Sets the values in the dictionary :obj:`value_dict` to the + attribute with name :obj:`key` to all node/edge types present in the + dictionary. + + .. code-block:: python + + data = HeteroGraph() + + data.set_value_dict('x', { + 'paper': ops.randn(4, 16), + 'author': ops.randn(8, 32), + }) + + print(data['paper'].x) + """ + for k, v in (value_dict or {}).items(): + self[k][key] = v + return self + + def update(self, data: Self) -> Self: + for store in data.stores: + for key, value in store.items(): + self[store._key][key] = value + return self + + def __cat_dim__( + self, + key: str, + value: Any, + store: Optional[NodeOrEdgeStorage] = None, + *args, + **kwargs, + ) -> Any: + if isinstance(store, EdgeStorage) and "index" in key: + return -1 + return 0 + + def __inc__( + self, + key: str, + value: Any, + store: Optional[NodeOrEdgeStorage] = None, + *args, + **kwargs, + ) -> Any: + if "batch" in key and isinstance(value, (ms.Tensor, np.array)): + return int(value.max()) + 1 + elif isinstance(store, EdgeStorage) and "index" in key: + if isinstance(value, ms.Tensor): + return ms.Tensor(store.shape).view(2, 1) + elif isinstance(value, np.ndarray): + return np.array(store.shape).reshape(2, 1) + else: + return 0 + + @property + def num_nodes(self) -> Optional[int]: + r"""Returns the number of nodes in the graph.""" + return super().num_nodes + + @property + def num_node_features(self) -> Dict[str, int]: + r"""Returns the number of features per node type in the graph.""" + return { + key: store.num_node_features for key, store in self._node_store_dict.items() + } + + @property + def num_features(self) -> Dict[str, int]: + r"""Returns the number of features per node type in the graph. + Alias for :py:attr:`~num_node_features`. + """ + return self.num_node_features + + @property + def num_edge_features(self) -> Dict[Tuple[str, str, str], int]: + r"""Returns the number of features per edge type in the graph.""" + return { + key: store.num_edge_features for key, store in self._edge_store_dict.items() + } + + def has_isolated_nodes(self) -> bool: + r"""Returns :obj:`True` if the graph contains isolated nodes.""" + edge_index, _, _ = to_homogeneous_edge_index(self) + return contains_isolated_nodes(edge_index, num_nodes=self.num_nodes) + + def is_undirected(self) -> bool: + r"""Returns :obj:`True` if graph edges are undirected.""" + edge_index, _, _ = to_homogeneous_edge_index(self) + return is_undirected(edge_index, num_nodes=self.num_nodes) + + def validate(self, raise_on_error: bool = True) -> bool: + r"""Validates the correctness of the data.""" + cls_name = self.__class__.__name__ + status = True + + node_types = set(self.node_types) + num_src_node_types = {src for src, _, _ in self.edge_types} + num_dst_node_types = {dst for _, _, dst in self.edge_types} + + dangling_types = (num_src_node_types | num_dst_node_types) - node_types + if len(dangling_types) > 0: + status = False + warn_or_raise( + f"The node types {dangling_types} are referenced in edge " + f"types but do not exist as node types", + raise_on_error, + ) + + dangling_types = node_types - (num_src_node_types | num_dst_node_types) + if len(dangling_types) > 0: + warn_or_raise( + f"The node types {dangling_types} are isolated and are not " + f"referenced by any edge type ", + raise_on_error=False, + ) + + for edge_type, store in self._edge_store_dict.items(): + src, _, dst = edge_type + + num_src_nodes = self[src].num_nodes + num_dst_nodes = self[dst].num_nodes + if num_src_nodes is None: + status = False + warn_or_raise( + f"'num_nodes' is undefined in node type '{src}' of " + f"'{cls_name}'", + raise_on_error, + ) + + if num_dst_nodes is None: + status = False + warn_or_raise( + f"'num_nodes' is undefined in node type '{dst}' of " + f"'{cls_name}'", + raise_on_error, + ) + + if "edge_index" in store: + if store.edge_index.dim() != 2 or store.edge_index.shape[0] != 2: + status = False + warn_or_raise( + f"'edge_index' of edge type {edge_type} needs to be " + f"of shape [2, num_edges] in '{cls_name}' (found " + f"{store.edge_index.shape})", + raise_on_error, + ) + + if "edge_index" in store and store.edge_index.numel() > 0: + if store.edge_index.min() < 0: + status = False + warn_or_raise( + f"'edge_index' of edge type {edge_type} contains " + f"negative indices in '{cls_name}' " + f"(found {int(store.edge_index.min())})", + raise_on_error, + ) + + if ( + num_src_nodes is not None + and store.edge_index[0].max() >= num_src_nodes + ): + status = False + warn_or_raise( + f"'edge_index' of edge type {edge_type} contains " + f"larger source indices than the number of nodes " + f"({num_src_nodes}) of this node type in '{cls_name}' " + f"(found {int(store.edge_index[0].max())})", + raise_on_error, + ) + + if ( + num_dst_nodes is not None + and store.edge_index[1].max() >= num_dst_nodes + ): + status = False + warn_or_raise( + f"'edge_index' of edge type {edge_type} contains " + f"larger destination indices than the number of nodes " + f"({num_dst_nodes}) of this node type in '{cls_name}' " + f"(found {int(store.edge_index[1].max())})", + raise_on_error, + ) + + return status + + def debug(self): + pass # TODO + + ########################################################################### + + def _to_canonical( + self, *args: Union[str, Tuple[str, str], Tuple[str, str, str]] + ) -> Union[str, Tuple[str, str, str]]: + # Converts a given `QueryType` to its "canonical type": + # 1. `relation_type` will get mapped to the unique + # `(src_node_type, relation_type, dst_node_type)` tuple. + # 2. `(src_node_type, dst_node_type)` will get mapped to the unique + # `(src_node_type, *, dst_node_type)` tuple, and + # `(src_node_type, 'to', dst_node_type)` otherwise. + if len(args) == 1: + args = args[0] + + if isinstance(args, str): + node_types = [key for key in self.node_types if key == args] + if len(node_types) == 1: + args = node_types[0] + return args + + # Try to map to edge type based on unique relation type: + edge_types = [key for key in self.edge_types if key[1] == args] + if len(edge_types) == 1: + args = edge_types[0] + return args + + elif len(args) == 2: + # Try to find the unique source/destination node tuple: + edge_types = [ + key + for key in self.edge_types + if key[0] == args[0] and key[-1] == args[-1] + ] + if len(edge_types) == 1: + args = edge_types[0] + return args + elif len(edge_types) == 0: + args = (args[0], DEFAULT_REL, args[1]) + return args + + return args + + def metadata(self) -> Tuple[List[str], List[Tuple[str, str, str]]]: + r"""Returns the heterogeneous meta-data, *i.e.* its node and edge + types. + + .. code-block:: python + + data = HeteroGraph() + data['paper'].x = ... + data['author'].x = ... + data['author', 'writes', 'paper'].edge_index = ... + + print(data.metadata()) + >>> (['paper', 'author'], [('author', 'writes', 'paper')]) + """ + return self.node_types, self.edge_types + + def collect( + self, + key: str, + allow_empty: bool = False, + ) -> Dict[Union[str, Tuple[str, str, str]], Any]: + r"""Collects the attribute :attr:`key` from all node and edge types. + + .. code-block:: python + + data = HeteroGraph() + data['paper'].x = ... + data['author'].x = ... + + print(data.collect('x')) + >>> { 'paper': ..., 'author': ...} + + .. note:: + + This is equivalent to writing :obj:`data.x_dict`. + + Args: + key (str): The attribute to collect from all node and ege types. + allow_empty (bool, optional): If set to :obj:`True`, will not raise + an error in case the attribute does not exit in any node or + edge type. (default: :obj:`False`) + """ + mapping = {} + for subtype, store in chain( + self._node_store_dict.items(), self._edge_store_dict.items() + ): + if hasattr(store, key): + mapping[subtype] = getattr(store, key) + if not allow_empty and len(mapping) == 0: + raise KeyError( + f"Tried to collect '{key}' but did not find any " + f"occurrences of it in any node and/or edge type" + ) + return mapping + + def _check_type_name(self, name: str): + if "__" in name: + warnings.warn( + f"The type '{name}' contains double underscores " + f"('__') which may lead to unexpected behavior. " + f"To avoid any issues, ensure that your type names " + f"only contain single underscores." + ) + + def get_node_store(self, key: str) -> NodeStorage: + r"""Gets the :class:`~sharker.data.storage.NodeStorage` object + of a particular node type :attr:`key`. + If the storage is not present yet, will create a new + :class:`sharker.data.storage.NodeStorage` object for the given + node type. + + .. code-block:: python + + data = HeteroGraph() + node_storage = data.get_node_store('paper') + """ + out = self._node_store_dict.get(key, None) + if out is None: + self._check_type_name(key) + out = NodeStorage(_parent=self, _key=key) + self._node_store_dict[key] = out + return out + + def get_edge_store(self, src: str, rel: str, dst: str) -> EdgeStorage: + r"""Gets the :class:`~sharker.data.storage.EdgeStorage` object + of a particular edge type given by the tuple :obj:`(src, rel, dst)`. + If the storage is not present yet, will create a new + :class:`sharker.data.storage.EdgeStorage` object for the given + edge type. + + .. code-block:: python + + data = HeteroGraph() + edge_storage = data.get_edge_store('author', 'writes', 'paper') + """ + key = (src, rel, dst) + out = self._edge_store_dict.get(key, None) + if out is None: + self._check_type_name(rel) + out = EdgeStorage(_parent=self, _key=key) + self._edge_store_dict[key] = out + return out + + def rename(self, name: str, new_name: str) -> Self: + r"""Renames the node type :obj:`name` to :obj:`new_name` in-place.""" + node_store = self._node_store_dict.pop(name) + node_store._key = new_name + self._node_store_dict[new_name] = node_store + + for edge_type in self.edge_types: + src, rel, dst = edge_type + if src == name or dst == name: + edge_store = self._edge_store_dict.pop(edge_type) + src = new_name if src == name else src + dst = new_name if dst == name else dst + edge_type = (src, rel, dst) + edge_store._key = edge_type + self._edge_store_dict[edge_type] = edge_store + + return self + + def subgraph(self, subset_dict: Dict[str, ms.Tensor]) -> Self: + r"""Returns the induced subgraph containing the node types and + corresponding nodes in :obj:`subset_dict`. + + If a node type is not a key in :obj:`subset_dict` then all nodes of + that type remain in the graph. + + .. code-block:: python + + data = HeteroGraph() + data['paper'].x = ... + data['author'].x = ... + data['conference'].x = ... + data['paper', 'cites', 'paper'].edge_index = ... + data['author', 'paper'].edge_index = ... + data['paper', 'conference'].edge_index = ... + print(data) + >>> HeteroGraph( + paper={ x=[10, 16] }, + author={ x=[5, 32] }, + conference={ x=[5, 8] }, + (paper, cites, paper)={ edge_index=[2, 50] }, + (author, to, paper)={ edge_index=[2, 30] }, + (paper, to, conference)={ edge_index=[2, 25] } + ) + + subset_dict = { + 'paper': Tensor, 5, 6]), + 'author': Tensor([0, 2]), + } + + print(data.subgraph(subset_dict)) + >>> HeteroGraph( + paper={ x=[4, 16] }, + author={ x=[2, 32] }, + conference={ x=[5, 8] }, + (paper, cites, paper)={ edge_index=[2, 24] }, + (author, to, paper)={ edge_index=[2, 5] }, + (paper, to, conference)={ edge_index=[2, 10] } + ) + + Args: + subset_dict (Dict[str, LongTensor or BoolTensor]): A dictionary + holding the nodes to keep for each node type. + """ + data = copy.copy(self) + subset_dict = copy.copy(subset_dict) + + for node_type, subset in subset_dict.items(): + for key, value in self[node_type].items(): + if key == "num_nodes": + if subset.dtype == ms.bool_: + data[node_type].num_nodes = int(subset.sum()) + else: + data[node_type].num_nodes = subset.shape[0] + elif self[node_type].is_node_attr(key): + data[node_type][key] = value[subset] + else: + data[node_type][key] = value + + for edge_type in self.edge_types: + if "edge_index" not in self[edge_type]: + continue + + src, _, dst = edge_type + + src_subset = subset_dict.get(src) + if src_subset is None: + src_subset = mint.arange(data[src].num_nodes) + dst_subset = subset_dict.get(dst) + if dst_subset is None: + dst_subset = mint.arange(data[dst].num_nodes) + + edge_index, _, edge_mask = bipartite_subgraph( + (src_subset, dst_subset), + self[edge_type].edge_index, + relabel_nodes=True, + size=(self[src].num_nodes, self[dst].num_nodes), + return_edge_mask=True, + ) + + for key, value in self[edge_type].items(): + if key == "edge_index": + data[edge_type].edge_index = edge_index + elif self[edge_type].is_edge_attr(key): + data[edge_type][key] = value[edge_mask] + else: + data[edge_type][key] = value + + return data + + def edge_subgraph( + self, + subset_dict: Dict[Tuple[str, str, str], ms.Tensor], + ) -> Self: + r"""Returns the induced subgraph given by the edge indices in + :obj:`subset_dict` for certain edge types. + Will currently preserve all the nodes in the graph, even if they are + isolated after subgraph computation. + + Args: + subset_dict (Dict[Tuple[str, str, str], LongTensor or BoolTensor]): + A dictionary holding the edges to keep for each edge type. + """ + data = copy.copy(self) + + for edge_type, subset in subset_dict.items(): + edge_store, new_edge_store = self[edge_type], data[edge_type] + for key, value in edge_store.items(): + if edge_store.is_edge_attr(key): + dim = self.__cat_dim__(key, value, edge_store) + if subset.dtype == ms.bool_: + new_edge_store[key] = mask_select(value, dim, subset) + else: + new_edge_store[key] = value.index_select(dim, subset) + + return data + + def node_type_subgraph(self, node_types: List[str]) -> Self: + r"""Returns the subgraph induced by the given :obj:`node_types`, *i.e.* + the returned :class:`HeteroGraph` object only contains the node types + which are included in :obj:`node_types`, and only contains the edge + types where both end points are included in :obj:`node_types`. + """ + data = copy.copy(self) + for edge_type in self.edge_types: + src, _, dst = edge_type + if src not in node_types or dst not in node_types: + del data[edge_type] + for node_type in self.node_types: + if node_type not in node_types: + del data[node_type] + return data + + def edge_type_subgraph(self, edge_types: List[str]) -> Self: + r"""Returns the subgraph induced by the given :obj:`edge_types`, *i.e.* + the returned :class:`HeteroGraph` object only contains the edge types + which are included in :obj:`edge_types`, and only contains the node + types of the end points which are included in :obj:`node_types`. + """ + edge_types = [self._to_canonical(e) for e in edge_types] + + data = copy.copy(self) + for edge_type in self.edge_types: + if edge_type not in edge_types: + del data[edge_type] + node_types = set(e[0] for e in edge_types) + node_types |= set(e[-1] for e in edge_types) + for node_type in self.node_types: + if node_type not in node_types: + del data[node_type] + return data + + def to_homogeneous( + self, + node_attrs: Optional[List[str]] = None, + edge_attrs: Optional[List[str]] = None, + add_node_type: bool = True, + add_edge_type: bool = True, + dummy_values: bool = True, + ) -> Graph: + """Converts a :class:`~sharker.data.HeteroGraph` object to a + homogeneous :class:`~sharker.data.Graph` object. + By default, all features with same feature dimensionality across + different types will be merged into a single representation, unless + otherwise specified via the :obj:`node_attrs` and :obj:`edge_attrs` + arguments. + Furthermore, attributes named :obj:`node_type` and :obj:`edge_type` + will be added to the returned :class:`~sharker.data.Graph` + object, denoting node-level and edge-level vectors holding the + node and edge type as integers, respectively. + + Args: + node_attrs (List[str], optional): The node features to combine + across all node types. These node features need to be of the + same feature dimensionality. If set to :obj:`None`, will + automatically determine which node features to combine. + (default: :obj:`None`) + edge_attrs (List[str], optional): The edge features to combine + across all edge types. These edge features need to be of the + same feature dimensionality. If set to :obj:`None`, will + automatically determine which edge features to combine. + (default: :obj:`None`) + add_node_type (bool, optional): If set to :obj:`False`, will not + add the node-level vector :obj:`node_type` to the returned + :class:`~sharker.data.Graph` object. + (default: :obj:`True`) + add_edge_type (bool, optional): If set to :obj:`False`, will not + add the edge-level vector :obj:`edge_type` to the returned + :class:`~sharker.data.Graph` object. + (default: :obj:`True`) + dummy_values (bool, optional): If set to :obj:`True`, will fill + attributes of remaining types with dummy values. + Dummy values are :obj:`NaN` for floating point attributes, + :obj:`False` for booleans, and :obj:`-1` for integers. + (default: :obj:`True`) + """ + + def get_sizes(stores: List[BaseStorage]) -> Dict[str, List[Tuple]]: + sizes_dict = defaultdict(list) + for store in stores: + for key, value in store.items(): + if key in ["edge_index", "edge_label_index", "adj", "adj_t"]: + continue + if isinstance(value, ms.Tensor): + dim = self.__cat_dim__(key, value, store) + size = value.shape[:dim] + value.shape[dim+1:] + sizes_dict[key].append(tuple(size)) + return sizes_dict + + def fill_dummy_(stores: List[BaseStorage], keys: Optional[List[str]] = None): + sizes_dict = get_sizes(stores) + + if keys is not None: + sizes_dict = { + key: sizes for key, sizes in sizes_dict.items() if key in keys + } + + sizes_dict = { + key: sizes for key, sizes in sizes_dict.items() if len(set(sizes)) == 1 + } + + for store in stores: # Fill stores with dummy features: + for key, sizes in sizes_dict.items(): + if key not in store: + ref = list(self.collect(key).values())[0] + dim = self.__cat_dim__(key, ref, store) + if ref.is_floating_point(): + dummy = float("NaN") + elif ref.dtype == ms.bool_: + dummy = False + else: + dummy = -1 + if isinstance(store, NodeStorage): + dim_size = store.num_nodes + else: + dim_size = store.num_edges + shape = sizes[0][:dim] + (dim_size,) + sizes[0][dim:] + store[key] = ops.full(shape, dummy, dtype=ref.dtype) + + def _consistent_size(stores: List[BaseStorage]) -> List[str]: + sizes_dict = get_sizes(stores) + keys = [] + for key, sizes in sizes_dict.items(): + if len(sizes) != len(stores): + continue + lengths = set([len(size) for size in sizes]) + if len(lengths) != 1: + continue + if len(sizes[0]) != 1 and len(set(sizes)) != 1: + continue + keys.append(key) + return keys + + if dummy_values: + self = copy.copy(self) + fill_dummy_(self.node_stores, node_attrs) + fill_dummy_(self.edge_stores, edge_attrs) + + edge_index, node_slices, edge_slices = to_homogeneous_edge_index(self) + + data = Graph(**self._store.to_dict()) + if edge_index is not None: + data.edge_index = edge_index + data._node_type_names = list(node_slices.keys()) + data._edge_type_names = list(edge_slices.keys()) + + # Combine node attributes into a single tensor: + if node_attrs is None: + node_attrs = _consistent_size(self.node_stores) + for key in node_attrs: + if key in {"ptr"}: + continue + values = [store[key] for store in self.node_stores] + dim = self.__cat_dim__(key, values[0], self.node_stores[0]) + dim = values[0].dim() + dim if dim < 0 else dim + # For two-dimensional features, we allow arbitrary shapes and + # pad them with zeros if necessary in case their size doesn't + # match: + if values[0].dim() == 2 and dim == 0: + _max = max([value.shape[-1] for value in values]) + for i, v in enumerate(values): + if v.shape[-1] < _max: + pad = v.new_zeros([v.shape[0], _max - v.shape[-1]]) + values[i] = mint.cat([v, pad], dim=-1) + value = mint.cat(values, dim=dim) + data[key] = value + + if not data.can_infer_num_nodes: + data.num_nodes = list(node_slices.values())[-1][1] + + # Combine edge attributes into a single tensor: + if edge_attrs is None: + edge_attrs = _consistent_size(self.edge_stores) + for key in edge_attrs: + values = [store[key] for store in self.edge_stores] + dim = self.__cat_dim__(key, values[0], self.edge_stores[0]) + value = mint.cat(values, dim=dim) if len(values) > 1 else values[0] + data[key] = value + + if "edge_label_index" in self: + edge_label_index_dict = self.edge_label_index_dict + for edge_type, edge_label_index in edge_label_index_dict.items(): + edge_label_index = edge_label_index.copy() + edge_label_index[0] += node_slices[edge_type[0]][0] + edge_label_index[1] += node_slices[edge_type[-1]][0] + edge_label_index_dict[edge_type] = edge_label_index + data.edge_label_index = mint.cat( + list(edge_label_index_dict.values()), dim=-1 + ) + + if add_node_type: + sizes = [offset[1] - offset[0] for offset in node_slices.values()] + sizes = ms.Tensor(sizes, dtype=ms.int64) + node_type = mint.arange(len(sizes)) + data.node_type = node_type.repeat_interleave(sizes) + + if add_edge_type and edge_index is not None: + sizes = [offset[1] - offset[0] for offset in edge_slices.values()] + sizes = ms.Tensor(sizes, dtype=ms.int64) + edge_type = mint.arange(len(sizes)) + data.edge_type = edge_type.repeat_interleave(sizes) + + return data + + +# Helper functions ############################################################ + + +def get_node_slices(num_nodes: Dict[str, int]) -> Dict[str, Tuple[int, int]]: + r"""Returns the boundaries of each node type in a graph.""" + node_slices: Dict[str, Tuple[int, int]] = {} + cumsum = 0 + for node_type, N in num_nodes.items(): + node_slices[node_type] = (cumsum, cumsum + N) + cumsum += N + return node_slices + + +def offset_edge_index( + node_slices: Dict[str, Tuple[int, int]], + edge_type: Tuple[str, str, str], + edge_index: ms.Tensor, +) -> ms.Tensor: + r"""Increases the edge indices by the offsets of source and destination + node types. + """ + src, _, dst = edge_type + offset = [[node_slices[src][0]], [node_slices[dst][0]]] + offset = ms.Tensor(offset) + return edge_index + offset + + +def to_homogeneous_edge_index( + graph: HeteroGraph, +) -> Tuple[Optional[ms.Tensor], Dict[str, Any], Dict[Tuple[str, str, str], Any]]: + r"""Converts a heterogeneous graph into a homogeneous typed graph.""" + # Record slice information per node type: + node_slices = get_node_slices(graph.num_nodes_dict) + + # Record edge indices and slice information per edge type: + cumsum = 0 + edge_indices: List[ms.Tensor] = [] + edge_slices: Dict[Tuple[str, str, str], Tuple[int, int]] = {} + for edge_type, edge_index in graph.collect("edge_index", True).items(): + edge_index = offset_edge_index(node_slices, edge_type, edge_index) + edge_indices.append(edge_index) + edge_slices[edge_type] = (cumsum, cumsum + edge_index.shape[1]) + cumsum += edge_index.shape[1] + + edge_index: Optional[ms.Tensor] = None + if len(edge_indices) == 1: + edge_index = edge_indices[0] + elif len(edge_indices) > 1: + edge_index = mint.cat(edge_indices, dim=-1) + + return edge_index, node_slices, edge_slices diff --git a/mindscience/sharker/data/hypergraph.py b/mindscience/sharker/data/hypergraph.py new file mode 100644 index 000000000..a374ca10d --- /dev/null +++ b/mindscience/sharker/data/hypergraph.py @@ -0,0 +1,234 @@ +import copy +from collections.abc import Mapping, Sequence +from typing import Any, List, Optional, Tuple +from typing_extensions import Self +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, mint + +from .graph import Graph, warn_or_raise +from .heterograph import HeteroGraph +from ..utils import select, hyper_subgraph + + +class HyperGraph(Graph): + r"""A data object describing a hypergraph. + + The data object can hold node-level, link-level and graph-level attributes. + This object differs from a standard :obj:`~sharker.data.Graph` + object by having hyperedges, i.e. edges that connect more + than two nodes. For example, in the hypergraph scenario + :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with + :math:`\mathcal{V} = \{ 0, 1, 2, 3, 4 \}` and + :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3, 4 \} \}`, the + hyperedge index :obj:`edge_index` is represented as: + + .. code-block:: python + + # hyper graph with two hyperedges + # connecting 3 and 4 nodes, respectively + edge_index = Tensor([ + [0, 1, 2, 1, 2, 3, 4], + [0, 0, 0, 1, 1, 1, 1], + ]) + + Args: + x (Tensor, optional): Node feature matrix with shape + :obj:`[num_nodes, num_node_features]`. (default: :obj:`None`) + edge_index (LongTensor, optional): Hyperedge tensor + with shape :obj:`[2, num_edges*num_nodes_per_edge]`. + Where `edge_index[1]` denotes the hyperedge index and + `edge_index[0]` denotes the node indicies that are connected + by the hyperedge. (default: :obj:`None`) + (default: :obj:`None`) + edge_attr (Tensor, optional): Edge feature matrix with shape + :obj:`[num_edges, num_edge_features]`. + (default: :obj:`None`) + y (Tensor, optional): Graph-level or node-level ground-truth + labels with arbitrary shape. (default: :obj:`None`) + crd (Tensor, optional): Node position matrix with shape + :obj:`[num_nodes, num_dimensions]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + """ + + def __init__( + self, + x: Optional[Tensor] = None, + edge_index: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + y: Optional[Tensor] = None, + crd: Optional[Tensor] = None, + **kwargs: Any, + ) -> None: + super().__init__( + x=x, + edge_index=edge_index, + edge_attr=edge_attr, + y=y, + crd=crd, + **kwargs, + ) + + @property + def num_edges(self) -> int: + r"""Returns the number of hyperedges in the hypergraph.""" + if self.edge_index is None: + return 0 + size = max(self.edge_index[1]) + 1 + if isinstance(size, Tensor): + size = size.item() + return size + + @property + def num_nodes(self) -> Optional[int]: + num_nodes = super().num_nodes + + # For hypergraphs, `edge_index[1]` does not contain node indices. + # Therefore, the below code is used to prevent `num_nodes` being + # estimated as the number of hyperedges. + if self.edge_index is not None and num_nodes == self.num_edges: + return max(self.edge_index[0]) + 1 + else: + return num_nodes + + @num_nodes.setter + def num_nodes(self, num_nodes: Optional[int]) -> None: + self._store.num_nodes = num_nodes + + def is_edge_attr(self, key: str) -> bool: + val = super().is_edge_attr(key) + if not val and self.edge_index is not None: + return key in self and self[key].shape[0] == self.num_edges + return val + + def __inc__(self, key: str, value: Any, *args: Any, **kwargs: Any) -> Any: + if key == "edge_index": + if isinstance(value, Tensor): + return ms.Tensor([[self.num_nodes], [self.num_edges]]) + elif isinstance(value, np.ndarray): + return np.array([[self.num_nodes], [self.num_edges]]) + elif isinstance(value, Sequence): + return [[self.num_nodes], [self.num_edges]] + else: + return super().__inc__(key, value, *args, **kwargs) + + def subgraph(self, subset: Tensor) -> Self: + r"""Returns the induced subgraph given by the node indices + :obj:`subset`. + + .. note:: + + If only a subset of a hyperedge's nodes are to be + selected in the subgraph, the hyperedge will remain in the + subgraph, but only the selected nodes will be connected by + the hyperedge. Hyperedges that only connects one node in the + subgraph will be removed. + + Examples: + >>> x = ops.randn(4, 16) + >>> edge_index = Tensor([ + ... [0, 1, 0, 2, 1, 1, 2, 4], + ... [0, 0, 1, 1, 1, 2, 2, 2] + >>> ]) + >>> data = HyperGraph(x = x, edge_index = edge_index) + >>> subset = Tensor([1, 2, 4]) + >>> subgraph = data.subgraph(subset) + >>> subgraph.edge_index + tensor([[2, 1, 1, 2, 4], + [0, 0, 1, 1, 1]]) + + Args: + subset (LongTensor or BoolTensor): The nodes to keep. + """ + assert self.edge_index is not None + edge_index, _, edge_mask = hyper_subgraph( + subset, + self.edge_index, + relabel_nodes=True, + num_nodes=self.num_nodes, + return_edge_mask=True, + ) + + data = copy.copy(self) + + for key, value in self.items(): + if key == "edge_index": + data.edge_index = edge_index + elif key == "num_nodes": + if subset.dtype == ms.bool_: + data.num_nodes = int(subset.sum()) + else: + data.num_nodes = subset.shape[0] + elif self.is_node_attr(key): + cat_dim = self.__cat_dim__(key, value) + data[key] = select(value, subset, axis=cat_dim) + elif self.is_edge_attr(key): + cat_dim = self.__cat_dim__(key, value) + data[key] = select(value, edge_mask, axis=cat_dim) + + return data + + def edge_subgraph(self, subset: Tensor) -> Self: + raise NotImplementedError + + def to_heterogeneous( + self, + node_type: Optional[Tensor] = None, + edge_type: Optional[Tensor] = None, + node_type_names: Optional[List[str]] = None, + edge_type_names: Optional[List[Tuple[str, str, str]]] = None, + ) -> HeteroGraph: + raise NotImplementedError + + def has_isolated_nodes(self) -> bool: + if self.edge_index is None: + return False + return mint.unique(self.edge_index[0]).shape[0] < self.num_nodes + + def is_directed(self) -> bool: + raise NotImplementedError + + def is_undirected(self) -> bool: + raise NotImplementedError + + def has_self_loops(self) -> bool: + raise NotImplementedError + + def validate(self, raise_on_error: bool = True) -> bool: + r"""Validates the correctness of the data.""" + cls_name = self.__class__.__name__ + status = True + + num_nodes = self.num_nodes + if num_nodes is None: + status = False + warn_or_raise(f"'num_nodes' is undefined in '{cls_name}'", raise_on_error) + + if self.edge_index is not None: + if self.edge_index.dim() != 2 or self.edge_index.shape[0] != 2: + status = False + warn_or_raise( + f"'edge_index' needs to be of shape [2, num_edges] in " + f"'{cls_name}' (found {self.edge_index.shape})", + raise_on_error, + ) + + if self.edge_index is not None and self.edge_index.numel() > 0: + if self.edge_index.min() < 0: + status = False + warn_or_raise( + f"'edge_index' contains negative indices in " + f"'{cls_name}' (found {int(self.edge_index.min())})", + raise_on_error, + ) + + if num_nodes is not None and self.edge_index[0].max() >= num_nodes: + status = False + warn_or_raise( + f"'edge_index' contains larger indices than the number " + f"of nodes ({num_nodes}) in '{cls_name}' " + f"(found {int(self.edge_index.max())})", + raise_on_error, + ) + + return status diff --git a/mindscience/sharker/data/in_memory.py b/mindscience/sharker/data/in_memory.py new file mode 100644 index 000000000..810560bb6 --- /dev/null +++ b/mindscience/sharker/data/in_memory.py @@ -0,0 +1,344 @@ +import os +import copy +import warnings +from typing import ( + Any, + Callable, + Dict, + Iterable, + List, + Mapping, + MutableSequence, + Optional, + Sequence, + Tuple, + Type, + Union, +) +import mindspore as ms +from mindspore import Tensor +from tqdm import tqdm + +from .batch import Batch +from .graph import Graph +from .collate import collate +from .dataset import Dataset, IndexType +from .on_disk import OnDiskDataset +from .separate import separate +from ..io import fs +from ..utils import to_array + + +class InMemoryDataset(Dataset): + r"""Dataset base class for creating graph datasets which easily fit + into CPU memory. + See `here `__ for the accompanying + tutorial. + + Args: + root (str, optional): Root directory where the dataset should be saved. + (optional: :obj:`None`) + transform (callable, optional): A function/transform that takes in a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + transformed version. + The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + a :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + transformed version. + The data object will be transformed before being saved to disk. + (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + boolean value, indicating whether the data object should be + included in the final dataset. (default: :obj:`None`) + log (bool, optional): Whether to print any console output while + downloading and processing the dataset. (default: :obj:`True`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + """ + + @property + def raw_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: + raise NotImplementedError + + @property + def processed_file_names(self) -> Union[str, List[str], Tuple[str, ...]]: + raise NotImplementedError + + def __init__( + self, + root: Optional[str] = None, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + log: bool = True, + force_reload: bool = False, + return_tensor: bool = True, + ) -> None: + super().__init__(root, transform, pre_transform, pre_filter, log, force_reload) + + self._data: Optional[Graph] = None + self.slices: Optional[Dict[str, Tensor]] = None + self._data_list: Optional[MutableSequence[Optional[Graph]]] = None + self.return_tensor = return_tensor + + @property + def num_classes(self) -> int: + if self.transform is None: + return self._infer_num_classes(self._data.y) + return super().num_classes + + def len(self) -> int: + if self.slices is None: + return 1 + for _, value in nested_iter(self.slices): + return len(value) - 1 + return 0 + + def get(self, idx: int) -> Graph: + # TODO (matthias) Avoid unnecessary copy here. + if self.len() == 1: + return copy.copy(self._data) + + if not hasattr(self, "_data_list") or self._data_list is None: + self._data_list = self.len() * [None] + elif self._data_list[idx] is not None: + return copy.copy(self._data_list[idx]) + + data = separate( + cls=self._data.__class__, + batch=self._data, + idx=idx, + slice_dict=self.slices, + decrement=False, + ) + + self._data_list[idx] = copy.copy(data) + + return data + + @classmethod + def save(cls, data_list: Sequence[Graph], path: str) -> None: + r"""Saves a list of data objects to the file path :obj:`path`.""" + data, slices = cls.collate(data_list) + slices = to_array(slices) + fs.pickle_save((data.numpy().to_dict(), slices, data.__class__), path) + + def load(self, path: str, data_cls: Type[Graph] = Graph) -> None: + r"""Loads the dataset from the file path :obj:`path`.""" + out = fs.pickle_load(path) + assert isinstance(out, tuple) + assert len(out) == 2 or len(out) == 3 + if len(out) == 2: + data, self.slices = out + else: + data, self.slices, data_cls = out + + if not isinstance(data, dict): + self.data = data + else: + self.data = data_cls.from_dict(data) + + @staticmethod + def collate( + data_list: Sequence[Graph], + return_tensor: bool = False + ) -> Tuple[Graph, Optional[Dict[str, Tensor]]]: + r"""Collates a list of :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` objects to the internal + storage format of :class:`~sharker.data.InMemoryDataset`. + """ + if len(data_list) == 1: + return data_list[0], None + + data, slices, _ = collate( + data_list[0].__class__, + data_list=data_list, + increment=False, + return_tensor=return_tensor, + add_batch=False, + ) + + return data, slices + + def copy(self, idx: Optional[IndexType] = None) -> "InMemoryDataset": + r"""Performs a deep-copy of the dataset. If :obj:`idx` is not given, + will clone the full dataset. Otherwise, will only clone a subset of the + dataset from indices :obj:`idx`. + Indices can be slices, lists, tuples, and a :obj:`Tensor` or + :obj:`np.ndarray` of type long or bool. + """ + if idx is None: + data_list = [self.get(i) for i in self.indices()] + else: + data_list = [self.get(i) for i in self.index_select(idx).indices()] + + dataset = copy.copy(self) + dataset._indices = None + dataset._data_list = None + dataset.data, dataset.slices = self.collate(data_list, self.return_tensor) + return dataset + + def to_on_disk_dataset( + self, + root: Optional[str] = None, + backend: str = "sqlite", + log: bool = True, + ) -> OnDiskDataset: + r"""Converts the :class:`InMemoryDataset` to a :class:`OnDiskDataset` + variant. Useful for distributed training and hardware instances with + limited amount of shared memory. + + root (str, optional): Root directory where the dataset should be saved. + If set to :obj:`None`, will save the dataset in + :obj:`root/on_disk`. + Note that it is important to specify :obj:`root` to account for + different dataset splits. (optional: :obj:`None`) + backend (str): The :class:`Database` backend to use. + (default: :obj:`"sqlite"`) + log (bool, optional): Whether to print any console output while + processing the dataset. (default: :obj:`True`) + """ + if root is None and (self.root is None or not os.path.exists(self.root)): + raise ValueError( + f"The root directory of " + f"'{self.__class__.__name__}' is not specified. " + f"Please pass in 'root' when creating on-disk " + f"datasets from it." + ) + + root = root or os.path.join(self.root, "on_disk") + + in_memory_dataset = self + ref_data = in_memory_dataset.get(0) + if not isinstance(ref_data, Graph): + raise NotImplementedError( + f"`{self.__class__.__name__}.to_on_disk_dataset()` is " + f"currently only supported on homogeneous graphs" + ) + + # Parse the schema ==================================================== + + schema: Dict[str, Any] = {} + for key, value in ref_data.to_dict().items(): + if isinstance(value, (int, float, str)): + schema[key] = value.__class__ + elif isinstance(value, Tensor) and value.dim() == 0: + schema[key] = dict(dtype=value.dtype, size=(-1,)) + elif isinstance(value, Tensor): + size = list(value.shape) + size[ref_data.__cat_dim__(key, value)] = -1 + schema[key] = dict(dtype=value.dtype, size=tuple(size)) + else: + schema[key] = object + + # Create the on-disk dataset ========================================== + + class OnDiskDataset(OnDiskDataset): + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + ): + super().__init__( + root=root, + transform=transform, + backend=backend, + schema=schema, + ) + + def process(self): + _iter = [in_memory_dataset.get(i) for i in in_memory_dataset.indices()] + if log: + _iter = tqdm(_iter, desc="Converting to OnDiskDataset") + + data_list: List[Graph] = [] + for i, data in enumerate(_iter): + data_list.append(data) + if i + 1 == len(in_memory_dataset) or (i + 1) % 1000 == 0: + self.extend(data_list) + data_list = [] + + def serialize(self, data: Graph) -> Dict[str, Any]: + return data.to_dict() + + def deserialize(self, data: Dict[str, Any]) -> Graph: + return Graph.from_dict(data) + + def __repr__(self) -> str: + arg_repr = str(len(self)) if len(self) > 1 else "" + return f"OnDisk{in_memory_dataset.__class__.__name__}(" f"{arg_repr})" + + return OnDiskDataset(root, transform=in_memory_dataset.transform) + + @property + def data(self) -> Any: + msg1 = ( + "It is not recommended to directly access the internal " + "storage format `data` of an 'InMemoryDataset'." + ) + msg2 = ( + "The given 'InMemoryDataset' only references a subset of " + "examples of the full dataset, but 'data' will contain " + "information of the full dataset." + ) + msg3 = ( + "The data of the dataset is already cached, so any " + "modifications to `data` will not be reflected when accessing " + "its elements. Clearing the cache now by removing all " + "elements in `dataset._data_list`." + ) + msg4 = ( + "If you are absolutely certain what you are doing, access the " + "internal storage via `InMemoryDataset._data` instead to " + "suppress this warning. Alternatively, you can access stacked " + "individual attributes of every graph via " + "`dataset.{attr_name}`." + ) + + msg = msg1 + if self._indices is not None: + msg += f" {msg2}" + if self._data_list is not None: + msg += f" {msg3}" + self._data_list = None + msg += f" {msg4}" + + warnings.warn(msg) + + return self._data + + @data.setter + def data(self, value: Any): + self._data = value + self._data_list = None + + def __getattr__(self, key: str) -> Any: + data = self.__dict__.get("_data") + if isinstance(data, Graph) and key in data: + if self._indices is None and data.__inc__(key, data[key]) == 0: + return data[key] + else: + data_list = [self.get(i) for i in self.indices()] + return Batch.from_data_list(data_list)[key] + + raise AttributeError( + f"'{self.__class__.__name__}' object has no " f"attribute '{key}'" + ) + + +def nested_iter(node: Union[Mapping, Sequence]) -> Iterable: + if isinstance(node, Mapping): + for key, value in node.items(): + for inner_key, inner_value in nested_iter(value): + yield inner_key, inner_value + elif isinstance(node, Sequence): + for i, inner_value in enumerate(node): + yield i, inner_value + else: + yield None, node diff --git a/mindscience/sharker/data/on_disk.py b/mindscience/sharker/data/on_disk.py new file mode 100644 index 000000000..2aa80408f --- /dev/null +++ b/mindscience/sharker/data/on_disk.py @@ -0,0 +1,175 @@ +import os +from typing import Any, Callable, Iterable, List, Optional, Sequence, Union + +from mindspore import Tensor + +from .database import Database, RocksDatabase, SQLiteDatabase, Schema +from .graph import Graph +from .dataset import Dataset + + +class OnDiskDataset(Dataset): + r"""Dataset base class for creating large graph datasets which do not + easily fit into CPU memory at once by leveraging a :class:`Database` + backend for on-disk storage and access of data objects. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object and returns a + transformed version. + The data object will be transformed before every access. + (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in a + :class:`~sharker.data.Graphh` or + :class:`~sharker.data.HeteroGraph` object and returns a + boolean value, indicating whether the data object should be + included in the final dataset. (default: :obj:`None`) + backend (str): The :class:`Database` backend to use + (one of :obj:`"sqlite"` or :obj:`"rocksdb"`). + (default: :obj:`"sqlite"`) + schema (Any or Tuple[Any] or Dict[str, Any], optional): The schema of + the input data. + Can take :obj:`int`, :obj:`float`, :obj:`str`, :obj:`object`, or a + dictionary with :obj:`dtype` and :obj:`size` keys (for specifying + tensor data) as input, and can be nested as a tuple or dictionary. + Specifying the schema will improve efficiency, since by default the + database will use python pickling for serializing and + deserializing. If specified to anything different than + :obj:`object`, implementations of :class:`OnDiskDataset` need to + override :meth:`serialize` and :meth:`deserialize` methods. + (default: :obj:`object`) + log (bool, optional): Whether to print any console output while + downloading and processing the dataset. (default: :obj:`True`) + """ + + BACKENDS = { + "sqlite": SQLiteDatabase, + "rocksdb": RocksDatabase, + } + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + backend: str = "sqlite", + schema: Schema = object, + log: bool = True, + ) -> None: + if backend not in self.BACKENDS: + raise ValueError( + f"Database backend must be one of " + f"{set(self.BACKENDS.keys())} " + f"(got '{backend}')" + ) + + self.backend = backend + self.schema = schema + + self._db: Optional[Database] = None + self._numel: Optional[int] = None + + super().__init__(root, transform, pre_filter=pre_filter, log=log) + + @property + def processed_file_names(self) -> str: + return f"{self.backend}.db" + + @property + def db(self) -> Database: + r"""Returns the underlying :class:`Database`.""" + if self._db is not None: + return self._db + + kwargs = {} + cls = self.BACKENDS[self.backend] + if issubclass(cls, SQLiteDatabase): + kwargs["name"] = self.__class__.__name__ + + os.makedirs(self.processed_dir, exist_ok=True) + path = self.processed_paths[0] + self._db = cls(path=path, schema=self.schema, **kwargs) + self._numel = len(self._db) + return self._db + + def close(self) -> None: + r"""Closes the connection to the underlying database.""" + if self._db is not None: + self._db.close() + + def serialize(self, data: Graph) -> Any: + r"""Serializes the :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object into the expected DB + schema. + """ + if self.schema == object: + return data + raise NotImplementedError( + f"`{self.__class__.__name__}.serialize()` " + f"needs to be overridden in case a " + f"non-default schema was passed" + ) + + def deserialize(self, data: Any) -> Graph: + r"""Deserializes the DB entry into a + :class:`~sharker.data.Graph` or + :class:`~sharker.data.HeteroGraph` object. + """ + if self.schema == object: + return data + raise NotImplementedError( + f"`{self.__class__.__name__}.deserialize()` " + f"needs to be overridden in case a " + f"non-default schema was passed" + ) + + def append(self, data: Graph) -> None: + r"""Appends the data object to the dataset.""" + index = len(self) + self.db.insert(index, self.serialize(data)) + self._numel += 1 + + def extend( + self, + data_list: Sequence[Graph], + batch_size: Optional[int] = None, + ) -> None: + r"""Extends the dataset by a list of data objects.""" + start = len(self) + end = start + len(data_list) + data_list = [self.serialize(data) for data in data_list] + self.db.multi_insert(range(start, end), data_list, batch_size) + self._numel += end - start + + def get(self, idx: int) -> Graph: + r"""Gets the data object at index :obj:`idx`.""" + return self.deserialize(self.db.get(idx)) + + def multi_get( + self, + indices: Union[Iterable[int], Tensor, slice, range], + batch_size: Optional[int] = None, + ) -> List[Graph]: + r"""Gets a list of data objects from the specified indices.""" + if len(indices) == 1: + data_list = [self.db.get(indices[0])] + else: + data_list = self.db.multi_get(indices, batch_size) + + data_list = [self.deserialize(data) for data in data_list] + if self.transform is not None: + data_list = [self.transform(data) for data in data_list] + return data_list + + def __getitems__(self, indices: List[int]) -> List[Graph]: + return self.multi_get(indices) + + def len(self) -> int: + if self._numel is None: + self._numel = len(self.db) + return self._numel + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({len(self)})" diff --git a/mindscience/sharker/data/remote_store.py b/mindscience/sharker/data/remote_store.py new file mode 100644 index 000000000..b5f542501 --- /dev/null +++ b/mindscience/sharker/data/remote_store.py @@ -0,0 +1,582 @@ +r"""This class defines the abstraction for a backend-agnostic feature store. +The goal of the feature store is to abstract away all node and edge feature +memory management so that varying implementations can allow for independent +scale-out. + +This particular feature store abstraction makes a few key assumptions: +* The features we care about storing are node and edge features of a graph. + To this end, the attributes that the feature store supports include a + `group_name` (e.g. a heterogeneous node name or a heterogeneous edge type), + an `attr_name` (e.g. `x` or `edge_attr`), and an index. +* A feature can be uniquely identified from any associated attributes specified + in `TensorAttr`. + +It is the job of a feature store implementor class to handle these assumptions +properly. For example, a simple in-memory feature store implementation may +concatenate all metadata values with a feature index and use this as a unique +index in a KV store. More complicated implementations may choose to partition +features in interesting manners based on the provided metadata. + +Major TODOs for future implementation: +* Async `put` and `get` functionality +""" + +import copy +from abc import ABC, abstractmethod +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple +from abc import ABC, abstractmethod +from dataclasses import dataclass +from enum import Enum +from typing import Any, List, Optional, Tuple, Union + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn + + +# We allow indexing with a tensor, numpy array, Python slicing, or a single +# integer index. +IndexType = Union[ms.Tensor, np.ndarray, slice, int] +# A representation of a feature tensor +FeatureTensorType = Union[ms.Tensor, np.ndarray] +ConversionOutputType = Tuple[ + Dict[Tuple[str, str, str], Tensor], Dict[Tuple[str, str, str], Tensor], Dict[Tuple[str, str, str], Optional[Tensor]] +] + + +class EdgeLayout(Enum): + COO = "coo" + CSC = "csc" + CSR = "csr" + + +@dataclass +class TensorAttr: + r"""Defines the attributes of a :class:`FeatureStore` tensor. + It holds all the parameters necessary to uniquely identify a tensor from + the :class:`FeatureStore`. + + Note that the order of the attributes is important; this is the order in + which attributes must be provided for indexing calls. :class:`FeatureStore` + implementations can define a different ordering by overriding + :meth:`TensorAttr.__init__`. + """ + + # The group name that the tensor corresponds to. Defaults to UNSET. + group_name: Optional[str] = None + + # The name of the tensor within its group. Defaults to UNSET. + attr_name: Optional[str] = None + + # The node indices the rows of the tensor correspond to. Defaults to UNSET. + index: Optional[IndexType] = None + + # Convenience methods ##################################################### + + def is_set(self, key: str) -> bool: + r"""Whether an attribute is set in :obj:`TensorAttr`.""" + assert key in self.__dataclass_fields__ + return getattr(self, key) != None + + def is_fully_specified(self) -> bool: + r"""Whether the :obj:`TensorAttr` has no unset fields.""" + return all([self.is_set(key) for key in self.__dataclass_fields__]) + + def fully_specify(self) -> "TensorAttr": + r"""Sets all :obj:`UNSET` fields to :obj:`None`.""" + for key in self.__dataclass_fields__: + if not self.is_set(key): + setattr(self, key, None) + return self + + def update(self, attr: "TensorAttr") -> "TensorAttr": + r"""Updates an :class:`TensorAttr` with set attributes from another + :class:`TensorAttr`. + """ + for key in self.__dataclass_fields__: + if attr.is_set(key): + setattr(self, key, getattr(attr, key)) + return self + + +@dataclass +class EdgeAttr: + r"""Defines the attributes of a :obj:`GraphStore` edge. + It holds all the parameters necessary to uniquely identify an edge from + the :class:`GraphStore`. + + Note that the order of the attributes is important; this is the order in + which attributes must be provided for indexing calls. :class:`GraphStore` + implementations can define a different ordering by overriding + :meth:`EdgeAttr.__init__`. + """ + + # The type of the edge: + edge_type: Tuple[str, str, str] + + # The layout of the edge representation: + layout: EdgeLayout + + # Whether the edge index is sorted by destination node. Useful for + # avoiding sorting costs when performing neighbor sampling, and only + # meaningful for COO (CSC is sorted and CSR is not sorted by definition): + is_sorted: bool = False + + # The number of source and destination nodes in this edge type: + size: Optional[Tuple[int, int]] = None + + def __posi_init__(self): + self.layout = EdgeLayout(self.layout) + if self.layout == EdgeLayout.CSR and self.is_sorted: + raise ValueError( + "Cannot create a 'CSR' edge attribute with " "option 'is_sorted=True'" + ) + if self.layout == EdgeLayout.CSC: + self.is_sorted = True + + +class AttrView: + r"""Defines a view of a :class:`FeatureStore` that is obtained from a + specification of attributes on the feature store. The view stores a + reference to the backing feature store as well as a :class:`TensorAttr` + object that represents the view's state. + + Users can create views either using the :class:`AttrView` constructor, + :meth:`FeatureStore.view`, or by incompletely indexing a feature store. + For example, the following calls all create views: + + .. code-block:: python + + store[group_name] + store[group_name].feat + store[group_name, feat] + + While the following calls all materialize those views and produce tensors + by either calling the view or fully-specifying the view: + + .. code-block:: python + + store[group_name]() + store[group_name].feat[index] + store[group_name, feat][index] + """ + + def __init__(self, store: "FeatureStore", attr: TensorAttr): + self.__dict__["_store"] = store + self.__dict__["_attr"] = attr + + # Advanced indexing ####################################################### + + def __getattr__(self, key: Any) -> Union["AttrView", FeatureTensorType]: + r"""Sets the first unset field of the backing :class:`TensorAttr` + object to the attribute. + + This allows for :class:`AttrView` to be indexed by different values of + attributes, in order. + In particular, for a feature store that we want to index by + :obj:`group_name` and :obj:`attr_name`, the following code will do so: + + .. code-block:: python + + store[group, attr] + store[group].attr + store.group.attr + """ + out = copy.copy(self) + + # Find the first attribute name that is UNSET: + attr_name: Optional[str] = None + for field in out._attr.__dataclass_fields__: + if getattr(out._attr, field) == None: + attr_name = field + break + + if attr_name is None: + raise AttributeError( + f"Cannot access attribute '{key}' on view " + f"'{out}' as all attributes have already " + f"been set in this view" + ) + + setattr(out._attr, attr_name, key) + + if out._attr.is_fully_specified(): + return out._store.get_tensor(out._attr) + + return out + + def __getitem__(self, key: Any) -> Union["AttrView", FeatureTensorType]: + r"""Sets the first unset field of the backing :class:`TensorAttr` + object to the attribute via indexing. + + This allows for :class:`AttrView` to be indexed by different values of + attributes, in order. + In particular, for a feature store that we want to index by + :obj:`group_name` and :obj:`attr_name`, the following code will do so: + + .. code-block:: python + + store[group, attr] + store[group][attr] + + """ + return self.__getattr__(key) + + # Setting attributes ###################################################### + + def __setattr__(self, key: str, value: Any): + r"""Supports attribute assignment to the backing :class:`TensorAttr` of + an :class:`AttrView`. + + This allows for :class:`AttrView` objects to set their backing + attribute values. + In particular, the following operation sets the :obj:`index` of an + :class:`AttrView`: + + .. code-block:: python + + view = store.view(group_name) + view.index = Tensor([1, 2, 3]) + """ + if key not in self._attr.__dataclass_fields__: + raise ValueError( + f"Attempted to set nonexistent attribute '{key}' " + f"(acceptable attributes are " + f"{self._attr.__dataclass_fields__})" + ) + + setattr(self._attr, key, value) + + def __setitem__(self, key: str, value: Any): + r"""Supports attribute assignment to the backing :class:`TensorAttr` of + an :class:`AttrView` via indexing. + + This allows for :class:`AttrView` objects to set their backing + attribute values. + In particular, the following operation sets the `index` of an + :class:`AttrView`: + + .. code-block:: python + + view = store.view(TensorAttr(group_name)) + view['index'] = Tensor([1, 2, 3]) + """ + self.__setattr__(key, value) + + # Miscellaneous built-ins ################################################# + + def __call__(self) -> FeatureTensorType: + r"""Supports :class:`AttrView` as a callable to force retrieval from + the currently specified attributes. + + In particular, this passes the current :class:`TensorAttr` object to a + GET call, regardless of whether all attributes have been specified. + It returns the result of this call. + In particular, the following operation returns a tensor by performing a + GET operation on the backing feature store: + + .. code-block:: python + + store[group_name, attr_name]() + """ + # Set all UNSET values to None: + out = copy.copy(self) + out._attr.fully_specify() + return out._store.get_tensor(out._attr) + + def __copy__(self) -> "AttrView": + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + out.__dict__["_attr"] = copy.copy(out.__dict__["_attr"]) + return out + + def __eq__(self, obj: Any) -> bool: + r"""Compares two :class:`AttrView` objects by checking equality of + their :class:`FeatureStore` references and :class:`TensorAttr` + attributes. + """ + if not isinstance(obj, AttrView): + return False + return self._store == obj._store and self._attr == obj._attr + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(store={self._store}, " f"attr={self._attr})" + + +class FeatureStore(ABC): + r"""An abstract base class to access features from a remote feature store. + + Args: + tensor_attr_cls (TensorAttr, optional): A user-defined + :class:`TensorAttr` class to customize the required attributes and + their ordering to unique identify tensor values. + (default: :obj:`None`) + """ + + _tensor_attr_cls: TensorAttr + + def __init__(self, tensor_attr_cls: Optional[Any] = None): + super().__init__() + self.__dict__["_tensor_attr_cls"] = tensor_attr_cls or TensorAttr + + # Core (CRUD) ############################################################# + + @abstractmethod + def _put_tensor(self, tensor: FeatureTensorType, attr: TensorAttr) -> bool: + r"""To be implemented by :class:`FeatureStore` subclasses.""" + pass + + def put_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: + r"""Synchronously adds a :obj:`tensor` to the :class:`FeatureStore`. + Returns whether insertion was successful. + + Args: + tensor (Tensor or np.ndarray): The feature tensor to be + added. + *args: Arguments passed to :class:`TensorAttr`. + **kwargs: Keyword arguments passed to :class:`TensorAttr`. + + Raises: + ValueError: If the input :class:`TensorAttr` is not fully + specified. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError( + f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully-specify the input by " + f"specifying all 'UNSET' fields" + ) + return self._put_tensor(tensor, attr) + + @abstractmethod + def _get_tensor(self, attr: TensorAttr) -> Optional[FeatureTensorType]: + r"""To be implemented by :class:`FeatureStore` subclasses.""" + pass + + def get_tensor( + self, + *args, + convert_type: bool = False, + **kwargs, + ) -> FeatureTensorType: + r"""Synchronously obtains a :class:`tensor` from the + :class:`FeatureStore`. + + Args: + *args: Arguments passed to :class:`TensorAttr`. + convert_type (bool, optional): Whether to convert the type of the + output tensor to the type of the attribute index. + (default: :obj:`False`) + **kwargs: Keyword arguments passed to :class:`TensorAttr`. + + Raises: + ValueError: If the input :class:`TensorAttr` is not fully + specified. + KeyError: If the tensor corresponding to the input + :class:`TensorAttr` was not found. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError( + f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully-specify the input by " + f"specifying all 'UNSET' fields." + ) + + tensor = self._get_tensor(attr) + if tensor is None: + raise KeyError(f"A tensor corresponding to '{attr}' was not found") + return self._to_type(attr, tensor) if convert_type else tensor + + def _multi_get_tensor( + self, + attrs: List[TensorAttr], + ) -> List[Optional[FeatureTensorType]]: + r"""To be implemented by :class:`FeatureStore` subclasses.""" + return [self._get_tensor(attr) for attr in attrs] + + def multi_get_tensor( + self, + attrs: List[TensorAttr], + convert_type: bool = False, + ) -> List[FeatureTensorType]: + r"""Synchronously obtains a list of tensors from the + :class:`FeatureStore` for each tensor associated with the attributes in + :obj:`attrs`. + + .. note:: + The default implementation simply iterates over all calls to + :meth:`get_tensor`. Implementor classes that can provide + additional, more performant functionality are recommended to + to override this method. + + Args: + attrs (List[TensorAttr]): A list of input :class:`TensorAttr` + objects that identify the tensors to obtain. + convert_type (bool, optional): Whether to convert the type of the + output tensor to the type of the attribute index. + (default: :obj:`False`) + + Raises: + ValueError: If any input :class:`TensorAttr` is not fully + specified. + KeyError: If any of the tensors corresponding to the input + :class:`TensorAttr` was not found. + """ + attrs = [self._tensor_attr_cls.cast(attr) for attr in attrs] + bad_attrs = [attr for attr in attrs if not attr.is_fully_specified()] + if len(bad_attrs) > 0: + raise ValueError( + f"The input TensorAttr(s) '{bad_attrs}' are not fully " + f"specified. Please fully-specify them by specifying all " + f"'UNSET' fields" + ) + + tensors = self._multi_get_tensor(attrs) + if any(v is None for v in tensors): + bad_attrs = [attrs[i] for i, v in enumerate(tensors) if v is None] + raise KeyError( + f"Tensors corresponding to attributes " f"'{bad_attrs}' were not found" + ) + + return [ + self._to_type(attr, tensor) if convert_type else tensor + for attr, tensor in zip(attrs, tensors) + ] + + @abstractmethod + def _remove_tensor(self, attr: TensorAttr) -> bool: + r"""To be implemented by :obj:`FeatureStore` subclasses.""" + pass + + def remove_tensor(self, *args, **kwargs) -> bool: + r"""Removes a tensor from the :class:`FeatureStore`. + Returns whether deletion was successful. + + Args: + *args: Arguments passed to :class:`TensorAttr`. + **kwargs: Keyword arguments passed to :class:`TensorAttr`. + + Raises: + ValueError: If the input :class:`TensorAttr` is not fully + specified. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + if not attr.is_fully_specified(): + raise ValueError( + f"The input TensorAttr '{attr}' is not fully " + f"specified. Please fully-specify the input by " + f"specifying all 'UNSET' fields." + ) + return self._remove_tensor(attr) + + def update_tensor(self, tensor: FeatureTensorType, *args, **kwargs) -> bool: + r"""Updates a :obj:`tensor` in the :class:`FeatureStore` with a new + value. Returns whether the update was succesful. + + .. note:: + Implementor classes can choose to define more efficient update + methods; the default performs a removal and insertion. + + Args: + tensor (Tensor or np.ndarray): The feature tensor to be + updated. + *args: Arguments passed to :class:`TensorAttr`. + **kwargs: Keyword arguments passed to :class:`TensorAttr`. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + self.remove_tensor(attr) + return self.put_tensor(tensor, attr) + + # Additional methods ###################################################### + + @abstractmethod + def _get_tensor_size(self, attr: TensorAttr) -> Optional[Tuple[int, ...]]: + pass + + def get_tensor_size(self, *args, **kwargs) -> Optional[Tuple[int, ...]]: + r"""Obtains the size of a tensor given its :class:`TensorAttr`, or + :obj:`None` if the tensor does not exist. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + if not attr.is_set("index"): + attr.index = None + return self._get_tensor_size(attr) + + @abstractmethod + def get_all_tensor_attrs(self) -> List[TensorAttr]: + r"""Returns all registered tensor attributes.""" + pass + + # `AttrView` methods ###################################################### + + def view(self, *args, **kwargs) -> AttrView: + r"""Returns a view of the :class:`FeatureStore` given a not yet + fully-specified :class:`TensorAttr`. + """ + attr = self._tensor_attr_cls.cast(*args, **kwargs) + return AttrView(self, attr) + + # Helper functions ######################################################## + + @staticmethod + def _to_type( + attr: TensorAttr, + tensor: FeatureTensorType, + ) -> FeatureTensorType: + if isinstance(attr.index, ms.Tensor) and isinstance(tensor, np.ndarray): + return Tensor.from_numpy(tensor) + if isinstance(attr.index, np.ndarray) and isinstance(tensor, ms.Tensor): + return tensor.asnumpy() + return tensor + + # Python built-ins ######################################################## + + def __setitem__(self, key: TensorAttr, value: FeatureTensorType): + r"""Supports :obj:`store[tensor_attr] = tensor`.""" + # CastMixin will handle the case of key being a tuple or TensorAttr + # object: + key = self._tensor_attr_cls.cast(key) + # We need to fully-specify the key for __setitem__ as it does not make + # sense to work with a view here: + key.fully_specify() + self.put_tensor(value, key) + + def __getitem__(self, key: TensorAttr) -> Any: + r"""Supports pythonic indexing into the :class:`FeatureStore`. + + In particular, the following rules are followed for indexing: + + * A fully-specified :obj:`key` will produce a tensor output. + + * A partially-specified :obj:`key` will produce an :class:`AttrView` + output, which is a view on the :class:`FeatureStore`. If a view is + called, it will produce a tensor output from the corresponding + (partially specified) attributes. + """ + # CastMixin will handle the case of key being a tuple or TensorAttr: + attr = self._tensor_attr_cls.cast(key) + if attr.is_fully_specified(): + return self.get_tensor(attr) + # If the view is not fully-specified, return a :class:`AttrView`: + return self.view(attr) + + def __delitem__(self, key: TensorAttr): + r"""Supports :obj:`del store[tensor_attr]`.""" + # CastMixin will handle the case of key being a tuple or TensorAttr + # object: + key = self._tensor_attr_cls.cast(key) + key.fully_specify() + self.remove_tensor(key) + + def __iter__(self): + raise NotImplementedError + + def __eq__(self, obj: object) -> bool: + return id(self) == id(obj) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mindscience/sharker/data/separate.py b/mindscience/sharker/data/separate.py new file mode 100644 index 000000000..f1287b699 --- /dev/null +++ b/mindscience/sharker/data/separate.py @@ -0,0 +1,149 @@ +from collections.abc import Mapping, Sequence +from typing import Any, Type, TypeVar + +import numpy as np +import mindspore as ms + +from .graph import Graph +from .storage import BaseStorage +from ..utils import narrow + + +T = TypeVar("T") + + +def separate( + cls: Type[T], + batch: Any, + idx: int, + slice_dict: Any, + inc_dict: Any = None, + decrement: bool = True, +) -> T: + # Separates the individual element from a `batch` at index `idx`. + # `separate` can handle both homogeneous and heterogeneous data objects by + # individually separating all their stores. + # In addition, `separate` can handle nested data structures such as + # dictionaries and lists. + + data = cls().stores_as(batch) + + # Iterate over each storage object and recursively separate its attributes: + for batch_store, data_store in zip(batch.stores, data.stores): + key = batch_store._key + if key is not None: + attrs = slice_dict[key].keys() + else: + attrs = set(batch_store.keys()) + attrs = [attr for attr in slice_dict.keys() if attr in attrs] + + for attr in attrs: + if key is not None: + slices = slice_dict[key][attr] + incs = inc_dict[key][attr] if decrement else None + else: + slices = slice_dict[attr] + incs = inc_dict[attr] if decrement else None + + data_store[attr] = _separate( + attr, + batch_store[attr], + idx, + slices, + incs, + batch, + batch_store, + decrement, + ) + + # The `num_nodes` attribute needs special treatment, as we cannot infer + # the real number of nodes from the total number of nodes alone: + if hasattr(batch_store, "_num_nodes"): + data_store.num_nodes = batch_store._num_nodes[idx] + + return data + + +def _separate( + key: str, + values: Any, + idx: int, + slices: Any, + incs: Any, + batch: Graph, + store: BaseStorage, + decrement: bool, +) -> Any: + + if isinstance(values, (ms.Tensor, np.ndarray)): + # Narrow a `Tensor` based on `slices`. + # NOTE: We need to take care of decrementing elements appropriately. + key = str(key) + is_tensor = False + if isinstance(values, ms.Tensor): + is_tensor = True + + slices_np = slices.asnumpy().astype(np.int64) if isinstance(slices, ms.Tensor) else slices + incs_np = incs.asnumpy().astype(np.int64) if isinstance(incs, ms.Tensor) else incs + values_np = values.asnumpy() if isinstance(values, ms.Tensor) else values + + cat_dim = batch.__cat_dim__(key, values, store) + start, end = int(slices_np[idx]), int(slices_np[idx + 1]) + value = narrow(values_np, cat_dim or 0, start, end - start) + value = np.squeeze(value, axis=0) if cat_dim is None else value + + if decrement and incs is not None and (incs.ndim > 1 or incs_np[idx] != 0): + value = value - incs_np[idx] + + return ms.Tensor(value) if is_tensor else value + + elif isinstance(values, Mapping): + # Recursively separate elements of dictionaries. + return { + key: _separate( + key, + value, + idx, + slices=slices[key], + incs=incs[key] if decrement else None, + batch=batch, + store=store, + decrement=decrement, + ) + for key, value in values.items() + } + + elif ( + isinstance(values, Sequence) + and isinstance(values[0], Sequence) + and not isinstance(values[0], str) + and len(values[0]) > 0 + and isinstance(values[0][0], ms.Tensor) + and isinstance(slices, Sequence) + ): + # Recursively separate elements of lists of lists. + return [value[idx] for value in values] + + elif ( + isinstance(values, Sequence) + and not isinstance(values, str) + and isinstance(values[0], ms.Tensor) + and isinstance(slices, Sequence) + ): + # Recursively separate elements of lists of Tensors/SparseTensors. + return [ + _separate( + key, + value, + idx, + slices=slices[i], + incs=incs[i] if decrement else None, + batch=batch, + store=store, + decrement=decrement, + ) + for i, value in enumerate(values) + ] + + else: + return values[idx] diff --git a/mindscience/sharker/data/storage.py b/mindscience/sharker/data/storage.py new file mode 100644 index 000000000..122f5078f --- /dev/null +++ b/mindscience/sharker/data/storage.py @@ -0,0 +1,780 @@ +import copy +import warnings +import weakref +from collections import defaultdict, namedtuple +from collections.abc import Mapping, MutableMapping, Sequence +from enum import Enum +from typing import ( + Any, + Callable, + Dict, + Iterable, + Iterator, + List, + NamedTuple, + Optional, + Set, + Tuple, + Union, +) +from typing_extensions import Self + +import numpy as np +import mindspore as ms +from mindspore import ops, mint + +from .view import ItemsView, KeysView, ValuesView +from ..utils.coalesce import coalesce, coalesce_np +from ..utils.undirected import is_undirected +from ..utils.select import select +from ..utils.sort_edge_index import sort_edge_index, sort_edge_index_np +from ..utils.isolated import contains_isolated_nodes + + +N_KEYS = {"x", "feat", "pos", "batch", "node_type", "n_id", "tf"} +E_KEYS = {"edge_index", "edge_weight", "edge_attr", "edge_type", "e_id"} + + +class AttrType(Enum): + NODE = "NODE" + EDGE = "EDGE" + OTHER = "OTHER" + + +class BaseStorage(MutableMapping): + # This class wraps a Python dictionary and extends it as follows: + # 1. It allows attribute assignments, e.g.: + # `storage.x = ...` in addition to `storage['x'] = ...` + # 2. It allows private attributes that are not exposed to the user, e.g.: + # `storage._{key} = ...` and accessible via `storage._{key}` + # 3. It holds an (optional) weak reference to its parent object, e.g.: + # `storage._parent = weakref.ref(parent)` + # 4. It allows iterating over only a subset of keys, e.g.: + # `storage.values('x', 'y')` or `storage.items('x', 'y') + # 5. It adds additional Mindspore Tensor functionality, e.g.: + # `storage.numpy()`, `storage.tensor()` or `storage.share_memory_()`. + def __init__( + self, + _mapping: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> None: + super().__init__() + self._mapping: Dict[str, Any] = {} + for key, value in (_mapping or {}).items(): + setattr(self, key, value) + for key, value in kwargs.items(): + setattr(self, key, value) + + @property + def _key(self) -> Any: + return None + + def _pop_cache(self, key: str) -> None: + for cache in getattr(self, "_cached_attr", {}).values(): + cache.discard(key) + + def __len__(self) -> int: + return len(self._mapping) + + def __getattr__(self, key: str) -> Any: + if key == "_mapping": + self._mapping = {} + return self._mapping + try: + return self[key] + except KeyError: + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{key}'" + ) from None + + def __setattr__(self, key: str, value: Any) -> None: + propobj = getattr(self.__class__, key, None) + if propobj is not None and getattr(propobj, "fset", None) is not None: + propobj.fset(self, value) + elif key == "_parent": + self.__dict__[key] = weakref.ref(value) + elif key[:1] == "_": + self.__dict__[key] = value + else: + self[key] = value + + def __delattr__(self, key: str) -> None: + if key[:1] == "_": + del self.__dict__[key] + else: + del self[key] + + def __getitem__(self, key: str) -> Any: + return self._mapping[key] + + def __setitem__(self, key: str, value: Any) -> None: + self._pop_cache(key) + if value is None and key in self._mapping: + del self._mapping[key] + elif value is not None: + self._mapping[key] = value + + def __delitem__(self, key: str) -> None: + if key in self._mapping: + self._pop_cache(key) + del self._mapping[key] + + def __iter__(self) -> Iterator[Any]: + return iter(self._mapping) + + def __copy__(self) -> Self: + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + if key != "_cached_attr": + out.__dict__[key] = value + out._mapping = copy.copy(out._mapping) + return out + + def __deepcopy__(self, memo: Optional[Dict[int, Any]]) -> Self: + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + out._mapping = copy.deepcopy(out._mapping, memo) + return out + + def __getstate__(self) -> Dict[str, Any]: + out = self.__dict__.copy() + + _parent = out.get("_parent", None) + if _parent is not None: + out["_parent"] = _parent() + + return out + + def __setstate__(self, mapping: Dict[str, Any]) -> None: + for key, value in mapping.items(): + self.__dict__[key] = value + + _parent = self.__dict__.get("_parent", None) + if _parent is not None: + self.__dict__["_parent"] = weakref.ref(_parent) + + def __repr__(self) -> str: + return repr(self._mapping) + + # Allow iterating over subsets ############################################ + + # In contrast to standard `keys()`, `values()` and `items()` functions of + # Python dictionaries, we allow to only iterate over a subset of items + # denoted by a list of keys `args`. + # This is especially useful for adding MindSpore Tensor functionality to the + # storage object, e.g., in case we only want to transfer a subset of keys + # to the GPU (i.e. the ones that are relevant to the deep learning model). + + def keys(self, *args: str) -> KeysView: # type: ignore + return KeysView(self._mapping, *args) + + def values(self, *args: str) -> ValuesView: # type: ignore + return ValuesView(self._mapping, *args) + + def items(self, *args: str) -> ItemsView: # type: ignore + return ItemsView(self._mapping, *args) + + def apply_(self, func: Callable, *args: str) -> Self: + r"""Applies the in-place function :obj:`func`, either to all attributes + or only the ones given in :obj:`*args`. + """ + for value in self.values(*args): + recursive_apply_(value, func) + return self + + def apply(self, func: Callable, *args: str) -> Self: + r"""Applies the function :obj:`func`, either to all attributes or only + the ones given in :obj:`*args`. + """ + for key, value in self.items(*args): + self[key] = recursive_apply(value, func) + return self + + # Additional functionality ################################################ + + def get(self, key: str, value: Optional[Any] = None) -> Any: + return self._mapping.get(key, value) + + def to_dict(self) -> Dict[str, Any]: + r"""Returns a dictionary of stored key/value pairs.""" + out_dict = copy.copy(self._mapping) + # Needed to preserve individual `num_nodes` attributes when calling + # `BaseData.collate`. + # TODO (matthias) Try to make this more generic. + if "_num_nodes" in self.__dict__: + out_dict["_num_nodes"] = self.__dict__["_num_nodes"] + return out_dict + + def to_namedtuple(self) -> NamedTuple: + r"""Returns a :obj:`NamedTuple` of stored key/value pairs.""" + field_names = list(self.keys()) + typename = f"{self.__class__.__name__}Tuple" + StorageTuple = namedtuple(typename, field_names) # type: ignore + return StorageTuple(*[self[key] for key in field_names]) + + def copy(self, *args: str) -> Self: + r"""Performs a deep-copy of the object.""" + return copy.deepcopy(self) + + def numpy(self, *args: str) -> Self: + r"""Copies attributes to CPU memory, either for all attributes or only + the ones given in :obj:`*args`. + """ + return self.apply(lambda x: x.asnumpy(), *args) + + def tensor(self, *args: str) -> Self: + r"""Copies attributes to CPU memory, either for all attributes or only + the ones given in :obj:`*args`. + """ + return self.apply( + lambda x: Tensor.from_numpy(x) if isinstance(x, np.ndarray) else x, *args + ) + + # Time Handling ########################################################### + + def _cat_dims(self, keys: Iterable[str]) -> Dict[str, int]: + return {key: self._parent().__cat_dim__(key, self[key], self) for key in keys} + + def _select( + self, + keys: Iterable[str], + index_or_mask: ms.Tensor, + ) -> Self: + + for key, dim in self._cat_dims(keys).items(): + self[key] = select(self[key], index_or_mask, dim) + + return self + + def concat(self, other: Self) -> Self: + if not (set(self.keys()) == set(other.keys())): + raise AttributeError("Given storage is not compatible") + + for key, dim in self._cat_dims(self.keys()).items(): + value1 = self[key] + value2 = other[key] + + if key in {"num_nodes", "num_edges"}: + self[key] = value1 + value2 + + elif isinstance(value1, list): + self[key] = value1 + value2 + + elif isinstance(value1, np.ndarray): + self[key] = np.concatenate([value1, value2], axis=dim) + + elif isinstance(value1, ms.Tensor): + self[key] = mint.cat([value1, value2], dim=dim) + + else: + raise NotImplementedError( + f"'{self.__class__.__name__}.concat' not yet implemented " + f"for '{type(value1)}'" + ) + + return self + + def is_sorted_by_time(self) -> bool: + if "time" in self: + return bool(np.all(self.time[:-1] <= self.time[1:])) + return True + + def sort_by_time(self) -> Self: + if self.is_sorted_by_time(): + return self + + if "time" in self: + perm = np.argsort(self.time) + + if self.is_node_attr("time"): + keys = self.node_attrs() + elif self.is_edge_attr("time"): + keys = self.edge_attrs() + + self._select(keys, perm) + + return self + + def snapshot( + self, + start_time: Union[float, int], + end_time: Union[float, int], + ) -> Self: + if "time" in self: + mask = np.logical_and((self.time >= start_time), (self.time <= end_time)) + + if self.is_node_attr("time"): + keys = self.node_attrs() + elif self.is_edge_attr("time"): + keys = self.edge_attrs() + + self._select(keys, mask) + + if self.is_node_attr("time") and "num_nodes" in self: + self.num_nodes: Optional[int] = int(mask.sum()) + + return self + + def up_to(self, time: Union[float, int]) -> Self: + if "time" in self: + return self.snapshot(self.time.min().item(), time) + return self + + +class NodeStorage(BaseStorage): + r"""A storage for node-level information.""" + + @property + def _key(self) -> str: + key = self.__dict__.get("_key", None) + if key is None or not isinstance(key, str): + raise ValueError("'_key' does not denote a valid node type") + return key + + @property + def can_infer_num_nodes(self) -> bool: + keys = set(self.keys()) + num_node_keys = { + "num_nodes", + "x", + "crd", + "batch", + "adj", + "adj_t", + "edge_index", + "face", + } + if len(keys & num_node_keys) > 0: + return True + elif len([key for key in keys if "node" in key]) > 0: + return True + else: + return False + + @property + def num_nodes(self) -> Optional[int]: + # We sequentially access attributes that reveal the number of nodes. + if "num_nodes" in self: + return self["num_nodes"] + for key, value in self.items(): + if isinstance(value, ms.Tensor) and key in N_KEYS: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + if isinstance(value, np.ndarray) and key in N_KEYS: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + for key, value in self.items(): + if isinstance(value, ms.Tensor) and "node" in key: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + if isinstance(value, np.ndarray) and "node" in key: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + warnings.warn( + f"Unable to accurately infer 'num_nodes' from the attribute set " + f"'{set(self.keys())}'. Please explicitly set 'num_nodes' as an " + f"attribute of " + + ("'data'" if self._key is None else f"'data[{self._key}]'") + + " to suppress this warning" + ) + if "edge_index" in self and isinstance(self.edge_index, ms.Tensor): + if self.edge_index.numel() > 0: + return int(mint.max(self.edge_index)) + 1 + return 0 + if "face" in self and isinstance(self.face, ms.Tensor): + if self.face.numel() > 0: + return int(mint.max(self.face)) + 1 + return 0 + return None + + @num_nodes.setter + def num_nodes(self, num_nodes: Optional[int]) -> None: + self["num_nodes"] = num_nodes + + @property + def num_node_features(self) -> int: + x: Optional[Any] = self.get("x") + if isinstance(x, ms.Tensor): + return 1 if x.dim() == 1 else x.shape[-1] + if isinstance(x, np.ndarray): + return 1 if x.ndim == 1 else x.shape[-1] + return 0 + + @property + def num_features(self) -> int: + return self.num_node_features + + def is_node_attr(self, key: str) -> bool: + if "_cached_attr" not in self.__dict__: + self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) + + if key in self._cached_attr[AttrType.NODE]: + return True + if key in self._cached_attr[AttrType.OTHER]: + return False + + value = self[key] + + if not isinstance(value, (ms.Tensor, np.ndarray)): + self._cached_attr[AttrType.OTHER].add(key) + return False + + if value.ndim == 0: + self._cached_attr[AttrType.OTHER].add(key) + return False + + cat_dim = self._parent().__cat_dim__(key, value, self) + if value.shape[cat_dim] != self.num_nodes: + self._cached_attr[AttrType.OTHER].add(key) + return False + + self._cached_attr[AttrType.NODE].add(key) + return True + + def is_edge_attr(self, key: str) -> bool: + return False + + def node_attrs(self) -> List[str]: + return [key for key in self.keys() if self.is_node_attr(key)] + + +class EdgeStorage(BaseStorage): + r"""A storage for edge-level information. + + We support multiple ways to store edge connectivity in a + :class:`EdgeStorage` object: + + * :obj:`edge_index`: A :class:`ms.Tensor` holding edge indices in + COO format with shape :obj:`[2, num_edges]` (the default format) + + * :obj:`adj`: A :class:`mindspore.SparseTensor` holding edge indices in + a sparse format, supporting both COO and CSR format. + + * :obj:`adj_t`: A **transposed** :class:`mindspore.SparseTensor` holding + edge indices in a sparse format, supporting both COO and CSR format. + This is the most efficient one for graph-based deep learning models as + indices are sorted based on target nodes. + """ + + @property + def _key(self) -> Tuple[str, str, str]: + key = self.__dict__.get("_key", None) + if key is None or not isinstance(key, tuple) or not len(key) == 3: + raise ValueError("'_key' does not denote a valid edge type") + return key + + @property + def edge_index(self) -> ms.Tensor: + if "edge_index" in self: + return self["edge_index"] + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute " + f"'edge_index', 'adj' or 'adj_t'" + ) + + @edge_index.setter + def edge_index(self, edge_index: Optional[ms.Tensor]) -> None: + self["edge_index"] = edge_index + + @property + def num_edges(self) -> int: + # We sequentially access attributes that reveal the number of edges. + if "num_edges" in self: + return self["num_edges"] + for key, value in self.items(): + if isinstance(value, ms.Tensor) and key in E_KEYS: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + if isinstance(value, np.ndarray) and key in E_KEYS: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + for key, value in self.items(): + if isinstance(value, ms.Tensor) and "edge" in key: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + if isinstance(value, np.ndarray) and "edge" in key: + cat_dim = self._parent().__cat_dim__(key, value, self) + return value.shape[cat_dim] + return 0 + + @property + def num_edge_features(self) -> int: + edge_attr: Optional[Any] = self.get("edge_attr") + if isinstance(edge_attr, ms.Tensor): + return 1 if edge_attr.dim() == 1 else edge_attr.shape[-1] + if isinstance(edge_attr, np.ndarray): + return 1 if edge_attr.ndim == 1 else edge_attr.shape[-1] + return 0 + + @property + def num_features(self) -> int: + return self.num_edge_features + + @property + def shape(self) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: + + if self._key is None: + raise NameError( + "Unable to infer 'size' without explicit " "'_key' assignment" + ) + + size = ( + self._parent()[self._key[0]].num_nodes, + self._parent()[self._key[-1]].num_nodes, + ) + + return size + + def is_node_attr(self, key: str) -> bool: + return False + + def is_edge_attr(self, key: str) -> bool: + if "_cached_attr" not in self.__dict__: + self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) + + if key in self._cached_attr[AttrType.EDGE]: + return True + if key in self._cached_attr[AttrType.OTHER]: + return False + + value = self[key] + + if not isinstance(value, (ms.Tensor, np.ndarray)): + self._cached_attr[AttrType.OTHER].add(key) + return False + + if value.ndim == 0: + self._cached_attr[AttrType.OTHER].add(key) + return False + + cat_dim = self._parent().__cat_dim__(key, value, self) + if value.shape[cat_dim] != self.num_edges: + self._cached_attr[AttrType.OTHER].add(key) + return False + + self._cached_attr[AttrType.EDGE].add(key) + return True + + def edge_attrs(self) -> List[str]: + return [key for key in self.keys() if self.is_edge_attr(key)] + + def is_sorted(self, sort_by_row: bool = True) -> bool: + if "edge_index" in self: + index = self.edge_index[0] if sort_by_row else self.edge_index[1] + return bool(np.all(index[:-1] <= index[1:])) + return True + + def sort(self, sort_by_row: bool = True) -> Self: + if "edge_index" in self: + edge_attrs = self.edge_attrs() + edge_attrs.remove("edge_index") + edge_feats = [self[edge_attr] for edge_attr in edge_attrs] + self.edge_index, edge_feats = sort_edge_index_np( + self.edge_index, edge_feats, sort_by_row=sort_by_row + ) + for key, edge_feat in zip(edge_attrs, edge_feats): + self[key] = edge_feat + return self + + def is_coalesced(self) -> bool: + for value in self.values("adj", "adj_t"): + return value.is_coalesced() + + if "edge_index" in self: + size = [s for s in self.shape if s is not None] + num_nodes = max(size) if len(size) > 0 else None + new_edge_index = coalesce_np(self.edge_index, num_nodes=num_nodes) + return self.edge_index.size == new_edge_index.size and np.all( + self.edge_index == new_edge_index + ) + + return True + + def coalesce(self, reduce: str = "sum") -> Self: + for key, value in self.items("adj", "adj_t"): + self[key] = value.coalesce(reduce) + + if "edge_index" in self: + + size = [s for s in self.shape if s is not None] + num_nodes = max(size) if len(size) > 0 else None + + self.edge_index, self.edge_attr = coalesce_np( + self.edge_index, + edge_attr=self.get("edge_attr"), + num_nodes=num_nodes, + ) + + return self + + def has_isolated_nodes(self) -> bool: + edge_index, num_nodes = self.edge_index, self.shape[1] + if num_nodes is None: + raise NameError("Unable to infer 'num_nodes'") + if self.is_bipartite(): + return np.unique(edge_index[1]).size < num_nodes + else: + return contains_isolated_nodes(edge_index, num_nodes) + + def has_self_loops(self) -> bool: + if self.is_bipartite(): + return False + edge_index = self.edge_index + return int((edge_index[0] == edge_index[1]).sum()) > 0 + + def is_undirected(self) -> bool: + if self.is_bipartite(): + return False + + for value in self.values("adj", "adj_t"): + return value.is_symmetric() + + edge_index = self.edge_index + edge_attr = self.edge_attr if "edge_attr" in self else None + return is_undirected(edge_index, edge_attr, num_nodes=self.shape[0]) + + def is_directed(self) -> bool: + return not self.is_undirected() + + def is_bipartite(self) -> bool: + return self._key is not None and self._key[0] != self._key[-1] + + +class GlobalStorage(NodeStorage, EdgeStorage): + r"""A storage for both node-level and edge-level information.""" + + @property + def _key(self) -> Any: + return None + + @property + def num_features(self) -> int: + return self.num_node_features + + @property + def shape(self) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: + size = (self.num_nodes, self.num_nodes) + return size + + def is_node_attr(self, key: str) -> bool: + if "_cached_attr" not in self.__dict__: + self._cached_attr: Dict[AttrType, Set[str]] = defaultdict(set) + + if key in self._cached_attr[AttrType.NODE]: + return True + if key in self._cached_attr[AttrType.EDGE]: + return False + if key in self._cached_attr[AttrType.OTHER]: + return False + + value = self[key] + + if (isinstance(value, (list, tuple)) + and len(value) == self.num_nodes): + self._cached_attr[AttrType.NODE].add(key) + return True + + if not isinstance(value, (ms.Tensor, np.ndarray)): + return False + + if value.ndim == 0: + self._cached_attr[AttrType.OTHER].add(key) + return False + + cat_dim = self._parent().__cat_dim__(key, value, self) + num_nodes, num_edges = self.num_nodes, self.num_edges + + if value.shape[cat_dim] != num_nodes: + if value.shape[cat_dim] == num_edges: + self._cached_attr[AttrType.EDGE].add(key) + else: + self._cached_attr[AttrType.OTHER].add(key) + return False + + if num_nodes != num_edges: + self._cached_attr[AttrType.NODE].add(key) + return True + + if "edge" not in key: + self._cached_attr[AttrType.NODE].add(key) + return True + else: + self._cached_attr[AttrType.EDGE].add(key) + return False + + def is_edge_attr(self, key: str) -> bool: + if "_cached_attr" not in self.__dict__: + self._cached_attr = defaultdict(set) + + if key in self._cached_attr[AttrType.EDGE]: + return True + if key in self._cached_attr[AttrType.NODE]: + return False + if key in self._cached_attr[AttrType.OTHER]: + return False + + value = self[key] + + if not isinstance(value, (ms.Tensor, np.ndarray)): + return False + + if value.ndim == 0: + self._cached_attr[AttrType.OTHER].add(key) + return False + + cat_dim = self._parent().__cat_dim__(key, value, self) + num_nodes, num_edges = self.num_nodes, self.num_edges + + if value.shape[cat_dim] != num_edges: + if value.shape[cat_dim] == num_nodes: + self._cached_attr[AttrType.NODE].add(key) + else: + self._cached_attr[AttrType.OTHER].add(key) + return False + + if num_edges != num_nodes: + self._cached_attr[AttrType.EDGE].add(key) + return True + + if "edge" in key: + self._cached_attr[AttrType.EDGE].add(key) + return True + else: + self._cached_attr[AttrType.NODE].add(key) + return False + + +def recursive_apply(data: Any, func: Callable) -> Any: + if isinstance(data, ms.Tensor): + return func(data) + elif isinstance(data, tuple) and hasattr(data, "_fields"): + return type(data)(*(recursive_apply(d, func) for d in data)) + elif isinstance(data, Sequence) and not isinstance(data, str): + return [recursive_apply(d, func) for d in data] + elif isinstance(data, Mapping): + return {key: recursive_apply(data[key], func) for key in data} + else: + try: + return func(data) + except Exception: + return data + + +def recursive_apply_(data: Any, func: Callable) -> Any: + if isinstance(data, ms.Tensor): + func(data) + elif isinstance(data, tuple) and hasattr(data, '_fields'): + for value in data: + recursive_apply_(value, func) + elif isinstance(data, Sequence) and not isinstance(data, str): + for value in data: + recursive_apply_(value, func) + elif isinstance(data, Mapping): + for value in data.values(): + recursive_apply_(value, func) + else: + try: + func(data) + except Exception: + pass diff --git a/mindscience/sharker/data/summary.py b/mindscience/sharker/data/summary.py new file mode 100644 index 000000000..cb0515de7 --- /dev/null +++ b/mindscience/sharker/data/summary.py @@ -0,0 +1,157 @@ +from collections import defaultdict +from dataclasses import dataclass +from typing import Dict, List, Optional, Union, Tuple + +import mindspore as ms +from tqdm import tqdm +from typing_extensions import Self + +from .dataset import Dataset +from .heterograph import HeteroGraph + + +@dataclass +class Stats: + mean: float + std: float + min: float + quantile25: float + median: float + quantile75: float + max: float + + @classmethod + def from_data( + cls, + data: Union[List[int], List[float], ms.Tensor], + ) -> Self: + if not isinstance(data, ms.Tensor): + data = ms.Tensor(data) + data = data.float() + + return cls( + mean=data.mean().item(), + std=data.std().item(), + min=data.min().item(), + quantile25=data.quantile(0.25).item(), + median=data.median().item(), + quantile75=data.quantile(0.75).item(), + max=data.max().item(), + ) + + +@dataclass(repr=False) +class Summary: + name: str + num_graphs: int + num_nodes: Stats + num_edges: Stats + num_nodes_per_type: Optional[Dict[str, Stats]] = None + num_edges_per_type: Optional[Dict[Tuple[str, str, str], Stats]] = None + + @classmethod + def from_dataset( + cls, + dataset: Dataset, + progress_bar: Optional[bool] = None, + per_type: bool = True, + ) -> Self: + r"""Creates a summary of a :class:`~sharker.data.Graphset` + object. + + Args: + dataset (Dataset): The dataset. + progress_bar (bool, optional): If set to :obj:`True`, will show a + progress bar during stats computation. If set to :obj:`None`, + will automatically decide whether to show a progress bar based + on dataset size. (default: :obj:`None`) + per_type (bool, optional): If set to :obj:`True`, will separate + statistics per node and edge type (only applicable in + heterogeneous graph datasets). (default: :obj:`True`) + """ + name = dataset.__class__.__name__ + + if progress_bar is None: + progress_bar = len(dataset) >= 10000 + + if progress_bar: + dataset = tqdm(dataset) + + num_nodes, num_edges = [], [] + _num_nodes_per_type = defaultdict(list) + _num_edges_per_type = defaultdict(list) + + for data in dataset: + assert data.num_nodes is not None + num_nodes.append(data.num_nodes) + num_edges.append(data.num_edges) + + if per_type and isinstance(data, HeteroGraph): + for node_type in data.node_types: + _num_nodes_per_type[node_type].append(data[node_type].num_nodes) + for edge_type in data.edge_types: + _num_edges_per_type[edge_type].append(data[edge_type].num_edges) + + num_nodes_per_type = None + if len(_num_nodes_per_type) > 0: + num_nodes_per_type = { + node_type: Stats.from_data(num_nodes_list) + for node_type, num_nodes_list in _num_nodes_per_type.items() + } + + num_edges_per_type = None + if len(_num_edges_per_type) > 0: + num_edges_per_type = { + edge_type: Stats.from_data(num_edges_list) + for edge_type, num_edges_list in _num_edges_per_type.items() + } + + return cls( + name=name, + num_graphs=len(dataset), + num_nodes=Stats.from_data(num_nodes), + num_edges=Stats.from_data(num_edges), + num_nodes_per_type=num_nodes_per_type, + num_edges_per_type=num_edges_per_type, + ) + + def __repr__(self) -> str: + from tabulate import tabulate + + body = f"{self.name} (#graphs={self.num_graphs}):\n" + + content = [["", "#nodes", "#edges"]] + stats = [self.num_nodes, self.num_edges] + for field in Stats.__dataclass_fields__: + row = [field] + [f"{getattr(s, field):.1f}" for s in stats] + content.append(row) + body += tabulate(content, headers="firstrow", tablefmt="psql") + + if self.num_nodes_per_type is not None: + content = [[""]] + content[0] += list(self.num_nodes_per_type.keys()) + + for field in Stats.__dataclass_fields__: + row = [field] + [ + f"{getattr(s, field):.1f}" for s in self.num_nodes_per_type.values() + ] + content.append(row) + body += "\nNumber of nodes per node type:\n" + body += tabulate(content, headers="firstrow", tablefmt="psql") + + if self.num_edges_per_type is not None: + content = [[""]] + content[0] += [ + f"({', '.join(edge_type)})" + for edge_type in self.num_edges_per_type.keys() + ] + + for field in Stats.__dataclass_fields__: + row = [field] + [ + f"{getattr(s, field):.1f}" for s in self.num_edges_per_type.values() + ] + content.append(row) + body += "\nNumber of edges per edge type:\n" + body += tabulate(content, headers="firstrow", tablefmt="psql") + + return body diff --git a/mindscience/sharker/data/temporal.py b/mindscience/sharker/data/temporal.py new file mode 100644 index 000000000..9f672eef8 --- /dev/null +++ b/mindscience/sharker/data/temporal.py @@ -0,0 +1,315 @@ +import copy +from typing import ( + Any, + Dict, + Iterable, + List, + NamedTuple, + Optional, + Tuple, + Union, +) + +import numpy as np +import mindspore as ms +from mindspore import ops, mint +from typing_extensions import Self + +from .graph import Graph, size_repr +from .storage import ( + BaseStorage, + EdgeStorage, + GlobalStorage, + NodeStorage, +) + + +class TemporalGraph(Graph): + r"""A data object composed by a stream of events describing a temporal + graph. + The :class:`~sharker.data.TemporalGraph` object can hold a list of + events (that can be understood as temporal edges in a graph) with + structured messages. + An event is composed by a source node, a destination node, a timestamp + and a message. Any *Continuous-Time Dynamic Graph* (CTDG) can be + represented with these four values. + + In general, :class:`~sharker.data.TemporalGraph` tries to mimic + the behavior of a regular :python:`Python` dictionary. + In addition, it provides useful functionality for analyzing graph + structures, and provides basic MindSpore tensor functionalities. + + .. code-block:: python + + from mindspore import + from mindscience.sharker.data import TemporalData + + events = TemporalData( + src=Tensor([1,2,3,4]), + dst=Tensor([2,3,4,5]), + t=Tensor([1000,1010,1100,2000]), + msg=Tensor([1,1,0,0]) + ) + + # Add additional arguments to `events`: + events.y = Tensor([1,1,0,0]) + + # It is also possible to set additional arguments in the constructor + events = TemporalData( + ..., + y=Tensor([1,1,0,0]) + ) + + # Get the number of events: + events.num_events + >>> 4 + + # Analyzing the graph structure: + events.num_nodes + >>> 5 + + # MindSpore tensor functionality: + events = events.pin_memory() + + Args: + src (Tensor, optional): A list of source nodes for the events + with shape :obj:`[num_events]`. (default: :obj:`None`) + dst (Tensor, optional): A list of destination nodes for the + events with shape :obj:`[num_events]`. (default: :obj:`None`) + t (Tensor, optional): The timestamps for each event with shape + :obj:`[num_events]`. (default: :obj:`None`) + msg (Tensor, optional): Messages feature matrix with shape + :obj:`[num_events, num_msg_features]`. (default: :obj:`None`) + **kwargs (optional): Additional attributes. + + .. note:: + The shape of :obj:`src`, :obj:`dst`, :obj:`t` and the first dimension + of :obj`msg` should be the same (:obj:`num_events`). + """ + + def __init__( + self, + src: Optional[ms.Tensor] = None, + dst: Optional[ms.Tensor] = None, + t: Optional[ms.Tensor] = None, + msg: Optional[ms.Tensor] = None, + **kwargs, + ): + super().__init__() + self.__dict__["_store"] = GlobalStorage(_parent=self) + + self.src = src + self.dst = dst + self.t = t + self.msg = msg + + for key, value in kwargs.items(): + setattr(self, key, value) + + @classmethod + def from_dict(cls, mapping: Dict[str, Any]) -> Self: + r"""Creates a :class:`~sharker.data.TemporalGraph` object from + a Python dictionary. + """ + return cls(**mapping) + + def index_select(self, idx: Any) -> Self: + idx = prepare_idx(idx) + data = copy.copy(self) + for key, value in data._store.items(): + if value.shape[0] == self.num_events: + data[key] = value[idx] + return data + + def __getitem__(self, idx: Any) -> Any: + if isinstance(idx, str): + return self._store[idx] + return self.index_select(idx) + + def __setitem__(self, key: str, value: Any): + """Sets the attribute :obj:`key` to :obj:`value`.""" + self._store[key] = value + + def __delitem__(self, key: str): + if key in self._store: + del self._store[key] + + def __getattr__(self, key: str) -> Any: + if "_store" not in self.__dict__: + raise RuntimeError( + "The 'data' object was created by an older version of MindGeometric. " + "If this error occurred while loading an already existing " + "dataset, remove the 'processed/' directory in the dataset's " + "root folder and try again." + ) + return getattr(self._store, key) + + def __setattr__(self, key: str, value: Any): + setattr(self._store, key, value) + + def __delattr__(self, key: str): + delattr(self._store, key) + + def __iter__(self) -> Iterable: + for i in range(self.num_events): + yield self[i] + + def __len__(self) -> int: + return self.num_events + + def __call__(self, *args: List[str]) -> Iterable: + for key, value in self._store.items(*args): + yield key, value + + def __copy__(self): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = value + out.__dict__["_store"] = copy.copy(self._store) + out._store._parent = out + return out + + def __deepcopy__(self, memo): + out = self.__class__.__new__(self.__class__) + for key, value in self.__dict__.items(): + out.__dict__[key] = copy.deepcopy(value, memo) + out._store._parent = out + return out + + def stores_as(self, data: Self): + return self + + @property + def stores(self) -> List[BaseStorage]: + return [self._store] + + @property + def node_stores(self) -> List[NodeStorage]: + return [self._store] + + @property + def edge_stores(self) -> List[EdgeStorage]: + return [self._store] + + def to_dict(self) -> Dict[str, Any]: + return self._store.to_dict() + + def to_namedtuple(self) -> NamedTuple: + return self._store.to_namedtuple() + + def debug(self): + pass # TODO + + @property + def num_nodes(self) -> int: + r"""Returns the number of nodes in the graph.""" + return max(int(self.src.max()), int(self.dst.max())) + 1 + + @property + def num_events(self) -> int: + r"""Returns the number of events loaded. + + .. note:: + In a :class:`~sharker.data.TemporalGraph`, each row denotes + an event. + Thus, they can be also understood as edges. + """ + return self.src.shape[0] + + @property + def num_edges(self) -> int: + r"""Alias for :meth:`~sharker.data.TemporalGraph.num_events`.""" + return self.num_events + + @property + def edge_index(self) -> ms.Tensor: + r"""Returns the edge indices of the graph.""" + if "edge_index" in self: + return self._store["edge_index"] + if self.src is not None and self.dst is not None: + return mint.stack([self.src, self.dst], dim=0) + raise ValueError( + f"{self.__class__.__name__} does not contain " f"'edge_index' information" + ) + + @property + def shape(self) -> Union[Tuple[Optional[int], Optional[int]], Optional[int]]: + r"""Returns the size of the adjacency matrix induced by the graph.""" + size = (int(self.src.max()), int(self.dst.max())) + return size + + def __cat_dim__(self, key: str, value: Any, *args, **kwargs) -> Any: + return 0 + + def __inc__(self, key: str, value: Any, *args, **kwargs) -> Any: + if "batch" in key and isinstance(value, ms.Tensor): + return int(value.max()) + 1 + elif key in ["src", "dst"]: + return self.num_nodes + else: + return 0 + + def __repr__(self) -> str: + cls = self.__class__.__name__ + info = ", ".join([size_repr(k, v) for k, v in self._store.items()]) + return f"{cls}({info})" + + ########################################################################### + + def train_val_test_split(self, val_ratio: float = 0.15, test_ratio: float = 0.15): + r"""Splits the data in training, validation and test sets according to + time. + + Args: + val_ratio (float, optional): The proportion (in percents) of the + dataset to include in the validation split. + (default: :obj:`0.15`) + test_ratio (float, optional): The proportion (in percents) of the + dataset to include in the test split. (default: :obj:`0.15`) + """ + val_time, test_time = np.quantile( + self.t.asnumpy(), [1.0 - val_ratio - test_ratio, 1.0 - test_ratio] + ) + + val_idx = int((self.t <= val_time).sum()) + test_idx = int((self.t <= test_time).sum()) + + return self[:val_idx], self[val_idx:test_idx], self[test_idx:] + + ########################################################################### + + def coalesce(self): + raise NotImplementedError + + def has_isolated_nodes(self) -> bool: + raise NotImplementedError + + def has_self_loops(self) -> bool: + raise NotImplementedError + + def is_undirected(self) -> bool: + raise NotImplementedError + + def is_directed(self) -> bool: + raise NotImplementedError + + +############################################################################### + + +def prepare_idx(idx): + if isinstance(idx, int): + return slice(idx, idx + 1) + if isinstance(idx, (list, tuple)): + return ms.Tensor(idx) + elif isinstance(idx, slice): + return idx + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.int64: + return idx + elif isinstance(idx, ms.Tensor) and idx.dtype == ms.bool_: + return idx + + raise IndexError( + f"Only strings, integers, slices (`:`), list, tuples, and long or " + f"bool tensors are valid indices (got '{type(idx).__name__}')" + ) diff --git a/mindscience/sharker/data/view.py b/mindscience/sharker/data/view.py new file mode 100644 index 000000000..2b117398f --- /dev/null +++ b/mindscience/sharker/data/view.py @@ -0,0 +1,39 @@ +from typing import Any, Iterator, List, Mapping, Tuple + + +class MappingView: + def __init__(self, mapping: Mapping[str, Any], *args: str): + self._mapping = mapping + self._args = args + + def _keys(self) -> List[str]: + if len(self._args) == 0: + return list(self._mapping.keys()) + else: + return [arg for arg in self._args if arg in self._mapping] + + def __len__(self) -> int: + return len(self._keys()) + + def __repr__(self) -> str: + mapping = {key: self._mapping[key] for key in self._keys()} + return f'{self.__class__.__name__}({mapping})' + + __class_getitem__ = classmethod(type([])) + + +class KeysView(MappingView): + def __iter__(self) -> Iterator[str]: + yield from self._keys() + + +class ValuesView(MappingView): + def __iter__(self) -> Iterator[Any]: + for key in self._keys(): + yield self._mapping[key] + + +class ItemsView(MappingView): + def __iter__(self) -> Iterator[Tuple[str, Any]]: + for key in self._keys(): + yield (key, self._mapping[key]) diff --git a/mindscience/sharker/dataset/__init__.py b/mindscience/sharker/dataset/__init__.py new file mode 100644 index 000000000..2f72f8f21 --- /dev/null +++ b/mindscience/sharker/dataset/__init__.py @@ -0,0 +1 @@ +from .qm9 import QM9 diff --git a/mindscience/sharker/dataset/qm9.py b/mindscience/sharker/dataset/qm9.py new file mode 100644 index 000000000..2cca5c4ef --- /dev/null +++ b/mindscience/sharker/dataset/qm9.py @@ -0,0 +1,332 @@ +import os +import os.path as osp +from typing import Callable, List, Optional + +import numpy as np +from mindspore import Tensor, ops +from tqdm import tqdm + +from ..data import ( + Graph, + InMemoryDataset, + download_url, + extract_zip, +) +from ..utils import scatter + +HAR2EV = 27.211386246 +KCALMOL2EV = 0.04336414 + +conversion = np.array( + [ + 1.0, + 1.0, + HAR2EV, + HAR2EV, + HAR2EV, + 1.0, + HAR2EV, + HAR2EV, + HAR2EV, + HAR2EV, + HAR2EV, + 1.0, + KCALMOL2EV, + KCALMOL2EV, + KCALMOL2EV, + KCALMOL2EV, + 1.0, + 1.0, + 1.0, + ] +) + +atomrefs = { + 6: [0.0, 0.0, 0.0, 0.0, 0.0], + 7: [-13.61312172, -1029.86312267, -1485.30251237, -2042.61123593, -2713.48485589], + 8: [-13.5745904, -1029.82456413, -1485.26398105, -2042.5727046, -2713.44632457], + 9: [-13.54887564, -1029.79887659, -1485.2382935, -2042.54701705, -2713.42063702], + 10: [-13.90303183, -1030.25891228, -1485.71166277, -2043.01812778, -2713.88796536], + 11: [0.0, 0.0, 0.0, 0.0, 0.0], +} + + +class QM9(InMemoryDataset): + r"""The QM9 dataset from the `"MoleculeNet: A Benchmark for Molecular + Machine Learning" `_ paper, consisting of + about 130,000 molecules with 19 regression targets. + Each molecule includes complete spatial information for the single low + energy conformation of the atoms in the molecule. + In addition, we provide the atom features from the `"Neural Message + Passing for Quantum Chemistry" `_ paper. + + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | Target | Property | Description | Unit | + +========+==================================+===================================================================================+=============================================+ + | 0 | :math:`\mu` | Dipole moment | :math:`\textrm{D}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 1 | :math:`\alpha` | Isotropic polarizability | :math:`{a_0}^3` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 2 | :math:`\epsilon_{\textrm{HOMO}}` | Highest occupied molecular orbital energy | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 3 | :math:`\epsilon_{\textrm{LUMO}}` | Lowest unoccupied molecular orbital energy | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 4 | :math:`\Delta \epsilon` | Gap between :math:`\epsilon_{\textrm{HOMO}}` and :math:`\epsilon_{\textrm{LUMO}}` | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 5 | :math:`\langle R^2 \rangle` | Electronic spatial extent | :math:`{a_0}^2` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 6 | :math:`\textrm{ZPVE}` | Zero point vibrational energy | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 7 | :math:`U_0` | Internal energy at 0K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 8 | :math:`U` | Internal energy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 9 | :math:`H` | Enthalpy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 10 | :math:`G` | Free energy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 11 | :math:`c_{\textrm{v}}` | Heat capavity at 298.15K | :math:`\frac{\textrm{cal}}{\textrm{mol K}}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 12 | :math:`U_0^{\textrm{ATOM}}` | Atomization energy at 0K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 13 | :math:`U^{\textrm{ATOM}}` | Atomization energy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 14 | :math:`H^{\textrm{ATOM}}` | Atomization enthalpy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 15 | :math:`G^{\textrm{ATOM}}` | Atomization free energy at 298.15K | :math:`\textrm{eV}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 16 | :math:`A` | Rotational constant | :math:`\textrm{GHz}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 17 | :math:`B` | Rotational constant | :math:`\textrm{GHz}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + | 18 | :math:`C` | Rotational constant | :math:`\textrm{GHz}` | + +--------+----------------------------------+-----------------------------------------------------------------------------------+---------------------------------------------+ + + .. note:: + + We also provide a pre-processed version of the dataset in case + :class:`rdkit` is not installed. The pre-processed version matches with + the manually processed version as outlined in :meth:`process`. + + Args: + root (str): Root directory where the dataset should be saved. + transform (callable, optional): A function/transform that takes in an + :obj:`sharker.data.Graph` object and returns a transformed + version. The data object will be transformed before every access. + (default: :obj:`None`) + pre_transform (callable, optional): A function/transform that takes in + an :obj:`sharker.data.Graph` object and returns a + transformed version. The data object will be transformed before + being saved to disk. (default: :obj:`None`) + pre_filter (callable, optional): A function that takes in an + :obj:`sharker.data.Graph` object and returns a boolean + value, indicating whether the data object should be included in the + final dataset. (default: :obj:`None`) + force_reload (bool, optional): Whether to re-process the dataset. + (default: :obj:`False`) + + **STATS:** + + .. list-table:: + :widths: 10 10 10 10 10 + :header-rows: 1 + + * - #graphs + - #nodes + - #edges + - #features + - #tasks + * - 130,831 + - ~18.0 + - ~37.3 + - 11 + - 19 + """ # noqa: E501 + + raw_url = ( + "https://deepchemdata.s3-us-west-1.amazonaws.com/datasets/" + "molnet_publish/qm9.zip" + ) + raw_url2 = "https://figshare.com/files/3195404" + processed_url = "https://data.pyg.org/datasets/qm9_v3.zip" + + def __init__( + self, + root: str, + transform: Optional[Callable] = None, + pre_transform: Optional[Callable] = None, + pre_filter: Optional[Callable] = None, + force_reload: bool = False, + ) -> None: + super().__init__( + root, transform, pre_transform, pre_filter, force_reload=force_reload + ) + self.load(self.processed_paths[0]) + + def mean(self, target: int) -> float: + y = np.concatenate([self.get(i).y for i in range(len(self))], axis=0) + return float(y[:, target].mean()) + + def std(self, target: int) -> float: + y = np.concatenate([self.get(i).y for i in range(len(self))], axis=0) + return float(y[:, target].std()) + + def atomref(self, target: int) -> Optional[Tensor]: + if target in atomrefs: + out = np.zeros(100) + out[np.array([1, 6, 7, 8, 9])] = np.array(atomrefs[target]) + return out.reshape(-1, 1) + return None + + @property + def raw_file_names(self) -> List[str]: + try: + return ["gdb9.sdf", "gdb9.sdf.csv", "uncharacterized.txt"] + except ImportError: + return ["qm9_v3.pt"] + + @property + def processed_file_names(self) -> str: + return "data_v3.pt" + + def download(self) -> None: + try: + file_path = download_url(self.raw_url, self.raw_dir) + extract_zip(file_path, self.raw_dir) + # os.unlink(file_path) + + file_path = download_url(self.raw_url2, self.raw_dir) + os.rename( + osp.join(self.raw_dir, "3195404"), + osp.join(self.raw_dir, "uncharacterized.txt"), + ) + except ImportError: + path = download_url(self.processed_url, self.raw_dir) + extract_zip(path, self.raw_dir) + # os.unlink(path) + + def process(self) -> None: + from rdkit import Chem, RDLogger + from rdkit.Chem.rdchem import BondType as BT + from rdkit.Chem.rdchem import HybridizationType + + RDLogger.DisableLog("rdApp.*") + + # if rdkit is None: + # print( + # ( + # "Using a pre-processed version of the dataset. Please " + # "install 'rdkit' to alternatively process the raw data." + # ), + # file=sys.stderr, + # ) + + # data_list = ms.load(self.raw_paths[0]) + # data_list = [Graph(**data_dict) for data_dict in data_list] + + # if self.pre_filter is not None: + # data_list = [d for d in data_list if self.pre_filter(d)] + + # if self.pre_transform is not None: + # data_list = [self.pre_transform(d) for d in data_list] + + # self.save(data_list, self.processed_paths[0]) + # return + + types = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4} + bonds = {BT.SINGLE: 0, BT.DOUBLE: 1, BT.TRIPLE: 2, BT.AROMATIC: 3} + + with open(self.raw_paths[1], "r") as f: + target = [ + [float(x) for x in line.split(",")[1:20]] + for line in f.read().split("\n")[1:-1] + ] + y = np.array(target, dtype=np.float32) + y = np.concatenate([y[:, 3:], y[:, :3]], axis=-1) + y = y * conversion.reshape(1, -1) + + with open(self.raw_paths[2], "r") as f: + skip = [int(x.split()[0]) - 1 for x in f.read().split("\n")[9:-2]] + + suppl = Chem.SDMolSupplier(self.raw_paths[0], removeHs=False, sanitize=False) + + data_list = [] + for i, mol in enumerate(tqdm(suppl)): + if i == 100: + break + if i in skip: + continue + + N = mol.GetNumAtoms() + + conf = mol.GetConformer() + crd = conf.GetPositions() + crd = np.array(crd, np.float32) + + type_idx = [] + atomic_number = [] + aromatic = [] + sp = [] + sp2 = [] + sp3 = [] + num_hs = [] + for atom in mol.GetAtoms(): + type_idx.append(types[atom.GetSymbol()]) + atomic_number.append(atom.GetAtomicNum()) + aromatic.append(1 if atom.GetIsAromatic() else 0) + hybridization = atom.GetHybridization() + sp.append(1 if hybridization == HybridizationType.SP else 0) + sp2.append(1 if hybridization == HybridizationType.SP2 else 0) + sp3.append(1 if hybridization == HybridizationType.SP3 else 0) + + z = np.array(atomic_number, dtype=np.int64) + + rows, cols, edge_types = [], [], [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + rows += [start, end] + cols += [end, start] + edge_types += 2 * [bonds[bond.GetBondType()]] + + edge_index = np.array([rows, cols], dtype=np.int64) + edge_type = np.array(edge_types, dtype=np.int64) + edge_attr = np.eye(len(bonds))[edge_type] + + perm = (edge_index[0] * N + edge_index[1]).argsort() + edge_index = edge_index[:, perm] + edge_type = edge_type[perm] + edge_attr = edge_attr[perm] + + row, col = edge_index + hs = (z == 1).astype(np.float32) + num_hs = np.zeros(N) + np.add.at(num_hs, col, hs[row]) + num_hs.tolist() + + x1 = np.eye(len(types))[np.array(type_idx)].astype(np.float32) + x2 = np.array([atomic_number, aromatic, sp, sp2, sp3, num_hs]).astype(np.float32).T + x = np.concatenate([x1, x2], axis=-1) + + name = mol.GetProp("_Name") + smiles = Chem.MolToSmiles(mol, isomericSmiles=True) + + data = Graph( + x=x, + z=z, + crd=crd, + edge_index=edge_index, + smiles=smiles, + edge_attr=edge_attr, + y=np.expand_dims(y[i], axis=0), + name=name, + idx=i, + ) + + if self.pre_filter is not None and not self.pre_filter(data): + continue + if self.pre_transform is not None: + data = self.pre_transform(data) + + data_list.append(data) + + self.save(data_list, self.processed_paths[0]) diff --git a/mindscience/sharker/experimental.py b/mindscience/sharker/experimental.py new file mode 100644 index 000000000..55785195a --- /dev/null +++ b/mindscience/sharker/experimental.py @@ -0,0 +1,136 @@ +import functools +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + + +__experimental_flag__: Dict[str, bool] = { + "disable_dynamic_shapes": False, +} + +Options = Optional[Union[str, List[str]]] + + +def get_options(options: Options) -> List[str]: + if options is None: + options = list(__experimental_flag__.keys()) + if isinstance(options, str): + options = [options] + return options + + +def is_experimental_mode_enabled(options: Options = None) -> bool: + r"""Returns :obj:`True` if the experimental mode is enabled. See + :class:`sharker.experimental_mode` for a list of (optional) + options. + """ + options = get_options(options) + return all([__experimental_flag__[option] for option in options]) + + +def set_experimental_mode_enabled(mode: bool, options: Options = None) -> None: + for option in get_options(options): + __experimental_flag__[option] = mode + + +class experimental_mode: + r"""Context-manager that enables the experimental mode to test new but + potentially unstable features. + + .. code-block:: python + + with sharker.experimental_mode(): + out = model(data.x, data.edge_index) + + Args: + options (str or list, optional): Currently there are no experimental + features. + """ + + def __init__(self, options: Options = None) -> None: + self.options = get_options(options) + self.previous_state = { + option: __experimental_flag__[option] for option in self.options + } + + def __enter__(self) -> None: + set_experimental_mode_enabled(True, self.options) + + def __exit__(self, *args: Any) -> None: + for option, value in self.previous_state.items(): + __experimental_flag__[option] = value + + +class set_experimental_mode: + r"""Context-manager that sets the experimental mode on or off. + + :class:`set_experimental_mode` will enable or disable the experimental mode + based on its argument :attr:`mode`. + It can be used as a context-manager or as a function. + + See :class:`experimental_mode` above for more details. + """ + + def __init__(self, mode: bool, options: Options = None) -> None: + self.options = get_options(options) + self.previous_state = { + option: __experimental_flag__[option] for option in self.options + } + set_experimental_mode_enabled(mode, self.options) + + def __enter__(self) -> None: + pass + + def __exit__(self, *args: Any) -> None: + for option, value in self.previous_state.items(): + __experimental_flag__[option] = value + + +def disable_dynamic_shapes(required_args: List[str]) -> Callable: + r"""A decorator that disables the usage of dynamic shapes for the given + arguments, i.e., it will raise an error in case :obj:`required_args` are + not passed and needs to be automatically inferred. + """ + + def decorator(func: Callable) -> Callable: + spec = inspect.getfullargspec(func) + + required_args_pos: Dict[str, int] = {} + for arg_name in required_args: + if arg_name not in spec.args: + raise ValueError( + f"The function '{func}' does not have a " f"'{arg_name}' argument" + ) + required_args_pos[arg_name] = spec.args.index(arg_name) + + num_args = len(spec.args) + num_default_args = 0 if spec.defaults is None else len(spec.defaults) + num_positional_args = num_args - num_default_args + + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> Any: + if not is_experimental_mode_enabled("disable_dynamic_shapes"): + return func(*args, **kwargs) + + for required_arg in required_args: + index = required_args_pos[required_arg] + + value: Optional[Any] = None + if index < len(args): + value = args[index] + elif required_arg in kwargs: + value = kwargs[required_arg] + elif num_default_args > 0: + assert spec.defaults is not None + value = spec.defaults[index - num_positional_args] + + if value is None: + raise ValueError( + f"Dynamic shapes disabled. Argument " + f"'{required_arg}' needs to be set" + ) + + return func(*args, **kwargs) + + return wrapper + + return decorator diff --git a/mindscience/sharker/home.py b/mindscience/sharker/home.py new file mode 100644 index 000000000..677e5553e --- /dev/null +++ b/mindscience/sharker/home.py @@ -0,0 +1,30 @@ +import os +import os.path as osp +from typing import Optional + +ENV_PYG_HOME = "SHARKER_HOME" +DEFAULT_CACHE_DIR = osp.join("~", ".cache", "sharker") + +_home_dir: Optional[str] = None + + +def get_home_dir() -> str: + r"""Get the cache directory used for storing all :pyg:`PyG`-related data. + + If :meth:`set_home_dir` is not called, the path is given by the environment + variable :obj:`$PYG_HOME` which defaults to :obj:`"~/.cache/pyg"`. + """ + if _home_dir is not None: + return _home_dir + + return osp.expanduser(os.getenv(ENV_PYG_HOME, DEFAULT_CACHE_DIR)) + + +def set_home_dir(path: str) -> None: + r"""Set the cache directory used for storing all :pyg:`PyG`-related data. + + Args: + path (str): The path to a local folder. + """ + global _home_dir + _home_dir = path diff --git a/mindscience/sharker/inspector.py b/mindscience/sharker/inspector.py new file mode 100644 index 000000000..0c07028cc --- /dev/null +++ b/mindscience/sharker/inspector.py @@ -0,0 +1,554 @@ +import inspect +import re +import sys +from typing import Any, Callable, Dict, List, NamedTuple, Optional, Type, Union + +from mindspore import Tensor, nn +import typing + + +class Parameter(NamedTuple): + name: str + type: Type + type_repr: str + default: Any + + +class Signature(NamedTuple): + param_dict: Dict[str, Parameter] + return_type: Type + return_type_repr: str + + +class Inspector: + r"""Inspects a given class and collects information about its instance + methods. + + Args: + cls (Type): The class to inspect. + """ + + def __init__(self, cls: Type): + self._cls = cls + self._signature_dict: Dict[str, Signature] = {} + self._source_dict: Dict[str, str] = {} + + def _get_modules(self, cls: Type) -> List[str]: + from .nn.conv import MessagePassing + + modules: List[str] = [] + for base_cls in cls.__bases__: + if base_cls not in {object, nn.Cell, MessagePassing}: + modules.extend(self._get_modules(base_cls)) + + modules.append(cls.__module__) + return modules + + @property + def _modules(self) -> List[str]: + return self._get_modules(self._cls) + + @property + def _globals(self) -> Dict[str, Any]: + out: Dict[str, Any] = {} + for module in self._modules: + out.update(sys.modules[module].__dict__) + return out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self._cls.__name__})" + + def eval_type(self, value: Any) -> Type: + r"""Returns the type hint of a string.""" + return eval_type(value, self._globals) + + def type_repr(self, obj: Any) -> str: + r"""Returns the type hint representation of an object.""" + return type_repr(obj, self._globals) + + def implements(self, func_name: str) -> bool: + r"""Returns :obj:`True` in case the inspected class implements the + :obj:`func_name` method. + + Args: + func_name (str): The function name to check for existence. + """ + func = getattr(self._cls, func_name, None) + if not callable(func): + return False + return not getattr(func, "__isabstractmethod__", False) + + # Inspecting Method Signatures ############################################ + + def inspect_signature( + self, + func: Union[Callable, str], + exclude: Optional[List[Union[str, int]]] = None, + ) -> Signature: + r"""Inspects the function signature of :obj:`func` and returns a tuple + of parameter types and return type. + + Args: + func (callabel or str): The function. + exclude (list[int or str]): A list of parameters to exclude, either + given by their name or index. (default: :obj:`None`) + """ + if isinstance(func, str): + func = getattr(self._cls, func) + assert callable(func) + + if func.__name__ in self._signature_dict: + return self._signature_dict[func.__name__] + + signature = inspect.signature(func) + params = [p for p in signature.parameters.values() if p.name != "self"] + + param_dict: Dict[str, Parameter] = {} + for i, param in enumerate(params): + if exclude is not None and (i in exclude or param.name in exclude): + continue + + param_type = param.annotation + # Mimic TorchScript to auto-infer `Tensor` on non-present types: + param_type = Tensor if param_type is inspect._empty else param_type + + param_dict[param.name] = Parameter( + name=param.name, + type=self.eval_type(param_type), + type_repr=self.type_repr(param_type), + default=param.default, + ) + + return_type = signature.return_annotation + # Mimic TorchScript to auto-infer `Tensor` on non-present types: + return_type = Tensor if return_type is inspect._empty else return_type + + self._signature_dict[func.__name__] = Signature( + param_dict=param_dict, + return_type=self.eval_type(return_type), + return_type_repr=self.type_repr(return_type), + ) + + return self._signature_dict[func.__name__] + + def get_signature( + self, + func: Union[Callable, str], + exclude: Optional[List[str]] = None, + ) -> Signature: + r"""Returns the function signature of the inspected function + :obj:`func`. + + Args: + func (callabel or str): The function. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + func_name = func if isinstance(func, str) else func.__name__ + signature = self._signature_dict.get(func_name) + if signature is None: + raise IndexError( + f"Could not access signature for function " + f"'{func_name}'. Did you forget to inspect it?" + ) + + if exclude is None: + return signature + + param_dict = { + name: param + for name, param in signature.param_dict.items() + if name not in exclude + } + return Signature( + param_dict=param_dict, + return_type=signature.return_type, + return_type_repr=signature.return_type_repr, + ) + + def remove_signature( + self, + func: Union[Callable, str], + ) -> Optional[Signature]: + r"""Removes the inspected function signature :obj:`func`. + + Args: + func (callabel or str): The function. + """ + func_name = func if isinstance(func, str) else func.__name__ + return self._signature_dict.pop(func_name, None) + + def get_param_dict( + self, + func: Union[Callable, str], + exclude: Optional[List[str]] = None, + ) -> Dict[str, Parameter]: + r"""Returns the parameters of the inspected function :obj:`func`. + + Args: + func (str or callable): The function. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + return self.get_signature(func, exclude).param_dict + + def get_params( + self, + func: Union[Callable, str], + exclude: Optional[List[str]] = None, + ) -> List[Parameter]: + r"""Returns the parameters of the inspected function :obj:`func`. + + Args: + func (str or callable): The function. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + return list(self.get_param_dict(func, exclude).values()) + + def get_flat_param_dict( + self, + funcs: List[Union[Callable, str]], + exclude: Optional[List[str]] = None, + ) -> Dict[str, Parameter]: + r"""Returns the union of parameters of all inspected functions in + :obj:`funcs`. + + Args: + funcs (list[str or callable]): The functions. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + param_dict: Dict[str, Parameter] = {} + for func in funcs: + params = self.get_params(func, exclude) + for param in params: + expected = param_dict.get(param.name) + if expected is not None and param.type != expected.type: + raise ValueError( + f"Found inconsistent types for argument " + f"'{param.name}'. Expected type " + f"'{expected.type}' but found type " + f"'{param.type}'." + ) + + if expected is not None and param.default != expected.default: + if ( + param.default is not inspect._empty + and expected.default is not inspect._empty + ): + raise ValueError( + f"Found inconsistent defaults for " + f"argument '{param.name}'. Expected " + f"'{expected.default}' but found " + f"'{param.default}'." + ) + + default = expected.default + if default is inspect._empty: + default = param.default + + param_dict[param.name] = Parameter( + name=param.name, + type=param.type, + type_repr=param.type_repr, + default=default, + ) + + if expected is None: + param_dict[param.name] = param + + return param_dict + + def get_flat_params( + self, + funcs: List[Union[Callable, str]], + exclude: Optional[List[str]] = None, + ) -> List[Parameter]: + r"""Returns the union of parameters of all inspected functions in + :obj:`funcs`. + + Args: + funcs (list[str or callable]): The functions. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + return list(self.get_flat_param_dict(funcs, exclude).values()) + + def get_param_names( + self, + func: Union[Callable, str], + exclude: Optional[List[str]] = None, + ) -> List[str]: + r"""Returns the parameter names of the inspected function :obj:`func`. + + Args: + func (str or callable): The function. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + return list(self.get_param_dict(func, exclude).keys()) + + def get_flat_param_names( + self, + funcs: List[Union[Callable, str]], + exclude: Optional[List[str]] = None, + ) -> List[str]: + r"""Returns the union of parameter names of all inspected functions in + :obj:`funcs`. + + Args: + funcs (list[str or callable]): The functions. + exclude (list[str], optional): The parameter names to exclude. + (default: :obj:`None`) + """ + return list(self.get_flat_param_dict(funcs, exclude).keys()) + + def collect_param_data( + self, + func: Union[Callable, str], + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + r"""Collects the input data of the inspected function :obj:`func` + according to its function signature from a data blob. + + Args: + func (callabel or str): The function. + kwargs (dict[str, Any]): The data blob which may serve as inputs. + """ + out_dict: Dict[str, Any] = {} + for param in self.get_params(func): + if param.name not in kwargs: + if param.default is inspect._empty: + raise TypeError(f"Parameter '{param.name}' is required") + out_dict[param.name] = param.default + else: + out_dict[param.name] = kwargs[param.name] + return out_dict + + # Inspecting Method Bodies ################################################ + + def get_source(self, cls: Optional[Type] = None) -> str: + r"""Returns the source code of :obj:`cls`.""" + from .nn import MessagePassing + + cls = cls or self._cls + if cls.__name__ in self._source_dict: + return self._source_dict[cls.__name__] + if cls in {object, nn.Cell, MessagePassing}: + return "" + source = inspect.getsource(cls) + self._source_dict[cls.__name__] = source + return source + + def get_params_from_method_call( + self, + func: Union[Callable, str], + exclude: Optional[List[Union[int, str]]] = None, + ) -> Dict[str, Parameter]: + r"""Parses a method call of :obj:`func` and returns its keyword + arguments. + + .. note:: + The method is required to be called via keyword arguments in case + type annotations are not found. + + Args: + func (callabel or str): The function. + exclude (list[int or str]): A list of parameters to exclude, either + given by their name or index. (default: :obj:`None`) + """ + func_name = func if isinstance(func, str) else func.__name__ + param_dict: Dict[str, Parameter] = {} + + # Three ways to specify the parameters of an unknown function header: + # 1. Defined as class attributes in `{func_name}_type`. + # 2. Defined via type annotations in `# {func_name}_type: (...)`. + # 3. Defined via parsing of the function call. + + # (1) Find class attribute: + if hasattr(self._cls, f"{func_name}_type"): + type_dict = getattr(self._cls, f"{func_name}_type") + if not isinstance(type_dict, dict): + raise ValueError( + f"'{func_name}_type' is expected to be a " + f"dictionary (got '{type(type_dict)}')" + ) + + for name, param_type in type_dict.items(): + param_dict[name] = Parameter( + name=name, + type=self.eval_type(param_type), + type_repr=self.type_repr(param_type), + default=inspect._empty, + ) + return param_dict + + # (2) Find type annotation: + for cls in self._cls.__mro__: + source = self.get_source(cls) + match = find_parenthesis_content(source, f"{func_name}_type:") + if match is not None: + for arg in split(match, sep=","): + name_and_type_repr = re.split(r"\s*:\s*", arg) + if len(name_and_type_repr) != 2: + raise ValueError( + f"Could not parse argument '{arg}' " + f"of '{func_name}_type' annotation" + ) + + name, type_repr = name_and_type_repr + param_dict[name] = Parameter( + name=name, + type=self.eval_type(type_repr), + type_repr=type_repr, + default=inspect._empty, + ) + return param_dict + + # (3) Parse the function call: + for cls in self._cls.__mro__: + source = self.get_source(cls) + source = remove_comments(source) + match = find_parenthesis_content(source, f"self.{func_name}") + if match is not None: + for i, kwarg in enumerate(split(match, sep=",")): + if exclude is not None and i in exclude: + continue + + name_and_content = re.split(r"\s*=\s*", kwarg) + if len(name_and_content) != 2: + raise ValueError( + f"Could not parse keyword argument " + f"'{kwarg}' in 'self.{func_name}()'" + ) + + name, _ = name_and_content + + if exclude is not None and name in exclude: + continue + + param_dict[name] = Parameter( + name=name, + type=Tensor, + type_repr=self.type_repr(Tensor), + default=inspect._empty, + ) + return param_dict + + return {} # (4) No function call found: + + +def eval_type(value: Any, _globals: Dict[str, Any]) -> Type: + r"""Returns the type hint of a string.""" + if isinstance(value, str): + value = typing.ForwardRef(value) + return typing._eval_type(value, _globals, None) # type: ignore + + +def type_repr(obj: Any, _globals: Dict[str, Any]) -> str: + r"""Returns the type hint representation of an object.""" + + def _get_name(name: str, module: str) -> str: + return name if name in _globals else f"{module}.{name}" + + if isinstance(obj, str): + return obj + + if obj is type(None): + return "None" + + if obj is ...: + return "..." + + if obj.__module__ == "typing": # Special logic for `typing.*` types: + name = obj._name + if name is None: # In some cases, `_name` is not populated. + name = str(obj.__origin__).split(".")[-1] + + args = getattr(obj, "__args__", None) + if args is None or len(args) == 0: + return _get_name(name, obj.__module__) + if all(isinstance(arg, typing.TypeVar) for arg in args): + return _get_name(name, obj.__module__) + + # Convert `Union[*, None]` to `Optional[*]`. + # This is only necessary for old Python versions, e.g. 3.8. + # TODO Only convert to `Optional` if `Optional` is importable. + if ( + name == "Union" + and len(args) == 2 + and any([arg is type(None) for arg in args]) + ): + name = "Optional" + + if name == "Optional": # Remove `None` from `Optional` arguments: + args = [arg for arg in obj.__args__ if arg is not type(None)] + + args_repr = ", ".join([type_repr(arg, _globals) for arg in args]) + return f"{_get_name(name, obj.__module__)}[{args_repr}]" + + if obj.__module__ == "builtins": + return obj.__qualname__ + + return _get_name(obj.__qualname__, obj.__module__) + + +def find_parenthesis_content(source: str, prefix: str) -> Optional[str]: + r"""Returns the content of :obj:`{prefix}.*(...)` within :obj:`source`.""" + match = re.search(prefix, source) + if match is None: + return None + + offset = source[match.start():].find("(") + if offset < 0: + return None + + source = source[match.start()+offset:] + + depth = 0 + for end, char in enumerate(source): + if char == "(": + depth += 1 + if char == ")": + depth -= 1 + if depth == 0: + content = source[1:end] + # Properly handle line breaks and multiple white-spaces: + content = content.replace("\n", " ") + content = content.replace("#", " ") + content = re.sub(" +", " ", content) + content = content.strip() + return content + + return None + + +def split(content: str, sep: str) -> List[str]: + r"""Splits :obj:`content` based on :obj:`sep`. + :obj:`sep` inside parentheses or square brackets are ignored. + """ + assert len(sep) == 1 + outs: List[str] = [] + + start = depth = 0 + for end, char in enumerate(content): + if char == "[" or char == "(": + depth += 1 + elif char == "]" or char == ")": + depth -= 1 + elif char == sep and depth == 0: + outs.append(content[start:end].strip()) + start = end + 1 + if start != len(content): # Respect dangling `sep`: + outs.append(content[start:].strip()) + return outs + + +def remove_comments(content: str) -> str: + content = re.sub(r"\s*#.*", "", content) + content = re.sub(re.compile(r'r"""(.*?)"""', re.DOTALL), "", content) + content = re.sub(re.compile(r'"""(.*?)"""', re.DOTALL), "", content) + content = re.sub(re.compile(r"r'''(.*?)'''", re.DOTALL), "", content) + content = re.sub(re.compile(r"'''(.*?)'''", re.DOTALL), "", content) + return content diff --git a/mindscience/sharker/io/__init__.py b/mindscience/sharker/io/__init__.py new file mode 100644 index 000000000..63522907c --- /dev/null +++ b/mindscience/sharker/io/__init__.py @@ -0,0 +1,18 @@ +from .txt_array import parse_txt_array, read_txt_array +from .tu import read_tu_data +from .ply import read_ply +from .obj import read_obj +from .sdf import read_sdf, parse_sdf +from .npz import read_npz, parse_npz + +__all__ = [ + 'parse_txt_array', + 'read_txt_array', + 'read_tu_data', + 'read_ply', + 'read_obj', + 'read_sdf', + 'parse_sdf', + 'read_npz', + 'parse_npz', +] diff --git a/mindscience/sharker/io/fs.py b/mindscience/sharker/io/fs.py new file mode 100644 index 000000000..df03f485c --- /dev/null +++ b/mindscience/sharker/io/fs.py @@ -0,0 +1,207 @@ +import os +import sys +from typing import Any, Dict, List, Optional, Union +from uuid import uuid4 +import io +import fsspec +import pickle +import mindspore as ms +from ..home import get_home_dir + +DEFAULT_CACHE_PATH = "/tmp/pyg_simplecache" + + +def get_fs(path: str) -> fsspec.AbstractFileSystem: + r"""Get filesystem backend given a path URI to the resource. + + Here are some common example paths and dispatch result: + + * :obj:`"/home/file"` -> + :class:`fsspec.implementations.local.LocalFileSystem` + * :obj:`"memory://home/file"` -> + :class:`fsspec.implementations.memory.MemoryFileSystem` + * :obj:`"https://home/file"` -> + :class:`fsspec.implementations.http.HTTPFileSystem` + * :obj:`"gs://home/file"` -> :class:`gcsfs.GCSFileSystem` + * :obj:`"s3://home/file"` -> :class:`s3fs.S3FileSystem` + + A full list of supported backend implementations of :class:`fsspec` can be + found `here `_. + + The backend dispatch logic can be updated with custom backends following + `this tutorial `_. + + Args: + path (str): The URI to the filesystem location, *e.g.*, + :obj:`"gs://home/me/file"`, :obj:`"s3://..."`. + """ + return fsspec.core.url_to_fs(path)[0] + + +def normpath(path: str) -> str: + if isdisk(path): + return os.path.normpath(path) + return path + + +def exists(path: str) -> bool: + return get_fs(path).exists(path) + + +def makedirs(path: str, exist_ok: bool = True) -> None: + return get_fs(path).makedirs(path, exist_ok) + + +def isdir(path: str) -> bool: + return get_fs(path).isdir(path) + + +def isfile(path: str) -> bool: + return get_fs(path).isfile(path) + + +def isdisk(path: str) -> bool: + return "file" in get_fs(path).protocol + + +def islocal(path: str) -> bool: + return isdisk(path) or "memory" in get_fs(path).protocol + + +def ls( + path: str, + detail: bool = False, +) -> Union[List[str], List[Dict[str, Any]]]: + fs = get_fs(path) + outputs = fs.ls(path, detail=detail) + + if not isdisk(path): + if detail: + for output in outputs: + output["name"] = fs.unstrip_protocol(output["name"]) + else: + outputs = [fs.unstrip_protocol(output) for output in outputs] + + return outputs + + +def cp( + path1: str, + path2: str, + extract: bool = False, + log: bool = True, + use_cache: bool = True, + clear_cache: bool = True, +) -> None: + kwargs: Dict[str, Any] = {} + + is_path1_dir = isdir(path1) + is_path2_dir = isdir(path2) + + # Cache result if the protocol is not local: + cache_dir: Optional[str] = None + if not islocal(path1): + if log and "pytest" not in sys.modules: + print(f"Downloading {path1}", file=sys.stderr) + + if extract and use_cache: # Cache seems to confuse the gcs filesystem. + home_dir = get_home_dir() + cache_dir = os.path.join(home_dir, "simplecache", uuid4().hex) + kwargs.setdefault("simplecache", dict(cache_storage=cache_dir)) + path1 = f"simplecache::{path1}" + + # Handle automatic extraction: + multiple_files = False + if extract and path1.endswith(".tar.gz"): + kwargs.setdefault("tar", dict(compression="gzip")) + path1 = f"tar://**::{path1}" + multiple_files = True + elif extract and path1.endswith(".zip"): + path1 = f"zip://**::{path1}" + multiple_files = True + elif extract and path1.endswith(".gz"): + kwargs.setdefault("compression", "infer") + elif extract: + raise NotImplementedError( + f"Automatic extraction of '{path1}' not yet supported" + ) + + # If the source path points to a directory, we need to make sure to + # recursively copy all files within this directory. Additionally, if the + # destination folder does not yet exist, we inherit the basename from the + # source folder. + if is_path1_dir: + if exists(path2): + path2 = os.path.join(path2, os.path.basename(path1)) + path1 = os.path.join(path1, "**") + multiple_files = True + + # Perform the copy: + for open_file in fsspec.open_files(path1, **kwargs): + with open_file as f_from: + if not multiple_files: + if is_path2_dir: + basename = os.path.basename(path1) + if extract and path1.endswith(".gz"): + basename = ".".join(basename.split(".")[:-1]) + to_path = os.path.join(path2, basename) + else: + to_path = path2 + else: + # Open file has protocol stripped. + common_path = os.path.commonprefix( + [fsspec.core.strip_protocol(path1), open_file.path] + ) + to_path = os.path.join(path2, open_file.path[len(common_path):]) + with fsspec.open(to_path, "wb") as f_to: + while True: + chunk = f_from.read(10 * 1024 * 1024) + if not chunk: + break + f_to.write(chunk) + + if use_cache and clear_cache and cache_dir is not None: + try: + rm(cache_dir) + except Exception: # FIXME + # Windows test yield "PermissionError: The process cannot access + # the file because it is being used by another process". + # Users may also observe "OSError: Directory not empty". + # This is a quick workaround until we figure out the deeper issue. + pass + + +def rm(path: str, recursive: bool = True) -> None: + get_fs(path).rm(path, recursive) + + +def mv(path1: str, path2: str, recursive: bool = True) -> None: + fs1 = get_fs(path1) + fs2 = get_fs(path2) + assert fs1.protocol == fs2.protocol + fs1.mv(path1, path2, recursive) + + +def glob(path: str) -> List[str]: + fs = get_fs(path) + paths = fs.glob(path) + + if not isdisk(path): + paths = [fs.unstrip_protocol(path) for path in paths] + + return paths + + +def pickle_save(data: Any, path: str) -> None: + buffer = io.BytesIO() + pickle.dump(data, buffer) + with fsspec.open(path, 'wb') as f: + f.write(buffer.getvalue()) + + +def pickle_load(path: str) -> Any: + with fsspec.open(path, 'rb') as f: + data = pickle.load(f) + return data diff --git a/mindscience/sharker/io/npz.py b/mindscience/sharker/io/npz.py new file mode 100644 index 000000000..4dc1fdc9a --- /dev/null +++ b/mindscience/sharker/io/npz.py @@ -0,0 +1,36 @@ +from typing import Any, Dict + +import numpy as np +import scipy.sparse as sp +from mindspore import Tensor, ops, nn + +from ..data import Graph +from ..utils import remove_self_loops +from ..utils import to_undirected as to_undirected_fn + + +def read_npz(path: str, to_undirected: bool = True) -> Graph: + with np.load(path) as f: + return parse_npz(f, to_undirected=to_undirected) + + +def parse_npz(f: Dict[str, Any], to_undirected: bool = True) -> Graph: + x = sp.csr_matrix( + (f["attr_data"], f["attr_indices"], f["attr_indptr"]), f["attr_shape"] + ).todense() + x = Tensor.from_numpy(x).float() + x[x > 0] = 1 + + adj = sp.csr_matrix( + (f["adj_data"], f["adj_indices"], f["adj_indptr"]), f["adj_shape"] + ).tocoo() + row = Tensor.from_numpy(adj.row).long() + col = Tensor.from_numpy(adj.col).long() + edge_index = ops.stack([row, col], axis=0) + edge_index, _ = remove_self_loops(edge_index) + if to_undirected: + edge_index = to_undirected_fn(edge_index, num_nodes=x.shape[0]) + + y = Tensor.from_numpy(f["labels"]).long() + + return Graph(x=x, edge_index=edge_index, y=y) diff --git a/mindscience/sharker/io/obj.py b/mindscience/sharker/io/obj.py new file mode 100644 index 000000000..8aac2d5f8 --- /dev/null +++ b/mindscience/sharker/io/obj.py @@ -0,0 +1,41 @@ +from typing import Iterator, List, Optional, Tuple, Union + +from mindspore import Tensor, ops, nn +from ..data import Graph + + +def yield_file(in_file: str) -> Iterator[Tuple[str, List[Union[int, float]]]]: + + f = open(in_file) + buf = f.read() + f.close() + for b in buf.split("\n"): + if b.startswith("v "): + yield "v", [float(x) for x in b.split(" ")[1:]] + elif b.startswith("f "): + triangles = b.split(" ")[1:] + # -1 as .obj is base 1 but the Graph class expects base 0 indices + yield "f", [int(t.split("/")[0]) - 1 for t in triangles] + else: + yield "", [] + + +def read_obj(in_file: str) -> Optional[Graph]: + vertices = [] + faces = [] + + for k, v in yield_file(in_file): + if k == "v": + vertices.append(v) + elif k == "f": + faces.append(v) + + if not len(faces) or not len(vertices): + return None + + crd = Tensor(vertices).float() + face = Tensor(faces).long().t() + + data = Graph(crd=crd, face=face) + + return data diff --git a/mindscience/sharker/io/off.py b/mindscience/sharker/io/off.py new file mode 100644 index 000000000..54836adfc --- /dev/null +++ b/mindscience/sharker/io/off.py @@ -0,0 +1,42 @@ +from typing import List + +from mindspore import Tensor, ops, nn +from mindspore import ops +from ..data import Graph +from .txt_array import parse_txt_array + + +def parse_off(src: List[str]) -> Graph: + # Some files may contain a bug and do not have a carriage return after OFF. + if src[0] == "OFF": + src = src[1:] + else: + src[0] = src[0][3:] + + num_nodes, num_faces = [int(item) for item in src[0].split()[:2]] + + crd = parse_txt_array(src[1:1+num_nodes]) + + face = face_to_tri(src[1+num_nodes:1+num_nodes+num_faces]) + + data = Graph(crd=crd) + data.face = face + + return data + + +def face_to_tri(face: List[str]) -> Tensor: + face_index = [[int(x) for x in line.strip().split()] for line in face] + + triangle = Tensor([line[1:] for line in face_index if line[0] == 3]).long() + + rect = Tensor([line[1:] for line in face_index if line[0] == 4]).long() + + if rect.numel() > 0: + first, second = rect[:, [0, 1, 2]], rect[:, [0, 2, 3]] + if triangle.numel() > 0: + return ops.cat([triangle, first, second], axis=0).T + else: + return ops.cat([first, second], axis=0).T + return triangle.T + diff --git a/mindscience/sharker/io/planetoid.py b/mindscience/sharker/io/planetoid.py new file mode 100644 index 000000000..f18460f53 --- /dev/null +++ b/mindscience/sharker/io/planetoid.py @@ -0,0 +1,59 @@ +import os.path as osp +import warnings +from itertools import repeat +from typing import Dict, List, Optional + +import fsspec +import mindspore as ms +from mindspore import ops +from mindspore import Tensor, ops, nn + +from ..data import Graph +from .txt_array import read_txt_array +from ..utils import ( + coalesce, + index_to_mask, + remove_self_loops, +) + +import pickle + + + +def read_file(folder: str, prefix: str, name: str) -> Tensor: + path = osp.join(folder, f"ind.{prefix.lower()}.{name}") + + if name == "test.index": + return read_txt_array(path, dtype=ms.int64) + + with fsspec.open(path, "rb") as f: + warnings.filterwarnings("ignore", ".*`scipy.sparse.csr` name.*") + out = pickle.load(f, encoding="latin1") + + if name == "graph": + return out + + out = out.todense() if hasattr(out, "todense") else out + out = Tensor.from_numpy(out).float() + return out + + +def edge_index_from_dict( + graph_dict: Dict[int, List[int]], + num_nodes: Optional[int] = None, +) -> Tensor: + rows: List[int] = [] + cols: List[int] = [] + for key, value in graph_dict.items(): + rows += repeat(key, len(value)) + cols += value + row = Tensor(rows) + col = Tensor(cols) + edge_index = ops.stack([row, col], axis=0) + + # NOTE: There are some duplicated edges and self loops in the datasets. + # Other implementations do not remove them! + edge_index, _ = remove_self_loops(edge_index) + edge_index = coalesce(edge_index, num_nodes=num_nodes, sort_by_row=False) + + return edge_index diff --git a/mindscience/sharker/io/ply.py b/mindscience/sharker/io/ply.py new file mode 100644 index 000000000..c6beafd4b --- /dev/null +++ b/mindscience/sharker/io/ply.py @@ -0,0 +1,19 @@ +from mindspore import Tensor, ops, nn + +from ..data import Graph + +try: + import openmesh +except ImportError: + openmesh = None + + +def read_ply(path: str) -> Graph: + if openmesh is None: + raise ImportError("`read_ply` requires the `openmesh` package.") + + mesh = openmesh.read_trimesh(path) + crd = Tensor.from_numpy(mesh.points()).float() + face = Tensor.from_numpy(mesh.face_vertex_indices()) + face = face.t().long() + return Graph(crd=crd, face=face) diff --git a/mindscience/sharker/io/sdf.py b/mindscience/sharker/io/sdf.py new file mode 100644 index 000000000..f8638b7de --- /dev/null +++ b/mindscience/sharker/io/sdf.py @@ -0,0 +1,32 @@ +import mindspore as ms +from mindspore import ops, Tensor +from ..data import Graph +from .txt_array import parse_txt_array +from ..utils import coalesce + +elems = {"H": 0, "C": 1, "N": 2, "O": 3, "F": 4} + + +def parse_sdf(src: str) -> Graph: + lines = src.split("\n")[3:] + num_atoms, num_bonds = [int(item) for item in lines[0].split()[:2]] + + atom_block = lines[1:num_atoms+1] + crd = parse_txt_array(atom_block, end=3) + x = Tensor([elems[item.split()[3]] for item in atom_block]) + x = ops.one_hot(x, depth=len(elems)) + + bond_block = lines[1+num_atoms:1+num_atoms+num_bonds] + row, col = parse_txt_array(bond_block, end=2, dtype=ms.int64).t() - 1 + row, col = ops.cat([row, col], axis=0), ops.cat([col, row], axis=0) + edge_index = ops.stack([row, col], axis=0) + edge_attr = parse_txt_array(bond_block, start=2, end=3) - 1 + edge_attr = ops.cat([edge_attr, edge_attr], axis=0) + edge_index, edge_attr = coalesce(edge_index, edge_attr, num_atoms) + + return Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, crd=crd) + + +def read_sdf(path: str) -> Graph: + with open(path, "r") as f: + return parse_sdf(f.read()) diff --git a/mindscience/sharker/io/tu.py b/mindscience/sharker/io/tu.py new file mode 100644 index 000000000..b70a5b6b8 --- /dev/null +++ b/mindscience/sharker/io/tu.py @@ -0,0 +1,142 @@ +import os.path as osp +from typing import Dict, List, Optional, Tuple + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn + +from ..data import Graph +from . import fs +from .txt_array import read_txt_array +from ..utils import coalesce, cumsum, remove_self_loops + +names = [ + "A", + "graph_indicator", + "node_labels", + "node_attributes" "edge_labels", + "edge_attributes", + "graph_labels", + "graph_attributes", +] + + +def read_tu_data( + folder: str, + prefix: str, +) -> Tuple[Graph, Dict[str, Tensor], Dict[str, int]]: + files = fs.glob(osp.join(folder, f"{prefix}_*.txt")) + names = [osp.basename(f)[len(prefix) + 1: -4] for f in files] + + edge_index = read_file(folder, prefix, "A", ms.int64).t() - 1 + batch = read_file(folder, prefix, "graph_indicator", ms.int64) - 1 + + node_attribute = ops.zeros((batch.shape[0], 0)) + if "node_attributes" in names: + node_attribute = read_file(folder, prefix, "node_attributes") + if node_attribute.dim() == 1: + node_attribute = node_attribute.unsqueeze(-1) + + node_label = ops.zeros((batch.shape[0], 0)) + if "node_labels" in names: + node_label = read_file(folder, prefix, "node_labels", ms.int64) + if node_label.dim() == 1: + node_label = node_label.unsqueeze(-1) + node_label = node_label - node_label.min(axis=0)[0] + node_labels = node_label.unbind(axis=-1) + node_labels = [ops.one_hot(x, x.max() + 1) for x in node_labels] + if len(node_labels) == 1: + node_label = node_labels[0] + else: + node_label = ops.cat(node_labels, axis=-1) + + edge_attribute = ops.zeros((edge_index.shape[1], 0)) + if "edge_attributes" in names: + edge_attribute = read_file(folder, prefix, "edge_attributes") + if edge_attribute.dim() == 1: + edge_attribute = edge_attribute.unsqueeze(-1) + + edge_label = ops.zeros((edge_index.shape[1], 0)) + if "edge_labels" in names: + edge_label = read_file(folder, prefix, "edge_labels", ms.int64) + if edge_label.dim() == 1: + edge_label = edge_label.unsqueeze(-1) + edge_label = edge_label - edge_label.min(axis=0)[0] + edge_labels = edge_label.unbind(axis=-1) + edge_labels = [ops.one_hot(e, e.max() + 1) for e in edge_labels] + if len(edge_labels) == 1: + edge_label = edge_labels[0] + else: + edge_label = ops.cat(edge_labels, axis=-1) + + x = cat([node_attribute, node_label]) + edge_attr = cat([edge_attribute, edge_label]) + + y = None + if "graph_attributes" in names: # Regression problem. + y = read_file(folder, prefix, "graph_attributes") + elif "graph_labels" in names: # Classification problem. + y = read_file(folder, prefix, "graph_labels", ms.int64) + _, y = y.unique(sorted=True, return_inverse=True) + + num_nodes = int(edge_index.max()) + 1 if x is None else x.shape[0] + edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) + edge_index, edge_attr = coalesce(edge_index, edge_attr, num_nodes) + + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + data, slices = split(data, batch) + + sizes = { + "num_node_attributes": node_attribute.shape[-1], + "num_node_labels": node_label.shape[-1], + "num_edge_attributes": edge_attribute.shape[-1], + "num_edge_labels": edge_label.shape[-1], + } + + return data, slices, sizes + + +def read_file( + folder: str, + prefix: str, + name: str, + dtype: Optional[ms.Type] = None, +) -> Tensor: + path = osp.join(folder, f"{prefix}_{name}.txt") + return read_txt_array(path, sep=",", dtype=dtype) + + +def cat(seq: List[Optional[Tensor]]) -> Optional[Tensor]: + values = [v for v in seq if v is not None] + values = [v for v in values if v.numel() > 0] + values = [v.unsqueeze(-1) if v.dim() == 1 else v for v in values] + return ops.cat(values, axis=-1) if len(values) > 0 else None + + +def split(data: Graph, batch: Tensor) -> Tuple[Graph, Dict[str, Tensor]]: + node_slice = cumsum(Tensor.from_numpy(np.bincount(batch))) + + assert data.edge_index is not None + row, _ = data.edge_index + edge_slice = cumsum(Tensor.from_numpy(np.bincount(batch[row]))) + + # Edge indices should start at zero for every graph. + data.edge_index -= node_slice[batch[row]].unsqueeze(0) + + slices = {"edge_index": edge_slice} + if data.x is not None: + slices["x"] = node_slice + else: + # Imitate `collate` functionality: + data._num_nodes = ops.bincount(batch).tolist() + data.num_nodes = batch.numel() + if data.edge_attr is not None: + slices["edge_attr"] = edge_slice + if data.y is not None: + assert isinstance(data.y, Tensor) + if data.y.shape[0] == batch.shape[0]: + slices["y"] = node_slice + else: + slices["y"] = ops.arange(0, int(batch[-1]) + 2, dtype=ms.int64) + + return data, slices diff --git a/mindscience/sharker/io/txt_array.py b/mindscience/sharker/io/txt_array.py new file mode 100644 index 000000000..19bdb7f25 --- /dev/null +++ b/mindscience/sharker/io/txt_array.py @@ -0,0 +1,32 @@ +from typing import List, Optional + +import fsspec +from mindspore import Tensor, ops, nn, Type + + +def parse_txt_array( + src: List[str], + sep: Optional[str] = None, + start: int = 0, + end: Optional[int] = None, + dtype: Optional[Type] = None, +) -> Tensor: + empty = ops.zeros(0, dtype=dtype) + to_number = float if empty.is_floating_point() else int + + return Tensor( + [[to_number(x) for x in line.split(sep)[start:end]] for line in src], + dtype=dtype, + ).squeeze() + + +def read_txt_array( + path: str, + sep: Optional[str] = None, + start: int = 0, + end: Optional[int] = None, + dtype: Optional[Type] = None, +) -> Tensor: + with fsspec.open(path, "r") as f: + src = f.read().split("\n")[:-1] + return parse_txt_array(src, sep, start, end, dtype) diff --git a/mindscience/sharker/loader/__init__.py b/mindscience/sharker/loader/__init__.py new file mode 100644 index 000000000..74935993b --- /dev/null +++ b/mindscience/sharker/loader/__init__.py @@ -0,0 +1,7 @@ +from .dataloader import Dataloader + +data_classes = [ + "Dataloader", +] + +__all__ = data_classes \ No newline at end of file diff --git a/mindscience/sharker/loader/dataloader.py b/mindscience/sharker/loader/dataloader.py new file mode 100644 index 000000000..12ed72178 --- /dev/null +++ b/mindscience/sharker/loader/dataloader.py @@ -0,0 +1,393 @@ +import json +import math +import time +import numpy as np +import mindspore as ms +import mindspore._c_dataengine as cde +import mindspore.dataset.engine.offload as offload +import weakref +from mindspore import log as logger +from mindspore import ops +from mindspore import log as logger +from mindspore.dataset.core.config import get_debug_mode +from mindspore.dataset.engine.datasets import Dataset, BatchDataset +from mindspore.dataset.engine.datasets_user_defined import GeneratorDataset +from mindspore.dataset.engine.iterators import Iterator, DummyIterator, _transform_md_to_output +from mindspore.dataset.engine.validators import check_dict_iterator +from typing import Any, List, Optional, Sequence, Union + +from ..data import Graph, Batch, HeteroGraph, TemporalGraph +from ..data.hypergraph import HyperGraph + +ITERATORS_LIST = list() + + +def _unset_iterator_cleanup(): + global _ITERATOR_CLEANUP + _ITERATOR_CLEANUP = False + +def collate_default(*kwargs): + data_list = [x['graph'] for x in kwargs[0]] + col1 = Batch.from_data_list(data_list, return_tensor=False) + return {'graph':col1} + +class Dataloader(GeneratorDataset): + """ + A Dataloader that generates data from Python by invoking Python data source each epoch. + + If the type in source contains Graph, the column name will be automatically recognized. + Otherwisw the column names and column types of generated dataset depend on Python data defined by users. + + Args: + source (Union[Callable, Iterable, Random Accessible]): + A generator callable object, an iterable Python object or a random accessible Python object. + Callable source is required to return a tuple of NumPy arrays as a row of the dataset on source().next(). + Iterable source is required to return a tuple of NumPy arrays as a row of the dataset on + iter(source).next(). + Random accessible source is required to return a tuple of NumPy arrays as a row of the dataset on + source[idx]. + column_names (Union[str, list[str]], optional): List of column names of the dataset. Default: ``None`` . + Users are required to provide either column_names or schema. + column_types (list[mindspore.dtype], optional): List of column data types of the dataset. Default: ``None`` . + If provided, sanity check will be performed on generator output. + schema (Union[str, Schema], optional): Data format policy, which specifies the data types and shapes of the data + column to be read. Both JSON file path and objects constructed by :class:`mindspore.dataset.Schema` are + acceptable. Default: ``None`` . + num_samples (int, optional): The number of samples to be included in the dataset. + Default: ``None`` , all images. + num_parallel_workers (int, optional): Number of worker threads/subprocesses used to + fetch the dataset in parallel. Default: ``1``. + shuffle (bool, optional): Whether or not to perform shuffle on the dataset. Random accessible input is required. + Default: ``None`` , expected order behavior shown in the table below. + sampler (Union[Sampler, Iterable], optional): Object used to choose samples from the dataset. Random accessible + input is required. Default: ``None`` , expected order behavior shown in the table below. + num_shards (int, optional): Number of shards that the dataset will be divided into. Default: ``None`` . + Random accessible input is required. When this argument is specified, `num_samples` reflects the maximum + sample number of per shard. + shard_id (int, optional): The shard ID within `num_shards` . Default: ``None`` . + This argument must be specified only when `num_shards` is also specified. + Random accessible input is required. + python_multiprocessing (bool, optional): Parallelize Python operations with multiple worker process. This + option could be beneficial if the Python operation is computational heavy. Default: ``True``. + max_rowsize(int, optional): Maximum size of row in MB that is used for shared memory + allocation to copy data between processes, the total occupied shared memory will increase as + ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1, + shared memory will be dynamically allocated with the actual size of data. This is only used if + ``python_multiprocessing`` is set to True. Default: 16. + """ + + def __init__(self, source, column_names=['graph'], column_types=None, schema=None, num_samples=None, + num_parallel_workers=1, shuffle=None, sampler=None, num_shards=None, shard_id=None, + python_multiprocessing=False, max_rowsize=-1): + self.dataset_is_graph = False + if not callable(source) and isinstance(source[0], Graph): + self.dataset_is_graph = True + self._output_numpy = False + if self.dataset_is_graph: + self._output_numpy = True + source = [{'graph': data} for data in source] + self.iterator = [] + + super().__init__(source, column_names=column_names, column_types=column_types, schema=schema, + num_samples=num_samples, + num_parallel_workers=num_parallel_workers, shuffle=shuffle, sampler=sampler, + num_shards=num_shards, shard_id=shard_id, + python_multiprocessing=python_multiprocessing, max_rowsize=max_rowsize) + + def __iter__(self): + """Create an iterator over the dataset.""" + if not self.iterator: + self.iterator = self.create_dict_iterator(num_epochs=-1, output_numpy=self._output_numpy) + return self.iterator + + def batch(self, batch_size, drop_remainder=False, num_parallel_workers=None, + per_batch_map=collate_default, input_columns=["graph"], output_columns=["graph"], **kwargs): + """ + Combine batch_size number of consecutive rows into batch which apply per_batch_map or collate_fn to the samples first. + + If the type in source contains Graph, all the elements within that column do not need to have the same shape. + Otherwise, for any column, all the elements within that column must have the same shape. + + Refer to the following figure for the execution process: + + .. image:: batch_en.png + + Note: + The order of using repeat and batch reflects the number of batches and (er_batch_map or collate_fn). + It is recommended that the repeat operation applied after the batch operation finished. + + Args: + batch_size (Union[int, Callable]): The number of rows each batch is created with. An + int or callable object which takes exactly 1 parameter, BatchInfo. + drop_remainder (bool, optional): Determines whether or not to drop the last block + whose data row number is less than batch size. Default: ``False`` . If ``True`` , + and if there are less than `batch_size` rows available to make the last batch, + then those rows will be dropped and not propagated to the child node. + num_parallel_workers (int, optional): Number of workers(threads) to process the dataset in parallel. + Default: ``None`` . + **kwargs: + + - per_batch_map (Callable[[List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo], \ + (List[numpy.ndarray], ..., List[numpy.ndarray])], optional): Per batch map callable. + Default: ``None``. + A callable which takes (List[numpy.ndarray], ..., List[numpy.ndarray], BatchInfo) as input parameters. + Each list[numpy.ndarray] represents a batch of numpy.ndarray on a given column. The number of lists + should match with the number of entries in input_columns. The last parameter of the callable should + always be a BatchInfo object. Per_batch_map should return + (list[numpy.ndarray], list[numpy.ndarray], ...). The length of each list in output should be the same + as the input. output_columns is required if the number of output lists is different from input. + + - input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of + the list should match with signature of `per_batch_map` callable. Default: ``None`` . + + - output_columns (Union[str, list[str]], optional): List of names assigned to the columns + outputted by the last operation. This parameter is mandatory if len(input_columns) != + len(output_columns). The size of this list must match the number of output + columns of the last operation. Default: ``None`` , output columns will have the same + name as the input columns, i.e., the columns will be replaced. + + - python_multiprocessing (bool, optional): Parallelize Python function `per_batch_map` with + multi-processing or multi-threading mode, ``True`` means multi-processing, + ``False`` means multi-threading If `per_batch_map` is a I/O bound task, use + multi-threading mode. If `per_batch_map` is a CPU bound task, it is recommended to use + multi-processing mode. Default: ``False`` , use python multi-threading mode. + + - max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory + allocation to copy data between processes, the total occupied shared memory will increase as + ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set + to -1, shared memory will be dynamically allocated with the actual size of data. This is only used if + ``python_multiprocessing`` is set to True. If it is an int value, it represents + ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory. + If it is a list, the first element represents the ``input_columns`` use this value as the unit to + create shared memory, and the second element represents ``output_columns`` use this value as the unit + to create shared memory. Default: 16. + + - collate_fn (Callable, optional): merges a list of samples to form a mini-batch of Tensor(s). + Used when using batched loading from a map-style dataset. + Returns: + CustomBatchDataset, a new dataset with the above operation applied. + """ + if (not self.dataset_is_graph): + per_batch_map = None + input_columns = None + output_columns = None + return CustomBatchDataset(self, batch_size, drop_remainder, self.dataset_is_graph, num_parallel_workers, + per_batch_map, input_columns, output_columns, **kwargs) + + def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True): + """ + Create an CustomDictIterator over the dataset that yields samples of type dict, + while the key is the column name and the value is the data. + + Args: + num_epochs (int, optional): The number of epochs to iterate over the entire dataset. + Default: ``-1`` , the dataset can be iterated indefinitely. + output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or + convert it to Tensor. Default: ``False`` . + do_copy (bool, optional): Whether to copy the data when converting output to Tensor, + or reuse the buffer for better performance, only works when `output_numpy` is ``False`` . + Default: ``True`` . + + Returns: + Iterator, a dataset iterator that yields samples of type dict. + """ + if output_numpy is None: + output_numpy = False + if Dataset._noop_mode(): + return DummyIterator(self, 'dict', output_numpy) + return CustomDictIterator(self, num_epochs, output_numpy, do_copy) + + +class CustomDictIterator(Iterator): + """ + The derived class of Iterator with dict type. + """ + + def __init__(self, dataset, num_epochs=-1, output_numpy=False, do_copy=True): + start_init = time.time() + super().__init__(dataset, num_epochs=num_epochs, output_numpy=output_numpy, do_copy=do_copy) + self.dataset_type = None + self.dataset_is_graph = dataset.dataset_is_graph + + self.__ori_dataset = dataset + + self.ir_tree, self.dataset = dataset.create_ir_tree() + + self._runtime_context = cde.PythonRuntimeContext() + self._runtime_context.Init() + if dataset.get_init_step() == 0: + init_step = 0 + dataset_size = -1 + else: + init_step = dataset.get_init_step() + dataset_size = dataset.get_dataset_size() + if get_debug_mode(): + consumer = cde.PythonPullBasedIteratorConsumer(num_epochs) + consumer.Init(self.ir_tree) + else: + consumer = cde.PythonIteratorConsumer(num_epochs) + consumer.Init(self.ir_tree, init_step, dataset_size) + self._runtime_context.AssignConsumer(consumer) + self._iterator = self._runtime_context.GetConsumer() + self._output_numpy = output_numpy + self._do_copy = do_copy + self.__index = 0 + self.last_step_end = False + self.offload_model = None + json_offload = json.loads(consumer.GetOffload()) + # See if GetOffload identified any operations set to be offloaded. + if json_offload is not None: + offload.check_concat_zip_dataset(self.__ori_dataset) + self.offload_model = offload.GetOffloadModel(consumer, self.__ori_dataset.get_col_names()) + + ITERATORS_LIST.append(weakref.ref(self)) + _unset_iterator_cleanup() + print('init', time.time()-start_init) + + def __next__(self): + """ + This is the implementation of the __next__() method for an iterator object in Python. + If the dataset type is 'Graph', it will call self._get_next() depends on the batch size, it then applies a collater + function if available to combine the data into a batch and returns the batch or data list. + If there is no collater function, it returns the data list. If there is no more data to iterate over, + it raises a StopIteration exception. + """ + if not self._runtime_context: + logger.warning("Iterator does not have a running C++ pipeline." + + "It might because Iterator stop() had been called, or C++ pipeline crashed silently.") + raise RuntimeError("Iterator does not have a running C++ pipeline.") + # Note offload is applied inside _get_next() if applicable since get_next converts to output format + start_next = time.time() + data = self._get_next() + if not data: + if self.__index == 0: + logger.warning("No records available.") + if self.__ori_dataset.dataset_size is None: + self.__ori_dataset.dataset_size = self.__index + self.__index = 0 + raise StopIteration + self.__index += 1 + if self.dataset_is_graph == True: + data = data.tensor() + return data + + def _get_next(self): + """ + Returns the next record in the dataset as dictionary, and convert the dictionary back to Graph class if required. + + Returns: + Dict, the next record in the dataset. + """ + try: + if self.offload_model is None: + if self.dataset_is_graph: #### Graph version + data_dict = {} + start_time = time.time() + for t in self._iterator.GetNextAsList(): + data_dict = _transform_md_to_output(t,self._output_numpy,self._do_copy) + if data_dict: + return data_dict['graph'] + else: + return None + else: + return [_transform_md_to_output(t,self._output_numpy,self._do_copy) for k, t in self._iterator.GetNextAsMap().items()] + + data = [self._transform_md_to_tensor(t) for t in self._iterator.GetNextAsList()] + if data: + data = offload.apply_offload_iterators(data, self.offload_model) + # Create output dictionary after offload + out_data = {} + for i, col in enumerate(self.get_col_names()): + out_data[col] = self._transform_tensor_to_output(data[i]) + data = out_data + return data + except RuntimeError as err: + err_info = str(err) + if err_info.find("Out of memory") >= 0 or err_info.find("MemoryError") >= 0: + logger.critical("Memory error occurred, process will exit.") + os.kill(os.getpid(), signal.SIGKILL) + raise err + + +class CustomBatchDataset(BatchDataset): + """ + The result of applying Batch operation to the input dataset. + + Args: + input_dataset (Dataset): Input Dataset to be batched. + batch_size (Union[int, function]): The number of rows each batch is created with. An + int or callable which takes exactly 1 parameter, BatchInfo. + drop_remainder (bool, optional): Determines whether or not to drop the last + possibly incomplete batch. Default: ``False``. If True, and if there are less + than batch_size rows available to make the last batch, then those rows will + be dropped and not propagated to the child node. + num_parallel_workers (int, optional): Number of workers to process the dataset in parallel. Default: ``None``. + per_batch_map (callable, optional): Per batch map callable. A callable which takes + (list[Tensor], list[Tensor], ..., BatchInfo) as input parameters. Each list[Tensor] represents a batch of + Tensors on a given column. The number of lists should match with number of entries in input_columns. The + last parameter of the callable must always be a BatchInfo object. + input_columns (Union[str, list[str]], optional): List of names of the input columns. The size of the list must + match with signature of per_batch_map callable. + output_columns (Union[str, list[str]], optional): List of names assigned to the columns outputted by + the last operation. This parameter is mandatory if len(input_columns) != + len(output_columns). The size of this list must match the number of output + columns of the last operation. Default: ``None``, output columns will have the same + name as the input columns, i.e., the columns will be replaced. + max_rowsize(Union[int, list[int]], optional): Maximum size of row in MB that is used for shared memory + allocation to copy data between processes, the total occupied shared memory will increase as + ``num_parallel_workers`` and :func:`mindspore.dataset.config.set_prefetch_size` increase. If set to -1, + shared memory will be dynamically allocated with the actual size of data. This is only used if + ``python_multiprocessing`` is set to True. If it is an int value, it represents + ``input_columns`` and ``output_columns`` use this value as the unit to create shared memory. + If it is a list, the first element represents the ``input_columns`` use this value as the unit to + create shared memory, and the second element represents ``output_columns`` use this value as the unit + to create shared memory. Default: 16. + collate_fn (Callable, optional): merges a list of samples to form a mini-batch of Tensor(s). + Used when using batched loading from a map-style dataset. + """ + + def __init__(self, input_dataset, batch_size, drop_remainder=False, dataset_is_graph=True, num_parallel_workers=None, + per_batch_map=None, input_columns=None, output_columns=None, python_multiprocessing=False, + max_rowsize=16, collate_fn=None): + + self.dataset_is_graph = dataset_is_graph + self._output_numpy = False + if self.dataset_is_graph: + self._output_numpy = True + self.iterator = [] + super().__init__(input_dataset, batch_size, drop_remainder=drop_remainder, + num_parallel_workers=num_parallel_workers, per_batch_map=per_batch_map, + input_columns=input_columns, output_columns=output_columns, + python_multiprocessing=python_multiprocessing, max_rowsize=max_rowsize) + + def __len__(self): + return self.get_dataset_size() + + def __iter__(self): + """Create an iterator over the dataset.""" + if not self.iterator: + self.iterator = self.create_dict_iterator(num_epochs=-1, output_numpy=self._output_numpy) + return self.iterator + + def create_dict_iterator(self, num_epochs=-1, output_numpy=False, do_copy=True): + """ + Create an CustomDictIterator over the dataset that yields samples of type dict, + while the key is the column name and the value is the data. + + Args: + num_epochs (int, optional): The number of epochs to iterate over the entire dataset. + Default: ``-1`` , the dataset can be iterated indefinitely. + output_numpy (bool, optional): Whether to keep the output data as NumPy ndarray, or + convert it to Tensor. Default: ``False`` . + do_copy (bool, optional): Whether to copy the data when converting output to Tensor, + or reuse the buffer for better performance, only works when `output_numpy` is ``False`` . + Default: ``True`` . + + Returns: + Iterator, a dataset iterator that yields samples of type dict. + """ + if output_numpy is None: + output_numpy = False + if Dataset._noop_mode(): + return DummyIterator(self, 'dict', output_numpy) + return CustomDictIterator(self, num_epochs, output_numpy, do_copy) \ No newline at end of file diff --git a/mindscience/sharker/nn/__init__.py b/mindscience/sharker/nn/__init__.py new file mode 100644 index 000000000..4bc0af480 --- /dev/null +++ b/mindscience/sharker/nn/__init__.py @@ -0,0 +1,15 @@ +from .reshape import Reshape +from .encoding import PositionalEncoding, TemporalEncoding + + +from .aggr import * # noqa +from .conv import * # noqa +from .norm import * # noqa +from .dense import * # noqa +from .models import * # noqa + +__all__ = [ + 'Reshape', + 'PositionalEncoding', + 'TemporalEncoding', +] diff --git a/mindscience/sharker/nn/aggr/__init__.py b/mindscience/sharker/nn/aggr/__init__.py new file mode 100644 index 000000000..0b3425849 --- /dev/null +++ b/mindscience/sharker/nn/aggr/__init__.py @@ -0,0 +1,56 @@ +from .base import Aggregation +from .multi import MultiAggregation +from .basic import ( + MeanAggregation, + SumAggregation, + MaxAggregation, + MinAggregation, + MulAggregation, + VarAggregation, + StdAggregation, + SoftmaxAggregation, + PowerMeanAggregation, +) +from .quantile import MedianAggregation, QuantileAggregation +from .lstm import LSTMAggregation +from .gru import GRUAggregation +from .set2set import Set2Set +from .scaler import DegreeScalerAggregation +from .equilibrium import EquilibriumAggregation +from .sort import SortAggregation +from .gmt import GraphMultisetTransformer +from .attention import AttentionalAggregation +from .mlp import MLPAggregation +from .deep_sets import DeepSetsAggregation +from .set_transformer import SetTransformerAggregation +from .lcm import LCMAggregation +from .variance_preserving import VariancePreservingAggregation + +__all__ = classes = [ + 'Aggregation', + 'MultiAggregation', + 'SumAggregation', + 'MeanAggregation', + 'MaxAggregation', + 'MinAggregation', + 'MulAggregation', + 'VarAggregation', + 'StdAggregation', + 'SoftmaxAggregation', + 'PowerMeanAggregation', + 'MedianAggregation', + 'QuantileAggregation', + 'LSTMAggregation', + 'GRUAggregation', + 'Set2Set', + 'DegreeScalerAggregation', + 'SortAggregation', + 'GraphMultisetTransformer', + 'AttentionalAggregation', + 'EquilibriumAggregation', + 'MLPAggregation', + 'DeepSetsAggregation', + 'SetTransformerAggregation', + 'LCMAggregation', + 'VariancePreservingAggregation', +] diff --git a/mindscience/sharker/nn/aggr/attention.py b/mindscience/sharker/nn/aggr/attention.py new file mode 100644 index 000000000..1df93c82b --- /dev/null +++ b/mindscience/sharker/nn/aggr/attention.py @@ -0,0 +1,92 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn +from mindspore import nn +from ..aggr import Aggregation +from ..inits import reset +from ..models.mlp import MLP +from ...utils import softmax + + +class AttentionalAggregation(Aggregation): + r"""The soft attention aggregation layer from the `"Graph Matching Networks + for Learning the Similarity of Graph Structured Objects" + `_ paper. + + .. math:: + \mathbf{r}_i = \sum_{n=1}^{N_i} \mathrm{softmax} \left( + h_{\mathrm{gate}} ( \mathbf{x}_n ) \right) \cdot + h_{\mathbf{\Theta}} ( \mathbf{x}_n ), + + where :math:`h_{\mathrm{gate}} \colon \mathbb{R}^F \to + \mathbb{R}` and :math:`h_{\mathbf{\Theta}}` denote neural networks, *i.e.* + MLPs. + + Args: + gate_nn (nn.Cell): A neural network :math:`h_{\mathrm{gate}}` + that computes attention scores by mapping node features :obj:`x` of + shape :obj:`[-1, in_channels]` to shape :obj:`[-1, 1]` (for + node-level gating) or :obj:`[1, out_channels]` (for feature-level + gating), *e.g.*, defined by :class:`nn.Sequential`. + nn (nn.Cell, optional): A neural network + :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` of + shape :obj:`[-1, in_channels]` to shape :obj:`[-1, out_channels]` + before combining them with the attention scores, *e.g.*, defined by + :class:`nn.Sequential`. (default: :obj:`None`) + """ + + def __init__( + self, + gate_nn: nn.Cell, + nn: Optional[nn.Cell] = None, + ): + super().__init__() + + self.gate_nn = self.gate_mlp = None + if isinstance(gate_nn, MLP): + self.gate_mlp = gate_nn + else: + self.gate_nn = gate_nn + + self.nn = self.mlp = None + if isinstance(nn, MLP): + self.mlp = nn + else: + self.nn = nn + + def reset_parameters(self): + reset(self.gate_nn) + reset(self.gate_mlp) + reset(self.nn) + reset(self.mlp) + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + self.assert_two_dimensional_input(x, dim) + + if self.gate_mlp is not None: + gate = self.gate_mlp(x, batch=index, batch_size=dim_size) + else: + gate = self.gate_nn(x) + + if self.mlp is not None: + x = self.mlp(x, batch=index, batch_size=dim_size) + elif self.nn is not None: + x = self.nn(x) + + gate = softmax(gate, index, ptr, dim_size, dim) + return self.reduce(gate * x, index, ptr, dim_size, dim) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"gate_nn={self.gate_mlp or self.gate_nn}, " + f"nn={self.mlp or self.nn})" + ) diff --git a/mindscience/sharker/nn/aggr/base.py b/mindscience/sharker/nn/aggr/base.py new file mode 100644 index 000000000..7243a1922 --- /dev/null +++ b/mindscience/sharker/nn/aggr/base.py @@ -0,0 +1,204 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint + +from ...experimental import disable_dynamic_shapes +from ...utils import scatter, to_dense_batch, segment + + +class Aggregation(nn.Cell): + r"""An abstract base class for implementing custom aggregations. + + Aggregation can be either performed via an :obj:`index` vector, which + defines the mapping from input elements to their location in the output + + Notably, :obj:`index` does not have to be sorted (for most aggregation + operators): + + .. code-block:: + + # Feature matrix holding 10 elements with 64 features each: + x = ops.randn(10, 64) + + # Assign each element to one of three sets: + index = Tensor([0, 0, 1, 0, 2, 0, 2, 1, 0, 2]) + + output = aggr(x, index) # Output shape: [3, 64] + + Alternatively, aggregation can be achieved via a "compressed" index vector + called :obj:`ptr`. Here, elements within the same set need to be grouped + together in the input, and :obj:`ptr` defines their boundaries: + + .. code-block:: + + # Feature matrix holding 10 elements with 64 features each: + x = ops.randn(10, 64) + + # Define the boundary indices for three sets: + ptr = Tensor([0, 4, 7, 10]) + + output = aggr(x, ptr=ptr) # Output shape: [3, 64] + + Note that at least one of :obj:`index` or :obj:`ptr` must be defined. + + Shapes: + - **input:** + node features :math:`(*, |\mathcal{V}|, F_{in})` or edge features + :math:`(*, |\mathcal{E}|, F_{in})`, + index vector :math:`(|\mathcal{V}|)` or :math:`(|\mathcal{E}|)`, + - **output:** graph features :math:`(*, |\mathcal{G}|, F_{out})` or + node features :math:`(*, |\mathcal{V}|, F_{out})` + """ + + def __init__(self) -> None: + super().__init__() + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + r"""Forward pass. + + Args: + x (Tensor): The source tensor. + index (Tensor, optional): The indices of elements for + applying the aggregation. + One of :obj:`index` or :obj:`ptr` must be defined. + (default: :obj:`None`) + ptr (Tensor, optional): If given, computes the aggregation + based on sorted inputs in CSR representation. + One of :obj:`index` or :obj:`ptr` must be defined. + (default: :obj:`None`) + dim_size (int, optional): The size of the output tensor at + dimension :obj:`dim` after aggregation. (default: :obj:`None`) + dim (int, optional): The dimension in which to aggregate. + (default: :obj:`-2`) + max_num_elements: (int, optional): The maximum number of elements + within a single aggregation group. (default: :obj:`None`) + """ + pass + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + pass + + @disable_dynamic_shapes(required_args=["dim_size"]) + def __call__( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + **kwargs, + ) -> Tensor: + + if dim >= x.dim() or dim < -x.dim(): + raise ValueError( + f"Encountered invalid dimension '{dim}' of " + f"source tensor with {x.dim()} dimensions" + ) + + if index is None and ptr is None: + index = mint.zeros(x.shape[dim], dtype=ms.int64) + + if ptr is not None: + if dim_size is None: + dim_size = ptr.numel() - 1 + elif dim_size != ptr.numel() - 1: + raise ValueError( + f"Encountered invalid 'dim_size' (got " + f"'{dim_size}' but expected " + f"'{ptr.numel() - 1}')" + ) + + if index is not None and dim_size is None: + dim_size = int(ops.amax(index)) + 1 if ops.numel(index) > 0 else 0 + elif index is not None and ops.numel(index) > 0 and dim_size <= int(ops.amax(index)): + raise ValueError( + f"Encountered invalid 'dim_size' (got '{dim_size}' but expected >= '{int(ops.amax(index)) + 1}')" + ) + return super().__call__( + x, index=index, ptr=ptr, dim_size=dim_size, dim=dim, **kwargs + ) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + # Assertions ############################################################## + + def assert_index_present(self, index: Optional[Tensor]): + if index is None: + raise NotImplementedError("Aggregation requires 'index' to be specified") + + def assert_sorted_index(self, index: Optional[Tensor]): + if index is not None and not mint.all(index[:-1] <= index[1:]): + raise ValueError( + "Can not perform aggregation since the 'index' " + "tensor is not sorted. Specifically, if you use " + "this aggregation as part of 'MessagePassing`, " + "ensure that 'edge_index' is sorted by " + "destination nodes, e.g., by calling " + "`data.sort(sort_by_row=False)`" + ) + + def assert_two_dimensional_input(self, x: Tensor, dim: int): + if x.dim() != 2: + raise ValueError( + f"Aggregation requires two-dimensional inputs " f"(got '{x.dim()}')" + ) + + if dim not in [-2, 0]: + raise ValueError( + f"Aggregation needs to perform aggregation in " + f"first dimension (got '{dim}')" + ) + + # Helper methods ########################################################## + + def reduce( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + reduce: str = "sum", + ) -> Tensor: + + if ptr is not None: + return segment(x, ptr, dim, dim_size=dim_size, reduce=reduce) + + if index is None: + raise RuntimeError("Aggregation requires 'index' to be specified") + + return scatter(x, index, dim, dim_size, reduce) + + def to_dense_batch( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + fill_value: float = 0.0, + max_num_elements: Optional[int] = None, + ) -> Tuple[Tensor, Tensor]: + + self.assert_index_present(index) + self.assert_sorted_index(index) + self.assert_two_dimensional_input(x, dim) + + return to_dense_batch( + x, + index, + batch_size=dim_size, + fill_value=fill_value, + max_num_nodes=max_num_elements, + ) diff --git a/mindscience/sharker/nn/aggr/basic.py b/mindscience/sharker/nn/aggr/basic.py new file mode 100644 index 000000000..0c81f815a --- /dev/null +++ b/mindscience/sharker/nn/aggr/basic.py @@ -0,0 +1,324 @@ +import math +from typing import Optional + +from mindspore import Tensor, ops, nn, mint +from mindspore import Parameter +from .base import Aggregation +from ...utils import softmax + + +class SumAggregation(Aggregation): + r"""An aggregation operator that sums up features across a set of elements. + + .. math:: + \mathrm{sum}(\mathcal{X}) = \sum_{\mathbf{x}_i \in \mathcal{X}} + \mathbf{x}_i. + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce="sum") + + +class MeanAggregation(Aggregation): + r"""An aggregation operator that averages features across a set of + elements. + + .. math:: + \mathrm{mean}(\mathcal{X}) = \frac{1}{|\mathcal{X}|} + \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i. + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce="mean") + + +class MaxAggregation(Aggregation): + r"""An aggregation operator that takes the feature-wise maximum across a + set of elements. + + .. math:: + \mathrm{max}(\mathcal{X}) = \max_{\mathbf{x}_i \in \mathcal{X}} + \mathbf{x}_i. + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce="max") + + +class MinAggregation(Aggregation): + r"""An aggregation operator that takes the feature-wise minimum across a + set of elements. + + .. math:: + \mathrm{min}(\mathcal{X}) = \min_{\mathbf{x}_i \in \mathcal{X}} + \mathbf{x}_i. + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + return self.reduce(x, index, ptr, dim_size, dim, reduce="min") + + +class MulAggregation(Aggregation): + r"""An aggregation operator that multiples features across a set of + elements. + + .. math:: + \mathrm{mul}(\mathcal{X}) = \prod_{\mathbf{x}_i \in \mathcal{X}} + \mathbf{x}_i. + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + self.assert_index_present(index) + return self.reduce(x, index, None, dim_size, dim, reduce="mul") + + +class VarAggregation(Aggregation): + r"""An aggregation operator that takes the feature-wise variance across a + set of elements. + + .. math:: + \mathrm{var}(\mathcal{X}) = \mathrm{mean}(\{ \mathbf{x}_i^2 : x \in + \mathcal{X} \}) - \mathrm{mean}(\mathcal{X})^2. + + Args: + semi_grad (bool, optional): If set to :obj:`True`, will turn off + gradient calculation during :math:`E[X^2]` computation. Therefore, + only semi-gradients are used during backpropagation. Useful for + saving memory and accelerating backward computation. + (default: :obj:`False`) + """ + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + mean = self.reduce(x, index, ptr, dim_size, dim, reduce="mean") + mean2 = self.reduce(x * x, index, ptr, dim_size, dim, "mean") + return mean2 - mean * mean + + +class StdAggregation(Aggregation): + r"""An aggregation operator that takes the feature-wise standard deviation + across a set of elements. + + .. math:: + \mathrm{std}(\mathcal{X}) = \sqrt{\mathrm{var}(\mathcal{X})}. + + Args: + semi_grad (bool, optional): If set to :obj:`True`, will turn off + gradient calculation during :math:`E[X^2]` computation. Therefore, + only semi-gradients are used during backpropagation. Useful for + saving memory and accelerating backward computation. + (default: :obj:`False`) + """ + + def __init__(self): + super().__init__() + self.var_aggr = VarAggregation() + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + var = self.var_aggr(x, index, ptr, dim_size, dim) + # Allow "undefined" gradient at `sqrt(0.0)`: + out = var.clamp(min=1e-5).sqrt() + out[out <= math.sqrt(1e-5)] = 0.0 + return out + + +class SoftmaxAggregation(Aggregation): + r"""The softmax aggregation operator based on a temperature term, as + described in the `"DeeperGCN: All You Need to Train Deeper GCNs" + `_ paper. + + .. math:: + \mathrm{softmax}(\mathcal{X}|t) = \sum_{\mathbf{x}_i\in\mathcal{X}} + \frac{\exp(t\cdot\mathbf{x}_i)}{\sum_{\mathbf{x}_j\in\mathcal{X}} + \exp(t\cdot\mathbf{x}_j)}\cdot\mathbf{x}_{i}, + + where :math:`t` controls the softness of the softmax when aggregating over + a set of features :math:`\mathcal{X}`. + + Args: + t (float, optional): Initial inverse temperature for softmax + aggregation. (default: :obj:`1.0`) + learn (bool, optional): If set to :obj:`True`, will learn the value + :obj:`t` for softmax aggregation dynamically. + (default: :obj:`False`) + semi_grad (bool, optional): If set to :obj:`True`, will turn off + gradient calculation during softmax computation. Therefore, only + semi-gradients are used during backpropagation. Useful for saving + memory and accelerating backward computation when :obj:`t` is not + learnable. (default: :obj:`False`) + channels (int, optional): Number of channels to learn from :math:`t`. + If set to a value greater than :obj:`1`, :math:`t` will be learned + per input feature channel. This requires compatible shapes for the + input to the construct calculation. (default: :obj:`1`) + """ + + def __init__( + self, + t: float = 1.0, + learn: bool = False, + channels: int = 1, + ): + super().__init__() + + if not learn and channels != 1: + raise ValueError( + f"Cannot set 'channels' greater than '1' in case " + f"'{self.__class__.__name__}' is not trainable" + ) + + self._init_t = t + self.learn = learn + self.channels = channels + + self.t = Parameter(mint.zeros(channels)) if learn else t + self.reset_parameters() + + def reset_parameters(self): + if isinstance(self.t, Parameter): + self.t[:] = self._init_t + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + t = self.t + if self.channels != 1: + self.assert_two_dimensional_input(x, dim) + assert isinstance(t, Tensor) + t = t.view(-1, self.channels) + + alpha = x + if not isinstance(t, (int, float)) or t != 1: + alpha = x * t + + alpha = softmax(alpha, index, ptr, dim_size, dim) + return self.reduce(x * alpha, index, ptr, dim_size, dim, reduce="sum") + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(learn={self.learn})" + + +class PowerMeanAggregation(Aggregation): + r"""The powermean aggregation operator based on a power term, as + described in the `"DeeperGCN: All You Need to Train Deeper GCNs" + `_ paper. + + .. math:: + \mathrm{powermean}(\mathcal{X}|p) = \left(\frac{1}{|\mathcal{X}|} + \sum_{\mathbf{x}_i\in\mathcal{X}}\mathbf{x}_i^{p}\right)^{1/p}, + + where :math:`p` controls the power of the powermean when aggregating over + a set of features :math:`\mathcal{X}`. + + Args: + p (float, optional): Initial power for powermean aggregation. + (default: :obj:`1.0`) + learn (bool, optional): If set to :obj:`True`, will learn the value + :obj:`p` for powermean aggregation dynamically. + (default: :obj:`False`) + channels (int, optional): Number of channels to learn from :math:`p`. + If set to a value greater than :obj:`1`, :math:`p` will be learned + per input feature channel. This requires compatible shapes for the + input to the construct calculation. (default: :obj:`1`) + """ + + def __init__(self, p: float = 1.0, learn: bool = False, channels: int = 1): + super().__init__() + + if not learn and channels != 1: + raise ValueError( + f"Cannot set 'channels' greater than '1' in case " + f"'{self.__class__.__name__}' is not trainable" + ) + + self._init_p = p + self.learn = learn + self.channels = channels + + self.p = Parameter(mint.zeros(channels)) if learn else p + self.reset_parameters() + + def reset_parameters(self): + if isinstance(self.p, Parameter): + self.p[:] = self._init_p + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + p = self.p + if self.channels != 1: + assert isinstance(p, Tensor) + self.assert_two_dimensional_input(x, dim) + p = p.view(-1, self.channels) + + if not isinstance(p, (int, float)) or p != 1: + x = x.clamp(min=0, max=100).pow(p) + + out = self.reduce(x, index, ptr, dim_size, dim, reduce="mean") + + if not isinstance(p, (int, float)) or p != 1: + out = out.clamp(min=0, max=100).pow(1.0 / p) + + return out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(learn={self.learn})" diff --git a/mindscience/sharker/nn/aggr/deep_sets.py b/mindscience/sharker/nn/aggr/deep_sets.py new file mode 100644 index 000000000..c317b5e14 --- /dev/null +++ b/mindscience/sharker/nn/aggr/deep_sets.py @@ -0,0 +1,81 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn +from mindspore import nn +from .base import Aggregation +from ...nn.inits import reset +from ...nn.models.mlp import MLP + + +class DeepSetsAggregation(Aggregation): + r"""Performs Deep Sets aggregation in which the elements to aggregate are + first transformed by a Multi-Layer Perceptron (MLP) + :math:`\phi_{\mathbf{\Theta}}`, summed, and then transformed by another MLP + :math:`\rho_{\mathbf{\Theta}}`, as suggested in the `"Graph Neural Networks + with Adaptive Readouts" `_ paper. + + Args: + local_nn (nn.Cell, optional): The neural network + :math:`\phi_{\mathbf{\Theta}}`, *e.g.*, defined by + :class:`nn.Sequential` or + :class:`sharker.nn.models.MLP`. (default: :obj:`None`) + global_nn (nn.Cell, optional): The neural network + :math:`\rho_{\mathbf{\Theta}}`, *e.g.*, defined by + :class:`nn.Sequential` or + :class:`sharker.nn.models.MLP`. (default: :obj:`None`) + """ + + def __init__( + self, + local_nn: Optional[nn.Cell] = None, + global_nn: Optional[nn.Cell] = None, + ): + super().__init__() + + self.local_nn = self.local_mlp = None + if isinstance(local_nn, MLP): + self.local_mlp = local_nn + else: + self.local_nn = local_nn + + self.global_nn = self.global_mlp = None + if isinstance(global_nn, MLP): + self.global_mlp = global_nn + else: + self.global_nn = global_nn + + def reset_parameters(self): + reset(self.local_nn) + reset(self.local_mlp) + reset(self.global_nn) + reset(self.global_mlp) + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + if self.local_mlp is not None: + x = self.local_mlp(x, batch=index, batch_size=dim_size) + if self.local_nn is not None: + x = self.local_nn(x) + + x = self.reduce(x, index, ptr, dim_size, dim, reduce="sum") + + if self.global_mlp is not None: + x = self.global_mlp(x, batch=index, batch_size=dim_size) + elif self.global_nn is not None: + x = self.global_nn(x) + + return x + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"local_nn={self.local_mlp or self.local_nn}, " + f"global_nn={self.global_mlp or self.global_nn})" + ) diff --git a/mindscience/sharker/nn/aggr/equilibrium.py b/mindscience/sharker/nn/aggr/equilibrium.py new file mode 100644 index 000000000..21579ba02 --- /dev/null +++ b/mindscience/sharker/nn/aggr/equilibrium.py @@ -0,0 +1,206 @@ +from typing import Callable, List, Optional, Tuple + +import mindspore as ms +from mindspore import nn +from mindspore import Tensor, ops, nn, mint +from .base import Aggregation +from ..inits import reset +from ...utils import scatter + + +class ResNetPotential(nn.Cell): + def __init__(self, in_channels: int, out_channels: int, num_layers: List[int]): + + super().__init__() + sizes = [in_channels] + num_layers + [out_channels] + self.layers = nn.CellList( + [ + nn.SequentialCell( + nn.Dense(in_size, out_size), + nn.LayerNorm([out_size]), + nn.Tanh(), + ) + for in_size, out_size in zip(sizes[:-2], sizes[1:-1]) + ] + ) + self.layers.append(nn.Dense(sizes[-2], sizes[-1])) + + self.res_trans = nn.CellList( + [ + nn.Dense(in_channels, layer_size) + for layer_size in num_layers + [out_channels] + ] + ) + + def construct( + self, + x: Tensor, + y: Tensor, + index: Optional[Tensor], + dim_size: Optional[int] = None, + ) -> Tensor: + if index is None: + inp = mint.cat([x, y.broadcast_to((x, -1))], dim=1) + else: + inp = mint.cat([x, y[index]], dim=1) + + h = inp + for layer, res in zip(self.layers, self.res_trans): + h = layer(h) + h = res(inp) + h + + if index is None: + return h.mean() + + if dim_size is None: + dim_size = int(index.max().item() + 1) + + return scatter(h, index, 0, dim_size, reduce="mean").sum() + + +class MomentumOptimizer(nn.Cell): + r"""Provides an inner loop optimizer for the implicitly defined output + layer. It is based on an unrolled Nesterov momentum algorithm. + + Args: + learning_rate (flaot): learning rate for optimizer. + momentum (float): momentum for optimizer. + learnable (bool): If :obj:`True` then the :obj:`learning_rate` and + :obj:`momentum` will be learnable parameters. If False they + are fixed. (default: :obj:`True`) + """ + + def __init__( + self, learning_rate: float = 0.1, momentum: float = 0.9, learnable: bool = True + ): + super().__init__() + + self._initial_lr = learning_rate + self._initial_mom = momentum + self._lr = ms.Parameter(Tensor([learning_rate]), requires_grad=learnable) + self._mom = ms.Parameter(Tensor([momentum]), requires_grad=learnable) + + def reset_parameters(self): + self._lr[:] = self._initial_lr + self._mom[:] = self._initial_mom + + @property + def learning_rate(self): + return mint.nn.functional.softplus(self._lr) + + @property + def momentum(self): + return mint.sigmoid(self._mom) + + def construct( + self, + x: Tensor, + y: Tensor, + index: Optional[Tensor], + dim_size: Optional[int], + func: Callable[[Tensor, Tensor, Optional[Tensor]], Tensor], + iterations: int = 5, + ) -> Tuple[Tensor, float]: + + momentum_buffer = mint.zeros_like(y) + for _ in range(iterations): + grad_fn = ops.grad(func, grad_position=1) + grad = grad_fn(x, y, index, dim_size) + delta = self.learning_rate * grad + momentum_buffer = self.momentum * momentum_buffer - delta + y = y + momentum_buffer + return y + + +class EquilibriumAggregation(Aggregation): + r"""The equilibrium aggregation layer from the `"Equilibrium Aggregation: + Encoding Sets via Optimization" `_ paper. + + The output of this layer :math:`\mathbf{y}` is defined implicitly via a + potential function :math:`F(\mathbf{x}, \mathbf{y})`, a regularization term + :math:`R(\mathbf{y})`, and the condition + + .. math:: + \mathbf{y} = \min_\mathbf{y} R(\mathbf{y}) + \sum_{i} + F(\mathbf{x}_i, \mathbf{y}). + + The given implementation uses a ResNet-like model for the potential + function and a simple :math:`L_2` norm :math:`R(\mathbf{y}) = + \textrm{softplus}(\lambda) \cdot {\| \mathbf{y} \|}^2_2` for the + regularizer with learnable weight :math:`\lambda`. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + num_layers (List[int): List of hidden channels in the potential + function. + grad_iter (int): The number of steps to take in the internal gradient + descent. (default: :obj:`5`) + lamb (float): The initial regularization constant. + (default: :obj:`0.1`) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: List[int], + grad_iter: int = 5, + lamb: float = 0.1, + ): + super().__init__() + + self.potential = ResNetPotential(in_channels + out_channels, 1, num_layers) + self.optimizer = MomentumOptimizer() + self.initial_lamb = Tensor(lamb) + self.lamb = ms.Parameter(self.initial_lamb) + self.grad_iter = grad_iter + self.output_dim = out_channels + self.reset_parameters() + + def reset_parameters(self): + self.lamb = ms.Parameter(self.initial_lamb) + reset(self.optimizer) + reset(self.potential) + + def init_output(self, dim_size: int) -> Tensor: + return mint.zeros([dim_size, self.output_dim]).float() + + def reg(self, y: Tensor) -> Tensor: + return mint.nn.functional.softplus(self.lamb) * y.square().sum(-1).mean() + + def energy( + self, + x: Tensor, + y: Tensor, + index: Optional[Tensor], + dim_size: Optional[int] = None, + ): + return self.potential(x, y, index, dim_size) + self.reg(y) + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + self.assert_index_present(index) + + dim_size = int(index.max()) + 1 if dim_size is None else dim_size + + y = self.optimizer( + x, + self.init_output(dim_size), + index, + dim_size, + self.energy, + iterations=self.grad_iter, + ) + + return y + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mindscience/sharker/nn/aggr/fused.py b/mindscience/sharker/nn/aggr/fused.py new file mode 100644 index 000000000..67f36815c --- /dev/null +++ b/mindscience/sharker/nn/aggr/fused.py @@ -0,0 +1,345 @@ +import math +from typing import Dict, List, Optional, Tuple, Union + +from mindspore import Tensor, ops, mint + +from .base import Aggregation +from .basic import ( + MaxAggregation, + MeanAggregation, + MinAggregation, + MulAggregation, + StdAggregation, + SumAggregation, + VarAggregation, +) +from ..resolver import aggregation_resolver +from ...utils import scatter + + +class FusedAggregation(Aggregation): + r"""Helper class to fuse computation of multiple aggregations together. + + Used internally in :class:`~sharker.nn.aggr.MultiAggregation` to + speed-up computation. + Currently, the following optimizations are performed: + + * :class:`MeanAggregation` will share the output with + :class:`SumAggregation` in case it is present as well. + + * :class:`VarAggregation` will share the output with either + :class:`MeanAggregation` or :class:`SumAggregation` in case one of them + is present as well. + + * :class:`StdAggregation` will share the output with either + :class:`VarAggregation`, :class:`MeanAggregation` or + :class:`SumAggregation` in case one of them is present as well. + + In addition, temporary values such as the count per group index are shared + as well. + + Benchmarking results on MindSpore (summed over 1000 runs): + + +------------------------------+---------+---------+ + | Aggregators | Vanilla | Fusion | + +==============================+=========+=========+ + | :obj:`[sum, mean]` | 0.3325s | 0.1996s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, min, max]` | 0.7139s | 0.5037s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, var]` | 0.6849s | 0.3871s | + +------------------------------+---------+---------+ + | :obj:`[sum, mean, var, std]` | 1.0955s | 0.3973s | + +------------------------------+---------+---------+ + + Args: + aggrs (list): The list of aggregation schemes to use. + """ + + # We can fuse all aggregations together that rely on `scatter` directives. + FUSABLE_AGGRS = { + SumAggregation, + MeanAggregation, + MinAggregation, + MaxAggregation, + MulAggregation, + VarAggregation, + StdAggregation, + } + + # All aggregations that rely on computing the degree of indices. + DEGREE_BASED_AGGRS = { + MeanAggregation, + VarAggregation, + StdAggregation, + } + + # Map aggregations to `reduce` options in `scatter` directives. + REDUCE = { + "SumAggregation": "sum", + "MeanAggregation": "sum", + "MinAggregation": "min", + "MaxAggregation": "max", + "MulAggregation": "mul", + "VarAggregation": "pow_sum", + "StdAggregation": "pow_sum", + } + + def __init__(self, aggrs: List[Union[Aggregation, str]]): + super().__init__() + + if not isinstance(aggrs, (list, tuple)): + raise ValueError( + f"'aggrs' of '{self.__class__.__name__}' should " + f"be a list or tuple (got '{type(aggrs)}')." + ) + + if len(aggrs) == 0: + raise ValueError( + f"'aggrs' of '{self.__class__.__name__}' should " f"not be empty." + ) + + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + aggr_classes = [aggr.__class__ for aggr in aggrs] + self.aggr_names = [cls.__name__ for cls in aggr_classes] + self.aggr_index: Dict[str, int] = { + name: i for i, name in enumerate(self.aggr_names) + } + + for cls in aggr_classes: + if cls not in self.FUSABLE_AGGRS: + raise ValueError( + f"Received aggregation '{cls.__name__}' in " + f"'{self.__class__.__name__}' which is not " + f"fusable" + ) + + # Check whether we need to compute degree information: + self.need_degree = False + for cls in aggr_classes: + if cls in self.DEGREE_BASED_AGGRS: + self.need_degree = True + + # Determine which reduction to use for each aggregator: + # An entry of `None` means that this operator re-uses intermediate + # outputs from other aggregators. + reduce_ops: List[Optional[str]] = [] + # Determine which `(Aggregator, index)` to use as intermediate output: + lookup_ops: List[Optional[Tuple[str, int]]] = [] + + for name in self.aggr_names: + if name == "MeanAggregation": + # Directly use output of `SumAggregation`: + if "SumAggregation" in self.aggr_index: + reduce_ops.append(None) + lookup_ops.append( + ( + "SumAggregation", + self.aggr_index["SumAggregation"], + ) + ) + else: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append(None) + + elif name == "VarAggregation": + if "MeanAggregation" in self.aggr_index: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append( + ( + "MeanAggregation", + self.aggr_index["MeanAggregation"], + ) + ) + elif "SumAggregation" in self.aggr_index: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append( + ( + "SumAggregation", + self.aggr_index["SumAggregation"], + ) + ) + else: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append(None) + + elif name == "StdAggregation": + # Directly use output of `VarAggregation`: + if "VarAggregation" in self.aggr_index: + reduce_ops.append(None) + lookup_ops.append( + ( + "VarAggregation", + self.aggr_index["VarAggregation"], + ) + ) + elif "MeanAggregation" in self.aggr_index: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append( + ( + "MeanAggregation", + self.aggr_index["MeanAggregation"], + ) + ) + elif "SumAggregation" in self.aggr_index: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append( + ( + "SumAggregation", + self.aggr_index["SumAggregation"], + ) + ) + else: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append(None) + + else: + reduce_ops.append(self.REDUCE[name]) + lookup_ops.append(None) + + self.reduce_ops: List[Optional[str]] = reduce_ops + self.lookup_ops: List[Optional[Tuple[str, int]]] = lookup_ops + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> List[Tensor]: + + self.assert_index_present(index) + self.assert_two_dimensional_input(x, dim) + + assert index is not None + + if dim_size is None: + if ptr is not None: + dim_size = ptr.numel() - 1 + else: + dim_size = int(index.max()) + 1 if index.numel() > 0 else 0 + + count: Optional[Tensor] = None + if self.need_degree: + count = scatter(mint.ones(x.shape, dtype=x.dtype), index, 0, dim_size, 'add').clamp(min=1) + + ####################################################################### + + outs: List[Optional[Tensor]] = [] + + # Iterate over all reduction ops to compute first results: + for i, reduce in enumerate(self.reduce_ops): + if reduce is None: + outs.append(None) + continue + assert isinstance(reduce, str) + + if reduce == "pow_sum": + out = scatter(x * x, index, 0, dim_size, reduce="sum") + else: + out = scatter(x, index, 0, dim_size, reduce=reduce) + + outs.append(out) + + ####################################################################### + + # Compute `MeanAggregation` first to be able to re-use it: + i = self.aggr_index.get("MeanAggregation") + if i is not None: + assert count is not None + + if self.lookup_ops[i] is None: + sum_ = outs[i] + else: + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + assert tmp_aggr == "SumAggregation" + + sum_ = outs[j] + + assert sum_ is not None + outs[i] = sum_ / count + + # Compute `VarAggregation` second to be able to re-use it: + if "VarAggregation" in self.aggr_index: + i = self.aggr_index["VarAggregation"] + + assert count is not None + + if self.lookup_ops[i] is None: + sum_ = scatter(x, index, 0, dim_size, reduce="sum") + mean = sum_ / count + else: + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + + if tmp_aggr == "SumAggregation": + sum_ = outs[j] + assert sum_ is not None + mean = sum_ / count + elif tmp_aggr == "MeanAggregation": + mean = outs[j] + else: + raise NotImplementedError + + pow_sum = outs[i] + + assert pow_sum is not None + assert mean is not None + outs[i] = (pow_sum / count) - (mean * mean) + + # Compute `StdAggregation` last: + if "StdAggregation" in self.aggr_index: + i = self.aggr_index["StdAggregation"] + + var: Optional[Tensor] = None + pow_sum: Optional[Tensor] = None + mean: Optional[Tensor] = None + + if self.lookup_ops[i] is None: + pow_sum = outs[i] + sum_ = scatter(x, index, 0, dim_size, reduce="sum") + assert count is not None + mean = sum_ / count + else: + lookup_op = self.lookup_ops[i] + assert lookup_op is not None + tmp_aggr, j = lookup_op + + if tmp_aggr == "VarAggregation": + var = outs[j] + elif tmp_aggr == "SumAggregation": + pow_sum = outs[i] + sum_ = outs[j] + assert sum_ is not None + assert count is not None + mean = sum_ / count + elif tmp_aggr == "MeanAggregation": + pow_sum = outs[i] + mean = outs[j] + else: + raise NotImplementedError + + if var is None: + assert pow_sum is not None + assert count is not None + assert mean is not None + var = (pow_sum / count) - (mean * mean) + + # Allow "undefined" gradient at `sqrt(0.0)`: + out = var.clamp(min=1e-5).sqrt() + out[out <= math.sqrt(1e-5)] = 0.0 + + outs[i] = out + + ####################################################################### + + vals: List[Tensor] = [] + for out in outs: + assert out is not None + vals.append(out) + + return vals diff --git a/mindscience/sharker/nn/aggr/gmt.py b/mindscience/sharker/nn/aggr/gmt.py new file mode 100644 index 000000000..9bff1d708 --- /dev/null +++ b/mindscience/sharker/nn/aggr/gmt.py @@ -0,0 +1,102 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn + +from ...experimental import disable_dynamic_shapes +from .base import Aggregation +from .utils import ( + PoolingByMultiheadAttention, + SetAttentionBlock, +) + + +class GraphMultisetTransformer(Aggregation): + r"""The Graph Multiset Transformer pooling operator from the + `"Accurate Learning of Graph Representations + with Graph Multiset Pooling" `_ paper. + + The :class:`GraphMultisetTransformer` aggregates elements into + :math:`k` representative elements via attention-based pooling, computes the + interaction among them via :obj:`num_encoder_blocks` self-attention blocks, + and finally pools the representative elements via attention-based pooling + into a single cluster. + + .. note:: + + :class:`GraphMultisetTransformer` requires sorted indices :obj:`index` + as input. Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + Args: + channels (int): Size of each input sample. + k (int): Number of :math:`k` representative nodes after pooling. + num_encoder_blocks (int, optional): Number of Set Attention Blocks + (SABs) between the two pooling blocks. (default: :obj:`1`) + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`True`, will apply layer + normalization. (default: :obj:`False`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + k: int, + num_encoder_blocks: int = 1, + heads: int = 1, + layer_norm: bool = False, + dropout: float = 0.0, + ): + super().__init__() + + self.channels = channels + self.k = k + self.heads = heads + self.layer_norm = layer_norm + self.dropout = dropout + + self.pma1 = PoolingByMultiheadAttention(channels, k, heads, layer_norm, dropout) + self.encoders = nn.CellList( + [ + SetAttentionBlock(channels, heads, layer_norm, dropout) + for _ in range(num_encoder_blocks) + ] + ) + self.pma2 = PoolingByMultiheadAttention(channels, 1, heads, layer_norm, dropout) + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + x, mask = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=max_num_elements + ) + + x = self.pma1(x, mask) + + for encoder in self.encoders: + x = encoder(x) + + x = self.pma2(x) + + return x.squeeze(1) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.channels}, " + f"k={self.k}, heads={self.heads}, " + f"layer_norm={self.layer_norm}, " + f"dropout={self.dropout})" + ) diff --git a/mindscience/sharker/nn/aggr/gru.py b/mindscience/sharker/nn/aggr/gru.py new file mode 100644 index 000000000..b5f4f176d --- /dev/null +++ b/mindscience/sharker/nn/aggr/gru.py @@ -0,0 +1,58 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn +from mindspore.nn import GRU + +from ...experimental import disable_dynamic_shapes +from .base import Aggregation + + +class GRUAggregation(Aggregation): + r"""Performs GRU aggregation in which the elements to aggregate are + interpreted as a sequence, as described in the `"Graph Neural Networks + with Adaptive Readouts" `_ paper. + + .. note:: + + :class:`GRUAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + .. warning:: + + :class:`GRUAggregation` is not a permutation-invariant operator. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + **kwargs (optional): Additional arguments of :class:`nn.GRU`. + """ + + def __init__(self, in_channels: int, out_channels: int, **kwargs): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.gru = GRU(in_channels, out_channels, batch_first=True, **kwargs) + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + x, _ = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=max_num_elements + ) + + return self.gru(x)[0][:, -1] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels})" diff --git a/mindscience/sharker/nn/aggr/lcm.py b/mindscience/sharker/nn/aggr/lcm.py new file mode 100644 index 000000000..de9a408df --- /dev/null +++ b/mindscience/sharker/nn/aggr/lcm.py @@ -0,0 +1,124 @@ +from math import ceil, log2 +from typing import Optional + +from mindspore import Tensor, ops, mint +from mindspore.nn import GRUCell, Dense + +from ...experimental import disable_dynamic_shapes +from .base import Aggregation + + +class LCMAggregation(Aggregation): + r"""The Learnable Commutative Monoid aggregation from the + `"Learnable Commutative Monoids for Graph Neural Networks" + `_ paper, in which the elements are + aggregated using a binary tree reduction with + :math:`\mathcal{O}(\log |\mathcal{V}|)` depth. + + .. note:: + + :class:`LCMAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + .. warning:: + + :class:`LCMAggregation` is not a permutation-invariant operator. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + project (bool, optional): If set to :obj:`True`, the layer will apply a + linear transformation followed by an activation function before + aggregation. (default: :obj:`True`) + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + project: bool = True, + ): + super().__init__() + + if in_channels != out_channels and not project: + raise ValueError( + f"Inputs of '{self.__class__.__name__}' must be " + f"projected if `in_channels != out_channels`" + ) + + self.in_channels = in_channels + self.out_channels = out_channels + self.project = project + + if self.project: + self.lin = Dense(in_channels, out_channels) + else: + self.lin = None + + self.gru_cell = GRUCell(out_channels, out_channels) + + def reset_parameters(self): + if self.project: + self.lin.reset_parameters() + self.gru_cell.reset_parameters() + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + if self.project: + x = mint.nn.functional.relu(self.lin(x)) + + x, _ = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=max_num_elements + ) + + x = x.permute(1, 0, 2) + _, num_nodes, num_features = x.shape + + depth = ceil(log2(x.shape[0])) + for _ in range(depth): + half_size = ceil(x.shape[0] / 2) + + if x.shape[0] % 2 == 1: + # This level of the tree has an odd number of nodes, so the + # remaining unmatched node gets moved to the next level. + x, remainder = x[:-1], x[-1:] + else: + remainder = None + + left_right = x.view(-1, 2, num_nodes, num_features) + right_left = left_right.flip(dims=[1]) + + left_right = left_right.reshape(-1, num_features) + right_left = right_left.reshape(-1, num_features) + + # Execute the GRUCell for all (left, right) pairs in the current + # level of the tree in parallel: + out = self.gru_cell(left_right, right_left) + out = out.view(-1, 2, num_nodes, num_features) + out = out.mean(1) + if remainder is not None: + out = mint.cat([out, remainder], dim=0) + + x = out.view(half_size, num_nodes, num_features) + + assert x.shape[0] == 1 + return x.squeeze(0) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, project={self.project})" + ) diff --git a/mindscience/sharker/nn/aggr/lstm.py b/mindscience/sharker/nn/aggr/lstm.py new file mode 100644 index 000000000..27637147e --- /dev/null +++ b/mindscience/sharker/nn/aggr/lstm.py @@ -0,0 +1,58 @@ +from typing import Optional + +from mindspore import Tensor +from mindspore.nn import LSTM + +from ...experimental import disable_dynamic_shapes +from .base import Aggregation + + +class LSTMAggregation(Aggregation): + r"""Performs LSTM-style aggregation in which the elements to aggregate are + interpreted as a sequence, as described in the `"Inductive Representation + Learning on Large Graphs" `_ paper. + + .. note:: + + :class:`LSTMAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + .. warning:: + + :class:`LSTMAggregation` is not a permutation-invariant operator. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + **kwargs (optional): Additional arguments of :class:`nn.LSTM`. + """ + + def __init__(self, in_channels: int, out_channels: int, **kwargs): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.lstm = LSTM(in_channels, out_channels, batch_first=True, **kwargs) + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + x, _ = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=max_num_elements + ) + + return self.lstm(x)[0][:, -1] + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels})" diff --git a/mindscience/sharker/nn/aggr/mlp.py b/mindscience/sharker/nn/aggr/mlp.py new file mode 100644 index 000000000..34e209aec --- /dev/null +++ b/mindscience/sharker/nn/aggr/mlp.py @@ -0,0 +1,76 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn + +from .base import Aggregation +from ..models.mlp import MLP + + +class MLPAggregation(Aggregation): + r"""Performs MLP aggregation in which the elements to aggregate are + flattened into a single vectorial representation, and are then processed by + a Multi-Layer Perceptron (MLP), as described in the `"Graph Neural Networks + with Adaptive Readouts" `_ paper. + + .. note:: + + :class:`MLPAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + .. warning:: + + :class:`MLPAggregation` is not a permutation-invariant operator. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + max_num_elements (int): The maximum number of elements to aggregate per + group. + **kwargs (optional): Additional arguments of + :class:`sharker.nn.models.MLP`. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + max_num_elements: int, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.max_num_elements = max_num_elements + + self.mlp = MLP( + in_channels=in_channels * max_num_elements, + out_channels=out_channels, + **kwargs, + ) + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + x, _ = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=self.max_num_elements + ) + + return self.mlp(x.view(-1, x.shape[1] * x.shape[2]), index, dim_size) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, " + f"max_num_elements={self.max_num_elements})" + ) diff --git a/mindscience/sharker/nn/aggr/multi.py b/mindscience/sharker/nn/aggr/multi.py new file mode 100644 index 000000000..aa6cb448e --- /dev/null +++ b/mindscience/sharker/nn/aggr/multi.py @@ -0,0 +1,204 @@ +import copy +from typing import Any, Dict, List, Optional, Union +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .base import Aggregation +from .fused import FusedAggregation +from ..dense import HeteroDictLinear +from ..resolver import aggregation_resolver + + +class MultiAggregation(Aggregation): + r"""Performs aggregations with one or more aggregators and combines + aggregated results, as described in the `"Principal Neighbourhood + Aggregation for Graph Nets" `_ and + `"Adaptive Filters and Aggregator Fusion for Efficient Graph Convolutions" + `_ papers. + + Args: + aggrs (list): The list of aggregation schemes to use. + aggrs_kwargs (dict, optional): Arguments passed to the + respective aggregation function in case it gets automatically + resolved. (default: :obj:`None`) + mode (str, optional): The combine mode to use for combining + aggregated results from multiple aggregations (:obj:`"cat"`, + :obj:`"proj"`, :obj:`"sum"`, :obj:`"mean"`, :obj:`"max"`, + :obj:`"min"`, :obj:`"logsumexp"`, :obj:`"std"`, :obj:`"var"`, + :obj:`"attn"`). (default: :obj:`"cat"`) + mode_kwargs (dict, optional): Arguments passed for the combine + :obj:`mode`. When :obj:`"proj"` or :obj:`"attn"` is used as the + combine :obj:`mode`, :obj:`in_channels` (int or tuple) and + :obj:`out_channels` (int) are needed to be specified respectively + for the size of each input sample to combine from the respective + aggregation outputs and the size of each output sample after + combination. When :obj:`"attn"` mode is used, :obj:`num_heads` + (int) is needed to be specified for the number of parallel + attention heads. (default: :obj:`None`) + """ + + fused_out_index: List[int] + is_fused_aggr: List[bool] + + def __init__( + self, + aggrs: List[Union[Aggregation, str]], + aggrs_kwargs: Optional[List[Dict[str, Any]]] = None, + mode: Optional[str] = "cat", + mode_kwargs: Optional[Dict[str, Any]] = None, + ): + + super().__init__() + + if not isinstance(aggrs, (list, tuple)): + raise ValueError( + f"'aggrs' of '{self.__class__.__name__}' should " + f"be a list or tuple (got '{type(aggrs)}')." + ) + + if len(aggrs) == 0: + raise ValueError( + f"'aggrs' of '{self.__class__.__name__}' should " f"not be empty." + ) + + if aggrs_kwargs is None: + aggrs_kwargs = [{}] * len(aggrs) + elif len(aggrs) != len(aggrs_kwargs): + raise ValueError( + f"'aggrs_kwargs' with invalid length passed to " + f"'{self.__class__.__name__}' " + f"(got '{len(aggrs_kwargs)}', " + f"expected '{len(aggrs)}'). Ensure that both " + f"'aggrs' and 'aggrs_kwargs' are consistent." + ) + + self.aggrs = nn.CellList( + [ + aggregation_resolver(aggr, **aggr_kwargs) + for aggr, aggr_kwargs in zip(aggrs, aggrs_kwargs) + ] + ) + + # Divide the set into fusable and non-fusable aggregations: + fused_aggrs: List[Aggregation] = [] + self.fused_out_index: List[int] = [] + self.is_fused_aggr: List[bool] = [] + for i, aggr in enumerate(self.aggrs): + if aggr.__class__ in FusedAggregation.FUSABLE_AGGRS: + fused_aggrs.append(aggr) + self.fused_out_index.append(i) + self.is_fused_aggr.append(True) + else: + self.is_fused_aggr.append(False) + + if len(fused_aggrs) > 0: + self.fused_aggr = FusedAggregation(fused_aggrs) + else: + self.fused_aggr = None + + self.mode = mode + mode_kwargs = copy.copy(mode_kwargs) or {} + + self.in_channels = mode_kwargs.pop("in_channels", None) + self.out_channels = mode_kwargs.pop("out_channels", None) + + if mode == "proj" or mode == "attn": + if len(aggrs) == 1: + raise ValueError( + "Multiple aggregations are required for " + "'proj' or 'attn' combine mode." + ) + + if (self.in_channels and self.out_channels) is None: + raise ValueError( + f"Combine mode '{mode}' must have `in_channels` " + f"and `out_channels` specified." + ) + + if isinstance(self.in_channels, int): + self.in_channels = [self.in_channels] * len(aggrs) + + if mode == "proj": + self.lin = nn.Dense( + sum(self.in_channels), + self.out_channels, + **mode_kwargs, + ) + + elif mode == "attn": + channels = {str(k): v for k, v, in enumerate(self.in_channels)} + self.lin_heads = HeteroDictLinear(channels, self.out_channels) + num_heads = mode_kwargs.pop("num_heads", 1) + self.multihead_attn = nn.MultiheadAttention( + self.out_channels, + num_heads, + **mode_kwargs, + ) + + dense_combine_modes = ["sum", "mean", "max", "min", "logsumexp", "var"] + if mode in dense_combine_modes: + self.dense_combine = getattr(ops, mode) + elif mode == 'std': + self.dense_combine = ms.numpy.std + elif mode == 'var': + self.dense_combine = ms.numpy.var + + def get_out_channels(self, in_channels: int) -> int: + if self.out_channels is not None: + return self.out_channels + if self.mode == "cat": + return in_channels * len(self.aggrs) + return in_channels + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + # `FusedAggregation` is currently limited to two-dimensional inputs: + if index is None or x.dim() != 2 or self.fused_aggr is None: + outs = [aggr(x, index, ptr, dim_size, dim) for aggr in self.aggrs] + return self.combine(outs) + + outs: List[Tensor] = [x] * len(self.aggrs) + + fused_outs = self.fused_aggr(x, index, ptr, dim_size, dim) + for i, out in zip(self.fused_out_index, fused_outs): + outs[i] = out + + for i, aggr in enumerate(self.aggrs): + if not self.is_fused_aggr[i]: + outs[i] = aggr(x, index, ptr, dim_size, dim) + + return self.combine(outs) + + def combine(self, inputs: List[Tensor]) -> Tensor: + if len(inputs) == 1: + return inputs[0] + + if self.mode == "cat": + return mint.cat(inputs, dim=-1) + + if hasattr(self, "lin"): + return self.lin(mint.cat(inputs, dim=-1)) + + if hasattr(self, "multihead_attn"): + x_dict = {str(k): v for k, v, in enumerate(inputs)} + x_dict = self.lin_heads(x_dict) + xs = [x_dict[str(key)] for key in range(len(inputs))] + x = mint.stack(xs, dim=0) + attn_out, _ = self.multihead_attn(x, x, x) + return mint.mean(attn_out, dim=0) + + if hasattr(self, "dense_combine"): + out = self.dense_combine(mint.stack(inputs, dim=0), 0) + return out if isinstance(out, Tensor) else out[0] + + raise ValueError(f"Combine mode '{self.mode}' is not supported.") + + def __repr__(self) -> str: + aggrs = ",\n".join([f" {aggr}" for aggr in self.aggrs]) + ",\n" + return f"{self.__class__.__name__}([\n{aggrs}], mode={self.mode})" diff --git a/mindscience/sharker/nn/aggr/quantile.py b/mindscience/sharker/nn/aggr/quantile.py new file mode 100644 index 000000000..dc8e20e9c --- /dev/null +++ b/mindscience/sharker/nn/aggr/quantile.py @@ -0,0 +1,171 @@ +from typing import List, Optional, Union + +import mindspore as ms +from mindspore import Tensor, ops +from .base import Aggregation +from ...utils import cumsum + + +class QuantileAggregation(Aggregation): + r"""An aggregation operator that returns the feature-wise :math:`q`-th + quantile of a set :math:`\mathcal{X}`. + + That is, for every feature :math:`d`, it computes + + .. math:: + {\mathrm{Q}_q(\mathcal{X})}_d = \begin{cases} + x_{\pi_i,d} & i = q \cdot n, \\ + f(x_{\pi_i,d}, x_{\pi_{i+1},d}) & i < q \cdot n < i + 1,\\ + \end{cases} + + where :math:`x_{\pi_1,d} \le \dots \le x_{\pi_i,d} \le \dots \le + x_{\pi_n,d}` and :math:`f(a, b)` is an interpolation + function defined by :obj:`interpolation`. + + Args: + q (float or list): The quantile value(s) :math:`q`. Can be a scalar or + a list of scalars in the range :math:`[0, 1]`. If more than a + quantile is passed, the results are concatenated. + interpolation (str): Interpolation method applied if the quantile point + :math:`q\cdot n` lies between two values + :math:`a \le b`. Can be one of the following: + + * :obj:`"lower"`: Returns the one with lowest value. + + * :obj:`"higher"`: Returns the one with highest value. + + * :obj:`"midpoint"`: Returns the average of the two values. + + * :obj:`"nearest"`: Returns the one whose index is nearest to the + quantile point. + + * :obj:`"linear"`: Returns a linear combination of the two + elements, defined as + :math:`f(a, b) = a + (b - a)\cdot(q\cdot n - i)`. + + (default: :obj:`"linear"`) + fill_value (float, optional): The default value in the case no entry is + found for a given index (default: :obj:`0.0`). + """ + + interpolations = {"linear", "lower", "higher", "nearest", "midpoint"} + + def __init__( + self, + q: Union[float, List[float]], + interpolation: str = "linear", + fill_value: float = 0.0, + ): + super().__init__() + + qs = [q] if not isinstance(q, (list, tuple)) else q + if len(qs) == 0: + raise ValueError("Provide at least one quantile value for `q`.") + if not all(0.0 <= quantile <= 1.0 for quantile in qs): + raise ValueError("`q` must be in the range [0, 1].") + if interpolation not in self.interpolations: + raise ValueError( + f"Invalid interpolation method " f"got ('{interpolation}')" + ) + + self._q = q + self.q = ms.Parameter(Tensor(qs).view(-1, 1), requires_grad=False) + self.interpolation = interpolation + self.fill_value = fill_value + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + dim = x.dim() + dim if dim < 0 else dim + + self.assert_index_present(index) + assert index is not None + + count = ops.bincount(index, minlength=dim_size or 0) + ptr = cumsum(count)[:-1] + + # In case there exists dangling indices (`dim_size > index.max()`), we + # need to clamp them to prevent out-of-bound issues: + if dim_size is not None: + ptr = ptr.clamp(max=x.shape[dim] - 1) + + q_point = self.q * (count - 1) + ptr + q_point = q_point.t().reshape(-1) + + shape = [1] * x.dim() + shape[dim] = -1 + index = index.reshape(shape).expand_as(x) + + # Two sorts: the first one on the value, + # the second (stable) on the indices: + x, x_perm = ops.sort(x, axis=dim) + index = ms.numpy.take_along_axis(index, x_perm, axis=dim) + index, index_perm = ops.sort(index, axis=dim) + x = ms.numpy.take_along_axis(x, index_perm, axis=dim) + + # Compute the quantile interpolations: + if self.interpolation == "lower": + quantile = x.index_select(dim, q_point.floor().long()) + elif self.interpolation == "higher": + quantile = x.index_select(dim, q_point.ceil().long()) + elif self.interpolation == "nearest": + quantile = x.index_select(dim, q_point.round().long()) + else: + l_quant = x.index_select(dim, q_point.floor().long()) + r_quant = x.index_select(dim, q_point.ceil().long()) + + if self.interpolation == "linear": + q_frac = q_point.frac().view(*shape) + quantile = l_quant + (r_quant - l_quant) * q_frac + else: + quantile = 0.5 * l_quant + 0.5 * r_quant + + # If the number of elements is zero, fill with pre-defined value: + repeats = self.q.numel() + mask = (count == 0).short().repeat(repeats).view(*shape) + + out = quantile.masked_fill(mask.bool(), self.fill_value) + + if self.q.numel() > 1: + shape = list(out.shape) + shape = shape[:dim] + [shape[dim] // self.q.numel(), -1] + shape[dim + 2:] + out = out.view(*shape) + + return out + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(q={self._q})" + + +class MedianAggregation(QuantileAggregation): + r"""An aggregation operator that returns the feature-wise median of a set. + + That is, for every feature :math:`d`, it computes + + .. math:: + {\mathrm{median}(\mathcal{X})}_d = x_{\pi_i,d} + + where :math:`x_{\pi_1,d} \le x_{\pi_2,d} \le \dots \le + x_{\pi_n,d}` and :math:`i = \lfloor \frac{n}{2} \rfloor`. + + .. note:: + If the median lies between two values, the lowest one is returned. + To compute the midpoint (or other kind of interpolation) of the two + values, use :class:`QuantileAggregation` instead. + + Args: + fill_value (float, optional): The default value in the case no entry is + found for a given index (default: :obj:`0.0`). + """ + + def __init__(self, fill_value: float = 0.0): + super().__init__(0.5, "lower", fill_value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mindscience/sharker/nn/aggr/scaler.py b/mindscience/sharker/nn/aggr/scaler.py new file mode 100644 index 000000000..4d94b11df --- /dev/null +++ b/mindscience/sharker/nn/aggr/scaler.py @@ -0,0 +1,115 @@ +from typing import Any, Dict, List, Optional, Union + +import mindspore as ms +from mindspore import Tensor, ops, mint +from .base import Aggregation +from .multi import MultiAggregation +from ..resolver import aggregation_resolver as aggr_resolver +from ...utils import degree + + +class DegreeScalerAggregation(Aggregation): + r"""Combines one or more aggregators and transforms its output with one or + more scalers as introduced in the `"Principal Neighbourhood Aggregation for + Graph Nets" `_ paper. + The scalers are normalised by the in-degree of the training set and so must + be provided at time of construction. + See :class:`sharker.nn.conv.PNAConv` for more information. + + Args: + aggr (str or [str] or Aggregation): The aggregation scheme to use. + See :class:`~sharker.nn.conv.MessagePassing` for more + information. + scaler (str or list): Set of scaling function identifiers, namely one + or more of :obj:`"identity"`, :obj:`"amplification"`, + :obj:`"attenuation"`, :obj:`"linear"` and :obj:`"inverse_linear"`. + deg (Tensor): Histogram of in-degrees of nodes in the training set, + used by scalers to normalize. + train_norm (bool, optional): Whether normalization parameters + are trainable. (default: :obj:`False`) + aggr_kwargs (Dict[str, Any], optional): Arguments passed to the + respective aggregation function in case it gets automatically + resolved. (default: :obj:`None`) + """ + + def __init__( + self, + aggr: Union[str, List[str], Aggregation], + scaler: Union[str, List[str]], + deg: Tensor, + train_norm: bool = False, + aggr_kwargs: Optional[List[Dict[str, Any]]] = None, + ): + super().__init__() + + if isinstance(aggr, (str, Aggregation)): + self.aggr = aggr_resolver(aggr, **(aggr_kwargs or {})) + elif isinstance(aggr, (tuple, list)): + self.aggr = MultiAggregation(aggr, aggr_kwargs) + else: + raise ValueError( + f"Only strings, list, tuples and instances of" + f"`sharker.nn.aggr.Aggregation` are " + f"valid aggregation schemes (got '{type(aggr)}')" + ) + + self.scaler = [scaler] if isinstance(aggr, str) else scaler + + deg = deg.float() + N = int(deg.sum()) + bin_degree = mint.arange(deg.numel()) + + self.init_avg_deg_lin = (bin_degree * deg).sum().item() / N + self.init_avg_deg_log = ((bin_degree + 1).float().log() * deg).sum().item() / N + + if train_norm: + self.avg_deg_lin = ms.Parameter(ms.numpy.empty(1)) + self.avg_deg_log = ms.Parameter(ms.numpy.empty(1)) + else: + self.avg_deg_lin = ms.numpy.empty(1) + self.avg_deg_log = ms.numpy.empty(1) + + self.reset_parameters() + + def reset_parameters(self): + self.avg_deg_lin[:] = self.init_avg_deg_lin + self.avg_deg_log[:] = self.init_avg_deg_log + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + self.assert_index_present(index) + + out = self.aggr(x, index, ptr, dim_size, dim) + + assert index is not None + deg = degree(index, num_nodes=dim_size, dtype=out.dtype) + size = [1] * len(out.shape) + size[dim] = -1 + deg = deg.view(*size) + + outs = [] + for scaler in self.scaler: + if scaler == "identity": + out_scaler = out + elif scaler == "amplification": + out_scaler = out * (mint.log(deg + 1) / self.avg_deg_log) + elif scaler == "attenuation": + # Clamp minimum degree to one to avoid dividing by zero: + out_scaler = out * (self.avg_deg_log / (deg.clamp(min=1) + 1).log()) + elif scaler == "linear": + out_scaler = out * (deg / self.avg_deg_lin) + elif scaler == "inverse_linear": + # Clamp minimum degree to one to avoid dividing by zero: + out_scaler = out * (self.avg_deg_lin / deg.clamp(min=1)) + else: + raise ValueError(f"Unknown scaler '{scaler}'") + outs.append(out_scaler) + + return mint.cat(outs, dim=-1) if len(outs) > 1 else outs[0] diff --git a/mindscience/sharker/nn/aggr/set2set.py b/mindscience/sharker/nn/aggr/set2set.py new file mode 100644 index 000000000..5b2d950d1 --- /dev/null +++ b/mindscience/sharker/nn/aggr/set2set.py @@ -0,0 +1,67 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn, mint +from .base import Aggregation +from ...utils import softmax + + +class Set2Set(Aggregation): + r"""The Set2Set aggregation operator based on iterative content-based + attention, as described in the `"Order Matters: Sequence to sequence for + Sets" `_ paper. + + .. math:: + \mathbf{q}_t &= \mathrm{LSTM}(\mathbf{q}^{*}_{t-1}) + + \alpha_{i,t} &= \mathrm{softmax}(\mathbf{x}_i \cdot \mathbf{q}_t) + + \mathbf{r}_t &= \sum_{i=1}^N \alpha_{i,t} \mathbf{x}_i + + \mathbf{q}^{*}_t &= \mathbf{q}_t \, \Vert \, \mathbf{r}_t, + + where :math:`\mathbf{q}^{*}_T` defines the output of the layer with twice + the dimensionality as the input. + + Args: + in_channels (int): Size of each input sample. + processing_steps (int): Number of iterations :math:`T`. + **kwargs (optional): Additional arguments of :class:`nn.LSTM`. + """ + + def __init__(self, in_channels: int, processing_steps: int, **kwargs): + super().__init__() + self.in_channels = in_channels + self.out_channels = 2 * in_channels + self.processing_steps = processing_steps + self.lstm = nn.LSTM(self.out_channels, in_channels, **kwargs) + + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + ) -> Tensor: + + self.assert_index_present(index) + self.assert_two_dimensional_input(x, dim) + + h = ( + mint.zeros((self.lstm.num_layers, dim_size, x.shape[-1]), dtype=x.dtype), + mint.zeros((self.lstm.num_layers, dim_size, x.shape[-1]), dtype=x.dtype), + ) + q_star = mint.zeros((dim_size, self.out_channels), dtype=x.dtype) + + for _ in range(self.processing_steps): + q, h = self.lstm(q_star.unsqueeze(0), h) + q = q.view(dim_size, self.in_channels) + e = (x * q[index]).sum(-1, keepdims=True) + a = softmax(e, index, ptr, dim_size, dim) + r = self.reduce(a * x, index, ptr, dim_size, dim, reduce="sum") + q_star = mint.cat([q, r], dim=-1) + + return q_star + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels})" diff --git a/mindscience/sharker/nn/aggr/set_transformer.py b/mindscience/sharker/nn/aggr/set_transformer.py new file mode 100644 index 000000000..735f4c5f1 --- /dev/null +++ b/mindscience/sharker/nn/aggr/set_transformer.py @@ -0,0 +1,116 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn +from mindspore import nn + +from ...experimental import disable_dynamic_shapes +from .base import Aggregation +from .utils import PoolingByMultiheadAttention, SetAttentionBlock + + +class SetTransformerAggregation(Aggregation): + r"""Performs "Set Transformer" aggregation in which the elements to + aggregate are processed by multi-head attention blocks, as described in + the `"Graph Neural Networks with Adaptive Readouts" + `_ paper. + + .. note:: + + :class:`SetTransformerAggregation` requires sorted indices :obj:`index` + as input. Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + Args: + channels (int): Size of each input sample. + num_seed_points (int, optional): Number of seed points. + (default: :obj:`1`) + num_encoder_blocks (int, optional): Number of Set Attention Blocks + (SABs) in the encoder. (default: :obj:`1`). + num_decoder_blocks (int, optional): Number of Set Attention Blocks + (SABs) in the decoder. (default: :obj:`1`). + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the seed embeddings + are averaged instead of concatenated. (default: :obj:`True`) + norm (str, optional): If set to :obj:`True`, will apply layer + normalization. (default: :obj:`False`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + num_seed_points: int = 1, + num_encoder_blocks: int = 1, + num_decoder_blocks: int = 1, + heads: int = 1, + concat: bool = True, + layer_norm: bool = False, + dropout: float = 0.0, + ): + super().__init__() + + self.channels = channels + self.num_seed_points = num_seed_points + self.heads = heads + self.concat = concat + self.layer_norm = layer_norm + self.dropout = dropout + + self.encoders = nn.CellList( + [ + SetAttentionBlock(channels, heads, layer_norm, dropout) + for _ in range(num_encoder_blocks) + ] + ) + + self.pma = PoolingByMultiheadAttention( + channels, num_seed_points, heads, layer_norm, dropout + ) + + self.decoders = nn.CellList( + [ + SetAttentionBlock(channels, heads, layer_norm, dropout) + for _ in range(num_decoder_blocks) + ] + ) + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + x, mask = self.to_dense_batch( + x, index, ptr, dim_size, dim, max_num_elements=max_num_elements + ) + + for encoder in self.encoders: + x = encoder(x, mask) + + x = self.pma(x, mask) + + for decoder in self.decoders: + x = decoder(x) + + x[x.isnan()] = 0 + + return x.flatten(start_dim=1, end_dim=2) if self.concat else x.mean(1) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.channels}, " + f"num_seed_points={self.num_seed_points}, " + f"heads={self.heads}, " + f"layer_norm={self.layer_norm}, " + f"dropout={self.dropout})" + ) diff --git a/mindscience/sharker/nn/aggr/sort.py b/mindscience/sharker/nn/aggr/sort.py new file mode 100644 index 000000000..bec7b38a5 --- /dev/null +++ b/mindscience/sharker/nn/aggr/sort.py @@ -0,0 +1,76 @@ +from typing import Optional + +import mindspore as ms +from mindspore import Tensor, ops, mint +from ...experimental import disable_dynamic_shapes +from .base import Aggregation + + +class SortAggregation(Aggregation): + r"""The pooling operator from the `"An End-to-End Deep Learning + Architecture for Graph Classification" + `_ paper, + where node features are sorted in descending order based on their last + feature channel. The first :math:`k` nodes form the output of the layer. + + .. note:: + + :class:`SortAggregation` requires sorted indices :obj:`index` as input. + Specifically, if you use this aggregation as part of + :class:`~sharker.nn.conv.MessagePassing`, ensure that + :obj:`edge_index` is sorted by destination nodes, either by manually + sorting edge indices via :meth:`~sharker.utils.sort_edge_index` + or by calling :meth:`sharker.data.Graph.sort`. + + Args: + k (int): The number of nodes to hold for each graph. + """ + + def __init__(self, k: int): + super().__init__() + self.k = k + + @disable_dynamic_shapes(required_args=["dim_size", "max_num_elements"]) + def construct( + self, + x: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + dim: int = -2, + max_num_elements: Optional[int] = None, + ) -> Tensor: + + fill_value = x.min() - 1 + batch_x, _ = self.to_dense_batch( + x, + index, + ptr, + dim_size, + dim, + fill_value=fill_value, + max_num_elements=max_num_elements, + ) + B, N, D = batch_x.shape + + _, perm = batch_x[:, :, -1].sort(axis=-1, descending=True) + arange = mint.arange(B, dtype=ms.int64) * N + perm = perm + arange.view(-1, 1) + + batch_x = batch_x.view(B * N, D) + batch_x = batch_x[perm] + batch_x = batch_x.view(B, N, D) + + if N >= self.k: + batch_x = batch_x[:, : self.k] + else: + expand_batch_x = mint.full((B, self.k - N, D), fill_value, dtype=batch_x.dtype) + batch_x = mint.cat([batch_x, expand_batch_x], dim=1) + + batch_x[batch_x == fill_value] = 0 + x = batch_x.view(B, self.k * D) + + return x + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(k={self.k})" diff --git a/mindscience/sharker/nn/aggr/utils.py b/mindscience/sharker/nn/aggr/utils.py new file mode 100644 index 000000000..2e774f3de --- /dev/null +++ b/mindscience/sharker/nn/aggr/utils.py @@ -0,0 +1,236 @@ +from typing import Optional + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from mindspore.nn import LayerNorm, Dense, MultiheadAttention +from mindspore.common.initializer import initializer + + +class MultiheadAttentionBlock(nn.Cell): + r"""The Multihead Attention Block (MAB) from the `"Set Transformer: A + Framework for Attention-based Permutation-Invariant Neural Networks" + `_ paper. + + .. math:: + + \mathrm{MAB}(\mathbf{x}, \mathbf{y}) &= \mathrm{LayerNorm}(\mathbf{h} + + \mathbf{W} \mathbf{h}) + + \mathbf{h} &= \mathrm{LayerNorm}(\mathbf{x} + + \mathrm{Multihead}(\mathbf{x}, \mathbf{y}, \mathbf{y})) + + Args: + channels (int): Size of each input sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`False`, will not apply layer + normalization. (default: :obj:`True`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + heads: int = 1, + layer_norm: bool = True, + dropout: float = 0.0, + ): + super().__init__() + + self.channels = channels + self.heads = heads + self.dropout = dropout + + self.attn = MultiheadAttention( + channels, + heads, + batch_first=True, + dropout=dropout, + ) + self.lin = Dense(channels, channels) + self.layer_norm1 = LayerNorm([channels]) if layer_norm else None + self.layer_norm2 = LayerNorm([channels]) if layer_norm else None + + def construct( + self, + x: Tensor, + y: Tensor, + x_mask: Optional[Tensor] = None, + y_mask: Optional[Tensor] = None, + ) -> Tensor: + """construct""" + if y_mask is not None: + y_mask = ~y_mask + + out = self.attn(x, y, y, y_mask, need_weights=False)[0] + + if x_mask is not None: + out[~x_mask] = 0.0 + + out += x + + if self.layer_norm1 is not None: + out = self.layer_norm1(out) + + out += mint.nn.functional.relu(self.lin(out)) + + if self.layer_norm2 is not None: + out = self.layer_norm2(out) + + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.channels}, " + f"heads={self.heads}, " + f"layer_norm={self.layer_norm1 is not None}, " + f"dropout={self.dropout})" + ) + + +class SetAttentionBlock(nn.Cell): + r"""The Set Attention Block (SAB) from the `"Set Transformer: A + Framework for Attention-based Permutation-Invariant Neural Networks" + `_ paper. + + .. math:: + + \mathrm{SAB}(\mathbf{X}) = \mathrm{MAB}(\mathbf{x}, \mathbf{y}) + + Args: + channels (int): Size of each input sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`False`, will not apply layer + normalization. (default: :obj:`True`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + heads: int = 1, + layer_norm: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.mab = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) + + def construct(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + return self.mab(x, x, mask, mask) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.mab.channels}, " + f"heads={self.mab.heads}, " + f"layer_norm={self.mab.layer_norm1 is not None}, " + f"dropout={self.mab.dropout})" + ) + + +class InducedSetAttentionBlock(nn.Cell): + r"""The Induced Set Attention Block (SAB) from the `"Set Transformer: A + Framework for Attention-based Permutation-Invariant Neural Networks" + `_ paper. + + .. math:: + + \mathrm{ISAB}(\mathbf{X}) &= \mathrm{MAB}(\mathbf{x}, \mathbf{h}) + + \mathbf{h} &= \mathrm{MAB}(\mathbf{I}, \mathbf{x}) + + where :math:`\mathbf{I}` denotes :obj:`num_induced_points` learnable + vectors. + + Args: + channels (int): Size of each input sample. + num_induced_points (int): Number of induced points. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`False`, will not apply layer + normalization. (default: :obj:`True`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + num_induced_points: int, + heads: int = 1, + layer_norm: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.ind = ms.Parameter(mint.zeros([1, num_induced_points, channels])) + self.mab1 = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) + self.mab2 = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) + + def construct(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + h = self.mab1(self.ind.broadcast_to((x.shape[0], -1, -1)), x, y_mask=mask) + return self.mab2(x, h, x_mask=mask) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.ind.shape[2]}, " + f"num_induced_points={self.ind.shape[1]}, " + f"heads={self.mab1.heads}, " + f"layer_norm={self.mab1.layer_norm1 is not None}, " + f"dropout={self.mab1.dropout})" + ) + + +class PoolingByMultiheadAttention(nn.Cell): + r"""The Pooling by Multihead Attention (PMA) layer from the `"Set + Transformer: A Framework for Attention-based Permutation-Invariant Neural + Networks" `_ paper. + + .. math:: + + \mathrm{PMA}(\mathbf{X}) = \mathrm{MAB}(\mathbf{S}, \mathbf{x}) + + where :math:`\mathbf{S}` denotes :obj:`num_seed_points` learnable vectors. + + Args: + channels (int): Size of each input sample. + num_seed_points (int, optional): Number of seed points. + (default: :obj:`1`) + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + norm (str, optional): If set to :obj:`False`, will not apply layer + normalization. (default: :obj:`True`) + dropout (float, optional): Dropout probability of attention weights. + (default: :obj:`0`) + """ + + def __init__( + self, + channels: int, + num_seed_points: int = 1, + heads: int = 1, + layer_norm: bool = True, + dropout: float = 0.0, + ): + super().__init__() + self.lin = Dense(channels, channels) + self.seed = ms.Parameter(mint.zeros([1, num_seed_points, channels])) + self.mab = MultiheadAttentionBlock(channels, heads, layer_norm, dropout) + self.reset_parameters() + + def reset_parameters(self): + self.seed = initializer("uniform", self.seed.shape, self.seed.dtype).init_data() + + def construct(self, x: Tensor, mask: Optional[Tensor] = None) -> Tensor: + x = mint.nn.functional.relu(self.lin(x)) + return self.mab(self.seed.broadcast_to((x.shape[0], -1, -1)), x, y_mask=mask) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.seed.shape[2]}, " + f"num_seed_points={self.seed.shape[1]}, " + f"heads={self.mab.heads}, " + f"layer_norm={self.mab.layer_norm1 is not None}, " + f"dropout={self.mab.dropout})" + ) diff --git a/mindscience/sharker/nn/aggr/variance_preserving.py b/mindscience/sharker/nn/aggr/variance_preserving.py new file mode 100644 index 000000000..c581be6b2 --- /dev/null +++ b/mindscience/sharker/nn/aggr/variance_preserving.py @@ -0,0 +1,34 @@ +from typing import Optional + +from mindspore import Tensor + +from . import Aggregation +from ...utils import degree +from ...utils import broadcast_to + + +class VariancePreservingAggregation(Aggregation): + r"""Performs the Variance Preserving Aggregation (VPA) from the `"GNN-VPA: + A Variance-Preserving Aggregation Strategy for Graph Neural Networks" + `_ paper. + + .. math:: + \mathrm{vpa}(\mathcal{X}) = \frac{1}{\sqrt{|\mathcal{X}|}} + \sum_{\mathbf{x}_i \in \mathcal{X}} \mathbf{x}_i + """ + + def construct(self, x: Tensor, index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, dim_size: Optional[int] = None, + dim: int = -2) -> Tensor: + + out = self.reduce(x, index, ptr, dim_size, dim, reduce='sum') + + if ptr is not None: + count = ptr.diff().float() + else: + count = degree(index, dim_size, dtype=out.dtype) + + count = count.sqrt().clamp(min=1.0) + count = broadcast_to(count, out, dim=dim, is_dense=True) + + return out / count diff --git a/mindscience/sharker/nn/conv/__init__.py b/mindscience/sharker/nn/conv/__init__.py new file mode 100644 index 000000000..4b62ae217 --- /dev/null +++ b/mindscience/sharker/nn/conv/__init__.py @@ -0,0 +1,113 @@ +from .message_passing import MessagePassing +from .simple_conv import SimpleConv +from .gcn_conv import GCNConv +from .cheb_conv import ChebConv +from .sage_conv import SAGEConv +from .graph_conv import GraphConv +from .gravnet_conv import GravNetConv +from .gated_graph_conv import GatedGraphConv +from .res_gated_graph_conv import ResGatedGraphConv +from .gat_conv import GATConv +from .gatv2_conv import GATv2Conv +from .transformer_conv import TransformerConv +from .agnn_conv import AGNNConv +from .tag_conv import TAGConv +from .gin_conv import GINConv, GINEConv +from .arma_conv import ARMAConv +from .sg_conv import SGConv +from .appnp import APPNP +from .mf_conv import MFConv +from .rgat_conv import RGATConv +from .signed_conv import SignedConv +from .dna_conv import DNAConv +from .point_conv import PointNetConv +from .gmm_conv import GMMConv +from .nn_conv import NNConv +from .cg_conv import CGConv +from .edge_conv import EdgeConv, DynamicEdgeConv +from .x_conv import XConv +from .ppf_conv import PPFConv +from .feast_conv import FeaStConv +from .point_transformer_conv import PointTransformerConv +from .hypergraph_conv import HypergraphConv +from .le_conv import LEConv +from .cluster_gcn_conv import ClusterGCNConv +from .gen_conv import GENConv +from .wl_conv import WLConv +from .wl_conv_continuous import WLConvContinuous +from .film_conv import FiLMConv +from .fa_conv import FAConv +from .eg_conv import EGConv +from .pdn_conv import PDNConv +from .general_conv import GeneralConv +from .hgt_conv import HGTConv +from .heat_conv import HEATConv +from .hetero_conv import HeteroConv +from .lg_conv import LGConv +from .ssg_conv import SSGConv +from .point_gnn_conv import PointGNNConv +from .antisymmetric_conv import AntiSymmetricConv +from .dir_gnn_conv import DirGNNConv +from .mixhop_conv import MixHopConv + + +__all__ = [ + "MessagePassing", + "SimpleConv", + "GCNConv", + "ChebConv", + "SAGEConv", + "GraphConv", + "GravNetConv", + "GatedGraphConv", + "ResGatedGraphConv", + "GATConv", + "GATv2Conv", + "TransformerConv", + "AGNNConv", + "TAGConv", + "GINConv", + "GINEConv", + "ARMAConv", + "SGConv", + "SSGConv", + "APPNP", + "MFConv", + "RGATConv", + "SignedConv", + "DNAConv", + "PointNetConv", + "GMMConv", + "NNConv", + "CGConv", + "EdgeConv", + "DynamicEdgeConv", + "XConv", + "PPFConv", + "FeaStConv", + "PointTransformerConv", + "HypergraphConv", + "LEConv", + "ClusterGCNConv", + "GENConv", + "WLConv", + "WLConvContinuous", + "FiLMConv", + "FAConv", + "EGConv", + "PDNConv", + "GeneralConv", + "HGTConv", + "HEATConv", + "HeteroConv", + "LGConv", + "PointGNNConv", + "AntiSymmetricConv", + "DirGNNConv", + "MixHopConv", +] + +classes = __all__ + +ECConv = NNConv +PointConv = PointNetConv diff --git a/mindscience/sharker/nn/conv/agnn_conv.py b/mindscience/sharker/nn/conv/agnn_conv.py new file mode 100644 index 000000000..f37373c1e --- /dev/null +++ b/mindscience/sharker/nn/conv/agnn_conv.py @@ -0,0 +1,76 @@ +from typing import Optional, Union +import mindspore as ms +from mindspore import Tensor, Parameter, ops, mint +from .message_passing import MessagePassing +from ...utils import add_self_loops, remove_self_loops, softmax + + +class AGNNConv(MessagePassing): + r"""The graph attentional propagation layer from the + `"Attention-based Graph Neural Network for Semi-Supervised Learning" + `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \mathbf{P} \mathbf{X}, + + where the propagation matrix :math:`\mathbf{P}` is computed as + + .. math:: + P_{i,j} = \frac{\exp( \beta \cdot \cos(\mathbf{x}_i, \mathbf{x}_j))} + {\sum_{k \in \mathcal{N}(i)\cup \{ i \}} \exp( \beta \cdot + \cos(\mathbf{x}_i, \mathbf{x}_k))} + + with trainable parameter :math:`\beta`. + + Args: + requires_grad (bool, optional): If set to :obj:`False`, :math:`\beta` + will not be trainable. (default: :obj:`True`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F)` + """ + + def __init__(self, requires_grad: bool = True, add_self_loops: bool = True, + **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.requires_grad = requires_grad + self.add_self_loops = add_self_loops + + if requires_grad: + self.beta = Parameter(ms.numpy.empty(1)) + else: + self.beta = Parameter(mint.ones(1), requires_grad=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.requires_grad: + self.beta.data[:] = 1 + + def construct(self, x: Tensor, edge_index: Union[Tensor, ]) -> Tensor: + if self.add_self_loops: + if isinstance(edge_index, Tensor): + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, + num_nodes=x.shape[self.node_dim]) + + x_norm = x / ms.numpy.norm(x, ord=2., axis=-1, keepdims=True) + + return self.propagate(edge_index, x=x, x_norm=x_norm) + + def message(self, x_j: Tensor, x_norm_i: Tensor, x_norm_j: Tensor, + index: Tensor, ptr: Optional[Tensor], + size_i: Optional[int]) -> Tensor: + alpha = self.beta * (x_norm_i * x_norm_j).sum(-1) + alpha = softmax(alpha, index, ptr, size_i) + return x_j * alpha.view(-1, 1) diff --git a/mindscience/sharker/nn/conv/antisymmetric_conv.py b/mindscience/sharker/nn/conv/antisymmetric_conv.py new file mode 100644 index 000000000..10d425940 --- /dev/null +++ b/mindscience/sharker/nn/conv/antisymmetric_conv.py @@ -0,0 +1,116 @@ +import numpy as np +from typing import Any, Callable, Dict, Optional, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from .gcn_conv import GCNConv +from ..inits import zeros, kaiming_uniform +from ..resolver import activation_resolver + + +class AntiSymmetricConv(nn.Cell): + r"""The anti-symmetric graph convolutional operator from the + `"Anti-Symmetric DGN: a stable architecture for Deep Graph Networks" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{x}_i + \epsilon \cdot \sigma \left( + (\mathbf{W}-\mathbf{W}^T-\gamma \mathbf{I}) \mathbf{x}_i + + \Phi(\mathbf{X}, \mathcal{N}_i) + \mathbf{b}\right), + + where :math:`\Phi(\mathbf{X}, \mathcal{N}_i)` denotes a + :class:`~nn.conv.MessagePassing` layer. + + Args: + in_channels (int): Size of each input sample. + phi (MessagePassing, optional): The message passing module + :math:`\Phi`. If set to :obj:`None`, will use a + :class:`~sharker.nn.conv.GCNConv` layer as default. + (default: :obj:`None`) + num_iters (int, optional): The number of times the anti-symmetric deep + graph network operator is called. (default: :obj:`1`) + epsilon (float, optional): The discretization step size + :math:`\epsilon`. (default: :obj:`0.1`) + gamma (float, optional): The strength of the diffusion :math:`\gamma`. + It regulates the stability of the method. (default: :obj:`0.1`) + act (str, optional): The non-linear activation function :math:`\sigma`, + *e.g.*, :obj:`"tanh"` or :obj:`"relu"`. (default: :class:`"tanh"`) + act_kwargs (Dict[str, Any], optional): Arguments passed to the + respective activation function defined by :obj:`act`. + (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{in})` + """ + + def __init__( + self, + in_channels: int, + phi: Optional[MessagePassing] = None, + num_iters: int = 1, + epsilon: float = 0.1, + gamma: float = 0.1, + act: Union[str, Callable, None] = 'tanh', + act_kwargs: Optional[Dict[str, Any]] = None, + bias: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.num_iters = num_iters + self.gamma = gamma + self.epsilon = epsilon + self.act = activation_resolver(act, **(act_kwargs or {})) + + if phi is None: + phi = GCNConv(in_channels, in_channels, has_bias=False) + + self.W = Parameter(ms.numpy.empty([in_channels, in_channels])) + self.eye = Parameter(mint.eye(in_channels), requires_grad=False) + self.phi = phi + + if bias: + self.bias = Parameter(ms.numpy.empty(in_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + kaiming_uniform(self.W, a=np.sqrt(5)) + self.phi.reset_parameters() + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Tensor, edge_index: Optional[Tensor], *args, **kwargs) -> Tensor: + r"""Runs the forward pass of the module.""" + antisymmetric_W = self.W - self.W.t() - self.gamma * self.eye + + for _ in range(self.num_iters): + h = self.phi(x, edge_index, *args, **kwargs) + h = x @ antisymmetric_W.t() + h + + if self.bias is not None: + h += self.bias + + if self.act is not None: + h = self.act(h) + + x = x + self.epsilon * h + + return x + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(' + f'{self.in_channels}, ' + f'phi={self.phi}, ' + f'num_iters={self.num_iters}, ' + f'epsilon={self.epsilon}, ' + f'gamma={self.gamma})') diff --git a/mindscience/sharker/nn/conv/appnp.py b/mindscience/sharker/nn/conv/appnp.py new file mode 100644 index 000000000..e88b6a740 --- /dev/null +++ b/mindscience/sharker/nn/conv/appnp.py @@ -0,0 +1,109 @@ +from typing import Optional, Union, Tuple +from mindspore import Tensor, ops +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm + + +class APPNP(MessagePassing): + r"""The approximate personalized propagation of neural predictions layer + from the `"Predict then Propagate: Graph Neural Networks meet Personalized + PageRank" `_ paper. + + .. math:: + \mathbf{X}^{(0)} &= \mathbf{X} + + \mathbf{X}^{(k)} &= (1 - \alpha) \mathbf{\hat{D}}^{-1/2} + \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \mathbf{X}^{(k-1)} + \alpha + \mathbf{X}^{(0)} + + \mathbf{X}^{\prime} &= \mathbf{X}^{(K)}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + The adjacency matrix can include other values than :obj:`1` representing + edge weights via the optional :obj:`edge_weight` tensor. + + Args: + K (int): Number of iterations :math:`K`. + alpha (float): Teleport probability :math:`\alpha`. + dropout (float, optional): Dropout probability of edges during + training. (default: :obj:`0`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the + cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + normalize (bool, optional): Whether to add self-loops and apply + symmetric normalization. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` + """ + _cached_edge_index: Optional[Tuple[Tensor, Optional[Tensor]]] + + def __init__(self, K: int, alpha: float, dropout: float = 0., + cached: bool = False, add_self_loops: bool = True, + normalize: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + self.K = K + self.alpha = alpha + self.dropout = dropout + self.cached = cached + self.add_self_loops = add_self_loops + self.normalize = normalize + + self._cached_edge_index = None + self._cached_adj_t = None + + def reset_parameters(self): + super().reset_parameters() + self._cached_edge_index = None + + def construct( + self, + x: Tensor, + edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, + ) -> Tensor: + + if self.normalize: + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index + if cache is None: + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, dtype=x.dtype) + if self.cached: + self._cached_edge_index = (edge_index, edge_weight) + else: + edge_index, edge_weight = cache[0], cache[1] + + h = x + for k in range(self.K): + if self.dropout > 0 and self.training: + if isinstance(edge_index, Tensor): + assert edge_weight is not None + edge_weight = ops.dropout(edge_weight, p=self.dropout) + + x = self.propagate(edge_index, x=x, edge_weight=edge_weight) + x = x * (1 - self.alpha) + x = x + self.alpha * h + + return x + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(K={self.K}, alpha={self.alpha})' diff --git a/mindscience/sharker/nn/conv/arma_conv.py b/mindscience/sharker/nn/conv/arma_conv.py new file mode 100644 index 000000000..7d3920e8b --- /dev/null +++ b/mindscience/sharker/nn/conv/arma_conv.py @@ -0,0 +1,131 @@ +from typing import Callable, Optional, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import glorot, zeros + + +class ARMAConv(MessagePassing): + r"""The ARMA graph convolutional operator from the `"Graph Neural Networks + with Convolutional ARMA Filters" `_ + paper. + + .. math:: + \mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K \mathbf{X}_k^{(T)}, + + with :math:`\mathbf{X}_k^{(T)}` being recursively defined by + + .. math:: + \mathbf{X}_k^{(t+1)} = \sigma \left( \mathbf{\hat{L}} + \mathbf{X}_k^{(t)} \mathbf{W} + \mathbf{X}^{(0)} \mathbf{V} \right), + + where :math:`\mathbf{\hat{L}} = \mathbf{I} - \mathbf{L} = \mathbf{D}^{-1/2} + \mathbf{A} \mathbf{D}^{-1/2}` denotes the + modified Laplacian :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} + \mathbf{A} \mathbf{D}^{-1/2}`. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample + :math:`\mathbf{x}^{(t+1)}`. + num_stacks (int, optional): Number of parallel stacks :math:`K`. + (default: :obj:`1`). + num_layers (int, optional): Number of layers :math:`T`. + (default: :obj:`1`) + act (callable, optional): Activation function :math:`\sigma`. + (default: :meth:`nn.ReLU()`) + shared_weights (int, optional): If set to :obj:`True` the layers in + each stack will share the same parameters. (default: :obj:`False`) + dropout (float, optional): Dropout probability of the skip connection. + (default: :obj:`0.`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, + num_stacks: int = 1, num_layers: int = 1, + shared_weights: bool = False, + act: Optional[Callable] = nn.ReLU, dropout: float = 0., + bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_stacks = num_stacks + self.num_layers = num_layers + self.act = act() + self.shared_weights = shared_weights + self.dropout = dropout + + K, T, F_in, F_out = num_stacks, num_layers, in_channels, out_channels + T = 1 if self.shared_weights else T + + self.weight = Parameter(ms.numpy.empty([max(1, T - 1), K, F_out, F_out])) + + self.init_weight = Parameter(ms.numpy.empty([K, F_in, F_out])) + self.root_weight = Parameter(ms.numpy.empty([T, K, F_in, F_out])) + + if bias: + self.bias = Parameter(ms.numpy.empty([T, K, 1, F_out])) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.weight) + glorot(self.init_weight) + glorot(self.root_weight) + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if isinstance(edge_index, Tensor): + edge_index, edge_weight = gcn_norm( + edge_index, edge_weight, x.shape[self.node_dim], + add_self_loops=False, flow=self.flow, dtype=x.dtype) + + x = x.unsqueeze(-3) + out = x + for t in range(self.num_layers): + if t == 0: + out = out @ self.init_weight + else: + out = out @ self.weight[0 if self.shared_weights else t - 1] + + out = self.propagate(edge_index, x=out, edge_weight=edge_weight) + + root = ops.dropout(x, p=self.dropout, training=self.training) + root = root @ self.root_weight[0 if self.shared_weights else t] + out += root + + if self.bias is not None: + out += self.bias[0 if self.shared_weights else t] + + if self.act is not None: + out = self.act(out) + + return out.mean(-3) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, num_stacks={self.num_stacks}, ' + f'num_layers={self.num_layers})') diff --git a/mindscience/sharker/nn/conv/cg_conv copy.py b/mindscience/sharker/nn/conv/cg_conv copy.py new file mode 100644 index 000000000..da6ac8a27 --- /dev/null +++ b/mindscience/sharker/nn/conv/cg_conv copy.py @@ -0,0 +1,86 @@ +from typing import Tuple, Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing + + +class CGConv(MessagePassing): + r"""The crystal graph convolutional operator from the + `"Crystal Graph Convolutional Neural Networks for an + Accurate and Interpretable Prediction of Material Properties" + `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} + \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) + \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right) + + where :math:`\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, + \mathbf{e}_{i,j} ]` denotes the concatenation of central node features, + neighboring node features and edge features. + In addition, :math:`\sigma` and :math:`g` denote the sigmoid and softplus + functions, respectively. + + Args: + channels (int or tuple): Size of each input sample. A tuple + corresponds to the sizes of source and target dimensionalities. + dim (int, optional): Edge feature dimensionality. (default: :obj:`0`) + aggr (str, optional): The aggregation operator to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"add"`) + batch_norm (bool, optional): If set to :obj:`True`, will make use of + batch normalization. (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` or + :math:`(|\mathcal{V_t}|, F_{t})` if bipartite + """ + + def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0, + aggr: str = 'add', batch_norm: bool = False, + bias: bool = True, **kwargs): + super().__init__(aggr=aggr, **kwargs) + self.channels = channels + self.dim = dim + self.batch_norm = batch_norm + + if isinstance(channels, int): + channels = (channels, channels) + + self.lin_f = nn.Dense(sum(channels) + dim, channels[1], has_bias=bias) + self.lin_s = nn.Dense(sum(channels) + dim, channels[1], has_bias=bias) + if batch_norm: + self.bn = nn.BatchNorm1d(channels[1]) + else: + self.bn = None + + def construct(self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + out = self.propagate(edge_index, x=x, edge_attr=edge_attr) + out = out if self.bn is None else self.bn(out) + out += x[1] + return out + + def message(self, x_i, x_j, edge_attr: Optional[Tensor]) -> Tensor: + if edge_attr is None: + z = mint.cat(([x_i, x_j]), axis=-1) + else: + z = mint.cat(([x_i, x_j, edge_attr]), axis=-1) + return mint.sigmoid(self.lin_f(z)) * ops.softplus(self.lin_s(z)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.channels}, dim={self.dim})' diff --git a/mindscience/sharker/nn/conv/cg_conv.py b/mindscience/sharker/nn/conv/cg_conv.py new file mode 100644 index 000000000..9b5ced71a --- /dev/null +++ b/mindscience/sharker/nn/conv/cg_conv.py @@ -0,0 +1,86 @@ +from typing import Tuple, Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing + + +class CGConv(MessagePassing): + r"""The crystal graph convolutional operator from the + `"Crystal Graph Convolutional Neural Networks for an + Accurate and Interpretable Prediction of Material Properties" + `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} + \sigma \left( \mathbf{z}_{i,j} \mathbf{W}_f + \mathbf{b}_f \right) + \odot g \left( \mathbf{z}_{i,j} \mathbf{W}_s + \mathbf{b}_s \right) + + where :math:`\mathbf{z}_{i,j} = [ \mathbf{x}_i, \mathbf{x}_j, + \mathbf{e}_{i,j} ]` denotes the concatenation of central node features, + neighboring node features and edge features. + In addition, :math:`\sigma` and :math:`g` denote the sigmoid and softplus + functions, respectively. + + Args: + channels (int or tuple): Size of each input sample. A tuple + corresponds to the sizes of source and target dimensionalities. + dim (int, optional): Edge feature dimensionality. (default: :obj:`0`) + aggr (str, optional): The aggregation operator to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"add"`) + batch_norm (bool, optional): If set to :obj:`True`, will make use of + batch normalization. (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` or + :math:`(|\mathcal{V_t}|, F_{t})` if bipartite + """ + + def __init__(self, channels: Union[int, Tuple[int, int]], dim: int = 0, + aggr: str = 'add', batch_norm: bool = False, + bias: bool = True, **kwargs): + super().__init__(aggr=aggr, **kwargs) + self.channels = channels + self.dim = dim + self.batch_norm = batch_norm + + if isinstance(channels, int): + channels = (channels, channels) + + self.lin_f = nn.Dense(sum(channels) + dim, channels[1], has_bias=bias) + self.lin_s = nn.Dense(sum(channels) + dim, channels[1], has_bias=bias) + if batch_norm: + self.bn = nn.BatchNorm1d(channels[1]) + else: + self.bn = None + + def construct(self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + out = self.propagate(edge_index, x=x, edge_attr=edge_attr) + out = out if self.bn is None else self.bn(out) + out += x[1] + return out + + def message(self, x_i, x_j, edge_attr: Optional[Tensor]) -> Tensor: + if edge_attr is None: + z = mint.cat([x_i, x_j], dim=-1) + else: + z = mint.cat([x_i, x_j, edge_attr], dim=-1) + return self.lin_f(z).sigmoid() * ops.softplus(self.lin_s(z)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.channels}, dim={self.dim})' diff --git a/mindscience/sharker/nn/conv/cheb_conv.py b/mindscience/sharker/nn/conv/cheb_conv.py new file mode 100644 index 000000000..e777ed984 --- /dev/null +++ b/mindscience/sharker/nn/conv/cheb_conv.py @@ -0,0 +1,181 @@ +from typing import Optional +import mindspore as ms +from mindspore import Tensor, ops, nn, Parameter, mint +from .message_passing import MessagePassing +from ..inits import zeros, glorot +from ...utils import get_laplacian + + +class ChebConv(MessagePassing): + r"""The chebyshev spectral graph convolutional operator from the + `"Convolutional Neural Networks on Graphs with Fast Localized Spectral + Filtering" `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \sum_{k=1}^{K} \mathbf{Z}^{(k)} \cdot + \mathbf{\Theta}^{(k)} + + where :math:`\mathbf{Z}^{(k)}` is computed recursively by + + .. math:: + \mathbf{Z}^{(1)} &= \mathbf{X} + + \mathbf{Z}^{(2)} &= \mathbf{\hat{L}} \cdot \mathbf{X} + + \mathbf{Z}^{(k)} &= 2 \cdot \mathbf{\hat{L}} \cdot + \mathbf{Z}^{(k-1)} - \mathbf{Z}^{(k-2)} + + and :math:`\mathbf{\hat{L}}` denotes the scaled and normalized Laplacian + :math:`\frac{2\mathbf{L}}{\lambda_{\max}} - \mathbf{I}`. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + K (int): Chebyshev filter size :math:`K`. + normalization (str, optional): The normalization scheme for the graph + Laplacian (default: :obj:`"sym"`): + + 1. :obj:`None`: No normalization + :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` + + 2. :obj:`"sym"`: Symmetric normalization + :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} + \mathbf{D}^{-1/2}` + + 3. :obj:`"rw"`: Random-walk normalization + :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` + + :obj:`\lambda_max` should be a :class:`Tensor` of size + :obj:`[num_graphs]` in a mini-batch scenario and a + scalar/zero-dimensional tensor when operating on single graphs. + You can pre-compute :obj:`lambda_max` via the + :class:`sharker.transforms.LaplacianLambdaMax` transform. + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)*, + batch vector :math:`(|\mathcal{V}|)` *(optional)*, + maximum :obj:`lambda` value :math:`(|\mathcal{G}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + K: int, + normalization: Optional[str] = 'sym', + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + assert K > 0 + assert normalization in [None, 'sym', 'rw'], 'Invalid normalization' + + self.in_channels = in_channels + self.out_channels = out_channels + self.normalization = normalization + self.lins = nn.CellList([ + nn.Dense(in_channels, out_channels, has_bias=False) for _ in range(K) + ]) + + if bias: + self.bias = Parameter(mint.zeros(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + for lin in self.lins: + glorot(lin.weight) + if self.bias is not None: + zeros(self.bias) + + def __norm__( + self, + edge_index: Tensor, + num_nodes: Optional[int], + edge_weight: Optional[Tensor], + normalization: Optional[str], + lambda_max: Optional[Tensor] = None, + dtype: Optional[int] = None, + batch: Optional[Tensor] = None, + ): + edge_index, edge_weight = get_laplacian(edge_index, edge_weight, + normalization, dtype, + num_nodes) + assert edge_weight is not None + + if lambda_max is None: + lambda_max = 2.0 * edge_weight.max() + elif not isinstance(lambda_max, Tensor): + lambda_max = ms.Tensor(lambda_max, dtype=dtype) + assert lambda_max is not None + + if batch is not None and lambda_max.numel() > 1: + lambda_max = lambda_max[batch[edge_index[0]]] + + edge_weight = (2.0 * edge_weight) / lambda_max + edge_weight[edge_weight == float('inf')] = 0 + + loop_mask = edge_index[0] == edge_index[1] + edge_weight[loop_mask] -= 1 + + return edge_index, edge_weight + + def construct( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: Optional[Tensor] = None, + batch: Optional[Tensor] = None, + lambda_max: Optional[Tensor] = None, + ) -> Tensor: + + edge_index, norm = self.__norm__( + edge_index, + x.shape[self.node_dim], + edge_weight, + self.normalization, + lambda_max, + dtype=x.dtype, + batch=batch, + ) + + Tx_0 = x + Tx_1 = x + out = self.lins[0](Tx_0) + + if len(self.lins) > 1: + Tx_1 = self.propagate(edge_index, x=x, norm=norm) + out += self.lins[1](Tx_1) + + for lin in self.lins[2:]: + Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm) + Tx_2 = 2. * Tx_2 - Tx_0 + out += lin(Tx_2) + Tx_0, Tx_1 = Tx_1, Tx_2 + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, norm: Tensor) -> Tensor: + return norm.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, K={len(self.lins)}, ' + f'normalization={self.normalization})') diff --git a/mindscience/sharker/nn/conv/cluster_gcn_conv.py b/mindscience/sharker/nn/conv/cluster_gcn_conv.py new file mode 100644 index 000000000..98018ea4c --- /dev/null +++ b/mindscience/sharker/nn/conv/cluster_gcn_conv.py @@ -0,0 +1,89 @@ +from typing import Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import glorot +from ...utils import ( + add_self_loops, + degree, + remove_self_loops,) + + +class ClusterGCNConv(MessagePassing): + r"""The ClusterGCN graph convolutional operator from the + `"Cluster-GCN: An Efficient Algorithm for Training Deep and Large Graph + Convolutional Networks" `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \left( \mathbf{\hat{A}} + \lambda \cdot + \textrm{diag}(\mathbf{\hat{A}}) \right) \mathbf{X} \mathbf{W}_1 + + \mathbf{X} \mathbf{W}_2 + + where :math:`\mathbf{\hat{A}} = {(\mathbf{D} + \mathbf{I})}^{-1}(\mathbf{A} + + \mathbf{I})`. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + diag_lambda (float, optional): Diagonal enhancement value + :math:`\lambda`. (default: :obj:`0.`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, + diag_lambda: float = 0., add_self_loops: bool = True, + bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.diag_lambda = diag_lambda + self.add_self_loops = add_self_loops + + self.lin_out = nn.Dense(in_channels, out_channels, has_bias=bias) + self.lin_root = nn.Dense(in_channels, out_channels, has_bias=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.lin_out.weight) + glorot(self.lin_root.weight) + + def construct(self, x: Tensor, edge_index: Union[Tensor, ]) -> Tensor: + num_nodes = x.shape[self.node_dim] + edge_weight: Optional[Tensor] = None + + if self.add_self_loops: + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + + row, col = edge_index[0], edge_index[1] + deg_inv = 1. / degree(col, num_nodes=num_nodes).clamp(1.) + + edge_weight = deg_inv[col] + edge_weight[row == col] += self.diag_lambda * deg_inv + + out = self.propagate(edge_index, x=x, edge_weight=edge_weight) + out = self.lin_out(out) + self.lin_root(x) + + return out + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, diag_lambda={self.diag_lambda})') diff --git a/mindscience/sharker/nn/conv/dir_gnn_conv.py b/mindscience/sharker/nn/conv/dir_gnn_conv.py new file mode 100644 index 000000000..a141a7254 --- /dev/null +++ b/mindscience/sharker/nn/conv/dir_gnn_conv.py @@ -0,0 +1,70 @@ +import copy +from mindspore import Tensor, ops, nn +from .message_passing import MessagePassing + + +class DirGNNConv(nn.Cell): + r"""A generic wrapper for computing graph convolution on directed + graphs as described in the `"Edge Directionality Improves Learning on + Heterophilic Graphs" `_ paper. + :class:`DirGNNConv` will pass messages both from source nodes to target + nodes and from target nodes to source nodes. + + Args: + conv (MessagePassing): The underlying + :class:`~sharker.nn.conv.MessagePassing` layer to use. + alpha (float, optional): The alpha coefficient used to weight the + aggregations of in- and out-edges as part of a convex combination. + (default: :obj:`0.5`) + root_weight (bool, optional): If set to :obj:`True`, the layer will add + transformed root node features to the output. + (default: :obj:`True`) + """ + + def __init__( + self, + conv: MessagePassing, + alpha: float = 0.5, + root_weight: bool = True, + ): + super().__init__() + + self.alpha = alpha + self.root_weight = root_weight + + self.conv_in = copy.deepcopy(conv) + self.conv_out = copy.deepcopy(conv) + + if hasattr(conv, 'add_self_loops'): + self.conv_in.add_self_loops = False + self.conv_out.add_self_loops = False + if hasattr(conv, 'root_weight'): + self.conv_in.root_weight = False + self.conv_out.root_weight = False + + if root_weight: + self.lin = nn.Dense(conv.in_channels, conv.out_channels) + else: + self.lin = None + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.conv_in.reset_parameters() + self.conv_out.reset_parameters() + + def construct(self, x: Tensor, edge_index: Tensor) -> Tensor: + """""" # noqa: D419 + x_in = self.conv_in(x, edge_index) + x_out = self.conv_out(x, edge_index.flip([0])) + + out = self.alpha * x_out + (1 - self.alpha) * x_in + + if self.root_weight: + out += self.lin(x) + + return out + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.conv_in}, alpha={self.alpha})' diff --git a/mindscience/sharker/nn/conv/dna_conv.py b/mindscience/sharker/nn/conv/dna_conv.py new file mode 100644 index 000000000..005ca2c5b --- /dev/null +++ b/mindscience/sharker/nn/conv/dna_conv.py @@ -0,0 +1,295 @@ +import math +from typing import Optional, Union, Tuple +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import kaiming_uniform, uniform + + +class Linear(nn.Cell): + def __init__(self, in_channels, out_channels, groups=1, bias=True): + super().__init__() + assert in_channels % groups == 0 and out_channels % groups == 0 + + self.in_channels = in_channels + self.out_channels = out_channels + self.groups = groups + + self.weight = Parameter( + ms.numpy.empty([groups, in_channels // groups, out_channels // groups])) + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + kaiming_uniform(self.weight, a=math.sqrt(5)) + uniform(self.weight.shape[1], self.bias) + + def construct(self, src): + + + if self.groups > 1: + size = src.shape[:-1] + src = src.view(-1, self.groups, self.in_channels // self.groups) + src = src.swapaxes(0, 1) + out = src @ self.weight + out = out.swapaxes(1, 0) + out = out.view(size + (self.out_channels, )) + else: + out = src @ self.weight.squeeze(0) + + if self.bias is not None: + out += self.bias + + return out + + def __repr__(self) -> str: # pragma: no cover + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, groups={self.groups})') + + +def restricted_softmax(src, dim: int = -1, margin: float = 0.): + src_max = mint.clamp(src.max(axis=dim, keepdims=True)[0], min=0.) + out = (src - src_max).exp() + out = out / (out.sum(dim, keepdims=True) + (margin - src_max).exp()) + return out + + +class Attention(nn.Cell): + def __init__(self, dropout=0): + super().__init__() + self.dropout = dropout + + def construct(self, query, key, value): + return self.compute_attention(query, key, value) + + def compute_attention(self, query, key, value): + # query: [*, query_entries, dim_k] + # key: [*, key_entries, dim_k] + # value: [*, key_entries, dim_v] + # Output: [*, query_entries, dim_v] + + assert query.dim() == key.dim() == value.dim() >= 2 + assert query.shape[-1] == key.shape[-1] + assert key.shape[-2] == value.shape[-2] + + # Score: [*, query_entries, key_entries] + score = query @ key.swapaxes(-2, -1) + score = score / math.sqrt(key.shape[-1]) + score = restricted_softmax(score, dim=-1) + score = ops.dropout(score, p=self.dropout, training=self.training) + + return score @ value + + def __repr__(self) -> str: # pragma: no cover + return f'{self.__class__.__name__}(dropout={self.dropout})' + + +class MultiHead(Attention): + def __init__(self, in_channels, out_channels, heads=1, groups=1, dropout=0, + bias=True): + super().__init__(dropout) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.groups = groups + self.bias = bias + + assert in_channels % heads == 0 and out_channels % heads == 0 + assert in_channels % groups == 0 and out_channels % groups == 0 + assert max(groups, self.heads) % min(groups, self.heads) == 0 + + self.lin_q = Linear(in_channels, out_channels, groups, bias) + self.lin_k = Linear(in_channels, out_channels, groups, bias) + self.lin_v = Linear(in_channels, out_channels, groups, bias) + + self.reset_parameters() + + def reset_parameters(self): + self.lin_q.reset_parameters() + self.lin_k.reset_parameters() + self.lin_v.reset_parameters() + + def construct(self, query, key, value): + # query: [*, query_entries, in_channels] + # key: [*, key_entries, in_channels] + # value: [*, key_entries, in_channels] + # Output: [*, query_entries, out_channels] + + assert query.dim() == key.dim() == value.dim() >= 2 + assert query.shape[-1] == key.shape[-1] == value.shape[-1] + assert key.shape[-2] == value.shape[-2] + + query = self.lin_q(query) + key = self.lin_k(key) + value = self.lin_v(value) + + # query: [*, heads, query_entries, out_channels // heads] + # key: [*, heads, key_entries, out_channels // heads] + # value: [*, heads, key_entries, out_channels // heads] + size = query.shape[:-2] + out_channels_per_head = self.out_channels // self.heads + + query_size = size + \ + (query.shape[-2], self.heads, out_channels_per_head) + query = query.view(query_size).swapaxes(-2, -3) + + key_size = size + (key.shape[-2], self.heads, out_channels_per_head) + key = key.view(key_size).swapaxes(-2, -3) + + value_size = size + \ + (value.shape[-2], self.heads, out_channels_per_head) + value = value.view(value_size).swapaxes(-2, -3) + + # Output: [*, heads, query_entries, out_channels // heads] + out = self.compute_attention(query, key, value) + # Output: [*, query_entries, heads, out_channels // heads] + out = out.swapaxes(-3, -2) + # Output: [*, query_entries, out_channels] + out = out.view(size + (query.shape[-2], self.out_channels)) + + return out + + def __repr__(self) -> str: # pragma: no cover + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads}, ' + f'groups={self.groups}, dropout={self.droput}, ' + f'bias={self.bias})') + + +class DNAConv(MessagePassing): + r"""The dynamic neighborhood aggregation operator from the `"Just Jump: + Towards Dynamic Neighborhood Aggregation in Graph Neural Networks" + `_ paper. + + .. math:: + \mathbf{x}_v^{(t)} = h_{\mathbf{\Theta}}^{(t)} \left( \mathbf{x}_{v + \leftarrow v}^{(t)}, \left\{ \mathbf{x}_{v \leftarrow w}^{(t)} : w \in + \mathcal{N}(v) \right\} \right) + + based on (multi-head) dot-product attention + + .. math:: + \mathbf{x}_{v \leftarrow w}^{(t)} = \textrm{Attention} \left( + \mathbf{x}^{(t-1)}_v \, \mathbf{\Theta}_Q^{(t)}, [\mathbf{x}_w^{(1)}, + \ldots, \mathbf{x}_w^{(t-1)}] \, \mathbf{\Theta}_K^{(t)}, \, + [\mathbf{x}_w^{(1)}, \ldots, \mathbf{x}_w^{(t-1)}] \, + \mathbf{\Theta}_V^{(t)} \right) + + with :math:`\mathbf{\Theta}_Q^{(t)}, \mathbf{\Theta}_K^{(t)}, + \mathbf{\Theta}_V^{(t)}` denoting (grouped) projection matrices for query, + key and value information, respectively. + :math:`h^{(t)}_{\mathbf{\Theta}}` is implemented as a non-trainable + version of :class:`sharker.nn.conv.GCNConv`. + + .. note:: + In contrast to other layers, this operator expects node features as + shape :obj:`[num_nodes, num_layers, channels]`. + + Args: + channels (int): Size of each input/output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + groups (int, optional): Number of groups to use for all linear + projections. (default: :obj:`1`) + dropout (float, optional): Dropout probability of attention + coefficients. (default: :obj:`0.`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the + cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + normalize (bool, optional): Whether to add self-loops and apply + symmetric normalization. (default: :obj:`True`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, L, F)` where :math:`L` is the + number of layers, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F)` + """ + + _cached_edge_index: Optional[Tuple[Tensor, Optional[Tensor]]] + + def __init__(self, channels: int, heads: int = 1, groups: int = 1, + dropout: float = 0., cached: bool = False, + normalize: bool = True, add_self_loops: bool = True, + bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.bias = bias + self.cached = cached + self.normalize = normalize + self.add_self_loops = add_self_loops + + self._cached_edge_index = None + self._cached_adj_t = None + + self.multi_head = MultiHead(channels, channels, heads, groups, dropout, + bias) + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self.multi_head.reset_parameters() + self._cached_edge_index = None + self._cached_adj_t = None + + def construct( + self, + x: Tensor, + edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, + ) -> Tensor: + r"""Runs the forward pass of the module. + + Args: + x (Tensor): The input node features of shape + :obj:`[num_nodes, num_layers, channels]`. + edge_index (Tensor or SparseTensor): The edge indices. + edge_weight (Tensor, optional): The edge weights. + (default: :obj:`None`) + """ + if x.dim() != 3: + raise ValueError('Feature shape must be [num_nodes, num_layers, ' + 'channels].') + + if self.normalize: + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index + if cache is None: + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, dtype=x.dtype) + if self.cached: + self._cached_edge_index = (edge_index, edge_weight) + else: + edge_index, edge_weight = cache[0], cache[1] + + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + return self.propagate(edge_index, x=x, edge_weight=edge_weight) + + def message(self, x_i: Tensor, x_j: Tensor, edge_weight: Tensor) -> Tensor: + x_i = x_i[:, -1:] # [num_edges, 1, channels] + out = self.multi_head(x_i, x_j, x_j) # [num_edges, 1, channels] + return edge_weight.view(-1, 1) * out.squeeze(1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.multi_head.in_channels}, ' + f'heads={self.multi_head.heads}, ' + f'groups={self.multi_head.groups})') diff --git a/mindscience/sharker/nn/conv/edge_conv.py b/mindscience/sharker/nn/conv/edge_conv.py new file mode 100644 index 000000000..ff3c26798 --- /dev/null +++ b/mindscience/sharker/nn/conv/edge_conv.py @@ -0,0 +1,138 @@ +from typing import Callable, Optional, Union, Tuple +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import reset +from ...utils.cluster import knn +import mindspore as ms + + +class EdgeConv(MessagePassing): + r"""The edge convolutional operator from the `"Dynamic Graph CNN for + Learning on Point Clouds" `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} + h_{\mathbf{\Theta}}(\mathbf{x}_i \, \Vert \, + \mathbf{x}_j - \mathbf{x}_i), + + where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* a MLP. + + Args: + nn (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` that + maps pair-wise concatenated node features :obj:`x` of shape + :obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`, + *e.g.*, defined by :class:`nn.SequentialCell`. + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"max"`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, nn: Callable, aggr: str = 'max', **kwargs): + super().__init__(aggr=aggr, **kwargs) + self.nn = nn + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.nn) + + def construct(self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, ]) -> Tensor: + if isinstance(x, Tensor): + x = (x, x) + # propagate_type: (x: Tuple[Tensor, Tensor]) + return self.propagate(edge_index, x=x) + + def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: + return self.nn(mint.cat(([x_i, x_j - x_i]), dim=-1)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(nn={self.nn})' + + +class DynamicEdgeConv(MessagePassing): + r"""The dynamic edge convolutional operator from the `"Dynamic Graph CNN + for Learning on Point Clouds" `_ paper + (see :class:`sharker.nn.conv.EdgeConv`), where the graph is + dynamically constructed using nearest neighbors in the feature space. + + Args: + nn (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` that + maps pair-wise concatenated node features :obj:`x` of shape + `:obj:`[-1, 2 * in_channels]` to shape :obj:`[-1, out_channels]`, + *e.g.* defined by :class:`nn.Sequential`. + k (int): Number of nearest neighbors. + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"max"`) + num_workers (int): Number of workers to use for k-NN computation. + Has no effect in case :obj:`batch` is not :obj:`None`, or the input + lies on the GPU. (default: :obj:`1`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V}|, F_{in}), (|\mathcal{V}|, F_{in}))` + if bipartite, + batch vector :math:`(|\mathcal{V}|)` or + :math:`((|\mathcal{V}|), (|\mathcal{V}|))` + if bipartite *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, nn: Callable, k: int, aggr: str = 'max', + num_workers: int = 1, **kwargs): + super().__init__(aggr=aggr, flow='src_to_trg', **kwargs) + + if knn is None: + raise ImportError('`DynamicEdgeConv` requires `mindspore-cluster`.') + + self.nn = nn + self.k = k + self.num_workers = num_workers + self.reset_parameters() + + def reset_parameters(self): + reset(self.nn) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + batch: Union[Optional[Tensor], Optional[Tuple[Tensor, Tensor]]] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + if x[0].dim() != 2: + raise ValueError("Static graphs not supported in DynamicEdgeConv") + + b: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None) + if isinstance(batch, Tensor): + b = (batch, batch) + elif isinstance(batch, tuple): + assert batch is not None + b = (batch[0], batch[1]) + edge_index = knn(x[0], x[1], self.k, b[0], b[1]).flip([0]) + # propagate_type: (x: Tuple[Tensor, Tensor]) + + return self.propagate(edge_index, x=x) + + def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: + return self.nn(mint.cat(([x_i, x_j - x_i]), dim=-1)) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(nn={self.nn}, k={self.k})' diff --git a/mindscience/sharker/nn/conv/eg_conv.py b/mindscience/sharker/nn/conv/eg_conv.py new file mode 100644 index 000000000..c46b8b10e --- /dev/null +++ b/mindscience/sharker/nn/conv/eg_conv.py @@ -0,0 +1,200 @@ +from typing import List, Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import zeros, glorot +from ...utils import add_remaining_self_loops, scatter + + +class EGConv(MessagePassing): + r"""The Efficient Graph Convolution from the `"Adaptive Filters and + Aggregator Fusion for Efficient Graph Convolutions" + `_ paper. + + Its node-wise formulation is given by: + + .. math:: + \mathbf{x}_i^{\prime} = {\LARGE ||}_{h=1}^H \sum_{\oplus \in + \mathcal{A}} \sum_{b = 1}^B w_{i, h, \oplus, b} \; + \underset{j \in \mathcal{N}(i) \cup \{i\}}{\bigoplus} + \mathbf{W}_b \mathbf{x}_{j} + + with :math:`\mathbf{W}_b` denoting a basis weight, + :math:`\oplus` denoting an aggregator, and :math:`w` denoting per-vertex + weighting coefficients across different heads, bases and aggregators. + + EGC retains :math:`\mathcal{O}(|\mathcal{V}|)` memory usage, making it a + sensible alternative to :class:`~sharker.nn.conv.GCNConv`, + :class:`~sharker.nn.conv.SAGEConv` or + :class:`~sharker.nn.conv.GINConv`. + + .. note:: + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + aggregators (List[str], optional): Aggregators to be used. + Supported aggregators are :obj:`"sum"`, :obj:`"mean"`, + :obj:`"symnorm"`, :obj:`"max"`, :obj:`"min"`, :obj:`"std"`, + :obj:`"var"`. + Multiple aggregators can be used to improve the performance. + (default: :obj:`["symnorm"]`) + num_heads (int, optional): Number of heads :math:`H` to use. Must have + :obj:`out_channels % num_heads == 0`. It is recommended to set + :obj:`num_heads >= num_bases`. (default: :obj:`8`) + num_bases (int, optional): Number of basis weights :math:`B` to use. + (default: :obj:`4`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of the edge index with added self loops on first + execution, along with caching the calculation of the symmetric + normalized edge weights if the :obj:`"symnorm"` aggregator is + being used. This parameter should only be set to :obj:`True` in + transductive learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + _cached_edge_index: Optional[Tuple[Tensor, Optional[Tensor]]] + # _cached_adj_t: Optional[SparseTensor] + + def __init__( + self, + in_channels: int, + out_channels: int, + aggregators: List[str] = ['symnorm'], + num_heads: int = 8, + num_bases: int = 4, + cached: bool = False, + add_self_loops: bool = True, + bias: bool = True, + **kwargs, + ): + super().__init__(node_dim=0, **kwargs) + + if out_channels % num_heads != 0: + raise ValueError(f"'out_channels' (got {out_channels}) must be " + f"divisible by the number of heads " + f"(got {num_heads})") + + for a in aggregators: + if a not in ['sum', 'mean', 'symnorm', 'min', 'max', 'var', 'std']: + raise ValueError(f"Unsupported aggregator: '{a}'") + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_heads = num_heads + self.num_bases = num_bases + self.cached = cached + self.add_self_loops = add_self_loops + self.aggregators = aggregators + + self.bases_lin = nn.Dense(in_channels, + (out_channels // num_heads) * num_bases, + has_bias=False) + self.comb_lin = nn.Dense( + in_channels, num_heads * num_bases * len(aggregators)) + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.bias is not None: + zeros(self.bias) + glorot(self.bases_lin.weight) + self._cached_adj_t = None + self._cached_edge_index = None + + def construct(self, x: Tensor, edge_index: Union[Tensor, ]) -> Tensor: + symnorm_weight: Optional[Tensor] = None + if "symnorm" in self.aggregators: + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index + if cache is None: + edge_index, symnorm_weight = gcn_norm( # yapf: disable + edge_index, None, num_nodes=x.shape[self.node_dim], + improved=False, add_self_loops=self.add_self_loops, + flow=self.flow, dtype=x.dtype) + if self.cached: + self._cached_edge_index = (edge_index, symnorm_weight) + else: + edge_index, symnorm_weight = cache + + elif self.add_self_loops: + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index + if self.cached and cache is not None: + edge_index = cache[0] + else: + edge_index, _ = add_remaining_self_loops(edge_index) + if self.cached: + self._cached_edge_index = (edge_index, None) + + bases = self.bases_lin(x) + weightings = self.comb_lin(x) + + aggregated = self.propagate(edge_index, x=bases, + symnorm_weight=symnorm_weight) + + weightings = weightings.view(-1, self.num_heads, + self.num_bases * len(self.aggregators)) + aggregated = aggregated.view( + -1, + len(self.aggregators) * self.num_bases, + self.out_channels // self.num_heads, + ) + + out = weightings @ aggregated + out = out.view(-1, self.out_channels) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def aggregate(self, inputs: Tensor, index: Tensor, + dim_size: Optional[int] = None, + symnorm_weight: Optional[Tensor] = None) -> Tensor: + + outs = [] + for aggr in self.aggregators: + if aggr == 'symnorm': + assert symnorm_weight is not None + out = scatter(inputs * symnorm_weight.view(-1, 1), index, 0, + dim_size, reduce='sum') + elif aggr == 'var' or aggr == 'std': + mean = scatter(inputs, index, 0, dim_size, reduce='mean') + mean_squares = scatter(inputs * inputs, index, 0, dim_size, + reduce='mean') + out = mean_squares - mean * mean + if aggr == 'std': + out = out.clamp(min=1e-5).sqrt() + else: + out = scatter(inputs, index, 0, dim_size, reduce=aggr) + + outs.append(out) + + return mint.stack((outs), dim=1) if len(outs) > 1 else outs[0] + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, aggregators={self.aggregators})') diff --git a/mindscience/sharker/nn/conv/fa_conv.py b/mindscience/sharker/nn/conv/fa_conv.py new file mode 100644 index 000000000..3c90e8572 --- /dev/null +++ b/mindscience/sharker/nn/conv/fa_conv.py @@ -0,0 +1,162 @@ +from typing import Optional, Tuple, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ...utils import is_sparse_tensor + + +class FAConv(MessagePassing): + r"""The Frequency Adaptive Graph Convolution operator from the + `"Beyond Low-Frequency Information in Graph Convolutional Networks" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i= \epsilon \cdot \mathbf{x}^{(0)}_i + + \sum_{j \in \mathcal{N}(i)} \frac{\alpha_{i,j}}{\sqrt{d_i d_j}} + \mathbf{x}_{j} + + where :math:`\mathbf{x}^{(0)}_i` and :math:`d_i` denote the initial feature + representation and node degree of node :math:`i`, respectively. + The attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \mathbf{\alpha}_{i,j} = \textrm{tanh}(\mathbf{a}^{\top}[\mathbf{x}_i, + \mathbf{x}_j]) + + based on the trainable parameter vector :math:`\mathbf{a}`. + + Args: + channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + eps (float, optional): :math:`\epsilon`-value. (default: :obj:`0.1`) + dropout (float, optional): Dropout probability of the normalized + coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`). + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\sqrt{d_i d_j}` on first execution, and + will use the cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + normalize (bool, optional): Whether to add self-loops (if + :obj:`add_self_loops` is :obj:`True`) and compute + symmetric normalization coefficients on the fly. + If set to :obj:`False`, :obj:`edge_weight` needs to be provided in + the layer's :meth:`forward` method. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)`, + initial node features :math:`(|\mathcal{V}|, F)`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` or + :math:`((|\mathcal{V}|, F), ((2, |\mathcal{E}|), + (|\mathcal{E}|)))` if :obj:`return_attention_weights=True` + """ + _cached_edge_index: Optional[Tuple[Tensor, Optional[Tensor]]] + + _alpha: Optional[Tensor] + + def __init__(self, channels: int, eps: float = 0.1, dropout: float = 0.0, + cached: bool = False, add_self_loops: bool = True, + normalize: bool = True, **kwargs): + + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.channels = channels + self.eps = eps + self.dropout = dropout + self.cached = cached + self.add_self_loops = add_self_loops + self.normalize = normalize + + self._cached_edge_index = None + self._cached_adj_t = None + self._alpha = None + + self.att_l = nn.Dense(channels, 1, has_bias=False) + self.att_r = nn.Dense(channels, 1, has_bias=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self._cached_edge_index = None + self._cached_adj_t = None + + def construct( # noqa: F811 + self, + x: Tensor, + x_0: Tensor, + edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, + return_attention_weights: Optional[bool] = None, + ) -> Union[ + Tensor, + Tuple[Tensor, Tuple[Tensor, Tensor]], + Tuple[Tensor, ], + ]: + r"""Runs the forward pass of the module. + + Args: + x (Tensor): The node features. + x_0 (Tensor): The initial input node features. + edge_index (Tensor or SparseTensor): The edge indices. + edge_weight (Tensor, optional): The edge weights. + (default: :obj:`None`) + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + if self.normalize: + if isinstance(edge_index, Tensor): + assert edge_weight is None + cache = self._cached_edge_index + if cache is None: + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, None, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, dtype=x.dtype) + if self.cached: + self._cached_edge_index = (edge_index, edge_weight) + else: + edge_index, edge_weight = cache[0], cache[1] + + else: + if isinstance(edge_index, Tensor) and not is_sparse_tensor(edge_index): + assert edge_weight is not None + + alpha_l = self.att_l(x) + alpha_r = self.att_r(x) + + out = self.propagate(edge_index, x=x, alpha=(alpha_l, alpha_r), + edge_weight=edge_weight) + + alpha = self._alpha + self._alpha = None + + if self.eps != 0.0: + out += self.eps * x_0 + + if isinstance(return_attention_weights, bool): + assert alpha is not None + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + else: + return out + + def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: Tensor, + edge_weight: Optional[Tensor]) -> Tensor: + assert edge_weight is not None + alpha = (alpha_j + alpha_i).tanh().squeeze(-1) + self._alpha = alpha + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + return x_j * (alpha * edge_weight).view(-1, 1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.channels}, eps={self.eps})' diff --git a/mindscience/sharker/nn/conv/feast_conv.py b/mindscience/sharker/nn/conv/feast_conv.py new file mode 100644 index 000000000..ec324ed9b --- /dev/null +++ b/mindscience/sharker/nn/conv/feast_conv.py @@ -0,0 +1,103 @@ +from typing import Union, Tuple +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from ..inits import normal, xavier_uniform +from ...utils import add_self_loops, remove_self_loops + + +class FeaStConv(MessagePassing): + r"""The (translation-invariant) feature-steered convolutional operator from + the `"FeaStNet: Feature-Steered Graph Convolutions for 3D Shape Analysis" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} + \sum_{j \in \mathcal{N}(i)} \sum_{h=1}^H + q_h(\mathbf{x}_i, \mathbf{x}_j) \mathbf{W}_h \mathbf{x}_j + + with :math:`q_h(\mathbf{x}_i, \mathbf{x}_j) = \mathrm{softmax}_j + (\mathbf{u}_h^{\top} (\mathbf{x}_j - \mathbf{x}_i) + c_h)`, where :math:`H` + denotes the number of attention heads, and :math:`\mathbf{W}_h`, + :math:`\mathbf{u}_h` and :math:`c_h` are trainable parameters. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + heads (int, optional): Number of attention heads :math:`H`. + (default: :obj:`1`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: int, out_channels: int, heads: int = 1, + add_self_loops: bool = True, bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'mean') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.add_self_loops = add_self_loops + + self.lin = nn.Dense(in_channels, heads * out_channels, has_bias=False) + self.u = nn.Dense(in_channels, heads, has_bias=False) + self.c = Parameter(ms.numpy.empty(heads)) + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + xavier_uniform(self.lin.weight) + xavier_uniform(self.u.weight) + normal(self.c, mean=0, std=0.1) + if self.bias is not None: + normal(self.bias, mean=0, std=0.1) + + def construct(self, x: Union[Tensor, Tuple[Tensor, Tensor]], edge_index: Union[Tensor, ]) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, + num_nodes=x[1].shape[0]) + + # propagate_type: (x: Tuple[Tensor, Tensor]) + out = self.propagate(edge_index, x=x) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_i: Tensor, x_j: Tensor) -> Tensor: + q = self.u(x_j - x_i) + self.c # Translation invariance. + q = ops.softmax(q, axis=1) + x_j = self.lin(x_j).view(x_j.shape[0], self.heads, -1) + return (x_j * q.view(-1, self.heads, 1)).sum(1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/film_conv.py b/mindscience/sharker/nn/conv/film_conv.py new file mode 100644 index 000000000..3ca450675 --- /dev/null +++ b/mindscience/sharker/nn/conv/film_conv.py @@ -0,0 +1,138 @@ +import copy +from typing import Callable, Optional, Tuple, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import reset + + +class FiLMConv(MessagePassing): + r"""The FiLM graph convolutional operator from the + `"GNN-FiLM: Graph Neural Networks with Feature-wise nn.Dense Modulation" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{r \in \mathcal{R}} + \sum_{j \in \mathcal{N}(i)} \sigma \left( + \boldsymbol{\gamma}_{r,i} \odot \mathbf{W}_r \mathbf{x}_j + + \boldsymbol{\beta}_{r,i} \right) + + where :math:`\boldsymbol{\beta}_{r,i}, \boldsymbol{\gamma}_{r,i} = + g(\mathbf{x}_i)` with :math:`g` being a single nn.Dense layer by default. + Self-loops are automatically added to the input graph and represented as + its own relation type. + + .. note:: + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + num_relations (int, optional): Number of relations. (default: :obj:`1`) + nn (nn.CellList, optional): The neural network :math:`g` that + maps node features :obj:`x_i` of shape + :obj:`[-1, in_channels]` to shape :obj:`[-1, 2 * out_channels]`. + If set to :obj:`None`, :math:`g` will be implemented as a single + nn.Dense layer. (default: :obj:`None`) + act (callable, optional): Activation function :math:`\sigma`. + (default: :meth:`nn.ReLU()`) + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"mean"`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge types :math:`(|\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + num_relations: int = 1, + net: Optional[Callable] = None, + act: Optional[Callable] = nn.ReLU, + aggr: str = 'mean', + **kwargs, + ): + super().__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_relations = max(num_relations, 1) + self.act = act() + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lins = nn.CellList() + self.films = nn.CellList() + for _ in range(num_relations): + self.lins.append(nn.Dense(in_channels[0], out_channels, has_bias=False)) + if net is None: + film = nn.Dense(in_channels[1], 2 * out_channels) + else: + film = copy.deepcopy(net) + self.films.append(film) + + self.lin_skip = nn.Dense(in_channels[1], self.out_channels, has_bias=False) + if net is None: + self.film_skip = nn.Dense(in_channels[1], 2 * self.out_channels, + has_bias=False) + else: + self.film_skip = copy.deepcopy(net) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.film_skip) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + edge_type: Optional[Tensor] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + beta, gamma = self.film_skip(x[1]).split(self.out_channels, axis=-1) + out = gamma * self.lin_skip(x[1]) + beta + if self.act is not None: + out = self.act(out) + + # propagate_type: (x: Tensor, beta: Tensor, gamma: Tensor) + if self.num_relations <= 1: + beta, gamma = self.films[0](x[1]).split(self.out_channels, axis=-1) + out += self.propagate(edge_index, x=self.lins[0](x[0]), + beta=beta, gamma=gamma) + else: + for i, (lin, film) in enumerate(zip(self.lins, self.films)): + beta, gamma = film(x[1]).split(self.out_channels, axis=-1) + assert edge_type is not None + mask = edge_type == i + out += self.propagate(edge_index[:, mask], x=lin( + x[0]), beta=beta, gamma=gamma) + + return out + + def message(self, x_j: Tensor, beta_i: Tensor, gamma_i: Tensor) -> Tensor: + out = gamma_i * x_j + beta_i + if self.act is not None: + out = self.act(out) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, num_relations={self.num_relations})') diff --git a/mindscience/sharker/nn/conv/gat_conv.py b/mindscience/sharker/nn/conv/gat_conv.py new file mode 100644 index 000000000..d62091d5c --- /dev/null +++ b/mindscience/sharker/nn/conv/gat_conv.py @@ -0,0 +1,310 @@ +from typing import Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops +from .message_passing import MessagePassing +from ..inits import glorot, zeros +from ...utils import ( + add_self_loops, + remove_self_loops, + softmax, +) + + +class GATConv(MessagePassing): + r"""The graph attentional operator from the `"Graph Attention Networks" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}_{s}\mathbf{x}_{i} + + \sum_{j \in \mathcal{N}(i)} + \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathrm{LeakyReLU}\left( + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathrm{LeakyReLU}\left( + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + + \mathbf{a}^{\top}_{t}\mathbf{\Theta}_{t}\mathbf{x}_k + \right)\right)}. + + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathrm{LeakyReLU}\left( + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_j + + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,j} + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathrm{LeakyReLU}\left( + \mathbf{a}^{\top}_{s} \mathbf{\Theta}_{s}\mathbf{x}_i + + \mathbf{a}^{\top}_{t} \mathbf{\Theta}_{t}\mathbf{x}_k + + \mathbf{a}^{\top}_{e} \mathbf{\Theta}_{e} \mathbf{e}_{i,k} + \right)\right)}. + + If the graph is not bipartite, :math:`\mathbf{\Theta}_{s} = + \mathbf{\Theta}_{t}`. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities in case of a bipartite graph. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). (default: :obj:`None`) + fill_value (float or Tensor or str, optional): The way to + generate edge features of self-loops (in case + :obj:`edge_dim != None`). + If given as :obj:`float` or :class:`Tensor`, edge features of + self-loops will be directly given by :obj:`fill_value`. + If given as :obj:`str`, edge features of self-loops are computed by + aggregating all features of edges that point to the specific node, + according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, + :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or + :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. + If :obj:`return_attention_weights=True`, then + :math:`((|\mathcal{V}|, H * F_{out}), + ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` + or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), + (|\mathcal{E}|, H)))` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + dropout: float = 0.0, + add_self_loops: bool = True, + edge_dim: Optional[int] = None, + fill_value: Union[float, Tensor, str] = 'mean', + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.add_self_loops = add_self_loops + self.edge_dim = edge_dim + self.fill_value = fill_value + + # In case we are operating in bipartite graphs, we apply separate + # transformations 'lin_src' and 'lin_dst' to source and target nodes: + self.lin = self.lin_src = self.lin_dst = None + if isinstance(in_channels, int): + self.lin = nn.Dense(in_channels, heads * out_channels, has_bias=False) + else: + self.lin_src = nn.Dense(in_channels[0], heads * out_channels, False) + self.lin_dst = nn.Dense(in_channels[1], heads * out_channels, False) + + # The learnable parameters to compute attention coefficients: + self.att_src = Parameter(ms.numpy.empty([1, heads, out_channels])) + self.att_dst = Parameter(ms.numpy.empty([1, heads, out_channels])) + + if edge_dim is not None: + self.lin_edge = nn.Dense(edge_dim, heads * out_channels, has_bias=False) + self.att_edge = Parameter(ms.numpy.empty([1, heads, out_channels])) + else: + self.lin_edge = None + self.att_edge = None + + if bias and concat: + self.bias = Parameter(ms.numpy.empty(heads * out_channels)) + elif bias and not concat: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.lin is not None: + glorot(self.lin.weight) + if self.lin_src is not None: + glorot(self.lin_src.weight) + if self.lin_dst is not None: + glorot(self.lin_dst.weight) + if self.lin_edge is not None: + glorot(self.lin_edge.weight) + glorot(self.att_src) + glorot(self.att_dst) + glorot(self.att_edge) + if self.bias is not None: + zeros(self.bias) + + def construct( # noqa: F811 + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + size: Tuple[int, ...] = None, + return_attention_weights: Optional[bool] = None, + ) -> Union[ + Tensor, + Tuple[Tensor, Tuple[Tensor, Tensor]], + Tuple[Tensor, ], + ]: + r"""Runs the forward pass of the module. + + Args: + x (Tensor or (Tensor, Tensor)): The input node + features. + edge_index (Tensor or SparseTensor): The edge indices. + edge_attr (Tensor, optional): The edge features. + (default: :obj:`None`) + size ((int, int), optional): The shape of the adjacency matrix. + (default: :obj:`None`) + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + + H, C = self.heads, self.out_channels + + # We first transform the input node features. If a tuple is passed, we + # transform source and target node features via separate weights: + if isinstance(x, Tensor): + assert x.dim() == 2, "Static graphs not supported in 'GATConv'" + + if self.lin is not None: + x_src = x_dst = self.lin(x).view(-1, H, C) + else: + # If the module is initialized as bipartite, transform source + # and destination node features separately: + assert self.lin_src is not None and self.lin_dst is not None + x_src = self.lin_src(x).view(-1, H, C) + x_dst = self.lin_dst(x).view(-1, H, C) + + else: # Tuple of source and target node features: + x_src, x_dst = x + assert x_src.dim() == 2, "Static graphs not supported in 'GATConv'" + + if self.lin is not None: + # If the module is initialized as non-bipartite, we expect that + # source and destination node features have the same shape and + # that they their transformations are shared: + x_src = self.lin(x_src).view(-1, H, C) + if x_dst is not None: + x_dst = self.lin(x_dst).view(-1, H, C) + else: + assert self.lin_src is not None and self.lin_dst is not None + + x_src = self.lin_src(x_src).view(-1, H, C) + if x_dst is not None: + x_dst = self.lin_dst(x_dst).view(-1, H, C) + + x = (x_src, x_dst) + + # Next, we compute node-level attention coefficients, both for source + # and target nodes (if present): + alpha_src = (x_src * self.att_src).sum(-1) + alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1) + alpha = (alpha_src, alpha_dst) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + # We only want to add self-loops for nodes that appear both as + # source and target nodes: + num_nodes = x_src.shape[0] + if x_dst is not None: + num_nodes = min(num_nodes, x_dst.shape[0]) + num_nodes = min(size) if size is not None else num_nodes + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + edge_index, edge_attr = add_self_loops( + edge_index, edge_attr, fill_value=self.fill_value, + num_nodes=num_nodes) + + alpha = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, + size=size) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], alpha: Tensor) + out = self.propagate(edge_index, x=x, alpha=alpha, shape=size) + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(1) + + if self.bias is not None: + out += self.bias + + if isinstance(return_attention_weights, bool): + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + else: + return out + + def edge_update(self, alpha_j: Tensor, alpha_i: Optional[Tensor], + edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], + dim_size: Optional[int]) -> Tensor: + # Given edge-level attention coefficients for source and target nodes, + # we simply need to sum them up to "emulate" concatenation: + alpha = alpha_j if alpha_i is None else alpha_j + alpha_i + if index.numel() == 0: + return alpha + if edge_attr is not None and self.lin_edge is not None: + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + edge_attr = self.lin_edge(edge_attr) + edge_attr = edge_attr.view(-1, self.heads, self.out_channels) + alpha_edge = (edge_attr * self.att_edge).sum(-1) + alpha = alpha + alpha_edge + + alpha = ops.leaky_relu(alpha, self.negative_slope) + alpha = softmax(alpha, index, ptr, dim_size) + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + return alpha + + def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: + return alpha.unsqueeze(-1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/gated_graph_conv.py b/mindscience/sharker/nn/conv/gated_graph_conv.py new file mode 100644 index 000000000..d783618e8 --- /dev/null +++ b/mindscience/sharker/nn/conv/gated_graph_conv.py @@ -0,0 +1,86 @@ +from typing import Union, Optional +from mindspore import Tensor, Parameter, nn, ops, mint +from ..inits import uniform +from .message_passing import MessagePassing + + +class GatedGraphConv(MessagePassing): + r"""The gated graph convolution operator from the `"Gated Graph Sequence + Neural Networks" `_ paper. + + .. math:: + \mathbf{h}_i^{(0)} &= \mathbf{x}_i \, \Vert \, \mathbf{0} + + \mathbf{m}_i^{(l+1)} &= \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot + \mathbf{\Theta} \cdot \mathbf{h}_j^{(l)} + + \mathbf{h}_i^{(l+1)} &= \textrm{GRU} (\mathbf{m}_i^{(l+1)}, + \mathbf{h}_i^{(l)}) + + up to representation :math:`\mathbf{h}_i^{(L)}`. + The number of input channels of :math:`\mathbf{x}_i` needs to be less or + equal than :obj:`out_channels`. + :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target + node :obj:`i` (default: :obj:`1`) + + Args: + out_channels (int): Size of each output sample. + num_layers (int): The sequence length :math:`L`. + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"add"`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + + """ + + def __init__(self, out_channels: int, num_layers: int, aggr: str = 'add', + bias: bool = True, **kwargs): + super().__init__(aggr=aggr, **kwargs) + + self.out_channels = out_channels + self.num_layers = num_layers + + self.weight = Parameter(mint.zeros([num_layers, out_channels, out_channels])) + self.rnn = nn.GRUCell(out_channels, out_channels, has_bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + uniform(self.out_channels, self.weight) + self.rnn.reset_parameters() + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if x.shape[-1] > self.out_channels: + raise ValueError('The number of input channels is not allowed to ' + 'be larger than the number of output channels') + + if x.shape[-1] < self.out_channels: + zero = mint.zeros([x.shape[0], self.out_channels - x.shape[-1]], dtype=x.dtype) + x = mint.cat(([x, zero]), dim=1) + + for i in range(self.num_layers): + m = x @ self.weight[i] + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + m = self.propagate(edge_index, x=m, edge_weight=edge_weight) + x = self.rnn(m, x) + + return x + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]): + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.out_channels}, ' + f'num_layers={self.num_layers})') diff --git a/mindscience/sharker/nn/conv/gatv2_conv.py b/mindscience/sharker/nn/conv/gatv2_conv.py new file mode 100644 index 000000000..7fc2de503 --- /dev/null +++ b/mindscience/sharker/nn/conv/gatv2_conv.py @@ -0,0 +1,280 @@ +from typing import Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from ..inits import glorot, zeros +from ...utils import ( + add_self_loops, + remove_self_loops, + softmax, +) + + +class GATv2Conv(MessagePassing): + r"""The GATv2 operator from the `"How Attentive are Graph Attention + Networks?" `_ paper, which fixes the + static attention problem of the standard + :class:`~sharker.conv.GATConv` layer. + Since the nn.Dense layers in the standard GAT are applied right after each + other, the ranking of attended nodes is unconditioned on the query node. + In contrast, in :class:`GATv2`, every node can attend to any other node. + + .. math:: + \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}_{s}\mathbf{x}_{i} + + \sum_{j \in \mathcal{N}(i)} + \alpha_{i,j}\mathbf{\Theta}_{t}\mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( + \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_j + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( + \mathbf{\Theta}_{s} \mathbf{x}_i + \mathbf{\Theta}_{t} \mathbf{x}_k + \right)\right)}. + + If the graph has multi-dimensional edge features :math:`\mathbf{e}_{i,j}`, + the attention coefficients :math:`\alpha_{i,j}` are computed as + + .. math:: + \alpha_{i,j} = + \frac{ + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( + \mathbf{\Theta}_{s} \mathbf{x}_i + + \mathbf{\Theta}_{t} \mathbf{x}_j + + \mathbf{\Theta}_{e} \mathbf{e}_{i,j} + \right)\right)} + {\sum_{k \in \mathcal{N}(i) \cup \{ i \}} + \exp\left(\mathbf{a}^{\top}\mathrm{LeakyReLU}\left( + \mathbf{\Theta}_{s} \mathbf{x}_i + + \mathbf{\Theta}_{t} \mathbf{x}_k + + \mathbf{\Theta}_{e} \mathbf{e}_{i,k}] + \right)\right)}. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities in case of a bipartite graph. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). (default: :obj:`None`) + fill_value (float or Tensor or str, optional): The way to + generate edge features of self-loops + (in case :obj:`edge_dim != None`). + If given as :obj:`float` or :class:`Tensor`, edge features of + self-loops will be directly given by :obj:`fill_value`. + If given as :obj:`str`, edge features of self-loops are computed by + aggregating all features of edges that point to the specific node, + according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, + :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`"mean"`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + share_weights (bool, optional): If set to :obj:`True`, the same matrix + will be applied to the source and the target node of every edge, + *i.e.* :math:`\mathbf{\Theta}_{s} = \mathbf{\Theta}_{t}`. + (default: :obj:`False`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, H * F_{out})` or + :math:`((|\mathcal{V}_t|, H * F_{out})` if bipartite. + If :obj:`return_attention_weights=True`, then + :math:`((|\mathcal{V}|, H * F_{out}), + ((2, |\mathcal{E}|), (|\mathcal{E}|, H)))` + or :math:`((|\mathcal{V_t}|, H * F_{out}), ((2, |\mathcal{E}|), + (|\mathcal{E}|, H)))` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + dropout: float = 0.0, + add_self_loops: bool = True, + edge_dim: Optional[int] = None, + fill_value: Union[float, Tensor, str] = 'mean', + bias: bool = True, + share_weights: bool = False, + **kwargs, + ): + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.add_self_loops = add_self_loops + self.edge_dim = edge_dim + self.fill_value = fill_value + self.share_weights = share_weights + + if isinstance(in_channels, int): + self.lin_l = nn.Dense(in_channels, heads * out_channels, has_bias=bias) + if share_weights: + self.lin_r = self.lin_l + else: + self.lin_r = nn.Dense(in_channels, heads * out_channels, has_bias=bias) + else: + self.lin_l = nn.Dense(in_channels[0], heads * out_channels, has_bias=bias) + if share_weights: + self.lin_r = self.lin_l + else: + self.lin_r = nn.Dense(in_channels[1], heads * out_channels, has_bias=bias) + + self.att = Parameter(ms.numpy.empty([1, heads, out_channels])) + + if edge_dim is not None: + self.lin_edge = nn.Dense(edge_dim, heads * out_channels, has_bias=False) + else: + self.lin_edge = None + + if bias and concat: + self.bias = Parameter(ms.numpy.empty(heads * out_channels)) + elif bias and not concat: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.lin_l.weight) + glorot(self.lin_r.weight) + if self.lin_edge is not None: + glorot(self.lin_edge.weight) + glorot(self.att) + if self.bias is not None: + zeros(self.bias) + + def construct( # noqa: F811 + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + return_attention_weights: Optional[bool] = None, + ) -> Union[ + Tensor, + Tuple[Tensor, Tuple[Tensor, Tensor]], + Tuple[Tensor, ], + ]: + r"""Runs the forward pass of the module. + + Args: + x (Tensor or (Tensor, Tensor)): The input node + features. + edge_index (Tensor or SparseTensor): The edge indices. + edge_attr (Tensor, optional): The edge features. + (default: :obj:`None`) + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + H, C = self.heads, self.out_channels + + x_l: Optional[Tensor] = None + x_r: Optional[Tensor] = None + if isinstance(x, Tensor): + assert x.dim() == 2 + x_l = self.lin_l(x).view(-1, H, C) + if self.share_weights: + x_r = x_l + else: + x_r = self.lin_r(x).view(-1, H, C) + else: + x_l, x_r = x[0], x[1] + assert x[0].dim() == 2 + x_l = self.lin_l(x_l).view(-1, H, C) + if x_r is not None: + x_r = self.lin_r(x_r).view(-1, H, C) + + assert x_l is not None + assert x_r is not None + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + num_nodes = x_l.shape[0] + if x_r is not None: + num_nodes = min(num_nodes, x_r.shape[0]) + edge_index, edge_attr = remove_self_loops( + edge_index, edge_attr) + edge_index, edge_attr = add_self_loops( + edge_index, edge_attr, fill_value=self.fill_value, + num_nodes=num_nodes) + + # edge_updater_type: (x: Tuple[Tensor, Tensor], edge_attr: Optional[Tensor]) + alpha = self.edge_updater(edge_index, x=(x_l, x_r), + edge_attr=edge_attr) + + # propagate_type: (x: Tuple[Tensor, Tensor], alpha: Tensor) + out = self.propagate(edge_index, x=(x_l, x_r), alpha=alpha) + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(1) + + if self.bias is not None: + out += self.bias + + if isinstance(return_attention_weights, bool): + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + else: + return out + + def edge_update(self, x_j: Tensor, x_i: Tensor, edge_attr: Optional[Tensor], + index: Tensor, ptr: Optional[Tensor], + dim_size: Optional[int]) -> Tensor: + x = x_i + x_j + + if edge_attr is not None: + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + assert self.lin_edge is not None + edge_attr = self.lin_edge(edge_attr) + edge_attr = edge_attr.view(-1, self.heads, self.out_channels) + x = x + edge_attr + + x = ops.leaky_relu(x, self.negative_slope) + alpha = (x * self.att).sum(-1) + alpha = softmax(alpha, index, ptr, dim_size) + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + return alpha + + def message(self, x_j: Tensor, alpha: Tensor) -> Tensor: + return x_j * alpha.unsqueeze(-1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/gcn_conv.py b/mindscience/sharker/nn/conv/gcn_conv.py new file mode 100644 index 000000000..8084c1e8f --- /dev/null +++ b/mindscience/sharker/nn/conv/gcn_conv.py @@ -0,0 +1,183 @@ +from typing import Optional, Union, Tuple +import mindspore as ms +from mindspore import Tensor, Parameter, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import zeros, glorot +from ...utils import ( + add_remaining_self_loops, + scatter, + maybe_num_nodes +) + + +def gcn_norm( # noqa: F811 + edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + improved: bool = False, + add_self_loops: bool = True, + flow: str = "src_to_trg", + dtype: Optional[ms.Type] = None, +): + fill_value = 2. if improved else 1. + + assert flow in ['src_to_trg', 'trg_to_src'] + num_nodes = maybe_num_nodes(edge_index, num_nodes) + if add_self_loops: + edge_index, edge_weight = add_remaining_self_loops( + edge_index, edge_weight, fill_value, num_nodes) + + if edge_weight is None: + edge_weight = ops.ones((edge_index.shape[1], ), dtype=dtype) + + row, col = edge_index[0], edge_index[1] + idx = col if flow == 'src_to_trg' else row + + deg = scatter(edge_weight, idx, dim=0, dim_size=num_nodes, reduce='sum') + + deg_inv_sqrt = deg.pow(-0.5) + deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0 + edge_weight = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col] + return edge_index, edge_weight + + +class GCNConv(MessagePassing): + r"""The graph convolutional operator from the `"Semi-supervised + Classification with Graph Convolutional Networks" + `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + The adjacency matrix can include other values than :obj:`1` representing + edge weights via the optional :obj:`edge_weight` tensor. + + Its node-wise formulation is given by: + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{\Theta}^{\top} \sum_{j \in + \mathcal{N}(i) \cup \{ i \}} \frac{e_{j,i}}{\sqrt{\hat{d}_j + \hat{d}_i}} \mathbf{x}_j + + with :math:`\hat{d}_i = 1 + \sum_{j \in \mathcal{N}(i)} e_{j,i}`, where + :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to target + node :obj:`i` (default: :obj:`1.0`) + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + improved (bool, optional): If set to :obj:`True`, the layer computes + :math:`\mathbf{\hat{A}}` as :math:`\mathbf{A} + 2\mathbf{I}`. + (default: :obj:`False`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2}` on first execution, and will use the + cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. By default, self-loops will be added + in case :obj:`normalize` is set to :obj:`True`, and not added + otherwise. (default: :obj:`None`) + normalize (bool, optional): Whether to add self-loops and compute + symmetric normalization coefficients on-the-fly. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)` + or sparse matrix :math:`(|\mathcal{V}|, |\mathcal{V}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + _cached_edge_index: Optional[Tuple[Tensor, Optional[Tensor]]] + + def __init__( + self, + in_channels: int, + out_channels: int, + improved: bool = False, + cached: bool = False, + add_self_loops: Optional[bool] = None, + normalize: bool = True, + has_bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + if add_self_loops is None: + add_self_loops = normalize + + if add_self_loops and not normalize: + raise ValueError(f"'{self.__class__.__name__}' does not support " + f"adding self-loops to the graph when no " + f"on-the-fly normalization is applied") + + self.in_channels = in_channels + self.out_channels = out_channels + self.improved = improved + self.cached = cached + self.add_self_loops = add_self_loops + self.normalize = normalize + + self._cached_edge_index = None + + self.lin = nn.Dense(in_channels, out_channels, has_bias=False) + + if has_bias: + self.bias = Parameter(mint.zeros(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.lin.weight) + if self.bias is not None: + zeros(self.bias) + self._cached_edge_index = None + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if isinstance(x, (tuple, list)): + raise ValueError(f"'{self.__class__.__name__}' received a tuple " + f"of node features as input while this layer " + f"does not support bipartite message passing. " + f"Please try other layers such as 'SAGEConv' or " + f"'GraphConv' instead") + if self.normalize: + if isinstance(edge_index, Tensor): + cache = self._cached_edge_index + if cache is None: + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], + self.improved, self.add_self_loops, self.flow, x.dtype) + if self.cached: + self._cached_edge_index = (edge_index, edge_weight) + else: + edge_index, edge_weight = cache[0], cache[1] + + x = self.lin(x) + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j diff --git a/mindscience/sharker/nn/conv/gen_conv.py b/mindscience/sharker/nn/conv/gen_conv.py new file mode 100644 index 000000000..c7b0df939 --- /dev/null +++ b/mindscience/sharker/nn/conv/gen_conv.py @@ -0,0 +1,217 @@ +from typing import List, Optional, Tuple, Union +from mindspore import Tensor, nn, ops, mint +from ..aggr import Aggregation, MultiAggregation +from .message_passing import MessagePassing +from ..inits import reset +from ..norm import MessageNorm + + +class MLP(nn.SequentialCell): + def __init__(self, channels: List[int], norm: Optional[str] = None, + has_bias: bool = True, dropout: float = 0.1): + m = [] + for i in range(1, len(channels)): + m.append(nn.Dense(channels[i - 1], channels[i], has_bias=has_bias)) + + if i < len(channels) - 1: + if norm and norm == 'batch': + m.append(nn.BatchNorm1d(channels[i], affine=True)) + elif norm and norm == 'layer': + m.append(nn.LayerNorm(channels[i], elementwise_affine=True)) + elif norm and norm == 'instance': + m.append(nn.InstanceNorm1d(channels[i], affine=False)) + elif norm: + raise NotImplementedError( + f'Normalization layer "{norm}" not supported.') + m.append(nn.ReLU()) + m.append(nn.Dropout(dropout)) + + super().__init__(*m) + + +class GENConv(MessagePassing): + r"""The GENeralized Graph Convolution (GENConv) from the `"DeeperGCN: All + You Need to Train Deeper GCNs" `_ paper. + + :class:`GENConv` supports both :math:`\textrm{softmax}` (see + :class:`~sharker.nn.aggr.SoftmaxAggregation`) and + :math:`\textrm{powermean}` (see + :class:`~sharker.nn.aggr.PowerMeanAggregation`) aggregation. + Its message construction is given by: + + .. math:: + \mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_i + + \mathrm{AGG} \left( \left\{ + \mathrm{ReLU} \left( \mathbf{x}_j + \mathbf{e_{ji}} \right) +\epsilon + : j \in \mathcal{N}(i) \right\} \right) + \right) + + .. note:: + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + aggr (str or Aggregation, optional): The aggregation scheme to use. + Any aggregation of :obj:`sharker.nn.aggr` can be used, + (:obj:`"softmax"`, :obj:`"powermean"`, :obj:`"add"`, :obj:`"mean"`, + :obj:`max`). (default: :obj:`"softmax"`) + t (float, optional): Initial inverse temperature for softmax + aggregation. (default: :obj:`1.0`) + learn_t (bool, optional): If set to :obj:`True`, will learn the value + :obj:`t` for softmax aggregation dynamically. + (default: :obj:`False`) + p (float, optional): Initial power for power mean aggregation. + (default: :obj:`1.0`) + learn_p (bool, optional): If set to :obj:`True`, will learn the value + :obj:`p` for power mean aggregation dynamically. + (default: :obj:`False`) + msg_norm (bool, optional): If set to :obj:`True`, will use message + normalization. (default: :obj:`False`) + learn_msg_scale (bool, optional): If set to :obj:`True`, will learn the + scaling factor of message normalization. (default: :obj:`False`) + norm (str, optional): Norm layer of MLP layers (:obj:`"batch"`, + :obj:`"layer"`, :obj:`"instance"`) (default: :obj:`batch`) + num_layers (int, optional): The number of MLP layers. + (default: :obj:`2`) + expansion (int, optional): The expansion factor of hidden channels in + MLP layers. (default: :obj:`2`) + eps (float, optional): The epsilon value of the message construction + function. (default: :obj:`1e-7`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + edge_dim (int, optional): Edge feature dimensionality. If set to + :obj:`None`, Edge feature dimensionality is expected to match + the `out_channels`. Other-wise, edge features are linearly + transformed to match `out_channels` of node feature dimensionality. + (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.GenMessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge attributes :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + aggr: Optional[Union[str, List[str], Aggregation]] = 'softmax', + t: float = 1.0, + learn_t: bool = False, + p: float = 1.0, + learn_p: bool = False, + msg_norm: bool = False, + learn_msg_scale: bool = False, + norm: str = 'batch', + num_layers: int = 2, + expansion: int = 2, + eps: float = 1e-7, + bias: bool = False, + edge_dim: Optional[int] = None, + **kwargs, + ): + + # Backward compatibility + aggr = 'softmax' if aggr == 'softmax_sg' else aggr + aggr = 'powermean' if aggr == 'power' else aggr + + # Override args of aggregator if `aggr_kwargs` is specified + if 'aggr_kwargs' not in kwargs: + if aggr == 'softmax': + kwargs['aggr_kwargs'] = dict(t=t, learn=learn_t) + elif aggr == 'powermean': + kwargs['aggr_kwargs'] = dict(p=p, learn=learn_p) + + super().__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.eps = eps + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + if in_channels[0] != out_channels: + self.lin_src = nn.Dense(in_channels[0], out_channels, has_bias=bias) + + if edge_dim is not None and edge_dim != out_channels: + self.lin_edge = nn.Dense(edge_dim, out_channels, has_bias=bias) + + if isinstance(self.aggr_module, MultiAggregation): + aggr_out_channels = self.aggr_module.get_out_channels(out_channels) + else: + aggr_out_channels = out_channels + + if aggr_out_channels != out_channels: + self.lin_aggr_out = nn.Dense(aggr_out_channels, out_channels, + has_bias=bias) + + if in_channels[1] != out_channels: + self.lin_dst = nn.Dense(in_channels[1], out_channels, has_bias=bias) + + channels = [out_channels] + for i in range(num_layers - 1): + channels.append(out_channels * expansion) + channels.append(out_channels) + self.mlp = MLP(channels, norm=norm, has_bias=bias) + + if msg_norm: + self.msg_norm = MessageNorm(learn_msg_scale) + + def reset_parameters(self): + super().reset_parameters() + reset(self.mlp) + if hasattr(self, 'msg_norm'): + self.msg_norm.reset_parameters() + + def construct(self, x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, shape: Tuple[int, ...] = None) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + if hasattr(self, 'lin_src'): + x = (self.lin_src(x[0]), x[1]) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, shape=shape) + + if hasattr(self, 'lin_aggr_out'): + out = self.lin_aggr_out(out) + + if hasattr(self, 'msg_norm'): + h = x[1] if x[1] is not None else x[0] + assert h is not None + out = self.msg_norm(h, out) + + x_dst = x[1] + if x_dst is not None: + if hasattr(self, 'lin_dst'): + x_dst = self.lin_dst(x_dst) + out += x_dst + + return self.mlp(out) + + def message(self, x_j: Tensor, edge_attr: Optional[Tensor]) -> Tensor: + if edge_attr is not None and hasattr(self, 'lin_edge'): + edge_attr = self.lin_edge(edge_attr) + + if edge_attr is not None: + assert x_j.shape[-1] == edge_attr.shape[-1] + + msg = x_j if edge_attr is None else x_j + edge_attr + return ops.relu(msg) + self.eps + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, aggr={self.aggr})') diff --git a/mindscience/sharker/nn/conv/general_conv.py b/mindscience/sharker/nn/conv/general_conv.py new file mode 100644 index 000000000..0f82aeed3 --- /dev/null +++ b/mindscience/sharker/nn/conv/general_conv.py @@ -0,0 +1,172 @@ +from typing import Tuple, Union, Optional +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from ..inits import glorot +from ...utils import softmax + + +class GeneralConv(MessagePassing): + r"""A general GNN layer adapted from the `"Design Space for Graph Neural + Networks" `_ paper. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + in_edge_channels (int, optional): Size of each input edge. + (default: :obj:`None`) + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"mean"`) + skip_linear (bool, optional): Whether apply linear function in skip + connection. (default: :obj:`False`) + directed_msg (bool, optional): If message passing is directed; + otherwise, message passing is bi-directed. (default: :obj:`True`) + heads (int, optional): Number of message passing ensembles. + If :obj:`heads > 1`, the GNN layer will output an ensemble of + multiple messages. + If attention is used (:obj:`attention=True`), this corresponds to + multi-head attention. (default: :obj:`1`) + attention (bool, optional): Whether to add attention to message + computation. (default: :obj:`False`) + attention_type (str, optional): Type of attention: :obj:`"additive"`, + :obj:`"dot_product"`. (default: :obj:`"additive"`) + l2_normalize (bool, optional): If set to :obj:`True`, output features + will be :math:`\ell_2`-normalized, *i.e.*, + :math:`\frac{\mathbf{x}^{\prime}_i} + {\| \mathbf{x}^{\prime}_i \|_2}`. + (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge attributes :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: Optional[int], + in_edge_channels: int = None, + aggr: str = "add", + skip_linear: str = False, + directed_msg: bool = True, + heads: int = 1, + attention: bool = False, + attention_type: str = "additive", + l2_normalize: bool = False, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', aggr) + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.in_edge_channels = in_edge_channels + self.aggr = aggr + self.skip_linear = skip_linear + self.directed_msg = directed_msg + self.heads = heads + self.attention = attention + self.attention_type = attention_type + self.normalize_l2 = l2_normalize + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + if self.directed_msg: + self.lin_msg = nn.Dense(in_channels[0], out_channels * self.heads, + has_bias=bias) + else: + self.lin_msg = nn.Dense(in_channels[0], out_channels * self.heads, + has_bias=bias) + self.lin_msg_i = nn.Dense(in_channels[0], out_channels * self.heads, + has_bias=bias) + + if self.skip_linear or self.in_channels != self.out_channels: + self.lin_self = nn.Dense(in_channels[1], out_channels, has_bias=bias) + else: + self.lin_self = nn.Identity() + + if self.in_edge_channels is not None: + self.lin_edge = nn.Dense(in_edge_channels, out_channels * self.heads, + has_bias=bias) + + # TODO: A general sharker.nn.AttentionLayer + if self.attention: + if self.attention_type == 'additive': + self.att_msg = Parameter( + ms.numpy.empty([1, self.heads, self.out_channels])) + elif self.attention_type == 'dot_product': + scaler = ms.Tensor(out_channels, dtype=ms.float32).sqrt() + self.scaler = Parameter(scaler, requires_grad=False) + else: + raise ValueError( + f"Attention type '{self.attention_type}' not supported") + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.attention and self.attention_type == 'additive': + glorot(self.att_msg) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + size: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x: Tuple[Tensor, Optional[Tensor]] = (x, x) + x_self = x[1] + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, x=x, shape=size, edge_attr=edge_attr) + out = out.mean(1) # todo: other approach to aggregate heads + out += self.lin_self(x_self) + if self.normalize_l2: + out /= ms.numpy.norm(out, ord=2, axis=-1, keepdims=True) + out[out.isnan()] = 0 + return out + + def message_basic(self, x_i: Tensor, x_j: Tensor, edge_attr: Optional[Tensor]): + if self.directed_msg: + x_j = self.lin_msg(x_j) + else: + x_j = self.lin_msg(x_j) + self.lin_msg_i(x_i) + if edge_attr is not None: + x_j = x_j + self.lin_edge(edge_attr) + return x_j + + def message(self, x_i: Tensor, x_j: Tensor, edge_index_i: Tensor, + size_i: Tensor, edge_attr: Tensor) -> Tensor: + x_j_out = self.message_basic(x_i, x_j, edge_attr) + x_j_out = x_j_out.view(-1, self.heads, self.out_channels) + if self.attention: + if self.attention_type == 'dot_product': + x_i_out = self.message_basic(x_j, x_i, edge_attr) + x_i_out = x_i_out.view(-1, self.heads, self.out_channels) + alpha = (x_i_out * x_j_out).sum(-1) / self.scaler + else: + alpha = (x_j_out * self.att_msg).sum(-1) + alpha = ops.leaky_relu(alpha, alpha=0.2) + alpha = softmax(alpha, edge_index_i, num_nodes=size_i) + alpha = alpha.view(-1, self.heads, 1) + return x_j_out * alpha + else: + return x_j_out diff --git a/mindscience/sharker/nn/conv/gin_conv.py b/mindscience/sharker/nn/conv/gin_conv.py new file mode 100644 index 000000000..bdf2d4002 --- /dev/null +++ b/mindscience/sharker/nn/conv/gin_conv.py @@ -0,0 +1,191 @@ +from typing import Callable, Optional, Union, Tuple +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops +from .message_passing import MessagePassing +from ..inits import reset + + +class GINConv(MessagePassing): + r"""The graph isomorphism operator from the `"How Powerful are + Graph Neural Networks?" `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot + \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \right) + + or + + .. math:: + \mathbf{X}^{\prime} = h_{\mathbf{\Theta}} \left( \left( \mathbf{A} + + (1 + \epsilon) \cdot \mathbf{I} \right) \cdot \mathbf{X} \right), + + here :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* an MLP. + + Args: + nn (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` that + maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to + shape :obj:`[-1, out_channels]`, *e.g.*, defined by + :class:`nn.Sequential`. + eps (float, optional): (Initial) :math:`\epsilon`-value. + (default: :obj:`0.`) + train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` + will be a trainable parameter. (default: :obj:`False`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, nn: Callable, eps: float = 0., train_eps: bool = False, + **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + self.nn = nn + self.initial_eps = eps + if train_eps: + self.eps = Parameter(ms.numpy.empty(1)) + else: + self.eps = Parameter(ms.numpy.empty(1), requires_grad=False) + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.nn) + self.eps.data[:] = self.initial_eps + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + shape: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]]) + out = self.propagate(edge_index, x=x, shape=shape) + + x_r = x[1] + if x_r is not None: + out += (1 + self.eps) * x_r + + return self.nn(out) + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(nn={self.nn})' + + +class GINEConv(MessagePassing): + r"""The modified :class:`GINConv` operator from the `"Strategies for + Pre-training Graph Neural Networks" `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = h_{\mathbf{\Theta}} \left( (1 + \epsilon) \cdot + \mathbf{x}_i + \sum_{j \in \mathcal{N}(i)} \mathrm{ReLU} + ( \mathbf{x}_j + \mathbf{e}_{j,i} ) \right) + + that is able to incorporate edge features :math:`\mathbf{e}_{j,i}` into + the aggregation procedure. + + Args: + nn (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` that + maps node features :obj:`x` of shape :obj:`[-1, in_channels]` to + shape :obj:`[-1, out_channels]`, *e.g.*, defined by + :class:`nn.Sequential`. + eps (float, optional): (Initial) :math:`\epsilon`-value. + (default: :obj:`0.`) + train_eps (bool, optional): If set to :obj:`True`, :math:`\epsilon` + will be a trainable parameter. (default: :obj:`False`) + edge_dim (int, optional): Edge feature dimensionality. If set to + :obj:`None`, node and edge feature dimensionality is expected to + match. Other-wise, edge features are linearly transformed to match + node feature dimensionality. (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, net: nn.CellList, eps: float = 0., + train_eps: bool = False, edge_dim: Optional[int] = None, + **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + self.nn = net + self.initial_eps = eps + if train_eps: + self.eps = Parameter(ms.numpy.empty(1)) + else: + self.eps = Parameter(ms.numpy.empty(1), requires_grad=False) + if edge_dim is not None: + if isinstance(self.nn, nn.SequentialCell): + net = self.nn[0] + if hasattr(net, 'in_features'): + in_channels = net.in_features + elif hasattr(net, 'in_channels'): + in_channels = net.in_channels + else: + raise ValueError("Could not infer input channels from `nn`.") + self.lin = nn.Dense(edge_dim, in_channels) + + else: + self.lin = None + self.reset_parameters() + + def reset_parameters(self): + reset(self.nn) + self.eps.data[:] = self.initial_eps + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + size: Tuple[int] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, shape=size) + + x_r = x[1] + if x_r is not None: + out += (1 + self.eps) * x_r + + return self.nn(out) + + def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: + if self.lin is None and x_j.shape[-1] != edge_attr.shape[-1]: + raise ValueError("Node and edge feature dimensionalities do not " + "match. Consider setting the 'edge_dim' " + "attribute of 'GINEConv'") + + if self.lin is not None: + edge_attr = self.lin(edge_attr) + + return ops.relu(x_j + edge_attr) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(nn={self.nn})' diff --git a/mindscience/sharker/nn/conv/gmm_conv.py b/mindscience/sharker/nn/conv/gmm_conv.py new file mode 100644 index 000000000..1af8f1bca --- /dev/null +++ b/mindscience/sharker/nn/conv/gmm_conv.py @@ -0,0 +1,168 @@ +from typing import Tuple, Union, Optional +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from ..inits import glorot, zeros + + +class GMMConv(MessagePassing): + r"""The gaussian mixture model convolutional operator from the `"Geometric + Deep Learning on Graphs and Manifolds using Mixture Model CNNs" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} + \sum_{j \in \mathcal{N}(i)} \frac{1}{K} \sum_{k=1}^K + \mathbf{w}_k(\mathbf{e}_{i,j}) \odot \mathbf{\Theta}_k \mathbf{x}_j, + + where + + .. math:: + \mathbf{w}_k(\mathbf{e}) = \exp \left( -\frac{1}{2} {\left( + \mathbf{e} - \mathbf{\mu}_k \right)}^{\top} \Sigma_k^{-1} + \left( \mathbf{e} - \mathbf{\mu}_k \right) \right) + + denotes a weighting function based on trainable mean vector + :math:`\mathbf{\mu}_k` and diagonal covariance matrix + :math:`\mathbf{\Sigma}_k`. + + .. note:: + + The edge attribute :math:`\mathbf{e}_{ij}` is usually given by + :math:`\mathbf{e}_{ij} = \mathbf{p}_j - \mathbf{p}_i`, where + :math:`\mathbf{p}_i` denotes the position of node :math:`i` (see + :class:`sharker.transform.Cartesian`). + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + axis (int): Pseudo-coordinate dimensionality. + kernel_size (int): Number of kernels :math:`K`. + separate_gaussians (bool, optional): If set to :obj:`True`, will + learn separate GMMs for every pair of input and output channel, + inspired by traditional CNNs. (default: :obj:`False`) + aggr (str, optional): The aggregation operator to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"mean"`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add transformed root node features to the output. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, axis: int, kernel_size: int, + separate_gaussians: bool = False, aggr: str = 'mean', + root_weight: bool = True, bias: bool = True, **kwargs): + super().__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.axis = axis + self.kernel_size = kernel_size + self.separate_gaussians = separate_gaussians + self.root_weight = root_weight + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + self.rel_in_channels = in_channels[0] + + self.g = Parameter( + mint.zeros([in_channels[0], out_channels * kernel_size])) + + if not self.separate_gaussians: + self.mu = Parameter(mint.zeros([kernel_size, axis])) + self.sigma = Parameter(mint.zeros([kernel_size, axis])) + if self.separate_gaussians: + self.mu = Parameter( + mint.zeros([in_channels[0], out_channels, kernel_size, axis])) + self.sigma = Parameter( + mint.zeros([in_channels[0], out_channels, kernel_size, axis])) + + if root_weight: + self.root = nn.Dense(in_channels[1], out_channels, has_bias=False) + + if bias: + self.bias = Parameter(ms.numpy.empty([out_channels])) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.g) + glorot(self.mu) + glorot(self.sigma) + if self.root_weight: + glorot(self.root.weight) + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, size: Tuple[int, ...] = None): + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_attr: Optional[Tensor]) + if not self.separate_gaussians: + out: Tuple[Tensor, Optional[Tensor]] = (x[0] @ self.g, x[1]) + out = self.propagate(edge_index, x=out, edge_attr=edge_attr, + shape=size) + else: + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, + shape=size) + + x_r = x[1] + if x_r is not None and self.root is not None: + out += self.root(x_r) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: + EPS = 1e-15 + F, M = self.rel_in_channels, self.out_channels + (E, D), K = edge_attr.shape, self.kernel_size + + if not self.separate_gaussians: + gaussian = -0.5 * (edge_attr.view(E, 1, D) - + self.mu.view(1, K, D)).pow(2) + gaussian = gaussian / (EPS + self.sigma.view(1, K, D).pow(2)) + gaussian = mint.exp(gaussian.sum(-1)) # [E, K] + + return (x_j.view(E, K, M) * gaussian.view(E, K, 1)).sum(-2) + + else: + gaussian = -0.5 * (edge_attr.view(E, 1, 1, 1, D) - + self.mu.view(1, F, M, K, D)).pow(2) + gaussian = gaussian / (EPS + self.sigma.view(1, F, M, K, D).pow(2)) + gaussian = mint.exp(gaussian.sum(-1)) # [E, F, M, K] + + gaussian = gaussian * self.g.view(1, F, M, K) + gaussian = gaussian.sum(-1) # [E, F, M] + + return (x_j.view(E, F, 1) * gaussian).sum(-2) # [E, M] + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, axis={self.axis})') diff --git a/mindscience/sharker/nn/conv/graph_conv.py b/mindscience/sharker/nn/conv/graph_conv.py new file mode 100644 index 000000000..7c16e89d9 --- /dev/null +++ b/mindscience/sharker/nn/conv/graph_conv.py @@ -0,0 +1,94 @@ +from typing import Tuple, Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing + + +class GraphConv(MessagePassing): + r"""The graph neural network operator from the `"Weisfeiler and Leman Go + Neural: Higher-order Graph Neural Networks" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j + + where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to + target node :obj:`i` (default: :obj:`1`) + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"add"`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + aggr: str = "add", + bias: bool = True, + **kwargs, + ): + super().__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_rel = nn.Dense(in_channels[0], out_channels, has_bias=bias) + self.lin_root = nn.Dense(in_channels[1], out_channels, has_bias=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor,], + edge_weight: Optional[Tensor] = None, + size: Optional[Tuple[int, int]] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, shape=size) + out = self.lin_rel(out) + + x_r = x[1] + if x_r is not None: + out += self.lin_root(x_r) + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def message_and_aggregate( + self, adj_t: Union[Tensor,], x: Tuple[Tensor, Optional[Tensor]] + ) -> Tensor: + return adj_t.matmul(x[0]) diff --git a/mindscience/sharker/nn/conv/gravnet_conv.py b/mindscience/sharker/nn/conv/gravnet_conv.py new file mode 100644 index 000000000..e46663ae2 --- /dev/null +++ b/mindscience/sharker/nn/conv/gravnet_conv.py @@ -0,0 +1,113 @@ +import warnings +from typing import Optional, Union, Tuple +from mindspore import Tensor, ops, nn, nn, mint +from .message_passing import MessagePassing +from ...utils.cluster import knn + + +class GravNetConv(MessagePassing): + r"""The GravNet operator from the `"Learning Representations of Irregular + Particle-detector Geometry with Distance-weighted Graph + Networks" `_ paper, where the graph is + dynamically constructed using nearest neighbors. + The neighbors are constructed in a learnable low-dimensional projection of + the feature space. + A second projection of the input feature space is then propagated from the + neighbors to each vertex using distance weights that are derived by + applying a Gaussian function to the distances. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): The number of output channels. + space_dimensions (int): The dimensionality of the space used to + construct the neighbors; referred to as :math:`S` in the paper. + propagate_dimensions (int): The number of features to be propagated + between the vertices; referred to as :math:`F_{\textrm{LR}}` in the + paper. + k (int): The number of nearest neighbors. + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` + if bipartite, + batch vector :math:`(|\mathcal{V}|)` or + :math:`((|\mathcal{V}_s|), (|\mathcal{V}_t|))` if bipartite + *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: int, out_channels: int, + space_dimensions: int, propagate_dimensions: int, k: int, + num_workers: Optional[int] = None, **kwargs): + super().__init__(aggr=['mean', 'max'], flow='src_to_trg', + **kwargs) + + if num_workers is not None: + warnings.warn( + "'num_workers' attribute in '{self.__class__.__name__}' is " + "deprecated and will be removed in a future release") + + self.in_channels = in_channels + self.out_channels = out_channels + self.k = k + + self.lin_s = nn.Dense(in_channels, space_dimensions) + self.lin_h = nn.Dense(in_channels, propagate_dimensions) + + self.lin_out1 = nn.Dense(in_channels, out_channels, has_bias=False) + self.lin_out2 = nn.Dense(2 * propagate_dimensions, out_channels) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + batch: Union[Optional[Tensor], Optional[Tuple[Tensor, Tensor]]] = None, + ) -> Tensor: + + is_bipartite: bool = True + if isinstance(x, Tensor): + x = (x, x) + is_bipartite = False + + if x[0].dim() != 2: + raise ValueError("Static graphs not supported in 'GravNetConv'") + + b: Tuple[Optional[Tensor], Optional[Tensor]] = (None, None) + if isinstance(batch, Tensor): + b = (batch, batch) + elif isinstance(batch, tuple): + assert batch is not None + b = (batch[0], batch[1]) + + h_l: Tensor = self.lin_h(x[0]) + + s_l: Tensor = self.lin_s(x[0]) + s_r: Tensor = self.lin_s(x[1]) if is_bipartite else s_l + + edge_index = knn(s_l, s_r, self.k, b[0], b[1]).flip([0]) + + edge_weight = (s_l[edge_index[0]] - s_r[edge_index[1]]).pow(2).sum(-1) + edge_weight = mint.exp(-10. * edge_weight) # 10 gives a better spread + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=(h_l, None), + edge_weight=edge_weight, + shape=(s_l.shape[0], s_r.shape[0])) + + return self.lin_out1(x[1]) + self.lin_out2(out) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return x_j * edge_weight.unsqueeze(1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, k={self.k})') diff --git a/mindscience/sharker/nn/conv/heat_conv.py b/mindscience/sharker/nn/conv/heat_conv.py new file mode 100644 index 000000000..68111064f --- /dev/null +++ b/mindscience/sharker/nn/conv/heat_conv.py @@ -0,0 +1,134 @@ +from typing import Optional, Union +from mindspore import Tensor, ops, nn, mint +from mindspore.nn import Embedding +from .message_passing import MessagePassing +from ..dense.linear import HeteroLinear +from ...utils import softmax + + +class HEATConv(MessagePassing): + r"""The heterogeneous edge-enhanced graph attentional operator from the + `"Heterogeneous Edge-Enhanced Graph Attention Network For Multi-Agent + Trajectory Prediction" `_ paper. + + :class:`HEATConv` enhances :class:`~sharker.nn.conv.GATConv` by: + + 1. type-specific transformations of nodes of different types + 2. edge type and edge feature incorporation, in which edges are assumed to + have different types but contain the same kind of attributes + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + num_node_types (int): The number of node types. + num_edge_types (int): The number of edge types. + edge_type_emb_dim (int): The embedding size of edge types. + edge_dim (int): Edge feature dimensionality. + edge_attr_emb_dim (int): The embedding size of edge features. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add transformed root node features to the output. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + node types :math:`(|\mathcal{V}|)`, + edge types :math:`(|\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, + num_node_types: int, num_edge_types: int, + edge_type_emb_dim: int, edge_dim: int, edge_attr_emb_dim: int, + heads: int = 1, concat: bool = True, + negative_slope: float = 0.2, dropout: float = 0.0, + root_weight: bool = True, bias: bool = True, **kwargs): + + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.root_weight = root_weight + + self.hetero_lin = HeteroLinear(in_channels, out_channels, + num_node_types, has_bias=bias) + + self.edge_type_emb = Embedding(num_edge_types, edge_type_emb_dim) + self.edge_attr_emb = nn.Dense(edge_dim, edge_attr_emb_dim, has_bias=False) + + self.att = nn.Dense(2 * out_channels + edge_type_emb_dim + edge_attr_emb_dim, + self.heads, has_bias=False) + + self.lin = nn.Dense(out_channels + edge_attr_emb_dim, out_channels, + has_bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self.hetero_lin.reset_parameters() + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], node_type: Tensor, + edge_type: Tensor, edge_attr: Optional[Tensor] = None) -> Tensor: + + x = self.hetero_lin(x, node_type) + + edge_type_emb = ops.leaky_relu(self.edge_type_emb(edge_type), + self.negative_slope) + + # propagate_type: (x: Tensor, edge_type_emb: Tensor, + # edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_type_emb=edge_type_emb, + edge_attr=edge_attr) + + if self.concat: + if self.root_weight: + out += x.view(-1, 1, self.out_channels) + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(1) + if self.root_weight: + out += x + + return out + + def message(self, x_i: Tensor, x_j: Tensor, edge_type_emb: Tensor, + edge_attr: Tensor, index: Tensor, ptr: Optional[Tensor], + size_i: Optional[int]) -> Tensor: + + edge_attr = ops.leaky_relu(self.edge_attr_emb(edge_attr), + self.negative_slope) + + alpha = mint.cat(([x_i, x_j, edge_type_emb, edge_attr]), dim=-1) + alpha = ops.leaky_relu(self.att(alpha), self.negative_slope) + alpha = softmax(alpha, index, ptr, size_i) + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + + out = self.lin(mint.cat(([x_j, edge_attr]), dim=-1)).unsqueeze(-2) + return out * alpha.unsqueeze(-1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/hetero_conv.py b/mindscience/sharker/nn/conv/hetero_conv.py new file mode 100644 index 000000000..6f805f23e --- /dev/null +++ b/mindscience/sharker/nn/conv/hetero_conv.py @@ -0,0 +1,167 @@ +import warnings +from typing import Dict, List, Optional, Tuple +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ...utils.hetero import check_add_self_loops + +def group(xs: List[Tensor], aggr: Optional[str]) -> Optional[Tensor]: + if len(xs) == 0: + return None + elif aggr is None: + return mint.stack((xs), dim=1) + elif len(xs) == 1: + return xs[0] + elif aggr == "cat": + return mint.cat((xs), dim=-1) + else: + out = mint.stack((xs), dim=0) + out = getattr(ops, aggr)(out, 0) + out = out[0] if isinstance(out, tuple) else out + return out + + +class HeteroConv(nn.CellList): + r"""A generic wrapper for computing graph convolution on heterogeneous + graphs. + This layer will pass messages from source nodes to target nodes based on + the bipartite GNN layer given for a specific edge type. + If multiple relations point to the same destination, their results will be + aggregated according to :attr:`aggr`. + In comparison to :meth:`sharker.nn.to_hetero`, this layer is + especially useful if you want to apply different message passing modules + for different edge types. + + .. code-block:: python + + hetero_conv = HeteroConv({ + ('paper', 'cites', 'paper'): GCNConv(-1, 64), + ('author', 'writes', 'paper'): SAGEConv((-1, -1), 64), + ('paper', 'written_by', 'author'): GATConv((-1, -1), 64), + }, aggr='sum') + + out_dict = hetero_conv(x_dict, edge_index_dict) + + print(list(out_dict.keys())) + >>> ['paper', 'author'] + + Args: + convs (Dict[Tuple[str, str, str], MessagePassing]): A dictionary + holding a bipartite + :class:`~sharker.nn.conv.MessagePassing` layer for each + individual edge type. + aggr (str, optional): The aggregation scheme to use for grouping node + embeddings generated by different relations + (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"cat"`, :obj:`None`). (default: :obj:`"sum"`) + """ + + def __init__( + self, + convs: Dict[Tuple[str, str, str], MessagePassing], + aggr: Optional[str] = "sum", + ): + super().__init__() + + self.conv_idx = {} + self.convs = nn.SequentialCell() + i = 0 + for edge_type, cell in convs.items(): + check_add_self_loops(cell, [edge_type]) + self.conv_idx[edge_type] = i + self.convs.append(cell) + i += 1 + src_node_types = set([key[0] for key in convs.keys()]) + dst_node_types = set([key[-1] for key in convs.keys()]) + if len(src_node_types - dst_node_types) > 0: + warnings.warn( + f"There exist node types ({src_node_types - dst_node_types}) " + f"whose representations do not get updated during message " + f"passing as they do not occur as destination type in any " + f"edge type. This may lead to unexpected behavior.") + self.aggr = aggr + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + for conv in self.convs.values(): + conv.reset_parameters() + + def construct( + self, + *args_dict, + **kwargs_dict, + ) -> Dict[str, Tensor]: + r"""Runs the forward pass of the module. + + Args: + x_dict (Dict[str, Tensor]): A dictionary holding node feature + information for each individual node type. + edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A + dictionary holding graph connectivity information for each + individual edge type, either as a :class:`Tensor` of + shape :obj:`[2, num_edges]` or a + :class:`mindspore_sparse.SparseTensor`. + *args_dict (optional): Additional forward arguments of invididual + :class:`sharker.nn.conv.MessagePassing` layers. + **kwargs_dict (optional): Additional forward arguments of + individual :class:`sharker.nn.conv.MessagePassing` + layers. + For example, if a specific GNN layer at edge type + :obj:`edge_type` expects edge attributes :obj:`edge_attr` as a + forward argument, then you can pass them to + :meth:`~sharker.nn.conv.HeteroConv.forward` via + :obj:`edge_attr_dict = { edge_type: edge_attr }`. + """ + out_dict: Dict[str, List[Tensor]] = {} + for edge_type, conv_idx in self.conv_idx.items(): + conv = self.convs[conv_idx] + src, rel, dst = edge_type + + has_edge_level_arg = False + + args = [] + for value_dict in args_dict: + if edge_type in value_dict: + has_edge_level_arg = True + args.append(value_dict[edge_type]) + elif src == dst and src in value_dict: + args.append(value_dict[src]) + elif src in value_dict or dst in value_dict: + args.append(( + value_dict.get(src, None), + value_dict.get(dst, None), + )) + + kwargs = {} + for arg, value_dict in kwargs_dict.items(): + if not arg.endswith('_dict'): + raise ValueError( + f"Keyword arguments in '{self.__class__.__name__}' " + f"need to end with '_dict' (got '{arg}')") + + arg = arg[:-5] # `{*}_dict` + if edge_type in value_dict: + has_edge_level_arg = True + kwargs[arg] = value_dict[edge_type] + elif src == dst and src in value_dict: + kwargs[arg] = value_dict[src] + elif src in value_dict or dst in value_dict: + kwargs[arg] = ( + value_dict.get(src, None), + value_dict.get(dst, None), + ) + + if not has_edge_level_arg: + continue + out = conv(*args, **kwargs) + + if dst not in out_dict: + out_dict[dst] = [out] + else: + out_dict[dst].append(out) + for key, value in out_dict.items(): + out_dict[key] = group(value, self.aggr) + + return out_dict + + def __repr__(self) -> str: + return f'{self.__class__.__name__}(num_relations={len(self.convs)})' \ No newline at end of file diff --git a/mindscience/sharker/nn/conv/hgt_conv.py b/mindscience/sharker/nn/conv/hgt_conv.py new file mode 100644 index 000000000..7a5131afc --- /dev/null +++ b/mindscience/sharker/nn/conv/hgt_conv.py @@ -0,0 +1,227 @@ +import math +from typing import Dict, List, Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, ops, nn, Parameter, mint +from .message_passing import MessagePassing +from ..dense import HeteroDictLinear, HeteroLinear +from ..inits import ones +from ...utils import softmax +from ...utils.hetero import construct_bipartite_edge_index + + +class HGTConv(MessagePassing): + r"""The Heterogeneous Graph Transformer (HGT) operator from the + `"Heterogeneous Graph Transformer" `_ + paper. + + .. note:: + + Args: + in_channels (int or Dict[str, int]): Size of each input sample of every + node type, or :obj:`-1` to derive the size from the first input(s) + to the forward method. + out_channels (int): Size of each output sample. + metadata (Tuple[List[str], List[Tuple[str, str, str]]]): The metadata + of the heterogeneous graph, *i.e.* its node and edge types given + by a list of strings and a list of string triplets, respectively. + See :meth:`sharker.data.HeteroGraph.metadata` for more + information. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + """ + + def __init__( + self, + in_channels: Union[int, Dict[str, int]], + out_channels: int, + metadata: Tuple[List[str], List[Tuple[str, str, str]]], + heads: int = 1, + **kwargs, + ): + super().__init__(aggr='add', node_dim=0, **kwargs) + + if out_channels % heads != 0: + raise ValueError(f"'out_channels' (got {out_channels}) must be " + f"divisible by the number of heads (got {heads})") + + if not isinstance(in_channels, dict): + in_channels = {node_type: in_channels for node_type in metadata[0]} + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.node_types = metadata[0] + self.edge_types = metadata[1] + self.edge_types_map = { + edge_type: i + for i, edge_type in enumerate(metadata[1]) + } + + self.dst_node_types = set([key[-1] for key in self.edge_types]) + + self.kqv_lin = HeteroDictLinear(self.in_channels, + self.out_channels * 3) + + self.out_lin = HeteroDictLinear(self.out_channels, self.out_channels, + types=self.node_types) + + dim = out_channels // heads + num_types = heads * len(self.edge_types) + + self.k_rel = HeteroLinear(dim, dim, num_types, has_bias=False, + is_sorted=True) + self.v_rel = HeteroLinear(dim, dim, num_types, has_bias=False, + is_sorted=True) + + self.skip = { + node_type: Parameter(ms.numpy.empty(1)) + for node_type in self.node_types + } + + self.p_rel = {} + for edge_type in self.edge_types: + self.p_rel[edge_type] = Parameter(ms.numpy.empty([1, heads])) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self.kqv_lin.reset_parameters() + self.out_lin.reset_parameters() + self.k_rel.reset_parameters() + self.v_rel.reset_parameters() + ones(self.skip) + ones(self.p_rel) + + def _cat(self, x_dict: Dict[str, Tensor]) -> Tuple[Tensor, Dict[str, int]]: + """Concatenates a dictionary of features.""" + cumsum = 0 + outs: List[Tensor] = [] + offset: Dict[str, int] = {} + for key, x in x_dict.items(): + outs.append(x) + offset[key] = cumsum + cumsum += x.shape[0] + return mint.cat((outs), dim=0), offset + + def _construct_src_node_feat( + self, k_dict: Dict[str, Tensor], v_dict: Dict[str, Tensor], + edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor,]] + ) -> Tuple[Tensor, Tensor, Dict[Tuple[str, str, str], int]]: + """Constructs the source node representations.""" + cumsum = 0 + num_edge_types = len(self.edge_types) + H, D = self.heads, self.out_channels // self.heads + + # Flatten into a single tensor with shape [num_edge_types * heads, D]: + ks: List[Tensor] = [] + vs: List[Tensor] = [] + type_list: List[Tensor] = [] + offset: Dict[Tuple[str, str, str]] = {} + for edge_type in edge_index_dict.keys(): + src = edge_type[0] + N = k_dict[src].shape[0] + offset[edge_type] = cumsum + cumsum += N + + # construct type_vec for curr edge_type with shape [H, D] + edge_type_offset = self.edge_types_map[edge_type] + type_vec = mint.arange(H, dtype=ms.int64).view(-1, 1).tile( + (1, N)) * num_edge_types + edge_type_offset + + type_list.append(type_vec) + ks.append(k_dict[src]) + vs.append(v_dict[src]) + + ks = mint.cat((ks), dim=0).swapaxes(0, 1).reshape(-1, D) + vs = mint.cat((vs), dim=0).swapaxes(0, 1).reshape(-1, D) + type_vec = mint.cat((type_list), dim=1).flatten() + + k = self.k_rel(ks, type_vec).view(H, -1, D).swapaxes(0, 1) + v = self.v_rel(vs, type_vec).view(H, -1, D).swapaxes(0, 1) + + return k, v, offset + + def construct( + self, + x_dict: Dict[str, Tensor], + edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor,]] # Support both. + ) -> Dict[str, Optional[Tensor]]: + r"""Runs the forward pass of the module. + + Args: + x_dict (Dict[str, Tensor]): A dictionary holding input node + features for each individual node type. + edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A + dictionary holding graph connectivity information for each + individual edge type, either as a :class:`Tensor` of + shape :obj:`[2, num_edges]` or a + :class:`mindspore_sparse.SparseTensor`. + + :rtype: :obj:`Dict[str, Optional[Tensor]]` - The output node + embeddings for each node type. + In case a node type does not receive any message, its output will + be set to :obj:`None`. + """ + F = self.out_channels + H = self.heads + D = F // H + + k_dict, q_dict, v_dict, out_dict = {}, {}, {}, {} + + # Compute K, Q, V over node types: + kqv_dict = self.kqv_lin(x_dict) + for key, val in kqv_dict.items(): + k, q, v = ms.numpy.split(val, 3, axis=1) + k_dict[key] = k.view(-1, H, D) + q_dict[key] = q.view(-1, H, D) + v_dict[key] = v.view(-1, H, D) + + q, dst_offset = self._cat(q_dict) + k, v, src_offset = self._construct_src_node_feat( + k_dict, v_dict, edge_index_dict) + + edge_index, edge_attr = construct_bipartite_edge_index( + edge_index_dict, src_offset, dst_offset, edge_attr_dict=self.p_rel, + num_nodes=k.shape[0]) + + out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) + + # Reconstruct output node embeddings dict: + for node_type, start_offset in dst_offset.items(): + end_offset = start_offset + q_dict[node_type].shape[0] + if node_type in self.dst_node_types: + out_dict[node_type] = out[start_offset:end_offset] + + # Transform output node embeddings: + a_dict = self.out_lin({ + k: + ops.gelu(v) if v is not None else v + for k, v in out_dict.items() + }) + + # Iterate over node types: + for node_type, out in out_dict.items(): + out = a_dict[node_type] + + if out.shape[-1] == x_dict[node_type].shape[-1]: + alpha = self.skip[node_type].sigmoid() + out = alpha * out + (1 - alpha) * x_dict[node_type] + out_dict[node_type] = out + + return out_dict + + def message(self, k_j: Tensor, q_i: Tensor, v_j: Tensor, edge_attr: Tensor, + index: Tensor, ptr: Optional[Tensor], + size_i: Optional[int]) -> Tensor: + alpha = (q_i * k_j).sum(-1) * edge_attr + alpha = alpha / math.sqrt(q_i.shape[-1]) + alpha = softmax(alpha, index, ptr, size_i) + out = v_j * alpha.view(-1, self.heads, 1) + return out.view(-1, self.out_channels) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(-1, {self.out_channels}, ' + f'heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/hypergraph_conv.py b/mindscience/sharker/nn/conv/hypergraph_conv.py new file mode 100644 index 000000000..d86120b71 --- /dev/null +++ b/mindscience/sharker/nn/conv/hypergraph_conv.py @@ -0,0 +1,210 @@ +from typing import Optional +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from ...experimental import disable_dynamic_shapes +from .message_passing import MessagePassing +from ..inits import glorot, zeros +from ...utils import scatter, softmax + + +class HypergraphConv(MessagePassing): + r"""The hypergraph convolutional operator from the `"Hypergraph Convolution + and Hypergraph Attention" `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \mathbf{D}^{-1} \mathbf{H} \mathbf{W} + \mathbf{B}^{-1} \mathbf{H}^{\top} \mathbf{X} \mathbf{\Theta} + + where :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` is the incidence + matrix, :math:`\mathbf{W} \in \mathbb{R}^M` is the diagonal hyperedge + weight matrix, and + :math:`\mathbf{D}` and :math:`\mathbf{B}` are the corresponding degree + matrices. + + For example, in the hypergraph scenario + :math:`\mathcal{G} = (\mathcal{V}, \mathcal{E})` with + :math:`\mathcal{V} = \{ 0, 1, 2, 3 \}` and + :math:`\mathcal{E} = \{ \{ 0, 1, 2 \}, \{ 1, 2, 3 \} \}`, the + :obj:`hyperedge_index` is represented as: + + .. code-block:: python + + hyperedge_index = ms.Tensor([ + [0, 1, 2, 1, 2, 3], + [0, 0, 0, 1, 1, 1], + ]) + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + use_attention (bool, optional): If set to :obj:`True`, attention + will be added to this layer. (default: :obj:`False`) + attention_mode (str, optional): The mode on how to compute attention. + If set to :obj:`"node"`, will compute attention scores of nodes + within all nodes belonging to the same hyperedge. + If set to :obj:`"edge"`, will compute attention scores of nodes + across all edges holding this node belongs to. + (default: :obj:`"node"`) + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + hyperedge indices :math:`(|\mathcal{V}|, |\mathcal{E}|)`, + hyperedge weights :math:`(|\mathcal{E}|)` *(optional)* + hyperedge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + use_attention: bool = False, + attention_mode: str = 'node', + heads: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + dropout: float = 0, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(flow='src_to_trg', node_dim=0, **kwargs) + + assert attention_mode in ['node', 'edge'] + + self.in_channels = in_channels + self.out_channels = out_channels + self.use_attention = use_attention + self.attention_mode = attention_mode + + if self.use_attention: + self.heads = heads + self.concat = concat + self.negative_slope = negative_slope + self.dropout = dropout + self.lin = nn.Dense(in_channels, heads * out_channels, has_bias=False) + self.att = Parameter(ms.numpy.empty([1, heads, 2 * out_channels])) + else: + self.heads = 1 + self.concat = True + self.lin = nn.Dense(in_channels, out_channels, has_bias=False) + + if bias and concat: + self.bias = Parameter(ms.numpy.empty(heads * out_channels)) + elif bias and not concat: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.lin.weight) + if self.use_attention: + glorot(self.att) + if self.bias is not None: + zeros(self.bias) + + @disable_dynamic_shapes(required_args=['num_edges']) + def construct(self, x: Tensor, hyperedge_index: Tensor, + hyperedge_weight: Optional[Tensor] = None, + hyperedge_attr: Optional[Tensor] = None, + num_edges: Optional[int] = None) -> Tensor: + r"""Runs the forward pass of the module. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + hyperedge_index (Tensor): The hyperedge indices, *i.e.* + the sparse incidence matrix + :math:`\mathbf{H} \in {\{ 0, 1 \}}^{N \times M}` mapping from + nodes to edges. + hyperedge_weight (Tensor, optional): Hyperedge weights + :math:`\mathbf{W} \in \mathbb{R}^M`. (default: :obj:`None`) + hyperedge_attr (Tensor, optional): Hyperedge feature matrix + in :math:`\mathbb{R}^{M \times F}`. + These features only need to get passed in case + :obj:`use_attention=True`. (default: :obj:`None`) + num_edges (int, optional) : The number of edges :math:`M`. + (default: :obj:`None`) + """ + num_nodes = x.shape[0] + + if num_edges is None: + num_edges = 0 + if hyperedge_index.numel() > 0: + num_edges = int(hyperedge_index[1].max()) + 1 + + if hyperedge_weight is None: + hyperedge_weight = x.new_ones(num_edges) + + x = self.lin(x) + + alpha = None + if self.use_attention: + assert hyperedge_attr is not None + x = x.view(-1, self.heads, self.out_channels) + hyperedge_attr = self.lin(hyperedge_attr) + hyperedge_attr = hyperedge_attr.view(-1, self.heads, + self.out_channels) + x_i = x[hyperedge_index[0]] + x_j = hyperedge_attr[hyperedge_index[1]] + alpha = (mint.cat(([x_i, x_j]), dim=-1) * self.att).sum(-1) + alpha = ops.leaky_relu(alpha, self.negative_slope) + if self.attention_mode == 'node': + alpha = softmax(alpha, hyperedge_index[1], num_nodes=num_edges) + else: + alpha = softmax(alpha, hyperedge_index[0], num_nodes=num_nodes) + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + + D = scatter(hyperedge_weight[hyperedge_index[1]], hyperedge_index[0], + dim=0, dim_size=num_nodes, reduce='sum') + D = 1.0 / D + D[D == float("inf")] = 0 + + B = scatter(x.new_ones(hyperedge_index.shape[1]), hyperedge_index[1], + dim=0, dim_size=num_edges, reduce='sum') + B = 1.0 / B + B[B == float("inf")] = 0 + + out = self.propagate(hyperedge_index, x=x, norm=B, alpha=alpha, + shape=(num_nodes, num_edges)) + out = self.propagate(hyperedge_index.flip([0]), x=out, norm=D, + alpha=alpha, shape=(num_edges, num_nodes)) + + if self.concat is True: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(1) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, norm_i: Tensor, alpha: Tensor) -> Tensor: + H, F = self.heads, self.out_channels + + out = norm_i.view(-1, 1, 1) * x_j.view(-1, H, F) + + if alpha is not None: + out = alpha.view(-1, self.heads, 1) * out + + return out diff --git a/mindscience/sharker/nn/conv/le_conv.py b/mindscience/sharker/nn/conv/le_conv.py new file mode 100644 index 000000000..9b8811134 --- /dev/null +++ b/mindscience/sharker/nn/conv/le_conv.py @@ -0,0 +1,94 @@ +from typing import Tuple, Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing + + +class LEConv(MessagePassing): + r"""The local extremum graph neural network operator from the + `"ASAP: Adaptive Structure Aware Pooling for Learning Hierarchical Graph + Representations" `_ paper. + + :class:`LEConv` finds the importance of nodes with respect to their + neighbors using the difference operator: + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{x}_i \cdot \mathbf{\Theta}_1 + + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot + (\mathbf{\Theta}_2 \mathbf{x}_i - \mathbf{\Theta}_3 \mathbf{x}_j) + + where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to + target node :obj:`i` (default: :obj:`1`) + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + bias (bool, optional): If set to :obj:`False`, the layer will + not learn an additive bias. (default: :obj:`True`). + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault("aggr", "add") + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin1 = nn.Dense(in_channels[0], out_channels, has_bias=bias) + self.lin2 = nn.Dense(in_channels[1], out_channels, has_bias=False) + self.lin3 = nn.Dense(in_channels[1], out_channels, has_bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor,], + edge_weight: Optional[Tensor] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + a = self.lin1(x[0]) + b = self.lin2(x[1]) + + # propagate_type: (a: Tensor, b: Tensor, edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, a=a, b=b, edge_weight=edge_weight) + + return out + self.lin3(x[1]) + + def message( + self, a_j: Tensor, b_i: Tensor, edge_weight: Optional[Tensor] + ) -> Tensor: + out = a_j - b_i + return out if edge_weight is None else out * edge_weight.view(-1, 1) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.in_channels}, " f"{self.out_channels})" diff --git a/mindscience/sharker/nn/conv/lg_conv.py b/mindscience/sharker/nn/conv/lg_conv.py new file mode 100644 index 000000000..b0a20738d --- /dev/null +++ b/mindscience/sharker/nn/conv/lg_conv.py @@ -0,0 +1,48 @@ +from typing import Union, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm + + +class LGConv(MessagePassing): + r"""The Light Graph Convolution (LGC) operator from the `"LightGCN: + Simplifying and Powering Graph Convolution Network for Recommendation" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i)} + \frac{e_{j,i}}{\sqrt{\deg(i)\deg(j)}} \mathbf{x}_j + + Args: + normalize (bool, optional): If set to :obj:`False`, output features + will not be normalized via symmetric normalization. + (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` + """ + + def __init__(self, normalize: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + self.normalize = normalize + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if self.normalize and isinstance(edge_index, Tensor): + out = gcn_norm(edge_index, edge_weight, x.shape[self.node_dim], + add_self_loops=False, flow=self.flow, dtype=x.dtype) + edge_index, edge_weight = out + + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + return self.propagate(edge_index, x=x, edge_weight=edge_weight) + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j \ No newline at end of file diff --git a/mindscience/sharker/nn/conv/message_passing.py b/mindscience/sharker/nn/conv/message_passing.py new file mode 100644 index 000000000..4acb3b2e6 --- /dev/null +++ b/mindscience/sharker/nn/conv/message_passing.py @@ -0,0 +1,611 @@ +import os +from abc import abstractmethod +from inspect import Parameter +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + OrderedDict, + Set, + Tuple, + Union, +) + +import mindspore as ms +from mindspore import Tensor +from mindspore import nn +from mindspore import ops + +from ...template import module_from_template +from ...inspector import Inspector +from ..aggr import Aggregation +from ..resolver import aggregation_resolver as aggr_resolver + +FUSE_AGGRS = {"add", "sum", "mean", "min", "max"} +HookDict = OrderedDict[int, Callable] + + +class MessagePassing(nn.Cell): + r"""Base class for creating message passing layers. + + Message passing layers follow the form + + .. math:: + \mathbf{x}_i^{\prime} = \gamma_{\mathbf{\Theta}} \left( \mathbf{x}_i, + \bigoplus_{j \in \mathcal{N}(i)} \, \phi_{\mathbf{\Theta}} + \left(\mathbf{x}_i, \mathbf{x}_j,\mathbf{e}_{j,i}\right) \right), + + where :math:`\bigoplus` denotes a differentiable, permutation invariant + function, *e.g.*, sum, mean, min, max or mul, and + :math:`\gamma_{\mathbf{\Theta}}` and :math:`\phi_{\mathbf{\Theta}}` denote + differentiable functions such as MLPs. + See `here `__ for the accompanying tutorial. + + Args: + aggr (str or [str] or Aggregation, optional): The aggregation scheme + to use, *e.g.*, :obj:`"sum"` :obj:`"mean"`, :obj:`"min"`, + :obj:`"max"` or :obj:`"mul"`. + In addition, can be any + :class:`~sharker.nn.aggr.Aggregation` module (or any string + that automatically resolves to it). + If given as a list, will make use of multiple aggregations in which + different outputs will get concatenated in the last dimension. + If set to :obj:`None`, the :class:`MessagePassing` instantiation is + expected to implement its own aggregation logic via + :meth:`aggregate`. (default: :obj:`"add"`) + aggr_kwargs (Dict[str, Any], optional): Arguments passed to the + respective aggregation function in case it gets automatically + resolved. (default: :obj:`None`) + flow (str, optional): The flow direction of message passing + (:obj:`"src_to_trg"` or :obj:`"trg_to_src"`). + (default: :obj:`"src_to_trg"`) + node_dim (int, optional): The dim along which to propagate. + (default: :obj:`-2`) + decomposed_layers (int, optional): The number of feature decomposition + layers, as introduced in the `"Optimizing Memory Efficiency of + Graph Neural Networks on Edge Computing Platforms" + `_ paper. + Feature decomposition reduces the peak memory usage by slicing + the feature dimensions into separated feature decomposition layers + during GNN aggregation. + This method can accelerate GNN execution on CPU-based platforms + (*e.g.*, 2-3x speedup on the + :class:`~sharker.datasets.Reddit` dataset) for common GNN + models such as :class:`~sharker.nn.models.GCN`, + :class:`~sharker.nn.models.GraphSAGE`, + :class:`~sharker.nn.models.GIN`, etc. + However, this method is not applicable to all GNN operators + available, in particular for operators in which message computation + can not easily be decomposed, *e.g.* in attention-based GNNs. + The selection of the optimal value of :obj:`decomposed_layers` + depends both on the specific graph dataset and available hardware + resources. + A value of :obj:`2` is suitable in most cases. + Although the peak memory usage is directly associated with the + granularity of feature decomposition, the same is not necessarily + true for execution speedups. (default: :obj:`1`) + """ + + special_args: Set[str] = { + "edge_index", + "adj_t", + "edge_index_i", + "edge_index_j", + "size", + "size_i", + "size_j", + "ptr", + "index", + "dim_size", + } + + def __init__( + self, + aggr: Optional[Union[str, List[str], Aggregation]] = "sum", + *, + aggr_kwargs: Optional[Dict[str, Any]] = None, + flow: str = "src_to_trg", + node_dim: int = -2, + decomposed_layers: int = 1, + ) -> None: + super().__init__() + + if flow not in ["src_to_trg", "trg_to_src"]: + raise ValueError( + f"Expected 'flow' to be either 'src_to_trg'" + f" or 'trg_to_src' (got '{flow}')" + ) + + # Cast `aggr` into a string representation for backward compatibility: + self.aggr: Optional[Union[str, List[str]]] + if aggr is None: + self.aggr = None + elif isinstance(aggr, (str, Aggregation)): + self.aggr = str(aggr) + elif isinstance(aggr, (tuple, list)): + self.aggr = [str(x) for x in aggr] + + self.aggr_module = aggr_resolver(aggr, **(aggr_kwargs or {})) + self.flow = flow + self.node_dim = node_dim + + # Collect attribute names requested in message passing hooks: + self.inspector = Inspector(self.__class__) + self.inspector.inspect_signature(self.message) + self.inspector.inspect_signature(self.aggregate, exclude=[0, "aggr"]) + self.inspector.inspect_signature(self.message_and_aggregate, [0]) + self.inspector.inspect_signature(self.update, exclude=[0]) + self.inspector.inspect_signature(self.edge_update) + + self._user_args: List[str] = self.inspector.get_flat_param_names( + ["message", "aggregate", "update"], exclude=self.special_args + ) + self._fused_user_args: List[str] = self.inspector.get_flat_param_names( + ["message_and_aggregate", "update"], exclude=self.special_args + ) + self._edge_user_args: List[str] = self.inspector.get_param_names( + "edge_update", exclude=self.special_args + ) + + # Support for "fused" message passing: + self.fuse = self.inspector.implements("message_and_aggregate") + if self.aggr is not None: + self.fuse &= isinstance(self.aggr, str) and self.aggr in FUSE_AGGRS + + root_dir = os.path.dirname(os.path.realpath(__file__)) + jinja_prefix = f"{self.__module__}_{self.__class__.__name__}" + # Optimize `propagate()` via `*.jinja` templates: + if not self.propagate.__module__.startswith(jinja_prefix): + try: + if "propagate" in self.__class__.__dict__: + raise ValueError("Cannot compile custom 'propagate' " "method") + + module = module_from_template( + module_name=f"{jinja_prefix}_propagate", + template_path=os.path.join(root_dir, "propagate.jinja"), + tmp_dirname="message_passing", + # Keyword arguments: + modules=self.inspector._modules, + collect_name="collect", + signature=self._get_propagate_signature(), + collect_param_dict=self.inspector.get_flat_param_dict( + ["message", "aggregate", "update"] + ), + message_args=self.inspector.get_param_names("message"), + aggregate_args=self.inspector.get_param_names("aggregate"), + message_and_aggregate_args=self.inspector.get_param_names( + "message_and_aggregate" + ), + update_args=self.inspector.get_param_names("update"), + fuse=self.fuse, + ) + + self.__class__._orig_propagate = self.__class__.propagate + self.__class__._jinja_propagate = module.propagate + + self.__class__.propagate = module.propagate + self.__class__.collect = module.collect + except Exception: # pragma: no cover + self.__class__._orig_propagate = self.__class__.propagate + self.__class__._jinja_propagate = self.__class__.propagate + + # Explainability: + self._explain: Optional[bool] = None + self._edge_mask: Optional[Tensor] = None + self._loop_mask: Optional[Tensor] = None + self._apply_sigmoid: bool = True + + # Inference Decomposition: + self.decomposed_layers = decomposed_layers + + def reset_parameters(self) -> None: + r"""Resets all learnable parameters of the module.""" + if self.aggr_module is not None: + self.aggr_module.reset_parameters() + + def __repr__(self) -> str: + channels_repr = "" + if hasattr(self, "in_channels") and hasattr(self, "out_channels"): + channels_repr = f"{self.in_channels}, {self.out_channels}" + elif hasattr(self, "channels"): + channels_repr = f"{self.channels}" + return f"{self.__class__.__name__}({channels_repr})" + + # Utilities ############################################################### + + def _check_input( + self, + edge_index: Union[Tensor], + size: Optional[Tuple[int, int]], + ) -> List[Optional[int]]: + + if isinstance(edge_index, Tensor): + int_dtypes = (ms.uint8, ms.int8, ms.int32, ms.int64) + + if edge_index.dtype not in int_dtypes: + raise ValueError( + f"Expected 'edge_index' to be of integer " + f"type (got '{edge_index.dtype}')" + ) + if edge_index.dim() != 2: + raise ValueError( + f"Expected 'edge_index' to be two-dimensional" + f" (got {edge_index.dim()} dimensions)" + ) + + return list(size) if size is not None else [None, None] + + raise ValueError( + ( + "`MessagePassing.propagate` only supports integer tensors of " + "shape `[2, num_messages]`, `mindspore_sparse.SparseTensor` or " + "`mindspore.sparse.Tensor` for argument `edge_index`." + ) + ) + + def _set_size( + self, + size: List[Optional[int]], + dim: int, + src: Tensor, + ) -> None: + the_size = size[dim] + if the_size is None: + size[dim] = src.shape[self.node_dim] + elif the_size != src.shape[self.node_dim]: + raise ValueError( + ( + f"Encountered tensor with size {src.shape[self.node_dim]} in " + f"dimension {self.node_dim}, but expected size {the_size}." + ) + ) + + def _index_select(self, src: Tensor, index: Tensor) -> Tensor: + if ops.numel(index) > 0 and ops.amin(index) < 0: + raise IndexError( + f"Found negative indices in 'edge_index' (got " + f"{ops.amin(index).item()}). Please ensure that all " + f"indices in 'edge_index' point to valid indices " + f"in the interval [0, {src.shape[self.node_dim]}) in " + f"your node feature matrix and try again." + ) + + if ops.numel(index) > 0 and ops.amax(index) >= src.shape[self.node_dim]: + raise IndexError( + f"Found indices in 'edge_index' that are larger " + f"than {src.shape[self.node_dim] - 1} (got " + f"{ops.amax(index).item()}). Please ensure that all " + f"indices in 'edge_index' point to valid indices " + f"in the interval [0, {src.shape[self.node_dim]}) in " + f"your node feature matrix and try again." + ) + try: + return src.index_select(self.node_dim, index) + except (IndexError, RuntimeError) as e: + raise e + + def _lift( + self, + src: Tensor, + edge_index: Union[Tensor], + dim: int, + ) -> Tensor: + + if isinstance(edge_index, Tensor): + return self._index_select(src, edge_index[dim]) + + raise ValueError( + ( + "`MessagePassing.propagate` only supports integer tensors of " + "shape `[2, num_messages]`, `mindspore_sparse.SparseTensor` " + "or `mindspore.sparse.Tensor` for argument `edge_index`." + ) + ) + + def _collect( + self, + args: Set[str], + edge_index: Union[Tensor], + size: List[Optional[int]], + kwargs: Dict[str, Any], + ) -> Dict[str, Any]: + + i, j = (1, 0) if self.flow == "src_to_trg" else (0, 1) + + out = {} + for arg in args: + if arg[-2:] not in ["_i", "_j"]: + out[arg] = kwargs.get(arg, Parameter.empty) + else: + dim = j if arg[-2:] == "_j" else i + data = kwargs.get(arg[:-2], Parameter.empty) + + if isinstance(data, (tuple, list)): + assert len(data) == 2 + if isinstance(data[1 - dim], Tensor): + self._set_size(size, 1 - dim, data[1 - dim]) + data = data[dim] + + if isinstance(data, Tensor): + self._set_size(size, dim, data) + data = self._lift(data, edge_index, dim) + + out[arg] = data + + if isinstance(edge_index, Tensor): + out["adj_t"] = None + out["edge_index"] = edge_index + out["edge_index_i"] = edge_index[i] + out["edge_index_j"] = edge_index[j] + + out["ptr"] = None + + out["index"] = out["edge_index_i"] + out["size"] = size + out["size_i"] = size[i] if size[i] is not None else size[j] + out["size_j"] = size[j] if size[j] is not None else size[i] + out["dim_size"] = out["size_i"] + + return out + + def construct(self, *args: Any, **kwargs: Any) -> Any: + r"""Runs the forward pass of the module.""" + pass + + def propagate( + self, + edge_index: Tensor, + shape: Optional[Tuple[int, int]] = None, + **kwargs: Any, + ) -> Tensor: + r"""The initial call to start propagating messages. + + Args: + edge_index (Tensor or SparseTensor): A :class:`Tensor`, + a :class:`mindspore_sparse.SparseTensor` or a + :class:`mindspore.sparse.Tensor` that defines the underlying + graph connectivity/message passing flow. + :obj:`edge_index` holds the indices of a general (sparse) + assignment matrix of shape :obj:`[N, M]`. + If :obj:`edge_index` is a :obj:`Tensor`, its :obj:`dtype` + should be :obj:`ms.int64` and its shape needs to be defined + as :obj:`[2, num_messages]` where messages from nodes in + :obj:`edge_index[0]` are sent to nodes in :obj:`edge_index[1]` + (in case :obj:`flow="src_to_trg"`). + If :obj:`edge_index` is a :class:`mindspore_sparse.SparseTensor` or + a :class:`mindspore.sparse.Tensor`, its sparse indices + :obj:`(row, col)` should relate to :obj:`row = edge_index[1]` + and :obj:`col = edge_index[0]`. + The major difference between both formats is that we need to + input the *transposed* sparse adjacency matrix into + :meth:`propagate`. + size ((int, int), optional): The size :obj:`(N, M)` of the + assignment matrix in case :obj:`edge_index` is a + :class:`Tensor`. + If set to :obj:`None`, the size will be automatically inferred + and assumed to be quadratic. + This argument is ignored in case :obj:`edge_index` is a + :class:`mindspore_sparse.SparseTensor` or + a :class:`mindspore.sparse.Tensor`. (default: :obj:`None`) + **kwargs: Any additional data which is needed to construct and + aggregate messages, and to update node embeddings. + """ + decomposed_layers = 1 if self.explain else self.decomposed_layers + + mutable_size = self._check_input(edge_index, shape) + + if decomposed_layers > 1: + user_args = self._user_args + decomp_args = {a[:-2] for a in user_args if a[-2:] == "_j"} + decomp_kwargs = { + a: kwargs[a].chunk(decomposed_layers, -1) for a in decomp_args + } + decomp_out = [] + for i in range(decomposed_layers): + if decomposed_layers > 1: + for arg in decomp_args: + kwargs[arg] = decomp_kwargs[arg][i] + + coll_dict = self._collect(self._user_args, edge_index, mutable_size, kwargs) + + msg_kwargs = self.inspector.collect_param_data("message", coll_dict) + + out = self.message(**msg_kwargs) + + if self.explain: + explain_msg_kwargs = self.inspector.collect_param_data( + "explain_message", coll_dict + ) + out = self.explain_message(out, **explain_msg_kwargs) + aggr_kwargs = self.inspector.collect_param_data("aggregate", coll_dict) + + out = self.aggregate(out, **aggr_kwargs) + + update_kwargs = self.inspector.collect_param_data("update", coll_dict) + out = self.update(out, **update_kwargs) + if decomposed_layers > 1: + decomp_out.append(out) + if decomposed_layers > 1: + out = ops.cat(decomp_out, dim=-1) + + return out + + def message(self, x_j: Tensor) -> Tensor: + r"""Constructs messages from node :math:`j` to node :math:`i` + in analogy to :math:`\phi_{\mathbf{\Theta}}` for each edge in + :obj:`edge_index`. + This function can take any argument as input which was initially + passed to :meth:`propagate`. + Furthermore, tensors passed to :meth:`propagate` can be mapped to the + respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or + :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. + """ + return x_j + + def aggregate( + self, + inputs: Tensor, + index: Tensor, + ptr: Optional[Tensor] = None, + dim_size: Optional[int] = None, + ) -> Tensor: + r"""Aggregates messages from neighbors as + :math:`\bigoplus_{j \in \mathcal{N}(i)}`. + + Takes in the output of message computation as first argument and any + argument which was initially passed to :meth:`propagate`. + + By default, this function will delegate its call to the underlying + :class:`~sharker.nn.aggr.Aggregation` module to reduce messages + as specified in :meth:`__init__` by the :obj:`aggr` argument. + """ + return self.aggr_module( + inputs, index, ptr=ptr, dim_size=dim_size, dim=self.node_dim + ) + + @abstractmethod + def message_and_aggregate(self, adj_t: Tensor) -> Tensor: + r"""Fuses computations of :func:`message` and :func:`aggregate` into a + single function. + If applicable, this saves both time and memory since messages do not + explicitly need to be materialized. + This function will only gets called in case it is implemented and + propagation takes place based on a :obj:`mindspore_sparse.SparseTensor` + or a :obj:`mindspore.sparse.Tensor`. + """ + raise NotImplementedError + + def update(self, inputs: Tensor) -> Tensor: + r"""Updates node embeddings in analogy to + :math:`\gamma_{\mathbf{\Theta}}` for each node + :math:`i \in \mathcal{V}`. + Takes in the output of aggregation as first argument and any argument + which was initially passed to :meth:`propagate`. + """ + return inputs + + # Edge-level Updates ###################################################### + + def edge_updater( + self, + edge_index: Tensor, + size: Optional[Tuple[int, int]] = None, + **kwargs: Any, + ) -> Tensor: + r"""The initial call to compute or update features for each edge in the + graph. + + Args: + edge_index (Tensor or SparseTensor): A :obj:`Tensor`, a + :class:`mindspore_sparse.SparseTensor` or a + :class:`mindspore.sparse.Tensor` that defines the underlying graph + connectivity/message passing flow. + See :meth:`propagate` for more information. + size ((int, int), optional): The size :obj:`(N, M)` of the + assignment matrix in case :obj:`edge_index` is a + :class:`Tensor`. + If set to :obj:`None`, the size will be automatically inferred + and assumed to be quadratic. + This argument is ignored in case :obj:`edge_index` is a + :class:`mindspore_sparse.SparseTensor` or + a :class:`mindspore.sparse.Tensor`. (default: :obj:`None`) + **kwargs: Any additional data which is needed to compute or update + features for each edge in the graph. + """ + + mutable_size = self._check_input(edge_index, size=None) + + coll_dict = self._collect( + self._edge_user_args, edge_index, mutable_size, kwargs + ) + + edge_kwargs = self.inspector.collect_param_data("edge_update", coll_dict) + out = self.edge_update(**edge_kwargs) + + return out + + @abstractmethod + def edge_update(self) -> Tensor: + r"""Computes or updates features for each edge in the graph. + This function can take any argument as input which was initially passed + to :meth:`edge_updater`. + Furthermore, tensors passed to :meth:`edge_updater` can be mapped to + the respective nodes :math:`i` and :math:`j` by appending :obj:`_i` or + :obj:`_j` to the variable name, *.e.g.* :obj:`x_i` and :obj:`x_j`. + """ + raise NotImplementedError + + # Inference Decomposition ################################################# + + @property + def decomposed_layers(self) -> int: + return self._decomposed_layers + + @decomposed_layers.setter + def decomposed_layers(self, decomposed_layers: int) -> None: + + self._decomposed_layers = decomposed_layers + + if decomposed_layers != 1: + self.propagate = self.__class__._orig_propagate.__get__(self, MessagePassing) + + elif ( + self.explain is None or self.explain is False + ) and not self.propagate.__module__.endswith("_propagate"): + self.propagate = self.__class__._jinja_propagate.__get__(self, MessagePassing) + + # Explainability ########################################################## + + @property + def explain(self) -> Optional[bool]: + return self._explain + + @explain.setter + def explain(self, explain: Optional[bool]) -> None: + + self._explain = explain + + if explain is True: + assert self.decomposed_layers == 1 + self.inspector.remove_signature(self.explain_message) + self.inspector.inspect_signature(self.explain_message, exclude=[0]) + self._user_args = self.inspector.get_flat_param_names( + funcs=["message", "explain_message", "aggregate", "update"], + exclude=self.special_args, + ) + + def explain_message( + self, + inputs: Tensor, + dim_size: Optional[int], + ) -> Tensor: + # NOTE Replace this method in custom explainers per message-passing + # layer to customize how messages shall be explained, e.g., via: + # conv.explain_message = explain_message.__get__(conv, MessagePassing) + # see stackoverflow.com: 394770/override-a-method-at-instance-level + edge_mask = self._edge_mask + + if edge_mask is None: + raise ValueError( + "Could not find a pre-defined 'edge_mask' " + "to explain. Did you forget to initialize it?" + ) + + if self._apply_sigmoid: + edge_mask = edge_mask.sigmoid() + + # Some ops add self-loops to `edge_index`. We need to do the same for + # `edge_mask` (but do not train these entries). + if inputs.shape[self.node_dim] != edge_mask.shape[0]: + assert dim_size is not None + edge_mask = edge_mask[self._loop_mask] + loop = ops.ones(dim_size, dtype=edge_mask.dtype) + edge_mask = ops.cat([edge_mask, loop], dim=0) + assert inputs.shape[self.node_dim] == edge_mask.shape[0] + + size = [1] * inputs.dim() + size[self.node_dim] = -1 + return inputs * edge_mask.view(*size) diff --git a/mindscience/sharker/nn/conv/mf_conv.py b/mindscience/sharker/nn/conv/mf_conv.py new file mode 100644 index 000000000..380e4ae70 --- /dev/null +++ b/mindscience/sharker/nn/conv/mf_conv.py @@ -0,0 +1,108 @@ +from typing import Tuple, Union, Optional +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from mindspore.nn import CellList +from .message_passing import MessagePassing +from ...utils import degree + + +class MFConv(MessagePassing): + r"""The graph neural network operator from the + `"Convolutional Networks on Graphs for Learning Molecular Fingerprints" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}^{(\deg(i))}_1 \mathbf{x}_i + + \mathbf{W}^{(\deg(i))}_2 \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j + + which trains a distinct weight matrix for each possible vertex degree. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + max_degree (int, optional): The maximum node degree to consider when + updating weights (default: :obj:`10`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **inputs:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, max_degree: int = 10, bias=True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.max_degree = max_degree + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lins_l = CellList([ + nn.Dense(in_channels[0], out_channels, has_bias=bias) + for _ in range(max_degree + 1) + ]) + + self.lins_r = CellList([ + nn.Dense(in_channels[1], out_channels, has_bias=False) + for _ in range(max_degree + 1) + ]) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + shape: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + x_r = x[1] + + deg = x[0] # Dummy. + i = 1 if self.flow == 'src_to_trg' else 0 + N = x[0].shape[self.node_dim] + N = shape[1] if shape is not None else N + N = x_r.shape[self.node_dim] if x_r is not None else N + deg = degree(edge_index[i], N, dtype=ms.int64) + + deg = deg.clamp(max=self.max_degree) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]]) + h = self.propagate(edge_index, x=x, shape=shape) + + out = mint.zeros(list(h.shape)[:-1] + [self.out_channels], dtype=h.dtype) + for i, (lin_l, lin_r) in enumerate(zip(self.lins_l, self.lins_r)): + idx = (deg == i).nonzero().view(-1) + if len(idx) != 0: + r = lin_l(h.index_select(self.node_dim, idx)) + + if x_r is not None: + r = r + lin_r(x_r.index_select(self.node_dim, idx)) + ops.swapaxes(out, 0, self.node_dim) + out[idx] = r + ops.swapaxes(out, 0, self.node_dim) + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j diff --git a/mindscience/sharker/nn/conv/mixhop_conv.py b/mindscience/sharker/nn/conv/mixhop_conv.py new file mode 100644 index 000000000..7c0228b8b --- /dev/null +++ b/mindscience/sharker/nn/conv/mixhop_conv.py @@ -0,0 +1,110 @@ +from typing import List, Optional, Union +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import zeros + + +class MixHopConv(MessagePassing): + r"""The Mix-Hop graph convolutional operator from the + `"MixHop: Higher-Order Graph Convolutional Architecturesvia Sparsified + Neighborhood Mixing" `_ paper. + + .. math:: + \mathbf{X}^{\prime}={\Bigg\Vert}_{p\in P} + {\left( \mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \right)}^p \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + powers (List[int], optional): The powers of the adjacency matrix to + use. (default: :obj:`[0, 1, 2]`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** + node features :math:`(|\mathcal{V}|, |P| \cdot F_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + powers: Optional[List[int]] = None, + add_self_loops: bool = True, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + if powers is None: + powers = [0, 1, 2] + + self.in_channels = in_channels + self.out_channels = out_channels + self.powers = powers + self.add_self_loops = add_self_loops + + self.lins = nn.CellList([ + nn.Dense(in_channels, out_channels, has_bias=False) + if p in powers else nn.Identity() + for p in range(max(powers) + 1) + ]) + + if bias: + self.bias = Parameter(ms.numpy.empty(len(powers) * out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if isinstance(edge_index, Tensor): + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, x.dtype) + + outs = [self.lins[0](x)] + + for lin in self.lins[1:]: + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + x = self.propagate(edge_index, x=x, edge_weight=edge_weight) + + outs.append(lin(x)) + + out = mint.cat(([outs[p] for p in self.powers]), dim=-1) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, powers={self.powers})') diff --git a/mindscience/sharker/nn/conv/nn_conv.py b/mindscience/sharker/nn/conv/nn_conv.py new file mode 100644 index 000000000..97d35515f --- /dev/null +++ b/mindscience/sharker/nn/conv/nn_conv.py @@ -0,0 +1,121 @@ +from typing import Callable, Tuple, Union, Optional +import mindspore as ms +from mindspore import Tensor, nn, Parameter +from .message_passing import MessagePassing +from ..inits import reset, zeros + + +class NNConv(MessagePassing): + r"""The continuous kernel-based convolutional operator from the + `"Neural Message Passing for Quantum Chemistry" + `_ paper. + + This convolution is also known as the edge-conditioned convolution from the + `"Dynamic Edge-Conditioned Filters in Convolutional Neural Networks on + Graphs" `_ paper (see + :class:`sharker.nn.conv.ECConv` for an alias): + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{\Theta} \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \mathbf{x}_j \cdot + h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), + + where :math:`h_{\mathbf{\Theta}}` denotes a neural network, *.i.e.* + a MLP. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + nn (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` that + maps edge features :obj:`edge_attr` of shape :obj:`[-1, + num_edge_features]` to shape + :obj:`[-1, in_channels * out_channels]`, *e.g.*, defined by + :class:`nn.SequentialCell`. + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"add"`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add the transformed root node features to the output. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, net: Callable, aggr: str = 'add', + root_weight: bool = True, bias: bool = True, **kwargs): + super().__init__(aggr=aggr, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.net = net + self.root_weight = root_weight + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.in_channels_l = in_channels[0] + + if root_weight: + self.lin = nn.Dense(in_channels[1], out_channels, has_bias=False, + weight_init='uniform') + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.net) + if self.bias is not None: + zeros(self.bias) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + size: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, shape=size) + + x_r = x[1] + if x_r is not None and self.root_weight: + out += self.lin(x_r) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: + weight = self.net(edge_attr) + weight = weight.view(-1, self.in_channels_l, self.out_channels) + return (x_j.unsqueeze(1) @ weight).squeeze(1) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, aggr={self.aggr}, nn={self.net})') diff --git a/mindscience/sharker/nn/conv/pdn_conv.py b/mindscience/sharker/nn/conv/pdn_conv.py new file mode 100644 index 000000000..a64ad0ed2 --- /dev/null +++ b/mindscience/sharker/nn/conv/pdn_conv.py @@ -0,0 +1,113 @@ +from typing import Union, Optional +import mindspore as ms +from mindspore import Tensor, Parameter, nn +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import glorot, zeros + + +class PDNConv(MessagePassing): + r"""The pathfinder discovery network convolutional operator from the + `"Pathfinder Discovery Networks for Neural Message Passing" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{j \in \mathcal{N}(i) \cup + \{i\}}f_{\Theta}(\textbf{e}_{(j,i)}) \cdot f_{\Omega}(\mathbf{x}_{j}) + + where :math:`z_{i,j}` denotes the edge feature vector from source node + :math:`j` to target node :math:`i`, and :math:`\mathbf{x}_{j}` denotes the + node feature vector of node :math:`j`. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + edge_dim (int): Edge feature dimensionality. + hidden_channels (int): Hidden edge feature dimensionality. + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + normalize (bool, optional): Whether to add self-loops and compute + symmetric normalization coefficients on the fly. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge features :math:`(|\mathcal{E}|, D)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, edge_dim: int, + hidden_channels: int, add_self_loops: bool = True, + normalize: bool = True, bias: bool = True, **kwargs): + + kwargs.setdefault("aggr", "add") + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.edge_dim = edge_dim + self.hidden_channels = hidden_channels + self.add_self_loops = add_self_loops + self.normalize = normalize + + self.lin = nn.Dense(in_channels, out_channels, has_bias=False) + + self.mlp = nn.SequentialCell( + nn.Dense(edge_dim, hidden_channels), + nn.ReLU(), + nn.Dense(hidden_channels, 1), + nn.Sigmoid(), + ) + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + glorot(self.lin.weight) + glorot(self.mlp[0].weight) + glorot(self.mlp[2].weight) + zeros(self.mlp[0].bias) + zeros(self.mlp[2].bias) + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None) -> Tensor: + + if edge_attr is not None: + edge_attr = self.mlp(edge_attr).squeeze(-1) + + if self.normalize: + if isinstance(edge_index, Tensor): + edge_index, edge_attr = gcn_norm(edge_index, edge_attr, + x.shape[self.node_dim], False, + self.add_self_loops, + self.flow, x.dtype) + + x = self.lin(x) + + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_weight=edge_attr) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def __repr__(self): + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels})') diff --git a/mindscience/sharker/nn/conv/point_conv.py b/mindscience/sharker/nn/conv/point_conv.py new file mode 100644 index 000000000..7f2dd1d83 --- /dev/null +++ b/mindscience/sharker/nn/conv/point_conv.py @@ -0,0 +1,107 @@ +from typing import Callable, Optional, Union, Tuple +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import reset +from ...utils import add_self_loops, remove_self_loops + + +class PointNetConv(MessagePassing): + r"""The PointNet set layer from the `"PointNet: Deep Learning on Point Sets + for 3D Classification and Segmentation" + `_ and `"PointNet++: Deep Hierarchical + Feature Learning on Point Sets in a Metric Space" + `_ papers. + + .. math:: + \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in + \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, + \mathbf{p}_j - \mathbf{p}_i) \right), + + where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` + denote neural networks, *i.e.* MLPs, and + :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of + each point. + + Args: + local_nn (nn.CellList, optional): A neural network + :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and + relative spatial coordinates :obj:`pos_j - pos_i` of shape + :obj:`[-1, in_channels + num_dimensions]` to shape + :obj:`[-1, out_channels]`, *e.g.*, defined by + :class:`nn.SequentialCell`. (default: :obj:`None`) + global_nn (nn.CellList, optional): A neural network + :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features + of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, + final_out_channels]`, *e.g.*, defined by + :class:`nn.SequentialCell`. (default: :obj:`None`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + positions :math:`(|\mathcal{V}|, 3)` or + :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, local_nn: Optional[Callable] = None, + global_nn: Optional[Callable] = None, + add_self_loops: bool = True, **kwargs): + kwargs.setdefault('aggr', 'max') + super().__init__(**kwargs) + + self.local_nn = local_nn + self.global_nn = global_nn + self.add_self_loops = add_self_loops + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.local_nn) + reset(self.global_nn) + + def construct( + self, + x: Union[Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]], + pos: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + ) -> Tensor: + + if not isinstance(x, tuple): + x = (x, None) + + if isinstance(pos, Tensor): + pos = (pos, pos) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops( + edge_index, num_nodes=min(pos[0].shape[0], pos[1].shape[0])) + out = self.propagate(edge_index, x=x, pos=pos) + + if self.global_nn is not None: + out = self.global_nn(out) + + return out + + def message(self, x_j: Optional[Tensor], pos_i: Tensor, + pos_j: Tensor) -> Tensor: + msg = pos_j - pos_i + if x_j is not None: + msg = mint.cat(([x_j, msg]), dim=1) + if self.local_nn is not None: + msg = self.local_nn(msg) + return msg + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(local_nn={self.local_nn}, ' + f'global_nn={self.global_nn})') diff --git a/mindscience/sharker/nn/conv/point_gnn_conv.py b/mindscience/sharker/nn/conv/point_gnn_conv.py new file mode 100644 index 000000000..6a3819fd6 --- /dev/null +++ b/mindscience/sharker/nn/conv/point_gnn_conv.py @@ -0,0 +1,80 @@ +from typing import Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import reset + + +class PointGNNConv(MessagePassing): + r"""The PointGNN operator from the `"Point-GNN: Graph Neural Network for + 3D Object Detection in a Point Cloud" `_ + paper. + + .. math:: + + \Delta \textrm{pos}_i &= h_{\mathbf{\Theta}}(\mathbf{x}_i) + + \mathbf{e}_{j,i} &= f_{\mathbf{\Theta}}(\textrm{pos}_j - + \textrm{pos}_i + \Delta \textrm{pos}_i, \mathbf{x}_j) + + \mathbf{x}^{\prime}_i &= g_{\mathbf{\Theta}}(\max_{j \in + \mathcal{N}(i)} \mathbf{e}_{j,i}) + \mathbf{x}_i + + The relative position is used in the message passing step to introduce + global translation invariance. + To also counter shifts in the local neighborhood of the center node, the + authors propose to utilize an alignment offset. + The graph should be statically constructed using radius-based cutoff. + + Args: + mlp_h (nn.CellList): A neural network :math:`h_{\mathbf{\Theta}}` + that maps node features of size :math:`F_{in}` to three-dimensional + coordination offsets :math:`\Delta \textrm{pos}_i`. + mlp_f (nn.CellList): A neural network :math:`f_{\mathbf{\Theta}}` + that computes :math:`\mathbf{e}_{j,i}` from the features of + neighbors of size :math:`F_{in}` and the three-dimensional vector + :math:`\textrm{pos_j} - \textrm{pos_i} + \Delta \textrm{pos}_i`. + mlp_g (nn.CellList): A neural network :math:`g_{\mathbf{\Theta}}` + that maps the aggregated edge features back to :math:`F_{in}`. + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + positions :math:`(|\mathcal{V}|, 3)`, + edge indices :math:`(2, |\mathcal{E}|)`, + - **output:** node features :math:`(|\mathcal{V}|, F_{in})` + """ + + def __init__( + self, + mlp_h: nn.CellList, + mlp_f: nn.CellList, + mlp_g: nn.CellList, + **kwargs, + ): + kwargs.setdefault('aggr', 'max') + super().__init__(**kwargs) + + self.mlp_h = mlp_h + self.mlp_f = mlp_f + self.mlp_g = mlp_g + + def construct(self, x: Tensor, pos: Tensor, edge_index: Union[Tensor, ]) -> Tensor: + # propagate_type: (x: Tensor, pos: Tensor) + out = self.propagate(edge_index, x=x, pos=pos) + out = self.mlp_g(out) + return x + out + + def message(self, pos_j: Tensor, pos_i: Tensor, x_i: Tensor, + x_j: Tensor) -> Tensor: + delta = self.mlp_h(x_i) + e = mint.cat(([pos_j - pos_i + delta, x_j]), dim=-1) + return self.mlp_f(e) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(\n' + f' mlp_h={self.mlp_h},\n' + f' mlp_f={self.mlp_f},\n' + f' mlp_g={self.mlp_g},\n' + f')') diff --git a/mindscience/sharker/nn/conv/point_transformer_conv.py b/mindscience/sharker/nn/conv/point_transformer_conv.py new file mode 100644 index 000000000..b6742797d --- /dev/null +++ b/mindscience/sharker/nn/conv/point_transformer_conv.py @@ -0,0 +1,139 @@ +from typing import Callable, Optional, Tuple, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ..inits import reset +from ...utils import add_self_loops, remove_self_loops, softmax + + +class PointTransformerConv(MessagePassing): + r"""The Point Transformer layer from the `"Point Transformer" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \sum_{j \in + \mathcal{N}(i) \cup \{ i \}} \alpha_{i,j} \left(\mathbf{W}_3 + \mathbf{x}_j + \delta_{ij} \right), + + where the attention coefficients :math:`\alpha_{i,j}` and + positional embedding :math:`\delta_{ij}` are computed as + + .. math:: + \alpha_{i,j}= \textrm{softmax} \left( \gamma_\mathbf{\Theta} + (\mathbf{W}_1 \mathbf{x}_i - \mathbf{W}_2 \mathbf{x}_j + + \delta_{i,j}) \right) + + and + + .. math:: + \delta_{i,j}= h_{\mathbf{\Theta}}(\mathbf{p}_i - \mathbf{p}_j), + + with :math:`\gamma_\mathbf{\Theta}` and :math:`h_\mathbf{\Theta}` + denoting neural networks, *i.e.* MLPs, and + :math:`\mathbf{P} \in \mathbb{R}^{N \times D}` defines the position of + each point. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + pos_nn (nn.CellList, optional): A neural network + :math:`h_\mathbf{\Theta}` which maps relative spatial coordinates + :obj:`pos_j - pos_i` of shape :obj:`[-1, 3]` to shape + :obj:`[-1, out_channels]`. + Will default to a :class:`nn.Dense` transformation if not + further specified. (default: :obj:`None`) + attn_nn (nn.CellList, optional): A neural network + :math:`\gamma_\mathbf{\Theta}` which maps transformed + node features of shape :obj:`[-1, out_channels]` + to shape :obj:`[-1, out_channels]`. (default: :obj:`None`) + add_self_loops (bool, optional) : If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + positions :math:`(|\mathcal{V}|, 3)` or + :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, pos_nn: Optional[Callable] = None, + attn_nn: Optional[Callable] = None, + add_self_loops: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.add_self_loops = add_self_loops + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.pos_nn = pos_nn + if self.pos_nn is None: + self.pos_nn = nn.Dense(3, out_channels) + + self.attn_nn = attn_nn + self.lin = nn.Dense(in_channels[0], out_channels, has_bias=False) + self.lin_src = nn.Dense(in_channels[0], out_channels, has_bias=False) + self.lin_dst = nn.Dense(in_channels[1], out_channels, has_bias=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.pos_nn) + if self.attn_nn is not None: + reset(self.attn_nn) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + pos: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + ) -> Tensor: + + if isinstance(x, Tensor): + alpha = (self.lin_src(x), self.lin_dst(x)) + x = (self.lin(x), x) + else: + alpha = (self.lin_src(x[0]), self.lin_dst(x[1])) + x = (self.lin(x[0]), x[1]) + + if isinstance(pos, Tensor): + pos = (pos, pos) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops( + edge_index, num_nodes=min(pos[0].shape[0], pos[1].shape[0])) + + # propagate_type: (x: Tuple[Tensor, Tensor], pos: Tuple[Tensor, Tensor], alpha: Tuple[Tensor, Tensor]) + out = self.propagate(edge_index, x=x, pos=pos, alpha=alpha) + return out + + def message(self, x_j: Tensor, pos_i: Tensor, pos_j: Tensor, + alpha_i: Tensor, alpha_j: Tensor, index: Tensor, + ptr: Optional[Tensor], size_i: Optional[int]) -> Tensor: + + delta = self.pos_nn(pos_i - pos_j) + alpha = alpha_i - alpha_j + delta + if self.attn_nn is not None: + alpha = self.attn_nn(alpha) + alpha = softmax(alpha, index, ptr, size_i) + return alpha * (x_j + delta) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels})') diff --git a/mindscience/sharker/nn/conv/ppf_conv.py b/mindscience/sharker/nn/conv/ppf_conv.py new file mode 100644 index 000000000..0f1d47639 --- /dev/null +++ b/mindscience/sharker/nn/conv/ppf_conv.py @@ -0,0 +1,129 @@ +from typing import Callable, Optional, Union, Tuple +from mindspore import Tensor, ops, numpy, mint +from .message_passing import MessagePassing +from ..inits import reset +from ...utils import add_self_loops, remove_self_loops + + +def get_angle(v1: Tensor, v2: Tensor) -> Tensor: + return ops.atan2( + numpy.cross(v1, v2, axis=1).norm(ord=2, dim=1), (v1 * v2).sum(1)) + + +def point_pair_features(pos_i: Tensor, pos_j: Tensor, normal_i: Tensor, + normal_j: Tensor) -> Tensor: + pseudo = pos_j - pos_i + return mint.stack([ + numpy.norm(pseudo, ord=2, axis=1), + get_angle(normal_i, pseudo), + get_angle(normal_j, pseudo), + get_angle(normal_i, normal_j) + ], dim=1) + + +class PPFConv(MessagePassing): + r"""The PPFNet operator from the `"PPFNet: Global Context Aware Local + Features for Robust 3D Point Matching" `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = \gamma_{\mathbf{\Theta}} \left( \max_{j \in + \mathcal{N}(i) \cup \{ i \}} h_{\mathbf{\Theta}} ( \mathbf{x}_j, \| + \mathbf{d_{j,i}} \|, \angle(\mathbf{n}_i, \mathbf{d_{j,i}}), + \angle(\mathbf{n}_j, \mathbf{d_{j,i}}), \angle(\mathbf{n}_i, + \mathbf{n}_j) \right) + + where :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` + denote neural networks, *.i.e.* MLPs, which takes in node features and + :class:`sharker.transforms.PointPairFeatures`. + + Args: + local_nn (nn.CellList, optional): A neural network + :math:`h_{\mathbf{\Theta}}` that maps node features :obj:`x` and + relative spatial coordinates :obj:`pos_j - pos_i` of shape + :obj:`[-1, in_channels + num_dimensions]` to shape + :obj:`[-1, out_channels]`, *e.g.*, defined by + :class:`nn.SequentialCell`. (default: :obj:`None`) + global_nn (nn.CellList, optional): A neural network + :math:`\gamma_{\mathbf{\Theta}}` that maps aggregated node features + of shape :obj:`[-1, out_channels]` to shape :obj:`[-1, + final_out_channels]`, *e.g.*, defined by + :class:`nn.SequentialCell`. (default: :obj:`None`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + positions :math:`(|\mathcal{V}|, 3)` or + :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, + point normals :math:`(|\mathcal{V}, 3)` or + :math:`((|\mathcal{V_s}|, 3), (|\mathcal{V_t}|, 3))` if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V}_t|, F_{out})` if bipartite + + """ + + def __init__(self, local_nn: Optional[Callable] = None, + global_nn: Optional[Callable] = None, + add_self_loops: bool = True, **kwargs): + kwargs.setdefault('aggr', 'max') + super().__init__(**kwargs) + + self.local_nn = local_nn + self.global_nn = global_nn + self.add_self_loops = add_self_loops + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + reset(self.local_nn) + reset(self.global_nn) + + def construct( + self, + x: Union[Optional[Tensor], Tuple[Optional[Tensor], Optional[Tensor]]], + pos: Union[Tensor, Tuple[Tensor, Tensor]], + normal: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + ) -> Tensor: + + if not isinstance(x, tuple): + x = (x, None) + + if isinstance(pos, Tensor): + pos = (pos, pos) + + if isinstance(normal, Tensor): + normal = (normal, normal) + + if self.add_self_loops: + if isinstance(edge_index, Tensor): + edge_index, _ = remove_self_loops(edge_index) + edge_index, _ = add_self_loops(edge_index, + num_nodes=pos[1].shape[0]) + out = self.propagate(edge_index, x=x, pos=pos, normal=normal) + + if self.global_nn is not None: + out = self.global_nn(out) + + return out + + def message(self, x_j: Optional[Tensor], pos_i: Tensor, pos_j: Tensor, + normal_i: Tensor, normal_j: Tensor) -> Tensor: + msg = point_pair_features(pos_i, pos_j, normal_i, normal_j) + if x_j is not None: + msg = mint.cat(([x_j, msg]), dim=1) + if self.local_nn is not None: + msg = self.local_nn(msg) + return msg + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}(local_nn={self.local_nn}, ' + f'global_nn={self.global_nn})') diff --git a/mindscience/sharker/nn/conv/res_gated_graph_conv.py b/mindscience/sharker/nn/conv/res_gated_graph_conv.py new file mode 100644 index 000000000..5e3efc6f4 --- /dev/null +++ b/mindscience/sharker/nn/conv/res_gated_graph_conv.py @@ -0,0 +1,138 @@ +from typing import Callable, Optional, Tuple, Union +from mindspore import Tensor, Parameter, nn, ops, mint +from ..inits import zeros +from .message_passing import MessagePassing + + +class ResGatedGraphConv(MessagePassing): + r"""The residual gated graph convolutional operator from the + `"Residual Gated Graph ConvNets" `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \eta_{i,j} \odot \mathbf{W}_2 \mathbf{x}_j + + where the gate :math:`\eta_{i,j}` is defined as + + .. math:: + \eta_{i,j} = \sigma(\mathbf{W}_3 \mathbf{x}_i + \mathbf{W}_4 + \mathbf{x}_j) + + with :math:`\sigma` denoting the sigmoid function. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + act (callable, optional): Gating function :math:`\sigma`. + (default: :meth:`nn.Sigmoid()`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add transformed root node features to the output. + (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **inputs:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + act: Optional[Callable] = nn.Sigmoid, + edge_dim: Optional[int] = None, + root_weight: bool = True, + bias: bool = True, + **kwargs, + ): + + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.act = act() + self.edge_dim = edge_dim + self.root_weight = root_weight + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + edge_dim = edge_dim if edge_dim is not None else 0 + self.lin_key = nn.Dense(in_channels[1] + edge_dim, out_channels) + self.lin_query = nn.Dense(in_channels[0] + edge_dim, out_channels) + self.lin_value = nn.Dense(in_channels[0] + edge_dim, out_channels) + + if root_weight: + self.lin_skip = nn.Dense(in_channels[1], out_channels, has_bias=False) + else: + self.lin_skip = None + + if bias: + self.bias = Parameter(mint.zeros(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.bias is not None: + zeros(self.bias) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # In case edge features are not given, we can compute key, query and + # value tensors in node-level space, which is a bit more efficient: + if self.edge_dim is None: + k = self.lin_key(x[1]) + q = self.lin_query(x[0]) + v = self.lin_value(x[0]) + else: + k, q, v = x[1], x[0], x[0] + + # propagate_type: (k: Tensor, q: Tensor, v: Tensor, + # edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, k=k, q=q, v=v, edge_attr=edge_attr) + + if self.root_weight: + out += self.lin_skip(x[1]) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, k_i: Tensor, q_j: Tensor, v_j: Tensor, + edge_attr: Optional[Tensor]) -> Tensor: + + assert (edge_attr is not None) == (self.edge_dim is not None) + + if edge_attr is not None: + k_i = self.lin_key(mint.cat(([k_i, edge_attr]), dim=-1)) + q_j = self.lin_query(mint.cat(([q_j, edge_attr]), dim=-1)) + v_j = self.lin_value(mint.cat(([v_j, edge_attr]), dim=-1)) + + return self.act(k_i + q_j) * v_j diff --git a/mindscience/sharker/nn/conv/rgat_conv.py b/mindscience/sharker/nn/conv/rgat_conv.py new file mode 100644 index 000000000..13d6f3bc0 --- /dev/null +++ b/mindscience/sharker/nn/conv/rgat_conv.py @@ -0,0 +1,515 @@ +from typing import Optional, Union, Tuple +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops, mint +from ...utils import scatter, softmax, Ncon +from ..inits import glorot, ones, zeros +from .message_passing import MessagePassing + + +class RGATConv(MessagePassing): + r"""The relational graph attentional operator from the `"Relational Graph + Attention Networks" `_ paper. + + Here, attention logits :math:`\mathbf{a}^{(r)}_{i,j}` are computed for each + relation type :math:`r` with the help of both query and key kernels, *i.e.* + + .. math:: + \mathbf{q}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot + \mathbf{Q}^{(r)} + \quad \textrm{and} \quad + \mathbf{k}^{(r)}_i = \mathbf{W}_1^{(r)}\mathbf{x}_{i} \cdot + \mathbf{K}^{(r)}. + + Two schemes have been proposed to compute attention logits + :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r`: + + **Additive attention** + + .. math:: + \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + + \mathbf{k}^{(r)}_j) + + or **multiplicative attention** + + .. math:: + \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j. + + If the graph has multi-dimensional edge features + :math:`\mathbf{e}^{(r)}_{i,j}`, the attention logits + :math:`\mathbf{a}^{(r)}_{i,j}` for each relation type :math:`r` are + computed as + + .. math:: + \mathbf{a}^{(r)}_{i,j} = \mathrm{LeakyReLU}(\mathbf{q}^{(r)}_i + + \mathbf{k}^{(r)}_j + \mathbf{W}_2^{(r)}\mathbf{e}^{(r)}_{i,j}) + + or + + .. math:: + \mathbf{a}^{(r)}_{i,j} = \mathbf{q}^{(r)}_i \cdot \mathbf{k}^{(r)}_j + \cdot \mathbf{W}_2^{(r)} \mathbf{e}^{(r)}_{i,j}, + + respectively. + The attention coefficients :math:`\alpha^{(r)}_{i,j}` for each relation + type :math:`r` are then obtained via two different attention mechanisms: + The **within-relation** attention mechanism + + .. math:: + \alpha^{(r)}_{i,j} = + \frac{\exp(\mathbf{a}^{(r)}_{i,j})} + {\sum_{k \in \mathcal{N}_r(i)} \exp(\mathbf{a}^{(r)}_{i,k})} + + or the **across-relation** attention mechanism + + .. math:: + \alpha^{(r)}_{i,j} = + \frac{\exp(\mathbf{a}^{(r)}_{i,j})} + {\sum_{r^{\prime} \in \mathcal{R}} + \sum_{k \in \mathcal{N}_{r^{\prime}}(i)} + \exp(\mathbf{a}^{(r^{\prime})}_{i,k})} + + where :math:`\mathcal{R}` denotes the set of relations, *i.e.* edge types. + Edge type needs to be a one-dimensional :obj:`ms.int64` tensor which + stores a relation identifier :math:`\in \{ 0, \ldots, |\mathcal{R}| - 1\}` + for each edge. + + To enhance the discriminative power of attention-based GNNs, this layer + further implements four different cardinality preservation options as + proposed in the `"Improving Attention Mechanism in Graph Neural Networks + via Cardinality Preservation" `_ paper: + + .. math:: + \text{additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= + \sum_{j \in \mathcal{N}_r(i)} + \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + \mathcal{W} \odot + \sum_{j \in \mathcal{N}_r(i)} \mathbf{x}^{(r)}_j + + \text{scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= + \psi(|\mathcal{N}_r(i)|) \odot + \sum_{j \in \mathcal{N}_r(i)} \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + + \text{f-additive:}~~~\mathbf{x}^{{\prime}(r)}_i &= + \sum_{j \in \mathcal{N}_r(i)} + (\alpha^{(r)}_{i,j} + 1) \cdot \mathbf{x}^{(r)}_j + + \text{f-scaled:}~~~\mathbf{x}^{{\prime}(r)}_i &= + |\mathcal{N}_r(i)| \odot \sum_{j \in \mathcal{N}_r(i)} + \alpha^{(r)}_{i,j} \mathbf{x}^{(r)}_j + + * If :obj:`attention_mode="additive-self-attention"` and + :obj:`concat=True`, the layer outputs :obj:`heads * out_channels` + features for each node. + + * If :obj:`attention_mode="multiplicative-self-attention"` and + :obj:`concat=True`, the layer outputs :obj:`heads * dim * out_channels` + features for each node. + + * If :obj:`attention_mode="additive-self-attention"` and + :obj:`concat=False`, the layer outputs :obj:`out_channels` features for + each node. + + * If :obj:`attention_mode="multiplicative-self-attention"` and + :obj:`concat=False`, the layer outputs :obj:`dim * out_channels` features + for each node. + + Please make sure to set the :obj:`in_channels` argument of the next + layer accordingly if more than one instance of this layer is used. + + .. note:: + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + num_relations (int): Number of relations. + num_bases (int, optional): If set, this layer will use the + basis-decomposition regularization scheme where :obj:`num_bases` + denotes the number of bases to use. (default: :obj:`None`) + num_blocks (int, optional): If set, this layer will use the + block-diagonal-decomposition regularization scheme where + :obj:`num_blocks` denotes the number of blocks to use. + (default: :obj:`None`) + mod (str, optional): The cardinality preservation option to use. + (:obj:`"additive"`, :obj:`"scaled"`, :obj:`"f-additive"`, + :obj:`"f-scaled"`, :obj:`None`). (default: :obj:`None`) + attention_mechanism (str, optional): The attention mechanism to use + (:obj:`"within-relation"`, :obj:`"across-relation"`). + (default: :obj:`"across-relation"`) + attention_mode (str, optional): The mode to calculate attention logits. + (:obj:`"additive-self-attention"`, + :obj:`"multiplicative-self-attention"`). + (default: :obj:`"additive-self-attention"`) + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + dim (int): Number of dimensions for query and key kernels. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + negative_slope (float, optional): LeakyReLU angle of the negative + slope. (default: :obj:`0.2`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + edge_dim (int, optional): Edge feature dimensionality (in case there + are any). (default: :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not + learn an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + """ + + _alpha: Optional[Tensor] + + def __init__( + self, + in_channels: int, + out_channels: int, + num_relations: int, + num_bases: Optional[int] = None, + num_blocks: Optional[int] = None, + mod: Optional[str] = None, + attention_mechanism: str = "across-relation", + attention_mode: str = "additive-self-attention", + heads: int = 1, + dim: int = 1, + concat: bool = True, + negative_slope: float = 0.2, + dropout: float = 0.0, + edge_dim: Optional[int] = None, + bias: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.heads = heads + self.negative_slope = negative_slope + self.dropout = dropout + self.mod = mod + self.activation = nn.ReLU() + self.concat = concat + self.attention_mode = attention_mode + self.attention_mechanism = attention_mechanism + self.dim = dim + self.edge_dim = edge_dim + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_relations = num_relations + self.num_bases = num_bases + self.num_blocks = num_blocks + + mod_types = ['additive', 'scaled', 'f-additive', 'f-scaled'] + + if (self.attention_mechanism != "within-relation" + and self.attention_mechanism != "across-relation"): + raise ValueError('attention mechanism must either be ' + '"within-relation" or "across-relation"') + + if (self.attention_mode != "additive-self-attention" + and self.attention_mode != "multiplicative-self-attention"): + raise ValueError('attention mode must either be ' + '"additive-self-attention" or ' + '"multiplicative-self-attention"') + + if self.attention_mode == "additive-self-attention" and self.dim > 1: + raise ValueError('"additive-self-attention" mode cannot be ' + 'applied when value of d is greater than 1. ' + 'Use "multiplicative-self-attention" instead.') + + if self.dropout > 0.0 and self.mod in mod_types: + raise ValueError('mod must be None with dropout value greater ' + 'than 0 in order to sample attention ' + 'coefficients stochastically') + + if num_bases is not None and num_blocks is not None: + raise ValueError('Can not apply both basis-decomposition and ' + 'block-diagonal-decomposition at the same time.') + + # The learnable parameters to compute both attention logits and + # attention coefficients: + self.q = Parameter( + ms.numpy.empty([self.heads * self.out_channels, self.heads * self.dim])) + self.k = Parameter( + ms.numpy.empty([self.heads * self.out_channels, self.heads * self.dim])) + + if bias and concat: + self.bias = Parameter( + ms.numpy.empty(self.heads * self.dim * self.out_channels)) + elif bias and not concat: + self.bias = Parameter(ms.numpy.empty(self.dim * self.out_channels)) + else: + self.bias = None + + if edge_dim is not None: + self.lin_edge = nn.Dense(self.edge_dim, + self.heads * self.out_channels, has_bias=False) + self.e = Parameter( + ms.numpy.empty([self.heads * self.out_channels, + self.heads * self.dim])) + else: + self.lin_edge = None + self.e = None + + if num_bases is not None: + self.att = Parameter( + ms.numpy.empty([self.num_relations, self.num_bases])) + self.basis = Parameter( + ms.numpy.empty([self.num_bases, self.in_channels, + self.heads * self.out_channels])) + elif num_blocks is not None: + assert ( + self.in_channels % self.num_blocks == 0 + and (self.heads * self.out_channels) % self.num_blocks == 0), ( + "both 'in_channels' and 'heads * out_channels' must be " + "multiple of 'num_blocks' used") + self.weight = Parameter( + ms.numpy.empty([self.num_relations, self.num_blocks, + self.in_channels // self.num_blocks, + (self.heads * self.out_channels) // + self.num_blocks])) + else: + self.weight = Parameter( + ms.numpy.empty([self.num_relations, self.in_channels, + self.heads * self.out_channels])) + + self.w = Parameter(mint.ones(self.out_channels)) + self.l1 = Parameter(ms.numpy.empty([1, self.out_channels])) + self.b1 = Parameter(ms.numpy.empty([1, self.out_channels])) + self.l2 = Parameter(ms.numpy.empty([self.out_channels, self.out_channels])) + self.b2 = Parameter(ms.numpy.empty([1, self.out_channels])) + + self._alpha = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.num_bases is not None: + glorot(self.basis) + glorot(self.att) + else: + glorot(self.weight) + glorot(self.q) + glorot(self.k) + if self.bias is not None: + zeros(self.bias) + ones(self.l1) + zeros(self.b1) + self.l2[:] = 1 / self.out_channels + zeros(self.b2) + if self.lin_edge is not None: + glorot(self.lin_edge) + glorot(self.e) + + def construct( + self, + x: Tensor, + edge_index: Union[Tensor, ], + edge_type: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + size: Tuple[int, ...] = None, + return_attention_weights=None, + ): + r"""Runs the forward pass of the module. + + Args: + x (Tensor): The input node features. + Can be either a :obj:`[num_nodes, in_channels]` node feature + matrix, or an optional one-dimensional node index tensor (in + which case input features are treated as trainable node + embeddings). + edge_index (Tensor or SparseTensor): The edge indices. + edge_type (Tensor, optional): The one-dimensional relation + type/index for each edge in :obj:`edge_index`. + Should be only :obj:`None` in case :obj:`edge_index` is of type + :class:`mindspore_sparse.SparseTensor` or + :class:`mindspore.sparse.Tensor`. (default: :obj:`None`) + edge_attr (Tensor, optional): The edge features. + (default: :obj:`None`) + size ((int, int), optional): The shape of the adjacency matrix. + (default: :obj:`None`) + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + # propagate_type: (x: Tensor, edge_type: Optional[Tensor], + # edge_attr: Optional[Tensor]) + out = self.propagate(edge_index=edge_index, edge_type=edge_type, x=x, + shape=size, edge_attr=edge_attr) + + alpha = self._alpha + assert alpha is not None + self._alpha = None + + if isinstance(return_attention_weights, bool): + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + else: + return out + + def message(self, x_i: Tensor, x_j: Tensor, edge_type: Tensor, + edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], + size_i: Optional[int]) -> Tensor: + + if self.num_bases is not None: # Basis-decomposition ================= + w = self.att @ self.basis.view(self.num_bases, -1) + w = w.view(self.num_relations, self.in_channels, + self.heads * self.out_channels) + if self.num_blocks is not None: # Block-diagonal-decomposition ======= + if (x_i.dtype == ms.int64 and x_j.dtype == ms.int64 + and self.num_blocks is not None): + raise ValueError('Block-diagonal decomposition not supported ' + 'for non-continuous input features.') + w = self.weight + x_i = x_i.view(-1, 1, w.shape[1], w.shape[2]) + x_j = x_j.view(-1, 1, w.shape[1], w.shape[2]) + w = mint.index_select(w, 0, edge_type) + outi = Ncon([[-1, 1, -2, 2], [-1, -2, 2, -3]])([x_i, w]) + + outi = outi.view(-1, self.heads * self.out_channels) + outj = Ncon([[-1, 1, -2, 2], [-1, -2, 2, -3]])([x_i, w]) + + outj = outj.view(-1, self.heads * self.out_channels) + else: # No regularization/Basis-decomposition ======================== + if self.num_bases is None: + w = self.weight + w = mint.index_select(w, 0, edge_type) + outi = mint.bmm(x_i.unsqueeze(1), w).squeeze(-2) + outj = mint.bmm(x_j.unsqueeze(1), w).squeeze(-2) + + qi = outi @ self.q + kj = outj @ self.k + + alpha_edge, alpha = 0, ms.Tensor([0]) + if edge_attr is not None: + if edge_attr.dim() == 1: + edge_attr = edge_attr.view(-1, 1) + assert self.lin_edge is not None, ( + "Please set 'edge_dim = edge_attr.shape[-1]' while calling the " + "RGATConv layer") + edge_attributes = self.lin_edge(edge_attr).view( + -1, self.heads * self.out_channels) + if edge_attributes.shape[0] != edge_attr.shape[0]: + edge_attributes = ops.index_select(edge_attributes, 0, + edge_type) + alpha_edge = edge_attributes @ self.e + + if self.attention_mode == "additive-self-attention": + if edge_attr is not None: + alpha = mint.add(qi, kj) + alpha_edge + else: + alpha = mint.add(qi, kj) + alpha = ops.leaky_relu(alpha, self.negative_slope) + elif self.attention_mode == "multiplicative-self-attention": + if edge_attr is not None: + alpha = (qi * kj) * alpha_edge + else: + alpha = qi * kj + + if self.attention_mechanism == "within-relation": + across_out = mint.zeros_like(alpha) + for r in range(self.num_relations): + mask = edge_type == r + if mask.sum() > 0: + across_out[mask] = softmax(alpha[mask], index[mask]) + alpha = across_out + elif self.attention_mechanism == "across-relation": + alpha = softmax(alpha, index, ptr, size_i) + + self._alpha = alpha + + if self.mod == "additive": + if self.attention_mode == "additive-self-attention": + ones = mint.ones_like(alpha) + h = (outj.view(-1, self.heads, self.out_channels) * + ones.view(-1, self.heads, 1)) + h = self.w * h + + return (outj.view(-1, self.heads, self.out_channels) * + alpha.view(-1, self.heads, 1) + h) + elif self.attention_mode == "multiplicative-self-attention": + ones = mint.ones_like(alpha) + h = (outj.view(-1, self.heads, 1, self.out_channels) * + ones.view(-1, self.heads, self.dim, 1)) + h = self.w * h + + return (outj.view(-1, self.heads, 1, self.out_channels) * + alpha.view(-1, self.heads, self.dim, 1) + h) + + elif self.mod == "scaled": + if self.attention_mode == "additive-self-attention": + ones = alpha.new_ones(index.shape) + degree = scatter(ones, index, dim_size=size_i, + reduce='sum')[index].unsqueeze(-1) + degree = degree @ self.l1 + self.b1 + degree = self.activation(degree) + degree = degree @ self.l2 + self.b2 + + return outj.view(-1, self.heads, self.out_channels) * \ + alpha.view(-1, self.heads, 1) * \ + degree.view(-1, 1, self.out_channels) + elif self.attention_mode == "multiplicative-self-attention": + ones = alpha.new_ones(index.shape) + degree = scatter(ones, index, dim_size=size_i, + reduce='sum')[index].unsqueeze(-1) + degree = degree @ self.l1 + self.b1 + degree = self.activation(degree) + degree = degree @ self.l2 + self.b2 + + return mint.mul( + outj.view(-1, self.heads, 1, self.out_channels) * + alpha.view(-1, self.heads, self.dim, 1), + degree.view(-1, 1, 1, self.out_channels)) + + elif self.mod == "f-additive": + alpha = mint.where(alpha > 0, alpha + 1, alpha) + + elif self.mod == "f-scaled": + ones = alpha.new_ones(index.shape) + degree = scatter(ones, index, dim_size=size_i, + reduce='sum')[index].unsqueeze(-1) + alpha = alpha * degree + + elif self.training and self.dropout > 0: + alpha = ops.dropout(alpha, p=self.dropout, training=True) + + else: + alpha = alpha # original + + if self.attention_mode == "additive-self-attention": + return alpha.view(-1, self.heads, 1) * outj.view( + -1, self.heads, self.out_channels) + else: + return (alpha.view(-1, self.heads, self.dim, 1) * + outj.view(-1, self.heads, 1, self.out_channels)) + + def update(self, aggr_out: Tensor) -> Tensor: + if self.attention_mode == "additive-self-attention": + if self.concat is True: + aggr_out = aggr_out.view(-1, self.heads * self.out_channels) + else: + aggr_out = aggr_out.mean(1) + + if self.bias is not None: + aggr_out += self.bias + + return aggr_out + else: + if self.concat is True: + aggr_out = aggr_out.view( + -1, self.heads * self.dim * self.out_channels) + else: + aggr_out = aggr_out.mean(1) + aggr_out = aggr_out.view(-1, self.dim * self.out_channels) + + if self.bias is not None: + aggr_out += self.bias + + return aggr_out + + def __repr__(self) -> str: + return '{}({}, {}, heads={})'.format(self.__class__.__name__, + self.in_channels, + self.out_channels, self.heads) diff --git a/mindscience/sharker/nn/conv/sage_conv.py b/mindscience/sharker/nn/conv/sage_conv.py new file mode 100644 index 000000000..8bf3f2bcd --- /dev/null +++ b/mindscience/sharker/nn/conv/sage_conv.py @@ -0,0 +1,142 @@ +from typing import List, Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from ..aggr import Aggregation, MultiAggregation +from .message_passing import MessagePassing + + +class SAGEConv(MessagePassing): + r"""The GraphSAGE operator from the `"Inductive Representation Learning on + Large Graphs" `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + \mathbf{W}_2 \cdot + \mathrm{mean}_{j \in \mathcal{N(i)}} \mathbf{x}_j + + If :obj:`project = True`, then :math:`\mathbf{x}_j` will first get + projected via + + .. math:: + \mathbf{x}_j \leftarrow \sigma ( \mathbf{W}_3 \mathbf{x}_j + + \mathbf{b}) + + as described in Eq. (3) of the paper. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + aggr (str or Aggregation, optional): The aggregation scheme to use. + Any aggregation of :obj:`sharker.nn.aggr` can be used, + *e.g.*, :obj:`"mean"`, :obj:`"max"`, or :obj:`"lstm"`. + (default: :obj:`"mean"`) + normalize (bool, optional): If set to :obj:`True`, output features + will be :math:`\ell_2`-normalized, *i.e.*, + :math:`\frac{\mathbf{x}^{\prime}_i} + {\| \mathbf{x}^{\prime}_i \|_2}`. + (default: :obj:`False`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add transformed root node features to the output. + (default: :obj:`True`) + project (bool, optional): If set to :obj:`True`, the layer will apply a + linear transformation followed by an activation function before + aggregation (as described in Eq. (3) of the paper). + (default: :obj:`False`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **inputs:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{s}), (|\mathcal{V_t}|, F_{t}))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + aggr: Optional[Union[str, List[str], Aggregation]] = "mean", + normalize: bool = False, + root_weight: bool = True, + project: bool = False, + bias: bool = True, + **kwargs, + ): + self.in_channels = in_channels + self.out_channels = out_channels + self.normalize = normalize + self.root_weight = root_weight + self.project = project + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + if aggr == 'lstm': + kwargs.setdefault('aggr_kwargs', {}) + kwargs['aggr_kwargs'].setdefault('in_channels', in_channels[0]) + kwargs['aggr_kwargs'].setdefault('out_channels', in_channels[0]) + + super().__init__(aggr, **kwargs) + + if self.project: + if in_channels[0] <= 0: + raise ValueError(f"'{self.__class__.__name__}' does not " + f"support lazy initialization with " + f"`project=True`") + self.lin = nn.Dense(in_channels[0], in_channels[0], has_bias=True) + + if isinstance(self.aggr_module, MultiAggregation): + aggr_out_channels = self.aggr_module.get_out_channels( + in_channels[0]) + else: + aggr_out_channels = in_channels[0] + + self.lin_l = nn.Dense(aggr_out_channels, out_channels, has_bias=bias) + if self.root_weight: + self.lin_r = nn.Dense(in_channels[1], out_channels, has_bias=False) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + size: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + if self.project and hasattr(self, 'lin'): + x = (ops.relu(self.lin(x[0])), x[1]) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]]) + out = self.propagate(edge_index, x=x, shape=size) + out = self.lin_l(out) + + x_r = x[1] + if self.root_weight and x_r is not None: + out += self.lin_r(x_r) + + if self.normalize: + out /= ms.numpy.norm(out, ord=2., axis=-1, keepdims=True) + out[out.isnan()] = 0 + return out + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, aggr={self.aggr})') diff --git a/mindscience/sharker/nn/conv/sg_conv.py b/mindscience/sharker/nn/conv/sg_conv.py new file mode 100644 index 000000000..a18e12175 --- /dev/null +++ b/mindscience/sharker/nn/conv/sg_conv.py @@ -0,0 +1,98 @@ +from typing import Optional, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm + + +class SGConv(MessagePassing): + r"""The simple graph convolutional operator from the `"Simplifying Graph + Convolutional Networks" `_ paper. + + .. math:: + \mathbf{X}^{\prime} = {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X} \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + The adjacency matrix can include other values than :obj:`1` representing + edge weights via the optional :obj:`edge_weight` tensor. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + K (int, optional): Number of hops :math:`K`. (default: :obj:`1`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`{\left(\mathbf{\hat{D}}^{-1/2} + \mathbf{\hat{A}} \mathbf{\hat{D}}^{-1/2} \right)}^K \mathbf{X}` on + first execution, and will use the cached version for further + executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** + node features :math:`(|\mathcal{V}|, F_{out})` + """ + + _cached_x: Optional[Tensor] + + def __init__(self, in_channels: int, out_channels: int, K: int = 1, + cached: bool = False, add_self_loops: bool = True, + bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.K = K + self.cached = cached + self.add_self_loops = add_self_loops + + self._cached_x = None + + self.lin = nn.Dense(in_channels, out_channels, has_bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self._cached_x = None + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + cache = self._cached_x + if cache is None: + if isinstance(edge_index, Tensor): + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, dtype=x.dtype) + + for k in range(self.K): + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + x = self.propagate(edge_index, x=x, edge_weight=edge_weight) + if self.cached: + self._cached_x = x + else: + x = cache + + return self.lin(x) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, K={self.K})') diff --git a/mindscience/sharker/nn/conv/signed_conv.py b/mindscience/sharker/nn/conv/signed_conv.py new file mode 100644 index 000000000..b6721a3d0 --- /dev/null +++ b/mindscience/sharker/nn/conv/signed_conv.py @@ -0,0 +1,136 @@ +from typing import Union, Tuple +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing + + +class SignedConv(MessagePassing): + r"""The signed graph convolutional operator from the `"Signed Graph + Convolutional Network" `_ paper. + + .. math:: + \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} + \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} + \mathbf{x}_w , \mathbf{x}_v \right] + + \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{neg})} + \left[ \frac{1}{|\mathcal{N}^{-}(v)|} \sum_{w \in \mathcal{N}^{-}(v)} + \mathbf{x}_w , \mathbf{x}_v \right] + + if :obj:`first_aggr` is set to :obj:`True`, and + + .. math:: + \mathbf{x}_v^{(\textrm{pos})} &= \mathbf{\Theta}^{(\textrm{pos})} + \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} + \mathbf{x}_w^{(\textrm{pos})}, \frac{1}{|\mathcal{N}^{-}(v)|} + \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{neg})}, + \mathbf{x}_v^{(\textrm{pos})} \right] + + \mathbf{x}_v^{(\textrm{neg})} &= \mathbf{\Theta}^{(\textrm{pos})} + \left[ \frac{1}{|\mathcal{N}^{+}(v)|} \sum_{w \in \mathcal{N}^{+}(v)} + \mathbf{x}_w^{(\textrm{neg})}, \frac{1}{|\mathcal{N}^{-}(v)|} + \sum_{w \in \mathcal{N}^{-}(v)} \mathbf{x}_w^{(\textrm{pos})}, + \mathbf{x}_v^{(\textrm{neg})} \right] + + otherwise. + In case :obj:`first_aggr` is :obj:`False`, the layer expects :obj:`x` to be + a tensor where :obj:`x[:, :in_channels]` denotes the positive node features + :math:`\mathbf{X}^{(\textrm{pos})}` and :obj:`x[:, in_channels:]` denotes + the negative node features :math:`\mathbf{X}^{(\textrm{neg})}`. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + first_aggr (bool): Denotes which aggregation formula to use. + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})` or + :math:`((|\mathcal{V_s}|, F_{in}), (|\mathcal{V_t}|, F_{in}))` + if bipartite, + positive edge indices :math:`(2, |\mathcal{E}^{(+)}|)`, + negative edge indices :math:`(2, |\mathcal{E}^{(-)}|)` + - **outputs:** node features :math:`(|\mathcal{V}|, F_{out})` or + :math:`(|\mathcal{V_t}|, F_{out})` if bipartite + """ + + def __init__(self, in_channels: int, out_channels: int, first_aggr: bool, + bias: bool = True, **kwargs): + + kwargs.setdefault('aggr', 'mean') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.first_aggr = first_aggr + + if first_aggr: + self.lin_pos_l = nn.Dense(in_channels, out_channels, False) + self.lin_pos_r = nn.Dense(in_channels, out_channels, bias) + self.lin_neg_l = nn.Dense(in_channels, out_channels, False) + self.lin_neg_r = nn.Dense(in_channels, out_channels, bias) + else: + self.lin_pos_l = nn.Dense(2 * in_channels, out_channels, False) + self.lin_pos_r = nn.Dense(in_channels, out_channels, bias) + self.lin_neg_l = nn.Dense(2 * in_channels, out_channels, False) + self.lin_neg_r = nn.Dense(in_channels, out_channels, bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + pos_edge_index: Union[Tensor, ], + neg_edge_index: Union[Tensor, ], + ): + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Tensor]) + if self.first_aggr: + + out_pos = self.propagate(pos_edge_index, x=x) + out_pos = self.lin_pos_l(out_pos) + out_pos = out_pos + self.lin_pos_r(x[1]) + + out_neg = self.propagate(neg_edge_index, x=x) + out_neg = self.lin_neg_l(out_neg) + out_neg = out_neg + self.lin_neg_r(x[1]) + + return mint.cat(([out_pos, out_neg]), dim=-1) + + else: + F_in = self.in_channels + + out_pos1 = self.propagate(pos_edge_index, + x=(x[0][..., :F_in], x[1][..., :F_in])) + out_pos2 = self.propagate(neg_edge_index, + x=(x[0][..., F_in:], x[1][..., F_in:])) + out_pos = mint.cat(([out_pos1, out_pos2]), dim=-1) + out_pos = self.lin_pos_l(out_pos) + out_pos = out_pos + self.lin_pos_r(x[1][..., :F_in]) + + out_neg1 = self.propagate(pos_edge_index, + x=(x[0][..., F_in:], x[1][..., F_in:])) + out_neg2 = self.propagate(neg_edge_index, + x=(x[0][..., :F_in], x[1][..., :F_in])) + out_neg = mint.cat(([out_neg1, out_neg2]), dim=-1) + out_neg = self.lin_neg_l(out_neg) + out_neg = out_neg + self.lin_neg_r(x[1][..., F_in:]) + + return mint.cat(([out_pos, out_neg]), dim=-1) + + def message(self, x_j: Tensor) -> Tensor: + return x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, first_aggr={self.first_aggr})') diff --git a/mindscience/sharker/nn/conv/simple_conv.py b/mindscience/sharker/nn/conv/simple_conv.py new file mode 100644 index 000000000..54568b13f --- /dev/null +++ b/mindscience/sharker/nn/conv/simple_conv.py @@ -0,0 +1,84 @@ +from typing import List, Optional, Union, Tuple +from mindspore import Tensor, ops, nn, mint +from ..aggr import Aggregation +from .message_passing import MessagePassing +from ...utils import add_self_loops + + +class SimpleConv(MessagePassing): + r"""A simple message passing operator that performs (non-trainable) + propagation. + + .. math:: + \mathbf{x}^{\prime}_i = \bigoplus_{j \in \mathcal{N(i)}} e_{ji} \cdot + \mathbf{x}_j + + where :math:`\bigoplus` defines a custom aggregation scheme. + + Args: + aggr (str or [str] or Aggregation, optional): The aggregation scheme + to use, *e.g.*, :obj:`"add"`, :obj:`"sum"` :obj:`"mean"`, + :obj:`"min"`, :obj:`"max"` or :obj:`"mul"`. + In addition, can be any + :class:`~sharker.nn.aggr.Aggregation` module (or any string + that automatically resolves to it). (default: :obj:`"sum"`) + combine_root (str, optional): Specifies whether or how to combine the + central node representation (one of :obj:`"sum"`, :obj:`"cat"`, + :obj:`"self_loop"`, :obj:`None`). (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **inputs:** + node features :math:`(|\mathcal{V}|, F)` or + :math:`((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, *))` + if bipartite, + edge indices :math:`(2, |\mathcal{E}|)` + - **outputs:** node features :math:`(|\mathcal{V}|, F)` or + :math:`(|\mathcal{V_t}|, F)` if bipartite + """ + + def __init__( + self, + aggr: Optional[Union[str, List[str], Aggregation]] = "sum", + combine_root: Optional[str] = None, + **kwargs, + ): + if combine_root not in ['sum', 'cat', 'self_loop', None]: + raise ValueError(f"Received invalid value for 'combine_root' " + f"(got '{combine_root}')") + + super().__init__(aggr, **kwargs) + self.combine_root = combine_root + + def construct(self, x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, size: Tuple[int, ...] = None) -> Tensor: + + if self.combine_root is not None: + if self.combine_root == 'self_loop': + if not isinstance(x, Tensor) or (size is not None + and size[0] != size[1]): + raise ValueError("Cannot use `combine_root='self_loop'` " + "for bipartite message passing") + if isinstance(edge_index, Tensor): + edge_index, edge_weight = add_self_loops( + edge_index, edge_weight, num_nodes=x.shape[0]) + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, + shape=size) + + x_dst = x[1] + if x_dst is not None and self.combine_root is not None: + if self.combine_root == 'sum': + out += x_dst + elif self.combine_root == 'cat': + out = mint.cat(([x_dst, out]), dim=-1) + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j \ No newline at end of file diff --git a/mindscience/sharker/nn/conv/spline_conv.py b/mindscience/sharker/nn/conv/spline_conv.py new file mode 100644 index 000000000..583ca9b88 --- /dev/null +++ b/mindscience/sharker/nn/conv/spline_conv.py @@ -0,0 +1,142 @@ +import warnings +from typing import List, Tuple, Union, Optional +import mindspore as ms +from mindspore import Tensor, ops, nn, Parameter, mint +from .message_passing import MessagePassing +from ..inits import uniform, zeros +from ...utils.repeat import repeat + +try: + from mindspore_spline_conv import spline_basis, spline_weighting +except (ImportError, OSError): # Fail gracefully on GLIBC errors + spline_basis = None + spline_weighting = None + + +class SplineConv(MessagePassing): + r"""The spline-based convolutional operator from the `"SplineCNN: Fast + Geometric Deep Learning with Continuous B-Spline Kernels" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \frac{1}{|\mathcal{N}(i)|} \sum_{j \in + \mathcal{N}(i)} \mathbf{x}_j \cdot + h_{\mathbf{\Theta}}(\mathbf{e}_{i,j}), + + where :math:`h_{\mathbf{\Theta}}` denotes a kernel function defined + over the weighted B-Spline tensor product basis. + + .. note:: + + Pseudo-coordinates must lay in the fixed interval :math:`[0, 1]` for + this method to work as intended. + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + dim (int): Pseudo-coordinate dimensionality. + kernel_size (int or [int]): Size of the convolving kernel. + is_open_spline (bool or [bool], optional): If set to :obj:`False`, the + operator will use a closed B-spline basis in this dimension. + (default :obj:`True`) + degree (int, optional): B-spline basis degrees. (default: :obj:`1`) + aggr (str, optional): The aggregation scheme to use + (:obj:`"add"`, :obj:`"mean"`, :obj:`"max"`). + (default: :obj:`"mean"`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add transformed root node features to the output. + (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + """ + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + dim: int, + kernel_size: Union[int, List[int]], + is_open_spline: bool = True, + degree: int = 1, + aggr: str = 'mean', + root_weight: bool = True, + bias: bool = True, + **kwargs, + ): + super().__init__(aggr=aggr, **kwargs) + + if spline_basis is None: + raise ImportError("'SplineConv' requires 'mindspore-spline-conv'") + + self.in_channels = in_channels + self.out_channels = out_channels + self.dim = dim + self.degree = degree + self.root_weight = root_weight + + kernel_size = ms.Tensor(repeat(kernel_size, dim), dtype=ms.int64) + self.kernel_size = Parameter(kernel_size, requires_grad=False) + + is_open_spline = repeat(is_open_spline, dim) + is_open_spline = ms.Tensor(is_open_spline, dtype=ms.uint8) + self.is_open_spline = Parameter(is_open_spline, requires_grad=False) + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.K = kernel_size.prod().item() + + if in_channels[0] > 0: + self.weight = Parameter( + ms.numpy.empty([self.K, in_channels[0], out_channels])) + + if root_weight: + self.lin = nn.Dense(in_channels[1], out_channels, has_bias=False, + weight_initializer='uniform') + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if hasattr(self, 'weight'): + size = self.weight.shape[0] * self.weight.shape[1] + uniform(size, self.weight) + if self.root_weight: + self.lin.reset_parameters() + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, size: Tuple[int, ...] = None) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + out = self.propagate(edge_index, x=x, edge_attr=edge_attr, shape=size) + + x_r = x[1] + if x_r is not None and self.root_weight: + out += self.lin(x_r) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_attr: Tensor) -> Tensor: + data = spline_basis(edge_attr, self.kernel_size, self.is_open_spline, + self.degree) + return spline_weighting(x_j, self.weight, *data) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, dim={self.dim})') diff --git a/mindscience/sharker/nn/conv/ssg_conv.py b/mindscience/sharker/nn/conv/ssg_conv.py new file mode 100644 index 000000000..4a0a9d0d4 --- /dev/null +++ b/mindscience/sharker/nn/conv/ssg_conv.py @@ -0,0 +1,109 @@ +from typing import Optional, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm + + +class SSGConv(MessagePassing): + r"""The simple spectral graph convolutional operator from the + `"Simple Spectral Graph Convolution" + `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \frac{1}{K} \sum_{k=1}^K\left((1-\alpha) + {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \right)}^k + \mathbf{X}+\alpha \mathbf{X}\right) \mathbf{\Theta}, + + where :math:`\mathbf{\hat{A}} = \mathbf{A} + \mathbf{I}` denotes the + adjacency matrix with inserted self-loops and + :math:`\hat{D}_{ii} = \sum_{j=0} \hat{A}_{ij}` its diagonal degree matrix. + The adjacency matrix can include other values than :obj:`1` representing + edge weights via the optional :obj:`edge_weight` tensor. + :class:`~sharker.nn.conv.SSGConv` is an improved operator of + :class:`~sharker.nn.conv.SGConv` by introducing the :obj:`alpha` + parameter to address the oversmoothing issue. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + alpha (float): Teleport probability :math:`\alpha \in [0, 1]`. + K (int, optional): Number of hops :math:`K`. (default: :obj:`1`) + cached (bool, optional): If set to :obj:`True`, the layer will cache + the computation of :math:`\frac{1}{K} \sum_{k=1}^K\left((1-\alpha) + {\left(\mathbf{\hat{D}}^{-1/2} \mathbf{\hat{A}} + \mathbf{\hat{D}}^{-1/2} \right)}^k \mathbf{X}+ + \alpha \mathbf{X}\right)` on first execution, and will use the + cached version for further executions. + This parameter should only be set to :obj:`True` in transductive + learning scenarios. (default: :obj:`False`) + add_self_loops (bool, optional): If set to :obj:`False`, will not add + self-loops to the input graph. (default: :obj:`True`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** + node features :math:`(|\mathcal{V}|, F_{out})` + """ + + _cached_h: Optional[Tensor] + + def __init__(self, in_channels: int, out_channels: int, alpha: float, + K: int = 1, cached: bool = False, add_self_loops: bool = True, + bias: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.alpha = alpha + self.K = K + self.cached = cached + self.add_self_loops = add_self_loops + + self._cached_h = None + + self.lin = nn.Dense(in_channels, out_channels, has_bias=bias) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + self._cached_h = None + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + cache = self._cached_h + if cache is None: + if isinstance(edge_index, Tensor): + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], False, + self.add_self_loops, self.flow, dtype=x.dtype) + + h = x * self.alpha + for k in range(self.K): + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + x = self.propagate(edge_index, x=x, edge_weight=edge_weight) + h = h + (1 - self.alpha) / self.K * x + if self.cached: + self._cached_h = h + else: + h = cache + + return self.lin(h) + + def message(self, x_j: Tensor, edge_weight: Tensor) -> Tensor: + return edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, K={self.K}, alpha={self.alpha})') diff --git a/mindscience/sharker/nn/conv/tag_conv.py b/mindscience/sharker/nn/conv/tag_conv.py new file mode 100644 index 000000000..703a8cd1a --- /dev/null +++ b/mindscience/sharker/nn/conv/tag_conv.py @@ -0,0 +1,95 @@ +from typing import Union, Optional +import mindspore as ms +from mindspore import Tensor, ops, nn, Parameter, mint +from .message_passing import MessagePassing +from .gcn_conv import gcn_norm +from ..inits import zeros + + +class TAGConv(MessagePassing): + r"""The topology adaptive graph convolutional networks operator from the + `"Topology Adaptive Graph Convolutional Networks" + `_ paper. + + .. math:: + \mathbf{X}^{\prime} = \sum_{k=0}^K \left( \mathbf{D}^{-1/2} \mathbf{A} + \mathbf{D}^{-1/2} \right)^k \mathbf{X} \mathbf{W}_{k}, + + where :math:`\mathbf{A}` denotes the adjacency matrix and + :math:`D_{ii} = \sum_{j=0} A_{ij}` its diagonal degree matrix. + The adjacency matrix can include other values than :obj:`1` representing + edge weights via the optional :obj:`edge_weight` tensor. + + Args: + in_channels (int): Size of each input sample, or :obj:`-1` to derive + the size from the first input(s) to the forward method. + out_channels (int): Size of each output sample. + K (int, optional): Number of hops :math:`K`. (default: :obj:`3`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + normalize (bool, optional): Whether to apply symmetric normalization. + (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node_features :math:`(|\mathcal{V}|, F_{in})`, + edge_index :math:`(2, |\mathcal{E}|)`, + edge_weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, K: int = 3, + bias: bool = True, normalize: bool = True, **kwargs): + kwargs.setdefault('aggr', 'add') + super().__init__(**kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.K = K + self.normalize = normalize + + self.lins = nn.CellList([ + nn.Dense(in_channels, out_channels, has_bias=False) for _ in range(K + 1) + ]) + + if bias: + self.bias = Parameter(ms.numpy.empty(out_channels)) + else: + self.bias = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + if self.bias is not None: + zeros(self.bias) + + def construct(self, x: Tensor, edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None) -> Tensor: + + if self.normalize: + if isinstance(edge_index, Tensor): + edge_index, edge_weight = gcn_norm( # yapf: disable + edge_index, edge_weight, x.shape[self.node_dim], + improved=False, add_self_loops=False, flow=self.flow, + dtype=x.dtype) + + out = self.lins[0](x) + for lin in self.lins[1:]: + # propagate_type: (x: Tensor, edge_weight: Optional[Tensor]) + x = self.propagate(edge_index, x=x, edge_weight=edge_weight) + out += lin(x) + + if self.bias is not None: + out += self.bias + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, K={self.K})') diff --git a/mindscience/sharker/nn/conv/transformer_conv.py b/mindscience/sharker/nn/conv/transformer_conv.py new file mode 100644 index 000000000..559c3b2f8 --- /dev/null +++ b/mindscience/sharker/nn/conv/transformer_conv.py @@ -0,0 +1,224 @@ +import math +from typing import Optional, Tuple, Union +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ...utils import softmax + + +class TransformerConv(MessagePassing): + r"""The graph transformer operator from the `"Masked Label Prediction: + Unified Message Passing Model for Semi-Supervised Classification" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j}, + + where the attention coefficients :math:`\alpha_{i,j}` are computed via + multi-head dot product attention: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} (\mathbf{W}_4\mathbf{x}_j)} + {\sqrt{d}} \right) + + Args: + in_channels (int or tuple): Size of each input sample, or :obj:`-1` to + derive the size from the first input(s) to the forward method. + A tuple corresponds to the sizes of source and target + dimensionalities. + out_channels (int): Size of each output sample. + heads (int, optional): Number of multi-head-attentions. + (default: :obj:`1`) + concat (bool, optional): If set to :obj:`False`, the multi-head + attentions are averaged instead of concatenated. + (default: :obj:`True`) + beta (bool, optional): If set, will combine aggregation and + skip information via + + .. math:: + \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i + + (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)} + \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i} + + with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top} + [ \mathbf{W}_1 \mathbf{x}_i, \mathbf{m}_i, \mathbf{W}_1 + \mathbf{x}_i - \mathbf{m}_i ])` (default: :obj:`False`) + dropout (float, optional): Dropout probability of the normalized + attention coefficients which exposes each node to a stochastically + sampled neighborhood during training. (default: :obj:`0`) + edge_dim (int, optional): Edge feature dimensionality (in case + there are any). Edge features are added to the keys after + linear transformation, that is, prior to computing the + attention dot product. They are also added to final values + after the same linear transformation. The model is: + + .. math:: + \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i + + \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left( + \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij} + \right), + + where the attention coefficients :math:`\alpha_{i,j}` are now + computed via: + + .. math:: + \alpha_{i,j} = \textrm{softmax} \left( + \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top} + (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})} + {\sqrt{d}} \right) + + (default :obj:`None`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + root_weight (bool, optional): If set to :obj:`False`, the layer will + not add the transformed root node features to the output and the + option :attr:`beta` is set to :obj:`False`. (default: :obj:`True`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + """ + _alpha: Optional[Tensor] + + def __init__( + self, + in_channels: Union[int, Tuple[int, int]], + out_channels: int, + heads: int = 1, + concat: bool = True, + beta: bool = False, + dropout: float = 0., + edge_dim: Optional[int] = None, + bias: bool = True, + root_weight: bool = True, + **kwargs, + ): + kwargs.setdefault('aggr', 'add') + super().__init__(node_dim=0, **kwargs) + + self.in_channels = in_channels + self.out_channels = out_channels + self.heads = heads + self.beta = beta and root_weight + self.root_weight = root_weight + self.concat = concat + self.dropout = dropout + self.edge_dim = edge_dim + self._alpha = None + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_key = nn.Dense(in_channels[0], heads * out_channels) + self.lin_query = nn.Dense(in_channels[1], heads * out_channels) + self.lin_value = nn.Dense(in_channels[0], heads * out_channels) + if edge_dim is not None: + self.lin_edge = nn.Dense(edge_dim, heads * out_channels, has_bias=False) + else: + self.lin_edge = None + + if concat: + self.lin_skip = nn.Dense(in_channels[1], heads * out_channels, + has_bias=bias) + if self.beta: + self.lin_beta = nn.Dense(3 * heads * out_channels, 1, has_bias=False) + else: + self.lin_beta = None + else: + self.lin_skip = nn.Dense(in_channels[1], out_channels, has_bias=bias) + if self.beta: + self.lin_beta = nn.Dense(3 * out_channels, 1, has_bias=False) + else: + self.lin_beta = None + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + + def construct( # noqa: F811 + self, + x: Union[Tensor, Tuple[Tensor, Tensor]], + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + return_attention_weights: Optional[bool] = None, + ) -> Union[ + Tensor, + Tuple[Tensor, Tuple[Tensor, Tensor]], + Tuple[Tensor, ], + ]: + r"""Runs the forward pass of the module. + + Args: + x (Tensor or (Tensor, Tensor)): The input node + features. + edge_index (Tensor or SparseTensor): The edge indices. + edge_attr (Tensor, optional): The edge features. + (default: :obj:`None`) + return_attention_weights (bool, optional): If set to :obj:`True`, + will additionally return the tuple + :obj:`(edge_index, attention_weights)`, holding the computed + attention weights for each edge. (default: :obj:`None`) + """ + H, C = self.heads, self.out_channels + + if isinstance(x, Tensor): + x = (x, x) + + query = self.lin_query(x[1]).view(-1, H, C) + key = self.lin_key(x[0]).view(-1, H, C) + value = self.lin_value(x[0]).view(-1, H, C) + + # propagate_type: (query: Tensor, key:Tensor, value: Tensor, + # edge_attr: Optional[Tensor]) + out = self.propagate(edge_index, query=query, key=key, value=value, + edge_attr=edge_attr) + + alpha = self._alpha + self._alpha = None + + if self.concat: + out = out.view(-1, self.heads * self.out_channels) + else: + out = out.mean(1) + + if self.root_weight: + x_r = self.lin_skip(x[1]) + if self.lin_beta is not None: + beta = self.lin_beta(mint.cat(([out, x_r, out - x_r]), dim=-1)) + beta = beta.sigmoid() + out = beta * x_r + (1 - beta) * out + else: + out += x_r + + if isinstance(return_attention_weights, bool): + assert alpha is not None + if isinstance(edge_index, Tensor): + return out, (edge_index, alpha) + else: + return out + + def message(self, query_i: Tensor, key_j: Tensor, value_j: Tensor, + edge_attr: Optional[Tensor], index: Tensor, ptr: Optional[Tensor], + size_i: Optional[int]) -> Tensor: + + if self.lin_edge is not None: + assert edge_attr is not None + edge_attr = self.lin_edge(edge_attr).view(-1, self.heads, + self.out_channels) + key_j = key_j + edge_attr + + alpha = (query_i * key_j).sum(-1) / math.sqrt(self.out_channels) + alpha = softmax(alpha, index, ptr, size_i) + self._alpha = alpha + alpha = ops.dropout(alpha, p=self.dropout, training=self.training) + + out = value_j + if edge_attr is not None: + out += edge_attr + + out = out * alpha.view(-1, self.heads, 1) + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, heads={self.heads})') diff --git a/mindscience/sharker/nn/conv/wl_conv.py b/mindscience/sharker/nn/conv/wl_conv.py new file mode 100644 index 000000000..1dafb1677 --- /dev/null +++ b/mindscience/sharker/nn/conv/wl_conv.py @@ -0,0 +1,81 @@ +from typing import Optional, Union +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from ...utils import ( + degree, + is_sparse_tensor, + scatter, + sort_edge_index, +) + + +class WLConv(nn.CellList): + r"""The Weisfeiler Lehman (WL) operator from the `"A Reduction of a Graph + to a Canonical Form and an Algebra Arising During this Reduction" + `_ paper. + + :class:`WLConv` iteratively refines node colorings according to: + + .. math:: + \mathbf{x}^{\prime}_i = \textrm{hash} \left( \mathbf{x}_i, \{ + \mathbf{x}_j \colon j \in \mathcal{N}(i) \} \right) + + Shapes: + - **input:** + node coloring :math:`(|\mathcal{V}|, F_{in})` *(one-hot encodings)* + or :math:`(|\mathcal{V}|)` *(integer-based)*, + edge indices :math:`(2, |\mathcal{E}|)` + - **output:** node coloring :math:`(|\mathcal{V}|)` *(integer-based)* + """ + + def __init__(self): + super().__init__() + self.hashmap = {} + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.hashmap = {} + + def construct(self, x: Tensor, edge_index: Union[Tensor, ]) -> Tensor: + r"""Runs the forward pass of the module.""" + if x.dim() > 1: + assert (x.sum(-1) == 1).sum() == x.shape[0] + x = x.argmax(axis=-1) + assert x.dtype == ms.int64 + + edge_index = sort_edge_index(edge_index, num_nodes=x.shape[0], + sort_by_row=False) + row, col = edge_index[0], edge_index[1] + + deg = degree(col, x.shape[0], dtype=ms.int64).tolist() + + out = [] + for node, neighbors in zip(x.tolist(), x[row].split(deg)): + idx = hash(tuple([node] + neighbors.sort()[0].tolist())) + if idx not in self.hashmap: + self.hashmap[idx] = len(self.hashmap) + out.append(self.hashmap[idx]) + + return ms.Tensor(out) + + def histogram(self, x: Tensor, batch: Optional[Tensor] = None, + norm: bool = False) -> Tensor: + r"""Given a node coloring :obj:`x`, computes the color histograms of + the respective graphs (separated by :obj:`batch`). + """ + if batch is None: + batch = mint.zeros(x.shape[0], dtype=ms.int32) + + num_colors = len(self.hashmap) + batch_size = int(batch.max()) + 1 + + index = batch * num_colors + x + out = scatter(mint.ones_like(index), index, dim=0, + dim_size=num_colors * batch_size, reduce='sum') + out = out.view(batch_size, num_colors) + + if norm: + out = out.float() + out /= ms.numpy.norm(out, axis=-1, keepdims=True) + out[out.isnan()] = 0 + return out diff --git a/mindscience/sharker/nn/conv/wl_conv_continuous.py b/mindscience/sharker/nn/conv/wl_conv_continuous.py new file mode 100644 index 000000000..b5b0082b5 --- /dev/null +++ b/mindscience/sharker/nn/conv/wl_conv_continuous.py @@ -0,0 +1,72 @@ +from typing import Union, Tuple, Optional +from mindspore import Tensor, ops, nn, mint +from .message_passing import MessagePassing +from ...utils import scatter + + +class WLConvContinuous(MessagePassing): + r"""The Weisfeiler Lehman operator from the `"Wasserstein + Weisfeiler-Lehman Graph Kernels" `_ + paper. + + Refinement is done though a degree-scaled mean aggregation and works on + nodes with continuous attributes: + + .. math:: + \mathbf{x}^{\prime}_i = \frac{1}{2}\big(\mathbf{x}_i + + \frac{1}{\textrm{deg}(i)} + \sum_{j \in \mathcal{N}(i)} e_{j,i} \cdot \mathbf{x}_j \big) + + where :math:`e_{j,i}` denotes the edge weight from source node :obj:`j` to + target node :obj:`i` (default: :obj:`1`) + + Args: + **kwargs (optional): Additional arguments of + :class:`sharker.nn.conv.MessagePassing`. + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F)` or + :math:`((|\mathcal{V_s}|, F), (|\mathcal{V_t}|, F))` if bipartite, + edge indices :math:`(2, |\mathcal{E}|)`, + edge weights :math:`(|\mathcal{E}|)` *(optional)* + - **output:** node features :math:`(|\mathcal{V}|, F)` or + :math:`(|\mathcal{V}_t|, F)` if bipartite + """ + + def __init__(self, **kwargs): + super().__init__(aggr='add', **kwargs) + + def construct( + self, + x: Union[Tensor, Tuple[Tensor, Optional[Tensor]]], + edge_index: Union[Tensor, ], + edge_weight: Optional[Tensor] = None, + size: Tuple[int, ...] = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + # propagate_type: (x: Tuple[Tensor, Optional[Tensor]], edge_weight: Optional[Tensor]) + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, + shape=size) + + dst_index = edge_index[1] + + if edge_weight is None: + edge_weight = x[0].new_ones(dst_index.numel()) + + deg = scatter(edge_weight, dst_index, 0, out.shape[0], reduce='sum') + deg_inv = 1. / deg + deg_inv[deg_inv == float('inf')] = 0 + out = deg_inv.view(-1, 1) * out + + x_dst = x[1] + if x_dst is not None: + out = 0.5 * (x_dst + out) + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j \ No newline at end of file diff --git a/mindscience/sharker/nn/conv/x_conv.py b/mindscience/sharker/nn/conv/x_conv.py new file mode 100644 index 000000000..59a21e69b --- /dev/null +++ b/mindscience/sharker/nn/conv/x_conv.py @@ -0,0 +1,148 @@ +from math import ceil +from typing import Optional +from mindspore import Tensor, ops, nn, mint +from ..reshape import Reshape +from ..inits import reset +from ...utils.cluster import knn_graph + + +class XConv(nn.CellList): + r"""The convolutional operator on :math:`\mathcal{X}`-transformed points + from the `"PointCNN: Convolution On X-Transformed Points" + `_ paper. + + .. math:: + \mathbf{x}^{\prime}_i = \mathrm{Conv}\left(\mathbf{K}, + \gamma_{\mathbf{\Theta}}(\mathbf{P}_i - \mathbf{p}_i) \times + \left( h_\mathbf{\Theta}(\mathbf{P}_i - \mathbf{p}_i) \, \Vert \, + \mathbf{x}_i \right) \right), + + where :math:`\mathbf{K}` and :math:`\mathbf{P}_i` denote the trainable + filter and neighboring point positions of :math:`\mathbf{x}_i`, + respectively. + :math:`\gamma_{\mathbf{\Theta}}` and :math:`h_{\mathbf{\Theta}}` describe + neural networks, *i.e.* MLPs, where :math:`h_{\mathbf{\Theta}}` + individually lifts each point into a higher-dimensional space, and + :math:`\gamma_{\mathbf{\Theta}}` computes the :math:`\mathcal{X}`- + transformation matrix based on *all* points in a neighborhood. + + Args: + in_channels (int): Size of each input sample. + out_channels (int): Size of each output sample. + dim (int): Point cloud dimensionality. + kernel_size (int): Size of the convolving kernel, *i.e.* number of + neighbors including self-loops. + hidden_channels (int, optional): Output size of + :math:`h_{\mathbf{\Theta}}`, *i.e.* dimensionality of lifted + points. If set to :obj:`None`, will be automatically set to + :obj:`in_channels / 4`. (default: :obj:`None`) + dilation (int, optional): The factor by which the neighborhood is + extended, from which :obj:`kernel_size` neighbors are then + uniformly sampled. Can be interpreted as the dilation rate of + classical convolutional operators. (default: :obj:`1`) + bias (bool, optional): If set to :obj:`False`, the layer will not learn + an additive bias. (default: :obj:`True`) + + Shapes: + - **input:** + node features :math:`(|\mathcal{V}|, F_{in})`, + positions :math:`(|\mathcal{V}|, D)`, + batch vector :math:`(|\mathcal{V}|)` *(optional)* + - **output:** + node features :math:`(|\mathcal{V}|, F_{out})` + """ + + def __init__(self, in_channels: int, out_channels: int, dim: int, + kernel_size: int, hidden_channels: Optional[int] = None, + dilation: int = 1, bias: bool = True): + super().__init__() + + if knn_graph is None: + raise ImportError('`XConv` requires `mindspore-cluster`.') + + self.in_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels // 4 + assert hidden_channels > 0 + self.hidden_channels = hidden_channels + self.out_channels = out_channels + self.dim = dim + self.kernel_size = kernel_size + self.dilation = dilation + + C_in, C_delta, C_out = in_channels, hidden_channels, out_channels + D, K = dim, kernel_size + + self.mlp1 = nn.SequentialCell( + nn.Dense(dim, C_delta), + nn.ELU(), + nn.BatchNorm1d(C_delta), + nn.Dense(C_delta, C_delta), + nn.ELU(), + nn.BatchNorm1d(C_delta), + Reshape(-1, K, C_delta), + ) + + self.mlp2 = nn.SequentialCell( + nn.Dense(D * K, K**2), + nn.ELU(), + nn.BatchNorm1d(K**2), + Reshape(-1, K, K), + nn.Conv1d(K, K**2, K, group=K, pad_mode='valid'), + nn.ELU(), + nn.BatchNorm1d(K**2), + Reshape(-1, K, K), + nn.Conv1d(K, K**2, K, group=K, pad_mode='valid'), + nn.BatchNorm1d(K**2), + Reshape(-1, K, K), + ) + + C_in += C_delta + depth_multiplier = int(ceil(C_out / C_in)) + self.conv = nn.SequentialCell( + nn.Conv1d(C_in, C_in * depth_multiplier, K, group=C_in, pad_mode='valid'), + Reshape(-1, C_in * depth_multiplier), + nn.Dense(C_in * depth_multiplier, C_out, has_bias=bias), + ) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + reset(self.mlp1) + reset(self.mlp2) + reset(self.conv) + + def construct(self, x: Tensor, pos: Tensor, batch: Optional[Tensor] = None): + r"""Runs the forward pass of the module.""" + pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos + (N, D), K = pos.shape, self.kernel_size + + edge_index = knn_graph(pos, K * self.dilation, batch, loop=True, + flow='trg_to_src') + + if self.dilation > 1: + edge_index = edge_index[:, ::self.dilation] + + row, col = edge_index[0], edge_index[1] + + pos = pos[col] - pos[row] + + x_star = self.mlp1(pos) + if x is not None: + x = x.unsqueeze(-1) if x.dim() == 1 else x + x = x[col].view(N, K, self.in_channels) + x_star = mint.cat(([x_star, x]), dim=-1) + x_star = x_star.swapaxes(1, 2) + + transform_matrix = self.mlp2(pos.view(N, K * D)) + + x_transformed = x_star @ transform_matrix + + out = self.conv(x_transformed) + + return out + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels})') diff --git a/mindscience/sharker/nn/dense/__init__.py b/mindscience/sharker/nn/dense/__init__.py new file mode 100644 index 000000000..02ca87f35 --- /dev/null +++ b/mindscience/sharker/nn/dense/__init__.py @@ -0,0 +1,14 @@ +r"""Dense neural network module package. + +This package provides modules applicable for operating on dense tensor +representations. +""" + +from .linear import HeteroLinear, HeteroDictLinear + +__all__ = [ + "HeteroLinear", + "HeteroDictLinear", +] + +lin_classes = __all__[:2] diff --git a/mindscience/sharker/nn/dense/linear.py b/mindscience/sharker/nn/dense/linear.py new file mode 100644 index 000000000..52ee2df2f --- /dev/null +++ b/mindscience/sharker/nn/dense/linear.py @@ -0,0 +1,272 @@ +import math +from typing import Any, Dict, Optional, Tuple, Union + +import mindspore as ms +from mindspore import Tensor, ops, nn, Parameter +from mindspore.common.initializer import initializer as init +from ..inits import glorot, zeros, uniform +from ...utils import index2ptr + + +def is_uninitialized_parameter(x: Any) -> bool: + if not hasattr(Parameter, "UninitializedParameter"): + return False + return x.is_init + + +def reset_weight_( + weight: Tensor, in_channels: int, initializer: Optional[str] = None +) -> Tensor: + if in_channels <= 0: + pass + elif initializer == "glorot": + glorot(weight) + elif initializer == "uniform": + bound = 1.0 / math.sqrt(in_channels) + weight[:] = init("uniform", weight.shape).init_data() * bound + else: + raise RuntimeError(f"Weight initializer '{initializer}' not supported") + + return weight + + +def reset_bias_( + bias: Optional[Tensor], in_channels: int, initializer: Optional[str] = None +) -> Optional[Tensor]: + if bias is None or in_channels <= 0: + pass + elif initializer == "zeros": + zeros(bias) + elif initializer is None: + uniform(bias, size=in_channels) + else: + raise RuntimeError(f"Bias initializer '{initializer}' not supported") + + return bias + + +class HeteroLinear(nn.Cell): + r"""Applies separate linear tranformations to the incoming data according + to types. + + For type :math:`\kappa`, it computes + + .. math:: + \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} + \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. + + It supports lazy initialization and customizable weight and bias + initialization. + + Args: + in_channels (int): Size of each input sample. Will be initialized + lazily in case it is given as :obj:`-1`. + out_channels (int): Size of each output sample. + num_types (int): The number of types. + is_sorted (bool, optional): If set to :obj:`True`, assumes that + :obj:`type_vec` is sorted. This avoids internal re-sorting of the + data and can improve runtime and memory efficiency. + (default: :obj:`False`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.Dense`. + + Shapes: + - **input:** + features :math:`(*, F_{in})`, + type vector :math:`(*)` + - **output:** features :math:`(*, F_{out})` + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + num_types: int, + is_sorted: bool = False, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.num_types = num_types + self.is_sorted = is_sorted + self.kwargs = kwargs + + if self.in_channels != -1: + self.weight = ms.Parameter( + ms.numpy.empty([num_types, in_channels, out_channels]) + ) + if kwargs.get("bias", True): + self.bias = ms.Parameter(ms.numpy.empty([num_types, out_channels])) + self.reset_parameters() + + # Timing cache for benchmarking naive vs. segment matmul usage: + self._timing_cache: Dict[int, Tuple[float, float]] = {} + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + if hasattr(self, 'weight'): + reset_weight_( + self.weight, self.in_channels, self.kwargs.get("weight_init", 'uniform') + ) + if hasattr(self, 'bias'): + reset_bias_(self.bias, self.in_channels, self.kwargs.get("bias_init", 'zeros')) + + def construct_naive(self, x: Tensor, type_ptr: Tensor) -> Tensor: + out = ops.zeros([x.shape[0], self.out_channels], dtype=x.dtype) + for i, (start, end) in enumerate(zip(type_ptr[:-1], type_ptr[1:])): + out[start:end] = x[start:end] @ self.weight[i] + return out + + def construct(self, x: Tensor, type_vec: Tensor) -> Tensor: + r"""The forward pass. + + Args: + x (Tensor): The input features. + type_vec (Tensor): A vector that maps each entry to a type. + """ + perm: Optional[Tensor] = None + if not self.is_sorted and (type_vec[1:] < type_vec[:-1]).any(): + perm = ops.argsort(type_vec) + type_vec = type_vec[perm] + x = x[perm] + + type_ptr = index2ptr(type_vec, self.num_types) + + out = self.construct_naive(x, type_ptr) + + if self.bias is not None: + out += self.bias[type_vec] + + if perm is not None: # Restore original order (if necessary). + out_unsorted = ms.numpy.empty_like(out) + out_unsorted[perm] = out + out = out_unsorted + + return out + + def initialize_parameters(self, module, input): + if is_uninitialized_parameter(self.weight): + self.in_channels = input[0].shape[-1] + self.weight.materialize( + (self.num_types, self.in_channels, self.out_channels) + ) + self.reset_parameters() + # self._hook.remove() + delattr(self, "_hook") + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"{self.out_channels}, num_types={self.num_types}, " + f'bias={self.kwargs.get("bias", True)})' + ) + + +class HeteroDictLinear(nn.Cell): + r"""Applies separate linear tranformations to the incoming data dictionary. + + For key :math:`\kappa`, it computes + + .. math:: + \mathbf{x}^{\prime}_{\kappa} = \mathbf{x}_{\kappa} + \mathbf{W}^{\top}_{\kappa} + \mathbf{b}_{\kappa}. + + It supports lazy initialization and customizable weight and bias + initialization. + + Args: + in_channels (int or Dict[Any, int]): Size of each input sample. If + passed an integer, :obj:`types` will be a mandatory argument. + initialized lazily in case it is given as :obj:`-1`. + out_channels (int): Size of each output sample. + types (List[Any], optional): The keys of the input dictionary. + (default: :obj:`None`) + **kwargs (optional): Additional arguments of + :class:`sharker.nn.Dense`. + """ + + def __init__( + self, + in_channels: Union[int, Dict[Any, int]], + out_channels: int, + types: Optional[Any] = None, + **kwargs, + ): + super().__init__() + + if isinstance(in_channels, dict): + self.types = list(in_channels.keys()) + + # if any([i == -1 for i in in_channels.values()]): + # self._hook = self.register_forward_pre_hook(self.initialize_parameters) + + if types is not None and set(self.types) != set(types): + raise ValueError( + "The provided 'types' do not match with the " + "keys in the 'in_channels' dictionary" + ) + + else: + if types is None: + raise ValueError( + "Please provide a list of 'types' if passing " + "'in_channels' as an integer" + ) + + # if in_channels == -1: + # self._hook = self.register_forward_pre_hook(self.initialize_parameters) + + self.types = types + in_channels = {node_type: in_channels for node_type in types} + + self.in_channels = in_channels + self.out_channels = out_channels + self.kwargs = kwargs + + self.lins = nn.CellDict({ + key: nn.Dense(channels, self.out_channels, **kwargs) + for key, channels in self.in_channels.items() + }) + + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + # for lin in self.lins.values(): + # lin.reset_parameters() + + def construct( + self, + x_dict: Dict[str, Tensor], + ) -> Dict[str, Tensor]: + r"""Forward pass. + + Args: + x_dict (Dict[Any, Tensor]): A dictionary holding input + features for each individual type. + """ + out_dict = {} + + for key, lin in self.lins.items(): + if key in x_dict: + out_dict[key] = lin(x_dict[key]) + + return out_dict + + def initialize_parameters(self, module, input): + for key, x in input[0].items(): + lin = self.lins[key] + if is_uninitialized_parameter(lin.weight): + self.lins[key].initialize_parameters(None, x) + self.reset_parameters() + # self._hook.remove() + self.in_channels = {key: x.shape[-1] for key, x in input[0].items()} + delattr(self, "_hook") + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f'{self.out_channels}, bias={self.kwargs.get("bias", True)})' + ) diff --git a/mindscience/sharker/nn/encoding.py b/mindscience/sharker/nn/encoding.py new file mode 100644 index 000000000..c23db7223 --- /dev/null +++ b/mindscience/sharker/nn/encoding.py @@ -0,0 +1,97 @@ +import math +from mindspore import Parameter, Tensor, nn, ops + + +class PositionalEncoding(nn.Cell): + r"""The positional encoding scheme from the `"Attention Is All You Need" + `_ paper. + + .. math:: + + PE(x)_{2 \cdot i} &= \sin(x / 10000^{2 \cdot i / d}) + + PE(x)_{2 \cdot i + 1} &= \cos(x / 10000^{2 \cdot i / d}) + + where :math:`x` is the position and :math:`i` is the dimension. + + Args: + out_channels (int): Size :math:`d` of each output sample. + base_freq (float, optional): The base frequency of sinusoidal + functions. (default: :obj:`1e-4`) + granularity (float, optional): The granularity of the positions. If + set to smaller value, the encoder will capture more fine-grained + changes in positions. (default: :obj:`1.0`) + """ + + def __init__( + self, + out_channels: int, + base_freq: float = 1e-4, + granularity: float = 1.0, + ): + super().__init__() + + if out_channels % 2 != 0: + raise ValueError(f"Cannot use sinusoidal positional encoding with " + f"odd 'out_channels' (got {out_channels}).") + + self.out_channels = out_channels + self.base_freq = base_freq + self.granularity = granularity + + frequency = ops.logspace(0, 1, out_channels // 2, base_freq) + self.frequency = Parameter(frequency, requires_grad=False) + + self.reset_parameters() + + def reset_parameters(self): + pass + + def construct(self, x: Tensor) -> Tensor: + """""" # noqa: D419 + x = x / self.granularity if self.granularity != 1.0 else x + out = x.view(-1, 1) * self.frequency.view(1, -1) + return ops.cat([ops.sin(out), ops.cos(out)], axis=-1) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.out_channels})' + + +class TemporalEncoding(nn.Cell): + r"""The time-encoding function from the `"Do We Really Need Complicated + Model Architectures for Temporal Networks?" + `_ paper. + + It first maps each entry to a vector with exponentially decreasing values, + and then uses the cosine function to project all values to range + :math:`[-1, 1]`. + + .. math:: + y_{i} = \cos \left(x \cdot \sqrt{d}^{-(i - 1)/\sqrt{d}} \right) + + where :math:`d` defines the output feature dimension, and + :math:`1 \leq i \leq d`. + + Args: + out_channels (int): Size :math:`d` of each output sample. + """ + + def __init__(self, out_channels: int): + super().__init__() + self.out_channels = out_channels + + sqrt = math.sqrt(out_channels) + weight = 1.0 / sqrt**ops.linspace(0, sqrt, out_channels).view(1, -1) + self.weight = Parameter(weight, requires_grad=False) + + self.reset_parameters() + + def reset_parameters(self): + pass + + def construct(self, x: Tensor) -> Tensor: + """""" # noqa: D419 + return ops.cos(x.view(-1, 1) @ self.weight) + + def __repr__(self) -> str: + return f'{self.__class__.__name__}({self.out_channels})' diff --git a/mindscience/sharker/nn/inits.py b/mindscience/sharker/nn/inits.py new file mode 100644 index 000000000..233db2c32 --- /dev/null +++ b/mindscience/sharker/nn/inits.py @@ -0,0 +1,103 @@ +import math +from typing import Any, Tuple + +from mindspore import Tensor, ops +from mindspore.common.initializer import initializer + + +def uniform(value: Any, size: int = None, bound: Tuple[float, float] = None): + if isinstance(value, Tensor): + if bound is None: + bound = Tensor(1.0 / math.sqrt(size)) + bound = (-bound, bound) + value[:] = ops.uniform(value.shape, bound[0], bound[1]) + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + uniform(v, size, bound) + + +def xavier_uniform(value: Any): + if isinstance(value, Tensor): + init = initializer('xavier_uniform', value.shape, dtype=value.dtype) + value[:] = init.init_data() + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + xavier_uniform(v) + + +def kaiming_uniform(value: Any, a: float, fan_mode: str = 'fan_in'): + fan_in, fan_out = _calculate_fan_in_and_fan_out(value) + fan = fan_in if fan_mode == 'fan_in' else fan_out + if isinstance(value, Tensor): + bound = Tensor(math.sqrt(6 / ((1 + a**2) * fan))) + value[:] = ops.uniform(value.shape, -bound, bound) + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + kaiming_uniform(v, a, fan_mode) + + +def _calculate_fan_in_and_fan_out(tensor): + dimensions = tensor.dim() + if dimensions < 2: + raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions") + + num_input_fmaps = tensor.shape[1] + num_output_fmaps = tensor.shape[0] + receptive_field_size = 1 + if tensor.dim() > 2: + # np.prod is not always available, accumulate the product manually + # we could use functools.reduce but that is not supported by TorchScript + for s in tensor.shape[2:]: + receptive_field_size *= s + fan_in = num_input_fmaps * receptive_field_size + fan_out = num_output_fmaps * receptive_field_size + + return fan_in, fan_out + + +def glorot(value: Any): + if isinstance(value, Tensor): + stdv = Tensor(math.sqrt(6.0 / (value.shape[-2] + value.shape[-1]))) + value[:] = ops.uniform(value.shape, -stdv, stdv) + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + glorot(v) + + +def glorot_orthogonal(tensor, scale): + if tensor is not None: + tensor[:] = initializer("orthogonal", [2, 3, 4], tensor.dtype).init_data() + scale /= (tensor.shape[-2] + tensor.shape[-1]) * tensor.var() + tensor *= scale.sqrt() + + +def constant(value: Any, fill_value: float): + if isinstance(value, Tensor): + value[:] = fill_value + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + constant(v, fill_value) + + +def zeros(value: Any): + constant(value, 0.0) + + +def ones(tensor: Any): + constant(tensor, 1.0) + + +def normal(value: Any, mean: float, std: float): + if isinstance(value, Tensor): + value[:] = ops.normal(value.shape, mean, std) + else: + for v in value.parameters() if hasattr(value, "parameters") else []: + normal(v, mean, std) + + +def reset(value: Any): + if hasattr(value, "reset_parameters"): + value.reset_parameters() + else: + for child in value.children() if hasattr(value, "children") else []: + reset(child) diff --git a/mindscience/sharker/nn/lr_scheduler.py b/mindscience/sharker/nn/lr_scheduler.py new file mode 100644 index 000000000..6ce60b472 --- /dev/null +++ b/mindscience/sharker/nn/lr_scheduler.py @@ -0,0 +1,251 @@ +# See HuggingFace `transformers/optimization.py`. +import functools +import math + +from mindspore.experimental.optim import Optimizer +from mindspore.experimental.optim.lr_scheduler import LambdaLR + + +class ConstantWithWarmupLR(LambdaLR): + r"""Creates a LR scheduler with a constant learning rate preceded by a + warmup period during which the learning rate increases linearly between + :obj:`0` and the initial LR set in the optimizer. + + Args: + optimizer (Optimizer): The optimizer to be scheduled. + num_warmup_steps (int): The number of steps for the warmup phase. + last_epoch (int, optional): The index of the last epoch when resuming + training. (default: :obj:`-1`) + """ + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + last_epoch: int = -1, + ): + lr_lambda = functools.partial( + self._lr_lambda, + num_warmup_steps=num_warmup_steps, + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _lr_lambda( + current_step: int, + num_warmup_steps: int, + ) -> float: + if current_step < num_warmup_steps: + return float(current_step) / float(max(1.0, num_warmup_steps)) + return 1.0 + + +class LinearWithWarmupLR(LambdaLR): + r"""Creates a LR scheduler with a learning rate that decreases linearly + from the initial LR set in the optimizer to :obj:`0`, after a warmup period + during which it increases linearly from :obj:`0` to the initial LR set in + the optimizer. + + Args: + optimizer (Optimizer): The optimizer to be scheduled. + num_warmup_steps (int): The number of steps for the warmup phase. + num_training_steps (int): The total number of training steps. + last_epoch (int, optional): The index of the last epoch when resuming + training. (default: :obj:`-1`) + """ + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + last_epoch: int = -1, + ): + lr_lambda = functools.partial( + self._lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _lr_lambda( + current_step: int, + num_warmup_steps: int, + num_training_steps: int, + ) -> float: + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + return max( + 0.0, + float(num_training_steps - current_step) + / float(max(1, num_training_steps - num_warmup_steps)), + ) + + +class CosineWithWarmupLR(LambdaLR): + r"""Creates a LR scheduler with a learning rate that decreases following + the values of the cosine function between the initial LR set in the + optimizer to :obj:`0`, after a warmup period during which it increases + linearly between :obj:`0` and the initial LR set in the optimizer. + + Args: + optimizer (Optimizer): The optimizer to be scheduled. + num_warmup_steps (int): The number of steps for the warmup phase. + num_training_steps (int): The total number of training steps. + num_cycles (float, optional): The number of waves in the cosine + schedule (the default decreases LR from the max value to :obj:`0` + following a half-cosine). (default: :obj:`0.5`) + last_epoch (int, optional): The index of the last epoch when resuming + training. (default: :obj:`-1`) + """ + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float = 0.5, + last_epoch: int = -1, + ): + lr_lambda = functools.partial( + self._lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _lr_lambda( + current_step: int, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: float, + ): + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + return max( + 0.0, + 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)), + ) + + +class CosineWithWarmupRestartsLR(LambdaLR): + r"""Creates a LR scheduler with a learning rate that decreases following + the values of the cosine function between the initial LR set in the + optimizer to :obj:`0`, with several hard restarts, after a warmup period + during which it increases linearly between :obj:`0` and the initial LR set + in the optimizer. + + Args: + optimizer (Optimizer): The optimizer to be scheduled. + num_warmup_steps (int): The number of steps for the warmup phase. + num_training_steps (int): The total number of training steps. + num_cycles (int, optional): The number of hard restarts to use. + (default: :obj:`3`) + last_epoch (int, optional): The index of the last epoch when resuming + training. (default: :obj:`-1`) + """ + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int = 3, + last_epoch: int = -1, + ): + lr_lambda = functools.partial( + self._lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + num_cycles=num_cycles, + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _lr_lambda( + current_step: int, + num_warmup_steps: int, + num_training_steps: int, + num_cycles: int, + ) -> float: + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + progress = float(current_step - num_warmup_steps) / float( + max(1, num_training_steps - num_warmup_steps) + ) + if progress >= 1.0: + return 0.0 + return max( + 0.0, + 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))), + ) + + +class PolynomialWithWarmupLR(LambdaLR): + r"""Creates a LR scheduler with a learning rate that decreases as a + polynomial decay from the initial LR set in the optimizer to end LR defined + by `lr_end`, after a warmup period during which it increases linearly from + :obj:`0` to the initial LR set in the optimizer. + + Args: + optimizer (Optimizer): The optimizer to be scheduled. + num_warmup_steps (int): The number of steps for the warmup phase. + num_training_steps (int): The total number of training steps. + lr_end (float, optional): The end learning rate. (default: :obj:`1e-7`) + power (float, optional): The power factor of the polynomial decay. + (default: :obj:`1.0`) + last_epoch (int, optional): The index of the last epoch when resuming + training. (default: :obj:`-1`) + """ + + def __init__( + self, + optimizer: Optimizer, + num_warmup_steps: int, + num_training_steps: int, + lr_end: float = 1e-7, + power: float = 1.0, + last_epoch: int = -1, + ): + lr_init = optimizer.defaults["lr"] + if not (lr_init > lr_end): + raise ValueError( + f"`lr_end` ({lr_end}) must be smaller than the " + f"initial lr ({lr_init})" + ) + + lr_lambda = functools.partial( + self._lr_lambda, + num_warmup_steps=num_warmup_steps, + num_training_steps=num_training_steps, + lr_init=lr_init, + lr_end=lr_end, + power=power, + ) + super().__init__(optimizer, lr_lambda, last_epoch) + + @staticmethod + def _lr_lambda( + current_step: int, + num_warmup_steps: int, + num_training_steps: int, + lr_init: float, + lr_end: float, + power: float, + ) -> float: + if current_step < num_warmup_steps: + return float(current_step) / float(max(1, num_warmup_steps)) + elif current_step > num_training_steps: + return lr_end / lr_init # As `LambdaLR` multiplies by `lr_init`. + else: + lr_range = lr_init - lr_end + decay_steps = num_training_steps - num_warmup_steps + pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps + decay = lr_range * pct_remaining**power + lr_end + return decay / lr_init # As `LambdaLR` multiplies by `lr_init`. diff --git a/mindscience/sharker/nn/models/__init__.py b/mindscience/sharker/nn/models/__init__.py new file mode 100644 index 000000000..45b4f2af8 --- /dev/null +++ b/mindscience/sharker/nn/models/__init__.py @@ -0,0 +1,3 @@ +from .mlp import * + +__all__ = ["MLP"] \ No newline at end of file diff --git a/mindscience/sharker/nn/models/mlp.py b/mindscience/sharker/nn/models/mlp.py new file mode 100644 index 000000000..744f138cc --- /dev/null +++ b/mindscience/sharker/nn/models/mlp.py @@ -0,0 +1,251 @@ +import inspect +import warnings +from typing import Any, Callable, Dict, Final, List, Optional, Union + +from mindspore import Tensor, ops, nn + +from ..resolver import activation_resolver, normalization_resolver + + +class MLP(nn.Cell): + r"""A Multi-Layer Perception (MLP) model. + + There exists two ways to instantiate an :class:`MLP`: + + 1. By specifying explicit channel sizes, *e.g.*, + + .. code-block:: python + + mlp = MLP([16, 32, 64, 128]) + + creates a three-layer MLP with **differently** sized hidden layers. + + 1. By specifying fixed hidden channel sizes over a number of layers, + *e.g.*, + + .. code-block:: python + + mlp = MLP(in_channels=16, hidden_channels=32, + out_channels=128, num_layers=3) + + creates a three-layer MLP with **equally** sized hidden layers. + + Args: + channel_list (List[int] or int, optional): List of input, intermediate + and output channels such that :obj:`len(channel_list) - 1` denotes + the number of layers of the MLP (default: :obj:`None`) + in_channels (int, optional): Size of each input sample. + Will override :attr:`channel_list`. (default: :obj:`None`) + hidden_channels (int, optional): Size of each hidden sample. + Will override :attr:`channel_list`. (default: :obj:`None`) + out_channels (int, optional): Size of each output sample. + Will override :attr:`channel_list`. (default: :obj:`None`) + num_layers (int, optional): The number of layers. + Will override :attr:`channel_list`. (default: :obj:`None`) + dropout (float or List[float], optional): Dropout probability of each + hidden embedding. If a list is provided, sets the dropout value per + layer. (default: :obj:`0.`) + act (str or Callable, optional): The non-linear activation function to + use. (default: :obj:`"relu"`) + act_first (bool, optional): If set to :obj:`True`, activation is + applied before normalization. (default: :obj:`False`) + act_kwargs (Dict[str, Any], optional): Arguments passed to the + respective activation function defined by :obj:`act`. + (default: :obj:`None`) + norm (str or Callable, optional): The normalization function to + use. (default: :obj:`"batch_norm"`) + norm_kwargs (Dict[str, Any], optional): Arguments passed to the + respective normalization function defined by :obj:`norm`. + (default: :obj:`None`) + plain_last (bool, optional): If set to :obj:`False`, will apply + non-linearity, batch normalization and dropout to the last layer as + well. (default: :obj:`True`) + bias (bool or List[bool], optional): If set to :obj:`False`, the module + will not learn additive biases. If a list is provided, sets the + bias per layer. (default: :obj:`True`) + **kwargs (optional): Additional deprecated arguments of the MLP layer. + """ + + supports_norm_batch: Final[bool] + + def __init__( + self, + channel_list: Optional[Union[List[int], int]] = None, + *, + in_channels: Optional[int] = None, + hidden_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: Optional[int] = None, + dropout: Union[float, List[float]] = 0.0, + act: Union[str, Callable, None] = "relu", + act_first: bool = False, + act_kwargs: Optional[Dict[str, Any]] = None, + norm: Union[str, Callable, None] = "batch_norm", + norm_kwargs: Optional[Dict[str, Any]] = None, + plain_last: bool = True, + bias: Union[bool, List[bool]] = True, + **kwargs, + ): + super().__init__() + + # Backward compatibility: + act_first = act_first or kwargs.get("relu_first", False) + batch_norm = kwargs.get("batch_norm", None) + if batch_norm is not None and isinstance(batch_norm, bool): + warnings.warn( + "Argument `batch_norm` is deprecated, " + "please use `norm` to specify normalization layer." + ) + norm = "batch_norm" if batch_norm else None + batch_norm_kwargs = kwargs.get("batch_norm_kwargs", None) + norm_kwargs = batch_norm_kwargs or {} + + if isinstance(channel_list, int): + in_channels = channel_list + + if in_channels is not None: + if num_layers is None: + raise ValueError("Argument `num_layers` must be given") + if num_layers > 1 and hidden_channels is None: + raise ValueError( + f"Argument `hidden_channels` must be given " + f"for `num_layers={num_layers}`" + ) + if out_channels is None: + raise ValueError("Argument `out_channels` must be given") + + channel_list = [hidden_channels] * (num_layers - 1) + channel_list = [in_channels] + channel_list + [out_channels] + + assert isinstance(channel_list, (tuple, list)) + assert len(channel_list) >= 2 + self.channel_list = channel_list + + self.act = activation_resolver(act, **(act_kwargs or {})) + self.act_first = act_first + self.plain_last = plain_last + + if isinstance(dropout, float): + dropout = [dropout] * (len(channel_list) - 1) + if plain_last: + dropout[-1] = 0.0 + if len(dropout) != len(channel_list) - 1: + raise ValueError( + f"Number of dropout values provided ({len(dropout)} does not " + f"match the number of layers specified " + f"({len(channel_list)-1})" + ) + self.dropout = dropout + + if isinstance(bias, bool): + bias = [bias] * (len(channel_list) - 1) + if len(bias) != len(channel_list) - 1: + raise ValueError( + f"Number of bias values provided ({len(bias)}) does not match " + f"the number of layers specified ({len(channel_list)-1})" + ) + + self.lins = nn.CellList() + iterator = zip(channel_list[:-1], channel_list[1:], bias) + for in_channels, out_channels, _bias in iterator: + self.lins.append(nn.Dense(in_channels, out_channels, has_bias=_bias)) + + self.norms = nn.CellList() + iterator = channel_list[1:-1] if plain_last else channel_list[1:] + for hidden_channels in iterator: + if norm is not None: + norm_layer = normalization_resolver( + norm, + hidden_channels, + **(norm_kwargs or {}), + ) + else: + norm_layer = nn.Identity() + self.norms.append(norm_layer) + + self.supports_norm_batch = False + if len(self.norms) > 0 and hasattr(self.norms[0], "forward"): + norm_params = inspect.signature(self.norms[0].forward).parameters + self.supports_norm_batch = "batch" in norm_params + + # self.reset_parameters() + + @property + def in_channels(self) -> int: + r"""Size of each input sample.""" + return self.channel_list[0] + + @property + def out_channels(self) -> int: + r"""Size of each output sample.""" + return self.channel_list[-1] + + @property + def num_layers(self) -> int: + r"""The number of layers.""" + return len(self.channel_list) - 1 + + # def reset_parameters(self): + # r"""Resets all learnable parameters of the module.""" + # for lin in self.lins: + # if hasattr(lin, "reset_parameters"): + # lin.reset_parameters() + # for norm in self.norms: + # if hasattr(norm, "reset_parameters"): + # norm.reset_parameters() + + def construct( + self, + x: Tensor, + batch: Optional[Tensor] = None, + batch_size: Optional[int] = None, + return_emb: Optional[Tensor] = None, + ) -> Tensor: + r"""Forward pass. + + Args: + x (Tensor): The source tensor. + batch (Tensor, optional): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns + each element to a specific example. + Only needs to be passed in case the underlying normalization + layers require the :obj:`batch` information. + (default: :obj:`None`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. + Only needs to be passed in case the underlying normalization + layers require the :obj:`batch` information. + (default: :obj:`None`) + return_emb (bool, optional): If set to :obj:`True`, will + additionally return the embeddings before execution of the + final output layer. (default: :obj:`False`) + """ + # `return_emb` is annotated here as `NoneType` to be compatible with + # TorchScript, which does not support different return types based on + # the value of an input argument. + emb: Optional[Tensor] = None + + # If `plain_last=True`, then `len(norms) = len(lins) -1, thus skipping + # the execution of the last layer inside the for-loop. + for i, (lin, norm) in enumerate(zip(self.lins, self.norms)): + x = lin(x) + if self.act is not None and self.act_first: + x = self.act(x) + if self.supports_norm_batch: + x = norm(x, batch, batch_size) + else: + x = norm(x) + if self.act is not None and not self.act_first: + x = self.act(x) + x = ops.dropout(x, p=self.dropout[i], training=self.training) + if isinstance(return_emb, bool) and return_emb is True: + emb = x + + if self.plain_last: + x = self.lins[-1](x) + x = ops.dropout(x, p=self.dropout[-1], training=self.training) + + return (x, emb) if isinstance(return_emb, bool) else x + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({str(self.channel_list)[1:-1]})" diff --git a/mindscience/sharker/nn/norm/__init__.py b/mindscience/sharker/nn/norm/__init__.py new file mode 100644 index 000000000..a19547ad3 --- /dev/null +++ b/mindscience/sharker/nn/norm/__init__.py @@ -0,0 +1,12 @@ +r"""Normalization package.""" + +from .batch_norm import BatchNorm, HeteroBatchNorm +from .msg_norm import MessageNorm + +__all__ = [ + 'BatchNorm', + 'HeteroBatchNorm', + 'MessageNorm', +] + +classes = __all__ diff --git a/mindscience/sharker/nn/norm/batch_norm.py b/mindscience/sharker/nn/norm/batch_norm.py new file mode 100644 index 000000000..fd57606f3 --- /dev/null +++ b/mindscience/sharker/nn/norm/batch_norm.py @@ -0,0 +1,213 @@ +from typing import Optional + +import mindspore as ms +from mindspore import Tensor, Parameter, nn, ops +from ..aggr.fused import FusedAggregation +from ..inits import zeros, ones + + +class BatchNorm(nn.Cell): + r"""Applies batch normalization over a batch of features as described in + the `"Batch Normalization: Accelerating Deep Network Training by + Reducing Internal Covariate Shift" `_ + paper. + + .. math:: + \mathbf{x}^{\prime}_i = \frac{\mathbf{x} - + \textrm{E}[\mathbf{x}]}{\sqrt{\textrm{Var}[\mathbf{x}] + \epsilon}} + \odot \gamma + \beta + + The mean and standard-deviation are calculated per-dimension over all nodes + inside the mini-batch. + + Args: + in_channels (int): Size of each input sample. + eps (float, optional): A value added to the denominator for numerical + stability. (default: :obj:`1e-5`) + momentum (float, optional): The value used for the running mean and + running variance computation. (default: :obj:`0.1`) + affine (bool, optional): If set to :obj:`True`, this module has + learnable affine parameters :math:`\gamma` and :math:`\beta`. + (default: :obj:`True`) + track_running_stats (bool, optional): If set to :obj:`True`, this + module tracks the running mean and variance, and when set to + :obj:`False`, this module does not track such statistics and always + uses batch statistics in both training and eval modes. + (default: :obj:`True`) + allow_single_element (bool, optional): If set to :obj:`True`, batches + with only a single element will work as during in evaluation. + That is the running mean and variance will be used. + Requires :obj:`track_running_stats=True`. (default: :obj:`False`) + """ + + def __init__( + self, + in_channels: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + allow_single_element: bool = False, + ): + super().__init__() + + if allow_single_element and not track_running_stats: + raise ValueError( + "'allow_single_element' requires " + "'track_running_stats' to be set to `True`" + ) + + self.module = nn.BatchNorm1d( + in_channels, eps, momentum, affine, track_running_stats + ) + self.in_channels = in_channels + self.allow_single_element = allow_single_element + + def reset_running_stats(self): + r"""Resets all running statistics of the module.""" + self.module.reset_running_stats() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.module.reset_parameters() + + def construct(self, x: Tensor) -> Tensor: + r"""Forward pass. + + Args: + x (Tensor): The source tensor. + """ + if self.allow_single_element and x.shape[0] <= 1: + return ops.batch_norm( + x, + self.module.running_mean, + self.module.running_var, + self.module.weight, + self.module.bias, + False, # bn_training + 0.0, # momentum + self.module.eps, + ) + return self.module(x) + + def __repr__(self): + return f"{self.__class__.__name__}({self.module.num_features})" + + +class HeteroBatchNorm(nn.Cell): + r"""Applies batch normalization over a batch of heterogeneous features as + described in the `"Batch Normalization: Accelerating Deep Network Training + by Reducing Internal Covariate Shift" `_ + paper. + Compared to :class:`BatchNorm`, :class:`HeteroBatchNorm` applies + normalization individually for each node or edge type. + + Args: + in_channels (int): Size of each input sample. + num_types (int): The number of types. + eps (float, optional): A value added to the denominator for numerical + stability. (default: :obj:`1e-5`) + momentum (float, optional): The value used for the running mean and + running variance computation. (default: :obj:`0.1`) + affine (bool, optional): If set to :obj:`True`, this module has + learnable affine parameters :math:`\gamma` and :math:`\beta`. + (default: :obj:`True`) + track_running_stats (bool, optional): If set to :obj:`True`, this + module tracks the running mean and variance, and when set to + :obj:`False`, this module does not track such statistics and always + uses batch statistics in both training and eval modes. + (default: :obj:`True`) + """ + + def __init__( + self, + in_channels: int, + num_types: int, + eps: float = 1e-5, + momentum: Optional[float] = 0.1, + affine: bool = True, + track_running_stats: bool = True, + ): + super().__init__() + + self.in_channels = in_channels + self.num_types = num_types + self.eps = eps + self.momentum = momentum + self.affine = affine + self.track_running_stats = track_running_stats + + if self.affine: + self.weight = Parameter(ms.numpy.empty(num_types, in_channels)) + self.bias = Parameter(ms.numpy.empty(num_types, in_channels)) + else: + self.weight = None + self.bias = None + + if self.track_running_stats: + self.running_mean = ms.Parameter(ms.numpy.empty(num_types, in_channels), requires_grad=False) + self.running_var = ms.Parameter(ms.numpy.empty(num_types, in_channels), requires_grad=False) + self.num_batches_tracked = ms.Parameter(Tensor(0), requires_grad=False) + else: + self.running_mean = None + self.running_var = None + self.num_batches_tracked = None + + self.mean_var = FusedAggregation(["mean", "var"]) + + self.reset_parameters() + + def reset_running_stats(self): + r"""Resets all running statistics of the module.""" + if self.track_running_stats: + self.running_mean[:] = 0 + self.running_var[:] = 1 + self.num_batches_tracked[:] = 0 + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.reset_running_stats() + if self.affine: + ones(self.weight) + zeros(self.bias) + + def construct(self, x: Tensor, type_vec: Tensor) -> Tensor: + r"""Forward pass. + + Args: + x (Tensor): The input features. + type_vec (Tensor): A vector that maps each entry to a type. + """ + if not self.training and self.track_running_stats: + mean, var = self.running_mean, self.running_var + else: + mean, var = self.mean_var(x, type_vec, dim_size=self.num_types) + + if self.training and self.track_running_stats: + if self.momentum is None: + self.num_batches_tracked.add_(1) + exp_avg_factor = 1.0 / float(self.num_batches_tracked) + else: + exp_avg_factor = self.momentum + + type_index = ops.unique(type_vec)[0] + + self.running_mean[type_index] = (1.0 - exp_avg_factor) * self.running_mean[ + type_index + ] + exp_avg_factor * mean[type_index] + self.running_var[type_index] = (1.0 - exp_avg_factor) * self.running_var[ + type_index + ] + exp_avg_factor * var[type_index] + + out = (x - mean[type_vec]) / var.clamp(self.eps).sqrt()[type_vec] + + if self.affine: + out = out * self.weight[type_vec] + self.bias[type_vec] + + return out + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}({self.in_channels}, " + f"num_types={self.num_types})" + ) diff --git a/mindscience/sharker/nn/norm/msg_norm.py b/mindscience/sharker/nn/norm/msg_norm.py new file mode 100644 index 000000000..067295007 --- /dev/null +++ b/mindscience/sharker/nn/norm/msg_norm.py @@ -0,0 +1,48 @@ +import mindspore as ms +from mindspore import Tensor, nn +from mindspore import Parameter + + +class MessageNorm(nn.Cell): + r"""Applies message normalization over the aggregated messages as described + in the `"DeeperGCNs: All You Need to Train Deeper GCNs" + `_ paper. + + .. math:: + + \mathbf{x}_i^{\prime} = \mathrm{MLP} \left( \mathbf{x}_{i} + s \cdot + {\| \mathbf{x}_i \|}_2 \cdot + \frac{\mathbf{m}_{i}}{{\|\mathbf{m}_i\|}_2} \right) + + Args: + learn_scale (bool, optional): If set to :obj:`True`, will learn the + scaling factor :math:`s` of message normalization. + (default: :obj:`False`) + """ + + def __init__(self, learn_scale: bool = False): + super().__init__() + self.scale = Parameter(ms.numpy.empty(1)) + self.reset_parameters() + + def reset_parameters(self): + r"""Resets all learnable parameters of the module.""" + self.scale[:] = 1.0 + + def construct(self, x: Tensor, msg: Tensor, p: float = 2.0) -> Tensor: + r"""Forward pass. + + Args: + x (Tensor): The source tensor. + msg (Tensor): The message tensor :math:`\mathbf{M}`. + p (float, optional): The norm :math:`p` to use for normalization. + (default: :obj:`2.0`) + """ + msg /= ms.numpy.norm(msg, ord=p, axis=-1, keepdims=True) + msg[msg.isnan()] = 0 + x_norm = ms.numpy.norm(x, ord=p, axis=-1, keepdims=True) + return msg * x_norm * self.scale + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}' + f'(learn_scale={self.scale.requires_grad})') diff --git a/mindscience/sharker/nn/reshape.py b/mindscience/sharker/nn/reshape.py new file mode 100644 index 000000000..59379fad6 --- /dev/null +++ b/mindscience/sharker/nn/reshape.py @@ -0,0 +1,16 @@ +from mindspore import Tensor, nn + + +class Reshape(nn.Cell): + def __init__(self, *shape): + super().__init__() + self.shape = shape + + def construct(self, x: Tensor) -> Tensor: + """""" # noqa: D419 + x = x.view(*self.shape) + return x + + def __repr__(self) -> str: + shape = ', '.join([str(dim) for dim in self.shape]) + return f'{self.__class__.__name__}({shape})' diff --git a/mindscience/sharker/nn/resolver.py b/mindscience/sharker/nn/resolver.py new file mode 100644 index 000000000..0932b9ef8 --- /dev/null +++ b/mindscience/sharker/nn/resolver.py @@ -0,0 +1,174 @@ +import inspect +from typing import Any, Optional, Union + +from mindspore import Tensor, nn +from mindspore.experimental.optim import Optimizer, lr_scheduler +from mindspore.experimental.optim.lr_scheduler import ReduceLROnPlateau +from mindspore.experimental.optim.lr_scheduler import LRScheduler + +from .lr_scheduler import ( + ConstantWithWarmupLR, + CosineWithWarmupLR, + CosineWithWarmupRestartsLR, + LinearWithWarmupLR, + PolynomialWithWarmupLR, +) +from ..resolver import normalize_string, resolver + + +def swish(x: Tensor) -> Tensor: + return x * x.sigmoid() + + +def activation_resolver(query: Union[Any, str] = "relu", *args, **kwargs): + base_cls = nn.Cell + base_cls_repr = "Act" + acts = [ + act + for act in vars(nn.layer.activation).values() + if isinstance(act, type) and issubclass(act, base_cls) + ] + acts += [ + swish, + ] + act_dict = {} + return resolver(acts, act_dict, query, base_cls, base_cls_repr, *args, **kwargs) + + +# Normalization Resolver ###################################################### + + +def normalization_resolver(query: Union[Any, str], *args, **kwargs): + from . import norm + + base_cls = nn.Cell + base_cls_repr = "Norm" + norms = [ + norm + for norm in vars(norm).values() + if isinstance(norm, type) and issubclass(norm, base_cls) + ] + norm_dict = {} + return resolver(norms, norm_dict, query, base_cls, base_cls_repr, *args, **kwargs) + + +# Aggregation Resolver ######################################################## + + +def aggregation_resolver(query: Union[Any, str], *args, **kwargs): + from . import aggr + + if isinstance(query, (list, tuple)): + return aggr.MultiAggregation(query, *args, **kwargs) + + base_cls = aggr.Aggregation + aggrs = [ + aggr + for aggr in vars(aggr).values() + if isinstance(aggr, type) and issubclass(aggr, base_cls) + ] + aggr_dict = { + "add": aggr.SumAggregation, + } + return resolver(aggrs, aggr_dict, query, base_cls, None, *args, **kwargs) + + +# Optimizer Resolver ########################################################## + + +def optimizer_resolver(query: Union[Any, str], *args, **kwargs): + base_cls = Optimizer + optimizers = [ + optimizer + for optimizer in vars(nn.optim).values() + if isinstance(optimizer, type) and issubclass(optimizer, base_cls) + ] + return resolver(optimizers, {}, query, base_cls, None, *args, **kwargs) + + +# Learning Rate Scheduler Resolver ############################################ + + +def lr_scheduler_resolver( + query: Union[Any, str], + optimizer: Optimizer, + warmup_ratio_or_steps: Optional[Union[float, int]] = 0.1, + num_training_steps: Optional[int] = None, + **kwargs, +) -> Union[LRScheduler, ReduceLROnPlateau]: + r"""A resolver to obtain a learning rate scheduler implemented in either + PyG or PyTorch from its name or type. + + Args: + query (Any or str): The query name of the learning rate scheduler. + optimizer (Optimizer): The optimizer to be scheduled. + warmup_ratio_or_steps (float or int, optional): The number of warmup + steps. If given as a `float`, it will act as a ratio that gets + multiplied with the number of training steps to obtain the number + of warmup steps. Only required for warmup-based LR schedulers. + (default: :obj:`0.1`) + num_training_steps (int, optional): The total number of training steps. + (default: :obj:`None`) + **kwargs (optional): Additional arguments of the LR scheduler. + """ + if not isinstance(query, str): + return query + + if isinstance(warmup_ratio_or_steps, float): + if warmup_ratio_or_steps < 0 or warmup_ratio_or_steps > 1: + raise ValueError( + f"`warmup_ratio_or_steps` needs to be between " + f"0.0 and 1.0 when given as a floating point " + f"number (got {warmup_ratio_or_steps})." + ) + if num_training_steps is not None: + warmup_steps = round(warmup_ratio_or_steps * num_training_steps) + elif isinstance(warmup_ratio_or_steps, int): + if warmup_ratio_or_steps < 0: + raise ValueError( + f"`warmup_ratio_or_steps` needs to be positive " + f"when given as an integer " + f"(got {warmup_ratio_or_steps})." + ) + warmup_steps = warmup_ratio_or_steps + else: + raise ValueError( + f"Found invalid type of `warmup_ratio_or_steps` " + f"(got {type(warmup_ratio_or_steps)})" + ) + + base_cls = LRScheduler + classes = [ + scheduler + for scheduler in vars(lr_scheduler).values() + if isinstance(scheduler, type) and issubclass(scheduler, base_cls) + ] + [ReduceLROnPlateau] + + customized_lr_schedulers = [ + ConstantWithWarmupLR, + LinearWithWarmupLR, + CosineWithWarmupLR, + CosineWithWarmupRestartsLR, + PolynomialWithWarmupLR, + ] + classes += customized_lr_schedulers + + query_repr = normalize_string(query) + base_cls_repr = normalize_string("LR") + + for cls in classes: + cls_repr = normalize_string(cls.__name__) + if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, "")]: + if inspect.isclass(cls): + if cls in customized_lr_schedulers: + cls_keys = inspect.signature(cls).parameters.keys() + if "num_warmup_steps" in cls_keys: + kwargs["num_warmup_steps"] = warmup_steps + if "num_training_steps" in cls_keys: + kwargs["num_training_steps"] = num_training_steps + obj = cls(optimizer, **kwargs) + return obj + return cls + + choices = set(cls.__name__ for cls in classes) + raise ValueError(f"Could not resolve '{query}' among choices {choices}") diff --git a/mindscience/sharker/profile/__init__.py b/mindscience/sharker/profile/__init__.py new file mode 100644 index 000000000..b89d11498 --- /dev/null +++ b/mindscience/sharker/profile/__init__.py @@ -0,0 +1,19 @@ +r"""GNN profiling package.""" + +from .benchmark import benchmark +from .utils import ( + count_parameters, + get_cpu_memory_from_gc, + get_data_size, + get_model_size, +) + +__all__ = [ + 'count_parameters', + 'get_model_size', + 'get_data_size', + 'get_cpu_memory_from_gc', + 'benchmark', +] + +classes = __all__ diff --git a/mindscience/sharker/profile/benchmark.py b/mindscience/sharker/profile/benchmark.py new file mode 100644 index 000000000..7a0243d91 --- /dev/null +++ b/mindscience/sharker/profile/benchmark.py @@ -0,0 +1,132 @@ +import time +from typing import Any, Callable, List, Optional, Tuple, Union + +from mindspore import Tensor, Parameter, ops + +from ..utils import is_sparse_tensor + + +def require_grad(x: Any, requires_grad: bool = True) -> Any: + if (isinstance(x, Parameter) and ops.is_floating_point(x) and not is_sparse_tensor(x)): + x.requires_grad = requires_grad + elif isinstance(x, list): + return [require_grad(v, requires_grad) for v in x] + elif isinstance(x, tuple): + return tuple(require_grad(v, requires_grad) for v in x) + elif isinstance(x, dict): + return {k: require_grad(v, requires_grad) for k, v in x.items()} + return x + + +def benchmark( + funcs: List[Callable], + args: Union[Tuple[Any], List[Tuple[Any]]], + num_steps: int, + func_names: Optional[List[str]] = None, + num_warmups: int = 10, + backward: bool = False, + per_step: bool = False, + progress_bar: bool = False, +): + r"""Benchmark a list of functions :obj:`funcs` that receive the same set + of arguments :obj:`args`. + + Args: + funcs ([Callable]): The list of functions to benchmark. + args ((Any, ) or [(Any, )]): The arguments to pass to the functions. + Can be a list of arguments for each function in :obj:`funcs` in + case their headers differ. + Alternatively, you can pass in functions that generate arguments + on-the-fly (e.g., useful for benchmarking models on various sizes). + num_steps (int): The number of steps to run the benchmark. + func_names ([str], optional): The names of the functions. If not given, + will try to infer the name from the function itself. + (default: :obj:`None`) + num_warmups (int, optional): The number of warmup steps. + (default: :obj:`10`) + backward (bool, optional): If set to :obj:`True`, will benchmark both + forward and backward passes. (default: :obj:`False`) + per_step (bool, optional): If set to :obj:`True`, will report runtimes + per step. (default: :obj:`False`) + progress_bar (bool, optional): If set to :obj:`True`, will print a + progress bar during benchmarking. (default: :obj:`False`) + """ + from tabulate import tabulate + + if num_steps <= 0: + raise ValueError(f"'num_steps' must be a positive integer " + f"(got {num_steps})") + + if num_warmups <= 0: + raise ValueError(f"'num_warmups' must be a positive integer " + f"(got {num_warmups})") + + if func_names is None: + func_names = [get_func_name(func) for func in funcs] + + if len(funcs) != len(func_names): + raise ValueError(f"Length of 'funcs' (got {len(funcs)}) and " + f"'func_names' (got {len(func_names)}) must be equal") + + # Zero-copy `args` for each function (if necessary): + args_list = [args] * len(funcs) if not isinstance(args, list) else args + + iterator = zip(funcs, args_list, func_names) + if progress_bar: + from tqdm import tqdm + iterator = tqdm(iterator, total=len(funcs)) + + ts: List[List[str]] = [] + for func, inputs, name in iterator: + t_forward = t_backward = 0 + for i in range(num_warmups + num_steps): + args = inputs() if callable(inputs) else inputs + args = require_grad(args, backward) + + t_start = time.perf_counter() + + out = func(*args) + + if i >= num_warmups: + t_forward += time.perf_counter() - t_start + + if backward: + if isinstance(out, (tuple, list)): + out = sum(o.sum() for o in out if isinstance(o, Tensor)) + elif isinstance(out, dict): + out = out.values() + out = sum(o.sum() for o in out if isinstance(o, Tensor)) + + out_grad = ops.randn_like(out) + t_start = time.perf_counter() + + out.backward(out_grad) + + if i >= num_warmups: + t_backward += time.perf_counter() - t_start + + if per_step: + ts.append([name, f'{t_forward/num_steps:.6f}s']) + else: + ts.append([name, f'{t_forward:.4f}s']) + if backward: + if per_step: + ts[-1].append(f'{t_backward/num_steps:.6f}s') + ts[-1].append(f'{(t_forward + t_backward)/num_steps:.6f}s') + else: + ts[-1].append(f'{t_backward:.4f}s') + ts[-1].append(f'{t_forward + t_backward:.4f}s') + + header = ['Name', 'Forward'] + if backward: + header.extend(['Backward', 'Total']) + + print(tabulate(ts, headers=header, tablefmt='psql')) + + +def get_func_name(func: Callable) -> str: + if hasattr(func, '__name__'): + return func.__name__ + elif hasattr(func, '__class__'): + return func.__class__.__name__ + raise ValueError("Could not infer name for function '{func}'") diff --git a/mindscience/sharker/profile/utils.py b/mindscience/sharker/profile/utils.py new file mode 100644 index 000000000..4458cf5f3 --- /dev/null +++ b/mindscience/sharker/profile/utils.py @@ -0,0 +1,93 @@ +import gc +import os +import os.path as osp +import random +import subprocess as sp +import sys +import warnings +from collections.abc import Mapping, Sequence +from typing import Any, Tuple + +import mindspore as ms +from mindspore import Tensor, nn + +from ..data import Data +# from ..typing import SparseTensor + + +def count_parameters(model: nn.Cell) -> int: + r"""Given a :class:`nn.Cell`, count its trainable parameters. + + Args: + model (mindspore.nn.Model): The model. + """ + return sum([p.numel() for p in model.parameters() if p.requires_grad]) + + +def get_model_size(model: nn.Cell) -> int: + r"""Given a :class:`nn.Cell`, get its actual disk size in bytes. + + Args: + model (mindspore model): The model. + """ + path = f'{random.randrange(sys.maxsize)}.pt' + ms.save_checkpoint(model, path) + model_size = osp.getsize(path) + os.remove(path) + return model_size + + +def get_data_size(data: Data) -> int: + r"""Given a :class:`mindGeometric.data.Data` object, get its theoretical + memory usage in bytes. + + Args: + data (mindGeometric.data.Data or mindGeometric.data.HeteroGraph): + The :class:`~mindGeometric.data.Data` or + :class:`~mindGeometric.data.HeteroGraph` graph object. + """ + data_ptrs = set() + + def _get_size(obj: Any) -> int: + if isinstance(obj, Tensor): + if obj in data_ptrs: + return 0 + data_ptrs.add(obj) + return obj.numel() * obj.element_size() + # elif isinstance(obj, SparseTensor): + # return _get_size(obj.csr()) + elif isinstance(obj, Sequence) and not isinstance(obj, str): + return sum([_get_size(x) for x in obj]) + elif isinstance(obj, Mapping): + return sum([_get_size(x) for x in obj.values()]) + else: + return 0 + + return sum([_get_size(store) for store in data.stores]) + + +def get_cpu_memory_from_gc() -> int: + r"""Returns the used CPU memory in bytes, as reported by the + :python:`Python` garbage collector. + """ + warnings.filterwarnings('ignore', '.*mindspore.distributed.reduce_op.*') + + mem = 0 + for obj in gc.get_objects(): + try: + if isinstance(obj, Tensor) and not obj.is_cuda: + mem += obj.numel() * obj.element_size() + except Exception: + pass + return mem + + +############################################################################### + + +def byte_to_megabyte(value: int, digits: int = 2) -> float: + return round(value / (1024 * 1024), digits) + + +def medibyte_to_megabyte(value: int, digits: int = 2) -> float: + return round(1.0485 * value, digits) diff --git a/mindscience/sharker/resolver.py b/mindscience/sharker/resolver.py new file mode 100644 index 000000000..3476c42c3 --- /dev/null +++ b/mindscience/sharker/resolver.py @@ -0,0 +1,43 @@ +import inspect +from typing import Any, Dict, List, Optional, Union + + +def normalize_string(s: str) -> str: + return s.lower().replace('-', '').replace('_', '').replace(' ', '') + + +def resolver( + classes: List[Any], + class_dict: Dict[str, Any], + query: Union[Any, str], + base_cls: Optional[Any], + base_cls_repr: Optional[str], + *args: Any, + **kwargs: Any, +) -> Any: + + if not isinstance(query, str): + return query + + query_repr = normalize_string(query) + if base_cls_repr is None: + base_cls_repr = base_cls.__name__ if base_cls else '' + base_cls_repr = normalize_string(base_cls_repr) + + for key_repr, cls in class_dict.items(): + if query_repr == key_repr: + if inspect.isclass(cls): + obj = cls(*args, **kwargs) + return obj + return cls + + for cls in classes: + cls_repr = normalize_string(cls.__name__) + if query_repr in [cls_repr, cls_repr.replace(base_cls_repr, '')]: + if inspect.isclass(cls): + obj = cls(*args, **kwargs) + return obj + return cls + + choices = set(cls.__name__ for cls in classes) | set(class_dict.keys()) + raise ValueError(f"Could not resolve '{query}' among choices {choices}") diff --git a/mindscience/sharker/seed.py b/mindscience/sharker/seed.py new file mode 100644 index 000000000..69b57f6b8 --- /dev/null +++ b/mindscience/sharker/seed.py @@ -0,0 +1,16 @@ +import random + +import numpy as np +import mindspore as ms + + +def seed_everything(seed: int) -> None: + r"""Sets the seed for generating random numbers in :mindspore:`Mindspore`, + :obj:`numpy` and :python:`Python`. + + Args: + seed (int): The desired seed. + """ + random.seed(seed) + np.random.seed(seed) + ms.set_seed(seed) diff --git a/mindscience/sharker/template.py b/mindscience/sharker/template.py new file mode 100644 index 000000000..2cc578fa3 --- /dev/null +++ b/mindscience/sharker/template.py @@ -0,0 +1,37 @@ +import importlib +import os.path as osp +import sys +import tempfile +from typing import Any + +from jinja2 import Environment, FileSystemLoader + + +def module_from_template( + module_name: str, + template_path: str, + **kwargs: Any, +) -> Any: + + if module_name in sys.modules: # If module is already loaded, return it: + return sys.modules[module_name] + + env = Environment(loader=FileSystemLoader(osp.dirname(template_path))) + template = env.get_template(osp.basename(template_path)) + module_repr = template.render(**kwargs) + + with tempfile.NamedTemporaryFile( + mode='w', + prefix=f'{module_name}_', + suffix='.py', + delete=False, + ) as tmp: + tmp.write(module_repr) + + spec = importlib.util.spec_from_file_location(module_name, tmp.name) + assert spec is not None + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + assert spec.loader is not None + spec.loader.exec_module(module) + return module diff --git a/mindscience/sharker/testing/__init__.py b/mindscience/sharker/testing/__init__.py new file mode 100644 index 000000000..c7cdca4c4 --- /dev/null +++ b/mindscience/sharker/testing/__init__.py @@ -0,0 +1,62 @@ +r"""Testing package. + +This package provides helper methods and decorators to ease testing. +""" + +from .decorators import ( + is_full_test, + onlyFullTest, + is_distributed_test, + onlyDistributedTest, + onlyLinux, + noWindows, + onlyPython, + # onlyCUDA, + # onlyXPU, + onlyOnline, + # onlyGraphviz, + # onlyNeighborSampler, + has_package, + withPackage, + # withDevice, + # withCUDA, + # withMETIS, + disableExtensions, + withoutExtensions, +) +from .asserts import assert_module +# from .feature_store import MyFeatureStore +# from .graph_store import MyGraphStore +from .data import ( + get_random_edge_index, + # get_random_tensor_frame, + FakeHeteroGraphset, +) + +__all__ = [ + 'is_full_test', + 'onlyFullTest', + 'is_distributed_test', + 'onlyDistributedTest', + 'onlyLinux', + 'noWindows', + 'onlyPython', + # 'onlyCUDA', + # 'onlyXPU', + 'onlyOnline', + # 'onlyGraphviz', + # 'onlyNeighborSampler', + 'has_package', + 'withPackage', + # 'withDevice', + # 'withCUDA', + # 'withMETIS', + 'disableExtensions', + 'withoutExtensions', + 'assert_module', + 'MyFeatureStore', + 'MyGraphStore', + 'get_random_edge_index', + # 'get_random_tensor_frame', + 'FakeHeteroGraphset', +] diff --git a/mindscience/sharker/testing/asserts.py b/mindscience/sharker/testing/asserts.py new file mode 100644 index 000000000..c6dc572b2 --- /dev/null +++ b/mindscience/sharker/testing/asserts.py @@ -0,0 +1,91 @@ +import copy +from typing import Any, Optional, Tuple + +import numpy as np +from mindspore import Tensor, ops, nn + + +# SPARSE_LAYOUTS: List[Union[str, torch.layout]] = [ +# 'torch_sparse', torch.sparse_csc, torch.sparse_coo +# ] + + +def assert_module( + module: nn.Cell, + x: Any, + edge_index: Tensor, + *, + expected_size: Tuple[int, ...], + test_edge_permutation: bool = True, + test_node_permutation: bool = False, + # test_sparse_layouts: Optional[List[Union[str, torch.layout]]] = None, + sparse_size: Optional[Tuple[int, int]] = None, + atol: float = 1e-08, + rtol: float = 1e-05, + equal_nan: bool = False, + **kwargs: Any, +) -> Any: + r"""Asserts that the output of a :obj:`module` is correct. + + Specifically, this method tests that: + + 1. The module output has the correct shape. + 2. The module is invariant to the permutation of edges. + 3. The module is invariant to the permutation of nodes. + 4. The module is invariant to the layout of :obj:`edge_index`. + + Args: + module (nn.Cell): The module to test. + x (Any): The input features to the module. + edge_index (torch.Tensor): The input edge indices. + expected_size (Tuple[int, ...]): The expected output size. + test_edge_permutation (bool, optional): If set to :obj:`False`, will + not test the module for edge permutation invariance. + test_node_permutation (bool, optional): If set to :obj:`False`, will + not test the module for node permutation invariance. + test_sparse_layouts (List[str or int], optional): The sparse layouts to + test for module invariance. (default: :obj:`["torch_sparse", + torch.sparse_csc, torch.sparse_coo]`) + sparse_size (Tuple[int, int], optional): The size of the sparse + adjacency matrix. If not given, will try to automatically infer it. + (default: :obj:`None`) + atol (float, optional): Absolute tolerance. (default: :obj:`1e-08`) + rtol (float, optional): Relative tolerance. (default: :obj:`1e-05`) + equal_nan (bool, optional): If set to :obj:`True`, then two :obj:`NaN`s + will be considered equal. (default: :obj:`False`) + **kwargs (optional): Additional arguments passed to + :meth:`module.forward`. + """ + # if test_sparse_layouts is None: + # test_sparse_layouts = SPARSE_LAYOUTS + + if sparse_size is None: + if 'size' in kwargs: + sparse_size = kwargs['size'] + elif isinstance(x, Tensor): + sparse_size = (x.shape[0], x.shape[0]) + elif (isinstance(x, (tuple, list)) and isinstance(x[0], Tensor) + and isinstance(x[1], Tensor)): + sparse_size = (x[0].shape[0], x[1].shape[0]) + + # if len(test_sparse_layouts) > 0 and sparse_size is None: + # raise ValueError(f"Got sparse layouts {test_sparse_layouts}, but no " + # f"'sparse_size' were specified") + + expected = module(x, edge_index=edge_index, **kwargs) + assert expected.shape == expected_size + + if test_edge_permutation: + perm = ops.shuffle(ops.arange(edge_index.shape[1])) + perm_kwargs = copy.copy(kwargs) + for key, value in kwargs.items(): + if isinstance(value, Tensor) and value.shape[0] == perm.numel(): + perm_kwargs[key] = value[perm] + out = module(x, edge_index[:, perm], **perm_kwargs) + assert np.allclose(out.asnumpy(), expected.asnumpy(), rtol, atol, equal_nan) + + if test_node_permutation: + raise NotImplementedError + + + return expected diff --git a/mindscience/sharker/testing/data.py b/mindscience/sharker/testing/data.py new file mode 100644 index 000000000..f2a4ea28a --- /dev/null +++ b/mindscience/sharker/testing/data.py @@ -0,0 +1,57 @@ +from typing import Callable, Optional + +import mindspore as ms +from mindspore import Tensor, ops + +from ..data import HeteroGraph, InMemoryDataset +from mindscience.sharker.utils import coalesce as coalesce_fn + + +def get_random_edge_index( + num_src_nodes: int, + num_dst_nodes: int, + num_edges: int, + dtype: Optional[ms.Type] = None, + coalesce: bool = False, +) -> Tensor: + row = ops.randint(0, num_src_nodes, (num_edges, ), dtype=dtype) + col = ops.randint(0, num_dst_nodes, (num_edges, ), dtype=dtype) + edge_index = ops.stack([row, col], axis=0) + + if coalesce: + edge_index = coalesce_fn(edge_index) + + return edge_index + + +class FakeHeteroGraphset(InMemoryDataset): + def __init__(self, transform: Optional[Callable] = None): + super().__init__(transform=transform) + + data = HeteroGraph() + + num_papers = 100 + num_authors = 10 + + data['paper'].x = ops.randn(num_papers, 16) + data['author'].x = ops.randn(num_authors, 8) + + edge_index = get_random_edge_index( + num_src_nodes=num_papers, + num_dst_nodes=num_authors, + num_edges=300, + ) + data['paper', 'author'].edge_index = edge_index + data['author', 'paper'].edge_index = edge_index.flip([0]) + + data['paper'].y = ops.randint(0, 4, (num_papers, )) + + perm = ops.shuffle(ops.arange(num_papers)) + data['paper'].train_mask = ops.zeros(num_papers).bool() + data['paper'].train_mask[perm[0:60]] = True + data['paper'].val_mask = ops.zeros(num_papers).bool() + data['paper'].val_mask[perm[60:80]] = True + data['paper'].test_mask = ops.zeros(num_papers).bool() + data['paper'].test_mask[perm[80:100]] = True + + self.data, self.slices = self.collate([data]) diff --git a/mindscience/sharker/testing/decorators.py b/mindscience/sharker/testing/decorators.py new file mode 100644 index 000000000..759fbcec5 --- /dev/null +++ b/mindscience/sharker/testing/decorators.py @@ -0,0 +1,150 @@ +import os +import sys +from importlib import import_module +from importlib.util import find_spec +from typing import Callable +import pytest +from packaging.requirements import Requirement + + +def is_full_test() -> bool: + r"""Whether to run the full but time-consuming test suite.""" + return os.getenv('FULL_TEST', '0') == '1' + + +def onlyFullTest(func: Callable) -> Callable: + r"""A decorator to specify that this function belongs to the full test + suite. + """ + return pytest.mark.skipif( + not is_full_test(), + reason="Fast test run", + )(func) + + +def is_distributed_test() -> bool: + r"""Whether to run the distributed test suite.""" + return ((is_full_test() or os.getenv('DIST_TEST', '0') == '1') + and sys.platform == 'linux' and has_package('pyg_lib')) + + +def onlyDistributedTest(func: Callable) -> Callable: + r"""A decorator to specify that this function belongs to the distributed + test suite. + """ + return pytest.mark.skipif( + not is_distributed_test(), + reason="Fast test run", + )(func) + + +def onlyLinux(func: Callable) -> Callable: + r"""A decorator to specify that this function should only execute on + Linux systems. + """ + return pytest.mark.skipif( + sys.platform != 'linux', + reason="No Linux system", + )(func) + + +def noWindows(func: Callable) -> Callable: + r"""A decorator to specify that this function should not execute on + Windows systems. + """ + return pytest.mark.skipif( + os.name == 'nt', + reason="Windows system", + )(func) + + +def onlyPython(*args: str) -> Callable: + r"""A decorator to run tests on specific :python:`Python` versions only.""" + def decorator(func: Callable) -> Callable: + python_version = f'{sys.version_info.major}.{sys.version_info.minor}' + return pytest.mark.skipif( + python_version not in args, + reason=f"Python {python_version} not supported", + )(func) + + return decorator + + +def onlyOnline(func: Callable) -> Callable: + r"""A decorator to skip tests if there exists no connection to the + internet. + """ + import http.client as httplib + + has_connection = True + connection = httplib.HTTPSConnection('8.8.8.8', timeout=5) + try: + connection.request('HEAD', '/') + except Exception: + has_connection = False + finally: + connection.close() + + return pytest.mark.skipif( + not has_connection, + reason="No internet connection", + )(func) + + +def has_package(package: str) -> bool: + r"""Returns :obj:`True` in case :obj:`package` is installed.""" + if '|' in package: + return any(has_package(p) for p in package.split('|')) + req = Requirement(package) + if find_spec(req.name) is None: + return False + module = import_module(req.name) + if not hasattr(module, '__version__'): + return True + + version = module.__version__ + # `req.specifier` does not support `.dev` suffixes, e.g., for + # `pyg_lib==0.1.0.dev*`, so we manually drop them: + if '.dev' in version: + version = '.'.join(version.split('.dev')[:-1]) + + return version in req.specifier + + +def withPackage(*args: str) -> Callable: + r"""A decorator to skip tests if certain packages are not installed. + Also supports version specification. + """ + na_packages = set(package for package in args if not has_package(package)) + + if len(na_packages) == 1: + reason = f"Package {list(na_packages)[0]} not found" + else: + reason = f"Packages {na_packages} not found" + + def decorator(func: Callable) -> Callable: + + return pytest.mark.skipif(len(na_packages) > 0, reason=reason)(func) + + return decorator + + + +def disableExtensions(func: Callable) -> Callable: + r"""A decorator to temporarily disable the usage of the + :obj:`torch_scatter`, :obj:`torch_sparse` and :obj:`pyg_lib` extension + packages. + """ + return pytest.mark.usefixtures('disable_extensions')(func) + + +def withoutExtensions(func: Callable) -> Callable: + r"""A decorator to test both with and without the usage of extension + packages such as :obj:`torch_scatter`, :obj:`torch_sparse` and + :obj:`pyg_lib`. + """ + return pytest.mark.parametrize( + 'without_extensions', + ['enable_extensions', 'disable_extensions'], + indirect=True, + )(func) diff --git a/mindscience/sharker/testing/distributed.py b/mindscience/sharker/testing/distributed.py new file mode 100644 index 000000000..3fb96dd7d --- /dev/null +++ b/mindscience/sharker/testing/distributed.py @@ -0,0 +1,92 @@ +import sys +import traceback +from dataclasses import dataclass +from io import StringIO +from typing import Any, Callable, List, Tuple + +import pytest +from multiprocessing import Manager, Queue +from typing_extensions import Self + + +@dataclass +class ProcArgs: + target: Callable + args: Tuple[Any, ...] + + +class MPCaptOutput: + def __enter__(self) -> Self: + self.stdout = StringIO() + self.stderr = StringIO() + self.old_stdout = sys.stdout + self.old_stderr = sys.stderr + + sys.stdout = self.stdout + sys.stderr = self.stderr + + return self + + def __exit__(self, *args: Any) -> None: + sys.stdout = self.old_stdout + sys.stderr = self.old_stderr + + @property + def stdout_str(self) -> str: + return self.stdout.getvalue() + + @property + def stderr_str(self) -> str: + return self.stderr.getvalue() + + +def ps_std_capture( + func: Callable, + queue: Queue, + *args: Any, + **kwargs: Any, +) -> None: + with MPCaptOutput() as capt: + try: + func(*args, **kwargs) + except Exception as e: + traceback.print_exc(file=sys.stderr) + raise e + finally: + queue.put((capt.stdout_str, capt.stderr_str)) + + +def assert_run_mproc( + mp_context: Any, + pargs: List[ProcArgs], + full_trace: bool = False, + timeout: int = 5, +) -> None: + manager = Manager() + world_size = len(pargs) + queues = [manager.Queue() for _ in pargs] + procs = [ + mp_context.Process( + target=ps_std_capture, + args=[p.target, q, world_size] + list(p.args), + ) for p, q in zip(pargs, queues) + ] + results = [] + + for p, q in zip(procs, queues): + p.start() + + for p, q in zip(procs, queues): + p.join() + stdout, stderr = q.get(timeout=timeout) + results.append((p, stdout, stderr)) + + for p, stdout, stderr in results: + if stdout: + print(stdout) + if stderr: # can be a warning as well => exitcode == 0 + print(stderr) + if p.exitcode != 0: + pytest.fail( + pytrace=full_trace, reason=stderr.splitlines()[-1] + if stderr else f"exitcode {p.exitcode}") diff --git a/mindscience/sharker/typing.py b/mindscience/sharker/typing.py new file mode 100644 index 000000000..b13df62b9 --- /dev/null +++ b/mindscience/sharker/typing.py @@ -0,0 +1,194 @@ +import os +import mindspore as ms +from typing import Any, Dict, List, Optional, Tuple, Union + +import numpy as np +from mindspore import Tensor + +WITH_WINDOWS = os.name == 'nt' + +MAX_INT64 = np.iinfo(np.int64).max + + +# Types for accessing data #################################################### + +# Node-types are denoted by a single string, e.g.: `data['paper']`: +NodeType = str + +# Edge-types are denotes by a triplet of strings, e.g.: +# `data[('author', 'writes', 'paper')] +EdgeType = Tuple[str, str, str] + +NodeOrEdgeType = Union[NodeType, EdgeType] + +DEFAULT_REL = 'to' +EDGE_TYPE_STR_SPLIT = '__' + + +WITH_SPARSE = False +WITH_SOFTMAX = False + + +class SparseStorage: # type: ignore + def __init__( + self, + row: Optional[Tensor] = None, + rowptr: Optional[Tensor] = None, + col: Optional[Tensor] = None, + value: Optional[Tensor] = None, + sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, + rowcount: Optional[Tensor] = None, + colptr: Optional[Tensor] = None, + colcount: Optional[Tensor] = None, + csr2csc: Optional[Tensor] = None, + csc2csr: Optional[Tensor] = None, + is_sorted: bool = False, + trust_data: bool = False, + ): + raise ImportError("'SparseStorage' requires 'torch-sparse'") + + def value(self) -> Optional[Tensor]: + raise ImportError("'SparseStorage' requires 'torch-sparse'") + + def rowcount(self) -> Tensor: + raise ImportError("'SparseStorage' requires 'torch-sparse'") + + +class SparseTensor: # type: ignore + def __init__( + self, + row: Optional[Tensor] = None, + rowptr: Optional[Tensor] = None, + col: Optional[Tensor] = None, + value: Optional[Tensor] = None, + sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, + is_sorted: bool = False, + trust_data: bool = False, + ): + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + @classmethod + def from_edge_index( + self, + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + sparse_sizes: Optional[Tuple[Optional[int], Optional[int]]] = None, + is_sorted: bool = False, + trust_data: bool = False, + ) -> 'SparseTensor': + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + @property + def storage(self) -> SparseStorage: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + @classmethod + def from_dense(self, mat: Tensor, + has_value: bool = True) -> 'SparseTensor': + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def size(self, dim: int) -> int: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def nnz(self) -> int: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def is_cuda(self) -> bool: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def has_value(self) -> bool: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def set_value(self, value: Optional[Tensor], + layout: Optional[str] = None) -> 'SparseTensor': + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def fill_value(self, fill_value: float, + dtype: Optional[ms.Type] = None) -> 'SparseTensor': + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def coo(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def csr(self) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def requires_grad(self) -> bool: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + def to_torch_sparse_csr_tensor( + self, + dtype: Optional[ms.Type] = None, + ) -> Tensor: + raise ImportError("'SparseTensor' requires 'torch-sparse'") + + +class EdgeTypeStr(str): + r"""A helper class to construct serializable edge types by merging an edge + type tuple into a single string. + """ + def __new__(cls, *args: Any) -> 'EdgeTypeStr': + if isinstance(args[0], (list, tuple)): + # Unwrap `EdgeType((src, rel, dst))` and `EdgeTypeStr((src, dst))`: + args = tuple(args[0]) + + if len(args) == 1 and isinstance(args[0], str): + arg = args[0] # An edge type string was passed. + + elif len(args) == 2 and all(isinstance(arg, str) for arg in args): + # A `(src, dst)` edge type was passed - add `DEFAULT_REL`: + arg = EDGE_TYPE_STR_SPLIT.join((args[0], DEFAULT_REL, args[1])) + + elif len(args) == 3 and all(isinstance(arg, str) for arg in args): + # A `(src, rel, dst)` edge type was passed: + arg = EDGE_TYPE_STR_SPLIT.join(args) + + else: + raise ValueError(f"Encountered invalid edge type '{args}'") + + return str.__new__(cls, arg) + + def to_tuple(self) -> EdgeType: + r"""Returns the original edge type.""" + out = tuple(self.split(EDGE_TYPE_STR_SPLIT)) + if len(out) != 3: + raise ValueError(f"Cannot convert the edge type '{self}' to a " + f"tuple since it holds invalid characters") + return out + + +# There exist some short-cuts to query edge-types (given that the full triplet +# can be uniquely reconstructed, e.g.: +# * via str: `data['writes']` +# * via Tuple[str, str]: `data[('author', 'paper')]` +QueryType = Union[NodeType, EdgeType, str, Tuple[str, str]] + +Metadata = Tuple[List[NodeType], List[EdgeType]] + +# A representation of a feature tensor +FeatureTensorType = Union[Tensor, np.ndarray] + +# A representation of an edge index, following the possible formats: +# * COO: (row, col) +# * CSC: (row, colptr) +# * CSR: (rowptr, col) +EdgeTensorType = Tuple[Tensor, Tensor] + +# Types for message passing ################################################### + +Adj = Union[Tensor, ] +OptTensor = Optional[Tensor] +PairTensor = Tuple[Tensor, Tensor] +OptPairTensor = Tuple[Tensor, Optional[Tensor]] +PairOptTensor = Tuple[Optional[Tensor], Optional[Tensor]] +Size = Optional[Tuple[int, int]] +NoneType = Optional[Tensor] + +MaybeHeteroNodeTensor = Union[Tensor, Dict[NodeType, Tensor]] +MaybeHeteroAdjTensor = Union[Tensor, Dict[EdgeType, Adj]] +MaybeHeteroEdgeTensor = Union[Tensor, Dict[EdgeType, Tensor]] + +# Types for sampling ########################################################## + +InputNodes = Union[OptTensor, NodeType, Tuple[NodeType, OptTensor]] +InputEdges = Union[OptTensor, EdgeType, Tuple[EdgeType, OptTensor]] diff --git a/mindscience/sharker/utils/__init__.py b/mindscience/sharker/utils/__init__.py new file mode 100644 index 000000000..07ec141a3 --- /dev/null +++ b/mindscience/sharker/utils/__init__.py @@ -0,0 +1,144 @@ +r'''Utility package.''' + +import copy +from .mixin import CastMixin +from ._scatter import scatter, group_argsort, group_cat, scatter_concat +from ._segment import segment +from .functions import cumsum, swapaxes, index_fill, index_select, broadcast_to +from .degree import degree +from .softmax import softmax +from .sort_edge_index import sort_edge_index +from .coalesce import coalesce +from .repeat import repeat +from .undirected import is_undirected, to_undirected +from .loop import ( + contains_self_loops, + remove_self_loops, + segregate_self_loops, + add_self_loops, + add_remaining_self_loops, + get_self_loop_attr, +) +from .isolated import contains_isolated_nodes, remove_isolated_nodes +from .subgraph import get_num_hops, subgraph, k_hop_subgraph, bipartite_subgraph, hyper_subgraph +from .dropout import dropout_node, dropout_edge +from .homophily import homophily +from .assortativity import assortativity +from .laplacian import get_laplacian, get_mesh_laplacian +from .mask import mask_select, index_to_mask, mask_to_index +from .select import select, narrow +from .to_dense_batch import to_dense_batch +from .to_dense_adj import to_dense_adj +from .sparse import ( + is_sparse_tensor, + ptr2index, + index2ptr, +) +from .num_nodes import maybe_num_nodes +from .unbatch import unbatch, unbatch_edge_index +from .normalize import normalized_cut +from .grid import grid +from .convert import to_scipy_sparse_matrix, from_scipy_sparse_matrix +from .convert import to_networkx, from_networkx +from .convert import to_trimesh, from_trimesh +from .convert import to_tensor, to_array +from .random import ( + erdos_renyi_graph, + barabasi_albert_graph, +) +from .negative_sampling import ( + negative_sampling, + batched_negative_sampling, + structured_negative_sampling, + structured_negative_sampling_feasible, +) +from .augmentation import shuffle_node, mask_feature, add_random_edge +from .tree_decomposition import tree_decomposition +from .embedding import get_embeddings +from .trim_to_layer import trim_to_layer, TrimToLayer +from .cluster import radius_graph +from .ncon import Ncon + + +__all__ = [ + 'segment', + 'scatter', + 'group_argsort', + 'group_cat', + 'scatter_concat', + 'cumsum', + 'swapaxes', + 'index_fill', + 'index_select', + 'broadcast_to', + 'degree', + 'softmax', + 'sort_edge_index', + 'coalesce', + 'is_undirected', + 'to_undirected', + 'contains_self_loops', + 'remove_self_loops', + 'segregate_self_loops', + 'add_self_loops', + 'add_remaining_self_loops', + 'get_self_loop_attr', + 'contains_isolated_nodes', + 'remove_isolated_nodes', + 'get_num_hops', + 'subgraph', + 'bipartite_subgraph', + 'k_hop_subgraph', + 'hyper_subgraph', + 'dropout_node', + 'dropout_edge', + 'CastMixin', + 'homophily', + 'assortativity', + 'get_laplacian', + 'get_mesh_laplacian', + 'mask_select', + 'index_to_mask', + 'mask_to_index', + 'select', + 'narrow', + 'to_dense_batch', + 'to_dense_adj', + 'to_tensor', + 'to_array', + 'is_sparse_tensor', + 'index2ptr', + 'ptr2index', + 'maybe_num_nodes', + 'unbatch', + 'unbatch_edge_index', + 'normalized_cut', + 'grid', + 'to_scipy_sparse_matrix', + 'from_scipy_sparse_matrix', + 'to_networkx', + 'from_networkx', + 'to_trimesh', + 'from_trimesh', + 'erdos_renyi_graph', + 'barabasi_albert_graph', + 'negative_sampling', + 'batched_negative_sampling', + 'structured_negative_sampling', + 'structured_negative_sampling_feasible', + 'shuffle_node', + 'mask_feature', + 'add_random_edge', + 'tree_decomposition', + 'get_embeddings', + 'trim_to_layer', + 'TrimToLayer', + 'repeat', + 'radius_graph', + 'Ncon', +] + +# `structured_negative_sampling_feasible` is a long name and thus destroys the +# documentation rendering. We remove it for now from the documentation: +classes = copy.copy(__all__) +classes.remove('structured_negative_sampling_feasible') diff --git a/mindscience/sharker/utils/_scatter.py b/mindscience/sharker/utils/_scatter.py new file mode 100644 index 000000000..c0eef11e8 --- /dev/null +++ b/mindscience/sharker/utils/_scatter.py @@ -0,0 +1,339 @@ +import mindspore as ms +from typing import List, Optional, Tuple, Union +from mindspore import Tensor, ops, mint +from .functions import cumsum, index_select, index_fill, broadcast_to + + +_scatter_max = ops.MultitypeFuncGraph('_scatter_max') +_scatter_amax = ops.MultitypeFuncGraph('_scatter_amax') +_scatter_min = ops.MultitypeFuncGraph('_scatter_min') +_scatter_amin = ops.MultitypeFuncGraph('_scatter_amin') + +_scatter_sum = ops.MultitypeFuncGraph('_scatter_sum') +_scatter_mean = ops.MultitypeFuncGraph('_scatter_mean') +_scatter_mul = ops.MultitypeFuncGraph('_scatter_mul') + + +@_scatter_max.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_max(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index == scope + out = index_select(src, mask, dim) + if out.shape[dim] in [0, 1]: + return out + return out.max(dim, keepdims=True) + + +@_scatter_amax.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_amax(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index != scope + out = index_fill(src.copy(), mask, src.min()-1, dim) + shape = list(src.shape) + if shape[dim] == 0: + shape[dim] = 1 + return ops.zeros(shape).long() + return out.argmax(dim, keepdims=True) + + +@_scatter_min.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_min(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index == scope + out = index_select(src, mask, dim) + if out.shape[dim] in [0, 1]: + return out + return out.min(dim, keepdims=True) + + +@_scatter_amin.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_amin(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index != scope + out = index_fill(src.copy(), mask, src.max()+1, dim) + shape = list(src.shape) + if shape[dim] == 0: + shape[dim] = 1 + return mint.zeros(shape).int() + return out.argmin(dim, keepdims=True) + + +@_scatter_sum.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_sum(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index == scope + out = index_select(src, mask, dim) + if out.shape[dim] in [0, 1]: + return out + return out.sum(dim, keepdims=True) + + +@_scatter_mean.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_mean(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index == scope + out = index_select(src, mask, dim) + if out.shape[dim] in [0, 1]: + return out + return out.mean(dim, keep_dims=True) + + +@_scatter_mul.register('Number', 'Tensor', 'Tensor', 'Number') +def _scatter_mul(dim: int, index: Tensor, src: Tensor, scope: int): + mask = index == scope + out = index_select(src, mask, dim) + if out.shape[dim] in [0, 1]: + return out + return out.prod(dim, keepdim=True) + + +def scatter( + src: Tensor, + index: Tensor, + dim: int = -1, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> Tensor: + if dim > src.ndim-1: + raise ValueError(f"`dim` must lay between 0 and {src.ndim-1}") + if dim_size is None: + dim_size = int(ops.amax(index)) + 1 + shape = list(src.shape) + shape[dim] = dim_size + out = mint.zeros(shape, dtype=src.dtype) + scope = ops.unique(index)[0].astype(ms.int64) + common_map = ops.Map() + + dim = src.ndim + dim if dim < 0 else dim + + size = src.shape[:dim] + (dim_size, ) + src.shape[dim + 1:] + + if reduce == 'any': + index = broadcast(index, src, dim) + return mint.scatter(mint.zeros(size, dtype=src.dtype), dim, ops.deepcopy(index), src) + + if reduce == 'sum' or reduce =='add': + index = broadcast(index, src, dim) + return mint.scatter_add(mint.zeros(size, dtype=src.dtype), dim, ops.deepcopy(index), src) + + if reduce == 'mean': + count = mint.zeros(dim_size) + count = mint.scatter_add(count, 0, index, mint.ones(src.shape[dim], dtype=src.dtype)) + count = count.clamp(min=1) + + index = broadcast(index, src, dim) + out = mint.scatter_add(mint.zeros(size, dtype=src.dtype), dim, ops.deepcopy(index), src) + + return out / broadcast(count, out, dim) + + elif reduce == 'mul': + vals = common_map(ops.partial(_scatter_mul, dim, index, src), scope) + elif reduce == 'max': + vals = common_map(ops.partial(_scatter_max, dim, index, src), scope) + elif reduce == 'amax': + out = out.long().fill(src.shape[dim]) + vals = common_map(ops.partial(_scatter_amax, dim, index, src), scope) + elif reduce == 'min': + vals = common_map(ops.partial(_scatter_min, dim, index, src), scope) + elif reduce == 'amin': + out = out.int().fill(src.shape[dim]) + vals = common_map(ops.partial(_scatter_amin, dim, index, src), scope) + else: + raise ValueError(f"invalid `reduce` argument '{reduce}'") + out = index_fill(out, scope, mint.cat(vals, dim=dim), dim) + return out + +def broadcast(src: Tensor, ref: Tensor, dim: int) -> Tensor: + dim = ref.ndim + dim if dim < 0 else dim + size = ((1, ) * dim) + (-1, ) + ((1, ) * (ref.dim() - dim - 1)) + return src.view(size).expand_as(ref) + +def scatter_softmax( + src: Tensor, index: Tensor, dim: int = -1, dim_size: Optional[int] = None +) -> Tensor: + max_value_per_index = scatter(src, index, dim=dim, dim_size=dim_size) + ix = broadcast_to(index, src, dim) + max_per_src_element = ops.gather_nd( + max_value_per_index, ix).reshape(src.shape) + + recentered_scores = src - max_per_src_element + recentered_scores_exp = mint.exp(recentered_scores) + + sum_per_index = scatter( + recentered_scores_exp, index, dim, dim_size=dim_size, reduce="sum" + ) + normalizing_constants = ops.gather_nd(sum_per_index, ix).reshape(src.shape) + return mint.div(recentered_scores_exp, normalizing_constants) + + +def scatter_log_softmax( + src: Tensor, + index: Tensor, + dim: int = -1, + eps: float = 1e-12, + dim_size: Optional[int] = None, +) -> Tensor: + max_value_per_index = scatter( + src, index, dim=dim, dim_size=dim_size, reduce="sum") + ix = broadcast_to(index, src, dim) + max_per_src_element = ops.gather_nd( + max_value_per_index, ix).reshape(src.shape) + + recentered_scores = src - max_per_src_element + recentered_scores_exp = mint.exp(recentered_scores) + + sum_per_index = scatter( + recentered_scores_exp, index, dim, dim_size=dim_size, reduce="sum" + ) + normalizing_constants = ops.gather_nd(mint.log(sum_per_index + eps), ix).reshape( + src.shape + ) + return recentered_scores - normalizing_constants + + +def group_argsort( + src: Tensor, + index: Tensor, + dim: int = 0, + num_groups: Optional[int] = None, + descending: bool = False, + return_consecutive: bool = False, +) -> Tensor: + r"""Returns the indices that sort the tensor :obj:`src` along a given + dimension in ascending order by value. + In contrast to :meth:`mindspore.argsort`, sorting is performed in groups + according to the values in :obj:`index`. + + Args: + src (Tensor): The source tensor. + index (Tensor): The index tensor. + dim (int, optional): The dimension along which to index. + (default: :obj:`0`) + num_groups (int, optional): The number of groups. + (default: :obj:`None`) + descending (bool, optional): Controls the sorting order (ascending or + descending). (default: :obj:`False`) + return_consecutive (bool, optional): If set to :obj:`True`, will not + offset the output to start from :obj:`0` for each group. + (default: :obj:`False`) + stable (bool, optional): Controls the relative order of equivalent + elements. (default: :obj:`False`) + + Example: + >>> src = Tensor([0, 1, 5, 4, 3, 2, 6, 7, 8]) + >>> index = Tensor, 1, 1, 1, 1, 2, 2, 2]) + >>> group_argsort(src, index) + tensor([0, 1, 3, 2, 1, 0, 0, 1, 2]) + """ + # Only implemented under certain conditions for now :( + assert src.dim() == 1 and index.dim() == 1 + assert dim == 0 or dim == -1 + assert src.numel() == index.numel() + + if src.numel() == 0: + return mint.zeros_like(src) + + # Normalize `src` to range [0, 1]: + src = src - src.min() + src = src / src.max() + + # Compute `grouped_argsort`: + src = src - 2 * index if descending else src + 2 * index + perm = src.argsort(descending=descending) + out = 0 - mint.ones_like(index) + out[perm] = mint.arange(index.numel()) + + if return_consecutive: + return out + + # Compute cumulative sum of number of entries with the same index: + count = scatter( + mint.ones_like(index), index, dim=dim, dim_size=num_groups, reduce="sum" + ) + ptr = cumsum(count) + + return out - ptr[index] + + +def group_cat( + tensors: Union[List[Tensor], Tuple[Tensor, ...]], + indices: Union[List[Tensor], Tuple[Tensor, ...]], + axis: int = 0, + return_index: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + r"""Concatenates the given sequence of tensors :obj:`tensors` in the given + dimension :obj:`dim`. + Different from :meth:`ms.cat`, values along the concatenating dimension + are grouped according to the indicies defined in the :obj:`index` tensors. + All tensors must have the same shape (except in the concatenating + dimension). + + Args: + tensors ([Tensor]): Sequence of tensors. + indices ([Tensor]): Sequence of index tensors. + dim (int, optional): The dimension along which the tensors are + concatenated. (default: :obj:`0`) + return_index (bool, optional): If set to :obj:`True`, will return the + new index tensor. (default: :obj:`False`) + + Example: + >>> x1 = ms.Tensor([[0.2716, 0.4233], + ... [0.3166, 0.0142], + ... [0.2371, 0.3839], + ... [0.4100, 0.0012]]) + >>> x2 = ms.Tensor([[0.3752, 0.5782], + ... [0.7757, 0.5999]]) + >>> index1 = ms.Tensor([0, 0, 1, 2]) + >>> index2 = ms.Tensor([0, 2]) + >>> scatter_concat([x1,x2], [index1, index2], axis=0) + tensor([[0.2716, 0.4233], + [0.3166, 0.0142], + [0.3752, 0.5782], + [0.2371, 0.3839], + [0.4100, 0.0012], + [0.7757, 0.5999]]) + """ + assert len(tensors) == len(indices) + index, perm = mint.sort(mint.cat(indices)) + out = mint.cat(tensors, dim=axis)[perm] + return (out, index) if return_index else out + + +def scatter_concat( + tensors: Union[List[Tensor], Tuple[Tensor, ...]], + indices: Union[List[Tensor], Tuple[Tensor, ...]], + axis: int = 0, + return_index: bool = False, +) -> Union[Tensor, Tuple[Tensor, Tensor]]: + r"""Concatenates the given sequence of tensors :obj:`tensors` in the given + dimension :obj:`dim`. + Different from :meth:`ops.cat`, values along the concatenating dimension + are grouped according to the indicies defined in the :obj:`index` tensors. + All tensors must have the same shape (except in the concatenating + dimension). + + Args: + tensors ([Tensor]): Sequence of tensors. + indices ([Tensor]): Sequence of index tensors. + dim (int, optional): The dimension along which the tensors are + concatenated. (default: :obj:`0`) + return_index (bool, optional): If set to :obj:`True`, will return the + new index tensor. (default: :obj:`False`) + + Example: + >>> x1 = Tensor([[0.2716, 0.4233], + ... [0.3166, 0.0142], + ... [0.2371, 0.3839], + ... [0.4100, 0.0012]]) + >>> x2 = Tensor([[0.3752, 0.5782], + ... [0.7757, 0.5999]]) + >>> index1 = Tensor([0, 0, 1, 2]) + >>> index2 = Tensor + >>> scatter_concat([x1,x2], [index1, index2], dim=0) + tensor([[0.2716, 0.4233], + [0.3166, 0.0142], + [0.3752, 0.5782], + [0.2371, 0.3839], + [0.4100, 0.0012], + [0.7757, 0.5999]]) + """ + assert len(tensors) == len(indices) + index = mint.cat(indices) + perm = index.argsort() + index = index[perm] + out = mint.cat(tensors, dim=axis)[perm] + return (out, index) if return_index else out diff --git a/mindscience/sharker/utils/_segment.py b/mindscience/sharker/utils/_segment.py new file mode 100644 index 000000000..b841f0c2a --- /dev/null +++ b/mindscience/sharker/utils/_segment.py @@ -0,0 +1,115 @@ +from mindspore import Tensor, ops, mint + + +_segment_max = ops.MultitypeFuncGraph('_segment_max') +_segment_amax = ops.MultitypeFuncGraph('_segment_amax') +_segment_min = ops.MultitypeFuncGraph('_segment_min') +_segment_amin = ops.MultitypeFuncGraph('_segment_amin') + +_segment_sum = ops.MultitypeFuncGraph('_segment_sum') +_segment_mean = ops.MultitypeFuncGraph('_segment_mean') +_segment_mul = ops.MultitypeFuncGraph('_segment_mul') + + +@_segment_max.register('Number', 'Tensor') +def _segment_max(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape) + return src.max(axis, keepdims=True) + + +@_segment_amax.register('Number', 'Tensor') +def _segment_amax(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape).long() + return src.argmax(axis, keepdims=True) + + +@_segment_min.register('Number', 'Tensor') +def _segment_min(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape, dtype=src.dtype) + return src.min(axis, keepdims=True) + + +@_segment_amin.register('Number', 'Tensor') +def _segment_amin(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape).int() + return src.argmin(axis, keepdims=True) + + +@_segment_sum.register('Number', 'Tensor') +def _segment_sum(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape, dtype=src.dtype) + return src.sum(axis, keepdims=True) + + +@_segment_mean.register('Number', 'Tensor') +def _segment_mean(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape, dtype=src.dtype) + return src.mean(axis, keep_dims=True) + + +@_segment_mul.register('Number', 'Tensor') +def _segment_mul(axis: int, src: Tensor): + shape = list(src.shape) + if shape[axis] == 0: + shape[axis] = 1 + return mint.zeros(shape, dtype=src.dtype) + return src.prod(axis, keepdim=True) + + +def segment(src: Tensor, ptr: Tensor, dim=0, dim_size=None, reduce: str = "sum") -> Tensor: + r"""Reduces all values in the first dimension of the :obj:`src` tensor + within the ranges specified in the :obj:`ptr`. :obj:`mindspore_scatter` package for more + information. + + Args: + src (Tensor): The source tensor. + ptr (Tensor): A monotonically increasing pointer tensor that + refers to the boundaries of segments such that :obj:`ptr[0] = 0` + and :obj:`ptr[-1] = src.shape[0]`. + reduce (str, optional): The reduce operation (:obj:`"sum"`, + :obj:`"mean"`, :obj:`"min"` or :obj:`"max"`). + (default: :obj:`"sum"`) + """ + vals = mint.split(src, ptr.diff().tolist(), dim=dim) + common_map = ops.Map() + if reduce == "max": + out = common_map(ops.partial(_segment_max, dim), vals) + elif reduce == "min": + out = common_map(ops.partial(_segment_min, dim), vals) + elif reduce == "amax": + out = common_map(ops.partial(_segment_amax, dim), vals) + elif reduce == "amin": + out = common_map(ops.partial(_segment_amin, dim), vals) + elif reduce == "mul": + out = common_map(ops.partial(_segment_mul, dim), vals) + elif reduce in ["sum", "add"]: + out = common_map(ops.partial(_segment_sum, dim), vals) + elif reduce == "mean": + out = common_map(ops.partial(_segment_mean, dim), vals) + else: + raise ValueError(f'The value of reduce `{reduce}` is not supported!') + + if dim_size is not None: + shape = list(src.shape) + shape[dim] = dim_size - len(vals) + out += (mint.zeros(shape, dtype=out[0].dtype), ) + val = mint.cat(out, dim=dim) + return val diff --git a/mindscience/sharker/utils/assortativity.py b/mindscience/sharker/utils/assortativity.py new file mode 100644 index 000000000..af374e911 --- /dev/null +++ b/mindscience/sharker/utils/assortativity.py @@ -0,0 +1,61 @@ +import mindspore as ms +from mindspore import Tensor, ops, mint +from mindspore import ops +from .coalesce import coalesce +from .degree import degree +from .to_dense_adj import to_dense_adj + + +def assortativity(edge_index: Tensor) -> float: + r"""The degree assortativity coefficient from the + `"Mixing patterns in networks" + `_ paper. + Assortativity in a network refers to the tendency of nodes to + connect with other similar nodes over dissimilar nodes. + It is computed from Pearson correlation coefficient of the node degrees. + + Args: + edge_index (Tensor or SparseTensor): The graph connectivity. + + Returns: + The value of the degree assortativity coefficient for the input + graph :math:`\in [-1, 1]` + + Example: + >>> edge_index = Tensor([[0, 1, 2, 3, 2], + ... [1, 2, 0, 1, 3]]) + >>> assortativity(edge_index) + -0.666667640209198 + """ + assert isinstance(edge_index, Tensor) + row, col = edge_index + + out_deg = degree(row, dtype=ms.int64) + in_deg = degree(col, dtype=ms.int64) + degrees = mint.unique(mint.cat([out_deg, in_deg])) + mapping = row.new_zeros(degrees.max().item() + 1) + mapping[degrees] = mint.arange(degrees.shape[0]) + + # Compute degree mixing matrix (joint probability distribution) `M` + num_degrees = degrees.shape[0] + src_deg = mapping[out_deg[row]] + dst_deg = mapping[in_deg[col]] + + pairs = mint.stack([src_deg, dst_deg], dim=0) + occurrence = mint.ones(pairs.shape[1]) + pairs, occurrence = coalesce(pairs, occurrence) + M = to_dense_adj(pairs, edge_attr=occurrence, max_num_nodes=num_degrees)[0] + # normalization + M /= M.sum() + + # numeric assortativity coefficient, computed by + # Pearson correlation coefficient of the node degrees + x = y = degrees.float() + a, b = M.sum(0), M.sum(1) + + vara = (a * x**2).sum() - ((a * x).sum()) ** 2 + varb = (b * x**2).sum() - ((b * x).sum()) ** 2 + xy = ops.outer(x, y) + ab = ops.outer(a, b) + out = (xy * (M - ab)).sum() / (vara * varb).sqrt() + return out.item() diff --git a/mindscience/sharker/utils/augmentation.py b/mindscience/sharker/utils/augmentation.py new file mode 100644 index 000000000..890b08c9c --- /dev/null +++ b/mindscience/sharker/utils/augmentation.py @@ -0,0 +1,242 @@ +from typing import Optional, Tuple, Union + +import mindspore as ms +from mindspore import Tensor, ops, mint, Generator +from .functions import cumsum +from . import scatter +from .negative_sampling import negative_sampling + + +def shuffle_node( + x: Tensor, + batch: Optional[Tensor] = None, + training: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Randomly shuffle the feature matrix :obj:`x` along the + first dimmension. + + The method returns (1) the shuffled :obj:`x`, (2) the permutation + indicating the orders of original nodes after shuffling. + + Args: + x (FloatTensor): The feature matrix. + batch (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. Must be ordered. (default: :obj:`None`) + training (bool, optional): If set to :obj:`False`, this operation is a + no-op. (default: :obj:`True`) + + :rtype: (:class:`FloatTensor`, :class:`LongTensor`) + + Example: + >>> # Standard case + >>> x = Tensor([[0, 1, 2], + ... [3, 4, 5], + ... [6, 7, 8], + ... [9, 10, 11]], dtype=ms.float32) + >>> x, node_perm = shuffle_node(x) + >>> x + tensor([[ 3., 4., 5.], + [ 9., 10., 11.], + [ 0., 1., 2.], + [ 6., 7., 8.]]) + >>> node_perm + tensor([1, 3, 0, 2]) + + >>> # For batched graphs as inputs + >>> batch = Tensor([0, 0, 1, 1]) + >>> x, node_perm = shuffle_node(x, batch) + >>> x + tensor([[ 3., 4., 5.], + [ 0., 1., 2.], + [ 9., 10., 11.], + [ 6., 7., 8.]]) + >>> node_perm + tensor([1, 0, 3, 2]) + """ + perm = ops.arange(x.shape[0]).int() + if not training: + return x, perm + if batch is None: + perm = ops.shuffle(perm.expand_dims(0).transpose()).transpose()[0] + return x[perm], perm + num_nodes = scatter(batch.new_ones(x.shape[0]), batch, dim=0, reduce="sum") + ptr = cumsum(num_nodes) + perm = mint.cat( + [ + ops.shuffle(ops.arange(n).int().expand_dims(0).transpose()).transpose()[0] + offset + for offset, n in zip(ptr[:-1], num_nodes) + ] + ).int() + return x[perm], perm + + +def mask_feature( + x: Tensor, + p: float = 0.5, + mode: str = "col", + fill_value: float = 0.0, + training: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Randomly masks feature from the feature matrix + :obj:`x` with probability :obj:`p` using samples from + a Bernoulli distribution. + + The method returns (1) the retained :obj:`x`, (2) the feature + mask broadcastable with :obj:`x` (:obj:`mode='row'` and :obj:`mode='col'`) + or with the same shape as :obj:`x` (:obj:`mode='all'`), + indicating where features are retained. + + Args: + x (FloatTensor): The feature matrix. + p (float, optional): The masking ratio. (default: :obj:`0.5`) + mode (str, optional): The masked scheme to use for feature masking. + (:obj:`"row"`, :obj:`"col"` or :obj:`"all"`). + If :obj:`mode='col'`, will mask entire features of all nodes + from the feature matrix. If :obj:`mode='row'`, will mask entire + nodes from the feature matrix. If :obj:`mode='all'`, will mask + individual features across all nodes. (default: :obj:`'col'`) + fill_value (float, optional): The value for masked features in the + output tensor. (default: :obj:`0`) + training (bool, optional): If set to :obj:`False`, this operation is a + no-op. (default: :obj:`True`) + + :rtype: (:class:`FloatTensor`, :class:`BoolTensor`) + + Examples: + >>> # Masked features are column-wise sampled + >>> x = Tensor([[1, 2, 3], + ... [4, 5, 6], + ... [7, 8, 9]], dtype=ms.float32) + >>> x, feat_mask = mask_feature(x) + >>> x + tensor([[1., 0., 3.], + [4., 0., 6.], + [7., 0., 9.]]), + >>> feat_mask + tensor([[True, False, True]]) + + >>> # Masked features are row-wise sampled + >>> x, feat_mask = mask_feature(x, mode='row') + >>> x + tensor([[1., 2., 3.], + [0., 0., 0.], + [7., 8., 9.]]), + >>> feat_mask + tensor([[True], [False], [True]]) + + >>> # Masked features are uniformly sampled + >>> x, feat_mask = mask_feature(x, mode='all') + >>> x + tensor([[0., 0., 0.], + [4., 0., 6.], + [0., 0., 9.]]) + >>> feat_mask + tensor([[False, False, False], + [True, False, True], + [False, False, True]]) + """ + y = x.copy() + if p < 0.0 or p > 1.0: + raise ValueError(f"Masking ratio has to be between 0 and 1 " f"(got {p}") + if not training or p == 0.0: + return y, mint.ones_like(y, dtype=ms.bool_) + assert mode in ["row", "col", "all"] + + if mode == "row": + mask = ops.rand(y.shape[0]) >= p + mask = mask.view(-1, 1) + elif mode == "col": + mask = ops.rand(y.shape[1]) >= p + mask = mask.view(1, -1) + + else: + mask = ops.rand_like(y) >= p + + y = y.masked_fill(~mask, fill_value) + return y, mask + + +def add_random_edge( + edge_index: Tensor, + p: float = 0.5, + force_undirected: bool = False, + num_nodes: Optional[Union[int, Tuple[int, int]]] = None, + training: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Randomly adds edges to :obj:`edge_index`. + + The method returns (1) the retained :obj:`edge_index`, (2) the added + edge indices. + + Args: + edge_index (LongTensor): The edge indices. + p (float): Ratio of added edges to the existing edges. + (default: :obj:`0.5`) + force_undirected (bool, optional): If set to :obj:`True`, + added edges will be undirected. + (default: :obj:`False`) + num_nodes (int, Tuple[int], optional): The overall number of nodes, + *i.e.* :obj:`max_val + 1`, or the number of source and + destination nodes, *i.e.* :obj:`(max_src_val + 1, max_dst_val + 1)` + of :attr:`edge_index`. (default: :obj:`None`) + training (bool, optional): If set to :obj:`False`, this operation is a + no-op. (default: :obj:`True`) + + :rtype: (:class:`LongTensor`, :class:`LongTensor`) + + Examples: + >>> # Standard case + >>> edge_index = Tensor([[0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2]]) + >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5) + >>> edge_index + tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3], + [1, 0, 2, 1, 3, 2, 0, 2, 1]]) + >>> added_edges + tensor([[2, 1, 3], + [0, 2, 1]]) + + >>> # The returned graph is kept undirected + >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, + ... force_undirected=True) + >>> edge_index + tensor([[0, 1, 1, 2, 2, 3, 2, 1, 3, 0, 2, 1], + [1, 0, 2, 1, 3, 2, 0, 2, 1, 2, 1, 3]]) + >>> added_edges + tensor([[2, 1, 3, 0, 2, 1], + [0, 2, 1, 2, 1, 3]]) + + >>> # For bipartite graphs + >>> edge_index = Tensor([[0, 1, 2, 3, 4, 5], + ... [2, 3, 1, 4, 2, 1]]) + >>> edge_index, added_edges = add_random_edge(edge_index, p=0.5, + ... num_nodes=(6, 5)) + >>> edge_index + tensor([[0, 1, 2, 3, 4, 5, 3, 4, 1], + [2, 3, 1, 4, 2, 1, 1, 3, 2]]) + >>> added_edges + tensor([[3, 4, 1], + [1, 3, 2]]) + """ + if p < 0.0 or p > 1.0: + raise ValueError( + f"Ratio of added edges has to be between 0 and 1 " f"(got '{p}')" + ) + if force_undirected and isinstance(num_nodes, (tuple, list)): + raise RuntimeError("`force_undirected` is not supported for `bipartite graphs`") + + if not training or p == 0.0: + edge_index_to_add = mint.zeros([2, 0]) + return edge_index, edge_index_to_add + + edge_index_to_add = negative_sampling( + edge_index=edge_index, + num_nodes=num_nodes, + num_neg_samples=round(edge_index.shape[1] * p), + force_undirected=force_undirected, + ) + + edge_index = mint.cat([edge_index, edge_index_to_add], dim=1) + + return edge_index, edge_index_to_add diff --git a/mindscience/sharker/utils/cluster.py b/mindscience/sharker/utils/cluster.py new file mode 100644 index 000000000..b5593d270 --- /dev/null +++ b/mindscience/sharker/utils/cluster.py @@ -0,0 +1,560 @@ +from typing import List, Optional, Union + +import mindspore as ms +from mindspore import Tensor, ops, mint +from .functions import cumsum + + +from .repeat import repeat + +_knn = ops.MultitypeFuncGraph('_knn') +_radius = ops.MultitypeFuncGraph('_radius') +_nearest = ops.MultitypeFuncGraph('_nearest') +_fps = ops.MultitypeFuncGraph('_fps') +_grid = ops.MultitypeFuncGraph('_grid') +_rw = ops.MultitypeFuncGraph('_rw') +_graclus = ops.MultitypeFuncGraph('_graclus') + + +@_knn.register('Number', 'Bool', 'Tensor', 'Tensor', 'Number', 'Number') +def _knn(k: int, cosine: bool, x: Tensor, y: Tensor, ptr_x: int = 0, ptr_y: int = 0): + if cosine: + dist = None + raise NotImplementedError('The parameter cosine has not been implemented!') + else: + dist = ops.cdist(y, x) + + _, neighbors = mint.sort(dist, dim=1) + neighbors = neighbors[:, :k] + src = mint.nonzero(mint.ones_like(neighbors))[:, 0].int() + dst = neighbors.view(-1).astype(ms.int32) + edge_list = mint.stack([src + ptr_y, dst + ptr_x], dim=0) + return edge_list + + +def knn( + x: Tensor, + y: Tensor, + k: int, + batch_x: Optional[Tensor] = None, + batch_y: Optional[Tensor] = None, + cosine: bool = False, + batch_size: Optional[int] = None, +) -> Tensor: + r"""Finds for each element in :obj:`y` the :obj:`k` nearest points in + :obj:`x`. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + y (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{M \times F}`. + k (int): The number of neighbors. + batch_x (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. :obj:`batch_x` needs to be sorted. + (default: :obj:`None`) + batch_y (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each + node to a specific example. :obj:`batch_y` needs to be sorted. + (default: :obj:`None`) + cosine (boolean, optional): If :obj:`True`, will use the Cosine + distance instead of the Euclidean distance to find nearest + neighbors. (default: :obj:`False`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + + :rtype: :class:`LongTensor` + + .. code-block:: python + + import mindspore as ms + from mindspore import knn + + x = Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) + batch_x = ms.Tensor([0, 0, 0, 0]) + y = Tensor([[-1, 0], [1, 0]]) + batch_y = ms.Tensor([0, 0]) + assign_index = knn(x, y, 2, batch_x, batch_y) + """ + if batch_x is not None: + count_x = batch_x.bincount().long().astype(ms.int32) + ptr_x = [0] + count_x.cumsum().tolist()[:-1] + x = x.split(count_x.tolist()) + elif batch_size is not None: + ptr_x = (mint.arange(batch_size) * batch_size).tolist() + x = x.split(batch_size) + else: + ptr_x = [0] + x = [x] + if batch_y is not None: + count_y = batch_y.bincount().long().astype(ms.int32) + ptr_y = [0] + count_y.cumsum().tolist()[:-1] + y = y.split(count_y.tolist()) + elif batch_size is not None: + ptr_y = (mint.arange(batch_size) * batch_size).tolist() + y = y.split(batch_size) + else: + ptr_y = [0] + y = [y] + assert len(x) == len(y) + common_map = ops.Map() + edge_index = common_map(ops.partial(_knn, k, cosine), x, y, ptr_x, ptr_y) + edge_index = mint.cat(edge_index, dim=1) + return edge_index + + +def knn_graph( + x: Tensor, + k: int, + batch: Optional[Tensor] = None, + loop: bool = False, + flow: str = 'src_to_trg', + cosine: bool = False, + batch_size: Optional[int] = None, +) -> Tensor: + r"""Computes graph edges to the nearest :obj:`k` points. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + k (int): The number of neighbors. + batch (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. :obj:`batch` needs to be sorted. + (default: :obj:`None`) + loop (bool, optional): If :obj:`True`, the graph will contain + self-loops. (default: :obj:`False`) + flow (string, optional): The flow direction when used in combination + with message passing (:obj:`"src_to_trg"` or + :obj:`"trg_to_src"`). (default: :obj:`"src_to_trg"`) + cosine (boolean, optional): If :obj:`True`, will use the Cosine + distance instead of Euclidean distance to find nearest neighbors. + (default: :obj:`False`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + + :rtype: :class:`LongTensor` + + .. code-block:: python + + import mindspore as ms + from mindGeometric_cluster import knn_graph + + x = Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) + batch = ms.Tensor([0, 0, 0, 0]) + edge_index = knn_graph(x, k=2, batch=batch, loop=False) + """ + + assert flow in ['src_to_trg', 'trg_to_src'] + edge_index = knn(x, x, k if loop else k + 1, batch, batch, cosine, batch_size) + + if flow == 'src_to_trg': + row, col = edge_index[1], edge_index[0] + else: + row, col = edge_index[0], edge_index[1] + + if not loop: + mask = row != col + row, col = row[mask], col[mask] + + return mint.stack([row, col], dim=0) + + +@_radius.register('Number', 'Number', 'Tensor', 'Tensor') +def _radius(r: float, max_num_neighbors: int = 32, x: Tensor = None, y: Tensor = None) -> Tensor: + dist = ops.cdist(x, y) + if max_num_neighbors is None: + edge_list = mint.nonzero(dist < r) + else: + sorted_dist, neighbors = mint.sort(dist, dim=1) + mask = sorted_dist < r + if max_num_neighbors < len(x): + neighbors = neighbors[:, :max_num_neighbors] + mask = mask[:, :max_num_neighbors] + src = mint.nonzero(mask)[:, 0].int() + dst = neighbors[mask] + edge_list = mint.stack([src, dst], dim=0) + return edge_list + + +def radius( + x: Tensor, + y: Tensor, + r: float, + batch_x: Optional[Tensor] = None, + batch_y: Optional[Tensor] = None, + max_num_neighbors: int = 32, + batch_size: Optional[int] = None, +) -> Tensor: + r"""Finds for each element in :obj:`y` all points in :obj:`x` within + distance :obj:`r`. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + y (Tensor): Node feature matrix + :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`. + r (float): The radius. + batch_x (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. :obj:`batch_x` needs to be sorted. + (default: :obj:`None`) + batch_y (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each + node to a specific example. :obj:`batch_y` needs to be sorted. + (default: :obj:`None`) + max_num_neighbors (int, optional): The maximum number of neighbors to + return for each element in :obj:`y`. + If the number of actual neighbors is greater than + :obj:`max_num_neighbors`, returned neighbors are picked randomly. + (default: :obj:`32`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + + .. code-block:: python + + import mindspore as ms + from mindGeometric_cluster import radius + + x = Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) + batch_x = ms.Tensor([0, 0, 0, 0]) + y = Tensor([[-1, 0], [1, 0]]) + batch_y = ms.Tensor([0, 0]) + assign_index = radius(x, y, 1.5, batch_x, batch_y) + """ + if batch_x is not None: + batch_x = batch_x.bincount().long().tolist() + x = x.split(batch_x) + elif batch_size is not None: + x = x.split(batch_size) + else: + x = [x] + if batch_y is not None: + batch_y = batch_y.bincount().long().tolist() + y = y.split(batch_y) + elif batch_size is not None: + y = y.split(batch_size) + else: + y = [y] + assert len(x) == len(y) + common_map = ops.Map() + edge_index = common_map(ops.partial(_radius, r, max_num_neighbors), x, y) + edge_index = mint.cat(edge_index, dim=1) + return edge_index + + +def radius_graph( + x: Tensor, + r: float, + batch: Optional[Tensor] = None, + loop: bool = False, + max_num_neighbors: int = 32, + flow: str = "src_to_dst", + batch_size: Optional[int] = None, + + +) -> Tensor: + r"""Computes graph edges to all points within a given distance. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + r (float): The radius. + batch (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. :obj:`batch` needs to be sorted. + (default: :obj:`None`) + loop (bool, optional): If :obj:`True`, the graph will contain + self-loops. (default: :obj:`False`) + max_num_neighbors (int, optional): The maximum number of neighbors to + return for each element. + If the number of actual neighbors is greater than + :obj:`max_num_neighbors`, returned neighbors are picked randomly. + (default: :obj:`32`) + flow (string, optional): The flow direction when used in combination + with message passing (:obj:`"src_to_dst"` or + :obj:`"dst_to_src"`). (default: :obj:`"src_to_dst"`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + + :rtype: :class:`LongTensor` + + .. code-block:: python + + import mindspore as ms + from cluster import radius_graph + + x = Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) + batch = ms.Tensor([0, 0, 0, 0]) + edge_index = radius_graph(x, r=1.5, batch=batch, loop=False) + """ + assert flow in ['src_to_dst', 'dst_to_src'] + max_num_neighbors = max_num_neighbors if loop else max_num_neighbors + 1 + edge_index = radius(x, x, r=r, batch_x=batch, batch_y=batch, max_num_neighbors=max_num_neighbors, + batch_size=batch_size) + if flow == 'src_to_dst': + row, col = edge_index[1], edge_index[0] + else: + row, col = edge_index[0], edge_index[1] + + if not loop: + mask = row != col + row, col = row[mask], col[mask] + + return mint.stack([row, col], dim=0) + + +@_nearest.register('Tensor', 'Tensor', 'Number') +def _nearest(x: Tensor, y: Tensor, ptr: int = 0): + dist = ops.cdist(x.float(), y.float()) + return dist.argmin(axis=1) + ptr + + +def nearest( + x: Tensor, + y: Tensor, + batch_x: Optional[Tensor] = None, + batch_y: Optional[Tensor] = None, +) -> Tensor: + r"""Clusters points in :obj:`x` together which are nearest to a given query + point in :obj:`y`. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + y (Tensor): Node feature matrix + :math:`\mathbf{Y} \in \mathbb{R}^{M \times F}`. + batch_x (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. :obj:`batch_x` needs to be sorted. + (default: :obj:`None`) + batch_y (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^M`, which assigns each + node to a specific example. :obj:`batch_y` needs to be sorted. + (default: :obj:`None`) + + :rtype: :class:`LongTensor` + + .. code-block:: python + + import mindspore as ms + from mindGeometric_cluster import nearest + + x = Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]) + batch_x = ms.Tensor([0, 0, 0, 0]) + y = Tensor([[-1, 0], [1, 0]]) + batch_y = ms.Tensor([0, 0]) + cluster = nearest(x, y, batch_x, batch_y) + """ + if batch_x is not None: + count_x = batch_x.bincount().long() + x = x.split(count_x.tolist()) + else: + x = [x] + if batch_y is not None: + count_y = batch_y.bincount().long() + ptr = [0] + cumsum(count_y).tolist()[:-1] + y = y.split(count_y.tolist()) + else: + ptr = [0] + y = [y] + common_map = ops.Map() + edge_index = common_map(_nearest, x, y, ptr) + edge_index = mint.cat(edge_index) + return edge_index + + +@_graclus.register('Tensor', 'Tensor', 'Tensor') +def _graclus(rowptr, col, weight): + pass + + +def graclus( + edge_index: Tensor, + weight: Optional[Tensor] = None, + num_nodes: Optional[int] = None +): + r"""A greedy clustering algorithm from the `"Weighted Graph Cuts without + Eigenvectors: A Multilevel Approach" `_ paper of picking an unmarked + vertex and matching it with one of its unmarked neighbors (that maximizes + its edge weight). + The GPU algorithm is adapted from the `"A GPU Algorithm for Greedy Graph + Matching" `_ + paper. + + Args: + edge_index (Tensor): The edge indices. + weight (Tensor, optional): One-dimensional edge weights. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + row, col = edge_index[0], edge_index[1] + if num_nodes is None: + num_nodes = max(int(row.max()), int(col.max())) + 1 + + # Remove self-loops. + mask = row != col + row, col = row[mask], col[mask] + + if weight is not None: + weight = weight[mask] + + # Randomly shuffle nodes. + if weight is None: + perm = ops.shuffle(mint.arange(row.size(0))) + row, col = row[perm], col[perm] + + # To CSR. + perm = ops.argsort(row) + row, col = row[perm], col[perm] + + if weight is not None: + weight = weight[perm] + + deg = mint.zeros(num_nodes).long() + ops.tensor_scatter_elements(row, 0, mint.ones_like(row), reduction='add') + rowptr = mint.zeros(num_nodes + 1).long() + rowptr[1:] = mint.cumsum(deg, 0) + return _graclus(rowptr, col, weight) + + +@_grid.register() +def _grid(pos, size, start, end): + pass + + +def grid(pos, size, start, end): + pass + + +def voxel_grid( + pos: Tensor, + size: Union[float, List[float], Tensor], + batch: Optional[Tensor] = None, + start: Optional[Union[float, List[float], Tensor]] = None, + end: Optional[Union[float, List[float], Tensor]] = None, +) -> Tensor: + r"""Voxel grid pooling from the, *e.g.*, `Dynamic Edge-Conditioned Filters + in Convolutional Networks on Graphs `_ + paper, which overlays a regular grid of user-defined size over a point + cloud and clusters all points within the same voxel. + + Args: + pos (Tensor): Node position matrix + :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times D}`. + size (float or [float] or Tensor): Size of a voxel (in each dimension). + batch (Tensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns each + node to a specific example. (default: :obj:`None`) + start (float or [float] or Tensor, optional): Start coordinates of the + grid (in each dimension). If set to :obj:`None`, will be set to the + minimum coordinates found in :attr:`pos`. (default: :obj:`None`) + end (float or [float] or Tensor, optional): End coordinates of the grid + (in each dimension). If set to :obj:`None`, will be set to the + maximum coordinates found in :attr:`pos`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + pos = pos.unsqueeze(-1) if pos.dim() == 1 else pos + dim = pos.shape[1] + + if batch is None: + batch = mint.zeros(pos.shape[0], dtype=ms.int32) + + pos = mint.cat([pos, batch.view(-1, 1).astype(pos.dtype)], dim=-1) + + if not isinstance(size, Tensor): + size = Tensor(size, dtype=pos.dtype) + size = repeat(size, dim) + size = mint.cat([size, mint.ones(1, dtype=size.dtype)]) # Add additional batch dim. + + if start is not None: + if not isinstance(start, Tensor): + start = Tensor(start, dtype=pos.dtype) + start = repeat(start, dim) + start = mint.cat([start, mint.zeros(1, dtype=start.dtype)]) + + if end is not None: + if not isinstance(end, Tensor): + end = Tensor(end, dtype=pos.dtype) + end = repeat(end, dim) + end = mint.cat([end, batch.max().unsqueeze(0)]) + + return grid(pos, size, start, end) + + +@_fps.register('Tensor', 'Number', 'Bool', 'Number') +def _fps(x: Tensor, ratio: float, random_start: bool, ptr: int): + dist = ops.cdist(x, x) + pos = ops.shuffle(mint.arange(len(x)))[0] if random_start else 0 + num = x.shape[0] * ratio + num = int(num) if num > 0 else 1 + out = mint.zeros(num) + for i in range(num): + pos = dist[pos].argmax() + out[num - i] = pos + dist[:, pos] = float('nan') + return out + ptr + + +def fps(x: Tensor, + batch: Tensor, + ratio: float = 0.5, + random_start: bool = True, + batch_size: int = None): + r"""Farthest point sampling (FPS) algorithm from the + `"PointNet++: Deep Hierarchical Feature Learning on Point Sets in a Metric Space" + `_ paper, which iteratively samples the + most distant point with regard to the rest points. + + .. code-block:: python + + import mindspore as ms + from mindscience.sharker.nn import fps + + x = Tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) + batch = Tensor([0, 0, 0, 0]) + index = fps(x, batch, ratio=0.5) + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{N \times F}`. + batch (Tensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. (default: :obj:`None`) + ratio (float, optional): Sampling ratio. (default: :obj:`0.5`) + random_start (bool, optional): If set to :obj:`False`, use the first + node in :math:`\mathbf{X}` as starting node. (default: obj:`True`) + batch_size (int, optional): The number of examples :math:`B`. + Automatically calculated if not given. (default: :obj:`None`) + + :rtype: :class:`Tensor` + """ + if batch is not None: + count = batch.bincount().long() + ptr = [0] + count.cumsum().tolist()[:-1] + x = x.split(count.tolist()) + elif batch_size is not None: + ptr = (mint.arange(batch_size) * batch_size).tolist() + x = x.split(batch_size) + else: + ptr = [0] + x = [x] + common_map = ops.Map() + edge_index = common_map(_fps, x, ratio, random_start, ptr) + edge_index = mint.cat(edge_index) + return edge_index + + +@_rw.register() +def _rw(): + pass + + +def random_walk(): + pass diff --git a/mindscience/sharker/utils/coalesce.py b/mindscience/sharker/utils/coalesce.py new file mode 100644 index 000000000..b22912e67 --- /dev/null +++ b/mindscience/sharker/utils/coalesce.py @@ -0,0 +1,194 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .num_nodes import maybe_num_nodes +from . import scatter + +MISSING = "???" + + +def coalesce( # noqa: F811 + edge_index: Tensor, + edge_attr: Union[Optional[Tensor], List[Tensor], str] = MISSING, + num_nodes: Optional[int] = None, + reduce: str = "sum", + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Union[Tensor, Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, List[Tensor]]]: + """Row-wise sorts :obj:`edge_index` and removes its duplicated entries. + Duplicate entries in :obj:`edge_attr` are merged by scattering them + together according to the given :obj:`reduce` option. + + Args: + edge_index (Tensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights + or multi-dimensional edge features. + If given as a list, will re-shuffle and remove duplicates for all + its entries. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + reduce (str, optional): The reduce operation to use for merging edge + features (:obj:`"sum"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"mul"`, :obj:`"any"`). (default: :obj:`"sum"`) + is_sorted (bool, optional): If set to :obj:`True`, will expect + :obj:`edge_index` to be already sorted row-wise. + sort_by_row (bool, optional): If set to :obj:`False`, will sort + :obj:`edge_index` column-wise. + + :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else + (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) + + .. warning:: + + From :pyg:`PyG >= 2.3.0` onwards, this function will always return a + tuple whenever :obj:`edge_attr` is passed as an argument (even in case + it is set to :obj:`None`). + + Example: + >>> edge_index = Tensor([[1, 1, 2, 3], + ... [3, 3, 1, 2]]) + >>> edge_attr = Tensor([1., 1., 1., 1.]) + >>> coalesce(edge_index) + tensor([[1, 2, 3], + [3, 1, 2]]) + + >>> # Sort `edge_index` column-wise + >>> coalesce(edge_index, sort_by_row=False) + tensor([[2, 3, 1], + [1, 2, 3]]) + + >>> coalesce(edge_index, edge_attr) + (tensor([[1, 2, 3], + [3, 1, 2]]), + tensor([2., 1., 1.])) + + >>> # Use 'mean' operation to merge edge features + >>> coalesce(edge_index, edge_attr, reduce='mean') + (tensor([[1, 2, 3], + [3, 1, 2]]), + tensor([1., 1., 1.])) + """ + num_edges = edge_index[0].shape[0] + num_nodes = maybe_num_nodes(edge_index, num_nodes) + idx = mint.neg(mint.ones(num_edges + 1, dtype=ms.int64)) + idx_from_1 = edge_index[1 - int(sort_by_row)] + idx_from_1 = mint.add(edge_index[int(sort_by_row)], idx_from_1, alpha=num_nodes) + + if not is_sorted: + idx_from_1, perm = mint.sort(idx_from_1) + if isinstance(edge_index, Tensor): + edge_index = mint.index_select(edge_index, 1, perm) + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][perm], edge_index[1][perm]) + else: + raise NotImplementedError + if isinstance(edge_attr, Tensor): + edge_attr = mint.index_select(edge_attr, 0, perm) + elif isinstance(edge_attr, (list, tuple)): + edge_attr = [e[perm] for e in edge_attr] + + idx[1:] = idx_from_1 + mask = mint.greater(idx_from_1, idx[:-1]) + + # Only perform expensive merging in case there exists duplicates:: + if mint.all(mask): + if edge_attr is None or isinstance(edge_attr, Tensor): + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): + return edge_index, edge_attr + return edge_index + if isinstance(edge_index, Tensor): + # edge_index = edge_index[:, mask] + edge_index = edge_index[:, mask] + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][mask], edge_index[1][mask]) + else: + raise NotImplementedError + + dim_size = None + if isinstance(edge_attr, (Tensor, list, tuple)) and len(edge_attr) > 0: + dim_size = edge_index.shape[1] + idx = mint.arange(0, num_edges) + idx -= mint.cumsum(~mask, dim=0) + + if edge_attr is None: + return edge_index, None + if isinstance(edge_attr, Tensor): + edge_attr = scatter(edge_attr, idx, 0, dim_size, reduce) + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): + if len(edge_attr) == 0: + return edge_index, edge_attr + edge_attr = [scatter(e, idx, 0, dim_size, reduce) for e in edge_attr] + return edge_index, edge_attr + + return edge_index + +def coalesce_np( # noqa: F811 + edge_index: np.ndarray, + edge_attr: Union[Optional[np.ndarray], List[np.ndarray], str] = MISSING, + num_nodes: Optional[int] = None, + reduce: str = "sum", + is_sorted: bool = False, + sort_by_row: bool = True, +) -> Union[Tensor, Tuple[np.ndarray, Optional[np.ndarray]], Tuple[np.ndarray, List[np.ndarray]]]: + + num_edges = edge_index[0].shape[0] + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + idx = -np.ones(num_edges + 1, dtype=np.int32) + idx_from_1 = edge_index[1 - int(sort_by_row)] + idx_from_1 = idx_from_1 * num_nodes + edge_index[int(sort_by_row)] + + if not is_sorted: + perm = np.argsort(idx_from_1) + idx_from_1 = np.sort(idx_from_1) + if isinstance(edge_index, np.ndarray): + edge_index = edge_index[:, perm] + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][perm], edge_index[1][perm]) + else: + raise NotImplementedError + if isinstance(edge_attr, np.ndarray): + edge_attr = edge_attr[perm] + elif isinstance(edge_attr, (list, tuple)): + edge_attr = [e[perm] for e in edge_attr] + + idx[1:] = idx_from_1 + + mask = idx[1:] > idx[:-1] + + # Only perform expensive merging in case there exists duplicates: + if mask.all(): + if edge_attr is None or isinstance(edge_attr, np.ndarray): + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): + return edge_index, edge_attr + return edge_index + if isinstance(edge_index, np.ndarray): + edge_index = edge_index[:, mask] + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][mask], edge_index[1][mask]) + else: + raise NotImplementedError + + dim_size = None + if isinstance(edge_attr, (np.ndarray, list, tuple)) and len(edge_attr) > 0: + dim_size = edge_index.shape[1] + idx = np.arange(0, num_edges) + idx -= np.cumsum(~mask, axis=0) + + if edge_attr is None: + return edge_index, None + if isinstance(edge_attr, Tensor): + edge_attr = scatter(edge_attr, idx, 0, dim_size, reduce) + return edge_index, edge_attr + if isinstance(edge_attr, (list, tuple)): + if len(edge_attr) == 0: + return edge_index, edge_attr + edge_attr = [scatter(e, idx, 0, dim_size, reduce) for e in edge_attr] + return edge_index, edge_attr + + return edge_index diff --git a/mindscience/sharker/utils/convert.py b/mindscience/sharker/utils/convert.py new file mode 100644 index 000000000..ce64a9533 --- /dev/null +++ b/mindscience/sharker/utils/convert.py @@ -0,0 +1,389 @@ +from collections import defaultdict +from typing import Any, Dict, Iterable, List, Literal, Optional, Tuple, Union + +import numpy as np +import scipy.sparse +import mindspore as ms +from mindspore import Tensor, ops, mint + +import sharker +from .num_nodes import maybe_num_nodes + + +def to_tensor(data): + if isinstance(data, tuple): + return (to_tensor(d) for d in data) + elif isinstance(data, list): + return [to_tensor(d) for d in data] + elif isinstance(data, set): + return set([to_tensor(d) for d in data]) + elif isinstance(data, dict): + return {k: to_tensor(v) for k, v in data.items()} + elif isinstance(data, np.ndarray): + return Tensor.from_numpy(data) + elif isinstance(data, Tensor): + return data + else: + raise NotImplementedError("Datatype {} cannot be cast to tensor!") + + +def to_array(data): + if isinstance(data, tuple): + return (to_array(d) for d in data) + elif isinstance(data, list): + return [to_array(d) for d in data] + elif isinstance(data, set): + return set([to_array(d) for d in data]) + elif isinstance(data, dict): + return {k: to_array(v) for k, v in data.items()} + elif isinstance(data, Tensor): + return data.asnumpy() + elif isinstance(data, np.ndarray): + return data + else: + raise NotImplementedError("Datatype {} cannot be cast to tensor!") + + +def to_scipy_sparse_matrix( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, +) -> scipy.sparse.coo_matrix: + r"""Converts a graph given by edge indices and edge attributes to a scipy + sparse matrix. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + + Examples: + >>> edge_index = Tensor([ + ... [0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2], + ... ]) + >>> to_scipy_sparse_matrix(edge_index) + <4x4 sparse matrix of type '' + with 6 stored elements in COOrdinate format> + """ + src, dst = edge_index.asnumpy() + + if edge_attr is None: + edge_attr = mint.ones(src.shape[0]) + else: + edge_attr = edge_attr.reshape(-1) + assert edge_attr.shape[0] == edge_index.shape[1] + + N = maybe_num_nodes(edge_index, num_nodes) + out = scipy.sparse.coo_matrix((edge_attr.asnumpy(), (src, dst)), (N, N)) + return out + + +def from_scipy_sparse_matrix(A: scipy.sparse.spmatrix) -> Tuple[Tensor, Tensor]: + r"""Converts a scipy sparse matrix to edge indices and edge attributes. + + Args: + A (scipy.sparse): A sparse matrix. + + Examples: + >>> edge_index = Tensor([ + ... [0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2], + ... ]) + >>> adj = to_scipy_sparse_matrix(edge_index) + >>> # `edge_index` and `edge_weight` are both returned + >>> from_scipy_sparse_matrix(adj) + (tensor([[0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2]]), + tensor([1., 1., 1., 1., 1., 1.])) + """ + A = A.tocoo() + src = Tensor.from_numpy(A.row).astype(ms.int64) + dst = Tensor.from_numpy(A.col).astype(ms.int64) + edge_index = mint.stack([src, dst], dim=0) + edge_weight = Tensor.from_numpy(A.data) + return edge_index, edge_weight + + +def to_networkx( + graph: Union["sharker.data.Graph", "sharker.data.HeteroGraph"], + node_attrs: Optional[Iterable[str]] = None, + edge_attrs: Optional[Iterable[str]] = None, + graph_attrs: Optional[Iterable[str]] = None, + to_undirected: Optional[Union[bool, str]] = False, + to_multi: bool = False, + remove_self_loops: bool = False, +) -> Any: + r"""Converts a :class:`sharker.data.Graph` instance to a + :obj:`networkx.Graph` if :attr:`to_undirected` is set to :obj:`True`, or + a directed :obj:`networkx.DiGraph` otherwise. + + Args: + graph (sharker.data.Graph or sharker.data.HeteroGraph): A + homogeneous or heterogeneous data object. + node_attrs (iterable of str, optional): The node attributes to be + copied. (default: :obj:`None`) + edge_attrs (iterable of str, optional): The edge attributes to be + copied. (default: :obj:`None`) + graph_attrs (iterable of str, optional): The graph attributes to be + copied. (default: :obj:`None`) + to_undirected (bool or str, optional): If set to :obj:`True`, will + return a :class:`networkx.Graph` instead of a + :class:`networkx.DiGraph`. + By default, will include all edges and make them undirected. + If set to :obj:`"upper"`, the undirected graph will only correspond + to the upper triangle of the input adjacency matrix. + If set to :obj:`"lower"`, the undirected graph will only correspond + to the lower triangle of the input adjacency matrix. + Only applicable in case the :obj:`data` object holds a homogeneous + graph. (default: :obj:`False`) + to_multi (bool, optional): if set to :obj:`True`, will return a + :class:`networkx.MultiGraph` or a :class:`networkx:MultiDiGraph` + (depending on the :obj:`to_undirected` option), which will not drop + duplicated edges that may exist in :obj:`data`. + (default: :obj:`False`) + remove_self_loops (bool, optional): If set to :obj:`True`, will not + include self-loops in the resulting graph. (default: :obj:`False`) + + Examples: + >>> edge_index = Tensor([ + ... [0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2], + ... ]) + >>> data = Graph(edge_index=edge_index, num_nodes=4) + >>> to_networkx(data) + + + """ + import networkx as nx + from ..data import HeteroGraph + + to_undirected_upper: bool = to_undirected == "upper" + to_undirected_lower: bool = to_undirected == "lower" + + to_undirected = to_undirected is True + to_undirected |= to_undirected_upper or to_undirected_lower + assert isinstance(to_undirected, bool) + + if isinstance(graph, HeteroGraph) and to_undirected: + raise ValueError( + "'to_undirected' is not supported in " + "'to_networkx' for heterogeneous graphs" + ) + + if to_undirected: + G = nx.MultiGraph() if to_multi else nx.Graph() + else: + G = nx.MultiDiGraph() if to_multi else nx.DiGraph() + + def to_networkx_value(value: Any) -> Any: + return value.tolist() if isinstance(value, Tensor) else value + + for key in graph_attrs or []: + G.graph[key] = to_networkx_value(graph[key]) + + node_offsets = graph.node_offsets + for node_store in graph.node_stores: + start = node_offsets[node_store._key] + assert node_store.num_nodes is not None + for i in range(node_store.num_nodes): + node_kwargs: Dict[str, Any] = {} + if isinstance(graph, HeteroGraph): + node_kwargs["type"] = node_store._key + for key in node_attrs or []: + node_kwargs[key] = to_networkx_value(node_store[key][i]) + + G.add_node(start + i, **node_kwargs) + + for edge_store in graph.edge_stores: + for i, (v, w) in enumerate(edge_store.edge_index.t().tolist()): + if to_undirected_upper and v > w: + continue + elif to_undirected_lower and v < w: + continue + elif remove_self_loops and v == w and not edge_store.is_bipartite(): + continue + + edge_kwargs: Dict[str, Any] = {} + if isinstance(graph, HeteroGraph): + v = v + node_offsets[edge_store._key[0]] + w = w + node_offsets[edge_store._key[-1]] + edge_kwargs["type"] = edge_store._key + for key in edge_attrs or []: + edge_kwargs[key] = to_networkx_value(edge_store[key][i]) + + G.add_edge(v, w, **edge_kwargs) + + return G + + +def from_networkx( + G: Any, + group_node_attrs: Optional[Union[List[str], Literal["all"]]] = None, + group_edge_attrs: Optional[Union[List[str], Literal["all"]]] = None, +) -> "sharker.data.Graph": + r"""Converts a :obj:`networkx.Graph` or :obj:`networkx.DiGraph` to a + :class:`sharker.data.Graph` instance. + + Args: + G (networkx.Graph or networkx.DiGraph): A networkx graph. + group_node_attrs (List[str] or "all", optional): The node attributes to + be concatenated and added to :obj:`data.x`. (default: :obj:`None`) + group_edge_attrs (List[str] or "all", optional): The edge attributes to + be concatenated and added to :obj:`data.edge_attr`. + (default: :obj:`None`) + + .. note:: + + All :attr:`group_node_attrs` and :attr:`group_edge_attrs` values must + be numeric. + + Examples: + >>> edge_index = Tensor([ + ... [0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2], + ... ]) + >>> data = Graph(edge_index=edge_index, num_nodes=4) + >>> g = to_networkx(data) + >>> # A `Graph` object is returned + >>> from_networkx(g) + Graph(edge_index=[2, 6], num_nodes=4) + """ + import networkx as nx + from ..data import Graph + + G = G.to_directed() if not nx.is_directed(G) else G + + mapping = dict(zip(G.nodes(), range(G.number_of_nodes()))) + edge_index = -mint.ones((2, G.number_of_edges()), dtype=ms.int64) + for i, (src, dst) in enumerate(G.edges()): + edge_index[0, i] = mapping[src] + edge_index[1, i] = mapping[dst] + + data_dict: Dict[str, Any] = defaultdict(list) + data_dict["edge_index"] = edge_index + + node_attrs: List[str] = [] + if G.number_of_nodes() > 0: + node_attrs = list(next(iter(G.nodes(data=True)))[-1].keys()) + + edge_attrs: List[str] = [] + if G.number_of_edges() > 0: + edge_attrs = list(next(iter(G.edges(data=True)))[-1].keys()) + + if group_node_attrs is not None and not isinstance(group_node_attrs, list): + group_node_attrs = node_attrs + + if group_edge_attrs is not None and not isinstance(group_edge_attrs, list): + group_edge_attrs = edge_attrs + + for i, (_, feat_dict) in enumerate(G.nodes(data=True)): + if set(feat_dict.keys()) != set(node_attrs): + raise ValueError("Not all nodes contain the same attributes") + for key, value in feat_dict.items(): + data_dict[str(key)].append(value) + + for i, (_, _, feat_dict) in enumerate(G.edges(data=True)): + if set(feat_dict.keys()) != set(edge_attrs): + raise ValueError("Not all edges contain the same attributes") + for key, value in feat_dict.items(): + key = f"edge_{key}" if key in node_attrs else key + data_dict[str(key)].append(value) + + for key, value in G.graph.items(): + if key == "node_default" or key == "edge_default": + continue # Do not load default attributes. + key = f"graph_{key}" if key in node_attrs else key + data_dict[str(key)] = value + + for key, value in data_dict.items(): + if isinstance(value, (tuple, list)) and isinstance(value[0], Tensor): + data_dict[key] = mint.stack(value, dim=0) + else: + try: + data_dict[key] = ms.Tensor(value) + except Exception: + pass + + data = Graph.from_dict(data_dict) + + if group_node_attrs is not None: + xs = [] + for key in group_node_attrs: + x = data[key] + x = x.view(-1, 1) if x.dim() <= 1 else x + xs.append(x) + del data[key] + data.x = mint.cat(xs, dim=-1) + + if group_edge_attrs is not None: + xs = [] + for key in group_edge_attrs: + key = f"edge_{key}" if key in node_attrs else key + x = data[key] + x = x.view(-1, 1) if x.dim() <= 1 else x + xs.append(x) + del data[key] + data.edge_attr = mint.cat(xs, dim=-1) + + if data.x is None and data.crd is None: + data.num_nodes = G.number_of_nodes() + + return data + + +def to_trimesh(data: "sharker.data.Graph") -> Any: + r"""Converts a :class:`sharker.data.Graph` instance to a + :obj:`trimesh.Trimesh`. + + Args: + data (sharker.data.Graph): The data object. + + Example: + >>> crd = Tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]], + ... dtype=ms.float32) + >>> face = Tensor, 3]]).t() + + >>> data = Graph(crd=crd, face=face) + >>> to_trimesh(data) + + """ + import trimesh + + assert data.crd is not None + assert data.face is not None + + return trimesh.Trimesh( + vertices=data.crd.asnumpy(), + faces=data.face.T.asnumpy(), + process=False, + ) + + +def from_trimesh(mesh: Any) -> "sharker.data.Graph": + r"""Converts a :obj:`trimesh.Trimesh` to a + :class:`sharker.data.Graph` instance. + + Args: + mesh (trimesh.Trimesh): A :obj:`trimesh` mesh. + + Example: + >>> crd = Tensor [1, 0, 0], [0, 1, 0], [1, 1, 0]], + ... dtype=ms.float32) + >>> face = Tensor([[0, 1, 2], [1, 2, 3]]).t() + + >>> data = Graph(crd=crd, face=face) + >>> mesh = to_trimesh(data) + >>> from_trimesh(mesh) + Graph(crd=[4, 3], face=[3, 2]) + """ + from ..data import Graph + + crd = Tensor.from_numpy(mesh.vertices).float() + face = Tensor.from_numpy(mesh.faces).T + + return Graph(crd=crd, face=face) + + diff --git a/mindscience/sharker/utils/degree.py b/mindscience/sharker/utils/degree.py new file mode 100644 index 000000000..d721831bc --- /dev/null +++ b/mindscience/sharker/utils/degree.py @@ -0,0 +1,31 @@ +from typing import Optional + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .num_nodes import maybe_num_nodes + + +def degree( + index: Tensor, num_nodes: Optional[int] = None, dtype: Optional[ms.Type] = None +) -> Tensor: + r"""Computes the (unweighted) degree of a given one-dimensional index + tensor. + + Args: + index (Tensor): Index tensor. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dtype (:obj:`ms.dtype`, optional): The desired data type of the + returned tensor. + + :rtype: :class:`Tensor` + + Example: + >>> row = Tensor([0, 1, 0, 2, 0]) + >>> degree(row, dtype=ms.int64) + tensor([3, 1, 1]) + """ + N = maybe_num_nodes(index, num_nodes) + one = mint.ones((index.shape[0],), dtype=dtype) + out = ops.unsorted_segment_sum(one, index, N) + return out diff --git a/mindscience/sharker/utils/dropout.py b/mindscience/sharker/utils/dropout.py new file mode 100644 index 000000000..602a643da --- /dev/null +++ b/mindscience/sharker/utils/dropout.py @@ -0,0 +1,151 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint, Generator + +from .subgraph import subgraph +from .num_nodes import maybe_num_nodes + + +def filter_adj( + row: Tensor, col: Tensor, edge_attr: Optional[Tensor], mask: Tensor +) -> Tuple[Tensor, Tensor, Optional[Tensor]]: + """_summary_ + + Args: + row (Tensor): _description_ + col (Tensor): _description_ + edge_attr (Optional[Tensor]): _description_ + mask (Tensor): _description_ + + Returns: + Tuple[Tensor, Tensor, Optional[Tensor]]: _description_ + """ + return row[mask], col[mask], None if edge_attr is None else edge_attr[mask] + + +def dropout_node( + edge_index: Tensor, + p: float = 0.5, + num_nodes: Optional[int] = None, + training: bool = True, + relabel_nodes: bool = False, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Randomly drops nodes from the adjacency matrix + :obj:`edge_index` with probability :obj:`p` using samples from + a Bernoulli distribution. + + The method returns (1) the retained :obj:`edge_index`, (2) the edge mask + indicating which edges were retained. (3) the node mask indicating + which nodes were retained. + + Args: + edge_index (LongTensor): The edge indices. + p (float, optional): Dropout probability. (default: :obj:`0.5`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + training (bool, optional): If set to :obj:`False`, this operation is a + no-op. (default: :obj:`True`) + relabel_nodes (bool, optional): If set to `True`, the resulting + `edge_index` will be relabeled to hold consecutive indices + starting from zero. + + :rtype: (:class:`LongTensor`, :class:`BoolTensor`, :class:`BoolTensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2]]) + >>> edge_index, edge_mask, node_mask = dropout_node(edge_index) + >>> edge_index + tensor([[0, 1], + [1, 0]]) + >>> edge_mask + tensor([ True, True, False, False, False, False]) + >>> node_mask + tensor([ True, True, False, False]) + """ + if p < 0.0 or p > 1.0: + raise ValueError(f"Dropout probability has to be between 0 and 1 " f"(got {p}") + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if not training or p == 0.0: + node_mask = mint.ones(num_nodes).bool() + edge_mask = mint.ones(edge_index.shape[1]).bool() + return edge_index, edge_mask, node_mask + + prob = ops.rand(num_nodes) + node_mask = prob > p + edge_index, _, edge_mask = subgraph( + node_mask, + edge_index, + relabel_nodes=relabel_nodes, + num_nodes=num_nodes, + return_edge_mask=True, + ) + return edge_index, edge_mask, node_mask + + +def dropout_edge( + edge_index: Tensor, + p: float = 0.5, + force_undirected: bool = False, + training: bool = True, +) -> Tuple[Tensor, Tensor]: + r"""Randomly drops edges from the adjacency matrix + :obj:`edge_index` with probability :obj:`p` using samples from + a Bernoulli distribution. + + The method returns (1) the retained :obj:`edge_index`, (2) the edge mask + or index indicating which edges were retained, depending on the argument + :obj:`force_undirected`. + + Args: + edge_index (LongTensor): The edge indices. + p (float, optional): Dropout probability. (default: :obj:`0.5`) + force_undirected (bool, optional): If set to :obj:`True`, will either + drop or keep both edges of an undirected edge. + (default: :obj:`False`) + training (bool, optional): If set to :obj:`False`, this operation is a + no-op. (default: :obj:`True`) + + :rtype: (:class:`LongTensor`, :class:`BoolTensor` or :class:`LongTensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 1, 2, 2, 3], + ... [1, 0, 2, 1, 3, 2]]) + >>> edge_index, edge_mask = dropout_edge(edge_index) + >>> edge_index + tensor([[0, 1, 2, 2], + [1, 2, 1, 3]]) + >>> edge_mask # masks indicating which edges are retained + tensor([ True, False, True, True, True, False]) + + >>> edge_index, edge_id = dropout_edge(edge_index, + ... force_undirected=True) + >>> edge_index + tensor([[0, 1, 2, 1, 2, 3], + [1, 2, 3, 0, 1, 2]]) + >>> edge_id # indices indicating which edges are retained + tensor([0, 2, 4, 0, 2, 4]) + """ + if p < 0.0 or p > 1.0: + raise ValueError(f"Dropout probability has to be between 0 and 1 " f"(got {p}") + + if not training or p == 0.0: + edge_mask = edge_index.new_ones(edge_index.shape[1], dtype=ms.bool_) + return edge_index, edge_mask + + row, col = edge_index + edge_mask = ops.rand(row.shape[0]) >= p + + if force_undirected: + edge_mask[row > col] = False + + edge_index = edge_index[:, edge_mask] + + if force_undirected: + edge_index = mint.cat([edge_index, edge_index.flip((0, ))], dim=1) + edge_mask = edge_mask.nonzero().tile((2, 1)).squeeze() + + return edge_index, edge_mask diff --git a/mindscience/sharker/utils/embedding.py b/mindscience/sharker/utils/embedding.py new file mode 100644 index 000000000..aa3deaa69 --- /dev/null +++ b/mindscience/sharker/utils/embedding.py @@ -0,0 +1,54 @@ +import warnings +from typing import Any, List + +from mindspore import nn +from mindspore import Tensor, ops, nn + + +def get_embeddings( + model: nn.Cell, + *args: Any, + **kwargs: Any, +) -> List[Tensor]: + """Returns the output embeddings of all + :class:`~sharker.nn.conv.MessagePassing` layers in + :obj:`model`. + + Internally, this method registers forward hooks on all + :class:`~sharker.nn.conv.MessagePassing` layers of a :obj:`model`, + and runs the forward pass of the :obj:`model` by calling + :obj:`model(*args, **kwargs)`. + + Args: + model (nn.Cell): The message passing model. + *args: Arguments passed to the model. + **kwargs (optional): Additional keyword arguments passed to the model. + """ + from ..nn import MessagePassing + + embeddings: List[Tensor] = [] + + def hook(model: nn.Cell, inputs: Any, outputs: Any) -> None: + # Clone output in case it will be later modified in-place: + outputs = outputs[0] if isinstance(outputs, tuple) else outputs + assert isinstance(outputs, Tensor) + embeddings.append(outputs.copy()) + + hook_handles = [] + for module in model.cells(): # Register forward hooks: + if isinstance(module, MessagePassing): + hook_handles.append(module.register_forward_hook(hook)) + + if len(hook_handles) == 0: + warnings.warn("The 'model' does not have any 'MessagePassing' layers") + + training = model.training + model.set_train(False) + # model.eval() + model(*args, **kwargs) + model.set_train(training) + + for handle in hook_handles: # Remove hooks: + handle.remove() + + return embeddings diff --git a/mindscience/sharker/utils/functions.py b/mindscience/sharker/utils/functions.py new file mode 100644 index 000000000..16d8cba6e --- /dev/null +++ b/mindscience/sharker/utils/functions.py @@ -0,0 +1,76 @@ +import mindspore as ms +from typing import Union +import numpy as np +import mindspore as ms +from mindspore import mint, Tensor + +def cumsum(x: ms.Tensor, axis: int = 0) -> ms.Tensor: + r"""Returns the cumulative sum of elements of :obj:`x`. + In contrast to :meth:`mindspore.cumsum`, prepends the output with zero. + + Args: + x (Tensor): The input tensor. + axis (int, optional): The dimension to do the operation over. + (default: :obj:`0`) + + Example: + >>> x = Tensor([2, 4, 1]) + >>> cumsum(x) + tensor([0, 2, 6, 7]) + + """ + size = list(x.shape) + size[axis] = 1 + pad_front = mint.zeros(size, dtype=x.dtype) + x_cum = mint.cumsum(x, dim=axis) + out = mint.cat([pad_front, x_cum], dim=axis) + return out + + +def cumsum_np(x: np.ndarray, axis: int = 0) -> np.ndarray: + temp = np.cumsum(x, axis=axis) + first_shape = list(x.shape) + first_shape[axis] = 1 + first = np.zeros(first_shape, dtype=temp.dtype) + out = np.concatenate([first,temp]) + return out + + +def broadcast_to(index: Tensor, src: Tensor, dim: int, is_dense=False) -> Tensor: + r"""Broadcat the index tensor to obtain the detailed information for the scatter operators. + :param src: The source tensor. + :param index: The indices of elements to scatter. + :param dim: The dim along which to index. + :return: the indices with detailed information. + """ + if dim > src.ndim-1: + raise ValueError(f"`dim` must lay between 0 and {src.ndim-1}") + dim = src.dim() + dim if dim < 0 else dim + index = index.expand_as(swapaxes(src, -1, dim)) + index = swapaxes(index, -1, dim) + if is_dense: + return index + + idx = (index == index).nonzero() + index = index.view(-1) + idx[:, dim] = index + return idx + + +def swapaxes(tensor: Tensor, dim0, dim1): + if dim0 == dim1 or dim1 == dim0 - tensor.ndim or dim0 == dim1 - tensor.ndim or tensor.ndim < 2: + return tensor + else: + return tensor.swapaxes(dim0, dim1) + + +def index_fill(tensor, index, value, dim=0): + tensor = swapaxes(tensor, 0, dim) + tensor[index] = swapaxes(value, 0, dim) + return swapaxes(tensor, 0, dim) + + +def index_select(tensor, mask, dim=0): + tensor = swapaxes(tensor, 0, dim) + out = tensor[mask] + return swapaxes(out, 0, dim) diff --git a/mindscience/sharker/utils/grid.py b/mindscience/sharker/utils/grid.py new file mode 100644 index 000000000..6d35fddd1 --- /dev/null +++ b/mindscience/sharker/utils/grid.py @@ -0,0 +1,74 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .coalesce import coalesce + + +def grid( + height: int, + width: int, + dtype: Optional[ms.Type] = None, +) -> Tuple[Tensor, Tensor]: + r"""Returns the edge indices of a two-dimensional grid graph with height + :attr:`height` and width :attr:`width` and its node positions. + + Args: + height (int): The height of the grid. + width (int): The width of the grid. + dtype (ms.Type, optional): The desired data type of the returned + position tensor. (default: :obj:`None`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Example: + >>> (row, col), pos = grid(height=2, width=2) + >>> row + tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3]) + >>> col + tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]) + >>> pos + tensor([[0., 1.], + [1., 1.], + [0., 0.], + [1., 0.]]) + """ + edge_index = grid_index(height, width) + pos = grid_pos(height, width, dtype) + return edge_index, pos + + +def grid_index( + height: int, + width: int, +) -> Tensor: + + w = width + kernel = ops.Tensor([-w - 1, -1, w - 1, -w, 0, w, -w + 1, 1, w + 1]) + + row = mint.arange(height * width).long() + row = row.view(-1, 1).tile((1, kernel.shape[0])) + col = row + kernel.view(1, -1) + row, col = row.view(height, -1), col.view(height, -1) + index = mint.arange(3, row.shape[-1] - 3).long() + row, col = row[:, index].view(-1), col[:, index].view(-1) + + mask = mint.logical_and((col >= 0), (col < height * width)) + row, col = row[mask], col[mask] + + edge_index = mint.stack([row, col], dim=0) + edge_index = coalesce(edge_index, num_nodes=height * width) + return edge_index + + +def grid_pos(height: int, width: int, dtype: Optional[ms.Type] = None) -> Tensor: + + dtype = ms.float32 if dtype is None else dtype + x = mint.arange(width, dtype=dtype) + y = (height - 1) - mint.arange(height, dtype=dtype) + + x = x.tile((height,)) + y = y.unsqueeze(-1).tile((1, width)).view(-1) + + return mint.stack([x, y], dim=-1) + diff --git a/mindscience/sharker/utils/hetero.py b/mindscience/sharker/utils/hetero.py new file mode 100644 index 000000000..7c5c9fd80 --- /dev/null +++ b/mindscience/sharker/utils/hetero.py @@ -0,0 +1,132 @@ +from typing import Dict, List, Optional, Set, Tuple, Union + +import mindspore as ms +from mindspore import Tensor, ops, mint + +from .num_nodes import maybe_num_nodes_dict + + +def group_hetero_graph( + edge_index_dict: Dict[Tuple[str, str, str], Tensor], + num_nodes_dict: Optional[Dict[str, int]] = None, +) -> Tuple[ + Tensor, + Tensor, + Tensor, + Tensor, + Dict[Union[str, int], Tensor], + Dict[Union[str, Tuple[str, str, str]], int], +]: + num_nodes_dict = maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) + + tmp = list(edge_index_dict.values())[0] + + key2int: Dict[Union[str, Tuple[str, str, str]], int] = {} + + cumsum, offset = 0, {} # Helper data. + node_types, local_node_indices = [], [] + local2global: Dict[Union[str, int], Tensor] = {} + for i, (key, N) in enumerate(num_nodes_dict.items()): + key2int[key] = i + node_types.append(ops.full((N,), i), dtype=tmp.dtype) + local_node_indices.append(mint.arange(N)) + offset[key] = cumsum + local2global[key] = local_node_indices[-1] + cumsum + local2global[i] = local2global[key] + cumsum += N + + node_type = mint.cat(node_types, dim=0) + local_node_idx = mint.cat(local_node_indices, dim=0) + + edge_indices, edge_types = [], [] + for i, (keys, edge_index) in enumerate(edge_index_dict.items()): + key2int[keys] = i + inc = ms.Tensor([offset[keys[0]], offset[keys[-1]]]).view(2, 1) + edge_indices.append(edge_index + inc) + edge_types.append(ops.full((edge_index.shape[1],), i), dtype=tmp.dtype) + + edge_index = mint.cat(edge_indices, dim=-1) + edge_type = mint.cat(edge_types, dim=0) + + return ( + edge_index, + edge_type, + node_type, + local_node_idx, + local2global, + key2int, + ) + + +def get_unused_node_types( + node_types: List[str], edge_types: List[Tuple[str, str, str]] +) -> Set[str]: + dst_node_types = set(edge_type[-1] for edge_type in edge_types) + return set(node_types) - set(dst_node_types) + + +def check_add_self_loops( + module: ms.nn.Cell, + edge_types: List[Tuple[str, str, str]], +) -> None: + is_bipartite = any([key[0] != key[-1] for key in edge_types]) + if is_bipartite and getattr(module, "add_self_loops", False): + raise ValueError( + f"'add_self_loops' attribute set to 'True' on module '{module}' " + f"for use with edge type(s) '{edge_types}'. This will lead to " + f"incorrect message passing results." + ) + + +def construct_bipartite_edge_index( + edge_index_dict: Dict[Tuple[str, str, str], Union[Tensor,]], + src_offset_dict: Dict[Tuple[str, str, str], int], + dst_offset_dict: Dict[str, int], + edge_attr_dict: Optional[Dict[Tuple[str, str, str], Tensor]] = None, + num_nodes: Optional[int] = None, +) -> Tuple[Union[Tensor,], Optional[Tensor]]: + """Constructs a tensor of edge indices by concatenating edge indices + for each edge type. The edge indices are increased by the offset of the + source and destination nodes. + + Args: + edge_index_dict (Dict[Tuple[str, str, str], Tensor]): A + dictionary holding graph connectivity information for each + individual edge type, either as a :class:`Tensor` of + shape :obj:`[2, num_edges]` or a + :class:`mindGeometric_sparse.SparseTensor`. + src_offset_dict (Dict[Tuple[str, str, str], int]): A dictionary of + offsets to apply to the source node type for each edge type. + dst_offset_dict (Dict[str, int]): A dictionary of offsets to apply for + destination node types. + edge_attr_dict (Dict[Tuple[str, str, str], Tensor]): A + dictionary holding edge features for each individual edge type. + (default: :obj:`None`) + num_nodes (int, optional): The final number of nodes in the bipartite + adjacency matrix. (default: :obj:`None`) + """ + edge_indices: List[Tensor] = [] + edge_attrs: List[Tensor] = [] + for edge_type, src_offset in src_offset_dict.items(): + edge_index = edge_index_dict[edge_type] + dst_offset = dst_offset_dict[edge_type[-1]] + + edge_index = edge_index.copy() + + edge_index[0] += src_offset + edge_index[1] += dst_offset + edge_indices.append(edge_index) + + if edge_attr_dict is not None: + value = edge_attr_dict[edge_type] + if value.shape[0] != edge_index.shape[1]: + value = value.broadcast_to((edge_index.shape[1], -1)) + edge_attrs.append(value) + + edge_index = mint.cat(edge_indices, dim=1) + + edge_attr: Optional[Tensor] = None + if edge_attr_dict is not None: + edge_attr = mint.cat(edge_attrs, dim=0) + + return edge_index, edge_attr diff --git a/mindscience/sharker/utils/homophily.py b/mindscience/sharker/utils/homophily.py new file mode 100644 index 000000000..6d0ed06eb --- /dev/null +++ b/mindscience/sharker/utils/homophily.py @@ -0,0 +1,128 @@ +from typing import Union, Optional + +import mindspore as ms +from mindspore import Tensor, ops, mint +from .degree import degree +from . import scatter + + +def homophily( + edge_index: Union[Tensor, ], + y: Tensor, + batch: Optional[Tensor] = None, + method: str = "edge", +) -> Union[float, Tensor]: + r"""The homophily of a graph characterizes how likely nodes with the same + label are near each other in a graph. + + There are many measures of homophily that fits this definition. + In particular: + + - In the `"Beyond Homophily in Graph Neural Networks: Current Limitations + and Effective Designs" `_ paper, the + homophily is the fraction of edges in a graph which connects nodes + that have the same class label: + + .. math:: + \frac{| \{ (v,w) : (v,w) \in \mathcal{E} \wedge y_v = y_w \} | } + {|\mathcal{E}|} + + That measure is called the *edge homophily ratio*. + + - In the `"Geom-GCN: Geometric Graph Convolutional Networks" + `_ paper, edge homophily is normalized + across neighborhoods: + + .. math:: + \frac{1}{|\mathcal{V}|} \sum_{v \in \mathcal{V}} \frac{ | \{ (w,v) : w + \in \mathcal{N}(v) \wedge y_v = y_w \} | } { |\mathcal{N}(v)| } + + That measure is called the *node homophily ratio*. + + - In the `"Large-Scale Learning on Non-Homophilous Graphs: New Benchmarks + and Strong Simple Methods" `_ paper, + edge homophily is modified to be insensitive to the number of classes + and size of each class: + + .. math:: + \frac{1}{C-1} \sum_{k=1}^{C} \max \left(0, h_k - \frac{|\mathcal{C}_k|} + {|\mathcal{V}|} \right), + + where :math:`C` denotes the number of classes, :math:`|\mathcal{C}_k|` + denotes the number of nodes of class :math:`k`, and :math:`h_k` denotes + the edge homophily ratio of nodes of class :math:`k`. + + Thus, that measure is called the *class insensitive edge homophily + ratio*. + + Args: + edge_index (Tensor or SparseTensor): The graph connectivity. + y (Tensor): The labels. + batch (LongTensor, optional): Batch vector\ + :math:`\mathbf{b} \in {\{ 0, \ldots,B-1\}}^N`, which assigns + each node to a specific example. (default: :obj:`None`) + method (str, optional): The method used to calculate the homophily, + either :obj:`"edge"` (first formula), :obj:`"node"` (second + formula) or :obj:`"edge_insensitive"` (third formula). + (default: :obj:`"edge"`) + + Examples: + >>> edge_index = Tensor([[0, 1, 2, 3], + ... [1, 2, 0, 4]]) + >>> y = Tensor, 0, 0, 1]) + >>> # Edge homophily ratio + >>> homophily(edge_index, y, method='edge') + 0.75 + + >>> # Node homophily ratio + >>> homophily(edge_index, y, method='node') + 0.6000000238418579 + + >>> # Class insensitive edge homophily ratio + >>> homophily(edge_index, y, method='edge_insensitive') + 0.19999998807907104 + """ + assert method in {"edge", "node", "edge_insensitive"} + y = y.squeeze(-1) if y.dim() > 1 else y + row, col = edge_index + + if method == "edge": + out = mint.zeros(row.shape[0]) + out[y[row] == y[col]] = 1.0 + if batch is None: + return float(out.mean()) + else: + dim_size = int(batch.max()) + 1 + return scatter(out, batch[col], 0, dim_size, reduce="mean") + + elif method == "node": + out = mint.zeros(row.shape[0]) + out[y[row] == y[col]] = 1.0 + out = scatter(out, col, 0, dim_size=y.shape[0], reduce="mean") + if batch is None: + return float(out.mean()) + else: + return scatter(out, batch, dim=0, reduce="mean") + + elif method == "edge_insensitive": + assert y.dim() == 1 + num_classes = int(y.max()) + 1 + assert num_classes >= 2 + batch = mint.zeros_like(y) if batch is None else batch + num_nodes = degree(batch, dtype=ms.int64) + num_graphs = num_nodes.numel() + batch = num_classes * batch + y + + h = homophily(edge_index, y, batch, method="edge") + h = h.view(num_graphs, num_classes) + + counts = batch.bincount(minlength=num_classes * num_graphs) + counts = counts.view(num_graphs, num_classes) + proportions = counts / num_nodes.view(-1, 1) + + out = (h - proportions).clamp(min=0).sum(-1) + out /= num_classes - 1 + return out if out.numel() > 1 else float(out) + + else: + raise NotImplementedError diff --git a/mindscience/sharker/utils/isolated.py b/mindscience/sharker/utils/isolated.py new file mode 100644 index 000000000..8a7fa707c --- /dev/null +++ b/mindscience/sharker/utils/isolated.py @@ -0,0 +1,96 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .loop import remove_self_loops, segregate_self_loops +from .num_nodes import maybe_num_nodes + + +def contains_isolated_nodes( + edge_index: Tensor, + num_nodes: Optional[int] = None, +) -> bool: + r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains + isolated nodes. + + Args: + edge_index (LongTensor): The edge indices. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: bool + + Examples: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> contains_isolated_nodes(edge_index) + False + + >>> contains_isolated_nodes(edge_index, num_nodes=3) + True + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + edge_index, _ = remove_self_loops(edge_index) + return mint.unique(edge_index.view(-1)).numel() < num_nodes + + +def remove_isolated_nodes( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, +) -> Tuple[Tensor, Optional[Tensor], Tensor]: + r"""Removes the isolated nodes from the graph given by :attr:`edge_index` + with optional edge attributes :attr:`edge_attr`. + In addition, returns a mask of shape :obj:`[num_nodes]` to manually filter + out isolated node features later on. + Self-loops are preserved for non-isolated nodes. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: (LongTensor, Tensor, BoolTensor) + + Examples: + >>> edge_index = Tensor + ... [1, 0, 0]]) + >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index) + >>> mask # node mask (2 nodes) + tensor([True, True]) + + >>> edge_index, edge_attr, mask = remove_isolated_nodes(edge_index, + ... num_nodes=3) + >>> mask # node mask (3 nodes) + tensor([True, True, False]) + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + out = segregate_self_loops(edge_index, edge_attr) + edge_index, edge_attr, loop_edge_index, loop_edge_attr = out + + mask = mint.zeros(num_nodes, dtype=ms.bool_) + mask[edge_index.view(-1)] = 1 + + assoc = ops.full((num_nodes,), -1, dtype=ms.int64) + assoc[mask] = mint.arange(mask.sum().item()) # type: ignore + edge_index = assoc[edge_index] + + loop_mask = mint.zeros_like(mask) + loop_mask[loop_edge_index[0]] = 1 + loop_mask = mint.logical_and(loop_mask, mask) + loop_assoc = ops.full_like(assoc, -1) + loop_assoc[loop_edge_index[0]] = mint.arange(loop_edge_index.shape[1]) + loop_idx = loop_assoc[loop_mask] + loop_edge_index = assoc[loop_edge_index[:, loop_idx]] + + edge_index = mint.cat([edge_index, loop_edge_index], dim=1) + + if edge_attr is not None: + assert loop_edge_attr is not None + loop_edge_attr = loop_edge_attr[loop_idx] + edge_attr = mint.cat([edge_attr, loop_edge_attr], dim=0) + + return edge_index, edge_attr, mask diff --git a/mindscience/sharker/utils/laplacian.py b/mindscience/sharker/utils/laplacian.py new file mode 100644 index 000000000..7fbee1fe0 --- /dev/null +++ b/mindscience/sharker/utils/laplacian.py @@ -0,0 +1,207 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint + +from .loop import add_self_loops, remove_self_loops +from . import scatter +from .num_nodes import maybe_num_nodes +from .undirected import to_undirected +from .ncon import Ncon + + +def get_laplacian( + edge_index: Tensor, + edge_weight: Optional[Tensor] = None, + normalization: Optional[str] = None, + dtype: Optional[ms.Type] = None, + num_nodes: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + r"""Computes the graph Laplacian of the graph given by :obj:`edge_index` + and optional :obj:`edge_weight`. + + Args: + edge_index (LongTensor): The edge indices. + edge_weight (Tensor, optional): One-dimensional edge weights. + (default: :obj:`None`) + normalization (str, optional): The normalization scheme for the graph + Laplacian (default: :obj:`None`): + + 1. :obj:`None`: No normalization + :math:`\mathbf{L} = \mathbf{D} - \mathbf{A}` + + 2. :obj:`"sym"`: Symmetric normalization + :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1/2} \mathbf{A} + \mathbf{D}^{-1/2}` + + 3. :obj:`"rw"`: Random-walk normalization + :math:`\mathbf{L} = \mathbf{I} - \mathbf{D}^{-1} \mathbf{A}` + dtype (ms.Type, optional): The desired data type of returned tensor + in case :obj:`edge_weight=None`. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + Examples: + >>> edge_index = Tensor([[0, 1, 1, 2], + ... [1, 0, 2, 1]]) + >>> edge_weight = Tensor([1., 2., 2., 4.]) + + >>> # No normalization + >>> lap = get_laplacian(edge_index, edge_weight) + + >>> # Symmetric normalization + >>> lap_sym = get_laplacian(edge_index, edge_weight, + normalization='sym') + + >>> # Random-walk normalization + >>> lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw') + """ + if normalization is not None: + assert normalization in ["sym", "rw"] # 'Invalid normalization' + + edge_index, edge_weight = remove_self_loops(edge_index, edge_weight) + + if edge_weight is None: + edge_weight = mint.ones(edge_index.shape[1], dtype=dtype) + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + src, dst = edge_index[0], edge_index[1] + deg = scatter(edge_weight, src, 0, dim_size=num_nodes, reduce="sum") + + if normalization is None: + # L = D - A. + edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes) + edge_weight = mint.cat([-edge_weight, deg], dim=0) + elif normalization == "sym": + # Compute A_norm = -D^{-1/2} A D^{-1/2}. + deg_inv_sqrt = deg**-0.5 + deg_inv_sqrt[deg_inv_sqrt == float("inf")] = 0 + edge_weight = deg_inv_sqrt[src] * edge_weight * deg_inv_sqrt[dst] + + # L = I - A_norm. + assert isinstance(edge_weight, Tensor) + edge_index, edge_weight = add_self_loops( # + edge_index, -edge_weight, fill_value=1.0, num_nodes=num_nodes + ) + else: + # Compute A_norm = -D^{-1} A. + deg_inv = 1.0 / deg + deg_inv[deg_inv == float("inf")] = 0 + edge_weight = deg_inv[src] * edge_weight + + # L = I - A_norm. + assert isinstance(edge_weight, Tensor) + edge_index, edge_weight = add_self_loops( # + edge_index, -edge_weight, fill_value=1.0, num_nodes=num_nodes + ) + + return edge_index, edge_weight + + +def get_mesh_laplacian( + crd: Tensor, + face: Tensor, + normalization: Optional[str] = None, +) -> Tuple[Tensor, Tensor]: + r"""Computes the mesh Laplacian of a mesh given by :obj:`pos` and + :obj:`face`. + + Computation is based on the cotangent matrix defined as + + .. math:: + \mathbf{C}_{ij} = \begin{cases} + \frac{\cot \angle_{ikj}~+\cot \angle_{ilj}}{2} & + \text{if } i, j \text{ is an edge} \\ + -\sum_{j \in N(i)}{C_{ij}} & + \text{if } i \text{ is in the diagonal} \\ + 0 & \text{otherwise} + \end{cases} + + Normalization depends on the mass matrix defined as + + .. math:: + \mathbf{M}_{ij} = \begin{cases} + a(i) & \text{if } i \text{ is in the diagonal} \\ + 0 & \text{otherwise} + \end{cases} + + where :math:`a(i)` is obtained by joining the barycenters of the + triangles around vertex :math:`i`. + + Args: + crd (Tensor): The node positions. + face (LongTensor): The face indices. + normalization (str, optional): The normalization scheme for the mesh + Laplacian (default: :obj:`None`): + + 1. :obj:`None`: No normalization + :math:`\mathbf{L} = \mathbf{C}` + + 2. :obj:`"sym"`: Symmetric normalization + :math:`\mathbf{L} = \mathbf{M}^{-1/2} \mathbf{C}\mathbf{M}^{-1/2}` + + 3. :obj:`"rw"`: Row-wise normalization + :math:`\mathbf{L} = \mathbf{M}^{-1} \mathbf{C}` + """ + assert crd.shape[1] == 3 and face.shape[0] == 3 + + num_nodes = crd.shape[0] + + def get_cots(left: Tensor, centre: Tensor, right: Tensor) -> Tensor: + left_pos, central_pos, right_pos = crd[left], crd[centre], crd[right] + left_vec = left_pos - central_pos + right_vec = right_pos - central_pos + dot = Ncon([[-1, 1], [-1, 1]])([left_vec, right_vec]) + cross = ms.numpy.norm(ms.numpy.cross(left_vec, right_vec, axis=1), axis=1) + cot = dot / cross # cot = cos / sin + return cot / 2.0 # by definition + + # For each triangle face, get all three cotangents: + cot_021 = get_cots(face[0], face[2], face[1]) + cot_102 = get_cots(face[1], face[0], face[2]) + cot_012 = get_cots(face[0], face[1], face[2]) + cot_weight = mint.cat([cot_021, cot_102, cot_012]) + + # Face to edge: + cot_index = mint.cat([face[:2], face[1:], face[::2]], dim=1) + cot_index, cot_weight = to_undirected(cot_index, cot_weight) + + # Compute the diagonal part: + cot_deg = scatter(cot_weight, cot_index[0], 0, num_nodes, reduce="sum") + edge_index, _ = add_self_loops(cot_index, num_nodes=num_nodes) + edge_weight = mint.cat([cot_weight, -cot_deg], dim=0) + + if normalization is not None: + + def get_areas(left: Tensor, centre: Tensor, right: Tensor) -> Tensor: + central_pos = crd[centre] + left_vec = crd[left] - central_pos + right_vec = crd[right] - central_pos + cross = ms.numpy.norm(ms.numpy.cross(left_vec, right_vec, axis=1), axis=1) + area = cross / 6.0 # one-third of a triangle's area is cross / 6.0 + return area / 2.0 # since each corresponding area is counted twice + + # Like before, but here we only need the diagonal (the mass matrix): + area_021 = get_areas(face[0], face[2], face[1]) + area_102 = get_areas(face[1], face[0], face[2]) + area_012 = get_areas(face[0], face[1], face[2]) + area_weight = mint.cat([area_021, area_102, area_012]) + area_index = mint.cat([face[:2], face[1:], face[::2]], dim=1) + area_index, area_weight = to_undirected(area_index, area_weight) + area_deg = scatter(area_weight, area_index[0], 0, num_nodes, "sum") + + if normalization == "sym": + area_deg_inv_sqrt = area_deg**-0.5 + area_deg_inv_sqrt[area_deg_inv_sqrt == float("inf")] = 0.0 + edge_weight = ( + area_deg_inv_sqrt[edge_index[0]] + * edge_weight + * area_deg_inv_sqrt[edge_index[1]] + ) + elif normalization == "rw": + area_deg_inv = 1.0 / area_deg + area_deg_inv[area_deg_inv == float("inf")] = 0.0 + edge_weight = area_deg_inv[edge_index[0]] * edge_weight + + return edge_index, edge_weight diff --git a/mindscience/sharker/utils/loop.py b/mindscience/sharker/utils/loop.py new file mode 100644 index 000000000..48024f18e --- /dev/null +++ b/mindscience/sharker/utils/loop.py @@ -0,0 +1,322 @@ +from typing import Optional, Tuple, Union + +from mindspore import Tensor, COOTensor, CSRTensor, ops, mint + +from . import scatter +from .num_nodes import maybe_num_nodes + + +def contains_self_loops(edge_index: Tensor) -> bool: + r"""Returns :obj:`True` if the graph given by :attr:`edge_index` contains + self-loops. + + Args: + edge_index (LongTensor): The edge indices. + + :rtype: bool + + Examples: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> contains_self_loops(edge_index) + True + + >>> edge_index = Tensor([[0, 1, 1], + ... [1, 0, 2]]) + >>> contains_self_loops(edge_index) + False + """ + mask = edge_index[0] == edge_index[1] + return mask.sum().item() > 0 + + +def remove_self_loops( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Removes every self-loop in the graph given by :attr:`edge_index`, so + that :math:`(i,i) \not\in \mathcal{E}` for every :math:`i \in \mathcal{V}`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Example: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> edge_attr = [[1, 2], [3, 4], [5, 6]] + >>> edge_attr = Tensor(edge_attr) + >>> remove_self_loops(edge_index, edge_attr) + (tensor([[0, 1], + [1, 0]]), + tensor([[1, 2], + [3, 4]])) + """ + value = None + mask = edge_index[0] != edge_index[1] + edge_index = edge_index[:, mask] + + if value is not None: + value = value[mask] + + if edge_attr is None: + return edge_index, None + else: + return edge_index, edge_attr[mask] + + +def segregate_self_loops( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, +) -> Tuple[Tensor, Optional[Tensor], Tensor, Optional[Tensor]]: + r"""Segregates self-loops from the graph. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`, :class:`LongTensor`, + :class:`Tensor`) + + Example: + >>> edge_index = Tensor([[0, 0, 1], + ... [0, 1, 0]]) + >>> (edge_index, edge_attr, + ... loop_edge_index, + ... loop_edge_attr) = segregate_self_loops(edge_index) + >>> loop_edge_index + tensor([[0], + [0]]) + """ + mask = mint.ne(edge_index[0],edge_index[1]) + inv_mask = mint.logical_not(mask) + loop_edge_index = mint.index_select(edge_index, 1, mint.nonzero(inv_mask).view(-1)) + loop_edge_attr = None if edge_attr is None else edge_attr[inv_mask] + edge_index = mint.index_select(edge_index, 1, mint.nonzero(mask).view(-1)) + edge_attr = None if edge_attr is None else edge_attr[mask] + + return edge_index, edge_attr, loop_edge_index, loop_edge_attr + + +def add_self_loops( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + fill_value: Optional[Union[float, Tensor, str]] = None, + num_nodes: Optional[Union[int, Tuple[int, int]]] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Adds a self-loop :math:`(i,i) \in \mathcal{E}` to every node + :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. + In case the graph is weighted or has multi-dimensional edge features + (:obj:`edge_attr != None`), edge features of self-loops will be added + according to :obj:`fill_value`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional edge + features. (default: :obj:`None`) + fill_value (float or Tensor or str, optional): The way to generate + edge features of self-loops (in case :obj:`edge_attr != None`). + If given as :obj:`float` or :class:`Tensor`, edge features of + self-loops will be directly given by :obj:`fill_value`. + If given as :obj:`str`, edge features of self-loops are computed by + aggregating all features of edges that point to the specific node, + according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, + :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) + num_nodes (int or Tuple[int, int], optional): The number of nodes, + *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. + If given as a tuple, then :obj:`edge_index` is interpreted as a + bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`. + (default: :obj:`None`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> edge_weight = Tensor0.5]) + >>> add_self_loops(edge_index) + (tensor([[0, 1, 0, 0, 1], + [1, 0, 0, 0, 1]]), + None) + + >>> add_self_loops(edge_index, edge_weight) + (tensor([[0, 1, 0, 0, 1], + [1, 0, 0, 0, 1]]), + tensor([0.5000, 0.5000, 0.5000, 1.0000, 1.0000])) + + >>> # edge features of self-loops are filled by constant `2.0` + >>> add_self_loops(edge_index, edge_weight, + ... fill_value=2.) + (tensor([[0, 1, 0, 0, 1], + [1, 0, 0, 0, 1]]), + tensor([0.5000, 0.5000, 0.5000, 2.0000, 2.0000])) + + >>> # Use 'add' operation to merge edge features for self-loops + >>> add_self_loops(edge_index, edge_weight, + ... fill_value='add') + (tensor([[0, 1, 0, 0, 1], + [1, 0, 0, 0, 1]]), + tensor([0.5000, 0.5000, 0.5000, 1.0000, 0.5000])) + """ + value = None + + if isinstance(num_nodes, (tuple, list)): + size = (num_nodes[0], num_nodes[1]) + N = min(size) + else: + N = maybe_num_nodes(edge_index, num_nodes) + size = (N, N) + + loop_index = mint.arange(N).view(1, -1).tile((2, 1)) + + full_edge_index = mint.cat([edge_index, loop_index], dim=1) + + if edge_attr is not None: + loop_attr = compute_loop_attr(edge_index, edge_attr, N, fill_value) # + edge_attr = mint.cat([edge_attr, loop_attr], dim=0) + + return full_edge_index, edge_attr + + +def compute_loop_attr( + edge_index: Tensor, + edge_attr: Tensor, + num_nodes: int, + fill_value: Optional[Union[float, Tensor, str]] = None, +) -> Tensor: + + if fill_value is None: + size = (num_nodes,) + edge_attr.shape[1:] + return mint.ones(size, dtype=edge_attr.dtype) + + elif isinstance(fill_value, (int, float)): + size = (num_nodes,) + edge_attr.shape[1:] + return ops.full(size, fill_value, dtype=edge_attr.dtype) + + elif isinstance(fill_value, Tensor): + size = (num_nodes,) + edge_attr.shape[1:] + loop_attr = fill_value.astype(edge_attr.dtype) + if edge_attr.dim() != loop_attr.dim(): + loop_attr = loop_attr.unsqueeze(0) + return loop_attr.broadcast_to(size) + + elif isinstance(fill_value, str): + col = edge_index[1] + return scatter(edge_attr, col, 0, num_nodes, fill_value) + else: + raise AttributeError("No valid 'fill_value' provided") + + +def add_remaining_self_loops( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + fill_value: Optional[Union[float, Tensor, str]] = None, + num_nodes: Optional[int] = None, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Adds remaining self-loop :math:`(i,i) \in \mathcal{E}` to every node + :math:`i \in \mathcal{V}` in the graph given by :attr:`edge_index`. + In case the graph is weighted or has multi-dimensional edge features + (:obj:`edge_attr != None`), edge features of non-existing self-loops will + be added according to :obj:`fill_value`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional edge + features. (default: :obj:`None`) + fill_value (float or Tensor or str, optional): The way to generate + edge features of self-loops (in case :obj:`edge_attr != None`). + If given as :obj:`float` or :class:`Tensor`, edge features of + self-loops will be directly given by :obj:`fill_value`. + If given as :obj:`str`, edge features of self-loops are computed by + aggregating all features of edges that point to the specific node, + according to a reduce operation. (:obj:`"add"`, :obj:`"mean"`, + :obj:`"min"`, :obj:`"max"`, :obj:`"mul"`). (default: :obj:`1.`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Example: + >>> edge_index = Tensor([[0, 1], + ... [1, 0]]) + >>> edge_weight = Tensor + >>> add_remaining_self_loops(edge_index, edge_weight) + (tensor([[0, 1, 0, 1], + [1, 0, 0, 1]]), + tensor([0.5000, 0.5000, 1.0000, 1.0000])) + """ + N = maybe_num_nodes(edge_index, num_nodes) + mask = edge_index[0] != edge_index[1] + + loop_index = mint.arange(N).view(1, -1).tile((2, 1)).astype(edge_index.dtype) + + if edge_attr is not None: + + loop_attr = ops.deepcopy(compute_loop_attr(edge_index, edge_attr, N, fill_value)) # + + inv_mask = ~mask + if inv_mask.any(): + loop_attr[edge_index[0][inv_mask]] = edge_attr[inv_mask] + + edge_attr = mint.cat([edge_attr[mask], loop_attr], dim=0) + + edge_index = edge_index[:, mask] + + edge_index = mint.cat([edge_index, loop_index], dim=1) + + return edge_index, edge_attr + + +def get_self_loop_attr( + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, +) -> Tensor: + r"""Returns the edge features or weights of self-loops + :math:`(i, i)` of every node :math:`i \in \mathcal{V}` in the + graph given by :attr:`edge_index`. Edge features of missing self-loops not + present in :attr:`edge_index` will be filled with zeros. If + :attr:`edge_attr` is not given, it will be the vector of ones. + + .. note:: + This operation is analogous to getting the diagonal elements of the + dense adjacency matrix. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional edge + features. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + + Examples: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> edge_weight = Tensor([0.2, 0.3, 0.5]) + >>> get_self_loop_attr(edge_index, edge_weight) + tensor([0.5000, 0.0000]) + + >>> get_self_loop_attr(edge_index, edge_weight, num_nodes=4) + tensor([0.5000, 0.0000, 0.0000, 0.0000]) + """ + loop_mask = edge_index[0] == edge_index[1] + loop_index = edge_index[0][loop_mask] + + if edge_attr is not None: + loop_attr = edge_attr[loop_mask] + else: # A vector of ones: + loop_attr = mint.ones(loop_index.numel()) + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + full_loop_attr = mint.zeros( + (num_nodes,) + loop_attr.shape[1:], dtype=loop_attr.dtype + ) + full_loop_attr[loop_index] = loop_attr + + return full_loop_attr diff --git a/mindscience/sharker/utils/map.py b/mindscience/sharker/utils/map.py new file mode 100644 index 000000000..60dd9eef1 --- /dev/null +++ b/mindscience/sharker/utils/map.py @@ -0,0 +1,169 @@ +from typing import Optional, Tuple, Union + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, mint + + +def map_index( + src: Tensor, + index: Tensor, + max_index: Optional[Union[int, Tensor]] = None, + inclusive: bool = False, +) -> Tuple[Tensor, Optional[Tensor]]: + r"""Maps indices in :obj:`src` to the positional value of their + corresponding occurence in :obj:`index`. + Indices must be strictly positive. + + Args: + src (Tensor): The source tensor to map. + index (Tensor): The index tensor that denotes the new mapping. + max_index (int, optional): The maximum index value. + (default :obj:`None`) + inclusive (bool, optional): If set to :obj:`True`, it is assumed that + every entry in :obj:`src` has a valid entry in :obj:`index`. + Can speed-up computation. (default: :obj:`False`) + + :rtype: (:class:`Tensor`, :class:`mindspore.BoolTensor`) + + Examples: + >>> src = Tensor([2, 0, 1, 0, 3]) + >>> index = Tensor([3, 2, 0, 1]) + + >>> map_index(src, index) + (tensor([1, 2, 3, 2, 0]), tensor([True, True, True, True, True])) + + >>> src = Tensor([2, 0, 1, 0, 3]) + >>> index = Tensor([3, 2, 0]) + + >>> map_index(src, index) + (tensor([1, 2, 2, 0]), tensor([True, True, False, True, True])) + + .. note:: + + If inputs are on GPU and :obj:`cudf` is available, consider using RMM + for significant speed boosts. + Proceed with caution as RMM may conflict with other allocators or + fragments. + """ + if src.is_floating_point(): + raise ValueError(f"Expected 'src' to be an index (got '{src.dtype}')") + if index.is_floating_point(): + raise ValueError(f"Expected 'index' to be an index (got " f"'{index.dtype}')") + if max_index is None: + max_index = max(src.max(), index.max()) + + # If the `max_index` is in a reasonable range, we can accelerate this + # operation by creating a helper vector to perform the mapping. + # NOTE This will potentially consumes a large chunk of memory + # (max_index=10 million => ~75MB), so we cap it at a reasonable size: + THRESHOLD = 10_000_000 + if max_index <= THRESHOLD: + if inclusive: + assoc = mint.zeros(max_index + 1, dtype=src.dtype) + else: + assoc = -mint.ones(max_index + 1, dtype=src.dtype) + assoc = ms.ops.scatter_update(assoc, index, mint.arange(index.numel(), dtype=src.dtype)) + out = mint.index_select(assoc, 0, src) + + if inclusive: + return out, None + else: + mask = out != -1 + return out[mask], mask + + import pandas as pd + + left_ser = pd.Series(src.asnumpy(), name="left_ser") + right_ser = pd.Series( + index=index.asnumpy(), + data=pd.RangeIndex(0, index.shape[0]), + name="right_ser", + ) + + result = pd.merge( + left_ser, right_ser, how="left", left_on="left_ser", right_index=True + ) + + out_numpy = result["right_ser"].values + + out = Tensor.from_numpy(out_numpy) + + if out.is_floating_point() and inclusive: + raise ValueError( + "Found invalid entries in 'src' that do not have " + "a corresponding entry in 'index'. Set " + "`inclusive=False` to ignore these entries." + ) + + if out.is_floating_point(): + mask = mint.logical_not(ops.isnan(out)) + out = ops.masked_select(out, mask).astype(index.dtype) + return out, mask + + if inclusive: + return out, None + else: + mask = out != -1 + return out[mask], mask + +def map_index_np( + src: np.array, + index: np.array, + max_index: Optional[Union[int, np.array]] = None, + inclusive: bool = False, +) -> Tuple[np.array, Optional[np.array]]: + if src.dtype.kind == 'f': + raise ValueError(f"Expected 'src' to be an index (got '{src.dtype}')") + if index.dtype.kind == 'f': + raise ValueError(f"Expected 'index' to be an index (got " f"'{index.dtype}')") + if max_index is None: + max_index = np.max(src.max(), index.max()) + + THRESHOLD = 10_000_000 + if max_index <= THRESHOLD: + if inclusive: + assoc = np.zeros(max_index + 1, dtype=src.dtype) + else: + assoc = -np.ones(max_index + 1, dtype=src.dtype) + assoc[index] = np.arange(index.size, dtype=src.dtype) + out = assoc[src] + + if inclusive: + return out, None + else: + mask = out != -1 + return out[mask], mask + + import pandas as pd + + left_ser = pd.Series(src, name="left_ser") + right_ser = pd.Series( + index=index, + data=pd.RangeIndex(0, index.shape[0]), + name="right_ser", + ) + + result = pd.merge( + left_ser, right_ser, how="left", left_on="left_ser", right_index=True + ) + + out = result["right_ser"].values + + if out.dtype.kind == 'f' and inclusive: + raise ValueError( + "Found invalid entries in 'src' that do not have " + "a corresponding entry in 'index'. Set " + "`inclusive=False` to ignore these entries." + ) + + if out.dtype.kind == 'f': + mask = ~np.isnan(out) + out = out[mask].astype(index.dtype) + return out, mask + + if inclusive: + return out, None + else: + mask = out != -1 + return out[mask], mask \ No newline at end of file diff --git a/mindscience/sharker/utils/mask.py b/mindscience/sharker/utils/mask.py new file mode 100644 index 000000000..34523fa7e --- /dev/null +++ b/mindscience/sharker/utils/mask.py @@ -0,0 +1,99 @@ +from typing import Optional + +import mindspore as ms +from mindspore import mint +import numpy as np + + +def mask_select(src: ms.Tensor, axis: int, mask: ms.Tensor) -> ms.Tensor: + r"""Returns a new tensor which masks the :obj:`src` tensor along the + dimension :obj:`dim` according to the boolean mask :obj:`mask`. + + Args: + src (Tensor): The input tensor. + dim (int): The dimension in which to mask. + mask (mindspore.BoolTensor): The 1-D tensor containing the binary mask to + index with. + """ + assert mask.dim() == 1 + + assert src.shape[axis] == mask.numel() + axis += src.dim() if axis < 0 else axis + assert axis >= 0 and axis < src.dim() + + # Applying a 1-dimensional mask in the first dimension is significantly + # faster than broadcasting the mask and utilizing `masked_select`. + # As such, we transpose in the first dimension, perform the masking, and + # then transpose back to the original shape. + idx = mint.nonzero(mask).reshape(-1) + if axis != 0: + out = mint.index_select(src, 1, idx) + else: + out = mint.index_select(src, 0, idx) + + return out + +def mask_select_np(src: np.ndarray, axis: int, mask: np.ndarray) -> np.ndarray: + assert mask.ndim == 1 + + assert src.shape[axis] == mask.size + axis += src.ndim if axis < 0 else axis + assert axis >= 0 and axis < src.ndim + + # Applying a 1-dimensional mask in the first dimension is significantly + # faster than broadcasting the mask and utilizing `masked_select`. + # As such, we transpose in the first dimension, perform the masking, and + # then transpose back to the original shape. + src = src.transpose(axis, 0) if axis != 0 else src + out = src[mask] + out = out.transpose(axis, 0) if axis != 0 else out + + return out + + +def index_to_mask(index: ms.Tensor, size: Optional[int] = None) -> ms.Tensor: + r"""Converts indices to a mask representation. + + Args: + index (Tensor): The indices. + size (int, optional): The size of the mask. If set to :obj:`None`, a + minimal sized output mask is returned. + + Example: + >>> index = ms.Tensor([1, 3, 5]) + >>> index_to_mask(index) + tensor([False, True, False, True, False, True]) + + >>> index_to_mask(index, size=7) + tensor([False, True, False, True, False, True, False]) + """ + index = index.view(-1).astype(ms.int32) + size = mint.max(index) + 1 if size is None else size + mask = mint.zeros(size) + mask = ms.ops.index_fill(mask, 0, index, True).astype(ms.bool_) + + return mask + +def index_to_mask_np(index: np.ndarray, size: Optional[int] = None) -> np.ndarray: + index = index.reshape(-1) + size = int(index.max()) + 1 if size is None else size + mask = np.zeros(size) + mask[index] = True + return mask.astype(np.bool_) + + +def mask_to_index(mask: ms.Tensor) -> ms.Tensor: + r"""Converts a mask to an index representation. + + Args: + mask (Tensor): The mask. + + Example: + >>> mask = Tensore, True, False]) + >>> mask_to_index(mask) + tensor([1]) + """ + return mint.nonzero(mask).view(-1) + +def mask_to_index_np(mask: np.array) -> np.array: + return mint.nonzero(mask).view(-1) diff --git a/mindscience/sharker/utils/mixin.py b/mindscience/sharker/utils/mixin.py new file mode 100644 index 000000000..7bda5baa8 --- /dev/null +++ b/mindscience/sharker/utils/mixin.py @@ -0,0 +1,22 @@ +from typing import Any, Iterator, TypeVar + +T = TypeVar('T') + + +class CastMixin: + @classmethod + def cast(cls: T, *args: Any, **kwargs: Any) -> T: + if len(args) == 1 and len(kwargs) == 0: + elem = args[0] + if elem is None: + return None # type: ignore + if isinstance(elem, CastMixin): + return elem # type: ignore + if isinstance(elem, tuple): + return cls(*elem) # type: ignore + if isinstance(elem, dict): + return cls(**elem) # type: ignore + return cls(*args, **kwargs) # type: ignore + + def __iter__(self) -> Iterator: + return iter(self.__dict__.values()) diff --git a/mindscience/sharker/utils/ncon.py b/mindscience/sharker/utils/ncon.py new file mode 100644 index 000000000..5bbfb798e --- /dev/null +++ b/mindscience/sharker/utils/ncon.py @@ -0,0 +1,514 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""ncon""" +from copy import deepcopy +import numpy as np +import mindspore as ms + +from mindspore import ops, nn, vmap +from mindspore.numpy import tensordot, trace, expand_dims + + +def list_to_tuple(lst): + """list_to_tuple""" + return tuple(list_to_tuple(item) if isinstance(item, list) else item for item in lst) + + +def nest_vmap(fn, in_list, out_list, pt): + """nest vmap function""" + if pt == len(in_list) - 1: + return vmap(fn, in_list[pt], out_list[pt]) + return vmap(nest_vmap(fn, in_list, out_list, pt + 1), in_list[pt], out_list[pt]) + + +def _create_order(con_list): + """ Identify all unique, positive indices and return them sorted. """ + flat_con = np.concatenate(con_list) + return np.unique(flat_con[flat_con > 0]).tolist() + + +def _single_trace(con, leg): + """_single_trace""" + leg = np.where(np.array(con) == leg)[0] + con = np.delete(con, leg).tolist() + return con, leg.tolist() + + +def _find_sum(con_list): + """_find_sum + + Args: + con_list: con_list + + Returns: + legs + """ + flat = [] + for item in con_list: + flat += item + legs = [] + for leg in np.unique(flat): + if leg < 0: + continue + if np.sum(np.array(flat) == leg) == 1: + legs.append(leg) + return legs + + +def _find_trace(con_list): + """_find_trace + + Args: + con_list: con_list + + Returns: + legs_list + """ + legs_list = [] + for i in range(len(con_list)): + tr_num = len(con_list[i]) - len(np.unique(con_list[i])) + legs = [] + if tr_num: + for leg in np.unique(con_list[i]): + if sum(con_list[i] == leg) > 1 and leg > 0: + leg = np.where(con_list[i] == leg)[0].tolist() + legs.append(leg) + con_list[i] = np.delete(con_list[i], leg).tolist() + + legs_list.append(legs) + return legs_list + + +def _find_batch(con_list): + """_find_batch + + Args: + con_list: con_list + + Returns: + outer + """ + outer = [] + for i in con_list: + if not isinstance(i, np.ndarray): + i = np.array(i) + outer.extend(i[i < 0].tolist()) + if not outer: + return None + if -len(outer) == min(outer): + return None + + for leg in np.unique(outer): + if sum(outer == leg) == 1: + outer = np.delete(outer, outer.index(leg)).tolist() + + return outer + + +def _process_perm(con, batch_leg): + """_process_perm""" + p = list(range(len(con))) + for i, ind in enumerate(batch_leg): + j = con.index(ind) + if i == j: + continue + con[i], con[j] = con[j], con[i] + p[i], p[j] = p[j], p[i] + + return con, tuple(p) + + +def _make_dict(mode, + inds=None, + legs=None, + batch_leg=None, + p_list=None, + res_legs=None, + permute_index=None, + expand_axis=None): + """_summary_ + + Args: + mode: mode + inds: inds. Defaults to None. + legs: legs. Defaults to None. + batch_leg: batch_leg. Defaults to None. + p_list: p_list. Defaults to None. + res_legs: res_legs. Defaults to None. + permute_index: permute_index. Defaults to None. + expand_axis: expand_axis. Defaults to None. + + Raises: + ValueError: ValueError + + Returns: + d + """ + d = {} + calculate_mode = 'mode' + indices = 'inds' + indices_legs = 'legs' + d[calculate_mode] = mode + + if d[calculate_mode] == 'permute': + d['perms'] = p_list + + elif d[calculate_mode] == 'outer': + d[indices] = inds + + elif d[calculate_mode] in ('diag', 'sum', 'trace'): + d[indices] = inds + d[indices_legs] = legs + + elif d[calculate_mode] == 'ndot': + d[indices] = inds + d[indices_legs] = legs + d['batch_leg'] = batch_leg + + elif d[calculate_mode] == 'hadamard': + d[indices] = inds + d[indices_legs] = legs + d['res_legs'] = res_legs + d['permute_index'] = permute_index + d['expand_axis'] = expand_axis + + else: + raise ValueError + + return d + + +def _process_commands(con_list): + """_process_commands + + Args: + con_list: con_list + + Returns: + conmmands, operators + """ + conmmands = [] + operators = [] + + # find sum index + sum_legs = _find_sum(con_list) + for leg in sum_legs: + for i, con in enumerate(con_list): + if leg in con: + leg_ind = con.index(leg) + con_list[i].remove(leg) + conmmands.append(_make_dict('sum', [i], [leg_ind])) + operators.append(ops.sum) + + # find trace + trace_legs = _find_trace(con_list) + for i, leg_list in enumerate(trace_legs): + if leg_list: + for legs in leg_list: + conmmands.append(_make_dict('trace', [i], legs)) + operators.append(trace) + + order = _create_order(con_list) + batch_legs = _find_batch(con_list) + + if not con_list[0]: + return conmmands, operators + + do_ndot(con_list, conmmands, operators, order, batch_legs) + + # do Hadamard(alike) product + do_hadamard(con_list, conmmands, operators) + + # do outer product + for i, con in enumerate(con_list): + if not i: + continue + if con: + inds = [0, i] + for leg in con: + con_list[0].append(leg) + con_list[i] = [] + conmmands.append(_make_dict('outer', inds)) + operators.append(tensordot) + + # do diagonal + min_leg = min(con_list[0]) + for leg in range(-1, min_leg - 1, -1): + num_leg = con_list[0].count(leg) + while num_leg > 1: + i = con_list[0].index(leg) + j = con_list[0].index(leg, i + 1) + conmmands.append(_make_dict('diag', [0], [i, j])) + operators.append(ops.diagonal) + con_list[0] = con_list[0][:i] + con_list[0][i + 1:j] + con_list[0][j + 1:] + [leg] + num_leg = con_list[0].count(leg) + + # do final permutation + fin_con = list(range(-1, -1 - len(con_list[0]), -1)) + con_list[0], p = _process_perm(con_list[0], fin_con) + conmmands.append(_make_dict('permute', p_list=[p])) + operators.append(ops.permute) + + return conmmands, operators + + +def do_ndot(con_list, conmmands, operators, order, batch_legs): + """do_ndot + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + order: order + batch_legs: batch_legs + """ + while order: + leg_now = order[-1] + inds = [] + legs = [] + batch_legs_now = [] + + # find the two tensors' indices + for i, item in enumerate(con_list): + if leg_now in item: + inds.append(i) + + # check trace + if len(inds) == 1: + con_list[inds[0]], legs = _single_trace(con_list[inds[0]], leg_now) + conmmands.append(_make_dict('trace', inds, legs)) + operators.append(trace) + + else: + # find batch legs + batch_leg_inds = [] + if batch_legs is not None: + tmp = np.intersect1d(con_list[inds[0]], con_list[inds[1]]) + batch_legs_now = np.intersect1d(tmp, batch_legs, False).tolist() + + # find indices of batch legs + for batch_leg in batch_legs_now: + i_leg_0 = con_list[inds[0]].index(batch_leg) + i_leg_1 = con_list[inds[1]].index(batch_leg) + con_list[inds[0]].remove(batch_leg) + con_list[inds[1]].remove(batch_leg) + batch_leg_inds.append((i_leg_0, i_leg_1, None)) + + ndot_legs = [] + ndot_leg_inds = [] + # find all ndot legs and their indices + for leg in con_list[inds[0]]: + if leg in con_list[inds[1]]: + i_leg_0 = con_list[inds[0]].index(leg) + i_leg_1 = con_list[inds[1]].index(leg) + ndot_legs.append(leg) + ndot_leg_inds.append([i_leg_0, i_leg_1]) + + # do ndot contraction and update order + for leg in ndot_legs: + con_list[inds[0]].remove(leg) + con_list[inds[1]].remove(leg) + for leg in ndot_legs: + if leg != leg_now: + order.remove(leg) + + ndot_leg_inds = ndot_leg_inds[0] if len(ndot_leg_inds) == 1 else np.array( + ndot_leg_inds).transpose().tolist() + conmmands.append(_make_dict('ndot', inds, list_to_tuple(ndot_leg_inds), batch_leg_inds)) + operators.append( + nest_vmap(tensordot, batch_leg_inds, [0] * len(batch_leg_inds), 0) if batch_leg_inds else tensordot) + + # merge two con_list + for leg in con_list[inds[1]]: + if leg not in batch_legs_now: + con_list[inds[0]].append(leg) + con_list[inds[1]] = [] + con_list[inds[0]] = batch_legs_now + con_list[inds[0]] + + order = order[:-1] + + +def do_hadamard(con_list, conmmands, operators): + """do_hadamard + + Args: + con_list: con_list + conmmands: conmmands + operators: operators + """ + is_con_list_not_none = len(con_list) == 2 and con_list[1] + if is_con_list_not_none and not [i for i in con_list[0] if i > 0] and not [i for i in con_list[1] if i > 0]: + con_list_all = [] + for con in con_list: + con_list_all.extend(con) + con_min_leg = min(con_list_all) + out_list = [i for i in range(-1, con_min_leg - 1, -1)] + + res_legs = [] + for ind in out_list: + for i, con in enumerate(con_list): + if ind in con: + res_legs.append((i, con.index(ind))) + break + + hadamard_legs = [[], []] + con_raw = deepcopy(con_list) + handle_inds(con_list, out_list, hadamard_legs) + + expand_axis = deepcopy(hadamard_legs) + for i, axis in enumerate(expand_axis): + if axis and len(axis) <= 1: + expand_axis[i] = axis[0] + + # input permute + permute_index = [[], []] + con_sort = deepcopy(con_raw) + for i, con in enumerate(con_raw): + con_sort[i].sort(reverse=True) + _, permute_index[i] = _process_perm(con, con_sort[i]) + + conmmands.append( + _make_dict('hadamard', + inds=[0, 1], + legs=hadamard_legs, + res_legs=res_legs, + permute_index=permute_index, + expand_axis=expand_axis)) + operators.append([ops.permute, ops.tile, ops.mul, expand_dims]) + + +def handle_inds(con_list, out_list, hadamard_legs): + """handle_inds""" + for i, con in enumerate(con_list): + if con: + for ind in out_list: + if ind not in con: + hadamard_legs[i].append((out_list.index(ind))) + if i: + con_list[i] = [] + else: + con_list[i] = out_list + + +class Ncon(nn.Cell): + r""" + Multiple-tensor contraction operator which has similar function to Einsum. + + Args: + con_list (List[List[int]]): lists of indices for each tensor. + The number of each list in `con_list` should coincide with the corresponding tensor's dimensions. + The positive indices indicate the dimensions to be contracted or summed. + The negative indices indicate the dimensions to be keeped (as batch dimensions). + + Inputs: + - **input** (List[Tensor]) - Tensor List. + + Outputs: + - **output** (Tensor) - The shape of tensor depends on the input and the computation process. + + Raises: + ValueError: If the number of commands is not match the number of operations. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> from mindspore import ops + >>> from mindchemistry.e3.utils import Ncon + Trace of a matrix: + >>> a = ops.ones((3, 3)) + >>> Ncon([[1, 1]])([a]) + 3.0 + Diagonal of a matrix: + >>> Ncon([[-1, -1]])([a]) + [1. 1. 1.] + Outer product: + >>> b = ops.ones((2)) + >>> c = ops.ones((3)) + >>> Ncon([[-1], [-2]])([b, c]).shape + (2, 3) + Batch matrix multiplication + >>> d = ops.ones((2, 3, 4)) + >>> e = ops.ones((2, 4, 1)) + >>> Ncon([[-1, -2, 1], [-1, 1, -3]])([d, e]).shape + (2, 3, 1) + """ + + def __init__(self, con_list): + super().__init__() + self.con_list = tuple(con_list) + con_list_copy = deepcopy(con_list) + self.commands, self.ops = _process_commands(con_list_copy) + if len(self.commands) != len(self.ops): + raise ValueError(f'{self.commands} is not match {len(self.ops)}') + + def construct(self, ten_list): + """ + The list of tensors to be conctracted. + """ + i = 0 + for d in self.commands: + if d['mode'] == 'diag': + ten_list[0] = self.ops[i](ten_list[0], 0, *d['legs']) + elif d['mode'] == 'permute': + ten_list[0] = self.ops[i](ten_list[0], d['perms'][0]) + elif d['mode'] == 'sum': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], d['legs'][0]) + elif d['mode'] == 'trace': + i1 = d['inds'][0] + ten_list[i1] = self.ops[i](ten_list[i1], 0, d['legs'][0], d['legs'][1]) + elif d['mode'] == 'outer': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], 0) + elif d['mode'] == 'ndot': + i1, i2 = d['inds'] + ten_list[i1] = self.ops[i](ten_list[i1], ten_list[i2], d['legs']) + elif d['mode'] == 'hadamard': + i1, i2 = d['inds'] + a = ten_list[i1] + b = ten_list[i2] + res_legs = d['res_legs'] + + a = ops.permute(a, d['permute_index'][i1]) + b = ops.permute(b, d['permute_index'][i2]) + + if d['expand_axis'][i1]: + a = expand_dims(a, d['expand_axis'][i1]) + if d['expand_axis'][i2]: + b = expand_dims(b, d['expand_axis'][i2]) + + tile_index = [[1 for _ in res_legs], [1 for _ in res_legs]] + for j in range(len(d['legs'][i1])): + tile_index[0][d['legs'][i1][j]] = ten_list[res_legs[d['legs'][i1][j]][0]].shape[res_legs[ + d['legs'][i1][j]][1]] + for j in range(len(d['legs'][i2])): + tile_index[1][d['legs'][i2][j]] = ten_list[res_legs[d['legs'][i2][j]][0]].shape[res_legs[ + d['legs'][i2][j]][1]] + a = ops.tile(a, tuple(tile_index[0])) + b = ops.tile(b, tuple(tile_index[1])) + + ten_list[i1] = ops.mul(a, b) + else: + i += 1 + continue + i += 1 + return ten_list[0] + + def __repr__(self): + s = f'Ncon: {self.con_list}\n' + for d in self.commands: + s += str(d) + '\n' + return s \ No newline at end of file diff --git a/mindscience/sharker/utils/negative_sampling.py b/mindscience/sharker/utils/negative_sampling.py new file mode 100644 index 000000000..1d45b7ffa --- /dev/null +++ b/mindscience/sharker/utils/negative_sampling.py @@ -0,0 +1,388 @@ +import random +from typing import Optional, Tuple, Union + +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn, mint + +from .coalesce import coalesce +from .functions import cumsum +from .degree import degree +from .loop import remove_self_loops +from .num_nodes import maybe_num_nodes + + +def negative_sampling( + edge_index: Tensor, + num_nodes: Optional[Union[int, Tuple[int, int]]] = None, + num_neg_samples: Optional[int] = None, + method: str = "sparse", + force_undirected: bool = False, +) -> Tensor: + r"""Samples random negative edges of a graph given by :attr:`edge_index`. + + Args: + edge_index (LongTensor): The edge indices. + num_nodes (int or Tuple[int, int], optional): The number of nodes, + *i.e.* :obj:`max_val + 1` of :attr:`edge_index`. + If given as a tuple, then :obj:`edge_index` is interpreted as a + bipartite graph with shape :obj:`(num_src_nodes, num_dst_nodes)`. + (default: :obj:`None`) + num_neg_samples (int, optional): The (approximate) number of negative + samples to return. + If set to :obj:`None`, will try to return a negative edge for every + positive edge. (default: :obj:`None`) + method (str, optional): The method to use for negative sampling, + *i.e.* :obj:`"sparse"` or :obj:`"dense"`. + This is a memory/runtime trade-off. + :obj:`"sparse"` will work on any graph of any size, while + :obj:`"dense"` can perform faster true-negative checks. + (default: :obj:`"sparse"`) + force_undirected (bool, optional): If set to :obj:`True`, sampled + negative edges will be undirected. (default: :obj:`False`) + + :rtype: LongTensor + + Examples: + >>> # Standard usage + >>> edge_index = Tensor([[0, 0, 1, 2], + ... [0, 1, 2, 3]]) + >>> negative_sampling(edge_index) + tensor([[3, 0, 0, 3], + [2, 3, 2, 1]]) + + >>> # For bipartite graph + >>> negative_sampling(edge_index, num_nodes=(3, 4)) + tensor([[0, 2, 2, 1], + [2, 2, 1, 3]]) + """ + assert method in ["sparse", "dense"] + + if num_nodes is None: + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + if isinstance(num_nodes, int): + size = (num_nodes, num_nodes) + bipartite = False + else: + size = num_nodes + bipartite = True + force_undirected = False + + idx, population = edge_index_to_vector( + edge_index, size, bipartite, force_undirected + ) + + if idx.numel() >= population: + return -ops.ones((2, 0), dtype=edge_index.dtype) + + if num_neg_samples is None: + num_neg_samples = edge_index.shape[1] + if force_undirected: + num_neg_samples = num_neg_samples // 2 + + prob = 1.0 - idx.numel() / population # Probability to sample a negative. + sample_size = int(1.1 * num_neg_samples / prob) # (Over)-sample size. + + neg_idx: Optional[Tensor] = None + if method == "dense": + # The dense version creates a mask of shape `population` to check for + # invalid samples. + mask = ops.ones(population).bool() + mask[idx] = False + for _ in range(3): # Number of tries to sample negative indices. + rnd = sample(population, sample_size) + rnd = rnd[((mask.astype("int32"))[rnd]).astype(ms.bool_)] # Filter true negatives. + neg_idx = rnd if neg_idx is None else mint.cat([neg_idx, rnd]) + if neg_idx.numel() >= num_neg_samples: + neg_idx = neg_idx[:num_neg_samples] + break + mask[neg_idx] = False + + else: # 'sparse' + # The sparse version checks for invalid samples via `np.isin`. + for _ in range(3): # Number of tries to sample negative indices. + rnd = sample(population, sample_size) + mask = np.isin(rnd.numpy(), idx.numpy()) # type: ignore + if neg_idx is not None: + mask |= np.isin(rnd, neg_idx) + mask = Tensor.from_numpy(mask).bool() + rnd = rnd[~mask] + neg_idx = rnd if neg_idx is None else mint.cat([neg_idx, rnd]) + if neg_idx.numel() >= num_neg_samples: + neg_idx = neg_idx[:num_neg_samples] + break + + assert neg_idx is not None + return vector_to_edge_index(neg_idx, size, bipartite, force_undirected) + + +def batched_negative_sampling( + edge_index: Tensor, + batch: Union[Tensor, Tuple[Tensor, Tensor]], + num_neg_samples: Optional[int] = None, + method: str = "sparse", + force_undirected: bool = False, +) -> Tensor: + r"""Samples random negative edges of multiple graphs given by + :attr:`edge_index` and :attr:`batch`. + + Args: + edge_index (LongTensor): The edge indices. + batch (LongTensor or Tuple[LongTensor, LongTensor]): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. + If given as a tuple, then :obj:`edge_index` is interpreted as a + bipartite graph connecting two different node types. + num_neg_samples (int, optional): The number of negative samples to + return. If set to :obj:`None`, will try to return a negative edge + for every positive edge. (default: :obj:`None`) + method (str, optional): The method to use for negative sampling, + *i.e.* :obj:`"sparse"` or :obj:`"dense"`. + This is a memory/runtime trade-off. + :obj:`"sparse"` will work on any graph of any size, while + :obj:`"dense"` can perform faster true-negative checks. + (default: :obj:`"sparse"`) + force_undirected (bool, optional): If set to :obj:`True`, sampled + negative edges will be undirected. (default: :obj:`False`) + + :rtype: LongTensor + + Examples: + >>> # Standard usage + >>> edge_index = Tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) + >>> edge_index = ops.cat([edge_index, edge_index + 4], dim=1) + >>> edge_index + tensor([[0, 0, 1, 2, 4, 4, 5, 6], + [0, 1, 2, 3, 4, 5, 6, 7]]) + >>> batch = Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + >>> batched_negative_sampling(edge_index, batch) + tensor([[3, 1, 3, 2, 7, 7, 6, 5], + [2, 0, 1, 1, 5, 6, 4, 4]]) + + >>> # For bipartite graph + >>> edge_index1 = Tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) + >>> edge_index2 = edge_index1 + Tensor([[2], [4]]) + >>> edge_index3 = edge_index2 + Tensor([[2], [4]]) + >>> edge_index = ops.cat([edge_index1, edge_index2, + ... edge_index3], dim=1) + >>> edge_index + tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], + [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]]) + >>> src_batch = Tensor([0, 0, 1, 1, 2, 2]) + >>> dst_batch = Tensor, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) + >>> batched_negative_sampling(edge_index, + ... (src_batch, dst_batch)) + tensor([[ 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5], + [ 2, 3, 0, 1, 6, 7, 4, 5, 10, 11, 8, 9]]) + """ + if isinstance(batch, Tensor): + src_batch, dst_batch = batch, batch + else: + src_batch, dst_batch = batch[0], batch[1] + + split = degree(src_batch[edge_index[0]], dtype=ms.int64).tolist() + edge_indices = mint.split(edge_index, split, dim=1) + + num_src = degree(src_batch, dtype=ms.int64) + cum_src = cumsum(num_src)[:-1] + + if isinstance(batch, Tensor): + num_nodes = num_src.tolist() + ptr = cum_src + else: + num_dst = degree(dst_batch, dtype=ms.int64) + cum_dst = cumsum(num_dst)[:-1] + + num_nodes = mint.stack([num_src, num_dst], dim=1).tolist() + ptr = mint.stack([cum_src, cum_dst], dim=1).unsqueeze(-1) + + neg_edge_indices = [] + for i, edge_index in enumerate(edge_indices): + edge_index = edge_index - ptr[i] + neg_edge_index = negative_sampling( + edge_index, num_nodes[i], num_neg_samples, method, force_undirected + ) + neg_edge_index += ptr[i] + neg_edge_indices.append(neg_edge_index) + + return mint.cat(neg_edge_indices, dim=1) + + +def structured_negative_sampling( + edge_index: Tensor, + num_nodes: Optional[int] = None, + contains_neg_self_loops: bool = True, +) -> Tuple[Tensor, Tensor, Tensor]: + r"""Samples a negative edge :obj:`(i,k)` for every positive edge + :obj:`(i,j)` in the graph given by :attr:`edge_index`, and returns it as a + tuple of the form :obj:`(i,j,k)`. + + Args: + edge_index (LongTensor): The edge indices. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + contains_neg_self_loops (bool, optional): If set to + :obj:`False`, sampled negative edges will not contain self loops. + (default: :obj:`True`) + + :rtype: (LongTensor, LongTensor, LongTensor) + + Example: + >>> edge_index = Tensor([[0, 0, 1, 2], + ... [0, 1, 2, 3]]) + >>> structured_negative_sampling(edge_index) + (tensor([0, 0, 1, 2]), tensor([0, 1, 2, 3]), tensor([2, 3, 0, 2])) + + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + row, col = edge_index + pos_idx = row * num_nodes + col + if not contains_neg_self_loops: + loop_idx = mint.arange(num_nodes) * (num_nodes + 1) + pos_idx = mint.cat([pos_idx, loop_idx], dim=0) + + rand = ops.randint(0, num_nodes, (row.shape[0],)).long() + neg_idx = row * num_nodes + rand + + mask = Tensor.from_numpy(np.isin(neg_idx, pos_idx)).bool() + rest = mask.nonzero().view(-1) + while rest.numel() > 0: # pragma: no cover + tmp = ops.randint(0, num_nodes, (rest.shape[0],)).long() + rand[rest] = tmp + neg_idx = row[rest] * num_nodes + tmp + + mask = Tensor.from_numpy(np.isin(neg_idx, pos_idx)).bool() + rest = rest[mask] + + return edge_index[0], edge_index[1], rand + + +def structured_negative_sampling_feasible( + edge_index: Tensor, + num_nodes: Optional[int] = None, + contains_neg_self_loops: bool = True, +) -> bool: + r"""Returns :obj:`True` if + :meth:`~sharker.utils.structured_negative_sampling` is feasible + on the graph given by :obj:`edge_index`. + :meth:`~sharker.utils.structured_negative_sampling` is infeasible + if atleast one node is connected to all other nodes. + + Args: + edge_index (LongTensor): The edge indices. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + contains_neg_self_loops (bool, optional): If set to + :obj:`False`, sampled negative edges will not contain self loops. + (default: :obj:`True`) + + :rtype: bool + + Examples: + >>> edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 2], + ... [1, 2, 0, 2, 0, 1, 1]]) + >>> structured_negative_sampling_feasible(edge_index, 3, False) + False + + >>> structured_negative_sampling_feasible(edge_index, 3, True) + True + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + max_num_neighbors = num_nodes + + edge_index = coalesce(edge_index, num_nodes=num_nodes) + + if not contains_neg_self_loops: + edge_index, _ = remove_self_loops(edge_index) + max_num_neighbors -= 1 # Reduce number of valid neighbors + + deg = degree(edge_index[0], num_nodes) + # True if there exists no node that is connected to all other nodes. + return bool(mint.all(deg < max_num_neighbors)) + + +############################################################################### + + +def sample( + population: int, + k: int, +) -> Tensor: + if population <= k: + return mint.arange(population) + else: + return Tensor(random.sample(range(population), k)) + + +def edge_index_to_vector( + edge_index: Tensor, + size: Tuple[int, int], + bipartite: bool, + force_undirected: bool = False, +) -> Tuple[Tensor, int]: + + row, col = edge_index.copy() + + if bipartite: # No need to account for self-loops. + idx = (row * size[1]) + col + population = size[0] * size[1] + return idx, population + + elif force_undirected: + assert size[0] == size[1] + num_nodes = size[0] + + # We only operate on the upper triangular matrix: + mask = row < col + row, col = row[mask], col[mask] + offset = mint.cumsum(mint.arange(1, num_nodes), dim=0)[row] + idx = row * num_nodes + col - offset + population = (num_nodes * (num_nodes + 1)) // 2 - num_nodes + return idx, population + + else: + assert size[0] == size[1] + num_nodes = size[0] + + # We remove self-loops as we do not want to take them into account + # when sampling negative values. + mask = row != col + row, col = row[mask], col[mask] + col[row < col] -= 1 + idx = row * (num_nodes - 1) + col + population = num_nodes * num_nodes - num_nodes + return idx, population + + +def vector_to_edge_index( + idx: Tensor, + size: Tuple[int, int], + bipartite: bool, + force_undirected: bool = False, +) -> Tensor: + + if bipartite: # No need to account for self-loops. + row = idx.div(size[1], rounding_mode="floor") + col = idx % size[1] + return mint.stack([row, col], dim=0) + + elif force_undirected: + assert size[0] == size[1] + num_nodes = size[0] + + offset = mint.cumsum(mint.arange(1, num_nodes), dim=0) + end = mint.arange(num_nodes, num_nodes * num_nodes, num_nodes) + row = ops.bucketize(idx, (end - offset).tolist(), right=True).astype(idx.dtype) + col = (offset[row] + idx) % num_nodes + return mint.stack([mint.cat([row, col]), mint.cat([col, row])], 0) + + else: + assert size[0] == size[1] + num_nodes = size[0] + + row = idx.div(num_nodes - 1, rounding_mode="floor") + col = idx % (num_nodes - 1) + col[row <= col] += 1 + return mint.stack([row, col], dim=0) diff --git a/mindscience/sharker/utils/noise_scheduler.py b/mindscience/sharker/utils/noise_scheduler.py new file mode 100644 index 000000000..35ebe4bf0 --- /dev/null +++ b/mindscience/sharker/utils/noise_scheduler.py @@ -0,0 +1,89 @@ +import math +from typing import Literal, Optional + +import mindspore as ms +from mindspore import Tensor, ops, mint + + +def get_smld_sigma_schedule( + sigma_min: float, + sigma_max: float, + num_scales: int, + dtype: Optional[ms.Type] = None, +) -> Tensor: + r"""Generates a set of noise values on a logarithmic scale for "Score + Matching with Langevin Dynamics" from the `"Generative Modeling by + Estimating Gradients of the Data Distribution" + `_ paper. + + This function returns a vector of sigma values that define the schedule of + noise levels used during Score Matching with Langevin Dynamics. + The sigma values are determined on a logarithmic scale from + :obj:`sigma_max` to :obj:`sigma_min`, inclusive. + + Args: + sigma_min (float): The minimum value of sigma, corresponding to the + lowest noise level. + sigma_max (float): The maximum value of sigma, corresponding to the + highest noise level. + num_scales (int): The number of sigma values to generate, defining the + granularity of the noise schedule. + dtype (ms.Type, optional): The output data type. + (default: :obj:`None`) + """ + out = ops.linspace( + math.log(sigma_max), + math.log(sigma_min), + num_scales, + ).exp() + + if dtype is not None: + out = out.astype(dtype) + return out + + +def get_diffusion_beta_schedule( + schedule_type: Literal["linear", "quadratic", "constant", "sigmoid"], + beta_start: float, + beta_end: float, + num_diffusion_timesteps: int, + dtype: Optional[ms.Type] = None, +) -> Tensor: + r"""Generates a schedule of beta values according to the specified strategy + for the diffusion process from the `"Denoising Diffusion Probabilistic + Models" `_ paper. + + Beta values are used to scale the noise added during the diffusion process + in generative models. This function creates an array of beta values + according to a pre-defined schedule, which can be either :obj:`"linear"`, + :obj:`"quadratic"`, :obj:`"constant"`, or :obj:`"sigmoid"`. + + Args: + schedule_type (str): The type of schedule to use for beta values. + beta_start (float): The starting value of beta. + beta_end (float): The ending value of beta. + num_diffusion_timesteps (int): The number of timesteps for the + diffusion process. + dtype (ms.Type, optional): The output data type. + (default: :obj:`None`) + ops.cat (ops.cat, optional): The output ops.cat. + (default: :obj:`None`) + """ + if schedule_type == "linear": + out = ops.linspace(beta_start, beta_end, num_diffusion_timesteps) + + elif schedule_type == "quadratic": + out = ops.linspace( + beta_start**0.5, beta_end**0.5, num_diffusion_timesteps + ) ** 2 + elif schedule_type == "constant": + return ops.full((num_diffusion_timesteps,), fill_value=beta_end) + + elif schedule_type == "sigmoid": + out = ops.linspace(-6, 6, num_diffusion_timesteps).sigmoid() * (beta_end - beta_start) + beta_start + else: + raise ValueError(f"Found invalid 'schedule_type' (got '{schedule_type}')") + + if dtype is not None: + out = out.astype(dtype=dtype) + return out diff --git a/mindscience/sharker/utils/normalize.py b/mindscience/sharker/utils/normalize.py new file mode 100644 index 000000000..105335329 --- /dev/null +++ b/mindscience/sharker/utils/normalize.py @@ -0,0 +1,36 @@ +from typing import Optional + +from mindspore import Tensor, ops, nn + +from .degree import degree + + +def normalized_cut( + edge_index: Tensor, + edge_attr: Tensor, + num_nodes: Optional[int] = None, +) -> Tensor: + r"""Computes the normalized cut :math:`\mathbf{e}_{i,j} \cdot + \left( \frac{1}{\deg(i)} + \frac{1}{\deg(j)} \right)` of a weighted graph + given by edge indices and edge attributes. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor): Edge weights or multi-dimensional edge features. + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + + :rtype: :class:`Tensor` + + Example: + >>> edge_index = Tensor([[1, 1, 2, 3], + ... [3, 3, 1, 2]]) + >>> edge_attr = Tensor1., 1., 1.]) + >>> normalized_cut(edge_index, edge_attr) + tensor([1.5000, 1.5000, 2.0000, 1.5000]) + """ + row, col = edge_index[0], edge_index[1] + deg = 1.0 / degree(col, num_nodes, edge_attr.dtype) + deg = deg[row] + deg[col] + cut = edge_attr * deg + return cut diff --git a/mindscience/sharker/utils/num_nodes.py b/mindscience/sharker/utils/num_nodes.py new file mode 100644 index 000000000..60304d386 --- /dev/null +++ b/mindscience/sharker/utils/num_nodes.py @@ -0,0 +1,57 @@ +from copy import copy +from typing import Dict, Optional, Tuple, Union + +import numpy as np +import mindspore as ms + +from mindspore import Tensor, COOTensor, CSRTensor + + +def maybe_num_nodes( + edge_index: Union[Tensor, Tuple[Tensor, Tensor]], + num_nodes: Optional[int] = None, +) -> int: + if num_nodes is not None: + return num_nodes + elif isinstance(edge_index, Tensor): + return int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + elif isinstance(edge_index, np.ndarray): + return int(edge_index.max()) + 1 if edge_index.size > 0 else 0 + elif isinstance(edge_index, tuple): + if isinstance(edge_index[0], Tensor): + return max( + int(edge_index[0].max()) + 1 if edge_index[0].numel() > 0 else 0, + int(edge_index[1].max()) + 1 if edge_index[1].numel() > 0 else 0, + ) + elif isinstance(edge_index[0], np.ndarray): + return max( + int(edge_index[0].max()) + 1 if edge_index[0].size > 0 else 0, + int(edge_index[1].max()) + 1 if edge_index[1].size > 0 else 0, + ) + elif isinstance(edge_index, (COOTensor, CSRTensor)): + return max(edge_index.shape[0], edge_index.shape[1]) + else: + raise NotImplementedError + + +def maybe_num_nodes_dict( + edge_index_dict: Dict[Tuple[str, str, str], Tensor], + num_nodes_dict: Optional[Dict[str, int]] = None, +) -> Dict[str, int]: + num_nodes_dict = {} if num_nodes_dict is None else copy(num_nodes_dict) + + found_types = list(num_nodes_dict.keys()) + + for keys, edge_index in edge_index_dict.items(): + + key = keys[0] + if key not in found_types: + N = int(edge_index[0].max() + 1) + num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) + + key = keys[-1] + if key not in found_types: + N = int(edge_index[1].max() + 1) + num_nodes_dict[key] = max(N, num_nodes_dict.get(key, N)) + + return num_nodes_dict diff --git a/mindscience/sharker/utils/random.py b/mindscience/sharker/utils/random.py new file mode 100644 index 000000000..5918584e1 --- /dev/null +++ b/mindscience/sharker/utils/random.py @@ -0,0 +1,88 @@ +import warnings +from typing import List, Union + +import mindspore as ms +from mindspore import Tensor, ops, mint, Generator + +from .loop import remove_self_loops +from .undirected import to_undirected + + + +def erdos_renyi_graph( + num_nodes: int, + edge_prob: float, + directed: bool = False, +) -> Tensor: + r"""Returns the :obj:`edge_index` of a random Erdos-Renyi graph. + + Args: + num_nodes (int): The number of nodes. + edge_prob (float): Probability of an edge. + directed (bool, optional): If set to :obj:`True`, will return a + directed graph. (default: :obj:`False`) + + Examples: + >>> erdos_renyi_graph(5, 0.2, directed=False) + tensor([[0, 1, 1, 4], + [1, 0, 4, 1]]) + + >>> erdos_renyi_graph(5, 0.2, directed=True) + tensor([[0, 1, 3, 3, 4, 4], + [4, 3, 1, 2, 1, 3]]) + """ + if directed: + idx = mint.arange((num_nodes - 1) * num_nodes) + idx = idx.view(num_nodes - 1, num_nodes) + idx = idx + mint.arange(1, num_nodes).view(-1, 1) + idx = idx.view(-1) + else: + warnings.filterwarnings("ignore", ".*pass the indexing argument.*") + idx = ops.combinations(mint.arange(num_nodes), r=2) + + # Filter edges. + mask = ops.rand(idx.shape[0]) < edge_prob + + if not mask.any(): + return idx.T[:, mask.T] + idx = idx[mask] + + if directed: + row = idx.div(num_nodes, rounding_mode="floor") + col = idx % num_nodes + edge_index = mint.stack([row, col], dim=0) + else: + edge_index = to_undirected(idx.t(), num_nodes=num_nodes) + print(" erdos_renyi_grapht edge index-----: ", edge_index) + return edge_index + + +def barabasi_albert_graph(num_nodes: int, num_edges: int) -> Tensor: + r"""Returns the :obj:`edge_index` of a Barabasi-Albert preferential + attachment model, where a graph of :obj:`num_nodes` nodes grows by + attaching new nodes with :obj:`num_edges` edges that are preferentially + attached to existing nodes with high degree. + + Args: + num_nodes (int): The number of nodes. + num_edges (int): The number of edges from a new node to existing nodes. + + Example: + >>> barabasi_albert_graph(num_nodes=4, num_edges=3) + tensor([[0, 0, 0, 1, 1, 2, 2, 3], + [1, 2, 3, 0, 2, 0, 1, 0]]) + """ + assert num_edges > 0 and num_edges < num_nodes + + row, col = mint.arange(num_edges), ops.shuffle(mint.arange(num_edges)) + + for i in range(num_edges, num_nodes): + row = mint.cat([row, ops.full((num_edges,), i).long()]) + choice = ops.shuffle(mint.cat([row, col]))[:num_edges] + col = mint.cat([col, choice]) + + edge_index = mint.stack([row, col], dim=0) + edge_index, _ = remove_self_loops(edge_index) + edge_index = to_undirected(edge_index, num_nodes=num_nodes) + + return edge_index diff --git a/mindscience/sharker/utils/repeat.py b/mindscience/sharker/utils/repeat.py new file mode 100644 index 000000000..d4043b674 --- /dev/null +++ b/mindscience/sharker/utils/repeat.py @@ -0,0 +1,35 @@ +import itertools +import numbers +from typing import Any + +from mindspore import Tensor, mint + + +def repeat(src: Any, length: int) -> Any: + if src is None: + return None + + if isinstance(src, Tensor): + if src.numel() == 1: + return src.tile([length]) + + if src.numel() > length: + return src[:length] + + if src.numel() < length: + last_elem = src[-1].unsqueeze(0) + padding = last_elem.tile([length - src.numel()]) + return mint.cat([src, padding]) + + return src + + if isinstance(src, numbers.Number): + return list(itertools.repeat(src, length)) + + if len(src) > length: + return src[:length] + + if len(src) < length: + return src + list(itertools.repeat(src[-1], length - len(src))) + + return src diff --git a/mindscience/sharker/utils/select.py b/mindscience/sharker/utils/select.py new file mode 100644 index 000000000..16ed559b7 --- /dev/null +++ b/mindscience/sharker/utils/select.py @@ -0,0 +1,68 @@ +from typing import Any, List, Union + +import numpy as np +import mindspore as ms +from mindspore import ops, mint + +from .mask import mask_select, mask_select_np + + +def select( + src: Union[ms.Tensor, np.ndarray, List[Any]], + index_or_mask: Union[ms.Tensor, np.ndarray], + axis: int, +) -> Union[ms.Tensor, np.ndarray, List[Any]]: + r"""Selects the input tensor or input list according to a given index or + mask vector. + + Args: + src (Tensor or list): The input tensor or list. + index_or_mask (Tensor): The index or mask vector. + axis (int): The dimension along which to select. + """ + if isinstance(src, ms.Tensor): + if index_or_mask.dtype == ms.bool_: + return mask_select(src, axis=axis, mask=index_or_mask) + return mint.index_select(src, dim=axis, index=index_or_mask) + + if isinstance(src, np.ndarray): + if index_or_mask.dtype == np.bool_: + return mask_select_np(src, axis=axis, mask=index_or_mask) + return np.take(src, index_or_mask, axis=axis) + + if isinstance(src, (tuple, list)): + if axis != 0: + raise ValueError("Cannot select along dimension other than 0") + if index_or_mask.dtype == ms.bool_: + return [src[i] for i, m in enumerate(index_or_mask) if m] + return [src[i] for i in index_or_mask] + + raise ValueError(f"Encountered invalid input type (got '{type(src)}')") + + +def narrow( + src: Union[ms.Tensor, List[Any]], axis: int, start: int, length: int +) -> Union[ms.Tensor, List[Any]]: + r"""Narrows the input tensor or input list to the specified range. + + Args: + src (Tensor or list): The input tensor or list. + axis (int): The dimension along which to narrow. + start (int): The starting dimension. + length (int): The distance to the ending dimension. + """ + if isinstance(src, ms.Tensor): + return mint.narrow(src, axis, start, length) + if isinstance(src, np.ndarray): + if axis == 0: + return src[start:start+length] + else: + src = src.swapaxes(0, axis) + out = src[start:start+length] + return out.swapaxes(0, axis) + if isinstance(src, list): + if axis != 0: + raise ValueError("Cannot narrow along dimension other than 0") + return src[start: start + length] + + raise ValueError(f"Encountered invalid input type (got '{type(src)}')") diff --git a/mindscience/sharker/utils/softmax.py b/mindscience/sharker/utils/softmax.py new file mode 100644 index 000000000..149a67507 --- /dev/null +++ b/mindscience/sharker/utils/softmax.py @@ -0,0 +1,70 @@ +from typing import Optional + +from mindspore import Tensor, mint +from .num_nodes import maybe_num_nodes +from . import scatter +from . import segment + + +def softmax( + src: Tensor, + index: Optional[Tensor] = None, + ptr: Optional[Tensor] = None, + num_nodes: Optional[int] = None, + axis: int = 0, +) -> Tensor: + r"""Computes a sparsely evaluated softmax. + Given a value tensor :attr:`src`, this function first groups the values + along the first dimension based on the indices specified in :attr:`index`, + and then proceeds to compute the softmax individually for each group. + + Args: + src (Tensor): The source tensor. + index (LongTensor, optional): The indices of elements for applying the + softmax. (default: :obj:`None`) + ptr (LongTensor, optional): If given, computes the softmax based on + sorted inputs in CSR representation. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`index`. (default: :obj:`None`) + dim (int, optional): The dimension in which to normalize. + (default: :obj:`0`) + + :rtype: :class:`Tensor` + + Examples: + >>> src = Tensor([1., 1., 1., 1.]) + >>> index = Tensor([0, 0, 1, 2]) + >>> ptr = Tensor([0, 2, 3, 4]) + >>> softmax(src, index) + tensor([0.5000, 0.5000, 1.0000, 1.0000]) + + >>> softmax(src, None, ptr) + tensor([0.5000, 0.5000, 1.0000, 1.0000]) + + >>> src = ops.randn(4, 4) + >>> ptr = Tensor([0, 4]) + >>> softmax(src, index, dim=-1) + tensor([[0.7404, 0.2596, 1.0000, 1.0000], + [0.1702, 0.8298, 1.0000, 1.0000], + [0.7607, 0.2393, 1.0000, 1.0000], + [0.8062, 0.1938, 1.0000, 1.0000]]) + """ + if ptr is not None and (ptr.dim() == 1 or (ptr.dim() > 1 and index is None)): + axis = axis + src.dim() if axis < 0 else axis + count = ptr[1:] - ptr[:-1] + src_max = segment(src, ptr, dim=axis, reduce='max') + src_max = src_max.repeat_interleave(count.tolist(), dim=axis) + out = (src - src_max).exp() + out_sum = segment(out, ptr, dim=axis, reduce='sum') + 1e-16 + out_sum = out_sum.repeat_interleave(count.tolist(), dim=axis) + elif index is not None: + N = maybe_num_nodes(index, num_nodes) + src_max = scatter(src, index, axis, dim_size=N, reduce="max") + out = src - mint.index_select(src_max, axis, index.astype("int32")) + out = out.exp() + out_sum = scatter(out, index, axis, dim_size=N, reduce="sum") + 1e-16 + out_sum = mint.index_select(out_sum, axis, index.astype("int32")) + else: + raise NotImplementedError("'softmax' requires 'index' to be specified") + + return out / out_sum \ No newline at end of file diff --git a/mindscience/sharker/utils/sort_edge_index.py b/mindscience/sharker/utils/sort_edge_index.py new file mode 100644 index 000000000..69fd5c65b --- /dev/null +++ b/mindscience/sharker/utils/sort_edge_index.py @@ -0,0 +1,104 @@ +from typing import List, Optional, Tuple, Union +import numpy as np +from mindspore import Tensor, ops, mint + +from .num_nodes import maybe_num_nodes + +MISSING = "???" + + +def sort_edge_index( + edge_index: Tensor, + edge_attr: Union[Optional[Tensor], List[Tensor], str] = MISSING, + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Union[Tensor, Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, List[Tensor]]]: + """Row-wise sorts :obj:`edge_index`. + + Args: + edge_index (Tensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights + or multi-dimensional edge features. + If given as a list, will re-shuffle and remove duplicates for all + its entries. (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + sort_by_src (bool, optional): If set to :obj:`False`, will sort + :obj:`edge_index` column-wise/by destination node. + (default: :obj:`True`) + + :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else + (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) + + .. warning:: + + From :pyg:`PyG >= 2.3.0` onwards, this function will always return a + tuple whenever :obj:`edge_attr` is passed as an argument (even in case + it is set to :obj:`None`). + + Examples: + >>> edge_index = Tensor([[2, 1, 1, 0], + [1, 2, 0, 1]]) + >>> edge_attr = Tensor [2], [3], [4]]) + >>> sort_edge_index(edge_index) + tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]) + + >>> sort_edge_index(edge_index, edge_attr) + (tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]), + tensor([[4], + [3], + [2], + [1]])) + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + idx = edge_index[1 - int(sort_by_row)] * num_nodes + idx += edge_index[int(sort_by_row)] + _, perm = mint.sort(idx) + + if isinstance(edge_index, Tensor): + edge_index = mint.index_select(edge_index, 1, perm) + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][perm], edge_index[1][perm]) + else: + raise NotImplementedError + + if edge_attr is None: + return edge_index, None + if isinstance(edge_attr, Tensor): + return edge_index, mint.index_select(edge_attr, 0, perm) + if isinstance(edge_attr, (list, tuple)): + return edge_index, [mint.index_select(e, 0, perm) for e in edge_attr] + + return edge_index + +def sort_edge_index_np( + edge_index: np.ndarray, + edge_attr: Union[Optional[np.ndarray], List[np.ndarray], str] = MISSING, + num_nodes: Optional[int] = None, + sort_by_row: bool = True, +) -> Union[np.ndarray, Tuple[np.ndarray, Optional[np.ndarray]], Tuple[np.ndarray, List[np.ndarray]]]: + + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + idx = edge_index[1 - int(sort_by_row)] * num_nodes + idx += edge_index[int(sort_by_row)] + perm = np.argsort(idx) + + if isinstance(edge_index, np.ndarray): + edge_index = edge_index[:, perm] + elif isinstance(edge_index, tuple): + edge_index = (edge_index[0][perm], edge_index[1][perm]) + else: + raise NotImplementedError + + if edge_attr is None: + return edge_index, None + if isinstance(edge_attr, np.ndarray): + return edge_index, edge_attr[perm] + if isinstance(edge_attr, (list, tuple)): + return edge_index, [e[perm] for e in edge_attr] + + return edge_index diff --git a/mindscience/sharker/utils/sparse.py b/mindscience/sharker/utils/sparse.py new file mode 100644 index 000000000..8175c34e1 --- /dev/null +++ b/mindscience/sharker/utils/sparse.py @@ -0,0 +1,37 @@ +import warnings +from typing import Any, List, Optional, Tuple, Union +import mindspore as ms +from mindspore import Tensor, ops, COOTensor, CSRTensor, mint + +from .coalesce import coalesce +from .functions import cumsum + + +def is_sparse_tensor(src: Any) -> bool: + r"""Returns :obj:`True` if the input :obj:`src` is a + :class:`ms.sparse.Tensor` (in any sparse layout). + + Args: + src (Any): The input object to be checked. + """ + if isinstance(src, COOTensor): + return True + elif isinstance(src, CSRTensor): + return True + return False + + +def ptr2index(ptr: Tensor) -> Tensor: + index = mint.arange(ptr.numel() - 1, dtype=ptr.dtype) + return index.repeat(ptr.diff().tolist()) + + +def index2ptr(index: Tensor, shape: Optional[int] = None) -> Tensor: + count = index.bincount().astype(index.dtype) + if shape is not None: + ptr = mint.zeros(shape).astype(index.dtype) + ptr[:len(count)] = count + else: + ptr = count + return cumsum(ptr) + diff --git a/mindscience/sharker/utils/subgraph.py b/mindscience/sharker/utils/subgraph.py new file mode 100644 index 000000000..e1d7b27f7 --- /dev/null +++ b/mindscience/sharker/utils/subgraph.py @@ -0,0 +1,449 @@ +from typing import List, Optional, Tuple, Union +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn, mint +from .map import map_index, map_index_np +from .mask import index_to_mask, index_to_mask_np +from .num_nodes import maybe_num_nodes + + +def get_num_hops(model: nn.Cell) -> int: + r"""Returns the number of hops the model is aggregating information + from. + + .. note:: + + This function counts the number of message passing layers as an + approximation of the total number of hops covered by the model. + Its output may not necessarily be correct in case message passing + layers perform multi-hop aggregation, *e.g.*, as in + :class:`~sharker.nn.conv.ChebConv`. + + Example: + >>> class GNN(nn.Cell): + ... def __init__(self): + ... super().__init__() + ... self.conv1 = GCNConv(3, 16) + ... self.conv2 = GCNConv(16, 16) + ... self.lin = nn.Dense16, 2) + ... + ... def construct(self, x, edge_index): + ... x = ops.relu(self.conv1(x, edge_index)) + ... x = ops.relu(self.conv2(x, edge_index)) + ... return self.lin(x) + >>> get_num_hops(GNN()) + 2 + """ + from mindscience.sharker.nn.conv import MessagePassing + + num_hops = 0 + for cell in model.cells(): + if isinstance(cell, MessagePassing): + num_hops += 1 + return num_hops + + +def subgraph( + subset: Union[Tensor, np.ndarray, List[int]], + edge_index: Union[Tensor, np.ndarray], + edge_attr: Optional[Union[Tensor, np.ndarray]] = None, + relabel_nodes: bool = False, + num_nodes: Optional[int] = None, + *, + return_edge_mask: bool = False, +) -> Union[ + Tuple[Tensor, Optional[Tensor]], + Tuple[Tensor, Optional[Tensor], Tensor], +]: + r"""Returns the induced subgraph of :obj:`(edge_index, edge_attr)` + containing the nodes in :obj:`subset`. + + Args: + ## Not support BoolTensor at the moment + subset (LongTensor, BoolTensor or [int]): The nodes to keep. + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + relabel_nodes (bool, optional): If set to :obj:`True`, the resulting + :obj:`edge_index` will be relabeled to hold consecutive indices + starting from zero. (default: :obj:`False`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max(edge_index) + 1`. (default: :obj:`None`) + return_edge_mask (bool, optional): If set to :obj:`True`, will return + the edge mask to filter out additional edge features. + (default: :obj:`False`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6], + ... [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5]]) + >>> edge_attr = Tensor, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) + >>> subset = Tensor([3, 4, 5]) + >>> subgraph(subset, edge_index, edge_attr) + (tensor([[3, 4, 4, 5], + [4, 3, 5, 4]]), + tensor([ 7., 8., 9., 10.])) + + >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True) + (tensor([[3, 4, 4, 5], + [4, 3, 5, 4]]), + tensor([ 7., 8., 9., 10.]), + tensor([False, False, False, False, False, False, True, + True, True, True, False, False])) + """ + if isinstance(subset, (list, tuple)): + subset = np.array(subset, dtype=np.int64) + elif isinstance(subset, Tensor): + subset = subset.asnumpy() + + if isinstance(edge_index, Tensor): + edge_index = edge_index.asnumpy() + + if subset.dtype != np.bool_: + num_nodes = maybe_num_nodes(edge_index, num_nodes) + node_mask = index_to_mask_np(subset, size=num_nodes) + else: + num_nodes = subset.shape[0] + node_mask = subset + subset = np.nonzero(node_mask)[0] + + src, dst = edge_index + edge_mask = node_mask[src] & node_mask[dst] + edge_index = edge_index[:,edge_mask] + + if edge_attr is not None: + if isinstance(edge_attr, Tensor): + edge_attr = edge_attr.asnumpy() + edge_attr = edge_attr[edge_mask] + else: + None + + if relabel_nodes: + edge_index, _ = map_index_np( + edge_index.reshape(-1), + subset, + max_index=num_nodes, + inclusive=True, + ) + edge_index = edge_index.reshape(2, -1) + + if return_edge_mask == True: + return [edge_index, edge_attr, edge_mask] + else: + return edge_index, edge_attr + + +def bipartite_subgraph( + subset: Union[Tuple[Tensor, Tensor], Tuple[List[int], List[int]]], + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + relabel_nodes: bool = False, + size: Optional[Tuple[int, int]] = None, + return_edge_mask: bool = False, +) -> Union[ + Tuple[Tensor, Optional[Tensor]], + Tuple[Tensor, Optional[Tensor], Optional[Tensor]], +]: + r"""Returns the induced subgraph of the bipartite graph + :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. + + Args: + subset (Tuple[Tensor, Tensor] or tuple([int],[int])): The nodes + to keep. + edge_index (LongTensor): The edge indices. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features. (default: :obj:`None`) + relabel_nodes (bool, optional): If set to :obj:`True`, the resulting + :obj:`edge_index` will be relabeled to hold consecutive indices + starting from zero. (default: :obj:`False`) + size (tuple, optional): The number of nodes. + (default: :obj:`None`) + return_edge_mask (bool, optional): If set to :obj:`True`, will return + the edge mask to filter out additional edge features. + (default: :obj:`False`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Examples: + >>> edge_index = Tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6], + ... [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]]) + >>> edge_attr = Tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) + >>> subset = (Tensor([2, 3, 5]), Tensor + >>> bipartite_subgraph(subset, edge_index, edge_attr) + (tensor([[2, 3, 5, 5], + [3, 2, 2, 3]]), + tensor([ 3, 4, 9, 10])) + + >>> bipartite_subgraph(subset, edge_index, edge_attr, + ... return_edge_mask=True) + (tensor([[2, 3, 5, 5], + [3, 2, 2, 3]]), + tensor([ 3, 4, 9, 10]), + tensor([False, False, True, True, False, False, False, False, + True, True, False])) + """ + src_subset, dst_subset = subset + if not isinstance(src_subset, Tensor): + src_subset = Tensor(src_subset, dtype=ms.int64) + if not isinstance(dst_subset, Tensor): + dst_subset = Tensor(dst_subset, dtype=ms.int64) + + src, dst = edge_index + if src_subset.dtype != ms.bool_: + src_size = int(src.max()) + 1 if size is None else size[0] + src_node_mask = index_to_mask(src_subset, size=src_size) + else: + src_size = src_subset.shape[0] + src_node_mask = src_subset + src_subset = mint.nonzero(src_subset).view(-1) + + if dst_subset.dtype != ms.bool_: + dst_size = int(dst.max()) + 1 if size is None else size[1] + dst_node_mask = index_to_mask(dst_subset, size=dst_size) + else: + dst_size = dst_subset.shape[0] + dst_node_mask = dst_subset + dst_subset = mint.nonzero(dst_subset).view(-1) + + edge_mask = mint.logical_and(mint.index_select(src_node_mask, 0, src), mint.index_select(dst_node_mask, 0, dst)) + edge_index = (ops.masked_select(edge_index, edge_mask)).view(2, -1) + + if edge_attr is not None: + edge_attr = ops.masked_select(edge_attr, edge_mask) + else: + None + + if relabel_nodes: + src_index, _ = map_index(edge_index[0], src_subset, max_index=src_size, inclusive=True) + dst_index, _ = map_index(edge_index[1], dst_subset, max_index=dst_size, inclusive=True) + edge_index = mint.stack([src_index, dst_index], dim=0) + + if return_edge_mask: + return edge_index, edge_attr, edge_mask + else: + return edge_index, edge_attr + + +def k_hop_subgraph( + node_idx: Union[int, List[int], Tensor], + num_hops: int, + edge_index: Tensor, + relabel_nodes: bool = False, + num_nodes: Optional[int] = None, + flow: str = "src_to_dst", + directed: bool = False, +) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + r"""Computes the induced subgraph of :obj:`edge_index` around all nodes in + :attr:`node_idx` reachable within :math:`k` hops. + + The :attr:`flow` argument denotes the direction of edges for finding + :math:`k`-hop neighbors. If set to :obj:`"src_to_dst"`, then the + method will find all neighbors that point to the initial set of seed nodes + in :attr:`node_idx.` + This mimics the natural flow of message passing in Graph Neural Networks. + + The method returns (1) the nodes involved in the subgraph, (2) the filtered + :obj:`edge_index` connectivity, (3) the mapping from node indices in + :obj:`node_idx` to their new location, and (4) the edge mask indicating + which edges were preserved. + + Args: + node_idx (int, list, tuple or :obj:`Tensor`): The central seed + node(s). + num_hops (int): The number of hops :math:`k`. + edge_index (LongTensor): The edge indices. + relabel_nodes (bool, optional): If set to :obj:`True`, the resulting + :obj:`edge_index` will be relabeled to hold consecutive indices + starting from zero. (default: :obj:`False`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max_val + 1` of :attr:`edge_index`. (default: :obj:`None`) + flow (str, optional): The flow direction of :math:`k`-hop aggregation + (:obj:`"src_to_trg"` or :obj:`"trg_to_src"`). + (default: :obj:`"src_to_trg"`) + directed (bool, optional): If set to :obj:`True`, will only include + directed edges to the seed nodes :obj:`node_idx`. + (default: :obj:`False`) + + :rtype: (:class:`Tensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 2, 3, 4, 5], + ... [2, 2, 4, 4, 6, 6]]) + + >>> # Center node 6, 2-hops + >>> subset, edge_index, mapping, edge_mask = k_hop_subgraph( + ... 6, 2, edge_index, relabel_nodes=True) + >>> subset + Tensor([2, 3, 4, 5, 6]) + >>> edge_index + Tensor([[0, 1, 2, 3], + [2, 2, 4, 4]]) + >>> mapping + Tensor([4]) + >>> edge_mask + Tensor([False, False, True, True, True, True]) + >>> subset[mapping] + Tensor([6]) + + >>> edge_index = Tensor([[1, 2, 4, 5], + ... [0, 1, 5, 6]]) + >>> (subset, edge_index, + ... mapping, edge_mask) = k_hop_subgraph([0, 6], 2, + ... edge_index, + ... relabel_nodes=True) + >>> subset + tensor([0, 1, 2, 4, 5, 6]) + >>> edge_index + tensor([[1, 2, 3, 4], + [0, 1, 4, 5]]) + >>> mapping + tensor([0, 5]) + >>> edge_mask + tensor([True, True, True, True]) + >>> subset[mapping] + tensor([0, 6]) + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + assert flow in ["src_to_dst", "dst_to_src"] + if flow == "dst_to_src": + dst, src = edge_index + else: + src, dst = edge_index + + node_mask = mint.zeros(num_nodes, dtype=ms.bool_) + edge_mask = mint.zeros(dst.shape[0], dtype=ms.bool_) + + + if isinstance(node_idx, int): + node_idx = Tensor([node_idx]) + elif isinstance(node_idx, (list, tuple)): + node_idx = Tensor(node_idx) + subsets = [node_idx] + + for _ in range(num_hops): + + node_mask = node_mask.fill(False) + node_mask[subsets[-1]] = True + edge_mask = mint.index_select(node_mask, 0, dst) + subsets.append(ops.masked_select(src, edge_mask)) + + subset, inv = mint.unique(mint.cat(subsets), return_inverse=True) + inv = inv[:node_idx.numel()] + + node_mask = node_mask.fill(False) + node_mask[subset] = True + + if not directed: + edge_mask = mint.logical_and(mint.index_select(node_mask, 0, dst), mint.index_select(node_mask, 0, src)) + + edge_index = (ops.masked_select(edge_index, edge_mask)).view(2, -1) + + if relabel_nodes: + mapping = -mint.ones(num_nodes, dtype=dst.dtype) + mapping[subset] = mint.arange(subset.shape[0]) + edge_index = mapping[edge_index] + + return subset, edge_index, inv, edge_mask + + +def hyper_subgraph( + subset: Union[Tensor, List[int]], + edge_index: Tensor, + edge_attr: Optional[Tensor] = None, + relabel_nodes: bool = False, + num_nodes: Optional[int] = None, + return_edge_mask: bool = False, +) -> Union[ + Tuple[Tensor, Optional[Tensor]], + Tuple[Tensor, Optional[Tensor], Tensor], +]: + r"""Returns the induced subgraph of the hyper graph of + :obj:`(edge_index, edge_attr)` containing the nodes in :obj:`subset`. + + Args: + subset (Tensor or [int]): The nodes to keep. + edge_index (LongTensor): Hyperedge tensor + with shape :obj:`[2, num_edges*num_nodes_per_edge]`, where + :obj:`edge_index[1]` denotes the hyperedge index and + :obj:`edge_index[0]` denotes the node indices that are connected + by the hyperedge. + edge_attr (Tensor, optional): Edge weights or multi-dimensional + edge features of shape :obj:`[num_edges, *]`. + (default: :obj:`None`) + relabel_nodes (bool, optional): If set to :obj:`True`, the + resulting :obj:`edge_index` will be relabeled to hold + consecutive indices starting from zero. (default: :obj:`False`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max(edge_index[0]) + 1`. (default: :obj:`None`) + return_edge_mask (bool, optional): If set to :obj:`True`, will return + the edge mask to filter out additional edge features. + (default: :obj:`False`) + + :rtype: (:class:`LongTensor`, :class:`Tensor`) + + Examples: + >>> edge_index = Tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3], + ... [0, 0, 0, 1, 1, 1, 2, 2, 2]]) + >>> edge_attr = Tensor([3, 2, 6]) + >>> subset = Tensor + >>> subgraph(subset, edge_index, edge_attr) + (tensor([[0, 3], + [0, 0]]), + tensor([ 6.])) + + >>> subgraph(subset, edge_index, edge_attr, return_edge_mask=True) + (tensor([[0, 3], + [0, 0]]), + tensor([ 6.])) + tensor([False, False, True]) + """ + if isinstance(subset, (list, tuple)): + subset = Tensor(subset, dtype=ms.int64) + + if subset.dtype != ms.bool_: + num_nodes = maybe_num_nodes(edge_index, num_nodes) + node_mask = index_to_mask(subset, size=num_nodes) + else: + num_nodes = subset.shape[0] + node_mask = subset + + src, dst = edge_index + # Mask all connections that contain a node not in the subset + hyper_edge_mask = mint.index_select(node_mask, 0, src) + + + # Mask hyperedges that contain one or less nodes from the subset + edge_mask = ops.unsorted_segment_sum( + hyper_edge_mask.astype(ms.int64), + dst, dst.max() + 1) > 1 + + # Mask connections if hyperedge contains one or less nodes from the subset + # or is connected to a node not in the subset + hyper_edge_mask = mint.logical_and(hyper_edge_mask, mint.index_select(edge_mask, 0, dst)) + + src = ops.masked_select(src, hyper_edge_mask) + dst = ops.masked_select(dst, hyper_edge_mask) + if edge_attr is not None: + edge_attr = ops.masked_select(edge_attr, edge_mask) + else: + None + + + # Relabel edges + edge_idx = mint.zeros(edge_mask.shape[0], dtype=ms.int64) + edge_mask_idx = ops.argwhere(edge_mask) + edge_mask_idx = edge_mask_idx.view(int(edge_mask.sum()),) + edge_idx[edge_mask_idx] = mint.arange(edge_mask.sum().item()) + src, dst = mint.stack([src, edge_idx[dst]]) + + if relabel_nodes: + node_idx = mint.zeros(node_mask.shape[0], dtype=ms.int64) + node_idx[subset] = mint.arange(node_mask.sum().item()) + src, dst = mint.stack([node_idx[src], dst]) + edge_index = mint.stack([src, dst]) + if return_edge_mask: + return edge_index, edge_attr, edge_mask + else: + return edge_index, edge_attr \ No newline at end of file diff --git a/mindscience/sharker/utils/to_dense_adj.py b/mindscience/sharker/utils/to_dense_adj.py new file mode 100644 index 000000000..84395f597 --- /dev/null +++ b/mindscience/sharker/utils/to_dense_adj.py @@ -0,0 +1,105 @@ +import mindspore as ms + +from typing import Optional +from mindspore import Tensor, ops, mint + +from .functions import cumsum +from . import scatter + + +def to_dense_adj( + edge_index: Tensor, + batch: Optional[Tensor] = None, + edge_attr: Optional[Tensor] = None, + max_num_nodes: Optional[int] = None, + batch_size: Optional[int] = None, +) -> Tensor: + r"""Converts batched sparse adjacency matrices given by edge indices and + edge attributes to a single dense batched adjacency matrix. + + Args: + edge_index (LongTensor): The edge indices. + batch (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. (default: :obj:`None`) + edge_attr (Tensor, optional): Edge weights or multi-dimensional edge + features. + If :obj:`edge_index` contains duplicated edges, the dense adjacency + matrix output holds the summed up entries of :obj:`edge_attr` for + duplicated edges. (default: :obj:`None`) + max_num_nodes (int, optional): The size of the output node dimension. + (default: :obj:`None`) + batch_size (int, optional): The batch size. (default: :obj:`None`) + + :rtype: :class:`Tensor` + + Examples: + >>> edge_index = Tensor([[0, 0, 1, 2, 3], + ... [0, 1, 0, 3, 0]]) + >>> batch = Tensor([0, 0, 1, 1]) + >>> to_dense_adj(edge_index, batch) + tensor([[[1., 1.], + [1., 0.]], + [[0., 1.], + [1., 0.]]]) + + >>> to_dense_adj(edge_index, batch, max_num_nodes=4) + tensor([[[1., 1., 0., 0.], + [1., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]], + [[0., 1., 0., 0.], + [1., 0., 0., 0.], + [0., 0., 0., 0.], + [0., 0., 0., 0.]]]) + + >>> edge_attr = Tensor([1.0, 2.0, 3.0, 4.0, 5.0]) + >>> to_dense_adj(edge_index, batch, edge_attr) + tensor([[[1., 2.], + [3., 0.]], + [[0., 4.], + [5., 0.]]]) + """ + edge_index = edge_index.astype("int32") + if batch is None: + max_index = int(edge_index.max()) + 1 if edge_index.numel() > 0 else 0 + # batch = edge_index.new_zeros(max_index, dtype=ms.int32) + batch = ops.zeros(max_index, edge_index.dtype) + + if batch_size is None: + batch_size = int(batch.max()) + 1 if batch.numel() > 0 else 1 + + one = batch.new_ones(batch.shape[0]) + num_nodes = scatter(one, batch, dim=0, dim_size=batch_size, reduce="sum") + cum_nodes = cumsum(num_nodes) + + idx0 = batch[edge_index[0]] + idx1 = edge_index[0] - cum_nodes[batch][edge_index[0]] + idx2 = edge_index[1] - cum_nodes[batch][edge_index[1]] + + if max_num_nodes is None: + max_num_nodes = int(num_nodes.max()) + + elif (idx1.numel() > 0 and idx1.max() >= max_num_nodes) or ( + idx2.numel() > 0 and idx2.max() >= max_num_nodes + ): + mask = mint.logical_and((idx1 < max_num_nodes), (idx2 < max_num_nodes)) + + idx0 = ops.masked_select(idx0, mask) + idx1 = ops.masked_select(idx1, mask) + idx2 = ops.masked_select(idx2, mask) + edge_attr = None if edge_attr is None else ops.masked_select(edge_attr, mask) + + + if edge_attr is None: + edge_attr = mint.ones(idx0.numel()) + + size = [batch_size, max_num_nodes, max_num_nodes] + size += list(edge_attr.shape)[1:] + flattened_size = batch_size * max_num_nodes * max_num_nodes + + idx = idx0 * max_num_nodes * max_num_nodes + idx1 * max_num_nodes + idx2 + adj = scatter(edge_attr, idx, dim=0, dim_size=flattened_size, reduce="sum") + adj = adj.reshape(size) + + return adj diff --git a/mindscience/sharker/utils/to_dense_batch.py b/mindscience/sharker/utils/to_dense_batch.py new file mode 100644 index 000000000..3c5646f69 --- /dev/null +++ b/mindscience/sharker/utils/to_dense_batch.py @@ -0,0 +1,136 @@ +from typing import Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, mint + +from .functions import cumsum +from . import scatter +from ..experimental import disable_dynamic_shapes, is_experimental_mode_enabled + + +@disable_dynamic_shapes(required_args=['batch_size', 'max_num_nodes']) +def to_dense_batch( + x: Tensor, + batch: Optional[Tensor] = None, + fill_value: float = 0.0, + max_num_nodes: Optional[int] = None, + batch_size: Optional[int] = None, +) -> Tuple[Tensor, Tensor]: + r"""Given a sparse batch of node features + :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}` (with + :math:`N_i` indicating the number of nodes in graph :math:`i`), creates a + dense node feature tensor + :math:`\mathbf{X} \in \mathbb{R}^{B \times N_{\max} \times F}` (with + :math:`N_{\max} = \max_i^B N_i`). + In addition, a mask of shape :math:`\mathbf{M} \in \{ 0, 1 \}^{B \times + N_{\max}}` is returned, holding information about the existence of + fake-nodes in the dense representation. + + Args: + x (Tensor): Node feature matrix + :math:`\mathbf{X} \in \mathbb{R}^{(N_1 + \ldots + N_B) \times F}`. + batch (LongTensor, optional): Batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. Must be ordered. (default: :obj:`None`) + fill_value (float, optional): The value for invalid entries in the + resulting dense output tensor. (default: :obj:`0`) + max_num_nodes (int, optional): The size of the output node dimension. + (default: :obj:`None`) + batch_size (int, optional): The batch size. (default: :obj:`None`) + + :rtype: (:class:`Tensor`, :class:`BoolTensor`) + + Examples: + >>> x = ops.arange(12).view(6, 2) + >>> x + tensor([[ 0, 1], + [ 2, 3], + [ 4, 5], + [ 6, 7], + [ 8, 9], + [10, 11]]) + + >>> out, mask = to_dense_batch(x) + >>> mask + tensor([[True, True, True, True, True, True]]) + + >>> batch = Tensor, 1, 2, 2, 2]) + >>> out, mask = to_dense_batch(x, batch) + >>> out + tensor([[[ 0, 1], + [ 2, 3], + [ 0, 0]], + [[ 4, 5], + [ 0, 0], + [ 0, 0]], + [[ 6, 7], + [ 8, 9], + [10, 11]]]) + >>> mask + tensor([[ True, True, False], + [ True, False, False], + [ True, True, True]]) + + >>> out, mask = to_dense_batch(x, batch, max_num_nodes=4) + >>> out + tensor([[[ 0, 1], + [ 2, 3], + [ 0, 0], + [ 0, 0]], + [[ 4, 5], + [ 0, 0], + [ 0, 0], + [ 0, 0]], + [[ 6, 7], + [ 8, 9], + [10, 11], + [ 0, 0]]]) + + >>> mask + tensor([[ True, True, False, False], + [ True, False, False, False], + [ True, True, True, False]]) + """ + if batch is None and max_num_nodes is None: + mask = mint.ones([1, x.shape[0]]).bool() + return x.unsqueeze(0), mask + + if batch is None: + batch = x.new_zeros(x.shape[0]).long() + + if batch_size is None: + batch_size = int(batch.max()) + 1 + + num_nodes = scatter( + ops.ones(x.shape[0], dtype=batch.dtype), batch, dim=0, dim_size=batch_size, reduce="sum" + ) + + cum_nodes = cumsum(num_nodes) + + filter_nodes = False + dynamic_shapes_disabled = is_experimental_mode_enabled( + 'disable_dynamic_shapes') + + if max_num_nodes is None: + max_num_nodes = int(num_nodes.max()) + elif not dynamic_shapes_disabled and num_nodes.max() > max_num_nodes: + filter_nodes = True + + tmp = mint.arange(batch.shape[0]) - cum_nodes[batch] + idx = tmp + (batch * max_num_nodes) + idx = idx.astype(ms.int64) + if filter_nodes: + mask = tmp < max_num_nodes + x, idx = x[mask], idx[mask] + + size = (batch_size * max_num_nodes, ) + x.shape[1:] + out = ms.Tensor(fill_value) + out = out.astype(x.dtype).tile(size) + out[idx] = x + out = out.reshape((batch_size, max_num_nodes) + x.shape[1:]) + + mask = mint.zeros(batch_size * max_num_nodes) + mask[idx] = 1 + mask = mask.view(batch_size, max_num_nodes) + + return out, mask diff --git a/mindscience/sharker/utils/tree_decomposition.py b/mindscience/sharker/utils/tree_decomposition.py new file mode 100644 index 000000000..0214fabf0 --- /dev/null +++ b/mindscience/sharker/utils/tree_decomposition.py @@ -0,0 +1,128 @@ +from itertools import chain +from typing import Any, List, Tuple, Union + +from scipy.sparse.csgraph import minimum_spanning_tree +import mindspore as ms +from mindspore import Tensor, ops, mint + +from .convert import from_scipy_sparse_matrix, to_scipy_sparse_matrix +from .undirected import to_undirected + + +def tree_decomposition( + mol: Any, + return_vocab: bool = False, +) -> Union[Tuple[Tensor, Tensor, int], Tuple[Tensor, Tensor, int, Tensor]]: + r"""The tree decomposition algorithm of molecules from the + `"Junction Tree Variational Autoencoder for Molecular Graph Generation" + `_ paper. + Returns the graph connectivity of the junction tree, the assignment + mapping of each atom to the clique in the junction tree, and the number + of cliques. + + Args: + mol (rdkit.Chem.Mol): An :obj:`rdkit` molecule. + return_vocab (bool, optional): If set to :obj:`True`, will return an + identifier for each clique (ring, bond, bridged compounds, single). + (default: :obj:`False`) + + :rtype: :obj:`(LongTensor, LongTensor, int)` if :obj:`return_vocab` is + :obj:`False`, else :obj:`(LongTensor, LongTensor, int, LongTensor)` + """ + import rdkit.Chem as Chem + + # Cliques = rings and bonds. + cliques: List[List[int]] = [list(x) for x in Chem.GetSymmSSSR(mol)] + xs: List[int] = [0] * len(cliques) + for bond in mol.GetBonds(): + if not bond.IsInRing(): + cliques.append([bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()]) + xs.append(1) + + # Generate `atom2cliques` mappings. + atom2cliques: List[List[int]] = [[] for i in range(mol.GetNumAtoms())] + for c in range(len(cliques)): + for atom in cliques[c]: + atom2cliques[atom].append(c) + + # Merge rings that share more than 2 atoms as they form bridged compounds. + for c1 in range(len(cliques)): + for atom in cliques[c1]: + for c2 in atom2cliques[atom]: + if c1 >= c2 or len(cliques[c1]) <= 2 or len(cliques[c2]) <= 2: + continue + if len(set(cliques[c1]) & set(cliques[c2])) > 2: + cliques[c1] = list(set(cliques[c1]) | set(cliques[c2])) + xs[c1] = 2 + cliques[c2] = [] + xs[c2] = -1 + cliques = [c for c in cliques if len(c) > 0] + xs = [x for x in xs if x >= 0] + + # Update `atom2cliques` mappings. + atom2cliques = [[] for i in range(mol.GetNumAtoms())] + for c in range(len(cliques)): + for atom in cliques[c]: + atom2cliques[atom].append(c) + + # Add singleton cliques in case there are more than 2 intersecting + # cliques. We further compute the "initial" clique graph. + edges = {} + for atom in range(mol.GetNumAtoms()): + cs = atom2cliques[atom] + if len(cs) <= 1: + continue + + # Number of bond clusters that the atom lies in. + bonds = [c for c in cs if len(cliques[c]) == 2] + # Number of ring clusters that the atom lies in. + rings = [c for c in cs if len(cliques[c]) > 4] + + if len(bonds) > 2 or (len(bonds) == 2 and len(cs) > 2): + cliques.append([atom]) + xs.append(3) + c2 = len(cliques) - 1 + for c1 in cs: + edges[(c1, c2)] = 1 + + elif len(rings) > 2: + cliques.append([atom]) + xs.append(3) + c2 = len(cliques) - 1 + for c1 in cs: + edges[(c1, c2)] = 99 + + else: + for i in range(len(cs)): + for j in range(i + 1, len(cs)): + c1, c2 = cs[i], cs[j] + count = len(set(cliques[c1]) & set(cliques[c2])) + edges[(c1, c2)] = min(count, edges.get((c1, c2), 99)) + + # Update `atom2cliques` mappings. + atom2cliques = [[] for i in range(mol.GetNumAtoms())] + for c in range(len(cliques)): + for atom in cliques[c]: + atom2cliques[atom].append(c) + + if len(edges) > 0: + edge_index_T, weight = zip(*edges.items()) + edge_index = ms.Tensor(edge_index_T).T + inv_weight = 100 - ms.Tensor(weight) + graph = to_scipy_sparse_matrix(edge_index, inv_weight, len(cliques)) + junc_tree = minimum_spanning_tree(graph) + edge_index, _ = from_scipy_sparse_matrix(junc_tree) + edge_index = to_undirected(edge_index, num_nodes=len(cliques)) + else: + edge_index = ops.Tensor((2, 0)).long() + + rows = [[i] * len(atom2cliques[i]) for i in range(mol.GetNumAtoms())] + row = ms.Tensor(list(chain.from_iterable(rows))) + col = ms.Tensor(list(chain.from_iterable(atom2cliques))) + atom2clique = mint.stack([row, col], dim=0).long() + + if return_vocab: + vocab = ms.Tensor(xs).long() + return edge_index, atom2clique, len(cliques), vocab + else: + return edge_index, atom2clique, len(cliques) diff --git a/mindscience/sharker/utils/trim_to_layer.py b/mindscience/sharker/utils/trim_to_layer.py new file mode 100644 index 000000000..339c00a6a --- /dev/null +++ b/mindscience/sharker/utils/trim_to_layer.py @@ -0,0 +1,157 @@ +from typing import Dict, List, Optional, Tuple, Union +from mindspore import Tensor, nn + + +def trim_to_layer( + layer: int, + num_sampled_nodes_per_hop: Union[List[int], Dict[str, List[int]]], + num_sampled_edges_per_hop: Union[List[int], Dict[Tuple[str, str, str], List[int]]], + x: Union[Tensor, Dict[str, Tensor]], + edge_index: Union[Tensor, Dict[Tuple[str, str, str], Tensor]], + edge_attr: Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]] = None, +) -> Tuple[ + Union[Tensor, Dict[str, Tensor]], + Union[Tensor, Dict[Tuple[str, str, str], Union[Tensor,]]], + Optional[Union[Tensor, Dict[Tuple[str, str, str], Tensor]]], +]: + r"""Trims the :obj:`edge_index` representation, node features :obj:`x` and + edge features :obj:`edge_attr` to a minimal-sized representation for the + current GNN layer :obj:`layer` in directed + :class:`~sharker.loader.NeighborLoader` scenarios. + + This ensures that no computation is performed for nodes and edges that are + not included in the current GNN layer, thus avoiding unnecessary + computation within the GNN when performing neighborhood sampling. + + Args: + layer (int): The current GNN layer. + num_sampled_nodes_per_hop (List[int] or Dict[str, List[int]]): The + number of sampled nodes per hop. + num_sampled_edges_per_hop (List[int] or Dict[Tuple[str, str, str], List[int]]): The + number of sampled edges per hop. + x (Tensor or Dict[str, Tensor]): The homogeneous or + heterogeneous (hidden) node features. + edge_index (Tensor or Dict[Tuple[str, str, str], Tensor]): The + homogeneous or heterogeneous edge indices. + edge_attr (Tensor or Dict[Tuple[str, str, str], Tensor], optional): The + homogeneous or heterogeneous (hidden) edge features. + """ + if layer <= 0: + return x, edge_index, edge_attr + + if isinstance(num_sampled_edges_per_hop, dict): + assert isinstance(num_sampled_nodes_per_hop, dict) + + assert isinstance(x, dict) + x = {k: trim_feat(v, layer, num_sampled_nodes_per_hop[k]) for k, v in x.items()} + + assert isinstance(edge_index, dict) + edge_index = { + k: trim_adj( + v, + layer, + num_sampled_nodes_per_hop[k[0]], + num_sampled_nodes_per_hop[k[-1]], + num_sampled_edges_per_hop[k], + ) + for k, v in edge_index.items() + } + + if edge_attr is not None: + assert isinstance(edge_attr, dict) + edge_attr = { + k: trim_feat(v, layer, num_sampled_edges_per_hop[k]) + for k, v in edge_attr.items() + } + + return x, edge_index, edge_attr + + assert isinstance(num_sampled_nodes_per_hop, list) + + assert isinstance(x, Tensor) + x = trim_feat(x, layer, num_sampled_nodes_per_hop) + + assert isinstance(edge_index, Tensor) + edge_index = trim_adj( + edge_index, + layer, + num_sampled_nodes_per_hop, + num_sampled_nodes_per_hop, + num_sampled_edges_per_hop, + ) + + if edge_attr is not None: + assert isinstance(edge_attr, Tensor) + edge_attr = trim_feat(edge_attr, layer, num_sampled_edges_per_hop) + + return x, edge_index, edge_attr + + +class TrimToLayer(nn.Cell): + def construct( + self, + layer: int, + num_sampled_nodes_per_hop: Optional[List[int]], + num_sampled_edges_per_hop: Optional[List[int]], + x: Tensor, + edge_index: Union[Tensor, ], + edge_attr: Optional[Tensor] = None, + ) -> Tuple[Tensor, Union[Tensor,], Optional[Tensor]]: + + if not isinstance(num_sampled_nodes_per_hop, list) and isinstance( + num_sampled_edges_per_hop, list + ): + raise ValueError("'num_sampled_nodes_per_hop' needs to be given") + if not isinstance(num_sampled_edges_per_hop, list) and isinstance( + num_sampled_nodes_per_hop, list + ): + raise ValueError("'num_sampled_edges_per_hop' needs to be given") + + if num_sampled_nodes_per_hop is None: + return x, edge_index, edge_attr + if num_sampled_edges_per_hop is None: + return x, edge_index, edge_attr + + return trim_to_layer( + layer, + num_sampled_nodes_per_hop, + num_sampled_edges_per_hop, + x, + edge_index, + edge_attr, + ) + + +# Helper functions ############################################################ + + +def trim_feat(x: Tensor, layer: int, num_samples_per_hop: List[int]) -> Tensor: + if layer <= 0: + return x + + return x.narrow( + axis=0, + start=0, + length=x.shape[0] - num_samples_per_hop[-layer], + ) + + +def trim_adj( + edge_index: Union[Tensor, ], + layer: int, + num_sampled_src_nodes_per_hop: List[int], + num_sampled_dst_nodes_per_hop: List[int], + num_sampled_edges_per_hop: List[int], +) -> Union[Tensor,]: + + if layer <= 0: + return edge_index + + if isinstance(edge_index, Tensor): + edge_index = edge_index.narrow( + axis=1, + start=0, + length=edge_index.shape[1] - num_sampled_edges_per_hop[-layer], + ) + return edge_index + raise ValueError(f"Unsupported 'edge_index' type '{type(edge_index)}'") diff --git a/mindscience/sharker/utils/unbatch.py b/mindscience/sharker/utils/unbatch.py new file mode 100644 index 000000000..66fa3fbe4 --- /dev/null +++ b/mindscience/sharker/utils/unbatch.py @@ -0,0 +1,70 @@ +from typing import List, Optional +from mindspore import Tensor + +from .functions import cumsum +from .degree import degree + + +def unbatch( + src: Tensor, + batch: Tensor, + axis: int = 0, + batch_size: Optional[int] = None, +) -> List[Tensor]: + r"""Splits :obj:`src` according to a :obj:`batch` vector along dimension + :obj:`dim`. + + Args: + src (Tensor): The source tensor. + batch (LongTensor): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + entry in :obj:`src` to a specific example. Must be ordered. + axis (int, optional): The dimension along which to split the :obj:`src` + tensor. (default: :obj:`0`) + batch_size (int, optional): The batch size. (default: :obj:`None`) + + :rtype: :class:`List[Tensor]` + + Example: + >>> src = ops.arange(7) + >>> batch = Tensor([0, 0, 0, 1, 1, 2, 2]) + >>> unbatch(src, batch) + (tensor([0, 1, 2]), tensor([3, 4]), tensor([5, 6])) + """ + sizes = degree(batch, batch_size).long() + return src.split(sizes.tolist(), axis) + + +def unbatch_edge_index( + edge_index: Tensor, + batch: Tensor, + batch_size: Optional[int] = None, +) -> List[Tensor]: + r"""Splits the :obj:`edge_index` according to a :obj:`batch` vector. + + Args: + edge_index (Tensor): The edge_index tensor. Must be ordered. + batch (LongTensor): The batch vector + :math:`\mathbf{b} \in {\{ 0, \ldots, B-1\}}^N`, which assigns each + node to a specific example. Must be ordered. + batch_size (int, optional): The batch size. (default: :obj:`None`) + + :rtype: :class:`List[Tensor]` + + Example: + >>> edge_index = Tensor([[0, 1, 1, 2, 2, 3, 4, 5, 5, 6], + ... [1, 0, 2, 1, 3, 2, 5, 4, 6, 5]]) + >>> batch = Tensor([0, 0, 0, 0, 1, 1, 1]) + >>> unbatch_edge_index(edge_index, batch) + (tensor([[0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2]]), + tensor([[0, 1, 1, 2], + [1, 0, 2, 1]])) + """ + deg = degree(batch, batch_size).long() + ptr = cumsum(deg) + + edge_batch = batch[edge_index[0]] + edge_index = edge_index - ptr[edge_batch] + sizes = degree(edge_batch, batch_size).long() + return edge_index.split(sizes.tolist(), axis=1) diff --git a/mindscience/sharker/utils/undirected.py b/mindscience/sharker/utils/undirected.py new file mode 100644 index 000000000..d0502afe1 --- /dev/null +++ b/mindscience/sharker/utils/undirected.py @@ -0,0 +1,142 @@ +from typing import List, Optional, Tuple, Union +from mindspore import Tensor, ops, mint +from .coalesce import coalesce +from .sort_edge_index import sort_edge_index +from .num_nodes import maybe_num_nodes + +MISSING = "???" + + +def is_undirected( + edge_index: Tensor, + edge_attr: Union[Optional[Tensor], List[Tensor]] = None, + num_nodes: Optional[int] = None, +) -> bool: + r"""Returns :obj:`True` if the graph given by :attr:`edge_index` is + undirected. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will check for equivalence in all its entries. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max(edge_index) + 1`. (default: :obj:`None`) + + :rtype: bool + + Examples: + >>> edge_index = Tensor([[0, 1, 0], + ... [1, 0, 0]]) + >>> weight = Tensor([0, 0, 1]) + >>> is_undirected(edge_index, weight) + True + + >>> weight = Tensor([0, 1, 1]) + >>> is_undirected(edge_index, weight) + False + + """ + num_nodes = maybe_num_nodes(edge_index, num_nodes) + + edge_attrs: List[Tensor] = [] + if isinstance(edge_attr, Tensor): + edge_attrs.append(edge_attr) + elif isinstance(edge_attr, (list, tuple)): + edge_attrs = edge_attr + + edge_index1, edge_attrs1 = sort_edge_index( + edge_index, + edge_attrs, + num_nodes=num_nodes, + sort_by_row=True, + ) + edge_index2, edge_attrs2 = sort_edge_index( + edge_index, + edge_attrs, + num_nodes=num_nodes, + sort_by_row=False, + ) + + if not ops.equal(edge_index1[0], edge_index2[1]).all(): + return False + + if not ops.equal(edge_index1[1], edge_index2[0]).all(): + return False + + assert isinstance(edge_attrs1, list) and isinstance(edge_attrs2, list) + for edge_attr1, edge_attr2 in zip(edge_attrs1, edge_attrs2): + if not ops.equal(edge_attr1, edge_attr2).all(): + return False + + return True + + +def to_undirected( # noqa: F811 + edge_index: Tensor, + edge_attr: Union[Optional[Tensor], List[Tensor], str] = MISSING, + num_nodes: Optional[int] = None, + reduce: str = "add", +) -> Union[Tensor, Tuple[Tensor, Optional[Tensor]], Tuple[Tensor, List[Tensor]]]: + r"""Converts the graph given by :attr:`edge_index` to an undirected graph + such that :math:`(j,i) \in \mathcal{E}` for every edge :math:`(i,j) \in + \mathcal{E}`. + + Args: + edge_index (LongTensor): The edge indices. + edge_attr (Tensor or List[Tensor], optional): Edge weights or multi- + dimensional edge features. + If given as a list, will remove duplicates for all its entries. + (default: :obj:`None`) + num_nodes (int, optional): The number of nodes, *i.e.* + :obj:`max(edge_index) + 1`. (default: :obj:`None`) + reduce (str, optional): The reduce operation to use for merging edge + features (:obj:`"add"`, :obj:`"mean"`, :obj:`"min"`, :obj:`"max"`, + :obj:`"mul"`). (default: :obj:`"add"`) + + :rtype: :class:`LongTensor` if :attr:`edge_attr` is not passed, else + (:class:`LongTensor`, :obj:`Optional[Tensor]` or :obj:`List[Tensor]]`) + + .. warning:: + + From :pyg:`PyG >= 2.3.0` onwards, this function will always return a + tuple whenever :obj:`edge_attr` is passed as an argument (even in case + it is set to :obj:`None`). + + Examples: + >>> edge_index = Tensor([[0, 1, 1], + ... [1, 0, 2]]) + >>> to_undirected(edge_index) + tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]) + + >>> edge_index = Tensor1, 1], + ... [1, 0, 2]]) + >>> edge_weight = Tensor([1., 1., 1.]) + >>> to_undirected(edge_index, edge_weight) + (tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]), + tensor([2., 2., 1., 1.])) + + >>> # Use 'mean' operation to merge edge features + >>> to_undirected(edge_index, edge_weight, reduce='mean') + (tensor([[0, 1, 1, 2], + [1, 0, 2, 1]]), + tensor([1., 1., 1., 1.])) + """ + # Maintain backward compatibility to `to_undirected(edge_index, num_nodes)` + if isinstance(edge_attr, int): + num_nodes = edge_attr + edge_attr = MISSING + + src, dst = edge_index[0], edge_index[1] + src, dst = mint.cat([src, dst], dim=0), mint.cat([dst, src], dim=0) + edge_index = mint.stack([src, dst], dim=0) + + if isinstance(edge_attr, Tensor): + edge_attr = mint.cat([edge_attr, edge_attr], dim=0) + elif isinstance(edge_attr, (list, tuple)): + edge_attr = [mint.cat([e, e], dim=0) for e in edge_attr] + + return coalesce(edge_index, edge_attr, num_nodes, reduce) diff --git a/tests/graph/__init__.py b/tests/graph/__init__.py index 83b15297d..0ac945f97 100644 --- a/tests/graph/__init__.py +++ b/tests/graph/__init__.py @@ -12,3 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +from . import utils diff --git a/tests/graph/cluster/test_fps.py b/tests/graph/cluster/test_fps.py new file mode 100644 index 000000000..26e1ea19e --- /dev/null +++ b/tests/graph/cluster/test_fps.py @@ -0,0 +1,71 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.cluster import fps +from mindscience.sharker.sparse.testing import grad_dtypes +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_fps(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-2, -2], + [-2, +2], + [+2, +2], + [+2, -2], + ], dtype) + batch = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + ptr_list = [0, 4, 8] + ptr = Tensor(ptr_list) + + out = fps(x, batch, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, batch, ratio=0.5, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + ratio = Tensor(0.5) + out = fps(x, batch, ratio=ratio, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, ptr=ptr_list, ratio=0.5, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, ptr=ptr, ratio=0.5, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + ratio = Tensor([0.5, 0.5]) + out = fps(x, batch, ratio=ratio, random_start=False) + assert out.tolist() == [0, 2, 4, 6] + + out = fps(x, random_start=False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] + + out = fps(x, ratio=0.5, random_start=False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] + + out = fps(x, ratio=Tensor(0.5), random_start=False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] + + out = fps(x, ratio=Tensor([0.5]), random_start=False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] + + if is_full_test(): + fps2 = ms.jit(fps) + out = fps2(x, None, Tensor([0.5]), False) + assert out.sort()[0].tolist() == [0, 5, 6, 7] + + +def test_random_fps(): + N = 1024 + for _ in range(5): + pos = ops.randn((2 * N, 3)) + batch_1 = ops.zeros(N, dtype=ms.int64) + batch_2 = ops.ones(N, dtype=ms.int64) + batch = ops.cat([batch_1, batch_2]) + idx = fps(pos, batch, ratio=0.5) + assert idx.min() >= 0 and idx.max() < 2 * N diff --git a/tests/graph/cluster/test_graclus.py b/tests/graph/cluster/test_graclus.py new file mode 100644 index 000000000..f877062ae --- /dev/null +++ b/tests/graph/cluster/test_graclus.py @@ -0,0 +1,56 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.cluster import graclus_cluster +from mindscience.sharker.sparse.testing import dtypes +from mindscience.sharker.sparse import ind2ptr +from mindscience.sharker.testing import is_full_test + +tests = [{ + 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], + 'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2], +}, { + 'row': [0, 0, 1, 1, 1, 2, 2, 2, 3, 3], + 'col': [1, 2, 0, 2, 3, 0, 1, 3, 1, 2], + 'weight': [1, 2, 1, 3, 2, 2, 3, 1, 2, 1], +}] + + +def assert_correct(row, col, cluster): + n = cluster.shape[0] + + # Every node was assigned a cluster. + assert cluster.min() >= 0 + + # There are no more than two nodes in each cluster. + _, index = ms.numpy.unique(cluster, return_inverse=True) + count = ops.zeros_like(cluster) + count = ops.tensor_scatter_elements(count, index, ops.ones_like(cluster), 0, reduction='add') + assert (count > 2).max() == 0 + + # Cluster value is minimal. + assert (cluster <= ops.arange(n, dtype=cluster.dtype)).sum() == n + + # Corresponding clusters must be adjacent. + for i in range(n): + x = cluster[col[row == i]] == cluster[i] # Neighbors with same cluster + y = cluster == cluster[i] # Nodes with same cluster. + y[i] = 0 # Do not look at cluster of `i`. + assert x.sum() == y.sum() + + +@pytest.mark.parametrize('test,dtype', product(tests, dtypes)) +def test_graclus_cluster(test, dtype): + row = Tensor(test['row'], ms.int64) + col = Tensor(test['col'], ms.int64) + weight = Tensor(test.get('weight'), dtype) if 'weight' in test else None + + cluster = graclus_cluster(ind2ptr(row), col, weight) + assert_correct(row, col, cluster) + + if is_full_test(): + jit = ms.jit(graclus_cluster) + cluster = jit(row, col, weight) + assert_correct(row, col, cluster) diff --git a/tests/graph/cluster/test_grid.py b/tests/graph/cluster/test_grid.py new file mode 100644 index 000000000..ea8c167cc --- /dev/null +++ b/tests/graph/cluster/test_grid.py @@ -0,0 +1,43 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.sparse.cluster import grid_cluster +from mindscience.sharker.sparse. testing import dtypes, Tensor +from mindscience.sharker.testing import is_full_test + +tests = [{ + 'crd': [2, 6], + 'size': [5], + 'cluster': [0, 0], +}, { + 'crd': [2, 6], + 'size': [5], + 'start': [0], + 'cluster': [0, 1], +}, { + 'crd': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]], + 'size': [5, 5], + 'cluster': [0, 5, 3, 0, 1], +}, { + 'crd': [[0, 0], [11, 9], [2, 8], [2, 2], [8, 3]], + 'size': [5, 5], + 'end': [19, 19], + 'cluster': [0, 6, 4, 0, 1], +}] + + +@pytest.mark.parametrize('test,dtype', product(tests, dtypes)) +def test_grid_cluster(test, dtype): + pos = Tensor(test['crd'], dtype) + size = Tensor(test['size'], dtype) + start = Tensor(test.get('start'), dtype) if 'start' in test else None + end = Tensor(test.get('end'), dtype) if 'end' in test else None + + cluster = grid_cluster(pos, size, start=start, end=end) + assert cluster.tolist() == test['cluster'] + + if is_full_test(): + jit = ms.jit(grid_cluster) + assert ops.equal(jit(pos, size, start, end), cluster).all() diff --git a/tests/graph/cluster/test_knn.py b/tests/graph/cluster/test_knn.py new file mode 100644 index 000000000..ac6149fe8 --- /dev/null +++ b/tests/graph/cluster/test_knn.py @@ -0,0 +1,88 @@ +import pytest +import scipy.spatial +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.cluster import knn, knn_graph +from mindscience.sharker.sparse.testing import grad_dtypes +from mindscience.sharker.testing import is_full_test + + +def to_set(edge_index): + return set([(i, j) for i, j in edge_index.t().tolist()]) + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_knn(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype) + y = Tensor([ + [1, 0], + [-1, 0], + ], dtype) + + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + batch_y = Tensor([0, 1], ms.int64) + + edge_index = knn(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + + if is_full_test(): + jit = ms.jit(knn) + edge_index = jit(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + + edge_index = knn(x, y, 2, batch_x, batch_y) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + + edge_index = knn(x, y, 2, batch_x, batch_y, cosine=True) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + + # Skipping a batch + batch_x = Tensor([0, 0, 0, 0, 2, 2, 2, 2], ms.int64) + batch_y = Tensor([0, 2], ms.int64) + edge_index = knn(x, y, 2, batch_x, batch_y) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_knn_graph(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype) + + edge_index = knn_graph(x, k=2, flow='trg_to_src') + assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), + (2, 3), (3, 0), (3, 2)]) + + edge_index = knn_graph(x, k=2, flow='src_to_trg') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + if is_full_test(): + jit = ms.jit(knn_graph) + edge_index = jit(x, k=2, flow='src_to_trg') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + + +@pytest.mark.parametrize('dtype', [ms.single]) +def test_knn_graph_large(dtype): + x = ops.randn(1000, 3, dtype=dtype) + + edge_index = knn_graph(x, k=5, flow='trg_to_src', loop=True) + + tree = scipy.spatial.cKDTree(x.asnumpy()) + _, col = tree.query(x, k=5) + truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) + + assert to_set(edge_index) == truth diff --git a/tests/graph/cluster/test_nearest.py b/tests/graph/cluster/test_nearest.py new file mode 100644 index 000000000..fab53ec6e --- /dev/null +++ b/tests/graph/cluster/test_nearest.py @@ -0,0 +1,63 @@ +import pytest +import mindspore as ms +from mindspore import Tensor +from mindscience.sharker.sparse.cluster import nearest +from mindscience.sharker.sparse. testing import grad_dtypes + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_nearest(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-2, -2], + [-2, +2], + [+2, +2], + [+2, -2], + ], dtype) + y = Tensor([ + [-1, 0], + [+1, 0], + [-2, 0], + [+2, 0], + ], dtype) + + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + batch_y = Tensor([0, 0, 1, 1], ms.int64) + + out = nearest(x, y, batch_x, batch_y) + assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] + + out = nearest(x, y) + assert out.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] + + # Invalid input: instance 1 only in batch_x + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + batch_y = Tensor([0, 0, 0, 0], ms.int64) + with pytest.raises(ValueError): + nearest(x, y, batch_x, batch_y) + + # Invalid input: instance 1 only in batch_x (implicitly as batch_y=None) + with pytest.raises(ValueError): + nearest(x, y, batch_x, batch_y=None) + + # Invalid input: instance 2 only in batch_x + # (i.e.instance in the middle missing) + batch_x = Tensor([0, 0, 1, 1, 2, 2, 3, 3], ms.int64) + batch_y = Tensor([0, 1, 3, 3], ms.int64) + with pytest.raises(ValueError): + nearest(x, y, batch_x, batch_y) + + # Invalid input: batch_x unsorted + batch_x = Tensor([0, 0, 1, 0, 0, 0, 0], ms.int64) + batch_y = Tensor([0, 0, 1, 1], ms.int64) + with pytest.raises(ValueError): + nearest(x, y, batch_x, batch_y) + + # Invalid input: batch_y unsorted + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + batch_y = Tensor([0, 0, 1, 0], ms.int64) + with pytest.raises(ValueError): + nearest(x, y, batch_x, batch_y) diff --git a/tests/graph/cluster/test_radius.py b/tests/graph/cluster/test_radius.py new file mode 100644 index 000000000..8c0946fcc --- /dev/null +++ b/tests/graph/cluster/test_radius.py @@ -0,0 +1,133 @@ +import pytest +import scipy.spatial +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.cluster import radius, radius_graph +from mindscience.sharker.testing import is_full_test + + +def to_set(edge_index): + return set([(i, j) for i, j in edge_index.t().tolist()]) + + +def to_degree(edge_index): + counts = np.bincount(edge_index[1]) + return counts.tolist() + + +def to_batch(nodes): + return [int(i / 4) for i in nodes] + + +@pytest.mark.parametrize('dtype', [ms.half, ms.single, ms.double]) +def test_radius(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype) + y = Tensor([ + [0, 0], + [0, 1], + ], dtype) + + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + batch_y = Tensor([0, 1], ms.int64) + + edge_index = radius(x, y, 2, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), + (1, 2), (1, 5), (1, 6)]) + if is_full_test(): + jit = ms.jit(radius) + edge_index = jit(x, y, 2, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 1), + (1, 2), (1, 5), (1, 6)]) + + edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), + (1, 6)]) + + # Skipping a batch + batch_x = Tensor([0, 0, 0, 0, 2, 2, 2, 2], ms.int64) + batch_y = Tensor([0, 2], ms.int64) + edge_index = radius(x, y, 2, batch_x, batch_y, max_num_neighbors=4) + assert to_set(edge_index) == set([(0, 0), (0, 1), (0, 2), (0, 3), (1, 5), + (1, 6)]) + + +@pytest.mark.parametrize('dtype', [ms.half, ms.single, ms.double]) +def test_radius_graph(dtype): + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype) + + edge_index = radius_graph(x, r=2.5, flow='trg_to_src') + assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), + (2, 3), (3, 0), (3, 2)]) + + edge_index = radius_graph(x, r=2.5, flow='src_to_trg') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + if is_full_test(): + jit = ms.jit(radius_graph) + edge_index = jit(x, r=2.5, flow='src_to_trg') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) + + edge_index = radius_graph(x, r=100, flow='src_to_trg', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + + x = Tensor([ + [-1, -1], + [-1, -1], + [-1, -1], + [-1, -1], + ], dtype) + + edge_index = radius_graph(x, r=100, flow='src_to_trg', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + + x = Tensor([ + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + [-1, -1], + [-1, +1], + [+1, +1], + [+1, -1], + ], dtype) + batch_x = Tensor([0, 0, 0, 0, 1, 1, 1, 1], ms.int64) + + edge_index = radius_graph(x, r=100, batch=batch_x, flow='src_to_trg', + max_num_neighbors=1) + assert set(to_degree(edge_index)) == set([1]) + assert to_batch(edge_index[0]) == batch_x.tolist() + + +@pytest.mark.parametrize('dtype', [ms.half, ms.single, ms.double]) +def test_radius_graph_large(dtype): + x = ops.randn(1000, 3, dtype=dtype) + + edge_index = radius_graph(x, + r=0.5, + flow='trg_to_src', + loop=True, + max_num_neighbors=2000) + + tree = scipy.spatial.cKDTree(x.asnumpy()) + col = tree.query_ball_point(x.asnumpy(), r=0.5) + truth = set([(i, j) for i, ns in enumerate(col) for j in ns]) + + assert to_set(edge_index) == truth diff --git a/tests/graph/cluster/test_rw.py b/tests/graph/cluster/test_rw.py new file mode 100644 index 000000000..2bc05c798 --- /dev/null +++ b/tests/graph/cluster/test_rw.py @@ -0,0 +1,85 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.sparse.cluster import random_walk +from mindscience.sharker.sparse. testing import Tensor +from mindscience.sharker.testing import is_full_test + + +def test_rw_large(): + row = Tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], ms.int64) + col = Tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], ms.int64) + start = Tensor([0, 1, 2, 3, 4], ms.int64) + walk_length = 10 + + out = random_walk(row, col, start, walk_length) + assert out[:, 0].tolist() == start.tolist() + + for n in range(start.shape[0]): + cur = start[n].item() + for i in range(1, walk_length): + assert out[n, i].item() in col[row == cur].tolist() + cur = out[n, i].item() + + +def test_rw_small(): + row = Tensor([0, 1], ms.int64) + col = Tensor([1, 0], ms.int64) + start = Tensor([0, 1, 2], ms.int64) + walk_length = 4 + + out = random_walk(row, col, start, walk_length, num_nodes=3) + assert out.tolist() == [[0, 1, 0, 1, 0], [1, 0, 1, 0, 1], [2, 2, 2, 2, 2]] + + if is_full_test(): + jit = ms.jit(random_walk) + assert ops.equal(jit(row, col, start, walk_length, num_nodes=3), out).all() + + +def test_rw_large_with_edge_indices(): + row = Tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4], ms.int64) + col = Tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3], ms.int64) + start = Tensor([0, 1, 2, 3, 4], ms.int64) + walk_length = 10 + + node_seq, edge_seq = random_walk( + row, + col, + start, + walk_length, + return_edge_indices=True, + ) + assert node_seq[:, 0].tolist() == start.tolist() + + for n in range(start.shape[0]): + cur = start[n].item() + for i in range(1, walk_length): + assert node_seq[n, i].item() in col[row == cur].tolist() + cur = node_seq[n, i].item() + + assert (edge_seq != -1).all() + + +def test_rw_small_with_edge_indices(): + row = Tensor([0, 1], ms.int64) + col = Tensor([1, 0], ms.int64) + start = Tensor([0, 1, 2], ms.int64) + walk_length = 4 + + node_seq, edge_seq = random_walk( + row, + col, + start, + walk_length, + num_nodes=3, + return_edge_indices=True, + ) + assert node_seq.tolist() == [ + [0, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [2, 2, 2, 2, 2], + ] + assert edge_seq.tolist() == [ + [0, 1, 0, 1], + [1, 0, 1, 0], + [-1, -1, -1, -1], + ] diff --git a/tests/graph/conftest.py b/tests/graph/conftest.py new file mode 100644 index 000000000..e4bb7af21 --- /dev/null +++ b/tests/graph/conftest.py @@ -0,0 +1,95 @@ +import functools +import logging +import os.path as osp +from typing import Callable + +import pytest + +import sharker.typing +from mindscience.sharker.data import Dataset +from mindscience.sharker.io import fs + + +def load_dataset(root: str, name: str, *args, **kwargs) -> Dataset: + r"""Returns a variety of datasets according to :obj:`name`.""" + if 'karate' in name.lower(): + from mindscience.sharker.datasets import KarateClub + return KarateClub(*args, **kwargs) + if name.lower() in ['cora', 'citeseer', 'pubmed']: + from mindscience.sharker.datasets import Planetoid + path = osp.join(root, 'Planetoid', name) + return Planetoid(path, name, *args, **kwargs) + if name in ['BZR', 'ENZYMES', 'IMDB-BINARY', 'MUTAG']: + from mindscience.sharker.datasets import TUDataset + path = osp.join(root, 'TUDataset') + return TUDataset(path, name, *args, **kwargs) + if name in ['ego-facebook', 'soc-Slashdot0811', 'wiki-vote']: + from mindscience.sharker.datasets import SNAPDataset + path = osp.join(root, 'SNAPDataset') + return SNAPDataset(path, name, *args, **kwargs) + if name.lower() in ['bashapes']: + from mindscience.sharker.datasets import BAShapes + return BAShapes(*args, **kwargs) + if name in ['citationCiteseer', 'illc1850']: + from mindscience.sharker.datasets import SuiteSparseMatrixCollection + path = osp.join(root, 'SuiteSparseMatrixCollection') + return SuiteSparseMatrixCollection(path, name=name, *args, **kwargs) + if 'elliptic' in name.lower(): + from mindscience.sharker.datasets import EllipticBitcoinDataset + path = osp.join(root, 'EllipticBitcoinDataset') + return EllipticBitcoinDataset(path, *args, **kwargs) + if name.lower() in ['hetero']: + from mindscience.sharker.testing import FakeHeteroDataset + return FakeHeteroDataset(*args, **kwargs) + + raise ValueError(f"Cannot load dataset with name '{name}'") + + +@pytest.fixture(scope='session') +def get_dataset() -> Callable: + # TODO Support memory filesystem on Windows. + if sharker.typing.WITH_WINDOWS: + root = osp.join('/', 'tmp', 'pyg_test_datasets') + else: + root = 'memory://pyg_test_datasets' + + yield functools.partial(load_dataset, root) + + fs.rm(root) + + +@pytest.fixture +def enable_extensions(): # Nothing to do. + yield + + +@pytest.fixture +def disable_extensions(): + def is_setting(name: str) -> bool: + if not name.startswith('WITH_'): + return False + if name.startswith('WITH_PT') or name.startswith('WITH_WINDOWS'): + return False + return True + + settings = dir(sharker.typing) + settings = [key for key in settings if is_setting(key)] + state = {key: getattr(sharker.typing, key) for key in settings} + + for key in state.keys(): + setattr(sharker.typing, key, False) + yield + for key, value in state.items(): + setattr(sharker.typing, key, value) + + +@pytest.fixture +def without_extensions(request): + request.getfixturevalue(request.param) + return request.param == 'disable_extensions' + + +# @pytest.fixture(scope='function') +# def spawn_context(): +# torch.multiprocessing.set_start_method('spawn', force=True) +# logging.info("Setting torch.multiprocessing context to 'spawn'") diff --git a/tests/graph/data/test_batch.py b/tests/graph/data/test_batch.py new file mode 100644 index 000000000..e6eaa136f --- /dev/null +++ b/tests/graph/data/test_batch.py @@ -0,0 +1,379 @@ +import os.path as osp +import pickle +import numpy as np +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data import Batch, Graph, HeteroGraph +from mindscience.sharker.testing import get_random_edge_index + +def test_batch_basic(): + + x = ms.Tensor([1.0, 2.0, 3.0]) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + data1 = Graph(x=x, y=1, edge_index=edge_index, string='1', array=['1', '2'], + num_nodes=3) + + x = ms.Tensor([1.0, 2.0]) + edge_index = ms.Tensor([[0, 1], [1, 0]]) + data2 = Graph(x=x, y=2, edge_index=edge_index, string='2', + array=['3', '4', '5'], num_nodes=2) + + x = ms.Tensor([1.0, 2.0, 3.0, 4.0]) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + data3 = Graph(x=x, y=3, edge_index=edge_index, string='3', + array=['6', '7', '8', '9'], num_nodes=4) + + batch = Batch.from_data_list([data1]) + assert str(batch) == ('GraphBatch(x=[3], edge_index=[2, 4], y=[1], ' + 'string=[1], array=[1], num_nodes=3, batch=[3], ' + 'ptr=[2])') + assert batch.num_graphs == len(batch) == 1 + assert batch.x.tolist() == [1, 2, 3] + assert batch.y.tolist() == [1] + assert batch.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert batch.string == ['1'] + assert batch.array == [['1', '2']] + assert batch.num_nodes == 3 + assert batch.batch.tolist() == [0, 0, 0] + assert batch.ptr.tolist() == [0, 3] + + batch = Batch.from_data_list([data1, data2, data3], + follow_batch=['string']) + + assert str(batch) == ('GraphBatch(x=[9], edge_index=[2, 12], y=[3], ' + 'string=[3], string_batch=[3], string_ptr=[4], ' + 'array=[3], num_nodes=9, batch=[9], ptr=[4])') + assert batch.num_graphs == len(batch) == 3 + assert batch.x.tolist() == [1, 2, 3, 1, 2, 1, 2, 3, 4] + assert batch.y.tolist() == [1, 2, 3] + assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 5, 6, 6, 7, 7, 8], + [1, 0, 2, 1, 4, 3, 6, 5, 7, 6, 8, 7]] + assert batch.string == ['1', '2', '3'] + assert batch.string_batch.tolist() == [0, 1, 2] + assert batch.string_ptr.tolist() == [0, 1, 2, 3] + assert batch.array == [['1', '2'], ['3', '4', '5'], ['6', '7', '8', '9']] + assert batch.num_nodes == 9 + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 2, 2, 2, 2] + assert batch.ptr.tolist() == [0, 3, 5, 9] + + assert str(batch[0]) == ("Graph(x=[3], edge_index=[2, 4], y=[1], " + "string='1', array=[2], num_nodes=3)") + assert str(batch[1]) == ("Graph(x=[2], edge_index=[2, 2], y=[1], " + "string='2', array=[3], num_nodes=2)") + assert str(batch[2]) == ("Graph(x=[4], edge_index=[2, 6], y=[1], " + "string='3', array=[4], num_nodes=4)") + + assert len(batch.index_select([1, 0])) == 2 + assert len(batch.index_select(ms.Tensor([1, 0]))) == 2 + assert len(batch.index_select(ms.Tensor([True, False]))) == 1 + assert len(batch.index_select(np.array([1, 0], dtype=np.int64))) == 2 + assert len(batch.index_select(np.array([True, False]))) == 1 + assert len(batch[:2]) == 2 + + data_list = batch.to_data_list() + assert len(data_list) == 3 + + assert len(data_list[0]) == 6 + assert data_list[0].x.tolist() == [1, 2, 3] + assert data_list[0].y.tolist() == [1] + assert data_list[0].edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert data_list[0].string == '1' + assert data_list[0].array == ['1', '2'] + assert data_list[0].num_nodes == 3 + + assert len(data_list[1]) == 6 + assert data_list[1].x.tolist() == [1, 2] + assert data_list[1].y.tolist() == [2] + assert data_list[1].edge_index.tolist() == [[0, 1], [1, 0]] + assert data_list[1].string == '2' + assert data_list[1].array == ['3', '4', '5'] + assert data_list[1].num_nodes == 2 + + assert len(data_list[2]) == 6 + assert data_list[2].x.tolist() == [1, 2, 3, 4] + assert data_list[2].y.tolist() == [3] + assert data_list[2].edge_index.tolist() == [[0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2]] + assert data_list[2].string == '3' + assert data_list[2].array == ['6', '7', '8', '9'] + assert data_list[2].num_nodes == 4 + +def test_batching_with_new_dimension(): + + class MyGraph(Graph): + def __cat_dim__(self, key, value, *args, **kwargs): + if key == 'foo': + return None + else: + return super().__cat_dim__(key, value, *args, **kwargs) + + x1 = ms.Tensor([1, 2, 3]).float() + foo1 = ops.randn(4) + y1 = ms.Tensor(1) + + x2 = ms.Tensor([1, 2]).float() + foo2 = ops.randn(4) + y2 = ms.Tensor(2) + + batch = Batch.from_data_list( + [MyGraph(x=x1, foo=foo1, y=y1), + MyGraph(x=x2, foo=foo2, y=y2)]) + + assert str(batch) == ('MyGraphBatch(x=[5], y=[2], foo=[2, 4], batch=[5], ' + 'ptr=[3])') + assert batch.num_graphs == len(batch) == 2 + assert batch.x.tolist() == [1, 2, 3, 1, 2] + assert batch.foo.shape == (2, 4) + assert batch.foo[0].tolist() == foo1.tolist() + assert batch.foo[1].tolist() == foo2.tolist() + assert batch.y.tolist() == [1, 2] + assert batch.batch.tolist() == [0, 0, 0, 1, 1] + assert batch.ptr.tolist() == [0, 3, 5] + assert batch.num_graphs == 2 + + data = batch[0] + assert str(data) == ('MyGraph(x=[3], y=[1], foo=[4])') + data = batch[1] + assert str(data) == ('MyGraph(x=[2], y=[1], foo=[4])') + + +def test_pickling(tmp_path): + data = Graph(x=ops.randn(5, 16)) + batch = Batch.from_data_list([data, data, data, data]) + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + path = osp.join(tmp_path, 'batch.ckpt') + pickle.dump(batch.numpy(), open(path, 'wb')) + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + batch = pickle.load(open(path, 'rb')) + assert id(batch._store._parent()) == id(batch) + assert batch.num_nodes == 20 + + assert batch.__class__.__name__ == 'GraphBatch' + assert batch.num_graphs == len(batch) == 4 + + +def test_recursive_batch(): + data1 = Graph( + x={ + '1': ops.randn(10, 32), + '2': ops.randn(20, 48) + }, + edge_index=[ + get_random_edge_index(30, 30, 50), + get_random_edge_index(30, 30, 70) + ], + num_nodes=30, + ) + + data2 = Graph( + x={ + '1': ops.randn(20, 32), + '2': ops.randn(40, 48) + }, + edge_index=[ + get_random_edge_index(60, 60, 80), + get_random_edge_index(60, 60, 90) + ], + num_nodes=60, + ) + + batch = Batch.from_data_list([data1, data2]) + + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 90 + + assert mint.isclose(batch.x['1'], + mint.cat(([data1.x['1'], data2.x['1']]), dim=0), + equal_nan = True).all() + assert mint.isclose(batch.x['2'], + mint.cat(([data1.x['2'], data2.x['2']]), dim=0), + equal_nan = True).all() + assert (batch.edge_index[0].tolist() == mint.cat( + [data1.edge_index[0], data2.edge_index[0] + 30], dim=1).tolist()) + assert (batch.edge_index[1].tolist() == mint.cat( + [data1.edge_index[1], data2.edge_index[1] + 30], dim=1).tolist()) + assert batch.batch.shape == (90, ) + assert batch.ptr.shape == (3, ) + + out1 = batch[0] + assert len(out1) == 3 + assert out1.num_nodes == 30 + assert mint.isclose(out1.x['1'], data1.x['1'], equal_nan = True).all() + assert mint.isclose(out1.x['2'], data1.x['2'], equal_nan = True).all() + assert out1.edge_index[0].tolist(), data1.edge_index[0].tolist() + assert out1.edge_index[1].tolist(), data1.edge_index[1].tolist() + + out2 = batch[1] + assert len(out2) == 3 + assert out2.num_nodes == 60 + assert mint.isclose(out2.x['1'], data2.x['1'], equal_nan = True).all() + assert mint.isclose(out2.x['2'], data2.x['2'], equal_nan = True).all() + assert out2.edge_index[0].tolist(), data2.edge_index[0].tolist() + assert out2.edge_index[1].tolist(), data2.edge_index[1].tolist() + + +def test_batching_of_batches(): + data = Graph(x=ops.randn(2, 16)) + batch = Batch.from_data_list([data, data]) + + batch = Batch.from_data_list([batch, batch]) + assert batch.num_graphs == len(batch) == 2 + assert batch.x[0:2].tolist() == data.x.tolist() + assert batch.x[2:4].tolist() == data.x.tolist() + assert batch.x[4:6].tolist() == data.x.tolist() + assert batch.x[6:8].tolist() == data.x.tolist() + assert batch.batch.tolist() == [0, 0, 1, 1, 2, 2, 3, 3] + + +def test_hetero_batch(): + e1 = ('p', 'a') + e2 = ('a', 'p') + data1 = HeteroGraph() + data1['p'].x = ops.randn(100, 128) + data1['a'].x = ops.randn(200, 128) + data1[e1].edge_index = get_random_edge_index(100, 200, 500) + data1[e1].edge_attr = ops.randn(500, 32) + data1[e2].edge_index = get_random_edge_index(200, 100, 400) + data1[e2].edge_attr = ops.randn(400, 32) + + data2 = HeteroGraph() + data2['p'].x = ops.randn(50, 128) + data2['a'].x = ops.randn(100, 128) + data2[e1].edge_index = get_random_edge_index(50, 100, 300) + data2[e1].edge_attr = ops.randn(300, 32) + data2[e2].edge_index = get_random_edge_index(100, 50, 200) + data2[e2].edge_attr = ops.randn(200, 32) + + batch = Batch.from_data_list([data1, data2]) + + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 450 + + assert mint.isclose(batch['p'].x[:100], data1['p'].x, equal_nan = True).all() + assert mint.isclose(batch['a'].x[:200], data1['a'].x, equal_nan = True).all() + assert mint.isclose(batch['p'].x[100:], data2['p'].x, equal_nan = True).all() + assert mint.isclose(batch['a'].x[200:], data2['a'].x, equal_nan = True).all() + assert (batch[e1].edge_index.tolist() == mint.cat([ + data1[e1].edge_index, + data2[e1].edge_index + ms.Tensor([[100], [200]]) + ], 1).tolist()) + assert mint.isclose( + batch[e1].edge_attr, + mint.cat([data1[e1].edge_attr, data2[e1].edge_attr], 0), + equal_nan = True).all() + assert (batch[e2].edge_index.tolist() == mint.cat([ + data1[e2].edge_index, + data2[e2].edge_index + ms.Tensor([[200], [100]]) + ], 1).tolist()) + assert mint.isclose( + batch[e2].edge_attr, + mint.cat([data1[e2].edge_attr, data2[e2].edge_attr], 0), + equal_nan = True).all() + assert batch['p'].batch.shape == (150, ) + assert batch['p'].ptr.shape == (3, ) + assert batch['a'].batch.shape == (300, ) + assert batch['a'].ptr.shape == (3, ) + + out1 = batch[0] + assert len(out1) == 3 + assert out1.num_nodes == 300 + assert mint.isclose(out1['p'].x, data1['p'].x, equal_nan = True).all() + assert mint.isclose(out1['a'].x, data1['a'].x, equal_nan = True).all() + assert out1[e1].edge_index.tolist() == data1[e1].edge_index.tolist() + assert mint.isclose(out1[e1].edge_attr, data1[e1].edge_attr, equal_nan = True).all() + assert out1[e2].edge_index.tolist() == data1[e2].edge_index.tolist() + assert mint.isclose(out1[e2].edge_attr, data1[e2].edge_attr, equal_nan = True).all() + + out2 = batch[1] + assert len(out2) == 3 + assert out2.num_nodes == 150 + assert mint.isclose(out2['p'].x, data2['p'].x, equal_nan = True).all() + assert mint.isclose(out2['a'].x, data2['a'].x, equal_nan = True).all() + assert out2[e1].edge_index.tolist() == data2[e1].edge_index.tolist() + assert mint.isclose(out2[e1].edge_attr, data2[e1].edge_attr, equal_nan = True).all() + assert out2[e2].edge_index.tolist() == data2[e2].edge_index.tolist() + assert mint.isclose(out2[e2].edge_attr, data2[e2].edge_attr, equal_nan = True).all() + + +def test_pair_data_batching(): + class PairGraph(Graph): + def __inc__(self, key, value, *args, **kwargs): + if key == 'edge_index_s': + return self.x_s.shape[0] + if key == 'edge_index_t': + return self.x_t.shape[0] + return super().__inc__(key, value, *args, **kwargs) + + x_s = ops.randn(5, 16) + edge_index_s = ms.Tensor([ + [0, 0, 0, 0], + [1, 2, 3, 4], + ]) + x_t = ops.randn(4, 16) + edge_index_t = ms.Tensor([ + [0, 0, 0], + [1, 2, 3], + ]) + + data = PairGraph(x_s=x_s, edge_index_s=edge_index_s, x_t=x_t, + edge_index_t=edge_index_t) + batch = Batch.from_data_list([data, data]) + + assert mint.isclose(batch.x_s, mint.cat(([x_s, x_s]), dim=0), equal_nan = True).all() + assert batch.edge_index_s.tolist() == [[0, 0, 0, 0, 5, 5, 5, 5], + [1, 2, 3, 4, 6, 7, 8, 9]] + + assert mint.isclose(batch.x_t, mint.cat(([x_t, x_t]), dim=0), equal_nan = True).all() + assert batch.edge_index_t.tolist() == [[0, 0, 0, 4, 4, 4], + [1, 2, 3, 5, 6, 7]] + + +def test_batch_with_empty_list(): + x = ops.randn(4, 1) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + data = Graph(x=x, edge_index=edge_index, nontensor=[]) + + batch = Batch.from_data_list([data, data]) + assert batch.nontensor == [[], []] + assert batch[0].nontensor == [] + assert batch[1].nontensor == [] + + +def test_nested_follow_batch(): + def tr(n, m): + return ms.ops.rand(n, m) + + d1 = Graph(xs=[tr(4, 3), tr(11, 4), tr(1, 2)], a={"aa": tr(11, 3)}, + x=tr(10, 5)) + d2 = Graph(xs=[tr(5, 3), tr(14, 4), tr(3, 2)], a={"aa": tr(2, 3)}, + x=tr(11, 5)) + d3 = Graph(xs=[tr(6, 3), tr(15, 4), tr(2, 2)], a={"aa": tr(4, 3)}, + x=tr(9, 5)) + d4 = Graph(xs=[tr(4, 3), tr(16, 4), tr(1, 2)], a={"aa": tr(8, 3)}, + x=tr(8, 5)) + + data_list = [d1, d2, d3, d4] + + batch = Batch.from_data_list(data_list, follow_batch=['xs', 'a']) + + assert batch.xs[0].shape == (19, 3) + assert batch.xs[1].shape == (56, 4) + assert batch.xs[2].shape == (7, 2) + assert batch.a['aa'].shape == (25, 3) + + assert len(batch.xs_batch) == 3 + assert len(batch.a_batch) == 1 + + assert batch.xs_batch[0].tolist() == \ + [0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3] + assert batch.xs_batch[1].tolist() == \ + [0] * 11 + [1] * 14 + [2] * 15 + [3] * 16 + assert batch.xs_batch[2].tolist() == \ + [0] * 1 + [1] * 3 + [2] * 2 + [3] * 1 + + assert batch.a_batch['aa'].tolist() == \ + [0] * 11 + [1] * 2 + [2] * 4 + [3] * 8 diff --git a/tests/graph/data/test_data.py b/tests/graph/data/test_data.py new file mode 100644 index 000000000..9fd649a40 --- /dev/null +++ b/tests/graph/data/test_data.py @@ -0,0 +1,429 @@ +import copy +import multiprocessing as mp + +import pytest +import mindspore as ms +from mindspore import ops, mint + +from mindscience.sharker.data import Graph +from mindscience.sharker.data.storage import AttrType +from mindscience.sharker.testing import withPackage + + +def test_data(): + + x = ms.Tensor([[1, 3, 5], [2, 4, 6]]).float().t() + edge_index = ms.Tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) + data = Graph(x=x, edge_index=edge_index) + data.validate(raise_on_error=True) + + N = data.num_nodes + assert N == 3 + + assert data.node_attrs() == ['x'] + assert data.edge_attrs() == ['edge_index'] + + assert data.x.tolist() == x.tolist() + assert data['x'].tolist() == x.tolist() + assert data.get('x').tolist() == x.tolist() + assert data.get('y', 2) == 2 + assert data.get('y', None) is None + assert data.num_edge_types == 1 + assert data.num_node_types == 1 + + assert sorted(data.keys()) == ['edge_index', 'x'] + assert len(data) == 2 + assert 'x' in data and 'edge_index' in data and 'pos' not in data + + data.apply_(lambda x: x * 2, 'x') + assert mint.isclose(data.x, x, equal_nan = True).all() + + D = data.to_dict() + assert len(D) == 2 + assert 'x' in D and 'edge_index' in D + + D = data.to_namedtuple() + assert len(D) == 2 + assert D.x is not None and D.edge_index is not None + + assert data.__cat_dim__('x', data.x) == 0 + assert data.__cat_dim__('edge_index', data.edge_index) == -1 + assert data.__inc__('x', data.x) == 0 + assert data.__inc__('edge_index', data.edge_index) == data.num_nodes + + assert not data.is_coalesced() + data = data.coalesce() + assert data.is_coalesced() + + clone = data.copy() + assert clone != data + assert len(clone) == len(data) + assert clone.x is not data + assert clone.x.tolist() == data.x.tolist() + assert clone.edge_index is not data.edge_index + assert clone.edge_index.tolist() == data.edge_index.tolist() + + out = data.to_hetero() + assert mint.isclose(data.x, out['0'].x, equal_nan = True).all() + assert mint.isclose(data.edge_index, out['0', '0'].edge_index, equal_nan = True).all() + + data.edge_type = ms.Tensor([0, 0, 1, 0]) + out = data.to_hetero() + assert mint.isclose(data.x, out['0'].x, equal_nan = True).all() + assert [store.num_edges for store in out.edge_stores] == [3, 1] + data.edge_type = None + + data['x'] = x + 1 + assert data.x.tolist() == (x + 1).tolist() + + assert str(data) == 'Graph(x=[3, 2], edge_index=[2, 4])' + + dictionary = {'x': data.x, 'edge_index': data.edge_index} + data = Graph.from_dict(dictionary) + assert sorted(data.keys()) == ['edge_index', 'x'] + + assert not data.has_isolated_nodes() + assert not data.has_self_loops() + assert data.is_undirected() + assert not data.is_directed() + + assert data.num_nodes == 3 + assert data.num_edges == 4 + assert data.num_node_features == 2 + assert data.num_features == 2 + + data.edge_attr = ops.randn(data.num_edges, 2) + assert data.num_edge_features == 2 + data.edge_attr = None + + data.x = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes == 3 + + data.edge_index = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes is None + assert data.num_edges == 0 + + data.num_nodes = 4 + assert data.num_nodes == 4 + + data = Graph(x=x, attribute=x) + assert len(data) == 2 + assert data.x.tolist() == x.tolist() + assert data.attribute.tolist() == x.tolist() + + face = ms.Tensor([[0, 1], [1, 2], [2, 3]]) + data = Graph(num_nodes=4, face=face) + assert data.num_nodes == 4 + + data = Graph(title='test') + assert str(data) == "Graph(title='test')" + assert data.num_node_features == 0 + assert data.num_edge_features == 0 + + key = value = 'test_value' + data[key] = value + assert data[key] == value + del data[value] + del data[value] + assert data.get(key) is None + assert data.get('title') == 'test' + + +def test_data_attr_cache(): + x = ops.randn(3, 16) + edge_index = ms.Tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) + edge_attr = ops.randn(5, 4) + y = ms.Tensor([0]) + + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr, y=y) + + assert data.is_node_attr('x') + assert 'x' in data._store._cached_attr[AttrType.NODE] + assert 'x' not in data._store._cached_attr[AttrType.EDGE] + assert 'x' not in data._store._cached_attr[AttrType.OTHER] + + assert not data.is_node_attr('edge_index') + assert 'edge_index' not in data._store._cached_attr[AttrType.NODE] + assert 'edge_index' in data._store._cached_attr[AttrType.EDGE] + assert 'edge_index' not in data._store._cached_attr[AttrType.OTHER] + + assert data.is_edge_attr('edge_attr') + assert 'edge_attr' not in data._store._cached_attr[AttrType.NODE] + assert 'edge_attr' in data._store._cached_attr[AttrType.EDGE] + assert 'edge_attr' not in data._store._cached_attr[AttrType.OTHER] + + assert not data.is_edge_attr('y') + assert 'y' not in data._store._cached_attr[AttrType.NODE] + assert 'y' not in data._store._cached_attr[AttrType.EDGE] + assert 'y' in data._store._cached_attr[AttrType.OTHER] + + +def test_data_attr_cache_not_shared(): + x = ops.rand((4, 4)) + edge_index = ms.Tensor([[0, 1, 2, 3, 0, 1], [0, 1, 2, 3, 0, 1]]) + time = mint.arange(edge_index.shape[1]) + data = Graph(x=x, edge_index=edge_index, time=time) + assert data.is_node_attr('x') + + out = data.up_to(3.5) + # This is expected behavior due to the ambiguity of between node-level and + # edge-level tensors when they share the same number of nodes/edges. + assert out.is_node_attr('time') + assert not data.is_node_attr('time') + + +def test_data_subgraph(): + x = mint.arange(5) + y = ms.Tensor([0.]) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3]]) + edge_weight = mint.arange(edge_index.shape[1]) + + data = Graph(x=x, y=y, edge_index=edge_index, edge_weight=edge_weight, + num_nodes=5) + + out = data.subgraph(ms.Tensor([1, 2, 3])) + assert len(out) == 5 + assert ops.equal(out.x, mint.arange(1, 4)).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert ops.equal(out.edge_weight, edge_weight[mint.arange(2, 6)]).all() + assert out.num_nodes == 3 + + # Test unordered selection: + out = data.subgraph(ms.Tensor([3, 1, 2])) + assert len(out) == 5 + assert ops.equal(out.x, ms.Tensor([3, 1, 2])).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[1, 2, 2, 0], [2, 1, 0, 2]] + assert ops.equal(out.edge_weight, edge_weight[mint.arange(2, 6)]).all() + assert out.num_nodes == 3 + + # Not support Bool Tensor at the moment + out = data.subgraph(ms.Tensor([False, False, False, True, True])) + assert len(out) == 5 + assert ops.equal(out.x, mint.arange(3, 5)).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[0, 1], [1, 0]] + assert ops.equal(out.edge_weight, edge_weight[mint.arange(6, 8)]).all() + assert out.num_nodes == 2 + + out = data.edge_subgraph(ms.Tensor([1, 2, 3])) + assert len(out) == 5 + assert out.num_nodes == data.num_nodes + assert ops.equal(out.x, data.x).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]] + assert ops.equal(out.edge_weight, edge_weight[ms.Tensor([1, 2, 3])]).all() + + out = data.edge_subgraph( + ms.Tensor([False, True, True, True, False, False, False, False])) + assert len(out) == 5 + assert out.num_nodes == data.num_nodes + assert ops.equal(out.x, data.x).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[1, 1, 2], [0, 2, 1]] + assert ops.equal(out.edge_weight, edge_weight[ms.Tensor([1, 2, 3])]).all() + + +def test_data_subgraph_with_list_field(): + x = mint.arange(5) + y = mint.arange(5) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3]]) + data = Graph(x=x, y=y, edge_index=edge_index) + + out = data.subgraph(ms.Tensor([1, 2, 3])) + assert len(out) == 3 + assert out.x.tolist() == out.y.tolist() == [1, 2, 3] + + # not support Bool Tensor at the moment + out = data.subgraph(ms.Tensor([False, True, True, True, False])) + assert len(out) == 3 + assert out.x.tolist() == out.y.tolist() == [1, 2, 3] + + +def test_data_empty_subgraph(): + data = Graph(x=mint.arange(5), y=ms.Tensor(0.0)) + + out = data.subgraph(ms.Tensor([1, 2, 3])) + assert 'edge_index' not in out + assert ops.equal(out.x, mint.arange(1, 4)).all() + assert ops.equal(out.y, data.y).all() + assert out.num_nodes == 3 + + +def test_copy_data(): + data = Graph(x=ops.randn(20, 5)) + + out = copy.copy(data) + assert id(data) != id(out) + assert id(data._store) != id(out._store) + assert len(data.stores) == len(out.stores) + for store1, store2 in zip(data.stores, out.stores): + assert id(store1) != id(store2) + assert id(data) == id(store1._parent()) + assert id(out) == id(store2._parent()) + assert data.x is out.x + + out = copy.deepcopy(data) + assert id(data) != id(out) + assert id(data._store) != id(out._store) + assert len(data.stores) == len(out.stores) + for store1, store2 in zip(data.stores, out.stores): + assert id(store1) != id(store2) + assert id(data) == id(store1._parent()) + assert id(out) == id(store2._parent()) + assert data.x is not out.x + assert data.x.tolist() == out.x.tolist() + + +def test_data_sort(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 2, 1, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = ops.randn(6, 8) + + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr) + assert not data.is_sorted(sort_by_row=True) + assert not data.is_sorted(sort_by_row=False) + + out = data.sort(sort_by_row=True) + assert out.is_sorted(sort_by_row=True) + assert not out.is_sorted(sort_by_row=False) + assert ops.equal(out.x, data.x).all() + assert out.edge_index.tolist() == [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]] + assert ops.equal( + out.edge_attr, + data.edge_attr[ms.Tensor([0, 1, 2, 4, 3, 5])], + ).all() + + out = data.sort(sort_by_row=False) + assert not out.is_sorted(sort_by_row=True) + assert out.is_sorted(sort_by_row=False) + assert ops.equal(out.x, data.x).all() + assert out.edge_index.tolist() == [[1, 2, 3, 0, 0, 0], [0, 0, 0, 1, 2, 3]] + assert ops.equal( + out.edge_attr, + data.edge_attr[ms.Tensor([4, 3, 5, 0, 1, 2])], + ).all() + + +def test_debug_data(): + + Graph() + Graph(edge_index=mint.zeros((2, 0), dtype=ms.int64), num_nodes=10) + Graph(face=mint.zeros((3, 0), dtype=ms.int64), num_nodes=10) + Graph(edge_index=ms.Tensor([[0, 1], [1, 0]]), edge_attr=ops.randn(2)) + Graph(x=ops.randn(5, 3), num_nodes=5) + Graph(pos=ops.randn(5, 3), num_nodes=5) + Graph(norm=ops.randn(5, 3), num_nodes=5) + + +def run(rank, data_list): + for data in data_list: + assert data.x.is_shared() + data.x.add_(1) + + +def test_data_setter_properties(): + class MyData(Graph): + def __init__(self): + super().__init__() + self.my_attr1 = 1 + self.my_attr2 = 2 + + @property + def my_attr1(self): + return self._my_attr1 + + @my_attr1.setter + def my_attr1(self, value): + self._my_attr1 = value + + data = MyData() + assert data.my_attr2 == 2 + + assert 'my_attr1' not in data._store + assert data.my_attr1 == 1 + + data.my_attr1 = 2 + assert 'my_attr1' not in data._store + assert data.my_attr1 == 2 + + +def test_data_update(): + data = Graph(x=mint.arange(0, 5), y=mint.arange(5, 10)) + other = Graph(z=mint.arange(10, 15), x=mint.arange(15, 20)) + data.update(other) + + assert len(data) == 3 + assert ops.equal(data.x, mint.arange(15, 20)).all() + assert ops.equal(data.y, mint.arange(5, 10)).all() + assert ops.equal(data.z, mint.arange(10, 15)).all() + + +def test_data_generate_ids(): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 0, 1, 1, 2], [1, 1, 0, 2, 1]]) + + data = Graph(x=x, edge_index=edge_index) + assert len(data) == 2 + + data.generate_ids() + assert len(data) == 4 + assert data.n_id.tolist() == [0, 1, 2] + assert data.e_id.tolist() == [0, 1, 2, 3, 4] + + +@pytest.mark.parametrize('num_nodes', [4]) +@pytest.mark.parametrize('num_edges', [8]) +def test_data_time_handling(num_nodes, num_edges): + data = Graph( + x=ops.randn(num_nodes, 12), + edge_index=ops.randint(0, num_nodes, (2, num_edges)), + edge_attr=ops.rand(num_edges, 16), + time=mint.arange(num_edges), + num_nodes=num_nodes, + ) + + assert data.is_edge_attr('time') + assert not data.is_node_attr('time') + assert data.is_sorted_by_time() + + out = data.up_to(5) + assert out.num_edges == 6 + assert mint.isclose(out.x, data.x, equal_nan = True).all() + assert ops.equal(out.edge_index, data.edge_index[:, :6]).all() + assert mint.isclose(out.edge_attr, data.edge_attr[:6], equal_nan = True).all() + assert ops.equal(out.time, data.time[:6]).all() + + out = data.snapshot(2, 5) + assert out.num_edges == 4 + assert mint.isclose(out.x, data.x, equal_nan = True).all() + assert ops.equal(out.edge_index, data.edge_index[:, 2:6], ).all() + assert mint.isclose(out.edge_attr, data.edge_attr[2:6, :], equal_nan = True).all() + assert ops.equal(out.time, data.time[2:6]).all() + + out = data.sort_by_time() + assert data.is_sorted_by_time() + + out = data.concat(data) + assert out.num_nodes == 8 + assert not out.is_sorted_by_time() + + assert mint.isclose(out.x, mint.cat(([data.x, data.x]), dim=0), equal_nan = True).all() + assert ops.equal( + out.edge_index, + mint.cat(([data.edge_index, data.edge_index]), dim=1), + ).all() + assert mint.isclose( + out.edge_attr, + mint.cat(([data.edge_attr, data.edge_attr]), dim=0), + equal_nan = True + ).all() + assert mint.isclose(out.time, mint.cat(([data.time, data.time]), dim=0), equal_nan = True).all() + + out = out.sort_by_time() + assert ops.equal(out.time, data.time.repeat_interleave(2)).all() diff --git a/tests/graph/data/test_database.py b/tests/graph/data/test_database.py new file mode 100644 index 000000000..4a1504030 --- /dev/null +++ b/tests/graph/data/test_database.py @@ -0,0 +1,219 @@ +import math +import os.path as osp + +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data import Graph, RocksDatabase, SQLiteDatabase +from mindscience.sharker.data.database import TensorInfo +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import has_package, withPackage + +AVAILABLE_DATABASES = [] +if has_package('sqlite3'): + AVAILABLE_DATABASES.append(SQLiteDatabase) +if has_package('rocksdict'): + AVAILABLE_DATABASES.append(RocksDatabase) + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +@pytest.mark.parametrize('batch_size', [None, 1]) +def test_database_single_tensor(tmp_path, Database, batch_size): + kwargs = dict(path=osp.join(tmp_path, 'storage.db')) + if Database == SQLiteDatabase: + kwargs['name'] = 'test_table' + + db = Database(**kwargs) + assert db.schema == {0: object} + + try: + assert len(db) == 0 + assert str(db) == f'{Database.__name__}(0)' + except NotImplementedError: + assert str(db) == f'{Database.__name__}()' + + data = ops.randn(5) + db.insert(0, data) + try: + assert len(db) == 1 + except NotImplementedError: + pass + assert ops.equal(db.get(0), data).all() + + indices = ms.Tensor([1, 2]) + data_list = ops.randn(2, 5) + db.multi_insert(indices, data_list, batch_size=batch_size) + try: + assert len(db) == 3 + except NotImplementedError: + pass + out_list = db.multi_get(indices, batch_size=batch_size) + assert isinstance(out_list, list) + assert len(out_list) == 2 + assert ops.equal(out_list[0], data_list[0]).all() + assert ops.equal(out_list[1], data_list[1]).all() + + db.close() + + +@pytest.mark.parametrize('Database', AVAILABLE_DATABASES) +def test_database_schema(tmp_path, Database): + kwargs = dict(name='test_table') if Database == SQLiteDatabase else {} + + path = osp.join(tmp_path, 'tuple_storage.db') + schema = (int, float, str, dict(dtype=ms.float32, size=(2, -1)), object) + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 0: int, + 1: float, + 2: str, + 3: TensorInfo(dtype=ms.float32, size=(2, -1)), + 4: object, + } + + data1 = (1, 0.1, 'a', ops.randn(2, 8), Graph(x=ops.randn(8))) + data2 = (2, float('inf'), 'b', ops.randn(2, 16), Graph(x=ops.randn(8))) + data3 = (3, float('NaN'), 'c', ops.randn(2, 32), Graph(x=ops.randn(8))) + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out[0] == data[0] + if math.isnan(data[1]): + assert math.isnan(out[1]) + else: + assert out[1] == data[1] + assert out[2] == data[2] + assert ops.equal(out[3], data[3]).all() + assert isinstance(out[4], Graph) and len(out[4]) == 1 + assert ops.equal(out[4].x, data[4].x).all() + + db.close() + + path = osp.join(tmp_path, 'dict_storage.db') + schema = { + 'int': int, + 'float': float, + 'str': str, + 'tensor': dict(dtype=ms.float32, size=(2, -1)), + 'data': object + } + db = Database(path, schema=schema, **kwargs) + assert db.schema == { + 'int': int, + 'float': float, + 'str': str, + 'tensor': TensorInfo(dtype=ms.float32, size=(2, -1)), + 'data': object, + } + + data1 = { + 'int': 1, + 'float': 0.1, + 'str': 'a', + 'tensor': ops.randn(2, 8), + 'data': Graph(x=ops.randn(1, 8)), + } + data2 = { + 'int': 2, + 'float': 0.2, + 'str': 'b', + 'tensor': ops.randn(2, 16), + 'data': Graph(x=ops.randn(2, 8)), + } + data3 = { + 'int': 3, + 'float': 0.3, + 'str': 'c', + 'tensor': ops.randn(2, 32), + 'data': Graph(x=ops.randn(3, 8)), + } + db.insert(0, data1) + db.multi_insert([1, 2], [data2, data3]) + + out1 = db.get(0) + out2, out3 = db.multi_get([1, 2]) + + for out, data in zip([out1, out2, out3], [data1, data2, data3]): + assert out['int'] == data['int'] + assert out['float'] == data['float'] + assert out['str'] == data['str'] + assert ops.equal(out['tensor'], data['tensor']).all() + assert isinstance(out['data'], Graph) and len(out['data']) == 1 + assert ops.equal(out['data'].x, data['data'].x).all() + + db.close() + + + +@withPackage('sqlite3') +def test_database_syntactic_sugar(tmp_path): + path = osp.join(tmp_path, 'storage.db') + db = SQLiteDatabase(path, name='test_table') + + data = ops.randn(5, 16) + db[0] = data[0] + db[1:3] = data[1:3] + db[ms.Tensor([3, 4])] = data[ms.Tensor([3, 4])] + assert len(db) == 5 + + assert ops.equal(db[0], data[0]).all() + assert ops.equal(mint.stack((db[:3]), dim=0), data[:3]).all() + assert ops.equal(mint.stack((db[3:]), dim=0), data[3:]).all() + assert ops.equal(mint.stack((db[1::2]), dim=0), data[1::2]).all() + assert ops.equal(mint.stack((db[[4, 3]]), dim=0), data[[4, 3]]).all() + assert ops.equal( + mint.stack((db[ms.Tensor([4, 3])]), dim=0), + data[ms.Tensor([4, 3])], + ).all() + assert ops.equal( + mint.stack((db[ms.Tensor([4, 4])]), dim=0), + data[ms.Tensor([4, 4])], + ).all() + + +if __name__ == '__main__': + import argparse + import tempfile + import time + + parser = argparse.ArgumentParser() + parser.add_argument('--numel', type=int, default=100_000) + parser.add_argument('--batch_size', type=int, default=256) + args = parser.parse_args() + + data = ops.randn(args.numel, 128) + tmp_dir = tempfile.TemporaryDirectory() + + path = osp.join(tmp_dir.name, 'sqlite.db') + sqlite_db = SQLiteDatabase(path, name='test_table') + t = time.perf_counter() + sqlite_db.multi_insert(range(args.numel), data, batch_size=100, log=True) + print(f'Initialized SQLiteDB in {time.perf_counter() - t:.2f} seconds') + + path = osp.join(tmp_dir.name, 'rocks.db') + rocks_db = RocksDatabase(path) + t = time.perf_counter() + rocks_db.multi_insert(range(args.numel), data, batch_size=100, log=True) + print(f'Initialized RocksDB in {time.perf_counter() - t:.2f} seconds') + + def in_memory_get(data): + index = ops.randint(0, args.numel, (args.batch_size, )) + return data[index] + + def db_get(db): + index = ops.randint(0, args.numel, (args.batch_size, )) + return db[index] + + benchmark( + funcs=[in_memory_get, db_get, db_get], + func_names=['In-Memory', 'SQLite', 'RocksDB'], + args=[(data, ), (sqlite_db, ), (rocks_db, )], + num_steps=50, + num_warmups=5, + ) + + tmp_dir.cleanup() diff --git a/tests/graph/data/test_dataloader.py b/tests/graph/data/test_dataloader.py new file mode 100644 index 000000000..f94b97a8f --- /dev/null +++ b/tests/graph/data/test_dataloader.py @@ -0,0 +1,172 @@ +import math +import mindspore as ms +import mindspore.dataset as ds +import numpy as np +import random +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import Dataloader +from mindscience.sharker.dataset import QM9 + + +def test_qm9_dataset_graph(): + dataset = QM9(root="./qm9") + dataset_length = len(dataset) + randomSampler = ds.RandomSampler() + + ## test qm9 + Dataset + Graph + drop_remainder=False + randomSampler + random batch_size + epoch = 5 + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset, sampler=randomSampler, shuffle=None) + newloader = newloader.batch(batch_size, drop_remainder=False) + for epoch in range(5): + index = 0 + for data in newloader: + assert isinstance(data, Graph) == True + index = index + 1 + assert dataset_length - (index - 1) * batch_size == len(data.name) + assert index == math.ceil(dataset_length / batch_size) + + ## test qm9 + list + Graph + drop_remainder=True + randomSampler + random batch_size + dataset_list = [graph for graph in dataset] + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset_list, sampler=randomSampler, shuffle=None) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + for data in newloader: + assert isinstance(data, Graph) == True + assert len(data.name) == batch_size + index = index + 1 + assert index == len(newloader) + assert index == math.floor(dataset_length / batch_size) + + ## test qm9 + Dataset + Graph + drop_remainder=False + randomSampler + no .batch + newloader = Dataloader(dataset, sampler=randomSampler, shuffle=None) + index = 0 + for data in newloader: + assert isinstance(data, Graph) + index = index + 1 + assert index == len(newloader) + assert index == dataset_length + + ## test qm9 + list + Graph + drop_remainder=True + randomSampler + no .batch + newloader = Dataloader(dataset_list, sampler=randomSampler, shuffle=None) + index = 0 + for data in newloader: + assert isinstance(data, Graph) + index = index + 1 + assert index == len(newloader) + assert index == dataset_length + + +def test_custom_class(): + class MyMapDataset(): + def __init__(self): + super(MyMapDataset).__init__() + self.data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + dataset = MyMapDataset() + dataset_length = len(dataset) + randomSampler = ds.RandomSampler() + + ## custom iterable class + drop_remainder=True + randomSampler + random batch_size + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset, shuffle=None, column_names=["data"], sampler=randomSampler) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape[0] == batch_size + index = index + 1 + assert index == len(newloader) + assert index == math.floor(dataset_length / batch_size) + + ## custom iterable class + drop_remainder=True + nosampler + shuffle=False + batch_size=1 + batch_size = 1 + newloader = Dataloader(dataset, shuffle=False, column_names=["data"]) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape[0] == batch_size + assert int(data[0]) == index + index = index + 1 + assert index == len(newloader) + assert index == math.floor(dataset_length / batch_size) + + ## custom iterable class + drop_remainder=True + nosampler + shuffle=False + no .batch + newloader = Dataloader(dataset, shuffle=False, column_names=["data"]) + index = 0 + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape == () + assert int(data[0]) == index + index = index + 1 + assert index == len(newloader) + assert index == dataset_length + + +def test_callable(): + def generator_multi_column(): + for i in range(64): + yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]) + + ## custom callable function + batch_size=2 + drop_remainder=False + batch_size = 2 + dataset = Dataloader(source=generator_multi_column, column_names=["col1", "col2"]) + dataset = dataset.batch(batch_size) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert isinstance(data[1], ms.Tensor) + assert len(data) == 2 + assert data[0].shape == (2, 1) + assert data[1].shape == (2, 2, 2) + index = index + 1 + assert index == len(dataset) + + ## custom callable function + drop_remainder=False + no .batch + dataset = Dataloader(source=generator_multi_column, column_names=["col1", "col2"]) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert isinstance(data[1], ms.Tensor) + assert len(data) == 2 + assert data[0].shape == (1,) + assert data[1].shape == (2, 2,) + index = index + 1 + assert index == len(dataset) + + +def test_list(): + input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # custom list + drop_remainder=True + batch_size=3 + batch_size = 3 + dataset = Dataloader(source=input, column_names=["col1"]) + dataset = dataset.batch(batch_size, drop_remainder=True) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert len(data) == 1 + assert data[0].shape == (3,) + index = index + 1 + assert index == len(dataset) + + # custom list + drop_remainder=True + no .batch + dataset = Dataloader(source=input, column_names=["col1"]) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert len(data) == 1 + assert data[0].shape == () + index = index + 1 + assert index == len(dataset) + diff --git a/tests/graph/data/test_dataset.py b/tests/graph/data/test_dataset.py new file mode 100644 index 000000000..b894306b0 --- /dev/null +++ b/tests/graph/data/test_dataset.py @@ -0,0 +1,340 @@ +import copy + +import pytest +import mindspore as ms +from mindspore import ops, mint + +from mindscience.sharker.data import Graph, HeteroGraph, InMemoryDataset +from mindscience.sharker.testing import withPackage + +from mindscience.sharker.typing import SparseTensor + + +class MyTestDataset(InMemoryDataset): + def __init__(self, data_list, transform=None): + super().__init__(None, transform=transform) + self.data, self.slices = self.collate(data_list, True) + + +class MyStoredTestDataset(InMemoryDataset): + def __init__(self, root, data_list, transform=None): + self.data_list = data_list + super().__init__(root, transform=transform) + self.load(self.processed_paths[0], data_cls=data_list[0].__class__) + + @property + def processed_file_names(self) -> str: + return 'data.pt' + + def process(self): + self.save(self.data_list, self.processed_paths[0]) + + +def test_in_memory_dataset(): + x1 = ms.Tensor([[1.0], [1.0], [1.0]]) + x2 = ms.Tensor([[2.0], [2.0], [2.0]]) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + face = ms.Tensor([[0], [1], [2]]) + + data1 = Graph(x1, edge_index, face=face, test_int=1, test_str='1') + data1.num_nodes = 10 + + data2 = Graph(x2, edge_index, face=face, test_int=2, test_str='2') + data2.num_nodes = 5 + + dataset = MyTestDataset([data1, data2]) + assert str(dataset) == 'MyTestDataset(2)' + assert len(dataset) == 2 + + assert len(dataset[0]) == 6 + assert dataset[0].num_nodes == 10 + assert dataset[0].x.tolist() == x1.tolist() + assert dataset[0].edge_index.tolist() == edge_index.tolist() + assert dataset[0].face.tolist() == face.tolist() + assert dataset[0].test_int == 1 + assert dataset[0].test_str == '1' + + assert len(dataset[1]) == 6 + assert dataset[1].num_nodes == 5 + assert dataset[1].x.tolist() == x2.tolist() + assert dataset[1].edge_index.tolist() == edge_index.tolist() + assert dataset[1].face.tolist() == face.tolist() + assert dataset[1].test_int == 2 + assert dataset[1].test_str == '2' + + with pytest.warns(UserWarning, match="internal storage format"): + dataset.data + + assert ops.equal(dataset.x, mint.cat(([x1, x2]), dim=0)).all() + assert dataset.edge_index.tolist() == [ + [0, 1, 1, 2, 10, 11, 11, 12], + [1, 0, 2, 1, 11, 10, 12, 11], + ] + assert ops.equal(dataset[1:].x, x2).all() + + +def test_stored_in_memory_dataset(tmp_path): + x1 = ms.Tensor([[1.0], [1.0], [1.0]]) + x2 = ms.Tensor([[2.0], [2.0], [2.0], [2.0]]) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + data1 = Graph(x1, edge_index, num_nodes=3, test_int=1, test_str='1') + data2 = Graph(x2, edge_index, num_nodes=4, test_int=2, test_str='2') + + dataset = MyStoredTestDataset(tmp_path, [data1, data2]) + assert dataset._data.num_nodes == 7 + assert dataset._data._num_nodes == [3, 4] + data = dataset[0].tensor() + assert ops.equal(data.x, x1).all() + assert ops.equal(data.edge_index, edge_index).all() + assert data.num_nodes == 3 + assert ops.equal(data.test_int, ms.Tensor([1])).all() + assert data.test_str == '1' + + data = dataset[1].tensor() + assert ops.equal(data.x, x2).all() + assert ops.equal(data.edge_index, edge_index).all() + assert data.num_nodes == 4 + assert ops.equal(data.test_int, ms.Tensor([2])).all() + assert data.test_str == '2' + + +def test_stored_hetero_in_memory_dataset(tmp_path): + x1 = ms.Tensor([[1.0], [1.0], [1.0]]) + x2 = ms.Tensor([[2.0], [2.0], [2.0], [2.0]]) + + data1 = HeteroGraph() + data1['paper'].x = x1 + data1['paper'].num_nodes = 3 + + data2 = HeteroGraph() + data2['paper'].x = x2 + data2['paper'].num_nodes = 4 + + dataset = MyStoredTestDataset(tmp_path, [data1, data2]) + assert dataset._data['paper'].num_nodes == 7 + assert dataset._data['paper']._num_nodes == [3, 4] + + data = dataset[0].tensor()['paper'] + assert ops.equal(data.x, x1).all() + assert data.num_nodes == 3 + + data = dataset[1].tensor()['paper'] + assert ops.equal(data.x, x2).all() + assert data.num_nodes == 4 + + +def test_in_memory_num_classes(): + dataset = MyTestDataset([Graph(), Graph()]) + assert dataset.num_classes == 0 + + dataset = MyTestDataset([Graph(y=0), Graph(y=1)]) + assert dataset.num_classes == 2 + + dataset = MyTestDataset([Graph(y=1.5), Graph(y=2.5), Graph(y=3.5)]) + with pytest.warns(UserWarning, match="unique elements"): + assert dataset.num_classes == 3 + + dataset = MyTestDataset([ + Graph(y=ms.Tensor([[0, 1, 0, 1]])), + Graph(y=ms.Tensor([[1, 0, 0, 0]])), + Graph(y=ms.Tensor([[0, 0, 1, 0]])), + ]) + assert dataset.num_classes == 4 + + def transform(data): + copied_data = copy.copy(data) + copied_data.y += 1 + return data, copied_data, 'foo' + + dataset = MyTestDataset([Graph(y=0), Graph(y=1)], transform=transform) + assert dataset.num_classes == 3 + + +def test_in_memory_dataset_copy(): + data_list = [Graph(x=ops.randn(5, 16)) for _ in range(4)] + dataset = MyTestDataset(data_list) + + copied_dataset = dataset.copy() + assert id(copied_dataset) != id(dataset) + + assert len(copied_dataset) == len(dataset) == 4 + for copied_data, data in zip(copied_dataset, dataset): + assert ops.equal(copied_data.x, data.x).all() + + copied_dataset = dataset.copy([1, 2]) + assert len(copied_dataset) == 2 + assert ops.equal(copied_dataset[0].x, data_list[1].x).all() + assert ops.equal(copied_dataset[1].x, data_list[2].x).all() + + +def test_collate_with_new_dimension(): + class MyData(Graph): + def __cat_dim__(self, key, value, *args, **kwargs): + if key == 'foo': + return None + else: + return super().__cat_dim__(key, value, *args, **kwargs) + + x = ms.Tensor([1, 2, 3]).float() + foo = ops.randn(4) + y = ms.Tensor(1) + + data = MyData(x=x, foo=foo, y=y) + + dataset = MyTestDataset([data, data]) + assert str(dataset) == 'MyTestDataset(2)' + assert len(dataset) == 2 + + data1 = dataset[0] + assert len(data1) == 3 + assert data1.x.tolist() == x.tolist() + assert data1.foo.tolist() == foo.tolist() + assert data1.y.tolist() == [1] + + data2 = dataset[0] + assert len(data2) == 3 + assert data2.x.tolist() == x.tolist() + assert data2.foo.tolist() == foo.tolist() + assert data2.y.tolist() == [1] + + +def test_hetero_in_memory_dataset(): + data1 = HeteroGraph() + data1.y = ops.randn(5) + data1['paper'].x = ops.randn(10, 16) + data1['paper', 'paper'].edge_index = ops.randint(0, 10, (2, 30)).long() + + data2 = HeteroGraph() + data2.y = ops.randn(5) + data2['paper'].x = ops.randn(10, 16) + data2['paper', 'paper'].edge_index = ops.randint(0, 10, (2, 30)).long() + + dataset = MyTestDataset([data1, data2]) + assert str(dataset) == 'MyTestDataset(2)' + assert len(dataset) == 2 + + assert len(dataset[0]) == 3 + assert dataset[0].y.tolist() == data1.y.tolist() + assert dataset[0]['paper'].x.tolist() == data1['paper'].x.tolist() + assert (dataset[0]['paper', 'paper'].edge_index.tolist() == data1[ + 'paper', 'paper'].edge_index.tolist()) + + assert len(dataset[1]) == 3 + assert dataset[1].y.tolist() == data2.y.tolist() + assert dataset[1]['paper'].x.tolist() == data2['paper'].x.tolist() + assert (dataset[1]['paper', 'paper'].edge_index.tolist() == data2[ + 'paper', 'paper'].edge_index.tolist()) + + +def test_override_behavior(): + class DS1(InMemoryDataset): + def __init__(self): + self.enter_download = False + self.enter_process = False + super().__init__() + + def _download(self): + self.enter_download = True + + def _process(self): + self.enter_process = True + + def download(self): + pass + + def process(self): + pass + + class DS2(InMemoryDataset): + def __init__(self): + self.enter_download = False + self.enter_process = False + super().__init__() + + def _download(self): + self.enter_download = True + + def _process(self): + self.enter_process = True + + def process(self): + pass + + class DS3(InMemoryDataset): + def __init__(self): + self.enter_download = False + self.enter_process = False + super().__init__() + + def _download(self): + self.enter_download = True + + def _process(self): + self.enter_process = True + + class DS4(DS1): + pass + + ds = DS1() + assert ds.enter_download + assert ds.enter_process + + ds = DS2() + assert not ds.enter_download + assert ds.enter_process + + ds = DS3() + assert not ds.enter_download + assert not ds.enter_process + + ds = DS4() + assert ds.enter_download + assert ds.enter_process + + +def test_lists_of_tensors_in_memory_dataset(): + def tr(n, m): + return ops.rand((n, m)) + + d1 = Graph(xs=[tr(4, 3), tr(11, 4), tr(1, 2)]) + d2 = Graph(xs=[tr(5, 3), tr(14, 4), tr(3, 2)]) + d3 = Graph(xs=[tr(6, 3), tr(15, 4), tr(2, 2)]) + d4 = Graph(xs=[tr(4, 3), tr(16, 4), tr(1, 2)]) + + data_list = [d1, d2, d3, d4] + + dataset = MyTestDataset(data_list) + assert len(dataset) == 4 + assert dataset[0].xs[1].shape == (11, 4) + assert dataset[0].xs[2].shape == (1, 2) + assert dataset[1].xs[0].shape == (5, 3) + assert dataset[2].xs[1].shape == (15, 4) + assert dataset[3].xs[1].shape == (16, 4) + + +def test_file_names_as_property_and_method(): + class MyTestDataset(InMemoryDataset): + def __init__(self): + super().__init__('/tmp/MyTestDataset') + + @property + def raw_file_names(self): + return ['test_file'] + + def download(self): + pass + + MyTestDataset() + + class MyTestDataset(InMemoryDataset): + def __init__(self): + super().__init__('/tmp/MyTestDataset') + + def raw_file_names(self): + return ['test_file'] + + def download(self): + pass + + MyTestDataset() diff --git a/tests/graph/data/test_dataset_summary.py b/tests/graph/data/test_dataset_summary.py new file mode 100644 index 000000000..02d90c4e3 --- /dev/null +++ b/tests/graph/data/test_dataset_summary.py @@ -0,0 +1,99 @@ +import mindspore as ms +from mindspore import Tensor +import numpy as np +from mindscience.sharker.data.summary import Stats, Summary +from mindscience.sharker.datasets import FakeDataset, FakeHeteroDataset +from mindscience.sharker.testing import withPackage + + +def check_stats(stats: Stats, expected: Tensor): + expected = expected.to(ms.single) + assert stats.mean == float(expected.mean()) + assert stats.std == float(expected.std()) + assert stats.min == float(expected.min()) + assert stats.quantile25 == float(np.quantile(expected.asnumpy(), 0.25)) + assert stats.median == float(expected.median()[0]) + assert stats.quantile75 == float(np.quantile(expected.asnumpy(), 0.75)) + assert stats.max == float(expected.max()) + + +def test_dataset_summary(): + dataset = FakeDataset(num_graphs=10) + num_nodes = Tensor([data.num_nodes for data in dataset]) + num_edges = Tensor([data.num_edges for data in dataset]) + + summary = dataset.get_summary() + + assert summary.name == 'FakeDataset' + assert summary.num_graphs == 10 + + check_stats(summary.num_nodes, num_nodes) + check_stats(summary.num_edges, num_edges) + + +@withPackage('tabulate') +def test_dataset_summary_representation(): + dataset = FakeDataset(num_graphs=10) + + summary1 = Summary.from_dataset(dataset, per_type=False) + summary2 = Summary.from_dataset(dataset, per_type=True) + + assert str(summary1) == str(summary2) + + +@withPackage('tabulate') +def test_dataset_summary_hetero(): + dataset1 = FakeHeteroDataset(num_graphs=10) + summary1 = Summary.from_dataset(dataset1, per_type=False) + + dataset2 = [data.to_homogeneous() for data in dataset1] + summary2 = Summary.from_dataset(dataset2) + summary2.name = 'FakeHeteroDataset' + + assert summary1 == summary2 + assert str(summary1) == str(summary2) + + +@withPackage('tabulate') +def test_dataset_summary_hetero_representation_length(): + dataset = FakeHeteroDataset(num_graphs=10) + summary = Summary.from_dataset(dataset) + num_lines = len(str(summary).splitlines()) + + stats_len = len(Stats.__dataclass_fields__) + len_header_and_border = 5 + num_tables = 3 # general, stats per node type, stats per edge type + + assert num_lines == num_tables * (stats_len + len_header_and_border) + + +def test_dataset_summary_hetero_per_type_check(): + dataset = FakeHeteroDataset(num_graphs=10) + exp_num_nodes = Tensor([data.num_nodes for data in dataset]) + exp_num_edges = Tensor([data.num_edges for data in dataset]) + + summary = dataset.get_summary() + + assert summary.name == 'FakeHeteroDataset' + assert summary.num_graphs == 10 + + check_stats(summary.num_nodes, exp_num_nodes) + check_stats(summary.num_edges, exp_num_edges) + + num_nodes_per_type = {} + for node_type in dataset.node_types: + num_nodes = [data[node_type].num_nodes for data in dataset] + num_nodes_per_type[node_type] = Tensor(num_nodes) + + assert len(summary.num_nodes_per_type) == len(dataset.node_types) + for node_type, stats in summary.num_nodes_per_type.items(): + check_stats(stats, num_nodes_per_type[node_type]) + + num_edges_per_type = {} + for edge_type in dataset.edge_types: + num_edges = [data[edge_type].num_edges for data in dataset] + num_edges_per_type[edge_type] = Tensor(num_edges) + + assert len(summary.num_edges_per_type) == len(dataset.edge_types) + for edge_type, stats in summary.num_edges_per_type.items(): + check_stats(stats, num_edges_per_type[edge_type]) diff --git a/tests/graph/data/test_feature_store.py b/tests/graph/data/test_feature_store.py new file mode 100644 index 000000000..1395df8b2 --- /dev/null +++ b/tests/graph/data/test_feature_store.py @@ -0,0 +1,109 @@ +from dataclasses import dataclass + +import pytest +from mindspore import Tensor, ops + +from mindscience.sharker.data import TensorAttr +from mindscience.sharker.data.feature_store import AttrView, _FieldStatus +from mindscience.sharker.testing import MyFeatureStore + + +@dataclass +class MyTensorAttrNoGroupName(TensorAttr): + def __init__(self, attr_name=_FieldStatus.UNSET, index=_FieldStatus.UNSET): + # Treat group_name as optional, and move it to the end + super().__init__(None, attr_name, index) + + +class MyFeatureStoreNoGroupName(MyFeatureStore): + def __init__(self): + super().__init__() + self._tensor_attr_cls = MyTensorAttrNoGroupName + + +def test_feature_store(): + store = MyFeatureStore() + tensor = Tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + + group_name = 'A' + attr_name = 'feat' + index = Tensor([0, 1, 2]) + attr = TensorAttr(group_name, attr_name, index) + assert TensorAttr(group_name).update(attr) == attr + + # Normal API: + store.put_tensor(tensor, attr) + assert ops.equal(store.get_tensor(attr), tensor).all() + assert ops.equal( + store.get_tensor(group_name, attr_name, index=Tensor([0, 2])), + tensor[Tensor([0, 2])], + ).all() + + assert store.update_tensor(tensor + 1, attr) + assert ops.equal(store.get_tensor(attr), tensor + 1).all() + + store.remove_tensor(attr) + with pytest.raises(KeyError): + _ = store.get_tensor(attr) + + # Views: + view = store.view(group_name=group_name) + view.attr_name = attr_name + view['index'] = index + assert view != "not a 'AttrView' object" + assert view == AttrView(store, TensorAttr(group_name, attr_name, index)) + assert str(view) == ("AttrView(store=MyFeatureStore(), " + "attr=TensorAttr(group_name='A', attr_name='feat', " + "index=Tensor(shape=[3], dtype=Int64, value= [0, 1, 2])))") + + # Indexing: + store[group_name, attr_name, index] = tensor + + # Fully-specified forms, all of which produce a tensor output + assert ops.equal(store[group_name, attr_name, index], tensor).all() + assert ops.equal(store[group_name, attr_name, None], tensor).all() + assert ops.equal(store[group_name, attr_name, :], tensor).all() + assert ops.equal(store[group_name][attr_name][:], tensor).all() + assert ops.equal(store[group_name].feat[:], tensor).all() + assert ops.equal(store.view().A.feat[:], tensor).all() + + with pytest.raises(AttributeError) as exc_info: + _ = store.view(group_name=group_name, index=None).feat.A + print(exc_info) + + # Partially-specified forms, which produce an AttrView object + assert store[group_name] == store.view(TensorAttr(group_name=group_name)) + assert store[group_name].feat == store.view( + TensorAttr(group_name=group_name, attr_name=attr_name)) + + # Partially-specified forms, when called, produce a Tensor output + # from the `TensorAttr` that has been partially specified. + store[group_name] = tensor + assert isinstance(store[group_name], AttrView) + assert ops.equal(store[group_name](), tensor).all() + + # Deletion: + del store[group_name, attr_name, index] + with pytest.raises(KeyError): + _ = store[group_name, attr_name, index] + del store[group_name] + with pytest.raises(KeyError): + _ = store[group_name]() + + +def test_feature_store_override(): + store = MyFeatureStoreNoGroupName() + tensor = Tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0], [2.0, 2.0, 2.0]]) + + attr_name = 'feat' + index = Tensor([0, 1, 2]) + + # Only use attr_name and index, in that order: + store[attr_name, index] = tensor + + # A few assertions to ensure group_name is not needed: + assert isinstance(store[attr_name], AttrView) + assert ops.equal(store[attr_name, index], tensor).all() + assert ops.equal(store[attr_name][index], tensor).all() + assert ops.equal(store[attr_name][:], tensor).all() + assert ops.equal(store[attr_name, :], tensor).all() diff --git a/tests/graph/data/test_graph_store.py b/tests/graph/data/test_graph_store.py new file mode 100644 index 000000000..128fe9e79 --- /dev/null +++ b/tests/graph/data/test_graph_store.py @@ -0,0 +1,96 @@ +import pytest + +from mindspore import Tensor, ops +from mindscience.sharker.data.graph_store import EdgeAttr +from mindscience.sharker.sparse import Layout +from mindscience.sharker.testing import ( + MyGraphStore, + get_random_edge_index +) +from mindscience.sharker.utils import to_coo +from mindscience.sharker.sparse import ind2ptr + + +def test_graph_store(): + graph_store = MyGraphStore() + + assert str(graph_store) == 'MyGraphStore()' + + coo = Tensor([0, 1]), Tensor([1, 2]) + csr = Tensor([0, 1, 2]), Tensor([1, 2]) + # csc = Tensor([0, 1]), Tensor([0, 0, 1, 2]) + + graph_store['edge_type', Layout.COO] = coo + graph_store['edge_type', Layout.CSR] = csr + # graph_store['edge_type', Layout.CSC] = csc + + assert ops.equal(graph_store['edge_type', Layout.COO][0], coo[0]).all() + assert ops.equal(graph_store['edge_type', Layout.COO][1], coo[1]).all() + assert ops.equal(graph_store['edge_type', Layout.CSR][0], csr[0]).all() + assert ops.equal(graph_store['edge_type', Layout.CSR][1], csr[1]).all() + # assert ops.equal(graph_store['edge_type', Layout.CSC][0], csc[0]).all() + # assert ops.equal(graph_store['edge_type', Layout.CSC][1], csc[1]).all() + + assert len(graph_store.get_all_edge_attrs()) == 2 + + del graph_store['edge_type', Layout.COO] + with pytest.raises(KeyError): + graph_store['edge_type', Layout.COO] + + with pytest.raises(KeyError): + graph_store['edge_type_2', Layout.COO] + + +def test_graph_store_conversion(): + graph_store = MyGraphStore() + + edge_index = get_random_edge_index(100, 100, 300) + adj = to_coo(edge_index, shape=(100, 100)) + coo = (adj.indices[:, 0], adj.indices[:, 1]) + csr = (ind2ptr(adj.indices[:, 0]), adj.indices[:, 1]) + csc = (adj.indices[:, 0], ind2ptr(adj.indices[:, 1])) + + graph_store.put_edge_index(coo, ('v', '1', 'v'), Layout.COO, shape=(100, 100)) + graph_store.put_edge_index(csr, ('v', '2', 'v'), Layout.CSR, shape=(100, 100)) + graph_store.put_edge_index(csc, ('v', '3', 'v'), Layout.CSC, shape=(100, 100)) + + # Convert to COO: + row_dict, col_dict, perm_dict = graph_store.coo() + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 + for row, col, perm in zip(row_dict.values(), col_dict.values(), perm_dict.values()): + assert ops.equal(row.sort()[0], coo[0].sort()[0]).all() + assert ops.equal(col.sort()[0], coo[1].sort()[0]).all() + assert perm is None + + # Convert to CSR: + row_dict, col_dict, perm_dict = graph_store.csr() + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 + for row, col in zip(row_dict.values(), col_dict.values()): + assert ops.equal(row, csr[0]).all() + assert ops.equal(col.sort()[0], csr[1].sort()[0]).all() + + # Convert to CSC: + row_dict, col_dict, perm_dict = graph_store.csc() + assert len(row_dict) == len(col_dict) == len(perm_dict) == 3 + for row, col in zip(row_dict.values(), col_dict.values()): + assert ops.equal(row.sort()[0], csc[0].sort()[0]).all() + assert ops.equal(col, csc[1]).all() + + # Ensure that 'edge_types' parameters work as intended: + out = graph_store.coo([('v', '1', 'v')]) + assert ops.equal(list(out[0].values())[0], coo[0]).all() + assert ops.equal(list(out[1].values())[0], coo[1]).all() + + # Ensure that 'store' parameter works as intended: + key = EdgeAttr(edge_type=('v', '1', 'v'), layout=Layout.CSR, + is_sorted=False, shape=(100, 100)) + with pytest.raises(KeyError): + graph_store[key] + + out = graph_store.csr([('v', '1', 'v')], store=True) + assert ops.equal(list(out[0].values())[0], csr[0]).all() + assert ops.equal(list(out[1].values())[0].sort()[0], csr[1].sort()[0]).all() + + out = graph_store[key] + assert ops.equal(out[0], csr[0]).all() + assert ops.equal(out[1].sort()[0], csr[1].sort()[0]).all() diff --git a/tests/graph/data/test_hetero.py b/tests/graph/data/test_hetero.py new file mode 100644 index 000000000..7a7aa5ba0 --- /dev/null +++ b/tests/graph/data/test_hetero.py @@ -0,0 +1,625 @@ +import copy + +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.data.storage import EdgeStorage +from mindscience.sharker.testing import ( + get_random_edge_index, + withPackage, +) + +x_paper = ops.randn(10, 16) +x_author = ops.randn(5, 32) +x_conference = ops.randn(5, 8) + +idx_paper = ops.randint(0, x_paper.shape[0], (100, ), dtype=ms.int64) +idx_author = ops.randint(0, x_author.shape[0], (100, ), dtype=ms.int64) +idx_conference = ops.randint(0, x_conference.shape[0], (100, ), dtype=ms.int64) + +edge_index_paper_paper = mint.stack(([idx_paper[:50], idx_paper[:50]]), dim=0) +edge_index_paper_author = mint.stack(([idx_paper[:30], idx_author[:30]]), dim=0) +edge_index_author_paper = mint.stack(([idx_author[:30], idx_paper[:30]]), dim=0) +edge_index_paper_conference = mint.stack( + [idx_paper[:25], idx_conference[:25]], dim=0) + +edge_attr_paper_paper = ops.randn(edge_index_paper_paper.shape[1], 8) +edge_attr_author_paper = ops.randn(edge_index_author_paper.shape[1], 8) + + +def test_init_hetero_data(): + data = HeteroGraph() + data['v1'].x = 1 + data['paper'].x = x_paper + data['author'].x = x_author + data['paper', 'paper'].edge_index = edge_index_paper_paper + data['paper', 'author'].edge_index = edge_index_paper_author + data['author', 'paper'].edge_index = edge_index_author_paper + with pytest.warns(UserWarning, match="{'v1'} are isolated"): + data.validate(raise_on_error=True) + + assert len(data) == 2 + assert data.node_types == ['v1', 'paper', 'author'] + assert len(data.node_stores) == 3 + assert len(data.node_items()) == 3 + assert len(data.edge_types) == 3 + assert len(data.edge_stores) == 3 + assert len(data.edge_items()) == 3 + + data = HeteroGraph( + v1={'x': 1}, + paper={'x': x_paper}, + author={'x': x_author}, + paper__paper={'edge_index': edge_index_paper_paper}, + paper__author={'edge_index': edge_index_paper_author}, + author__paper={'edge_index': edge_index_author_paper}, + ) + + assert len(data) == 2 + assert data.node_types == ['v1', 'paper', 'author'] + assert len(data.node_stores) == 3 + assert len(data.node_items()) == 3 + assert len(data.edge_types) == 3 + assert len(data.edge_stores) == 3 + assert len(data.edge_items()) == 3 + + data = HeteroGraph({ + 'v1': { + 'x': 1 + }, + 'paper': { + 'x': x_paper + }, + 'author': { + 'x': x_author + }, + ('paper', 'paper'): { + 'edge_index': edge_index_paper_paper + }, + ('paper', 'author'): { + 'edge_index': edge_index_paper_author + }, + ('author', 'paper'): { + 'edge_index': edge_index_author_paper + }, + }) + + assert len(data) == 2 + assert data.node_types == ['v1', 'paper', 'author'] + assert len(data.node_stores) == 3 + assert len(data.node_items()) == 3 + assert len(data.edge_types) == 3 + assert len(data.edge_stores) == 3 + assert len(data.edge_items()) == 3 + + +def test_hetero_data_to_from_dict(): + data = HeteroGraph() + data.global_id = '1' + data['v1'].x = ops.randn(5, 16) + data['v2'].y = ops.randn(4, 16) + data['v1', 'v2'].edge_index = ms.Tensor([[0, 1, 2, 3], [0, 1, 2, 3]]) + + out = HeteroGraph.from_dict(data.to_dict()) + assert out.global_id == data.global_id + assert ops.equal(out['v1'].x, data['v1'].x).all() + assert ops.equal(out['v2'].y, data['v2'].y).all() + assert ops.equal(out['v1', 'v2'].edge_index, data['v1', 'v2'].edge_index).all() + + +def test_hetero_data_functions(): + data = HeteroGraph() + with pytest.raises(KeyError, match="did not find any occurrences of it"): + data.collect('x') + data['paper'].x = x_paper + data['author'].x = x_author + data['paper', 'paper'].edge_index = edge_index_paper_paper + data['paper', 'author'].edge_index = edge_index_paper_author + data['author', 'paper'].edge_index = edge_index_author_paper + data['paper', 'paper'].edge_attr = edge_attr_paper_paper + assert len(data) == 3 + assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x'] + assert 'x' in data and 'edge_index' in data and 'edge_attr' in data + assert data.num_nodes == 15 + assert data.num_edges == 110 + + assert data.node_attrs() == ['x'] + assert sorted(data.edge_attrs()) == ['edge_attr', 'edge_index'] + + assert data.num_node_features == {'paper': 16, 'author': 32} + assert data.num_edge_features == { + ('paper', 'to', 'paper'): 8, + ('paper', 'to', 'author'): 0, + ('author', 'to', 'paper'): 0, + } + + node_types, edge_types = data.metadata() + assert node_types == ['paper', 'author'] + assert edge_types == [ + ('paper', 'to', 'paper'), + ('paper', 'to', 'author'), + ('author', 'to', 'paper'), + ] + + x_dict = data.collect('x') + assert len(x_dict) == 2 + assert x_dict['paper'].tolist() == x_paper.tolist() + assert x_dict['author'].tolist() == x_author.tolist() + assert x_dict == data.x_dict + + data.y = 0 + assert data['y'] == 0 and data.y == 0 + assert len(data) == 4 + assert sorted(data.keys()) == ['edge_attr', 'edge_index', 'x', 'y'] + + del data['paper', 'author'] + node_types, edge_types = data.metadata() + assert node_types == ['paper', 'author'] + assert edge_types == [('paper', 'to', 'paper'), ('author', 'to', 'paper')] + + assert len(data.to_dict()) == 5 + assert len(data.to_namedtuple()) == 5 + assert data.to_namedtuple().y == 0 + assert len(data.to_namedtuple().paper) == 1 + + +def test_hetero_data_set_value_dict(): + data = HeteroGraph() + data.set_value_dict('x', { + 'paper': ops.randn(4, 16), + 'author': ops.randn(8, 32), + }) + assert data.node_types == ['paper', 'author'] + assert data.edge_types == [] + assert data['paper'].x.shape == (4, 16) + assert data['author'].x.shape == (8, 32) + + +def test_hetero_data_rename(): + data = HeteroGraph() + data['paper'].x = x_paper + data['author'].x = x_author + data['paper', 'paper'].edge_index = edge_index_paper_paper + data['paper', 'author'].edge_index = edge_index_paper_author + data['author', 'paper'].edge_index = edge_index_author_paper + + data = data.rename('paper', 'article') + assert data.node_types == ['author', 'article'] + assert data.edge_types == [ + ('article', 'to', 'article'), + ('article', 'to', 'author'), + ('author', 'to', 'article'), + ] + + assert data['article'].x.tolist() == x_paper.tolist() + edge_index = data['article', 'article'].edge_index + assert edge_index.tolist() == edge_index_paper_paper.tolist() + + +def test_dangling_types(): + data = HeteroGraph() + data['src', 'to', 'dst'].edge_index = ops.randint(0, 10, (2, 20)) + with pytest.raises(ValueError, match="do not exist as node types"): + data.validate() + + data = HeteroGraph() + data['node'].num_nodes = 10 + with pytest.warns(UserWarning, match="{'node'} are isolated"): + data.validate() + + +def test_hetero_data_subgraph(): + data = HeteroGraph() + data.num_node_types = 3 + data['paper'].x = x_paper + data['paper'].name = 'paper' + data['paper'].num_nodes = x_paper.shape[0] + data['author'].x = x_author + data['author'].num_nodes = x_author.shape[0] + data['conf'].x = x_conference + data['conf'].num_nodes = x_conference.shape[0] + data['paper', 'paper'].edge_index = edge_index_paper_paper + data['paper', 'paper'].edge_attr = edge_attr_paper_paper + data['paper', 'paper'].name = 'cites' + data['author', 'paper'].edge_index = edge_index_author_paper + data['paper', 'author'].edge_index = edge_index_paper_author + data['paper', 'conf'].edge_index = edge_index_paper_conference + + subset = { + 'paper': ops.shuffle(mint.arange(x_paper.shape[0]))[:4], + 'author': ops.shuffle(mint.arange(x_author.shape[0]))[:2], + 'conf': ops.shuffle(mint.arange(x_conference.shape[0]))[:2], + } + + out = data.subgraph(subset) + out.validate(raise_on_error=True) + + assert out.num_node_types == data.num_node_types + assert out.node_types == ['paper', 'author', 'conf'] + + for key in out.node_types: + assert len(out[key]) == len(data[key]) + assert mint.isclose(out[key].x, data[key].x[subset[key]]).all() + assert out[key].num_nodes == subset[key].shape[0] + if key == 'paper': + assert out['paper'].name == 'paper' + + # Construct correct edge index manually: + node_mask = {} # for each node type a mask of nodes in the subgraph + node_map = {} # for each node type a map from old node id to new node id + for key in out.node_types: + # node_mask[key] = mint.zeros((data[key].num_nodes, )).bool() + node_mask[key] = mint.zeros((data[key].num_nodes, ), dtype=ms.bool_) + node_map[key] = mint.zeros((data[key].num_nodes, ), dtype=ms.int64) + node_mask[key][subset[key]] = True + node_map[key][subset[key]] = mint.arange(subset[key].shape[0]) + + edge_mask = {} # for each edge type a mask of edges in the subgraph + subgraph_edge_index = { + } # for each edge type the edge index of the subgraph + for key in out.edge_types: + edge_mask[key] = mint.logical_and(mint.index_select(node_mask[key[0]], 0, data[key].edge_index[0]), + mint.index_select(node_mask[key[-1]], 0, data[key].edge_index[1])) + subgraph_edge_index[key] = ops.masked_select(data[key].edge_index, edge_mask[key]).view(2, -1) + subgraph_edge_index[key][0] = node_map[key[0]][subgraph_edge_index[key] + [0]] + subgraph_edge_index[key][1] = node_map[key[-1]][ + subgraph_edge_index[key][1]] + + assert out.edge_types == [ + ('paper', 'to', 'paper'), + ('author', 'to', 'paper'), + ('paper', 'to', 'author'), + ('paper', 'to', 'conf'), + ] + + for key in out.edge_types: + assert len(out[key]) == len(data[key]) + assert ops.equal(out[key].edge_index, subgraph_edge_index[key]).all() + if key == ('paper', 'to', 'paper'): + assert mint.isclose(out[key].edge_attr, data[key].edge_attr[edge_mask[key]]).all() + assert out[key].name == 'cites' + + # Test for bool and long in `subset_dict`. + author_mask = mint.zeros((x_author.shape[0], )).bool() + author_mask[subset['author']] = True + subset_mixed = { + 'paper': subset['paper'], + 'author': author_mask, + } + out = data.subgraph(subset_mixed) + out.validate(raise_on_error=True) + + assert out.num_node_types == data.num_node_types + assert out.node_types == ['paper', 'author', 'conf'] + assert out['paper'].num_nodes == subset['paper'].shape[0] + assert out['author'].num_nodes == subset['author'].shape[0] + assert out['conf'].num_nodes == data['conf'].num_nodes + assert out.edge_types == [ + ('paper', 'to', 'paper'), + ('author', 'to', 'paper'), + ('paper', 'to', 'author'), + ('paper', 'to', 'conf'), + ] + + out = data.node_type_subgraph(['paper', 'author']) + assert out.node_types == ['paper', 'author'] + assert out.edge_types == [('paper', 'to', 'paper'), + ('author', 'to', 'paper'), + ('paper', 'to', 'author')] + + out = data.edge_type_subgraph([('paper', 'author')]) + assert out.node_types == ['paper', 'author'] + assert out.edge_types == [('paper', 'to', 'author')] + + subset = { + ('paper', 'to', 'paper'): mint.arange(4), + } + + out = data.edge_subgraph(subset) + assert out.node_types == data.node_types + assert out.edge_types == data.edge_types + assert data['paper'] == out['paper'] + assert data['author'] == out['author'] + assert data['paper', 'author'] == out['paper', 'author'] + assert data['author', 'paper'] == out['author', 'paper'] + + assert out['paper', 'paper'].num_edges == 4 + assert out['paper', 'paper'].edge_index.shape == (2, 4) + assert out['paper', 'paper'].edge_attr.shape == (4, 8) + + +def test_hetero_data_empty_subgraph(): + data = HeteroGraph() + data.num_node_types = 3 + data['paper'].x = mint.arange(5) + data['author'].x = mint.arange(5) + data['paper', 'author'].edge_weight = mint.arange(5) + + out = data.subgraph(subset_dict={ + 'paper': ms.Tensor([1, 2, 3]), + 'author': ms.Tensor([1, 2, 3]), + }) + + assert ops.equal(out['paper'].x, mint.arange(1, 4)).all() + assert out['paper'].num_nodes == 3 + assert ops.equal(out['author'].x, mint.arange(1, 4)).all() + assert out['author'].num_nodes == 3 + assert 'edge_index' not in out['paper', 'author'] + assert ops.equal(out['paper', 'author'].edge_weight, mint.arange(5)).all() + + +def test_copy_hetero_data(): + data = HeteroGraph() + data['paper'].x = x_paper + data['paper', 'to', 'paper'].edge_index = edge_index_paper_paper + + out = copy.copy(data) + assert id(data) != id(out) + assert len(data.stores) == len(out.stores) + for store1, store2 in zip(data.stores, out.stores): + assert id(store1) != id(store2) + assert id(data) == id(store1._parent()) + assert id(out) == id(store2._parent()) + assert out['paper']._key == 'paper' + assert data['paper'].x is out['paper'].x + assert out['to']._key == ('paper', 'to', 'paper') + assert data['to'].edge_index is out['to'].edge_index + + out = copy.deepcopy(data) + assert id(data) != id(out) + assert len(data.stores) == len(out.stores) + for store1, store2 in zip(data.stores, out.stores): + assert id(store1) != id(store2) + assert id(out) == id(out['paper']._parent()) + assert out['paper']._key == 'paper' + assert data['paper'].x is not out['paper'].x + assert data['paper'].x.tolist() == out['paper'].x.tolist() + assert id(out) == id(out['to']._parent()) + assert out['to']._key == ('paper', 'to', 'paper') + assert data['to'].edge_index is not out['to'].edge_index + assert data['to'].edge_index.tolist() == out['to'].edge_index.tolist() + + +def test_to_homogeneous_and_vice_versa(): + data = HeteroGraph() + + data['paper'].x = ops.randn(100, 128) + data['paper'].y = ops.randint(0, 10, (100, )) + data['author'].x = ops.randn(200, 128) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 250) + data['paper', 'paper'].edge_weight = ops.randn(250, ) + data['paper', 'paper'].edge_attr = ops.randn(250, 64) + + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500) + data['paper', 'author'].edge_weight = ops.randn(500, ) + data['paper', 'author'].edge_attr = ops.randn(500, 64) + + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + data['author', 'paper'].edge_weight = ops.randn(1000, ) + data['author', 'paper'].edge_attr = ops.randn(1000, 64) + + out = data.to_homogeneous() + assert len(out) == 7 + assert out.num_nodes == 300 + assert out.num_edges == 1750 + assert out.num_node_features == 128 + assert out.num_edge_features == 64 + assert out.node_type.shape == (300, ) + assert out.node_type.min() == 0 + assert out.node_type.max() == 1 + assert out.edge_type.shape == (1750, ) + assert out.edge_type.min() == 0 + assert out.edge_type.max() == 2 + assert len(out._node_type_names) == 2 + assert len(out._edge_type_names) == 3 + assert out.y.shape == (300, ) + assert mint.isclose(out.y[:100], data['paper'].y).all() + assert mint.all(out.y[100:] == -1) + assert 'y' not in data['author'] + + out = out.to_hetero() + assert len(out) == 5 + assert mint.isclose(data['paper'].x, out['paper'].x).all() + assert mint.isclose(data['author'].x, out['author'].x).all() + assert mint.isclose(data['paper'].y, out['paper'].y).all() + + edge_index1 = data['paper', 'paper'].edge_index + edge_index2 = out['paper', 'paper'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['paper', 'paper'].edge_weight, + out['paper', 'paper'].edge_weight, + ).all() + assert mint.isclose( + data['paper', 'paper'].edge_attr, + out['paper', 'paper'].edge_attr, + ).all() + + edge_index1 = data['paper', 'author'].edge_index + edge_index2 = out['paper', 'author'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['paper', 'author'].edge_weight, + out['paper', 'author'].edge_weight, + ).all() + assert mint.isclose( + data['paper', 'author'].edge_attr, + out['paper', 'author'].edge_attr, + ).all() + + edge_index1 = data['author', 'paper'].edge_index + edge_index2 = out['author', 'paper'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['author', 'paper'].edge_weight, + out['author', 'paper'].edge_weight, + ).all() + assert mint.isclose( + data['author', 'paper'].edge_attr, + out['author', 'paper'].edge_attr, + ).all() + + out = data.to_homogeneous() + node_type = out.node_type + edge_type = out.edge_type + del out.node_type + del out.edge_type + del out._edge_type_names + del out._node_type_names + out = out.to_hetero(node_type, edge_type) + assert len(out) == 5 + assert mint.isclose(data['paper'].x, out['0'].x).all() + assert mint.isclose(data['author'].x, out['1'].x).all() + + edge_index1 = data['paper', 'paper'].edge_index + edge_index2 = out['0', '0'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['paper', 'paper'].edge_weight, + out['0', '0'].edge_weight, + ).all() + assert mint.isclose( + data['paper', 'paper'].edge_attr, + out['0', '0'].edge_attr, + ).all() + + edge_index1 = data['paper', 'author'].edge_index + edge_index2 = out['0', '1'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['paper', 'author'].edge_weight, + out['0', '1'].edge_weight, + ).all() + assert mint.isclose( + data['paper', 'author'].edge_attr, + out['0', '1'].edge_attr, + ).all() + + edge_index1 = data['author', 'paper'].edge_index + edge_index2 = out['1', '0'].edge_index + assert edge_index1.tolist() == edge_index2.tolist() + assert mint.isclose( + data['author', 'paper'].edge_weight, + out['1', '0'].edge_weight, + ).all() + assert mint.isclose( + data['author', 'paper'].edge_attr, + out['1', '0'].edge_attr, + ).all() + + data = HeteroGraph() + + data['paper'].num_nodes = 100 + data['author'].num_nodes = 200 + + out = data.to_homogeneous(add_node_type=False) + assert len(out) == 1 + assert out.num_nodes == 300 + + out = data.to_homogeneous().to_hetero() + assert len(out) == 1 + assert out['paper'].num_nodes == 100 + assert out['author'].num_nodes == 200 + + +def test_to_homogeneous_padding(): + data = HeteroGraph() + data['paper'].x = ops.randn(100, 128) + data['author'].x = ops.randn(50, 64) + + out = data.to_homogeneous() + assert len(out) == 2 + assert out.node_type.shape == (150, ) + assert out.node_type[:100].abs().sum() == 0 + assert out.node_type[100:].sub(1).abs().sum() == 0 + assert out.x.shape == (150, 128) + assert ops.equal(out.x[:100], data['paper'].x).all() + assert ops.equal(out.x[100:, :64], data['author'].x).all() + assert out.x[100:, 64:].abs().sum() == 0 + + +def test_hetero_data_to_canonical(): + data = HeteroGraph() + assert isinstance(data['user', 'product'], EdgeStorage) + assert len(data.edge_types) == 1 + assert isinstance(data['user', 'to', 'product'], EdgeStorage) + assert len(data.edge_types) == 1 + + data = HeteroGraph() + assert isinstance(data['user', 'buys', 'product'], EdgeStorage) + assert isinstance(data['user', 'clicks', 'product'], EdgeStorage) + assert len(data.edge_types) == 2 + + with pytest.raises(TypeError, match="missing 1 required"): + data['user', 'product'] + + +def test_hetero_data_invalid_names(): + data = HeteroGraph() + with pytest.warns(UserWarning, match="single underscores"): + data['my test', 'a__b', 'my test'].edge_attr = ops.randn(10, 16) + assert data.edge_types == [('my test', 'a__b', 'my test')] + + +def test_hetero_data_update(): + data = HeteroGraph() + data['paper'].x = mint.arange(0, 5) + data['paper'].y = mint.arange(5, 10) + data['author'].x = mint.arange(10, 15) + + other = HeteroGraph() + other['paper'].x = mint.arange(15, 20) + other['author'].y = mint.arange(20, 25) + other['paper', 'paper'].edge_index = ops.randint(0, 5, (2, 20)) + + data.update(other) + assert len(data) == 3 + assert ops.equal(data['paper'].x, mint.arange(15, 20)).all() + assert ops.equal(data['paper'].y, mint.arange(5, 10)).all() + assert ops.equal(data['author'].x, mint.arange(10, 15)).all() + assert ops.equal(data['author'].y, mint.arange(20, 25)).all() + assert ops.equal(data['paper', 'paper'].edge_index, + other['paper', 'paper'].edge_index).all() + +def test_generate_ids(): + data = HeteroGraph() + + data['paper'].x = ops.randn(100, 128) + data['author'].x = ops.randn(200, 128) + + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 300) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400) + assert len(data) == 2 + + data.generate_ids() + assert len(data) == 4 + assert data['paper'].n_id.tolist() == list(range(100)) + assert data['author'].n_id.tolist() == list(range(200)) + assert data['paper', 'author'].e_id.tolist() == list(range(300)) + assert data['author', 'paper'].e_id.tolist() == list(range(400)) + + +def test_invalid_keys(): + data = HeteroGraph() + + data['paper'].x = ops.randn(10, 128) + data['paper'].node_attrs = ['y'] + data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 20) + data['paper', 'paper'].edge_attrs = ['edge_attr'] + + assert data['paper'].node_attrs() == ['x'] + assert data['paper']['node_attrs'] == ['y'] + assert data['paper', 'paper'].edge_attrs() == ['edge_index'] + assert data['paper', 'paper']['edge_attrs'] == ['edge_attr'] + + out = data.to_homogeneous() + assert set(out.node_attrs()) == {'x', 'node_type'} + assert set(out.edge_attrs()) == {'edge_index', 'edge_type'} + + +if __name__ == '__main__': + test_hetero_data_subgraph() + diff --git a/tests/graph/data/test_hypergraph.py b/tests/graph/data/test_hypergraph.py new file mode 100644 index 000000000..53110df97 --- /dev/null +++ b/tests/graph/data/test_hypergraph.py @@ -0,0 +1,156 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data.hypergraph import HyperGraph +from mindscience.sharker.loader import Dataloader + + +def test_hypergraph_data(): + + x = ms.Tensor([[1, 3, 5, 7], [2, 4, 6, 8], [7, 8, 9, 10]], + dtype=ms.float32).t() + edge_index = ms.Tensor([[0, 1, 2, 1, 2, 3, 0, 2, 3], + [0, 0, 0, 1, 1, 1, 2, 2, 2]]) + data = HyperGraph(x=x, edge_index=edge_index) + data.validate(raise_on_error=True) + + assert data.num_nodes == 4 + assert data.num_edges == 3 + + assert data.node_attrs() == ['x'] + assert data.edge_attrs() == ['edge_index'] + + assert data.x.tolist() == x.tolist() + assert data['x'].tolist() == x.tolist() + assert data.get('x').tolist() == x.tolist() + assert data.get('y', 2) == 2 + assert data.get('y', None) is None + + assert sorted(data.keys()) == ['edge_index', 'x'] + assert len(data) == 2 + assert 'x' in data and 'edge_index' in data and 'pos' not in data + + D = data.to_dict() + assert len(D) == 2 + assert 'x' in D and 'edge_index' in D + + D = data.to_namedtuple() + assert len(D) == 2 + assert D.x is not None and D.edge_index is not None + + assert data.__cat_dim__('x', data.x) == 0 + assert data.__cat_dim__('edge_index', data.edge_index) == -1 + assert data.__inc__('x', data.x) == 0 + assert ops.equal(data.__inc__('edge_index', data.edge_index), + ms.Tensor([[data.num_nodes], [data.num_edges]])).all() + data_list = [data, data] + loader = Dataloader(data_list) + loader = loader.batch(batch_size=2) + batch = next(iter(loader)) + batched_edge_index = batch.edge_index + assert batched_edge_index.tolist() == [[ + 0, 1, 2, 1, 2, 3, 0, 2, 3, 4, 5, 6, 5, 6, 7, 4, 6, 7 + ], [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5]] + + assert not data.is_coalesced() + data = data.coalesce() + assert data.is_coalesced() + + clone = data.copy() + assert clone != data + assert len(clone) == len(data) + assert clone.x is not data.x + assert clone.x.tolist() == data.x.tolist() + assert clone.edge_index is not data.edge_index + assert clone.edge_index.tolist() == data.edge_index.tolist() + + data['x'] = x + 1 + assert data.x.tolist() == (x + 1).tolist() + + assert str(data) == 'HyperGraph(x=[4, 3], edge_index=[2, 9])' + + dictionary = {'x': data.x, 'edge_index': data.edge_index} + data = HyperGraph.from_dict(dictionary) + assert sorted(data.keys()) == ['edge_index', 'x'] + + assert not data.has_isolated_nodes() + + assert data.num_nodes == 4 + assert data.num_edges == 3 + assert data.num_node_features == 3 + assert data.num_features == 3 + + data.edge_attr = ops.randn(data.num_edges, 2) + assert data.num_edge_features == 2 + assert data.is_edge_attr('edge_attr') + data.edge_attr = None + + data.x = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes == 4 + + data.edge_index = None + with pytest.warns(UserWarning, match='Unable to accurately infer'): + assert data.num_nodes is None + assert data.num_edges == 0 + + data.num_nodes = 4 + assert data.num_nodes == 4 + + data = HyperGraph(x=x, attribute=x) + assert len(data) == 2 + assert data.x.tolist() == x.tolist() + assert data.attribute.tolist() == x.tolist() + + face = ms.Tensor([[0, 1], [1, 2], [2, 3]]) + data = HyperGraph(num_nodes=4, face=face) + assert data.num_nodes == 4 + + data = HyperGraph(title='test') + assert str(data) == "HyperGraph(title='test')" + assert data.num_node_features == 0 + + key = value = 'test_value' + data[key] = value + assert data[key] == value + del data[value] + del data[value] # Deleting unset attributes should work as well. + + assert data.get(key) is None + assert data.get('title') == 'test' + + +def test_hypergraph_subgraph(): + x = mint.arange(5) + y = ms.Tensor([0.]) + edge_index = ms.Tensor([[0, 1, 3, 2, 4, 0, 3, 4, 2, 1, 2, 3], + [0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3]]) + edge_attr = ops.rand(4, 2) + data = HyperGraph(x=x, y=y, edge_index=edge_index, edge_attr=edge_attr, + num_nodes=5) + + out = data.subgraph(ms.Tensor([1, 2, 4])) + assert len(out) == 5 + assert ops.equal(out.x, ms.Tensor([1, 2, 4])).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[1, 2, 2, 1, 0, 1], [0, 0, 1, 1, 2, 2]] + assert ops.equal(out.edge_attr, edge_attr[[1, 2, 3]]).all() + assert out.num_nodes == 3 + + # Test unordered selection: + out = data.subgraph(ms.Tensor([3, 1, 2])) + assert len(out) == 5 + assert ops.equal(out.x, ms.Tensor([3, 1, 2])).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[0, 2, 0, 2, 1, 2, 0], + [0, 0, 1, 1, 2, 2, 2]] + assert ops.equal(out.edge_attr, edge_attr[[1, 2, 3]]).all() + assert out.num_nodes == 3 + + out = data.subgraph(ms.Tensor([False, False, False, True, True])) + assert len(out) == 5 + assert ops.equal(out.x, mint.arange(3, 5)).all() + assert ops.equal(out.y, data.y).all() + assert out.edge_index.tolist() == [[0, 1, 0, 1], [0, 0, 1, 1]] + assert ops.equal(out.edge_attr, edge_attr[[1, 2]]).all() + assert out.num_nodes == 2 diff --git a/tests/graph/data/test_inherit.py b/tests/graph/data/test_inherit.py new file mode 100644 index 000000000..e163c5525 --- /dev/null +++ b/tests/graph/data/test_inherit.py @@ -0,0 +1,61 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph, Dataset, InMemoryDataset + + +class MyData(Graph): + def __init__(self, x=None, edge_index=None, arg=None): + super().__init__(x=x, edge_index=edge_index, arg=arg) + + def random(self): + return ops.randn(list(self.x.shape) + list(self.arg.shape)) + + +class MyInMemoryDataset(InMemoryDataset): + def __init__(self): + super().__init__('/tmp/MyInMemoryDataset') + + x = ops.randn(4, 5) + edge_index = ms.Tensor([[0, 0, 0], [1, 2, 3]]) + arg = ops.randn(4, 3) + + data_list = [MyData(x, edge_index, arg) for _ in range(10)] + self.data, self.slices = self.collate(data_list) + + def _download(self): + pass + + def _process(self): + pass + + +class MyDataset(Dataset): + def __init__(self): + super().__init__('/tmp/MyDataset') + + def _download(self): + pass + + def _process(self): + pass + + def len(self): + return 10 + + def get(self, idx): + x = ops.randn(4, 5) + edge_index = ms.Tensor([[0, 0, 0], [1, 2, 3]]) + arg = ops.randn(4, 3) + return MyData(x, edge_index, arg) + + +def test_inherit(): + dataset = MyDataset() + assert len(dataset) == 10 + data = dataset[0] + assert data.random().shape == (4, 5, 4, 3) + + dataset = MyInMemoryDataset() + assert len(dataset) == 10 + data = dataset[0] + assert data.random().shape == (4, 5, 4, 3) diff --git a/tests/graph/data/test_on_disk_dataset.py b/tests/graph/data/test_on_disk_dataset.py new file mode 100644 index 000000000..0ee50de56 --- /dev/null +++ b/tests/graph/data/test_on_disk_dataset.py @@ -0,0 +1,111 @@ +import os.path as osp +from typing import Any, Dict + +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph, OnDiskDataset +from mindscience.sharker.testing import withPackage + + +@withPackage('sqlite3') +def test_pickle(tmp_path): + dataset = OnDiskDataset(tmp_path) + assert len(dataset) == 0 + assert str(dataset) == 'OnDiskDataset(0)' + assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) + + data_list = [ + Graph( + x=ops.randn(5, 8), + edge_index=ops.randint(0, 5, (2, 16)), + num_nodes=5, + ) for _ in range(4) + ] + + dataset.append(data_list[0]) + assert len(dataset) == 1 + + dataset.extend(data_list[1:]) + assert len(dataset) == 4 + + out = dataset.get(0) + assert ops.equal(out.x, data_list[0].x).all() + assert ops.equal(out.edge_index, data_list[0].edge_index).all() + assert out.num_nodes == data_list[0].num_nodes + + out_list = dataset.multi_get([1, 2, 3]) + for out, data in zip(out_list, data_list[1:]): + assert ops.equal(out.x, data.x).all() + assert ops.equal(out.edge_index, data.edge_index).all() + assert out.num_nodes == data.num_nodes + + dataset.close() + + # Test persistence of datasets: + dataset = OnDiskDataset(tmp_path) + assert len(dataset) == 4 + + out = dataset.get(0) + assert ops.equal(out.x, data_list[0].x).all() + assert ops.equal(out.edge_index, data_list[0].edge_index).all() + assert out.num_nodes == data_list[0].num_nodes + + dataset.close() + + +@withPackage('sqlite3') +def test_custom_schema(tmp_path): + class CustomSchemaOnDiskDataset(OnDiskDataset): + def __init__(self, root: str): + schema = { + 'x': dict(dtype=ms.float32, size=(-1, 8)), + 'edge_index': dict(dtype=ms.int64, size=(2, -1)), + 'num_nodes': int, + } + self.serialize_count = 0 + self.deserialize_count = 0 + super().__init__(root, schema=schema) + + def serialize(self, data: Graph) -> Dict[str, Any]: + self.serialize_count += 1 + return data.to_dict() + + def deserialize(self, mapping: Dict[str, Any]) -> Any: + self.deserialize_count += 1 + return Graph.from_dict(mapping) + + dataset = CustomSchemaOnDiskDataset(tmp_path) + assert len(dataset) == 0 + assert str(dataset) == 'CustomSchemaOnDiskDataset(0)' + assert osp.exists(osp.join(tmp_path, 'processed', 'sqlite.db')) + + data_list = [ + Graph( + x=ops.randn(5, 8), + edge_index=ops.randint(0, 5, (2, 16)), + num_nodes=5, + ) for _ in range(4) + ] + + dataset.append(data_list[0]) + assert dataset.serialize_count == 1 + assert len(dataset) == 1 + + dataset.extend(data_list[1:]) + assert dataset.serialize_count == 4 + assert len(dataset) == 4 + + out = dataset.get(0) + assert dataset.deserialize_count == 1 + assert ops.equal(out.x, data_list[0].x).all() + assert ops.equal(out.edge_index, data_list[0].edge_index).all() + assert out.num_nodes == data_list[0].num_nodes + + out_list = dataset.multi_get([1, 2, 3]) + assert dataset.deserialize_count == 4 + for out, data in zip(out_list, data_list[1:]): + assert ops.equal(out.x, data.x).all() + assert ops.equal(out.edge_index, data.edge_index).all() + assert out.num_nodes == data.num_nodes + + dataset.close() diff --git a/tests/graph/data/test_remote_backend_utils.py b/tests/graph/data/test_remote_backend_utils.py new file mode 100644 index 000000000..4c4995949 --- /dev/null +++ b/tests/graph/data/test_remote_backend_utils.py @@ -0,0 +1,33 @@ +import pytest +from mindspore import ops +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.data.remote_backend_utils import num_nodes, size +from mindscience.sharker.testing import ( + MyFeatureStore, + MyGraphStore, + get_random_edge_index, +) +from mindscience.sharker.sparse import Layout + + +@pytest.mark.parametrize('FeatureStore', [MyFeatureStore, HeteroGraph]) +@pytest.mark.parametrize('GraphStore', [MyGraphStore, HeteroGraph]) +def test_num_nodes_shape(FeatureStore, GraphStore): + feature_store = FeatureStore() + graph_store = GraphStore() + + # Infer num nodes from features: + x = ops.arange(100) + feature_store.put_tensor(x, group_name='x', attr_name='x', index=None) + assert num_nodes(feature_store, graph_store, 'x') == 100 + + # Infer num nodes and shape from edges: + xy = get_random_edge_index(100, 50, 20) + graph_store.put_edge_index(xy, edge_type=('x', 'to', 'y'), layout=Layout.COO, + shape=(100, 50)) + assert num_nodes(feature_store, graph_store, 'y') == 50 + assert size(feature_store, graph_store, ('x', 'to', 'y')) == (100, 50) + + # Throw an error if we cannot infer for an unknown node type: + with pytest.raises(ValueError, match="Unable to accurately infer"): + _ = num_nodes(feature_store, graph_store, 'z') diff --git a/tests/graph/data/test_storage.py b/tests/graph/data/test_storage.py new file mode 100644 index 000000000..4ebfbb40d --- /dev/null +++ b/tests/graph/data/test_storage.py @@ -0,0 +1,81 @@ +import copy +from typing import Any + +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data.storage import BaseStorage + + +def test_base_storage(): + storage = BaseStorage() + assert storage._mapping == {} + storage.x = mint.zeros(1) + storage.y = mint.ones(1) + assert len(storage) == 2 + assert storage._mapping == {'x': mint.zeros(1), 'y': mint.ones(1)} + assert storage.x is not None + assert storage.y is not None + + assert mint.isclose(storage.get('x', None), storage.x, equal_nan= True).all() + assert mint.isclose(storage.get('y', None), storage.y, equal_nan= True).all() + assert storage.get('z', 2) == 2 + assert storage.get('z', None) is None + assert len(list(storage.keys('x', 'y', 'z'))) == 2 + assert len(list(storage.keys('x', 'y', 'z'))) == 2 + assert len(list(storage.values('x', 'y', 'z'))) == 2 + assert len(list(storage.items('x', 'y', 'z'))) == 2 + + del storage.y + assert len(storage) == 1 + assert storage.x is not None + + storage = BaseStorage({'x': mint.zeros(1)}) + assert len(storage) == 1 + assert storage.x is not None + + storage = BaseStorage(x=mint.zeros(1)) + assert len(storage) == 1 + assert storage.x is not None + + storage = BaseStorage(x=mint.zeros(1)) + copied_storage = copy.copy(storage) + assert storage == copied_storage + assert id(storage) != id(copied_storage) + assert storage.x is copied_storage.x + assert int(storage.x) == 0 + assert int(copied_storage.x) == 0 + + deepcopied_storage = copy.deepcopy(storage) + assert storage == deepcopied_storage + assert id(storage) != id(deepcopied_storage) + assert storage.x is not deepcopied_storage + assert int(storage.x) == 0 + assert int(deepcopied_storage.x) == 0 + + with pytest.raises(AttributeError, match="has no attribute 'asdf'"): + storage.asdf + + +def test_storage_tensor_methods(): + x = ops.randn(5) + storage = BaseStorage({'x': x}) + + storage = storage.copy() + assert storage.x is not x + + +def test_setter_and_getter(): + class MyStorage(BaseStorage): + @property + def my_property(self) -> Any: + return self._my_property + + @my_property.setter + def my_property(self, value: Any): + self._my_property = value + + storage = MyStorage() + storage.my_property = 'hello' + assert storage.my_property == 'hello' + assert storage._my_property == storage._my_property diff --git a/tests/graph/data/test_temporal.py b/tests/graph/data/test_temporal.py new file mode 100644 index 000000000..c26a68774 --- /dev/null +++ b/tests/graph/data/test_temporal.py @@ -0,0 +1,137 @@ +import copy + +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data import TemporalGraph + + +def get_temporal_graph(num_events, msg_channels): + return TemporalGraph( + src=mint.arange(num_events), + dst=mint.arange(num_events, num_events * 2), + t=mint.arange(0, num_events * 1000, step=1000), + msg=ops.randn(num_events, msg_channels), + y=ops.randint(0, 2, (num_events, )), + ) + + +def test_temporal_grpah(): + data = get_temporal_graph(num_events=3, msg_channels=16) + assert str(data) == ("TemporalGraph(src=[3], dst=[3], t=[3], " + "msg=[3, 16], y=[3])") + + assert data.num_nodes == 6 + assert data.num_events == data.num_edges == len(data) == 3 + + assert data.src.tolist() == [0, 1, 2] + assert data['src'].tolist() == [0, 1, 2] + + assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] + data.edge_index = 'edge_index' + assert data.edge_index == 'edge_index' + del data.edge_index + assert data.edge_index.tolist() == [[0, 1, 2], [3, 4, 5]] + + assert sorted(data.keys()) == ['dst', 'msg', 'src', 't', 'y'] + assert sorted(data.to_dict().keys()) == sorted(data.keys()) + + data_tuple = data.to_namedtuple() + assert len(data_tuple) == 5 + assert data_tuple.src is not None + assert data_tuple.dst is not None + assert data_tuple.t is not None + assert data_tuple.msg is not None + assert data_tuple.y is not None + + assert data.__cat_dim__('src', data.src) == 0 + assert data.__inc__('src', data.src) == 6 + + clone = data.copy() + assert clone != data + assert len(clone) == len(data) + assert clone.src is not data.src + assert clone.src.tolist() == data.src.tolist() + assert clone.dst is not data.dst + assert clone.dst.tolist() == data.dst.tolist() + + deepcopy = copy.deepcopy(data) + assert deepcopy != data + assert len(deepcopy) == len(data) + assert deepcopy.src is not data.src + assert deepcopy.src.tolist() == data.src.tolist() + assert deepcopy.dst is not data.dst + assert deepcopy.dst.tolist() == data.dst.tolist() + + key = value = 'test_value' + data[key] = value + assert data[key] == value + assert data.test_value == value + del data[key] + del data[key] + + assert data.get(key, 10) == 10 + + assert len([event for event in data]) == 3 + + assert len([attr for attr in data()]) == 5 + + assert data.shape == (2, 5) + + del data.src + assert 'src' not in data + + +def test_train_val_test_split(): + data = get_temporal_graph(num_events=100, msg_channels=16) + + train_data, val_data, test_data = data.train_val_test_split( + val_ratio=0.2, test_ratio=0.15) + + assert len(train_data) == 65 + assert len(val_data) == 20 + assert len(test_data) == 15 + + assert train_data.t.max() < val_data.t.min() + assert val_data.t.max() < test_data.t.min() + + +def test_temporal_indexing(): + data = get_temporal_graph(num_events=10, msg_channels=16) + + elem = data[0] + assert isinstance(elem, TemporalGraph) + assert len(elem) == 1 + assert elem.src.tolist() == data.src[0:1].tolist() + assert elem.dst.tolist() == data.dst[0:1].tolist() + assert elem.t.tolist() == data.t[0:1].tolist() + assert elem.msg.tolist() == data.msg[0:1].tolist() + assert elem.y.tolist() == data.y[0:1].tolist() + + subset = data[0:5] + assert isinstance(subset, TemporalGraph) + assert len(subset) == 5 + assert subset.src.tolist() == data.src[0:5].tolist() + assert subset.dst.tolist() == data.dst[0:5].tolist() + assert subset.t.tolist() == data.t[0:5].tolist() + assert subset.msg.tolist() == data.msg[0:5].tolist() + assert subset.y.tolist() == data.y[0:5].tolist() + + index = [0, 4, 8] + subset = data[ms.Tensor(index)] + assert isinstance(subset, TemporalGraph) + assert len(subset) == 3 + assert subset.src.tolist() == data.src[0::4].tolist() + assert subset.dst.tolist() == data.dst[0::4].tolist() + assert subset.t.tolist() == data.t[0::4].tolist() + assert subset.msg.tolist() == data.msg[0::4].tolist() + assert subset.y.tolist() == data.y[0::4].tolist() + + mask = [True, False, True, False, True, False, True, False, True, False] + subset = data[ms.Tensor(mask)] + assert isinstance(subset, TemporalGraph) + assert len(subset) == 5 + assert subset.src.tolist() == data.src[0::2].tolist() + assert subset.dst.tolist() == data.dst[0::2].tolist() + assert subset.t.tolist() == data.t[0::2].tolist() + assert subset.msg.tolist() == data.msg[0::2].tolist() + assert subset.y.tolist() == data.y[0::2].tolist() diff --git a/tests/graph/data/test_view.py b/tests/graph/data/test_view.py new file mode 100644 index 000000000..8375a41b2 --- /dev/null +++ b/tests/graph/data/test_view.py @@ -0,0 +1,32 @@ +import mindspore as ms +from mindscience.sharker.data.storage import BaseStorage + + +def test_views(): + storage = BaseStorage(x=1, y=2, z=3) + + assert str(storage.keys()) == "KeysView({'x': 1, 'y': 2, 'z': 3})" + assert len(storage.keys()) == 3 + assert list(storage.keys()) == ['x', 'y', 'z'] + + assert str(storage.values()) == "ValuesView({'x': 1, 'y': 2, 'z': 3})" + assert len(storage.values()) == 3 + assert list(storage.values()) == [1, 2, 3] + + assert str(storage.items()) == "ItemsView({'x': 1, 'y': 2, 'z': 3})" + assert len(storage.items()) == 3 + assert list(storage.items()) == [('x', 1), ('y', 2), ('z', 3)] + + args = ['x', 'z', 'foo'] + + assert str(storage.keys(*args)) == "KeysView({'x': 1, 'z': 3})" + assert len(storage.keys(*args)) == 2 + assert list(storage.keys(*args)) == ['x', 'z'] + + assert str(storage.values(*args)) == "ValuesView({'x': 1, 'z': 3})" + assert len(storage.values(*args)) == 2 + assert list(storage.values(*args)) == [1, 3] + + assert str(storage.items(*args)) == "ItemsView({'x': 1, 'z': 3})" + assert len(storage.items(*args)) == 2 + assert list(storage.items(*args)) == [('x', 1), ('z', 3)] diff --git a/tests/graph/datasets/graph_generator/test_ba_graph.py b/tests/graph/datasets/graph_generator/test_ba_graph.py new file mode 100644 index 000000000..29523d5ba --- /dev/null +++ b/tests/graph/datasets/graph_generator/test_ba_graph.py @@ -0,0 +1,11 @@ +from mindscience.sharker.datasets.graph_generator import BAGraph + + +def test_ba_graph(): + graph_generator = BAGraph(num_nodes=300, num_edges=5) + assert str(graph_generator) == 'BAGraph(num_nodes=300, num_edges=5)' + + data = graph_generator() + assert len(data) == 2 + assert data.num_nodes == 300 + assert data.num_edges <= 2 * 300 * 5 diff --git a/tests/graph/datasets/graph_generator/test_er_graph.py b/tests/graph/datasets/graph_generator/test_er_graph.py new file mode 100644 index 000000000..567e16249 --- /dev/null +++ b/tests/graph/datasets/graph_generator/test_er_graph.py @@ -0,0 +1,12 @@ +from mindscience.sharker.datasets.graph_generator import ERGraph + + +def test_er_graph(): + graph_generator = ERGraph(num_nodes=300, edge_prob=0.1) + assert str(graph_generator) == 'ERGraph(num_nodes=300, edge_prob=0.1)' + + data = graph_generator() + assert len(data) == 2 + assert data.num_nodes == 300 + assert data.num_edges >= 300 * 300 * 0.05 + assert data.num_edges <= 300 * 300 * 0.15 diff --git a/tests/graph/datasets/graph_generator/test_grid_graph.py b/tests/graph/datasets/graph_generator/test_grid_graph.py new file mode 100644 index 000000000..87405e37a --- /dev/null +++ b/tests/graph/datasets/graph_generator/test_grid_graph.py @@ -0,0 +1,11 @@ +from mindscience.sharker.datasets.graph_generator import GridGraph + + +def test_grid_graph(): + graph_generator = GridGraph(height=10, width=10) + assert str(graph_generator) == 'GridGraph(height=10, width=10)' + + data = graph_generator() + assert len(data) == 2 + assert data.num_nodes == 100 + assert data.num_edges == 784 diff --git a/tests/graph/datasets/graph_generator/test_tree_graph.py b/tests/graph/datasets/graph_generator/test_tree_graph.py new file mode 100644 index 000000000..89f8e4cd4 --- /dev/null +++ b/tests/graph/datasets/graph_generator/test_tree_graph.py @@ -0,0 +1,25 @@ +import pytest + +from mindscience.sharker.datasets.graph_generator import TreeGraph + + +@pytest.mark.parametrize('undirected', [False, True]) +def test_tree_graph(undirected): + graph_generator = TreeGraph(depth=2, branch=2, undirected=undirected) + assert str(graph_generator) == (f'TreeGraph(depth=2, branch=2, ' + f'undirected={undirected})') + + data = graph_generator() + assert len(data) == 3 + assert data.num_nodes == 7 + assert data.depth.tolist() == [0, 1, 1, 2, 2, 2, 2] + if not undirected: + assert data.edge_index.tolist() == [ + [0, 0, 1, 1, 2, 2], + [1, 2, 3, 4, 5, 6], + ] + else: + assert data.edge_index.tolist() == [ + [0, 0, 1, 1, 1, 2, 2, 2, 3, 4, 5, 6], + [1, 2, 0, 3, 4, 0, 5, 6, 1, 1, 2, 2], + ] diff --git a/tests/graph/datasets/motif_generator/test_custom_motif.py b/tests/graph/datasets/motif_generator/test_custom_motif.py new file mode 100644 index 000000000..325714a6e --- /dev/null +++ b/tests/graph/datasets/motif_generator/test_custom_motif.py @@ -0,0 +1,37 @@ +import pytest +import mindspore as ms +from mindscience.sharker.data import Graph +from mindscience.sharker.datasets.motif_generator import CustomMotif +from mindscience.sharker.testing import withPackage + + +def test_custom_motif_pyg_data(): + structure = Graph( + num_nodes=3, + edge_index=ms.Tensor([[0, 1, 2, 1, 2, 0], [1, 2, 0, 0, 1, 2]]), + ) + + motif_generator = CustomMotif(structure) + assert str(motif_generator) == 'CustomMotif()' + + assert structure == motif_generator() + + +@withPackage('networkx') +def test_custom_motif_networkx(): + import networkx as nx + + structure = nx.gnm_random_graph(5, 10, seed=2000) + + motif_generator = CustomMotif(structure) + assert str(motif_generator) == 'CustomMotif()' + + out = motif_generator() + assert len(out) == 2 + assert out.num_nodes == 5 + assert out.num_edges == 20 + + +def test_custom_motif_unknown(): + with pytest.raises(ValueError, match="motif structure of type"): + CustomMotif(structure='unknown') diff --git a/tests/graph/datasets/motif_generator/test_cycle_motif.py b/tests/graph/datasets/motif_generator/test_cycle_motif.py new file mode 100644 index 000000000..af315547f --- /dev/null +++ b/tests/graph/datasets/motif_generator/test_cycle_motif.py @@ -0,0 +1,15 @@ +from mindscience.sharker.datasets.motif_generator import CycleMotif + + +def test_cycle_motif(): + motif_generator = CycleMotif(5) + assert str(motif_generator) == 'CycleMotif(5)' + + motif = motif_generator() + assert len(motif) == 2 + assert motif.num_nodes == 5 + assert motif.num_edges == 10 + assert motif.edge_index.tolist() == [ + [0, 0, 1, 1, 2, 2, 3, 3, 4, 4], + [1, 4, 0, 2, 1, 3, 2, 4, 0, 3], + ] diff --git a/tests/graph/datasets/motif_generator/test_grid_motif.py b/tests/graph/datasets/motif_generator/test_grid_motif.py new file mode 100644 index 000000000..df2c24625 --- /dev/null +++ b/tests/graph/datasets/motif_generator/test_grid_motif.py @@ -0,0 +1,17 @@ +from mindscience.sharker.datasets.motif_generator import GridMotif + + +def test_grid_motif(): + motif_generator = GridMotif() + assert str(motif_generator) == 'GridMotif()' + + motif = motif_generator() + assert len(motif) == 3 + assert motif.num_nodes == 9 + assert motif.num_edges == 24 + assert motif.edge_index.shape == (2, 24) + assert motif.edge_index.min() == 0 + assert motif.edge_index.max() == 8 + assert motif.y.shape == (9, ) + assert motif.y.min() == 0 + assert motif.y.max() == 2 diff --git a/tests/graph/datasets/motif_generator/test_house_motif.py b/tests/graph/datasets/motif_generator/test_house_motif.py new file mode 100644 index 000000000..e7a90f580 --- /dev/null +++ b/tests/graph/datasets/motif_generator/test_house_motif.py @@ -0,0 +1,12 @@ +from mindscience.sharker.datasets.motif_generator import HouseMotif + + +def test_house_motif(): + motif_generator = HouseMotif() + assert str(motif_generator) == 'HouseMotif()' + + motif = motif_generator() + assert len(motif) == 3 + assert motif.num_nodes == 5 + assert motif.num_edges == 12 + assert motif.y.min() == 0 and motif.y.max() == 2 diff --git a/tests/graph/datasets/test_ba_shapes.py b/tests/graph/datasets/test_ba_shapes.py new file mode 100644 index 000000000..09de88908 --- /dev/null +++ b/tests/graph/datasets/test_ba_shapes.py @@ -0,0 +1,18 @@ +import pytest + + +def test_ba_shapes(get_dataset): + dataset = get_dataset(name='BAShapes') + + assert str(dataset) == 'BAShapes()' + assert len(dataset) == 1 + assert dataset.num_features == 10 + assert dataset.num_classes == 4 + + data = dataset[0] + assert len(data) == 5 + assert data.edge_index.shape[1] >= 1120 + assert data.x.shape == (700, 10) + assert data.y.shape == (700, ) + assert data.expl_mask.sum() == 60 + assert data.edge_label.sum() == 960 diff --git a/tests/graph/datasets/test_bzr.py b/tests/graph/datasets/test_bzr.py new file mode 100644 index 000000000..56945db4d --- /dev/null +++ b/tests/graph/datasets/test_bzr.py @@ -0,0 +1,23 @@ +from mindscience.sharker.testing import onlyFullTest, onlyOnline + + +@onlyOnline +# @onlyFullTest +def test_bzr(get_dataset): + dataset = get_dataset(name='BZR') + assert len(dataset) == 405 + assert dataset.num_features == 53 + assert dataset.num_node_labels == 53 + assert dataset.num_node_attributes == 3 + assert dataset.num_classes == 2 + assert str(dataset) == 'BZR(405)' + assert len(dataset[0]) == 3 + + +@onlyOnline +# @onlyFullTest +def test_bzr_with_node_attr(get_dataset): + dataset = get_dataset(name='BZR', use_node_attr=True) + assert dataset.num_features == 56 + assert dataset.num_node_labels == 53 + assert dataset.num_node_attributes == 3 diff --git a/tests/graph/datasets/test_elliptic.py b/tests/graph/datasets/test_elliptic.py new file mode 100644 index 000000000..2e915e7ae --- /dev/null +++ b/tests/graph/datasets/test_elliptic.py @@ -0,0 +1,28 @@ +from mindscience.sharker.testing import onlyFullTest, onlyOnline + + +@onlyOnline +# @onlyFullTest +def test_elliptic_bitcoin_dataset(get_dataset): + dataset = get_dataset(name='EllipticBitcoinDataset') + assert str(dataset) == 'EllipticBitcoinDataset()' + assert len(dataset) == 1 + assert dataset.num_features == 165 + assert dataset.num_classes == 2 + + data = dataset[0] + assert len(data) == 5 + assert data.x.shape == (203769, 165) + assert data.edge_index.shape == (2, 234355) + assert data.y.shape == (203769, ) + + assert data.train_mask.shape == (203769, ) + assert data.train_mask.sum() > 0 + assert data.test_mask.shape == (203769, ) + assert data.test_mask.sum() > 0 + assert data.train_mask.sum() + data.test_mask.sum() == 4545 + 42019 + assert data.y[data.train_mask].sum() == 3462 + assert data.y[data.test_mask].sum() == 1083 + assert data.y[data.train_mask].sum() + data.y[data.test_mask].sum() == 4545 + assert data.y[data.test_mask | data.train_mask].min() == 0 + assert data.y[data.test_mask | data.train_mask].max() == 1 diff --git a/tests/graph/datasets/test_enzymes.py b/tests/graph/datasets/test_enzymes.py new file mode 100644 index 000000000..bf2efc8d6 --- /dev/null +++ b/tests/graph/datasets/test_enzymes.py @@ -0,0 +1,71 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.loader import DataListLoader, DataLoader, DenseDataLoader +from mindscience.sharker.testing import onlyOnline +from mindscience.sharker.transforms import ToDense + + +@onlyOnline +def test_enzymes(get_dataset): + dataset = get_dataset(name='ENZYMES') + assert len(dataset) == 600 + assert dataset.num_features == 3 + assert dataset.num_classes == 6 + assert str(dataset) == 'ENZYMES(600)' + + assert len(dataset[0]) == 3 + assert len(dataset.shuffle()) == 600 + assert len(dataset.shuffle(return_perm=True)) == 2 + assert len(dataset[:100]) == 100 + assert len(dataset[0.1:0.2]) == 60 + assert len(dataset[ops.arange(100, dtype=ms.int64)]) == 100 + mask = ops.zeros(600).bool() + mask[:100] = 1 + assert len(dataset[mask]) == 100 + + loader = DataLoader(dataset, batch_size=len(dataset)) + for batch in loader: + assert batch.num_graphs == len(batch) == 600 + + avg_num_nodes = batch.num_nodes / batch.num_graphs + assert pytest.approx(avg_num_nodes, abs=1e-2) == 32.63 + + avg_num_edges = batch.num_edges / (2 * batch.num_graphs) + assert pytest.approx(avg_num_edges, abs=1e-2) == 62.14 + + assert list(batch.x.shape) == [batch.num_nodes, 3] + assert list(batch.y.shape) == [batch.num_graphs] + assert batch.y.max() + 1 == 6 + assert list(batch.batch.shape) == [batch.num_nodes] + assert batch.ptr.numel() == batch.num_graphs + 1 + + assert batch.has_isolated_nodes() + assert not batch.has_self_loops() + assert batch.is_undirected() + + loader = DataListLoader(dataset, batch_size=len(dataset)) + for data_list in loader: + assert len(data_list) == 600 + + dataset.transform = ToDense(num_nodes=126) + loader = DenseDataLoader(dataset, batch_size=len(dataset)) + for data in loader: + assert list(data.x.shape) == [600, 126, 3] + assert list(data.adj.shape) == [600, 126, 126] + assert list(data.mask.shape) == [600, 126] + assert list(data.y.shape) == [600, 1] + + +@onlyOnline +def test_enzymes_with_node_attr(get_dataset): + dataset = get_dataset(name='ENZYMES', use_node_attr=True) + assert dataset.num_node_features == 21 + assert dataset.num_features == 21 + assert dataset.num_edge_features == 0 + + +@onlyOnline +def test_cleaned_enzymes(get_dataset): + dataset = get_dataset(name='ENZYMES', cleaned=True) + assert len(dataset) == 595 diff --git a/tests/graph/datasets/test_explainer_dataset.py b/tests/graph/datasets/test_explainer_dataset.py new file mode 100644 index 000000000..978da7585 --- /dev/null +++ b/tests/graph/datasets/test_explainer_dataset.py @@ -0,0 +1,47 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import seed_everything +from mindscience.sharker.datasets import ExplainerDataset +from mindscience.sharker.datasets.graph_generator import BAGraph +from mindscience.sharker.datasets.motif_generator import HouseMotif + + +@pytest.mark.parametrize('graph_generator', [ + pytest.param(BAGraph(num_nodes=80, num_edges=5), id='BAGraph'), +]) +@pytest.mark.parametrize('motif_generator', [ + pytest.param(HouseMotif(), id='HouseMotif'), + 'house', +]) +def test_explainer_dataset_ba_house(graph_generator, motif_generator): + dataset = ExplainerDataset(graph_generator, motif_generator, num_motifs=2) + assert str(dataset) == ('ExplainerDataset(1, graph_generator=' + 'BAGraph(num_nodes=80, num_edges=5), ' + 'motif_generator=HouseMotif(), num_motifs=2)') + assert len(dataset) == 1 + + data = dataset[0] + assert len(data) == 4 + assert data.num_nodes == 80 + (2 * 5) + assert data.edge_index.min() >= 0 + assert data.edge_index.max() < data.num_nodes + assert data.y.min() == 0 and data.y.max() == 3 + assert data.node_mask.shape == (data.num_nodes, ) + assert data.edge_mask.shape == (data.num_edges, ) + assert data.node_mask.min() == 0 and data.node_mask.max() == 1 + assert data.node_mask.sum() == 2 * 5 + assert data.edge_mask.min() == 0 and data.edge_mask.max() == 1 + assert data.edge_mask.sum() == 2 * 12 + + +def test_explainer_dataset_reproducibility(): + seed_everything(12345) + data1 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(), + num_motifs=2)[0] + + seed_everything(12345) + data2 = ExplainerDataset(BAGraph(num_nodes=80, num_edges=5), HouseMotif(), + num_motifs=2)[0] + + assert ops.equal(data1.edge_index, data2.edge_index).all() diff --git a/tests/graph/datasets/test_fake.py b/tests/graph/datasets/test_fake.py new file mode 100644 index 000000000..a57087e43 --- /dev/null +++ b/tests/graph/datasets/test_fake.py @@ -0,0 +1,84 @@ +import pytest + +from mindscience.sharker.datasets import FakeDataset, FakeHeteroDataset + + +@pytest.mark.parametrize('num_graphs', [1, 10]) +@pytest.mark.parametrize('edge_dim', [0, 1, 4]) +@pytest.mark.parametrize('task', ['node', 'graph', 'auto']) +def test_fake_dataset(num_graphs, edge_dim, task): + dataset = FakeDataset(num_graphs, edge_dim=edge_dim, task=task, + global_features=3) + + if num_graphs > 1: + assert str(dataset) == f'FakeDataset({num_graphs})' + else: + assert str(dataset) == 'FakeDataset()' + + assert len(dataset) == num_graphs + + data = dataset[0] + + assert data.num_features == 64 + + if edge_dim == 0: + assert len(data) == 4 + elif edge_dim == 1: + assert len(data) == 5 + assert data.edge_weight.shape == (data.num_edges, ) + assert data.edge_weight.min() >= 0 and data.edge_weight.max() < 1 + else: + assert len(data) == 5 + assert data.edge_attr.shape == (data.num_edges, edge_dim) + assert data.edge_attr.min() >= 0 and data.edge_attr.max() < 1 + + assert data.y.min() >= 0 and data.y.max() < 10 + if task == 'node' or (task == 'auto' and num_graphs == 1): + assert data.y.shape == (data.num_nodes, ) + else: + assert data.y.shape == (1, ) + + assert data.global_features.shape == (3, ) + + +@pytest.mark.parametrize('num_graphs', [1, 10]) +@pytest.mark.parametrize('edge_dim', [0, 1, 4]) +@pytest.mark.parametrize('task', ['node', 'graph', 'auto']) +def test_fake_hetero_dataset(num_graphs, edge_dim, task): + dataset = FakeHeteroDataset(num_graphs, edge_dim=edge_dim, task=task, + global_features=3) + + if num_graphs > 1: + assert str(dataset) == f'FakeHeteroDataset({num_graphs})' + else: + assert str(dataset) == 'FakeHeteroDataset()' + + assert len(dataset) == num_graphs + + data = dataset[0] + + for store in data.node_stores: + assert store.num_features > 0 + + if task == 'node' or (task == 'auto' and num_graphs == 1): + if store._key == 'v0': + assert store.y.min() >= 0 and store.y.max() < 10 + assert store.y.shape == (store.num_nodes, ) + + for store in data.edge_stores: + if edge_dim == 0: + assert len(data) == 4 + elif edge_dim == 1: + assert len(data) == 5 + assert store.edge_weight.shape == (store.num_edges, ) + assert store.edge_weight.min() >= 0 and store.edge_weight.max() < 1 + else: + assert len(data) == 5 + assert store.edge_attr.shape == (store.num_edges, edge_dim) + assert store.edge_attr.min() >= 0 and store.edge_attr.max() < 1 + + if task == 'graph' or (task == 'auto' and num_graphs > 1): + assert data.y.min() >= 0 and data.y.max() < 10 + assert data.y.shape == (1, ) + + assert data.global_features.shape == (3, ) diff --git a/tests/graph/datasets/test_imdb_binary.py b/tests/graph/datasets/test_imdb_binary.py new file mode 100644 index 000000000..cee55e266 --- /dev/null +++ b/tests/graph/datasets/test_imdb_binary.py @@ -0,0 +1,17 @@ +from mindscience.sharker.testing import onlyFullTest, onlyOnline + + +@onlyOnline +# @onlyFullTest +def test_imdb_binary(get_dataset): + dataset = get_dataset(name='IMDB-BINARY') + assert len(dataset) == 1000 + assert dataset.num_features == 0 + assert dataset.num_classes == 2 + assert str(dataset) == 'IMDB-BINARY(1000)' + + data = dataset[0] + assert len(data) == 3 + assert data.edge_index.shape == (2, 146) + assert data.y.shape == (1, ) + assert data.num_nodes == 20 diff --git a/tests/graph/datasets/test_infection_dataset.py b/tests/graph/datasets/test_infection_dataset.py new file mode 100644 index 000000000..d91cc840d --- /dev/null +++ b/tests/graph/datasets/test_infection_dataset.py @@ -0,0 +1,54 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import seed_everything +from mindscience.sharker.data import Graph +from mindscience.sharker.datasets import InfectionDataset +from mindscience.sharker.datasets.graph_generator import ERGraph, GraphGenerator + + +class DummyGraph(GraphGenerator): + def __call__(self) -> Graph: + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8], + ]) + return Graph(num_nodes=10, edge_index=edge_index) + + +def test_infection_dataset(): + seed_everything(12345) + graph_generator = DummyGraph() + dataset = InfectionDataset(graph_generator, num_infected_nodes=2, + max_path_length=2) + assert str(dataset) == ('InfectionDataset(1, ' + 'graph_generator=DummyGraph(), ' + 'num_infected_nodes=2, ' + 'max_path_length=2)') + assert len(dataset) == 1 + + data = dataset[0] + assert len(data) == 4 + assert data.x.shape == (10, 2) + assert data.x[:, 0].sum() == 8 and data.x[:, 1].sum() == 2 + assert ops.equal(data.edge_index, graph_generator().edge_index).all() + assert data.y.shape == (10, ) + + # With `seed=12345`, node 0 and node 7 will be infected: + assert data.x[0].tolist() == [0, 1] # First infected node. + assert data.x[7].tolist() == [1, 0] # Second infected node. + assert data.y.tolist() == [0, 0, 1, 2, 3, 3, 3, 3, 3, 3] + assert data.edge_mask.tolist() == [1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + + +def test_infection_dataset_reproducibility(): + graph_generator = ERGraph(num_nodes=500, edge_prob=0.004) + + seed_everything(12345) + dataset1 = InfectionDataset(graph_generator, num_infected_nodes=50, + max_path_length=5) + + seed_everything(12345) + dataset2 = InfectionDataset(graph_generator, num_infected_nodes=50, + max_path_length=5) + + assert ops.equal(dataset1[0].edge_mask, dataset2[0].edge_mask).all() diff --git a/tests/graph/datasets/test_karate.py b/tests/graph/datasets/test_karate.py new file mode 100644 index 000000000..3cc8a1ea3 --- /dev/null +++ b/tests/graph/datasets/test_karate.py @@ -0,0 +1,13 @@ +def test_karate(get_dataset): + dataset = get_dataset(name='KarateClub') + assert str(dataset) == 'KarateClub()' + assert len(dataset) == 1 + assert dataset.num_features == 34 + assert dataset.num_classes == 4 + + assert len(dataset[0]) == 4 + assert dataset[0].edge_index.shape == (2, 156) + assert dataset[0].x.shape == (34, 34) + assert dataset[0].y.shape == (34, ) + assert dataset[0].train_mask.shape == (34, ) + assert dataset[0].train_mask.sum().item() == 4 diff --git a/tests/graph/datasets/test_mutag.py b/tests/graph/datasets/test_mutag.py new file mode 100644 index 000000000..2479c75ab --- /dev/null +++ b/tests/graph/datasets/test_mutag.py @@ -0,0 +1,19 @@ +from mindscience.sharker.testing import onlyOnline + + +@onlyOnline +def test_mutag(get_dataset): + dataset = get_dataset(name='MUTAG') + assert len(dataset) == 188 + assert dataset.num_features == 7 + assert dataset.num_classes == 2 + assert str(dataset) == 'MUTAG(188)' + + assert len(dataset[0]) == 4 + assert dataset[0].edge_attr.shape[1] == 4 + + +@onlyOnline +def test_mutag_with_node_attr(get_dataset): + dataset = get_dataset(name='MUTAG', use_node_attr=True) + assert dataset.num_features == 7 diff --git a/tests/graph/datasets/test_planetoid.py b/tests/graph/datasets/test_planetoid.py new file mode 100644 index 000000000..dff420950 --- /dev/null +++ b/tests/graph/datasets/test_planetoid.py @@ -0,0 +1,62 @@ +from mindscience.sharker.loader import DataLoader +from mindscience.sharker.testing import onlyOnline + + +@onlyOnline +def test_citeseer(get_dataset): + dataset = get_dataset(name='CiteSeer') + loader = DataLoader(dataset, batch_size=len(dataset)) + + assert len(dataset) == 1 + assert str(dataset) == 'CiteSeer()' + + for batch in loader: + assert batch.num_graphs == len(batch) == 1 + assert batch.num_nodes == 3327 + assert batch.num_edges / 2 == 4552 + + assert list(batch.x.shape) == [batch.num_nodes, 3703] + assert list(batch.y.shape) == [batch.num_nodes] + assert batch.y.max() + 1 == 6 + assert batch.train_mask.sum() == 6 * 20 + assert batch.val_mask.sum() == 500 + assert batch.test_mask.sum() == 1000 + assert (batch.train_mask & batch.val_mask & batch.test_mask).sum() == 0 + assert list(batch.batch.shape) == [batch.num_nodes] + assert batch.ptr.tolist() == [0, batch.num_nodes] + + assert batch.has_isolated_nodes() + assert not batch.has_self_loops() + assert batch.is_undirected() + + +@onlyOnline +def test_citeseer_with_full_split(get_dataset): + dataset = get_dataset(name='CiteSeer', split='full') + data = dataset[0] + assert data.val_mask.sum() == 500 + assert data.test_mask.sum() == 1000 + assert data.train_mask.sum() == data.num_nodes - 1500 + assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0 + + +@onlyOnline +def test_citeseer_with_random_split(get_dataset): + dataset = get_dataset( + name='CiteSeer', + split='random', + num_train_per_class=11, + num_val=29, + num_test=41, + ) + data = dataset[0] + # from mindscience.sharker import EdgeIndex + # assert isinstance(data.edge_index, EdgeIndex) + # assert data.edge_index.sparse_shape == (data.num_nodes, data.num_nodes) + # assert data.edge_index.is_undirected + # assert data.edge_index.is_sorted_by_col + + assert data.train_mask.sum() == dataset.num_classes * 11 + assert data.val_mask.sum() == 29 + assert data.test_mask.sum() == 41 + assert (data.train_mask & data.val_mask & data.test_mask).sum() == 0 diff --git a/tests/graph/datasets/test_snap_dataset.py b/tests/graph/datasets/test_snap_dataset.py new file mode 100644 index 000000000..f35a42a09 --- /dev/null +++ b/tests/graph/datasets/test_snap_dataset.py @@ -0,0 +1,25 @@ +from mindscience.sharker.testing import onlyFullTest, onlyOnline + + +@onlyOnline +# @onlyFullTest +def test_ego_facebook_snap_dataset(get_dataset): + dataset = get_dataset(name='ego-facebook') + assert str(dataset) == 'SNAP-ego-facebook(10)' + assert len(dataset) == 10 + + +@onlyOnline +# @onlyFullTest +def test_soc_slashdot_snap_dataset(get_dataset): + dataset = get_dataset(name='soc-Slashdot0811') + assert str(dataset) == 'SNAP-soc-slashdot0811(1)' + assert len(dataset) == 1 + + +@onlyOnline +# @onlyFullTest +def test_wiki_vote_snap_dataset(get_dataset): + dataset = get_dataset(name='wiki-vote') + assert str(dataset) == 'SNAP-wiki-vote(1)' + assert len(dataset) == 1 diff --git a/tests/graph/datasets/test_suite_sparse.py b/tests/graph/datasets/test_suite_sparse.py new file mode 100644 index 000000000..bf65bed76 --- /dev/null +++ b/tests/graph/datasets/test_suite_sparse.py @@ -0,0 +1,19 @@ +from mindscience.sharker.testing import onlyFullTest, onlyOnline + + +@onlyOnline +# @onlyFullTest +def test_suite_sparse_dataset(get_dataset): + dataset = get_dataset(group='DIMACS10', name='citationCiteseer') + assert str(dataset) == ('SuiteSparseMatrixCollection(' + 'group=DIMACS10, name=citationCiteseer)') + assert len(dataset) == 1 + + +@onlyOnline +# @onlyFullTest +def test_illc1850_suite_sparse_dataset(get_dataset): + dataset = get_dataset(group='HB', name='illc1850') + assert str(dataset) == ('SuiteSparseMatrixCollection(' + 'group=HB, name=illc1850)') + assert len(dataset) == 1 diff --git a/tests/graph/explain/algorithm/test_attention_explainer.py b/tests/graph/explain/algorithm/test_attention_explainer.py new file mode 100644 index 000000000..529e117b8 --- /dev/null +++ b/tests/graph/explain/algorithm/test_attention_explainer.py @@ -0,0 +1,84 @@ +import pytest +import mindspore as ms +from mindspore import nn, ops +from mindscience.sharker.explain import AttentionExplainer, Explainer +from mindscience.sharker.explain.config import ExplanationType, MaskType +from mindscience.sharker.nn import AttentiveFP, GATConv, GATv2Conv, TransformerConv + + +class AttentionGNN(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = GATConv(3, 16, heads=4) + self.conv2 = GATv2Conv(4 * 16, 16, heads=2) + self.conv3 = TransformerConv(2 * 16, 7, heads=1) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + x = self.conv3(x, edge_index) + return x + + +x = ops.randn(8, 3) +edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], +]) +edge_attr = ops.randn(edge_index.shape[1], 5) +batch = ms.Tensor([0, 0, 0, 1, 1, 2, 2, 2]) + + +@pytest.mark.parametrize('index', [None, 2, ops.arange(3)]) +def test_attention_explainer(index, check_explanation): + explainer = Explainer( + model=AttentionGNN(), + algorithm=AttentionExplainer(), + explanation_type='model', + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index, index=index) + check_explanation(explanation, None, explainer.edge_mask_type) + + +@pytest.mark.parametrize('explanation_type', [e for e in ExplanationType]) +@pytest.mark.parametrize('node_mask_type', [m for m in MaskType]) +def test_attention_explainer_supports(explanation_type, node_mask_type): + with pytest.raises(ValueError, match="not support the given explanation"): + Explainer( + model=AttentionGNN(), + algorithm=AttentionExplainer(), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + task_level='node', + return_type='raw', + ), + ) + + +def test_attention_explainer_attentive_fp(check_explanation): + model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) + + explainer = Explainer( + model=model, + algorithm=AttentionExplainer(), + explanation_type='model', + edge_mask_type='object', + model_config=dict( + mode='binary_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) + check_explanation(explanation, None, explainer.edge_mask_type) diff --git a/tests/graph/explain/algorithm/test_captum.py b/tests/graph/explain/algorithm/test_captum.py new file mode 100644 index 000000000..4e6996ef9 --- /dev/null +++ b/tests/graph/explain/algorithm/test_captum.py @@ -0,0 +1,201 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.explain.algorithm.captum import to_captum_input, to_captum_model +from mindscience.sharker.nn import GAT, GCN, SAGEConv +from mindscience.sharker.nn.conv import MessagePassing +from mindscience.sharker.testing import withPackage + +x = ops.randn(8, 3) +edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6]]) + +GCN = GCN(3, 16, 2, 7, dropout=0.5) +GAT = GAT(3, 16, 2, 7, heads=2, concat=False) +mask_types = ['edge', 'node_and_edge', 'node'] +methods = [ + 'Saliency', + 'InputXGradient', + 'Deconvolution', + 'FeatureAblation', + 'ShapleyValueSampling', + 'IntegratedGradients', + 'GradientShap', + 'Occlusion', + 'GuidedBackprop', + 'KernelShap', + 'Lime', +] + + +@pytest.mark.parametrize('mask_type', mask_types) +@pytest.mark.parametrize('model', [GCN, GAT]) +@pytest.mark.parametrize('output_idx', [None, 1]) +def test_to_captum(model, mask_type, output_idx): + captum_model = to_captum_model(model, mask_type=mask_type, + output_idx=output_idx) + pre_out = model(x, edge_index) + if mask_type == 'node': + mask = x * 0.0 + out = captum_model(mask.unsqueeze(0), edge_index) + elif mask_type == 'edge': + mask = ops.ones(edge_index.shape[1], dtype=ms.float32) * 0.5 + out = captum_model(mask.unsqueeze(0), x, edge_index) + elif mask_type == 'node_and_edge': + node_mask = x * 0.0 + edge_mask = ops.ones(edge_index.shape[1], dtype=ms.float32) * 0.5 + out = captum_model(node_mask.unsqueeze(0), edge_mask.unsqueeze(0), + edge_index) + + if output_idx is not None: + assert out.shape == (1, 7) + assert ops.any(out != pre_out[[output_idx]]) + else: + assert out.shape == (8, 7) + assert ops.any(out != pre_out) + + +@withPackage('captum') +@pytest.mark.parametrize('mask_type', mask_types) +@pytest.mark.parametrize('method', methods) +def test_captum_attribution_methods(mask_type, method): + from captum import attr # noqa + + captum_model = to_captum_model(GCN, mask_type, 0) + explainer = getattr(attr, method)(captum_model) + data = Graph(x, edge_index) + input, additional_forward_args = to_captum_input(data.x, data.edge_index, + mask_type) + if mask_type == 'node': + sliding_window_shapes = (3, 3) + elif mask_type == 'edge': + sliding_window_shapes = (5, ) + elif mask_type == 'node_and_edge': + sliding_window_shapes = ((3, 3), (5, )) + + if method == 'IntegratedGradients': + attributions, delta = explainer.attribute( + input, target=0, internal_batch_size=1, + additional_forward_args=additional_forward_args, + return_convergence_delta=True) + elif method == 'GradientShap': + attributions, delta = explainer.attribute( + input, target=0, return_convergence_delta=True, baselines=input, + n_samples=1, additional_forward_args=additional_forward_args) + elif method == 'DeepLiftShap' or method == 'DeepLift': + attributions, delta = explainer.attribute( + input, target=0, return_convergence_delta=True, baselines=input, + additional_forward_args=additional_forward_args) + elif method == 'Occlusion': + attributions = explainer.attribute( + input, target=0, sliding_window_shapes=sliding_window_shapes, + additional_forward_args=additional_forward_args) + else: + attributions = explainer.attribute( + input, target=0, additional_forward_args=additional_forward_args) + if mask_type == 'node': + assert attributions[0].shape == (1, 8, 3) + elif mask_type == 'edge': + assert attributions[0].shape == (1, 14) + else: + assert attributions[0].shape == (1, 8, 3) + assert attributions[1].shape == (1, 14) + + +def test_custom_explain_message(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + + conv = SAGEConv(8, 32) + + def explain_message(self, inputs, x_i, x_j): + assert isinstance(self, SAGEConv) + assert inputs.shape == (6, 8) + assert inputs.shape == x_i.shape == x_j.shape + assert ops.isclose(inputs, x_j).all() + self.x_i = x_i + self.x_j = x_j + return inputs + + conv.explain_message = explain_message.__get__(conv, MessagePassing) + conv.explain = True + + conv(x, edge_index) + + assert ops.isclose(conv.x_i, x[edge_index[1]]).all() + assert ops.isclose(conv.x_j, x[edge_index[0]]).all() + + +@withPackage('captum') +@pytest.mark.parametrize('mask_type', ['node', 'edge', 'node_and_edge']) +def test_to_captum_input(mask_type): + num_nodes = x.shape[0] + num_node_feats = x.shape[1] + num_edges = edge_index.shape[1] + + # Check for Data: + data = Graph(x, edge_index) + args = 'test_args' + inputs, additional_forward_args = to_captum_input(data.x, data.edge_index, + mask_type, args) + if mask_type == 'node': + assert len(inputs) == 1 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert len(additional_forward_args) == 2 + assert ops.isclose(additional_forward_args[0], edge_index).all() + elif mask_type == 'edge': + assert len(inputs) == 1 + assert inputs[0].shape == (1, num_edges) + assert inputs[0].sum() == num_edges + assert len(additional_forward_args) == 3 + assert ops.isclose(additional_forward_args[0], x).all() + assert ops.isclose(additional_forward_args[1], edge_index).all() + else: + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_edges) + assert inputs[1].sum() == num_edges + assert len(additional_forward_args) == 2 + assert ops.isclose(additional_forward_args[0], edge_index).all() + + # Check for HeteroGraph: + data = HeteroGraph() + x2 = ops.rand(8, 3) + data['paper'].x = x + data['author'].x = x2 + data['paper', 'to', 'author'].edge_index = edge_index + data['author', 'to', 'paper'].edge_index = edge_index.flip([0]) + inputs, additional_forward_args = to_captum_input(data.x_dict, + data.edge_index_dict, + mask_type, args) + if mask_type == 'node': + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_nodes, num_node_feats) + assert len(additional_forward_args) == 2 + for key in data.edge_types: + ops.isclose(additional_forward_args[0][key], + data[key].edge_index).all() + elif mask_type == 'edge': + assert len(inputs) == 2 + assert inputs[0].shape == (1, num_edges) + assert inputs[1].shape == (1, num_edges) + assert inputs[1].sum() == inputs[0].sum() == num_edges + assert len(additional_forward_args) == 3 + for key in data.node_types: + ops.isclose(additional_forward_args[0][key], data[key].x).all() + for key in data.edge_types: + ops.isclose(additional_forward_args[1][key], + data[key].edge_index).all() + else: + assert len(inputs) == 4 + assert inputs[0].shape == (1, num_nodes, num_node_feats) + assert inputs[1].shape == (1, num_nodes, num_node_feats) + assert inputs[2].shape == (1, num_edges) + assert inputs[3].shape == (1, num_edges) + assert inputs[3].sum() == inputs[2].sum() == num_edges + assert len(additional_forward_args) == 2 + for key in data.edge_types: + ops.isclose(additional_forward_args[0][key], + data[key].edge_index).all() diff --git a/tests/graph/explain/algorithm/test_captum_explainer.py b/tests/graph/explain/algorithm/test_captum_explainer.py new file mode 100644 index 000000000..d444d8f97 --- /dev/null +++ b/tests/graph/explain/algorithm/test_captum_explainer.py @@ -0,0 +1,255 @@ +from typing import Optional + +import pytest +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.explain import Explainer, Explanation +from mindscience.sharker.explain.algorithm import CaptumExplainer +from mindscience.sharker.explain.config import ( + MaskType, + ModelConfig, + ModelMode, + ModelReturnType, + ModelTaskLevel, +) +from mindscience.sharker.nn import GCNConv, global_add_pool +from mindscience.sharker.testing import withPackage + +methods = [ + 'Saliency', + 'InputXGradient', + 'Deconvolution', + 'ShapleyValueSampling', + 'IntegratedGradients', + 'GuidedBackprop', +] + +unsupported_methods = [ + 'FeatureAblation', + 'Occlusion', + 'DeepLift', + 'DeepLiftShap', + 'GradientShap', + 'KernelShap', + 'Lime', +] + + +class GCN(nn.Cell): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + if model_config.mode == ModelMode.multiclass_classification: + out_channels = 7 + else: + out_channels = 1 + + self.conv1 = GCNConv(3, 16) + self.conv2 = GCNConv(16, out_channels) + + def construct(self, x, edge_index, batch=None, edge_label_index=None): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + + if self.model_config.task_level == ModelTaskLevel.graph: + x = global_add_pool(x, batch) + elif self.model_config.task_level == ModelTaskLevel.edge: + assert edge_label_index is not None + x = x[edge_label_index[0]] * x[edge_label_index[1]] + + if self.model_config.mode == ModelMode.binary_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = x.sigmoid() + elif self.model_config.mode == ModelMode.multiclass_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = ops.softmax(x, -1) + elif self.model_config.return_type == ModelReturnType.log_probs: + x = ops.log_softmax(x, -1) + + return x + + +node_mask_types = [MaskType.attributes, None] +edge_mask_types = [MaskType.object, None] +task_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph] +indices = [1, ops.arange(2)] + + +def check_explanation( + explanation: Explanation, + node_mask_type: Optional[MaskType], + edge_mask_type: Optional[MaskType], +): + if node_mask_type == MaskType.attributes: + assert explanation.node_mask.shape == explanation.x.shape + elif node_mask_type is None: + assert 'node_mask' not in explanation + + if edge_mask_type == MaskType.object: + assert explanation.edge_mask.shape == (explanation.num_edges, ) + elif edge_mask_type is None: + assert 'edge_mask' not in explanation + + +@withPackage('captum') +@pytest.mark.parametrize('method', unsupported_methods) +def test_unsupported_methods(method): + model_config = ModelConfig(mode='regression', task_level='node') + + with pytest.raises(ValueError, match="does not support attribution"): + Explainer( + GCN(model_config), + algorithm=CaptumExplainer(method), + explanation_type='model', + edge_mask_type='object', + node_mask_type='attributes', + model_config=model_config, + ) + + +@withPackage('captum') +@pytest.mark.parametrize('method', ['IntegratedGradients']) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('task_level', task_levels) +@pytest.mark.parametrize('index', indices) +def test_captum_explainer_binary_classification( + method, + data, + node_mask_type, + edge_mask_type, + task_level, + index, +): + if node_mask_type is None and edge_mask_type is None: + return + + batch = ms.Tensor([0, 0, 1, 1]) + edge_label_index = ms.Tensor([[0, 1, 2], [2, 3, 1]]) + + model_config = ModelConfig( + mode='binary_classification', + task_level=task_level, + return_type='probs', + ) + + explainer = Explainer( + GCN(model_config), + algorithm=CaptumExplainer(method), + explanation_type='model', + edge_mask_type=edge_mask_type, + node_mask_type=node_mask_type, + model_config=model_config, + ) + + explanation = explainer( + data.x, + data.edge_index, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + check_explanation(explanation, node_mask_type, edge_mask_type) + + +@withPackage('captum') +@pytest.mark.parametrize('method', methods) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('task_level', task_levels) +@pytest.mark.parametrize('index', indices) +def test_captum_explainer_multiclass_classification( + method, + data, + node_mask_type, + edge_mask_type, + task_level, + index, +): + if node_mask_type is None and edge_mask_type is None: + return + + batch = ms.Tensor([0, 0, 1, 1]) + edge_label_index = ms.Tensor([[0, 1, 2], [2, 3, 1]]) + + model_config = ModelConfig( + mode='multiclass_classification', + task_level=task_level, + return_type='probs', + ) + + explainer = Explainer( + GCN(model_config), + algorithm=CaptumExplainer(method), + explanation_type='model', + edge_mask_type=edge_mask_type, + node_mask_type=node_mask_type, + model_config=model_config, + ) + + explanation = explainer( + data.x, + data.edge_index, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + check_explanation(explanation, node_mask_type, edge_mask_type) + + +@withPackage('captum') +@pytest.mark.parametrize( + 'method', + [m for m in methods if m != 'ShapleyValueSampling'], +) +@pytest.mark.parametrize( + 'node_mask_type', + [nm for nm in node_mask_types if nm is not None], +) +@pytest.mark.parametrize( + 'edge_mask_type', + [em for em in edge_mask_types if em is not None], +) +@pytest.mark.parametrize('index', [1, ops.arange(2)]) +def test_captum_hetero_data(method, node_mask_type, edge_mask_type, index, + hetero_data, hetero_model): + + model_config = ModelConfig(mode='regression', task_level='node') + + explainer = Explainer( + hetero_model(hetero_data.metadata()), + algorithm=CaptumExplainer(method), + edge_mask_type=edge_mask_type, + node_mask_type=node_mask_type, + model_config=model_config, + explanation_type='model', + ) + + explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict, + index=index) + + explanation.validate(raise_on_error=True) + + +@withPackage('captum') +@pytest.mark.parametrize('node_mask_type', [ + MaskType.object, + MaskType.common_attributes, +]) +def test_captum_explainer_supports(node_mask_type): + model_config = ModelConfig( + mode='multiclass_classification', + task_level='node', + return_type='probs', + ) + + with pytest.raises(ValueError, match="not support the given explanation"): + Explainer( + GCN(model_config), + algorithm=CaptumExplainer('IntegratedGradients'), + edge_mask_type=MaskType.object, + node_mask_type=node_mask_type, + model_config=model_config, + explanation_type='model', + ) diff --git a/tests/graph/explain/algorithm/test_captum_hetero.py b/tests/graph/explain/algorithm/test_captum_hetero.py new file mode 100644 index 000000000..3af5a8739 --- /dev/null +++ b/tests/graph/explain/algorithm/test_captum_hetero.py @@ -0,0 +1,106 @@ +import pytest + +from mindscience.sharker.explain.algorithm.captum import ( + CaptumHeteroModel, + captum_output_to_dicts, + to_captum_input, + to_captum_model +) +from mindscience.sharker.testing import withPackage + +mask_types = [ + 'node', + 'edge', + 'node_and_edge', +] + +methods = [ + 'Saliency', + 'InputXGradient', + 'Deconvolution', + 'FeatureAblation', + 'ShapleyValueSampling', + 'IntegratedGradients', + 'GradientShap', + 'Occlusion', + 'GuidedBackprop', + 'KernelShap', + 'Lime', +] + + +@withPackage('captum') +@pytest.mark.parametrize('mask_type', mask_types) +@pytest.mark.parametrize('method', methods) +def test_captum_attribution_methods_hetero(mask_type, method, hetero_data, + hetero_model): + from captum import attr # noqa + data = hetero_data + metadata = data.metadata() + model = hetero_model(metadata) + captum_model = to_captum_model(model, mask_type, 0, metadata) + explainer = getattr(attr, method)(captum_model) + assert isinstance(captum_model, CaptumHeteroModel) + + inputs, additional_forward_args = to_captum_input( + data.x_dict, + data.edge_index_dict, + mask_type, + 'additional_arg', + ) + + if mask_type == 'node': + sliding_window_shapes = ((3, 3), (3, 3)) + elif mask_type == 'edge': + sliding_window_shapes = ((5, ), (5, ), (5, )) + else: + sliding_window_shapes = ((3, 3), (3, 3), (5, ), (5, ), (5, )) + + if method == 'IntegratedGradients': + attributions, delta = explainer.attribute( + inputs, target=0, internal_batch_size=1, + additional_forward_args=additional_forward_args, + return_convergence_delta=True) + elif method == 'GradientShap': + attributions, delta = explainer.attribute( + inputs, target=0, return_convergence_delta=True, baselines=inputs, + n_samples=1, additional_forward_args=additional_forward_args) + elif method == 'DeepLiftShap' or method == 'DeepLift': + attributions, delta = explainer.attribute( + inputs, target=0, return_convergence_delta=True, baselines=inputs, + additional_forward_args=additional_forward_args) + elif method == 'Occlusion': + attributions = explainer.attribute( + inputs, target=0, sliding_window_shapes=sliding_window_shapes, + additional_forward_args=additional_forward_args) + else: + attributions = explainer.attribute( + inputs, target=0, additional_forward_args=additional_forward_args) + + if mask_type == 'node': + assert len(attributions) == len(metadata[0]) + x_attr_dict, _ = captum_output_to_dicts(attributions, mask_type, + metadata) + for node_type in metadata[0]: + num_nodes = data[node_type].num_nodes + num_node_feats = data[node_type].x.shape[1] + assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) + elif mask_type == 'edge': + assert len(attributions) == len(metadata[1]) + _, edge_attr_dict = captum_output_to_dicts(attributions, mask_type, + metadata) + for edge_type in metadata[1]: + num_edges = data[edge_type].edge_index.shape[1] + assert edge_attr_dict[edge_type].shape == (num_edges, ) + else: + assert len(attributions) == len(metadata[0]) + len(metadata[1]) + x_attr_dict, edge_attr_dict = captum_output_to_dicts( + attributions, mask_type, metadata) + for edge_type in metadata[1]: + num_edges = data[edge_type].edge_index.shape[1] + assert edge_attr_dict[edge_type].shape == (num_edges, ) + + for node_type in metadata[0]: + num_nodes = data[node_type].num_nodes + num_node_feats = data[node_type].x.shape[1] + assert x_attr_dict[node_type].shape == (num_nodes, num_node_feats) diff --git a/tests/graph/explain/algorithm/test_explain_algorithm_utils.py b/tests/graph/explain/algorithm/test_explain_algorithm_utils.py new file mode 100644 index 000000000..af69e57fd --- /dev/null +++ b/tests/graph/explain/algorithm/test_explain_algorithm_utils.py @@ -0,0 +1,78 @@ +from mindspore import nn, ops +from mindscience.sharker.explain.algorithm.utils import ( + clear_masks, + set_hetero_masks, +) +from mindscience.sharker.nn import GCNConv, HeteroConv, SAGEConv # , to_hetero + + +class HeteroModel(nn.Cell): + def __init__(self): + super().__init__() + + self.conv1 = HeteroConv({ + ('paper', 'to', 'paper'): + GCNConv(-1, 32), + ('author', 'to', 'paper'): + SAGEConv((-1, -1), 32), + ('paper', 'to', 'author'): + SAGEConv((-1, -1), 32), + }) + + self.conv2 = HeteroConv({ + ('paper', 'to', 'paper'): + GCNConv(-1, 32), + ('author', 'to', 'paper'): + SAGEConv((-1, -1), 32), + ('paper', 'to', 'author'): + SAGEConv((-1, -1), 32), + }) + + +class GraphSAGE(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = SAGEConv((-1, -1), 32) + self.conv2 = SAGEConv((-1, -1), 32) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + return self.conv2(x, edge_index) + + +def test_set_clear_mask(hetero_data): + edge_mask_dict = { + ('paper', 'to', 'paper'): ops.ones(200), + ('author', 'to', 'paper'): ops.ones(100), + ('paper', 'to', 'author'): ops.ones(100), + } + + model = HeteroModel() + + set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict) + for edge_type in hetero_data.edge_types: + # Check that masks are correctly set: + assert ops.isclose(model.conv1.convs[edge_type]._edge_mask, + edge_mask_dict[edge_type]).all() + assert model.conv1.convs[edge_type].explain + + clear_masks(model) + for edge_type in hetero_data.edge_types: + assert model.conv1.convs[edge_type]._edge_mask is None + assert not model.conv1.convs[edge_type].explain + + model = to_hetero(GraphSAGE(), hetero_data.metadata(), debug=False) + + set_hetero_masks(model, edge_mask_dict, hetero_data.edge_index_dict) + for edge_type in hetero_data.edge_types: + # Check that masks are correctly set: + str_edge_type = '__'.join(edge_type) + assert ops.isclose(model.conv1[str_edge_type]._edge_mask, + edge_mask_dict[edge_type]).all() + assert model.conv1[str_edge_type].explain + + clear_masks(model) + for edge_type in hetero_data.edge_types: + str_edge_type = '__'.join(edge_type) + assert model.conv1[str_edge_type]._edge_mask is None + assert not model.conv1[str_edge_type].explain diff --git a/tests/graph/explain/algorithm/test_gnn_explainer.py b/tests/graph/explain/algorithm/test_gnn_explainer.py new file mode 100644 index 000000000..f996d087a --- /dev/null +++ b/tests/graph/explain/algorithm/test_gnn_explainer.py @@ -0,0 +1,275 @@ +import pytest +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.explain import Explainer, GNNExplainer +from mindscience.sharker.explain.config import ( + ExplanationType, + MaskType, + ModelConfig, + ModelMode, + ModelReturnType, + ModelTaskLevel, +) +from mindscience.sharker.nn import AttentiveFP, ChebConv, GCNConv, global_add_pool + + +class GCN(nn.Cell): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + if model_config.mode == ModelMode.multiclass_classification: + out_channels = 7 + else: + out_channels = 1 + + self.conv1 = GCNConv(3, 16) + self.conv2 = GCNConv(16, out_channels) + + def construct(self, x, edge_index, batch=None, edge_label_index=None): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + + if self.model_config.task_level == ModelTaskLevel.graph: + x = global_add_pool(x, batch) + elif self.model_config.task_level == ModelTaskLevel.edge: + assert edge_label_index is not None + x = x[edge_label_index[0]] * x[edge_label_index[1]] + + if self.model_config.mode == ModelMode.binary_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = x.sigmoid() + elif self.model_config.mode == ModelMode.multiclass_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = ops.softmax(x, axis=-1) + elif self.model_config.return_type == ModelReturnType.log_probs: + x = ops.log_softmax(x, axis=-1) + + return x + + +node_mask_types = [ + MaskType.object, + MaskType.common_attributes, + MaskType.attributes, +] +edge_mask_types = [MaskType.object, None] +explanation_types = [ExplanationType.model, ExplanationType.phenomenon] +task_levels = [ModelTaskLevel.node, ModelTaskLevel.edge, ModelTaskLevel.graph] +indices = [None, 2, ops.arange(3)] + +x = ops.randn(8, 3) +edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], +]) +edge_attr = ops.randn(edge_index.shape[1], 5) +batch = ms.Tensor([0, 0, 0, 1, 1, 2, 2, 2]) +edge_label_index = ms.Tensor([[0, 1, 2], [3, 4, 5]]) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', explanation_types) +@pytest.mark.parametrize('task_level', task_levels) +@pytest.mark.parametrize('return_type', [ + ModelReturnType.probs, + ModelReturnType.raw, +]) +@pytest.mark.parametrize('index', indices) +def test_gnn_explainer_binary_classification( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + return_type, + index, + check_explanation, +): + model_config = ModelConfig( + mode='binary_classification', + task_level=task_level, + return_type=return_type, + ) + + model = GCN(model_config) + + target = None + if explanation_type == ExplanationType.phenomenon: + out = model(x, edge_index, batch, edge_label_index) + if model_config.return_type == ModelReturnType.raw: + target = (out > 0).long().view(-1) + if model_config.return_type == ModelReturnType.probs: + target = (out > 0.5).long().view(-1) + + explainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=2), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + assert explainer.algorithm.node_mask is None + assert explainer.algorithm.edge_mask is None + + check_explanation(explanation, node_mask_type, edge_mask_type) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', explanation_types) +@pytest.mark.parametrize('task_level', task_levels) +@pytest.mark.parametrize('return_type', [ + ModelReturnType.log_probs, + ModelReturnType.probs, + ModelReturnType.raw, +]) +@pytest.mark.parametrize('index', indices) +def test_gnn_explainer_multiclass_classification( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + return_type, + index, + check_explanation, +): + model_config = ModelConfig( + mode='multiclass_classification', + task_level=task_level, + return_type=return_type, + ) + + model = GCN(model_config) + + target = None + if explanation_type == ExplanationType.phenomenon: + target = model(x, edge_index, batch, edge_label_index).argmax(-1) + + explainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=2), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + assert explainer.algorithm.node_mask is None + assert explainer.algorithm.edge_mask is None + + check_explanation(explanation, node_mask_type, edge_mask_type) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', explanation_types) +@pytest.mark.parametrize('task_level', task_levels) +@pytest.mark.parametrize('index', indices) +def test_gnn_explainer_regression( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + index, + check_explanation, +): + model_config = ModelConfig( + mode='regression', + task_level=task_level, + ) + + model = GCN(model_config) + + target = None + if explanation_type == ExplanationType.phenomenon: + target = model(x, edge_index, batch, edge_label_index) + + explainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=2), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + assert explainer.algorithm.node_mask is None + assert explainer.algorithm.edge_mask is None + + check_explanation(explanation, node_mask_type, edge_mask_type) + + +def test_gnn_explainer_cheb_conv(check_explanation): + explainer = Explainer( + model=ChebConv(3, 1, K=2), + algorithm=GNNExplainer(epochs=2), + explanation_type='model', + node_mask_type='object', + edge_mask_type='object', + model_config=dict( + mode='binary_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index) + + assert explainer.algorithm.node_mask is None + assert explainer.algorithm.edge_mask is None + + check_explanation(explanation, MaskType.object, MaskType.object) + + +def test_gnn_explainer_attentive_fp(check_explanation): + model = AttentiveFP(3, 16, 1, edge_dim=5, num_layers=2, num_timesteps=2) + + explainer = Explainer( + model=model, + algorithm=GNNExplainer(epochs=2), + explanation_type='model', + node_mask_type='object', + edge_mask_type='object', + model_config=dict( + mode='binary_classification', + task_level='node', + return_type='raw', + ), + ) + + explanation = explainer(x, edge_index, edge_attr=edge_attr, batch=batch) + + assert explainer.algorithm.node_mask is None + assert explainer.algorithm.edge_mask is None + + check_explanation(explanation, MaskType.object, MaskType.object) diff --git a/tests/graph/explain/algorithm/test_graphmask_explainer.py b/tests/graph/explain/algorithm/test_graphmask_explainer.py new file mode 100644 index 000000000..230189e22 --- /dev/null +++ b/tests/graph/explain/algorithm/test_graphmask_explainer.py @@ -0,0 +1,232 @@ +import pytest +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.explain import Explainer, Explanation, GraphMaskExplainer +from mindscience.sharker.explain.config import ( + MaskType, + ModelConfig, + ModelMode, + ModelReturnType, + ModelTaskLevel, +) +from mindscience.sharker.nn import GCNConv, global_add_pool + + +class GCN(nn.Cell): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + if model_config.mode == ModelMode.multiclass_classification: + out_channels = 7 + else: + out_channels = 1 + + self.conv1 = GCNConv(3, 16) + self.conv2 = GCNConv(16, out_channels) + + def construct(self, x, edge_index, batch=None, edge_label_index=None): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + + if self.model_config.task_level == ModelTaskLevel.graph: + x = global_add_pool(x, batch) + elif self.model_config.task_level == ModelTaskLevel.edge: + assert edge_label_index is not None + x = x[edge_label_index[0]] * x[edge_label_index[1]] + + if self.model_config.mode == ModelMode.binary_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = x.sigmoid() + elif self.model_config.mode == ModelMode.multiclass_classification: + if self.model_config.return_type == ModelReturnType.probs: + x = ops.softmax(x, axis=-1) + elif self.model_config.return_type == ModelReturnType.log_probs: + x = ops.log_softmax(x, axis=-1) + + return x + + +def check_explanation( + edge_mask_type: MaskType, + node_mask_type: MaskType, + explanation: Explanation, +): + if node_mask_type == MaskType.attributes: + assert explanation.node_mask.shape == explanation.x.shape + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + elif node_mask_type == MaskType.object: + assert explanation.node_mask.shape == (explanation.num_nodes, 1) + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + elif node_mask_type == MaskType.common_attributes: + assert explanation.node_mask.shape == (1, explanation.num_features) + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + + if edge_mask_type == MaskType.object: + assert explanation.edge_mask.shape == (explanation.num_edges, ) + assert explanation.edge_mask.min() >= 0 + assert explanation.edge_mask.max() <= 1 + + +node_mask_types = [ + MaskType.object, + MaskType.common_attributes, + MaskType.attributes, +] +edge_mask_types = [ + MaskType.object, + None, +] + +x = ops.randn(8, 3) +edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], +]) +batch = ms.Tensor([0, 0, 0, 1, 1, 2, 2, 2]) +edge_label_index = ms.Tensor([[0, 1, 2], [3, 4, 5]]) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) +@pytest.mark.parametrize('return_type', ['probs', 'raw']) +@pytest.mark.parametrize('index', [None, 2, ops.arange(3)]) +def test_graph_mask_explainer_binary_classification( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + return_type, + index, +): + model_config = ModelConfig( + mode='binary_classification', + task_level=task_level, + return_type=return_type, + ) + + model = GCN(model_config) + + target = None + if explanation_type == 'phenomenon': + out = model(x, edge_index, batch, edge_label_index) + if model_config.return_type == ModelReturnType.raw: + target = (out > 0).long().view(-1) + if model_config.return_type == ModelReturnType.probs: + target = (out > 0.5).long().view(-1) + + explainer = Explainer( + model=model, + algorithm=GraphMaskExplainer(2, epochs=5, log=False), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + check_explanation(edge_mask_type, node_mask_type, explanation) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) +@pytest.mark.parametrize('return_type', ['log_probs', 'probs', 'raw']) +@pytest.mark.parametrize('index', [None, 2, ops.arange(3)]) +def test_graph_mask_explainer_multiclass_classification( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + return_type, + index, +): + model_config = ModelConfig( + mode='multiclass_classification', + task_level=task_level, + return_type=return_type, + ) + + model = GCN(model_config) + + target = None + if explanation_type == 'phenomenon': + target = model(x, edge_index, batch, edge_label_index).argmax(-1) + + explainer = Explainer( + model=model, + algorithm=GraphMaskExplainer(2, epochs=5, log=False), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + check_explanation(edge_mask_type, node_mask_type, explanation) + + +@pytest.mark.parametrize('edge_mask_type', edge_mask_types) +@pytest.mark.parametrize('node_mask_type', node_mask_types) +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +@pytest.mark.parametrize('task_level', ['node', 'edge', 'graph']) +@pytest.mark.parametrize('index', [None, 2, ops.arange(3)]) +def test_graph_mask_explainer_regression( + edge_mask_type, + node_mask_type, + explanation_type, + task_level, + index, +): + model_config = ModelConfig( + mode='regression', + task_level=task_level, + ) + + model = GCN(model_config) + + target = None + if explanation_type == 'phenomenon': + target = model(x, edge_index, batch, edge_label_index) + + explainer = Explainer( + model=model, + algorithm=GraphMaskExplainer(2, epochs=5, log=False), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + model_config=model_config, + ) + + explanation = explainer( + x, + edge_index, + target=target, + index=index, + batch=batch, + edge_label_index=edge_label_index, + ) + + check_explanation(edge_mask_type, node_mask_type, explanation) diff --git a/tests/graph/explain/algorithm/test_pg_explainer.py b/tests/graph/explain/algorithm/test_pg_explainer.py new file mode 100644 index 000000000..001004a89 --- /dev/null +++ b/tests/graph/explain/algorithm/test_pg_explainer.py @@ -0,0 +1,178 @@ +import pytest +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.explain import Explainer, PGExplainer +from mindscience.sharker.explain.config import ( + ModelConfig, + ModelMode, + ModelTaskLevel, +) +from mindscience.sharker.nn import GCNConv, global_add_pool + + +class GCN(nn.Cell): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + if model_config.mode == ModelMode.multiclass_classification: + out_channels = 7 + else: + out_channels = 1 + + self.conv1 = GCNConv(3, 16) + self.conv2 = GCNConv(16, out_channels) + + def construct(self, x, edge_index, batch=None, edge_label_index=None): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + if self.model_config.task_level == ModelTaskLevel.graph: + x = global_add_pool(x, batch) + return x + + +@pytest.mark.parametrize('mode', [ + ModelMode.binary_classification, + ModelMode.multiclass_classification, + ModelMode.regression, +]) +def test_pg_explainer_node(check_explanation, mode): + x = ops.randn(8, 3) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], + ]) + + if mode == ModelMode.binary_classification: + target = ops.randint(0, 2, (x.shape[0], )) + in_channels = 3 + elif mode == ModelMode.multiclass_classification: + target = ops.randint(0, 7, (x.shape[0], )) + in_channels = 21 + elif mode == ModelMode.regression: + target = ops.randn((x.shape[0], 1)) + in_channels = 3 + + model_config = ModelConfig(mode=mode, task_level='node', return_type='raw') + + model = GCN(model_config) + + explainer = Explainer( + model=model, + algorithm=PGExplainer(in_channels=in_channels, epochs=2), + explanation_type='phenomenon', + edge_mask_type='object', + model_config=model_config, + ) + + with pytest.raises(ValueError, match="not yet fully trained"): + explainer(x, edge_index, target=target) + + explainer.algorithm.reset_parameters() + for epoch in range(2): + for index in range(x.shape[0]): + loss = explainer.algorithm.train(epoch, model, x, edge_index, + target=target, index=index) + assert loss >= 0.0 + + explanation = explainer(x, edge_index, target=target, index=0) + + check_explanation(explanation, None, explainer.edge_mask_type) + + +@pytest.mark.parametrize('mode', [ + ModelMode.binary_classification, + ModelMode.multiclass_classification, + ModelMode.regression, +]) +def test_pg_explainer_graph(check_explanation, mode): + x = ops.randn(8, 3) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], + ]) + in_channels, target = None, None + if mode == ModelMode.binary_classification: + target = np.random.randint(2) + in_channels = 2 + elif mode == ModelMode.multiclass_classification: + target = np.random.randint(7) + in_channels = 14 + elif mode == ModelMode.regression: + target = ops.randn((1, 1)) + in_channels = 2 + + model_config = ModelConfig(mode=mode, task_level='graph', + return_type='raw') + + model = GCN(model_config) + + explainer = Explainer( + model=model, + algorithm=PGExplainer(in_channels=in_channels, epochs=2), + explanation_type='phenomenon', + edge_mask_type='object', + model_config=model_config, + ) + + with pytest.raises(ValueError, match="not yet fully trained"): + explainer(x, edge_index, target=target) + + explainer.algorithm.reset_parameters() + for epoch in range(2): + loss = explainer.algorithm.train(epoch, model, x, edge_index, + target=target) + assert loss >= 0.0 + + explanation = explainer(x, edge_index, target=target) + + check_explanation(explanation, None, explainer.edge_mask_type) + + +def test_pg_explainer_supports(): + # Test unsupported model task level: + with pytest.raises(ValueError, match="not support the given explanation"): + model_config = ModelConfig( + mode='binary_classification', + task_level='edge', + return_type='raw', + ) + Explainer( + model=GCN(model_config), + algorithm=PGExplainer(2, epochs=2), + explanation_type='phenomenon', + edge_mask_type='object', + model_config=model_config, + ) + + # Test unsupported explanation type: + with pytest.raises(ValueError, match="not support the given explanation"): + model_config = ModelConfig( + mode='binary_classification', + task_level='node', + return_type='raw', + ) + Explainer( + model=GCN(model_config), + algorithm=PGExplainer(2, epochs=2), + explanation_type='model', + edge_mask_type='object', + model_config=model_config, + ) + + # Test unsupported node mask: + with pytest.raises(ValueError, match="not support the given explanation"): + model_config = ModelConfig( + mode='binary_classification', + task_level='node', + return_type='raw', + ) + Explainer( + model=GCN(model_config), + algorithm=PGExplainer(2, epochs=2), + explanation_type='model', + node_mask_type='object', + edge_mask_type='object', + model_config=model_config, + ) diff --git a/tests/graph/explain/conftest.py b/tests/graph/explain/conftest.py new file mode 100644 index 000000000..1942c76a6 --- /dev/null +++ b/tests/graph/explain/conftest.py @@ -0,0 +1,94 @@ +from typing import Optional + +import pytest +from mindspore import nn, load_obf_params_into_net +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.explain import Explanation +from mindscience.sharker.explain.config import MaskType +from mindscience.sharker.nn import SAGEConv +from mindscience.sharker.testing import get_random_edge_index + + +@pytest.fixture() +def data(): + return Graph( + x=ops.randn(4, 3), + edge_index=get_random_edge_index(4, 4, num_edges=6), + edge_attr=ops.randn(6, 3), + ) + + +@pytest.fixture() +def hetero_data(): + data = HeteroGraph() + data['paper'].x = ops.randn(8, 16) + data['author'].x = ops.randn(10, 8) + + data['paper', 'paper'].edge_index = get_random_edge_index(8, 8, 10) + data['paper', 'paper'].edge_attr = ops.randn(10, 16) + data['paper', 'author'].edge_index = get_random_edge_index(8, 10, 10) + data['paper', 'author'].edge_attr = ops.randn(10, 8) + data['author', 'paper'].edge_index = get_random_edge_index(10, 8, 10) + data['author', 'paper'].edge_attr = ops.randn(10, 8) + + return data + + +@pytest.fixture() +def hetero_model(): + return HeteroSAGE + + +class GraphSAGE(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = SAGEConv((-1, -1), 32) + self.conv2 = SAGEConv((-1, -1), 32) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + return self.conv2(x, edge_index) + + +class HeteroSAGE(nn.Cell): + def __init__(self, metadata): + super().__init__() + self.graph_sage = to_hetero(GraphSAGE(), metadata, debug=False) + self.lin = nn.Dense(32, 1) + + def construct(self, x_dict, edge_index_dict, + additonal_arg=None) -> Tensor: + return self.lin(self.graph_sage(x_dict, edge_index_dict)['paper']) + + +@pytest.fixture() +def check_explanation(): + def _check_explanation( + explanation: Explanation, + node_mask_type: Optional[MaskType], + edge_mask_type: Optional[MaskType], + ): + if node_mask_type == MaskType.attributes: + assert explanation.node_mask.shape == explanation.x.shape + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + elif node_mask_type == MaskType.object: + assert explanation.node_mask.shape == (explanation.num_nodes, 1) + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + elif node_mask_type == MaskType.common_attributes: + assert explanation.node_mask.shape == (1, explanation.x.shape[-1]) + assert explanation.node_mask.min() >= 0 + assert explanation.node_mask.max() <= 1 + elif node_mask_type is None: + assert 'node_mask' not in explanation + + if edge_mask_type == MaskType.object: + assert explanation.edge_mask.shape == (explanation.num_edges, ) + assert explanation.edge_mask.min() >= 0 + assert explanation.edge_mask.max() <= 1 + elif edge_mask_type is None: + assert 'edge_mask' not in explanation + + return _check_explanation diff --git a/tests/graph/explain/metric/test_basic_metric.py b/tests/graph/explain/metric/test_basic_metric.py new file mode 100644 index 000000000..de0a1a6f4 --- /dev/null +++ b/tests/graph/explain/metric/test_basic_metric.py @@ -0,0 +1,46 @@ +import warnings + +import mindspore as ms +from mindspore import ops +from mindscience.sharker.explain import groundtruth_metrics + + +def test_groundtruth_metrics(): + pred_mask = ops.rand(10) + target_mask = ops.rand(10) + + accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( + pred_mask, target_mask) + + assert accuracy >= 0.0 and accuracy <= 1.0 + assert recall >= 0.0 and recall <= 1.0 + assert precision >= 0.0 and precision <= 1.0 + assert f1_score >= 0.0 and f1_score <= 1.0 + assert auroc >= 0.0 and auroc <= 1.0 + + +def test_perfect_groundtruth_metrics(): + pred_mask = target_mask = ops.rand(10) + + accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( + pred_mask, target_mask) + + assert round(accuracy, 6) == 1.0 + assert round(recall, 6) == 1.0 + assert round(precision, 6) == 1.0 + assert round(f1_score, 6) == 1.0 + assert round(auroc, 6) == 1.0 + + +def test_groundtruth_true_negative(): + warnings.filterwarnings('ignore', '.*No positive samples in targets.*') + pred_mask = target_mask = ops.zeros(10) + + accuracy, recall, precision, f1_score, auroc = groundtruth_metrics( + pred_mask, target_mask) + + assert round(accuracy, 6) == 1.0 + assert round(recall, 6) == 0.0 + assert round(precision, 6) == 0.0 + assert round(f1_score, 6) == 0.0 + assert round(auroc, 6) == 0.0 diff --git a/tests/graph/explain/metric/test_faithfulness.py b/tests/graph/explain/metric/test_faithfulness.py new file mode 100644 index 000000000..797077634 --- /dev/null +++ b/tests/graph/explain/metric/test_faithfulness.py @@ -0,0 +1,59 @@ +import pytest +import mindspore as ms +from mindspore import nn, ops +from mindscience.sharker.explain import ( + DummyExplainer, + Explainer, + ModelConfig, + unfaithfulness, +) + + +class DummyModel(nn.Cell): + def __init__(self, model_config: ModelConfig): + super().__init__() + self.model_config = model_config + + def construct(self, x, edge_index): + if self.model_config.return_type.value == 'probs': + x = ops.softmax(x, -1) + elif self.model_config.return_type.value == 'log_probs': + x = ops.log_softmax(x, -1) + return x + + +@pytest.mark.parametrize('top_k', [None, 2]) +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +@pytest.mark.parametrize('node_mask_type', ['common_attributes', 'attributes']) +@pytest.mark.parametrize('return_type', ['raw', 'probs', 'log_probs']) +def test_unfaithfulness(top_k, explanation_type, node_mask_type, return_type): + x = ops.randn(8, 4) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], + ]) + + model_config = ModelConfig( + mode='multiclass_classification', + task_level='node', + return_type=return_type, + ) + + explainer = Explainer( + DummyModel(model_config), + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=model_config, + ) + + target = None + if explanation_type == 'phenomenon': + target = ops.randint(0, x.shape[1], (x.shape[0], )) + + explanation = explainer(x, edge_index, target=target, + index=ops.arange(4)) + + metric = unfaithfulness(explainer, explanation, top_k) + assert metric >= 0. and metric <= 1. diff --git a/tests/graph/explain/metric/test_fidelity.py b/tests/graph/explain/metric/test_fidelity.py new file mode 100644 index 000000000..023878906 --- /dev/null +++ b/tests/graph/explain/metric/test_fidelity.py @@ -0,0 +1,74 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops, nn +from mindscience.sharker.explain import ( + DummyExplainer, + Explainer, + characterization_score, + fidelity, + fidelity_curve_auc, +) + + +class DummyModel(nn.Cell): + def construct(self, x, edge_index): + return x + + +@pytest.mark.parametrize('explanation_type', ['model', 'phenomenon']) +def test_fidelity(explanation_type): + x = ops.randn(8, 4) + edge_index = Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6], + ]) + + explainer = Explainer( + DummyModel(), + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type='object', + edge_mask_type='object', + model_config=dict( + mode='multiclass_classification', + return_type='raw', + task_level='node', + ), + ) + + target = None + if explanation_type == 'phenomenon': + target = ops.randint(0, x.shape[1], (x.shape[0], )) + + explanation = explainer(x, edge_index, target=target, + index=ops.arange(4)) + + pos_fidelity, neg_fidelity = fidelity(explainer, explanation) + assert pos_fidelity == 0.0 and neg_fidelity == 0.0 + + +def test_characterization_score(): + out = characterization_score( + pos_fidelity=Tensor([1.0, 0.6, 0.5, 1.0]), + neg_fidelity=Tensor([0.0, 0.2, 0.5, 1.0]), + pos_weight=0.2, + neg_weight=0.8, + ) + assert ops.isclose(out, Tensor([1.0, 0.75, 0.5, 0.0])).all() + + +def test_fidelity_curve_auc(): + pos_fidelity = Tensor([1.0, 1.0, 0.5, 1.0]) + neg_fidelity = Tensor([0.0, 0.5, 0.5, 0.9]) + + x = Tensor([0, 1, 2, 3]) + out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) + assert out == 8.5 + + x = Tensor([10, 11, 12, 13]) + out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) + assert out == 8.5 + + x = Tensor([0, 1, 2, 5]) + out = round(float(fidelity_curve_auc(pos_fidelity, neg_fidelity, x)), 4) + assert out == 19.5 diff --git a/tests/graph/explain/test_explain_config.py b/tests/graph/explain/test_explain_config.py new file mode 100644 index 000000000..0cadefc6c --- /dev/null +++ b/tests/graph/explain/test_explain_config.py @@ -0,0 +1,48 @@ +import pytest + +from mindscience.sharker.explain.config import ExplainerConfig, ThresholdConfig + + +@pytest.mark.parametrize('threshold_pairs', [ + ('hard', 0.5, True), + ('hard', 1.1, False), + ('hard', -1, False), + ('topk', 1, True), + ('topk', 0, False), + ('topk', -1, False), + ('topk', 0.5, False), + ('invalid', None, False), + ('hard', None, False), +]) +def test_threshold_config(threshold_pairs): + threshold_type, threshold_value, valid = threshold_pairs + if valid: + threshold = ThresholdConfig(threshold_type, threshold_value) + assert threshold.type.value == threshold_type + assert threshold.value == threshold_value + else: + with pytest.raises(ValueError): + ThresholdConfig(threshold_type, threshold_value) + + +@pytest.mark.parametrize('explanation_type', [ + 'model', + 'phenomenon', + 'invalid', +]) +@pytest.mark.parametrize('mask_type', [ + None, + 'object', + 'common_attributes', + 'attributes', + 'invalid', +]) +def test_configuration_config(explanation_type, mask_type): + if (explanation_type != 'invalid' and mask_type is not None + and mask_type != 'invalid'): + config = ExplainerConfig(explanation_type, mask_type, None) + assert config.explanation_type.value == explanation_type + assert config.node_mask_type.value == mask_type + else: + with pytest.raises(ValueError): + ExplainerConfig(explanation_type, mask_type, mask_type) diff --git a/tests/graph/explain/test_explainer.py b/tests/graph/explain/test_explainer.py new file mode 100644 index 000000000..cca7fa461 --- /dev/null +++ b/tests/graph/explain/test_explainer.py @@ -0,0 +1,123 @@ +import pytest +import mindspore as ms +from mindspore import nn, ops +from mindscience.sharker.explain import DummyExplainer, Explainer, Explanation +from mindscience.sharker.explain.config import ExplanationType + + +class DummyModel(nn.Cell): + def construct(self, x, edge_index): + return x.mean().view(-1) + + +def test_get_prediction(data): + model = DummyModel() + # assert model.training + + explainer = Explainer( + model, + algorithm=DummyExplainer(), + explanation_type='phenomenon', + node_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + ) + pred = explainer.get_prediction(data.x, data.edge_index) + # assert model.training + assert pred.shape == (1, ) + + +@pytest.mark.parametrize('target', [None, ops.randn(2)]) +@pytest.mark.parametrize('explanation_type', [x for x in ExplanationType]) +def test_forward(data, target, explanation_type): + model = DummyModel() + # assert model.training + + explainer = Explainer( + model, + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type='attributes', + model_config=dict( + mode='regression', + task_level='graph', + ), + ) + + if target is None and explanation_type == ExplanationType.phenomenon: + with pytest.raises(ValueError): + explainer(data.x, data.edge_index, target=target) + else: + explanation = explainer( + data.x, + data.edge_index, + target=target + if explanation_type == ExplanationType.phenomenon else None, + ) + # assert model.training + assert isinstance(explanation, Explanation) + assert 'x' in explanation + assert 'edge_index' in explanation + assert 'target' in explanation + assert 'node_mask' in explanation.available_explanations + assert explanation.node_mask.shape == data.x.shape + + +@pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8]) +@pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) +def test_hard_threshold(data, threshold_value, node_mask_type): + explainer = Explainer( + DummyModel(), + algorithm=DummyExplainer(), + explanation_type='model', + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + threshold_config=('hard', threshold_value), + ) + explanation = explainer(data.x, data.edge_index) + + assert 'node_mask' in explanation.available_explanations + assert 'edge_mask' in explanation.available_explanations + + for key in explanation.available_explanations: + mask = explanation[key] + assert set(ops.unique(mask)[0].tolist()).issubset({0, 1}) + + +@pytest.mark.parametrize('threshold_value', [1, 5, 10]) +@pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard']) +@pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) +def test_topk_threshold(data, threshold_value, threshold_type, node_mask_type): + explainer = Explainer( + DummyModel(), + algorithm=DummyExplainer(), + explanation_type='model', + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + threshold_config=(threshold_type, threshold_value), + ) + explanation = explainer(data.x, data.edge_index) + + assert 'node_mask' in explanation.available_explanations + assert 'edge_mask' in explanation.available_explanations + + for key in explanation.available_explanations: + mask = explanation[key] + if threshold_type == 'topk': + assert (mask > 0).sum() == min(mask.numel(), threshold_value) + assert ((mask == 0).sum() == mask.numel() - + min(mask.numel(), threshold_value)) + else: + assert (mask == 1).sum() == min(mask.numel(), threshold_value) + assert ((mask == 0).sum() == mask.numel() - + min(mask.numel(), threshold_value)) diff --git a/tests/graph/explain/test_explanation.py b/tests/graph/explain/test_explanation.py new file mode 100644 index 000000000..33a4f71c1 --- /dev/null +++ b/tests/graph/explain/test_explanation.py @@ -0,0 +1,150 @@ +import os.path as osp +from typing import Optional, Union + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.explain import Explanation +from mindscience.sharker.explain.config import MaskType +from mindscience.sharker.testing import withPackage + + +def create_random_explanation( + data: Graph, + node_mask_type: Optional[Union[MaskType, str]] = None, + edge_mask_type: Optional[Union[MaskType, str]] = None, +): + if node_mask_type is not None: + node_mask_type = MaskType(node_mask_type) + if edge_mask_type is not None: + edge_mask_type = MaskType(edge_mask_type) + + if node_mask_type == MaskType.object: + node_mask = ops.rand(data.x.shape[0], 1) + elif node_mask_type == MaskType.common_attributes: + node_mask = ops.rand(1, data.x.shape[1]) + elif node_mask_type == MaskType.attributes: + node_mask = ops.rand_like(data.x) + else: + node_mask = None + + if edge_mask_type == MaskType.object: + edge_mask = ops.rand(data.edge_index.shape[1]) + else: + edge_mask = None + + return Explanation( # Create explanation. + node_mask=node_mask, + edge_mask=edge_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + + +@pytest.mark.parametrize('node_mask_type', + [None, 'object', 'common_attributes', 'attributes']) +@pytest.mark.parametrize('edge_mask_type', [None, 'object']) +def test_available_explanations(data, node_mask_type, edge_mask_type): + expected = [] + if node_mask_type is not None: + expected.append('node_mask') + if edge_mask_type is not None: + expected.append('edge_mask') + + explanation = create_random_explanation( + data, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + ) + + assert set(explanation.available_explanations) == set(expected) + + +def test_validate_explanation(data): + explanation = create_random_explanation(data) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 5 nodes"): + explanation = create_random_explanation(data, node_mask_type='object') + explanation.x = ops.randn(5, 5) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 4 features"): + explanation = create_random_explanation(data, 'attributes') + explanation.x = ops.randn(4, 4) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 7 edges"): + explanation = create_random_explanation(data, edge_mask_type='object') + explanation.edge_index = ops.randint(0, 4, (2, 7)) + explanation.validate(raise_on_error=True) + + +def test_node_mask(data): + node_mask = ms.Tensor([[1.], [0.], [1.], [0.]]) + + explanation = Explanation( + node_mask=node_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + explanation.validate(raise_on_error=True) + + out = explanation.get_explanation_subgraph() + assert out.node_mask.shape == (2, 1) + assert (out.node_mask > 0.0).sum() == 2 + assert out.x.shape == (2, 3) + assert out.edge_index.shape[1] <= 6 + assert out.edge_index.shape[1] == out.edge_attr.shape[0] + + out = explanation.get_complement_subgraph() + assert out.node_mask.shape == (2, 1) + assert (out.node_mask == 0.0).sum() == 2 + assert out.x.shape == (2, 3) + assert out.edge_index.shape[1] <= 6 + assert out.edge_index.shape[1] == out.edge_attr.shape[0] + + +def test_edge_mask(data): + edge_mask = ms.Tensor([1., 0., 1., 0., 0., 1.]) + + explanation = Explanation( + edge_mask=edge_mask, + x=data.x, + edge_index=data.edge_index, + edge_attr=data.edge_attr, + ) + explanation.validate(raise_on_error=True) + + out = explanation.get_explanation_subgraph() + assert out.x.shape == (4, 3) + assert out.edge_mask.shape == (3, ) + assert (out.edge_mask > 0.0).sum() == 3 + assert out.edge_index.shape == (2, 3) + assert out.edge_attr.shape == (3, 3) + + out = explanation.get_complement_subgraph() + assert out.x.shape == (4, 3) + assert out.edge_mask.shape == (3, ) + assert (out.edge_mask == 0.0).sum() == 3 + assert out.edge_index.shape == (2, 3) + assert out.edge_attr.shape == (3, 3) + + +@withPackage('matplotlib', 'pandas') +@pytest.mark.parametrize('top_k', [2, None]) +@pytest.mark.parametrize('node_mask_type', [None, 'attributes']) +def test_visualize_feature_importance(tmp_path, data, top_k, node_mask_type): + explanation = create_random_explanation(data, node_mask_type) + + path = osp.join(tmp_path, 'feature_importance.png') + + if node_mask_type is None: + with pytest.raises(ValueError, match="node_mask' is not"): + explanation.visualize_feature_importance(path, top_k=top_k) + else: + explanation.visualize_feature_importance(path, top_k=top_k) + assert osp.exists(path) diff --git a/tests/graph/explain/test_hetero_explainer.py b/tests/graph/explain/test_hetero_explainer.py new file mode 100644 index 000000000..2688a7de6 --- /dev/null +++ b/tests/graph/explain/test_hetero_explainer.py @@ -0,0 +1,128 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, nn, ops +from mindscience.sharker.explain import ( + DummyExplainer, + Explainer, + HeteroExplanation, +) +from mindscience.sharker.explain.config import ExplanationType + + +class DummyModel(nn.Cell): + def construct(self, x_dict, edge_index_dict, *args) -> Tensor: + return x_dict['paper'].mean().view(-1) + + +def test_get_prediction(hetero_data): + model = DummyModel() + # assert model.training + + explainer = Explainer( + model, + algorithm=DummyExplainer(), + explanation_type='phenomenon', + node_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + ) + pred = explainer.get_prediction(hetero_data.x_dict, + hetero_data.edge_index_dict) + # assert model.training + assert pred.shape == (1, ) + + +@pytest.mark.parametrize('target', [None, ops.randn(2)]) +@pytest.mark.parametrize('explanation_type', [x for x in ExplanationType]) +def test_forward(hetero_data, target, explanation_type): + model = DummyModel() + + explainer = Explainer( + model, + algorithm=DummyExplainer(), + explanation_type=explanation_type, + node_mask_type='attributes', + model_config=dict( + mode='regression', + task_level='graph', + ), + ) + + if target is None and explanation_type == ExplanationType.phenomenon: + with pytest.raises(ValueError): + explainer(hetero_data.x_dict, hetero_data.edge_index_dict, + target=target) + else: + explanation = explainer( + hetero_data.x_dict, + hetero_data.edge_index_dict, + target=target + if explanation_type == ExplanationType.phenomenon else None, + ) + # assert model.training + assert isinstance(explanation, HeteroExplanation) + assert 'node_mask' in explanation.available_explanations + for key in explanation.node_types: + expected_size = hetero_data[key].x.shape + assert explanation[key].node_mask.shape == expected_size + + +@pytest.mark.parametrize('threshold_value', [0.2, 0.5, 0.8]) +@pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) +def test_hard_threshold(hetero_data, threshold_value, node_mask_type): + + explainer = Explainer( + DummyModel(), + algorithm=DummyExplainer(), + explanation_type='model', + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + threshold_config=('hard', threshold_value), + ) + explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict) + assert 'node_mask' in explanation.available_explanations + assert 'edge_mask' in explanation.available_explanations + + for key in explanation.available_explanations: + for mask in explanation.collect(key).values(): + assert set(ops.unique(mask)[0].tolist()).issubset({0, 1}) + + +@pytest.mark.parametrize('threshold_value', [1, 5, 10]) +@pytest.mark.parametrize('threshold_type', ['topk', 'topk_hard']) +@pytest.mark.parametrize('node_mask_type', ['object', 'attributes']) +def test_topk_threshold(hetero_data, threshold_value, threshold_type, + node_mask_type): + explainer = Explainer( + DummyModel(), + algorithm=DummyExplainer(), + explanation_type='model', + node_mask_type=node_mask_type, + edge_mask_type='object', + model_config=dict( + mode='regression', + task_level='graph', + ), + threshold_config=(threshold_type, threshold_value), + ) + explanation = explainer(hetero_data.x_dict, hetero_data.edge_index_dict) + + assert 'node_mask' in explanation.available_explanations + assert 'edge_mask' in explanation.available_explanations + + for key in explanation.available_explanations: + for mask in explanation.collect(key).values(): + if threshold_type == 'topk': + assert (mask > 0).sum() == min(mask.numel(), threshold_value) + assert ((mask == 0).sum() == mask.numel() - + min(mask.numel(), threshold_value)) + else: + assert (mask == 1).sum() == min(mask.numel(), threshold_value) + assert ((mask == 0).sum() == mask.numel() - + min(mask.numel(), threshold_value)) diff --git a/tests/graph/explain/test_hetero_explanation.py b/tests/graph/explain/test_hetero_explanation.py new file mode 100644 index 000000000..df30cffef --- /dev/null +++ b/tests/graph/explain/test_hetero_explanation.py @@ -0,0 +1,144 @@ +import os.path as osp +from typing import Optional, Union + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.explain import HeteroExplanation +from mindscience.sharker.explain.config import MaskType +from mindscience.sharker.testing import withPackage + + +def create_random_explanation( + hetero_data: HeteroGraph, + node_mask_type: Optional[Union[MaskType, str]] = None, + edge_mask_type: Optional[Union[MaskType, str]] = None, +): + if node_mask_type is not None: + node_mask_type = MaskType(node_mask_type) + if edge_mask_type is not None: + edge_mask_type = MaskType(edge_mask_type) + + out = HeteroExplanation() + + for key in ['paper', 'author']: + out[key].x = hetero_data[key].x + if node_mask_type == MaskType.object: + out[key].node_mask = ops.rand(hetero_data[key].num_nodes, 1) + elif node_mask_type == MaskType.common_attributes: + out[key].node_mask = ops.rand(1, hetero_data[key].num_features) + elif node_mask_type == MaskType.attributes: + out[key].node_mask = ops.rand_like(hetero_data[key].x) + + for key in [('paper', 'paper'), ('paper', 'author')]: + out[key].edge_index = hetero_data[key].edge_index + out[key].edge_attr = hetero_data[key].edge_attr + if edge_mask_type == MaskType.object: + out[key].edge_mask = ops.rand(hetero_data[key].num_edges) + + return out + + +@pytest.mark.parametrize('node_mask_type', + [None, 'object', 'common_attributes', 'attributes']) +@pytest.mark.parametrize('edge_mask_type', [None, 'object']) +def test_available_explanations(hetero_data, node_mask_type, edge_mask_type): + expected = [] + if node_mask_type: + expected.append('node_mask') + if edge_mask_type: + expected.append('edge_mask') + + explanation = create_random_explanation( + hetero_data, + node_mask_type=node_mask_type, + edge_mask_type=edge_mask_type, + ) + + assert set(explanation.available_explanations) == set(expected) + + +def test_validate_explanation(hetero_data): + explanation = create_random_explanation(hetero_data) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 8 nodes"): + explanation = create_random_explanation(hetero_data) + explanation['paper'].node_mask = ops.rand(5, 5) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 5 features"): + explanation = create_random_explanation(hetero_data, 'attributes') + explanation['paper'].x = ops.randn(8, 5) + explanation.validate(raise_on_error=True) + + with pytest.raises(ValueError, match="with 10 edges"): + explanation = create_random_explanation(hetero_data) + explanation['paper', 'paper'].edge_mask = ops.randn(5) + explanation.validate(raise_on_error=True) + + +def test_node_mask(): + explanation = HeteroExplanation() + explanation['paper'].node_mask = ms.Tensor([[1.], [0.], [1.], [1.]]) + explanation['author'].node_mask = ms.Tensor([[1.], [0.], [1.], [1.]]) + with pytest.warns(UserWarning, match="are isolated"): + explanation.validate(raise_on_error=True) + + out = explanation.get_explanation_subgraph() + assert out['paper'].node_mask.shape == (3, 1) + assert out['author'].node_mask.shape == (3, 1) + + out = explanation.get_complement_subgraph() + assert out['paper'].node_mask.shape == (1, 1) + assert out['author'].node_mask.shape == (1, 1) + + +def test_edge_mask(): + explanation = HeteroExplanation() + explanation['paper'].num_nodes = 4 + explanation['author'].num_nodes = 4 + explanation['paper', 'author'].edge_index = ms.Tensor([ + [0, 1, 2, 3], + [0, 1, 2, 3], + ]) + explanation['paper', 'author'].edge_mask = ms.Tensor([1., 0., 1., 1.]) + + out = explanation.get_explanation_subgraph() + assert out['paper'].num_nodes == 4 + assert out['author'].num_nodes == 4 + assert out['paper', 'author'].edge_mask.shape == (3, ) + assert ops.equal(out['paper', 'author'].edge_index, + ms.Tensor([[0, 2, 3], [0, 2, 3]])).all() + + out = explanation.get_complement_subgraph() + assert out['paper'].num_nodes == 4 + assert out['author'].num_nodes == 4 + assert out['paper', 'author'].edge_mask.shape == (1, ) + assert ops.equal(out['paper', 'author'].edge_index, + ms.Tensor([[1], [1]])).all() + + +@withPackage('matplotlib') +@pytest.mark.parametrize('top_k', [2, None]) +@pytest.mark.parametrize('node_mask_type', [None, 'attributes']) +def test_visualize_feature_importance( + top_k, + node_mask_type, + tmp_path, + hetero_data, +): + explanation = create_random_explanation( + hetero_data, + node_mask_type=node_mask_type, + ) + + path = osp.join(tmp_path, 'feature_importance.png') + + if node_mask_type is None: + with pytest.raises(KeyError, match="Tried to collect 'node_mask'"): + explanation.visualize_feature_importance(path, top_k=top_k) + else: + explanation.visualize_feature_importance(path, top_k=top_k) + assert osp.exists(path) diff --git a/tests/graph/io/example1.off b/tests/graph/io/example1.off new file mode 100644 index 000000000..c01aa035f --- /dev/null +++ b/tests/graph/io/example1.off @@ -0,0 +1,8 @@ +OFF +4 2 0 +0.0 0.0 0.0 +0.0 1.0 0.0 +1.0 0.0 0.0 +1.0 1.0 0.0 +3 0 1 2 +3 1 2 3 diff --git a/tests/graph/io/example2.off b/tests/graph/io/example2.off new file mode 100644 index 000000000..83566449e --- /dev/null +++ b/tests/graph/io/example2.off @@ -0,0 +1,7 @@ +OFF +4 1 0 +0.0 0.0 0.0 +0.0 1.0 0.0 +1.0 0.0 0.0 +1.0 1.0 0.0 +4 0 1 2 3 diff --git a/tests/graph/io/test_fs.py b/tests/graph/io/test_fs.py new file mode 100644 index 000000000..65d6a9277 --- /dev/null +++ b/tests/graph/io/test_fs.py @@ -0,0 +1,131 @@ +import zipfile +from os import path as osp + +import fsspec +import pytest +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.data import extract_zip +from mindscience.sharker.io import fs +from mindscience.sharker.testing import noWindows + +if typing.WITH_WINDOWS: # FIXME + params = ['file'] +else: + params = ['file', 'memory'] + + +@pytest.fixture(params=params) +def tmp_fs_path(request, tmp_path) -> str: + if request.param == 'file': + return tmp_path.resolve().as_posix() + elif request.param == 'memory': + return f'memory://{tmp_path}' + raise NotImplementedError + + +def test_get_fs(): + assert 'file' in fs.get_fs('/tmp/test').protocol + assert 'memory' in fs.get_fs('memory:///tmp/test').protocol + + +@noWindows +def test_normpath(): + assert fs.normpath('////home') == '/home' + assert fs.normpath('memory:////home') == 'memory:////home' + + +def test_exists(tmp_fs_path): + path = osp.join(tmp_fs_path, 'file.txt') + assert not fs.exists(path) + with fsspec.open(path, 'w') as f: + f.write('here') + assert fs.exists(path) + + +def test_makedirs(tmp_fs_path): + path = osp.join(tmp_fs_path, '1', '2') + assert not fs.isdir(path) + fs.makedirs(path) + assert fs.isdir(path) + + +@pytest.mark.parametrize('detail', [False, True]) +def test_ls(tmp_fs_path, detail): + for i in range(2): + with fsspec.open(osp.join(tmp_fs_path, str(i)), 'w') as f: + f.write('here') + res = fs.ls(tmp_fs_path, detail) + assert len(res) == 2 + expected_protocol = fs.get_fs(tmp_fs_path).protocol + for output in res: + if detail: + output = output['name'] + assert fs.get_fs(output).protocol == expected_protocol + + +def test_cp(tmp_fs_path): + src = osp.join(tmp_fs_path, 'src') + for i in range(2): + with fsspec.open(osp.join(src, str(i)), 'w') as f: + f.write('here') + assert fs.exists(src) + + dst = osp.join(tmp_fs_path, 'dst') + assert not fs.exists(dst) + + # Can copy a file to new name: + fs.cp(osp.join(src, '1'), dst) + assert fs.isfile(dst) + fs.rm(dst) + + # Can copy a single file to directory: + fs.makedirs(dst) + fs.cp(osp.join(src, '1'), dst) + assert len(fs.ls(dst)) == 1 + + # Can copy multiple files to directory: + fs.cp(src, dst) + assert len(fs.ls(dst)) == 2 + for i in range(2): + fs.exists(osp.join(dst, str(i))) + + +def test_extract(tmp_fs_path): + def make_zip(path: str): + with fsspec.open(path, mode='wb') as f: + with zipfile.ZipFile(f, mode='w') as z: + z.writestr('1', b'data') + z.writestr('2', b'data') + + src = osp.join(tmp_fs_path, 'src', 'test.zip') + make_zip(src) + assert len(fsspec.open_files(f'zip://*::{src}')) == 2 + + dst = osp.join(tmp_fs_path, 'dst') + assert not fs.exists(dst) + + # Can copy and extract afterwards: + if fs.isdisk(tmp_fs_path): + fs.cp(src, osp.join(dst, 'test.zip')) + assert fs.exists(osp.join(dst, 'test.zip')) + extract_zip(osp.join(dst, 'test.zip'), dst) + assert len(fs.ls(dst)) == 3 + for i in range(2): + fs.exists(osp.join(dst, str(i))) + fs.rm(dst) + + # Can copy and extract: + fs.cp(src, dst, extract=True) + assert len(fs.ls(dst)) == 2 + for i in range(2): + fs.exists(osp.join(dst, str(i))) + + +def test_pickle_save_load(tmp_fs_path): + x = ops.randn(5, 5) + path = osp.join(tmp_fs_path, 'x.pt') + + fs.pickle_save(x, path) + out = fs.pickle_load(path) + assert ops.equal(x, out).all() diff --git a/tests/graph/io/test_off.py b/tests/graph/io/test_off.py new file mode 100644 index 000000000..8d14bd144 --- /dev/null +++ b/tests/graph/io/test_off.py @@ -0,0 +1,37 @@ +import os +import os.path as osp +import random +import sys + +import mindspore as ms +from mindscience.sharker.data import Graph +from mindscience.sharker.io import read_off +# from mindscience.sharker.io import write_off + + +def test_read_off(): + root_dir = osp.join(osp.dirname(osp.realpath(__file__))) + + data = read_off(osp.join(root_dir, 'example1.off')) + assert len(data) == 2 + assert data.crd.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]] + assert data.face.tolist() == [[0, 1], [1, 2], [2, 3]] + + data = read_off(osp.join(root_dir, 'example2.off')) + assert len(data) == 2 + assert data.crd.tolist() == [[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]] + assert data.face.tolist() == [[0, 0], [1, 2], [2, 3]] + + +# def test_write_off(): +# pos = ms.Tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) +# face = ms.Tensor([[0, 1], [1, 2], [2, 3]]) + +# name = str(random.randrange(sys.maxsize)) +# path = osp.join('/', 'tmp', f'{name}.off') +# write_off(Graph(pos=pos, face=face), path) +# data = read_off(path) +# os.unlink(path) + +# assert data.crd.tolist() == pos.tolist() +# assert data.face.tolist() == face.tolist() diff --git a/tests/graph/loader/test_cache.py b/tests/graph/loader/test_cache.py new file mode 100644 index 000000000..2289add36 --- /dev/null +++ b/tests/graph/loader/test_cache.py @@ -0,0 +1,68 @@ +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import CachedLoader, NeighborLoader + + +def test_cached_loader(): + x = ops.randn(14, 16) + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + + loader = NeighborLoader( + Graph(x=x, edge_index=edge_index), + num_neighbors=[2], + batch_size=10, + shuffle=False, + ) + cached_loader = CachedLoader(loader) + + assert len(cached_loader) == len(loader) + assert len(cached_loader._cache) == 0 + + cache = [] + for i, batch in enumerate(cached_loader): + assert len(cached_loader._cache) == i + 1 + cache.append(batch) + + for i, batch in enumerate(cached_loader): + assert batch == cache[i] + + cached_loader.clear() + assert len(cached_loader._cache) == 0 + + +def test_cached_loader_transform(): + x = ops.randn(14, 16) + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + + loader = NeighborLoader( + Graph(x=x, edge_index=edge_index), + num_neighbors=[2], + batch_size=10, + shuffle=False, + ) + cached_loader = CachedLoader( + loader, + transform=lambda batch: batch.edge_index, + ) + + assert len(cached_loader) == len(loader) + assert len(cached_loader._cache) == 0 + + cache = [] + for i, batch in enumerate(cached_loader): + assert len(cached_loader._cache) == i + 1 + assert isinstance(batch, Tensor) + assert batch.dim() == 2 and batch.shape[0] == 2 + + cache.append(batch) + + for i, batch in enumerate(cached_loader): + assert ops.equal(batch, cache[i]).all() diff --git a/tests/graph/loader/test_cluster.py b/tests/graph/loader/test_cluster.py new file mode 100644 index 000000000..1e9b22785 --- /dev/null +++ b/tests/graph/loader/test_cluster.py @@ -0,0 +1,204 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import ClusterData, ClusterLoader +from mindscience.sharker.testing import onlyFullTest, onlyOnline, withMETIS +from mindscience.sharker.utils import sort_edge_index +from mindscience.sharker.sparse import Layout + + +@withMETIS +def test_cluster_gcn(): + adj = ms.Tensor([ + [1, 1, 1, 0, 1, 0], + [1, 1, 0, 1, 0, 1], + [1, 0, 1, 0, 1, 0], + [0, 1, 0, 1, 0, 1], + [1, 0, 1, 0, 1, 0], + [0, 1, 0, 1, 0, 1], + ]) + + x = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + [5.0, 5.0], + ]) + edge_index = adj.nonzero().t() + edge_attr = ops.arange(edge_index.shape[1]) + n_id = ops.arange(6) + data = Graph(x=x, n_id=n_id, edge_index=edge_index, edge_attr=edge_attr) + data.num_nodes = 6 + + cluster_data = ClusterData(data, num_parts=2, log=False) + + partition = cluster_data._partition( + edge_index, cluster=ms.Tensor([0, 1, 0, 1, 0, 1])) + assert partition.partptr.tolist() == [0, 3, 6] + assert partition.node_perm.tolist() == [0, 2, 4, 1, 3, 5] + assert partition.edge_perm.tolist() == [ + 0, 2, 3, 1, 8, 9, 10, 14, 15, 16, 4, 5, 6, 7, 11, 12, 13, 17, 18, 19 + ] + + assert cluster_data.partition.partptr.tolist() == [0, 3, 6] + assert ops.equal( + cluster_data.partition.node_perm.sort()[0], + ops.arange(data.num_nodes), + ).all() + assert ops.equal( + cluster_data.partition.edge_perm.sort()[0], + ops.arange(data.num_edges), + ).all() + + out = cluster_data[0] + expected = data.subgraph(out.n_id) + out.validate() + assert out.num_nodes == 3 + assert out.n_id.shape == (3, ) + assert ops.equal(out.x, expected.x).all() + tmp = sort_edge_index(expected.edge_index, expected.edge_attr) + assert ops.equal(out.edge_index, tmp[0]).all() + assert ops.equal(out.edge_attr, tmp[1]).all() + + out = cluster_data[1] + out.validate() + assert out.num_nodes == 3 + assert out.n_id.shape == (3, ) + expected = data.subgraph(out.n_id) + assert ops.equal(out.x, expected.x).all() + tmp = sort_edge_index(expected.edge_index, expected.edge_attr) + assert ops.equal(out.edge_index, tmp[0]).all() + assert ops.equal(out.edge_attr, tmp[1]).all() + + loader = ClusterLoader(cluster_data, batch_size=1) + iterator = iter(loader) + + out = next(iterator) + out.validate() + assert out.num_nodes == 3 + assert out.n_id.shape == (3, ) + expected = data.subgraph(out.n_id) + assert ops.equal(out.x, expected.x).all() + tmp = sort_edge_index(expected.edge_index, expected.edge_attr) + assert ops.equal(out.edge_index, tmp[0]).all() + assert ops.equal(out.edge_attr, tmp[1]).all() + + out = next(iterator) + out.validate() + assert out.num_nodes == 3 + assert out.n_id.shape == (3, ) + expected = data.subgraph(out.n_id) + assert ops.equal(out.x, expected.x).all() + tmp = sort_edge_index(expected.edge_index, expected.edge_attr) + assert ops.equal(out.edge_index, tmp[0]).all() + assert ops.equal(out.edge_attr, tmp[1]).all() + + loader = ClusterLoader(cluster_data, batch_size=2, shuffle=False) + out = next(iter(loader)) + out.validate() + assert out.num_nodes == 6 + assert out.n_id.shape == (6, ) + expected = data.subgraph(out.n_id) + assert ops.equal(out.x, expected.x).all() + tmp = sort_edge_index(expected.edge_index, expected.edge_attr) + assert ops.equal(out.edge_index, tmp[0]).all() + assert ops.equal(out.edge_attr, tmp[1]).all() + + +@withMETIS +def test_keep_inter_cluster_edges(): + adj = ms.Tensor([ + [1, 1, 1, 0, 1, 0], + [1, 1, 0, 1, 0, 1], + [1, 0, 1, 0, 1, 0], + [0, 1, 0, 1, 0, 1], + [1, 0, 1, 0, 1, 0], + [0, 1, 0, 1, 0, 1], + ]) + + x = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + [5.0, 5.0], + ]) + edge_index = adj.nonzero().t() + edge_attr = ops.arange(edge_index.shape[1]) + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr) + data.num_nodes = 6 + + cluster_data = ClusterData(data, num_parts=2, log=False, + keep_inter_cluster_edges=True) + + data = cluster_data[0] + assert data.edge_index[0].min() == 0 + assert data.edge_index[0].max() == 2 + assert data.edge_index[1].min() == 0 + assert data.edge_index[1].max() > 2 + assert data.edge_index.shape[1] == data.edge_attr.shape[0] + + data = cluster_data[1] + assert data.edge_index[0].min() == 0 + assert data.edge_index[0].max() == 2 + assert data.edge_index[1].min() == 0 + assert data.edge_index[1].max() > 2 + assert data.edge_index.shape[1] == data.edge_attr.shape[0] + + +@withMETIS +@onlyOnline +@onlyFullTest +@pytest.mark.parametrize('sparse_format', [Layout.CSR, Layout.CSC]) +def test_cluster_gcn_correctness(get_dataset, sparse_format): + dataset = get_dataset('Cora') + data = dataset[0].copy() + data.n_id = ops.arange(data.num_nodes) + cluster_data = ClusterData( + data, + num_parts=10, + log=False, + sparse_format=sparse_format, + ) + loader = ClusterLoader(cluster_data, batch_size=3, shuffle=False) + + for batch1 in loader: + batch1.validate() + batch2 = data.subgraph(batch1.n_id) + assert batch1.num_nodes == batch2.num_nodes + assert batch1.num_edges == batch2.num_edges + assert ops.equal(batch1.x, batch2.x).all() + assert ops.equal( + batch1.edge_index, + sort_edge_index( + batch2.edge_index, + sort_by_row=sparse_format == Layout.CSR, + ), + ).all() + + +# if __name__ == '__main__': +# import argparse + +# from ogb.nodeproppred import PygNodePropPredDataset +# from tqdm import tqdm + +# parser = argparse.ArgumentParser() +# parser.add_argument('--num_workers', type=int, default=0) +# args = parser.parse_args() + +# data = PygNodePropPredDataset('ogbn-products', root='/tmp/ogb')[0] + +# loader = ClusterLoader( +# ClusterData(data, num_parts=15_000, save_dir='/tmp/ogb/ogbn_products'), +# batch_size=32, +# shuffle=True, +# num_workers=args.num_workers, +# ) + +# for batch in tqdm(loader): +# pass diff --git a/tests/graph/loader/test_common.py b/tests/graph/loader/test_common.py new file mode 100644 index 000000000..60fcd39a4 --- /dev/null +++ b/tests/graph/loader/test_common.py @@ -0,0 +1,161 @@ +import math +import random + +import mindspore as ms +import mindspore.dataset as ds +import numpy as np +from mindscience.sharker.data import Graph +from mindscience.sharker.loader.common import Dataloader, Collater +from mindscience.sharker.datasets import QM9 + + +def test_qm9_dataset_graph(): + dataset = QM9(root="~/db/") + dataset_length = len(dataset) + randomSampler = ds.RandomSampler() + + # test qm9 + Dataset + Graph + drop_remainder=False + randomSampler + random batch_size + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset, sampler=randomSampler, shuffle=None) + newloader = newloader.batch(batch_size, collate_fn=Collater, drop_remainder=False) + index = 0 + for data in newloader: + assert isinstance(data, Graph) is True + index = index + 1 + assert dataset_length - (index - 1) * batch_size == len(data.name) + assert index == math.ceil(dataset_length / batch_size) + + # test qm9 + list + Graph + drop_remainder=True + randomSampler + random batch_size + dataset_list = [graph for graph in dataset] + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset_list, sampler=randomSampler, shuffle=None) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + for data in newloader: + assert isinstance(data, Graph) is True + assert len(data.name) == batch_size + index = index + 1 + assert index == math.floor(dataset_length / batch_size) + + # test qm9 + Dataset + Graph + drop_remainder=False + randomSampler + no .batch + newloader = Dataloader(dataset, sampler=randomSampler, shuffle=None) + index = 0 + for data in newloader: + assert isinstance(data, Graph) + index = index + 1 + assert index == dataset_length + + # test qm9 + list + Graph + drop_remainder=True + randomSampler + no .batch + newloader = Dataloader(dataset_list, sampler=randomSampler, shuffle=None) + index = 0 + for data in newloader: + assert isinstance(data, Graph) + index = index + 1 + assert index == dataset_length + + +def test_custom_class(): + class MyMapDataset(): + def __init__(self): + super(MyMapDataset).__init__() + self.data = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + + def __getitem__(self, index): + return self.data[index] + + def __len__(self): + return len(self.data) + + dataset = MyMapDataset() + dataset_length = len(dataset) + randomSampler = ds.RandomSampler() + + # custom iterable class + drop_remainder=True + randomSampler + random batch_size + batch_size = random.randint(1, dataset_length - 1) + newloader = Dataloader(dataset, shuffle=None, column_names=["data"], sampler=randomSampler) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape[0] == batch_size + index = index + 1 + assert index == math.floor(dataset_length / batch_size) + + # custom iterable class + drop_remainder=True + nosampler + shuffle=False + batch_size=1 + batch_size = 1 + newloader = Dataloader(dataset, shuffle=False, column_names=["data"]) + newloader = newloader.batch(batch_size, drop_remainder=True) + index = 0 + + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape[0] == batch_size + assert int(data[0]) == index + index = index + 1 + assert index == math.floor(dataset_length / batch_size) + + # custom iterable class + drop_remainder=True + nosampler + shuffle=False + no .batch + newloader = Dataloader(dataset, shuffle=False, column_names=["data"]) + index = 0 + for data in newloader: + assert isinstance(data, list) + assert isinstance(data[0], ms.Tensor) + assert data[0].shape == () + assert int(data[0]) == index + index = index + 1 + assert index == dataset_length + + +def test_callable(): + def generator_multi_column(): + for i in range(64): + yield np.array([i]), np.array([[i, i + 1], [i + 2, i + 3]]) + + # custom callable function + batch_size=2 + drop_remainder=False + batch_size = 2 + dataset = Dataloader(source=generator_multi_column, column_names=["col1", "col2"]) + dataset = dataset.batch(batch_size) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert isinstance(data[1], ms.Tensor) + assert len(data) == 2 + assert data[0].shape == (2, 1) + assert data[1].shape == (2, 2, 2) + index = index + 1 + + # custom callable function + drop_remainder=False + no .batch + dataset = Dataloader(source=generator_multi_column, column_names=["col1", "col2"]) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert isinstance(data[1], ms.Tensor) + assert len(data) == 2 + assert data[0].shape == (1,) + assert data[1].shape == (2, 2,) + index = index + 1 + + +def test_list(): + input = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] + # custom list + drop_remainder=True + batch_size=3 + batch_size = 3 + dataset = Dataloader(source=input, column_names=["col1"]) + dataset = dataset.batch(batch_size, drop_remainder=True) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert len(data) == 1 + assert data[0].shape == (3,) + index = index + 1 + + # custom list + drop_remainder=True + no .batch + dataset = Dataloader(source=input, column_names=["col1"]) + index = 0 + for data in dataset: + assert isinstance(data[0], ms.Tensor) + assert len(data) == 1 + assert data[0].shape == () + index = index + 1 diff --git a/tests/graph/loader/test_dataloader.py b/tests/graph/loader/test_dataloader.py new file mode 100644 index 000000000..92d216952 --- /dev/null +++ b/tests/graph/loader/test_dataloader.py @@ -0,0 +1,282 @@ +import multiprocessing +import sys +from collections import namedtuple + +import pytest +import mindspore as ms +from mindspore import COOTensor, ops +from mindscience.sharker import EdgeIndex +from mindscience.sharker.data import Graph, HeteroGraph, OnDiskDataset +from mindscience.sharker.loader import DataLoader +from mindscience.sharker.testing import ( + get_random_edge_index, + onlyLinux, +) + +with_mp = sys.platform not in ['win32'] +num_workers_list = [0, 2] if with_mp else [0] + +if sys.platform == 'darwin': + multiprocessing.set_start_method('spawn') + + +@pytest.mark.parametrize('num_workers', num_workers_list) +def test_dataloader(num_workers): + if num_workers > 0: + return + + x = ms.Tensor([[1.0], [1.0], [1.0]]) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + face = ms.Tensor([[0], [1], [2]]) + y = 2. + z = ms.Tensor(0.) + name = 'data' + + data = Graph(x=x, edge_index=edge_index, y=y, z=z, name=name) + assert str(data) == ("Graph(x=[3, 1], edge_index=[2, 4], y=2.0, z=0.0, " + "name='data')") + data.face = face + + loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False) + assert len(loader) == 2 + + for batch in loader: + # assert batch.x.device == device + # assert batch.edge_index.device == device + # assert batch.z.device == device + assert batch.num_graphs == len(batch) == 2 + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1] + assert batch.ptr.tolist() == [0, 3, 6] + assert batch.x.tolist() == [[1], [1], [1], [1], [1], [1]] + assert batch.edge_index.tolist() == [[0, 1, 1, 2, 3, 4, 4, 5], + [1, 0, 2, 1, 4, 3, 5, 4]] + assert batch.y.tolist() == [2.0, 2.0] + assert batch.z.tolist() == [0.0, 0.0] + assert batch.name == ['data', 'data'] + assert batch.face.tolist() == [[0, 3], [1, 4], [2, 5]] + + for store in batch.stores: + assert id(batch) == id(store._parent()) + + loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False, + follow_batch=['edge_index'], collate_fn=None) + assert len(loader) == 2 + + for batch in loader: + assert batch.num_graphs == len(batch) == 2 + assert batch.edge_index_batch.tolist() == [0, 0, 0, 0, 1, 1, 1, 1] + + +@onlyLinux +@pytest.mark.parametrize('num_workers', num_workers_list) +def test_dataloader_on_disk_dataset(tmp_path, num_workers): + dataset = OnDiskDataset(tmp_path) + data1 = Graph(x=ops.randn(3, 8)) + data2 = Graph(x=ops.randn(4, 8)) + dataset.extend([data1, data2]) + + loader = DataLoader(dataset, batch_size=2) + assert len(loader) == 1 + batch = next(iter(loader)) + assert batch.num_nodes == 7 + assert ops.equal(batch.x, ops.cat(([data1.x, data2.x]), axis=0)).all() + assert batch.batch.tolist() == [0, 0, 0, 1, 1, 1, 1] + + dataset.close() + + +def test_dataloader_fallbacks(): + # Test inputs of type List[Tensor]: + data_list = [ops.ones(3) for _ in range(4)] + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch, ops.ones([4, 3])).all() + + # Test inputs of type List[float]: + data_list = [1.0, 1.0, 1.0, 1.0] + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch, ops.ones(4)).all() + + # Test inputs of type List[int]: + data_list = [1, 1, 1, 1] + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch, ops.ones(4, dtype=ms.int64)).all() + + # Test inputs of type List[str]: + data_list = ['test'] * 4 + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert batch == data_list + + # Test inputs of type List[Mapping]: + data_list = [{'x': ops.ones(3), 'y': 1}] * 4 + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch['x'], ops.ones([4, 3])).all() + assert ops.equal(batch['y'], ops.ones(4, dtype=ms.int64)).all() + + # Test inputs of type List[Tuple]: + DataTuple = namedtuple('DataTuple', 'x y') + data_list = [DataTuple(0.0, 1)] * 4 + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch.x, ops.zeros(4)).all() + assert ops.equal(batch[1], ops.ones(4, dtype=ms.int64)).all() + + # Test inputs of type List[Sequence]: + data_list = [[0.0, 1]] * 4 + batch = next(iter(DataLoader(data_list, batch_size=4))) + assert ops.equal(batch[0], ops.zeros(4)).all() + assert ops.equal(batch[1], ops.ones(4, dtype=ms.int64)).all() + + # Test that inputs of unsupported types raise an error: + class DummyClass: + pass + + with pytest.raises(TypeError): + data_list = [DummyClass()] * 4 + next(iter(DataLoader(data_list, batch_size=4))) + + +# @pytest.mark.skipif(not with_mp, reason='Multi-processing not available') +# def test_multiprocessing(): +# queue = ops.multiprocessing.Manager().Queue() +# data = Graph(x=ops.randn(5, 16)) +# data_list = [data, data, data, data] +# loader = DataLoader(data_list, batch_size=2) +# for batch in loader: +# queue.put(batch) + +# batch = queue.get() +# assert batch.num_graphs == len(batch) == 2 + +# batch = queue.get() +# assert batch.num_graphs == len(batch) == 2 + + +# def test_pin_memory(): +# x = ops.randn(3, 16) +# edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) +# data = Graph(x=x, edge_index=edge_index) + +# loader = DataLoader([data] * 4, batch_size=2, pin_memory=True) +# for batch in loader: +# assert batch.x.is_pinned() or not torch.cuda.is_available() +# assert batch.edge_index.is_pinned() or not torch.cuda.is_available() + + +@pytest.mark.parametrize('num_workers', num_workers_list) +def test_heterogeneous_dataloader(num_workers): + data = HeteroGraph() + data['p'].x = ops.randn(100, 128) + data['a'].x = ops.randn(200, 128) + data['p', 'a'].edge_index = get_random_edge_index(100, 200, 500) + data['p'].edge_attr = ops.randn(500, 32) + data['a', 'p'].edge_index = get_random_edge_index(200, 100, 400) + data['a', 'p'].edge_attr = ops.randn(400, 32) + + loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False) + assert len(loader) == 2 + + for batch in loader: + assert batch.num_graphs == len(batch) == 2 + assert batch.num_nodes == 600 + + for store in batch.stores: + assert id(batch) == id(store._parent()) + + +@pytest.mark.parametrize('num_workers', num_workers_list) +@pytest.mark.parametrize('sort_order', [None, 'row', 'col']) +def test_edge_index_dataloader(num_workers, sort_order): + if sort_order == 'col': + edge_index = [[1, 0, 2, 1], [0, 1, 1, 2]] + else: + edge_index = [[0, 1, 1, 2], [1, 0, 2, 1]] + + edge_index = EdgeIndex( + edge_index, + sparse_shape=(3, 3), + sort_order=sort_order, + is_undirected=True, + ) + data = Graph(edge_index=edge_index) + assert data.num_nodes == 3 + + loader = DataLoader( + [data, data, data, data], + batch_size=2 + ) + assert len(loader) == 2 + + for batch in loader: + assert isinstance(batch.edge_index, EdgeIndex) + assert batch.edge_index.sparse_shape == (6, 6) + assert batch.edge_index.sort_order == sort_order + assert batch.edge_index.is_undirected + + +# @withPackage('torch_frame') +# def test_dataloader_tensor_frame(): +# tf = get_random_tensor_frame(num_rows=10) +# loader = DataLoader([tf, tf, tf, tf], batch_size=2, shuffle=False) +# assert len(loader) == 2 + +# for batch in loader: +# assert batch.num_rows == 20 + +# data = Graph(tf=tf, edge_index=get_random_edge_index(10, 10, 20)) +# loader = DataLoader([data, data, data, data], batch_size=2, shuffle=False) +# assert len(loader) == 2 + +# for batch in loader: +# assert batch.num_graphs == len(batch) == 2 +# assert batch.num_nodes == 20 +# assert batch.tf.num_rows == 20 +# assert batch.edge_index.max() >= 10 + + +def test_dataloader_sparse(): + adj_t = COOTensor( + indices=ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]).T, + values=ops.randn(4), + shape=(3, 3), + ) + data = Graph(adj_t=adj_t) + + loader = DataLoader([data, data], batch_size=2) + for batch in loader: + assert batch.adj_t.shape == (6, 6) + + +if __name__ == '__main__': + import argparse + import time + + from mindscience.sharker.datasets import QM9 + + parser = argparse.ArgumentParser() + parser.add_argument('--num_workers', type=int, default=0) + args = parser.parse_args() + + kwargs = dict(batch_size=128, shuffle=test_heterogeneous_dataloader) + + in_memory_dataset = QM9('/tmp/QM9') + loader = DataLoader(in_memory_dataset, **kwargs) + + print('In-Memory Dataset:') + for _ in range(2): + print(f'Start loading {len(loader)} mini-batches ... ', end='') + t = time.perf_counter() + for batch in loader: + pass + print(f'Done! [{time.perf_counter() - t:.4f}s]') + + on_disk_dataset = in_memory_dataset.to_on_disk_dataset() + loader = DataLoader(on_disk_dataset, **kwargs) + + print('On-Disk Dataset:') + for _ in range(2): + print(f'Start loading {len(loader)} mini-batches ... ', end='') + t = time.perf_counter() + for batch in loader: + pass + print(f'Done! [{time.perf_counter() - t:.4f}s]') + + on_disk_dataset.close() diff --git a/tests/graph/loader/test_dynamic_batch_sampler.py b/tests/graph/loader/test_dynamic_batch_sampler.py new file mode 100644 index 000000000..b51459586 --- /dev/null +++ b/tests/graph/loader/test_dynamic_batch_sampler.py @@ -0,0 +1,37 @@ +from typing import List + +import pytest +import mindspore as ms +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import DataLoader, DynamicBatchSampler + + +def test_dataloader_with_dynamic_batches(): + data_list: List[Graph] = [] + for num_nodes in range(100, 110): + data_list.append(Graph(num_nodes=num_nodes)) + + ms.set_seed(12345) + batch_sampler = DynamicBatchSampler(data_list, 300, shuffle=True) + loader = DataLoader(data_list, sampler=batch_sampler) + + num_nodes_total = 0 + for data in loader: + assert data.num_nodes <= 300 + num_nodes_total += data.num_nodes + assert num_nodes_total == 1045 + + # Test skipping + data_list = [Graph(num_nodes=400)] + data_list + batch_sampler = DynamicBatchSampler(data_list, 300, skip_too_big=True, + num_steps=2) + loader = DataLoader(data_list, sampler=batch_sampler) + + num_nodes_total = 0 + for data in loader: + num_nodes_total += data.num_nodes + assert num_nodes_total == 404 + + with pytest.raises(ValueError, match="length of 'DynamicBatchSampler'"): + len(DynamicBatchSampler(data_list, max_num=300)) + assert len(DynamicBatchSampler(data_list, max_num=300, num_steps=2)) == 2 diff --git a/tests/graph/loader/test_graph_saint.py b/tests/graph/loader/test_graph_saint.py new file mode 100644 index 000000000..3014c2319 --- /dev/null +++ b/tests/graph/loader/test_graph_saint.py @@ -0,0 +1,82 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import ( + GraphSAINTEdgeSampler, + GraphSAINTNodeSampler, + GraphSAINTRandomWalkSampler, +) + + +def test_graph_saint(): + adj = ms.Tensor([ + [+1, +2, +3, +0, +4, +0], + [+5, +6, +0, +7, +0, +8], + [+9, +0, 10, +0, 11, +0], + [+0, 12, +0, 13, +0, 14], + [15, +0, 16, +0, 17, +0], + [+0, 18, +0, 19, +0, 20], + ]) + + edge_index = adj.nonzero().t() + edge_id = adj[edge_index[0], edge_index[1]] + x = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + [5.0, 5.0], + ]) + n_id = ops.arange(6) + data = Graph(edge_index=edge_index, x=x, n_id=n_id, edge_id=edge_id, + num_nodes=6) + + loader = GraphSAINTNodeSampler(data, batch_size=3, num_steps=4, + sample_coverage=10, log=False) + + assert len(loader) == 4 + for sample in loader: + assert sample.num_nodes <= data.num_nodes + assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 + assert sample.num_nodes == sample.n_id.numel() + assert sample.x.tolist() == x[sample.n_id].tolist() + assert sample.edge_index.min() >= 0 + assert sample.edge_index.max() < sample.num_nodes + assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 + assert sample.edge_id.numel() == sample.num_edges + assert sample.node_norm.numel() == sample.num_nodes + assert sample.edge_norm.numel() == sample.num_edges + + loader = GraphSAINTEdgeSampler(data, batch_size=2, num_steps=4, + sample_coverage=10, log=False) + + assert len(loader) == 4 + for sample in loader: + assert sample.num_nodes <= data.num_nodes + assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 + assert sample.num_nodes == sample.n_id.numel() + assert sample.x.tolist() == x[sample.n_id].tolist() + assert sample.edge_index.min() >= 0 + assert sample.edge_index.max() < sample.num_nodes + assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 + assert sample.edge_id.numel() == sample.num_edges + assert sample.node_norm.numel() == sample.num_nodes + assert sample.edge_norm.numel() == sample.num_edges + + loader = GraphSAINTRandomWalkSampler(data, batch_size=2, walk_length=1, + num_steps=4, sample_coverage=10, + log=False) + + assert len(loader) == 4 + for sample in loader: + assert sample.num_nodes <= data.num_nodes + assert sample.n_id.min() >= 0 and sample.n_id.max() < 6 + assert sample.num_nodes == sample.n_id.numel() + assert sample.x.tolist() == x[sample.n_id].tolist() + assert sample.edge_index.min() >= 0 + assert sample.edge_index.max() < sample.num_nodes + assert sample.edge_id.min() >= 1 and sample.edge_id.max() <= 21 + assert sample.edge_id.numel() == sample.num_edges + assert sample.node_norm.numel() == sample.num_nodes + assert sample.edge_norm.numel() == sample.num_edges diff --git a/tests/graph/loader/test_hgt_loader.py b/tests/graph/loader/test_hgt_loader.py new file mode 100644 index 000000000..b101aaab0 --- /dev/null +++ b/tests/graph/loader/test_hgt_loader.py @@ -0,0 +1,214 @@ +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.loader import HGTLoader +from mindscience.sharker.nn import GraphConv +from mindscience.sharker.testing import ( + get_random_edge_index, + onlyOnline, +) +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import k_hop_subgraph + + +def is_subset(subedge_index, edge_index, src_idx, dst_idx): + num_nodes = int(edge_index.max()) + 1 + idx = num_nodes * edge_index[0] + edge_index[1] + subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]] + mask = ms.Tensor.from_numpy(np.isin(subidx, idx)) + return int(mask.sum()) == mask.numel() + + +def test_hgt_loader(): + ms.set_seed(12345) + + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['author'].x = ops.arange(100, 300) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) + data['paper', 'paper'].edge_attr = ops.arange(500) + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) + data['paper', 'author'].edge_attr = ops.arange(500, 1500) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + data['author', 'paper'].edge_attr = ops.arange(1500, 2500) + + r1, c1 = data['paper', 'paper'].edge_index + r2, c2 = data['paper', 'author'].edge_index + ms.Tensor([[0], [100]]) + r3, c3 = data['author', 'paper'].edge_index + ms.Tensor([[100], [0]]) + full_adj = SparseTensor( + row=ops.cat([r1, r2, r3]), + col=ops.cat([c1, c2, c3]), + value=ops.arange(2500), + ) + + batch_size = 20 + loader = HGTLoader(data, num_samples=[5] * 4, batch_size=batch_size, + input_nodes='paper') + assert str(loader) == 'HGTLoader()' + assert len(loader) == (100 + batch_size - 1) // batch_size + + for batch in loader: + assert isinstance(batch, HeteroGraph) + + # Test node and types: + assert set(batch.node_types) == {'paper', 'author'} + assert set(batch.edge_types) == set(data.edge_types) + + assert len(batch['paper']) == 4 + assert batch['paper'].n_id.shape == (batch['paper'].num_nodes, ) + assert batch['paper'].x.shape == (40, ) # 20 + 4 * 5 + assert batch['paper'].input_id.numel() == batch_size + assert batch['paper'].batch_size == batch_size + assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 + + assert len(batch['author']) == 2 + assert batch['author'].n_id.shape == (batch['author'].num_nodes, ) + assert batch['author'].x.shape == (20, ) # 4 * 5 + assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300 + + # Test edge type selection: + assert set(batch.edge_types) == {('paper', 'to', 'paper'), + ('paper', 'to', 'author'), + ('author', 'to', 'paper')} + + assert len(batch['paper', 'paper']) == 3 + num_edges = batch['paper', 'paper'].num_edges + assert batch['paper', 'paper'].e_id.shape == (num_edges, ) + row, col = batch['paper', 'paper'].edge_index + value = batch['paper', 'paper'].edge_attr + adj = full_adj[batch['paper'].x, batch['paper'].x] + assert row.min() >= 0 and row.max() < 40 + assert col.min() >= 0 and col.max() < 40 + assert value.min() >= 0 and value.max() < 500 + assert adj.nnz() == row.shape[0] + assert ops.isclose(ms.numpy.unique(row), ms.numpy.unique(adj.storage.row())).all() + assert ops.isclose(ms.numpy.unique(col), ms.numpy.unique(adj.storage.col())).all() + assert ops.isclose(ms.numpy.unique(value), ms.numpy.unique(adj.storage.value())).all() + + assert is_subset(batch['paper', 'paper'].edge_index, + data['paper', 'paper'].edge_index, batch['paper'].x, + batch['paper'].x) + + assert len(batch['paper', 'author']) == 3 + num_edges = batch['paper', 'author'].num_edges + assert batch['paper', 'author'].e_id.shape == (num_edges, ) + row, col = batch['paper', 'author'].edge_index + value = batch['paper', 'author'].edge_attr + adj = full_adj[batch['paper'].x, batch['author'].x] + assert row.min() >= 0 and row.max() < 40 + assert col.min() >= 0 and col.max() < 20 + assert value.min() >= 500 and value.max() < 1500 + assert adj.nnz() == row.shape[0] + assert ops.isclose(ms.numpy.unique(row), ms.numpy.unique(adj.storage.row())).all() + assert ops.isclose(ms.numpy.unique(col), ms.numpy.unique(adj.storage.col())).all() + assert ops.isclose(ms.numpy.unique(value), ms.numpy.unique(adj.storage.value())).all() + + assert is_subset(batch['paper', 'author'].edge_index, + data['paper', 'author'].edge_index, batch['paper'].x, + batch['author'].x - 100) + + assert len(batch['author', 'paper']) == 3 + num_edges = batch['author', 'paper'].num_edges + assert batch['author', 'paper'].e_id.shape == (num_edges, ) + row, col = batch['author', 'paper'].edge_index + value = batch['author', 'paper'].edge_attr + adj = full_adj[batch['author'].x, batch['paper'].x] + assert row.min() >= 0 and row.max() < 20 + assert col.min() >= 0 and col.max() < 40 + assert value.min() >= 1500 and value.max() < 2500 + assert adj.nnz() == row.shape[0] + assert ops.isclose(ms.numpy.unique(row), ms.numpy.unique(adj.storage.row())).all() + assert ops.isclose(ms.numpy.unique(col), ms.numpy.unique(adj.storage.col())).all() + assert ops.isclose(ms.numpy.unique(value), ms.numpy.unique(adj.storage.value())).all() + + assert is_subset(batch['author', 'paper'].edge_index, + data['author', 'paper'].edge_index, + batch['author'].x - 100, batch['paper'].x) + + # Test for isolated nodes (there shouldn't exist any): + n_id = ops.cat([batch['paper'].x, batch['author'].x]) + row, col, _ = full_adj[n_id, n_id].coo() + assert ms.numpy.unique(ops.cat([row, col])).numel() >= 59 + + +# @onlyOnline +# def test_hgt_loader_on_cora(get_dataset): +# dataset = get_dataset(name='Cora') +# data = dataset[0] +# data.edge_weight = ops.rand(data.num_edges) + +# hetero_data = HeteroGraph() +# hetero_data['paper'].x = data.x +# hetero_data['paper'].n_id = ops.arange(data.num_nodes) +# hetero_data['paper', 'paper'].edge_index = data.edge_index +# hetero_data['paper', 'paper'].edge_weight = data.edge_weight + +# split_idx = ops.arange(5, 8) + +# # Sample the complete two-hop neighborhood: +# loader = HGTLoader(hetero_data, num_samples=[data.num_nodes] * 2, +# batch_size=split_idx.numel(), +# input_nodes=('paper', split_idx)) +# assert len(loader) == 1 + +# hetero_batch = next(iter(loader)) +# batch_size = hetero_batch['paper'].batch_size + +# n_id, _, _, e_mask = k_hop_subgraph(split_idx, num_hops=2, +# edge_index=data.edge_index, +# num_nodes=data.num_nodes) + +# n_id = n_id.sort()[0] +# assert n_id.tolist() == hetero_batch['paper'].n_id.sort()[0].tolist() +# assert hetero_batch['paper', 'paper'].num_edges == int(e_mask.sum()) + +# class GNN(nn.Cell): +# def __init__(self, in_channels, hidden_channels, out_channels): +# super().__init__() +# self.conv1 = GraphConv(in_channels, hidden_channels) +# self.conv2 = GraphConv(hidden_channels, out_channels) + +# def construct(self, x, edge_index, edge_weight): +# x = ops.relu(self.conv1(x, edge_index, edge_weight)) +# x = ops.relu(self.conv2(x, edge_index, edge_weight)) +# return x + +# model = GNN(dataset.num_features, 16, dataset.num_classes) +# hetero_model = to_hetero(model, hetero_data.metadata()) + +# out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] +# out2 = hetero_model(hetero_batch.x_dict, hetero_batch.edge_index_dict, +# hetero_batch.edge_weight_dict)['paper'][:batch_size] +# assert ops.isclose(out1, out2, atol=1e-6).all() + + +def test_hgt_loader_disconnected(): + data = HeteroGraph() + + data['paper'].x = ops.randn(10, 16) + data['author'].x = ops.randn(10, 16) + + # Paper nodes are disconnected from author nodes: + data['paper', 'paper'].edge_index = get_random_edge_index(10, 10, 15) + data['paper', 'paper'].edge_attr = ops.randn(15, 8) + data['author', 'author'].edge_index = get_random_edge_index(10, 10, 15) + data['author', 'author'].edge_attr = ops.randn(15, 8) + + loader = HGTLoader(data, num_samples=[2], batch_size=2, + input_nodes='paper') + + for batch in loader: + assert isinstance(batch, HeteroGraph) + + # Test node and edge types: + assert set(batch.node_types) == set(data.node_types) + assert set(batch.edge_types) == set(data.edge_types) + + assert batch['author'].num_nodes == 0 + assert batch['author'].x.shape == (0, 16) + assert batch['author', 'author'].num_edges == 0 + assert batch['author', 'author'].edge_index.shape == (2, 0) + assert batch['author', 'author'].edge_attr.shape == (0, 8) diff --git a/tests/graph/loader/test_ibmb_loader.py b/tests/graph/loader/test_ibmb_loader.py new file mode 100644 index 000000000..0ea5d0f07 --- /dev/null +++ b/tests/graph/loader/test_ibmb_loader.py @@ -0,0 +1,67 @@ +import pytest +from mindspore import Tensor, ops + +from mindscience.sharker import typing +from mindscience.sharker.datasets import KarateClub +from mindscience.sharker.loader.ibmb_loader import IBMBBatchLoader, IBMBNodeLoader +from mindscience.sharker.testing import withPackage +from mindscience.sharker.typing import SparseTensor + + +@withPackage('python_tsp') +@pytest.mark.parametrize( + 'use_sparse_tensor', + [False] + [True] if typing.WITH_SPARSE else []) +@pytest.mark.parametrize('kwargs', [ + dict(num_partitions=4, batch_size=1), + dict(num_partitions=8, batch_size=2), +]) +def test_ibmb_batch_loader(use_sparse_tensor, kwargs): + data = KarateClub()[0] + + loader = IBMBBatchLoader( + data, + batch_order='order', + input_nodes=ops.shuffle(ops.arange(data.num_nodes))[:20], + return_edge_index_type='adj' if use_sparse_tensor else 'edge_index', + **kwargs, + ) + assert str(loader) == 'IBMBBatchLoader()' + assert len(loader) == 4 + assert sum([batch.output_node_mask.sum() for batch in loader]) == 20 + + for batch in loader: + if use_sparse_tensor: + assert isinstance(batch.edge_index, SparseTensor) + else: + assert isinstance(batch.edge_index, Tensor) + + +@ withPackage('python_tsp', 'numba') +@ pytest.mark.parametrize( + 'use_sparse_tensor', + [False] + [True] if typing.WITH_SPARSE else []) +@ pytest.mark.parametrize('kwargs', [ + dict(num_nodes_per_batch=4, batch_size=1), + dict(num_nodes_per_batch=2, batch_size=2), +]) +def test_ibmb_node_loader(use_sparse_tensor, kwargs): + data = KarateClub()[0] + + loader = IBMBNodeLoader( + data, + batch_order='order', + input_nodes=ops.shuffle(ops.arange(data.num_nodes))[:20], + num_auxiliary_nodes=4, + return_edge_index_type='adj' if use_sparse_tensor else 'edge_index', + **kwargs, + ) + assert str(loader) == 'IBMBNodeLoader()' + assert len(loader) == 5 + assert sum([batch.output_node_mask.sum() for batch in loader]) == 20 + + for batch in loader: + if use_sparse_tensor: + assert isinstance(batch.edge_index, SparseTensor) + else: + assert isinstance(batch.edge_index, Tensor) diff --git a/tests/graph/loader/test_imbalanced_sampler.py b/tests/graph/loader/test_imbalanced_sampler.py new file mode 100644 index 000000000..1a61d01db --- /dev/null +++ b/tests/graph/loader/test_imbalanced_sampler.py @@ -0,0 +1,116 @@ +from typing import List + +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.datasets import FakeDataset, FakeHeteroDataset +from mindscience.sharker.loader import ( + DataLoader, + ImbalancedSampler, + NeighborLoader, +) +# from mindscience.sharker.testing import onlyNeighborSampler + + +def test_dataloader_with_imbalanced_sampler(): + data_list: List[Graph] = [] + for _ in range(10): + data_list.append(Graph(num_nodes=10, y=0)) + for _ in range(90): + data_list.append(Graph(num_nodes=10, y=1)) + + ms.set_seed(12345) + sampler = ImbalancedSampler(data_list) + loader = DataLoader(data_list, batch_size=10, sampler=sampler) + + y = ops.cat([batch.y for batch in loader]) + + histogram = y.bincount() + prob = histogram / histogram.sum() + + assert histogram.sum() == len(data_list) + assert prob.min() > 0.4 and prob.max() < 0.6 + + # Test with label tensor as input: + ms.set_seed(12345) + sampler = ImbalancedSampler(ms.Tensor([data.y for data in data_list])) + loader = DataLoader(data_list, batch_size=10, sampler=sampler) + + assert ops.isclose(y, ops.cat([batch.y for batch in loader])).all() + + # Test with list of data objects as input where each y is a tensor: + ms.set_seed(12345) + for data in data_list: + data.y = ms.Tensor([data.y]) + sampler = ImbalancedSampler(data_list) + loader = DataLoader(data_list, batch_size=100, sampler=sampler) + + assert ops.isclose(y, ops.cat([batch.y for batch in loader])).all() + + +def test_in_memory_dataset_imbalanced_sampler(): + ms.set_seed(12345) + dataset = FakeDataset(num_graphs=100, avg_num_nodes=10, avg_degree=0, + num_channels=0, num_classes=2) + sampler = ImbalancedSampler(dataset) + loader = DataLoader(dataset, batch_size=10, sampler=sampler) + + y = ops.cat([batch.y for batch in loader]) + histogram = y.bincount() + prob = histogram / histogram.sum() + + assert histogram.sum() == len(dataset) + assert prob.min() > 0.4 and prob.max() < 0.6 + + +# @onlyNeighborSampler +def test_neighbor_loader_with_imbalanced_sampler(): + zeros = ops.zeros(10, dtype=ms.int64) + ones = ops.ones(90, dtype=ms.int64) + + y = ops.cat(([zeros, ones]), axis=0) + edge_index = ms.numpy.empty((2, 0), dtype=ms.int64) + data = Graph(edge_index=edge_index, y=y, num_nodes=y.shape[0]) + + ms.set_seed(12345) + sampler = ImbalancedSampler(data) + loader = NeighborLoader(data, batch_size=10, sampler=sampler, + num_neighbors=[-1]) + + y = ops.cat([batch.y for batch in loader]) + + histogram = y.bincount() + prob = histogram / histogram.sum() + + assert histogram.sum() == data.num_nodes + assert prob.min() > 0.4 and prob.max() < 0.6 + + # Test with label tensor as input: + ms.set_seed(12345) + sampler = ImbalancedSampler(data.y) + loader = NeighborLoader(data, batch_size=10, sampler=sampler, + num_neighbors=[-1]) + + assert ops.isclose(y, ops.cat([batch.y for batch in loader])).all() + + +# @onlyNeighborSampler +def test_hetero_neighbor_loader_with_imbalanced_sampler(): + ms.set_seed(12345) + data = FakeHeteroDataset(num_classes=2)[0] + + loader = NeighborLoader( + data, + batch_size=100, + input_nodes='v0', + num_neighbors=[-1], + sampler=ImbalancedSampler(data['v0'].y), + ) + + y = ops.cat([batch['v0'].y[:batch['v0'].batch_size] for batch in loader]) + + histogram = y.asnumpy().bincount() + prob = histogram / histogram.sum() + + assert histogram.sum() == data['v0'].num_nodes + assert prob.min() > 0.4 and prob.max() < 0.6 diff --git a/tests/graph/loader/test_link_neighbor_loader.py b/tests/graph/loader/test_link_neighbor_loader.py new file mode 100644 index 000000000..ede1ab9bd --- /dev/null +++ b/tests/graph/loader/test_link_neighbor_loader.py @@ -0,0 +1,588 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.loader import LinkNeighborLoader +from mindscience.sharker.testing import ( + MyFeatureStore, + MyGraphStore, + get_random_edge_index, + # onlyNeighborSampler, + withPackage, +) +from mindscience.sharker.sparse import Layout + + +def unique_edge_pairs(edge_index): + return set(map(tuple, edge_index.t().tolist())) + + +# @onlyNeighborSampler +@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) +@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0]) +@pytest.mark.parametrize('filter_per_worker', [None, True, False]) +def test_homo_link_neighbor_loader_basic(subgraph_type, + neg_sampling_ratio, + filter_per_worker): + pos_edge_index = get_random_edge_index(50, 50, 500) + neg_edge_index = get_random_edge_index(50, 50, 500) + neg_edge_index += 50 + + input_edges = ops.cat(([pos_edge_index, neg_edge_index]), axis=-1) + edge_label = ops.cat([ + ops.ones(500), + ops.zeros(500), + ], axis=0) + + data = Graph() + + data.edge_index = pos_edge_index + data.x = ops.arange(100) + data.edge_attr = ops.arange(500) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=20, + edge_label_index=input_edges, + edge_label=edge_label if neg_sampling_ratio is None else None, + subgraph_type=subgraph_type, + neg_sampling_ratio=neg_sampling_ratio, + shuffle=True, + filter_per_worker=filter_per_worker, + ) + + assert str(loader) == 'LinkNeighborLoader()' + assert len(loader) == 1000 / 20 + + batch = loader([0]) + assert isinstance(batch, Graph) + assert int(input_edges[0, 0]) in batch.n_id.tolist() + assert int(input_edges[1, 0]) in batch.n_id.tolist() + + for batch in loader: + assert isinstance(batch, Graph) + + assert batch.n_id.shape == (batch.num_nodes, ) + assert batch.e_id.shape == (batch.num_edges, ) + # assert batch.x.device == device + assert batch.x.shape[0] <= 100 + assert batch.x.min() >= 0 and batch.x.max() < 100 + assert batch.input_id.numel() == 20 + # assert batch.edge_index.device == device + assert batch.edge_index.min() >= 0 + assert batch.edge_index.max() < batch.num_nodes + # assert batch.edge_attr.device == device + assert batch.edge_attr.min() >= 0 + assert batch.edge_attr.max() < 500 + + if neg_sampling_ratio is None: + assert batch.edge_label_index.shape[1] == 20 + + # Assert positive samples are present in the original graph: + edge_index = unique_edge_pairs(batch.edge_index) + edge_label_index = batch.edge_label_index[:, batch.edge_label == 1] + edge_label_index = unique_edge_pairs(edge_label_index) + assert len(edge_index | edge_label_index) == len(edge_index) + + # Assert negative samples are not present in the original graph: + edge_index = unique_edge_pairs(batch.edge_index) + edge_label_index = batch.edge_label_index[:, batch.edge_label == 0] + edge_label_index = unique_edge_pairs(edge_label_index) + assert len(edge_index & edge_label_index) == 0 + + else: + assert batch.edge_label_index.shape[1] == 40 + assert ops.all(batch.edge_label[:20] == 1) + assert ops.all(batch.edge_label[20:] == 0) + + # Ensure local `edge_label_index` correctly maps to input edges. + global_edge_label_index = batch.n_id[batch.edge_label_index] + global_edge_label_index = ( + global_edge_label_index[:, batch.edge_label >= 1]) + global_edge_label_index = unique_edge_pairs(global_edge_label_index) + assert (len(global_edge_label_index & unique_edge_pairs(input_edges)) + == len(global_edge_label_index)) + + +@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) +@pytest.mark.parametrize('neg_sampling_ratio', [None, 1.0]) +def test_hetero_link_neighbor_loader_basic(subgraph_type, neg_sampling_ratio): + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['author'].x = ops.arange(100, 300) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) + data['paper', 'paper'].edge_attr = ops.arange(500) + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) + data['paper', 'author'].edge_attr = ops.arange(500, 1500) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + data['author', 'paper'].edge_attr = ops.arange(1500, 2500) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'author'), + batch_size=20, + subgraph_type=subgraph_type, + neg_sampling_ratio=neg_sampling_ratio, + shuffle=True, + ) + + assert str(loader) == 'LinkNeighborLoader()' + assert len(loader) == 1000 / 20 + + for batch in loader: + assert isinstance(batch, HeteroGraph) + if neg_sampling_ratio is None: + # Assert only positive samples are present in the original graph: + edge_index = unique_edge_pairs(batch['paper', 'author'].edge_index) + edge_label_index = batch['paper', 'author'].edge_label_index + edge_label_index = unique_edge_pairs(edge_label_index) + assert len(edge_index | edge_label_index) == len(edge_index) + + else: + assert batch['paper', 'author'].edge_label_index.shape[1] == 40 + assert ops.all(batch['paper', 'author'].edge_label[:20] == 1) + assert ops.all(batch['paper', 'author'].edge_label[20:] == 0) + + +@pytest.mark.parametrize('subgraph_type', ['directional', 'bidirectional']) +def test_hetero_link_neighbor_loader_loop(subgraph_type): + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['author'].x = ops.arange(100, 300) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'paper'), + batch_size=20, + subgraph_type=subgraph_type, + ) + + for batch in loader: + assert batch['paper'].x.shape[0] <= 100 + assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 + + # Assert positive samples are present in the original graph: + edge_index = unique_edge_pairs(batch['paper', 'paper'].edge_index) + edge_label_index = batch['paper', 'paper'].edge_label_index + edge_label_index = unique_edge_pairs(edge_label_index) + assert len(edge_index | edge_label_index) == len(edge_index) + + +def test_link_neighbor_loader_edge_label(): + edge_index = get_random_edge_index(100, 100, 500) + data = Graph(edge_index=edge_index, x=ops.arange(100)) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=10, + neg_sampling_ratio=1.0, + ) + + for batch in loader: + assert batch.edge_label.dtype == ms.float32 + assert ops.all(batch.edge_label[:10] == 1.0) + assert ops.all(batch.edge_label[10:] == 0.0) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=10, + edge_label=ops.ones(500, dtype=ms.int64), + neg_sampling_ratio=1.0, + ) + + for batch in loader: + assert batch.edge_label.dtype == ms.int64 + assert ops.all(batch.edge_label[:10] == 1) + assert ops.all(batch.edge_label[10:] == 0) + + +@withPackage('pyg_lib') +@pytest.mark.parametrize('batch_size', [1]) +def test_temporal_homo_link_neighbor_loader(batch_size): + data = Graph( + x=ops.randn(10, 5), + edge_index=ops.randint(0, 10, (2, 123)), + time=ops.arange(10), + ) + + # Ensure that nodes exist at the time of the `edge_label_time`: + edge_label_time = ops.max( + data.time[data.edge_index[0]], + data.time[data.edge_index[1]], + ) + + loader = LinkNeighborLoader( + data, + num_neighbors=[-1], + time_attr='time', + edge_label=ops.ones(data.num_edges), + edge_label_time=edge_label_time, + batch_size=batch_size, + shuffle=True, + ) + + for batch in loader: + assert batch.edge_label_index.shape == (2, batch_size) + assert batch.edge_label_time.shape == (batch_size, ) + assert batch.edge_label.shape == (batch_size, ) + assert ops.all(batch.time <= batch.edge_label_time) + + +@withPackage('pyg_lib') +def test_temporal_hetero_link_neighbor_loader(): + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['paper'].time = ops.arange(data['paper'].num_nodes) - 200 + data['author'].x = ops.arange(100, 300) + data['author'].time = ops.arange(data['author'].num_nodes) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 500) + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + + with pytest.raises(ValueError, match=r"'edge_label_time' is not set"): + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'paper'), + batch_size=32, + time_attr='time', + ) + + # With edge_time: + edge_time = ops.arange(data['paper', 'paper'].edge_index.shape[1]) + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'paper'), + edge_label_time=edge_time, + batch_size=32, + time_attr='time', + neg_sampling_ratio=0.5, + drop_last=True, + ) + for batch in loader: + # Check if each seed edge has a different batch: + assert int(batch['paper'].batch.max()) + 1 == 32 + + author_max = batch['author'].time.max() + edge_max = batch['paper', 'paper'].edge_label_time.max() + assert edge_max >= author_max + author_min = batch['author'].time.min() + edge_min = batch['paper', 'paper'].edge_label_time.min() + assert edge_min >= author_min + + +def test_custom_hetero_link_neighbor_loader(): + data = HeteroGraph() + feature_store = MyFeatureStore() + graph_store = MyGraphStore() + + # Set up node features: + x = ops.arange(100) + data['paper'].x = x + feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None) + + x = ops.arange(100, 300) + data['author'].x = x + feature_store.put_tensor(x, group_name='author', attr_name='x', index=None) + + # Set up edge indices (GraphStore does not support `edge_attr` at the + # moment): + edge_index = get_random_edge_index(100, 100, 500) + data['paper', 'to', 'paper'].edge_index = edge_index + graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), + edge_type=('paper', 'to', 'paper'), + layout=Layout.COO, size=(100, 100)) + + edge_index = get_random_edge_index(100, 200, 1000) + data['paper', 'to', 'author'].edge_index = edge_index + graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), + edge_type=('paper', 'to', 'author'), + layout=Layout.COO, size=(100, 200)) + + edge_index = get_random_edge_index(200, 100, 1000) + data['author', 'to', 'paper'].edge_index = edge_index + graph_store.put_edge_index(edge_index=(edge_index[0], edge_index[1]), + edge_type=('author', 'to', 'paper'), + layout=Layout.COO, size=(200, 100)) + + loader1 = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'to', 'author'), + batch_size=20, + ) + + loader2 = LinkNeighborLoader( + (feature_store, graph_store), + num_neighbors=[-1] * 2, + edge_label_index=('paper', 'to', 'author'), + batch_size=20, + ) + + assert str(loader1) == str(loader2) + + for (batch1, batch2) in zip(loader1, loader2): + # Mapped indices of neighbors may be differently sorted: + assert ops.isclose(batch1['paper'].x.sort().all()[0], + batch2['paper'].x.sort()[0]) + assert ops.isclose(batch1['author'].x.sort().all()[0], + batch2['author'].x.sort()[0]) + + # Assert that edge indices have the same size: + assert (batch1['paper', 'to', 'paper'].edge_index.shape == batch1[ + 'paper', 'to', 'paper'].edge_index.shape) + assert (batch1['paper', 'to', 'author'].edge_index.shape == batch1[ + 'paper', 'to', 'author'].edge_index.shape) + assert (batch1['author', 'to', 'paper'].edge_index.shape == batch1[ + 'author', 'to', 'paper'].edge_index.shape) + + +def test_homo_link_neighbor_loader_no_edges(): + loader = LinkNeighborLoader( + Graph(num_nodes=100), + num_neighbors=[], + batch_size=20, + edge_label_index=get_random_edge_index(100, 100, 100), + ) + + for batch in loader: + assert isinstance(batch, Graph) + assert batch.input_id.numel() == 20 + assert batch.edge_label_index.shape[1] == 20 + assert batch.num_nodes == batch.edge_label_index.unique().numel() + + +def test_hetero_link_neighbor_loader_no_edges(): + loader = LinkNeighborLoader( + HeteroGraph(paper=dict(num_nodes=100)), + num_neighbors=[], + edge_label_index=( + ('paper', 'paper'), + get_random_edge_index(100, 100, 100), + ), + batch_size=20, + ) + + for batch in loader: + assert isinstance(batch, HeteroGraph) + assert batch['paper', 'paper'].input_id.numel() == 20 + assert batch['paper', 'paper'].edge_label_index.shape[1] == 20 + assert batch['paper'].num_nodes == batch[ + 'paper', 'paper'].edge_label_index.unique().numel() + + +@withPackage('pyg_lib') +@pytest.mark.parametrize('disjoint', [False, True]) +@pytest.mark.parametrize('temporal', [False, True]) +@pytest.mark.parametrize('amount', [1, 2]) +def test_homo_link_neighbor_loader_triplet(disjoint, temporal, amount): + if not disjoint and temporal: + return + + data = Graph() + data.x = ops.arange(100) + data.edge_index = get_random_edge_index(100, 100, 400) + data.edge_label_index = get_random_edge_index(100, 100, 500) + data.edge_attr = ops.arange(data.num_edges) + + time_attr = edge_label_time = None + if temporal: + time_attr = 'time' + data.time = ops.arange(data.num_nodes) + + edge_label_time = ops.max(data.time[data.edge_label_index[0]], + data.time[data.edge_label_index[1]]) + edge_label_time = edge_label_time + 50 + + batch_size = 20 + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=batch_size, + edge_label_index=data.edge_label_index, + edge_label_time=edge_label_time, + time_attr=time_attr, + disjoint=disjoint, + neg_sampling=dict(mode='triplet', amount=amount), + shuffle=True, + ) + + assert str(loader) == 'LinkNeighborLoader()' + assert len(loader) == 500 / batch_size + + for batch in loader: + assert isinstance(batch, Graph) + + # Check that `src_index` and `dst_pos_index` point to valid edges: + assert ops.equal(batch.x[batch.src_index], + data.edge_label_index[0, batch.input_id]).all() + assert ops.equal(batch.x[batch.dst_pos_index], + data.edge_label_index[1, batch.input_id]).all() + + # Check that `dst_neg_index` points to valid nodes in the batch: + if amount == 1: + assert batch.dst_neg_index.shape == (batch_size, ) + else: + assert batch.dst_neg_index.shape == (batch_size, amount) + assert batch.dst_neg_index.min() >= 0 + assert batch.dst_neg_index.max() < batch.num_nodes + + if disjoint: + # In disjoint mode, seed nodes should always be placed first: + assert batch.src_index.min() == 0 + assert batch.src_index.max() == batch_size - 1 + + assert batch.dst_pos_index.min() == batch_size + assert batch.dst_pos_index.max() == 2 * batch_size - 1 + + assert batch.dst_neg_index.min() == 2 * batch_size + max_seed_nodes = 2 * batch_size + batch_size * amount + assert batch.dst_neg_index.max() == max_seed_nodes - 1 + + assert batch.batch.min() == 0 + assert batch.batch.max() == batch_size - 1 + + # Check that `batch` is always increasing: + for i in range(0, max_seed_nodes, batch_size): + batch_vector = batch.batch[i:i + batch_size] + assert ops.equal(batch_vector, ops.arange(batch_size)).all() + + if temporal: + for i in range(batch_size): + assert batch.time[batch.batch == i].max() <= batch.seed_time[i] + + +@withPackage('pyg_lib') +@pytest.mark.parametrize('disjoint', [False, True]) +@pytest.mark.parametrize('temporal', [False, True]) +@pytest.mark.parametrize('amount', [1, 2]) +def test_hetero_link_neighbor_loader_triplet(disjoint, temporal, amount): + if not disjoint and temporal: + return + + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['author'].x = ops.arange(100, 300) + + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 400) + edge_label_index = get_random_edge_index(100, 100, 500) + data['paper', 'paper'].edge_label_index = edge_label_index + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 1000) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 1000) + + time_attr = edge_label_time = None + if temporal: + time_attr = 'time' + data['paper'].time = ops.arange(data['paper'].num_nodes) + data['author'].time = ops.arange(data['author'].num_nodes) + + edge_label_time = ops.max( + data['paper'].time[data['paper', 'paper'].edge_label_index[0]], + data['paper'].time[data['paper', 'paper'].edge_label_index[1]], + ) + edge_label_time = edge_label_time + 50 + + weight = ops.rand(data['paper'].num_nodes) if not temporal else None + + batch_size = 20 + index = (('paper', 'paper'), data['paper', 'paper'].edge_label_index) + loader = LinkNeighborLoader( + data, + num_neighbors=[-1] * 2, + batch_size=batch_size, + edge_label_index=index, + edge_label_time=edge_label_time, + time_attr=time_attr, + disjoint=disjoint, + neg_sampling=dict(mode='triplet', amount=amount, weight=weight), + shuffle=True, + ) + + assert str(loader) == 'LinkNeighborLoader()' + assert len(loader) == 500 / batch_size + + for batch in loader: + assert isinstance(batch, HeteroGraph) + + node_store = batch['paper'] + edge_store = batch['paper', 'paper'] + + # Check that `src_index` and `dst_pos_index` point to valid edges: + assert ops.equal( + node_store.x[node_store.src_index], + data['paper', 'paper'].edge_label_index[0, edge_store.input_id]).all() + assert ops.equal( + node_store.x[node_store.dst_pos_index], + data['paper', 'paper'].edge_label_index[1, edge_store.input_id]).all() + + # Check that `dst_neg_index` points to valid nodes in the batch: + if amount == 1: + assert node_store.dst_neg_index.shape == (batch_size, ) + else: + assert node_store.dst_neg_index.shape == (batch_size, amount) + assert node_store.dst_neg_index.min() >= 0 + assert node_store.dst_neg_index.max() < node_store.num_nodes + + if disjoint: + # In disjoint mode, seed nodes should always be placed first: + assert node_store.src_index.min() == 0 + assert node_store.src_index.max() == batch_size - 1 + + assert node_store.dst_pos_index.min() == batch_size + assert node_store.dst_pos_index.max() == 2 * batch_size - 1 + + assert node_store.dst_neg_index.min() == 2 * batch_size + max_seed_nodes = 2 * batch_size + batch_size * amount + assert node_store.dst_neg_index.max() == max_seed_nodes - 1 + + assert node_store.batch.min() == 0 + assert node_store.batch.max() == batch_size - 1 + + # Check that `batch` is always increasing: + for i in range(0, max_seed_nodes, batch_size): + batch_vector = node_store.batch[i:i + batch_size] + assert ops.equal(batch_vector, ops.arange(batch_size)).all() + + if temporal: + for i in range(batch_size): + assert (node_store.time[node_store.batch == i].max() + <= node_store.seed_time[i]) + + +@withPackage('pyg_lib') +def test_link_neighbor_loader_mapping(): + edge_index = ms.Tensor([ + [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5], + [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11], + ]) + data = Graph(edge_index=edge_index, num_nodes=12) + + loader = LinkNeighborLoader( + data, + edge_label_index=data.edge_index, + num_neighbors=[1], + batch_size=2, + shuffle=True, + ) + + for batch in loader: + assert ops.equal( + batch.n_id[batch.edge_index], + data.edge_index[:, batch.e_id], + ).all() diff --git a/tests/graph/loader/test_mixin.py b/tests/graph/loader/test_mixin.py new file mode 100644 index 000000000..7e42746b3 --- /dev/null +++ b/tests/graph/loader/test_mixin.py @@ -0,0 +1,46 @@ +import subprocess +from time import sleep + +import pytest +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import NeighborLoader +# from mindscience.sharker.testing import onlyLinux, onlyNeighborSampler + + +# @onlyLinux +# @onlyNeighborSampler +@pytest.mark.parametrize('loader_cores', [None, [1, 2]]) +def test_cpu_affinity_neighbor_loader(loader_cores, spawn_context): + data = Graph(x=ops.randn(1, 1)) + loader = NeighborLoader(data, num_neighbors=[-1], batch_size=1, + num_workers=2) + out = [] + with loader.enable_cpu_affinity(loader_cores): + iterator = loader._get_iterator() + workers = iterator._workers + sleep(3) # Gives time for worker to initialize. + for worker in workers: + process = subprocess.Popen( + ['taskset', '-c', '-p', f'{worker.pid}'], + stdout=subprocess.PIPE) + stdout = process.communicate()[0].decode('utf-8') + # returns "pid 's current affinity list -" + out.append(stdout.split(':')[1].strip()) + if loader_cores: + out == ['[1]', '[2]'] + else: + out[0] != out[1] + + +# @onlyLinux +# @onlyNeighborSampler +def test_multithreading_neighbor_loader(spawn_context): + loader = NeighborLoader( + data=Graph(x=ops.randn(1, 1)), + num_neighbors=[-1], + batch_size=1 + ) + + with loader.enable_multithreading(2): + loader._get_iterator() # Runs assertion in `init_fn`. diff --git a/tests/graph/loader/test_neighbor_loader.py b/tests/graph/loader/test_neighbor_loader.py new file mode 100644 index 000000000..0c0ccc161 --- /dev/null +++ b/tests/graph/loader/test_neighbor_loader.py @@ -0,0 +1,955 @@ +import os.path as osp + +import numpy as np +import pytest +import mindspore as ms +from mindspore import Tensor, ops, nn +from mindscience.sharker import typing +from mindscience.sharker import EdgeIndex +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.loader import NeighborLoader +from mindscience.sharker.nn import GraphConv # , to_hetero +from mindscience.sharker.sampler.base import SubgraphType +from mindscience.sharker.testing import ( + MyFeatureStore, + MyGraphStore, + get_random_edge_index, + onlyLinux, + onlyOnline, + withPackage, +) +from mindscience.sharker.typing import ( + # WITH_EDGE_TIME_NEIGHBOR_SAMPLE, + # WITH_PYG_LIB, + WITH_SPARSE, + # WITH_WEIGHTED_NEIGHBOR_SAMPLE, +) +from mindscience.sharker.utils import ( + is_undirected, + sort_edge_index, + to_csr, + to_undirected, +) +from mindscience.sharker.sparse import Layout + + +DTYPES = [ + pytest.param(ms.int64, id='int64'), + pytest.param(ms.int32, id='int32'), +] + +SUBGRAPH_TYPES = [ + pytest.param(SubgraphType.directional, id='directional'), + pytest.param(SubgraphType.bidirectional, id='bidirectional'), + pytest.param(SubgraphType.induced, id='induced'), +] + +FILTER_PER_WORKERS = [ + pytest.param(None, id='auto_filter'), + pytest.param(True, id='filter_per_worker'), + pytest.param(False, id='filter_in_main'), +] + + +def is_subset(subedge_index, edge_index, src_idx, dst_idx): + num_nodes = int(edge_index.max()) + 1 + idx = num_nodes * edge_index[0] + edge_index[1] + subidx = num_nodes * src_idx[subedge_index[0]] + dst_idx[subedge_index[1]] + mask = ms.Tensor.from_numpy(np.isin(subidx.numpy(), idx.numpy())) + return int(mask.sum()) == mask.numel() + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) +@pytest.mark.parametrize('filter_per_worker', FILTER_PER_WORKERS) +def test_homo_neighbor_loader_basic( + dtype, + subgraph_type, + filter_per_worker, +): + # if dtype != ms.int64 and not typing.WITH_PT20: + # return + # induced = SubgraphType.induced + # if subgraph_type == SubgraphType.induced and not WITH_SPARSE: + # return + # if dtype != ms.int64 and (not WITH_PYG_LIB or subgraph_type == induced): + # return + + ms.set_seed(12345) + + data = Graph() + + data.x = ops.arange(100) + data.edge_index = get_random_edge_index(100, 100, 500, dtype) + data.edge_attr = ops.arange(500) + + loader = NeighborLoader( + data, + num_neighbors=[5] * 2, + batch_size=20, + subgraph_type=subgraph_type, + filter_per_worker=filter_per_worker, + ) + + assert str(loader) == 'NeighborLoader()' + assert len(loader) == 5 + + batch = loader([0]) + assert isinstance(batch, Graph) + assert batch.n_id[:1].tolist() == [0] + + for i, batch in enumerate(loader): + assert isinstance(batch, Graph) + # assert batch.x.device == device + assert batch.x.shape[0] <= 100 + assert batch.n_id.shape == (batch.num_nodes, ) + assert batch.input_id.numel() == batch.batch_size == 20 + assert batch.x.min() >= 0 and batch.x.max() < 100 + assert isinstance(batch.edge_index, EdgeIndex) + batch.edge_index.validate() + size = (batch.num_nodes, batch.num_nodes) + assert batch.edge_index.sparse_shape == size + assert batch.edge_index.sort_order == 'col' + # assert batch.edge_index.device == device + assert batch.edge_index.min() >= 0 + assert batch.edge_index.max() < batch.num_nodes + # assert batch.edge_attr.device == device + assert batch.edge_attr.shape[0] == batch.edge_index.shape[1] + + # Input nodes are always sampled first: + assert ops.equal( + batch.x[:batch.batch_size], + ops.arange(i * batch.batch_size, (i + 1) * batch.batch_size)).all() + + if subgraph_type != SubgraphType.bidirectional: + assert batch.e_id.shape == (batch.num_edges, ) + assert batch.edge_attr.min() >= 0 + assert batch.edge_attr.max() < 500 + + assert is_subset( + batch.edge_index.long(), + data.edge_index.long(), + batch.x, + batch.x, + ) + + +# @onlyNeighborSampler +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) +def test_hetero_neighbor_loader_basic(subgraph_type, dtype): + # induced = SubgraphType.induced + # if subgraph_type == SubgraphType.induced and not WITH_SPARSE: + # return + # if dtype != ms.int64 and (not WITH_PYG_LIB or subgraph_type == induced): + # return + + ms.set_seed(12345) + + data = HeteroGraph() + + data['paper'].x = ops.arange(100) + data['author'].x = ops.arange(100, 300) + + edge_index = get_random_edge_index(100, 100, 500, dtype) + data['paper', 'paper'].edge_index = edge_index + data['paper', 'paper'].edge_attr = ops.arange(500) + edge_index = get_random_edge_index(100, 200, 1000, dtype) + data['paper', 'author'].edge_index = edge_index + data['paper', 'author'].edge_attr = ops.arange(500, 1500) + edge_index = get_random_edge_index(200, 100, 1000, dtype) + data['author', 'paper'].edge_index = edge_index + data['author', 'paper'].edge_attr = ops.arange(1500, 2500) + + r1, c1 = data['paper', 'paper'].edge_index + r2, c2 = data['paper', 'author'].edge_index + ms.Tensor([[0], [100]]) + r3, c3 = data['author', 'paper'].edge_index + ms.Tensor([[100], [0]]) + + batch_size = 20 + + with pytest.raises(ValueError, match="hops must be the same across all"): + loader = NeighborLoader( + data, + num_neighbors={ + ('paper', 'to', 'paper'): [-1], + ('paper', 'to', 'author'): [-1, -1], + ('author', 'to', 'paper'): [-1, -1], + }, + input_nodes='paper', + batch_size=batch_size, + subgraph_type=subgraph_type, + ) + next(iter(loader)) + + loader = NeighborLoader( + data, + num_neighbors=[10] * 2, + input_nodes='paper', + batch_size=batch_size, + subgraph_type=subgraph_type, + ) + + assert str(loader) == 'NeighborLoader()' + assert len(loader) == (100 + batch_size - 1) // batch_size + + for batch in loader: + assert isinstance(batch, HeteroGraph) + + # Test node type selection: + assert set(batch.node_types) == {'paper', 'author'} + + assert batch['paper'].n_id.shape == (batch['paper'].num_nodes, ) + assert batch['paper'].x.shape[0] <= 100 + assert batch['paper'].input_id.numel() == batch_size + assert batch['paper'].batch_size == batch_size + assert batch['paper'].x.min() >= 0 and batch['paper'].x.max() < 100 + + assert batch['author'].n_id.shape == (batch['author'].num_nodes, ) + assert batch['author'].x.shape[0] <= 200 + assert batch['author'].x.min() >= 100 and batch['author'].x.max() < 300 + + # Test edge type selection: + assert set(batch.edge_types) == {('paper', 'to', 'paper'), + ('paper', 'to', 'author'), + ('author', 'to', 'paper')} + + for edge_type, edge_index in batch.edge_index_dict.items(): + src, _, dst = edge_type + assert isinstance(edge_index, EdgeIndex) + edge_index.validate() + size = (batch[src].num_nodes, batch[dst].num_nodes) + assert edge_index.sparse_shape == size + assert edge_index.sort_order == 'col' + + row, col = batch['paper', 'paper'].edge_index + assert row.min() >= 0 and row.max() < batch['paper'].num_nodes + assert col.min() >= 0 and col.max() < batch['paper'].num_nodes + + if subgraph_type != SubgraphType.bidirectional: + assert batch['paper', 'paper'].e_id.shape == (row.numel(), ) + value = batch['paper', 'paper'].edge_attr + assert value.min() >= 0 and value.max() < 500 + + assert is_subset( + batch['paper', 'paper'].edge_index.long(), + data['paper', 'paper'].edge_index.long(), + batch['paper'].x, + batch['paper'].x, + ) + elif subgraph_type != SubgraphType.directional: + assert 'e_id' not in batch['paper', 'paper'] + assert 'edge_attr' not in batch['paper', 'paper'] + + assert is_undirected(batch['paper', 'paper'].edge_index) + + row, col = batch['paper', 'author'].edge_index + assert row.min() >= 0 and row.max() < batch['paper'].num_nodes + assert col.min() >= 0 and col.max() < batch['author'].num_nodes + + if subgraph_type != SubgraphType.bidirectional: + assert batch['paper', 'author'].e_id.shape == (row.numel(), ) + value = batch['paper', 'author'].edge_attr + assert value.min() >= 500 and value.max() < 1500 + + assert is_subset( + batch['paper', 'author'].edge_index.long(), + data['paper', 'author'].edge_index.long(), + batch['paper'].x, + batch['author'].x - 100, + ) + elif subgraph_type != SubgraphType.directional: + assert 'e_id' not in batch['paper', 'author'] + assert 'edge_attr' not in batch['paper', 'author'] + + edge_index1 = batch['paper', 'author'].edge_index + edge_index2 = batch['author', 'paper'].edge_index + assert ops.equal( + edge_index1, + sort_edge_index(edge_index2.flip([0]), sort_by_row=False), + ).all() + + row, col = batch['author', 'paper'].edge_index + assert row.min() >= 0 and row.max() < batch['author'].num_nodes + assert col.min() >= 0 and col.max() < batch['paper'].num_nodes + + if subgraph_type != SubgraphType.bidirectional: + assert batch['author', 'paper'].e_id.shape == (row.numel(), ) + value = batch['author', 'paper'].edge_attr + assert value.min() >= 1500 and value.max() < 2500 + + assert is_subset( + batch['author', 'paper'].edge_index.long(), + data['author', 'paper'].edge_index.long(), + batch['author'].x - 100, + batch['paper'].x, + ) + elif subgraph_type != SubgraphType.directional: + assert 'e_id' not in batch['author', 'paper'] + assert 'edge_attr' not in batch['author', 'paper'] + + edge_index1 = batch['author', 'paper'].edge_index + edge_index2 = batch['paper', 'author'].edge_index + assert ops.equal( + edge_index1, + sort_edge_index(edge_index2.flip([0]), sort_by_row=False), + ).all() + + # Test for isolated nodes (there shouldn't exist any): + assert not batch.has_isolated_nodes() + + +@onlyOnline +@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) +def test_homo_neighbor_loader_on_cora(get_dataset, subgraph_type): + if subgraph_type == SubgraphType.induced and not WITH_SPARSE: + return + dataset = get_dataset(name='Cora') + data = dataset[0] + + mask = data.edge_index[0] < data.edge_index[1] + edge_index = data.edge_index[:, mask] + edge_weight = ops.rand(edge_index.shape[1]) + data.edge_index, data.edge_weight = to_undirected(edge_index, edge_weight) + + split_idx = ops.arange(5, 8) + + loader = NeighborLoader( + data, + num_neighbors=[-1, -1], + batch_size=split_idx.numel(), + input_nodes=split_idx, + subgraph_type=subgraph_type, + ) + assert len(loader) == 1 + + batch = next(iter(loader)) + batch_size = batch.batch_size + + class GNN(nn.Cell): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = GraphConv(in_channels, hidden_channels) + self.conv2 = GraphConv(hidden_channels, out_channels) + + def construct(self, x, edge_index, edge_weight): + x = ops.relu(self.conv1(x, edge_index, edge_weight)) + x = self.conv2(x, edge_index, edge_weight) + return x + + model = GNN(dataset.num_features, 16, dataset.num_classes) + + out1 = model(data.x, data.edge_index, data.edge_weight)[split_idx] + out2 = model(batch.x, batch.edge_index, batch.edge_weight)[:batch_size] + assert ops.isclose(out1, out2, atol=1e-6).all() + + +@onlyOnline +# @onlyNeighborSampler +@pytest.mark.parametrize('subgraph_type', SUBGRAPH_TYPES) +def test_hetero_neighbor_loader_on_cora(get_dataset, subgraph_type): + if subgraph_type == SubgraphType.induced and not WITH_SPARSE: + return + dataset = get_dataset(name='Cora') + data = dataset[0] + + hetero_data = HeteroGraph() + hetero_data['paper'].x = data.x + hetero_data['paper', 'paper'].edge_index = data.edge_index + + split_idx = ops.arange(5, 8) + + loader = NeighborLoader( + hetero_data, + num_neighbors=[-1, -1], + batch_size=split_idx.numel(), + input_nodes=('paper', split_idx), + subgraph_type=subgraph_type, + ) + assert len(loader) == 1 + + hetero_batch = next(iter(loader)) + batch_size = hetero_batch['paper'].batch_size + + class GNN(nn.Cell): + def __init__(self, in_channels, hidden_channels, out_channels): + super().__init__() + self.conv1 = GraphConv(in_channels, hidden_channels) + self.conv2 = GraphConv(hidden_channels, out_channels) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + return x + + model = GNN(dataset.num_features, 16, dataset.num_classes) + hetero_model = to_hetero(model, hetero_data.metadata()) + + out1 = model(data.x, data.edge_index)[split_idx] + out2 = hetero_model(hetero_batch.x_dict, + hetero_batch.edge_index_dict)['paper'][:batch_size] + assert ops.isclose(out1, out2, atol=1e-6).all() + + +@onlyOnline +@withPackage('pyg_lib') +def test_temporal_hetero_neighbor_loader_on_cora(get_dataset): + dataset = get_dataset(name='Cora') + data = dataset[0] + + hetero_data = HeteroGraph() + hetero_data['paper'].x = data.x + hetero_data['paper'].time = ops.arange(data.num_nodes, 0, -1) + hetero_data['paper', 'paper'].edge_index = data.edge_index + + loader = NeighborLoader(hetero_data, num_neighbors=[-1, -1], + input_nodes='paper', time_attr='time', + batch_size=1) + + for batch in loader: + mask = batch['paper'].time[0] >= batch['paper'].time[1:] + assert ops.all(mask) + + +# @onlyNeighborSampler +def test_custom_neighbor_loader(): + # Initialize feature store, graph store, and reference: + feature_store = MyFeatureStore() + graph_store = MyGraphStore() + + # Set up node features: + x = ops.arange(100, 300) + feature_store.put_tensor(x, group_name=None, attr_name='x', index=None) + + y = ops.arange(100, 300) + feature_store.put_tensor(y, group_name=None, attr_name='y', index=None) + + # COO: + edge_index = get_random_edge_index(100, 100, 500, coalesce=True) + edge_index = edge_index[:, ops.shuffle(ops.arange(edge_index.shape[1]))] + coo = (edge_index[0], edge_index[1]) + graph_store.put_edge_index(edge_index=coo, edge_type=None, layout=Layout.COO, + size=(100, 100)) + + data = Graph(x=x, edge_index=edge_index, y=y, num_nodes=200) + + # Construct neighbor loaders: + loader1 = NeighborLoader(data, batch_size=20, + input_nodes=ops.arange(100), + num_neighbors=[-1] * 2) + + loader2 = NeighborLoader((feature_store, graph_store), batch_size=20, + input_nodes=ops.arange(100), + num_neighbors=[-1] * 2) + + assert str(loader1) == str(loader2) + assert len(loader1) == len(loader2) + + for batch1, batch2 in zip(loader1, loader2): + assert len(batch1) == len(batch2) + assert batch1.num_nodes == batch2.num_nodes + assert batch1.num_edges == batch2.num_edges + assert batch1.batch_size == batch2.batch_size + + # Mapped indices of neighbors may be differently sorted ... + assert ops.isclose(batch1.x.sort()[0], batch2.x.sort()[0]).all() + assert ops.isclose(batch1.y.sort()[0], batch2.y.sort()[0]).all() + + +# @onlyNeighborSampler +def test_custom_hetero_neighbor_loader(): + # Initialize feature store, graph store, and reference: + feature_store = MyFeatureStore() + graph_store = MyGraphStore() + data = HeteroGraph() + + # Set up node features: + x = ops.arange(100) + data['paper'].x = x + feature_store.put_tensor(x, group_name='paper', attr_name='x', index=None) + + x = ops.arange(100, 300) + data['author'].x = x + feature_store.put_tensor(x, group_name='author', attr_name='x', index=None) + + # COO: + edge_index = get_random_edge_index(100, 100, 500, coalesce=True) + edge_index = edge_index[:, ops.shuffle(ops.arange(edge_index.shape[1]))] + data['paper', 'to', 'paper'].edge_index = edge_index + coo = (edge_index[0], edge_index[1]) + graph_store.put_edge_index(edge_index=coo, + edge_type=('paper', 'to', 'paper'), + layout=Layout.COO, size=(100, 100)) + + # CSR: + edge_index = get_random_edge_index(100, 200, 1000, coalesce=True) + data['paper', 'to', 'author'].edge_index = edge_index + adj = to_csr(edge_index, shape=(100, 200)) + csr = (adj.crow_indices(), adj.col_indices()) + graph_store.put_edge_index(edge_index=csr, + edge_type=('paper', 'to', 'author'), + layout=Layout.CSR, size=(100, 200)) + + # CSC: + edge_index = get_random_edge_index(200, 100, 1000, coalesce=True) + data['author', 'to', 'paper'].edge_index = edge_index + adj = to_csr(edge_index.flip([0]), shape=(100, 200)) + csc = (adj.col_indices(), adj.crow_indices()) + graph_store.put_edge_index(edge_index=csc, + edge_type=('author', 'to', 'paper'), + layout=Layout.CSC, size=(200, 100)) + + # COO (sorted): + edge_index = get_random_edge_index(200, 200, 100, coalesce=True) + edge_index = edge_index[:, edge_index[1].argsort()] + data['author', 'to', 'author'].edge_index = edge_index + coo = (edge_index[0], edge_index[1]) + graph_store.put_edge_index(edge_index=coo, + edge_type=('author', 'to', 'author'), + layout=Layout.COO, size=(200, 200), is_sorted=True) + + # Construct neighbor loaders: + loader1 = NeighborLoader(data, batch_size=20, + input_nodes=('paper', range(100)), + num_neighbors=[-1] * 2) + + loader2 = NeighborLoader((feature_store, graph_store), batch_size=20, + input_nodes=('paper', range(100)), + num_neighbors=[-1] * 2) + + assert str(loader1) == str(loader2) + assert len(loader1) == len(loader2) + + for batch1, batch2 in zip(loader1, loader2): + # `loader2` explicitly adds `num_nodes` to the batch: + assert len(batch1) + 1 == len(batch2) + assert batch1['paper'].batch_size == batch2['paper'].batch_size + + # Mapped indices of neighbors may be differently sorted ... + for node_type in data.node_types: + assert ops.isclose( + batch1[node_type].x.sort()[0], + batch2[node_type].x.sort()[0], + ).all() + + # ... but should sample the exact same number of edges: + for edge_type in data.edge_types: + assert batch1[edge_type].num_edges == batch2[edge_type].num_edges + + +@onlyOnline +@withPackage('pyg_lib') +def test_temporal_custom_neighbor_loader_on_cora(get_dataset): + # Initialize dataset (once): + dataset = get_dataset(name='Cora') + data = dataset[0] + data.time = ops.arange(data.num_nodes, 0, -1) + + # Initialize feature store, graph store, and reference: + feature_store = MyFeatureStore() + graph_store = MyGraphStore() + hetero_data = HeteroGraph() + + feature_store.put_tensor( + data.x, + group_name='paper', + attr_name='x', + index=None, + ) + hetero_data['paper'].x = data.x + + feature_store.put_tensor( + data.time, + group_name='paper', + attr_name='time', + index=None, + ) + hetero_data['paper'].time = data.time + + # Sort according to time in local neighborhoods: + row, col = data.edge_index + perm = ((col * (data.num_nodes + 1)) + data.time[row]).argsort() + edge_index = data.edge_index[:, perm] + + graph_store.put_edge_index( + edge_index, + edge_type=('paper', 'to', 'paper'), + layout=Layout.COO, + is_sorted=True, + size=(data.num_nodes, data.num_nodes), + ) + hetero_data['paper', 'to', 'paper'].edge_index = data.edge_index + + loader1 = NeighborLoader( + hetero_data, + num_neighbors=[-1, -1], + input_nodes='paper', + time_attr='time', + batch_size=128, + ) + + loader2 = NeighborLoader( + (feature_store, graph_store), + num_neighbors=[-1, -1], + input_nodes='paper', + time_attr='time', + batch_size=128, + ) + + for batch1, batch2 in zip(loader1, loader2): + assert ops.equal(batch1['paper'].time, batch2['paper'].time).all() + + +@withPackage('pyg_lib') +def test_pyg_lib_and_torch_sparse_homo_equality(): + edge_index = get_random_edge_index(20, 20, 100) + adj = to_csr(edge_index.flip([0]), shape=(20, 20)) + colptr, row = adj.crow_indices(), adj.col_indices() + + seed = ops.arange(10) + + sample = torch.ops.pyg.neighbor_sample + out1 = sample(colptr, row, seed, [-1, -1], None, None, None, None, True) + sample = torch.ops.torch_sparse.neighbor_sample + out2 = sample(colptr, row, seed, [-1, -1], False, True) + + row1, col1, node_id1, edge_id1 = out1[:4] + node_id2, row2, col2, edge_id2 = out2 + assert ops.equal(node_id1, node_id2).all() + assert ops.equal(row1, row2).all() + assert ops.equal(col1, col2).all() + assert ops.equal(edge_id1, edge_id2).all() + + +@withPackage('pyg_lib') +def test_pyg_lib_and_torch_sparse_hetero_equality(): + edge_index = get_random_edge_index(20, 10, 50) + adj = to_csr(edge_index.flip([0]), shape=(10, 20)) + colptr1, row1 = adj.crow_indices(), adj.col_indices() + + edge_index = get_random_edge_index(10, 20, 50) + adj = to_csr(edge_index.flip([0]), shape=(20, 10)) + colptr2, row2 = adj.crow_indices(), adj.col_indices() + + node_types = ['paper', 'author'] + edge_types = [('paper', 'to', 'author'), ('author', 'to', 'paper')] + colptr_dict = { + 'paper__to__author': colptr1, + 'author__to__paper': colptr2, + } + row_dict = { + 'paper__to__author': row1, + 'author__to__paper': row2, + } + seed_dict = {'paper': ops.arange(1)} + num_neighbors_dict = { + 'paper__to__author': [-1, -1], + 'author__to__paper': [-1, -1], + } + + sample = ops.pyg.hetero_neighbor_sample + out1 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, + num_neighbors_dict, None, None, None, None, True, False, + True, False, "uniform", True) + sample = ops.torch_sparse.hetero_neighbor_sample + out2 = sample(node_types, edge_types, colptr_dict, row_dict, seed_dict, + num_neighbors_dict, 2, False, True) + + row1_dict, col1_dict, node_id1_dict, edge_id1_dict = out1[:4] + node_id2_dict, row2_dict, col2_dict, edge_id2_dict = out2 + assert len(node_id1_dict) == len(node_id2_dict) + for key in node_id1_dict.keys(): + assert ops.equal(node_id1_dict[key], node_id2_dict[key]).all() + assert len(row1_dict) == len(row2_dict) + for key in row1_dict.keys(): + assert ops.equal(row1_dict[key], row2_dict[key]).all() + assert len(col1_dict) == len(col2_dict) + for key in col1_dict.keys(): + assert ops.equal(col1_dict[key], col2_dict[key]).all() + assert len(edge_id1_dict) == len(edge_id2_dict) + for key in edge_id1_dict.keys(): + assert ops.equal(edge_id1_dict[key], edge_id2_dict[key]).all() + + +@onlyLinux +# @onlyNeighborSampler +def test_memmap_neighbor_loader(tmp_path): + path = osp.join(tmp_path, 'x.npy') + x = np.memmap(path, dtype=np.float32, mode='w+', shape=(100, 32)) + x[:] = np.random.randn(100, 32) + + data = Graph() + data.x = np.memmap(path, dtype=np.float32, mode='r', shape=(100, 32)) + data.edge_index = get_random_edge_index(100, 100, 500) + + assert str(data) == 'Data(x=[100, 32], edge_index=[2, 500])' + assert data.num_nodes == 100 + + loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20, + num_workers=2) + batch = next(iter(loader)) + assert batch.num_nodes <= 100 + assert isinstance(batch.x, Tensor) + assert batch.x.shape == (batch.num_nodes, 32) + + +@withPackage('pyg_lib') +def test_homo_neighbor_loader_sampled_info(): + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + + data = Graph(edge_index=edge_index, num_nodes=14) + + loader = NeighborLoader( + data, + num_neighbors=[1, 2, 4], + batch_size=2, + shuffle=False, + ) + batch = next(iter(loader)) + + assert batch.num_sampled_nodes == [2, 2, 3, 4] + assert batch.num_sampled_edges == [2, 4, 4] + + +@withPackage('pyg_lib') +def test_hetero_neighbor_loader_sampled_info(): + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + + data = HeteroGraph() + data['paper'].num_nodes = data['author'].num_nodes = 14 + data['paper', 'paper'].edge_index = edge_index + data['paper', 'author'].edge_index = edge_index + data['author', 'paper'].edge_index = edge_index + + loader = NeighborLoader( + data, + num_neighbors=[1, 2, 4], + batch_size=2, + input_nodes='paper', + shuffle=False, + ) + batch = next(iter(loader)) + + expected_num_sampled_nodes = { + 'paper': [2, 2, 3, 4], + 'author': [0, 2, 3, 4], + } + expected_num_sampled_edges = { + ('paper', 'to', 'paper'): [2, 4, 4], + ('paper', 'to', 'author'): [0, 4, 4], + ('author', 'to', 'paper'): [2, 4, 4], + } + + for node_type in batch.node_types: + assert (batch[node_type].num_sampled_nodes == + expected_num_sampled_nodes[node_type]) + for edge_type in batch.edge_types: + assert (batch[edge_type].num_sampled_edges == + expected_num_sampled_edges[edge_type]) + + +@withPackage('pyg_lib') +def test_neighbor_loader_mapping(): + edge_index = ms.Tensor([ + [0, 0, 0, 0, 0, 1, 1, 1, 2, 2, 3, 5], + [1, 2, 3, 4, 5, 8, 6, 7, 9, 10, 6, 11], + ]) + data = Graph(edge_index=edge_index, num_nodes=12) + + loader = NeighborLoader( + data, + num_neighbors=[1], + batch_size=2, + shuffle=True, + ) + + for batch in loader: + assert ops.equal( + batch.n_id[batch.edge_index], + data.edge_index[:, batch.e_id], + ).all() + + +@pytest.mark.skipif( + not WITH_WEIGHTED_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_weighted_homo_neighbor_loader(): + edge_index = ms.Tensor([ + [1, 3, 0, 4], + [2, 2, 1, 3], + ]) + edge_weight = ms.Tensor([0.0, 1.0, 0.0, 1.0]) + + data = Graph(num_nodes=5, edge_index=edge_index, edge_weight=edge_weight) + + loader = NeighborLoader( + data, + input_nodes=ms.Tensor([2]), + num_neighbors=[1] * 2, + batch_size=1, + weight_attr='edge_weight', + ) + assert len(loader) == 1 + + batch = next(iter(loader)) + + assert batch.num_nodes == 3 + assert batch.n_id.tolist() == [2, 3, 4] + assert batch.num_edges == 2 + assert batch.n_id[batch.edge_index].tolist() == [[3, 4], [2, 3]] + + +@pytest.mark.skipif( + not WITH_WEIGHTED_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_weighted_hetero_neighbor_loader(): + edge_index = ms.Tensor([ + [1, 3, 0, 4], + [2, 2, 1, 3], + ]) + edge_weight = ms.Tensor([0.0, 1.0, 0.0, 1.0]) + + data = HeteroGraph() + data['paper'].num_nodes = 5 + data['paper', 'to', 'paper'].edge_index = edge_index + data['paper', 'to', 'paper'].edge_weight = edge_weight + + loader = NeighborLoader( + data, + input_nodes=('paper', ms.Tensor([2])), + num_neighbors=[1] * 2, + batch_size=1, + weight_attr='edge_weight', + ) + assert len(loader) == 1 + + batch = next(iter(loader)) + + assert batch['paper'].num_nodes == 3 + assert batch['paper'].n_id.tolist() == [2, 3, 4] + assert batch['paper', 'paper'].num_edges == 2 + global_edge_index = batch['paper'].n_id[batch['paper', 'paper'].edge_index] + assert global_edge_index.tolist() == [[3, 4], [2, 3]] + + +@pytest.mark.skipif( + not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_edge_level_temporal_homo_neighbor_loader(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3], + ]) + edge_time = ops.arange(edge_index.shape[1]) + + data = Graph(edge_index=edge_index, edge_time=edge_time, num_nodes=5) + + loader = NeighborLoader( + data, + num_neighbors=[-1, -1], + input_time=ms.Tensor([4, 4, 4, 4, 4]), + time_attr='edge_time', + batch_size=1, + ) + + for batch in loader: + assert batch.edge_time.numel() == batch.num_edges + if batch.edge_time.numel() > 0: + assert batch.edge_time.max() <= 4 + + +@pytest.mark.skipif( + not WITH_EDGE_TIME_NEIGHBOR_SAMPLE, + reason="'pyg-lib' does not support weighted neighbor sampling", +) +def test_edge_level_temporal_hetero_neighbor_loader(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3], + ]) + edge_time = ops.arange(edge_index.shape[1]) + + data = HeteroGraph() + data['A'].num_nodes = 5 + data['A', 'A'].edge_index = edge_index + data['A', 'A'].edge_time = edge_time + + loader = NeighborLoader( + data, + num_neighbors=[-1, -1], + input_nodes='A', + input_time=ms.Tensor([4, 4, 4, 4, 4]), + time_attr='edge_time', + batch_size=1, + ) + + for batch in loader: + assert batch['A', 'A'].edge_time.numel() == batch['A', 'A'].num_edges + if batch['A', 'A'].edge_time.numel() > 0: + assert batch['A', 'A'].edge_time.max() <= 4 + + +# @onlyNeighborSampler +# @withPackage('torch_frame') +# def test_neighbor_loader_with_tensor_frame(device): +# data = Graph() +# data.tf = get_random_tensor_frame(num_rows=100) +# data.edge_index = get_random_edge_index(100, 100, 500) +# data.edge_attr = get_random_tensor_frame(500) +# data.global_tf = get_random_tensor_frame(num_rows=1) + +# loader = NeighborLoader(data, num_neighbors=[5] * 2, batch_size=20) +# assert len(loader) == 5 + +# for batch in loader: +# assert isinstance(batch.tf, TensorFrame) +# assert batch.tf.device == device +# assert batch.tf.num_rows == batch.n_id.numel() +# assert batch.tf == data.tf[batch.n_id] + +# assert isinstance(batch.edge_attr, TensorFrame) +# assert batch.edge_attr.device == device +# assert batch.edge_attr.num_rows == batch.e_id.numel() +# assert batch.edge_attr == data.edge_attr[batch.e_id] + +# assert isinstance(batch.global_tf, TensorFrame) +# assert batch.global_tf.device == device +# assert batch.global_tf.num_rows == 1 +# assert batch.global_tf == data.global_tf + + +# @onlyNeighborSampler +def test_neighbor_loader_input_id(): + data = HeteroGraph() + data['a'].num_nodes = 10 + data['b'].num_nodes = 12 + + row = ops.randint(0, data['a'].num_nodes, (40, )) + col = ops.randint(0, data['b'].num_nodes, (40, )) + data['a', 'b'].edge_index = ops.stack(([row, col]), axis=0) + data['b', 'a'].edge_index = ops.stack(([col, row]), axis=0) + + mask = ops.ones(data['a'].num_nodes).bool() + mask[0] = False + + loader = NeighborLoader( + data, + input_nodes=('a', mask), + batch_size=2, + num_neighbors=[2, 2], + ) + for i, batch in enumerate(loader): + if i < 4: + expected = [(2 * i) + 1, (2 * i) + 2] + else: + expected = [(2 * i) + 1] + + assert batch['a'].input_id.tolist() == expected diff --git a/tests/graph/loader/test_neighbor_sampler.py b/tests/graph/loader/test_neighbor_sampler.py new file mode 100644 index 000000000..e847764ed --- /dev/null +++ b/tests/graph/loader/test_neighbor_sampler.py @@ -0,0 +1,107 @@ +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.loader import NeighborSampler +from mindscience.sharker.nn.conv import GATConv, SAGEConv +from mindscience.sharker.testing import onlyOnline +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import erdos_renyi_graph + + +def test_neighbor_sampler_basic(): + edge_index = erdos_renyi_graph(num_nodes=10, edge_prob=0.5) + adj_t = SparseTensor.from_edge_index(edge_index, sparse_shape=(10, 10)).t() + E = edge_index.shape[1] + + loader = NeighborSampler(edge_index, sizes=[2, 4], batch_size=2) + assert str(loader) == 'NeighborSampler(sizes=[2, 4])' + assert len(loader) == 5 + + for batch_size, n_id, adjs in loader: + assert batch_size == 2 + assert all(np.isin(n_id, ops.arange(10)).tolist()) + assert ms.numpy.unique(n_id).shape[0] == n_id.shape[0] + for (edge_index, e_id, size) in adjs: + assert int(edge_index[0].max() + 1) <= size[0] + assert int(edge_index[1].max() + 1) <= size[1] + assert all(np.isin(e_id, ops.arange(E)).tolist()) + assert ms.numpy.unique(e_id).shape[0] == e_id.shape[0] + assert size[0] >= size[1] + + out = loader.sample([1, 2]) + assert len(out) == 3 + + loader = NeighborSampler(adj_t, sizes=[2, 4], batch_size=2) + + for batch_size, n_id, adjs in loader: + for (adj_t, e_id, size) in adjs: + assert adj_t.shape[0] == size[1] + assert adj_t.shape[1] == size[0] + + +def test_neighbor_sampler_invalid_kwargs(): + # Ignore `collate_fn` and `dataset` arguments: + edge_index = ms.Tensor([[0, 1], [1, 0]]) + NeighborSampler(edge_index, sizes=[-1], collate_fn=None, dataset=None) + + +@onlyOnline +def test_neighbor_sampler_on_cora(get_dataset): + dataset = get_dataset(name='Cora') + data = dataset[0] + + batch = ops.arange(10) + loader = NeighborSampler(data.edge_index, sizes=[-1, -1, -1], + node_idx=batch, batch_size=10) + + class SAGE(nn.Cell): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.convs = nn.CellList() + self.convs.append(SAGEConv(in_channels, 16)) + self.convs.append(SAGEConv(16, 16)) + self.convs.append(SAGEConv(16, out_channels)) + + def batch(self, x, adjs): + for i, (edge_index, _, size) in enumerate(adjs): + x_target = x[:size[1]] # Target nodes are always placed first. + x = self.convs[i]((x, x_target), edge_index) + return x + + def full(self, x, edge_index): + for conv in self.convs: + x = conv(x, edge_index) + return x + + model = SAGE(dataset.num_features, dataset.num_classes) + + _, n_id, adjs = next(iter(loader)) + out1 = model.batch(data.x[n_id], adjs) + out2 = model.full(data.x, data.edge_index)[batch] + assert ops.isclose(out1, out2, atol=1e-7).all() + + class GAT(nn.Cell): + def __init__(self, in_channels, out_channels): + super().__init__() + + self.convs = nn.CellList() + self.convs.append(GATConv(in_channels, 16, heads=2)) + self.convs.append(GATConv(32, 16, heads=2)) + self.convs.append(GATConv(32, out_channels, heads=2, concat=False)) + + def batch(self, x, adjs): + for i, (edge_index, _, size) in enumerate(adjs): + x_target = x[:size[1]] # Target nodes are always placed first. + x = self.convs[i]((x, x_target), edge_index) + return x + + def full(self, x, edge_index): + for conv in self.convs: + x = conv(x, edge_index) + return x + + _, n_id, adjs = next(iter(loader)) + out1 = model.batch(data.x[n_id], adjs) + out2 = model.full(data.x, data.edge_index)[batch] + assert ops.isclose(out1, out2, atol=1e-7).all() diff --git a/tests/graph/loader/test_random_node_loader.py b/tests/graph/loader/test_random_node_loader.py new file mode 100644 index 000000000..c0d12a235 --- /dev/null +++ b/tests/graph/loader/test_random_node_loader.py @@ -0,0 +1,46 @@ +from mindspore import ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.loader import RandomNodeLoader +from mindscience.sharker.testing import get_random_edge_index + + +def test_random_node_loader(): + data = Graph() + data.x = ops.randn(100, 128) + data.node_id = ops.arange(100) + data.edge_index = get_random_edge_index(100, 100, 500) + data.edge_attr = ops.randn(500, 32) + + loader = RandomNodeLoader(data, num_parts=4, shuffle=True) + assert len(loader) == 4 + + for batch in loader: + assert len(batch) == 4 + assert batch.node_id.min() >= 0 + assert batch.node_id.max() < 100 + assert batch.edge_index.shape[1] == batch.edge_attr.shape[0] + assert ops.isclose(batch.x, data.x[batch.node_id]).all() + batch.validate() + + +def test_heterogeneous_random_node_loader(): + data = HeteroGraph() + data['paper'].x = ops.randn(100, 128) + data['paper'].node_id = ops.arange(100) + data['author'].x = ops.randn(200, 128) + data['author'].node_id = ops.arange(200) + data['paper', 'author'].edge_index = get_random_edge_index(100, 200, 500) + data['paper', 'author'].edge_attr = ops.randn(500, 32) + data['author', 'paper'].edge_index = get_random_edge_index(200, 100, 400) + data['author', 'paper'].edge_attr = ops.randn(400, 32) + data['paper', 'paper'].edge_index = get_random_edge_index(100, 100, 600) + data['paper', 'paper'].edge_attr = ops.randn(600, 32) + + loader = RandomNodeLoader(data, num_parts=4, shuffle=True) + assert len(loader) == 4 + + for batch in loader: + assert len(batch) == 4 + assert batch.node_types == data.node_types + assert batch.edge_types == data.edge_types + batch.validate() diff --git a/tests/graph/loader/test_shadow.py b/tests/graph/loader/test_shadow.py new file mode 100644 index 000000000..1bb2a3eb2 --- /dev/null +++ b/tests/graph/loader/test_shadow.py @@ -0,0 +1,54 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import ShaDowKHopSampler +from mindscience.sharker.typing import SparseTensor + + +def test_shadow_k_hop_sampler(): + row = ms.Tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5]) + col = ms.Tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]) + edge_index = ops.stack(([row, col]), axis=0) + edge_weight = ops.arange(row.shape[0]) + x = ops.randn(6, 16) + y = ops.randint(0, 3, (6, ), dtype=ms.int64) + data = Graph(edge_index=edge_index, edge_weight=edge_weight, x=x, y=y) + + train_mask = ms.Tensor([1, 1, 0, 0, 0, 0]).bool() + loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3, + node_idx=train_mask, batch_size=2) + assert len(loader) == 1 + + batch1 = next(iter(loader)) + assert batch1.num_graphs == len(batch1) == 2 + + assert batch1.batch.tolist() == [0, 0, 0, 0, 1, 1, 1] + assert batch1.ptr.tolist() == [0, 4, 7] + assert batch1.root_n_id.tolist() == [0, 5] + assert batch1.x.tolist() == x[ms.Tensor([0, 1, 2, 3, 0, 1, 2])].tolist() + assert batch1.y.tolist() == y[train_mask].tolist() + row, col = batch1.edge_index + assert row.tolist() == [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6] + assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5] + e_id = ms.Tensor([0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6]) + assert batch1.edge_weight.tolist() == edge_weight[e_id].tolist() + + adj_t = SparseTensor(row=edge_index[0], col=edge_index[1], value=edge_weight) + data = Graph(adj_t=adj_t.t(), x=x, y=y) + + loader = ShaDowKHopSampler(data, depth=1, num_neighbors=3, + node_idx=train_mask, batch_size=2) + assert len(loader) == 1 + + batch2 = next(iter(loader)) + assert batch2.num_graphs == len(batch2) == 2 + + assert batch1.batch.tolist() == batch2.batch.tolist() + assert batch1.ptr.tolist() == batch2.ptr.tolist() + assert batch1.root_n_id.tolist() == batch2.root_n_id.tolist() + assert batch1.x.tolist() == batch2.x.tolist() + assert batch1.y.tolist() == batch2.y.tolist() + row, col, value = batch2.adj_t.t().coo() + assert batch1.edge_index[0].tolist() == row.tolist() + assert batch1.edge_index[1].tolist() == col.tolist() + assert batch1.edge_weight.tolist() == value.tolist() diff --git a/tests/graph/loader/test_temporal_dataloader.py b/tests/graph/loader/test_temporal_dataloader.py new file mode 100644 index 000000000..e1a124d3b --- /dev/null +++ b/tests/graph/loader/test_temporal_dataloader.py @@ -0,0 +1,28 @@ +import pytest +# import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import TemporalGraph +from mindscience.sharker.loader import TemporalDataLoader + + +@pytest.mark.parametrize('batch_size,drop_last', [(4, True), (2, False)]) +def test_temporal_dataloader(batch_size, drop_last): + src = dst = t = ops.arange(10) + msg = ops.randn(10, 16) + + data = TemporalGraph(src=src, dst=dst, t=t, msg=msg) + + loader = TemporalDataLoader( + data, + batch_size=batch_size, + drop_last=drop_last, + ) + assert len(loader) == 10 // batch_size + + for i, batch in enumerate(loader): + assert len(batch) == batch_size + arange = list(range(len(batch) * i, len(batch) * i + len(batch))) + assert batch.src.tolist() == data.src[arange].tolist() + assert batch.dst.tolist() == data.dst[arange].tolist() + assert batch.t.tolist() == data.t[arange].tolist() + assert batch.msg.tolist() == data.msg[arange].tolist() diff --git a/tests/graph/loader/test_utils.py b/tests/graph/loader/test_utils.py new file mode 100644 index 000000000..f00ceffac --- /dev/null +++ b/tests/graph/loader/test_utils.py @@ -0,0 +1,16 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.loader.utils import index_select + + +def test_index_select(): + x = ops.randn(3, 5) + index = ms.Tensor([0, 2]) + assert ops.equal(index_select(x, index), x[index]).all() + assert ops.equal(index_select(x, index, dim=-1), x[..., index]).all() + + +# def test_index_select_out_of_range(): +# with pytest.raises(IndexError, match="out of range"): +# index_select(ops.randn(3, 5), ms.Tensor([0, 2, 3])) diff --git a/tests/graph/loader/test_zip_loader.py b/tests/graph/loader/test_zip_loader.py new file mode 100644 index 000000000..d1e3dbb33 --- /dev/null +++ b/tests/graph/loader/test_zip_loader.py @@ -0,0 +1,40 @@ +import pytest +# import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import NeighborLoader, ZipLoader + + +@pytest.mark.parametrize('filter_per_worker', [True]) # , False +def test_zip_loader(filter_per_worker): + x = ops.arange(100) + edge_index = ops.randint(0, 100, (2, 1000)) + data = Graph(x=x, edge_index=edge_index) + + loaders = [ + NeighborLoader(data, [5], input_nodes=ops.arange(0, 50)), + NeighborLoader(data, [5], input_nodes=ops.arange(50, 95)), + ] + + loader = ZipLoader(loaders, batch_size=10, + filter_per_worker=filter_per_worker) + + assert str(loader) == ('ZipLoader(loaders=[NeighborLoader(), ' + 'NeighborLoader()])') + assert len(loader) == 5 + assert loader.dataset == range(0, 45) + + for i, (batch1, batch2) in enumerate(loader): + n_id1 = batch1.n_id[:batch1.batch_size] + n_id2 = batch2.n_id[:batch2.batch_size] + + if i < 4: + assert batch1.batch_size == 10 + assert batch2.batch_size == 10 + assert ops.equal(n_id1, ops.arange(0 + i * 10, 10 + i * 10)).all() + assert ops.equal(n_id2, ops.arange(50 + i * 10, 60 + i * 10)).all() + else: + assert batch1.batch_size == 5 + assert batch2.batch_size == 5 + assert ops.equal(n_id1, ops.arange(0 + i * 10, 5 + i * 10)).all() + assert ops.equal(n_id2, ops.arange(50 + i * 10, 55 + i * 10)).all() diff --git a/tests/graph/metrics/test_link_pred_metric.py b/tests/graph/metrics/test_link_pred_metric.py new file mode 100644 index 000000000..0b5d1966c --- /dev/null +++ b/tests/graph/metrics/test_link_pred_metric.py @@ -0,0 +1,94 @@ +from typing import List + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.metrics import LinkPredF1, LinkPredMAP, LinkPredNDCG, LinkPredPrecision, LinkPredRecall + + +@pytest.mark.parametrize('num_src_nodes', [100]) +@pytest.mark.parametrize('num_dst_nodes', [1000]) +@pytest.mark.parametrize('num_edges', [3000]) +@pytest.mark.parametrize('batch_size', [32]) +@pytest.mark.parametrize('k', [1, 10, 100]) +def test_precision(num_src_nodes, num_dst_nodes, num_edges, batch_size, k): + row = ops.randint(0, num_src_nodes, (num_edges, )) + col = ops.randint(0, num_dst_nodes, (num_edges, )) + edge_label_index = ops.stack(([row, col]), axis=0) + + pred = ops.rand(num_src_nodes, num_dst_nodes) + pred[row, col] += 0.3 # Offset positive links by a little. + top_k_pred_mat = pred.topk(k, dim=1)[1] + + metric = LinkPredPrecision(k) + assert str(metric) == f'LinkPredPrecision(k={k})' + + for node_id in ops.split(ops.shuffle(ops.arange(num_src_nodes)), batch_size): + mask = ms.numpy.isin(edge_label_index[0], node_id) + + y_batch, y_index = edge_label_index[:, mask] + # Remap `y_batch` back to `[0, batch_size - 1]` range: + arange = ms.numpy.empty(num_src_nodes, dtype=node_id.dtype) + arange[node_id] = ops.arange(node_id.numel()) + y_batch = arange[y_batch] + + metric.update(top_k_pred_mat[node_id], (y_batch, y_index)) + + out = metric.eval() + metric.clear() + + values: List[float] = [] + for i in range(num_src_nodes): # Naive computation per node: + y_index = col[row == i] + if y_index.numel() > 0: + mask = ms.numpy.isin(top_k_pred_mat[i], y_index) + precision = float(mask.sum() / k) + values.append(precision) + expected = ms.Tensor(values).mean() + assert ops.isclose(out.float(), expected).all() + + +def test_recall(): + pred_mat = ms.Tensor([[1, 0], [1, 2], [0, 2]]) + edge_label_index = ms.Tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + + metric = LinkPredRecall(k=2) + assert str(metric) == 'LinkPredRecall(k=2)' + metric.update(pred_mat, edge_label_index) + result = metric.eval() + + assert float(result) == pytest.approx(0.5 * (2 / 3 + 0.5)) + + +def test_f1(): + pred_mat = ms.Tensor([[1, 0], [1, 2], [0, 2]]) + edge_label_index = ms.Tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + + metric = LinkPredF1(k=2) + assert str(metric) == 'LinkPredF1(k=2)' + metric.update(pred_mat, edge_label_index) + result = metric.eval() + assert float(result) == pytest.approx(0.6500) + + +def test_map(): + pred_mat = ms.Tensor([[1, 0], [1, 2], [0, 2]]) + edge_label_index = ms.Tensor([[0, 0, 0, 2, 2], [0, 1, 2, 2, 1]]) + + metric = LinkPredMAP(k=2) + assert str(metric) == 'LinkPredMAP(k=2)' + metric.update(pred_mat, edge_label_index) + result = metric.eval() + assert float(result) == pytest.approx(0.6250) + + +def test_ndcg(): + pred_mat = ms.Tensor([[1, 0], [1, 2], [0, 2]]) + edge_label_index = ms.Tensor([[0, 0, 2, 2], [0, 1, 2, 1]]) + + metric = LinkPredNDCG(k=2) + assert str(metric) == 'LinkPredNDCG(k=2)' + metric.update(pred_mat, edge_label_index) + result = metric.eval() + + assert float(result) == pytest.approx(0.6934264) diff --git a/tests/graph/my_config.yaml b/tests/graph/my_config.yaml new file mode 100644 index 000000000..ef976f43b --- /dev/null +++ b/tests/graph/my_config.yaml @@ -0,0 +1,15 @@ +defaults: + - dataset: KarateClub + - transform@dataset.transform: + - NormalizeFeatures + - AddSelfLoops + - model: GCN + - optimizer: Adam + - lr_scheduler: ReduceLROnPlateau + - _self_ + +model: + in_channels: 34 + out_channels: 4 + hidden_channels: 16 + num_layers: 2 diff --git a/tests/graph/nn/aggr/test_aggr_utils.py b/tests/graph/nn/aggr/test_aggr_utils.py new file mode 100644 index 000000000..48fde3ceb --- /dev/null +++ b/tests/graph/nn/aggr/test_aggr_utils.py @@ -0,0 +1,66 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr.utils import ( + InducedSetAttentionBlock, + MultiheadAttentionBlock, + PoolingByMultiheadAttention, + SetAttentionBlock, +) + + +def test_multihead_attention_block(): + x = ops.randn(2, 4, 8) + y = ops.randn(2, 3, 8) + x_mask = ms.Tensor([[1, 1, 1, 1], [1, 1, 0, 0]]).bool() + y_mask = ms.Tensor([[1, 1, 0], [1, 1, 1]]).bool() + + block = MultiheadAttentionBlock(8, heads=2) + assert str(block) == ('MultiheadAttentionBlock(8, heads=2, ' + 'layer_norm=True, dropout=0.0)') + + out = block(x, y, x_mask, y_mask) + assert out.shape == (2, 4, 8) + + +def test_multihead_attention_block_dropout(): + x = ops.randn(2, 4, 8) + + block = MultiheadAttentionBlock(8, dropout=0.5) + block.set_train() + assert not mint.isclose(block(x, x), block(x, x)).all() + + +def test_set_attention_block(): + x = ops.randn(2, 4, 8) + mask = ms.Tensor([[1, 1, 1, 1], [1, 1, 0, 0]]).bool() + + block = SetAttentionBlock(8, heads=2) + assert str(block) == ('SetAttentionBlock(8, heads=2, layer_norm=True, ' + 'dropout=0.0)') + + out = block(x, mask) + assert out.shape == (2, 4, 8) + + +def test_induced_set_attention_block(): + x = ops.randn(2, 4, 8) + mask = ms.Tensor([[1, 1, 1, 1], [1, 1, 0, 0]]).bool() + + block = InducedSetAttentionBlock(8, num_induced_points=2, heads=2) + assert str(block) == ('InducedSetAttentionBlock(8, num_induced_points=2, ' + 'heads=2, layer_norm=True, dropout=0.0)') + + out = block(x, mask) + assert out.shape == (2, 4, 8) + + +def test_pooling_by_multihead_attention(): + x = ops.randn(2, 4, 8) + mask = ms.Tensor([[1, 1, 1, 1], [1, 1, 0, 0]]).bool() + + block = PoolingByMultiheadAttention(8, num_seed_points=2, heads=2) + assert str(block) == ('PoolingByMultiheadAttention(8, num_seed_points=2, ' + 'heads=2, layer_norm=True, dropout=0.0)') + + out = block(x, mask) + assert out.shape == (2, 2, 8) diff --git a/tests/graph/nn/aggr/test_attention.py b/tests/graph/nn/aggr/test_attention.py new file mode 100644 index 000000000..3eb07f5d7 --- /dev/null +++ b/tests/graph/nn/aggr/test_attention.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindspore import ops,mint +from mindscience.sharker.nn.models.mlp import MLP +from mindscience.sharker.nn.aggr import AttentionalAggregation +import pytest + +@pytest.mark.parametrize('dim', [2,3]) +def test_attentional_aggregation(): + channels = 16 + x = ops.randn(6, channels) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + ptr = ms.Tensor([0, 2, 5, 6]) + + gate_nn = MLP([channels, 1], act='relu') + nn = MLP([channels, channels], act='relu') + aggr = AttentionalAggregation(gate_nn, nn) + aggr.reset_parameters() + assert str(aggr) == (f'AttentionalAggregation(gate_nn=MLP({channels}, 1), ' + f'nn=MLP({channels}, {channels}))') + + out = aggr(x, index) + assert out.shape == (3, channels) + + assert mint.isclose(out, aggr(x, ptr=ptr), rtol=1e-04, atol=1e-4).all() diff --git a/tests/graph/nn/aggr/test_basic.py b/tests/graph/nn/aggr/test_basic.py new file mode 100644 index 000000000..e3ee9a59b --- /dev/null +++ b/tests/graph/nn/aggr/test_basic.py @@ -0,0 +1,95 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker import typing +from mindscience.sharker.nn.aggr import ( + MaxAggregation, + MeanAggregation, + MinAggregation, + MulAggregation, + PowerMeanAggregation, + SoftmaxAggregation, + StdAggregation, + SumAggregation, + VarAggregation, +) + + +def test_validate(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + ptr = ms.Tensor([0, 2, 5, 6]) + + aggr = MeanAggregation() + + with pytest.raises(ValueError, match="invalid dimension"): + aggr(x, index, dim=-3) + + with pytest.raises(ValueError, match="invalid 'dim_size'"): + aggr(x, ptr=ptr, dim_size=2) + + with pytest.raises(ValueError, match="invalid 'dim_size'"): + aggr(x, index, dim_size=2) + + +@pytest.mark.parametrize('Aggregation', [ + MeanAggregation, + SumAggregation, + MaxAggregation, + MinAggregation, + MulAggregation, + VarAggregation, + StdAggregation, +]) +def test_basic_aggregation(Aggregation): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + ptr = ms.Tensor([0, 2, 5, 6]) + + aggr = Aggregation() + assert str(aggr) == f'{Aggregation.__name__}()' + + out = aggr(x, index) + assert out.shape == (3, x.shape[1]) + + if isinstance(aggr, MulAggregation): + with pytest.raises(RuntimeError, match="requires 'index'"): + aggr(x, ptr=ptr) + else: + assert mint.isclose(out, aggr(x, ptr=ptr), rtol=1e-04, atol=1e-4).all() + + +def test_var_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + var_aggr = VarAggregation() + out = var_aggr(x, index) + + mean_aggr = MeanAggregation() + expected = mean_aggr((x - mean_aggr(x, index)[index]).pow(2), index) + assert mint.isclose(out, expected, rtol=1e-04, atol=1e-4).all() + + +@pytest.mark.parametrize('Aggregation', [ + SoftmaxAggregation, + PowerMeanAggregation, +]) +@pytest.mark.parametrize('learn', [True, False]) +def test_learnable_aggregation(Aggregation, learn): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + ptr = ms.Tensor([0, 2, 5, 6]) + + aggr = Aggregation(learn=learn) + assert str(aggr) == f'{Aggregation.__name__}(learn={learn})' + + out = aggr(x, index) + assert out.shape == (3, x.shape[1]) + assert mint.isclose(out, aggr(x, ptr=ptr), rtol=1e-04, atol=1e-4).all() + + grad_fn = ops.grad(lambda a, b: aggr(a, b).mean(), grad_position=None, weights=aggr.trainable_params()) + if learn: + grads = grad_fn(x, index) + for grad in grads: + assert not ops.isnan(grad).any() diff --git a/tests/graph/nn/aggr/test_deep_sets.py b/tests/graph/nn/aggr/test_deep_sets.py new file mode 100644 index 000000000..70dab1564 --- /dev/null +++ b/tests/graph/nn/aggr/test_deep_sets.py @@ -0,0 +1,18 @@ +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.nn.aggr import DeepSetsAggregation + + +def test_deep_sets_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = DeepSetsAggregation( + local_nn=nn.Dense(16, 32), + global_nn=nn.Dense(32, 64), + ) + + assert str(aggr) == ('DeepSetsAggregation(local_nn=Dense(input_channels=16, output_channels=32, has_bias=True), global_nn=Dense(input_channels=32, output_channels=64, has_bias=True))') + + out = aggr(x, index) + assert out.shape == (3, 64) diff --git a/tests/graph/nn/aggr/test_equilibrium.py b/tests/graph/nn/aggr/test_equilibrium.py new file mode 100644 index 000000000..2a808263a --- /dev/null +++ b/tests/graph/nn/aggr/test_equilibrium.py @@ -0,0 +1,45 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr import EquilibriumAggregation + + +@pytest.mark.parametrize('iter', [0, 1, 5]) +@pytest.mark.parametrize('alpha', [0, .1, 5]) +def test_equilibrium(iter, alpha): + batch_size = 10 + feature_channels = 3 + output_channels = 2 + x = ops.randn((batch_size, feature_channels)) + model = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter) + + assert str(model) == 'EquilibriumAggregation()' + out = model(x) + assert out.shape == (1, 2) + + out = model(x, dim_size=3) + assert out.shape == (3, 2) + assert mint.all(out[1:, :] == 0) + + +@pytest.mark.parametrize('iter', [0, 1, 5]) +@pytest.mark.parametrize('alpha', [0, .1, 5]) +def test_equilibrium_batch(iter, alpha): + batch_1, batch_2 = 4, 6 + feature_channels = 3 + output_channels = 2 + x = ops.randn(batch_1 + batch_2, feature_channels) + batch = ms.Tensor([0 for _ in range(batch_1)] + + [1 for _ in range(batch_2)]) + + model = EquilibriumAggregation(feature_channels, output_channels, + num_layers=[10, 10], grad_iter=iter) + + assert str(model) == 'EquilibriumAggregation()' + out = model(x, batch) + assert out.shape == (2, 2) + + out = model(x, dim_size=3) + assert out.shape == (3, 2) + assert mint.all(out[1:, :] == 0) diff --git a/tests/graph/nn/aggr/test_fused.py b/tests/graph/nn/aggr/test_fused.py new file mode 100644 index 000000000..5ac83c87d --- /dev/null +++ b/tests/graph/nn/aggr/test_fused.py @@ -0,0 +1,74 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr.fused import FusedAggregation +from mindscience.sharker.nn.resolver import aggregation_resolver +from mindscience.sharker.profile import benchmark + + +@pytest.mark.parametrize('aggrs', [ + ['sum', 'mean', 'min', 'max', 'mul', 'var', 'std'], + ['sum', 'min', 'max', 'mul', 'var', 'std'], + ['min', 'max', 'mul', 'var', 'std'], + ['mean', 'min', 'max', 'mul', 'var', 'std'], + ['sum', 'min', 'max', 'mul', 'std'], + ['mean', 'min', 'max', 'mul', 'std'], + ['min', 'max', 'mul', 'std'], +]) +def test_fused_aggregation(aggrs): + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + + x = ops.randn(6, 1) + y = x.copy() + index = ms.Tensor([0, 0, 1, 1, 1, 3]) + + aggr = FusedAggregation(aggrs) + assert str(aggr) == 'FusedAggregation()' + out = mint.cat((aggr(x, index)), dim=-1) + + expected = mint.cat(([aggr(y, index) for aggr in aggrs]), dim=-1) + assert mint.isclose(out, expected, rtol=1e-04, atol=1e-4).all() + + grad_fn = ops.grad(lambda a, b: mint.cat((aggr(a, b)), dim=-1).mean(), grad_position=0, weights=None) + + grad_x = grad_fn(x, index) + assert grad_x is not None + grad_y = grad_fn(y, index) + assert grad_y is not None + assert mint.isclose(grad_x, grad_y, rtol=1e-04, atol=1e-4).all() + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + args = parser.parse_args() + + num_nodes, num_edges = 1_000, 50_000 + x = ops.randn(num_edges, 64) + index = ops.randint(0, num_nodes, (num_edges, )) + + aggrs = ['sum', 'mean', 'max', 'std'] + print(f'Aggregators: {", ".join(aggrs)}') + + aggrs = [aggregation_resolver(aggr) for aggr in aggrs] + fused_aggregation = FusedAggregation(aggrs) + + def naive_aggr(x, index, dim_size): + outs = [aggr(x, index, dim_size=dim_size) for aggr in aggrs] + return mint.cat((outs), dim=-1) + + def fused_aggr(x, index, dim_size): + outs = fused_aggregation(x, index, dim_size=dim_size) + return mint.cat((outs), dim=-1) + + benchmark( + funcs=[naive_aggr, fused_aggr], + func_names=['Naive', 'Fused'], + args=(x, index, num_nodes), + num_steps=100 if args.device == 'cpu' else 1000, + num_warmups=50 if args.device == 'cpu' else 500, + backward=args.backward, + ) diff --git a/tests/graph/nn/aggr/test_gmt.py b/tests/graph/nn/aggr/test_gmt.py new file mode 100644 index 000000000..e17d81a30 --- /dev/null +++ b/tests/graph/nn/aggr/test_gmt.py @@ -0,0 +1,17 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import GraphMultisetTransformer +from mindscience.sharker.testing import is_full_test + + +def test_graph_multiset_transformer(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = GraphMultisetTransformer(16, k=2, heads=2) + assert str(aggr) == ('GraphMultisetTransformer(16, k=2, heads=2, ' + 'layer_norm=False, dropout=0.0)') + + out = aggr(x, index) + assert out.shape == (3, 16) + diff --git a/tests/graph/nn/aggr/test_gru.py b/tests/graph/nn/aggr/test_gru.py new file mode 100644 index 000000000..3932e89f4 --- /dev/null +++ b/tests/graph/nn/aggr/test_gru.py @@ -0,0 +1,14 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import GRUAggregation + + +def test_gru_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = GRUAggregation(16, 32) + assert str(aggr) == 'GRUAggregation(16, 32)' + + out = aggr(x, index) + assert out.shape == (3, 32) diff --git a/tests/graph/nn/aggr/test_lcm.py b/tests/graph/nn/aggr/test_lcm.py new file mode 100644 index 000000000..d1c963183 --- /dev/null +++ b/tests/graph/nn/aggr/test_lcm.py @@ -0,0 +1,72 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import LCMAggregation +from mindscience.sharker.profile import benchmark + + +def test_lcm_aggregation_with_project(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = LCMAggregation(16, 32) + assert str(aggr) == 'LCMAggregation(16, 32, project=True)' + + out = aggr(x, index) + assert out.shape == (3, 32) + + +def test_lcm_aggregation_without_project(): + x = ops.randn(5, 16) + index = ms.Tensor([0, 1, 1, 2, 2]) + + aggr = LCMAggregation(16, 16, project=False) + assert str(aggr) == 'LCMAggregation(16, 16, project=False)' + + out = aggr(x, index) + assert out.shape == (3, 16) + + +def test_lcm_aggregation_error_handling(): + with pytest.raises(ValueError, match="must be projected"): + LCMAggregation(16, 32, project=False) + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + args = parser.parse_args() + + channels = 128 + batch_size_list = [2**i for i in range(10, 12)] + num_nodes_list = [2**i for i in range(15, 18)] + + aggr = LCMAggregation(channels, channels, project=False) + aggr = aggr.to(args.device) + + funcs = [] + func_names = [] + args_list = [] + for batch_size, num_nodes in product(batch_size_list, num_nodes_list): + x = ops.randn((num_nodes, channels)) + index = ops.randint(0, batch_size, (num_nodes, )) + index = index.sort()[0] + + funcs.append(aggr) + func_names.append(f'B={batch_size}, N={num_nodes}') + args_list.append((x, index)) + + benchmark( + funcs=funcs, + func_names=func_names, + args=args_list, + num_steps=10 if args.device == 'cpu' else 100, + num_warmups=5 if args.device == 'cpu' else 50, + backward=args.backward, + progress_bar=True, + ) diff --git a/tests/graph/nn/aggr/test_lstm.py b/tests/graph/nn/aggr/test_lstm.py new file mode 100644 index 000000000..006422c5e --- /dev/null +++ b/tests/graph/nn/aggr/test_lstm.py @@ -0,0 +1,18 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import LSTMAggregation + + +def test_lstm_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = LSTMAggregation(16, 32) + assert str(aggr) == 'LSTMAggregation(16, 32)' + + with pytest.raises(ValueError, match="is not sorted"): + aggr(x, ms.Tensor([0, 1, 0, 1, 2, 1])) + + out = aggr(x, index) + assert out.shape == (3, 32) diff --git a/tests/graph/nn/aggr/test_mlp_aggr.py b/tests/graph/nn/aggr/test_mlp_aggr.py new file mode 100644 index 000000000..1fb7497d4 --- /dev/null +++ b/tests/graph/nn/aggr/test_mlp_aggr.py @@ -0,0 +1,19 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import MLPAggregation + + +def test_mlp_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + + aggr = MLPAggregation( + in_channels=16, + out_channels=32, + max_num_elements=3, + num_layers=1, + ) + assert str(aggr) == 'MLPAggregation(16, 32, max_num_elements=3)' + + out = aggr(x, index) + assert out.shape == (3, 32) diff --git a/tests/graph/nn/aggr/test_multi.py b/tests/graph/nn/aggr/test_multi.py new file mode 100644 index 000000000..bfd407c95 --- /dev/null +++ b/tests/graph/nn/aggr/test_multi.py @@ -0,0 +1,42 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr import MultiAggregation + + +@pytest.mark.parametrize('multi_aggr_tuple', [ + (dict(mode='cat'), 3), + (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1), + (dict(mode='attn', mode_kwargs=dict(in_channels=16, out_channels=16, + num_heads=4)), 1), + (dict(mode='sum'), 1), + (dict(mode='mean'), 1), + (dict(mode='max'), 1), + (dict(mode='min'), 1), + (dict(mode='logsumexp'), 1), + (dict(mode='std'), 1), + (dict(mode='var'), 1), +]) +def test_multi_aggr(multi_aggr_tuple): + # The 'cat' combine mode will expand the output dimensions by + # the number of aggregators which is 3 here, while the other + # modes keep output dimensions unchanged. + aggr_kwargs, expand = multi_aggr_tuple + x = ops.randn(7, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2, 3]) + ptr = ms.Tensor([0, 2, 5, 6, 7]) + + aggrs = ['mean', 'sum', 'max'] + aggr = MultiAggregation(aggrs, **aggr_kwargs) + assert str(aggr) == ('MultiAggregation([\n' + ' MeanAggregation(),\n' + ' SumAggregation(),\n' + ' MaxAggregation(),\n' + f"], mode={aggr_kwargs['mode']})") + + out = aggr(x, index) + assert out.shape == (4, expand * x.shape[1]) + + assert mint.isclose(out, aggr(x, ptr=ptr), rtol=1e-04, atol=1e-4).all() + + diff --git a/tests/graph/nn/aggr/test_quantile.py b/tests/graph/nn/aggr/test_quantile.py new file mode 100644 index 000000000..8c50e1aba --- /dev/null +++ b/tests/graph/nn/aggr/test_quantile.py @@ -0,0 +1,111 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import MedianAggregation, QuantileAggregation +import numpy as np + + +@pytest.mark.parametrize('q', [0., .1, .2, .3, .4, .5, .6, .7, .8, .9, 1.]) +@pytest.mark.parametrize('interpolation', QuantileAggregation.interpolations) +@pytest.mark.parametrize('dim', [0, 1]) +@pytest.mark.parametrize('dim_size', [None, 15]) +@pytest.mark.parametrize('fill_value', [0.0, 10.0]) +def test_quantile_aggregation(q, interpolation, dim, dim_size, fill_value): + x = ms.Tensor([ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 0.0, 1.0], + [2.0, 3.0, 4.0], + [5.0, 6.0, 7.0], + [8.0, 9.0, 0.0], + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]) + index = ops.zeros(x.shape[dim], dtype=ms.int64) + + aggr = QuantileAggregation(q=q, interpolation=interpolation, + fill_value=fill_value) + assert str(aggr) == f"QuantileAggregation(q={q})" + + out = aggr(x, index, dim=dim, dim_size=dim_size) + expected = np.quantile(x.asnumpy(), q, axis=dim, method=interpolation, keepdims=True) + assert np.allclose(out.narrow(dim, 0, 1).asnumpy(), expected) + + if out.shape[0] > index.max() + 1 and out.shape[dim] > 1: + padding = out.narrow(dim, 1, out.shape[dim] - 1) + assert ops.isclose(padding, ms.Tensor(fill_value)).all() + + +def test_median_aggregation(): + x = ms.Tensor([ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 0.0, 1.0], + [2.0, 3.0, 4.0], + [5.0, 6.0, 7.0], + [8.0, 9.0, 0.0], + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]) + + aggr = MedianAggregation() + assert str(aggr) == "MedianAggregation()" + + index = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]) + assert aggr(x, index).tolist() == [ + [3.0, 1.0, 2.0], + [5.0, 6.0, 4.0], + [4.0, 5.0, 6.0], + ] + + index = ms.Tensor([0, 1, 0]) + assert aggr(x, index, dim=1).tolist() == [ + [0.0, 1.0], + [3.0, 4.0], + [6.0, 7.0], + [1.0, 0.0], + [2.0, 3.0], + [5.0, 6.0], + [0.0, 9.0], + [1.0, 2.0], + [4.0, 5.0], + [7.0, 8.0], + ] + + +def test_quantile_aggregation_multi(): + x = ms.Tensor([ + [0.0, 1.0, 2.0], + [3.0, 4.0, 5.0], + [6.0, 7.0, 8.0], + [9.0, 0.0, 1.0], + [2.0, 3.0, 4.0], + [5.0, 6.0, 7.0], + [8.0, 9.0, 0.0], + [1.0, 2.0, 3.0], + [4.0, 5.0, 6.0], + [7.0, 8.0, 9.0], + ]) + index = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 2, 2, 2]) + + qs = [0.25, 0.5, 0.75] + + assert ops.isclose( + QuantileAggregation(qs)(x, index), + ops.cat(([QuantileAggregation(q)(x, index) for q in qs]), axis=-1), + ).all() + + +def test_quantile_aggregation_validate(): + with pytest.raises(ValueError, match="at least one quantile"): + QuantileAggregation(q=[]) + + with pytest.raises(ValueError, match="must be in the range"): + QuantileAggregation(q=-1) + + with pytest.raises(ValueError, match="Invalid interpolation method"): + QuantileAggregation(q=0.5, interpolation=None) diff --git a/tests/graph/nn/aggr/test_scaler.py b/tests/graph/nn/aggr/test_scaler.py new file mode 100644 index 000000000..f158f3005 --- /dev/null +++ b/tests/graph/nn/aggr/test_scaler.py @@ -0,0 +1,25 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.aggr import DegreeScalerAggregation + + +@pytest.mark.parametrize('train_norm', [True, False]) +def test_degree_scaler_aggregation(train_norm): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 2]) + ptr = ms.Tensor([0, 2, 5, 6]) + deg = ms.Tensor([0, 3, 0, 1, 1, 0]) + + aggr = ['mean', 'sum', 'max'] + scaler = [ + 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' + ] + aggr = DegreeScalerAggregation(aggr, scaler, deg, train_norm=train_norm) + assert str(aggr) == 'DegreeScalerAggregation()' + + out = aggr(x, index) + assert out.shape == (3, 240) + + with pytest.raises(NotImplementedError, match="requires 'index'"): + aggr(x, ptr=ptr) diff --git a/tests/graph/nn/aggr/test_set2set.py b/tests/graph/nn/aggr/test_set2set.py new file mode 100644 index 000000000..c985ade6d --- /dev/null +++ b/tests/graph/nn/aggr/test_set2set.py @@ -0,0 +1,28 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr import Set2Set + + +def test_set2set(): + set2set = Set2Set(in_channels=2, processing_steps=1) + assert str(set2set) == 'Set2Set(2, 4)' + + N = 4 + x_1, batch_1 = ops.randn(N, 2), mint.zeros(N, dtype=ms.int64) + out_1 = set2set(x_1, batch_1).view(-1) + + N = 6 + x_2, batch_2 = ops.randn(N, 2), mint.zeros(N, dtype=ms.int64) + out_2 = set2set(x_2, batch_2).view(-1) + + x, batch = mint.cat([x_1, x_2]), mint.cat([batch_1, batch_2 + 1]) + out = set2set(x, batch) + assert out.shape == (2, 4) + assert mint.isclose(out_1, out[0]).all() + assert mint.isclose(out_2, out[1]).all() + + x, batch = mint.cat([x_2, x_1]), mint.cat([batch_2, batch_1 + 1]) + out = set2set(x, batch) + assert out.shape == (2, 4) + assert mint.isclose(out_1, out[1]).all() + assert mint.isclose(out_2, out[0]).all() diff --git a/tests/graph/nn/aggr/test_set_transformer.py b/tests/graph/nn/aggr/test_set_transformer.py new file mode 100644 index 000000000..9aa4eeff2 --- /dev/null +++ b/tests/graph/nn/aggr/test_set_transformer.py @@ -0,0 +1,19 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr import SetTransformerAggregation +from mindscience.sharker.testing import is_full_test + + +def test_set_transformer_aggregation(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 3]) + + aggr = SetTransformerAggregation(16, num_seed_points=2, heads=2) + assert str(aggr) == ('SetTransformerAggregation(16, num_seed_points=2, ' + 'heads=2, layer_norm=False, dropout=0.0)') + + out = aggr(x, index) + assert out.shape == (4, 2 * 16) + assert out.isnan().sum() == 0 + assert out[2].abs().sum() == 0 + diff --git a/tests/graph/nn/aggr/test_sort.py b/tests/graph/nn/aggr/test_sort.py new file mode 100644 index 000000000..1d81a55f9 --- /dev/null +++ b/tests/graph/nn/aggr/test_sort.py @@ -0,0 +1,72 @@ +import mindspore as ms +from mindspore import ops,mint +from mindscience.sharker.nn.aggr import SortAggregation + + +def test_sort_aggregation(): + N_1, N_2 = 4, 6 + x = ops.randn(N_1 + N_2, 4) + index = ms.Tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) + + aggr = SortAggregation(k=5) + assert str(aggr) == 'SortAggregation(k=5)' + + out = aggr(x, index) + assert out.shape == (2, 5 * 4) + + out_dim = out = aggr(x, index, dim=0) + assert mint.isclose(out_dim, out).all() + + out = out.view(2, 5, 4) + + # First graph output has been filled up with zeros. + assert out[0, -1].tolist() == [0, 0, 0, 0] + + # Nodes are sorted. + assert ops.equal(out[0, :4, -1].argsort(), 3 - mint.arange(4)).all() + assert ops.equal(out[1, :, -1].argsort(), 4 - mint.arange(5)).all() + + +def test_sort_aggregation_smaller_than_k(): + N_1, N_2 = 4, 6 + x = ops.randn(N_1 + N_2, 4) + index = ms.Tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) + + # Set k which is bigger than both N_1=4 and N_2=6. + aggr = SortAggregation(k=10) + assert str(aggr) == 'SortAggregation(k=10)' + + out = aggr(x, index) + assert out.shape == (2, 10 * 4) + + out_dim = out = aggr(x, index, dim=0) + assert mint.isclose(out_dim, out).all() + + out = out.view(2, 10, 4) + + # Both graph outputs have been filled up with zeros. + assert out[0, -1].tolist() == [0, 0, 0, 0] + assert out[1, -1].tolist() == [0, 0, 0, 0] + + # Nodes are sorted. + assert ops.equal(out[0, :4, -1].argsort(), 3 - mint.arange(4)).all() + assert ops.equal(out[1, :6, -1].argsort(), 5 - mint.arange(6)).all() + + +def test_sort_aggregation_dim_size(): + N_1, N_2 = 4, 6 + x = ops.randn(N_1 + N_2, 4) + index = ms.Tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) + + aggr = SortAggregation(k=5) + assert str(aggr) == 'SortAggregation(k=5)' + + # expand batch output by 1 + out = aggr(x, index, dim_size=3) + assert out.shape == (3, 5 * 4) + + out = out.view(3, 5, 4) + + # Both first and last graph outputs have been filled up with zeros. + assert out[0, -1].tolist() == [0, 0, 0, 0] + assert out[2, -1].tolist() == [0, 0, 0, 0] diff --git a/tests/graph/nn/aggr/test_variance_preserving.py b/tests/graph/nn/aggr/test_variance_preserving.py new file mode 100644 index 000000000..34dbede83 --- /dev/null +++ b/tests/graph/nn/aggr/test_variance_preserving.py @@ -0,0 +1,29 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.nn.aggr import ( + MeanAggregation, + SumAggregation, + VariancePreservingAggregation, +) + + +def test_variance_preserving(): + x = ops.randn(6, 16) + index = ms.Tensor([0, 0, 1, 1, 1, 3]) + ptr = ms.Tensor([0, 2, 5, 5, 6]) + + vpa_aggr = VariancePreservingAggregation() + mean_aggr = MeanAggregation() + sum_aggr = SumAggregation() + + out_vpa = vpa_aggr(x, index) + out_mean = mean_aggr(x, index) + out_sum = sum_aggr(x, index) + + # Equivalent formulation: + expected = mint.sqrt(out_mean.abs() * out_sum.abs()) * out_sum.sign() + + assert out_vpa.shape == (4, 16) + assert mint.isclose(out_vpa, expected, rtol=1e-04, atol=1e-4).all() + + assert mint.isclose(out_vpa, vpa_aggr(x, ptr=ptr), rtol=1e-04, atol=1e-4).all() diff --git a/tests/graph/nn/attn/test_performer.py b/tests/graph/nn/attn/test_performer.py new file mode 100644 index 000000000..5e9d89247 --- /dev/null +++ b/tests/graph/nn/attn/test_performer.py @@ -0,0 +1,12 @@ +from mindspore import ops +from mindscience.sharker.nn.attn import PerformerAttention + + +def test_performer_attention(): + x = ops.randn(1, 4, 16) + mask = ops.ones([1, 4]).bool() + attn = PerformerAttention(channels=16, heads=4) + out = attn(x, mask) + assert out.shape == (1, 4, 16) + assert str(attn) == ('PerformerAttention(heads=4, ' + 'head_channels=64 kernel=ReLU<>)') diff --git a/tests/graph/nn/conv/test_agnn_conv.py b/tests/graph/nn/conv/test_agnn_conv.py new file mode 100644 index 000000000..b3f635376 --- /dev/null +++ b/tests/graph/nn/conv/test_agnn_conv.py @@ -0,0 +1,32 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import AGNNConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import to_csr +from mindscience.sharker.sparse import SparseTensor +from mindscience.sharker import typing, seed_everything + + +@pytest.mark.parametrize('requires_grad', [True, False]) +def test_agnn_conv(requires_grad): + seed_everything(0) + + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = AGNNConv(requires_grad=requires_grad) + assert str(conv) == 'AGNNConv()' + out = conv(x, edge_index) + assert out.shape == (4, 16) + + assert ops.isclose(conv(x, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_antisymmetric_conv.py b/tests/graph/nn/conv/test_antisymmetric_conv.py new file mode 100644 index 000000000..1878923ed --- /dev/null +++ b/tests/graph/nn/conv/test_antisymmetric_conv.py @@ -0,0 +1,32 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import AntiSymmetricConv +from mindscience.sharker.utils import to_csr +from mindscience.sharker import typing +from mindscience.sharker.sparse import SparseTensor + + +def test_antisymmetric_conv(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = AntiSymmetricConv(8) + assert str(conv) == ('AntiSymmetricConv(8, phi=GCNConv(8, 8), ' + 'num_iters=1, epsilon=0.1, gamma=0.1)') + + out1 = conv(x, edge_index) + assert out1.shape == (4, 8) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 8) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_appnp.py b/tests/graph/nn/conv/test_appnp.py new file mode 100644 index 000000000..a42384ee1 --- /dev/null +++ b/tests/graph/nn/conv/test_appnp.py @@ -0,0 +1,56 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import APPNP +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_appnp(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = APPNP(K=3, alpha=0.1, cached=True) + assert str(conv) == 'APPNP(K=3, alpha=0.1)' + out = conv(x, edge_index) + assert out.shape == (4, 16) + assert ops.isclose(conv(x, adj1.t()), out).all() + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out).all() + + # Run again to test the cached functionality: + assert conv._cached_edge_index is not None + assert ops.isclose(conv(x, edge_index), conv(x, adj1.t())).all() + if typing.WITH_SPARSE: + assert conv._cached_adj_t is not None + assert ops.isclose(conv(x, edge_index), conv(x, adj2.t())).all() + + conv.reset_parameters() + assert conv._cached_edge_index is None + assert conv._cached_adj_t is None + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj2.t()), out).all() + + +def test_appnp_dropout(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + # With dropout probability of 1.0, the final output equals to alpha * x: + conv = APPNP(K=2, alpha=0.1, dropout=0.999) + conv.set_train() + assert ops.isclose(0.1 * x, conv(x, edge_index)).all() + assert ops.isclose(0.1 * x, conv(x, adj1.t())).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(0.1 * x, conv(x, adj2.t())).all() diff --git a/tests/graph/nn/conv/test_arma_conv.py b/tests/graph/nn/conv/test_arma_conv.py new file mode 100644 index 000000000..2afe57ca3 --- /dev/null +++ b/tests/graph/nn/conv/test_arma_conv.py @@ -0,0 +1,53 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import ARMAConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_arma_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = ARMAConv(16, 32, num_stacks=8, num_layers=4) + assert str(conv) == 'ARMAConv(16, 32, num_stacks=8, num_layers=4)' + out = conv(x, edge_index) + assert out.shape == (4, 32) + with pytest.raises(AssertionError): # No 3D feature tensor support. + assert ops.isclose(conv(x, adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj2.t()), out, atol=1e-6).all() + + +# def test_lazy_arma_conv(): +# x = ops.randn(4, 16) +# edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + +# conv = ARMAConv(-1, 32, num_stacks=8, num_layers=4) +# assert str(conv) == 'ARMAConv(-1, 32, num_stacks=8, num_layers=4)' +# out = conv(x, edge_index) +# assert out.shape == (4, 32) + + # if typing.WITH_SPARSE: + # adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + # assert ops.isclose(conv(x, adj2.t()), out).all() + + # if is_full_test(): + # jit = ms.jit(conv) + # assert ops.isclose(jit(x, edge_index), out).all() + + # if typing.WITH_SPARSE: + # assert ops.isclose(jit(x, adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_cg_conv.py b/tests/graph/nn/conv/test_cg_conv.py new file mode 100644 index 000000000..c6f1f90e0 --- /dev/null +++ b/tests/graph/nn/conv/test_cg_conv.py @@ -0,0 +1,92 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import CGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('batch_norm', [False, True]) +def test_cg_conv(batch_norm): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = CGConv(8, batch_norm=batch_norm) + assert str(conv) == 'CGConv(8, dim=0)' + out = conv(x1, edge_index) + assert out.shape == (4, 8) + assert ops.isclose(conv(x1, adj1.t()), out).all() + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + conv = CGConv((8, 16)) + assert str(conv) == 'CGConv((8, 16), dim=0)' + out = conv((x1, x2), edge_index) + assert out.shape == (2, 16) + assert ops.isclose(conv((x1, x2), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() + + +def test_cg_conv_with_edge_features(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1], 3) + + conv = CGConv(8, dim=3) + assert str(conv) == 'CGConv(8, dim=3)' + out = conv(x1, edge_index, value) + assert out.shape == (4, 8) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, value), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out).all() + + # Test bipartite message passing: + conv = CGConv((8, 16), dim=3) + assert str(conv) == 'CGConv((8, 16), dim=3)' + out = conv((x1, x2), edge_index, value) + assert out.shape == (2, 16) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, value), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out).all() diff --git a/tests/graph/nn/conv/test_cheb_conv.py b/tests/graph/nn/conv/test_cheb_conv.py new file mode 100644 index 000000000..60a89f50c --- /dev/null +++ b/tests/graph/nn/conv/test_cheb_conv.py @@ -0,0 +1,71 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch, Graph +from mindscience.sharker.nn import ChebConv +from mindscience.sharker.testing import is_full_test + + +def test_cheb_conv(): + in_channels, out_channels = (16, 32) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + num_nodes = edge_index.max().item() + 1 + edge_weight = ops.rand(edge_index.shape[1]) + x = ops.randn((num_nodes, in_channels)) + + conv = ChebConv(in_channels, out_channels, K=3) + assert str(conv) == 'ChebConv(16, 32, K=3, normalization=sym)' + out1 = conv(x, edge_index) + assert out1.shape == (num_nodes, out_channels) + out2 = conv(x, edge_index, edge_weight) + assert out2.shape == (num_nodes, out_channels) + out3 = conv(x, edge_index, edge_weight, lambda_max=3.0) + assert out3.shape == (num_nodes, out_channels) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1).all() + assert ops.isclose(jit(x, edge_index, edge_weight), out2).all() + assert ops.isclose( + jit(x, edge_index, edge_weight, lambda_max=ms.Tensor(3.0)), + out3).all() + + batch = ms.Tensor([0, 0, 1, 1]) + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + num_nodes = edge_index.max().item() + 1 + edge_weight = ops.rand(edge_index.shape[1]) + x = ops.randn((num_nodes, in_channels)) + lambda_max = ms.Tensor([2.0, 3.0]) + + out4 = conv(x, edge_index, edge_weight, batch) + assert out4.shape == (num_nodes, out_channels) + out5 = conv(x, edge_index, edge_weight, batch, lambda_max) + assert out5.shape == (num_nodes, out_channels) + + if is_full_test(): + assert ops.isclose(ms.jit(x, edge_index, edge_weight, batch), out4).all() + assert ops.iscloselose( + ms.jit(x, edge_index, edge_weight, batch, lambda_max), out5).all() + + +def test_cheb_conv_batch(): + x1 = ops.randn(4, 8) + edge_index1 = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + edge_weight1 = ops.rand(edge_index1.shape[1]) + data1 = Graph(x=x1, edge_index=edge_index1, edge_weight=edge_weight1) + + x2 = ops.randn(3, 8) + edge_index2 = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight2 = ops.rand(edge_index2.shape[1]) + data2 = Graph(x=x2, edge_index=edge_index2, edge_weight=edge_weight2) + + conv = ChebConv(8, 16, K=2) + + out1 = conv(x1, edge_index1, edge_weight1) + out2 = conv(x2, edge_index2, edge_weight2) + + batch = Batch.from_data_list([data1, data2]) + out = conv(batch.x, batch.edge_index, batch.edge_weight, batch.batch) + + assert out.shape == (7, 16) + assert ops.isclose(out1, out[:4], atol=1e-6).all() + assert ops.isclose(out2, out[4:], atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_cluster_gcn_conv.py b/tests/graph/nn/conv/test_cluster_gcn_conv.py new file mode 100644 index 000000000..be8f40827 --- /dev/null +++ b/tests/graph/nn/conv/test_cluster_gcn_conv.py @@ -0,0 +1,30 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import ClusterGCNConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_cluster_gcn_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = ClusterGCNConv(16, 32, diag_lambda=1.) + assert str(conv) == 'ClusterGCNConv(16, 32, diag_lambda=1.0)' + out = conv(x, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out, atol=1e-5).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-5).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj2.t()), out, atol=1e-5).all() diff --git a/tests/graph/nn/conv/test_create_gnn.py b/tests/graph/nn/conv/test_create_gnn.py new file mode 100644 index 000000000..9c1e94132 --- /dev/null +++ b/tests/graph/nn/conv/test_create_gnn.py @@ -0,0 +1,36 @@ +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker.nn import MessagePassing +from mindscience.sharker.utils import add_self_loops, degree + + +class GCNConv(MessagePassing): + def __init__(self, in_channels, out_channels): + super().__init__(aggr='add') + self.lin = nn.Dense(in_channels, out_channels) + + def construct(self, x, edge_index): + edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) + + row, col = edge_index + deg = degree(row, x.shape[0], dtype=x.dtype) + deg_inv_sqrt = deg.pow(-0.5) + norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] + + x = self.lin(x) + return self.propagate(edge_index, shape=(x.shape[0], x.shape[0]), x=x, + norm=norm) + + def message(self, x_j, norm): + return norm.view(-1, 1) * x_j + + def update(self, aggr_out): + return aggr_out + + +def test_create_gnn(): + conv = GCNConv(16, 32) + x = ops.randn(5, 16) + edge_index = ops.randint(0, 5, (2, 64), dtype=ms.int64) + out = conv(x, edge_index) + assert out.shape == (5, 32) diff --git a/tests/graph/nn/conv/test_dir_gnn_conv.py b/tests/graph/nn/conv/test_dir_gnn_conv.py new file mode 100644 index 000000000..8df95fbd6 --- /dev/null +++ b/tests/graph/nn/conv/test_dir_gnn_conv.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DirGNNConv, SAGEConv + + +def test_dir_gnn_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2], [1, 2, 3]]) + + conv = DirGNNConv(SAGEConv(16, 32)) + assert str(conv) == 'DirGNNConv(SAGEConv(16, 32, aggr=mean), alpha=0.5)' + + out = conv(x, edge_index) + assert out.shape == (4, 32) + + +def test_static_dir_gnn_conv(): + x = ops.randn(3, 4, 16) + edge_index = ms.Tensor([[0, 1, 2], [1, 2, 3]]) + + conv = DirGNNConv(SAGEConv(16, 32)) + + out = conv(x, edge_index) + assert out.shape == (3, 4, 32) diff --git a/tests/graph/nn/conv/test_dna_conv.py b/tests/graph/nn/conv/test_dna_conv.py new file mode 100644 index 000000000..124c883b0 --- /dev/null +++ b/tests/graph/nn/conv/test_dna_conv.py @@ -0,0 +1,88 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import DNAConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('channels', [32]) +@pytest.mark.parametrize('num_layers', [3]) +def test_dna_conv(channels, num_layers): + x = ops.randn((4, num_layers, channels)) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + + conv = DNAConv(channels, heads=4, groups=8, dropout=0.0) + assert str(conv) == 'DNAConv(32, heads=4, groups=8)' + out = conv(x, edge_index) + assert out.shape == (4, channels) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() + + conv = DNAConv(channels, heads=1, groups=1, dropout=0.0) + assert str(conv) == 'DNAConv(32, heads=1, groups=1)' + out = conv(x, edge_index) + assert out.shape == (4, channels) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() + + conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True) + out = conv(x, edge_index) + assert conv._cached_edge_index is not None + out = conv(x, edge_index) + assert out.shape == (4, channels) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() + + +@pytest.mark.parametrize('channels', [32]) +@pytest.mark.parametrize('num_layers', [3]) +def test_dna_conv_sparse_tensor(channels, num_layers): + x = ops.randn((4, num_layers, channels)) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = DNAConv(32, heads=4, groups=8, dropout=0.0) + assert str(conv) == 'DNAConv(32, heads=4, groups=8)' + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + # assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() + + conv = DNAConv(channels, heads=1, groups=1, dropout=0.0, cached=True) + + out1 = conv(x, adj1.t()) + assert conv._cached_edge_index is not None + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert conv._cached_adj_t is not None + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_edge_conv.py b/tests/graph/nn/conv/test_edge_conv.py new file mode 100644 index 000000000..dbef51d9c --- /dev/null +++ b/tests/graph/nn/conv/test_edge_conv.py @@ -0,0 +1,96 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker import typing +from mindscience.sharker.nn import DynamicEdgeConv, EdgeConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_edge_conv_conv(): + x1 = ops.randn(4, 16) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32)) + conv = EdgeConv(nn) + assert str(conv) == ( + 'EdgeConv(nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)' + ) + out = conv(x1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv((x1, x1), edge_index), out, atol=1e-6).all() + assert ops.isclose(conv(x1, adj1.t()), out, atol=1e-6).all() + assert ops.isclose(conv((x1, x1), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + assert ops.isclose(conv((x1, x1), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out, atol=1e-6).all() + assert ops.isclose(jit((x1, x1), edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + assert ops.isclose(jit((x1, x1), adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv((x1, x2), edge_index) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, x2), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + assert ops.isclose(jit((x1, x2), edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() + + +def test_dynamic_edge_conv(): + x1 = ops.randn(8, 16) + x2 = ops.randn(4, 16) + batch1 = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + batch2 = ms.Tensor([0, 0, 1, 1]) + + nn = Seq(Lin(32, 16), ReLU(), Lin(16, 32)) + conv = DynamicEdgeConv(nn, k=2) + assert str(conv) == ( + 'DynamicEdgeConv(nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >, k=2)') + out11 = conv(x1) + assert out11.shape == (8, 32) + + out12 = conv(x1, batch1) + assert out12.shape == (8, 32) + + out21 = conv((x1, x2)) + assert out21.shape == (4, 32) + + out22 = conv((x1, x2), (batch1, batch2)) + assert out22.shape == (4, 32) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1), out11).all() + assert ops.isclose(jit(x1, batch1), out12).all() + assert ops.isclose(jit((x1, x2)), out21).all() + assert ops.isclose(jit((x1, x2), (batch1, batch2)), out22).all() diff --git a/tests/graph/nn/conv/test_eg_conv.py b/tests/graph/nn/conv/test_eg_conv.py new file mode 100644 index 000000000..97e6084d1 --- /dev/null +++ b/tests/graph/nn/conv/test_eg_conv.py @@ -0,0 +1,68 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import EGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_eg_conv_with_error(): + with pytest.raises(ValueError, match="must be divisible by the number of"): + EGConv(16, 30, num_heads=8) + + with pytest.raises(ValueError, match="Unsupported aggregator"): + EGConv(16, 32, aggregators=['xxx']) + + +@pytest.mark.parametrize('aggregators', [ + ['symnorm'], + ['sum', 'symnorm', 'std'], +]) +@pytest.mark.parametrize('add_self_loops', [True, False]) +def test_eg_conv(aggregators, add_self_loops): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = EGConv( + in_channels=16, + out_channels=32, + aggregators=aggregators, + add_self_loops=add_self_loops, + ) + assert str(conv) == f"EGConv(16, 32, aggregators={aggregators})" + out = conv(x, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out, atol=1e-2).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-2).all() + + conv.cached = True + assert ops.isclose(conv(x, edge_index), out, atol=1e-2).all() + assert conv._cached_edge_index is not None + assert ops.isclose(conv(x, edge_index), out, atol=1e-2).all() + assert ops.isclose(conv(x, adj1.t()), out, atol=1e-2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-2).all() + assert conv._cached_adj_t is not None + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-2).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj2.t()), out, atol=1e-2).all() + + +def test_eg_conv_with_sparse_input_feature(): + x = ops.randn(4, 16).to_coo() + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + + conv = EGConv(16, 32) + assert conv(x, edge_index).shape == (4, 32) diff --git a/tests/graph/nn/conv/test_fa_conv.py b/tests/graph/nn/conv/test_fa_conv.py new file mode 100644 index 000000000..12ea46285 --- /dev/null +++ b/tests/graph/nn/conv/test_fa_conv.py @@ -0,0 +1,125 @@ +from typing import Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker import typing +from mindscience.sharker.nn import FAConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import Adj, SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_fa_conv(): + x = ops.randn(4, 16) + x_0 = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = FAConv(16, eps=1.0, cached=True) + assert str(conv) == 'FAConv(16, eps=1.0)' + out = conv(x, x_0, edge_index) + assert conv._cached_edge_index is not None + assert out.shape == (4, 16) + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, x_0, adj2.t()), out, atol=1e-6).all() + assert conv._cached_adj_t is not None + assert ops.isclose(conv(x, x_0, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + x_0: Tensor, + edge_index: Adj, + ) -> Tensor: + return self.conv(x, x_0, edge_index) + + jit = ms.jit(MyModule()) + assert ops.isclose(jit(x, x_0, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, x_0, adj2.t()), out).all() + + conv.reset_parameters() + assert conv._cached_edge_index is None + assert conv._cached_adj_t is None + + # Test without caching: + conv.cached = False + out = conv(x, x_0, edge_index) + assert ops.isclose(conv(x, x_0, adj1.t()), out, atol=1e-6).all() + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, x_0, adj2.t()), out, atol=1e-6).all() + + # Test `return_attention_weights`: + result = conv(x, x_0, edge_index, return_attention_weights=True) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1][0].shape == (2, 10) + assert result[1][1].shape == (10, ) + assert conv._alpha is None + + result = conv(x, x_0, adj1.t(), return_attention_weights=True) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1][0].shape == (4, 4) + # assert result[1][0]._nnz() == 10 + assert conv._alpha is None + + if typing.WITH_SPARSE: + result = conv(x, x_0, adj2.t(), return_attention_weights=True) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == [4, 4] and result[1].nnz() == 10 + assert conv._alpha is None + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + x_0: Tensor, + edge_index: Tensor, + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + return self.conv(x, x_0, edge_index, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x, x_0, edge_index) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1][0].shape == (2, 10) + assert result[1][1].shape == (10, ) + assert conv._alpha is None + + if typing.WITH_SPARSE: + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + x_0: Tensor, + edge_index: SparseTensor, + ) -> Tuple[Tensor, SparseTensor]: + return self.conv(x, x_0, edge_index, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x, x_0, adj2.t()) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == [4, 4] and result[1].nnz() == 10 + assert conv._alpha is None diff --git a/tests/graph/nn/conv/test_feast_conv.py b/tests/graph/nn/conv/test_feast_conv.py new file mode 100644 index 000000000..77b65cd36 --- /dev/null +++ b/tests/graph/nn/conv/test_feast_conv.py @@ -0,0 +1,49 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import FeaStConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_feast_conv(): + x1 = ops.randn(4, 16) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = FeaStConv(16, 32, heads=2) + assert str(conv) == 'FeaStConv(16, 32, heads=2)' + + out = conv(x1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv((x1, x2), edge_index) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, x2), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + assert ops.isclose(jit((x1, x2), edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_film_conv.py b/tests/graph/nn/conv/test_film_conv.py new file mode 100644 index 000000000..77b0a256c --- /dev/null +++ b/tests/graph/nn/conv/test_film_conv.py @@ -0,0 +1,79 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import FiLMConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_film_conv(): + x1 = ops.randn(4, 4) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [0, 0, 1, 0, 1, 1]]) + edge_type = ms.Tensor([0, 1, 1, 0, 0, 1]) + + conv = FiLMConv(4, 32) + assert str(conv) == 'FiLMConv(4, 32, num_relations=1)' + out = conv(x1, edge_index) + assert out.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out, atol=1e-6).all() + + conv = FiLMConv(4, 32, num_relations=2) + assert str(conv) == 'FiLMConv(4, 32, num_relations=2)' + out = conv(x1, edge_index, edge_type) + assert out.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4)) + assert ops.isclose(conv(x1, adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, edge_type), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + conv = FiLMConv((4, 16), 32) + assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=1)' + out = conv((x1, x2), edge_index) + assert out.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out, atol=1e-6).all() + + conv = FiLMConv((4, 16), 32, num_relations=2) + assert str(conv) == 'FiLMConv((4, 16), 32, num_relations=2)' + out = conv((x1, x2), edge_index, edge_type) + assert out.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, edge_type).all(), out, + atol=1e-6) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_fused_gat_conv.py b/tests/graph/nn/conv/test_fused_gat_conv.py new file mode 100644 index 000000000..a979ece0c --- /dev/null +++ b/tests/graph/nn/conv/test_fused_gat_conv.py @@ -0,0 +1,35 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import FusedGATConv +from mindscience.sharker.testing import withPackage + + +def test_to_graph_format() -> None: + edge_index = ms.Tensor([[1, 0, 2, 3], [0, 0, 1, 1]]) + + csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4)) + + assert csr[0].dtype == ms.int64 + assert ops.equal(csr[0], ms.Tensor([0, 1, 2, 3, 4], dtype=ms.int64)).all() + assert csr[1].dtype == ms.int64 + assert ops.equal(csr[1], ms.Tensor([0, 0, 1, 1], dtype=ms.int64)).all() + assert csc[0].dtype == ms.int64 + assert ops.equal(csc[0], ms.Tensor([0, 1, 2, 3], dtype=ms.int64)).all() + assert csc[1].dtype == ms.int64 + assert ops.equal(csc[1], ms.Tensor([0, 2, 4, 4, 4], dtype=ms.int64)).all() + assert perm.dtype == ms.int64 + assert ops.equal(perm, ms.Tensor([0, 1, 2, 3], dtype=ms.int64)).all() + + +@withPackage('dgNN') +def test_fused_gat_conv() -> None: + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + csr, csc, perm = FusedGATConv.to_graph_format(edge_index, size=(4, 4)) + + conv = FusedGATConv(8, 32, heads=2, add_self_loops=False) + assert str(conv) == 'FusedGATConv(8, 32, heads=2)' + + out = conv(x, csr, csc, perm) + assert out.shape == (4, 64) diff --git a/tests/graph/nn/conv/test_gat_conv.py b/tests/graph/nn/conv/test_gat_conv.py new file mode 100644 index 000000000..b7bbb6e46 --- /dev/null +++ b/tests/graph/nn/conv/test_gat_conv.py @@ -0,0 +1,175 @@ +import pytest +import mindspore as ms +from typing import Optional, Tuple +from mindspore import Tensor, ops, nn, mint +from mindspore.ops import IsClose +from mindscience.sharker import typing +from mindscience.sharker.nn import GATConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import Adj, Size, SparseTensor + + +def test_gat_conv(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = GATConv(8, 32, heads=2) + assert str(conv) == 'GATConv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.shape == (4, 64) + + assert mint.isclose(conv(x1, edge_index, size=(4, 4)), out, rtol=1e-04, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Adj, + size: Size = None, + ) -> Tensor: + return self.conv(x, edge_index, size=size) + + jit = ms.jit(MyModule()) + assert mint.isclose(jit(x1, edge_index), out).all() + assert mint.isclose(jit(x1, edge_index, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert mint.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + result = conv(x1, edge_index, return_attention_weights=True) + assert mint.isclose(result[0], out, rtol=1e-04, atol=1e-6).all() + assert result[1][0].shape == (2, 7) + assert result[1][1].shape == (7, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + + if typing.WITH_SPARSE: + result = conv(x1, adj2.t(), return_attention_weights=True) + assert mint.isclose(result[0], out, atol=1e-6).all() + assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Tensor, + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + return self.conv(x, edge_index, return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, edge_index) + assert mint.isclose(result[0], out).all() + assert result[1][0].shape == (2, 7) + assert result[1][1].shape == (7, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + + if typing.WITH_SPARSE: + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: SparseTensor, + ) -> Tuple[Tensor, SparseTensor]: + return self.conv(x, edge_index, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, adj2.t()) + assert mint.isclose(result[0], out, atol=1e-6).all() + assert result[1].sizes() == [4, 4, 2] and result[1].nnz() == 7 + + conv = GATConv((8, 16), 32, heads=2) + assert str(conv) == 'GATConv((8, 16), 32, heads=2)' + + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 64) + assert mint.isclose(conv((x1, x2), edge_index, size=(4, 2)), out1, rtol=1e-04, atol=1e-6).all() + + out2 = conv((x1, None), edge_index, size=(4, 2)) + assert out2.shape == (2, 64) + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 2)) + assert mint.isclose(conv((x1, x2), adj2.t()), out1, atol=1e-6).all() + assert mint.isclose(conv((x1, None), adj2.t()), out2, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tuple[Tensor, Optional[Tensor]], + edge_index: Adj, + size: Size = None, + ) -> Tensor: + return self.conv(x, edge_index, size=size) + + jit = ms.jit(MyModule()) + assert mint.isclose(jit((x1, x2), edge_index), out1).all() + assert mint.isclose(jit((x1, x2), edge_index, size=(4, 2)), out1).all() + assert mint.isclose(jit((x1, None), edge_index, size=(4, 2)), out2).all() + + if typing.WITH_SPARSE: + assert mint.isclose(jit((x1, x2), adj2.t()), out1, atol=1e-6).all() + assert mint.isclose(jit((x1, None), adj2.t()), out2, atol=1e-6).all() + + +def test_gat_conv_with_edge_attr(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 4) + + conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value=0.5) + out = conv(x, edge_index, edge_weight) + assert out.shape == (4, 64) + if typing.WITH_SPARSE: + adj1 = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 4)) + with pytest.raises(NotImplementedError): + assert mint.isclose(conv(x, adj1.t()), out).all() + + conv = GATConv(8, 32, heads=2, edge_dim=1, fill_value='mean') + out = conv(x, edge_index, edge_weight) + assert out.shape == (4, 64) + if typing.WITH_SPARSE: + with pytest.raises(NotImplementedError): + assert mint.isclose(conv(x, adj1.t()), out).all() + + conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value=0.5) + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 64) + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + with pytest.raises(NotImplementedError): + assert mint.isclose(conv(x, adj2.t()), out).all() + + conv = GATConv(8, 32, heads=2, edge_dim=4, fill_value='mean') + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 64) + if typing.WITH_SPARSE: + with pytest.raises(NotImplementedError): + assert mint.isclose(conv(x, adj2.t()), out).all() \ No newline at end of file diff --git a/tests/graph/nn/conv/test_gated_graph_conv.py b/tests/graph/nn/conv/test_gated_graph_conv.py new file mode 100644 index 000000000..43a4255d6 --- /dev/null +++ b/tests/graph/nn/conv/test_gated_graph_conv.py @@ -0,0 +1,39 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GatedGraphConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_gated_graph_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = GatedGraphConv(32, num_layers=3) + assert str(conv) == 'GatedGraphConv(32, num_layers=3)' + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_gatv2_conv.py b/tests/graph/nn/conv/test_gatv2_conv.py new file mode 100644 index 000000000..e8bb9371e --- /dev/null +++ b/tests/graph/nn/conv/test_gatv2_conv.py @@ -0,0 +1,156 @@ +from typing import Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker import typing +from mindscience.sharker.nn import GATv2Conv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import Adj, SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_gatv2_conv(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = GATv2Conv(8, 32, heads=2) + assert str(conv) == 'GATv2Conv(8, 32, heads=2)' + out = conv(x1, edge_index) + assert out.shape == (4, 64) + assert ops.isclose(conv(x1, edge_index), out).all() + assert ops.isclose(conv(x1, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Adj, + ) -> Tensor: + return self.conv(x, edge_index) + + jit = ms.jit(MyModule()) + assert ops.isclose(jit(x1, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test `return_attention_weights`. + result = conv(x1, edge_index, return_attention_weights=True) + assert ops.isclose(result[0], out).all() + assert result[1][0].shape == (2, 7) + assert result[1][1].shape == (7, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + + if typing.WITH_SPARSE: + result = conv(x1, adj2.t(), return_attention_weights=True) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == [4, 4, 2] and result[1].nnz() == 7 + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Tensor, + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + return self.conv(x, edge_index, return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, edge_index) + assert ops.isclose(result[0], out).all() + assert result[1][0].shape == (2, 7) + assert result[1][1].shape == (7, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + + if typing.WITH_SPARSE: + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: SparseTensor, + ) -> Tuple[Tensor, SparseTensor]: + return self.conv(x, edge_index, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, adj2.t()) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == [4, 4, 2] and result[1].nnz() == 7 + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv((x1, x2), edge_index) + assert out.shape == (2, 64) + assert ops.isclose(conv((x1, x2), edge_index), out).all() + assert ops.isclose(conv((x1, x2), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tuple[Tensor, Tensor], + edge_index: Adj, + ) -> Tensor: + return self.conv(x, edge_index) + + jit = ms.jit(MyModule()) + assert ops.isclose(jit((x1, x2), edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() + + +def test_gatv2_conv_with_edge_attr(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 4) + + conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value=0.5) + out = conv(x, edge_index, edge_weight) + assert out.shape == (4, 64) + + conv = GATv2Conv(8, 32, heads=2, edge_dim=1, fill_value='mean') + out = conv(x, edge_index, edge_weight) + assert out.shape == (4, 64) + + conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value=0.5) + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 64) + + conv = GATv2Conv(8, 32, heads=2, edge_dim=4, fill_value='mean') + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 64) diff --git a/tests/graph/nn/conv/test_gcn2_conv.py b/tests/graph/nn/conv/test_gcn2_conv.py new file mode 100644 index 000000000..c125ac9b9 --- /dev/null +++ b/tests/graph/nn/conv/test_gcn2_conv.py @@ -0,0 +1,51 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GCN2Conv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_gcn2_conv(): + x = ops.randn(4, 16) + x_0 = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = GCN2Conv(16, alpha=0.2) + assert str(conv) == 'GCN2Conv(16, alpha=0.2, beta=1.0)' + out1 = conv(x, x_0, edge_index) + assert out1.shape == (4, 16) + assert ops.isclose(conv(x, x_0, adj1.t()), out1, atol=1e-6).all() + out2 = conv(x, x_0, edge_index, value) + assert out2.shape == (4, 16) + assert ops.isclose(conv(x, x_0, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, x_0, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, x_0, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, x_0, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, x_0, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, x_0, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, x_0, adj4.t()), out2, atol=1e-6).all() + + conv.cached = True + conv(x, x_0, edge_index) + assert conv._cached_edge_index is not None + assert ops.isclose(conv(x, x_0, edge_index), out1, atol=1e-6).all() + assert ops.isclose(conv(x, x_0, adj1.t()), out1, atol=1e-6).all() + + if typing.WITH_SPARSE: + conv(x, x_0, adj3.t()) + assert conv._cached_adj_t is not None + assert ops.isclose(conv(x, x_0, adj3.t()), out1, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_gcn_conv.py b/tests/graph/nn/conv/test_gcn_conv.py new file mode 100644 index 000000000..2ca9a3a4f --- /dev/null +++ b/tests/graph/nn/conv/test_gcn_conv.py @@ -0,0 +1,72 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker import typing +from mindscience.sharker.nn import GCNConv +from mindscience.sharker.nn.conv.gcn_conv import gcn_norm +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_gcn_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = mint.rand(edge_index.shape[1]) + + conv = GCNConv(16, 32) + assert str(conv) == 'GCNConv(16, 32)' + + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert mint.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert mint.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert mint.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert mint.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert mint.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert mint.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() + + conv.cached = True + conv(x, edge_index) + assert conv._cached_edge_index is not None + assert mint.isclose(conv(x, edge_index), out1 , rtol=1e-04, atol=1e-6).all() + + if typing.WITH_SPARSE: + conv(x, adj3.t()) + assert conv._cached_adj_t is not None + assert mint.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + +def test_static_gcn_conv(): + x = ops.randn(3, 4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + + conv = GCNConv(16, 32) + out = conv(x, edge_index) + assert out.shape == (3, 4, 32) + + +def test_gcn_conv_error(): + with pytest.raises(ValueError, match="does not support adding self-loops"): + GCNConv(16, 32, normalize=False, add_self_loops=True) + + +def test_gcn_conv_flow(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0], [1, 2, 3]]) + + conv = GCNConv(16, 32, flow="src_to_trg") + out1 = conv(x, edge_index) + conv.flow = "trg_to_src" + out2 = conv(x, edge_index.flip([0])) + assert mint.isclose(out1, out2, atol=1e-6).all() \ No newline at end of file diff --git a/tests/graph/nn/conv/test_gen_conv.py b/tests/graph/nn/conv/test_gen_conv.py new file mode 100644 index 000000000..7636fbeb5 --- /dev/null +++ b/tests/graph/nn/conv/test_gen_conv.py @@ -0,0 +1,141 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GENConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_coo + + +@pytest.mark.parametrize('aggr', [ + 'softmax', + 'powermean', + ['softmax', 'powermean'], +]) +def test_gen_conv(aggr): + x1 = ops.randn(4, 16) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.randn(edge_index.shape[1], 16) + adj1 = to_coo(edge_index, shape=(4, 4)) + # adj2 = to_coo(edge_index, value, shape=(4, 4)) + + conv = GENConv(16, 32, aggr, edge_dim=16, msg_norm=True) + assert str(conv) == f'GENConv(16, 32, aggr={aggr})' + out1 = conv(x1, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, shape=(4, 4)), out1).all() + assert ops.isclose(conv(x1, adj1.t().coalesce()), out1).all() + + out2 = conv(x1, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, (4, 4)), out2).all() + # t() expects a tensor with <= 2 sparse and 0 dense dimensions + # assert ops.isclose(conv(x1, adj2.swapaxes(1, 0).coalesce()), out2).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj3.t()), out1, atol=1e-4).all() + assert ops.isclose(conv(x1, adj4.t()), out2, atol=1e-4).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out1, atol=1e-4).all() + assert ops.isclose(jit(x1, edge_index, shape=(4, 4)).all(), out1, + atol=1e-4) + assert ops.isclose(jit(x1, edge_index, value), out2, atol=1e-4).all() + assert ops.isclose(jit(x1, edge_index, value, shape=(4, 4)).all(), out2, + atol=1e-4) + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj3.t()), out1, atol=1e-4).all() + assert ops.isclose(jit(x1, adj4.t()), out2, atol=1e-4).all() + + # Test bipartite message passing: + adj1 = to_coo(edge_index, shape=(4, 2)) + # adj2 = to_coo(edge_index, value, shape=(4, 2)) + + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, shape=(4, 2)), out1).all() + assert ops.isclose(conv((x1, x2), adj1.t().coalesce()), out1).all() + + out2 = conv((x1, x2), edge_index, value) + assert out2.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out2).all() + # assert ops.isclose(conv((x1, x2).all(), adj2.swapaxes(1, 0).coalesce()), out2) + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj3.t()), out1, atol=1e-4).all() + assert ops.isclose(conv((x1, x2), adj4.t()), out2, atol=1e-4).all() + + if is_full_test(): + assert ops.isclose(jit((x1, x2), edge_index), out1, atol=1e-4).all() + assert ops.isclose(jit((x1, x2), edge_index, shape=(4, 2)).all(), out1, + atol=1e-4) + assert ops.isclose(jit((x1, x2), edge_index, value).all(), out2, + atol=1e-4) + assert ops.isclose(jit((x1, x2), edge_index, value, (4, 2)).all(), out2, + atol=1e-4) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj3.t()), out1, atol=1e-4).all() + assert ops.isclose(jit((x1, x2), adj4.t()), out2, atol=1e-4).all() + + # Test bipartite message passing with unequal feature dimensions: + conv.reset_parameters() + assert float(conv.msg_norm.scale) == 1 + + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + + conv = GENConv((8, 16), 32, aggr) + assert str(conv) == f'GENConv((8, 16), 32, aggr={aggr})' + + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, shape=(4, 2)), out1).all() + # assert ops.isclose(conv((x1, x2), adj1.t().coalesce()), out1).all() + + out2 = conv((x1, None), edge_index, shape=(4, 2)) + assert out2.shape == (2, 32) + # assert ops.isclose(conv((x1, None), adj1.t().coalesce()), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(conv((x1, x2), adj3.t()), out1, atol=1e-4).all() + assert ops.isclose(conv((x1, None), adj3.t()), out2, atol=1e-4).all() + + # Test lazy initialization: + # conv = GENConv((-1, -1), 32, aggr, edge_dim=-1) + # assert str(conv) == f'GENConv((-1, -1), 32, aggr={aggr})' + # out1 = conv((x1, x2), edge_index, value) + # assert out1.shape == (2, 32) + # assert ops.isclose(conv((x1, x2), edge_index, value, shape=(4, 2)), out1).all() + # assert ops.isclose(conv((x1, x2).all(), + # adj2.swapaxes(1, 0).coalesce()), out1) + + # out2 = conv((x1, None), edge_index, value, shape=(4, 2)) + # assert out2.shape == (2, 32) + # assert ops.isclose(conv((x1, None).all(), + # adj2.swapaxes(1, 0).coalesce()), out2) + + # if typing.WITH_SPARSE: + # assert ops.isclose(conv((x1, x2), adj4.t()), out1, atol=1e-4).all() + # assert ops.isclose(conv((x1, None), adj4.t()), out2, atol=1e-4).all() + + # if is_full_test(): + # jit = ms.jit(conv) + # assert ops.isclose(jit((x1, x2), edge_index, value).all(), out1, + # atol=1e-4) + # assert ops.isclose(jit((x1, x2), edge_index, value, shape=(4, 2)).all(), + # out1, atol=1e-4) + # assert ops.isclose(jit((x1, None), edge_index, value, shape=(4, 2)).all(), + # out2, atol=1e-4) + + # if typing.WITH_SPARSE: + # assert ops.isclose(jit((x1, x2), adj4.t()), out1, atol=1e-4).all() + # assert ops.isclose(jit((x1, None), adj4.t()), out2, atol=1e-4).all() diff --git a/tests/graph/nn/conv/test_general_conv.py b/tests/graph/nn/conv/test_general_conv.py new file mode 100644 index 000000000..fceffe65f --- /dev/null +++ b/tests/graph/nn/conv/test_general_conv.py @@ -0,0 +1,31 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GeneralConv +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('kwargs', [ + dict(), + dict(skip_linear=True), + dict(directed_msg=False), + dict(heads=3), + dict(attention=True), + dict(heads=3, attention=True), + dict(heads=3, attention=True, attention_type='dot_product'), + dict(l2_normalize=True), +]) +def test_general_conv(kwargs): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn(edge_index.shape[1], 16) + + conv = GeneralConv(8, 32, in_edge_channels=16, **kwargs) + assert str(conv) == 'GeneralConv(8, 32)' + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + assert ops.isclose(conv(x, adj.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_gin_conv.py b/tests/graph/nn/conv/test_gin_conv.py new file mode 100644 index 000000000..2a3ee4743 --- /dev/null +++ b/tests/graph/nn/conv/test_gin_conv.py @@ -0,0 +1,159 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker import typing +from mindscience.sharker.nn import GINConv, GINEConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_gin_conv(): + x1 = ops.randn(4, 16) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) + conv = GINConv(nn, train_eps=True) + assert str(conv) == ( + 'GINConv(nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)') + out = conv(x1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, shape=(4, 4)), out, atol=1e-6).all() + assert ops.isclose(conv(x1, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out, atol=1e-6).all() + assert ops.isclose(jit(x1, edge_index, shape=(4, 4)), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, (4, 2)), out1, atol=1e-6).all() + assert ops.isclose(conv((x1, x2), adj1.t()), out1, atol=1e-6).all() + + out2 = conv((x1, None), edge_index, (4, 2)) + assert out2.shape == (2, 32) + assert ops.isclose(conv((x1, None), adj1.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out1, atol=1e-6).all() + assert ops.isclose(conv((x1, None), adj2.t()), out2, atol=1e-6).all() + + if is_full_test(): + assert ops.isclose(jit((x1, x2), edge_index), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, size=(4, 2)), out1).all() + assert ops.isclose(jit((x1, None), edge_index, size=(4, 2)), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out1).all() + assert ops.isclose(jit((x1, None), adj2.t()), out2).all() + + +def test_gine_conv(): + x1 = ops.randn(4, 16) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.randn(edge_index.shape[1], 16) + + nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) + conv = GINEConv(nn, train_eps=True) + assert str(conv) == ( + 'GINEConv(nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)') + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, value), out).all() + assert ops.isclose(jit(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out).all() + + # Test bipartite message passing: + out1 = conv((x1, x2), edge_index, value) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out2.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out1).all() + assert ops.isclose(conv((x1, None), adj.t()), out2).all() + + if is_full_test(): + assert ops.isclose(jit((x1, x2), edge_index, value), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, value, size=(4, 2)).all(), + out1) + assert ops.isclose(jit((x1, None), edge_index, value, size=(4, 2)).all(), + out2) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out1).all() + assert ops.isclose(jit((x1, None), adj.t()), out2).all() + + +def test_gine_conv_edge_dim(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) + conv = GINEConv(nn, train_eps=True, edge_dim=8) + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 32) + + nn = Lin(16, 32) + conv = GINEConv(nn, train_eps=True, edge_dim=8) + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 32) + + +def test_static_gin_conv(): + x = ops.randn(3, 4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + + nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) + conv = GINConv(nn, train_eps=True) + out = conv(x, edge_index) + assert out.shape == (3, 4, 32) + + +def test_static_gine_conv(): + x = ops.randn(3, 4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = ops.randn(edge_index.shape[1], 16) + + nn = Seq(Lin(16, 32), ReLU(), Lin(32, 32)) + conv = GINEConv(nn, train_eps=True) + out = conv(x, edge_index, edge_attr) + assert out.shape == (3, 4, 32) diff --git a/tests/graph/nn/conv/test_gmm_conv.py b/tests/graph/nn/conv/test_gmm_conv.py new file mode 100644 index 000000000..23d24a2c0 --- /dev/null +++ b/tests/graph/nn/conv/test_gmm_conv.py @@ -0,0 +1,93 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GMMConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +# from mindscience.sharker.utils import to_coo + + +@pytest.mark.parametrize('separate_gaussians', [True, False]) +def test_gmm_conv(separate_gaussians): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1], 3) + # adj1 = to_coo(edge_index, value, shape=(4, 4)) + + conv = GMMConv(8, 32, axis=3, kernel_size=25, + separate_gaussians=separate_gaussians) + assert str(conv) == 'GMMConv(8, 32, axis=3)' + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, size=(4, 4)), out).all() + # t() expects a tensor with <= 2 sparse and 0 dense dimensions + # assert ops.isclose(conv(x1, adj1.swapaxes(0, 1).coalesce()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, value), out).all() + assert ops.isclose(jit(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out).all() + + # Test bipartite message passing: + # adj1 = to_coo(edge_index, value, shape=(4, 2)) + + conv = GMMConv((8, 16), 32, axis=3, kernel_size=5, + separate_gaussians=separate_gaussians) + assert str(conv) == 'GMMConv((8, 16), 32, axis=3)' + + out1 = conv((x1, x2), edge_index, value) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + # assert ops.isclose(conv((x1, x2), + # adj1.swapaxes(0, 1).coalesce()), out1).all() + + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out2.shape == (2, 32) + # assert ops.isclose(conv((x1, None).all(), + # adj1.swapaxes(0, 1).coalesce()), out2) + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out1).all() + assert ops.isclose(conv((x1, None), adj2.t()), out2).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, value), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, value, size=(4, 2)).all(), + out1) + assert ops.isclose(jit((x1, None), edge_index, value, size=(4, 2)).all(), + out2) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out1).all() + assert ops.isclose(jit((x1, None), adj2.t()), out2).all() + + +# @pytest.mark.parametrize('separate_gaussians', [True, False]) +# def test_lazy_gmm_conv(separate_gaussians): +# x1 = ops.randn(4, 8) +# x2 = ops.randn(2, 16) +# edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) +# value = ops.rand(edge_index.shape[1], 3) + +# conv = GMMConv(-1, 32, axis=3, kernel_size=25, +# separate_gaussians=separate_gaussians) +# assert str(conv) == 'GMMConv(-1, 32, axis=3)' +# out = conv(x1, edge_index, value) +# assert out.shape == (4, 32) + +# conv = GMMConv((-1, -1), 32, axis=3, kernel_size=25, +# separate_gaussians=separate_gaussians) +# assert str(conv) == 'GMMConv((-1, -1), 32, axis=3)' +# out = conv((x1, x2), edge_index, value) +# assert out.shape == (2, 32) diff --git a/tests/graph/nn/conv/test_gps_conv.py b/tests/graph/nn/conv/test_gps_conv.py new file mode 100644 index 000000000..686a8dfdd --- /dev/null +++ b/tests/graph/nn/conv/test_gps_conv.py @@ -0,0 +1,37 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GPSConv, SAGEConv +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('attn_type', ['multihead', 'performer']) +@pytest.mark.parametrize('norm', [None, 'batch_norm', 'layer_norm']) +def test_gps_conv(norm, attn_type): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 3, 2]]) + batch = ms.Tensor([0, 0, 1, 1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = GPSConv(16, conv=SAGEConv(16, 16), heads=4, norm=norm, + attn_type=attn_type) + conv.reset_parameters() + assert str(conv) == (f'GPSConv(16, conv=SAGEConv(16, 16, aggr=mean), ' + f'heads=4, attn_type={attn_type})') + + out = conv(x, edge_index) + assert out.shape == (4, 16) + assert ops.isclose(conv(x, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out, atol=1e-6).all() + + out = conv(x, edge_index, batch) + assert out.shape == (4, 16) + assert ops.isclose(conv(x, adj1.t(), batch), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, adj2.t(), batch), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_graph_conv.py b/tests/graph/nn/conv/test_graph_conv.py new file mode 100644 index 000000000..24d224a45 --- /dev/null +++ b/tests/graph/nn/conv/test_graph_conv.py @@ -0,0 +1,108 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import GraphConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_graph_conv(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.randn(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = GraphConv(8, 32) + assert str(conv) == 'GraphConv(8, 32)' + out1 = conv(x1, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, size=(4, 4)), out1, atol=1e-6).all() + assert ops.isclose(conv(x1, adj1.t()), out1, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj3.t()), out1, atol=1e-6).all() + + out2 = conv(x1, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, size=(4, 4)), out2, atol=1e-6).all() + assert ops.isclose(conv(x1, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out1).all() + assert ops.isclose(jit(x1, edge_index, size=(4, 4)), out1).all() + assert ops.isclose(jit(x1, edge_index, value), out2).all() + assert ops.isclose(jit(x1, edge_index, value, size=(4, 4)), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x1, adj4.t()), out2, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + adj2 = to_csr(edge_index, value, shape=(4, 2)) + + conv = GraphConv((8, 16), 32) + assert str(conv) == 'GraphConv((8, 16), 32)' + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, size=(4, 2)), out1).all() + assert ops.isclose(conv((x1, x2), adj1.t()), out1, atol=1e-6).all() + + out2 = conv((x1, None), edge_index, size=(4, 2)) + assert out2.shape == (2, 32) + assert ops.isclose(conv((x1, None), adj1.t()), out2, atol=1e-6).all() + + out3 = conv((x1, x2), edge_index, value) + assert out3.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out3).all() + assert ops.isclose(conv((x1, x2), adj2.t()), out3, atol=1e-6).all() + + out4 = conv((x1, None), edge_index, value, size=(4, 2)) + assert out4.shape == (2, 32) + assert ops.isclose(conv((x1, None), adj2.t()), out4, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv((x1, None), adj3.t()), out2, atol=1e-6).all() + assert ops.isclose(conv((x1, x2), adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv((x1, None), adj4.t()), out4, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, size=(4, 2)), out1).all() + assert ops.isclose(jit((x1, None), edge_index, size=(4, 2)), out2).all() + assert ops.isclose(jit((x1, x2), edge_index, value), out3).all() + assert ops.isclose(jit((x1, x2), edge_index, value, (4, 2)), out3).all() + assert ops.isclose(jit((x1, None), edge_index, value, (4, 2)), out4).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit((x1, None), adj3.t()), out2, atol=1e-6).all() + assert ops.isclose(jit((x1, x2), adj4.t()), out3, atol=1e-6).all() + assert ops.isclose(jit((x1, None), adj4.t()), out4, atol=1e-6).all() + + +class EdgeGraphConv(GraphConv): + def message(self, x_j, edge_weight): + return edge_weight.view(-1, 1) * x_j + + +def test_inheritance(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_weight = ops.rand(4) + + conv = EdgeGraphConv(8, 16) + assert conv(x, edge_index, edge_weight).shape == (4, 16) diff --git a/tests/graph/nn/conv/test_gravnet_conv.py b/tests/graph/nn/conv/test_gravnet_conv.py new file mode 100644 index 000000000..93327180a --- /dev/null +++ b/tests/graph/nn/conv/test_gravnet_conv.py @@ -0,0 +1,34 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GravNetConv +from mindscience.sharker.testing import is_full_test + + +def test_gravnet_conv(): + x1 = ops.randn(8, 16) + x2 = ops.randn(4, 16) + batch1 = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + batch2 = ms.Tensor([0, 0, 1, 1]) + + conv = GravNetConv(16, 32, space_dimensions=4, propagate_dimensions=8, k=2) + assert str(conv) == 'GravNetConv(16, 32, k=2)' + + out11 = conv(x1) + assert out11.shape == (8, 32) + + out12 = conv(x1, batch1) + assert out12.shape == (8, 32) + + out21 = conv((x1, x2)) + assert out21.shape == (4, 32) + + out22 = conv((x1, x2), (batch1, batch2)) + assert out22.shape == (4, 32) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1), out11).all() + assert ops.isclose(jit(x1, batch1), out12).all() + + assert ops.isclose(jit((x1, x2)), out21).all() + assert ops.isclose(jit((x1, x2), (batch1, batch2)), out22).all() diff --git a/tests/graph/nn/conv/test_han_conv.py b/tests/graph/nn/conv/test_han_conv.py new file mode 100644 index 000000000..925621325 --- /dev/null +++ b/tests/graph/nn/conv/test_han_conv.py @@ -0,0 +1,136 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import HANConv +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import coalesce, to_csr + + +def test_han_conv(): + x_dict = { + 'author': ops.randn(6, 16), + 'paper': ops.randn(5, 12), + 'term': ops.randn(4, 3) + } + edge_index1 = coalesce(ops.randint(0, 6, (2, 7))) + edge_index2 = coalesce(ops.randint(0, 5, (2, 4))) + edge_index3 = coalesce(ops.randint(0, 3, (2, 5))) + edge_index_dict = { + ('author', 'metapath0', 'author'): edge_index1, + ('paper', 'metapath1', 'paper'): edge_index2, + ('paper', 'metapath2', 'paper'): edge_index3, + } + + adj_t_dict1 = {} + for edge_type, edge_index in edge_index_dict.items(): + src_type, _, dst_type = edge_type + adj_t_dict1[edge_type] = to_csr( + edge_index, + shape=(x_dict[src_type].shape[0], x_dict[dst_type].shape[0]), + ).to_dense().t().to_csr() + + metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) + in_channels = {'author': 16, 'paper': 12, 'term': 3} + + conv = HANConv(in_channels, 16, metadata, heads=2) + assert str(conv) == 'HANConv(16, heads=2)' + out_dict1 = conv(x_dict, edge_index_dict) + assert len(out_dict1) == 3 + assert out_dict1['author'].shape == (6, 16) + assert out_dict1['paper'].shape == (5, 16) + assert out_dict1['term'] is None + del out_dict1['term'] + del x_dict['term'] + + out_dict2 = conv(x_dict, adj_t_dict1) + assert len(out_dict1) == len(out_dict2) + for key in out_dict1.keys(): + assert ops.isclose(out_dict1[key], out_dict2[key], atol=1e-6).all() + + if typing.WITH_SPARSE: + adj_t_dict2 = {} + for edge_type, edge_index in edge_index_dict.items(): + adj_t_dict2[edge_type] = SparseTensor.from_edge_index( + edge_index, + sparse_shape=adj_t_dict1[edge_type].shape[::-1], + ).t() + out_dict3 = conv(x_dict, adj_t_dict2) + assert len(out_dict1) == len(out_dict3) + for key in out_dict3.keys(): + assert ops.isclose(out_dict1[key], out_dict3[key], atol=1e-6).all() + + # Test non-zero dropout: + conv = HANConv(in_channels, 16, metadata, heads=2, dropout=0.1) + assert str(conv) == 'HANConv(16, heads=2)' + out_dict1 = conv(x_dict, edge_index_dict) + assert len(out_dict1) == 2 + assert out_dict1['author'].shape == (6, 16) + assert out_dict1['paper'].shape == (5, 16) + + +# def test_han_conv_lazy(): +# x_dict = { +# 'author': ops.randn(6, 16), +# 'paper': ops.randn(5, 12), +# } +# edge_index1 = coalesce(ops.randint(0, 6, (2, 8))) +# edge_index2 = coalesce(ops.randint(0, 5, (2, 6))) +# edge_index_dict = { +# ('author', 'to', 'author'): edge_index1, +# ('paper', 'to', 'paper'): edge_index2, +# } + +# adj_t_dict1 = {} +# for edge_type, edge_index in edge_index_dict.items(): +# src_type, _, dst_type = edge_type +# adj_t_dict1[edge_type] = to_csr( +# edge_index, +# shape=(x_dict[src_type].shape[0], x_dict[dst_type].shape[0]), +# ).to_dense().t().to_csr() + +# metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) +# conv = HANConv(-1, 16, metadata, heads=2) +# assert str(conv) == 'HANConv(16, heads=2)' +# out_dict1 = conv(x_dict, edge_index_dict) +# assert len(out_dict1) == 2 +# assert out_dict1['author'].shape == (6, 16) +# assert out_dict1['paper'].shape == (5, 16) + +# # out_dict2 = conv(x_dict, adj_t_dict1) +# # assert len(out_dict1) == len(out_dict2) +# # for key in out_dict1.keys(): +# # assert ops.isclose(out_dict1[key], out_dict2[key], atol=1e-6).all() + +# # if typing.WITH_SPARSE: +# # adj_t_dict2 = {} +# # for edge_type, edge_index in edge_index_dict.items(): +# # adj_t_dict2[edge_type] = SparseTensor.from_edge_index( +# # edge_index, +# # sparse_shapeadj_t_dict1[edge_type].shape[::-1], +# # ).t() +# # out_dict3 = conv(x_dict, adj_t_dict2) +# # assert len(out_dict1) == len(out_dict3) +# # for key in out_dict1.keys(): +# # assert ops.isclose(out_dict1[key], out_dict3[key], atol=1e-6).all() + + +# def test_han_conv_empty_tensor(): +# x_dict = { +# 'author': ops.randn(6, 16), +# 'paper': ms.numpy.empty([0, 12]), +# } +# edge_index_dict = { +# ('paper', 'to', 'author'): ms.numpy.empty((2, 0), dtype=ms.int64), +# ('author', 'to', 'paper'): ms.numpy.empty((2, 0), dtype=ms.int64), +# ('paper', 'to', 'paper'): ms.numpy.empty((2, 0), dtype=ms.int64), +# } + +# metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) +# in_channels = {'author': 16, 'paper': 12} +# conv = HANConv(in_channels, 16, metadata, heads=2) + +# out_dict = conv(x_dict, edge_index_dict) +# assert len(out_dict) == 2 +# assert out_dict['author'].shape == (6, 16) +# assert ops.all(out_dict['author'] == 0) +# assert out_dict['paper'].shape == (0, 16) diff --git a/tests/graph/nn/conv/test_heat_conv.py b/tests/graph/nn/conv/test_heat_conv.py new file mode 100644 index 000000000..ad7c4dc9c --- /dev/null +++ b/tests/graph/nn/conv/test_heat_conv.py @@ -0,0 +1,39 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import HEATConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('concat', [True, False]) +def test_heat_conv(concat): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn((4, 2)) + node_type = ms.Tensor([0, 0, 1, 2]) + edge_type = ms.Tensor([0, 2, 1, 2]) + + conv = HEATConv(in_channels=8, out_channels=16, num_node_types=3, + num_edge_types=3, edge_type_emb_dim=5, edge_dim=2, + edge_attr_emb_dim=6, heads=2, concat=concat) + assert str(conv) == 'HEATConv(8, 16, heads=2)' + + out = conv(x, edge_index, node_type, edge_type, edge_attr) + assert out.shape == (4, 32 if concat else 16) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + assert ops.isclose(conv(x, adj.t(), node_type, edge_type), out, + atol=1e-5).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose( + jit(x, edge_index, node_type, edge_type, edge_attr), out, + atol=1e-5).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t(), node_type, edge_type), out, + atol=1e-5).all() diff --git a/tests/graph/nn/conv/test_hetero_conv.py b/tests/graph/nn/conv/test_hetero_conv.py new file mode 100644 index 000000000..542699297 --- /dev/null +++ b/tests/graph/nn/conv/test_hetero_conv.py @@ -0,0 +1,204 @@ +import pytest +from mindspore import ops, nn +from mindscience.sharker.data import HeteroGraph +# from mindscience.sharker.datasets import FakeHeteroDataset +from mindscience.sharker.nn import ( + GATConv, + GCN2Conv, + GCNConv, + HeteroConv, + MessagePassing, + SAGEConv, +) + +from mindscience.sharker.testing import ( + get_random_edge_index +) + + +@pytest.mark.parametrize('aggr', ['sum', 'mean', 'min', 'max', 'cat', None]) +def test_hetero_conv(aggr): + data = HeteroGraph() + data['paper'].x = ops.randn(50, 32) + data['author'].x = ops.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['paper', 'author'].edge_attr = ops.randn(100, 3) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + data['paper', 'paper'].edge_weight = ops.rand(200) + + # Unspecified edge types should be ignored: + data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100) + + conv = HeteroConv( + { + ('paper', 'to', 'paper'): + GCNConv(32, 64), + ('author', 'to', 'paper'): + SAGEConv((64, 32), 64), + ('paper', 'to', 'author'): + GATConv((32, 64), 64, edge_dim=3, add_self_loop=False), + }, + aggr=aggr, + ) + + assert len(list(conv.trainable_params())) > 0 + assert str(conv) == 'HeteroConv(num_relations=3)' + + out_dict = conv( + data.x_dict, + data.edge_index_dict, + data.edge_attr_dict, + edge_weight_dict=data.edge_weight_dict, + ) + + assert len(out_dict) == 2 + if aggr == 'cat': + assert out_dict['paper'].shape == (50, 128) + assert out_dict['author'].shape == (30, 64) + elif aggr is not None: + assert out_dict['paper'].shape == (50, 64) + assert out_dict['author'].shape == (30, 64) + else: + assert out_dict['paper'].shape == (50, 2, 64) + assert out_dict['author'].shape == (30, 1, 64) + + +def test_gcn2_hetero_conv(): + data = HeteroGraph() + data['paper'].x = ops.randn(50, 32) + data['author'].x = ops.randn(30, 64) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['author', 'author'].edge_index = get_random_edge_index(30, 30, 100) + data['paper', 'paper'].edge_weight = ops.rand(200) + + conv = HeteroConv({ + ('paper', 'to', 'paper'): GCN2Conv(32, alpha=0.1), + ('author', 'to', 'author'): GCN2Conv(64, alpha=0.2), + }) + + out_dict = conv( + data.x_dict, + data.x_dict, + data.edge_index_dict, + edge_weight_dict=data.edge_weight_dict, + ) + + assert len(out_dict) == 2 + assert out_dict['paper'].shape == (50, 32) + assert out_dict['author'].shape == (30, 64) + + +class CustomConv(MessagePassing): + def __init__(self, in_channels, out_channels): + super().__init__(aggr='add') + self.lin = nn.Dense(in_channels, out_channels) + + def construct(self, x, edge_index, y, z): + return self.propagate(edge_index, x=x, y=y, z=z) + + def message(self, x_j, y_j, z_j): + return self.lin(ops.cat(([x_j, y_j, z_j]), axis=-1)) + + +def test_hetero_conv_with_custom_conv(): + data = HeteroGraph() + data['paper'].x = ops.randn(50, 32) + data['paper'].y = ops.randn(50, 3) + data['paper'].z = ops.randn(50, 3) + data['author'].x = ops.randn(30, 64) + data['author'].y = ops.randn(30, 3) + data['author'].z = ops.randn(30, 3) + data['paper', 'paper'].edge_index = get_random_edge_index(50, 50, 200) + data['paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'paper'].edge_index = get_random_edge_index(30, 50, 100) + + conv = {} + for key in data.edge_types: + name = key[0] + in_channel = data[name].x.shape[1] + data[name].y.shape[1] + data[name].z.shape[1] + cell = CustomConv(in_channel, 64) + conv[key] = cell + net = HeteroConv(conv) + # Test node `args_dict` and `kwargs_dict` with `y_dict` and `z_dict`: + out_dict = net( + data.x_dict, + data.edge_index_dict, + data.y_dict, + z_dict=data.z_dict, + ) + assert len(out_dict) == 2 + assert out_dict['paper'].shape == (50, 64) + assert out_dict['author'].shape == (30, 64) + + +class MessagePassingLoops(MessagePassing): + def __init__(self): + super().__init__() + self.add_self_loops = True + + +def test_hetero_conv_self_loop_error(): + HeteroConv({('a', 'to', 'a'): MessagePassingLoops()}) + with pytest.raises(ValueError, match="incorrect message passing"): + HeteroConv({('a', 'to', 'b'): MessagePassingLoops()}) + + +def test_hetero_conv_with_dot_syntax_node_types(): + data = HeteroGraph() + data['src.paper'].x = ops.randn(50, 32) + data['author'].x = ops.randn(30, 64) + edge_index = get_random_edge_index(50, 50, 200) + data['src.paper', 'src.paper'].edge_index = edge_index + data['src.paper', 'author'].edge_index = get_random_edge_index(50, 30, 100) + data['author', 'src.paper'].edge_index = get_random_edge_index(30, 50, 100) + data['src.paper', 'src.paper'].edge_weight = ops.rand(200) + + conv = HeteroConv({ + ('src.paper', 'to', 'src.paper'): + GCNConv(32, 64), + ('author', 'to', 'src.paper'): + SAGEConv((64, 32), 64), + ('src.paper', 'to', 'author'): + GATConv((32, 64), 64, add_self_loop=False), + }) + + assert len(list(conv.trainable_params())) > 0 + assert str(conv) == 'HeteroConv(num_relations=3)' + + out_dict = conv( + data.x_dict, + data.edge_index_dict, + edge_weight_dict=data.edge_weight_dict, + ) + + assert len(out_dict) == 2 + assert out_dict['src.paper'].shape == (50, 64) + assert out_dict['author'].shape == (30, 64) + + +# def test_compile_hetero_conv_graph_breaks(): +# import torch._dynamo as dynamo + +# data = HeteroGraph() +# data['a'].x = ops.randn(50, 16) +# data['b'].x = ops.randn(50, 16) +# edge_index = get_random_edge_index(50, 50, 100) +# data['a', 'to', 'b'].edge_index = edge_index +# data['b', 'to', 'a'].edge_index = edge_index.flip([0]) + +# conv = HeteroConv({ +# ('a', 'to', 'b'): SAGEConv(16, 32), +# ('b', 'to', 'a'): SAGEConv(16, 32), +# }) + +# explanation = dynamo.explain(conv)(data.x_dict, data.edge_index_dict) +# assert explanation.graph_break_count == 0 + +# compiled_conv = torch.compile(conv) + +# expected = conv(data.x_dict, data.edge_index_dict) +# out = compiled_conv(data.x_dict, data.edge_index_dict) +# assert len(out) == len(expected) +# for key in expected.keys(): +# assert ops.isclose(out[key], expected[key], atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_hgt_conv.py b/tests/graph/nn/conv/test_hgt_conv.py new file mode 100644 index 000000000..647e70710 --- /dev/null +++ b/tests/graph/nn/conv/test_hgt_conv.py @@ -0,0 +1,228 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.nn import HGTConv +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import get_random_edge_index +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import coalesce, to_csr + + +def test_hgt_conv_same_dimensions(): + x_dict = { + 'author': ops.randn(4, 16), + 'paper': ops.randn(6, 16), + } + edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) + + edge_index_dict = { + ('author', 'writes', 'paper'): edge_index, + ('paper', 'written_by', 'author'): edge_index.flip([0]), + } + + adj_t_dict1 = {} + for edge_type, edge_index in edge_index_dict.items(): + src_type, _, dst_type = edge_type + adj_t_dict1[edge_type] = to_csr( + edge_index, + shape=(x_dict[src_type].shape[0], x_dict[dst_type].shape[0]), + ).t() + + metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) + + conv = HGTConv(16, 16, metadata, heads=2) + assert str(conv) == 'HGTConv(-1, 16, heads=2)' + out_dict1 = conv(x_dict, edge_index_dict) + assert len(out_dict1) == 2 + assert out_dict1['author'].shape == (4, 16) + assert out_dict1['paper'].shape == (6, 16) + + out_dict2 = conv(x_dict, adj_t_dict1) + assert len(out_dict1) == len(out_dict2) + for key in out_dict1.keys(): + assert ops.isclose(out_dict1[key], out_dict2[key], atol=1e-6).all() + + if typing.WITH_SPARSE: + adj_t_dict2 = {} + for edge_type, edge_index in edge_index_dict.items(): + adj_t_dict2[edge_type] = SparseTensor.from_edge_index( + edge_index, + sparse_shape=adj_t_dict1[edge_type].shape[::-1], + ).t() + out_dict3 = conv(x_dict, adj_t_dict2) + assert len(out_dict1) == len(out_dict3) + for key in out_dict1.keys(): + assert ops.isclose(out_dict1[key], out_dict3[key], atol=1e-6).all() + + +def test_hgt_conv_different_dimensions(): + x_dict = { + 'author': ops.randn(4, 16), + 'paper': ops.randn(6, 32), + } + edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) + + edge_index_dict = { + ('author', 'writes', 'paper'): edge_index, + ('paper', 'written_by', 'author'): edge_index.flip([0]), + } + + adj_t_dict1 = {} + for edge_type, edge_index in edge_index_dict.items(): + src_type, _, dst_type = edge_type + adj_t_dict1[edge_type] = to_csr( + edge_index, + shape=(x_dict[src_type].shape[0], x_dict[dst_type].shape[0]), + ).t() + + metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) + + conv = HGTConv(in_channels={ + 'author': 16, + 'paper': 32 + }, out_channels=32, metadata=metadata, heads=2) + assert str(conv) == 'HGTConv(-1, 32, heads=2)' + out_dict1 = conv(x_dict, edge_index_dict) + assert len(out_dict1) == 2 + assert out_dict1['author'].shape == (4, 32) + assert out_dict1['paper'].shape == (6, 32) + + out_dict2 = conv(x_dict, adj_t_dict1) + assert len(out_dict1) == len(out_dict2) + for key in out_dict1.keys(): + assert ops.isclose(out_dict1[key], out_dict2[key], atol=1e-6).all() + + if typing.WITH_SPARSE: + adj_t_dict2 = {} + for edge_type, edge_index in edge_index_dict.items(): + adj_t_dict2[edge_type] = SparseTensor.from_edge_index( + edge_index, + sparse_shape=adj_t_dict1[edge_type].shape[::-1], + ).t() + out_dict3 = conv(x_dict, adj_t_dict2) + assert len(out_dict1) == len(out_dict3) + for node_type in out_dict1.keys(): + assert ops.isclose(out_dict1[key], out_dict3[key], atol=1e-6).all() + + +# def test_hgt_conv_lazy(): +# x_dict = { +# 'author': ops.randn(4, 16), +# 'paper': ops.randn(6, 32), +# } +# edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) + +# edge_index_dict = { +# ('author', 'writes', 'paper'): edge_index, +# ('paper', 'written_by', 'author'): edge_index.flip([0]), +# } + +# adj_t_dict1 = {} +# for edge_type, edge_index in edge_index_dict.items(): +# src_type, _, dst_type = edge_type +# adj_t_dict1[edge_type] = to_csr( +# edge_index, +# size=(x_dict[src_type].shape[0], x_dict[dst_type].shape[0]), +# ).t() + +# metadata = (list(x_dict.keys()), list(edge_index_dict.keys())) + +# conv = HGTConv(-1, 32, metadata, heads=2) +# assert str(conv) == 'HGTConv(-1, 32, heads=2)' +# out_dict1 = conv(x_dict, edge_index_dict) +# assert len(out_dict1) == 2 +# assert out_dict1['author'].shape == (4, 32) +# assert out_dict1['paper'].shape == (6, 32) + +# out_dict2 = conv(x_dict, adj_t_dict1) +# assert len(out_dict1) == len(out_dict2) +# for key in out_dict1.keys(): +# assert ops.isclose(out_dict1[key], out_dict2[key], atol=1e-6).all() + +# if False and typing.WITH_SPARSE: +# adj_t_dict2 = {} +# for edge_type, edge_index in edge_index_dict.items(): +# adj_t_dict2[edge_type] = SparseTensor.from_edge_index( +# edge_index, +# sparse_shapeadj_t_dict1[edge_type].shape[::-1], +# ).t() +# out_dict3 = conv(x_dict, adj_t_dict2) +# assert len(out_dict1) == len(out_dict3) +# for key in out_dict1.keys(): +# assert ops.isclose(out_dict1[key], out_dict3[key], atol=1e-6).all() + + +def test_hgt_conv_out_of_place(): + data = HeteroGraph() + data['author'].x = ops.randn(4, 16) + data['paper'].x = ops.randn(6, 32) + + edge_index = coalesce(get_random_edge_index(4, 6, num_edges=20)) + + data['author', 'paper'].edge_index = edge_index + data['paper', 'author'].edge_index = edge_index.flip([0]) + + conv = HGTConv({'author': 16, 'paper': 32}, 64, data.metadata(), heads=1) + + x_dict, edge_index_dict = data.x_dict, data.edge_index_dict + assert x_dict['author'].shape == (4, 16) + assert x_dict['paper'].shape == (6, 32) + + _ = conv(x_dict, edge_index_dict) + + assert x_dict['author'].shape == (4, 16) + assert x_dict['paper'].shape == (6, 32) + + +def test_hgt_conv_missing_dst_node_type(): + data = HeteroGraph() + data['author'].x = ops.randn(4, 16) + data['paper'].x = ops.randn(6, 32) + data['university'].x = ops.randn(10, 32) + + data['author', 'paper'].edge_index = get_random_edge_index(4, 6, 20) + data['paper', 'author'].edge_index = get_random_edge_index(6, 4, 20) + data['university', 'author'].edge_index = get_random_edge_index(10, 4, 10) + + conv = HGTConv({'author': 16, 'paper': 32, 'university': 32}, 64, data.metadata(), heads=1) + + out_dict = conv(data.x_dict, data.edge_index_dict) + assert out_dict['author'].shape == (4, 64) + assert out_dict['paper'].shape == (6, 64) + assert 'university' not in out_dict + + +def test_hgt_conv_missing_input_node_type(): + data = HeteroGraph() + data['author'].x = ops.randn(4, 16) + data['paper'].x = ops.randn(6, 32) + data['author', 'writes', 'paper'].edge_index = get_random_edge_index(4, 6, 20) + + # Some nodes from metadata are missing in data. + # This might happen while using NeighborLoader. + metadata = (['author', 'paper', 'university'], [('author', 'writes', 'paper')]) + conv = HGTConv({'author': 16, 'paper': 32}, 64, metadata, heads=1) + + out_dict = conv(data.x_dict, data.edge_index_dict) + assert out_dict['paper'].shape == (6, 64) + assert 'university' not in out_dict + + +def test_hgt_conv_missing_edge_type(): + data = HeteroGraph() + data['author'].x = ops.randn(4, 16) + data['paper'].x = ops.randn(6, 32) + data['university'].x = ops.randn(10, 32) + + data['author', 'writes', 'paper'].edge_index = get_random_edge_index(4, 6, 20) + + metadata = (['author', 'paper', + 'university'], [('author', 'writes', 'paper'), + ('university', 'employs', 'author')]) + conv = HGTConv({'author': 16, 'paper': 32}, 64, metadata, heads=1) + + out_dict = conv(data.x_dict, data.edge_index_dict) + assert out_dict['author'].shape == (4, 64) + assert out_dict['paper'].shape == (6, 64) + assert 'university' not in out_dict diff --git a/tests/graph/nn/conv/test_hypergraph_conv.py b/tests/graph/nn/conv/test_hypergraph_conv.py new file mode 100644 index 000000000..18a795265 --- /dev/null +++ b/tests/graph/nn/conv/test_hypergraph_conv.py @@ -0,0 +1,48 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import HypergraphConv + + +def test_hypergraph_conv_with_more_nodes_than_edges(): + in_channels, out_channels = (16, 32) + hyperedge_index = ms.Tensor([[0, 0, 1, 1, 2, 3], [0, 1, 0, 1, 0, 1]]) + num_nodes = hyperedge_index[0].max().item() + 1 + num_edges = hyperedge_index[1].max().item() + 1 + x = ops.randn((num_nodes, in_channels)) + hyperedge_weight = ms.Tensor([1.0, 0.5]) + hyperedge_attr = ops.randn((num_edges, in_channels)) + + conv = HypergraphConv(in_channels, out_channels) + assert str(conv) == 'HypergraphConv(16, 32)' + out = conv(x, hyperedge_index) + assert out.shape == (num_nodes, out_channels) + out = conv(x, hyperedge_index, hyperedge_weight) + assert out.shape == (num_nodes, out_channels) + + conv = HypergraphConv(in_channels, out_channels, use_attention=True, + heads=2) + out = conv(x, hyperedge_index, hyperedge_attr=hyperedge_attr) + assert out.shape == (num_nodes, 2 * out_channels) + out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr) + assert out.shape == (num_nodes, 2 * out_channels) + + conv = HypergraphConv(in_channels, out_channels, use_attention=True, + heads=2, concat=False, dropout=0.5) + out = conv(x, hyperedge_index, hyperedge_weight, hyperedge_attr) + assert out.shape == (num_nodes, out_channels) + + +def test_hypergraph_conv_with_more_edges_than_nodes(): + in_channels, out_channels = (16, 32) + hyperedge_index = ms.Tensor([[0, 0, 1, 1, 2, 3, 3, 3, 2, 1, 2], + [0, 1, 2, 1, 2, 1, 0, 3, 3, 4, 4]]) + hyperedge_weight = ms.Tensor([1.0, 0.5, 0.8, 0.2, 0.7]) + num_nodes = hyperedge_index[0].max().item() + 1 + x = ops.randn((num_nodes, in_channels)) + + conv = HypergraphConv(in_channels, out_channels) + assert str(conv) == 'HypergraphConv(16, 32)' + out = conv(x, hyperedge_index) + assert out.shape == (num_nodes, out_channels) + out = conv(x, hyperedge_index, hyperedge_weight) + assert out.shape == (num_nodes, out_channels) diff --git a/tests/graph/nn/conv/test_le_conv.py b/tests/graph/nn/conv/test_le_conv.py new file mode 100644 index 000000000..f338d3ba0 --- /dev/null +++ b/tests/graph/nn/conv/test_le_conv.py @@ -0,0 +1,30 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import LEConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_le_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = LEConv(16, 32) + assert str(conv) == 'LEConv(16, 32)' + out = conv(x, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + ops.isclose(jit(x, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj2.t()), out).all() diff --git a/tests/graph/nn/conv/test_lg_conv.py b/tests/graph/nn/conv/test_lg_conv.py new file mode 100644 index 000000000..ddf67f2b8 --- /dev/null +++ b/tests/graph/nn/conv/test_lg_conv.py @@ -0,0 +1,40 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import LGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_lg_conv(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = LGConv() + assert str(conv) == 'LGConv()' + out1 = conv(x, edge_index) + assert out1.shape == (4, 8) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 8) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_message_passing.py b/tests/graph/nn/conv/test_message_passing.py new file mode 100644 index 000000000..eaa995138 --- /dev/null +++ b/tests/graph/nn/conv/test_message_passing.py @@ -0,0 +1,383 @@ +import copy +import os.path as osp +import pytest +import mindspore as ms +from typing import Optional, Tuple, Union +from mindspore import Tensor, ops, nn, mint +from mindscience.sharker import typing +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.nn import MessagePassing, aggr +from mindscience.sharker.typing import ( + Adj, + OptPairTensor, + OptTensor, + Size, + SparseTensor, +) +from mindscience.sharker.utils import ( + add_self_loops, + scatter, +) + + +class MyConv(MessagePassing): + def __init__(self, in_channels: Union[int, Tuple[int, int]], + out_channels: int, aggr: str = 'add'): + super().__init__(aggr=aggr) + + if isinstance(in_channels, int): + in_channels = (in_channels, in_channels) + + self.lin_l = nn.Dense(in_channels[0], out_channels) + self.lin_r = nn.Dense(in_channels[1], out_channels) + + def construct( + self, + x: Union[Tensor, OptPairTensor], + edge_index: Adj, + edge_weight: OptTensor = None, + size: Size = None, + ) -> Tensor: + + if isinstance(x, Tensor): + x = (x, x) + + out = self.propagate(edge_index, x=x, edge_weight=edge_weight, + shape=size) + out = self.lin_l(out) + + x_r = x[1] + if x_r is not None: + out += self.lin_r(x_r) + + return out + + def message(self, x_j: Tensor, edge_weight: Optional[Tensor]) -> Tensor: + return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j + + def message_and_aggregate(self, adj_t: Adj, x: OptPairTensor) -> Tensor: + return spmm(adj_t, x[0], reduce=self.aggr) + + +class MyConvWithSelfLoops(MessagePassing): + def __init__(self, aggr: str = 'add'): + super().__init__(aggr=aggr) + + def construct(self, x: Tensor, edge_index: Tensor) -> Tensor: + edge_index, _ = add_self_loops(edge_index) + + return self.propagate(edge_index, x=x) + + +def test_my_conv_basic(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.randn(edge_index.shape[1]) + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + + conv = MyConv(8, 32) + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + assert mint.isclose(conv(x1, edge_index, value, (4, 4)), out , rtol=1e-04, atol=1e-6).all() + if typing.WITH_SPARSE: + assert mint.isclose(conv(x1, adj2.t()), out , rtol=1e-04, atol=1e-6).all() + conv.fuse = False + if typing.WITH_SPARSE: + assert mint.isclose(conv(x1, adj2.t()), out , rtol=1e-04, atol=1e-6).all() + conv.fuse = True + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + + conv = MyConv((8, 16), 32) + out1 = conv((x1, x2), edge_index, value) + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out1.shape == (2, 32) + assert out2.shape == (2, 32) + assert mint.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + if typing.WITH_SPARSE: + assert mint.isclose(conv((x1, x2), adj2.t()), out1 , rtol=1e-04, atol=1e-6).all() + assert mint.isclose(conv((x1, None), adj2.t()), out2 , rtol=1e-04, atol=1e-6).all() + conv.fuse = False + if typing.WITH_SPARSE: + assert mint.isclose(conv((x1, x2), adj2.t()), out1 , rtol=1e-04, atol=1e-6).all() + assert mint.isclose(conv((x1, None), adj2.t()), out2 , rtol=1e-04, atol=1e-6).all() + + +class MyCommentedConv(MessagePassing): + r"""This layer calls `self.propagate()` internally.""" + + def construct(self, x: Tensor, edge_index: Tensor) -> Tensor: + return self.propagate(edge_index, x=x) + + +def test_my_commented_conv(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = MyCommentedConv() + conv(x, edge_index) + if is_full_test(): + jit = ms.jit(conv) + jit(x, edge_index) + + +def test_my_conv_out_of_bounds(): + x = ops.randn(3, 8) + value = ops.randn(4) + + conv = MyConv(8, 32) + + with pytest.raises(IndexError, match="valid indices"): + edge_index = ms.Tensor([[-1, 1, 2, 2], [0, 0, 1, 1]]) + conv(x, edge_index, value) + + with pytest.raises(IndexError, match="valid indices"): + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + conv(x, edge_index, value) + + +@pytest.mark.parametrize('aggr', ['add', 'sum', 'mean', 'min', 'max', 'mul']) +def test_my_conv_aggr(aggr): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + + conv = MyConv(8, 32, aggr=aggr) + out = conv(x, edge_index, edge_weight) + assert out.shape == (4, 32) + + +def test_my_static_graph_conv(): + x1 = ops.randn(3, 4, 8) + x2 = ops.randn(3, 2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.randn(edge_index.shape[1]) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + + conv = MyConv(8, 32) + out = conv(x1, edge_index, value) + assert out.shape == (3, 4, 32) + assert mint.isclose(conv(x1, edge_index, value, (4, 4)), out).all() + if typing.WITH_SPARSE: + assert mint.isclose(conv(x1, adj.t()), out).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + + conv = MyConv((8, 16), 32) + out1 = conv((x1, x2), edge_index, value) + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out1.shape == (3, 2, 32) + assert out2.shape == (3, 2, 32) + assert mint.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + if typing.WITH_SPARSE: + assert mint.isclose(conv((x1, x2), adj.t()), out1).all() + assert mint.isclose(conv((x1, None), adj.t()), out2).all() + + +class MyMultipleAggrConv(MessagePassing): + def __init__(self, **kwargs): + super().__init__(aggr=['add', 'mean', 'max'], **kwargs) + + def construct(self, x: Tensor, edge_index: Adj) -> Tensor: + return self.propagate(edge_index, x=x) + + +@pytest.mark.parametrize('multi_aggr_tuple', [ + (dict(mode='cat'), 3), + (dict(mode='proj', mode_kwargs=dict(in_channels=16, out_channels=16)), 1) +]) +def test_my_multiple_aggr_conv(multi_aggr_tuple): + aggr_kwargs, expand = multi_aggr_tuple + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + + conv = MyMultipleAggrConv(aggr_kwargs=aggr_kwargs) + out = conv(x, edge_index) + assert out.shape == (4, 16 * expand) + if typing.WITH_SPARSE: + assert mint.isclose(conv(x, adj2.t()), out).all() + + +def test_copy(): + conv = MyConv(8, 32) + conv2 = copy.copy(conv) + + assert conv != conv2 + assert ops.equal(conv.lin_l.weight, conv2.lin_l.weight).all() + assert ops.equal(conv.lin_r.weight, conv2.lin_r.weight).all() + assert conv.lin_l.weight is conv2.lin_l.weight + assert conv.lin_r.weight is conv2.lin_r.weight + + conv = copy.deepcopy(conv) + assert conv != conv2 + assert ops.equal(conv.lin_l.weight, conv2.lin_l.weight).all() + assert ops.equal(conv.lin_r.weight, conv2.lin_r.weight).all() + assert conv.lin_l.weight is not conv2.lin_l.weight + assert conv.lin_r.weight is not conv2.lin_r.weight + + +class MyEdgeConv(MessagePassing): + def __init__(self): + super().__init__(aggr='add') + + def construct(self, x: Tensor, edge_index: Adj) -> Tensor: + edge_attr = self.edge_updater(edge_index, x=x) + return self.propagate(edge_index, edge_attr=edge_attr, + shape=(x.shape[0], x.shape[0])) + + def edge_update(self, x_j: Tensor, x_i: Tensor) -> Tensor: + return x_j - x_i + + def message(self, edge_attr: Tensor) -> Tensor: + return edge_attr + + +def test_my_edge_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + row, col = edge_index + expected = scatter(x[row] - x[col], col, dim=0, dim_size=4, reduce='sum') + + conv = MyEdgeConv() + out = conv(x, edge_index) + assert out.shape == (4, 16) + assert mint.isclose(out, expected).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert mint.isclose(conv(x, adj2.t()), out).all() + + +class MyDefaultArgConv(MessagePassing): + def __init__(self): + super().__init__(aggr='mean') + + def construct(self, x: Tensor, edge_index: Adj) -> Tensor: + return self.propagate(edge_index, x=x) + + def message(self, x_j, zeros: bool = True): + return x_j * 0 if zeros else x_j + + +def test_my_default_arg_conv(): + x = ops.randn(4, 1) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = MyDefaultArgConv() + assert conv(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + + if is_full_test(): + jit = ms.jit(conv) + assert jit(x, edge_index).view(-1).tolist() == [0, 0, 0, 0] + + +class MyMultipleOutputConv(MessagePassing): + def __init__(self): + super().__init__() + + def construct(self, x: Tensor, edge_index: Tensor) -> Tuple[Tensor, Tensor]: + return self.propagate(edge_index, x=x) + + def message(self, x_j: Tensor) -> Tuple[Tensor, Tensor]: + return x_j, x_j + + def aggregate(self, inputs: Tuple[Tensor, Tensor], + index: Tensor) -> Tuple[Tensor, Tensor]: + return (scatter(inputs[0], index, dim=0, reduce='sum'), + scatter(inputs[0], index, dim=0, reduce='mean')) + + def update(self, inputs: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: + return inputs + + +def test_tuple_output(): + conv = MyMultipleOutputConv() + + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + out1 = conv(x, edge_index) + assert isinstance(out1, tuple) and len(out1) == 2 + + +class MyExplainConv(MessagePassing): + def __init__(self): + super().__init__(aggr='add') + + def construct(self, x: Tensor, edge_index: Adj) -> Tensor: + return self.propagate(edge_index, x=x) + + +def test_explain_message(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = MyExplainConv() + conv.explain = True + assert conv.propagate.__module__.endswith('message_passing') + + with pytest.raises(ValueError, match="pre-defined 'edge_mask'"): + conv(x, edge_index) + + conv._edge_mask = ms.Tensor([0.0, 0.0, 0.0, 0.0]) + conv._apply_sigmoid = False + assert conv(x, edge_index).abs().sum() == 0. + + conv._edge_mask = ms.Tensor([1.0, 1.0, 1.0, 1.0]) + conv._apply_sigmoid = False + out1 = conv(x, edge_index) + + conv.explain = False + out2 = conv(x, edge_index) + assert mint.isclose(out1, out2).all() + + +class MyAggregatorConv(MessagePassing): + def __init__(self, **kwargs): + super().__init__(**kwargs) + + def construct(self, x: Tensor, edge_index: Adj) -> Tensor: + return self.propagate(edge_index, x=x) + + +@pytest.mark.parametrize('aggr_module', [ + aggr.MeanAggregation(), + aggr.SumAggregation(), + aggr.MaxAggregation(), + aggr.SoftmaxAggregation(), + aggr.PowerMeanAggregation(), + aggr.MultiAggregation(['mean', 'max']) +]) +def test_message_passing_with_aggr_module(aggr_module): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + row, col = edge_index + + conv = MyAggregatorConv(aggr=aggr_module) + assert isinstance(conv.aggr_module, aggr.Aggregation) + out = conv(x, edge_index) + assert out.shape[0] == 4 and out.shape[1] in {8, 16} + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_sizes=(4, 4)) + assert mint.isclose(conv(x, adj2.t()), out).all() + + +def test_message_passing_int32_edge_index(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]], dtype=ms.int32) + edge_weight = ops.randn(edge_index.shape[1]) + + conv = MyConv(8, 32) + assert conv(x, edge_index, edge_weight).shape == (4, 32) \ No newline at end of file diff --git a/tests/graph/nn/conv/test_mf_conv.py b/tests/graph/nn/conv/test_mf_conv.py new file mode 100644 index 000000000..19dc98da3 --- /dev/null +++ b/tests/graph/nn/conv/test_mf_conv.py @@ -0,0 +1,56 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import MFConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_mf_conv(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = MFConv(8, 32) + assert str(conv) == 'MFConv(8, 32)' + out = conv(x1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, shape=(4, 4)), out).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out).all() + assert ops.isclose(jit(x1, edge_index, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out).all() + + # Test bipartite message passing: + conv = MFConv((8, 16), 32) + assert str(conv) == 'MFConv((8, 16), 32)' + + out1 = conv((x1, x2), edge_index) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, (4, 2)), out1).all() + + out2 = conv((x1, None), edge_index, (4, 2)) + assert out2.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out1).all() + assert ops.isclose(conv((x1, None), adj.t()), out2).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, size=(4, 2)), out1).all() + assert ops.isclose(jit((x1, None), edge_index, size=(4, 2)), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out1).all() + assert ops.isclose(jit((x1, None), adj.t()), out2).all() diff --git a/tests/graph/nn/conv/test_mixhop_conv.py b/tests/graph/nn/conv/test_mixhop_conv.py new file mode 100644 index 000000000..451f32ab8 --- /dev/null +++ b/tests/graph/nn/conv/test_mixhop_conv.py @@ -0,0 +1,41 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import MixHopConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_mixhop_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = MixHopConv(16, 32, powers=[0, 1, 2, 4]) + assert str(conv) == 'MixHopConv(16, 32, powers=[0, 1, 2, 4])' + + out1 = conv(x, edge_index) + assert out1.shape == (4, 128) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 128) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_nn_conv.py b/tests/graph/nn/conv/test_nn_conv.py new file mode 100644 index 000000000..733ca5f1d --- /dev/null +++ b/tests/graph/nn/conv/test_nn_conv.py @@ -0,0 +1,82 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker import typing +from mindscience.sharker.nn import NNConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_coo + + +def test_nn_conv(): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1], 3) + # adj1 = to_coo(edge_index, value, shape=(4, 4)) + + net = Seq(Lin(3, 32), ReLU(), Lin(32, 8 * 32)) + conv = NNConv(8, 32, net=net) + assert str(conv) == ( + 'NNConv(8, 32, aggr=add, nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)') + + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, size=(4, 4)), out).all() + # assert ops.isclose(conv(x1, adj1.swapaxes(0, 1).coalesce()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, value), out).all() + assert ops.isclose(jit(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out).all() + + # Test bipartite message passing: + # adj1 = to_coo(edge_index, value, shape=(4, 2)) + + conv = NNConv((8, 16), 32, net=net) + assert str(conv) == ( + 'NNConv((8, 16), 32, aggr=add, nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)') + + out1 = conv((x1, x2), edge_index, value) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + # assert ops.isclose(conv((x1, x2).all(), + # adj1.swapaxes(0, 1).coalesce()), out1) + + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out2.shape == (2, 32) + # assert ops.isclose(conv((x1, None).all(), + # adj1.swapaxes(0, 1).coalesce()), out2) + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out1).all() + assert ops.isclose(conv((x1, None), adj2.t()), out2).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, value), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, value, size=(4, 2)).all(), + out1) + assert ops.isclose(jit((x1, None), edge_index, value, size=(4, 2)).all(), + out2) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out1).all() + assert ops.isclose(jit((x1, None), adj2.t()), out2).all() diff --git a/tests/graph/nn/conv/test_pan_conv.py b/tests/graph/nn/conv/test_pan_conv.py new file mode 100644 index 000000000..62b2a34c0 --- /dev/null +++ b/tests/graph/nn/conv/test_pan_conv.py @@ -0,0 +1,28 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import PANConv +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_pan_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + + conv = PANConv(16, 32, filter_size=2) + assert str(conv) == 'PANConv(16, 32, filter_size=2)' + out1, M1 = conv(x, edge_index) + assert out1.shape == (4, 32) + + out2, M2 = conv(x, adj1.t()) + assert ops.isclose(out1, out2, atol=1e-6).all() + assert ops.isclose(M1.to_dense(), M2.to_dense()).all() + + if typing.WITH_SPARSE: + out3, M3 = conv(x, adj2.t()) + assert ops.isclose(out1, out3, atol=1e-6).all() + assert ops.isclose(M1.to_dense(), M3.to_dense()).all() diff --git a/tests/graph/nn/conv/test_pdn_conv.py b/tests/graph/nn/conv/test_pdn_conv.py new file mode 100644 index 000000000..bf2a8db0e --- /dev/null +++ b/tests/graph/nn/conv/test_pdn_conv.py @@ -0,0 +1,55 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import PDNConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_pdn_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128) + assert str(conv) == "PDNConv(16, 32)" + + out = conv(x, edge_index, edge_attr) + assert out.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + assert ops.isclose(conv(x, adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index, edge_attr), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + + +# def test_pdn_conv_with_sparse_node_input_feature(): +# x = torch.sparse_coo_tensor( +# indices=ms.Tensor([[0, 0], [0, 1]]), +# values=ms.Tensor([1.0, 1.0]), +# size=torch.Size([4, 16]), +# ) +# edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) +# edge_attr = ops.randn(edge_index.shape[1], 8) + +# conv = PDNConv(16, 32, edge_dim=8, hidden_channels=128) + +# out = conv(x, edge_index, edge_attr) +# assert out.shape == (4, 32) + +# if typing.WITH_SPARSE: +# adj = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) +# assert ops.isclose(conv(x, adj.t(), edge_attr), out, atol=1e-6).all() + +# if is_full_test(): +# jit = ms.jit(conv) +# assert ops.isclose(jit(x, edge_index, edge_attr), out).all() + +# if typing.WITH_SPARSE: +# assert ops.isclose(jit(x, adj.t(), edge_attr), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_pna_conv.py b/tests/graph/nn/conv/test_pna_conv.py new file mode 100644 index 000000000..2fe11710d --- /dev/null +++ b/tests/graph/nn/conv/test_pna_conv.py @@ -0,0 +1,75 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import DataLoader, NeighborLoader +from mindscience.sharker.nn import PNAConv +from mindscience.sharker.testing import is_full_test # , onlyNeighborSampler +from mindscience.sharker.typing import SparseTensor + +aggregators = ['sum', 'mean', 'min', 'max', 'var', 'std'] +scalers = [ + 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' +] + + +@pytest.mark.parametrize('divide_input', [True, False]) +def test_pna_conv(divide_input): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + deg = ms.Tensor([0, 3, 0, 1]) + value = ops.rand(edge_index.shape[1], 3) + + conv = PNAConv(16, 32, aggregators, scalers, deg=deg, edge_dim=3, towers=4, + pre_layers=2, post_layers=2, divide_input=divide_input) + assert str(conv) == 'PNAConv(16, 32, towers=4, edge_dim=3)' + + out = conv(x, edge_index, value) + assert out.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index, value), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + + +# @onlyNeighborSampler +def test_pna_conv_get_degree_histogram_neighbor_loader(): + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) + data = Graph(num_nodes=5, edge_index=edge_index) + loader = NeighborLoader( + data, + num_neighbors=[-1], + input_nodes=None, + batch_size=5, + shuffle=False, + ) + deg_hist = PNAConv.get_degree_histogram(loader) + assert ops.equal(deg_hist, ms.Tensor([1, 2, 1, 1])).all() + + +def test_pna_conv_get_degree_histogram_dataloader(): + edge_index_1 = ms.Tensor([[0, 0, 0, 1, 1, 2, 3], [1, 2, 3, 2, 0, 0, 0]]) + edge_index_2 = ms.Tensor([[1, 1, 2, 2, 0, 3, 3], [2, 3, 3, 1, 1, 0, 2]]) + edge_index_3 = ms.Tensor([[1, 3, 2, 0, 0, 4, 2], [2, 0, 4, 1, 1, 0, 3]]) + edge_index_4 = ms.Tensor([[0, 1, 2, 4, 0, 1, 3], [2, 3, 3, 1, 1, 0, 2]]) + + data_1 = Graph(num_nodes=5, edge_index=edge_index_1) # hist = [1, 2 ,1 ,1] + data_2 = Graph(num_nodes=5, edge_index=edge_index_2) # hist = [1, 1, 3] + data_3 = Graph(num_nodes=5, edge_index=edge_index_3) # hist = [0, 3, 2] + data_4 = Graph(num_nodes=5, edge_index=edge_index_4) # hist = [1, 1, 3] + + loader = DataLoader( + [data_1, data_2, data_3, data_4], + batch_size=1, + shuffle=False, + ) + deg_hist = PNAConv.get_degree_histogram(loader) + assert ops.equal(deg_hist, ms.Tensor([3, 7, 9, 1])).all() diff --git a/tests/graph/nn/conv/test_point_conv.py b/tests/graph/nn/conv/test_point_conv.py new file mode 100644 index 000000000..46286f806 --- /dev/null +++ b/tests/graph/nn/conv/test_point_conv.py @@ -0,0 +1,65 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker import typing +from mindscience.sharker.nn import PointNetConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_point_net_conv(): + x1 = ops.randn(4, 16) + pos1 = ops.randn(4, 3) + pos2 = ops.randn(2, 3) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + local_nn = Seq(Lin(16 + 3, 32), ReLU(), Lin(32, 32)) + global_nn = Seq(Lin(32, 32)) + conv = PointNetConv(local_nn, global_nn) + assert str(conv) == ( + 'PointNetConv(local_nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >, global_nn=SequentialCell<\n ' + '(0): Dense\n >)' + ) + + out = conv(x1, pos1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, pos1, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, pos1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, pos1, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, pos1, adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv(x1, (pos1, pos2), edge_index) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, None), (pos1, pos2), edge_index), out).all() + assert ops.isclose(conv(x1, (pos1, pos2), adj1.t()), out, atol=1e-6).all() + assert ops.isclose(conv((x1, None), (pos1, pos2), adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv(x1, (pos1, pos2), adj2.t()), out, atol=1e-6).all() + assert ops.isclose(conv((x1, None), (pos1, pos2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + assert ops.isclose(jit((x1, None), (pos1, pos2), edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, None), (pos1, pos2), adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_point_gnn_conv.py b/tests/graph/nn/conv/test_point_gnn_conv.py new file mode 100644 index 000000000..50d6a308f --- /dev/null +++ b/tests/graph/nn/conv/test_point_gnn_conv.py @@ -0,0 +1,41 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import PointGNNConv +from mindscience.sharker.nn.models.mlp import MLP +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_point_gnn_conv(): + x = ops.randn(6, 8) + pos = ops.randn(6, 3) + edge_index = ms.Tensor([[0, 1, 1, 1, 2, 5], [1, 2, 3, 4, 3, 4]]) + adj1 = to_csr(edge_index, shape=(6, 6)) + + conv = PointGNNConv( + mlp_h=MLP([8, 16, 3]), + mlp_f=MLP([3 + 8, 16, 8]), + mlp_g=MLP([8, 16, 8]), + ) + assert str(conv) == ('PointGNNConv(\n' + ' mlp_h=MLP(8, 16, 3),\n' + ' mlp_f=MLP(11, 16, 8),\n' + ' mlp_g=MLP(8, 16, 8),\n' + ')') + + out = conv(x, pos, edge_index) + assert out.shape == (6, 8) + assert ops.isclose(conv(x, pos, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(6, 6)) + assert ops.isclose(conv(x, pos, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, pos, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, pos, adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_point_transformer_conv.py b/tests/graph/nn/conv/test_point_transformer_conv.py new file mode 100644 index 000000000..8c69c3c97 --- /dev/null +++ b/tests/graph/nn/conv/test_point_transformer_conv.py @@ -0,0 +1,69 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Linear +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Sequential + +from mindscience.sharker import typing +from mindscience.sharker.nn import PointTransformerConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_point_transformer_conv(): + x1 = ops.rand(4, 16) + x2 = ops.randn(2, 8) + pos1 = ops.rand(4, 3) + pos2 = ops.randn(2, 3) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = PointTransformerConv(in_channels=16, out_channels=32) + assert str(conv) == 'PointTransformerConv(16, 32)' + + out = conv(x1, pos1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, pos1, adj1.t()), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, pos1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, pos1, edge_index), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, pos1, adj2.t()), out, atol=1e-6).all() + + pos_nn = Sequential(Linear(3, 16), ReLU(), Linear(16, 32)) + attn_nn = Sequential(Linear(32, 32), ReLU(), Linear(32, 32)) + conv = PointTransformerConv(16, 32, pos_nn, attn_nn) + + out = conv(x1, pos1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, pos1, adj1.t()), out, atol=1e-6).all() + if typing.WITH_SPARSE: + assert ops.isclose(conv(x1, pos1, adj2.t()), out, atol=1e-6).all() + + # Test biparitite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + conv = PointTransformerConv((16, 8), 32) + assert str(conv) == 'PointTransformerConv((16, 8), 32)' + + out = conv((x1, x2), (pos1, pos2), edge_index) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, x2), (pos1, pos2), adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), (pos1, pos2), adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), (pos1, pos2), edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), (pos1, pos2), adj2.t()), out).all() diff --git a/tests/graph/nn/conv/test_ppf_conv.py b/tests/graph/nn/conv/test_ppf_conv.py new file mode 100644 index 000000000..4f1a3140e --- /dev/null +++ b/tests/graph/nn/conv/test_ppf_conv.py @@ -0,0 +1,74 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker import typing +from mindscience.sharker.nn import PPFConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_ppf_conv(): + x1 = ops.randn(4, 16) + pos1 = ops.randn(4, 3) + pos2 = ops.randn(2, 3) + n1 = ops.rand(4, 3) + n1 /= ops.norm(n1, dim=-1, keepdim=True) + n2 = ops.rand(2, 3) + n2 /= ops.norm(n2, dim=-1, keepdim=True) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + local_nn = Seq(Lin(16 + 4, 32), ReLU(), Lin(32, 32)) + global_nn = Seq(Lin(32, 32)) + conv = PPFConv(local_nn, global_nn) + assert str(conv) == ( + 'PPFConv(local_nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n (2): Dense\n >, global_nn=SequentialCell<\n ' + '(0): Dense\n >)' + ) + + out = conv(x1, pos1, n1, edge_index) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, pos1, n1, adj1.t()), out, atol=1e-3).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, pos1, n1, adj2.t()), out, atol=1e-3).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, pos1, n1, edge_index), out, atol=1e-3).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, pos1, n1, adj2.t()), out, atol=1e-3).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv(x1, (pos1, pos2), (n1, n2), edge_index) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, None), (pos1, pos2), (n1, n2), edge_index), + out, atol=1e-3).all() + assert ops.isclose(conv(x1, (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-3).all() + assert ops.isclose(conv((x1, None), (pos1, pos2), (n1, n2), adj1.t()), out, atol=1e-3).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv(x1, (pos1, pos2), (n1, n2), adj2.t()), out, + atol=1e-3).all() + assert ops.isclose( + conv((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out, atol=1e-3).all() + + if is_full_test(): + assert ops.isclose( + jit((x1, None), (pos1, pos2), (n1, n2), edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose( + jit((x1, None), (pos1, pos2), (n1, n2), adj2.t()), out, + atol=1e-3).all() diff --git a/tests/graph/nn/conv/test_res_gated_graph_conv.py b/tests/graph/nn/conv/test_res_gated_graph_conv.py new file mode 100644 index 000000000..5c6209afb --- /dev/null +++ b/tests/graph/nn/conv/test_res_gated_graph_conv.py @@ -0,0 +1,57 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import ResGatedGraphConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('edge_dim', [None, 4]) +def test_res_gated_graph_conv(edge_dim): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn(edge_index.shape[1], edge_dim) if edge_dim else None + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = ResGatedGraphConv(8, 32, edge_dim=edge_dim) + assert str(conv) == 'ResGatedGraphConv(8, 32)' + + out = conv(x1, edge_index, edge_attr) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, edge_attr), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + conv = ResGatedGraphConv((8, 32), 32, edge_dim=edge_dim) + assert str(conv) == 'ResGatedGraphConv((8, 32), 32)' + + out = conv((x1, x2), edge_index, edge_attr) + assert out.shape == (2, 32) + assert ops.isclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, edge_attr).all(), out, + atol=1e-6) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_rgat_conv.py b/tests/graph/nn/conv/test_rgat_conv.py new file mode 100644 index 000000000..5c855dd26 --- /dev/null +++ b/tests/graph/nn/conv/test_rgat_conv.py @@ -0,0 +1,126 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import RGATConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_coo + + +@pytest.mark.parametrize('mod', [ + 'additive', + 'scaled', + 'f-additive', + 'f-scaled', +]) +@pytest.mark.parametrize('attention_mechanism', [ + 'within-relation', + 'across-relation', +]) +@pytest.mark.parametrize('attention_mode', [ + 'additive-self-attention', + 'multiplicative-self-attention', +]) +@pytest.mark.parametrize('concat', [True, False]) +@pytest.mark.parametrize('edge_dim', [8, None]) +def test_rgat_conv(mod, attention_mechanism, attention_mode, concat, edge_dim): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_type = ms.Tensor([0, 2, 1, 2]) + edge_attr = ops.randn((4, edge_dim)) if edge_dim else None + + conv1 = RGATConv( # `num_bases` is not None: + in_channels=8, + out_channels=16, + num_relations=4, + num_bases=4, + mod=mod, + attention_mechanism=attention_mechanism, + attention_mode=attention_mode, + heads=2, + dim=1, + concat=concat, + edge_dim=edge_dim, + ) + + conv2 = RGATConv( # `num_blocks` is not `None` + in_channels=8, + out_channels=16, + num_relations=4, + num_blocks=4, + mod=mod, + attention_mechanism=attention_mechanism, + attention_mode=attention_mode, + heads=2, + dim=1, + concat=concat, + edge_dim=edge_dim, + ) + + conv3 = RGATConv( # Both `num_bases` and `num_blocks` are `None`: + in_channels=8, + out_channels=16, + num_relations=4, + mod=mod, + attention_mechanism=attention_mechanism, + attention_mode=attention_mode, + heads=2, + dim=1, + concat=concat, + edge_dim=edge_dim, + ) + + conv4 = RGATConv( # `dropout > 0` and `mod` is `None`: + in_channels=8, + out_channels=16, + num_relations=4, + mod=None, + attention_mechanism=attention_mechanism, + attention_mode=attention_mode, + heads=2, + dim=1, + concat=concat, + edge_dim=edge_dim, + dropout=0.5, + ) + + for conv in [conv1, conv2, conv3, conv4]: + assert str(conv) == 'RGATConv(8, 16, heads=2)' + + out = conv(x, edge_index, edge_type, edge_attr) + assert out.shape == (4, 16 * (2 if concat else 1)) + + out, (adj, alpha) = conv(x, edge_index, edge_type, edge_attr, + return_attention_weights=True) + assert out.shape == (4, 16 * (2 if concat else 1)) + assert adj.shape == edge_index.shape + assert alpha.shape == (4, 2) + + +def test_rgat_conv_jit(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn((edge_index.shape[1], 8)) + edge_type = ms.Tensor([0, 2, 1, 2]) + # adj1 = to_coo(edge_index, edge_attr, shape=(4, 4)) + + conv = RGATConv(8, 20, num_relations=4, num_bases=4, mod='additive', + attention_mechanism='across-relation', + attention_mode='additive-self-attention', heads=2, dim=1, + edge_dim=8, bias=False) + + out = conv(x, edge_index, edge_type, edge_attr) + assert out.shape == (4, 40) + # t() expects a tensor with <= 2 sparse and 0 dense dimensions + # adj1_t = adj1.swapaxes(0, 1).coalesce() + # assert ops.isclose(conv(x, adj1_t, edge_type), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, (4, 4)) + assert ops.isclose(conv(x, adj2.t(), edge_type), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index, edge_type).all(), + conv(x, edge_index, edge_type)) diff --git a/tests/graph/nn/conv/test_rgcn_conv.py b/tests/graph/nn/conv/test_rgcn_conv.py new file mode 100644 index 000000000..47d5a7044 --- /dev/null +++ b/tests/graph/nn/conv/test_rgcn_conv.py @@ -0,0 +1,136 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import FastRGCNConv, RGCNConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + +classes = [RGCNConv, FastRGCNConv] +confs = [(None, None), (2, None), (None, 2)] + + +@pytest.mark.parametrize('conf', confs) +def test_rgcn_conv_equality(conf): + num_bases, num_blocks = conf + + x1 = ops.randn(4, 4) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3], + [0, 0, 1, 0, 1, 1], + ]) + edge_type = ms.Tensor([0, 1, 1, 0, 0, 1]) + + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 0, 1, 1, 2, 2, 3], + [0, 0, 1, 0, 1, 1, 0, 0, 1, 0, 1, 1], + ]) + edge_type = ms.Tensor([0, 1, 1, 0, 0, 1, 2, 3, 3, 2, 2, 3]) + + ms.set_seed(12345) + conv1 = RGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum') + + ms.set_seed(12345) + conv2 = FastRGCNConv(4, 32, 4, num_bases, num_blocks, aggr='sum') + + out1 = conv1(x1, edge_index, edge_type) + out2 = conv2(x1, edge_index, edge_type) + out1[out1.isnan()] = 0 + out2[out2.isnan()] = 0 + assert ops.isclose(out1, out2, atol=1e-2).all() + + if num_blocks is None: + out1 = conv1(None, edge_index, edge_type) + out2 = conv2(None, edge_index, edge_type) + assert ops.isclose(out1, out2, atol=1e-2).all() + + +@pytest.mark.parametrize('cls', classes) +@pytest.mark.parametrize('conf', confs) +def test_rgcn_conv(cls, conf): + num_bases, num_blocks = conf + + x1 = ops.randn(4, 4) + x2 = ops.randn(2, 16) + idx1 = ops.arange(4) + idx2 = ops.arange(2) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3], + [0, 0, 1, 0, 1, 1], + ]) + edge_type = ms.Tensor([0, 1, 1, 0, 0, 1]) + + conv = cls(4, 32, 2, num_bases, num_blocks, aggr='sum') + assert str(conv) == f'{cls.__name__}(4, 32, num_relations=2)' + + out1 = conv(x1, edge_index, edge_type) + assert out1.shape == (4, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 4)) + assert ops.isclose(conv(x1, adj.t()), out1, atol=1e-3).all() + + if num_blocks is None: + out1 = conv(idx1, edge_index, edge_type) + out2 = conv(None, edge_index, edge_type) + + assert ops.isclose(out1, out2, 1e-3).all() + assert out2.shape == (4, 32) + if typing.WITH_SPARSE: + assert ops.isclose(conv(None, adj.t()), out2, atol=1e-3).all() + assert ops.isclose(conv(idx1, adj.t()), out2, atol=1e-3).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, edge_type), out1, atol=1e-3).all() + if num_blocks is None: + assert ops.isclose(jit(idx1, edge_index, edge_type), out2, + atol=1e-3).all() + assert ops.isclose(jit(None, edge_index, edge_type), out2, + atol=1e-3).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out1).all() + if num_blocks is None: + assert ops.isclose(jit(idx1, adj.t()), out2, atol=1e-3).all() + assert ops.isclose(jit(None, adj.t()), out2, atol=1e-3).all() + + # Test bipartite message passing: + conv = cls((4, 16), 32, 2, num_bases, num_blocks, aggr='sum') + assert str(conv) == f'{cls.__name__}((4, 16), 32, num_relations=2)' + + out1 = conv((x1, x2), edge_index, edge_type) + assert out1.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_type, (4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out1, atol=1e-3).all() + + if num_blocks is None: + out1 = conv((idx1, idx2), edge_index, edge_type) + out2 = conv((None, idx2), edge_index, edge_type) + out1[out1.isnan()] = 0 + out2[out2.isnan()] = 0 + assert out2.shape == (2, 32) + assert ops.isclose(out1, out2, atol=1e-3).all() + if typing.WITH_SPARSE: + assert ops.isclose(conv((None, idx2), adj.t()), out2, atol=1e-3).all() + assert ops.isclose(conv((idx1, idx2), adj.t()), out2, atol=1e-3).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, edge_type), out1, + atol=1e-3).all() + if num_blocks is None: + assert ops.isclose(jit((None, idx2), edge_index, edge_type), + out2, atol=1e-3).all() + assert ops.isclose(jit((idx1, idx2), edge_index, edge_type), + out2, atol=1e-3).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out1, atol=1e-3).all() + if num_blocks is None: + assert ops.isclose(jit((None, idx2), adj.t()), out2, + atol=1e-3).all() + assert ops.isclose(jit((idx1, idx2), adj.t()), out2, + atol=1e-3).all() diff --git a/tests/graph/nn/conv/test_sage_conv.py b/tests/graph/nn/conv/test_sage_conv.py new file mode 100644 index 000000000..2ab882b19 --- /dev/null +++ b/tests/graph/nn/conv/test_sage_conv.py @@ -0,0 +1,143 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import MLPAggregation, SAGEConv +from mindscience.sharker.testing import ( + assert_module, + is_full_test +) +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('project', [False, True]) +@pytest.mark.parametrize('aggr', ['mean', 'sum']) +def test_sage_conv(project, aggr): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = SAGEConv(8, 32, project=project, aggr=aggr) + assert str(conv) == f'SAGEConv(8, 32, aggr={aggr})' + + out = assert_module(conv, x, edge_index, expected_size=(4, 32)) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, size=(4, 4)), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + + conv = SAGEConv((8, 16), 32, project=project, aggr=aggr) + assert str(conv) == f'SAGEConv((8, 16), 32, aggr={aggr})' + + out1 = assert_module(conv, (x1, x2), edge_index, expected_size=(2, 32)) + out2 = assert_module(conv, (x1, None), edge_index, size=(4, 2), + expected_size=(2, 32)) + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit((x1, x2), edge_index, size=(4, 2)), out1).all() + assert ops.isclose(jit((x1, None), edge_index, size=(4, 2)), out2).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(jit((x1, x2), adj.t()), out1, atol=1e-6).all() + assert ops.isclose(jit((x1, None), adj.t()), out2, atol=1e-6).all() + + +# @pytest.mark.parametrize('project', [False, True]) +# def test_lazy_sage_conv(project): +# x = ops.randn(4, 8) +# edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + +# if project: +# with pytest.raises(ValueError, match="does not support lazy"): +# SAGEConv(-1, 32, project=project) +# else: +# conv = SAGEConv(-1, 32, project=project) +# assert str(conv) == 'SAGEConv(-1, 32, aggr=mean)' + +# out = conv(x, edge_index) +# assert out.shape == (4, 32) + + +def test_lstm_aggr_sage_conv(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = SAGEConv(8, 32, aggr='lstm') + assert str(conv) == 'SAGEConv(8, 32, aggr=lstm)' + + assert_module(conv, x, edge_index, expected_size=(4, 32), + test_edge_permutation=False) + + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 1, 0]]) + with pytest.raises(ValueError, match="'index' tensor is not sorted"): + conv(x, edge_index) + + +def test_mlp_sage_conv(): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = SAGEConv( + in_channels=8, + out_channels=32, + aggr=MLPAggregation( + in_channels=8, + out_channels=8, + max_num_elements=2, + num_layers=1, + ), + ) + + out = conv(x, edge_index) + assert out.shape == (4, 32) + + +@pytest.mark.parametrize('aggr_kwargs', [ + dict(mode='cat'), + dict(mode='proj', mode_kwargs=dict(in_channels=8, out_channels=16)), + dict(mode='attn', mode_kwargs=dict(in_channels=8, out_channels=16, + num_heads=4)), + dict(mode='sum'), +]) +def test_multi_aggr_sage_conv(aggr_kwargs): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + aggr_kwargs['aggrs_kwargs'] = [{}, {}, {}, dict(learn=True, t=1)] + conv = SAGEConv(8, 32, aggr=['mean', 'max', 'sum', 'softmax'], + aggr_kwargs=aggr_kwargs) + + assert_module(conv, x, edge_index, expected_size=(4, 32)) + + +# def test_compile_multi_aggr_sage_conv(device): +# import torch._dynamo as dynamo + +# x = ops.randn(4, 8) +# edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + +# conv = SAGEConv( +# in_channels=8, +# out_channels=32, +# aggr=['mean', 'sum', 'min', 'max', 'std'], +# ) + +# explanation = dynamo.explain(conv)(x, edge_index) +# assert explanation.graph_break_count == 0 + +# compiled_conv = torch.compile(conv) + +# expected = conv(x, edge_index) +# out = compiled_conv(x, edge_index) +# assert ops.isclose(out, expected, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_sg_conv.py b/tests/graph/nn/conv/test_sg_conv.py new file mode 100644 index 000000000..fd199f95f --- /dev/null +++ b/tests/graph/nn/conv/test_sg_conv.py @@ -0,0 +1,49 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import SGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_sg_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = SGConv(16, 32, K=10) + assert str(conv) == 'SGConv(16, 32, K=10)' + + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() + + conv.cached = True + conv(x, edge_index) + assert conv._cached_x is not None + assert ops.isclose(conv(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_signed_conv.py b/tests/graph/nn/conv/test_signed_conv.py new file mode 100644 index 000000000..78df977d4 --- /dev/null +++ b/tests/graph/nn/conv/test_signed_conv.py @@ -0,0 +1,72 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import SignedConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_signed_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv1 = SignedConv(16, 32, first_aggr=True) + assert str(conv1) == 'SignedConv(16, 32, first_aggr=True)' + + conv2 = SignedConv(32, 48, first_aggr=False) + assert str(conv2) == 'SignedConv(32, 48, first_aggr=False)' + + out1 = conv1(x, edge_index, edge_index) + assert out1.shape == (4, 64) + assert ops.isclose(conv1(x, adj1.t(), adj1.t()), out1).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv1(x, adj2.t(), adj2.t()), out1).all() + + out2 = conv2(out1, edge_index, edge_index) + assert out2.shape == (4, 96) + assert ops.isclose(conv2(out1, adj1.t(), adj1.t()), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(conv2(out1, adj2.t(), adj2.t()), out2).all() + + if is_full_test(): + jit1 = ms.jit(conv1) + jit2 = ms.jit(conv2) + assert ops.isclose(jit1(x, edge_index, edge_index), out1).all() + assert ops.isclose(jit2(out1, edge_index, edge_index), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit1(x, adj2.t(), adj2.t()), out1).all() + assert ops.isclose(jit2(out1, adj2.t(), adj2.t()), out2).all() + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + assert ops.isclose(conv1((x, x[:2]), edge_index, edge_index), out1[:2], + atol=1e-6).all() + assert ops.isclose(conv1((x, x[:2]), adj1.t(), adj1.t()), out1[:2], atol=1e-6).all() + assert ops.isclose(conv2((out1, out1[:2]), edge_index, edge_index), out2[:2], atol=1e-6).all() + assert ops.isclose(conv2((out1, out1[:2]), adj1.t(), adj1.t()), out2[:2], atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 2)) + assert ops.isclose(conv1((x, x[:2]), adj2.t(), adj2.t()), out1[:2], + atol=1e-6).all() + assert ops.isclose(conv2((out1, out1[:2]), adj2.t(), adj2.t()), + out2[:2], atol=1e-6).all() + + if is_full_test(): + assert ops.isclose(jit1((x, x[:2]), edge_index, edge_index), + out1[:2], atol=1e-6).all() + assert ops.isclose(jit2((out1, out1[:2]), edge_index, edge_index), + out2[:2], atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit1((x, x[:2]), adj2.t(), adj2.t()), + out1[:2], atol=1e-6).all() + assert ops.isclose(jit2((out1, out1[:2]), adj2.t(), adj2.t()), + out2[:2], atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_simple_conv.py b/tests/graph/nn/conv/test_simple_conv.py new file mode 100644 index 000000000..0b8e95177 --- /dev/null +++ b/tests/graph/nn/conv/test_simple_conv.py @@ -0,0 +1,58 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import SimpleConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('aggr, combine_root', [ + ('mean', None), + ('sum', 'sum'), + (['mean', 'max'], 'cat'), + ('mean', 'self_loop'), +]) +def test_simple_conv(aggr, combine_root): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 0, 1, 1]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = SimpleConv(aggr, combine_root) + assert str(conv) == 'SimpleConv()' + + num_aggrs = 1 if isinstance(aggr, str) else len(aggr) + output_size = sum([8] * num_aggrs) + (8 if combine_root == 'cat' else 0) + + out = conv(x1, edge_index) + assert out.shape == (4, output_size) + assert ops.isclose(conv(x1, edge_index, size=(4, 4)), out).all() + assert ops.isclose(conv(x1, adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index), out).all() + assert ops.isclose(jit(x1, edge_index, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out).all() + + # Test bipartite message passing: + if combine_root != 'self_loop': + adj1 = to_csr(edge_index, shape=(4, 2)) + + out = conv((x1, x2), edge_index) + assert out.shape == (2, output_size) + assert ops.isclose(conv((x1, x2), edge_index, size=(4, 2)), out).all() + assert ops.isclose(conv((x1, x2), adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, + sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out).all() diff --git a/tests/graph/nn/conv/test_spline_conv.py b/tests/graph/nn/conv/test_spline_conv.py new file mode 100644 index 000000000..dc39f8e31 --- /dev/null +++ b/tests/graph/nn/conv/test_spline_conv.py @@ -0,0 +1,84 @@ +import warnings + +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import SplineConv +from mindscience.sharker.testing import is_full_test, withPackage +from mindscience.sharker.typing import SparseTensor + + +@withPackage('torch_spline_conv') +def test_spline_conv(): + warnings.filterwarnings('ignore', '.*non-optimized CPU version.*') + + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1], 3) + + conv = SplineConv(8, 32, dim=3, kernel_size=5) + assert str(conv) == 'SplineConv(8, 32, dim=3)' + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + assert ops.isclose(conv(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x1, adj.t()), out, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x1, edge_index, value), out, atol=1e-6).all() + assert ops.isclose(jit(x1, edge_index, value, size=(4, 4)), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + conv = SplineConv((8, 16), 32, dim=3, kernel_size=5) + assert str(conv) == 'SplineConv((8, 16), 32, dim=3)' + + out1 = conv((x1, x2), edge_index, value) + assert out1.shape == (2, 32) + assert ops.isclose(conv((x1, x2), edge_index, value, (4, 2)), out1).all() + + out2 = conv((x1, None), edge_index, value, (4, 2)) + assert out2.shape == (2, 32) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, value, (4, 2)) + assert ops.isclose(conv((x1, x2), adj.t()), out1, atol=1e-6).all() + assert ops.isclose(conv((x1, None), adj.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit((x1, x2), edge_index, value), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, value, size=(4, 2)).all(), + out1, atol=1e-6) + assert ops.isclose(jit((x1, None), edge_index, value, size=(4, 2)).all(), + out2, atol=1e-6) + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj.t()), out1, atol=1e-6).all() + assert ops.isclose(jit((x1, None), adj.t()), out2, atol=1e-6).all() + + +@withPackage('torch_spline_conv') +def test_lazy_spline_conv(): + warnings.filterwarnings('ignore', '.*non-optimized CPU version.*') + + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + value = ops.rand(edge_index.shape[1], 3) + + conv = SplineConv(-1, 32, dim=3, kernel_size=5) + assert str(conv) == 'SplineConv(-1, 32, dim=3)' + out = conv(x1, edge_index, value) + assert out.shape == (4, 32) + + conv = SplineConv((-1, -1), 32, dim=3, kernel_size=5) + assert str(conv) == 'SplineConv((-1, -1), 32, dim=3)' + out = conv((x1, x2), edge_index, value) + assert out.shape == (2, 32) diff --git a/tests/graph/nn/conv/test_ssg_conv.py b/tests/graph/nn/conv/test_ssg_conv.py new file mode 100644 index 000000000..50df871af --- /dev/null +++ b/tests/graph/nn/conv/test_ssg_conv.py @@ -0,0 +1,49 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import SSGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_ssg_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = SSGConv(16, 32, alpha=0.1, K=10) + assert str(conv) == 'SSGConv(16, 32, K=10, alpha=0.1)' + + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() + + conv.cached = True + conv(x, edge_index) + assert conv._cached_h is not None + assert ops.isclose(conv(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + if typing.WITH_SPARSE: + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_static_graph.py b/tests/graph/nn/conv/test_static_graph.py new file mode 100644 index 000000000..955990872 --- /dev/null +++ b/tests/graph/nn/conv/test_static_graph.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch, Graph +from mindscience.sharker.nn import ChebConv, GCNConv, MessagePassing + + +class MyConv(MessagePassing): + def construct(self, x, edge_index): + return self.propagate(edge_index, x=x) + + +def test_static_graph(): + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + x1, x2 = ops.randn(3, 8), ops.randn(3, 8) + + data1 = Graph(edge_index=edge_index, x=x1) + data2 = Graph(edge_index=edge_index, x=x2) + batch = Batch.from_data_list([data1, data2]) + + x = ops.stack(([x1, x2]), axis=0) + for conv in [MyConv(), GCNConv(8, 16), ChebConv(8, 16, K=2)]: + out1 = conv(batch.x, batch.edge_index) + assert out1.shape[0] == 6 + conv.node_dim = 1 + out2 = conv(x, edge_index) + assert out2.shape[:2] == (2, 3) + assert ops.isclose(out1, out2.view(-1, out2.shape[-1])).all() diff --git a/tests/graph/nn/conv/test_supergat_conv.py b/tests/graph/nn/conv/test_supergat_conv.py new file mode 100644 index 000000000..cbb10ed11 --- /dev/null +++ b/tests/graph/nn/conv/test_supergat_conv.py @@ -0,0 +1,44 @@ +import pytest +import mindspore as ms +from mindspore import ops, Tensor +from mindscience.sharker import typing +from mindscience.sharker.nn import SuperGATConv +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('att_type', ['MX', 'SD']) +def test_supergat_conv(att_type): + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + conv = SuperGATConv(8, 32, heads=2, attention_type=att_type, + neg_sample_ratio=1.0, edge_sample_ratio=1.0) + conv.set_train(True) + assert str(conv) == f'SuperGATConv(8, 32, heads=2, type={att_type})' + + out = conv(x, edge_index) + assert out.shape == (4, 64) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(conv(x, adj.t()), out, atol=1e-6).all() + + # Negative samples are given: + neg_edge_index = conv.negative_sampling(edge_index, x.shape[0]) + assert ops.isclose(conv(x, edge_index, neg_edge_index), out).all() + att_loss = conv.get_attention_loss() + assert isinstance(att_loss, Tensor) and att_loss > 0 + + # Batch of graphs: + x = ops.randn(8, 8) + edge_index = ms.Tensor([[0, 1, 2, 3, 4, 5, 6, 7], + [0, 0, 1, 1, 4, 4, 5, 5]]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + out = conv(x, edge_index, batch=batch) + assert out.shape == (8, 64) + + # Batch of graphs and negative samples are given: + neg_edge_index = conv.negative_sampling(edge_index, x.shape[0], batch) + assert ops.isclose(conv(x, edge_index, neg_edge_index), out).all() + att_loss = conv.get_attention_loss() + assert isinstance(att_loss, Tensor) and att_loss > 0 diff --git a/tests/graph/nn/conv/test_tag_conv.py b/tests/graph/nn/conv/test_tag_conv.py new file mode 100644 index 000000000..547952bc3 --- /dev/null +++ b/tests/graph/nn/conv/test_tag_conv.py @@ -0,0 +1,50 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import TAGConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_tag_conv(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + value = ops.rand(edge_index.shape[1]) + adj1 = to_csr(edge_index, shape=(4, 4)) + adj2 = to_csr(edge_index, value, shape=(4, 4)) + + conv = TAGConv(16, 32) + assert str(conv) == 'TAGConv(16, 32, K=3)' + + out1 = conv(x, edge_index) + assert out1.shape == (4, 32) + assert ops.isclose(conv(x, adj1.t()), out1, atol=1e-6).all() + + out2 = conv(x, edge_index, value) + assert out2.shape == (4, 32) + assert ops.isclose(conv(x, adj2.t()), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj3 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + adj4 = SparseTensor.from_edge_index(edge_index, value, (4, 4)) + assert ops.isclose(conv(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(conv(x, adj4.t()), out2, atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out1, atol=1e-6).all() + assert ops.isclose(jit(x, edge_index, value), out2, atol=1e-6).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj3.t()), out1, atol=1e-6).all() + assert ops.isclose(jit(x, adj4.t()), out2, atol=1e-6).all() + + +def test_static_tag_conv(): + x = ops.randn(3, 4, 16) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + + conv = TAGConv(16, 32) + out = conv(x, edge_index) + assert out.shape == (3, 4, 32) diff --git a/tests/graph/nn/conv/test_transformer_conv.py b/tests/graph/nn/conv/test_transformer_conv.py new file mode 100644 index 000000000..10585bab3 --- /dev/null +++ b/tests/graph/nn/conv/test_transformer_conv.py @@ -0,0 +1,154 @@ +from typing import Optional, Tuple + +import pytest +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker import typing +from mindscience.sharker.nn import TransformerConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import Adj, SparseTensor +from mindscience.sharker.utils import to_csr + + +@pytest.mark.parametrize('edge_dim', [None, 8]) +@pytest.mark.parametrize('concat', [True, False]) +def test_transformer_conv(edge_dim, concat): + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 16) + out_channels = 32 + heads = 2 + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_attr = ops.randn(edge_index.shape[1], edge_dim) if edge_dim else None + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = TransformerConv(8, out_channels, heads, beta=True, + edge_dim=edge_dim, concat=concat) + assert str(conv) == f'TransformerConv(8, {out_channels}, heads={heads})' + + out = conv(x1, edge_index, edge_attr) + assert out.shape == (4, out_channels * (heads if concat else 1)) + assert ops.isclose(conv(x1, adj1.t(), edge_attr), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, + sparse_shape=(4, 4)) + assert ops.isclose(conv(x1, adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Adj, + edge_attr: Optional[Tensor] = None, + ) -> Tensor: + return self.conv(x, edge_index, edge_attr) + + jit = ms.jit(MyModule()) + assert ops.isclose(jit(x1, edge_index, edge_attr), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x1, adj2.t()), out, atol=1e-6).all() + + # Test `return_attention_weights`. + result = conv(x1, edge_index, edge_attr, return_attention_weights=True) + assert ops.isclose(result[0], out).all() + assert result[1][0].shape == (2, 4) + assert result[1][1].shape == (4, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + assert conv._alpha is None + + if typing.WITH_SPARSE: + result = conv(x1, adj2.t(), return_attention_weights=True) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == (4, 4, 2) and result[1].nnz() == 4 + assert conv._alpha is None + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: Tensor, + edge_attr: Optional[Tensor], + ) -> Tuple[Tensor, Tuple[Tensor, Tensor]]: + return self.conv(x, edge_index, edge_attr, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, edge_index, edge_attr) + assert ops.isclose(result[0], out).all() + assert result[1][0].shape == (2, 4) + assert result[1][1].shape == (4, 2) + assert result[1][1].min() >= 0 and result[1][1].max() <= 1 + assert conv._alpha is None + + if typing.WITH_SPARSE: + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tensor, + edge_index: SparseTensor, + ) -> Tuple[Tensor, SparseTensor]: + return self.conv(x, edge_index, + return_attention_weights=True) + + jit = ms.jit(MyModule()) + result = jit(x1, adj2.t()) + assert ops.isclose(result[0], out, atol=1e-6).all() + assert result[1].shape == (4, 4, 2) and result[1].nnz() == 4 + assert conv._alpha is None + + # Test bipartite message passing: + adj1 = to_csr(edge_index, shape=(4, 2)) + + conv = TransformerConv((8, 16), out_channels, heads=heads, beta=True, + edge_dim=edge_dim, concat=concat) + assert str(conv) == (f'TransformerConv((8, 16), {out_channels}, ' + f'heads={heads})') + + out = conv((x1, x2), edge_index, edge_attr) + assert out.shape == (2, out_channels * (heads if concat else 1)) + assert ops.isclose(conv((x1, x2), adj1.t(), edge_attr), out, atol=1e-6).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, edge_attr, + sparse_shape=(4, 2)) + assert ops.isclose(conv((x1, x2), adj2.t()), out, atol=1e-6).all() + + if is_full_test(): + + class MyModule(nn.Cell): + def __init__(self): + super().__init__() + self.conv = conv + + def construct( + self, + x: Tuple[Tensor, Tensor], + edge_index: Adj, + edge_attr: Optional[Tensor] = None, + ) -> Tensor: + return self.conv(x, edge_index, edge_attr) + + jit = ms.jit(MyModule()) + assert ops.isclose(jit((x1, x2), edge_index, edge_attr), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, x2), adj2.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_wl_conv.py b/tests/graph/nn/conv/test_wl_conv.py new file mode 100644 index 000000000..c35f8e063 --- /dev/null +++ b/tests/graph/nn/conv/test_wl_conv.py @@ -0,0 +1,31 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import WLConv +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import to_csr + + +def test_wl_conv(): + x1 = ms.Tensor([1, 0, 0, 1]) + x2 = ops.one_hot(x1, x1.max() + 1) + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + adj1 = to_csr(edge_index, shape=(4, 4)) + + conv = WLConv() + assert str(conv) == 'WLConv<>' + + out = conv(x1, edge_index) + assert out.tolist() == [0, 1, 1, 0] + assert ops.equal(conv(x2, edge_index), out).all() + assert ops.equal(conv(x1, adj1.t()), out).all() + assert ops.equal(conv(x2, adj1.t()), out).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.equal(conv(x1, adj2.t()), out).all() + assert ops.equal(conv(x2, adj2.t()), out).all() + + assert conv.histogram(out).tolist() == [[2, 2]] + assert ops.isclose(conv.histogram(out, norm=True), + ms.Tensor([[0.7071, 0.7071]])).all() diff --git a/tests/graph/nn/conv/test_wl_conv_continuous.py b/tests/graph/nn/conv/test_wl_conv_continuous.py new file mode 100644 index 000000000..a08e51f25 --- /dev/null +++ b/tests/graph/nn/conv/test_wl_conv_continuous.py @@ -0,0 +1,53 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import WLConvContinuous +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_wl_conv(): + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=ms.int64) + x = ms.Tensor([[-1], [0], [1]]).float() + + conv = WLConvContinuous() + assert str(conv) == 'WLConvContinuous()' + + out = conv(x, edge_index) + assert out.tolist() == [[-0.5], [0.0], [0.5]] + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(3, 3)) + assert ops.isclose(conv(x, adj.t()), out).all() + + if is_full_test(): + jit = ms.jit(conv) + assert ops.isclose(jit(x, edge_index), out).all() + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + + # Test bipartite message passing: + x1 = ops.randn(4, 8) + x2 = ops.randn(2, 8) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + + out1 = conv((x1, None), edge_index, edge_weight, size=(4, 2)) + assert out1.shape == (2, 8) + + out2 = conv((x1, x2), edge_index, edge_weight) + assert out2.shape == (2, 8) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, (4, 2)) + assert ops.isclose(conv((x1, None), adj.t()), out1).all() + assert ops.isclose(conv((x1, x2), adj.t()), out2).all() + + if is_full_test(): + assert ops.isclose( + jit((x1, None), edge_index, edge_weight, size=(4, 2)), out1).all() + assert ops.isclose(jit((x1, x2), edge_index, edge_weight), out2).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit((x1, None), adj.t()), out1, atol=1e-6).all() + assert ops.isclose(jit((x1, x2), adj.t()), out2, atol=1e-6).all() diff --git a/tests/graph/nn/conv/test_x_conv.py b/tests/graph/nn/conv/test_x_conv.py new file mode 100644 index 000000000..61333fab8 --- /dev/null +++ b/tests/graph/nn/conv/test_x_conv.py @@ -0,0 +1,30 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import XConv +from mindscience.sharker.testing import is_full_test + + +def test_x_conv(): + x = ops.randn(8, 16) + pos = ops.rand(8, 3) + batch = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + + conv = XConv(16, 32, dim=3, kernel_size=2, dilation=2) + assert str(conv) == 'XConv(16, 32)' + + ms.set_seed(12345) + out1 = conv(x, pos) + assert out1.shape == (8, 32) + + ms.set_seed(12345) + out2 = conv(x, pos, batch) + assert out2.shape == (8, 32) + + if is_full_test(): + jit = ms.jit(conv) + + ms.set_seed(12345) + assert ops.isclose(jit(x, pos), out1, atol=1e-6).all() + + ms.set_seed(12345) + assert ops.isclose(jit(x, pos, batch), out2, atol=1e-6).all() diff --git a/tests/graph/nn/dense/test_dense_gat_conv.py b/tests/graph/nn/dense/test_dense_gat_conv.py new file mode 100644 index 000000000..eafdf62d2 --- /dev/null +++ b/tests/graph/nn/dense/test_dense_gat_conv.py @@ -0,0 +1,66 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DenseGATConv, GATConv +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('heads', [1, 4]) +@pytest.mark.parametrize('concat', [True, False]) +def test_dense_gat_conv(heads, concat): + channels = 16 + sparse_conv = GATConv(channels, channels, heads=heads, concat=concat) + dense_conv = DenseGATConv(channels, channels, heads=heads, concat=concat) + assert str(dense_conv) == f'DenseGATConv(16, 16, heads={heads})' + + # Ensure same weights and bias: + dense_conv.lin = sparse_conv.lin + dense_conv.att_src = sparse_conv.att_src + dense_conv.att_dst = sparse_conv.att_dst + dense_conv.bias = sparse_conv.bias + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 1, 1, 2, 3, 4], [1, 0, 2, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + + x = ops.cat(([x, ops.zeros([1, channels], dtype=x.dtype)]), axis=0).view(2, 3, channels) + adj = ms.Tensor([ + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 1.0], + [0.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ]) + mask = ms.Tensor([[1, 1, 1], [1, 1, 0]]).bool() + + dense_out = dense_conv(x, adj, mask) + + if is_full_test(): + jit = ms.jit(dense_conv) + assert ops.isclose(jit(x, adj, mask), dense_out).all() + + assert dense_out[1, 2].abs().sum() == 0 + dense_out = dense_out.view(6, -1)[:-1] + assert ops.isclose(sparse_out, dense_out, atol=1e-4).all() + + +def test_dense_gat_conv_with_broadcasting(): + batch_size, num_nodes, channels = 8, 3, 16 + conv = DenseGATConv(channels, channels, heads=4) + + x = ops.randn(batch_size, num_nodes, channels) + adj = ms.Tensor([ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ]) + + assert conv(x, adj).shape == (batch_size, num_nodes, 64) + mask = ms.Tensor([1, 1, 1]).bool() + assert conv(x, adj, mask).shape == (batch_size, num_nodes, 64) diff --git a/tests/graph/nn/dense/test_dense_gcn_conv.py b/tests/graph/nn/dense/test_dense_gcn_conv.py new file mode 100644 index 000000000..441b78c18 --- /dev/null +++ b/tests/graph/nn/dense/test_dense_gcn_conv.py @@ -0,0 +1,64 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DenseGCNConv, GCNConv +from mindscience.sharker.testing import is_full_test + + +def test_dense_gcn_conv(): + channels = 16 + sparse_conv = GCNConv(channels, channels) + dense_conv = DenseGCNConv(channels, channels) + assert str(dense_conv) == 'DenseGCNConv(16, 16)' + + # Ensure same weights and bias: + dense_conv.lin.weight = sparse_conv.lin.weight + dense_conv.bias = sparse_conv.bias + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 4], + [1, 2, 0, 2, 0, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + assert sparse_out.shape == (5, channels) + + x = ops.cat(([x, ops.zeros([1, channels], dtype=x.dtype)]), axis=0).view(2, 3, channels) + adj = ms.Tensor([ + [ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ]) + mask = ms.Tensor([[1, 1, 1], [1, 1, 0]]).bool() + + dense_out = dense_conv(x, adj, mask) + assert dense_out.shape == (2, 3, channels) + + if is_full_test(): + jit = ms.jit(dense_conv) + assert ops.isclose(jit(x, adj, mask), dense_out).all() + + assert dense_out[1, 2].abs().sum() == 0 + dense_out = dense_out.view(6, channels)[:-1] + assert ops.isclose(sparse_out, dense_out, atol=1e-4).all() + + +def test_dense_gcn_conv_with_broadcasting(): + batch_size, num_nodes, channels = 8, 3, 16 + conv = DenseGCNConv(channels, channels) + + x = ops.randn(batch_size, num_nodes, channels) + adj = ms.Tensor([ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ]) + + assert conv(x, adj).shape == (batch_size, num_nodes, channels) + mask = ms.Tensor([1, 1, 1]).bool() + assert conv(x, adj, mask).shape == (batch_size, num_nodes, channels) diff --git a/tests/graph/nn/dense/test_dense_gin_conv.py b/tests/graph/nn/dense/test_dense_gin_conv.py new file mode 100644 index 000000000..533123f63 --- /dev/null +++ b/tests/graph/nn/dense/test_dense_gin_conv.py @@ -0,0 +1,71 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker.nn import DenseGINConv, GINConv +from mindscience.sharker.testing import is_full_test + + +def test_dense_gin_conv(): + channels = 16 + nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels)) + sparse_conv = GINConv(nn) + dense_conv = DenseGINConv(nn) + dense_conv = DenseGINConv(nn, train_eps=True) + assert str(dense_conv) == ( + 'DenseGINConv(nn=SequentialCell<\n ' + '(0): Dense\n ' + '(1): ReLU<>\n ' + '(2): Dense\n >)') + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 4], + [1, 2, 0, 2, 0, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + assert sparse_out.shape == (5, channels) + + x = ops.cat(([x, ops.zeros([1, channels], dtype=x.dtype)]), axis=0).view(2, 3, channels) + adj = ms.Tensor([ + [ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ]) + mask = ms.Tensor([[1, 1, 1], [1, 1, 0]]).bool() + + dense_out = dense_conv(x, adj, mask) + assert dense_out.shape == (2, 3, channels) + + if is_full_test(): + jit = ms.jit(dense_conv) + assert ops.isclose(jit(x, adj, mask), dense_out).all() + + assert dense_out[1, 2].abs().sum().item() == 0 + dense_out = dense_out.view(6, channels)[:-1] + assert ops.isclose(sparse_out, dense_out, atol=1e-04).all() + + +def test_dense_gin_conv_with_broadcasting(): + batch_size, num_nodes, channels = 8, 3, 16 + nn = Seq(Lin(channels, channels), ReLU(), Lin(channels, channels)) + conv = DenseGINConv(nn) + + x = ops.randn(batch_size, num_nodes, channels) + adj = ms.Tensor([ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ]) + + assert conv(x, adj).shape == (batch_size, num_nodes, channels) + mask = ms.Tensor([1, 1, 1]).bool() + assert conv(x, adj, mask).shape == (batch_size, num_nodes, channels) diff --git a/tests/graph/nn/dense/test_dense_graph_conv.py b/tests/graph/nn/dense/test_dense_graph_conv.py new file mode 100644 index 000000000..32061be9c --- /dev/null +++ b/tests/graph/nn/dense/test_dense_graph_conv.py @@ -0,0 +1,94 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DenseGraphConv, GraphConv +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import to_dense_adj + + +@pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) +def test_dense_graph_conv(aggr): + channels = 16 + sparse_conv = GraphConv(channels, channels, aggr=aggr) + dense_conv = DenseGraphConv(channels, channels, aggr=aggr) + assert str(dense_conv) == 'DenseGraphConv(16, 16)' + + # Ensure same weights and bias. + dense_conv.lin_rel = sparse_conv.lin_rel + dense_conv.lin_root = sparse_conv.lin_root + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 4], + [1, 2, 0, 2, 0, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + assert sparse_out.shape == (5, channels) + + adj = to_dense_adj(edge_index) + mask = ops.ones(5).bool() + + dense_out = dense_conv(x, adj, mask)[0] + + assert dense_out.shape == (5, channels) + assert ops.isclose(sparse_out, dense_out, atol=1e-04).all() + + if is_full_test(): + jit = ms.jit(dense_conv) + assert ops.isclose(jit(x, adj, mask), dense_out).all() + + +@pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) +def test_dense_graph_conv_batch(aggr): + channels = 16 + sparse_conv = GraphConv(channels, channels, aggr=aggr) + dense_conv = DenseGraphConv(channels, channels, aggr=aggr) + + # Ensure same weights and bias. + dense_conv.lin_rel = sparse_conv.lin_rel + dense_conv.lin_root = sparse_conv.lin_root + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 4], + [1, 2, 0, 2, 0, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + assert sparse_out.shape == (5, channels) + + x = ops.cat(([x, ops.zeros([1, channels], dtype=x.dtype)]), axis=0).view(2, 3, channels) + adj = ms.Tensor([ + [ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ]) + mask = ms.Tensor([[1, 1, 1], [1, 1, 0]]).bool() + + dense_out = dense_conv(x, adj, mask) + assert dense_out.shape == (2, 3, channels) + dense_out = dense_out.view(-1, channels) + + assert ops.isclose(sparse_out, dense_out[:5], atol=1e-04).all() + assert dense_out[-1].abs().sum() == 0 + + +@pytest.mark.parametrize('aggr', ['add', 'mean', 'max']) +def test_dense_graph_conv_with_broadcasting(aggr): + batch_size, num_nodes, channels = 8, 3, 16 + conv = DenseGraphConv(channels, channels, aggr=aggr) + + x = ops.randn(batch_size, num_nodes, channels) + adj = ms.Tensor([ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ]) + + assert conv(x, adj).shape == (batch_size, num_nodes, channels) + mask = ms.Tensor([1, 1, 1]).bool() + assert conv(x, adj, mask).shape == (batch_size, num_nodes, channels) diff --git a/tests/graph/nn/dense/test_dense_sage_conv.py b/tests/graph/nn/dense/test_dense_sage_conv.py new file mode 100644 index 000000000..2b12063b4 --- /dev/null +++ b/tests/graph/nn/dense/test_dense_sage_conv.py @@ -0,0 +1,64 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DenseSAGEConv, SAGEConv +from mindscience.sharker.testing import is_full_test + + +def test_dense_sage_conv(): + channels = 16 + sparse_conv = SAGEConv(channels, channels, normalize=True) + dense_conv = DenseSAGEConv(channels, channels, normalize=True) + assert str(dense_conv) == 'DenseSAGEConv(16, 16)' + + # Ensure same weights and bias. + dense_conv.lin_rel = sparse_conv.lin_l + dense_conv.lin_root = sparse_conv.lin_r + + x = ops.randn((5, channels)) + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 4], + [1, 2, 0, 2, 0, 1, 4, 3]]) + + sparse_out = sparse_conv(x, edge_index) + assert sparse_out.shape == (5, channels) + + x = ops.cat(([x, ops.zeros([1, channels], dtype=x.dtype)]), axis=0).view(2, 3, channels) + adj = ms.Tensor([ + [ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + ]) + mask = ms.Tensor([[1, 1, 1], [1, 1, 0]]).bool() + + dense_out = dense_conv(x, adj, mask) + assert dense_out.shape == (2, 3, channels) + + if is_full_test(): + jit = ms.jit(dense_conv) + assert ops.isclose(jit(x, adj, mask), dense_out).all() + + assert dense_out[1, 2].abs().sum().item() == 0 + dense_out = dense_out.view(6, channels)[:-1] + assert ops.isclose(sparse_out, dense_out, atol=1e-04).all() + + +def test_dense_sage_conv_with_broadcasting(): + batch_size, num_nodes, channels = 8, 3, 16 + conv = DenseSAGEConv(channels, channels) + + x = ops.randn(batch_size, num_nodes, channels) + adj = ms.Tensor([ + [0.0, 1.0, 1.0], + [1.0, 0.0, 1.0], + [1.0, 1.0, 0.0], + ]) + + assert conv(x, adj).shape == (batch_size, num_nodes, channels) + mask = ms.Tensor([1, 1, 1]).bool() + assert conv(x, adj, mask).shape == (batch_size, num_nodes, channels) diff --git a/tests/graph/nn/dense/test_diff_pool.py b/tests/graph/nn/dense/test_diff_pool.py new file mode 100644 index 000000000..306b0bdb2 --- /dev/null +++ b/tests/graph/nn/dense/test_diff_pool.py @@ -0,0 +1,65 @@ +from itertools import product + +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import dense_diff_pool +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import is_full_test + + +def test_dense_diff_pool(): + batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) + x = ops.randn((batch_size, num_nodes, channels)) + adj = ops.rand((batch_size, num_nodes, num_nodes)) + s = ops.randn((batch_size, num_nodes, num_clusters)) + mask = ops.randint(0, 2, (batch_size, num_nodes)).bool() + + x_out, adj_out, link_loss, ent_loss = dense_diff_pool(x, adj, s, mask) + assert x_out.shape == (2, 10, 16) + assert adj_out.shape == (2, 10, 10) + assert link_loss.item() >= 0 + assert ent_loss.item() >= 0 + + if is_full_test(): + jit = ms.jit(dense_diff_pool) + x_jit, adj_jit, link_loss, ent_loss = jit(x, adj, s, mask) + assert ops.isclose(x_jit, x_out).all() + assert ops.isclose(adj_jit, adj_out).all() + assert link_loss.item() >= 0 + assert ent_loss.item() >= 0 + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + args = parser.parse_args() + + BS = [2**i for i in range(4, 8)] + NS = [2**i for i in range(4, 8)] + FS = [2**i for i in range(5, 9)] + CS = [2**i for i in range(5, 9)] + + funcs = [] + func_names = [] + args_list = [] + for B, N, F, C in product(BS, NS, FS, CS): + x = ops.randn(B, N, F) + adj = ops.randint(0, 2, (B, N, N), dtype=x.dtype) + s = ops.randn(B, N, C) + + funcs.append(dense_diff_pool) + func_names.append(f'B={B}, N={N}, F={F}, C={C}') + args_list.append((x, adj, s)) + + benchmark( + funcs=funcs, + func_names=func_names, + args=args_list, + num_steps=50 if args.device == 'cpu' else 500, + num_warmups=10 if args.device == 'cpu' else 100, + backward=args.backward, + progress_bar=True, + ) diff --git a/tests/graph/nn/dense/test_dmon_pool.py b/tests/graph/nn/dense/test_dmon_pool.py new file mode 100644 index 000000000..933d38e9f --- /dev/null +++ b/tests/graph/nn/dense/test_dmon_pool.py @@ -0,0 +1,21 @@ +import math +from mindspore import ops +from mindscience.sharker.nn import DMoNPooling + + +def test_dmon_pooling(): + batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) + x = ops.randn((batch_size, num_nodes, channels)) + adj = ops.ones((batch_size, num_nodes, num_nodes)) + mask = ops.randint(0, 2, (batch_size, num_nodes)).bool() + + pool = DMoNPooling([channels, channels], num_clusters) + assert str(pool) == 'DMoNPooling(16, num_clusters=10)' + + s, x, adj, spectral_loss, ortho_loss, cluster_loss = pool(x, adj, mask) + assert s.shape == (2, 20, 10) + assert x.shape == (2, 10, 16) + assert adj.shape == (2, 10, 10) + assert -1 <= spectral_loss <= 0.5 + assert 0 <= ortho_loss <= math.sqrt(2) + assert 0 <= cluster_loss <= math.sqrt(num_clusters) - 1 diff --git a/tests/graph/nn/dense/test_linear.py b/tests/graph/nn/dense/test_linear.py new file mode 100644 index 000000000..5f009206e --- /dev/null +++ b/tests/graph/nn/dense/test_linear.py @@ -0,0 +1,248 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import HeteroDictLinear, HeteroLinear +from mindscience.sharker.testing import is_full_test + + +weight_inits = ['glorot', 'kaiming_uniform', None] +bias_inits = ['zeros', None] + + +# @pytest.mark.parametrize('weight', weight_inits) +# @pytest.mark.parametrize('bias', bias_inits) +# def test_linear(weight, bias, device): +# x = ops.randn(3, 4, 16) +# lin = Linear(16, 32, weight_initializer=weight, bias_initializer=bias) +# lin = lin +# assert str(lin) == 'Linear(16, 32, bias=True)' +# assert lin(x).shape == (3, 4, 32) + + +# @pytest.mark.parametrize('weight', weight_inits) +# @pytest.mark.parametrize('bias', bias_inits) +# def test_lazy_linear(weight, bias, device): +# x = ops.randn(3, 4, 16) +# lin = Linear(-1, 32, weight_initializer=weight, bias_initializer=bias) +# lin = lin +# copied_lin = copy.deepcopy(lin) + +# assert lin.weight.device == device +# assert lin.bias.device == device +# assert str(lin) == 'Linear(-1, 32, bias=True)' +# assert lin(x).shape == (3, 4, 32) +# assert str(lin) == 'Linear(16, 32, bias=True)' + +# assert copied_lin.weight.device == device +# assert copied_lin.bias.device == device +# assert copied_lin(x).shape == (3, 4, 32) + + +# @pytest.mark.parametrize('dim1', [-1, 16]) +# @pytest.mark.parametrize('dim2', [-1, 16]) +# @pytest.mark.parametrize('bias', [True, False]) +# def test_load_lazy_linear(dim1, dim2, bias, device): +# lin1 = Linear(dim1, 32, bias=bias) +# lin2 = Linear(dim2, 32, bias=bias) +# lin2.load_state_dict(lin1.state_dict()) + +# if dim1 != -1: +# assert isinstance(lin1.weight, torch.nn.Parameter) +# assert isinstance(lin2.weight, torch.nn.Parameter) +# assert ops.isclose(lin1.weight, lin2.weight).all() +# assert not hasattr(lin1, '_hook') +# assert not hasattr(lin2, '_hook') +# else: +# assert isinstance(lin1.weight, UninitializedParameter) +# assert isinstance(lin2.weight, UninitializedParameter) +# assert hasattr(lin1, '_hook') +# assert hasattr(lin2, '_hook') + +# if bias: +# assert isinstance(lin1.bias, torch.nn.Parameter) +# assert isinstance(lin2.bias, torch.nn.Parameter) +# if dim1 != -1: # Only check for equality on materialized bias: +# assert ops.isclose(lin1.bias, lin2.bias).all() +# else: +# assert lin1.bias is None +# assert lin2.bias is None + +# with pytest.raises(RuntimeError, match="in state_dict"): +# lin1.load_state_dict({}, strict=True) +# lin1.load_state_dict({}, strict=False) + + +# @pytest.mark.parametrize('lazy', [True, False]) +# def test_identical_linear_default_initialization(lazy): +# x = ops.randn(3, 4, 16) + +# ms.set_seed(12345) +# lin1 = Linear(-1 if lazy else 16, 32) +# lin1(x) + +# ms.set_seed(12345) +# lin2 = PTLinear(16, 32) + +# assert ops.equal(lin1.weight, lin2.weight).all() +# assert ops.equal(lin1.bias, lin2.bias).all() +# assert ops.isclose(lin1(x), lin2(x)).all() + + +# def test_copy_unintialized_parameter(): +# weight = UninitializedParameter() +# if typing.WITH_PT113: +# copy.deepcopy(weight) +# else: # PyTorch <= 1.12 +# with pytest.raises(Exception): +# copy.deepcopy(weight) + + +# @pytest.mark.parametrize('lazy', [True, False]) +# def test_copy_linear(lazy, device): +# lin = Linear(-1 if lazy else 16, 32) + +# copied_lin = copy.copy(lin) +# assert id(copied_lin) != id(lin) +# assert id(copied_lin.weight) == id(lin.weight) +# if not isinstance(copied_lin.weight, UninitializedParameter): +# assert copied_lin.weight.data_ptr() == lin.weight.data_ptr() +# assert id(copied_lin.bias) == id(lin.bias) +# assert copied_lin.bias.data_ptr() == lin.bias.data_ptr() + +# copied_lin = copy.deepcopy(lin) +# assert id(copied_lin) != id(lin) +# assert id(copied_lin.weight) != id(lin.weight) +# if not isinstance(copied_lin.weight, UninitializedParameter): +# assert copied_lin.weight.data_ptr() != lin.weight.data_ptr() +# assert ops.isclose(copied_lin.weight, lin.weight).all() +# assert id(copied_lin.bias) != id(lin.bias) +# assert copied_lin.bias.data_ptr() != lin.bias.data_ptr() +# if int(torch.isnan(lin.bias).sum()) == 0: +# assert ops.isclose(copied_lin.bias, lin.bias).all() + + +def test_hetero_linear(): + x = ops.randn(3, 16) + type_vec = ms.Tensor([0, 1, 2]) + + lin = HeteroLinear(16, 32, num_types=3) + assert str(lin) == 'HeteroLinear(16, 32, num_types=3, bias=True)' + + out = lin(x, type_vec) + assert out.shape == (3, 32) + if is_full_test(): + jit = ms.jit(lin) + assert ops.isclose(jit(x, type_vec), out, atol=1e-3).all() + + +def test_hetero_linear_initializer(): + lin = HeteroLinear( + 16, + 32, + num_types=3, + weight_initializer='glorot', + bias_initializer='zeros', + ) + assert ops.equal(lin.bias, ops.zeros_like(lin.bias)).all() + + +# @pytest.mark.parametrize('use_segment_matmul', [None, True, False]) +# def test_hetero_linear_amp(device, use_segment_matmul): +# warnings.filterwarnings('ignore', '.*but CUDA is not available.*') + +# old_state = sharker.backend.use_segment_matmul +# sharker.backend.use_segment_matmul = use_segment_matmul + +# x = ops.randn(3, 16) +# type_vec = ms.Tensor([0, 1, 2]) + +# lin = HeteroLinear(16, 32, num_types=3) + +# with torch.cuda.amp.autocast(): +# assert lin(x, type_vec).shape == (3, 32) + +# sharker.backend.use_segment_matmul = old_state + + +# def test_lazy_hetero_linear(device): +# x = ops.randn(3, 16) +# type_vec = ms.Tensor([0, 1, 2]) + +# lin = HeteroLinear(-1, 32, num_types=3) +# assert str(lin) == 'HeteroLinear(-1, 32, num_types=3, bias=True)' + +# out = lin(x, type_vec) +# assert out.shape == (3, 32) + + +@pytest.mark.parametrize('bias', [True, False]) +def test_hetero_dict_linear(bias): + x_dict = { + 'v': ops.randn(3, 16), + 'w': ops.randn(2, 8), + } + + lin = HeteroDictLinear({'v': 16, 'w': 8}, 32, has_bias=bias) + assert str(lin) == (f"HeteroDictLinear({{'v': 16, 'w': 8}}, 32, " + f"bias={bias})") + + out_dict = lin(x_dict) + assert len(out_dict) == 2 + assert out_dict['v'].shape == (3, 32) + assert out_dict['w'].shape == (2, 32) + + x_dict = { + 'v': ops.randn(3, 16), + 'w': ops.randn(2, 16), + } + + lin = HeteroDictLinear(16, 32, types=['v', 'w'], has_bias=bias) + assert str(lin) == (f"HeteroDictLinear({{'v': 16, 'w': 16}}, 32, " + f"bias={bias})") + + out_dict = lin(x_dict) + assert len(out_dict) == 2 + assert out_dict['v'].shape == (3, 32) + assert out_dict['w'].shape == (2, 32) + + if is_full_test(): + x_dict = { + 'v': ops.randn(3, 16), + 'w': ops.randn(2, 8), + } + + lin = HeteroDictLinear({'v': 16, 'w': 8}, 32) + + jit = ms.jit(lin) + assert len(jit(x_dict)) == 2 + + +# def test_lazy_hetero_dict_linear(device): +# x_dict = { +# 'v': ops.randn(3, 16), +# 'w': ops.randn(2, 8), +# } + +# lin = HeteroDictLinear(-1, 32, types=['v', 'w']) +# assert str(lin) == "HeteroDictLinear({'v': -1, 'w': -1}, 32, bias=True)" + +# out_dict = lin(x_dict) +# assert len(out_dict) == 2 +# assert out_dict['v'].shape == (3, 32) +# assert out_dict['w'].shape == (2, 32) + + +@pytest.mark.parametrize('type_vec', [ + ms.Tensor([0, 0, 1, 1, 2, 2]), + ms.Tensor([0, 1, 2, 0, 1, 2]), +]) +def test_hetero_linear_sort(type_vec): + x = ops.randn(type_vec.numel(), 16) + + lin = HeteroLinear(16, 32, num_types=3) + out = lin(x, type_vec) + + for i in range(type_vec.numel()): + node_type = int(type_vec[i]) + expected = x[i] @ lin.weight[node_type] + lin.bias[node_type] + assert ops.isclose(out[i], expected, atol=1e-3).all() diff --git a/tests/graph/nn/dense/test_mincut_pool.py b/tests/graph/nn/dense/test_mincut_pool.py new file mode 100644 index 000000000..bd2b2f454 --- /dev/null +++ b/tests/graph/nn/dense/test_mincut_pool.py @@ -0,0 +1,30 @@ +import math + +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import dense_mincut_pool +from mindscience.sharker.testing import is_full_test + + +def test_dense_mincut_pool(): + batch_size, num_nodes, channels, num_clusters = (2, 20, 16, 10) + x = ops.randn((batch_size, num_nodes, channels)) + adj = ops.ones((batch_size, num_nodes, num_nodes)) + s = ops.randn((batch_size, num_nodes, num_clusters)) + mask = ops.randint(0, 2, (batch_size, num_nodes)).bool() + + x_out, adj_out, mincut_loss, ortho_loss = dense_mincut_pool( + x, adj, s, mask) + assert x_out.shape == (2, 10, 16) + assert adj_out.shape == (2, 10, 10) + assert -1 <= mincut_loss <= 0 + assert 0 <= ortho_loss <= 2 + + if is_full_test(): + jit = ms.jit(dense_mincut_pool) + + x_jit, adj_jit, mincut_loss, ortho_loss = jit(x, adj, s, mask) + assert x_jit.shape == (2, 10, 16) + assert adj_jit.shape == (2, 10, 10) + assert -1 <= mincut_loss <= 0 + assert 0 <= ortho_loss <= math.sqrt(2) diff --git a/tests/graph/nn/kge/test_complex.py b/tests/graph/nn/kge/test_complex.py new file mode 100644 index 000000000..b7355fcd0 --- /dev/null +++ b/tests/graph/nn/kge/test_complex.py @@ -0,0 +1,58 @@ +import mindspore as ms +from mindscience.sharker.nn import ComplEx + + +def test_complex_scoring(): + model = ComplEx(num_nodes=5, num_relations=2, hidden_channels=1) + + model.node_emb.embedding_table[:] = ms.Tensor([ + [2.], + [3.], + [5.], + [1.], + [2.], + ]) + model.node_emb_im.embedding_table[:] = ms.Tensor([ + [4.], + [1.], + [3.], + [1.], + [2.], + ]) + model.rel_emb.embedding_table[:] = ms.Tensor([ + [2.], + [3.], + ]) + model.rel_emb_im.embedding_table[:] = ms.Tensor([ + [3.], + [1.], + ]) + + score = model( + head_index=ms.Tensor([1, 3]), + rel_type=ms.Tensor([1, 0]), + tail_index=ms.Tensor([2, 4]), + ) + assert score.tolist() == [58., 8.] + + +def test_complex(): + model = ComplEx(num_nodes=10, num_relations=5, hidden_channels=32) + assert str(model) == 'ComplEx(10, num_relations=5, hidden_channels=32)' + + head_index = ms.Tensor([0, 2, 4, 6, 8]) + rel_type = ms.Tensor([0, 1, 2, 3, 4]) + tail_index = ms.Tensor([1, 3, 5, 7, 9]) + + loader = model.loader(head_index, rel_type, tail_index, batch_size=5) + for h, r, t in loader: + out = model(h, r, t) + assert out.shape == (5, ) + + loss = model.loss(h, r, t) + assert loss >= 0. + + mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) + assert 0 <= mean_rank <= 10 + assert 0 < mrr <= 1 + assert hits == 1.0 diff --git a/tests/graph/nn/kge/test_distmult.py b/tests/graph/nn/kge/test_distmult.py new file mode 100644 index 000000000..3de72dacf --- /dev/null +++ b/tests/graph/nn/kge/test_distmult.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindscience.sharker.nn import DistMult + + +def test_distmult(): + model = DistMult(num_nodes=10, num_relations=5, hidden_channels=32) + assert str(model) == 'DistMult(10, num_relations=5, hidden_channels=32)' + + head_index = ms.Tensor([0, 2, 4, 6, 8]) + rel_type = ms.Tensor([0, 1, 2, 3, 4]) + tail_index = ms.Tensor([1, 3, 5, 7, 9]) + + loader = model.loader(head_index, rel_type, tail_index, batch_size=5) + for h, r, t in loader: + out = model(h, r, t) + assert out.shape == (5, ) + + loss = model.loss(h, r, t) + assert loss >= 0. + + mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) + assert 0 <= mean_rank <= 10 + assert 0 < mrr <= 1 + assert hits == 1.0 diff --git a/tests/graph/nn/kge/test_rotate.py b/tests/graph/nn/kge/test_rotate.py new file mode 100644 index 000000000..a493e29c5 --- /dev/null +++ b/tests/graph/nn/kge/test_rotate.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindscience.sharker.nn import RotatE + + +def test_rotate(): + model = RotatE(num_nodes=10, num_relations=5, hidden_channels=32) + assert str(model) == 'RotatE(10, num_relations=5, hidden_channels=32)' + + head_index = ms.Tensor([0, 2, 4, 6, 8]) + rel_type = ms.Tensor([0, 1, 2, 3, 4]) + tail_index = ms.Tensor([1, 3, 5, 7, 9]) + + loader = model.loader(head_index, rel_type, tail_index, batch_size=5) + for h, r, t in loader: + out = model(h, r, t) + assert out.shape == (5, ) + + loss = model.loss(h, r, t) + assert loss >= 0. + + mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) + assert 0 <= mean_rank <= 10 + assert 0 < mrr <= 1 + assert hits == 1.0 diff --git a/tests/graph/nn/kge/test_transe.py b/tests/graph/nn/kge/test_transe.py new file mode 100644 index 000000000..4e2a29542 --- /dev/null +++ b/tests/graph/nn/kge/test_transe.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindscience.sharker.nn import TransE + + +def test_transe(): + model = TransE(num_nodes=10, num_relations=5, hidden_channels=32) + assert str(model) == 'TransE(10, num_relations=5, hidden_channels=32)' + + head_index = ms.Tensor([0, 2, 4, 6, 8]) + rel_type = ms.Tensor([0, 1, 2, 3, 4]) + tail_index = ms.Tensor([1, 3, 5, 7, 9]) + + loader = model.loader(head_index, rel_type, tail_index, batch_size=5) + for h, r, t in loader: + out = model(h, r, t) + assert out.shape == (5, ) + + loss = model.loss(h, r, t) + assert loss >= 0. + + mean_rank, mrr, hits = model.test(h, r, t, batch_size=5, log=False) + assert 0 <= mean_rank <= 10 + assert 0 < mrr <= 1 + assert hits == 1.0 diff --git a/tests/graph/nn/models/test_attentive_fp.py b/tests/graph/nn/models/test_attentive_fp.py new file mode 100644 index 000000000..f0b779f9a --- /dev/null +++ b/tests/graph/nn/models/test_attentive_fp.py @@ -0,0 +1,23 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import AttentiveFP +from mindscience.sharker.testing import is_full_test + + +def test_attentive_fp(): + model = AttentiveFP(8, 16, 32, edge_dim=3, num_layers=2, num_timesteps=2) + assert str(model) == ('AttentiveFP(in_channels=8, hidden_channels=16, ' + 'out_channels=32, edge_dim=3, num_layers=2, ' + 'num_timesteps=2)') + + x = ops.randn(4, 8) + edge_index = ms.Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = ops.randn(edge_index.shape[1], 3) + batch = ms.Tensor([0, 0, 0, 0]) + + out = model(x, edge_index, edge_attr, batch) + assert out.shape == (1, 32) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(x, edge_index, edge_attr, batch), out).all() diff --git a/tests/graph/nn/models/test_autoencoder.py b/tests/graph/nn/models/test_autoencoder.py new file mode 100644 index 000000000..677b03789 --- /dev/null +++ b/tests/graph/nn/models/test_autoencoder.py @@ -0,0 +1,107 @@ +import mindspore as ms +from mindspore import ops, nn + +from mindscience.sharker.data import Graph +from mindscience.sharker.nn import ARGA, ARGVA, GAE, VGAE +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.transforms import RandomLinkSplit + + +def test_gae(): + model = GAE(encoder=lambda x: x) + model.reset_parameters() + + x = ms.Tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) + z = model.encode(x) + assert ops.isclose(z, x).all() + + adj = model.decoder.forward_all(z) + expected = ms.Tensor([ + [2.0, -1.0, 1.0], + [-1.0, 5.0, 4.0], + [1.0, 4.0, 5.0], + ]).sigmoid() + assert ops.isclose(adj, expected).all() + + edge_index = ms.Tensor([[0, 1], [1, 2]]) + value = model.decode(z, edge_index) + assert ops.isclose(value, ms.Tensor([-1.0, 4.0]).sigmoid()).all() + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit.encode(x), z).all() + assert ops.isclose(jit.decode(z, edge_index), value).all() + + edge_index = ms.Tensor([[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]]) + data = Graph(edge_index=edge_index, num_nodes=11) + transform = RandomLinkSplit(split_labels=True, + add_negative_train_samples=False) + train_data, val_data, test_data = transform(data) + + z = ops.randn(11, 16) + loss = model.recon_loss(z, train_data.pos_edge_label_index) + assert float(loss) > 0 + + auc, ap = model.test(z, val_data.pos_edge_label_index, + val_data.neg_edge_label_index) + assert auc >= 0 and auc <= 1 and ap >= 0 and ap <= 1 + + +def test_vgae(): + model = VGAE(encoder=lambda x: (x, x)) + + x = ms.Tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) + model.encode(x) + assert float(model.kl_loss()) > 0 + + model.eval() + model.encode(x) + + if is_full_test(): + jit = ms.jit(model) + jit.encode(x) + assert float(jit.kl_loss()) > 0 + + +def test_arga(): + model = ARGA(encoder=lambda x: x, discriminator=lambda x: T([0.5])) + model.reset_parameters() + + x = ms.Tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) + z = model.encode(x) + + assert float(model.reg_loss(z)) > 0 + assert float(model.discriminator_loss(z)) > 0 + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit.encode(x), z).all() + assert float(jit.reg_loss(z)) > 0 + assert float(jit.discriminator_loss(z)) > 0 + + +def test_argva(): + model = ARGVA(encoder=lambda x: (x, x), discriminator=lambda x: T([0.5])) + + x = ms.Tensor([[1.0, -1.0], [1.0, 2.0], [2.0, 1.0]]) + model.encode(x) + model.reparametrize(model.__mu__, model.__logstd__) + assert float(model.kl_loss()) > 0 + + if is_full_test(): + jit = ms.jit(model) + jit.encode(x) + jit.reparametrize(jit.__mu__, jit.__logstd__) + assert float(jit.kl_loss()) > 0 + + +def test_init(): + encoder = nn.Dense(16, 32) + decoder = nn.Dense(32, 16) + discriminator = nn.Dense(32, 1) + + GAE(encoder, decoder) + VGAE(encoder, decoder) + ARGA(encoder, discriminator, decoder) + ARGVA(encoder, discriminator, decoder) diff --git a/tests/graph/nn/models/test_basic_gnn.py b/tests/graph/nn/models/test_basic_gnn.py new file mode 100644 index 000000000..b817f67a5 --- /dev/null +++ b/tests/graph/nn/models/test_basic_gnn.py @@ -0,0 +1,409 @@ +import os +import os.path as osp +import random +import sys +import warnings + +import pytest +import mindspore as ms +from mindspore import ops, nn +from mindscience.sharker import typing +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import NeighborLoader +from mindscience.sharker.nn import SAGEConv +from mindscience.sharker.nn.models import GAT, GCN, GIN, PNA, EdgeCNN, GraphSAGE +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import ( + onlyFullTest, + onlyLinux, + # onlyNeighborSampler, + onlyOnline, +) + +out_dims = [None, 8] +dropouts = [0.0, 0.5] +acts = [None, 'leaky_relu', ops.elu, nn.ReLU()] +norms = [None, 'batch_norm', 'layer_norm'] +jks = [None, 'last', 'cat', 'max', 'lstm'] + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_gcn(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = GCN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, + act=act, norm=norm, jk=jk) + assert str(model) == f'GCN(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_graph_sage(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = GraphSAGE(8, 16, num_layers=3, out_channels=out_dim, + dropout=dropout, act=act, norm=norm, jk=jk) + assert str(model) == f'GraphSAGE(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_gin(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = GIN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, + act=act, norm=norm, jk=jk) + assert str(model) == f'GIN(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_gat(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + for v2 in [False, True]: + model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2, + dropout=dropout, act=act, norm=norm, jk=jk) + assert str(model) == f'GAT(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + model = GAT(8, 16, num_layers=3, out_channels=out_dim, v2=v2, + dropout=dropout, act=act, norm=norm, jk=jk, heads=4) + assert str(model) == f'GAT(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_pna(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + deg = ms.Tensor([0, 2, 1]) + out_channels = 16 if out_dim is None else out_dim + aggregators = ['mean', 'min', 'max', 'std', 'var', 'sum'] + scalers = [ + 'identity', 'amplification', 'attenuation', 'linear', 'inverse_linear' + ] + + model = PNA(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, + act=act, norm=norm, jk=jk, aggregators=aggregators, + scalers=scalers, deg=deg) + assert str(model) == f'PNA(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('dropout', dropouts) +@pytest.mark.parametrize('act', acts) +@pytest.mark.parametrize('norm', norms) +@pytest.mark.parametrize('jk', jks) +def test_edge_cnn(out_dim, dropout, act, norm, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = EdgeCNN(8, 16, num_layers=3, out_channels=out_dim, dropout=dropout, + act=act, norm=norm, jk=jk) + assert str(model) == f'EdgeCNN(8, {out_channels}, num_layers=3)' + assert model(x, edge_index).shape == (3, out_channels) + + +def test_jit(): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = GCN(8, 16, num_layers=2) + model = ms.jit(model) + + assert model(x, edge_index).shape == (3, 16) + + +@pytest.mark.parametrize('out_dim', out_dims) +@pytest.mark.parametrize('jk', jks) +def test_one_layer_gnn(out_dim, jk): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + out_channels = 16 if out_dim is None else out_dim + + model = GraphSAGE(8, 16, num_layers=1, out_channels=out_dim, jk=jk) + assert model(x, edge_index).shape == (3, out_channels) + + +@pytest.mark.parametrize('norm', [ + 'BatchNorm', + 'GraphNorm', + 'InstanceNorm', + 'LayerNorm', +]) +def test_batch(norm): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + batch = ms.Tensor([0, 0, 1]) + + model = GraphSAGE(8, 16, num_layers=2, norm=norm) + assert model.supports_norm_batch == (norm != 'BatchNorm') + + out = model(x, edge_index, batch=batch) + assert out.shape == (3, 16) + + if model.supports_norm_batch: + with pytest.raises(RuntimeError, match="out of bounds"): + model(x, edge_index, batch=batch, batch_size=1) + + +@onlyOnline +# @onlyNeighborSampler +@pytest.mark.parametrize('jk', [None, 'last']) +def test_basic_gnn_inference(get_dataset, jk): + dataset = get_dataset(name='Cora') + data = dataset[0] + + model = GraphSAGE(dataset.num_features, hidden_channels=16, num_layers=2, + out_channels=dataset.num_classes, jk=jk) + model.eval() + + out1 = model(data.x, data.edge_index) + assert out1.shape == (data.num_nodes, dataset.num_classes) + + loader = NeighborLoader(data, num_neighbors=[-1], batch_size=128) + out2 = model.inference(loader) + assert out1.shape == out2.shape + assert ops.isclose(out1, out2, atol=1e-4).all() + + assert 'n_id' not in data + + +@onlyFullTest +def test_compile_basic(): + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = GCN(8, 16, num_layers=3) + compiled_model = torch.compile(model) + + expected = model(x, edge_index) + out = compiled_model(x, edge_index) + assert ops.isclose(out, expected, atol=1e-6).all() + + +def test_packaging(): + warnings.filterwarnings('ignore', '.*TypedStorage is deprecated.*') + + os.makedirs(torch.hub._get_torch_home(), exist_ok=True) + + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = GraphSAGE(8, 16, num_layers=3) + path = osp.join(torch.hub._get_torch_home(), 'pyg_test_model.pt') + torch.save(model, path) + + model = torch.load(path) + with torch.no_grad(): + assert model(x, edge_index).shape == (3, 16) + + model = GraphSAGE(8, 16, num_layers=3) + path = osp.join(torch.hub._get_torch_home(), 'pyg_test_package.pt') + with torch.package.PackageExporter(path) as pe: + pe.extern('sharker.nn.**') + pe.extern('sharker.inspector') + pe.extern('sharker.utils._trim_to_layer') + pe.extern('_operator') + pe.save_pickle('models', 'model.pkl', model) + + pi = torch.package.PackageImporter(path) + model = pi.load_pickle('models', 'model.pkl') + with torch.no_grad(): + assert model(x, edge_index).shape == (3, 16) + + +@withPackage('onnx', 'onnxruntime') +def test_onnx(tmp_path): + import onnx + import onnxruntime as ort + + warnings.filterwarnings('ignore', '.*tensor to a Python boolean.*') + warnings.filterwarnings('ignore', '.*shape inference of prim::Constant.*') + + class MyModel(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = SAGEConv(8, 16) + self.conv2 = SAGEConv(16, 16) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + return x + + model = MyModel() + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 2], [1, 0, 2]]) + expected = model(x, edge_index) + assert expected.shape == (3, 16) + + path = osp.join(tmp_path, 'model.onnx') + torch.onnx.export(model, (x, edge_index), path, + input_names=('x', 'edge_index'), opset_version=16) + + model = onnx.load(path) + onnx.checker.check_model(model) + + providers = ['CPUExecutionProvider'] + ort_session = ort.InferenceSession(path, providers=providers) + + out = ort_session.run(None, { + 'x': x.numpy(), + 'edge_index': edge_index.numpy() + })[0] + out = ms.Tensor.from_numpy(out) + assert ops.isclose(out, expected, atol=1e-6).all() + + +def test_trim_to_layer(): + x = ops.randn(14, 16) + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + data = Graph(x=x, edge_index=edge_index) + + loader = NeighborLoader( + data, + num_neighbors=[1, 2, 4], + batch_size=2, + shuffle=False, + ) + batch = next(iter(loader)) + + model = GraphSAGE(in_channels=16, hidden_channels=16, num_layers=3) + out1 = model(batch.x, batch.edge_index)[:2] + assert out1.shape == (2, 16) + + out2 = model( + batch.x, + batch.edge_index, + num_sampled_nodes_per_hop=batch.num_sampled_nodes, + num_sampled_edges_per_hop=batch.num_sampled_edges, + )[:2] + assert out2.shape == (2, 16) + + assert ops.isclose(out1, out2).all() + + +@pytest.mark.parametrize('Model', [GCN, GraphSAGE, GIN, GAT, EdgeCNN, PNA]) +def test_compile_graph_breaks(Model, device): + import torch._dynamo as dynamo + + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + kwargs = {} + if Model in {GCN, GAT}: + # Adding self-loops inside the model leads to graph breaks :( + kwargs['add_self_loops'] = False + + if Model in {PNA}: # `PNA` requires additional arguments: + kwargs['aggregators'] = ['sum', 'mean', 'min', 'max', 'var', 'std'] + kwargs['scalers'] = ['identity', 'amplification', 'attenuation'] + kwargs['deg'] = ms.Tensor([1, 2, 1]) + + model = Model( + in_channels=8, + hidden_channels=16, + num_layers=2, + **kwargs, + ) + + explanation = dynamo.explain(model)(x, edge_index) + assert explanation.graph_break_count == 0 + + +def test_basic_gnn_cache(): + x = ops.randn(14, 16) + edge_index = ms.Tensor([ + [2, 3, 4, 5, 7, 7, 10, 11, 12, 13], + [0, 1, 2, 3, 2, 3, 7, 7, 7, 7], + ]) + + loader = NeighborLoader( + Graph(x=x, edge_index=edge_index), + num_neighbors=[-1], + batch_size=2, + ) + + model = GCN(in_channels=16, hidden_channels=16, num_layers=2) + model.eval() + + out1 = model.inference(loader, cache=False) + out2 = model.inference(loader, cache=True) + + assert ops.isclose(out1, out2).all() + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + parser.add_argument('--dynamic', action='store_true') + args = parser.parse_args() + + if args.dynamic: + min_num_nodes, max_num_nodes = 10_000, 15_000 + min_num_edges, max_num_edges = 200_000, 300_000 + else: + min_num_nodes, max_num_nodes = 10_000, 10_000 + min_num_edges, max_num_edges = 200_000, 200_000 + + def gen_args(): + N = random.randint(min_num_nodes, max_num_nodes) + E = random.randint(min_num_edges, max_num_edges) + + x = ops.randn(N, 64) + edge_index = ops.randint(0, N, (2, E)) + + return x, edge_index + + for Model in [GCN, GraphSAGE, GIN, EdgeCNN]: + print(f'Model: {Model.__name__}') + + model = Model(64, 64, num_layers=3) + compiled_model = ms.jit(model) + + benchmark( + funcs=[model, compiled_model], + func_names=['Vanilla', 'Compiled'], + args=gen_args, + num_steps=50 if args.device == 'cpu' else 500, + num_warmups=10 if args.device == 'cpu' else 100, + backward=args.backward, + ) diff --git a/tests/graph/nn/models/test_correct_and_smooth.py b/tests/graph/nn/models/test_correct_and_smooth.py new file mode 100644 index 000000000..c98fdb339 --- /dev/null +++ b/tests/graph/nn/models/test_correct_and_smooth.py @@ -0,0 +1,53 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn.models import CorrectAndSmooth +from mindscience.sharker.testing import noWindows +from mindscience.sharker.typing import SparseTensor + + +@noWindows +def test_correct_and_smooth(): + y_soft = ms.Tensor([0.1, 0.5, 0.4]).tile((6, 1)) + y_true = ms.Tensor([1, 0, 0, 2, 1, 1]) + edge_index = ms.Tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) + mask = ops.randint(0, 2, (6, )).bool() + + model = CorrectAndSmooth( + num_correction_layers=2, + correction_alpha=0.5, + num_smoothing_layers=2, + smoothing_alpha=0.5, + ) + assert str(model) == ('CorrectAndSmooth(\n' + ' correct: num_layers=2, alpha=0.5\n' + ' smooth: num_layers=2, alpha=0.5\n' + ' autoscale=True, scale=1.0\n' + ')') + + out = model.correct(y_soft, y_true[mask], mask, edge_index) + assert out.shape == (6, 3) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(6, 6)) + assert ops.isclose( + out, model.correct(y_soft, y_true[mask], mask, adj.t())).all() + + out = model.smooth(y_soft, y_true[mask], mask, edge_index) + assert out.shape == (6, 3) + if typing.WITH_SPARSE: + assert ops.isclose( + out, model.smooth(y_soft, y_true[mask], mask, adj.t())).all() + + # Test without autoscale: + model = CorrectAndSmooth( + num_correction_layers=2, + correction_alpha=0.5, + num_smoothing_layers=2, + smoothing_alpha=0.5, + autoscale=False, + ) + out = model.correct(y_soft, y_true[mask], mask, edge_index) + assert out.shape == (6, 3) + if typing.WITH_SPARSE: + assert ops.isclose( + out, model.correct(y_soft, y_true[mask], mask, adj.t())).all() diff --git a/tests/graph/nn/models/test_deep_graph_infomax.py b/tests/graph/nn/models/test_deep_graph_infomax.py new file mode 100644 index 000000000..0902ec084 --- /dev/null +++ b/tests/graph/nn/models/test_deep_graph_infomax.py @@ -0,0 +1,68 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GCN, DeepGraphInfomax +from mindscience.sharker.testing import is_full_test + + +def test_infomax(): + model = DeepGraphInfomax( + hidden_channels=16, + encoder=lambda x: x, + summary=lambda z, *args: z.mean(0), + corruption=lambda x: x + 1, + ) + assert str(model) == 'DeepGraphInfomax(16)' + + x = ops.ones([20, 16]) + + pos_z, neg_z, summary = model(x) + assert pos_z.shape == (20, 16) + assert neg_z.shape == (20, 16) + assert summary.shape == (16, ) + + loss = model.loss(pos_z, neg_z, summary) + assert float(loss) >= 0 + + if is_full_test(): + jit = ms.jit(model) + pos_z, neg_z, summary = jit(x) + assert pos_z.shape == (20, 16) and neg_z.shape == (20, 16) + assert summary.shape == (16, ) + + acc = model.test( + train_z=ops.ones([20, 16]), + train_y=ops.randint(0, 10, (20, )), + test_z=ops.ones([20, 16]), + test_y=ops.randint(0, 10, (20, )), + ) + assert 0 <= acc <= 1 + + +def test_infomax_predefined_model(): + def corruption(x, edge_index, edge_weight): + return ( + x[ops.shuffle(ops.arange(x.shape[0]))], + edge_index, + edge_weight, + ) + + model = DeepGraphInfomax( + hidden_channels=16, + encoder=GCN(16, 16, num_layers=2), + summary=lambda z, *args, **kwargs: z.mean(0).sigmoid(), + corruption=corruption, + ) + + x = ops.randn(4, 16) + edge_index = ms.Tensor( + [[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]] + ) + edge_weight = ops.rand(edge_index.shape[1]) + + pos_z, neg_z, summary = model(x, edge_index, edge_weight=edge_weight) + assert pos_z.shape == (4, 16) + assert neg_z.shape == (4, 16) + assert summary.shape == (16, ) + + loss = model.loss(pos_z, neg_z, summary) + assert float(loss) >= 0 diff --git a/tests/graph/nn/models/test_deepgcn.py b/tests/graph/nn/models/test_deepgcn.py new file mode 100644 index 000000000..666212198 --- /dev/null +++ b/tests/graph/nn/models/test_deepgcn.py @@ -0,0 +1,25 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindspore.nn import ReLU + +from mindscience.sharker.nn import DeepGCNLayer, GENConv, LayerNorm + + +@pytest.mark.parametrize( + 'block_tuple', + [('res+', 1), ('res', 1), ('dense', 2), ('plain', 1)], +) +@pytest.mark.parametrize('ckpt_grad', [True, False]) +def test_deepgcn(block_tuple, ckpt_grad): + block, expansion = block_tuple + x = ops.randn(3, 8) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + conv = GENConv(8, 8) + norm = LayerNorm(8) + act = ReLU() + layer = DeepGCNLayer(conv, norm, act, block=block, ckpt_grad=ckpt_grad) + assert str(layer) == f'DeepGCNLayer(block={block})' + + out = layer(x, edge_index) + assert out.shape == (3, 8 * expansion) diff --git a/tests/graph/nn/models/test_dimenet.py b/tests/graph/nn/models/test_dimenet.py new file mode 100644 index 000000000..f1f0fe2e1 --- /dev/null +++ b/tests/graph/nn/models/test_dimenet.py @@ -0,0 +1,66 @@ +import pytest +import mindspore as ms +from mindspore import ops, nn + +from mindscience.sharker.nn import DimeNet, DimeNetPlusPlus +from mindscience.sharker.nn.models.dimenet import ( + BesselBasisLayer, + Envelope, + ResidualLayer, +) +from mindscience.sharker.testing import is_full_test, withPackage + + +def test_dimenet_modules(): + env = Envelope(exponent=5) + x = ops.randn(10, 3) + assert env(x).shape == (10, 3) # Isotonic layer. + + bbl = BesselBasisLayer(5) + x = ops.randn(10, 3) + assert bbl(x).shape == (10, 3, 5) # Non-isotonic layer. + + rl = ResidualLayer(128, ops.relu) + x = ops.randn(128, 128) + assert rl(x).shape == (128, 128) # Isotonic layer. + + +@withPackage('sympy') +@pytest.mark.parametrize('Model', [DimeNet, DimeNetPlusPlus]) +def test_dimenet(Model): + z = ops.randint(1, 10, (20, )) + pos = ops.randn(20, 3) + + if Model == DimeNet: + kwargs = dict(num_bilinear=3) + else: + kwargs = dict(out_emb_channels=3, int_emb_size=5, basis_emb_size=5) + + model = Model( + hidden_channels=5, + out_channels=1, + num_blocks=5, + num_spherical=5, + num_radial=5, + **kwargs, + ) + model.reset_parameters() + + out = model(z, pos) + assert out.shape == (1, ) + + jit = ms.jit(model) + assert ops.isclose(jit(z, pos), out).all() + + if is_full_test(): + optimizer = nn.Adam(model.trainable_params(), lr=0.1) + + min_loss = float('inf') + for _ in range(100): + optimizer.zero_grad() + out = model(z, pos) + loss = ops.l1_loss(out, ms.Tensor([1.0])) + loss.backward() + optimizer.step() + min_loss = min(float(loss), min_loss) + assert min_loss < 2 diff --git a/tests/graph/nn/models/test_gnnff.py b/tests/graph/nn/models/test_gnnff.py new file mode 100644 index 000000000..65d5b7819 --- /dev/null +++ b/tests/graph/nn/models/test_gnnff.py @@ -0,0 +1,23 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GNNFF +from mindscience.sharker.testing import is_full_test + + +def test_gnnff(): + z = ops.randint(1, 10, (20, )) + pos = ops.randn(20, 3) + + model = GNNFF( + hidden_node_channels=5, + hidden_edge_channels=5, + num_layers=5, + ) + model.reset_parameters() + + out = model(z, pos) + assert out.shape == (20, 3) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(z, pos), out).all() diff --git a/tests/graph/nn/models/test_graph_mixer.py b/tests/graph/nn/models/test_graph_mixer.py new file mode 100644 index 000000000..016ae9b77 --- /dev/null +++ b/tests/graph/nn/models/test_graph_mixer.py @@ -0,0 +1,74 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.models.graph_mixer import ( + LinkEncoder, + NodeEncoder, + get_latest_k_edge_attr, +) + + +def test_node_encoder(): + x = ops.arange(4).float().view(-1, 1) + edge_index = ms.Tensor([[1, 2, 0, 0, 1, 3], [0, 0, 1, 2, 2, 2]]) + edge_time = ms.Tensor([0, 1, 1, 1, 2, 3]) + seed_time = ms.Tensor([2, 2, 2, 2]) + + encoder = NodeEncoder(time_window=2) + encoder.reset_parameters() + assert str(encoder) == 'NodeEncoder(time_window=2)' + + out = encoder(x, edge_index, edge_time, seed_time) + # Node 0 aggregates information from node 2 (excluding node 1). + # Node 1 aggregates information from node 0. + # Node 2 aggregates information from node 0 and node 1 (exluding node 3). + # Node 3 aggregates no information. + expected = ms.Tensor([ + [0 + 2], + [1 + 0], + [2 + 0.5 * (0 + 1)], + [3], + ]) + assert ops.isclose(out, expected).all() + + +def test_link_encoder(): + num_nodes = 3 + num_edges = 6 + edge_attr = ops.rand((num_edges, 10)) + edge_index = ops.randint(low=0, high=num_nodes, size=(2, num_edges)) + edge_time = ops.rand(num_edges) + seed_time = ops.ones(num_nodes) + + encoder = LinkEncoder( + k=3, + in_channels=edge_attr.shape[1], + hidden_channels=7, + out_channels=11, + time_channels=13, + ) + encoder.reset_parameters() + assert str(encoder) == ('LinkEncoder(k=3, in_channels=10, ' + 'hidden_channels=7, out_channels=11, ' + 'time_channels=13, dropout=0.0)') + + out = encoder(edge_index, edge_attr, edge_time, seed_time) + assert out.shape == (num_nodes, 11) + + +def test_latest_k_edge_attr(): + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 0], [0, 1, 0, 1, 0, 1, 2]]) + edge_time = ms.Tensor([3, 1, 2, 3, 1, 2, 3]) + edge_attr = ms.Tensor([1, -1, 3, 4, -1, 6, 7]).view(-1, 1) + + k = 2 + out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time, num_nodes=3) + expected = ms.Tensor([[[1], [3]], [[4], [6]], [[7], [0]]]) + assert out.shape == (3, 2, 1) + assert ops.equal(out, expected).all() + + k = 1 + out = get_latest_k_edge_attr(k, edge_index, edge_attr, edge_time, + num_nodes=3) + expected = ms.Tensor([[[1]], [[4]], [[7]]]) + assert out.shape == (3, 1, 1) + assert ops.equal(out, expected).all() diff --git a/tests/graph/nn/models/test_graph_unet.py b/tests/graph/nn/models/test_graph_unet.py new file mode 100644 index 000000000..672b35a03 --- /dev/null +++ b/tests/graph/nn/models/test_graph_unet.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GraphUNet +from mindscience.sharker.testing import is_full_test + + +def test_graph_unet(): + model = GraphUNet(16, 32, 8, depth=2) + out = 'GraphUNet(16, 32, 8, depth=2, pool_ratios=[0.5, 0.5])' + # todo may cause empty tensor of edge_index, which should be solved with empty tensor + # model = GraphUNet(16, 32, 8, depth=3) + # out = 'GraphUNet(16, 32, 8, depth=3, pool_ratios=[0.5, 0.5, 0.5])' + assert str(model) == out + + x = ops.randn(3, 16) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + out = model(x, edge_index) + assert out.shape == (3, 8) + + if is_full_test(): + jit = ms.jit(model) + out = jit(x, edge_index) + assert out.shape == (3, 8) diff --git a/tests/graph/nn/models/test_jumping_knowledge.py b/tests/graph/nn/models/test_jumping_knowledge.py new file mode 100644 index 000000000..5631894d3 --- /dev/null +++ b/tests/graph/nn/models/test_jumping_knowledge.py @@ -0,0 +1,40 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import JumpingKnowledge +from mindscience.sharker.testing import is_full_test + + +def test_jumping_knowledge(): + num_nodes, channels, num_layers = 100, 17, 5 + xs = list([ops.randn(num_nodes, channels) for _ in range(num_layers)]) + + model = JumpingKnowledge('cat') + assert str(model) == 'JumpingKnowledge(cat)' + + out = model(xs) + assert out.shape == (num_nodes, channels * num_layers) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(xs), out).all() + + model = JumpingKnowledge('max') + assert str(model) == 'JumpingKnowledge(max)' + + out = model(xs) + assert out.shape == (num_nodes, channels) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(xs), out).all() + + model = JumpingKnowledge('lstm', channels, num_layers) + assert str(model) == (f'JumpingKnowledge(lstm, channels=' + f'{channels}, layers={num_layers})') + + out = model(xs) + assert out.shape == (num_nodes, channels) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(xs), out).all() diff --git a/tests/graph/nn/models/test_label_prop.py b/tests/graph/nn/models/test_label_prop.py new file mode 100644 index 000000000..25ddba3ab --- /dev/null +++ b/tests/graph/nn/models/test_label_prop.py @@ -0,0 +1,31 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn.models import LabelPropagation +from mindscience.sharker.typing import SparseTensor + + +def test_label_prop(): + y = ms.Tensor([1, 0, 0, 2, 1, 1]) + edge_index = ms.Tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) + mask = ops.randint(0, 2, (6, )).bool() + + model = LabelPropagation(num_layers=2, alpha=0.5) + assert str(model) == 'LabelPropagation(num_layers=2, alpha=0.5)' + + # Test without mask: + out = model(y, edge_index) + assert out.shape == (6, 3) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(6, 6)) + assert ops.isclose(model(y, adj.t()), out).all() + + # Test with mask: + out = model(y, edge_index, mask) + assert out.shape == (6, 3) + if typing.WITH_SPARSE: + assert ops.isclose(model(y, adj.t(), mask), out).all() + + # Test post step: + out = model(y, edge_index, mask, post_step=lambda y: ops.zeros_like(y)) + assert ops.sum(out) == 0. diff --git a/tests/graph/nn/models/test_lightgcn.py b/tests/graph/nn/models/test_lightgcn.py new file mode 100644 index 000000000..4e3c249cf --- /dev/null +++ b/tests/graph/nn/models/test_lightgcn.py @@ -0,0 +1,72 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.models import LightGCN + + +@pytest.mark.parametrize('embedding_dim', [32, 64]) +@pytest.mark.parametrize('with_edge_weight', [False, True]) +@pytest.mark.parametrize('lambda_reg', [0, 1e-4]) +@pytest.mark.parametrize('alpha', [0, .25, ms.Tensor([0.4, 0.3, 0.2])]) +def test_lightgcn_ranking(embedding_dim, with_edge_weight, lambda_reg, alpha): + num_nodes = 500 + num_edges = 400 + edge_index = ops.randint(0, num_nodes, (2, num_edges)) + edge_weight = ops.rand(num_edges) if with_edge_weight else None + edge_label_index = ops.randint(0, num_nodes, (2, 100)) + + model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha) + assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)' + + pred = model(edge_index, edge_label_index, edge_weight) + assert pred.shape == (100, ) + + loss = model.recommendation_loss( + pos_edge_rank=pred[:50], + neg_edge_rank=pred[50:], + node_id=ops.unique(edge_index)[0], + lambda_reg=lambda_reg, + ) + assert loss.dim() == 0 and loss > 0 + + out = model.recommend(edge_index, edge_weight, k=2) + assert out.shape == (500, 2) + assert out.min() >= 0 and out.max() < 500 + + src_index = ops.arange(0, 250) + dst_index = ops.arange(250, 500) + + out = model.recommend(edge_index, edge_weight, src_index, dst_index, k=2) + assert out.shape == (250, 2) + assert out.min() >= 250 and out.max() < 500 + + +@pytest.mark.parametrize('embedding_dim', [32, 64]) +@pytest.mark.parametrize('with_edge_weight', [False, True]) +@pytest.mark.parametrize('alpha', [0, .25, ms.Tensor([0.4, 0.3, 0.2])]) +def test_lightgcn_link_prediction(embedding_dim, with_edge_weight, alpha): + num_nodes = 500 + num_edges = 400 + edge_index = ops.randint(0, num_nodes, (2, num_edges)) + edge_weight = ops.rand(num_edges) if with_edge_weight else None + edge_label_index = ops.randint(0, num_nodes, (2, 100)) + edge_label = ops.randint(0, 2, (edge_label_index.shape[1], )) + + model = LightGCN(num_nodes, embedding_dim, num_layers=2, alpha=alpha) + assert str(model) == f'LightGCN(500, {embedding_dim}, num_layers=2)' + + pred = model(edge_index, edge_label_index, edge_weight) + assert pred.shape == (100, ) + + loss = model.link_pred_loss(pred, edge_label) + assert loss.dim() == 0 and loss > 0 + + prob = model.predict_link(edge_index, edge_label_index, edge_weight, + prob=True) + assert prob.shape == (100, ) + assert prob.min() > 0 and prob.max() < 1 + + prob = model.predict_link(edge_index, edge_label_index, edge_weight, + prob=False) + assert prob.shape == (100, ) + assert ((prob == 0) | (prob == 1)).sum() == 100 diff --git a/tests/graph/nn/models/test_linkx.py b/tests/graph/nn/models/test_linkx.py new file mode 100644 index 000000000..0b57a222a --- /dev/null +++ b/tests/graph/nn/models/test_linkx.py @@ -0,0 +1,49 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import LINKX +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('num_edge_layers', [1, 2]) +def test_linkx(num_edge_layers): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 2], [1, 2, 3]]) + edge_weight = ops.rand(edge_index.shape[1]) + + model = LINKX(num_nodes=4, in_channels=16, hidden_channels=32, + out_channels=8, num_layers=2, + num_edge_layers=num_edge_layers) + assert str(model) == 'LINKX(num_nodes=4, in_channels=16, out_channels=8)' + + out = model(x, edge_index) + assert out.shape == (4, 8) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(4, 4)) + assert ops.isclose(out, model(x, adj.t()), atol=1e-6).all() + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(x, edge_index), out).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + + out = model(None, edge_index) + assert out.shape == (4, 8) + if typing.WITH_SPARSE: + assert ops.isclose(out, model(None, adj.t()), atol=1e-6).all() + + out = model(x, edge_index, edge_weight) + assert out.shape == (4, 8) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, edge_weight, + sparse_shape=(4, 4)) + assert ops.isclose(model(x, adj.t()), out, atol=1e-6).all() + + out = model(None, edge_index, edge_weight) + assert out.shape == (4, 8) + if typing.WITH_SPARSE: + assert ops.isclose(model(None, adj.t()), out, atol=1e-6).all() diff --git a/tests/graph/nn/models/test_mask_label.py b/tests/graph/nn/models/test_mask_label.py new file mode 100644 index 000000000..3154af8d2 --- /dev/null +++ b/tests/graph/nn/models/test_mask_label.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MaskLabel + + +def test_mask_label(): + model = MaskLabel(2, 10) + assert str(model) == 'MaskLabel()' + + x = ops.rand(4, 10) + y = ms.Tensor([1, 0, 1, 0]) + mask = ms.Tensor([False, False, True, True]) + + out = model(x, y, mask) + assert out.shape == (4, 10) + assert ops.isclose(out[~mask], x[~mask]).all() + + model = MaskLabel(2, 10, method='concat') + out = model(x, y, mask) + assert out.shape == (4, 20) + assert ops.isclose(out[:, :10], x).all() + + +def test_ratio_mask(): + mask = ms.Tensor([True, True, True, True, False, False, False, False]) + out = MaskLabel.ratio_mask(mask, 0.5) + assert out[:4].sum() <= 4 and out[4:].sum() == 0 diff --git a/tests/graph/nn/models/test_meta.py b/tests/graph/nn/models/test_meta.py new file mode 100644 index 000000000..78db643ec --- /dev/null +++ b/tests/graph/nn/models/test_meta.py @@ -0,0 +1,127 @@ +from typing import Optional + +import mindspore as ms +from mindspore import Tensor, ops, nn +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker.nn import MetaLayer +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.sparse import scatter + +count = 0 + + +def test_meta_layer(): + assert str(MetaLayer()) == ('MetaLayer(\n' + ' edge_model=None,\n' + ' node_model=None,\n' + ' global_model=None\n' + ')') + + def dummy_model(*args): + global count + count += 1 + return None + + x = ops.randn(20, 10) + edge_index = ops.randint(0, high=10, size=(2, 20), dtype=ms.int64) + + for edge_model in (dummy_model, None): + for node_model in (dummy_model, None): + for global_model in (dummy_model, None): + model = MetaLayer(edge_model, node_model, global_model) + out = model(x, edge_index) + assert isinstance(out, tuple) and len(out) == 3 + + assert count == 12 + + +def test_meta_layer_example(): + class EdgeModel(nn.Cell): + def __init__(self): + super().__init__() + self.edge_mlp = Seq(Lin(2 * 10 + 5 + 20, 5), ReLU(), Lin(5, 5)) + + def construct( + self, + src: Tensor, + dst: Tensor, + edge_attr: Optional[Tensor], + u: Optional[Tensor], + batch: Optional[Tensor], + ) -> Tensor: + assert edge_attr is not None + assert u is not None + assert batch is not None + out = ops.cat([src, dst, edge_attr, u[batch]], 1) + return self.edge_mlp(out) + + class NodeModel(nn.Cell): + def __init__(self): + super().__init__() + self.node_mlp_1 = Seq(Lin(15, 10), ReLU(), Lin(10, 10)) + self.node_mlp_2 = Seq(Lin(2 * 10 + 20, 10), ReLU(), Lin(10, 10)) + + def construct( + self, + x: Tensor, + edge_index: Tensor, + edge_attr: Optional[Tensor], + u: Optional[Tensor], + batch: Optional[Tensor], + ) -> Tensor: + assert edge_attr is not None + assert u is not None + assert batch is not None + row = edge_index[0] + col = edge_index[1] + out = ops.cat(([x[row], edge_attr]), axis=1) + out = self.node_mlp_1(out) + out = scatter(out, col, dim=0, dim_size=x.shape[0], reduce='mean') + out = ops.cat(([x, out, u[batch]]), axis=1) + return self.node_mlp_2(out) + + class GlobalModel(nn.Cell): + def __init__(self): + super().__init__() + self.global_mlp = Seq(Lin(20 + 10, 20), ReLU(), Lin(20, 20)) + + def construct( + self, + x: Tensor, + edge_index: Tensor, + edge_attr: Optional[Tensor], + u: Optional[Tensor], + batch: Optional[Tensor], + ) -> Tensor: + assert u is not None + assert batch is not None + out = ops.cat([ + u, + scatter(x, batch, dim=0, reduce='mean'), + ], axis=1) + return self.global_mlp(out) + + op = MetaLayer(EdgeModel(), NodeModel(), GlobalModel()) + + x = ops.randn(20, 10) + edge_attr = ops.randn(40, 5) + u = ops.randn(2, 20) + batch = ms.Tensor([0] * 10 + [1] * 10) + edge_index = ops.randint(0, high=10, size=(2, 20), dtype=ms.int64) + edge_index = ops.cat(([edge_index, 10 + edge_index]), axis=1) + + x_out, edge_attr_out, u_out = op(x, edge_index, edge_attr, u, batch) + assert x_out.shape == (20, 10) + assert edge_attr_out.shape == (40, 5) + assert u_out.shape == (2, 20) + + if is_full_test(): + jit = ms.jit(op) + + x_out, edge_attr_out, u_out = jit(x, edge_index, edge_attr, u, batch) + assert x_out.shape == (20, 10) + assert edge_attr_out.shape == (40, 5) + assert u_out.shape == (2, 20) diff --git a/tests/graph/nn/models/test_metapath2vec.py b/tests/graph/nn/models/test_metapath2vec.py new file mode 100644 index 000000000..2fd83cb16 --- /dev/null +++ b/tests/graph/nn/models/test_metapath2vec.py @@ -0,0 +1,61 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MetaPath2Vec + + +def test_metapath2vec(): + edge_index_dict = { + ('author', 'writes', 'paper'): + ms.Tensor([[0, 1, 1, 2], [0, 0, 1, 1]]), + ('paper', 'written_by', 'author'): + ms.Tensor([[0, 0, 1, 1], [0, 1, 1, 2]]) + } + + metapath = [ + ('author', 'writes', 'paper'), + ('paper', 'written_by', 'author'), + ] + + model = MetaPath2Vec(edge_index_dict, embedding_dim=16, metapath=metapath, + walk_length=2, context_size=2) + assert str(model) == 'MetaPath2Vec(5, 16)' + + z = model('author') + assert z.shape == (3, 16) + + z = model('paper') + assert z.shape == (2, 16) + + z = model('author', ops.arange(2)) + assert z.shape == (2, 16) + + pos_rw, neg_rw = model._sample(ops.arange(3)) + + loss = model.loss(pos_rw, neg_rw) + assert 0 <= loss.item() + + acc = model.test(ops.ones([20, 16]), ops.randint(0, 10, (20, )), + ops.ones([20, 16]), ops.randint(0, 10, (20, ))) + assert 0 <= acc and acc <= 1 + + +def test_metapath2vec_empty_edges(): + num_nodes_dict = {'a': 3, 'b': 4} + edge_index_dict = { + ('a', 'to', 'b'): ms.numpy.empty((2, 0), dtype=ms.int64), + ('b', 'to', 'a'): ms.numpy.empty((2, 0), dtype=ms.int64), + } + metapath = [('a', 'to', 'b'), ('b', 'to', 'a')] + + model = MetaPath2Vec( + edge_index_dict, + embedding_dim=16, + metapath=metapath, + walk_length=10, + context_size=7, + walks_per_node=5, + num_negative_samples=5, + num_nodes_dict=num_nodes_dict, + ) + loader = model.loader(batch_size=16, shuffle=True) + next(iter(loader)) diff --git a/tests/graph/nn/models/test_mlp.py b/tests/graph/nn/models/test_mlp.py new file mode 100644 index 000000000..75a7eccd3 --- /dev/null +++ b/tests/graph/nn/models/test_mlp.py @@ -0,0 +1,90 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MLP +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('norm', ['batch_norm', None]) +@pytest.mark.parametrize('act_first', [False, True]) +@pytest.mark.parametrize('plain_last', [False, True]) +def test_mlp(norm, act_first, plain_last): + x = ops.randn(4, 16) + + ms.set_seed(12345) + mlp = MLP( + [16, 32, 32, 64], + norm=norm, + act_first=act_first, + plain_last=plain_last, + ) + assert str(mlp) == 'MLP(16, 32, 32, 64)' + out = mlp(x) + assert out.shape == (4, 64) + if is_full_test(): + jit = ms.jit(mlp) + assert ops.isclose(jit(x), out).all() + + ms.set_seed(12345) + mlp = MLP( + 16, + hidden_channels=32, + out_channels=64, + num_layers=3, + norm=norm, + act_first=act_first, + plain_last=plain_last, + ) + assert ops.isclose(mlp(x), out).all() + + +@pytest.mark.parametrize('norm', [ + 'BatchNorm', + 'GraphNorm', + 'InstanceNorm', + 'LayerNorm', +]) +def test_batch(norm): + x = ops.randn(3, 8) + batch = ms.Tensor([0, 0, 1]) + + model = MLP( + 8, + hidden_channels=16, + out_channels=32, + num_layers=2, + norm=norm, + ) + assert model.supports_norm_batch == (norm != 'BatchNorm') + + out = model(x, batch=batch) + assert out.shape == (3, 32) + + if model.supports_norm_batch: + # with pytest.raises(RuntimeError, match="out of bounds"): + with pytest.raises(RuntimeError): + model(x, batch=batch, batch_size=1) + + +def test_mlp_return_emb(): + x = ops.randn(4, 16) + + mlp = MLP([16, 32, 1]) + + out, emb = mlp(x, return_emb=True) + assert out.shape == (4, 1) + assert emb.shape == (4, 32) + + out, emb = mlp(x, return_emb=False) + assert out.shape == (4, 1) + assert emb is None + + +@pytest.mark.parametrize('plain_last', [False, True]) +def test_fine_grained_mlp(plain_last): + mlp = MLP( + [16, 32, 32, 64], + dropout=[0.1, 0.2, 0.3], + bias=[False, True, False], + ) + assert mlp(ops.randn(4, 16)).shape == (4, 64) diff --git a/tests/graph/nn/models/test_neural_fingerprint.py b/tests/graph/nn/models/test_neural_fingerprint.py new file mode 100644 index 000000000..5dcc70c21 --- /dev/null +++ b/tests/graph/nn/models/test_neural_fingerprint.py @@ -0,0 +1,28 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import NeuralFingerprint +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +@pytest.mark.parametrize('batch', [None, ms.Tensor([0, 1, 1])]) +def test_neural_fingerprint(batch): + x = ops.randn(3, 7) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + model = NeuralFingerprint(7, 16, out_channels=5, num_layers=4) + assert str(model) == 'NeuralFingerprint(7, 5, num_layers=4)' + model.reset_parameters() + + out = model(x, edge_index, batch) + assert out.shape == (1, 5) if batch is None else (2, 5) + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(3, 3)) + assert ops.isclose(model(x, adj.t(), batch), out).all() + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(x, edge_index, batch), out).all() diff --git a/tests/graph/nn/models/test_node2vec.py b/tests/graph/nn/models/test_node2vec.py new file mode 100644 index 000000000..809dd31ab --- /dev/null +++ b/tests/graph/nn/models/test_node2vec.py @@ -0,0 +1,32 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import Node2Vec +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('p', [1.0]) +@pytest.mark.parametrize('q', [1.0, 0.5]) +def test_node2vec(p, q): + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + kwargs = dict(embedding_dim=16, walk_length=2, context_size=2, p=p, q=q) + + model = Node2Vec(edge_index, **kwargs) + assert str(model) == 'Node2Vec(3, 16)' + + assert model(ops.arange(3)).shape == (3, 16) + + pos_rw, neg_rw = model.sample(ops.arange(3)) + assert float(model.loss(pos_rw, neg_rw)) >= 0 + + acc = model.test(ops.ones(20, 16), ops.randint(0, 10, (20, )), + ops.ones(20, 16), ops.randint(0, 10, (20, ))) + assert 0 <= acc and acc <= 1 + + if is_full_test(): + jit = ms.jit(model) + + assert jit(ops.arange(3)).shape == (3, 16) + + pos_rw, neg_rw = jit.sample(ops.arange(3)) + assert float(jit.loss(pos_rw, neg_rw)) >= 0 diff --git a/tests/graph/nn/models/test_pmlp.py b/tests/graph/nn/models/test_pmlp.py new file mode 100644 index 000000000..2c2d8bc6b --- /dev/null +++ b/tests/graph/nn/models/test_pmlp.py @@ -0,0 +1,23 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.models import PMLP + + +def test_pmlp(): + x = ops.randn(4, 16) + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + pmlp = PMLP(in_channels=16, hidden_channels=32, out_channels=2, + num_layers=4) + assert str(pmlp) == 'PMLP(16, 2, num_layers=4)' + + pmlp.training = True + assert pmlp(x).shape == (4, 2) + + pmlp.training = False + assert pmlp(x, edge_index).shape == (4, 2) + + with pytest.raises(ValueError, match="'edge_index' needs to be present"): + pmlp.training = False + pmlp(x) diff --git a/tests/graph/nn/models/test_re_net.py b/tests/graph/nn/models/test_re_net.py new file mode 100644 index 000000000..1921c99e0 --- /dev/null +++ b/tests/graph/nn/models/test_re_net.py @@ -0,0 +1,65 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.datasets.icews import EventDataset +from mindscience.sharker.loader import DataLoader +from mindscience.sharker.nn import RENet +from mindscience.sharker.testing import is_full_test + + +class MyTestEventDataset(EventDataset): + def __init__(self, root, seq_len): + super().__init__(root, pre_transform=RENet.pre_transform(seq_len)) + self.load(self.processed_paths[0]) + + @property + def num_nodes(self): + return 16 + + @property + def num_rels(self): + return 8 + + @property + def processed_file_names(self): + return 'data.pt' + + def _download(self): + pass + + def process_events(self): + sub = ops.randint(0, self.num_nodes, (64, ), dtype=ms.int64) + rel = ops.randint(0, self.num_rels, (64, ), dtype=ms.int64) + obj = ops.randint(0, self.num_nodes, (64, ), dtype=ms.int64) + t = ops.arange(8, dtype=ms.int64).view(-1, 1).repeat(1, 8).view(-1) + return ops.stack(([sub, rel, obj, t]), axis=1) + + def process(self): + data_list = self._process_data_list() + self.save(data_list, self.processed_paths[0]) + + +def test_re_net(tmp_path): + dataset = MyTestEventDataset(tmp_path, seq_len=4) + loader = DataLoader(dataset, 2, follow_batch=['h_sub', 'h_obj']) + + model = RENet(dataset.num_nodes, dataset.num_rels, hidden_channels=16, + seq_len=4) + + if is_full_test(): + jit = ms.jit(model) + + logits = ops.randn(6, 6) + y = ms.Tensor([0, 1, 2, 3, 4, 5]) + + mrr, hits1, hits3, hits10 = model.test(logits, y) + assert 0.15 < mrr <= 1 + assert hits1 <= hits3 and hits3 <= hits10 and hits10 == 1 + + for data in loader: + log_prob_obj, log_prob_sub = model(data) + if is_full_test(): + log_prob_obj_jit, log_prob_sub_jit = jit(data) + assert ops.isclose(log_prob_obj_jit, log_prob_obj).all() + assert ops.isclose(log_prob_sub_jit, log_prob_sub).all() + model.test(log_prob_obj, data.obj) + model.test(log_prob_sub, data.sub) diff --git a/tests/graph/nn/models/test_rect.py b/tests/graph/nn/models/test_rect.py new file mode 100644 index 000000000..ba333e102 --- /dev/null +++ b/tests/graph/nn/models/test_rect.py @@ -0,0 +1,43 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import RECT_L +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor + + +def test_rect(): + x = ops.randn(6, 8) + y = ms.Tensor([1, 0, 0, 2, 1, 1]) + edge_index = ms.Tensor([[0, 1, 1, 2, 4, 5], [1, 0, 2, 1, 5, 4]]) + mask = ops.randint(0, 2, (6, )).bool() + + model = RECT_L(8, 16) + assert str(model) == 'RECT_L(8, 16)' + + out = model(x, edge_index) + assert out.shape == (6, 8) + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_shape=(6, 6)) + assert ops.isclose(out, model(x, adj.t()), atol=1e-6).all() + + # Test `embed`: + embed_out = model.embed(x, edge_index) + assert embed_out.shape == (6, 16) + if typing.WITH_SPARSE: + assert ops.isclose(embed_out, model.embed(x, adj.t()), atol=1e-6).all() + + # Test `get_semantic_labels`: + labels_out = model.get_semantic_labels(x, y, mask) + assert labels_out.shape == (int(mask.sum()), 8) + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(x, edge_index), out, atol=1e-6).all() + assert ops.isclose(embed_out, jit.embed(x, edge_index), atol=1e-6).all() + assert ops.isclose(labels_out, jit.get_semantic_labels(x, y, mask)).all() + + if typing.WITH_SPARSE: + assert ops.isclose(jit(x, adj.t()), out, atol=1e-6).all() + assert ops.isclose(embed_out, jit.embed(x, adj.t()), atol=1e-6).all() + assert ops.isclose(labels_out, jit.get_semantic_labels(x, y, mask)).all() diff --git a/tests/graph/nn/models/test_rev_gnn.py b/tests/graph/nn/models/test_rev_gnn.py new file mode 100644 index 000000000..99a9f924c --- /dev/null +++ b/tests/graph/nn/models/test_rev_gnn.py @@ -0,0 +1,106 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Linear +from mindscience.sharker import typing +from mindscience.sharker.nn import GraphConv, GroupAddRev, SAGEConv + + +@pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) +def test_revgnn_forward_inverse(num_groups): + x = ops.randn(4, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + lin = Linear(32, 32) + conv = SAGEConv(32 // num_groups, 32 // num_groups) + conv = GroupAddRev(conv, num_groups=num_groups) + assert str(conv) == (f'GroupAddRev(SAGEConv({32 // num_groups}, ' + f'{32 // num_groups}, aggr=mean), ' + f'num_groups={num_groups})') + + h = lin(x) + h_o = h.copy() + + out = conv(h, edge_index) + if typing.WITH_PT20: + assert h.untyped_storage().shape == 0 + else: + assert h.storage().shape == 0 + + h_rev = conv.inverse(out, edge_index) + assert ops.isclose(h_o, h_rev, atol=0.001).all() + + +@pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) +def test_revgnn_grad(num_groups): + x = ops.randn(4, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + lin = Linear(32, 32) + conv = SAGEConv(32 // num_groups, 32 // num_groups) + conv = GroupAddRev(conv, num_groups=num_groups) + + def fn(x, edge_index): + h = lin(x) + out = conv(h, edge_index) + target = out.mean() + return target + + grad_fn = ops.grad(fn) + grad_fn(x, edge_index) + + +@pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) +def test_revgnn_multi_grad(num_groups): + x = ops.randn(4, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + lin = Linear(32, 32) + conv = SAGEConv(32 // num_groups, 32 // num_groups) + conv = GroupAddRev(conv, num_groups=num_groups, num_bwd_passes=4) + + h = lin(x) + out = conv(h, edge_index) + target = out.mean() + target.backward(retain_graph=True) + target.backward(retain_graph=True) + ops.grad(outputs=target, inputs=[h] + list(conv.trainable_params()), + retain_graph=True) + ops.grad(outputs=target, inputs=[h] + list(conv.trainable_params())) + + +@pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) +def test_revgnn_diable(num_groups): + x = ops.randn(4, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + lin = Linear(32, 32) + conv = SAGEConv(32 // num_groups, 32 // num_groups) + conv = GroupAddRev(conv, num_groups=num_groups, disable=True) + + h = lin(x) + out = conv(h, edge_index) + target = out.mean() + target.backward() + + # Memory will not be freed if disable: + if typing.WITH_PT20: + assert h.untyped_storage().shape == 4 * 4 * 32 + else: + assert h.storage().shape == 4 * 32 + + +@pytest.mark.parametrize('num_groups', [2, 4, 8, 16]) +def test_revgnn_with_args(num_groups): + x = ops.randn(4, 32) + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + edge_weight = ops.rand(4) + + lin = Linear(32, 32) + conv = GraphConv(32 // num_groups, 32 // num_groups) + conv = GroupAddRev(conv, num_groups=num_groups) + + h = lin(x) + out = conv(h, edge_index, edge_weight) + target = out.mean() + target.backward() diff --git a/tests/graph/nn/models/test_schnet.py b/tests/graph/nn/models/test_schnet.py new file mode 100644 index 000000000..f685c231f --- /dev/null +++ b/tests/graph/nn/models/test_schnet.py @@ -0,0 +1,66 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch, Graph +from mindscience.sharker.nn import SchNet +from mindscience.sharker.nn.models.schnet import RadiusInteractionGraph +from mindscience.sharker.testing import is_full_test, withPackage + + +def generate_data(): + data = Graph( + z=ops.randint(1, 10, (20, )), + crd=ops.randn(20, 3), + ) + print(data.num_nodes) + return data + + +@withPackage('ase') +@pytest.mark.parametrize('use_interaction_graph', [False, True]) +@pytest.mark.parametrize('use_atomref', [False, True]) +def test_schnet(use_interaction_graph, use_atomref): + data = generate_data() + + interaction_graph = None + if use_interaction_graph: + interaction_graph = RadiusInteractionGraph(cutoff=6.0) + + model = SchNet( + hidden_channels=16, + num_filters=16, + num_interactions=2, + interaction_graph=interaction_graph, + num_gaussians=10, + cutoff=6.0, + dipole=True, + atomref=ops.randn(100, 1) if use_atomref else None, + ) + + assert str(model) == ('SchNet(hidden_channels=16, num_filters=16, ' + 'num_interactions=2, num_gaussians=10, cutoff=6.0)') + + out = model(data.z, data.crd) + assert out.shape == (1, 1) + + if is_full_test(): + jit = ms.jit(model) + out = jit(data.z, data.crd) + assert out.shape == (1, 1) + + +def test_schnet_batch(): + num_graphs = 3 + batch = [generate_data() for _ in range(num_graphs)] + batch = Batch.from_data_list(batch) + + model = SchNet( + hidden_channels=16, + num_filters=16, + num_interactions=2, + num_gaussians=10, + cutoff=6.0, + ) + + out = model(batch.z, batch.crd, batch.batch) + assert out.shape == (num_graphs, 1) diff --git a/tests/graph/nn/models/test_signed_gcn.py b/tests/graph/nn/models/test_signed_gcn.py new file mode 100644 index 000000000..1230cfebe --- /dev/null +++ b/tests/graph/nn/models/test_signed_gcn.py @@ -0,0 +1,37 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import SignedGCN +from mindscience.sharker.testing import is_full_test + + +def test_signed_gcn(): + model = SignedGCN(8, 16, num_layers=2, lamb=5) + assert str(model) == 'SignedGCN(8, 16, num_layers=2)' + + pos_index = ops.randint(0, high=10, size=(2, 40), dtype=ms.int64) + neg_index = ops.randint(0, high=10, size=(2, 40), dtype=ms.int64) + + train_pos_index, test_pos_index = model.split_edges(pos_index) + train_neg_index, test_neg_index = model.split_edges(neg_index) + + assert train_pos_index.shape == (2, 32) + assert test_pos_index.shape == (2, 8) + assert train_neg_index.shape == (2, 32) + assert test_neg_index.shape == (2, 8) + + x = model.create_spectral_features(train_pos_index, train_neg_index, 10) + assert x.shape == (10, 8) + + z = model(x, train_pos_index, train_neg_index) + assert z.shape == (10, 16) + + loss = model.loss(z, train_pos_index, train_neg_index) + assert loss.item() >= 0 + + auc, f1 = model.test(z, test_pos_index, test_neg_index) + assert auc >= 0 + assert f1 >= 0 + + if is_full_test(): + jit = ms.jit(model) + assert ops.isclose(jit(x, train_pos_index, train_neg_index), z).all() diff --git a/tests/graph/nn/models/test_tgn.py b/tests/graph/nn/models/test_tgn.py new file mode 100644 index 000000000..165ba6a3e --- /dev/null +++ b/tests/graph/nn/models/test_tgn.py @@ -0,0 +1,82 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import TemporalGraph +from mindscience.sharker.loader import TemporalDataLoader +from mindscience.sharker.nn import TGNMemory +from mindscience.sharker.nn.models.tgn import ( + IdentityMessage, + LastAggregator, + LastNeighborLoader, +) + + +@pytest.mark.parametrize('neg_sampling_ratio', [0.0, 1.0]) +def test_tgn(neg_sampling_ratio): + memory_dim = 16 + time_dim = 16 + + src = ms.Tensor([0, 1, 0, 2, 0, 3, 1, 4, 2, 3]) + dst = ms.Tensor([1, 2, 1, 1, 3, 2, 4, 3, 3, 4]) + t = ops.arange(10) + msg = ops.randn(10, 16) + data = TemporalGraph(src=src, dst=dst, t=t, msg=msg) + + loader = TemporalDataLoader( + data, + batch_size=5, + neg_sampling_ratio=neg_sampling_ratio, + ) + neighbor_loader = LastNeighborLoader(data.num_nodes, size=3) + assert neighbor_loader.cur_e_id == 0 + assert neighbor_loader.e_id.shape == (data.num_nodes, 3) + + memory = TGNMemory( + num_nodes=data.num_nodes, + raw_msg_dim=msg.shape[-1], + memory_dim=memory_dim, + time_dim=time_dim, + message_module=IdentityMessage(msg.shape[-1], memory_dim, time_dim), + aggregator_module=LastAggregator(), + ) + assert memory.memory.shape == (data.num_nodes, memory_dim) + assert memory.last_update.shape == (data.num_nodes, ) + + # Test TGNMemory training: + for i, batch in enumerate(loader): + n_id, edge_index, e_id = neighbor_loader(batch.n_id) + z, last_update = memory(n_id) + memory.update_state(batch.src, batch.dst, batch.t, batch.msg) + neighbor_loader.insert(batch.src, batch.dst) + if i == 0: + assert n_id.shape[0] >= 4 + assert edge_index.numel() == 0 + assert e_id.numel() == 0 + assert z.shape == (n_id.shape[0], memory_dim) + assert ops.sum(last_update) == 0 + else: + assert n_id.shape[0] == 5 + assert edge_index.numel() == 12 + assert e_id.numel() == 6 + assert z.shape == (n_id.shape[0], memory_dim) + assert ops.equal(last_update, ms.Tensor([4, 3, 3, 4, 0])).all() + + # Test TGNMemory inference: + memory.eval() + all_n_id = ops.arange(data.num_nodes) + z, last_update = memory(all_n_id) + assert z.shape == (data.num_nodes, memory_dim) + assert ops.equal(last_update, ms.Tensor([4, 6, 8, 9, 9])).all() + + post_src = ms.Tensor([3, 4]) + post_dst = ms.Tensor([4, 3]) + post_t = ms.Tensor([10, 10]) + post_msg = ops.randn(2, 16) + memory.update_state(post_src, post_dst, post_t, post_msg) + post_z, post_last_update = memory(all_n_id) + assert ops.isclose(z[0:3], post_z[0:3]).all() + assert ops.equal(post_last_update, ms.Tensor([4, 6, 8, 10, 10])).all() + + memory.reset_state() + assert memory.memory.sum() == 0 + assert memory.last_update.sum() == 0 diff --git a/tests/graph/nn/models/test_visnet.py b/tests/graph/nn/models/test_visnet.py new file mode 100644 index 000000000..5640798a4 --- /dev/null +++ b/tests/graph/nn/models/test_visnet.py @@ -0,0 +1,28 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import ViSNet + + +@pytest.mark.parametrize('kwargs', [ + dict(lmax=2, vecnorm_type=None, vertex=False), + dict(lmax=1, vecnorm_type='max_min', vertex=True), +]) +@pytest.mark.parametrize('derivative', [True, False]) +def test_visnet(kwargs, derivative): + z = ops.randint(1, 10, (20, )) + pos = ops.randn(20, 3) + batch = ops.zeros(20, dtype=ms.int64) + + model = ViSNet(**kwargs) + + model.reset_parameters() + + if derivative: + grad_fn = ops.value_and_grad(model, 1) + energy, forces = grad_fn(z, pos, batch) + assert forces.shape == (20, 3) + assert energy.shape == (1, 1) + else: + energy = model(z, pos, batch) + assert energy.shape == (1, 1) diff --git a/tests/graph/nn/norm/test_batch_norm.py b/tests/graph/nn/norm/test_batch_norm.py new file mode 100644 index 000000000..e58be07a8 --- /dev/null +++ b/tests/graph/nn/norm/test_batch_norm.py @@ -0,0 +1,70 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import BatchNorm, HeteroBatchNorm +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('conf', [True, False]) +def test_batch_norm(conf): + x = ops.randn(100, 16) + + norm = BatchNorm(16, affine=conf, track_running_stats=conf) + # norm.reset_running_stats() + norm.reset_parameters() + assert str(norm) == 'BatchNorm(16)' + + if is_full_test(): + ms.jit(norm) + + out = norm(x) + assert out.shape == (100, 16) + + +def test_batch_norm_single_element(): + x = ops.randn(1, 16) + + norm = BatchNorm(16) + with pytest.raises(ValueError, match="Expected more than 1 value"): + norm(x) + + with pytest.raises(ValueError, match="requires 'track_running_stats'"): + norm = BatchNorm(16, track_running_stats=False, + allow_single_element=True) + + norm = BatchNorm(16, track_running_stats=True, allow_single_element=True) + out = norm(x) + assert ops.isclose(out, x).all() + + +@pytest.mark.parametrize('conf', [True, False]) +def test_hetero_batch_norm(conf): + x = ops.randn((100, 16)) + + # Test single type: + norm = BatchNorm(16, affine=conf, track_running_stats=conf) + expected = norm(x) + + type_vec = ops.zeros(100, dtype=ms.int64) + norm = HeteroBatchNorm(16, num_types=1, affine=conf, + track_running_stats=conf) + norm.reset_running_stats() + norm.reset_parameters() + assert str(norm) == 'HeteroBatchNorm(16, num_types=1)' + + out = norm(x, type_vec) + assert out.shape == (100, 16) + assert ops.isclose(out, expected, atol=1e-3).all() + + # Test multiple types: + type_vec = ops.randint(0, 5, (100, )) + norm = HeteroBatchNorm(16, num_types=5, affine=conf, + track_running_stats=conf) + out = norm(x, type_vec) + assert out.shape == (100, 16) + + for i in range(5): # Check that mean=0 and std=1 across all types: + mean = out[type_vec == i].mean() + std = out[type_vec == i].std(unbiased=False) + assert ops.isclose(mean, ops.zeros_like(mean), atol=1e-7).all() + assert ops.isclose(std, ops.ones_like(std), atol=1e-7).all() diff --git a/tests/graph/nn/norm/test_diff_group_norm.py b/tests/graph/nn/norm/test_diff_group_norm.py new file mode 100644 index 000000000..6c9795109 --- /dev/null +++ b/tests/graph/nn/norm/test_diff_group_norm.py @@ -0,0 +1,38 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import DiffGroupNorm +from mindscience.sharker.testing import is_full_test + + +def test_diff_group_norm(): + x = ops.randn(6, 16) + + norm = DiffGroupNorm(16, groups=4, lamda=0) + assert str(norm) == 'DiffGroupNorm(16, groups=4)' + + assert ops.isclose(norm(x), x).all() + + if is_full_test(): + jit = ms.jit(norm) + assert ops.isclose(jit(x), x).all() + + norm = DiffGroupNorm(16, groups=4, lamda=0.01) + assert str(norm) == 'DiffGroupNorm(16, groups=4)' + + out = norm(x) + assert out.shape == x.shape + + if is_full_test(): + jit = ms.jit(norm) + assert ops.isclose(jit(x), out).all() + + +def test_group_distance_ratio(): + x = ops.randn(6, 16) + y = ms.Tensor([0, 1, 0, 1, 1, 1]) + + assert DiffGroupNorm.group_distance_ratio(x, y) > 0 + + if is_full_test(): + jit = ms.jit(DiffGroupNorm.group_distance_ratio) + assert jit(x, y) > 0 diff --git a/tests/graph/nn/norm/test_graph_norm.py b/tests/graph/nn/norm/test_graph_norm.py new file mode 100644 index 000000000..b1c79c17b --- /dev/null +++ b/tests/graph/nn/norm/test_graph_norm.py @@ -0,0 +1,26 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GraphNorm +from mindscience.sharker.testing import is_full_test + + +def test_graph_norm(): + ms.set_seed(42) + x = ops.randn(200, 16) + batch = ops.arange(4).view(-1, 1).tile((1, 50)).view(-1) + + norm = GraphNorm(16) + assert str(norm) == 'GraphNorm(16)' + + if is_full_test(): + ms.jit(norm) + + out = norm(x) + assert out.shape == (200, 16) + assert ops.isclose(out.mean(0), ops.zeros(16), atol=1e-6).all() + assert ops.isclose(out.std(0), ops.ones(16), atol=1e-6).all() + + out = norm(x, batch) + assert out.shape == (200, 16) + assert ops.isclose(out[:50].mean(0), ops.zeros(16), atol=1e-6).all() + assert ops.isclose(out[:50].std(0), ops.ones(16), atol=1e-6).all() diff --git a/tests/graph/nn/norm/test_graph_size_norm.py b/tests/graph/nn/norm/test_graph_size_norm.py new file mode 100644 index 000000000..9e7cec3dc --- /dev/null +++ b/tests/graph/nn/norm/test_graph_size_norm.py @@ -0,0 +1,19 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import GraphSizeNorm +from mindscience.sharker.testing import is_full_test + + +def test_graph_size_norm(): + x = ops.randn(100, 16) + batch = ops.arange(10).repeat(10) + + norm = GraphSizeNorm() + assert str(norm) == 'GraphSizeNorm()' + + out = norm(x, batch) + assert out.shape == (100, 16) + + if is_full_test(): + jit = ms.jit(norm) + assert ops.isclose(jit(x, batch), out).all() diff --git a/tests/graph/nn/norm/test_instance_norm.py b/tests/graph/nn/norm/test_instance_norm.py new file mode 100644 index 000000000..4f9463b5b --- /dev/null +++ b/tests/graph/nn/norm/test_instance_norm.py @@ -0,0 +1,52 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import InstanceNorm +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('conf', [True, False]) +def test_instance_norm(conf): + batch = ops.zeros(100, dtype=ms.int64) + + x1 = ops.randn(100, 16) + x2 = ops.randn(100, 16) + + norm1 = InstanceNorm(16, affine=conf, track_running_stats=conf) + norm2 = InstanceNorm(16, affine=conf, track_running_stats=conf) + assert str(norm1) == 'InstanceNorm(16)' + + if is_full_test(): + ms.jit(norm1) + + out1 = norm1(x1) + out2 = norm2(x1, batch) + assert out1.shape == (100, 16) + assert ops.isclose(out1, out2, atol=1e-7).all() + if conf: + assert ops.isclose(norm1.running_mean, norm2.running_mean).all() + assert ops.isclose(norm1.running_var, norm2.running_var).all() + + out1 = norm1(x2) + out2 = norm2(x2, batch) + assert ops.isclose(out1, out2, atol=1e-7).all() + if conf: + assert ops.isclose(norm1.running_mean, norm2.running_mean).all() + assert ops.isclose(norm1.running_var, norm2.running_var).all() + + norm1.eval() + norm2.eval() + + out1 = norm1(x1) + out2 = norm2(x1, batch) + assert ops.isclose(out1, out2, atol=1e-7).all() + + out1 = norm1(x2) + out2 = norm2(x2, batch) + assert ops.isclose(out1, out2, atol=1e-7).all() + + out1 = norm2(x1) + out2 = norm2(x2) + out3 = norm2(ops.cat(([x1, x2]), axis=0), ops.cat([batch, batch + 1])) + assert ops.isclose(out1, out3[:100], atol=1e-7).all() + assert ops.isclose(out2, out3[100:], atol=1e-7).all() diff --git a/tests/graph/nn/norm/test_layer_norm.py b/tests/graph/nn/norm/test_layer_norm.py new file mode 100644 index 000000000..c75bef343 --- /dev/null +++ b/tests/graph/nn/norm/test_layer_norm.py @@ -0,0 +1,67 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import HeteroLayerNorm, LayerNorm +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('affine', [True, False]) +@pytest.mark.parametrize('mode', ['graph', 'node']) +def test_layer_norm(affine, mode): + x = ops.randn(100, 16) + batch = ops.zeros(100, dtype=ms.int64) + + norm = LayerNorm(16, affine=affine, mode=mode) + assert str(norm) == f'LayerNorm(16, affine={affine}, mode={mode})' + + if is_full_test(): + ms.jit(norm) + + out1 = norm(x) + assert out1.shape == (100, 16) + assert ops.isclose(norm(x, batch), out1, atol=1e-6).all() + + out2 = norm(ops.cat(([x, x]), axis=0), ops.cat(([batch, batch + 1]), axis=0)) + assert ops.isclose(out1, out2[:100], atol=1e-6).all() + assert ops.isclose(out1, out2[100:], atol=1e-6).all() + + +@pytest.mark.parametrize('affine', [False, True]) +def test_hetero_layer_norm(affine): + x = ops.randn((100, 16)) + model = LayerNorm(16, affine=affine, mode='node') + expected = model(x) + # Test single type: + type_vec = ops.zeros(100, dtype=ms.int64) + type_ptr = [0, 100] + + norm = HeteroLayerNorm(16, num_types=1, affine=affine) + assert str(norm) == 'HeteroLayerNorm(16, num_types=1)' + + out = norm(x, type_vec) + assert out.shape == (100, 16) + assert ops.isclose(out, expected, atol=1e-3).all() + assert ops.isclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3).all() + + mean = out.mean(-1) + std = out.std(-1) + assert ops.isclose(mean, ops.zeros_like(mean), atol=1e-2).all() + assert ops.isclose(std, ops.ones_like(std), atol=1e-2).all() + + # Test multiple types: + type_vec = ops.arange(5) + type_vec = type_vec.view(-1, 1).tile((1, 20)).view(-1) + type_ptr = [0, 20, 40, 60, 80, 100] + + norm = HeteroLayerNorm(16, num_types=5, affine=affine) + assert str(norm) == 'HeteroLayerNorm(16, num_types=5)' + + out = norm(x, type_vec) + assert out.shape == (100, 16) + assert ops.isclose(out, expected, atol=1e-3).all() + assert ops.isclose(norm(out, type_ptr=type_ptr), expected, atol=1e-3).all() + + mean = out.mean(-1) + std = out.std(-1) + assert ops.isclose(mean, ops.zeros_like(mean), atol=1e-2).all() + assert ops.isclose(std, ops.ones_like(std), atol=1e-2).all() diff --git a/tests/graph/nn/norm/test_mean_subtraction_norm.py b/tests/graph/nn/norm/test_mean_subtraction_norm.py new file mode 100644 index 000000000..7067b0416 --- /dev/null +++ b/tests/graph/nn/norm/test_mean_subtraction_norm.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MeanSubtractionNorm +from mindscience.sharker.testing import is_full_test + + +def test_mean_subtraction_norm(): + x = ops.randn(6, 16) + batch = ms.Tensor([0, 0, 1, 1, 1, 2]) + + norm = MeanSubtractionNorm() + assert str(norm) == 'MeanSubtractionNorm()' + + if is_full_test(): + ms.jit(norm) + + out = norm(x) + assert out.shape == (6, 16) + assert ops.isclose(out.mean(), ms.Tensor(0.), atol=1e-6).all() + + out = norm(x, batch) + assert out.shape == (6, 16) + assert ops.isclose(out[0:2].mean(), ms.Tensor(0.), atol=1e-6).all() + assert ops.isclose(out[0:2].mean(), ms.Tensor(0.), atol=1e-6).all() diff --git a/tests/graph/nn/norm/test_msg_norm.py b/tests/graph/nn/norm/test_msg_norm.py new file mode 100644 index 000000000..59ecc471c --- /dev/null +++ b/tests/graph/nn/norm/test_msg_norm.py @@ -0,0 +1,26 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MessageNorm +from mindscience.sharker.testing import is_full_test + + +def test_message_norm(): + norm = MessageNorm(learn_scale=True) + assert str(norm) == 'MessageNorm(learn_scale=True)' + x = ops.randn(100, 16) + msg = ops.randn(100, 16) + out = norm(x, msg) + assert out.shape == (100, 16) + + if is_full_test(): + jit = ms.jit(norm) + assert ops.isclose(jit(x, msg), out).all() + + norm = MessageNorm(learn_scale=False) + assert str(norm) == 'MessageNorm(learn_scale=False)' + out = norm(x, msg) + assert out.shape == (100, 16) + + if is_full_test(): + jit = ms.jit(norm) + assert ops.isclose(jit(x, msg), out).all() diff --git a/tests/graph/nn/norm/test_pair_norm.py b/tests/graph/nn/norm/test_pair_norm.py new file mode 100644 index 000000000..9bbd00944 --- /dev/null +++ b/tests/graph/nn/norm/test_pair_norm.py @@ -0,0 +1,24 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import PairNorm +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('scale_individually', [False, True]) +def test_pair_norm(scale_individually): + x = ops.randn(100, 16) + batch = ops.zeros(100, dtype=ms.int64) + + norm = PairNorm(scale_individually=scale_individually) + assert str(norm) == 'PairNorm()' + + if is_full_test(): + ms.jit(norm) + + out1 = norm(x) + assert out1.shape == (100, 16) + + out2 = norm(ops.cat(([x, x]), axis=0), ops.cat(([batch, batch + 1]), axis=0)) + assert ops.isclose(out1, out2[:100], atol=1e-6).all() + assert ops.isclose(out1, out2[100:], atol=1e-6).all() diff --git a/tests/graph/nn/pool/connect/test_filter_edges.py b/tests/graph/nn/pool/connect/test_filter_edges.py new file mode 100644 index 000000000..9e1fc36e9 --- /dev/null +++ b/tests/graph/nn/pool/connect/test_filter_edges.py @@ -0,0 +1,33 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.pool.connect import FilterEdges +from mindscience.sharker.nn.pool.select import SelectOutput +from mindscience.sharker.testing import is_full_test + + +def test_filter_edges(): + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 1, 3, 2, 2]]) + edge_attr = ms.Tensor([1, 2, 3, 4, 5, 6]) + batch = ms.Tensor([0, 0, 1, 1]) + + select_output = SelectOutput( + node_index=ms.Tensor([1, 2]), + num_nodes=4, + cluster_index=ms.Tensor([0, 1]), + num_clusters=2, + ) + + connect = FilterEdges() + assert str(connect) == 'FilterEdges()' + + out1 = connect(select_output, edge_index, edge_attr, batch) + assert out1.edge_index.tolist() == [[0, 1], [0, 1]] + assert out1.edge_attr.tolist() == [3, 5] + assert out1.batch.tolist() == [0, 1] + + if is_full_test(): + jit = ms.jit(connect) + out2 = jit(select_output, edge_index, edge_attr, batch) + ops.isclose(out1.edge_index, out2.edge_index).all() + ops.isclose(out1.edge_attr, out2.edge_attr).all() + ops.equal(out1.batch, out2.batch).all() diff --git a/tests/graph/nn/pool/select/test_select_topk.py b/tests/graph/nn/pool/select/test_select_topk.py new file mode 100644 index 000000000..0e602ce23 --- /dev/null +++ b/tests/graph/nn/pool/select/test_select_topk.py @@ -0,0 +1,89 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.pool.select import SelectOutput, SelectTopK +from mindscience.sharker.nn.pool.select.topk import topk +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import is_full_test + + +def test_topk_ratio(): + x = ms.Tensor([2.0, 4.0, 5.0, 6.0, 2.0, 9.0]) + batch = ms.Tensor([0, 0, 1, 1, 1, 1]) + + perm1 = topk(x, 0.5, batch) + assert perm1.tolist() == [1, 5, 3] + assert x[perm1].tolist() == [4.0, 9.0, 6.0] + assert batch[perm1].tolist() == [0, 1, 1] + + perm2 = topk(x, 2, batch) + assert perm2.tolist() == [1, 0, 5, 3] + assert x[perm2].tolist() == [4.0, 2.0, 9.0, 6.0] + assert batch[perm2].tolist() == [0, 0, 1, 1] + + perm3 = topk(x, 3, batch) + assert perm3.tolist() == [1, 0, 5, 3, 2] + assert x[perm3].tolist() == [4.0, 2.0, 9.0, 6.0, 5.0] + assert batch[perm3].tolist() == [0, 0, 1, 1, 1] + + if is_full_test(): + jit = ms.jit(topk) + assert ops.equal(jit(x, 0.5, batch), perm1).all() + assert ops.equal(jit(x, 2, batch), perm2).all() + assert ops.equal(jit(x, 3, batch), perm3).all() + + +@pytest.mark.parametrize('min_score', [None, 2.0]) +def test_select_topk(min_score): + x = ops.randn(6, 16) + batch = ms.Tensor([0, 0, 1, 1, 1, 1]) + + pool = SelectTopK(16, min_score=min_score) + + if min_score is None: + assert str(pool) == 'SelectTopK(16, ratio=0.5)' + else: + assert str(pool) == 'SelectTopK(16, min_score=2.0)' + + out = pool(x, batch) + assert isinstance(out, SelectOutput) + + assert out.num_nodes == 6 + assert out.num_clusters <= out.num_nodes + assert out.node_index.min() >= 0 + assert out.node_index.max() < out.num_nodes + assert out.cluster_index.min() == 0 + assert out.cluster_index.max() == out.num_clusters - 1 + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + args = parser.parse_args() + + BS = [2**i for i in range(6, 8)] + NS = [2**i for i in range(8, 16)] + + funcs = [] + func_names = [] + args_list = [] + for B, N in product(BS, NS): + x = ops.randn(N) + batch = ops.randint(0, B, (N, )).sort()[0] + + funcs.append(topk) + func_names.append(f'B={B}, N={N}') + args_list.append((x, 0.5, batch)) + + benchmark( + funcs=funcs, + func_names=func_names, + args=args_list, + num_steps=50 if args.device == 'cpu' else 500, + num_warmups=10 if args.device == 'cpu' else 100, + progress_bar=True, + ) diff --git a/tests/graph/nn/pool/test_approx_knn.py b/tests/graph/nn/pool/test_approx_knn.py new file mode 100644 index 000000000..98edd4bde --- /dev/null +++ b/tests/graph/nn/pool/test_approx_knn.py @@ -0,0 +1,60 @@ +import warnings + +import mindspore as ms +from mindscience.sharker.nn import approx_knn, approx_knn_graph +from mindscience.sharker.testing import onlyFullTest, withPackage + + +def to_set(edge_index): + return set([(i, j) for i, j in edge_index.t().tolist()]) + + +@onlyFullTest # JIT compile makes this test too slow :( +@withPackage('pynndescent') +def test_approx_knn(): + warnings.filterwarnings('ignore', '.*find n_neighbors.*') + + x = ms.Tensor([ + [-1.0, -1.0], + [-1.0, +1.0], + [+1.0, +1.0], + [+1.0, -1.0], + [-1.0, -1.0], + [-1.0, +1.0], + [+1.0, +1.0], + [+1.0, -1.0], + ]) + y = ms.Tensor([ + [+1.0, 0.0], + [-1.0, 0.0], + ]) + + batch_x = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + batch_y = ms.Tensor([0, 1]) + + edge_index = approx_knn(x, y, 2) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 0), (1, 1)]) + + edge_index = approx_knn(x, y, 2, batch_x, batch_y) + assert to_set(edge_index) == set([(0, 2), (0, 3), (1, 4), (1, 5)]) + + +@onlyFullTest # JIT compile makes this test too slow :( +@withPackage('pynndescent') +def test_approx_knn_graph(): + warnings.filterwarnings('ignore', '.*find n_neighbors.*') + + x = ms.Tensor([ + [-1.0, -1.0], + [-1.0, +1.0], + [+1.0, +1.0], + [+1.0, -1.0], + ]) + + edge_index = approx_knn_graph(x, k=2, flow='target_to_source') + assert to_set(edge_index) == set([(0, 1), (0, 3), (1, 0), (1, 2), (2, 1), + (2, 3), (3, 0), (3, 2)]) + + edge_index = approx_knn_graph(x, k=2, flow='source_to_target') + assert to_set(edge_index) == set([(1, 0), (3, 0), (0, 1), (2, 1), (1, 2), + (3, 2), (0, 3), (2, 3)]) diff --git a/tests/graph/nn/pool/test_asap.py b/tests/graph/nn/pool/test_asap.py new file mode 100644 index 000000000..beb72ca8c --- /dev/null +++ b/tests/graph/nn/pool/test_asap.py @@ -0,0 +1,40 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import ASAPooling, GCNConv, GraphConv +from mindscience.sharker.testing import is_full_test + + +def test_asap(): + in_channels = 16 + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) + num_nodes = edge_index.max().item() + 1 + x = ops.randn((num_nodes, in_channels)) + + for GNN in [GraphConv, GCNConv]: + pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, + add_self_loops=False) + assert str(pool) == ('ASAPooling(16, ratio=0.5)') + out = pool(x, edge_index) + assert out[0].shape == (num_nodes // 2, in_channels) + assert out[1].shape == (2, 2) + + if is_full_test(): + ms.jit(pool) + + pool = ASAPooling(in_channels, ratio=0.5, GNN=GNN, add_self_loops=True) + assert str(pool) == ('ASAPooling(16, ratio=0.5)') + out = pool(x, edge_index) + assert out[0].shape == (num_nodes // 2, in_channels) + assert out[1].shape == (2, 4) + + pool = ASAPooling(in_channels, ratio=2, GNN=GNN, add_self_loops=False) + assert str(pool) == ('ASAPooling(16, ratio=2)') + out = pool(x, edge_index) + assert out[0].shape == (2, in_channels) + assert out[1].shape == (2, 2) + + +# def test_asap_jit_save(): +# pool = ASAPooling(in_channels=16) +# ms.jit.save(ms.jit(pool), io.BytesIO()) diff --git a/tests/graph/nn/pool/test_avg_pool.py b/tests/graph/nn/pool/test_avg_pool.py new file mode 100644 index 000000000..caa3f688d --- /dev/null +++ b/tests/graph/nn/pool/test_avg_pool.py @@ -0,0 +1,105 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch +from mindscience.sharker.nn import avg_pool, avg_pool_neighbor_x, avg_pool_x +from mindscience.sharker.testing import is_full_test + + +def test_avg_pool_x(): + cluster = ms.Tensor([0, 1, 0, 1, 2, 2]) + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + out = avg_pool_x(cluster, x, batch) + assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]] + assert out[1].tolist() == [0, 0, 1] + + if is_full_test(): + jit = ms.jit(avg_pool_x) + out = jit(cluster, x, batch) + assert out[0].tolist() == [[3, 4], [5, 6], [10, 11]] + assert out[1].tolist() == [0, 0, 1] + + out, _ = avg_pool_x(cluster, x, batch, size=2) + assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]] + + batch_size = int(batch.max().item()) + 1 + out2, _ = avg_pool_x(cluster, x, batch, batch_size=batch_size, size=2) + assert ops.equal(out, out2).all() + + if is_full_test(): + jit = ms.jit(avg_pool_x) + out, _ = jit(cluster, x, batch, size=2) + assert out.tolist() == [[3, 4], [5, 6], [10, 11], [0, 0]] + + out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2) + assert ops.equal(out, out2).all() + + +def test_avg_pool(): + cluster = ms.Tensor([0, 1, 0, 1, 2, 2]) + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + pos = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + [5.0, 5.0], + ]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + edge_attr = ops.ones(edge_index.shape[1]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + data = Batch(x=x, crd=pos, edge_index=edge_index, edge_attr=edge_attr, + batch=batch) + + data = avg_pool(cluster, data, transform=lambda x: x) + + assert data.x.tolist() == [[3, 4], [5, 6], [10, 11]] + assert data.crd.tolist() == [[1, 1], [2, 2], [4.5, 4.5]] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + assert data.edge_attr.tolist() == [4, 4] + assert data.batch.tolist() == [0, 0, 1] + + +def test_avg_pool_neighbor_x(): + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + data = Batch(x=x, edge_index=edge_index, batch=batch) + data = avg_pool_neighbor_x(data) + + assert data.x.tolist() == [ + [4, 5], + [4, 5], + [4, 5], + [4, 5], + [10, 11], + [10, 11], + ] + assert ops.equal(data.edge_index, edge_index).all() diff --git a/tests/graph/nn/pool/test_consecutive.py b/tests/graph/nn/pool/test_consecutive.py new file mode 100644 index 000000000..cf458282a --- /dev/null +++ b/tests/graph/nn/pool/test_consecutive.py @@ -0,0 +1,13 @@ +import mindspore as ms +from mindscience.sharker.nn.pool.consecutive import consecutive_cluster + + +def test_consecutive_cluster(): + src = ms.Tensor([8, 2, 10, 15, 100, 1, 100]) + + out, perm = consecutive_cluster(src) + assert out.tolist() == [2, 1, 3, 4, 5, 0, 5] + + # Todo: check the difference of Tensor.scatter between mindspore and torch + # assert perm.tolist() == [5, 1, 0, 2, 3, 6] + assert perm.tolist() == [5, 1, 0, 2, 3, 4] diff --git a/tests/graph/nn/pool/test_decimation.py b/tests/graph/nn/pool/test_decimation.py new file mode 100644 index 000000000..42cf73639 --- /dev/null +++ b/tests/graph/nn/pool/test_decimation.py @@ -0,0 +1,40 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.pool.decimation import decimation_indices + + +def test_decimation_basic(): + N_1, N_2 = 4, 6 + decimation_factor = 2 + ptr = ms.Tensor([0, N_1, N_1 + N_2]) + + idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) + + expected_size = (N_1 // decimation_factor) + (N_2 // decimation_factor) + assert idx_decim.shape[0] == expected_size + + expected = ms.Tensor([0, N_1 // decimation_factor, expected_size]) + assert ops.equal(ptr_decim, expected).all() + + +def test_decimation_single_cloud(): + N_1 = 4 + decimation_factor = 2 + ptr = ms.Tensor([0, N_1]) + + idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) + + expected_size = N_1 // decimation_factor + assert idx_decim.shape[0] == expected_size + assert ops.equal(ptr_decim, ms.Tensor([0, expected_size])).all() + + +def test_decimation_almost_empty(): + N_1 = 4 + decimation_factor = 666 # greater than N_1 + ptr = ms.Tensor([0, N_1]) + + idx_decim, ptr_decim = decimation_indices(ptr, decimation_factor) + + assert idx_decim.shape[0] == 0 + assert ops.equal(ptr_decim, ms.Tensor([0, 0])).all() diff --git a/tests/graph/nn/pool/test_edge_pool.py b/tests/graph/nn/pool/test_edge_pool.py new file mode 100644 index 000000000..26e92fee0 --- /dev/null +++ b/tests/graph/nn/pool/test_edge_pool.py @@ -0,0 +1,103 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import EdgePooling +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.sparse import scatter + + +def test_compute_edge_score_softmax(): + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + raw = ops.randn(edge_index.shape[1]) + e = EdgePooling.compute_edge_score_softmax(raw, edge_index, 6) + assert ops.all(e >= 0) and ops.all(e <= 1) + + # Test whether all incoming edge scores sum up to one. + assert ops.isclose( + scatter(e, edge_index[1], reduce='sum'), + ops.ones(6), + ).all() + + if is_full_test(): + jit = ms.jit(EdgePooling.compute_edge_score_softmax) + assert ops.isclose(jit(raw, edge_index, 6), e).all() + + +def test_compute_edge_score_tanh(): + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + raw = ops.randn(edge_index.shape[1]) + e = EdgePooling.compute_edge_score_tanh(raw, edge_index, 6) + assert ops.all(e >= -1) and ops.all(e <= 1) + assert ops.all(ops.argsort(raw) == ops.argsort(e)) + + if is_full_test(): + jit = ms.jit(EdgePooling.compute_edge_score_tanh) + assert ops.isclose(jit(raw, edge_index, 6), e).all() + + +def test_compute_edge_score_sigmoid(): + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + raw = ops.randn(edge_index.shape[1]) + e = EdgePooling.compute_edge_score_sigmoid(raw, edge_index, 6) + assert ops.all(e >= 0) and ops.all(e <= 1) + assert ops.all(ops.argsort(raw) == ops.argsort(e)) + + if is_full_test(): + jit = ms.jit(EdgePooling.compute_edge_score_sigmoid) + assert ops.isclose(jit(raw, edge_index, 6), e).all() + + +def test_edge_pooling(): + x = ms.Tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [-1.0]]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5, 6], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4, 0]]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1, 0]) + + op = EdgePooling(in_channels=1) + assert str(op) == 'EdgePooling(1)' + + # Setting parameters fixed so we can test the expected outcome: + op.lin.weight[:] = 1. + op.lin.bias[:] = 0 + + # Test pooling: + new_x, new_edge_index, new_batch, unpool_info = op(x, edge_index, batch) + + assert new_x.shape[0] == new_batch.shape[0] == 4 + assert new_edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [0, 1, 2, 1, 2, 2]] + assert new_batch.tolist() == [1, 0, 0, 0] + + if is_full_test(): + jit = ms.jit(op) + out = jit(x, edge_index, batch) + assert ops.isclose(new_x, out[0]).all() + assert ops.equal(new_edge_index, out[1]).all() + assert ops.equal(new_batch, out[2]).all() + + # Test unpooling: + out = op.unpool(new_x, unpool_info) + assert out[0].shape == x.shape + assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]] + assert ops.equal(out[1], edge_index).all() + assert ops.equal(out[2], batch).all() + + if is_full_test(): + jit = ms.jit(op) + out = jit.unpool(new_x, unpool_info) + assert out[0].shape == x.shape + assert out[0].tolist() == [[1], [1], [5], [5], [9], [9], [-1]] + assert ops.equal(out[1], edge_index).all() + assert ops.equal(out[2], batch).all() + + # Test edge cases. + x = ms.Tensor([[0.0], [1.0], [2.0], [3.0], [4.0], [5.0]]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + new_x, new_edge_index, new_batch, _ = op(x, edge_index, batch) + + assert new_x.shape[0] == new_batch.shape[0] == 3 + assert new_batch.tolist() == [1, 0, 0] + assert new_edge_index.tolist() == [[0, 1, 1, 2, 2], [0, 1, 2, 1, 2]] diff --git a/tests/graph/nn/pool/test_glob.py b/tests/graph/nn/pool/test_glob.py new file mode 100644 index 000000000..bb50a89fc --- /dev/null +++ b/tests/graph/nn/pool/test_glob.py @@ -0,0 +1,72 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import ( + global_add_pool, + global_max_pool, + global_mean_pool, +) + + +def test_global_pool(): + N_1, N_2 = 4, 6 + x = ops.randn(N_1 + N_2, 4) + batch = ms.Tensor([0 for _ in range(N_1)] + [1 for _ in range(N_2)]) + + out = global_add_pool(x, batch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], x[:4].sum(0)).all() + assert ops.isclose(out[1], x[4:].sum(0)).all() + + out = global_add_pool(x, None) + assert out.shape == (1, 4) + assert ops.isclose(out, x.sum(0, keepdims=True)).all() + + out = global_mean_pool(x, batch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], x[:4].mean(0)).all() + assert ops.isclose(out[1], x[4:].mean(0)).all() + + out = global_mean_pool(x, None) + assert out.shape == (1, 4) + assert ops.isclose(out, x.mean(0, keep_dims=True)).all() + + out = global_max_pool(x, batch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], x[:4].max(0)).all() + assert ops.isclose(out[1], x[4:].max(0)).all() + + out = global_max_pool(x, None) + assert out.shape == (1, 4) + assert ops.isclose(out, x.max(0, keepdims=True)).all() + + +def test_permuted_global_pool(): + N_1, N_2 = 4, 6 + x = ops.randn(N_1 + N_2, 4) + batch = ops.cat([ops.zeros(N_1), ops.ones(N_2)]).long() + perm = ops.shuffle(ops.arange(N_1 + N_2)) + + px = x[perm] + pbatch = batch[perm] + px1 = px[pbatch == 0] + px2 = px[pbatch == 1] + + out = global_add_pool(px, pbatch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], px1.sum(0)).all() + assert ops.isclose(out[1], px2.sum(0)).all() + + out = global_mean_pool(px, pbatch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], px1.mean(0)).all() + assert ops.isclose(out[1], px2.mean(0)).all() + + out = global_max_pool(px, pbatch) + assert out.shape == (2, 4) + assert ops.isclose(out[0], px1.max(0)).all() + assert ops.isclose(out[1], px2.max(0)).all() + + +def test_dense_global_pool(): + x = ops.randn(3, 16, 32) + assert ops.isclose(global_add_pool(x, None), x.sum(1)).all() diff --git a/tests/graph/nn/pool/test_graclus.py b/tests/graph/nn/pool/test_graclus.py new file mode 100644 index 000000000..51233802c --- /dev/null +++ b/tests/graph/nn/pool/test_graclus.py @@ -0,0 +1,7 @@ +import mindspore as ms +from mindscience.sharker.nn import graclus + + +def test_graclus(): + edge_index = ms.Tensor([[0, 1], [1, 0]]) + assert graclus(edge_index).tolist() == [0, 0] diff --git a/tests/graph/nn/pool/test_knn.py b/tests/graph/nn/pool/test_knn.py new file mode 100644 index 000000000..c06580c65 --- /dev/null +++ b/tests/graph/nn/pool/test_knn.py @@ -0,0 +1,130 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import ( + ApproxL2KNNIndex, + ApproxMIPSKNNIndex, + L2KNNIndex, + MIPSKNNIndex, +) +from mindscience.sharker.testing import withPackage + + +@withPackage('faiss') +@pytest.mark.parametrize('k', [2]) +def test_l2(device, k): + lhs = ops.randn(10, 16) + rhs = ops.randn(100, 16) + + index = L2KNNIndex(rhs) + assert index.get_emb().device == device + assert ops.equal(index.get_emb(), rhs).all() + + out = index.search(lhs, k) + assert out.score.device == device + assert out.index.device == device + assert out.score.shape == (10, k) + assert out.index.shape == (10, k) + + mat = ms.numpy.norm(lhs.unsqueeze(1) - rhs.unsqueeze(0), axis=-1).pow(2) + score, index = mat.sort(dim=-1) + + assert ops.isclose(out.score, score[:, :k]).all() + assert ops.equal(out.index, index[:, :k]).all() + + +@withPackage('faiss') +@pytest.mark.parametrize('k', [2]) +def test_mips(device, k): + lhs = ops.randn(10, 16) + rhs = ops.randn(100, 16) + + index = MIPSKNNIndex(rhs) + assert index.get_emb().device == device + assert ops.equal(index.get_emb(), rhs).all() + + out = index.search(lhs, k) + assert out.score.device == device + assert out.index.device == device + assert out.score.shape == (10, k) + assert out.index.shape == (10, k) + + mat = lhs @ rhs.t() + score, index = mat.sort(dim=-1, descending=True) + + assert ops.isclose(out.score, score[:, :k]).all() + assert ops.equal(out.index, index[:, :k]).all() + + +@withPackage('faiss') +@pytest.mark.parametrize('k', [2]) +@pytest.mark.parametrize('reserve', [None, 100]) +def test_approx_l2(device, k, reserve): + lhs = ops.randn(10, 16) + rhs = ops.randn(10_000, 16) + + index = ApproxL2KNNIndex( + num_cells=10, + num_cells_to_visit=10, + bits_per_vector=8, + emb=rhs, + reserve=reserve, + ) + + out = index.search(lhs, k) + assert out.score.device == device + assert out.index.device == device + assert out.score.shape == (10, k) + assert out.index.shape == (10, k) + assert out.index.min() >= 0 and out.index.max() < 10_000 + + +@withPackage('faiss') +@pytest.mark.parametrize('k', [2]) +@pytest.mark.parametrize('reserve', [None, 100]) +def test_approx_mips(device, k, reserve): + lhs = ops.randn(10, 16) + rhs = ops.randn(10_000, 16) + + index = ApproxMIPSKNNIndex( + num_cells=10, + num_cells_to_visit=10, + bits_per_vector=8, + emb=rhs, + reserve=reserve, + ) + + out = index.search(lhs, k) + assert out.score.device == device + assert out.index.device == device + assert out.score.shape == (10, k) + assert out.index.shape == (10, k) + assert out.index.min() >= 0 and out.index.max() < 10_000 + + +@withPackage('faiss') +@pytest.mark.parametrize('k', [50]) +def test_mips_exclude(device, k): + lhs = ops.randn(10, 16) + rhs = ops.randn(100, 16) + + exclude_lhs = ops.randint(0, 10, (500, )) + exclude_rhs = ops.randint(0, 100, (500, )) + exclude_links = ops.stack(([exclude_lhs, exclude_rhs]), axis=0) + exclude_links = exclude_links.unique(dim=1) + + index = MIPSKNNIndex(rhs) + + out = index.search(lhs, k, exclude_links) + assert out.score.device == device + assert out.index.device == device + assert out.score.shape == (10, k) + assert out.index.shape == (10, k) + + # Ensure that excluded links are not present in `out.index`: + batch = ops.arange(lhs.shape[0]).repeat_interleave(k) + knn_links = ops.stack(([batch, out.index.view(-1)]), axis=0) + knn_links = knn_links[:, knn_links[1] >= 0] + + unique_links = ops.cat(([knn_links, exclude_links]), axis=1).unique(dim=1) + assert unique_links.shape[1] == knn_links.shape[1] + exclude_links.shape[1] diff --git a/tests/graph/nn/pool/test_max_pool.py b/tests/graph/nn/pool/test_max_pool.py new file mode 100644 index 000000000..b600f8beb --- /dev/null +++ b/tests/graph/nn/pool/test_max_pool.py @@ -0,0 +1,105 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch +from mindscience.sharker.nn import max_pool, max_pool_neighbor_x, max_pool_x +from mindscience.sharker.testing import is_full_test + + +def test_max_pool_x(): + cluster = ms.Tensor([0, 1, 0, 1, 2, 2]) + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + out = max_pool_x(cluster, x, batch) + assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]] + assert out[1].tolist() == [0, 0, 1] + + if is_full_test(): + jit = ms.jit(max_pool_x) + out = jit(cluster, x, batch) + assert out[0].tolist() == [[5, 6], [7, 8], [11, 12]] + assert out[1].tolist() == [0, 0, 1] + + out, _ = max_pool_x(cluster, x, batch, size=2) + assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]] + + batch_size = int(batch.max().item()) + 1 + out2, _ = max_pool_x(cluster, x, batch, batch_size=batch_size, size=2) + assert ops.equal(out, out2).all() + + if is_full_test(): + jit = ms.jit(max_pool_x) + out, _ = jit(cluster, x, batch, size=2) + assert out.tolist() == [[5, 6], [7, 8], [11, 12], [0, 0]] + + out2, _ = jit(cluster, x, batch, batch_size=batch_size, size=2) + assert ops.equal(out, out2).all() + + +def test_max_pool(): + cluster = ms.Tensor([0, 1, 0, 1, 2, 2]) + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + crd = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + [5.0, 5.0], + ]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + edge_attr = ops.ones(edge_index.shape[1]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + data = Batch(x=x, crd=crd, edge_index=edge_index, edge_attr=edge_attr, + batch=batch) + + data = max_pool(cluster, data, transform=lambda x: x) + + assert data.x.tolist() == [[5, 6], [7, 8], [11, 12]] + assert data.crd.tolist() == [[1, 1], [2, 2], [4.5, 4.5]] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + assert data.edge_attr.tolist() == [4, 4] + assert data.batch.tolist() == [0, 0, 1] + + +def test_max_pool_neighbor_x(): + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1]) + + data = Batch(x=x, edge_index=edge_index, batch=batch) + data = max_pool_neighbor_x(data) + + assert data.x.tolist() == [ + [7, 8], + [7, 8], + [7, 8], + [7, 8], + [11, 12], + [11, 12], + ] + assert ops.equal(data.edge_index, edge_index).all() diff --git a/tests/graph/nn/pool/test_mem_pool.py b/tests/graph/nn/pool/test_mem_pool.py new file mode 100644 index 000000000..4ea68332c --- /dev/null +++ b/tests/graph/nn/pool/test_mem_pool.py @@ -0,0 +1,28 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn import MemPooling +from mindscience.sharker.utils import to_dense_batch + + +def test_mem_pool(): + mpool1 = MemPooling(4, 8, heads=3, num_clusters=2) + assert str(mpool1) == 'MemPooling(4, 8, heads=3, num_clusters=2)' + mpool2 = MemPooling(8, 4, heads=2, num_clusters=1) + + x = ops.randn(17, 4) + batch = ms.Tensor([0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 4, 4, 4]) + _, mask = to_dense_batch(x, batch) + + out1, S = mpool1(x, batch) + loss = MemPooling.kl_loss(S) + grad_fn = ops.value_and_grad(MemPooling.kl_loss, weights=mpool1.trainable_params()) + loss, grad = grad_fn(S) + + out2, _ = mpool2(out1) + + assert out1.shape == (5, 2, 8) + assert out2.shape == (5, 1, 4) + assert S[~mask].sum() == 0 + assert round(S[mask].sum().item()) == x.shape[0] + assert float(loss) > 0 + assert not grad[0].isnan().any() diff --git a/tests/graph/nn/pool/test_pan_pool.py b/tests/graph/nn/pool/test_pan_pool.py new file mode 100644 index 000000000..3443d9a17 --- /dev/null +++ b/tests/graph/nn/pool/test_pan_pool.py @@ -0,0 +1,35 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import PANConv, PANPooling +from mindscience.sharker.testing import is_full_test + + +def test_pan_pooling(): + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) + num_nodes = edge_index.max().item() + 1 + x = ops.randn((num_nodes, 16)) + + conv = PANConv(16, 32, filter_size=2) + pool = PANPooling(32, ratio=0.5) + assert str(pool) == 'PANPooling(32, ratio=0.5, multiplier=1.0)' + + x, M = conv(x, edge_index) + h, edge_index, edge_weight, batch, perm, score = pool(x, M) + + assert h.shape == (2, 32) + assert edge_index.shape == (2, 4) + assert edge_weight.shape == (4, ) + assert perm.shape == (2, ) + assert score.shape == (2, ) + + if is_full_test(): + jit = ms.jit(pool) + out = jit(x, M) + assert ops.isclose(h, out[0]).all() + assert ops.equal(edge_index, out[1]).all() + assert ops.isclose(edge_weight, out[2]).all() + assert ops.equal(batch, out[3]).all() + assert ops.equal(perm, out[4]).all() + assert ops.isclose(score, out[5]).all() diff --git a/tests/graph/nn/pool/test_pool.py b/tests/graph/nn/pool/test_pool.py new file mode 100644 index 000000000..c4768d155 --- /dev/null +++ b/tests/graph/nn/pool/test_pool.py @@ -0,0 +1,18 @@ +import mindspore as ms +from mindspore import Tensor, nn, ops + +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.sparse.cluster import radius_graph + + +def test_radius_graph(): + x = ms.Tensor([[-1, -1], [-1, 1], [1, -1], [1, 1]]).float() + batch = ms.Tensor([0, 0, 0, 0]) + + out = radius_graph(x, r=2.5, batch=batch, loop=False) + assert out.tolist() == [[1, 2, 0, 3, 0, 3, 1, 2], [0, 0, 1, 1, 2, 2, 3, 3]] + + if is_full_test(): + jit = ms.jit(radius_graph) + out1 = jit(x, r=2.5, batch=batch, loop=False) + assert ops.isclose(out, out1).all() diff --git a/tests/graph/nn/pool/test_sag_pool.py b/tests/graph/nn/pool/test_sag_pool.py new file mode 100644 index 000000000..56de4cc92 --- /dev/null +++ b/tests/graph/nn/pool/test_sag_pool.py @@ -0,0 +1,51 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn import ( + GATConv, + GCNConv, + GraphConv, + SAGEConv, + SAGPooling, +) +from mindscience.sharker.testing import is_full_test + + +def test_sag_pooling(): + in_channels = 16 + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) + num_nodes = edge_index.max().item() + 1 + x = ops.randn((num_nodes, in_channels)) + + for GNN in [GraphConv, GCNConv, GATConv, SAGEConv]: + pool1 = SAGPooling(in_channels, ratio=0.5, GNN=GNN) + assert str(pool1) == (f'SAGPooling({GNN.__name__}, 16, ' + f'ratio=0.5, multiplier=1.0)') + out1 = pool1(x, edge_index) + assert out1[0].shape == (num_nodes // 2, in_channels) + assert out1[1].shape == (2, 2) + + pool2 = SAGPooling(in_channels, ratio=None, GNN=GNN, min_score=0.1) + assert str(pool2) == (f'SAGPooling({GNN.__name__}, 16, ' + f'min_score=0.1, multiplier=1.0)') + out2 = pool2(x, edge_index) + assert out2[0].shape[0] <= x.shape[0] and out2[0].shape[1] == (16) + assert out2[1].shape[0] == 2 and out2[1].shape[1] <= edge_index.shape[1] + + pool3 = SAGPooling(in_channels, ratio=2, GNN=GNN) + assert str(pool3) == (f'SAGPooling({GNN.__name__}, 16, ' + f'ratio=2, multiplier=1.0)') + out3 = pool3(x, edge_index) + assert out3[0].shape == (2, in_channels) + assert out3[1].shape == (2, 2) + + if is_full_test(): + jit1 = ms.jit(pool1) + assert ops.isclose(jit1(x, edge_index)[0], out1[0]).all() + + jit2 = ms.jit(pool2) + assert ops.isclose(jit2(x, edge_index)[0], out2[0]).all() + + jit3 = ms.jit(pool3) + assert ops.isclose(jit3(x, edge_index)[0], out3[0]).all() diff --git a/tests/graph/nn/pool/test_topk_pool.py b/tests/graph/nn/pool/test_topk_pool.py new file mode 100644 index 000000000..6aa613725 --- /dev/null +++ b/tests/graph/nn/pool/test_topk_pool.py @@ -0,0 +1,60 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker import typing +from mindscience.sharker.nn.pool import TopKPooling +from mindscience.sharker.nn.pool.connect.filter_edges import filter_adj +from mindscience.sharker.testing import is_full_test + + +def test_filter_adj(): + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 3, 3], + [1, 3, 0, 2, 1, 3, 0, 2]]) + edge_attr = ms.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + perm = ms.Tensor([2, 3]) + + out = filter_adj(edge_index, edge_attr, perm) + assert out[0].tolist() == [[0, 1], [1, 0]] + assert out[1].tolist() == [6.0, 8.0] + + if is_full_test(): + jit = ms.jit(filter_adj) + + out = jit(edge_index, edge_attr, perm) + assert out[0].tolist() == [[0, 1], [1, 0]] + assert out[1].tolist() == [6.0, 8.0] + + +def test_topk_pooling(): + in_channels = 16 + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]]) + num_nodes = edge_index.max().item() + 1 + x = ops.randn((num_nodes, in_channels)) + + pool1 = TopKPooling(in_channels, ratio=0.5) + assert str(pool1) == 'TopKPooling(16, ratio=0.5, multiplier=1.0)' + out1 = pool1(x, edge_index) + assert out1[0].shape == (num_nodes // 2, in_channels) + assert out1[1].shape == (2, 2) + + pool2 = TopKPooling(in_channels, ratio=None, min_score=0.1) + assert str(pool2) == 'TopKPooling(16, min_score=0.1, multiplier=1.0)' + out2 = pool2(x, edge_index) + assert out2[0].shape[0] <= x.shape[0] and out2[0].shape[1] == (16) + assert out2[1].shape[0] == 2 and out2[1].shape[1] <= edge_index.shape[1] + + pool3 = TopKPooling(in_channels, ratio=2) + assert str(pool3) == 'TopKPooling(16, ratio=2, multiplier=1.0)' + out3 = pool3(x, edge_index) + assert out3[0].shape == (2, in_channels) + assert out3[1].shape == (2, 2) + + if is_full_test(): + jit1 = ms.jit(pool1) + assert ops.isclose(jit1(x, edge_index)[0], out1[0]).all() + + jit2 = ms.jit(pool2) + assert ops.isclose(jit2(x, edge_index)[0], out2[0]).all() + + jit3 = ms.jit(pool3) + assert ops.isclose(jit3(x, edge_index)[0], out3[0]).all() diff --git a/tests/graph/nn/pool/test_voxel_grid.py b/tests/graph/nn/pool/test_voxel_grid.py new file mode 100644 index 000000000..2ab6c8f5d --- /dev/null +++ b/tests/graph/nn/pool/test_voxel_grid.py @@ -0,0 +1,50 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.data import Batch +from mindscience.sharker.nn import avg_pool +from mindscience.sharker.sparse.cluster import voxel_grid + + +def test_voxel_grid(): + pos = ms.Tensor([ + [0.0, 0.0], + [11.0, 9.0], + [2.0, 8.0], + [2.0, 2.0], + [8.0, 3.0], + ]) + batch = ms.Tensor([0, 0, 0, 1, 1]) + + assert voxel_grid(pos, size=5, batch=batch).tolist() == [0, 5, 3, 6, 7] + assert voxel_grid(pos, size=5).tolist() == [0, 5, 3, 0, 1] + + cluster = voxel_grid(pos, size=5, batch=batch, start=-1, end=[18, 14]) + assert cluster.tolist() == [0, 10, 4, 16, 17] + + cluster_no_batch = voxel_grid(pos, size=5, start=-1, end=[18, 14]) + assert cluster_no_batch.tolist() == [0, 10, 4, 0, 1] + + +def test_single_voxel_grid(): + pos = ms.Tensor([ + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + [3.0, 3.0], + [4.0, 4.0], + ]) + edge_index = ms.Tensor([[0, 0, 3], [1, 2, 4]]) + batch = ms.Tensor([0, 0, 0, 1, 1]) + x = ops.randn(5, 16) + + cluster = voxel_grid(pos, size=5, batch=batch) + assert cluster.tolist() == [0, 0, 0, 1, 1] + + data = Batch(x=x, edge_index=edge_index, pos=pos, batch=batch) + data = avg_pool(cluster, data) + + cluster_no_batch = voxel_grid(pos, size=5) + assert cluster_no_batch.tolist() == [0, 0, 0, 0, 0] + + data_no_batch = Batch(x=x, edge_index=edge_index, pos=pos) + data_no_batch = avg_pool(cluster_no_batch, data_no_batch) diff --git a/tests/graph/nn/test_encoding.py b/tests/graph/nn/test_encoding.py new file mode 100644 index 000000000..ff217e1b9 --- /dev/null +++ b/tests/graph/nn/test_encoding.py @@ -0,0 +1,18 @@ +import mindspore as ms +from mindscience.sharker.nn import PositionalEncoding, TemporalEncoding + + +def test_positional_encoding(): + encoder = PositionalEncoding(64) + assert str(encoder) == 'PositionalEncoding(64)' + + x = ms.Tensor([1.0, 2.0, 3.0]) + assert encoder(x).shape == (3, 64) + + +def test_temporal_encoding(): + encoder = TemporalEncoding(64) + assert str(encoder) == 'TemporalEncoding(64)' + + x = ms.Tensor([1.0, 2.0, 3.0]) + assert encoder(x).shape == (3, 64) diff --git a/tests/graph/nn/test_inits.py b/tests/graph/nn/test_inits.py new file mode 100644 index 000000000..7fcff7a80 --- /dev/null +++ b/tests/graph/nn/test_inits.py @@ -0,0 +1,62 @@ +import mindspore as ms +from mindspore import ops +from mindspore.nn import Dense as Lin +from mindspore.nn import ReLU +from mindspore.nn import SequentialCell as Seq + +from mindscience.sharker.nn.inits import ( + glorot, + glorot_orthogonal, + ones, + reset, + uniform, + zeros, +) + + +def test_inits(): + x = ms.numpy.empty([1, 4]) + + uniform(size=4, value=x) + assert x.min() >= -0.5 + assert x.max() <= 0.5 + + glorot(x) + assert x.min() >= -1.1 + assert x.max() <= 1.1 + + glorot_orthogonal(x, scale=1.0) + assert x.min() >= -2.5 + assert x.max() <= 2.5 + + zeros(x) + assert x.tolist() == [[0, 0, 0, 0]] + + ones(x) + assert x.tolist() == [[1, 1, 1, 1]] + + nn = Lin(16, 16) + uniform(size=4, value=nn.weight) + assert nn.weight[0].min() >= -0.5 + assert nn.weight[0].max() <= 0.5 + + glorot(nn.weight) + assert nn.weight[0].min() >= -0.45 + assert nn.weight[0].max() <= 0.45 + + glorot_orthogonal(nn.weight, scale=1.0) + assert nn.weight[0].min() >= -2.5 + assert nn.weight[0].max() <= 2.5 + + +def test_reset(): + nn = Lin(16, 16) + w = nn.weight.copy() + reset(nn) + assert not ops.isclose(nn.weight, w).all() + + nn = Seq(Lin(16, 16), ReLU(), Lin(16, 16)) + w_1, w_2 = nn[0].weight.copy(), nn[2].weight.copy() + reset(nn) + assert not ops.isclose(nn[0].weight, w_1).all() + assert not ops.isclose(nn[2].weight, w_2).all() diff --git a/tests/graph/nn/test_reshape.py b/tests/graph/nn/test_reshape.py new file mode 100644 index 000000000..10188022f --- /dev/null +++ b/tests/graph/nn/test_reshape.py @@ -0,0 +1,13 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.nn.reshape import Reshape + + +def test_reshape(): + x = ops.randn(10, 4) + + op = Reshape(5, 2, 4) + assert str(op) == 'Reshape(5, 2, 4)' + + assert op(x).shape == (5, 2, 4) + assert ops.equal(op(x).view(10, 4), x).all() diff --git a/tests/graph/nn/test_resolver.py b/tests/graph/nn/test_resolver.py new file mode 100644 index 000000000..45797890f --- /dev/null +++ b/tests/graph/nn/test_resolver.py @@ -0,0 +1,111 @@ +import pytest +import mindspore as ms +from mindspore import Parameter, ops, nn +from mindspore.experimental.optim import Adam +from mindspore.experimental.optim.lr_scheduler import ( + ConstantLR, LambdaLR, ReduceLROnPlateau +) + +from mindscience.sharker import nn as snn +from mindscience.sharker.nn.resolver import ( + activation_resolver, + aggregation_resolver, + lr_scheduler_resolver, + normalization_resolver, + optimizer_resolver, +) + + +def test_activation_resolver(): + assert isinstance(activation_resolver(nn.ELU()), nn.ELU) + assert isinstance(activation_resolver(nn.ReLU()), nn.ReLU) + assert isinstance(activation_resolver(nn.PReLU()), nn.PReLU) + + assert isinstance(activation_resolver('elu'), nn.ELU) + assert isinstance(activation_resolver('relu'), nn.ReLU) + assert isinstance(activation_resolver('prelu'), nn.PReLU) + + +@pytest.mark.parametrize('aggr_tuple', [ + (snn.MeanAggregation, 'mean'), + (snn.SumAggregation, 'sum'), + (snn.SumAggregation, 'add'), + (snn.MaxAggregation, 'max'), + (snn.MinAggregation, 'min'), + (snn.MulAggregation, 'mul'), + (snn.VarAggregation, 'var'), + (snn.StdAggregation, 'std'), + (snn.SoftmaxAggregation, 'softmax'), + (snn.PowerMeanAggregation, 'powermean'), +]) +def test_aggregation_resolver(aggr_tuple): + aggr_module, aggr_repr = aggr_tuple + assert isinstance(aggregation_resolver(aggr_module()), aggr_module) + assert isinstance(aggregation_resolver(aggr_repr), aggr_module) + + +def test_multi_aggregation_resolver(): + aggr = aggregation_resolver(None) + assert aggr is None + + # aggr = aggregation_resolver(['sum', 'mean', None]) + # assert len(aggr.aggrs) == 3 + # assert aggr.aggrs[2] is None + + aggr = aggregation_resolver(['sum', 'mean']) + assert len(aggr.aggrs) == 2 + assert isinstance(aggr, snn.MultiAggregation) + assert isinstance(aggr.aggrs[0], snn.SumAggregation) + assert isinstance(aggr.aggrs[1], snn.MeanAggregation) + + +@pytest.mark.parametrize('norm_tuple', [ + (snn.BatchNorm, 'batch', (16, )), + (snn.BatchNorm, 'batch_norm', (16, )), + # (snn.InstanceNorm, 'instance_norm', (16, )), + (snn.LayerNorm, 'layer_norm', (16, )), + (snn.GraphNorm, 'graph_norm', (16, )), + (snn.GraphSizeNorm, 'graphsize_norm', ()), + (snn.PairNorm, 'pair_norm', ()), + (snn.MessageNorm, 'message_norm', ()), + (snn.DiffGroupNorm, 'diffgroup_norm', (16, 4)), +]) +def test_normalization_resolver(norm_tuple): + norm_module, norm_repr, norm_args = norm_tuple + assert isinstance(normalization_resolver(norm_module(*norm_args)), norm_module) + assert isinstance(normalization_resolver(norm_repr, *norm_args), norm_module) + + +def test_optimizer_resolver(): + params = [Parameter(ops.randn(1))] + + assert isinstance(optimizer_resolver(nn.SGD(params, learning_rate=0.01)), nn.SGD) + assert isinstance(optimizer_resolver(nn.Adam(params)), nn.Adam) + assert isinstance(optimizer_resolver(nn.Rprop(params)), nn.Rprop) + + assert isinstance(optimizer_resolver('sgd', params, learning_rate=0.01), nn.SGD) + assert isinstance(optimizer_resolver('adam', params), nn.Adam) + assert isinstance(optimizer_resolver('rprop', params), nn.Rprop) + + +@pytest.mark.parametrize('scheduler_args', [ + ('constant_with_warmup', LambdaLR), + ('linear_with_warmup', LambdaLR), + ('cosine_with_warmup', LambdaLR), + ('cosine_with_warmup_restarts', LambdaLR), + ('polynomial_with_warmup', LambdaLR), + ('constant', ConstantLR), + ('ReduceLROnPlateau', ReduceLROnPlateau), +]) +def test_lr_scheduler_resolver(scheduler_args): + scheduler_name, scheduler_cls = scheduler_args + + model = nn.Dense(10, 5) + optimizer = Adam(model.trainable_params(), lr=0.01) + + lr_scheduler = lr_scheduler_resolver( + scheduler_name, + optimizer, + num_training_steps=100, + ) + assert isinstance(lr_scheduler, scheduler_cls) diff --git a/tests/graph/nn/unpool/test_knn_interpolate.py b/tests/graph/nn/unpool/test_knn_interpolate.py new file mode 100644 index 000000000..4851bed76 --- /dev/null +++ b/tests/graph/nn/unpool/test_knn_interpolate.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindscience.sharker.nn import knn_interpolate + + +def test_knn_interpolate(): + x = ms.Tensor([[1.0], [10.0], [100.0], [-1.0], [-10.0], [-100.0]]) + pos_x = ms.Tensor([ + [-1.0, 0.0], + [0.0, 0.0], + [1.0, 0.0], + [-2.0, 0.0], + [0.0, 0.0], + [2.0, 0.0], + ]) + pos_y = ms.Tensor([ + [-1.0, -1.0], + [1.0, 1.0], + [-2.0, -2.0], + [2.0, 2.0], + ]) + batch_x = ms.Tensor([0, 0, 0, 1, 1, 1]) + batch_y = ms.Tensor([0, 0, 1, 1]) + + y = knn_interpolate(x, pos_x, pos_y, batch_x, batch_y, k=2) + assert y.tolist() == [[4.0], [70.0], [-4.0], [-70.0]] diff --git a/tests/graph/profile/test_benchmark.py b/tests/graph/profile/test_benchmark.py new file mode 100644 index 000000000..b6a396827 --- /dev/null +++ b/tests/graph/profile/test_benchmark.py @@ -0,0 +1,21 @@ +import mindspore as ms +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import withPackage + + +@withPackage('tabulate') +def test_benchmark(capfd): + def add(x, y): + return x + y + + benchmark( + funcs=[add], + args=(ops.randn(10), ops.randn(10)), + num_steps=1, + num_warmups=1, + backward=True, + ) + + out, _ = capfd.readouterr() + assert '| Name | Forward | Backward | Total |' in out + assert '| add |' in out diff --git a/tests/graph/profile/test_profile.py b/tests/graph/profile/test_profile.py new file mode 100644 index 000000000..50f6311e4 --- /dev/null +++ b/tests/graph/profile/test_profile.py @@ -0,0 +1,174 @@ +import os +import os.path as osp +import warnings + +import pytest +import mindspore as msimport torch.nn.functional as F + +from mindscience.sharker.nn import GraphSAGE +from mindscience.sharker.profile import ( + get_stats_summary, + profileit, + rename_profile_file, + timeit, +) +from mindscience.sharker.profile.profile import torch_profile, xpu_profile +from mindscience.sharker.testing import ( + onlyCUDA, + onlyLinux, + onlyOnline, + onlyXPU, + withDevice, + withPackage, +) + + +@onlyLinux +def test_timeit(device): + x = ops.randn(100, 16) + lin = torch.nn.Linear(16, 32) + + with timeit(log=False) as t: + assert not hasattr(t, 'duration') + + with torch.no_grad(): + lin(x) + t.reset() + assert t.duration > 0 + + del t.duration + assert not hasattr(t, 'duration') + assert t.duration > 0 + + +@onlyCUDA +@onlyOnline +@withPackage('pytorch_memlab') +def test_profileit_cuda(get_dataset): + warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') + + dataset = get_dataset(name='Cora') + data = dataset[0] + model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, + out_channels=dataset.num_classes) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + @profileit('cuda') + def train(model, x, edge_index, y): + model.train() + optimizer.zero_grad() + out = model(x, edge_index) + loss = F.cross_entropy(out, y) + loss.backward() + return float(loss) + + stats_list = [] + for epoch in range(5): + _, stats = train(model, data.x, data.edge_index, data.y) + assert stats.time > 0 + assert stats.max_allocated_gpu > 0 + assert stats.max_reserved_gpu > 0 + assert stats.max_active_gpu > 0 + assert stats.nvidia_smi_free_cuda > 0 + assert stats.nvidia_smi_used_cuda > 0 + + if epoch >= 2: # Warm-up + stats_list.append(stats) + + stats_summary = get_stats_summary(stats_list) + assert stats_summary.time_mean > 0 + assert stats_summary.time_std > 0 + assert stats_summary.max_allocated_gpu > 0 + assert stats_summary.max_reserved_gpu > 0 + assert stats_summary.max_active_gpu > 0 + assert stats_summary.min_nvidia_smi_free_cuda > 0 + assert stats_summary.max_nvidia_smi_used_cuda > 0 + + +@onlyXPU +def test_profileit_xpu(get_dataset): + warnings.filterwarnings('ignore', '.*arguments of DataFrame.drop.*') + + dataset = get_dataset(name='Cora') + data = dataset[0] + model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, + out_channels=dataset.num_classes) + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + + @profileit('xpu') + def train(model, x, edge_index, y): + model.train() + optimizer.zero_grad() + out = model(x, edge_index) + loss = F.cross_entropy(out, y) + loss.backward() + return float(loss) + + stats_list = [] + for epoch in range(5): + _, stats = train(model, data.x, data.edge_index, data.y) + assert stats.time > 0 + assert stats.max_allocated_gpu > 0 + assert stats.max_reserved_gpu > 0 + assert stats.max_active_gpu > 0 + assert not hasattr(stats, 'nvidia_smi_free_cuda') + assert not hasattr(stats, 'nvidia_smi_used_cuda') + + if epoch >= 2: # Warm-up + stats_list.append(stats) + + stats_summary = get_stats_summary(stats_list) + assert stats_summary.time_mean > 0 + assert stats_summary.time_std > 0 + assert stats_summary.max_allocated_gpu > 0 + assert stats_summary.max_reserved_gpu > 0 + assert stats_summary.max_active_gpu > 0 + assert not hasattr(stats_summary, 'min_nvidia_smi_free_cuda') + assert not hasattr(stats_summary, 'max_nvidia_smi_used_cuda') + + +@onlyOnline +def test_torch_profile(capfd, get_dataset, device): + dataset = get_dataset(name='Cora') + data = dataset[0] + model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, + out_channels=dataset.num_classes) + + with torch_profile(): + model(data.x, data.edge_index) + + out, _ = capfd.readouterr() + assert 'Self CPU time total' in out + if data.x.is_cuda: + assert 'Self CUDA time total' in out + + rename_profile_file('test_profile') + assert osp.exists('profile-test_profile.json') + os.remove('profile-test_profile.json') + + +@onlyXPU +@onlyOnline +@pytest.mark.parametrize('export_chrome_trace', [False, True]) +def test_xpu_profile(capfd, get_dataset, export_chrome_trace): + dataset = get_dataset(name='Cora') + device = torch.device('xpu') + data = dataset[0] + model = GraphSAGE(dataset.num_features, hidden_channels=64, num_layers=3, + out_channels=dataset.num_classes) + + with xpu_profile(export_chrome_trace): + model(data.x, data.edge_index) + + out, _ = capfd.readouterr() + assert 'Self CPU' in out + if data.x.is_xpu: + assert 'Self XPU' in out + + f_name = 'timeline.json' + f_exists = osp.exists(f_name) + if not export_chrome_trace: + assert not f_exists + else: + assert f_exists + os.remove(f_name) diff --git a/tests/graph/profile/test_profile_utils.py b/tests/graph/profile/test_profile_utils.py new file mode 100644 index 000000000..51384d46c --- /dev/null +++ b/tests/graph/profile/test_profile_utils.py @@ -0,0 +1,81 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindspore.nn import Dense as Linear + +from mindscience.sharker.data import Graph +from mindscience.sharker.profile import ( + count_parameters, + get_cpu_memory_from_gc, + get_data_size, + get_gpu_memory_from_gc, + # get_gpu_memory_from_ipex, + get_gpu_memory_from_nvidia_smi, + get_model_size, +) +from mindscience.sharker.profile.utils import ( + byte_to_megabyte, + medibyte_to_megabyte, +) +from mindscience.sharker.testing import withPackage +from mindscience.sharker.typing import SparseTensor + + +def test_count_parameters(): + assert count_parameters(Linear(32, 128)) == 32 * 128 + 128 + + +def test_get_model_size(): + model_size = get_model_size(Linear(32, 128, bias=False)) + assert model_size >= 32 * 128 * 4 and model_size < 32 * 128 * 4 + 2000 + + +def test_get_data_size(): + x = ops.randn(10, 128) + data = Graph(x=x, y=x) + + data_size = get_data_size(data) + assert data_size == 10 * 128 * 4 + + +@withPackage('torch_sparse') +def test_get_data_size_with_sparse_tensor(): + x = ops.randn(10, 128) + row, col = ops.randint(0, 10, (2, 100), dtype=ms.int64) + adj_t = SparseTensor(row=row, col=col, value=None, sparse_shape=(10, 10)) + data = Graph(x=x, y=x, adj_t=adj_t) + + data_size = get_data_size(data) + assert data_size == 10 * 128 * 4 + 11 * 8 + 100 * 8 + + +def test_get_cpu_memory_from_gc(): + old_mem = get_cpu_memory_from_gc() + _ = ops.randn(10, 128) + new_mem = get_cpu_memory_from_gc() + assert new_mem - old_mem == 10 * 128 * 4 + + +def test_get_gpu_memory_from_gc(): + old_mem = get_gpu_memory_from_gc() + _ = ops.randn(10, 128) + new_mem = get_gpu_memory_from_gc() + assert new_mem - old_mem == 10 * 128 * 4 + + +def test_get_gpu_memory_from_nvidia_smi(): + free_mem, used_mem = get_gpu_memory_from_nvidia_smi(device=0, digits=2) + assert free_mem >= 0 + assert used_mem >= 0 + + +# @onlyXPU +# def test_get_gpu_memory_from_ipex(): +# max_allocated, max_reserved, max_active = get_gpu_memory_from_ipex() +# assert max_allocated >= 0 +# assert max_reserved >= 0 +# assert max_active >= 0 + + +def test_bytes_function(): + assert byte_to_megabyte((1024 * 1024)) == 1.00 + assert medibyte_to_megabyte(1 / 1.0485) == 1.00 diff --git a/tests/graph/profile/test_profiler.py b/tests/graph/profile/test_profiler.py new file mode 100644 index 000000000..980967318 --- /dev/null +++ b/tests/graph/profile/test_profiler.py @@ -0,0 +1,26 @@ +import mindspore as ms +from mindscience.sharker.nn import GraphSAGE +from mindscience.sharker.profile.profiler import Profiler +from mindscience.sharker.testing import withDevice, withPackage + + +@withPackage('torch>=1.13.0') # TODO Investigate test errors +def test_profiler(capfd, get_dataset, device): + x = ops.randn(10, 16) + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5, 7, 6, 8, 7, 9, 8], + ]) + + model = GraphSAGE(16, hidden_channels=32, num_layers=2) + + with Profiler(model, profile_memory=True, use_cuda=x.is_cuda) as prof: + model(x, edge_index) + + _, err = capfd.readouterr() + assert 'Completed Stage' in err + + _, heading_list, raw_results, layer_names, layer_stats = prof.get_trace() + assert 'Self CPU total' in heading_list + assert 'aten::relu' in raw_results + assert '-act--aten::relu' in layer_names diff --git a/tests/graph/sampler/test_sampler_base.py b/tests/graph/sampler/test_sampler_base.py new file mode 100644 index 000000000..51f5e92bd --- /dev/null +++ b/tests/graph/sampler/test_sampler_base.py @@ -0,0 +1,114 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.sampler import HeteroSamplerOutput, NumNeighbors, SamplerOutput +from mindscience.sharker.testing import get_random_edge_index +from mindscience.sharker.utils import is_undirected + + +def test_homogeneous_num_neighbors(): + with pytest.raises(ValueError, match="'default' must be set to 'None'"): + num_neighbors = NumNeighbors([25, 10], default=[-1, -1]) + + num_neighbors = NumNeighbors([25, 10]) + assert str(num_neighbors) == 'NumNeighbors(values=[25, 10], default=None)' + + assert num_neighbors.get_values() == [25, 10] + assert num_neighbors.__dict__['_values'] == [25, 10] + assert num_neighbors.get_values() == [25, 10] # Test caching. + + assert num_neighbors.get_mapped_values() == [25, 10] + assert num_neighbors.__dict__['_mapped_values'] == [25, 10] + assert num_neighbors.get_mapped_values() == [25, 10] # Test caching. + + assert num_neighbors.num_hops == 2 + assert num_neighbors.__dict__['_num_hops'] == 2 + assert num_neighbors.num_hops == 2 # Test caching. + + +def test_heterogeneous_num_neighbors_list(): + num_neighbors = NumNeighbors([25, 10]) + + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) + assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]} + + values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) + assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]} + + assert num_neighbors.num_hops == 2 + + +def test_heterogeneous_num_neighbors_dict_and_default(): + num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1]) + with pytest.raises(ValueError, match="hops must be the same across all"): + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) + + num_neighbors = NumNeighbors({('A', 'B'): [25, 10]}, default=[-1, -1]) + + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) + assert values == {('A', 'B'): [25, 10], ('B', 'A'): [-1, -1]} + + values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) + assert values == {'A__to__B': [25, 10], 'B__to__A': [-1, -1]} + + assert num_neighbors.num_hops == 2 + + +def test_heterogeneous_num_neighbors_empty_dict(): + num_neighbors = NumNeighbors({}, default=[25, 10]) + + values = num_neighbors.get_values([('A', 'B'), ('B', 'A')]) + assert values == {('A', 'B'): [25, 10], ('B', 'A'): [25, 10]} + + values = num_neighbors.get_mapped_values([('A', 'B'), ('B', 'A')]) + assert values == {'A__to__B': [25, 10], 'B__to__A': [25, 10]} + + assert num_neighbors.num_hops == 2 + + +def test_homogeneous_to_bidirectional(): + edge_index = get_random_edge_index(10, 10, num_edges=20) + + obj = SamplerOutput( + node=ops.arange(10), + row=edge_index[0], + col=edge_index[0], + edge=ops.arange(edge_index.shape[1]), + ).to_bidirectional() + + assert is_undirected(ops.stack(([obj.row, obj.col]), axis=0)) + + +def test_heterogeneous_to_bidirectional(): + edge_index1 = get_random_edge_index(10, 5, num_edges=20) + edge_index2 = get_random_edge_index(5, 10, num_edges=20) + edge_index3 = get_random_edge_index(10, 10, num_edges=20) + + obj = HeteroSamplerOutput( + node={ + 'v1': ops.arange(10), + 'v2': ops.arange(5) + }, + row={ + ('v1', 'to', 'v2'): edge_index1[0], + ('v2', 'rev_to', 'v1'): edge_index2[0], + ('v1', 'to', 'v1'): edge_index3[0], + }, + col={ + ('v1', 'to', 'v2'): edge_index1[1], + ('v2', 'rev_to', 'v1'): edge_index2[1], + ('v1', 'to', 'v1'): edge_index3[1], + }, + edge={}, + ).to_bidirectional() + + assert ops.equal( + obj.row['v1', 'to', 'v2'].sort()[0], + obj.col['v2', 'rev_to', 'v1'].sort()[0], + ).all() + assert ops.equal( + obj.col['v1', 'to', 'v2'].sort()[0], + obj.row['v2', 'rev_to', 'v1'].sort()[0], + ).all() + assert is_undirected( + ops.stack([obj.row['v1', 'to', 'v1'], obj.col['v1', 'to', 'v1']], 0)) diff --git a/tests/graph/sparse/test_add.py b/tests/graph/sparse/test_add.py new file mode 100644 index 000000000..5181bbe10 --- /dev/null +++ b/tests/graph/sparse/test_add.py @@ -0,0 +1,33 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse import SparseTensor, add +from mindscience.sharker.sparse.testing import dtypes, tensor +from mindscience.sharker.testing import is_full_test + + +@pytest.mark.parametrize('dtype', dtypes) +def test_add(dtype): + rowA = Tensor([0, 0, 1, 2, 2]) + colA = Tensor([0, 2, 1, 0, 1]) + valueA = tensor([1, 2, 4, 1, 3], dtype) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = Tensor([0, 0, 1, 2, 2]) + colB = Tensor([1, 2, 2, 1, 2]) + valueB = tensor([2, 3, 1, 2, 4], dtype) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A + B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2] + assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2] + assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4] + + if is_full_test(): + @ms.jit + def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor: + return add(A, B) + jit_add(A, B) diff --git a/tests/graph/sparse/test_cat.py b/tests/graph/sparse/test_cat.py new file mode 100644 index 000000000..0a188f706 --- /dev/null +++ b/tests/graph/sparse/test_cat.py @@ -0,0 +1,48 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse.cat import cat +from mindscience.sharker.sparse import SparseTensor, Layout +from mindscience.sharker.sparse.testing import tensor + + +def test_cat(): + row, col = tensor([[0, 0, 1], [0, 1, 2]], ms.int64) + mat1 = SparseTensor(row=row, col=col) + mat1.fill_cache_() + + row, col = tensor([[0, 0, 1, 2], [0, 1, 1, 0]], ms.int64) + mat2 = SparseTensor(row=row, col=col) + mat2.fill_cache_() + + out = cat([mat1, mat2], axis=0) + assert out.to_dense().tolist() == [[1, 1, 0], [0, 0, 1], [1, 1, 0], + [0, 1, 0], [1, 0, 0]] + assert out.storage.has_row() + assert out.storage.has_rowptr() + assert out.storage.has_rowcount() + assert out.storage.num_cached_keys() == 1 + + out = cat([mat1, mat2], axis=1) + assert out.to_dense().tolist() == [[1, 1, 0, 1, 1], [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0]] + assert out.storage.has_row() + assert not out.storage.has_rowptr() + assert out.storage.num_cached_keys() == 2 + + out = cat([mat1, mat2], axis=(0, 1)) + assert out.to_dense().tolist() == [[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], [0, 0, 0, 0, 1], + [0, 0, 0, 1, 0]] + assert out.storage.has_row() + assert out.storage.has_rowptr() + assert out.storage.num_cached_keys() == 5 + + value = ops.randn((mat1.nnz(), 4)) + mat1 = mat1.set_value_(value, layout=Layout.COO) + out = cat([mat1, mat1], axis=-1) + assert out.storage.value().shape == (mat1.nnz(), 8) + assert out.storage.has_row() + assert out.storage.has_rowptr() + assert out.storage.num_cached_keys() == 5 diff --git a/tests/graph/sparse/test_coalesce.py b/tests/graph/sparse/test_coalesce.py new file mode 100644 index 000000000..d218ab6f8 --- /dev/null +++ b/tests/graph/sparse/test_coalesce.py @@ -0,0 +1,34 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import coalesce + + +def test_coalesce(): + row = Tensor([1, 0, 1, 0, 2, 1]) + col = Tensor([0, 1, 1, 1, 0, 0]) + index = ops.stack([row, col], axis=0) + + index, _ = coalesce(index, None, m=3, n=2) + assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]] + + +def test_coalesce_add(): + row = Tensor([1, 0, 1, 0, 2, 1]) + col = Tensor([0, 1, 1, 1, 0, 0]) + index = ops.stack([row, col], axis=0) + value = Tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]]) + + index, value = coalesce(index, value, m=3, n=2) + assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]] + assert value.tolist() == [[6, 8], [7, 9], [3, 4], [5, 6]] + + +def test_coalesce_max(): + row = Tensor([1, 0, 1, 0, 2, 1]) + col = Tensor([0, 1, 1, 1, 0, 0]) + index = ops.stack([row, col], axis=0) + value = Tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]]) + + index, value = coalesce(index, value, m=3, n=2, op='max') + assert index.tolist() == [[0, 1, 1, 2], [1, 0, 1, 0]] + assert value.tolist() == [[4, 5], [6, 7], [3, 4], [5, 6]] diff --git a/tests/graph/sparse/test_convert.py b/tests/graph/sparse/test_convert.py new file mode 100644 index 000000000..016b4bd47 --- /dev/null +++ b/tests/graph/sparse/test_convert.py @@ -0,0 +1,24 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import to_coo_tensor, from_coo_tensor +from mindscience.sharker.sparse import to_scipy, from_scipy + + +def test_convert_scipy(): + index = Tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]]) + value = Tensor([1, 2, 4, 1, 3]) + N = 3 + + out = from_scipy(to_scipy(index, value, N, N)) + assert out[0].tolist() == index.tolist() + assert out[1].tolist() == value.tolist() + + +def test_convert_torch_sparse(): + index = Tensor([[0, 0, 1, 2, 2], [0, 2, 1, 0, 1]]) + value = Tensor([1, 2, 4, 1, 3]) + N = 3 + + out = from_coo_tensor(to_coo_tensor(index, value.float(), N, N).coalesce()) + assert out[0].tolist() == index.tolist() + assert out[1].tolist() == value.tolist() diff --git a/tests/graph/sparse/test_diag.py b/tests/graph/sparse/test_diag.py new file mode 100644 index 000000000..f8e9908a4 --- /dev/null +++ b/tests/graph/sparse/test_diag.py @@ -0,0 +1,65 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse.tensor import SparseTensor +from mindscience.sharker.sparse.testing import dtypes, tensor + + +@pytest.mark.parametrize('dtype', dtypes) +def test_remove_diag(dtype): + row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], ms.int64) + value = tensor([1, 2, 3, 4], dtype) + mat = SparseTensor(row=row, col=col, value=value) + mat.fill_cache_() + + mat = mat.remove_diag() + assert mat.storage.row().tolist() == [0, 1] + assert mat.storage.col().tolist() == [1, 2] + assert mat.storage.value().tolist() == [2, 3] + assert mat.storage.num_cached_keys() == 2 + assert mat.storage.rowcount().tolist() == [1, 1, 0] + assert mat.storage.colcount().tolist() == [0, 1, 1] + + mat = SparseTensor(row=row, col=col, value=value) + mat.fill_cache_() + + mat = mat.remove_diag(k=1) + assert mat.storage.row().tolist() == [0, 2] + assert mat.storage.col().tolist() == [0, 2] + assert mat.storage.value().tolist() == [1, 4] + assert mat.storage.num_cached_keys() == 2 + assert mat.storage.rowcount().tolist() == [1, 0, 1] + assert mat.storage.colcount().tolist() == [1, 0, 1] + + +@pytest.mark.parametrize('dtype', dtypes) +def test_set_diag(dtype): + row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], ms.int64) + value = tensor([1, 2, 3, 4], dtype) + mat = SparseTensor(row=row, col=col, value=value) + + mat = mat.set_diag(tensor([-8, -8], dtype), k=-1) + mat = mat.set_diag(tensor([-8], dtype), k=1) + + +@pytest.mark.parametrize('dtype', dtypes) +def test_fill_diag(dtype): + row, col = tensor([[0, 0, 9, 9], [0, 1, 0, 1]], ms.int64) + value = tensor([1, 2, 3, 4], dtype) + mat = SparseTensor(row=row, col=col, value=value) + + mat = mat.fill_diag(-8, k=-1) + mat = mat.fill_diag(-8, k=1) + + +@pytest.mark.parametrize('dtype', dtypes) +def test_get_diag(dtype): + row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], ms.int64) + value = tensor([[1, 1], [2, 2], [3, 3], [4, 4]], dtype) + mat = SparseTensor(row=row, col=col, value=value) + assert mat.get_diag().tolist() == [[1, 1], [0, 0], [4, 4]] + + row, col = tensor([[0, 0, 1, 2], [0, 1, 2, 2]], ms.int64) + mat = SparseTensor(row=row, col=col) + assert mat.get_diag().tolist() == [1, 0, 1] diff --git a/tests/graph/sparse/test_ego_sample.py b/tests/graph/sparse/test_ego_sample.py new file mode 100644 index 000000000..5a101e7ef --- /dev/null +++ b/tests/graph/sparse/test_ego_sample.py @@ -0,0 +1,23 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import SparseTensor +from mindscience.sharker.sparse.sampling import ego_k_hop_sample_adj + + +def test_ego_k_hop_sample_adj(): + rowptr = Tensor([0, 3, 5, 9, 10, 12, 14]) + row = Tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5]) + col = Tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]) + _ = SparseTensor(row=row, col=col, sparse_shape=(6, 6)) + + nid = Tensor([0, 1]) + out = ego_k_hop_sample_adj(rowptr, col, nid, 1, 3, False) + rowptr, col, nid, eid, ptr, root_n_id = out + + assert nid.tolist() == [0, 1, 2, 3, 0, 1, 2] + assert rowptr.tolist() == [0, 3, 5, 7, 8, 10, 12, 14] + # row [0, 0, 0, 1, 1, 2, 2, 3, 4, 4, 5, 5, 6, 6] + assert col.tolist() == [1, 2, 3, 0, 2, 0, 1, 0, 5, 6, 4, 6, 4, 5] + assert eid.tolist() == [0, 1, 2, 3, 4, 5, 6, 9, 0, 1, 3, 4, 5, 6] + assert ptr.tolist() == [0, 4, 7] + assert root_n_id.tolist() == [0, 5] diff --git a/tests/graph/sparse/test_eye.py b/tests/graph/sparse/test_eye.py new file mode 100644 index 000000000..7640aa676 --- /dev/null +++ b/tests/graph/sparse/test_eye.py @@ -0,0 +1,48 @@ +import pytest + +from mindscience.sharker.sparse.tensor import SparseTensor +from mindscience.sharker.sparse.testing import dtypes + + +@pytest.mark.parametrize('dtype', dtypes) +def test_eye(dtype): + mat = SparseTensor.eye(3, dtype=dtype) + assert mat.storage.sparse_shape == (3, 3) + assert mat.storage.row().tolist() == [0, 1, 2] + assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] + assert mat.storage.col().tolist() == [0, 1, 2] + assert mat.storage.value().tolist() == [1, 1, 1] + assert mat.storage.value().dtype == dtype + assert mat.storage.num_cached_keys() == 0 + + mat = SparseTensor.eye(3, has_value=False) + assert mat.storage.sparse_shape == (3, 3) + assert mat.storage.row().tolist() == [0, 1, 2] + assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] + assert mat.storage.col().tolist() == [0, 1, 2] + assert mat.storage.value() is None + assert mat.storage.num_cached_keys() == 0 + + mat = SparseTensor.eye(3, 4, fill_cache=True) + assert mat.storage.sparse_shape == (3, 4) + assert mat.storage.row().tolist() == [0, 1, 2] + assert mat.storage.rowptr().tolist() == [0, 1, 2, 3] + assert mat.storage.col().tolist() == [0, 1, 2] + assert mat.storage.num_cached_keys() == 5 + assert mat.storage.rowcount().tolist() == [1, 1, 1] + assert mat.storage.colptr().tolist() == [0, 1, 2, 3, 3] + assert mat.storage.colcount().tolist() == [1, 1, 1, 0] + assert mat.storage.csr2csc().tolist() == [0, 1, 2] + assert mat.storage.csc2csr().tolist() == [0, 1, 2] + + mat = SparseTensor.eye(4, 3, fill_cache=True) + assert mat.storage.sparse_shape == (4, 3) + assert mat.storage.row().tolist() == [0, 1, 2] + assert mat.storage.rowptr().tolist() == [0, 1, 2, 3, 3] + assert mat.storage.col().tolist() == [0, 1, 2] + assert mat.storage.num_cached_keys() == 5 + assert mat.storage.rowcount().tolist() == [1, 1, 1, 0] + assert mat.storage.colptr().tolist() == [0, 1, 2, 3] + assert mat.storage.colcount().tolist() == [1, 1, 1] + assert mat.storage.csr2csc().tolist() == [0, 1, 2] + assert mat.storage.csc2csr().tolist() == [0, 1, 2] diff --git a/tests/graph/sparse/test_matmul.py b/tests/graph/sparse/test_matmul.py new file mode 100644 index 000000000..dbde6aebf --- /dev/null +++ b/tests/graph/sparse/test_matmul.py @@ -0,0 +1,83 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import scatter +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.sparse.matmul import matmul, spspmm +from mindscience.sharker.sparse.tensor import SparseTensor +from mindscience.sharker.sparse.testing import grad_dtypes, reductions + + +@pytest.mark.parametrize('dtype, reduce', + product(grad_dtypes, reductions)) +def test_spmm(dtype, reduce): + src = ops.randn((10, 8), dtype=dtype) + src[2:4, :] = 0 # Remove multiple rows. + src[:, 2:4] = 0 # Remove multiple columns. + src = SparseTensor.from_dense(src) + row, col, val = src.coo() + + other = ops.randn((2, 8, 2), dtype=dtype) + + def spmm(row, col, value, other): + src_col = other.index_select(-2, col) * value.unsqueeze(-1) + expected = scatter(src_col, row, dim=-2, reduce=reduce) + # Todo: Check this segment of code: it is exceptional in gradient computation. + # if reduce == 'min': + # expected[expected > 1000] = 0 + # if reduce == 'max': + # expected[expected < -1000] = 0 + return expected + grad_spmm = ops.value_and_grad(spmm, grad_position=(2, 3)) + expected, (_, expected_grad_other) = grad_spmm(row, col, val, other) + + grad_matmul = ops.value_and_grad(matmul, grad_position=(0, 1)) + out, other_grad = grad_matmul(src, other, reduce) + + atol = 1e-1 if dtype == ms.half else 1e-3 + + assert ops.isclose(expected, out, atol=atol).all() + # assert ops.isclose(expected_grad_value, val_grad, atol=atol).all() + # assert ops.isclose(expected_grad_other, other_grad, atol=atol).all() + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_spspmm(dtype): + src = Tensor([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype=dtype) + + src = SparseTensor.from_dense(src) + out = matmul(src, src) + assert out.shape == (3, 3) + assert out.has_value() + rowptr, col, value = out.csr() + assert rowptr.tolist() == [0, 1, 2, 3] + assert col.tolist() == [0, 1, 2] + assert value.tolist() == [1, 1, 1] + + src.set_value_(None) + out = matmul(src, src) + assert out.shape == (3, 3) + assert not out.has_value() + rowptr, col, value = out.csr() + assert rowptr.tolist() == [0, 1, 2, 3] + assert col.tolist() == [0, 1, 2] + + src = ops.randn((10, 8), dtype=dtype) + src[2:4, :] = 0 # Remove multiple rows. + src[:, 2:4] = 0 # Remove multiple columns. + src = SparseTensor.from_dense(src) + + trg = ops.randn((8, 5), dtype=dtype) + trg[1:3, :] = 0 # Remove multiple rows. + trg[:, 1:3] = 0 # Remove multiple columns. + trg = SparseTensor.from_dense(trg) + + out1 = matmul(src, trg).to_dense() + out2 = src.to_dense() @ trg.to_dense() + + atol = 1e-1 if dtype == ms.half else 1e-3 + assert ops.isclose(out1, out2, atol=atol).all() + if is_full_test(): + ms.jit(spspmm) diff --git a/tests/graph/sparse/test_metis.py b/tests/graph/sparse/test_metis.py new file mode 100644 index 000000000..d3aa973dd --- /dev/null +++ b/tests/graph/sparse/test_metis.py @@ -0,0 +1,40 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse.tensor import SparseTensor +from mindscience.sharker.sparse import utils + +try: + rowptr = Tensor([0, 1]) + col = Tensor([0]) + utils.partition(rowptr, col, None, 1, None, True) + with_metis = True +except RuntimeError: + with_metis = False + + +@pytest.mark.skipif(not with_metis, reason='Not compiled with METIS support') +@pytest.mark.parametrize('weighted', [False, True]) +def test_metis(weighted): + mat1 = ops.randn(6 * 6).view(6, 6) + mat2 = ops.arange(6 * 6, dtype=ms.int64).view(6, 6) + mat3 = ops.ones(6 * 6).view(6, 6) + + vec1 = None + vec2 = ops.rand(6) + + for mat, vec in product([mat1, mat2, mat3], [vec1, vec2]): + mat = SparseTensor.from_dense(mat) + + _, partptr, perm = mat.partition(num_parts=1, recursive=False, + weighted=weighted, node_weight=vec) + assert partptr.numel() == 2 + assert perm.numel() == 6 + + _, partptr, perm = mat.partition(num_parts=2, recursive=False, + weighted=weighted, node_weight=vec) + assert partptr.numel() == 3 + assert perm.numel() == 6 diff --git a/tests/graph/sparse/test_mul.py b/tests/graph/sparse/test_mul.py new file mode 100644 index 000000000..745ad7bbf --- /dev/null +++ b/tests/graph/sparse/test_mul.py @@ -0,0 +1,52 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.sparse import SparseTensor, mul +from mindscience.sharker.sparse.testing import dtypes, tensor + + +@pytest.mark.parametrize('dtype', dtypes) +def test_sparse_sparse_mul(dtype): + rowA = Tensor([0, 0, 1, 2, 2]) + colA = Tensor([0, 2, 1, 0, 1]) + valueA = tensor([1, 2, 4, 1, 3], dtype) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = Tensor([0, 0, 1, 2, 2]) + colB = Tensor([1, 2, 2, 1, 2]) + valueB = tensor([2, 3, 1, 2, 4], dtype) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A * B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [0, 2] + assert colC.tolist() == [2, 1] + assert valueC.tolist() == [6, 6] + if is_full_test(): + @ms.jit + def jit_mul(A: SparseTensor, B: SparseTensor) -> SparseTensor: + return mul(A, B) + + jit_mul(A, B) + + +@pytest.mark.parametrize('dtype', dtypes) +def test_sparse_sparse_mul_empty(dtype): + rowA = Tensor([0]) + colA = Tensor([1]) + valueA = tensor([1], dtype) + A = SparseTensor(row=rowA, col=colA, value=valueA) + + rowB = Tensor([1]) + colB = Tensor([0]) + valueB = tensor([2], dtype) + B = SparseTensor(row=rowB, col=colB, value=valueB) + + C = A * B + rowC, colC, valueC = C.coo() + + assert rowC.tolist() == [] + assert colC.tolist() == [] + assert valueC.tolist() == [] diff --git a/tests/graph/sparse/test_neighbor_sample.py b/tests/graph/sparse/test_neighbor_sample.py new file mode 100644 index 000000000..f760a7a58 --- /dev/null +++ b/tests/graph/sparse/test_neighbor_sample.py @@ -0,0 +1,43 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import SparseTensor +from mindscience.sharker.seed import seed_everything +from mindscience.sharker.sparse.sampling import neighbor_sample + + +def test_neighbor_sample(): + adj = SparseTensor.from_edge_index(Tensor([[0], [1]])) + colptr, row, _ = adj.csc() + + # Sampling in a non-directed way should not sample in wrong direction: + out = neighbor_sample(colptr.asnumpy(), row.asnumpy(), Tensor([0]).asnumpy(), [1], False, False) + assert out[0].tolist() == [0] + assert out[1].tolist() == [] + assert out[2].tolist() == [] + + # Sampling should work: + out = neighbor_sample(colptr.asnumpy(), row.asnumpy(), Tensor([1]).asnumpy(), [1], False, False) + assert out[0].tolist() == [1, 0] + assert out[1].tolist() == [1] + assert out[2].tolist() == [0] + + # Sampling with more hops: + out = neighbor_sample(colptr.asnumpy(), row.asnumpy(), Tensor([1]).asnumpy(), [1, 1], False, False) + assert out[0].tolist() == [1, 0] + assert out[1].tolist() == [1] + assert out[2].tolist() == [0] + + +def test_neighbor_sample_seed(): + colptr = Tensor([0, 3, 6, 9]) + row = Tensor([0, 1, 2, 0, 1, 2, 0, 1, 2]) + input_nodes = Tensor([0, 1]) + + seed_everything(42) + out1 = neighbor_sample(colptr.asnumpy(), row.asnumpy(), input_nodes.asnumpy(), [1, 1], True, False) + + seed_everything(42) + out2 = neighbor_sample(colptr.asnumpy(), row.asnumpy(), input_nodes.asnumpy(), [1, 1], True, False) + + for data1, data2 in zip(out1, out2): + assert data1.tolist() == data2.tolist() diff --git a/tests/graph/sparse/test_overload.py b/tests/graph/sparse/test_overload.py new file mode 100644 index 000000000..1284b5d5e --- /dev/null +++ b/tests/graph/sparse/test_overload.py @@ -0,0 +1,26 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.tensor import SparseTensor + + +def test_overload(): + row = Tensor([0, 1, 1, 2, 2]) + col = Tensor([1, 0, 2, 1, 2]) + mat = SparseTensor(row=row, col=col) + + other = Tensor([1, 2, 3]).view(3, 1) + a = other + mat + b = mat + other + c = other * mat + d = mat * other + + other = Tensor([1, 2, 3]).view(1, 3) + e = other + mat + f = mat + other + + g = other * mat + h = mat * other + assert a == b + assert c == d + assert e == f + assert g == h diff --git a/tests/graph/sparse/test_permute.py b/tests/graph/sparse/test_permute.py new file mode 100644 index 000000000..dc7dafd11 --- /dev/null +++ b/tests/graph/sparse/test_permute.py @@ -0,0 +1,17 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse.tensor import SparseTensor +from mindscience.sharker.sparse.testing import tensor + + +def test_permute(): + row, col = tensor([[0, 0, 1, 2, 2], [0, 1, 0, 1, 2]], ms.int64) + value = tensor([1, 2, 3, 4, 5], ms.single) + adj = SparseTensor(row=row, col=col, value=value) + + row, col, value = adj.permute(Tensor([1, 0, 2])).coo() + assert row.tolist() == [0, 1, 1, 2, 2] + assert col.tolist() == [1, 0, 1, 0, 2] + assert value.tolist() == [3, 2, 1, 4, 5] diff --git a/tests/graph/sparse/test_saint.py b/tests/graph/sparse/test_saint.py new file mode 100644 index 000000000..4ccf00b42 --- /dev/null +++ b/tests/graph/sparse/test_saint.py @@ -0,0 +1,16 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse.tensor import SparseTensor + + +def test_saint_subgraph(): + row = Tensor([0, 0, 1, 1, 2, 2, 2, 3, 3, 4]) + col = Tensor([1, 2, 0, 2, 0, 1, 3, 2, 4, 3]) + adj = SparseTensor(row=row, col=col) + node_idx = Tensor([0, 1, 2]) + + adj, edge_index = adj.saint_subgraph(node_idx) + row, col, _ = adj.coo() + assert row.tolist() == [0, 0, 1, 1, 2, 2] + assert col.tolist() == [1, 2, 0, 2, 0, 1] + assert edge_index.tolist() == [0, 1, 2, 3, 4, 5] diff --git a/tests/graph/sparse/test_sample.py b/tests/graph/sparse/test_sample.py new file mode 100644 index 000000000..f509c89e2 --- /dev/null +++ b/tests/graph/sparse/test_sample.py @@ -0,0 +1,36 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import SparseTensor, sample, sample_adj + + +def test_sample(): + row = Tensor([0, 0, 2, 2]) + col = Tensor([1, 2, 0, 1]) + adj = SparseTensor(row=row, col=col, sparse_shape=(3, 3)) + + out = sample(adj, num_neighbors=1) + assert out.min() >= 0 and out.max() <= 2 + + +def test_sample_adj(): + row = Tensor([0, 0, 0, 1, 1, 2, 2, 2, 2, 3, 4, 4, 5, 5]) + col = Tensor([1, 2, 3, 0, 2, 0, 1, 4, 5, 0, 2, 5, 2, 4]) + value = ops.arange(row.shape[0]) + adj_t = SparseTensor(row=row, col=col, value=value, sparse_shape=(6, 6)) + + out, n_id = sample_adj(adj_t, ops.arange(2, 6), num_neighbors=-1) + + assert n_id.tolist() == [2, 3, 4, 5, 0, 1] + + row, col, val = out.coo() + assert row.tolist() == [0, 0, 0, 0, 1, 2, 2, 3, 3] + assert col.tolist() == [2, 3, 4, 5, 4, 0, 3, 0, 2] + assert val.tolist() == [7, 8, 5, 6, 9, 10, 11, 12, 13] + + out, n_id = sample_adj(adj_t, ops.arange(2, 6), num_neighbors=2, + replace=True) + assert out.nnz() == 8 + + out, n_id = sample_adj(adj_t, ops.arange(2, 6), num_neighbors=2, + replace=False) + assert out.nnz() == 7 # node 3 has only one edge... diff --git a/tests/graph/sparse/test_scatter.py b/tests/graph/sparse/test_scatter.py new file mode 100644 index 000000000..0ec3a270f --- /dev/null +++ b/tests/graph/sparse/test_scatter.py @@ -0,0 +1,124 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker import seed_everything +from mindscience.sharker.sparse import scatter, ptr2ind, group_argsort, group_cat + + +def test_scatter_validate(): + src = ops.randn(100, 32) + index = ops.randint(0, 10, (100, ), dtype=ms.int64) + + with pytest.raises(ValueError, match="must lay between 0 and 1"): + scatter(src, index, dim=2) + + with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"): + scatter(src, index, reduce='std') + + +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'mul', 'min', 'amin', 'max', 'amax']) +def test_scatter(reduce): + seed_everything(1) + src = ops.randn(20, 16) + ptr = ms.Tensor([0, 0, 5, 10, 15, 20]) + index = ptr2ind(ptr) + + out1 = scatter(src, index, dim=0, reduce=reduce) + out2 = scatter(src.T, index, dim=1, reduce=reduce).T + if reduce == 'mul': + expected = ops.prod(src.view(4, 5, -1), 1) + elif reduce == 'add': + expected = ops.sum(src.view(4, 5, -1), 1) + elif reduce == 'amax': + expected = ops.argmax(src.view(4, 5, -1), 1) + expected[1] += 5 + expected[2] += 10 + expected[3] += 15 + elif reduce == 'amin': + expected = ops.argmin(src.view(4, 5, -1), 1) + expected[1] += 5 + expected[2] += 10 + expected[3] += 15 + else: + expected = getattr(ops, reduce)(src.view(4, 5, -1), 1) + expected = expected[0] if isinstance(expected, tuple) else expected + + assert out1.shape == (5, 16) + assert (out1[:1] == (20 if reduce in ['amin', 'amax'] else 0)).all() + assert ops.isclose(out1[1:], expected, atol=1e-3).all() + assert ops.isclose(out1, out2, atol=1e-3).all() + + # jit = ms.jit(scatter) + # out3 = jit(src, index, dim=0, reduce=reduce) + # assert out3.shape == (8, 8) + # assert ops.isclose(out1, out3, atol=1e-6).all() + + src = ops.randn(2, 4, 8) + index = ops.randint(0, 8, (4, )) + out1 = scatter(src, index, dim=1, reduce=reduce) + assert out1.shape[0] == 2 and out1.shape[2] == 8 + + +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) +def test_scatter_gradient(reduce): + src = ops.randn([8, 100, 8]) + index = ops.randint(0, 8, (100, )) + + grad_fn = ops.value_and_grad( + scatter, grad_position=0, weights=None, has_aux=False) + value, grad = grad_fn(src, index, dim=1, reduce=reduce) + + assert value is not None + assert grad is not None + + +def test_scatter_any(): + src = ops.randn(6, 4) + index = ms.Tensor([0, 0, 1, 1, 2, 2]) + + out = scatter(src, index, dim=0, reduce='any') + + for i in range(3): + for j in range(4): + assert float(out[i, j]) in src[2 * i:2 * i + 2, j].tolist() + + +@pytest.mark.parametrize('num_groups', [4]) +@pytest.mark.parametrize('descending', [False, True]) +def test_group_argsort(num_groups, descending): + src = ops.randn(20) + index = ops.randint(0, num_groups, (20, )) + + out = group_argsort(src, index, 0, num_groups, descending=descending) + + expected = ops.zeros_like(index) + for i in range(num_groups): + mask = index == i + tmp = src[mask].argsort(descending=descending).long() + perm = ops.zeros_like(tmp) + perm[tmp] = ops.arange(tmp.numel()) + expected[mask] = perm + + assert ops.equal(out, expected).all() + + empty_tensor = ms.Tensor([]) + out = group_argsort(empty_tensor, empty_tensor) + assert out.numel() == 0 + + +def test_group_cat(): + x1 = ops.randn(4, 4) + x2 = ops.randn(2, 4) + index1 = ms.Tensor([0, 0, 1, 2]) + index2 = ms.Tensor([0, 2]) + + expected = ops.cat(([x1[:2], x2[:1], x1[2:4], x2[1:]]), axis=0) + + out, index = group_cat( + [x1, x2], + [index1, index2], + axis=0, + return_index=True, + ) + assert ops.equal(out, expected).all() + assert index.tolist() == [0, 0, 0, 1, 2, 2] diff --git a/tests/graph/sparse/test_segment.py b/tests/graph/sparse/test_segment.py new file mode 100644 index 000000000..2e8872a9f --- /dev/null +++ b/tests/graph/sparse/test_segment.py @@ -0,0 +1,30 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.sparse import segment + + +@pytest.mark.parametrize('reduce', ['sum', 'mean', 'mul', 'min', 'max', 'amax', 'amin']) +def test_segment(reduce): + src = ops.randn(20, 16) + ptr = ms.Tensor([0, 0, 5, 10, 15, 20]) + out = segment(src, ptr, dim=0, reduce=reduce) + out1 = segment(src.T, ptr, dim=1, reduce=reduce).T + + if reduce == 'mul': + expected = ops.prod(src.view(4, 5, -1), 1) + elif reduce == 'amax': + expected = ops.argmax(src.view(4, 5, -1), 1) + elif reduce == 'amin': + expected = ops.argmin(src.view(4, 5, -1), 1) + else: + expected = getattr(ops, reduce)(src.view(4, 5, -1), 1) + expected = expected[0] if isinstance(expected, tuple) else expected + + assert ops.isclose(out[:1], ops.zeros([1, 16], dtype=out.dtype)).all() + assert ops.isclose(out[1:], expected).all() + assert ops.isclose(out, out1).all() + + # jit = ms.jit(segment) + # out1 = jit(src, ptr, reduce=reduce) + # assert ops.isclose(out, out1).all() diff --git a/tests/graph/sparse/test_spmm.py b/tests/graph/sparse/test_spmm.py new file mode 100644 index 000000000..498ea8fae --- /dev/null +++ b/tests/graph/sparse/test_spmm.py @@ -0,0 +1,20 @@ +from itertools import product + +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse import spmm +from mindscience.sharker.sparse.testing import dtypes, tensor + + +@pytest.mark.parametrize('dtype', dtypes) +def test_spmm(dtype): + row = Tensor([0, 0, 1, 2, 2]) + col = Tensor([0, 2, 1, 0, 1]) + index = ops.stack([row, col], axis=0) + value = tensor([1, 2, 4, 1, 3], dtype) + x = tensor([[1, 4], [2, 5], [3, 6]], dtype) + + out = spmm(index, value, 3, 3, x) + assert out.tolist() == [[7, 16], [8, 20], [7, 19]] diff --git a/tests/graph/sparse/test_spspmm.py b/tests/graph/sparse/test_spspmm.py new file mode 100644 index 000000000..3adeb5ffc --- /dev/null +++ b/tests/graph/sparse/test_spspmm.py @@ -0,0 +1,48 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse import SparseTensor, spspmm +from mindscience.sharker.sparse.testing import grad_dtypes, tensor + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_spspmm(dtype): + if dtype in {ms.half, ms.bfloat16}: + return # Not yet implemented. + + indexA = Tensor([[0, 0, 1, 2, 2], [1, 2, 0, 0, 1]]) + valueA = tensor([1, 2, 3, 4, 5], dtype) + indexB = Tensor([[0, 2], [1, 0]]) + valueB = tensor([2, 4], dtype) + + indexC, valueC = spspmm(indexA, valueA, indexB, valueB, 3, 3, 2) + assert indexC.tolist() == [[0, 1, 2], [0, 1, 1]] + assert valueC.tolist() == [8, 6, 8] + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_sparse_tensor_spspmm(dtype): + if dtype in {ms.half, ms.bfloat16}: + return # Not yet implemented. + + x = SparseTensor( + row=Tensor( + [0, 1, 1, 1, 2, 3, 4, 5, 5, 6, 6, 7, 7, 7, 8, 8, 9, 9]), + col=Tensor( + [0, 5, 10, 15, 1, 2, 3, 7, 13, 6, 9, 5, 10, 15, 11, 14, 5, 15]), + value=Tensor([ + 1, 3**-0.5, 3**-0.5, 3**-0.5, 1, 1, 1, -2**-0.5, -2**-0.5, + -2**-0.5, -2**-0.5, 6**-0.5, -6**0.5 / 3, 6**-0.5, -2**-0.5, + -2**-0.5, 2**-0.5, -2**-0.5 + ], dtype=dtype), + ) + + expected = ops.eye(10, dtype=dtype) + + out = x @ x.to_dense().t() + assert ops.isclose(out, expected, atol=1e-2).all() + + out = x @ x.t() + out = out.to_dense() + assert ops.isclose(out, expected, atol=1e-2).all() diff --git a/tests/graph/sparse/test_storage.py b/tests/graph/sparse/test_storage.py new file mode 100644 index 000000000..4f788b4fd --- /dev/null +++ b/tests/graph/sparse/test_storage.py @@ -0,0 +1,143 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.sparse import SparseStorage, Layout +from mindscience.sharker.sparse.testing import dtypes + + +@pytest.mark.parametrize('dtype', dtypes) +def test_storage(dtype): + row, col = Tensor([[0, 0, 1, 1], [0, 1, 0, 1]], ms.int64) + + storage = SparseStorage(row=row, col=col) + assert storage.row().tolist() == [0, 0, 1, 1] + assert storage.col().tolist() == [0, 1, 0, 1] + assert storage.value() is None + assert storage.sparse_shape == (2, 2) + + row, col = Tensor([[0, 0, 1, 1], [1, 0, 1, 0]], ms.int64) + value = Tensor([2, 1, 4, 3], dtype) + storage = SparseStorage(row=row, col=col, value=value) + assert storage.row().tolist() == [0, 0, 1, 1] + assert storage.col().tolist() == [0, 1, 0, 1] + assert storage.value().tolist() == [1, 2, 3, 4] + assert storage.sparse_shape == (2, 2) + + +@pytest.mark.parametrize('dtype', dtypes) +def test_caching(dtype): + row, col = Tensor([[0, 0, 1, 1], [0, 1, 0, 1]], ms.int64) + storage = SparseStorage(row=row, col=col) + + assert storage._row.tolist() == row.tolist() + assert storage._col.tolist() == col.tolist() + assert storage._value is None + + assert storage._rowcount is None + assert storage._rowptr is None + assert storage._colcount is None + assert storage._colptr is None + assert storage._csr2csc is None + assert storage.num_cached_keys() == 0 + + storage.fill_cache_() + assert storage._rowcount.tolist() == [2, 2] + assert storage._rowptr.tolist() == [0, 2, 4] + assert storage._colcount.tolist() == [2, 2] + assert storage._colptr.tolist() == [0, 2, 4] + assert storage._csr2csc.tolist() == [0, 2, 1, 3] + assert storage._csc2csr.tolist() == [0, 2, 1, 3] + assert storage.num_cached_keys() == 5 + + storage = SparseStorage(row=row, rowptr=storage._rowptr, col=col, + value=storage._value, + sparse_shape=storage._sparse_shape, + rowcount=storage._rowcount, colptr=storage._colptr, + colcount=storage._colcount, + csr2csc=storage._csr2csc, csc2csr=storage._csc2csr) + + assert storage._rowcount.tolist() == [2, 2] + assert storage._rowptr.tolist() == [0, 2, 4] + assert storage._colcount.tolist() == [2, 2] + assert storage._colptr.tolist() == [0, 2, 4] + assert storage._csr2csc.tolist() == [0, 2, 1, 3] + assert storage._csc2csr.tolist() == [0, 2, 1, 3] + assert storage.num_cached_keys() == 5 + + storage.clear_cache_() + assert storage._rowcount is None + assert storage._rowptr is not None + assert storage._colcount is None + assert storage._colptr is None + assert storage._csr2csc is None + assert storage.num_cached_keys() == 0 + + +@pytest.mark.parametrize('dtype', dtypes) +def test_utility(dtype): + row, col = Tensor([[0, 0, 1, 1], [1, 0, 1, 0]], ms.int64) + value = Tensor([1, 2, 3, 4], dtype) + storage = SparseStorage(row=row, col=col, value=value) + + assert storage.has_value() + + storage.set_value_(value, layout=Layout.CSC) + assert storage.value().tolist() == [1, 3, 2, 4] + storage.set_value_(value, layout=Layout.COO) + assert storage.value().tolist() == [1, 2, 3, 4] + + storage = storage.set_value(value, layout=Layout.CSC) + assert storage.value().tolist() == [1, 3, 2, 4] + storage = storage.set_value(value, layout=Layout.COO) + assert storage.value().tolist() == [1, 2, 3, 4] + + storage = storage.sparse_resize((3, 3)) + assert storage.sparse_shape == (3, 3) + + new_storage = storage.copy() + assert new_storage != storage + assert new_storage.col() is storage.col() + + new_storage = storage.clone() + assert new_storage != storage + assert new_storage.col() is not storage.col() + + +@pytest.mark.parametrize('dtype', dtypes) +def test_coalesce(dtype): + row, col = Tensor([[0, 0, 0, 1, 1], [0, 1, 1, 0, 1]], ms.int64) + value = Tensor([1, 1, 1, 3, 4], dtype) + storage = SparseStorage(row=row, col=col, value=value) + + assert storage.row().tolist() == row.tolist() + assert storage.col().tolist() == col.tolist() + assert storage.value().tolist() == value.tolist() + + assert not storage.is_coalesced() + storage = storage.coalesce() + assert storage.is_coalesced() + + assert storage.row().tolist() == [0, 0, 1, 1] + assert storage.col().tolist() == [0, 1, 0, 1] + assert storage.value().tolist() == [1, 2, 3, 4] + + +@pytest.mark.parametrize('dtype', dtypes) +def test_sparse_reshape(dtype): + row, col = Tensor([[0, 1, 2, 3], [0, 1, 2, 3]], ms.int64) + storage = SparseStorage(row=row, col=col) + + storage = storage.sparse_reshape(2, 8) + assert storage.sparse_shape == (2, 8) + assert storage.row().tolist() == [0, 0, 1, 1] + assert storage.col().tolist() == [0, 5, 2, 7] + + storage = storage.sparse_reshape(-1, 4) + assert storage.sparse_shape == (4, 4) + assert storage.row().tolist() == [0, 1, 2, 3] + assert storage.col().tolist() == [0, 1, 2, 3] + + storage = storage.sparse_reshape(2, -1) + assert storage.sparse_shape == (2, 8) + assert storage.row().tolist() == [0, 0, 1, 1] + assert storage.col().tolist() == [0, 5, 2, 7] diff --git a/tests/graph/sparse/test_tensor.py b/tests/graph/sparse/test_tensor.py new file mode 100644 index 000000000..3dba93337 --- /dev/null +++ b/tests/graph/sparse/test_tensor.py @@ -0,0 +1,104 @@ +from mindscience.sharker.sparse.testing import dtypes, grad_dtypes +from mindscience.sharker.sparse import ptr2ind, ind2ptr +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse import SparseTensor + + +@pytest.mark.parametrize('dtype', grad_dtypes) +def test_getitem(dtype): + m = 50 + n = 40 + k = 10 + mat = ops.randn(m, n, dtype=dtype) + mat = SparseTensor.from_dense(mat) + + idx1 = ops.randint(0, m, (k, ), dtype=ms.int64) + idx2 = ops.randint(0, n, (k, ), dtype=ms.int64) + bool1 = ops.zeros(m, dtype=ms.bool_) + bool2 = ops.zeros(n, dtype=ms.bool_) + bool1[idx1] = 1 + bool2[idx2] = 1 + # idx1 and idx2 may have duplicates + k1_bool = bool1.nonzero().shape[0] + k2_bool = bool2.nonzero().shape[0] + + idx1np = idx1.asnumpy() + idx2np = idx2.asnumpy() + bool1np = bool1.asnumpy() + bool2np = bool2.asnumpy() + + idx1list = idx1np.tolist() + idx2list = idx2np.tolist() + bool1list = bool1np.tolist() + bool2list = bool2np.tolist() + + assert mat[:k, :k].shape == (k, k) + assert mat[..., :k].shape == (m, k) + + assert mat[idx1, idx2].shape == (k, k) + assert mat[idx1np, idx2np].shape == (k, k) + assert mat[idx1list, idx2list].shape == (k, k) + + assert mat[bool1, bool2].shape == (k1_bool, k2_bool) + assert mat[bool1np, bool2np].shape == (k1_bool, k2_bool) + assert mat[bool1list, bool2list].shape == (k1_bool, k2_bool) + + assert mat[idx1].shape == (k, n) + assert mat[idx1np].shape == (k, n) + assert mat[idx1list].shape == (k, n) + + assert mat[bool1].shape == (k1_bool, n) + assert mat[bool1np].shape == (k1_bool, n) + assert mat[bool1list].shape == (k1_bool, n) + + +def test_to_symmetric(): + row = Tensor([0, 0, 0, 1, 1]) + col = Tensor([0, 1, 2, 0, 2]) + value = ops.arange(1, 6) + mat = SparseTensor(row=row, col=col, value=value) + assert not mat.is_symmetric() + + mat = mat.to_symmetric() + + assert mat.is_symmetric() + assert mat.to_dense().tolist() == [ + [2, 6, 3], + [6, 0, 5], + [3, 5, 0], + ] + + +def test_equal(): + row = Tensor([0, 0, 0, 1, 1]) + col = Tensor([0, 1, 2, 0, 2]) + value = ops.arange(1, 6) + matA = SparseTensor(row=row, col=col, value=value) + matB = SparseTensor(row=row, col=col, value=value) + col = Tensor([0, 1, 2, 0, 1]) + matC = SparseTensor(row=row, col=col, value=value) + + assert id(matA) != id(matB) + assert matA == matB + + assert id(matA) != id(matC) + assert matA != matC + + +def test_ind2ptr(): + row = Tensor([2, 2, 4, 5, 5, 6], ms.int64) + rowptr = ind2ptr(row, 8) + assert rowptr.tolist() == [0, 0, 0, 2, 2, 3, 5, 6, 6] + + row = ptr2ind(rowptr, 6) + assert row.tolist() == [2, 2, 4, 5, 5, 6] + + row = Tensor([], ms.int64) + rowptr = ind2ptr(row, 8) + assert rowptr.tolist() == [0, 0, 0, 0, 0, 0, 0, 0, 0] + + row = ptr2ind(rowptr, 0) + assert row.tolist() == [] diff --git a/tests/graph/sparse/test_transpose.py b/tests/graph/sparse/test_transpose.py new file mode 100644 index 000000000..6caed02c4 --- /dev/null +++ b/tests/graph/sparse/test_transpose.py @@ -0,0 +1,30 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.sparse import transpose +from mindscience.sharker.sparse.testing import dtypes, tensor + + +@pytest.mark.parametrize('dtype', dtypes) +def test_transpose_matrix(dtype): + row = Tensor([1, 0, 1, 2]) + col = Tensor([0, 1, 1, 0]) + index = ops.stack([row, col], axis=0) + value = tensor([1, 2, 3, 4], dtype) + + index, value = transpose(index, value, m=3, n=2) + assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]] + assert value.tolist() == [1, 4, 2, 3] + + +@pytest.mark.parametrize('dtype', dtypes) +def test_transpose(dtype): + row = Tensor([1, 0, 1, 0, 2, 1]) + col = Tensor([0, 1, 1, 1, 0, 0]) + index = ops.stack([row, col], axis=0) + value = tensor([[1, 2], [2, 3], [3, 4], [4, 5], [5, 6], [6, 7]], dtype) + + index, value = transpose(index, value, m=3, n=2) + assert index.tolist() == [[0, 0, 1, 1], [1, 2, 0, 1]] + assert value.tolist() == [[7, 9], [5, 6], [6, 8], [3, 4]] diff --git a/tests/graph/test_config_store.py b/tests/graph/test_config_store.py new file mode 100644 index 000000000..0e8fb2113 --- /dev/null +++ b/tests/graph/test_config_store.py @@ -0,0 +1,147 @@ +from typing import Any + +from mindscience.sharker.config_store import ( + class_from_dataclass, + clear_config_store, + dataclass_from_class, + fill_config_store, + get_config_store, + register, + to_dataclass, +) +from mindscience.sharker.testing import withPackage +from mindscience.sharker.transforms import AddSelfLoops + + +def teardown_function(): + clear_config_store() + + +def test_to_dataclass(): + from mindscience.sharker.transforms import AddSelfLoops + + AddSelfLoopsConfig = to_dataclass(AddSelfLoops, with_target=True) + assert AddSelfLoopsConfig.__name__ == 'AddSelfLoops' + + fields = AddSelfLoopsConfig.__dataclass_fields__ + + assert fields['attr'].name == 'attr' + assert fields['attr'].type == str + assert fields['attr'].default == 'edge_weight' + + assert fields['fill_value'].name == 'fill_value' + assert fields['fill_value'].type == Any + assert fields['fill_value'].default == 1.0 + + assert fields['_target_'].name == '_target_' + assert fields['_target_'].type == str + assert fields['_target_'].default == ( + 'sharker.transforms.add_self_loops.AddSelfLoops') + + cfg = AddSelfLoopsConfig() + assert str(cfg) == ("AddSelfLoops(attr='edge_weight', fill_value=1.0, " + "_target_='sharker.transforms.add_self_loops." + "AddSelfLoops')") + + +def test_register(): + register(AddSelfLoops, group='transform') + assert 'transform' in get_config_store().repo + + AddSelfLoopsConfig = dataclass_from_class('AddSelfLoops') + + Cls = class_from_dataclass('AddSelfLoops') + assert Cls == AddSelfLoops + Cls = class_from_dataclass(AddSelfLoopsConfig) + assert Cls == AddSelfLoops + + ConfigCls = dataclass_from_class('AddSelfLoops') + assert ConfigCls == AddSelfLoopsConfig + ConfigCls = dataclass_from_class(ConfigCls) + assert ConfigCls == AddSelfLoopsConfig + + +def test_fill_config_store(): + fill_config_store() + + assert { + 'transform', + 'dataset', + 'model', + 'optimizer', + 'lr_scheduler', + }.issubset(get_config_store().repo.keys()) + + +@withPackage('hydra') +def test_hydra_config_store(): + import hydra + from omegaconf import DictConfig + + fill_config_store() + + with hydra.initialize(config_path='.', version_base='1.1'): + cfg = hydra.compose(config_name='my_config') + + assert len(cfg) == 4 + assert 'dataset' in cfg + assert 'model' in cfg + assert 'optimizer' in cfg + assert 'lr_scheduler' in cfg + + # Check `cfg.dataset`: + assert len(cfg.dataset) == 2 + assert cfg.dataset._target_.split('.')[-1] == 'KarateClub' + + # Check `cfg.dataset.transform`: + assert isinstance(cfg.dataset.transform, DictConfig) + assert len(cfg.dataset.transform) == 2 + assert 'NormalizeFeatures' in cfg.dataset.transform + assert 'AddSelfLoops' in cfg.dataset.transform + + assert isinstance(cfg.dataset.transform.NormalizeFeatures, DictConfig) + assert (cfg.dataset.transform.NormalizeFeatures._target_.split('.')[-1] == + 'NormalizeFeatures') + assert cfg.dataset.transform.NormalizeFeatures.attrs == ['x'] + + assert isinstance(cfg.dataset.transform.AddSelfLoops, DictConfig) + assert (cfg.dataset.transform.AddSelfLoops._target_.split('.')[-1] == + 'AddSelfLoops') + assert cfg.dataset.transform.AddSelfLoops.attr == 'edge_weight' + assert cfg.dataset.transform.AddSelfLoops.fill_value == 1.0 + + # Check `cfg.model`: + assert len(cfg.model) == 12 + assert cfg.model._target_.split('.')[-1] == 'GCN' + assert cfg.model.in_channels == 34 + assert cfg.model.out_channels == 4 + assert cfg.model.hidden_channels == 16 + assert cfg.model.num_layers == 2 + assert cfg.model.dropout == 0.0 + assert cfg.model.act == 'relu' + assert cfg.model.norm is None + assert cfg.model.norm_kwargs is None + assert cfg.model.jk is None + assert not cfg.model.act_first + assert cfg.model.act_kwargs is None + + # Check `cfg.optimizer`: + assert cfg.optimizer._target_.split('.')[-1] == 'Adam' + assert cfg.optimizer.lr == 0.001 + assert cfg.optimizer.betas == [0.9, 0.999] + assert cfg.optimizer.eps == 1e-08 + assert cfg.optimizer.weight_decay == 0 + assert not cfg.optimizer.amsgrad + if hasattr(cfg.optimizer, 'maximize'): + assert not cfg.optimizer.maximize + + # Check `cfg.lr_scheduler`: + assert cfg.lr_scheduler._target_.split('.')[-1] == 'ReduceLROnPlateau' + assert cfg.lr_scheduler.mode == 'min' + assert cfg.lr_scheduler.factor == 0.1 + assert cfg.lr_scheduler.patience == 10 + assert cfg.lr_scheduler.threshold == 0.0001 + assert cfg.lr_scheduler.threshold_mode == 'rel' + assert cfg.lr_scheduler.cooldown == 0 + assert cfg.lr_scheduler.min_lr == 0 + assert cfg.lr_scheduler.eps == 1e-08 diff --git a/tests/graph/test_debug.py b/tests/graph/test_debug.py new file mode 100644 index 000000000..33b0d4cc0 --- /dev/null +++ b/tests/graph/test_debug.py @@ -0,0 +1,28 @@ +from mindscience.sharker import debug, is_debug_enabled, set_debug + + +def test_debug(): + assert is_debug_enabled() is False + set_debug(True) + assert is_debug_enabled() is True + set_debug(False) + assert is_debug_enabled() is False + + assert is_debug_enabled() is False + with set_debug(True): + assert is_debug_enabled() is True + assert is_debug_enabled() is False + + assert is_debug_enabled() is False + set_debug(True) + assert is_debug_enabled() is True + with set_debug(False): + assert is_debug_enabled() is False + assert is_debug_enabled() is True + set_debug(False) + assert is_debug_enabled() is False + + assert is_debug_enabled() is False + with debug(): + assert is_debug_enabled() is True + assert is_debug_enabled() is False diff --git a/tests/graph/test_edge_index.py b/tests/graph/test_edge_index.py new file mode 100644 index 000000000..822318536 --- /dev/null +++ b/tests/graph/test_edge_index.py @@ -0,0 +1,1183 @@ +import os.path as osp +import warnings +from typing import List, Optional + +import pytest +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker.loader import DataLoader +from mindscience.sharker import EdgeIndex +from mindscience.sharker.edge_index import ( + SUPPORTED_DTYPES, + ReduceType, + SortReturnType, + _scatter_spmm, + _torch_sparse_spmm, + _TorchSPMM, + set_tuple_item, +) +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import withoutExtensions +from mindscience.sharker.sparse import SparseTensor, Layout, scatter +from mindscience.sharker import typing + +DTYPES = [pytest.param(dtype, id=str(dtype)[6:]) for dtype in SUPPORTED_DTYPES] +IS_UNDIRECTED = [ + pytest.param(False, id='directed'), + pytest.param(True, id='undirected'), +] +TRANSPOSE = [ + pytest.param(False, id=''), + pytest.param(True, id='transpose'), +] + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_basic(dtype, device): + kwargs = dict(dtype=dtype, sparse_shape=(3, 3)) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + adj.validate() + assert isinstance(adj, EdgeIndex) + + if typing.WITH_PT112: + assert str(adj).startswith('EdgeIndex([[0, 1, 1, 2],\n' + ' [1, 0, 2, 1]], ') + else: + assert str(adj).startswith('tensor([[0, 1, 1, 2],\n' + ' [1, 0, 2, 1]], ') + assert str(adj).endswith('sparse_shape=(3, 3), nnz=4)') + assert (f"device='{device}'" in str(adj)) == adj.is_cuda + assert (f'dtype={dtype}' in str(adj)) == (dtype != ms.int64) + + assert adj.dtype == dtype + assert adj.device == device + assert adj.sparse_shape == (3, 3) + assert adj.sparse_shape[0] == 3 + assert adj.sparse_shape[-1] == 3 + + assert adj.sort_order is None + assert not adj.is_sorted + assert not adj.is_sorted_by_row + assert not adj.is_sorted_by_col + + assert not adj.is_undirected + + out = adj.as_tensor() + assert not isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device + + out = adj * 1 + assert not isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_identity(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_shape=(3, 3), **kwargs) + + out = EdgeIndex(adj) + assert out.data_ptr() == adj.data_ptr() + assert out.dtype == adj.dtype + assert out.device == adj.device + assert out.sparse_shape == adj.sparse_shape + assert out.sort_order == adj.sort_order + assert out.is_undirected == adj.is_undirected + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_sparse_tensor(dtype): + kwargs = dict(dtype=dtype, is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out = EdgeIndex(adj.to_coo()) + assert out.equal(adj) + assert out.sort_order == 'row' + assert out.sparse_shape == (3, 3) + assert out._indptr is None + + out = EdgeIndex(adj.to_csr()) + assert out.equal(adj) + assert out.sort_order == 'row' + assert out.sparse_shape == (3, 3) + assert out._indptr.equal(Tensor([0, 1, 3, 4])) + + out = EdgeIndex(adj.to_sparse_csc()) + assert out.equal(adj.sort_by('col')[0]) + assert out.sort_order == 'col' + assert out.sparse_shape == (3, 3) + assert out._indptr.equal(Tensor([0, 1, 3, 4])) + + +def test_set_tuple_item(): + tmp = (0, 1, 2) + assert set_tuple_item(tmp, 0, 3) == (3, 1, 2) + assert set_tuple_item(tmp, 1, 3) == (0, 3, 2) + assert set_tuple_item(tmp, 2, 3) == (0, 1, 3) + with pytest.raises(IndexError, match="tuple index out of range"): + set_tuple_item(tmp, 3, 3) + assert set_tuple_item(tmp, -1, 3) == (0, 1, 3) + assert set_tuple_item(tmp, -2, 3) == (0, 3, 2) + assert set_tuple_item(tmp, -3, 3) == (3, 1, 2) + with pytest.raises(IndexError, match="tuple index out of range"): + set_tuple_item(tmp, -4, 3) + + +def test_validate(): + with pytest.raises(ValueError, match="unsupported data type"): + EdgeIndex([[0.0, 1.0], [1.0, 0.0]]) + with pytest.raises(ValueError, match="needs to be two-dimensional"): + EdgeIndex([[[0], [1]], [[1], [0]]]) + with pytest.raises(ValueError, match="needs to have a shape of"): + EdgeIndex([[0, 1], [1, 0], [1, 1]]) + with pytest.raises(ValueError, match="received a non-symmetric size"): + EdgeIndex([[0, 1], [1, 0]], is_undirected=True, sparse_shape=(2, 3)) + with pytest.raises(TypeError, match="invalid combination of arguments"): + EdgeIndex(ms.Tensor([[0, 1], [1, 0]]), ms.int64) + with pytest.raises(TypeError, match="invalid keyword arguments"): + EdgeIndex(ms.Tensor([[0, 1], [1, 0]]), dtype=ms.int64) + with pytest.raises(ValueError, match="contains negative indices"): + EdgeIndex([[-1, 0], [0, 1]]).validate() + with pytest.raises(ValueError, match="than its number of rows"): + EdgeIndex([[0, 10], [1, 0]], sparse_shape=(2, 2)).validate() + with pytest.raises(ValueError, match="than its number of columns"): + EdgeIndex([[0, 1], [10, 0]], sparse_shape=(2, 2)).validate() + with pytest.raises(ValueError, match="not sorted by row indices"): + EdgeIndex([[1, 0], [0, 1]], sort_order='row').validate() + with pytest.raises(ValueError, match="not sorted by column indices"): + EdgeIndex([[0, 1], [1, 0]], sort_order='col').validate() + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_undirected(dtype, device): + kwargs = dict(dtype=dtype, is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + assert isinstance(adj, EdgeIndex) + assert adj.is_undirected + + assert adj.sparse_shape == (None, None) + adj.get_num_rows() + assert adj.sparse_shape == (3, 3) + adj.validate() + + adj = EdgeIndex([[0, 1], [1, 0]], sparse_shape=(3, None), **kwargs) + assert adj.sparse_shape == (3, 3) + adj.validate() + + adj = EdgeIndex([[0, 1], [1, 0]], sparse_shape=(None, 3), **kwargs) + assert adj.sparse_shape == (3, 3) + adj.validate() + + with pytest.raises(ValueError, match="'EdgeIndex' is not undirected"): + EdgeIndex([[0, 1, 1, 2], [0, 0, 1, 1]], **kwargs).validate() + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_fill_cache_(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.validate().fill_cache_() + assert adj.sparse_shape == (3, 3) + assert adj._indptr.dtype == dtype + assert adj._indptr.equal(Tensor([0, 1, 3, 4])) + assert adj._T_perm.dtype == ms.int64 + assert (adj._T_perm.equal(Tensor([1, 0, 3, 2])) + or adj._T_perm.equal(Tensor([1, 3, 0, 2]))) + assert adj._T_index[0].dtype == dtype + assert (adj._T_index[0].equal(Tensor([1, 0, 2, 1])) + or adj._T_index[0].equal(Tensor([1, 2, 0, 1]))) + assert adj._T_index[1].dtype == dtype + assert adj._T_index[1].equal(Tensor([0, 1, 1, 2])) + if is_undirected: + assert adj._T_indptr is None + else: + assert adj._T_indptr.dtype == dtype + assert adj._T_indptr.equal(Tensor([0, 1, 3, 4])) + + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) + adj.validate().fill_cache_() + assert adj.sparse_shape == (3, 3) + assert adj._indptr.dtype == dtype + assert adj._indptr.equal(Tensor([0, 1, 3, 4])) + assert (adj._T_perm.equal(Tensor([1, 0, 3, 2])) + or adj._T_perm.equal(Tensor([1, 3, 0, 2]))) + assert adj._T_index[0].dtype == dtype + assert adj._T_index[0].equal(Tensor([0, 1, 1, 2])) + assert adj._T_index[1].dtype == dtype + assert (adj._T_index[1].equal(Tensor([1, 0, 2, 1])) + or adj._T_index[1].equal(Tensor([1, 2, 0, 1]))) + if is_undirected: + assert adj._T_indptr is None + else: + assert adj._T_indptr.dtype == dtype + assert adj._T_indptr.equal(Tensor([0, 1, 3, 4])) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_clone(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out = adj.copy() + assert isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device + assert out.is_sorted_by_row + assert out.is_undirected == is_undirected + + out = adj.copy() + assert isinstance(out, EdgeIndex) + assert out.dtype == dtype + assert out.device == device + assert out.is_sorted_by_row + assert out.is_undirected == is_undirected + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_to(dtype, device, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + adj = adj + assert isinstance(adj, EdgeIndex) + assert adj.device == device + assert adj._indptr.device == device + assert adj._T_perm.device == device + + out = adj + assert isinstance(out, EdgeIndex) + # assert out.device == torch.device('cpu') + + out = adj.int() + assert out.dtype == ms.int32 + assert isinstance(out, EdgeIndex) + assert out._indptr.dtype == ms.int32 + assert out._T_perm.dtype == ms.int32 + + out = adj.float() + assert not isinstance(out, EdgeIndex) + assert out.dtype == ms.float32 + + out = adj.long() + assert isinstance(out, EdgeIndex) + assert out.dtype == ms.int64 + + out = adj.int() + assert out.dtype == ms.int32 + assert isinstance(out, EdgeIndex) + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_cpu_cuda(dtype): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + + out = adj + assert isinstance(out, EdgeIndex) + assert out.is_cuda + + out = out + assert isinstance(out, EdgeIndex) + assert not out.is_cuda + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_share_memory(dtype): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + adj = adj.share_memory_() + assert isinstance(adj, EdgeIndex) + assert adj.is_shared() + assert adj._indptr.is_shared() + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_contiguous(dtype): + kwargs = dict(dtype=dtype) + data = Tensor([[0, 1], [1, 0], [1, 2], [2, 1]], **kwargs).t() + + with pytest.raises(ValueError, match="needs to be contiguous"): + EdgeIndex(data) + + adj = EdgeIndex(data) + assert isinstance(adj, EdgeIndex) + # assert adj.is_contiguous() + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_sort_by(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + out = adj.sort_by('row') + assert isinstance(out, SortReturnType) + assert isinstance(out.values, EdgeIndex) + assert not isinstance(out.indices, EdgeIndex) + assert out.values.equal(adj) + assert out.indices == slice(None, None, None) + + adj = EdgeIndex([[0, 1, 2, 1], [1, 0, 1, 2]], **kwargs) + out = adj.sort_by('row') + assert isinstance(out, SortReturnType) + assert isinstance(out.values, EdgeIndex) + assert not isinstance(out.indices, EdgeIndex) + assert out.values[0].equal(Tensor([0, 1, 1, 2])) + assert (out.values[1].equal(Tensor([1, 0, 2, 1])) + or out.values[1].equal(Tensor([1, 2, 0, 1]))) + assert (out.indices.equal(Tensor([0, 1, 3, 2])) + or out.indices.equal(Tensor([0, 3, 1, 2]))) + + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out, perm = adj.sort_by('col') + assert adj._T_perm is not None # Check caches. + assert adj._T_index[0] is not None and adj._T_index[1] is not None + assert (out[0].equal(Tensor([1, 0, 2, 1])) + or out[0].equal(Tensor([1, 2, 0, 1]))) + assert out[1].equal(Tensor([0, 1, 1, 2])) + assert (perm.equal(Tensor([1, 0, 3, 2])) + or perm.equal(Tensor([1, 3, 0, 2]))) + assert out._T_perm is None + assert out._T_index[0] is None and out._T_index[1] is None + + out, perm = out.sort_by('row') + assert out[0].equal(Tensor([0, 1, 1, 2])) + assert (out[1].equal(Tensor([1, 0, 2, 1])) + or out[1].equal(Tensor([1, 2, 0, 1]))) + assert (perm.equal(Tensor([1, 0, 3, 2])) + or perm.equal(Tensor([2, 3, 0, 1]))) + assert out._T_perm is None + assert out._T_index[0] is None and out._T_index[1] is None + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_cat(dtype, is_undirected): + args = dict(dtype=dtype, is_undirected=is_undirected) + adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_shape=(3, 3), **args) + adj2 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], sparse_shape=(4, 4), **args) + adj3 = EdgeIndex([[1, 2, 2, 3], [2, 1, 3, 2]], dtype=dtype) + + out = ops.cat(([adj1]), axis=1) + assert id(out) == id(adj1) + + out = ops.cat(([adj1, adj2]), axis=1) + assert out.shape == (2, 8) + assert isinstance(out, EdgeIndex) + assert out.sparse_shape == (4, 4) + assert not out.is_sorted + assert out.is_undirected == is_undirected + + assert out._cat_metadata.nnz == [4, 4] + assert out._cat_metadata.sparse_shape == [(3, 3), (4, 4)] + assert out._cat_metadata.sort_order == [None, None] + assert out._cat_metadata.is_undirected == [is_undirected, is_undirected] + + out = ops.cat(([adj1, adj2, adj3]), axis=1) + assert out.shape == (2, 12) + assert isinstance(out, EdgeIndex) + assert out.sparse_shape == (None, None) + assert not out.is_sorted + assert not out.is_undirected + + out = ops.cat(([adj1, adj2]), axis=0) + assert out.shape == (4, 4) + assert not isinstance(out, EdgeIndex) + + inplace = ms.numpy.empty((2, 8), dtype=dtype) + out = ops.cat(([adj1, adj2]), axis=1, out=inplace) + assert isinstance(out, EdgeIndex) + assert not isinstance(inplace, EdgeIndex) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_flip(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + out = adj.flip(0) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 0, 2, 1], [0, 1, 1, 2]])) + assert out.is_sorted_by_col + assert out.is_undirected == is_undirected + assert out._T_indptr.equal(Tensor([0, 1, 3, 4])) + + out = adj.flip([0, 1]) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 2, 0, 1], [2, 1, 1, 0]])) + assert not out.is_sorted + assert out.is_undirected == is_undirected + assert out._T_indptr is None + + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) + out = adj.flip(0) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) + assert out.is_sorted_by_row + assert out.is_undirected == is_undirected + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_index_select(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + index = Tensor([1, 3]) + out = adj.index_select(1, index) + assert out.equal(Tensor([[1, 2], [0, 1]])) + assert isinstance(out, EdgeIndex) + assert not out.is_sorted + assert not out.is_undirected + + index = Tensor([0]) + out = adj.index_select(0, index) + assert out.equal(Tensor([[0, 1, 1, 2]])) + assert not isinstance(out, EdgeIndex) + + index = Tensor([1, 3]) + inplace = ms.numpy.empty((2, 2), dtype=dtype) + inplace[:] = ops.index_select(adj, 1, index) + assert not isinstance(inplace, EdgeIndex) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_narrow(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out = adj.narrow(axis=1, start=1, length=2) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 1], [0, 2]])) + assert out.is_sorted_by_row + assert not out.is_undirected + + out = adj.narrow(axis=0, start=0, length=1) + assert not isinstance(out, EdgeIndex) + assert out.equal(Tensor([[0, 1, 1, 2]])) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_getitem(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + + out = adj[:, Tensor([False, True, False, True])] + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 2], [0, 1]])) + assert out.is_sorted_by_row + assert not out.is_undirected + + out = adj[..., Tensor([1, 3])] + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 2], [0, 1]])) + assert not out.is_sorted + assert not out.is_undirected + + out = adj[..., 1::2] + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[1, 2], [0, 1]])) + assert out.is_sorted_by_row + assert not out.is_undirected + + out = adj[:, 0] + assert not isinstance(out, EdgeIndex) + + out = adj[Tensor([0])] + assert not isinstance(out, EdgeIndex) + + out = adj[Tensor([0]), Tensor([0])] + assert not isinstance(out, EdgeIndex) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('value_dtype', [None, ms.double]) +def test_to_dense(dtype, value_dtype): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) + + out = adj.to_dense(dtype=value_dtype) + assert isinstance(out, Tensor) + assert out.shape == (3, 3) + expected = [[0.0, 1.0, 0.0], [1.0, 0.0, 1.0], [0.0, 1.0, 0.0]] + assert out.equal(Tensor(expected, dtype=value_dtype)) + + value = ops.arange(1, 5, dtype=value_dtype or ms.float32) + out = adj.to_dense(value) + assert isinstance(out, Tensor) + assert out.shape == (3, 3) + expected = [[0.0, 2.0, 0.0], [1.0, 0.0, 4.0], [0.0, 3.0, 0.0]] + assert out.equal(Tensor(expected, dtype=value_dtype)) + + value = ops.arange(1, 5, dtype=value_dtype or ms.float32) + out = adj.to_dense(value.view(-1, 1)) + assert isinstance(out, Tensor) + assert out.shape == (3, 3, 1) + expected = [ + [[0.0], [2.0], [0.0]], + [[1.0], [0.0], [4.0]], + [[0.0], [3.0], [0.0]], + ] + assert out.equal(Tensor(expected, dtype=value_dtype)) + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_to_sparse_coo(dtype): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) + + with pytest.raises(ValueError, match="Unexpected tensor layout"): + adj.to_sparse(layout='int64') + + out = adj.to_sparse(layout=Layout.COO) + assert isinstance(out, Tensor) + assert out.dtype == ms.float32 + assert isinstance(out, ms.COOTensor) + assert out.shape == (3, 3) + assert adj.equal(out.indices.t()) + assert not out.is_coalesced() + + adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], **kwargs) + out = adj.to_coo() + assert isinstance(out, Tensor) + assert out.dtype == ms.float32 + assert isinstance(out, ms.COOTensor) + assert out.shape == (3, 3) + assert adj.equal(out.indices.t()) + assert not out.is_coalesced() + + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + out = adj.to_coo() + assert isinstance(out, Tensor) + assert out.dtype == ms.float32 + assert isinstance(out, ms.COOTensor) + assert out.shape == (3, 3) + assert adj.equal(out.indices.t()) + assert out.is_coalesced() + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_to_sparse_csr(dtype, device): + kwargs = dict(dtype=dtype) + with pytest.raises(ValueError, match="not sorted"): + EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_csr() + + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + out = adj.to_sparse(layout=Layout.CSR) + assert isinstance(out, Tensor) + assert out.dtype == ms.float32 + assert out.device == device + assert out.layout == Layout.CSR + assert out.shape == (3, 3) + assert adj._indptr.equal(out.indptr) + assert adj[1].equal(out.indices) + + +# @pytest.mark.parametrize('dtype', DTYPES) +# @pytest.mark.skipif(not typing.WITH_PT112, reason="<1.12") +# def test_to_sparse_csc(dtype, device): +# kwargs = dict(dtype=dtype) +# with pytest.raises(ValueError, match="not sorted"): +# EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs).to_sparse_csc() + +# adj = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', **kwargs) +# if typing.WITH_PT20: +# out = adj.to_sparse(layout=torch.sparse_csc) +# else: +# out = adj.to_sparse_csc() +# assert isinstance(out, Tensor) +# assert out.dtype == ms.float32 +# assert out.layout == torch.sparse_csc +# assert out.shape == (3, 3) +# assert adj._indptr.equal(out.ccol_indices()) +# assert adj[0].equal(out.row_indices()) + + +def test_to_sparse_tensor(): + kwargs = dict() + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + out = adj.to_sparse_tensor() + assert isinstance(out, SparseTensor) + assert out.shape == (3, 3) + row, col, _ = out.coo() + assert row.equal(adj[0]) + assert col.equal(adj[1]) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_add(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sparse_shape=(3, 3), **kwargs) + + out = ops.add(adj, 2 * 2) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[4, 5, 5, 6], [5, 4, 6, 5]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (7, 7) + + out = adj + ms.Tensor([2], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (5, 5) + + out = adj + ms.Tensor([[2], [1]], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [2, 1, 3, 2]])) + assert not out.is_undirected + assert out.sparse_shape == (5, 4) + + out = adj + ms.Tensor([[2], [2]], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (5, 5) + + out = adj.add(adj) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[0, 2, 2, 4], [2, 0, 4, 2]])) + assert not out.is_undirected + assert out.sparse_shape == (6, 6) + + adj += 2 + assert isinstance(adj, EdgeIndex) + assert adj.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert adj.is_undirected == is_undirected + assert adj.sparse_shape == (5, 5) + + +@pytest.mark.parametrize('dtype', DTYPES) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_sub(dtype, is_undirected): + kwargs = dict(dtype=dtype, is_undirected=is_undirected) + adj = EdgeIndex([[4, 5, 5, 6], [5, 4, 6, 5]], sparse_shape=(7, 7), **kwargs) + + out = ops.sub(adj, 2 * 2) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[0, 1, 1, 2], [1, 0, 2, 1]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (3, 3) + + out = adj - ms.Tensor([2], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (5, 5) + + out = adj - ms.Tensor([[2], [1]], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [4, 3, 5, 4]])) + assert not out.is_undirected + assert out.sparse_shape == (5, 6) + + out = adj - ms.Tensor([[2], [2]], dtype=dtype) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert out.is_undirected == is_undirected + assert out.sparse_shape == (5, 5) + + out = adj.sub(adj) + assert isinstance(out, EdgeIndex) + assert out.equal(Tensor([[0, 0, 0, 0], [0, 0, 0, 0]])) + assert not out.is_undirected + assert out.sparse_shape == (None, None) + + adj -= 2 + assert isinstance(adj, EdgeIndex) + assert adj.equal(Tensor([[2, 3, 3, 4], [3, 2, 4, 3]])) + assert adj.is_undirected == is_undirected + assert adj.sparse_shape == (5, 5) + + +@pytest.mark.parametrize('reduce', ReduceType.__args__) +@pytest.mark.parametrize('transpose', TRANSPOSE) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_torch_sparse_spmm(reduce, transpose, is_undirected): + if is_undirected: + kwargs = dict(is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + else: + adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]]) + adj = adj.sort_by('col' if transpose else 'row').values + + # Basic: + x = ops.randn(3, 1) + + out = _torch_sparse_spmm(adj, x, None, reduce, transpose) + exp = _scatter_spmm(adj, x, None, reduce, transpose) + assert out.allclose(exp, atol=1e-6) + + # With non-zero values: + x = ops.randn(3, 1) + value = ops.rand(adj.shape[1]) + + out = _torch_sparse_spmm(adj, x, value, reduce, transpose) + exp = _scatter_spmm(adj, x, value, reduce, transpose) + assert out.allclose(exp, atol=1e-6) + + # Gradients w.r.t. other: + x1 = ops.randn(3, 1) + x2 = x1.requires_grad_() + grad = ops.randn_like(x1) + + out = _torch_sparse_spmm(adj, x1, None, reduce, transpose) + out.backward(grad) + exp = _scatter_spmm(adj, x2, None, reduce, transpose) + exp.backward(grad) + assert x1.grad.allclose(x2.grad, atol=1e-6) + + # Gradients w.r.t. value: + x = ops.randn(3, 1) + value1 = ops.rand(adj.shape[1]) + value2 = value1.requires_grad_() + grad = ops.randn_like(x) + + out = _torch_sparse_spmm(adj, x, value1, reduce, transpose) + out.backward(grad) + exp = _scatter_spmm(adj, x, value2, reduce, transpose) + exp.backward(grad) + assert value1.grad.allclose(value2.grad, atol=1e-6) + + +@pytest.mark.parametrize('reduce', ReduceType.__args__) +@pytest.mark.parametrize('transpose', TRANSPOSE) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_torch_spmm(reduce, transpose, is_undirected): + if is_undirected: + kwargs = dict(is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + else: + adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]]) + adj, perm = adj.sort_by('col' if transpose else 'row') + + # Basic: + x = ops.randn(3, 2) + + if reduce in ['sum', 'add']: + out = _TorchSPMM.apply(adj, x, None, reduce, transpose) + exp = _scatter_spmm(adj, x, None, reduce, transpose) + assert out.allclose(exp) + else: + with pytest.raises(AssertionError): + _TorchSPMM.apply(adj, x, None, reduce, transpose) + + # With non-zero values: + x = ops.randn(3, 1) + value = ops.rand(adj.shape[1]) + + if reduce in ['sum', 'add']: + out = _TorchSPMM.apply(adj, x, value, reduce, transpose) + exp = _scatter_spmm(adj, x, value, reduce, transpose) + assert out.allclose(exp) + else: + with pytest.raises(AssertionError): + _TorchSPMM.apply(adj, x, value, reduce, transpose) + + # Gradients w.r.t. other: + x1 = ops.randn(3, 1, requires_grad=True) + x2 = x1.requires_grad_() + grad = ops.randn_like(x1) + + if reduce in ['sum', 'add']: + out = _TorchSPMM.apply(adj, x1, None, reduce, transpose) + out.backward(grad) + exp = _scatter_spmm(adj, x2, None, reduce, transpose) + exp.backward(grad) + assert x1.grad.allclose(x2.grad) + else: + with pytest.raises(AssertionError): + out = _TorchSPMM.apply(adj, x1, None, reduce, transpose) + out.backward(grad) + + # Gradients w.r.t. value: + x = ops.randn(3, 1) + value1 = ops.rand(adj.shape[1], requires_grad=True) + grad = ops.randn_like(x) + + with pytest.raises((AssertionError, NotImplementedError)): + out = _TorchSPMM.apply(adj, x, value1, reduce, transpose) + out.backward(grad) + + +@withoutExtensions +@pytest.mark.parametrize('reduce', ReduceType.__args__) +@pytest.mark.parametrize('transpose', TRANSPOSE) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_spmm(reduce, transpose, is_undirected): + warnings.filterwarnings('ignore', '.*can be accelerated via.*') + + if is_undirected: + kwargs = dict(is_undirected=True) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + else: + adj = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]]) + adj = adj.sort_by('col' if transpose else 'row').values + + # Basic: + x = ops.randn(3, 1) + + with pytest.raises(ValueError, match="to be sorted by"): + adj.matmul(x, reduce=reduce, transpose=not transpose) + + out = adj.matmul(x, reduce=reduce, transpose=transpose) + exp = _scatter_spmm(adj, x, None, reduce, transpose) + assert out.allclose(exp) + + # With non-zero values: + x = ops.randn(3, 1) + value = ops.rand(adj.shape[1]) + + with pytest.raises(ValueError, match="'other_value' not supported"): + adj.matmul(x, reduce=reduce, other_value=value, transpose=transpose) + + out = adj.matmul(x, value, reduce=reduce, transpose=transpose) + exp = _scatter_spmm(adj, x, value, reduce, transpose) + assert out.allclose(exp) + + # Gradients w.r.t. other: + x1 = ops.randn(3, 1) + x2 = x1.requires_grad_() + grad = ops.randn_like(x1) + + out = adj.matmul(x1, reduce=reduce, transpose=transpose) + out.backward(grad) + exp = _scatter_spmm(adj, x2, None, reduce, transpose) + exp.backward(grad) + assert x1.grad.allclose(x2.grad) + + # Gradients w.r.t. value: + x = ops.randn(3, 1) + value1 = ops.rand(adj.shape[1]) + value2 = value1.requires_grad_() + grad = ops.randn_like(x) + + out = adj.matmul(x, value1, reduce=reduce, transpose=transpose) + out.backward(grad) + exp = _scatter_spmm(adj, x, value2, reduce, transpose) + exp.backward(grad) + assert value1.grad.allclose(value2.grad) + + +@pytest.mark.parametrize('reduce', ReduceType.__args__) +@pytest.mark.parametrize('transpose', TRANSPOSE) +@pytest.mark.parametrize('is_undirected', IS_UNDIRECTED) +def test_spspmm(device, reduce, transpose, is_undirected): + if is_undirected: + kwargs = dict(device=device, sort_order='row', is_undirected=True) + adj1 = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + else: + kwargs = dict(device=device, sort_order='row') + adj1 = EdgeIndex([[0, 1, 1, 2], [2, 0, 1, 2]], **kwargs) + + adj1_dense = adj1.to_dense().t() if transpose else adj1.to_dense() + adj2 = EdgeIndex([[1, 0, 2, 1], [0, 1, 1, 2]], sort_order='col', + device=device) + adj2_dense = adj2.to_dense() + + if reduce in ['sum', 'add']: + out, value = adj1.matmul(adj2, reduce=reduce, transpose=transpose) + assert isinstance(out, EdgeIndex) + assert out.is_sorted_by_row + assert out._sparse_shape == (3, 3) + if not typing.NO_MKL: + assert out._indptr is not None + assert ops.isclose(out.to_dense(value), adj1_dense @ adj2_dense).all() + else: + with pytest.raises(NotImplementedError, match="not yet supported"): + adj1.matmul(adj2, reduce=reduce, transpose=transpose) + + +@withoutExtensions +def test_matmul(): + kwargs = dict(sort_order='row') + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], **kwargs) + x = ops.randn(3, 1) + expected = adj.to_dense() @ x + + out = adj @ x + assert ops.isclose(out, expected).all() + + out = adj.matmul(x) + assert ops.isclose(out, expected).all() + + out = ops.mm(adj, x) + assert ops.isclose(out, expected).all() + + out = ops.matmul(adj, x) + assert ops.isclose(out, expected).all() + + out = torch.sparse.mm(adj, x, reduce='sum') + assert ops.isclose(out, expected).all() + + +def test_sparse_narrow(device): + adj = EdgeIndex( + [[0, 1, 1, 2], [1, 0, 2, 1]], + device=device, + sort_order='row', + ) + + out = adj.sparse_narrow(dim=0, start=1, length=1) + assert out.equal(ms.Tensor([[0, 0], [0, 2]])) + assert out.sparse_shape == (1, None) + assert out.sort_order == 'row' + assert out._indptr.equal(ms.Tensor([0, 2])) + + out = adj.sparse_narrow(dim=0, start=2, length=0) + assert out.equal(ms.Tensor([[], []])) + assert out.sparse_shape == (0, None) + assert out.sort_order == 'row' + assert out._indptr is None + + out = adj.sparse_narrow(dim=1, start=1, length=1) + assert (out.equal(ms.Tensor([[0, 2], [0, 0]])) + or out.equal(ms.Tensor([[2, 0], [0, 0]]))) + assert out.sparse_shape == (3, 1) + assert out.sort_order == 'col' + assert out._indptr.equal(ms.Tensor([0, 2])) + + out = adj.sparse_narrow(dim=1, start=2, length=0) + assert out.equal(ms.Tensor([[], []])) + assert out.sparse_shape == (3, 0) + assert out.sort_order == 'col' + assert out._indptr is None + + +def test_sparse_resize(): + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]]) + + out = adj.sort_by('row')[0].fill_cache_() + assert out.sparse_shape == (3, 3) + assert out._indptr.equal(Tensor([0, 1, 3, 4])) + assert out._T_indptr.equal(Tensor([0, 1, 3, 4])) + out = out.sparse_rehape(4, 5) + assert out.sparse_shape == (4, 5) + assert out._indptr.equal(Tensor([0, 1, 3, 4, 4])) + assert out._T_indptr.equal(Tensor([0, 1, 3, 4, 4, 4])) + out = out.sparse_rehape(3, 3) + assert out.sparse_shape == (3, 3) + assert out._indptr is None + assert out._T_indptr is None + + out = adj.sort_by('col')[0].fill_cache_() + assert out.sparse_shape == (3, 3) + assert out._indptr.equal(Tensor([0, 1, 3, 4])) + assert out._T_indptr.equal(Tensor([0, 1, 3, 4])) + out = out.sparse_rehape(4, 5) + assert out.sparse_shape == (4, 5) + assert out._indptr.equal(Tensor([0, 1, 3, 4, 4, 4])) + assert out._T_indptr.equal(Tensor([0, 1, 3, 4, 4])) + out = out.sparse_rehape(3, 3) + assert out.sparse_shape == (3, 3) + assert out._indptr is None + assert out._T_indptr is None + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_save_and_load(dtype, tmp_path): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + assert adj.sort_order == 'row' + assert adj._indptr is not None + import numpy as np + path = osp.join(tmp_path, 'edge_index.pt') + np.save(adj, path) + out = np.load(path) + + assert isinstance(out, EdgeIndex) + assert out.equal(adj) + assert out.sort_order == 'row' + assert out._indptr.equal(adj._indptr) + + +def _collate_fn(edge_indices: List[EdgeIndex]) -> List[EdgeIndex]: + return edge_indices + + +@pytest.mark.parametrize('dtype', DTYPES) +def test_data_loader(dtype): + kwargs = dict(dtype=dtype) + adj = EdgeIndex([[0, 1, 1, 2], [1, 0, 2, 1]], sort_order='row', **kwargs) + adj.fill_cache_() + + loader = DataLoader( + [adj] * 4, + batch_size=2, + collate_fn=_collate_fn, + drop_last=True, + ) + + assert len(loader) == 2 + for batch in loader: + assert isinstance(batch, list) + assert len(batch) == 2 + for adj in batch: + assert isinstance(adj, EdgeIndex) + assert adj.dtype == adj.dtype + assert adj.is_shared() == True + assert adj._indptr.is_shared() == True + + +def test_torch_script(): + class Model(nn.Cell): + def construct(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: + row, col = edge_index[0], edge_index[1] + x_j = x[row] + out = scatter(x_j, col, dim_size=edge_index.num_cols) + return out + + x = ops.randn(3, 8) + # Test that `num_cols` gets picked up by making last node isolated. + edge_index = EdgeIndex([[0, 1, 1, 2], [1, 0, 0, 1]], sparse_shape=(3, 3)) + + model = Model() + expected = model(x, edge_index) + assert expected.shape == (3, 8) + + # `ms.jit` does not support inheritance at the `Tensor` level :( + with pytest.raises(RuntimeError, match="attribute or method 'num_cols'"): + ms.jit(model) + + # A valid workaround is to treat `EdgeIndex` as a regular PyTorch tensor + # whenever we are in script mode: + class ScriptableModel(nn.Cell): + def construct(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: + row, col = edge_index[0], edge_index[1] + x_j = x[row] + dim_size: Optional[int] = None + if isinstance(edge_index, EdgeIndex): + dim_size = edge_index.num_cols + out = scatter(x_j, col, dim_size=dim_size) + return out + + script_model = ms.jit(ScriptableModel()) + out = script_model(x, edge_index) + assert out.shape == (2, 8) + assert ops.isclose(out, expected[:2]).all() + + +# @onlyLinux +# @withPackage('torch==2.2.0') # TODO Make it work on nightly. +# def test_compile(): +# import torch._dynamo as dynamo + +# class Model(nn.Cell): +# def construct(self, x: Tensor, edge_index: EdgeIndex) -> Tensor: +# x_j = x[edge_index[0]] +# out = scatter(x_j, edge_index[1], dim_size=edge_index.num_cols) +# return out + +# x = ops.randn(3, 8) +# # Test that `num_cols` gets picked up by making last node isolated. +# edge_index = EdgeIndex( +# [[0, 1, 1, 2], [1, 0, 0, 1]], +# sparse_shape=(3, 3), +# sort_order='row', +# ).fill_cache_() + +# model = Model() +# expected = model(x, edge_index) +# assert expected.shape == (3, 8) + +# explanation = dynamo.explain(model)(x, edge_index) +# assert explanation.graph_break_count == 0 + +# compiled_model = torch.compile(model, fullgraph=True) +# out = compiled_model(x, edge_index) +# assert ops.isclose(out, expected).all() + + +# @onlyLinux +# @withPackage('torch==2.2.0') # TODO Make it work on nightly. +# def test_compile_create_edge_index(): +# import torch._dynamo as dynamo + +# class Model(nn.Cell): +# def construct(self) -> None: +# # TODO Add more tests once closed: +# # https://github.com/pytorch/pytorch/issues/117806 +# out = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) +# out.as_subclass(EdgeIndex) +# return + +# model = Model() + +# explanation = dynamo.explain(model)() +# assert explanation.graph_break_count == 0 + +# compiled_model = torch.compile(model, fullgraph=True) +# assert compiled_model() is None + + +# if __name__ == '__main__': +# import argparse + +# warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*") + +# parser = argparse.ArgumentParser() +# parser.add_argument('--device', type=str, default='cuda') +# parser.add_argument('--backward', action='store_true') +# args = parser.parse_args() + +# channels = 128 +# num_nodes = 20_000 +# num_edges = 200_000 + +# x = ops.randn(num_nodes, channels) +# edge_index = EdgeIndex( +# ops.randint(0, num_nodes, size=(2, num_edges)), +# sparse_shape=(num_nodes, num_nodes), +# ).sort_by('row')[0] +# edge_index.fill_cache_() +# adj1 = edge_index.to_sparse_csr() +# adj2 = SparseTensor( +# row=edge_index[0], +# col=edge_index[1], +# sparse_shape=(num_nodes, num_nodes), +# ) + +# def edge_index_mm(edge_index, x, reduce): +# return edge_index.matmul(x, reduce=reduce) + +# def torch_sparse_mm(adj, x): +# return adj @ x + +# def sparse_tensor_mm(adj, x, reduce): +# return adj.matmul(x, reduce=reduce) + +# def scatter_mm(edge_index, x, reduce): +# return _scatter_spmm(edge_index, x, reduce=reduce) + +# funcs = [edge_index_mm, torch_sparse_mm, sparse_tensor_mm, scatter_mm] +# func_names = ['edge_index', 'torch.sparse', 'SparseTensor', 'scatter'] + +# for reduce in ['sum', 'mean', 'amin', 'amax']: +# func_args = [(edge_index, x, reduce), (adj1, x), (adj2, x, reduce), +# (edge_index, x, reduce)] +# print(f"reduce='{reduce}':") + +# benchmark( +# funcs=funcs, +# func_names=func_names, +# args=func_args, +# num_steps=100 if args.device == 'cpu' else 1000, +# num_warmups=50 if args.device == 'cpu' else 500, +# backward=args.backward, +# ) diff --git a/tests/graph/test_experimental.py b/tests/graph/test_experimental.py new file mode 100644 index 000000000..1174f003c --- /dev/null +++ b/tests/graph/test_experimental.py @@ -0,0 +1,28 @@ +import pytest + +from mindscience.sharker import ( + experimental_mode, + is_experimental_mode_enabled, + set_experimental_mode, +) + + +@pytest.mark.parametrize('options', ['disable_dynamic_shapes']) +def test_experimental_mode(options): + assert is_experimental_mode_enabled(options) is False + with experimental_mode(options): + assert is_experimental_mode_enabled(options) is True + assert is_experimental_mode_enabled(options) is False + + with set_experimental_mode(True, options): + assert is_experimental_mode_enabled(options) is True + assert is_experimental_mode_enabled(options) is False + + with set_experimental_mode(False, options): + assert is_experimental_mode_enabled(options) is False + assert is_experimental_mode_enabled(options) is False + + set_experimental_mode(True, options) + assert is_experimental_mode_enabled(options) is True + set_experimental_mode(False, options) + assert is_experimental_mode_enabled(options) is False diff --git a/tests/graph/test_home.py b/tests/graph/test_home.py new file mode 100644 index 000000000..2dee1d832 --- /dev/null +++ b/tests/graph/test_home.py @@ -0,0 +1,19 @@ +import os +import os.path as osp + +from mindscience.sharker import get_home_dir, set_home_dir +from mindscience.sharker.home import DEFAULT_CACHE_DIR + + +def test_home(): + os.environ.pop('SHARKER_HOME', None) + home_dir = osp.expanduser(DEFAULT_CACHE_DIR) + assert get_home_dir() == home_dir + + home_dir = '/tmp/test_pyg1' + os.environ['SHARKER_HOME'] = home_dir + assert get_home_dir() == home_dir + + home_dir = '/tmp/test_pyg2' + set_home_dir(home_dir) + assert get_home_dir() == home_dir diff --git a/tests/graph/test_inspector.py b/tests/graph/test_inspector.py new file mode 100644 index 000000000..ef2aae331 --- /dev/null +++ b/tests/graph/test_inspector.py @@ -0,0 +1,138 @@ +import inspect +from typing import Any, Dict, Final, List, Optional, Set, Tuple, Union + +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker.inspector import Inspector, Parameter, Signature +from mindscience.sharker.nn import GATConv, SAGEConv +from mindscience.sharker.typing import OptPairTensor + + +def test_eval_type() -> None: + inspector = Inspector(SAGEConv) + + assert inspector.eval_type('Tensor') == Tensor + assert inspector.eval_type('List[Tensor]') == List[Tensor] + assert inspector.eval_type('Tuple[Tensor, int]') == Tuple[Tensor, int] + assert inspector.eval_type('Tuple[int, ...]') == Tuple[int, ...] + + +def test_type_repr() -> None: + inspector = Inspector(SAGEConv) + + assert inspector.type_repr(Any) == 'typing.Any' + assert inspector.type_repr(Final) == 'typing.Final' + assert inspector.type_repr(OptPairTensor) == ( + 'Tuple[Tensor, Optional[Tensor]]') + assert inspector.type_repr( + Final[Optional[Tensor]]) == ('typing.Final[Optional[Tensor]]') + assert inspector.type_repr(Union[None, Tensor]) == 'Optional[Tensor]' + assert inspector.type_repr(Optional[Tensor]) == 'Optional[Tensor]' + assert inspector.type_repr(Set[Tensor]) == 'typing.Set[Tensor]' + assert inspector.type_repr(List) == 'List' + assert inspector.type_repr(Tuple) == 'Tuple' + assert inspector.type_repr(Set) == 'typing.Set' + assert inspector.type_repr(Dict) == 'typing.Dict' + assert inspector.type_repr(Dict[str, Tuple[Tensor, Tensor]]) == ( # + 'typing.Dict[str, Tuple[Tensor, Tensor]]') + assert inspector.type_repr(Tuple[int, ...]) == 'Tuple[int, ...]' + assert inspector.type_repr(Union[int, str, None]) == ( # + 'Union[int, str, None]') + + +def test_inspector_sage_conv() -> None: + inspector = Inspector(SAGEConv) + assert str(inspector) == 'Inspector(SAGEConv)' + assert inspector.implements('message') + assert inspector.implements('message_and_aggregate') + + out = inspector.inspect_signature(SAGEConv.message) + assert isinstance(out, Signature) + assert out.param_dict == { + 'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty) + } + assert out.return_type == Tensor + assert inspector.get_flat_params(['message', 'message']) == [ + Parameter('x_j', Tensor, 'Tensor', inspect._empty), + ] + assert inspector.get_flat_param_names(['message']) == ['x_j'] + + kwargs = {'x_j': ops.randn(5), 'x_i': ops.randn(5)} + data = inspector.collect_param_data('message', kwargs) + assert len(data) == 1 + assert ops.isclose(data['x_j'], kwargs['x_j']).all() + + assert inspector.get_params_from_method_call(SAGEConv.propagate) == { + 'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty), + } + + +def test_inspector_gat_conv() -> None: + inspector = Inspector(GATConv) + assert str(inspector) == 'Inspector(GATConv)' + assert inspector.implements('message') + assert not inspector.implements('message_and_aggregate') + + out = inspector.inspect_signature(GATConv.message) + assert isinstance(out, Signature) + assert out.param_dict == { + 'x_j': Parameter('x_j', Tensor, 'Tensor', inspect._empty), + 'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty), + } + assert out.return_type == Tensor + assert inspector.get_flat_params(['message', 'message']) == [ + Parameter('x_j', Tensor, 'Tensor', inspect._empty), + Parameter('alpha', Tensor, 'Tensor', inspect._empty), + ] + assert inspector.get_flat_param_names(['message']) == ['x_j', 'alpha'] + + kwargs = {'x_j': ops.randn(5), 'alpha': ops.randn(5)} + data = inspector.collect_param_data('message', kwargs) + assert len(data) == 2 + assert ops.isclose(data['x_j'], kwargs['x_j']).all() + assert ops.isclose(data['alpha'], kwargs['alpha']).all() + + assert inspector.get_params_from_method_call(SAGEConv.propagate) == { + 'x': Parameter('x', OptPairTensor, 'OptPairTensor', inspect._empty), + 'alpha': Parameter('alpha', Tensor, 'Tensor', inspect._empty), + } + + +def test_get_params_from_method_call() -> None: + class FromMethodCall1: + propagate_type = {'x': Tensor} + + inspector = Inspector(FromMethodCall1) + assert inspector.get_params_from_method_call('propagate') == { + 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), + } + + class FromMethodCall2: + # propagate_type: (x: Tensor) + pass + + inspector = Inspector(FromMethodCall2) + assert inspector.get_params_from_method_call('propagate') == { + 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), + } + + class FromMethodCall3: + def construct(self) -> None: + self.propagate( # type: ignore + ops.randn(5, 5), + x=None, + size=None, + ) + + inspector = Inspector(FromMethodCall3) + exclude = [0, 'size'] + assert inspector.get_params_from_method_call('propagate', exclude) == { + 'x': Parameter('x', Tensor, 'Tensor', inspect._empty), + } + + class FromMethodCall4: + pass + + inspector = Inspector(FromMethodCall4) + assert inspector.get_params_from_method_call('propagate') == {} diff --git a/tests/graph/test_schnet.py b/tests/graph/test_schnet.py new file mode 100644 index 000000000..2ccc545f3 --- /dev/null +++ b/tests/graph/test_schnet.py @@ -0,0 +1,37 @@ +from mindspore import ops + +from mindscience.sharker.data import Batch, Graph +from mindscience.sharker.nn.models.schnet import SchNet +from mindscience.sharker.nn.models.schnet import RadiusInteractionGraph +from mindscience.sharker.loader import DataLoader + + +def qm9_train(use_interaction_graph=None, use_atomref=None): + from mindscience.sharker.datasets import QM9 + interaction_graph = RadiusInteractionGraph(cutoff=6.0) + + model = SchNet( + hidden_channels=16, + num_filters=16, + num_interactions=2, + interaction_graph=interaction_graph, + num_gaussians=10, + cutoff=6.0, + dipole=True, + atomref=ops.randn(100, 1) if use_atomref else None, + ) + + assert str(model) == ( + "SchNet(hidden_channels=16, num_filters=16, " + "num_interactions=2, num_gaussians=10, cutoff=6.0)" + ) + + dataset = QM9(root="/home/liuxh/db") + loader = DataLoader(dataset, 16, shuffle=True) + for batch in loader: + out = model(batch.z, batch.crd) + assert out.shape == (1, 1) + + +if __name__ == "__main__": + qm9_train() diff --git a/tests/graph/test_seed.py b/tests/graph/test_seed.py new file mode 100644 index 000000000..2c4baf66d --- /dev/null +++ b/tests/graph/test_seed.py @@ -0,0 +1,16 @@ +import random + +import numpy as np +from mindspore import ops +from mindscience.sharker import seed_everything + + +def test_seed_everything(): + seed_everything(0) + + assert random.randint(0, 100) == 49 + assert random.randint(0, 100) == 97 + assert np.random.randint(0, 100) == 44 + assert np.random.randint(0, 100) == 47 + assert ops.randint(0, 100, (1, ))[0] == 21 + assert ops.randint(0, 100, (1, ))[0] == 12 diff --git a/tests/graph/test_typing.py b/tests/graph/test_typing.py new file mode 100644 index 000000000..7355c8f81 --- /dev/null +++ b/tests/graph/test_typing.py @@ -0,0 +1,36 @@ +import pytest + +from mindscience.sharker.typing import EdgeTypeStr + + +def test_edge_type_str(): + edge_type_str = EdgeTypeStr('a__links__b') + assert isinstance(edge_type_str, str) + assert edge_type_str == 'a__links__b' + assert edge_type_str.to_tuple() == ('a', 'links', 'b') + + edge_type_str = EdgeTypeStr('a', 'b') + assert isinstance(edge_type_str, str) + assert edge_type_str == 'a__to__b' + assert edge_type_str.to_tuple() == ('a', 'to', 'b') + + edge_type_str = EdgeTypeStr(('a', 'b')) + assert isinstance(edge_type_str, str) + assert edge_type_str == 'a__to__b' + assert edge_type_str.to_tuple() == ('a', 'to', 'b') + + edge_type_str = EdgeTypeStr('a', 'links', 'b') + assert isinstance(edge_type_str, str) + assert edge_type_str == 'a__links__b' + assert edge_type_str.to_tuple() == ('a', 'links', 'b') + + edge_type_str = EdgeTypeStr(('a', 'links', 'b')) + assert isinstance(edge_type_str, str) + assert edge_type_str == 'a__links__b' + assert edge_type_str.to_tuple() == ('a', 'links', 'b') + + with pytest.raises(ValueError, match="invalid edge type"): + EdgeTypeStr('a', 'b', 'c', 'd') + + with pytest.raises(ValueError, match="Cannot convert the edge type"): + EdgeTypeStr('a__b__c__d').to_tuple() diff --git a/tests/graph/transforms/test_add_metapaths.py b/tests/graph/transforms/test_add_metapaths.py new file mode 100644 index 000000000..ce6e145b6 --- /dev/null +++ b/tests/graph/transforms/test_add_metapaths.py @@ -0,0 +1,232 @@ +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.data import HeteroGraph +from mindscience.sharker.transforms import AddMetaPaths, AddRandomMetaPaths +from mindscience.sharker.utils import coalesce +from mindscience.sharker.seed import seed_everything + + +def generate_data() -> HeteroGraph: + data = HeteroGraph() + data['p'].x = ops.ones(5) + data['a'].x = ops.ones(6) + data['c'].x = ops.ones(3) + data['p', 'p'].edge_index = Tensor([[0, 1, 2, 3], [1, 2, 4, 2]]) + data['p', 'a'].edge_index = Tensor([[0, 1, 2, 3, 4], [2, 2, 5, 2, 5]]) + data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) + data['c', 'p'].edge_index = Tensor([[0, 0, 1, 2, 2], [0, 1, 2, 3, 4]]) + data['p', 'c'].edge_index = data['c', 'p'].edge_index.flip([0]) + return data + + +def test_add_metapaths() -> None: + data = generate_data() + # Test transform options: + metapaths = [[('p', 'c'), ('c', 'p')]] + + transform = AddMetaPaths(metapaths) + assert str(transform) == 'AddMetaPaths()' + meta1 = transform(data) + + transform = AddMetaPaths(metapaths, drop_orig_edge_types=True) + assert str(transform) == 'AddMetaPaths()' + meta2 = transform(data) + + transform = AddMetaPaths(metapaths, drop_orig_edge_types=True, + keep_same_node_type=True) + assert str(transform) == 'AddMetaPaths()' + meta3 = transform(data) + + transform = AddMetaPaths(metapaths, drop_orig_edge_types=True, + keep_same_node_type=True, + drop_unconnected_node_types=True) + assert str(transform) == 'AddMetaPaths()' + meta4 = transform(data) + + assert meta1['metapath_0'].edge_index.shape == (2, 9) + assert meta2['metapath_0'].edge_index.shape == (2, 9) + assert meta3['metapath_0'].edge_index.shape == (2, 9) + assert meta4['metapath_0'].edge_index.shape == (2, 9) + + assert all([i in meta1.edge_types for i in data.edge_types]) + assert meta2.edge_types == [('p', 'metapath_0', 'p')] + assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + + assert meta3.node_types == ['p', 'a', 'c'] + assert meta4.node_types == ['p'] + + # Test 4-hop metapath: + metapaths = [ + [('a', 'p'), ('p', 'c')], + [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')], + ] + transform = AddMetaPaths(metapaths) + meta = transform(data) + new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] + assert meta['metapath_0'].edge_index.shape == (2, 4) + assert meta['metapath_1'].edge_index.shape == (2, 4) + + # Test `metapath_dict` information: + assert list(meta.metapath_dict.values()) == metapaths + assert list(meta.metapath_dict.keys()) == new_edge_types + + +def test_add_metapaths_max_sample() -> None: + seed_everything(12345) + + data = generate_data() + + metapaths = [[('p', 'c'), ('c', 'p')]] + transform = AddMetaPaths(metapaths, max_sample=1) + + meta = transform(data) + assert meta['metapath_0'].edge_index.shape[1] < 9 + + +def test_add_weighted_metapaths() -> None: + seed_everything(12345) + + data = HeteroGraph() + data['a'].num_nodes = 2 + data['b'].num_nodes = 3 + data['c'].num_nodes = 2 + data['d'].num_nodes = 2 + data['a', 'b'].edge_index = Tensor([[0, 1, 1], [0, 1, 2]]) + data['b', 'a'].edge_index = data['a', 'b'].edge_index.flip([0]) + data['b', 'c'].edge_index = Tensor([[0, 1, 2], [0, 1, 1]]) + data['c', 'b'].edge_index = data['b', 'c'].edge_index.flip([0]) + data['c', 'd'].edge_index = Tensor([[0, 1], [0, 0]]) + data['d', 'c'].edge_index = data['c', 'd'].edge_index.flip([0]) + + metapaths = [ + [('a', 'b'), ('b', 'c')], + [('a', 'b'), ('b', 'c'), ('c', 'd')], + [('a', 'b'), ('b', 'c'), ('c', 'd'), ('d', 'c'), ('c', 'b'), + ('b', 'a')], + ] + transform = AddMetaPaths(metapaths, weighted=True) + out = transform(data) + + # Make sure manually added metapaths compute the correct number of edges: + edge_index = out['a', 'a'].edge_index + edge_weight = out['a', 'a'].edge_weight + edge_index, edge_weight = coalesce(edge_index, edge_weight) + assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] + assert edge_weight.tolist() == [1, 2, 2, 4] + + edge_index = out['a', 'c'].edge_index + edge_weight = out['a', 'c'].edge_weight + edge_index, edge_weight = coalesce(edge_index, edge_weight) + assert edge_index.tolist() == [[0, 1], [0, 1]] + assert edge_weight.tolist() == [1, 2] + + edge_index = out['a', 'd'].edge_index + edge_weight = out['a', 'd'].edge_weight + edge_index, edge_weight = coalesce(edge_index, edge_weight) + assert edge_index.tolist() == [[0, 1], [0, 0]] + assert edge_weight.tolist() == [1, 2] + + # Compute intra-table metapaths efficiently: + metapaths = [[('a', 'b'), ('b', 'c'), ('c', 'd')]] + out = AddMetaPaths(metapaths, weighted=True)(data) + out['d', 'a'].edge_index = out['a', 'd'].edge_index.flip([0]) + out['d', 'a'].edge_weight = out['a', 'd'].edge_weight + metapaths = [[('a', 'd'), ('d', 'a')]] + out = AddMetaPaths(metapaths, weighted=True)(out) + + edge_index = out['a', 'a'].edge_index + edge_weight = out['a', 'a'].edge_weight + edge_index, edge_weight = coalesce(edge_index, edge_weight) + assert edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] + assert edge_weight.tolist() == [1, 2, 2, 4] + + +def test_add_random_metapaths() -> None: + data = generate_data() + + # Test transform options: + metapaths = [[('p', 'c'), ('c', 'p')]] + seed_everything(12345) + + transform = AddRandomMetaPaths(metapaths) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[1])') + meta1 = transform(data) + + transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[1])') + meta2 = transform(data) + + transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True, + keep_same_node_type=True) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[1])') + meta3 = transform(data) + + transform = AddRandomMetaPaths(metapaths, drop_orig_edge_types=True, + keep_same_node_type=True, + drop_unconnected_node_types=True) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[1])') + meta4 = transform(data) + + transform = AddRandomMetaPaths(metapaths, sample_ratio=0.8, + drop_orig_edge_types=True, + keep_same_node_type=True, + drop_unconnected_node_types=True) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=0.8, ' + 'walks_per_node=[1])') + meta5 = transform(data) + + transform = AddRandomMetaPaths(metapaths, walks_per_node=5, + drop_orig_edge_types=True, + keep_same_node_type=True, + drop_unconnected_node_types=True) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[5])') + meta6 = transform(data) + + assert meta1['metapath_0'].edge_index.shape == (2, 5) + assert meta2['metapath_0'].edge_index.shape == (2, 5) + assert meta3['metapath_0'].edge_index.shape == (2, 5) + assert meta4['metapath_0'].edge_index.shape == (2, 5) + assert meta5['metapath_0'].edge_index.shape == (2, 4) + assert meta6['metapath_0'].edge_index.shape == (2, 9) + + assert all([i in meta1.edge_types for i in data.edge_types]) + assert meta2.edge_types == [('p', 'metapath_0', 'p')] + assert meta3.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + assert meta4.edge_types == [('p', 'to', 'p'), ('p', 'metapath_0', 'p')] + + assert meta3.node_types == ['p', 'a', 'c'] + assert meta4.node_types == ['p'] + + # Test 4-hop metapath: + metapaths = [ + [('a', 'p'), ('p', 'c')], + [('a', 'p'), ('p', 'c'), ('c', 'p'), ('p', 'a')], + ] + transform = AddRandomMetaPaths(metapaths) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[1, 1])') + + meta1 = transform(data) + new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] + assert meta1['metapath_0'].edge_index.shape == (2, 2) + assert meta1['metapath_1'].edge_index.shape == (2, 2) + + # Test `metapath_dict` information: + assert list(meta1.metapath_dict.values()) == metapaths + assert list(meta1.metapath_dict.keys()) == new_edge_types + + transform = AddRandomMetaPaths(metapaths, walks_per_node=[2, 5]) + assert str(transform) == ('AddRandomMetaPaths(sample_ratio=1.0, ' + 'walks_per_node=[2, 5])') + + meta2 = transform(data) + new_edge_types = [('a', 'metapath_0', 'c'), ('a', 'metapath_1', 'a')] + assert meta2['metapath_0'].edge_index.shape == (2, 3) + assert meta2['metapath_1'].edge_index.shape == (2, 3) diff --git a/tests/graph/transforms/test_add_positional_encoding.py b/tests/graph/transforms/test_add_positional_encoding.py new file mode 100644 index 000000000..4a3091572 --- /dev/null +++ b/tests/graph/transforms/test_add_positional_encoding.py @@ -0,0 +1,114 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import ( + AddLaplacianEigenvectorPE, + AddRandomWalkPE, +) + + +def test_add_laplacian_eigenvector_pe(): + x = ops.randn(6, 4) + edge_index = Tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + data = Graph(x=x, edge_index=edge_index) + + transform = AddLaplacianEigenvectorPE(k=3) + assert str(transform) == 'AddLaplacianEigenvectorPE()' + out = transform(data) + assert out.laplacian_eigenvector_pe.shape == (6, 3) + + transform = AddLaplacianEigenvectorPE(k=3, attr_name=None) + out = transform(data) + assert out.x.shape == (6, 4 + 3) + + transform = AddLaplacianEigenvectorPE(k=3, attr_name='x') + out = transform(data) + assert out.x.shape == (6, 3) + + # Output tests: + edge_index = Tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5, 2, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3, 5, 2]]) + data = Graph(x=x, edge_index=edge_index) + + transform1 = AddLaplacianEigenvectorPE(k=1, is_undirected=True) + transform2 = AddLaplacianEigenvectorPE(k=1, is_undirected=False) + + # Clustering test with first non-trivial eigenvector (Fiedler vector) + pe = transform1(data).laplacian_eigenvector_pe + pe_cluster_1 = pe[[0, 1, 4]] + pe_cluster_2 = pe[[2, 3, 5]] + assert not ops.isclose(pe_cluster_1, pe_cluster_2).all() + assert ops.isclose(pe_cluster_1, pe_cluster_1.mean()).all() + assert ops.isclose(pe_cluster_2, pe_cluster_2.mean()).all() + + pe = transform2(data).laplacian_eigenvector_pe + pe_cluster_1 = pe[[0, 1, 4]] + pe_cluster_2 = pe[[2, 3, 5]] + assert not ops.isclose(pe_cluster_1, pe_cluster_2).all() + assert ops.isclose(pe_cluster_1, pe_cluster_1.mean()).all() + assert ops.isclose(pe_cluster_2, pe_cluster_2.mean()).all() + + +def test_eigenvector_permutation_invariance(): + edge_index = Tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + data = Graph(edge_index=edge_index, num_nodes=6) + + perm = ops.shuffle(ops.arange(data.num_nodes)) + transform = AddLaplacianEigenvectorPE( + k=2, + is_undirected=True, + attr_name='x', + ) + out1 = transform(data) + + transform = AddLaplacianEigenvectorPE( + k=2, + is_undirected=True, + attr_name='x', + ) + out2 = transform(data.subgraph(perm)) + + assert ops.isclose(out1.x[perm].abs(), out2.x.abs(), atol=1e-6).all() + + +def test_add_random_walk_pe(): + x = ops.randn(6, 4) + edge_index = Tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + data = Graph(x=x, edge_index=edge_index) + + transform = AddRandomWalkPE(walk_length=3) + assert str(transform) == 'AddRandomWalkPE()' + out = transform(data) + assert out.random_walk_pe.shape == (6, 3) + + transform = AddRandomWalkPE(walk_length=3, attr_name=None) + out = transform(data) + assert out.x.shape == (6, 4 + 3) + + transform = AddRandomWalkPE(walk_length=3, attr_name='x') + out = transform(data) + assert out.x.shape == (6, 3) + + # Output tests: + assert out.x.tolist() == [ + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.00], + [0.0, 1.0, 0.00], + [0.0, 0.5, 0.25], + [0.0, 0.5, 0.00], + ] + + edge_index = Tensor([[0, 1, 2], [0, 1, 2]]) + data = Graph(edge_index=edge_index, num_nodes=4) + out = transform(data) + + assert out.x.tolist() == [ + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [1.0, 1.0, 1.0], + [0.0, 0.0, 0.0], + ] diff --git a/tests/graph/transforms/test_add_remaining_self_loops.py b/tests/graph/transforms/test_add_remaining_self_loops.py new file mode 100644 index 000000000..0115fcff3 --- /dev/null +++ b/tests/graph/transforms/test_add_remaining_self_loops.py @@ -0,0 +1,72 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import AddRemainingSelfLoops + + +def test_add_remaining_self_loops(): + assert str(AddRemainingSelfLoops()) == 'AddRemainingSelfLoops()' + + assert len(AddRemainingSelfLoops()(Graph())) == 0 + + # No self-loops in `edge_index`. + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = Tensor([1, 2, 3, 4]) + edge_attr = Tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) + + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddRemainingSelfLoops()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + + # Single self-loop in `edge_index`. + edge_index = Tensor([[0, 0, 1, 2], [1, 0, 2, 1]]) + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddRemainingSelfLoops()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) + data = AddRemainingSelfLoops(attr='edge_weight', fill_value=5)(data) + assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + assert data.edge_weight.tolist() == [1, 3, 4, 2, 5, 5] + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + data = AddRemainingSelfLoops(attr='edge_attr', fill_value='add')(data) + assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [1, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + assert data.edge_attr.tolist() == [[1, 2], [5, 6], [7, 8], [3, 4], [8, 10], + [5, 6]] + + +def test_add_remaining_self_loops_all_loops_exist(): + # All self-loops already exist in the data object. + edge_index = Tensor([[0, 1, 2], [0, 1, 2]]) + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddRemainingSelfLoops()(data) + assert data.edge_index.tolist() == edge_index.tolist() + + # All self-loops already exist in the data object, some of them appear + # multiple times. + edge_index = Tensor([[0, 0, 1, 1, 2], [0, 0, 1, 1, 2]]) + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddRemainingSelfLoops()(data) + assert data.edge_index.tolist() == [[0, 1, 2], [0, 1, 2]] + + +def test_hetero_add_remaining_self_loops(): + edge_index = Tensor([[0, 0, 1, 2], [1, 0, 2, 1]]) + + data = HeteroGraph() + data['v'].num_nodes = 3 + data['w'].num_nodes = 3 + data['v', 'v'].edge_index = edge_index + data['v', 'w'].edge_index = edge_index + data = AddRemainingSelfLoops()(data) + assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 0, 1, 2], + [1, 2, 1, 0, 1, 2]] + assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() diff --git a/tests/graph/transforms/test_add_self_loops.py b/tests/graph/transforms/test_add_self_loops.py new file mode 100644 index 000000000..c7df87814 --- /dev/null +++ b/tests/graph/transforms/test_add_self_loops.py @@ -0,0 +1,58 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import AddSelfLoops + + +def test_add_self_loops(): + assert str(AddSelfLoops()) == 'AddSelfLoops()' + + assert len(AddSelfLoops()(Graph())) == 0 + + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = Tensor([1, 2, 3, 4]) + edge_attr = Tensor([[1, 2], [3, 4], [5, 6], [7, 8]]) + + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddSelfLoops()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) + data = AddSelfLoops(attr='edge_weight', fill_value=5)(data) + assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + assert data.edge_weight.tolist() == [1, 2, 3, 4, 5, 5, 5] + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + data = AddSelfLoops(attr='edge_attr', fill_value='add')(data) + assert data.edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert data.num_nodes == 3 + assert data.edge_attr.tolist() == [[1, 2], [3, 4], [5, 6], [7, 8], [3, 4], + [8, 10], [5, 6]] + + +def test_add_self_loops_with_existing_self_loops(): + edge_index = Tensor([[0, 1, 2], [0, 1, 2]]) + data = Graph(edge_index=edge_index, num_nodes=3) + data = AddSelfLoops()(data) + assert data.edge_index.tolist() == [[0, 1, 2, 0, 1, 2], [0, 1, 2, 0, 1, 2]] + assert data.num_nodes == 3 + + +def test_hetero_add_self_loops(): + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + data = HeteroGraph() + data['v'].num_nodes = 3 + data['w'].num_nodes = 3 + data['v', 'v'].edge_index = edge_index + data['v', 'w'].edge_index = edge_index + data = AddSelfLoops()(data) + assert data['v', 'v'].edge_index.tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() diff --git a/tests/graph/transforms/test_cartesian.py b/tests/graph/transforms/test_cartesian.py new file mode 100644 index 000000000..5f0174b1b --- /dev/null +++ b/tests/graph/transforms/test_cartesian.py @@ -0,0 +1,37 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Cartesian + + +def test_cartesian(): + assert str(Cartesian()) == 'Cartesian(norm=True, max_value=None)' + + pos = Tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = Tensor([1.0, 2.0, 3.0, 4.0]) + + data = Graph(edge_index=edge_index, crd=pos) + data = Cartesian(norm=False)(data) + assert len(data) == 3 + assert ops.equal(data.crd, pos).all() + assert ops.equal(data.edge_index, edge_index).all() + assert ops.isclose( + data.edge_attr, + Tensor([[-1.0, 0.0], [1.0, 0.0], [-2.0, 0.0], [2.0, 0.0]]), + ).all() + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = Cartesian(norm=True)(data) + assert len(data) == 3 + assert ops.equal(data.crd, pos).all() + assert ops.equal(data.edge_index, edge_index).all() + assert ops.isclose( + data.edge_attr, + Tensor([ + [1, 0.25, 0.5], + [2, 0.75, 0.5], + [3, 0, 0.5], + [4, 1, 0.5], + ]), + ).all() diff --git a/tests/graph/transforms/test_center.py b/tests/graph/transforms/test_center.py new file mode 100644 index 000000000..ef632b3e4 --- /dev/null +++ b/tests/graph/transforms/test_center.py @@ -0,0 +1,16 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Center + + +def test_center(): + transform = Center() + assert str(transform) == 'Center()' + + crd = Tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) + data = Graph(crd=crd) + + data = transform(data) + assert len(data) == 1 + assert data.crd.tolist() == [[-2, 0], [0, 0], [2, 0]] diff --git a/tests/graph/transforms/test_compose.py b/tests/graph/transforms/test_compose.py new file mode 100644 index 000000000..7e5120b2b --- /dev/null +++ b/tests/graph/transforms/test_compose.py @@ -0,0 +1,62 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker import transforms as T +from mindscience.sharker.data import Graph + + +def test_compose(): + transform = T.Compose([T.Center(), T.AddSelfLoops()]) + assert str(transform) == ('Compose([\n' + ' Center(),\n' + ' AddSelfLoops()\n' + '])') + + pos = Tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + data = Graph(edge_index=edge_index, crd=pos) + data = transform(data) + assert len(data) == 2 + assert data.crd.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]] + assert data.edge_index.shape == (2, 7) + + +def test_compose_data_list(): + transform = T.Compose([T.Center(), T.AddSelfLoops()]) + + pos = Tensor([[0.0, 0.0], [2.0, 0.0], [4.0, 0.0]]) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + data_list = [Graph(edge_index=edge_index, crd=pos) for _ in range(3)] + data_list = transform(data_list) + assert len(data_list) == 3 + for data in data_list: + assert len(data) == 2 + assert data.crd.tolist() == [[-2.0, 0.0], [0.0, 0.0], [2.0, 0.0]] + assert data.edge_index.shape == (2, 7) + + +def test_compose_filters(): + filter_fn = T.ComposeFilters([ + lambda d: d.num_nodes > 2, + lambda d: d.num_edges > 2, + ]) + assert str(filter_fn)[:16] == 'ComposeFilters([' + + data1 = Graph(x=ops.arange(3)) + assert not filter_fn(data1) + + data2 = Graph(x=ops.arange(2), edge_index=Tensor([ + [0, 0, 1], + [0, 1, 1], + ])) + assert not filter_fn(data2) + + data3 = Graph(x=ops.arange(3), edge_index=Tensor([ + [0, 0, 1], + [0, 1, 1], + ])) + assert filter_fn(data3) + + # Test tuple of data objects: + assert filter_fn((data1, data2, data3)) is False diff --git a/tests/graph/transforms/test_constant.py b/tests/graph/transforms/test_constant.py new file mode 100644 index 000000000..555198713 --- /dev/null +++ b/tests/graph/transforms/test_constant.py @@ -0,0 +1,38 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import Constant + + +def test_constant(): + assert str(Constant()) == 'Constant(value=1.0)' + + x = Tensor([[-1, 0], [0, 0], [2, 0]], dtype=ms.float32) + edge_index = Tensor([[0, 1], [1, 2]]) + + data = Graph(edge_index=edge_index, num_nodes=3) + data = Constant()(data) + assert len(data) == 3 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.x.tolist() == [[1], [1], [1]] + assert data.num_nodes == 3 + + data = Graph(edge_index=edge_index, x=x) + data = Constant()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] + + data = HeteroGraph() + data['v'].x = x + data = Constant()(data) + assert len(data) == 1 + assert data['v'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] + + data = HeteroGraph() + data['v'].x = x + data['w'].x = x + data = Constant(node_types='w')(data) + assert len(data) == 1 + assert data['v'].x.tolist() == x.tolist() + assert data['w'].x.tolist() == [[-1, 0, 1], [0, 0, 1], [2, 0, 1]] diff --git a/tests/graph/transforms/test_delaunay.py b/tests/graph/transforms/test_delaunay.py new file mode 100644 index 000000000..2d21e3de6 --- /dev/null +++ b/tests/graph/transforms/test_delaunay.py @@ -0,0 +1,32 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Delaunay + + +def test_delaunay(): + assert str(Delaunay()) == 'Delaunay()' + + pos = Tensor([[-1, -1], [-1, 1], [1, 1], [1, -1]], dtype=ms.float32) + data = Graph(crd=pos) + data = Delaunay()(data) + assert len(data) == 2 + assert data.face.tolist() == [[3, 1], [1, 3], [0, 2]] + + pos = Tensor([[-1, -1], [-1, 1], [1, 1]], dtype=ms.float32) + data = Graph(crd=pos) + data = Delaunay()(data) + assert len(data) == 2 + assert data.face.tolist() == [[0], [1], [2]] + + pos = Tensor([[-1, -1], [1, 1]], dtype=ms.float32) + data = Graph(crd=pos) + data = Delaunay()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + + pos = Tensor([[-1, -1]], dtype=ms.float32) + data = Graph(crd=pos) + data = Delaunay()(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[], []] diff --git a/tests/graph/transforms/test_distance.py b/tests/graph/transforms/test_distance.py new file mode 100644 index 000000000..9208448c9 --- /dev/null +++ b/tests/graph/transforms/test_distance.py @@ -0,0 +1,31 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Distance + + +def test_distance(): + assert str(Distance()) == 'Distance(norm=True, max_value=None)' + + pos = Tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = Tensor([1.0, 1.0, 1.0, 1.0]) + + data = Graph(edge_index=edge_index, crd=pos) + data = Distance(norm=False)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [[1.0], [1.0], [2.0], [2.0]] + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = Distance(norm=True)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [ + [1.0, 0.5], + [1.0, 0.5], + [1.0, 1.0], + [1.0, 1.0], + ] diff --git a/tests/graph/transforms/test_face_to_edge.py b/tests/graph/transforms/test_face_to_edge.py new file mode 100644 index 000000000..73b8a7ea5 --- /dev/null +++ b/tests/graph/transforms/test_face_to_edge.py @@ -0,0 +1,18 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import FaceToEdge + + +def test_face_to_edge(): + transform = FaceToEdge() + assert str(transform) == 'FaceToEdge()' + + face = Tensor([[0, 0], [1, 1], [2, 3]]) + data = Graph(face=face, num_nodes=4) + + data = transform(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 0, 0, 1, 1, 1, 2, 2, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 0, 1]] + assert data.num_nodes == 4 diff --git a/tests/graph/transforms/test_feature_propagation.py b/tests/graph/transforms/test_feature_propagation.py new file mode 100644 index 000000000..4a1247596 --- /dev/null +++ b/tests/graph/transforms/test_feature_propagation.py @@ -0,0 +1,29 @@ +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import FeaturePropagation, ToSparseTensor + + +def test_feature_propagation(): + x = ops.randn(6, 4) + x[0, 1] = float('nan') + x[2, 3] = float('nan') + missing_mask = ops.isnan(x) + edge_index = Tensor([[0, 1, 0, 4, 1, 4, 2, 3, 3, 5], + [1, 0, 4, 0, 4, 1, 3, 2, 5, 3]]) + + transform = FeaturePropagation(missing_mask) + assert str(transform) == ('FeaturePropagation(missing_features=8.3%, ' + 'num_iterations=40)') + + data1 = Graph(x=x, edge_index=edge_index) + assert ops.isnan(data1.x).sum() == 2 + data1 = FeaturePropagation(missing_mask)(data1) + assert ops.isnan(data1.x).sum() == 0 + assert data1.x.shape == x.shape + + data2 = Graph(x=x, edge_index=edge_index) + assert ops.isnan(data2.x).sum() == 2 + data2 = ToSparseTensor()(data2) + data2 = transform(data2) + assert ops.isnan(data2.x).sum() == 0 + assert ops.isclose(data1.x, data2.x).all() diff --git a/tests/graph/transforms/test_fixed_points.py b/tests/graph/transforms/test_fixed_points.py new file mode 100644 index 000000000..764639c35 --- /dev/null +++ b/tests/graph/transforms/test_fixed_points.py @@ -0,0 +1,64 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import FixedPoints + + +def test_fixed_points(): + assert str(FixedPoints(1024)) == 'FixedPoints(1024, replace=True)' + + data = Graph( + crd=ops.randn(100, 3), + x=ops.randn(100, 16), + y=ops.randn(1), + edge_attr=ops.randn(100, 3), + num_nodes=100, + ) + + out = FixedPoints(50, replace=True)(data) + assert len(out) == 5 + assert out.crd.shape == (50, 3) + assert out.x.shape == (50, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 50 + + out = FixedPoints(200, replace=True)(data) + assert len(out) == 5 + assert out.crd.shape == (200, 3) + assert out.x.shape == (200, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 200 + + out = FixedPoints(50, replace=False, allow_duplicates=False)(data) + assert len(out) == 5 + assert out.crd.shape == (50, 3) + assert out.x.shape == (50, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 50 + + out = FixedPoints(200, replace=False, allow_duplicates=False)(data) + assert len(out) == 5 + assert out.crd.shape == (100, 3) + assert out.x.shape == (100, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 100 + + out = FixedPoints(50, replace=False, allow_duplicates=True)(data) + assert len(out) == 5 + assert out.crd.shape == (50, 3) + assert out.x.shape == (50, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 50 + + out = FixedPoints(200, replace=False, allow_duplicates=True)(data) + assert len(out) == 5 + assert out.crd.shape == (200, 3) + assert out.x.shape == (200, 16) + assert out.y.shape == (1, ) + assert out.edge_attr.shape == (100, 3) + assert out.num_nodes == 200 diff --git a/tests/graph/transforms/test_gcn_norm.py b/tests/graph/transforms/test_gcn_norm.py new file mode 100644 index 000000000..4ef963253 --- /dev/null +++ b/tests/graph/transforms/test_gcn_norm.py @@ -0,0 +1,47 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker import typing +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import GCNNorm +from mindscience.sharker.typing import SparseTensor + + +def test_gcn_norm(): + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = ops.ones(edge_index.shape[1]) + + transform = GCNNorm() + assert str(transform) == 'GCNNorm(add_self_loops=True)' + + expected_edge_index = [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] + expected_edge_weight = Tensor( + [0.4082, 0.4082, 0.4082, 0.4082, 0.5000, 0.3333, 0.5000]) + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) + data = transform(data) + assert len(data) == 3 + assert data.num_nodes == 3 + assert data.edge_index.tolist() == expected_edge_index + assert ops.isclose(data.edge_weight, expected_edge_weight, atol=1e-4).all() + + data = Graph(edge_index=edge_index, num_nodes=3) + data = transform(data) + assert len(data) == 3 + assert data.num_nodes == 3 + assert data.edge_index.tolist() == expected_edge_index + assert ops.isclose(data.edge_weight, expected_edge_weight, atol=1e-4).all() + + # For `SparseTensor`, expected outputs will be sorted: + if typing.WITH_SPARSE: + expected_edge_index = [[0, 0, 1, 1, 1, 2, 2], [0, 1, 0, 1, 2, 1, 2]] + expected_edge_weight = Tensor( + [0.500, 0.4082, 0.4082, 0.3333, 0.4082, 0.4082, 0.5000]) + + adj_t = SparseTensor.from_edge_index(edge_index, edge_weight).t() + data = Graph(adj_t=adj_t) + data = transform(data) + assert len(data) == 1 + row, col, value = data.adj_t.coo() + assert row.tolist() == expected_edge_index[0] + assert col.tolist() == expected_edge_index[1] + assert ops.isclose(value, expected_edge_weight, atol=1e-4).all() diff --git a/tests/graph/transforms/test_gdc.py b/tests/graph/transforms/test_gdc.py new file mode 100644 index 000000000..ea5da2ea7 --- /dev/null +++ b/tests/graph/transforms/test_gdc.py @@ -0,0 +1,103 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.datasets import KarateClub +from mindscience.sharker.testing import withPackage +from mindscience.sharker.transforms import GDC +from mindscience.sharker.utils import to_dense_adj + + +@withPackage('numba') +def test_gdc(): + data = KarateClub()[0] + + gdc = GDC( + self_loop_weight=1, + normalization_in='sym', + normalization_out='sym', + diffusion_kwargs=dict(method='ppr', alpha=0.15), + sparsification_kwargs=dict(method='threshold', avg_degree=2), + exact=True, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + assert ops.all(mat >= -1e-8) + assert ops.isclose(mat, mat.t(), atol=1e-4).all() + + gdc = GDC( + self_loop_weight=1, + normalization_in='sym', + normalization_out='sym', + diffusion_kwargs=dict(method='heat', t=10), + sparsification_kwargs=dict(method='threshold', avg_degree=2), + exact=True, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + assert ops.all(mat >= -1e-8) + assert ops.isclose(mat, mat.t(), atol=1e-4).all() + + gdc = GDC( + self_loop_weight=1, + normalization_in='col', + normalization_out='col', + diffusion_kwargs=dict(method='heat', t=10), + sparsification_kwargs=dict(method='topk', k=2, dim=0), + exact=True, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + col_sum = mat.sum(0) + assert ops.all(mat >= -1e-8) + assert ops.all( + ops.isclose(col_sum, Tensor(1.0)) + | ops.isclose(col_sum, Tensor(0.0))) + assert ops.all((~ops.isclose(mat, Tensor(0.0))).sum(0) == 2) + + gdc = GDC( + self_loop_weight=1, + normalization_in='row', + normalization_out='row', + diffusion_kwargs=dict(method='heat', t=5), + sparsification_kwargs=dict(method='topk', k=2, dim=1), + exact=True, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + row_sum = mat.sum(1) + assert ops.all(mat >= -1e-8) + assert ops.all( + ops.isclose(row_sum, Tensor(1.0)) + | ops.isclose(row_sum, Tensor(0.0))) + assert ops.all((~ops.isclose(mat, Tensor(0.0))).sum(1) == 2) + + gdc = GDC( + self_loop_weight=1, + normalization_in='row', + normalization_out='row', + diffusion_kwargs=dict(method='coeff', coeffs=[0.8, 0.3, 0.1]), + sparsification_kwargs=dict(method='threshold', eps=0.1), + exact=True, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + row_sum = mat.sum(1) + assert ops.all(mat >= -1e-8) + assert ops.all( + ops.isclose(row_sum, Tensor(1.0)) + | ops.isclose(row_sum, Tensor(0.0))) + + gdc = GDC( + self_loop_weight=1, + normalization_in='sym', + normalization_out='col', + diffusion_kwargs=dict(method='ppr', alpha=0.15, eps=1e-4), + sparsification_kwargs=dict(method='threshold', avg_degree=2), + exact=False, + ) + out = gdc(data) + mat = to_dense_adj(out.edge_index, edge_attr=out.edge_attr).squeeze() + col_sum = mat.sum(0) + assert ops.all(mat >= -1e-8) + assert ops.all( + ops.isclose(col_sum, Tensor(1.0)) + | ops.isclose(col_sum, Tensor(0.0))) diff --git a/tests/graph/transforms/test_generate_mesh_normals.py b/tests/graph/transforms/test_generate_mesh_normals.py new file mode 100644 index 000000000..48ca0db50 --- /dev/null +++ b/tests/graph/transforms/test_generate_mesh_normals.py @@ -0,0 +1,29 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import GenerateMeshNormals + + +def test_generate_mesh_normals(): + transform = GenerateMeshNormals() + assert str(transform) == 'GenerateMeshNormals()' + + pos = Tensor([ + [0.0, 0.0, 0.0], + [-2.0, 1.0, 0.0], + [-1.0, 1.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 1.0, 0.0], + [2.0, 1.0, 0.0], + ]) + face = Tensor([ + [0, 0, 0, 0], + [1, 2, 3, 4], + [2, 3, 4, 5], + ]) + + data = transform(Graph(crd=pos, face=face)) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.face.tolist() == face.tolist() + assert data.norm.tolist() == [[0.0, 0.0, -1.0]] * 6 diff --git a/tests/graph/transforms/test_grid_sampling.py b/tests/graph/transforms/test_grid_sampling.py new file mode 100644 index 000000000..92cec9b2d --- /dev/null +++ b/tests/graph/transforms/test_grid_sampling.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import GridSampling + + +def test_grid_sampling(): + assert str(GridSampling(5)) == 'GridSampling(size=5)' + + pos = Tensor([ + [0.0, 2.0], + [3.0, 2.0], + [3.0, 2.0], + [2.0, 8.0], + [2.0, 6.0], + ]) + y = Tensor([0, 1, 1, 2, 2]) + batch = Tensor([0, 0, 0, 0, 0]) + + data = Graph(crd=pos, y=y, batch=batch) + data = GridSampling(size=5, start=0)(data) + assert len(data) == 3 + assert data.crd.tolist() == [[2, 2], [2, 7]] + assert data.y.tolist() == [1, 2] + assert data.batch.tolist() == [0, 0] diff --git a/tests/graph/transforms/test_half_hop.py b/tests/graph/transforms/test_half_hop.py new file mode 100644 index 000000000..23bf55b6f --- /dev/null +++ b/tests/graph/transforms/test_half_hop.py @@ -0,0 +1,45 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import HalfHop +from mindscience.sharker.seed import seed_everything + + +def test_half_hop(): + edge_index = Tensor([[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]]) + x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], + dtype=ms.float32) + data = Graph(x=x, edge_index=edge_index) + + transform = HalfHop() + assert str(transform) == 'HalfHop(alpha=0.5, p=1.0)' + data = transform(data) + + expected_edge_index = [[0, 1, 2, 0, 1, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1], + [0, 1, 2, 3, 4, 5, 6, 1, 0, 2, 1, 3, 4, 5, 6]] + expected_x = Tensor( + [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [3, 4, 5, 6], + [3, 4, 5, 6], [7, 8, 9, 10], [7, 8, 9, 10]], dtype=ms.float32) + assert len(data) == 3 + assert data.num_nodes == 7 + assert data.edge_index.tolist() == expected_edge_index + assert ops.isclose(data.x, expected_x, atol=1e-4).all() + assert data.slow_node_mask.tolist() == [ + False, False, False, True, True, True, True + ] + + seed_everything(12345) + data = Graph(x=x, edge_index=edge_index) + transform = HalfHop(p=0.5) + assert str(transform) == 'HalfHop(alpha=0.5, p=0.5)' + data = transform(data) + + expected_edge_index = [[1, 0, 1, 2, 0, 1, 2, 3, 4, 5, 1, 0, 1], + [2, 0, 1, 2, 3, 4, 5, 1, 0, 1, 3, 4, 5]] + expected_x = Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], + [3, 4, 5, 6], [3, 4, 5, 6], [7, 8, 9, 10]], + dtype=ms.float32) + assert data.num_nodes == 6 + assert data.edge_index.tolist() == expected_edge_index + assert ops.isclose(data.x, expected_x, atol=1e-4).all() + assert data.slow_node_mask.tolist() == [False, False, False, True, True, True] diff --git a/tests/graph/transforms/test_knn_graph.py b/tests/graph/transforms/test_knn_graph.py new file mode 100644 index 000000000..7cfbc16ca --- /dev/null +++ b/tests/graph/transforms/test_knn_graph.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import KNNGraph + + +def test_knn_graph(): + assert str(KNNGraph()) == 'KNNGraph(k=6)' + + pos = Tensor([ + [0.0, 0.0], + [1.0, 0.0], + [2.0, 0.0], + [0.0, 1.0], + [-2.0, 0.0], + [0.0, -2.0], + ]) + + expected_row = [0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4, 5, 5] + expected_col = [1, 2, 3, 4, 5, 0, 2, 3, 5, 0, 1, 0, 1, 4, 0, 3, 0, 1] + + data = Graph(crd=pos) + data = KNNGraph(k=2, force_undirected=True)(data) + assert len(data) == 2 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index[0].tolist() == expected_row + assert data.edge_index[1].tolist() == expected_col diff --git a/tests/graph/transforms/test_laplacian_lambda_max.py b/tests/graph/transforms/test_laplacian_lambda_max.py new file mode 100644 index 000000000..13d49692b --- /dev/null +++ b/tests/graph/transforms/test_laplacian_lambda_max.py @@ -0,0 +1,33 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LaplacianLambdaMax + + +def test_laplacian_lambda_max(): + out = str(LaplacianLambdaMax()) + assert out == 'LaplacianLambdaMax(normalization=None)' + + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=ms.int64) + edge_attr = Tensor([1, 1, 2, 2], dtype=ms.float32) + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + out = LaplacianLambdaMax(normalization=None, is_undirected=True)(data) + assert len(out) == 4 + assert ops.isclose(Tensor(out.lambda_max), Tensor(4.732049)).all() + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + out = LaplacianLambdaMax(normalization='sym', is_undirected=True)(data) + assert len(out) == 4 + assert ops.isclose(Tensor(out.lambda_max), Tensor(2.0)).all() + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + out = LaplacianLambdaMax(normalization='rw', is_undirected=True)(data) + assert len(out) == 4 + assert ops.isclose(Tensor(out.lambda_max), Tensor(2.0)).all() + + data = Graph(edge_index=edge_index, edge_attr=ops.randn(4, 2), + num_nodes=3) + out = LaplacianLambdaMax(normalization=None)(data) + assert len(out) == 4 + assert ops.isclose(Tensor(out.lambda_max), Tensor(3.0)).all() diff --git a/tests/graph/transforms/test_largest_connected_components.py b/tests/graph/transforms/test_largest_connected_components.py new file mode 100644 index 000000000..746fcb7c9 --- /dev/null +++ b/tests/graph/transforms/test_largest_connected_components.py @@ -0,0 +1,46 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LargestConnectedComponents + + +def test_largest_connected_components(): + assert str(LargestConnectedComponents()) == 'LargestConnectedComponents(1)' + + edge_index = Tensor([ + [0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6, 8, 9], + [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 7, 9, 8], + ]) + data = Graph(edge_index=edge_index, num_nodes=10) + + # Testing without `connection` specified: + transform = LargestConnectedComponents(num_components=2) + out = transform(data) + assert out.num_nodes == 8 + assert out.edge_index.tolist() == data.edge_index[:, :12].tolist() + + # Testing with `connection = strong`: + transform = LargestConnectedComponents(num_components=2, + connection='strong') + out = transform(data) + assert out.num_nodes == 7 + assert out.edge_index.tolist() == [[0, 0, 1, 1, 2, 2, 2, 3, 3, 4, 5, 6], + [1, 2, 0, 2, 0, 1, 3, 2, 4, 3, 6, 5]] + + edge_index = Tensor([ + [0, 1, 2, 3, 3, 4], + [1, 0, 3, 2, 4, 3], + ]) + data = Graph(edge_index=edge_index, num_nodes=5) + + # Testing without `num_components` and `connection` specified: + transform = LargestConnectedComponents() + out = transform(data) + assert out.num_nodes == 3 + assert out.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + + # Testing with larger `num_components` than actual number of components: + transform = LargestConnectedComponents(num_components=3) + out = transform(data) + assert out.num_nodes == 5 + assert out.edge_index.tolist() == data.edge_index.tolist() diff --git a/tests/graph/transforms/test_line_graph.py b/tests/graph/transforms/test_line_graph.py new file mode 100644 index 000000000..e12048343 --- /dev/null +++ b/tests/graph/transforms/test_line_graph.py @@ -0,0 +1,33 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LineGraph + + +def test_line_graph(): + transform = LineGraph() + assert str(transform) == 'LineGraph()' + + # Directed. + edge_index = Tensor([ + [0, 1, 2, 2, 3], + [1, 2, 0, 3, 0], + ]) + data = Graph(edge_index=edge_index, num_nodes=4) + data = transform(data) + assert data.edge_index.tolist() == [[0, 1, 1, 2, 3, 4], [1, 2, 3, 0, 4, 0]] + assert data.num_nodes == data.edge_index.max().item() + 1 + + # Undirected. + edge_index = Tensor([[0, 0, 0, 1, 1, 2, 2, 3, 3, 3, 4, 4], + [1, 2, 3, 0, 4, 0, 3, 0, 2, 4, 1, 3]]) + edge_attr = ops.ones(edge_index.shape[1]) + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=5) + data = transform(data) + assert data.edge_index.max().item() + 1 == data.x.shape[0] + assert data.edge_index.tolist() == [ + [0, 0, 0, 1, 1, 1, 2, 2, 2, 2, 3, 3, 4, 4, 4, 5, 5, 5], + [1, 2, 3, 0, 2, 4, 0, 1, 4, 5, 0, 5, 1, 2, 5, 2, 3, 4], + ] + assert data.x.tolist() == [2, 2, 2, 2, 2, 2] + assert data.num_nodes == data.edge_index.max().item() + 1 diff --git a/tests/graph/transforms/test_linear_transformation.py b/tests/graph/transforms/test_linear_transformation.py new file mode 100644 index 000000000..90a6f26a5 --- /dev/null +++ b/tests/graph/transforms/test_linear_transformation.py @@ -0,0 +1,26 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LinearTransformation + + +@pytest.mark.parametrize('matrix', [ + [[2.0, 0.0], [0.0, 2.0]], + Tensor([[2.0, 0.0], [0.0, 2.0]]), +]) +def test_linear_transformation(matrix): + pos = Tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) + + transform = LinearTransformation(matrix) + assert str(transform) == ('LinearTransformation(\n' + '[[2. 0.]\n' + ' [0. 2.]]\n' + ')') + + out = transform(Graph(crd=pos)) + assert len(out) == 1 + assert ops.isclose(out.crd, 2 * pos).all() + + out = transform(Graph()) + assert len(out) == 0 diff --git a/tests/graph/transforms/test_local_cartesian.py b/tests/graph/transforms/test_local_cartesian.py new file mode 100644 index 000000000..5206c5ac5 --- /dev/null +++ b/tests/graph/transforms/test_local_cartesian.py @@ -0,0 +1,29 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LocalCartesian + + +def test_local_cartesian(): + transform = LocalCartesian() + assert str(transform) == 'LocalCartesian()' + + pos = Tensor([[-1.0, 0.0], [0.0, 0.0], [2.0, 0.0]]) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = Tensor([1.0, 2.0, 3.0, 4.0]) + data = Graph(edge_index=edge_index, crd=pos) + + data = transform(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [[0.25, 0.5], [1.0, 0.5], [0.0, 0.5], + [1.0, 0.5]] + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = transform(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [[1, 0.25, 0.5], [2, 1.0, 0.5], + [3, 0.0, 0.5], [4, 1.0, 0.5]] diff --git a/tests/graph/transforms/test_local_degree_profile.py b/tests/graph/transforms/test_local_degree_profile.py new file mode 100644 index 000000000..1d0e60f48 --- /dev/null +++ b/tests/graph/transforms/test_local_degree_profile.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import LocalDegreeProfile + + +def test_target_indegree(): + assert str(LocalDegreeProfile()) == 'LocalDegreeProfile()' + + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + x = Tensor([[1.0], [1.0], [1.0], [1.0]]) # One isolated node. + + expected = Tensor([ + [1, 2, 2, 2, 0], + [2, 1, 1, 1, 0], + [1, 2, 2, 2, 0], + [0, 0, 0, 0, 0], + ], dtype=ms.float32) + + data = Graph(edge_index=edge_index, num_nodes=x.shape[0]) + data = LocalDegreeProfile()(data) + assert ops.isclose(data.x, expected, atol=1e-2).all() + + data = Graph(edge_index=edge_index, x=x) + data = LocalDegreeProfile()(data) + assert ops.isclose(data.x[:, :1], x).all() + assert ops.isclose(data.x[:, 1:], expected, atol=1e-2).all() diff --git a/tests/graph/transforms/test_mask_transform.py b/tests/graph/transforms/test_mask_transform.py new file mode 100644 index 000000000..077538b26 --- /dev/null +++ b/tests/graph/transforms/test_mask_transform.py @@ -0,0 +1,92 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import IndexToMask, MaskToIndex + + +def test_index_to_mask(): + assert str(IndexToMask()) == ('IndexToMask(attrs=None, sizes=None, ' + 'replace=False)') + + edge_index = Tensor([[0, 1, 1, 2, 2, 3, 3, 4], + [1, 0, 2, 1, 3, 2, 4, 3]]) + train_index = ops.arange(0, 3) + test_index = ops.arange(3, 5) + data = Graph(edge_index=edge_index, train_index=train_index, + test_index=test_index, num_nodes=5) + + out = IndexToMask(replace=True)(data) + assert len(out) == len(data) + assert out.train_mask.tolist() == [True, True, True, False, False] + assert out.test_mask.tolist() == [False, False, False, True, True] + + out = IndexToMask(replace=False)(data) + assert len(out) == len(data) + 2 + + out = IndexToMask(sizes=6, replace=True)(data) + assert out.train_mask.tolist() == [True, True, True, False, False, False] + assert out.test_mask.tolist() == [False, False, False, True, True, False] + + out = IndexToMask(attrs='train_index')(data) + assert len(out) == len(data) + 1 + assert 'train_index' in out + assert 'train_mask' in out + assert 'test_index' in out + assert 'test_mask' not in out + + +def test_mask_to_index(): + assert str(MaskToIndex()) == 'MaskToIndex(attrs=None, replace=False)' + + train_mask = Tensor([True, True, True, False, False]) + test_mask = Tensor([False, False, False, True, True]) + data = Graph(train_mask=train_mask, test_mask=test_mask) + + out = MaskToIndex(replace=True)(data) + assert len(out) == len(data) + assert out.train_index.tolist() == [0, 1, 2] + assert out.test_index.tolist() == [3, 4] + + out = MaskToIndex(replace=False)(data) + assert len(out) == len(data) + 2 + + out = MaskToIndex(attrs='train_mask')(data) + assert len(out) == len(data) + 1 + assert 'train_mask' in out + assert 'train_index' in out + assert 'test_mask' in out + assert 'test_index' not in out + + +def test_hetero_index_to_mask(): + data = HeteroGraph() + data['u'].train_index = ops.arange(0, 3) + data['u'].test_index = ops.arange(3, 5) + data['u'].num_nodes = 5 + + data['v'].train_index = ops.arange(0, 3) + data['v'].test_index = ops.arange(3, 5) + data['v'].num_nodes = 5 + + out = IndexToMask()(data) + assert len(out) == len(data) + 2 + assert 'train_mask' in out['u'] + assert 'test_mask' in out['u'] + assert 'train_mask' in out['v'] + assert 'test_mask' in out['v'] + + +def test_hetero_mask_to_index(): + data = HeteroGraph() + data['u'].train_mask = Tensor([True, True, True, False, False]) + data['u'].test_mask = Tensor([False, False, False, True, True]) + + data['v'].train_mask = Tensor([True, True, True, False, False]) + data['v'].test_mask = Tensor([False, False, False, True, True]) + + out = MaskToIndex()(data) + assert len(out) == len(data) + 2 + assert 'train_index' in out['u'] + assert 'test_index' in out['u'] + assert 'train_index' in out['v'] + assert 'test_index' in out['v'] diff --git a/tests/graph/transforms/test_node_property_split.py b/tests/graph/transforms/test_node_property_split.py new file mode 100644 index 000000000..520c60b38 --- /dev/null +++ b/tests/graph/transforms/test_node_property_split.py @@ -0,0 +1,39 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.datasets import graph_generator +from mindscience.sharker.testing import withPackage +from mindscience.sharker.transforms import NodePropertySplit + + +@withPackage('networkx') +@pytest.mark.parametrize('property_name', [ + 'popularity', + 'locality', + 'density', +]) +def test_node_property_split(property_name): + ratios = [0.3, 0.1, 0.1, 0.2, 0.3] + + transform = NodePropertySplit(property_name, ratios) + assert str(transform) == f'NodePropertySplit({property_name})' + + data = graph_generator.ERGraph(num_nodes=100, edge_prob=0.4)() + data = transform(data) + + node_ids = [] + for name, ratio in zip([ + 'id_train_mask', + 'id_val_mask', + 'id_test_mask', + 'ood_val_mask', + 'ood_test_mask', + ], ratios): + assert data[name].dtype == ms.bool_ + assert data[name].shape == (100, ) + assert int(data[name].sum()) == 100 * ratio + node_ids.extend(data[name].nonzero().view(-1).tolist()) + + # Check that masks are non-intersecting and cover all nodes: + node_ids = Tensor(node_ids) + assert node_ids.numel() == ops.unique(node_ids)[0].numel() == 100 diff --git a/tests/graph/transforms/test_normalize_features.py b/tests/graph/transforms/test_normalize_features.py new file mode 100644 index 000000000..21dfe19fa --- /dev/null +++ b/tests/graph/transforms/test_normalize_features.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import NormalizeFeatures + + +def test_normalize_scale(): + transform = NormalizeFeatures() + assert str(transform) == 'NormalizeFeatures()' + + x = Tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=ms.float32) + data = Graph(x=x) + + data = transform(data) + assert len(data) == 1 + assert data.x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] + + +def test_hetero_normalize_scale(): + x = Tensor([[1, 0, 1], [0, 1, 0], [0, 0, 0]], dtype=ms.float32) + + data = HeteroGraph() + data['v'].x = x + data['w'].x = x + data = NormalizeFeatures()(data) + assert data['v'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] + assert data['w'].x.tolist() == [[0.5, 0, 0.5], [0, 1, 0], [0, 0, 0]] diff --git a/tests/graph/transforms/test_normalize_rotation.py b/tests/graph/transforms/test_normalize_rotation.py new file mode 100644 index 000000000..8f84fe150 --- /dev/null +++ b/tests/graph/transforms/test_normalize_rotation.py @@ -0,0 +1,57 @@ +from math import sqrt + +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import NormalizeRotation + + +def test_normalize_rotation(): + assert str(NormalizeRotation()) == 'NormalizeRotation()' + + pos = Tensor([ + [-2.0, -2.0], + [-1.0, -1.0], + [0.0, 0.0], + [1.0, 1.0], + [2.0, 2.0], + ]) + normal = Tensor([ + [-1.0, 1.0], + [-1.0, 1.0], + [-1.0, 1.0], + [-1.0, 1.0], + [-1.0, 1.0], + ]) + data = Graph(crd=pos) + data.normal = normal + data = NormalizeRotation()(data) + assert len(data) == 2 + + expected_pos = Tensor([ + [-2 * sqrt(2), 0.0], + [-sqrt(2), 0.0], + [0.0, 0.0], + [sqrt(2), 0.0], + [2 * sqrt(2), 0.0], + ]) + expected_normal = [[0, 1], [0, 1], [0, 1], [0, 1], [0, 1]] + + assert ops.isclose(data.crd, expected_pos, atol=1e-04).all() + assert data.normal.tolist() == expected_normal + + data = Graph(crd=pos) + data.normal = normal + data = NormalizeRotation(max_points=3)(data) + assert len(data) == 2 + + assert ops.isclose(data.crd, expected_pos, atol=1e-04).all() + assert data.normal.tolist() == expected_normal + + data = Graph(crd=pos) + data.normal = normal + data = NormalizeRotation(sort=True)(data) + assert len(data) == 2 + + assert ops.isclose(data.crd, expected_pos, atol=1e-04).all() + assert data.normal.tolist() == expected_normal diff --git a/tests/graph/transforms/test_normalize_scale.py b/tests/graph/transforms/test_normalize_scale.py new file mode 100644 index 000000000..25ac813df --- /dev/null +++ b/tests/graph/transforms/test_normalize_scale.py @@ -0,0 +1,17 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import NormalizeScale + + +def test_normalize_scale(): + transform = NormalizeScale() + assert str(transform) == 'NormalizeScale()' + + pos = ops.randn((10, 3)) + data = Graph(crd=pos) + + data = transform(data) + assert len(data) == 1 + assert data.crd.min().item() > -1 + assert data.crd.max().item() < 1 diff --git a/tests/graph/transforms/test_one_hot_degree.py b/tests/graph/transforms/test_one_hot_degree.py new file mode 100644 index 000000000..7a5b3151e --- /dev/null +++ b/tests/graph/transforms/test_one_hot_degree.py @@ -0,0 +1,34 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import OneHotDegree + + +def test_one_hot_degree(): + assert str(OneHotDegree(max_degree=3)) == 'OneHotDegree(3)' + + edge_index = Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + x = Tensor([1.0, 1.0, 1.0, 1.0]) + + data = Graph(edge_index=edge_index, num_nodes=4) + data = OneHotDegree(max_degree=3)(data) + assert len(data) == 3 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.x.tolist() == [ + [0.0, 0.0, 0.0, 1.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + assert data.num_nodes == 4 + + data = Graph(edge_index=edge_index, x=x) + data = OneHotDegree(max_degree=3)(data) + assert len(data) == 2 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.x.tolist() == [ + [1.0, 0.0, 0.0, 0.0, 1.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + [1.0, 0.0, 1.0, 0.0, 0.0], + ] diff --git a/tests/graph/transforms/test_pad.py b/tests/graph/transforms/test_pad.py new file mode 100644 index 000000000..1bda383f4 --- /dev/null +++ b/tests/graph/transforms/test_pad.py @@ -0,0 +1,585 @@ +import numbers +from typing import Dict, Generator, List, Optional, Tuple, Union + +import pytest +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.datasets import FakeDataset, FakeHeteroDataset +from mindscience.sharker.transforms import Pad +from mindscience.sharker.transforms.pad import ( + AttrNamePadding, + EdgeTypePadding, + NodeTypePadding, + Padding, + UniformPadding, +) +from mindscience.sharker.typing import EdgeType, NodeType + + +def fake_data() -> Graph: + return FakeDataset(avg_num_nodes=10, avg_degree=5, edge_dim=2)[0] + + +def fake_hetero_data(node_types=2, edge_types=5) -> HeteroGraph: + return FakeHeteroDataset(num_node_types=node_types, + num_edge_types=edge_types, avg_num_nodes=10, + edge_dim=2)[0] + + +def _generate_homodata_node_attrs(data: Graph) -> Generator[str, None, None]: + for attr in data.keys(): + if data.is_node_attr(attr): + yield attr + + +def _generate_homodata_edge_attrs(data: Graph) -> Generator[str, None, None]: + for attr in data.keys(): + if data.is_edge_attr(attr): + yield attr + + +def _generate_heterodata_nodes( + data: HeteroGraph +) -> Generator[Tuple[NodeType, str, Tensor], None, None]: + for node_type, store in data.node_items(): + for attr in store.keys(): + yield node_type, attr + + +def _generate_heterodata_edges( + data: HeteroGraph +) -> Generator[Tuple[EdgeType, str, Tensor], None, None]: + for edge_type, store in data.edge_items(): + for attr in store.keys(): + yield edge_type, attr + + +def _check_homo_data_nodes( + original: Graph, + padded: Graph, + max_num_nodes: Union[int, Dict[NodeType, int]], + node_pad_value: Optional[Padding] = None, + is_mask_available: bool = False, + exclude_keys: Optional[List[str]] = None, +): + assert padded.num_nodes == max_num_nodes + + compare_pad_start_idx = original.num_nodes + + if is_mask_available: + assert padded.pad_node_mask.numel() == padded.num_nodes + assert ops.all(padded.pad_node_mask[:compare_pad_start_idx]) + assert not ops.any(padded.pad_node_mask[compare_pad_start_idx:]) + + for attr in _generate_homodata_node_attrs(original): + if attr in exclude_keys: + assert attr not in padded.keys() + continue + + assert attr in padded.keys() + + if not isinstance(padded[attr], Tensor): + continue + + assert padded[attr].shape[0] == max_num_nodes + + # Check values in padded area. + pad_value = node_pad_value.get_value( + None, attr) if node_pad_value is not None else 0.0 + assert (ops.flatten(padded[attr][compare_pad_start_idx:]) == pad_value).all() + + # Check values in non-padded area. + assert ops.equal(original[attr], + padded[attr][:compare_pad_start_idx]).all() + + +def _check_homo_data_edges( + original: Graph, + padded: Graph, + max_num_edges: Optional[int] = None, + edge_pad_value: Optional[Padding] = None, + is_mask_available: bool = False, + exclude_keys: Optional[List[str]] = None, + + +): + # Check edge index attribute. + if max_num_edges is None: + max_num_edges = padded.num_nodes**2 + assert padded.num_edges == max_num_edges + assert padded.edge_index.shape[1] == max_num_edges + assert padded.edge_index.shape[1] == max_num_edges + + compare_pad_start_idx = original.num_edges + expected_node = original.num_nodes + + # Check values in padded area. + assert (padded.edge_index[1, compare_pad_start_idx:max_num_edges] == expected_node).all() + assert (padded.edge_index[0, compare_pad_start_idx:max_num_edges] == expected_node).all() + + # Check values in non-padded area. + assert ops.equal(original.edge_index, padded.edge_index[:, :compare_pad_start_idx]).all() + + if is_mask_available: + assert padded.pad_edge_mask.numel() == padded.num_edges + assert ops.all(padded.pad_edge_mask[:compare_pad_start_idx]) + assert not ops.any(padded.pad_edge_mask[compare_pad_start_idx:]) + + # Check other attributes. + for attr in _generate_homodata_edge_attrs(original): + if attr == 'edge_index': + continue + if attr in exclude_keys: + assert attr not in padded.keys() + continue + + assert attr in padded.keys() + + if not isinstance(padded[attr], Tensor): + continue + + assert padded[attr].shape[0] == max_num_edges + + # Check values in padded area. + pad_value = edge_pad_value.get_value( + None, attr) if edge_pad_value is not None else 0.0 + assert (ops.flatten(padded[attr][compare_pad_start_idx:, :] == pad_value)).all() + + # Check values in non-padded area. + assert ops.equal(original[attr], + padded[attr][:compare_pad_start_idx, :]).all() + + +def _check_hetero_data_nodes( + original: HeteroGraph, + padded: HeteroGraph, + max_num_nodes: Union[int, Dict[NodeType, int]], + node_pad_value: Optional[Padding] = None, + is_mask_available: bool = False, + exclude_keys: Optional[List[str]] = None, +): + if is_mask_available: + for store in padded.node_stores: + assert 'pad_node_mask' in store + + expected_nodes = max_num_nodes + + for node_type, attr in _generate_heterodata_nodes(original): + if attr in exclude_keys: + assert attr not in padded[node_type].keys() + continue + + assert attr in padded[node_type].keys() + + if not isinstance(padded[node_type][attr], Tensor): + continue + + compare_pad_start_idx = original[node_type].num_nodes + padded_Tensor = padded[node_type][attr] + + if attr == 'pad_node_mask': + assert padded_Tensor.numel() == padded[node_type].num_nodes + assert ops.all(padded_Tensor[:compare_pad_start_idx]) + assert not ops.any(padded_Tensor[compare_pad_start_idx:]) + continue + + original_Tensor = original[node_type][attr] + + # Check the number of nodes. + if isinstance(max_num_nodes, dict): + expected_nodes = max_num_nodes[node_type] + + assert padded_Tensor.shape[0] == expected_nodes + + compare_pad_start_idx = original_Tensor.shape[0] + pad_value = node_pad_value.get_value( + node_type, attr) if node_pad_value is not None else 0.0 + assert (ops.flatten(padded_Tensor[compare_pad_start_idx:]) == pad_value).all() + # Compare non-padded area with the original. + assert ops.equal(original_Tensor, padded_Tensor[:compare_pad_start_idx]).all() + + +def _check_hetero_data_edges( + original: HeteroGraph, + padded: HeteroGraph, + max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None, + edge_pad_value: Optional[Padding] = None, + is_mask_available: bool = False, + exclude_keys: Optional[List[str]] = None, +): + if is_mask_available: + for store in padded.edge_stores: + assert 'pad_edge_mask' in store + + for edge_type, attr in _generate_heterodata_edges(padded): + if attr in exclude_keys: + assert attr not in padded[edge_type].keys() + continue + + assert attr in padded[edge_type].keys() + + if not isinstance(padded[edge_type][attr], Tensor): + continue + + compare_pad_start_idx = original[edge_type].num_edges + padded_Tensor = padded[edge_type][attr] + + if attr == 'pad_edge_mask': + assert padded_Tensor.numel() == padded[edge_type].num_edges + assert ops.all(padded_Tensor[:compare_pad_start_idx]) + assert not ops.any(padded_Tensor[compare_pad_start_idx:]) + continue + + original_Tensor = original[edge_type][attr] + + if isinstance(max_num_edges, numbers.Number): + expected_num_edges = max_num_edges + elif max_num_edges is None or edge_type not in max_num_edges.keys(): + v1, _, v2 = edge_type + expected_num_edges = padded[v1].num_nodes * padded[v2].num_nodes + else: + expected_num_edges = max_num_edges[edge_type] + + if attr == 'edge_index': + # Check the number of edges. + assert padded_Tensor.shape[1] == expected_num_edges + + # Check padded area values. + src_nodes = original[edge_type[0]].num_nodes + assert (ops.flatten(padded_Tensor[0, compare_pad_start_idx:]) == src_nodes).all() + dst_nodes = original[edge_type[2]].num_nodes + assert (ops.flatten(padded_Tensor[1, compare_pad_start_idx:]) == dst_nodes).all() + + # Compare non-padded area with the original. + assert ops.equal(original_Tensor, padded_Tensor[:, :compare_pad_start_idx]).all() + else: + # Check padded area size. + assert padded_Tensor.shape[0] == expected_num_edges + + # Check padded area values. + pad_value = edge_pad_value.get_value( + edge_type, attr) if edge_pad_value is not None else 0.0 + assert (ops.flatten(padded_Tensor[compare_pad_start_idx:, :]) == pad_value).all() + + # Compare non-padded area with the original. + assert ops.equal(original_Tensor, padded_Tensor[:compare_pad_start_idx, :]).all() + + +def _check_data( + original: Union[Graph, HeteroGraph], + padded: Union[Graph, HeteroGraph], + max_num_nodes: Union[int, Dict[NodeType, int]], + max_num_edges: Optional[Union[int, Dict[EdgeType, int]]] = None, + node_pad_value: Optional[Union[Padding, int, float]] = None, + edge_pad_value: Optional[Union[Padding, int, float]] = None, + is_mask_available: bool = False, + exclude_keys: Optional[List[str]] = None, +): + + if not isinstance(node_pad_value, Padding) and node_pad_value is not None: + node_pad_value = UniformPadding(node_pad_value) + if not isinstance(edge_pad_value, Padding) and edge_pad_value is not None: + edge_pad_value = UniformPadding(edge_pad_value) + + if is_mask_available is None: + is_mask_available = False + + if exclude_keys is None: + exclude_keys = [] + + if isinstance(original, Graph): + _check_homo_data_nodes(original, padded, max_num_nodes, node_pad_value, + is_mask_available, exclude_keys) + _check_homo_data_edges(original, padded, max_num_edges, edge_pad_value, + is_mask_available, exclude_keys) + else: + _check_hetero_data_nodes(original, padded, max_num_nodes, + node_pad_value, is_mask_available, + exclude_keys) + _check_hetero_data_edges(original, padded, max_num_edges, + edge_pad_value, is_mask_available, + exclude_keys) + + +def test_pad_repr(): + pad_str = 'Pad(max_num_nodes=10, max_num_edges=15, ' \ + 'node_pad_value=UniformPadding(value=3.0), ' \ + 'edge_pad_value=UniformPadding(value=1.5))' + assert str(eval(pad_str)) == pad_str + + +@ pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()]) +@ pytest.mark.parametrize('num_nodes', [32, 64]) +@ pytest.mark.parametrize('add_pad_mask', [True, False]) +def test_pad_auto_edges(data, num_nodes, add_pad_mask): + transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask) + + out = transform(data) + _check_data(data, out, num_nodes, is_mask_available=add_pad_mask) + + +@ pytest.mark.parametrize('num_nodes', [32, 64]) +@ pytest.mark.parametrize('num_edges', [300, 411]) +@ pytest.mark.parametrize('add_pad_mask', [True, False]) +def test_pad_data_explicit_edges(num_nodes, num_edges, add_pad_mask): + data = fake_data() + transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges, + add_pad_mask=add_pad_mask) + + out = transform(data) + _check_data(data, out, num_nodes, num_edges, + is_mask_available=add_pad_mask) + + +@ pytest.mark.parametrize('num_nodes', [32, {'v0': 64, 'v1': 36}]) +@ pytest.mark.parametrize('num_edges', [300, {('v0', 'e0', 'v1'): 397}]) +@ pytest.mark.parametrize('add_pad_mask', [True, False]) +def test_pad_heterodata_explicit_edges(num_nodes, num_edges, add_pad_mask): + data = fake_hetero_data() + transform = Pad(max_num_nodes=num_nodes, max_num_edges=num_edges, + add_pad_mask=add_pad_mask) + + out = transform(data) + _check_data(data, out, num_nodes, num_edges, + is_mask_available=add_pad_mask) + + +@ pytest.mark.parametrize('node_pad_value', [10, AttrNamePadding({'x': 3.0})]) +@ pytest.mark.parametrize('edge_pad_value', + [11, AttrNamePadding({'edge_attr': 2.0})]) +def test_pad_data_pad_values(node_pad_value, edge_pad_value): + data = fake_data() + num_nodes = 32 + transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value, + edge_pad_value=edge_pad_value) + out = transform(data) + _check_data(data, out, num_nodes, node_pad_value=node_pad_value, + edge_pad_value=edge_pad_value) + + +@ pytest.mark.parametrize('node_pad_value', [ + UniformPadding(12), + AttrNamePadding({'x': 0}), + NodeTypePadding({ + 'v0': UniformPadding(12), + 'v1': AttrNamePadding({'x': 7}) + }) +]) +@ pytest.mark.parametrize('edge_pad_value', [ + UniformPadding(13), + EdgeTypePadding({ + ('v0', 'e0', 'v1'): + UniformPadding(13), + ('v1', 'e0', 'v0'): + AttrNamePadding({'edge_attr': UniformPadding(-1.0)}) + }) +]) +def test_pad_heterodata_pad_values(node_pad_value, edge_pad_value): + data = fake_hetero_data() + num_nodes = 32 + transform = Pad(max_num_nodes=num_nodes, node_pad_value=node_pad_value, + edge_pad_value=edge_pad_value) + + out = transform(data) + _check_data(data, out, num_nodes, node_pad_value=node_pad_value, + edge_pad_value=edge_pad_value) + + +@ pytest.mark.parametrize('data', [fake_data(), fake_hetero_data()]) +@ pytest.mark.parametrize('add_pad_mask', [True, False]) +@ pytest.mark.parametrize('exclude_keys', [ + ['y'], + ['edge_attr'], + ['y', 'edge_attr'], +]) +def test_pad_data_exclude_keys(data, add_pad_mask, exclude_keys): + num_nodes = 32 + transform = Pad(max_num_nodes=num_nodes, add_pad_mask=add_pad_mask, + exclude_keys=exclude_keys) + + out = transform(data) + _check_data(data, out, num_nodes, is_mask_available=add_pad_mask, + exclude_keys=exclude_keys) + + +@ pytest.mark.parametrize('is_hetero', [False, True]) +def test_pad_invalid_max_num_nodes(is_hetero): + if is_hetero: + data = fake_hetero_data(node_types=1) + else: + data = fake_data() + + transform = Pad(max_num_nodes=data.num_nodes - 1) + + with pytest.raises(AssertionError, match="after padding"): + transform(data) + + +@ pytest.mark.parametrize('is_hetero', [False, True]) +def test_pad_invalid_max_num_edges(is_hetero): + if is_hetero: + data = fake_hetero_data(node_types=1, edge_types=1) + else: + data = fake_data() + + transform = Pad(max_num_nodes=data.num_nodes + 10, + max_num_edges=data.num_edges - 1) + + with pytest.raises(AssertionError, match="after padding"): + transform(data) + + +def test_pad_num_nodes_not_complete(): + data = fake_hetero_data(node_types=2, edge_types=1) + transform = Pad(max_num_nodes={'v0': 100}) + + with pytest.raises(KeyError): + transform(data) + + +def test_pad_invalid_padding_type(): + with pytest.raises(ValueError, match="to be an integer or float"): + Pad(max_num_nodes=100, node_pad_value='somestring') + with pytest.raises(ValueError, match="to be an integer or float"): + Pad(max_num_nodes=100, edge_pad_value='somestring') + + +def test_pad_data_non_Tensor_attr(): + data = fake_data() + batch_size = 13 + data.batch_size = batch_size + + transform = Pad(max_num_nodes=100) + padded = transform(data) + assert padded.batch_size == batch_size + + exclude_transform = Pad(max_num_nodes=101, exclude_keys=('batch_size', )) + padded = exclude_transform(data) + assert 'batch_size' not in padded.keys() + + +@ pytest.mark.parametrize('mask_pad_value', [True, False]) +def test_pad_node_additional_attr_mask(mask_pad_value): + data = fake_data() + mask = ops.randn(data.num_nodes) > 0 + mask_names = ['train_mask', 'test_mask', 'val_mask'] + for mask_name in mask_names: + setattr(data, mask_name, mask) + padding_num = 20 + + max_num_nodes = int(data.num_nodes) + padding_num + max_num_edges = data.num_edges + padding_num + + transform = Pad(max_num_nodes, max_num_edges, node_pad_value=0.1, + mask_pad_value=mask_pad_value) + padded = transform(data) + padded_masks = [getattr(padded, mask_name) for mask_name in mask_names] + + for padded_mask in padded_masks: + assert padded_mask.ndim == 1 + assert padded_mask.shape[0] == max_num_nodes + assert ops.all(padded_mask[-padding_num:] == mask_pad_value) + + +def test_uniform_padding(): + pad_val = 10.0 + p = UniformPadding(pad_val) + assert p.get_value() == pad_val + assert p.get_value("v1", "x") == pad_val + + p = UniformPadding() + assert p.get_value() == 0.0 + + with pytest.raises(ValueError, match="to be an integer or float"): + UniformPadding('') + + +def test_attr_name_padding(): + x_val = 10.0 + y_val = 15.0 + default = 3.0 + padding_dict = {'x': x_val, 'y': UniformPadding(y_val)} + padding = AttrNamePadding(padding_dict, default=default) + + assert padding.get_value(attr_name='x') == x_val + assert padding.get_value('v1', 'x') == x_val + assert padding.get_value(attr_name='y') == y_val + assert padding.get_value('v1', 'y') == y_val + assert padding.get_value(attr_name='x2') == default + + padding = AttrNamePadding({}) + assert padding.get_value(attr_name='x') == 0.0 + + +def test_attr_name_padding_invalid(): + with pytest.raises(ValueError, match="to be a dictionary"): + AttrNamePadding(10.0) + + with pytest.raises(ValueError, match="to be a string"): + AttrNamePadding({10: 10.0}) + + with pytest.raises(ValueError, match="to be of type"): + AttrNamePadding({"x": {}}) + + with pytest.raises(ValueError, match="to be of type"): + AttrNamePadding({"x": {}}) + + node_type_padding = NodeTypePadding({"x": 10.0}) + with pytest.raises(ValueError, match="to be of type"): + AttrNamePadding({'x': node_type_padding}) + + +@ pytest.mark.parametrize('store_type', ['node', 'edge']) +def test_node_edge_type_padding(store_type): + if store_type == "node": + stores = ['v1', 'v2', 'v3', 'v4'] + padding_cls = NodeTypePadding + else: + stores = [('v1', 'e1', 'v1'), ('v1', 'e2', 'v1'), ('v1', 'e1', 'v2'), + ('v2', 'e1', 'v1')] + padding_cls = EdgeTypePadding + + s0_default = 3.0 + s0_padding_dict = {'x': 10.0, 'y': -12.0} + s0_padding = AttrNamePadding(s0_padding_dict, s0_default) + s1_default = 0.1 + s1_padding_dict = {'y': 0.0, 'p': 13.0} + s1_padding = AttrNamePadding(s1_padding_dict, s1_default) + + s2_default = 7.5 + store_default = -11.0 + padding_dict = { + stores[0]: s0_padding, + stores[1]: s1_padding, + stores[2]: s2_default + } + padding = padding_cls(padding_dict, store_default) + + assert padding.get_value(stores[0], 'x') == s0_padding_dict['x'] + assert padding.get_value(stores[0], 'y') == s0_padding_dict['y'] + assert padding.get_value(stores[0], 'p') == s0_default + assert padding.get_value(stores[0], 'z') == s0_default + + assert padding.get_value(stores[1], 'x') == s1_default + assert padding.get_value(stores[1], 'y') == s1_padding_dict['y'] + assert padding.get_value(stores[1], 'p') == s1_padding_dict['p'] + assert padding.get_value(stores[1], 'z') == s1_default + + assert padding.get_value(stores[2], 'x') == s2_default + assert padding.get_value(stores[2], 'z') == s2_default + + assert padding.get_value(stores[3], 'x') == store_default + + +def test_edge_padding_invalid(): + with pytest.raises(ValueError, match="to be a tuple"): + EdgeTypePadding({'v1': 10.0}) + + with pytest.raises(ValueError, match="got 1"): + EdgeTypePadding({('v1', ): 10.0}) + + with pytest.raises(ValueError, match="got 2"): + EdgeTypePadding({('v1', 'v2'): 10.0}) + + with pytest.raises(ValueError, match="got 4"): + EdgeTypePadding({('v1', 'e2', 'v1', 'v2'): 10.0}) diff --git a/tests/graph/transforms/test_point_pair_features.py b/tests/graph/transforms/test_point_pair_features.py new file mode 100644 index 000000000..faaf1d7e2 --- /dev/null +++ b/tests/graph/transforms/test_point_pair_features.py @@ -0,0 +1,40 @@ +from math import pi as PI + +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import PointPairFeatures + + +def test_point_pair_features(): + transform = PointPairFeatures() + assert str(transform) == 'PointPairFeatures()' + + pos = Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + edge_index = Tensor([[0, 1], [1, 0]]) + norm = Tensor([[1.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + edge_attr = Tensor([1.0, 1.0]) + data = Graph(edge_index=edge_index, crd=pos, norm=norm) + + data = transform(data) + assert len(data) == 4 + assert data.crd.tolist() == pos.tolist() + assert data.norm.tolist() == norm.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 0.0, 0.0, 0.0], [1.0, PI, PI, 0.0]]), + atol=1e-4, + ).all() + + data = Graph(edge_index=edge_index, crd=pos, norm=norm, edge_attr=edge_attr) + data = transform(data) + assert len(data) == 4 + assert data.crd.tolist() == pos.tolist() + assert data.norm.tolist() == norm.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 1.0, 0.0, 0.0, 0.0], [1.0, 1.0, PI, PI, 0.0]]), + atol=1e-4, + ).all() diff --git a/tests/graph/transforms/test_polar.py b/tests/graph/transforms/test_polar.py new file mode 100644 index 000000000..a9322393f --- /dev/null +++ b/tests/graph/transforms/test_polar.py @@ -0,0 +1,36 @@ +from math import pi as PI + +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Polar + + +def test_polar(): + assert str(Polar()) == 'Polar(norm=True, max_value=None)' + + pos = Tensor([[0.0, 0.0], [1.0, 0.0]]) + edge_index = Tensor([[0, 1], [1, 0]]) + edge_attr = Tensor([1.0, 1.0]) + + data = Graph(edge_index=edge_index, crd=pos) + data = Polar(norm=False)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 0.0], [1.0, PI]]), + atol=1e-4, + ).all() + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = Polar(norm=True)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 1.0, 0.0], [1.0, 1.0, 0.5]]), + atol=1e-4, + ).all() diff --git a/tests/graph/transforms/test_radius_graph.py b/tests/graph/transforms/test_radius_graph.py new file mode 100644 index 000000000..b127012b3 --- /dev/null +++ b/tests/graph/transforms/test_radius_graph.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RadiusGraph +from mindscience.sharker.utils import coalesce + + +def test_radius_graph(): + assert str(RadiusGraph(r=1)) == 'RadiusGraph(r=1)' + + pos = Tensor([ + [0.0, 0.0], + [1.0, 0.0], + [2.0, 0.0], + [0.0, 1.0], + [-2.0, 0.0], + [0.0, -2.0], + ]) + + data = Graph(crd=pos) + data = RadiusGraph(r=1.5)(data) + assert len(data) == 2 + assert data.crd.tolist() == pos.tolist() + assert coalesce(data.edge_index).tolist() == [[0, 0, 1, 1, 1, 2, 3, 3], + [1, 3, 0, 2, 3, 1, 0, 1]] diff --git a/tests/graph/transforms/test_random_flip.py b/tests/graph/transforms/test_random_flip.py new file mode 100644 index 000000000..4ba768cd2 --- /dev/null +++ b/tests/graph/transforms/test_random_flip.py @@ -0,0 +1,20 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RandomFlip + + +def test_random_flip(): + assert str(RandomFlip(axis=0)) == 'RandomFlip(axis=0, p=0.5)' + + pos = Tensor([[-1.0, 1.0], [-3.0, 0.0], [2.0, -1.0]]) + + data = Graph(crd=pos) + data = RandomFlip(axis=0, p=1)(data) + assert len(data) == 1 + assert data.crd.tolist() == [[1.0, 1.0], [3.0, 0.0], [-2.0, -1.0]] + + data = Graph(crd=pos) + data = RandomFlip(axis=1, p=1)(data) + assert len(data) == 1 + assert data.crd.tolist() == [[-1.0, -1.0], [-3.0, 0.0], [2.0, 1.0]] diff --git a/tests/graph/transforms/test_random_jitter.py b/tests/graph/transforms/test_random_jitter.py new file mode 100644 index 000000000..8de83fe16 --- /dev/null +++ b/tests/graph/transforms/test_random_jitter.py @@ -0,0 +1,29 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RandomJitter + + +def test_random_jitter(): + assert str(RandomJitter(0.1)) == 'RandomJitter(0.1)' + + pos = Tensor([[0.0, 0.0], [0.0, 0.0], [0.0, 0.0], [0.0, 0.0]]) + + data = Graph(crd=pos) + data = RandomJitter(0)(data) + assert len(data) == 1 + assert ops.isclose(data.crd, pos).all() + + data = Graph(crd=pos) + data = RandomJitter(0.1)(data) + assert len(data) == 1 + assert data.crd.min() >= -0.1 + assert data.crd.max() <= 0.1 + + data = Graph(crd=pos) + data = RandomJitter([0.1, 1])(data) + assert len(data) == 1 + assert data.crd[:, 0].min() >= -0.1 + assert data.crd[:, 0].max() <= 0.1 + assert data.crd[:, 1].min() >= -1 + assert data.crd[:, 1].max() <= 1 diff --git a/tests/graph/transforms/test_random_link_split.py b/tests/graph/transforms/test_random_link_split.py new file mode 100644 index 000000000..db0035b51 --- /dev/null +++ b/tests/graph/transforms/test_random_link_split.py @@ -0,0 +1,323 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.testing import ( + get_random_edge_index, + onlyFullTest, + onlyOnline, +) +from mindscience.sharker.transforms import RandomLinkSplit +from mindscience.sharker.utils import is_undirected, to_undirected + + +def test_random_link_split(): + assert str(RandomLinkSplit()) == ('RandomLinkSplit(' + 'num_val=0.1, num_test=0.2)') + + edge_index = Tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) + edge_attr = ops.randn(edge_index.shape[1], 3) + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=100) + + # No test split: + transform = RandomLinkSplit(num_val=2, num_test=0, is_undirected=True) + train_data, val_data, test_data = transform(data) + + assert len(train_data) == 5 + assert train_data.num_nodes == 100 + assert train_data.edge_index.shape == (2, 6) + assert train_data.edge_attr.shape == (6, 3) + assert train_data.edge_label_index.shape[1] == 6 + assert train_data.edge_label.shape[0] == 6 + + assert len(val_data) == 5 + assert val_data.num_nodes == 100 + assert val_data.edge_index.shape == (2, 6) + assert val_data.edge_attr.shape == (6, 3) + assert val_data.edge_label_index.shape[1] == 4 + assert val_data.edge_label.shape[0] == 4 + + assert len(test_data) == 5 + assert test_data.num_nodes == 100 + assert test_data.edge_index.shape == (2, 10) + assert test_data.edge_attr.shape == (10, 3) + assert test_data.edge_label_index.shape == (2, 0) + assert test_data.edge_label.shape == (0, ) + + # Percentage split: + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, + neg_sampling_ratio=2.0, is_undirected=False) + train_data, val_data, test_data = transform(data) + + assert len(train_data) == 5 + assert train_data.num_nodes == 100 + assert train_data.edge_index.shape == (2, 6) + assert train_data.edge_attr.shape == (6, 3) + assert train_data.edge_label_index.shape[1] == 18 + assert train_data.edge_label.shape[0] == 18 + + assert len(val_data) == 5 + assert val_data.num_nodes == 100 + assert val_data.edge_index.shape == (2, 6) + assert val_data.edge_attr.shape == (6, 3) + assert val_data.edge_label_index.shape[1] == 6 + assert val_data.edge_label.shape[0] == 6 + + assert len(test_data) == 5 + assert test_data.num_nodes == 100 + assert test_data.edge_index.shape == (2, 8) + assert test_data.edge_attr.shape == (8, 3) + assert test_data.edge_label_index.shape[1] == 6 + assert test_data.edge_label.shape[0] == 6 + + # Disjoint training split: + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=False, + disjoint_train_ratio=0.5) + train_data, val_data, test_data = transform(data) + + assert len(train_data) == 5 + assert train_data.num_nodes == 100 + assert train_data.edge_index.shape == (2, 3) + assert train_data.edge_attr.shape == (3, 3) + assert train_data.edge_label_index.shape[1] == 6 + assert train_data.edge_label.shape[0] == 6 + + +def test_random_link_split_with_label(): + edge_index = Tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) + edge_label = Tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + + data = Graph(edge_index=edge_index, edge_label=edge_label, num_nodes=6) + + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, + neg_sampling_ratio=0.0) + train_data, _, _ = transform(data) + assert len(train_data) == 4 + assert train_data.num_nodes == 6 + assert train_data.edge_index.shape == (2, 6) + assert train_data.edge_label_index.shape == (2, 6) + assert train_data.edge_label.shape == (6, ) + assert train_data.edge_label.min() == 0 + assert train_data.edge_label.max() == 1 + + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, + neg_sampling_ratio=1.0) + train_data, _, _ = transform(data) + assert len(train_data) == 4 + assert train_data.num_nodes == 6 + assert train_data.edge_index.shape == (2, 6) + assert train_data.edge_label_index.shape == (2, 12) + assert train_data.edge_label.shape == (12, ) + assert train_data.edge_label.min() == 0 + assert train_data.edge_label.max() == 2 + assert train_data.edge_label[6:].sum() == 0 + + +def test_random_link_split_increment_label(): + edge_index = Tensor([[0, 1, 1, 2, 2, 3, 3, 4, 4, 5], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4]]) + edge_label = Tensor([0, 0, 0, 0, 0, 1, 1, 1, 1, 1]) + + data = Graph(edge_index=edge_index, edge_label=edge_label, num_nodes=6) + + transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=0.0) + train_data, _, _ = transform(data) + assert train_data.edge_label.numel() == edge_index.shape[1] + assert train_data.edge_label.min() == 0 + assert train_data.edge_label.max() == 1 + + transform = RandomLinkSplit(num_val=0, num_test=0, neg_sampling_ratio=1.0) + train_data, _, _ = transform(data) + assert train_data.edge_label.numel() == 2 * edge_index.shape[1] + assert train_data.edge_label.min() == 0 + assert train_data.edge_label.max() == 2 + assert train_data.edge_label[edge_index.shape[1]:].sum() == 0 + + +def test_random_link_split_on_hetero_data(): + data = HeteroGraph() + + data['p'].x = ops.arange(100) + data['a'].x = ops.arange(100, 300) + + data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500) + data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index) + data['p', 'p'].edge_attr = ops.arange(data['p', 'p'].num_edges) + data['p', 'a'].edge_index = get_random_edge_index(100, 200, 1000) + data['p', 'a'].edge_attr = ops.arange(500, 1500) + data['a', 'p'].edge_index = data['p', 'a'].edge_index.flip([0]) + data['a', 'p'].edge_attr = ops.arange(1500, 2500) + + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True, + edge_types=('p', 'p')) + train_data, val_data, test_data = transform(data) + + assert len(train_data['p']) == 1 + assert len(train_data['a']) == 1 + assert len(train_data['p', 'p']) == 4 + assert len(train_data['p', 'a']) == 2 + assert len(train_data['a', 'p']) == 2 + + assert is_undirected(train_data['p', 'p'].edge_index, + train_data['p', 'p'].edge_attr) + assert is_undirected(val_data['p', 'p'].edge_index, + val_data['p', 'p'].edge_attr) + assert is_undirected(test_data['p', 'p'].edge_index, + test_data['p', 'p'].edge_attr) + + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, + edge_types=('p', 'a'), + rev_edge_types=('a', 'p')) + train_data, val_data, test_data = transform(data) + + assert len(train_data['p']) == 1 + assert len(train_data['a']) == 1 + assert len(train_data['p', 'p']) == 2 + assert len(train_data['p', 'a']) == 4 + assert len(train_data['a', 'p']) == 2 + + assert train_data['p', 'a'].edge_index.shape == (2, 600) + assert train_data['p', 'a'].edge_attr.shape == (600, ) + assert train_data['p', 'a'].edge_attr.min() >= 500 + assert train_data['p', 'a'].edge_attr.max() <= 1500 + assert train_data['a', 'p'].edge_index.shape == (2, 600) + assert train_data['a', 'p'].edge_attr.shape == (600, ) + assert train_data['a', 'p'].edge_attr.min() >= 500 + assert train_data['a', 'p'].edge_attr.max() <= 1500 + assert train_data['p', 'a'].edge_label_index.shape == (2, 1200) + assert train_data['p', 'a'].edge_label.shape == (1200, ) + + assert val_data['p', 'a'].edge_index.shape == (2, 600) + assert val_data['p', 'a'].edge_attr.shape == (600, ) + assert val_data['p', 'a'].edge_attr.min() >= 500 + assert val_data['p', 'a'].edge_attr.max() <= 1500 + assert val_data['a', 'p'].edge_index.shape == (2, 600) + assert val_data['a', 'p'].edge_attr.shape == (600, ) + assert val_data['a', 'p'].edge_attr.min() >= 500 + assert val_data['a', 'p'].edge_attr.max() <= 1500 + assert val_data['p', 'a'].edge_label_index.shape == (2, 400) + assert val_data['p', 'a'].edge_label.shape == (400, ) + + assert test_data['p', 'a'].edge_index.shape == (2, 800) + assert test_data['p', 'a'].edge_attr.shape == (800, ) + assert test_data['p', 'a'].edge_attr.min() >= 500 + assert test_data['p', 'a'].edge_attr.max() <= 1500 + assert test_data['a', 'p'].edge_index.shape == (2, 800) + assert test_data['a', 'p'].edge_attr.shape == (800, ) + assert test_data['a', 'p'].edge_attr.min() >= 500 + assert test_data['a', 'p'].edge_attr.max() <= 1500 + assert test_data['p', 'a'].edge_label_index.shape == (2, 400) + assert test_data['p', 'a'].edge_label.shape == (400, ) + + transform = RandomLinkSplit(num_val=0.2, num_test=0.2, is_undirected=True, + edge_types=[('p', 'p'), ('p', 'a')], + rev_edge_types=[None, ('a', 'p')]) + train_data, val_data, test_data = transform(data) + + assert len(train_data['p']) == 1 + assert len(train_data['a']) == 1 + assert len(train_data['p', 'p']) == 4 + assert len(train_data['p', 'a']) == 4 + assert len(train_data['a', 'p']) == 2 + + assert is_undirected(train_data['p', 'p'].edge_index, + train_data['p', 'p'].edge_attr) + assert train_data['p', 'a'].edge_index.shape == (2, 600) + assert train_data['a', 'p'].edge_index.shape == (2, 600) + + # No reverse edge types specified: + transform = RandomLinkSplit(edge_types=[('p', 'p'), ('p', 'a')]) + train_data, val_data, test_data = transform(data) + assert train_data['p', 'p'].num_edges < data['p', 'p'].num_edges + assert train_data['p', 'a'].num_edges < data['p', 'a'].num_edges + assert train_data['a', 'p'].num_edges == data['a', 'p'].num_edges + + +def test_random_link_split_on_undirected_hetero_data(): + data = HeteroGraph() + data['p'].x = ops.arange(100) + data['p', 'p'].edge_index = get_random_edge_index(100, 100, 500) + data['p', 'p'].edge_index = to_undirected(data['p', 'p'].edge_index) + + transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p')) + train_data, val_data, test_data = transform(data) + assert train_data['p', 'p'].is_undirected() + + transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'), + rev_edge_types=('p', 'p')) + train_data, val_data, test_data = transform(data) + assert train_data['p', 'p'].is_undirected() + + transform = RandomLinkSplit(is_undirected=True, edge_types=('p', 'p'), + rev_edge_types=('p', 'p')) + train_data, val_data, test_data = transform(data) + assert train_data['p', 'p'].is_undirected() + + +def test_random_link_split_insufficient_negative_edges(): + edge_index = Tensor([[0, 0, 1, 1, 2, 2], [1, 3, 0, 2, 0, 1]]) + data = Graph(edge_index=edge_index, num_nodes=4) + + transform = RandomLinkSplit(num_val=0.34, num_test=0.34, + is_undirected=False, neg_sampling_ratio=2, + split_labels=True) + + with pytest.warns(UserWarning, match="not enough negative edges"): + train_data, val_data, test_data = transform(data) + + assert train_data.neg_edge_label_index.shape == (2, 2) + assert val_data.neg_edge_label_index.shape == (2, 2) + assert test_data.neg_edge_label_index.shape == (2, 2) + + +# def test_random_link_split_non_contiguous(): +# edge_index = get_random_edge_index(40, 40, num_edges=150) +# edge_index = edge_index[:, :100] +# assert not edge_index.is_contiguous() + +# data = Graph(edge_index=edge_index, num_nodes=40) +# transform = RandomLinkSplit(num_val=0.2, num_test=0.2) +# train_data, val_data, test_data = transform(data) +# assert train_data.num_edges == 60 +# assert train_data.edge_index.is_contiguous() + +# data = HeteroGraph() +# data['p'].num_nodes = 40 +# data['p', 'p'].edge_index = edge_index +# transform = RandomLinkSplit(num_val=0.2, num_test=0.2, +# edge_types=('p', 'p')) +# train_data, val_data, test_data = transform(data) +# assert train_data['p', 'p'].num_edges == 60 +# assert train_data['p', 'p'].edge_index.is_contiguous() + + +@onlyOnline +def test_random_link_split_on_dataset(get_dataset): + dataset = get_dataset(name='MUTAG') + + dataset.transform = RandomLinkSplit( + num_val=0.1, + num_test=0.1, + disjoint_train_ratio=0.3, + add_negative_train_samples=False, + ) + + train_dataset, val_dataset, test_dataset = zip(*dataset) + assert len(train_dataset) == len(dataset) + assert len(val_dataset) == len(dataset) + assert len(test_dataset) == len(dataset) + + assert isinstance(train_dataset[0], Graph) + assert train_dataset[0].edge_label.min() == 1.0 + assert train_dataset[0].edge_label.max() == 1.0 + + assert isinstance(val_dataset[0], Graph) + assert val_dataset[0].edge_label.min() == 0.0 + assert val_dataset[0].edge_label.max() == 1.0 + + assert isinstance(test_dataset[0], Graph) + assert test_dataset[0].edge_label.min() == 0.0 + assert test_dataset[0].edge_label.max() == 1.0 diff --git a/tests/graph/transforms/test_random_node_split.py b/tests/graph/transforms/test_random_node_split.py new file mode 100644 index 000000000..068cb5451 --- /dev/null +++ b/tests/graph/transforms/test_random_node_split.py @@ -0,0 +1,159 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import RandomNodeSplit + + +@pytest.mark.parametrize('num_splits', [1, 2]) +def test_random_node_split(num_splits): + num_nodes, num_classes = 1000, 4 + x = ops.randn(num_nodes, 16) + y = ops.randint(0, num_classes, (num_nodes, ), dtype=ms.int64) + data = Graph(x=x, y=y) + + transform = RandomNodeSplit(split='train_rest', num_splits=num_splits, + num_val=100, num_test=200) + assert str(transform) == 'RandomNodeSplit(split=train_rest)' + data = transform(data) + assert len(data) == 5 + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + assert train_mask.shape == (num_nodes, num_splits) + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + assert val_mask.shape == (num_nodes, num_splits) + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + assert test_mask.shape == (num_nodes, num_splits) + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == num_nodes - 100 - 200 + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == 200 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == num_nodes) + + transform = RandomNodeSplit(split='train_rest', num_splits=num_splits, + num_val=0.1, num_test=0.2) + data = transform(data) + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == num_nodes - 100 - 200 + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == 200 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == num_nodes) + + transform = RandomNodeSplit(split='test_rest', num_splits=num_splits, + num_train_per_class=10, num_val=100) + assert str(transform) == 'RandomNodeSplit(split=test_rest)' + data = transform(data) + assert len(data) == 5 + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == 10 * num_classes + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == num_nodes) + + transform = RandomNodeSplit(split='test_rest', num_splits=num_splits, + num_train_per_class=10, num_val=0.1) + data = transform(data) + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == 10 * num_classes + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == num_nodes - 10 * num_classes - 100 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == num_nodes) + + transform = RandomNodeSplit(split='random', num_splits=num_splits, + num_train_per_class=10, num_val=100, + num_test=200) + assert str(transform) == 'RandomNodeSplit(split=random)' + data = transform(data) + assert len(data) == 5 + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == 10 * num_classes + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == 200 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200) + + transform = RandomNodeSplit(split='random', num_splits=num_splits, + num_train_per_class=10, num_val=0.1, + num_test=0.2) + assert str(transform) == 'RandomNodeSplit(split=random)' + data = transform(data) + + train_mask = data.train_mask + train_mask = train_mask.unsqueeze(-1) if num_splits == 1 else train_mask + val_mask = data.val_mask + val_mask = val_mask.unsqueeze(-1) if num_splits == 1 else val_mask + test_mask = data.test_mask + test_mask = test_mask.unsqueeze(-1) if num_splits == 1 else test_mask + + for i in range(train_mask.shape[-1]): + assert train_mask[:, i].sum() == 10 * num_classes + assert val_mask[:, i].sum() == 100 + assert test_mask[:, i].sum() == 200 + assert (train_mask[:, i] & val_mask[:, i] & test_mask[:, i]).sum() == 0 + assert ((train_mask[:, i] | val_mask[:, i] + | test_mask[:, i]).sum() == 10 * num_classes + 100 + 200) + + +def test_random_node_split_on_hetero_data(): + data = HeteroGraph() + + data['paper'].x = ops.randn(2000, 16) + data['paper'].y = ops.randint(0, 4, (2000, ), dtype=ms.int64) + data['author'].x = ops.randn(300, 16) + + transform = RandomNodeSplit() + assert str(transform) == 'RandomNodeSplit(split=train_rest)' + data = transform(data) + assert len(data) == 5 + + assert len(data['author']) == 1 + assert len(data['paper']) == 5 + + assert data['paper'].train_mask.sum() == 500 + assert data['paper'].val_mask.sum() == 500 + assert data['paper'].test_mask.sum() == 1000 diff --git a/tests/graph/transforms/test_random_rotate.py b/tests/graph/transforms/test_random_rotate.py new file mode 100644 index 000000000..85c27f134 --- /dev/null +++ b/tests/graph/transforms/test_random_rotate.py @@ -0,0 +1,46 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RandomRotate + + +def test_random_rotate(): + assert str(RandomRotate([-180, 180])) == ('RandomRotate(' + '[-180, 180], axis=0)') + + pos = Tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) + + data = Graph(crd=pos) + data = RandomRotate(0)(data) + assert len(data) == 1 + assert data.crd.tolist() == pos.tolist() + + data = Graph(crd=pos) + data = RandomRotate([180, 180])(data) + assert len(data) == 1 + assert data.crd.tolist() == [[1, 1], [1, -1], [-1, 1], [-1, -1]] + + pos = Tensor([ + [-1.0, -1.0, 1.0], + [-1.0, 1.0, 1.0], + [1.0, -1.0, -1.0], + [1.0, 1.0, -1.0], + ]) + + data = Graph(crd=pos) + data = RandomRotate([180, 180], axis=0)(data) + assert len(data) == 1 + assert data.crd.tolist() == [[-1, 1, -1], [-1, -1, -1], [1, 1, 1], + [1, -1, 1]] + + data = Graph(crd=pos) + data = RandomRotate([180, 180], axis=1)(data) + assert len(data) == 1 + assert data.crd.tolist() == [[1, -1, -1], [1, 1, -1], [-1, -1, 1], + [-1, 1, 1]] + + data = Graph(crd=pos) + data = RandomRotate([180, 180], axis=2)(data) + assert len(data) == 1 + assert data.crd.tolist() == [[1, 1, 1], [1, -1, 1], [-1, 1, -1], + [-1, -1, -1]] diff --git a/tests/graph/transforms/test_random_scale.py b/tests/graph/transforms/test_random_scale.py new file mode 100644 index 000000000..a5ceb2d35 --- /dev/null +++ b/tests/graph/transforms/test_random_scale.py @@ -0,0 +1,20 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RandomScale + + +def test_random_scale(): + assert str(RandomScale([1, 2])) == 'RandomScale([1, 2])' + + pos = Tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) + + data = Graph(crd=pos) + data = RandomScale([1, 1])(data) + assert len(data) == 1 + assert data.crd.tolist() == pos.tolist() + + data = Graph(crd=pos) + data = RandomScale([2, 2])(data) + assert len(data) == 1 + assert data.crd.tolist() == [[-2, -2], [-2, 2], [2, -2], [2, 2]] diff --git a/tests/graph/transforms/test_random_shear.py b/tests/graph/transforms/test_random_shear.py new file mode 100644 index 000000000..4ed6cac62 --- /dev/null +++ b/tests/graph/transforms/test_random_shear.py @@ -0,0 +1,20 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RandomShear + + +def test_random_shear(): + assert str(RandomShear(0.1)) == 'RandomShear(0.1)' + + pos = Tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]]) + + data = Graph(crd=pos) + data = RandomShear(0)(data) + assert len(data) == 1 + assert ops.isclose(data.crd, pos).all() + + data = Graph(crd=pos) + data = RandomShear(0.1)(data) + assert len(data) == 1 + assert not ops.isclose(data.crd, pos).all() diff --git a/tests/graph/transforms/test_remove_duplicated_edges.py b/tests/graph/transforms/test_remove_duplicated_edges.py new file mode 100644 index 000000000..02195db3d --- /dev/null +++ b/tests/graph/transforms/test_remove_duplicated_edges.py @@ -0,0 +1,20 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RemoveDuplicatedEdges + + +def test_remove_duplicated_edges(): + edge_index = Tensor([[0, 0, 0, 0, 1, 1, 1, 1], + [0, 0, 1, 1, 0, 0, 1, 1]]) + edge_weight = Tensor([1, 2, 3, 4, 5, 6, 7, 8]) + data = Graph(edge_index=edge_index, edge_weight=edge_weight, num_nodes=2) + + transform = RemoveDuplicatedEdges() + assert str(transform) == 'RemoveDuplicatedEdges()' + + out = transform(data) + assert len(out) == 3 + assert out.num_nodes == 2 + assert out.edge_index.tolist() == [[0, 0, 1, 1], [0, 1, 0, 1]] + assert out.edge_weight.tolist() == [3, 7, 11, 15] diff --git a/tests/graph/transforms/test_remove_isolated_nodes.py b/tests/graph/transforms/test_remove_isolated_nodes.py new file mode 100644 index 000000000..a3a371ae7 --- /dev/null +++ b/tests/graph/transforms/test_remove_isolated_nodes.py @@ -0,0 +1,51 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import RemoveIsolatedNodes + + +def test_remove_isolated_nodes(): + assert str(RemoveIsolatedNodes()) == 'RemoveIsolatedNodes()' + + data = Graph() + data.x = ops.arange(3) + data.edge_index = Tensor([[0, 2], [2, 0]]) + data.edge_attr = ops.arange(2) + + data = RemoveIsolatedNodes()(data) + + assert len(data) == 3 + assert data.x.tolist() == [0, 2] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + assert data.edge_attr.tolist() == [0, 1] + + +def test_remove_isolated_nodes_in_hetero_data(): + data = HeteroGraph() + + data['p'].x = ops.arange(6) + data['a'].x = ops.arange(6) + data['i'].num_nodes = 4 + + # isolated paper nodes: {4} + # isolated author nodes: {3, 4, 5} + # isolated institution nodes: {0, 1, 2, 3} + data['p', '1', 'p'].edge_index = Tensor([[0, 1, 2], [0, 1, 3]]) + data['p', '2', 'a'].edge_index = Tensor([[1, 3, 5], [0, 1, 2]]) + data['p', '2', 'a'].edge_attr = ops.arange(3) + data['p', '3', 'a'].edge_index = Tensor([[5], [2]]) + + data = RemoveIsolatedNodes()(data) + + assert len(data) == 4 + assert data['p'].num_nodes == 5 + assert data['a'].num_nodes == 3 + assert data['i'].num_nodes == 0 + + assert data['p'].x.tolist() == [0, 1, 2, 3, 5] + assert data['a'].x.tolist() == [0, 1, 2] + + assert data['1'].edge_index.tolist() == [[0, 1, 2], [0, 1, 3]] + assert data['2'].edge_index.tolist() == [[1, 3, 4], [0, 1, 2]] + assert data['2'].edge_attr.tolist() == [0, 1, 2] + assert data['3'].edge_index.tolist() == [[4], [2]] diff --git a/tests/graph/transforms/test_remove_training_classes.py b/tests/graph/transforms/test_remove_training_classes.py new file mode 100644 index 000000000..5b2258b5a --- /dev/null +++ b/tests/graph/transforms/test_remove_training_classes.py @@ -0,0 +1,19 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import RemoveTrainingClasses + + +def test_remove_training_classes(): + y = Tensor([1, 0, 0, 2, 1, 3]) + train_mask = Tensor([False, False, True, True, True, True]) + + data = Graph(y=y, train_mask=train_mask) + + transform = RemoveTrainingClasses(classes=[0, 1]) + assert str(transform) == 'RemoveTrainingClasses([0, 1])' + + data = transform(data) + assert len(data) == 2 + assert ops.equal(data.y, y).all() + assert data.train_mask.tolist() == [False, False, False, True, False, True] diff --git a/tests/graph/transforms/test_rooted_subgraph.py b/tests/graph/transforms/test_rooted_subgraph.py new file mode 100644 index 000000000..10111ccff --- /dev/null +++ b/tests/graph/transforms/test_rooted_subgraph.py @@ -0,0 +1,87 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.loader import DataLoader +from mindscience.sharker.transforms import RootedEgoNets, RootedRWSubgraph + + +def test_rooted_ego_nets(): + x = ops.randn(3, 8) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = ops.randn(4, 8) + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr) + + transform = RootedEgoNets(num_hops=1) + assert str(transform) == 'RootedEgoNets(num_hops=1)' + + out = transform(data) + assert len(out) == 8 + + assert ops.equal(out.x, data.x).all() + assert ops.equal(out.edge_index, data.edge_index).all() + assert ops.equal(out.edge_attr, data.edge_attr).all() + + assert out.sub_edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6], + [1, 0, 3, 2, 4, 3, 6, 5]] + assert out.n_id.tolist() == [0, 1, 0, 1, 2, 1, 2] + assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2] + assert out.e_id.tolist() == [0, 1, 0, 1, 2, 3, 2, 3] + assert out.e_sub_batch.tolist() == [0, 0, 1, 1, 1, 1, 2, 2] + + out = out.map_data() + assert len(out) == 4 + + assert ops.isclose(out.x, x[[0, 1, 0, 1, 2, 1, 2]]).all() + assert out.edge_index.tolist() == [[0, 1, 2, 3, 3, 4, 5, 6], + [1, 0, 3, 2, 4, 3, 6, 5]] + assert ops.isclose(out.edge_attr, edge_attr[[0, 1, 0, 1, 2, 3, 2, 3]]).all() + assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 1, 2, 2] + + +def test_rooted_rw_subgraph(): + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + data = Graph(edge_index=edge_index, num_nodes=3) + + transform = RootedRWSubgraph(walk_length=1) + assert str(transform) == 'RootedRWSubgraph(walk_length=1)' + + out = transform(data) + assert len(out) == 7 + + assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2] + assert out.sub_edge_index.shape == (2, 6) + + out = out.map_data() + assert len(out) == 3 + + assert out.edge_index.shape == (2, 6) + assert out.num_nodes == 6 + assert out.n_sub_batch.tolist() == [0, 0, 1, 1, 2, 2] + + +def test_rooted_subgraph_minibatch(): + x = ops.randn(3, 8) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = ops.randn(4, 8) + data = Graph(x=x, edge_index=edge_index, edge_attr=edge_attr) + + transform = RootedEgoNets(num_hops=1) + data = transform(data) + + loader = DataLoader([data, data], batch_size=2) + batch = next(iter(loader)) + batch = batch.map_data() + assert batch.num_graphs == len(batch) == 2 + + assert batch.x.shape == (14, 8) + assert batch.edge_index.shape == (2, 16) + assert batch.edge_attr.shape == (16, 8) + assert batch.n_sub_batch.shape == (14, ) + assert batch.batch.shape == (14, ) + assert batch.ptr.shape == (3, ) + + assert batch.edge_index.min() == 0 + assert batch.edge_index.max() == 13 + + assert batch.n_sub_batch.min() == 0 + assert batch.n_sub_batch.max() == 5 diff --git a/tests/graph/transforms/test_sample_points.py b/tests/graph/transforms/test_sample_points.py new file mode 100644 index 000000000..02694ff80 --- /dev/null +++ b/tests/graph/transforms/test_sample_points.py @@ -0,0 +1,31 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import SamplePoints + + +def test_sample_points(): + assert str(SamplePoints(1024)) == 'SamplePoints(1024)' + + pos = Tensor([ + [0.0, 0.0, 0.0], + [1.0, 0.0, 0.0], + [0.0, 1.0, 0.0], + [1.0, 1.0, 0.0], + ]) + face = Tensor([[0, 1], [1, 2], [2, 3]]) + + data = Graph(crd=pos) + data.face = face + data = SamplePoints(8)(data) + assert len(data) == 1 + assert pos[:, 0].min() >= 0 and pos[:, 0].max() <= 1 + assert pos[:, 1].min() >= 0 and pos[:, 1].max() <= 1 + assert pos[:, 2].abs().sum() == 0 + + data = Graph(crd=pos) + data.face = face + data = SamplePoints(8, include_normals=True)(data) + assert len(data) == 2 + assert data.normal[:, :2].abs().sum() == 0 + assert data.normal[:, 2].abs().sum() == 8 diff --git a/tests/graph/transforms/test_sign.py b/tests/graph/transforms/test_sign.py new file mode 100644 index 000000000..bd290fd9b --- /dev/null +++ b/tests/graph/transforms/test_sign.py @@ -0,0 +1,32 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import SIGN + + +def test_sign(): + x = ops.ones([5, 3]) + edge_index = Tensor([ + [0, 1, 2, 3, 3, 4], + [1, 0, 3, 2, 4, 3], + ]) + data = Graph(x=x, edge_index=edge_index) + + transform = SIGN(K=2) + assert str(transform) == 'SIGN(K=2)' + + expected_x1 = Tensor([ + [1, 1, 1], + [1, 1, 1], + [0.7071, 0.7071, 0.7071], + [1.4142, 1.4142, 1.4142], + [0.7071, 0.7071, 0.7071], + ]) + expected_x2 = ops.ones([5, 3]) + + out = transform(data) + assert len(out) == 4 + assert ops.equal(out.edge_index, edge_index).all() + assert ops.isclose(out.x, x).all() + assert ops.isclose(out.x1, expected_x1, atol=1e-4).all() + assert ops.isclose(out.x2, expected_x2).all() diff --git a/tests/graph/transforms/test_spherical.py b/tests/graph/transforms/test_spherical.py new file mode 100644 index 000000000..ca549da30 --- /dev/null +++ b/tests/graph/transforms/test_spherical.py @@ -0,0 +1,61 @@ +from math import pi as PI + +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import Spherical + + +def test_spherical(): + assert str(Spherical()) == 'Spherical(norm=True, max_value=None)' + + pos = Tensor([[0.0, 0.0, 0.0], [1.0, 0.0, 0.0]]) + edge_index = Tensor([[0, 1], [1, 0]]) + edge_attr = Tensor([1.0, 1.0]) + + data = Graph(edge_index=edge_index, crd=pos) + data = Spherical(norm=False)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 0.0, PI / 2.0], [1.0, PI, PI / 2.0]]), + atol=1e-4, + ).all() + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = Spherical(norm=True)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 1.0, 0.0, 0.5], [1.0, 1.0, 0.5, 0.5]]), + atol=1e-4, + ).all() + + pos = Tensor([[0.0, 0.0, 0.0], [0.0, 0.0, 1.0]]) + edge_index = Tensor([[0, 1], [1, 0]]) + + data = Graph(edge_index=edge_index, crd=pos) + data = Spherical(norm=False)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 0.0, 0.0], [1.0, 0.0, PI]]), + atol=1e-4, + ).all() + + data = Graph(edge_index=edge_index, crd=pos, edge_attr=edge_attr) + data = Spherical(norm=True)(data) + assert len(data) == 3 + assert data.crd.tolist() == pos.tolist() + assert data.edge_index.tolist() == edge_index.tolist() + assert ops.isclose( + data.edge_attr, + Tensor([[1.0, 1.0, 0.0, 0.0], [1.0, 1.0, 0.0, 1.0]]), + atol=1e-4, + ).all() diff --git a/tests/graph/transforms/test_svd_feature_reduction.py b/tests/graph/transforms/test_svd_feature_reduction.py new file mode 100644 index 000000000..9b24e0bac --- /dev/null +++ b/tests/graph/transforms/test_svd_feature_reduction.py @@ -0,0 +1,19 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import SVDFeatureReduction + + +def test_svd_feature_reduction(): + assert str(SVDFeatureReduction(10)) == 'SVDFeatureReduction(10)' + + x = ops.randn(4, 16) + S, U, _ = ops.svd(x) + data = Graph(x=x) + data = SVDFeatureReduction(10)(data) + assert ops.isclose(data.x, ops.mm(U[:, :10], ops.diag(S[:10]))).all() + + x = ops.randn(4, 8) + data.x = x + data = SVDFeatureReduction(10)(Graph(x=x)) + assert ops.isclose(data.x, x).all() diff --git a/tests/graph/transforms/test_target_indegree.py b/tests/graph/transforms/test_target_indegree.py new file mode 100644 index 000000000..ffcf532cf --- /dev/null +++ b/tests/graph/transforms/test_target_indegree.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import TargetIndegree + + +def test_target_indegree(): + assert str(TargetIndegree()) == 'TargetIndegree(norm=True, max_value=None)' + + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_attr = Tensor([1.0, 1.0, 1.0, 1.0]) + + data = Graph(edge_index=edge_index, num_nodes=3) + data = TargetIndegree(norm=False)(data) + assert len(data) == 3 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [[2], [1], [1], [2]] + assert data.num_nodes == 3 + + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=3) + data = TargetIndegree(norm=True)(data) + assert len(data) == 3 + assert data.edge_index.tolist() == edge_index.tolist() + assert data.edge_attr.tolist() == [[1, 1], [1, 0.5], [1, 0.5], [1, 1]] + assert data.num_nodes == 3 diff --git a/tests/graph/transforms/test_to_dense.py b/tests/graph/transforms/test_to_dense.py new file mode 100644 index 000000000..95cfe76c2 --- /dev/null +++ b/tests/graph/transforms/test_to_dense.py @@ -0,0 +1,54 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import ToDense + + +def test_to_dense(): + edge_index = Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + num_nodes = edge_index.max().item() + 1 + x = ops.randn((num_nodes, 4)) + pos = ops.randn((num_nodes, 3)) + y = ops.randint(0, 4, (num_nodes, ), dtype=ms.int64) + + transform = ToDense() + assert str(transform) == 'ToDense()' + data = Graph(x=x, crd=pos, edge_index=edge_index, edge_attr=edge_attr, y=y) + data = transform(data) + assert len(data) == 5 + assert data.x.tolist() == x.tolist() + assert data.crd.tolist() == pos.tolist() + assert data.y.tolist() == y.tolist() + assert data.adj.shape == (num_nodes, num_nodes) + assert data.adj.tolist() == [ + [0, 1, 2, 3], + [4, 0, 0, 0], + [5, 0, 0, 0], + [6, 0, 0, 0], + ] + assert data.mask.tolist() == [1, 1, 1, 1] + + transform = ToDense(num_nodes=5) + assert str(transform) == 'ToDense(num_nodes=5)' + data = Graph(x=x, crd=pos, edge_index=edge_index, edge_attr=edge_attr, y=y) + data = transform(data) + assert len(data) == 5 + assert data.x.shape == (5, 4) + assert data.x[:4].tolist() == x.tolist() + assert data.x[4].tolist() == [0, 0, 0, 0] + assert data.crd.shape == (5, 3) + assert data.crd[:4].tolist() == pos.tolist() + assert data.crd[4].tolist() == [0, 0, 0] + assert data.y.shape == (5, ) + assert data.y[:4].tolist() == y.tolist() + assert data.y[4].tolist() == 0 + assert data.adj.shape == (5, 5) + assert data.adj.tolist() == [ + [0, 1, 2, 3, 0], + [4, 0, 0, 0, 0], + [5, 0, 0, 0, 0], + [6, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + ] + assert data.mask.tolist() == [1, 1, 1, 1, 0] diff --git a/tests/graph/transforms/test_to_sparse_tensor.py b/tests/graph/transforms/test_to_sparse_tensor.py new file mode 100644 index 000000000..270e5ff9e --- /dev/null +++ b/tests/graph/transforms/test_to_sparse_tensor.py @@ -0,0 +1,124 @@ +import pytest +from mindspore import Tensor, ops +from mindscience.sharker import typing +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import ToSparseTensor +from mindscience.sharker.sparse import Layout + + +@pytest.mark.parametrize('layout', [None, Layout.COO, Layout.CSR]) +def test_to_sparse_Tensor_basic(layout): + transform = ToSparseTensor(layout=layout) + assert str(transform) == (f'ToSparseTensor(attr=edge_weight, ' + f'layout={layout})') + + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + perm = Tensor([1, 0, 3, 2]) + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, num_nodes=3) + data = transform(data) + + assert len(data) == 3 + assert data.num_nodes == 3 + assert ops.equal(data.edge_attr, edge_attr[perm]).all() + assert 'adj_t' in data + + if layout is None and typing.WITH_SPARSE: + row, col, value = data.adj_t.coo() + assert row.tolist() == [0, 1, 1, 2] + assert col.tolist() == [1, 0, 2, 1] + assert ops.equal(value, edge_weight[perm]).all() + else: + adj_t = data.adj_t + assert adj_t.layout == layout or Layout.CSR + if layout != Layout.COO: + adj_t = adj_t.to_coo() + assert adj_t.coalesce().indices.t().tolist() == [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ] + assert ops.equal(adj_t.coalesce().values, edge_weight[perm]).all() + + +def test_to_sparse_Tensor_and_keep_edge_index(): + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + perm = Tensor([1, 0, 3, 2]) + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, num_nodes=3) + data = ToSparseTensor(remove_edge_index=False)(data) + + assert len(data) == 5 + assert ops.equal(data.edge_index, edge_index[:, perm]).all() + assert ops.equal(data.edge_weight, edge_weight[perm]).all() + assert ops.equal(data.edge_attr, edge_attr[perm]).all() + + +@pytest.mark.parametrize('layout', [None, Layout.COO, Layout.CSR]) +def test_hetero_to_sparse_Tensor(layout): + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + + data = HeteroGraph() + data['v'].num_nodes = 3 + data['w'].num_nodes = 3 + data['v', 'v'].edge_index = edge_index + data['v', 'w'].edge_index = edge_index + + data = ToSparseTensor(layout=layout)(data) + + if layout is None and typing.WITH_SPARSE: + row, col, value = data['v', 'v'].adj_t.coo() + assert row.tolist() == [0, 1, 1, 2] + assert col.tolist() == [1, 0, 2, 1] + assert value is None + + row, col, value = data['v', 'w'].adj_t.coo() + assert row.tolist() == [0, 1, 1, 2] + assert col.tolist() == [1, 0, 2, 1] + assert value is None + else: + adj_t = data['v', 'v'].adj_t + assert adj_t.layout == layout or Layout.CSR + if layout != Layout.COO: + adj_t = adj_t.to_coo() + assert adj_t.coalesce().indices.t().tolist() == [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ] + assert adj_t.coalesce().values.tolist() == [1., 1., 1., 1.] + + adj_t = data['v', 'w'].adj_t + assert adj_t.layout == layout or Layout.CSR + if layout != Layout.COO: + adj_t = adj_t.to_coo() + assert adj_t.coalesce().indices.t().tolist() == [ + [0, 1, 1, 2], + [1, 0, 2, 1], + ] + assert adj_t.coalesce().values.tolist() == [1., 1., 1., 1.] + + +def test_to_sparse_Tensor_num_nodes_equals_num_edges(): + x = ops.arange(4) + y = ops.arange(4) + edge_index = Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + perm = Tensor([1, 0, 3, 2]) + + data = Graph(x=x, edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, y=y) + data = ToSparseTensor()(data) + + assert len(data) == 4 + assert ops.equal(data.x, x).all() + assert ops.equal(data.y, y).all() + assert ops.equal(data.edge_attr, edge_attr[perm]).all() diff --git a/tests/graph/transforms/test_to_superpixels.py b/tests/graph/transforms/test_to_superpixels.py new file mode 100644 index 000000000..3400a4851 --- /dev/null +++ b/tests/graph/transforms/test_to_superpixels.py @@ -0,0 +1,90 @@ +import os +import os.path as osp +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import download_url, extract_gz +from mindscience.sharker.loader import DataLoader +from mindscience.sharker.testing import onlyOnline, withPackage +from mindscience.sharker.transforms import ToSLIC + +resources = [ + 'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz', + 'https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz', +] + + +@onlyOnline +@withPackage('torchvision', 'skimage') +def test_to_superpixels(tmp_path): + import sharker.transforms as T + from torchvision.datasets.mnist import ( + MNIST, + read_image_file, + read_label_file, + ) + + raw_folder = osp.join(tmp_path, 'MNIST', 'raw') + processed_folder = osp.join(tmp_path, 'MNIST', 'processed') + + os.makedirs(raw_folder, exist_ok=True) + os.makedirs(processed_folder, exist_ok=True) + for resource in resources: + path = download_url(resource, raw_folder) + extract_gz(path, osp.join(tmp_path, raw_folder)) + + test_set = ( + read_image_file(osp.join(raw_folder, 't10k-images-idx3-ubyte')), + read_label_file(osp.join(raw_folder, 't10k-labels-idx1-ubyte')), + ) + + np.save(test_set, osp.join(processed_folder, 'training.pt')) + np.save(test_set, osp.join(processed_folder, 'test.pt')) + + dataset = MNIST(tmp_path, download=False) + + dataset.transform = T.Compose([T.ToTensor(), ToSLIC()]) + + data, y = dataset[0] + assert len(data) == 2 + assert data.crd.dim() == 2 and data.crd.shape[1] == 2 + assert data.x.dim() == 2 and data.x.shape[1] == 1 + assert data.crd.shape[0] == data.x.shape[0] + assert y == 7 + + loader = DataLoader(dataset, batch_size=2, shuffle=False) + for batch, y in loader: + assert batch.num_graphs == len(batch) == 2 + assert batch.crd.dim() == 2 and batch.crd.shape[1] == 2 + assert batch.x.dim() == 2 and batch.x.shape[1] == 1 + assert batch.batch.dim() == 1 + assert batch.ptr.dim() == 1 + assert batch.crd.shape[0] == batch.x.shape[0] == batch.batch.shape[0] + assert y.tolist() == [7, 2] + break + + dataset.transform = T.Compose( + [T.ToTensor(), ToSLIC(add_seg=True, add_img=True)]) + + data, y = dataset[0] + assert len(data) == 4 + assert data.crd.dim() == 2 and data.crd.shape[1] == 2 + assert data.x.dim() == 2 and data.x.shape[1] == 1 + assert data.crd.shape[0] == data.x.shape[0] + assert data.seg.shape == (1, 28, 28) + assert data.img.shape == (1, 1, 28, 28) + assert data.seg.max().item() + 1 == data.x.shape[0] + assert y == 7 + + loader = DataLoader(dataset, batch_size=2, shuffle=False) + for batch, y in loader: + assert batch.num_graphs == len(batch) == 2 + assert batch.crd.dim() == 2 and batch.crd.shape[1] == 2 + assert batch.x.dim() == 2 and batch.x.shape[1] == 1 + assert batch.batch.dim() == 1 + assert batch.ptr.dim() == 1 + assert batch.crd.shape[0] == batch.x.shape[0] == batch.batch.shape[0] + assert batch.seg.shape == (2, 28, 28) + assert batch.img.shape == (2, 1, 28, 28) + assert y.tolist() == [7, 2] + break diff --git a/tests/graph/transforms/test_to_undirected.py b/tests/graph/transforms/test_to_undirected.py new file mode 100644 index 000000000..638139616 --- /dev/null +++ b/tests/graph/transforms/test_to_undirected.py @@ -0,0 +1,67 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.transforms import ToUndirected + + +def test_to_undirected(): + assert str(ToUndirected()) == 'ToUndirected()' + + edge_index = Tensor([[2, 0, 2], [3, 1, 0]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + perm = Tensor([1, 2, 1, 2, 0, 0]) + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, num_nodes=4) + data = ToUndirected()(data) + assert len(data) == 4 + assert data.edge_index.tolist() == [[0, 0, 1, 2, 2, 3], [1, 2, 0, 0, 3, 2]] + assert data.edge_weight.tolist() == edge_weight[perm].tolist() + assert data.edge_attr.tolist() == edge_attr[perm].tolist() + assert data.num_nodes == 4 + + +def test_to_undirected_with_duplicates(): + edge_index = Tensor([[0, 0, 1, 1], [0, 1, 0, 2]]) + edge_weight = ops.ones(4) + + data = Graph(edge_index=edge_index, edge_weight=edge_weight, num_nodes=3) + data = ToUndirected()(data) + assert len(data) == 3 + assert data.edge_index.tolist() == [[0, 0, 1, 1, 2], [0, 1, 0, 2, 1]] + assert data.edge_weight.tolist() == [2, 2, 2, 1, 1] + assert data.num_nodes == 3 + + +def test_hetero_to_undirected(): + edge_index = Tensor([[2, 0], [3, 1]]) + edge_weight = ops.randn(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + perm = Tensor([1, 1, 0, 0]) + + data = HeteroGraph() + data['v'].num_nodes = 4 + data['w'].num_nodes = 4 + data['v', 'v'].edge_index = edge_index + data['v', 'v'].edge_weight = edge_weight + data['v', 'v'].edge_attr = edge_attr + data['v', 'w'].edge_index = edge_index + data['v', 'w'].edge_weight = edge_weight + data['v', 'w'].edge_attr = edge_attr + + assert not data.is_undirected() + data = ToUndirected()(data) + assert data.is_undirected() + + assert data['v', 'v'].edge_index.tolist() == [[0, 1, 2, 3], [1, 0, 3, 2]] + assert data['v', 'v'].edge_weight.tolist() == edge_weight[perm].tolist() + assert data['v', 'v'].edge_attr.tolist() == edge_attr[perm].tolist() + assert data['v', 'w'].edge_index.tolist() == edge_index.tolist() + assert data['v', 'w'].edge_weight.tolist() == edge_weight.tolist() + assert data['v', 'w'].edge_attr.tolist() == edge_attr.tolist() + assert data['w', 'v'].edge_index.tolist() == [[3, 1], [2, 0]] + assert data['w', 'v'].edge_weight.tolist() == edge_weight.tolist() + assert data['w', 'v'].edge_attr.tolist() == edge_attr.tolist() diff --git a/tests/graph/transforms/test_two_hop.py b/tests/graph/transforms/test_two_hop.py new file mode 100644 index 000000000..05e0aa60e --- /dev/null +++ b/tests/graph/transforms/test_two_hop.py @@ -0,0 +1,27 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import TwoHop + + +def test_two_hop(): + transform = TwoHop() + assert str(transform) == 'TwoHop()' + + edge_index = Tensor([[0, 0, 0, 1, 2, 3], [1, 2, 3, 0, 0, 0]]) + edge_attr = Tensor([1, 2, 3, 1, 2, 3], dtype=ms.float32) + data = Graph(edge_index=edge_index, edge_attr=edge_attr, num_nodes=4) + + data = transform(data) + assert len(data) == 3 + assert data.edge_index.tolist() == [[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]] + assert data.edge_attr.tolist() == [1, 2, 3, 1, 0, 0, 2, 0, 0, 3, 0, 0] + assert data.num_nodes == 4 + + data = Graph(edge_index=edge_index, num_nodes=4) + data = transform(data) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2]] + assert data.num_nodes == 4 diff --git a/tests/graph/transforms/test_virtual_node.py b/tests/graph/transforms/test_virtual_node.py new file mode 100644 index 000000000..1dcf64af2 --- /dev/null +++ b/tests/graph/transforms/test_virtual_node.py @@ -0,0 +1,38 @@ +import mindspore as ms +from mindspore import Tensor, ops +from mindscience.sharker.data import Graph +from mindscience.sharker.transforms import VirtualNode + + +def test_virtual_node(): + assert str(VirtualNode()) == 'VirtualNode()' + + x = ops.randn(4, 16) + edge_index = Tensor([[2, 0, 2], [3, 1, 0]]) + edge_weight = ops.rand(edge_index.shape[1]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + data = Graph(x=x, edge_index=edge_index, edge_weight=edge_weight, + edge_attr=edge_attr, num_nodes=x.shape[0]) + + data = VirtualNode()(data) + assert len(data) == 6 + + assert data.x.shape == (5, 16) + assert ops.isclose(data.x[:4], x).all() + assert data.x[4:].abs().sum() == 0 + + assert data.edge_index.tolist() == [[2, 0, 2, 0, 1, 2, 3, 4, 4, 4, 4], + [3, 1, 0, 4, 4, 4, 4, 0, 1, 2, 3]] + + assert data.edge_weight.shape == (11, ) + assert ops.isclose(data.edge_weight[:3], edge_weight).all() + assert data.edge_weight[3:].abs().sum() == 8 + + assert data.edge_attr.shape == (11, 8) + assert ops.isclose(data.edge_attr[:3], edge_attr).all() + assert data.edge_attr[3:].abs().sum() == 0 + + assert data.num_nodes == 5 + + assert data.edge_type.tolist() == [0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2] diff --git a/tests/graph/utils/test_assortativity.py b/tests/graph/utils/test_assortativity.py new file mode 100644 index 000000000..c675c1f8b --- /dev/null +++ b/tests/graph/utils/test_assortativity.py @@ -0,0 +1,29 @@ +import pytest +import mindspore as ms +from mindscience.sharker import typing +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import assortativity + + +def test_assortativity(): + # Completely assortative graph: + edge_index = ms.Tensor([[0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 5], + [1, 2, 3, 0, 2, 3, 0, 1, 3, 0, 1, 2, 5, 4]]) + out = assortativity(edge_index) + assert pytest.approx(out, abs=1e-5) == 1.0 + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6]) + out = assortativity(adj) + assert pytest.approx(out, abs=1e-5) == 1.0 + + # Completely disassortative graph: + edge_index = ms.Tensor([[0, 1, 2, 3, 4, 5, 5, 5, 5, 5], + [5, 5, 5, 5, 5, 0, 1, 2, 3, 4]]) + out = assortativity(edge_index) + assert pytest.approx(out, abs=1e-5) == -1.0 + + if typing.WITH_SPARSE: + adj = SparseTensor.from_edge_index(edge_index, sparse_sizes=[6, 6]) + out = assortativity(adj) + assert pytest.approx(out, abs=1e-5) == -1.0 diff --git a/tests/graph/utils/test_augmentation.py b/tests/graph/utils/test_augmentation.py new file mode 100644 index 000000000..01d3c67cd --- /dev/null +++ b/tests/graph/utils/test_augmentation.py @@ -0,0 +1,96 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker import seed_everything +from mindscience.sharker.utils import ( + add_random_edge, + is_undirected, + mask_feature, + shuffle_node, +) + + +def test_shuffle_node(): + x = ms.Tensor([[0, 1, 2], [3, 4, 5]]).float() + + out = shuffle_node(x, training=False) + assert out[0].tolist() == x.tolist() + assert out[1].tolist() == list(range(len(x))) + + seed_everything(1) + out = shuffle_node(x) + assert out[0].shape == (2, 3) + assert out[1].shape == (2, ) + + + seed_everything(66) + x = ops.arange(21).view(7, 3).float() + batch = ms.Tensor([0, 0, 1, 1, 2, 2, 2]) + out = shuffle_node(x, batch) + assert out[0].shape == (7, 3) + assert out[1].shape == (7, ) + + +def test_mask_feature(): + x = ms.Tensor([[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]]).float() + + out = mask_feature(x, training=False) + assert out[0].tolist() == x.tolist() + assert mint.all(out[1]) + seed_everything(31) + out = mask_feature(x) + assert out[0].tolist() == [[1.0, 0.0, 3.0, 4.0], [5.0, 0.0, 7.0, 8.0], + [9.0, 0.0, 11.0, 12.0]] + assert out[1].tolist() == [[True, False, True, True]] + + + + seed_everything(32) + out = mask_feature(x, mode='row') + assert out[0].tolist() == [[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], + [9.0, 10.0, 11.0, 12.0]] + assert out[1].tolist() == [[True], [True], [True]] + + seed_everything(251) + out = mask_feature(x, mode='all') + assert out[0].tolist() == [[1.0, 2.0, 3.0, 4.0], [0.0, 6.0, 0.0, 0.0], + [0.0, 0.0, 0.0, 0.0]] + + assert out[1].tolist() == [[True, True, True, True], + [False, True, False, False], + [False, False, False, False]] + + seed_everything(251) + out = mask_feature(x, mode='all', fill_value=-1) + + +def test_add_random_edge(): + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + out = add_random_edge(edge_index, p=0.5, training=False) + assert out[0].tolist() == edge_index.tolist() + assert out[1].tolist() == [[], []] + + seed_everything(5) + out = add_random_edge(edge_index, p=0.5) + assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 3, 1, 2], + [1, 0, 2, 1, 3, 2, 0, 3, 0]] + assert out[1].tolist() == [[3, 1, 2], [0, 3, 0]] + + seed_everything(6) + out = add_random_edge(edge_index, p=0.5, force_undirected=True) + assert out[0].tolist() == [[0, 1, 1, 2, 2, 3, 1, 3], + [1, 0, 2, 1, 3, 2, 3, 1]] + assert out[1].tolist() == [[1, 3], [3, 1]] + assert is_undirected(out[0]) + assert is_undirected(out[1]) + + # Test for bipartite graph: + seed_everything(7) + edge_index = ms.Tensor([[0, 1, 2, 3, 4, 5], [2, 3, 1, 4, 2, 1]]) + out = add_random_edge(edge_index, p=0.5, num_nodes=(6, 5)) + assert out[0].tolist() == [[0, 1, 2, 3, 4, 5, 2, 0, 2], + [2, 3, 1, 4, 2, 1, 0, 4, 2]] + assert out[1].tolist() == [[2, 0, 2], [0, 4, 2]] + + with pytest.raises(RuntimeError, match="not supported for `bipartite graphs`"): + add_random_edge(edge_index, force_undirected=True, num_nodes=(6, 5)) diff --git a/tests/graph/utils/test_cluster.py b/tests/graph/utils/test_cluster.py new file mode 100644 index 000000000..07d58e621 --- /dev/null +++ b/tests/graph/utils/test_cluster.py @@ -0,0 +1,27 @@ +from mindspore import Tensor + +from mindscience.sharker.utils.cluster import nearest + + +def test_nearest(): + code_book = Tensor([ + [1., 1., 1.], + [1., 1., 1.], + [1., 1., 1.], + [2., 2., 2.], + [2., 2., 2.], + [2., 2., 2.]]).float() + features = Tensor([ + [1.9000, 2.3000, 1.7000], + [1.9000, 2.3000, 1.7000], + [1.9000, 2.3000, 1.7000], + [1.5000, 2.5000, 2.2000], + [1.5000, 2.5000, 2.2000], + [1.5000, 2.5000, 2.2000], + [0.8000, 0.6000, 1.7000], + [0.8000, 0.6000, 1.7000], + [0.8000, 0.6000, 1.7000]]).float() + results = nearest(code_book.float(), features.float(), batch_x=Tensor( + [0, 0, 1, 1, 2, 2]), batch_y= Tensor([0, 0, 0, 1, 1, 1, 2, 2, 2])) + assert results.tolist() == [0, 0, 0, 0, 3, 3] + diff --git a/tests/graph/utils/test_coalesce.py b/tests/graph/utils/test_coalesce.py new file mode 100644 index 000000000..c6cf46169 --- /dev/null +++ b/tests/graph/utils/test_coalesce.py @@ -0,0 +1,53 @@ +from typing import List, Optional, Tuple + +import mindspore as ms +from mindspore import Tensor, ops, nn + +from mindscience.sharker.utils import coalesce + + +def test_coalesce(): + edge_index = ms.Tensor([[2, 1, 1, 0, 2], [1, 2, 0, 1, 1]]) + edge_attr = ms.Tensor([[1], [2], [3], [4], [5]], dtype=ms.float32) + + out = coalesce(edge_index) + assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + + out = coalesce(edge_index, None) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1] is None + + out = coalesce(edge_index, edge_attr) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1].tolist() == [[4], [3], [2], [6]] + + out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)]) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1][0].tolist() == [[4], [3], [2], [6]] + assert out[1][1].tolist() == [4, 3, 2, 6] + + out = coalesce((edge_index[0], edge_index[1])) + assert isinstance(out, tuple) + assert out[0].tolist() == [0, 1, 1, 2] + assert out[1].tolist() == [1, 0, 2, 1] + + +def test_coalesce_without_duplicates(): + edge_index = ms.Tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) + edge_attr = ms.Tensor([[1], [2], [3], [4]]) + + out = coalesce(edge_index) + assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + + out = coalesce(edge_index, None) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1] is None + + out = coalesce(edge_index, edge_attr) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1].tolist() == [[4], [3], [2], [1]] + + out = coalesce(edge_index, [edge_attr, edge_attr.view(-1)]) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1][0].tolist() == [[4], [3], [2], [1]] + assert out[1][1].tolist() == [4, 3, 2, 1] diff --git a/tests/graph/utils/test_convert.py b/tests/graph/utils/test_convert.py new file mode 100644 index 000000000..b5ad5d04c --- /dev/null +++ b/tests/graph/utils/test_convert.py @@ -0,0 +1,399 @@ +import pytest +import scipy.sparse +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.data import Graph, HeteroGraph +from mindscience.sharker.testing import get_random_edge_index, withPackage +from mindscience.sharker.utils import ( + from_networkx, + from_scipy_sparse_matrix, + from_trimesh, + subgraph, + to_networkx, + to_scipy_sparse_matrix, + to_trimesh, +) + + +def test_to_scipy_sparse_matrix(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + + adj = to_scipy_sparse_matrix(edge_index) + assert isinstance(adj, scipy.sparse.coo_matrix) is True + assert adj.shape == (2, 2) + assert adj.row.tolist() == edge_index[0].tolist() + assert adj.col.tolist() == edge_index[1].tolist() + assert adj.data.tolist() == [1, 1, 1] + + edge_attr = ms.Tensor([1.0, 2.0, 3.0]) + adj = to_scipy_sparse_matrix(edge_index, edge_attr) + assert isinstance(adj, scipy.sparse.coo_matrix) is True + assert adj.shape == (2, 2) + assert adj.row.tolist() == edge_index[0].tolist() + assert adj.col.tolist() == edge_index[1].tolist() + assert adj.data.tolist() == edge_attr.tolist() + + +def test_from_scipy_sparse_matrix(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + adj = to_scipy_sparse_matrix(edge_index) + + out = from_scipy_sparse_matrix(adj) + assert out[0].tolist() == edge_index.tolist() + assert out[1].tolist() == [1, 1, 1] + + +@withPackage('networkx') +def test_to_networkx(): + import networkx as nx + + x = ms.Tensor([[1.0, 2.0], [3.0, 4.0]]) + crd = ms.Tensor([[0.0, 0.0], [1.0, 1.0]]) + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_attr = ms.Tensor([1.0, 2.0, 3.0]) + data = Graph(x=x, crd=crd, edge_index=edge_index, weight=edge_attr) + + for remove_self_loops in [True, False]: + G = to_networkx(data, node_attrs=['x', 'crd'], edge_attrs=['weight'], + remove_self_loops=remove_self_loops) + + assert G.nodes[0]['x'] == [1.0, 2.0] + assert G.nodes[1]['x'] == [3.0, 4.0] + assert G.nodes[0]['crd'] == [0.0, 0.0] + assert G.nodes[1]['crd'] == [1.0, 1.0] + + if remove_self_loops: + assert nx.to_numpy_array(G).tolist() == [[0.0, 1.0], [2.0, 0.0]] + else: + assert nx.to_numpy_array(G).tolist() == [[3.0, 1.0], [2.0, 0.0]] + + +@withPackage('networkx') +def test_from_networkx_set_node_attributes(): + import networkx as nx + + G = nx.path_graph(3) + attrs = { + 0: { + 'x': ms.Tensor([1, 0, 0]) + }, + 1: { + 'x': ms.Tensor([0, 1, 0]) + }, + 2: { + 'x': ms.Tensor([0, 0, 1]) + }, + } + nx.set_node_attributes(G, attrs) + + assert from_networkx(G).x.tolist() == [[1, 0, 0], [0, 1, 0], [0, 0, 1]] + + +@withPackage('networkx') +def test_to_networkx_undirected(): + import networkx as nx + + x = ms.Tensor([[1.0, 2.0], [3.0, 4.0]]) + crd = ms.Tensor([[0.0, 0.0], [1.0, 1.0]]) + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_attr = ms.Tensor([1.0, 2.0, 3.0]) + data = Graph(x=x, crd=crd, edge_index=edge_index, weight=edge_attr) + + for remove_self_loops in [True, False]: + G = to_networkx( + data, + node_attrs=['x', 'crd'], + edge_attrs=['weight'], + remove_self_loops=remove_self_loops, + to_undirected=True, + ) + + assert G.nodes[0]['x'] == [1, 2] + assert G.nodes[1]['x'] == [3, 4] + assert G.nodes[0]['crd'] == [0, 0] + assert G.nodes[1]['crd'] == [1, 1] + + if remove_self_loops: + assert nx.to_numpy_array(G).tolist() == [[0, 2], [2, 0]] + else: + assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]] + + G = to_networkx(data, edge_attrs=['weight'], to_undirected=False) + assert nx.to_numpy_array(G).tolist() == [[3, 1], [2, 0]] + + G = to_networkx(data, edge_attrs=['weight'], to_undirected='upper') + assert nx.to_numpy_array(G).tolist() == [[3, 1], [1, 0]] + + G = to_networkx(data, edge_attrs=['weight'], to_undirected='lower') + assert nx.to_numpy_array(G).tolist() == [[3, 2], [2, 0]] + + +def test_to_networkx_undirected_options(): + import networkx as nx + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 0]]) + data = Graph(edge_index=edge_index, num_nodes=3) + + G = to_networkx(data, to_undirected=True) + assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 1], [1, 1, 0]] + + G = to_networkx(data, to_undirected='upper') + assert nx.to_numpy_array(G).tolist() == [[0, 1, 0], [1, 0, 1], [0, 1, 0]] + + G = to_networkx(data, to_undirected='lower') + assert nx.to_numpy_array(G).tolist() == [[0, 1, 1], [1, 0, 0], [1, 0, 0]] + + +@withPackage('networkx') +def test_to_networkx_hetero(): + edge_index = get_random_edge_index(5, 10, 20, coalesce=True) + + data = HeteroGraph() + data['global_id'] = 0 + data['author'].x = mint.arange(5) + data['paper'].x = mint.arange(10) + data['author', 'paper'].edge_index = edge_index + data['author', 'paper'].edge_attr = mint.arange(edge_index.shape[1]) + + G = to_networkx(data, node_attrs=['x'], edge_attrs=['edge_attr'], + graph_attrs=['global_id']) + + assert G.number_of_nodes() == 15 + assert G.number_of_edges() == edge_index.shape[1] + + assert G.graph == {'global_id': 0} + + for i, (v, data) in enumerate(G.nodes(data=True)): + assert i == v + assert len(data) == 2 + if i < 5: + assert data['x'] == i + assert data['type'] == 'author' + else: + assert data['x'] == i - 5 + assert data['type'] == 'paper' + + for i, (v, w, data) in enumerate(G.edges(data=True)): + assert v == int(edge_index[0, i]) + assert w == int(edge_index[1, i]) + 5 + assert len(data) == 2 + assert data['type'] == ('author', 'to', 'paper') + assert data['edge_attr'] == i + + +@withPackage('networkx') +def test_from_networkx(): + x = ops.randn(2, 8) + crd = ops.randn(2, 3) + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_attr = ops.randn(edge_index.shape[1]) + perm = ms.Tensor([0, 2, 1]) + data = Graph(x=x, crd=crd, edge_index=edge_index, edge_attr=edge_attr) + G = to_networkx(data, node_attrs=['x', 'crd'], edge_attrs=['edge_attr']) + data = from_networkx(G) + assert len(data) == 4 + assert data.x.tolist() == x.tolist() + assert data.crd.tolist() == crd.tolist() + assert data.edge_index.tolist() == edge_index[:, perm].tolist() + assert data.edge_attr.tolist() == edge_attr[perm].tolist() + + +@withPackage('networkx') +def test_from_networkx_group_attrs(): + x = ops.randn(2, 2) + x1 = ops.randn(2, 4) + x2 = ops.randn(2, 8) + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_attr1 = ops.randn(edge_index.shape[1]) + edge_attr2 = ops.randn(edge_index.shape[1]) + perm = ms.Tensor([0, 2, 1]) + data = Graph(x=x, x1=x1, x2=x2, edge_index=edge_index, + edge_attr1=edge_attr1, edge_attr2=edge_attr2) + G = to_networkx(data, node_attrs=['x', 'x1', 'x2'], + edge_attrs=['edge_attr1', 'edge_attr2']) + data = from_networkx(G, group_node_attrs=['x', 'x2'], group_edge_attrs=all) + assert len(data) == 4 + assert data.x.tolist() == mint.cat(([x, x2]), dim=-1).tolist() + assert data.x1.tolist() == x1.tolist() + assert data.edge_index.tolist() == edge_index[:, perm].tolist() + assert data.edge_attr.tolist() == mint.stack([edge_attr1, edge_attr2], + dim=-1)[perm].tolist() + + +@withPackage('networkx') +def test_networkx_vice_versa_convert(): + import networkx as nx + + G = nx.complete_graph(5) + assert G.is_directed() is False + data = from_networkx(G) + assert data.is_directed() is False + G = to_networkx(data) + assert G.is_directed() is True + G = nx.to_undirected(G) + assert G.is_directed() is False + + +@withPackage('networkx') +def test_from_networkx_non_consecutive(): + import networkx as nx + + graph = nx.Graph() + graph.add_node(4) + graph.add_node(2) + graph.add_edge(4, 2) + for node in graph.nodes(): + graph.nodes[node]['x'] = node + + data = from_networkx(graph) + assert len(data) == 2 + assert data.x.tolist() == [4, 2] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + + +@withPackage('networkx') +def test_from_networkx_inverse(): + import networkx as nx + + graph = nx.Graph() + graph.add_node(3) + graph.add_node(2) + graph.add_node(1) + graph.add_node(0) + graph.add_edge(3, 1) + graph.add_edge(2, 1) + graph.add_edge(1, 0) + + data = from_networkx(graph) + assert len(data) == 2 + assert data.edge_index.tolist() == [[0, 1, 2, 2, 2, 3], [2, 2, 0, 1, 3, 2]] + assert data.num_nodes == 4 + + +@withPackage('networkx') +def test_from_networkx_non_numeric_labels(): + import networkx as nx + + graph = nx.Graph() + graph.add_node('4') + graph.add_node('2') + graph.add_edge('4', '2') + for node in graph.nodes(): + graph.nodes[node]['x'] = node + data = from_networkx(graph) + assert len(data) == 2 + assert data.x == ['4', '2'] + assert data.edge_index.tolist() == [[0, 1], [1, 0]] + + +@withPackage('networkx') +def test_from_networkx_without_edges(): + import networkx as nx + + graph = nx.Graph() + graph.add_node(1) + graph.add_node(2) + data = from_networkx(graph) + assert len(data) == 2 + assert data.edge_index.shape == (2, 0) + assert data.num_nodes == 2 + + +@withPackage('networkx') +def test_from_networkx_with_same_node_and_edge_attributes(): + import networkx as nx + + G = nx.Graph() + G.add_nodes_from([(0, {'age': 1}), (1, {'age': 6}), (2, {'age': 5})]) + G.add_edges_from([(0, 1, {'age': 2}), (1, 2, {'age': 7})]) + + data = from_networkx(G) + assert len(data) == 4 + assert data.age.tolist() == [1, 6, 5] + assert data.num_nodes == 3 + assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert data.edge_age.tolist() == [2, 2, 7, 7] + + data = from_networkx(G, group_node_attrs=all, group_edge_attrs=all) + assert len(data) == 3 + assert data.x.tolist() == [[1], [6], [5]] + assert data.edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert data.edge_attr.tolist() == [[2], [2], [7], [7]] + + +@withPackage('networkx') +def test_from_networkx_subgraph_convert(): + import networkx as nx + + G = nx.complete_graph(5) + + edge_index = from_networkx(G).edge_index + sub_edge_index_1, _ = subgraph([0, 1, 3, 4], edge_index, + relabel_nodes=True) + + sub_edge_index_2 = from_networkx(G.subgraph([0, 1, 3, 4])).edge_index + + assert sub_edge_index_1.tolist() == sub_edge_index_2.tolist() + + +@withPackage('networkx') +@pytest.mark.parametrize('n', [100]) +@pytest.mark.parametrize('p', [0.8]) +@pytest.mark.parametrize('q', [0.2]) +def test_from_networkx_sbm(n, p, q): + import networkx as nx + G = nx.stochastic_block_model( + sizes=[n // 2, n // 2], + p=[[p, q], [q, p]], + seed=0, + directed=False, + ) + + data = from_networkx(G) + assert data.num_nodes == 100 + assert ops.equal(data.block[:50], mint.zeros(50, dtype=data.block.dtype)).all() + assert ops.equal(data.block[50:], mint.ones(50, dtype=data.block.dtype)).all() + + + +@withPackage('trimesh') +def test_trimesh_vice_versa(): + crd = ms.Tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0]]).float() + face = ms.Tensor([[0, 1, 2], [1, 2, 3]]).t() + + data = Graph(crd=crd, face=face) + mesh = to_trimesh(data) + data = from_trimesh(mesh) + + assert crd.tolist() == data.crd.tolist() + assert face.tolist() == data.face.tolist() + + +@withPackage('trimesh') +def test_to_trimesh(): + import trimesh + + crd = ms.Tensor([[0, 0, 0], [0, 1, 0], [1, 0, 0], [1, 1, 0]]) + face = ms.Tensor([[0, 1, 2], [2, 1, 3]]).t() + data = Graph(crd=crd, face=face) + + obj = to_trimesh(data) + + assert isinstance(obj, trimesh.Trimesh) + assert obj.vertices.shape == (4, 3) + assert obj.faces.shape == (2, 3) + assert obj.vertices.tolist() == data.crd.tolist() + assert obj.faces.tolist() == data.face.t().tolist() + + +@withPackage('trimesh') +def test_from_trimesh(): + import trimesh + + vertices = [[0, 0, 0], [1, 0, 0], [0, 1, 0]] + faces = [[0, 1, 2]] + mesh = trimesh.Trimesh(vertices=vertices, faces=faces, process=False) + + data = from_trimesh(mesh) + + assert data.crd.tolist() == vertices + assert data.face.t().tolist() == faces diff --git a/tests/graph/utils/test_degree.py b/tests/graph/utils/test_degree.py new file mode 100644 index 000000000..a4e168ee7 --- /dev/null +++ b/tests/graph/utils/test_degree.py @@ -0,0 +1,9 @@ +import mindspore as ms +from mindscience.sharker.utils import degree + + +def test_degree(): + row = ms.Tensor([0, 1, 0, 2, 0]) + deg = degree(row, dtype=ms.int64) + assert deg.dtype == ms.int64 + assert deg.tolist() == [3, 1, 1] diff --git a/tests/graph/utils/test_dropout.py b/tests/graph/utils/test_dropout.py new file mode 100644 index 000000000..6f7be2cc3 --- /dev/null +++ b/tests/graph/utils/test_dropout.py @@ -0,0 +1,44 @@ +import mindspore as ms +from mindscience.sharker.utils import ( + dropout_edge, + dropout_node, +) + + +def test_dropout_node(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2], + ]) + + out = dropout_node(edge_index, training=False) + assert edge_index.tolist() == out[0].tolist() + assert out[1].tolist() == [True, True, True, True, True, True] + assert out[2].tolist() == [True, True, True, True] + + + ms.set_seed(5) + out = dropout_node(edge_index) + assert out[0].tolist() == [[1, 2], [2, 1]] + assert out[1].tolist() == [False, False, True, True, False, False] + assert out[2].tolist() == [False, True, True, False] + + + +def test_dropout_edge(): + edge_index = ms.Tensor([[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]]) + + out = dropout_edge(edge_index, training=False) + assert edge_index.tolist() == out[0].tolist() + assert out[1].tolist() == [True, True, True, True, True, True] + + ms.set_seed(5) + out = dropout_edge(edge_index) + assert out[0].tolist() == [[1, 1, 2, 3], [0, 2, 3, 2]] + assert out[1].tolist() == [False, True, True, False, True, True] + + ms.set_seed(6) + out = dropout_edge(edge_index, force_undirected=True) + assert out[0].tolist() == [[2, 3], [3, 2]] + assert out[1].tolist() == [4, 4] + diff --git a/tests/graph/utils/test_embedding.py b/tests/graph/utils/test_embedding.py new file mode 100644 index 000000000..9cbe437de --- /dev/null +++ b/tests/graph/utils/test_embedding.py @@ -0,0 +1,34 @@ +import pytest +import mindspore as ms +from mindspore import nn, ops +from mindscience.sharker.nn.conv import GCNConv +from mindscience.sharker.utils import get_embeddings + + +class GNN(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(5, 6) + self.conv2 = GCNConv(6, 7) + + def construct(self, x0, edge_index): + x1 = self.conv1(x0, edge_index) + x2 = self.conv2(x1, edge_index) + return [x1, x2] + + +def test_get_embeddings(): + x = ops.randn(6, 5) + edge_index = ms.Tensor([[0, 1, 2, 3, 4], [1, 2, 3, 4, 5]]) + + with pytest.warns(UserWarning, match="any 'MessagePassing' layers"): + intermediate_outs = get_embeddings(nn.Dense(5, 5), x) + assert len(intermediate_outs) == 0 + + model = GNN() + expected_embeddings = model(x, edge_index) + + embeddings = get_embeddings(model, x, edge_index) + assert len(embeddings) == 2 + for expected, out in zip(expected_embeddings, embeddings): + assert ops.isclose(expected, out).all() diff --git a/tests/graph/utils/test_functions.py b/tests/graph/utils/test_functions.py new file mode 100644 index 000000000..1ebbd3ea2 --- /dev/null +++ b/tests/graph/utils/test_functions.py @@ -0,0 +1,15 @@ +import mindspore as ms +from mindscience.sharker.utils import cumsum + + +def test_cumsum(): + """ + cumsum only support {Tensor[Float16], Tensor[Float32], + Tensor[Float64], Tensor[Int32], Tensor[Int8], Tensor[UInt8]} + + """ + x = ms.Tensor([2, 4, 1]) + assert cumsum(x).tolist() == [0, 2, 6, 7] + + x = ms.Tensor([[2, 4], [3, 6]]) + assert cumsum(x, axis=1).tolist() == [[0, 2, 6], [0, 3, 9]] diff --git a/tests/graph/utils/test_geodesic.py b/tests/graph/utils/test_geodesic.py new file mode 100644 index 000000000..8c11fb68d --- /dev/null +++ b/tests/graph/utils/test_geodesic.py @@ -0,0 +1,48 @@ +from math import sqrt + +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.testing import withPackage +from mindscience.sharker.utils import geodesic_distance + + +@withPackage('gdist') +@pytest.mark.skip(reason="No way of currently testing this") +def test_geodesic_distance(): + pos = ms.Tensor([ + [0.0, 0.0, 0.0], + [2.0, 0.0, 0.0], + [0.0, 2.0, 0.0], + [2.0, 2.0, 0.0], + ]) + face = ms.Tensor([[0, 1, 3], [0, 2, 3]]).t() + + out = geodesic_distance(pos, face) + expected = ms.Tensor([ + [0.0, 1.0, 1.0, sqrt(2)], + [1.0, 0.0, sqrt(2), 1.0], + [1.0, sqrt(2), 0.0, 1.0], + [sqrt(2), 1.0, 1.0, 0.0], + ]) + assert ops.isclose(out, expected).all() + assert ops.isclose(out, geodesic_distance(pos, face, num_workers=-1)).all() + + out = geodesic_distance(pos, face, norm=False) + expected = [ + [0, 2, 2, 2 * sqrt(2)], + [2, 0, 2 * sqrt(2), 2], + [2, 2 * sqrt(2), 0, 2], + [2 * sqrt(2), 2, 2, 0], + ] + assert ops.isclose(out, expected).all() + + src = ms.Tensor([0, 0, 0, 0]) + dst = ms.Tensor([0, 1, 2, 3]) + out = geodesic_distance(pos, face, src=src, dst=dst) + expected = ms.Tensor([0.0, 1.0, 1.0, sqrt(2)]) + assert ops.isclose(out, expected).all() + + out = geodesic_distance(pos, face, dst=dst) + expected = ms.Tensor([0.0, 0.0, 0.0, 0.0]) + assert ops.isclose(out, expected).all() diff --git a/tests/graph/utils/test_grid.py b/tests/graph/utils/test_grid.py new file mode 100644 index 000000000..94251f201 --- /dev/null +++ b/tests/graph/utils/test_grid.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import grid + + +def test_grid(): + (row, col), pos = grid(height=3, width=2) + + expected_row = [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2] + expected_col = [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5] + expected_row += [3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5] + expected_col += [0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5] + + expected_pos = [[0, 2], [1, 2], [0, 1], [1, 1], [0, 0], [1, 0]] + + assert row.tolist() == expected_row + assert col.tolist() == expected_col + assert pos.tolist() == expected_pos + + if is_full_test(): + jit = ms.jit(grid) + (row, col), pos = jit(height=3, width=2) + assert row.tolist() == expected_row + assert col.tolist() == expected_col + assert pos.tolist() == expected_pos diff --git a/tests/graph/utils/test_hetero.py b/tests/graph/utils/test_hetero.py new file mode 100644 index 000000000..7e3ce07a0 --- /dev/null +++ b/tests/graph/utils/test_hetero.py @@ -0,0 +1,38 @@ +from mindspore import ops +import mindspore as ms +from mindscience.sharker.testing import get_random_edge_index +from mindscience.sharker.utils.hetero import construct_bipartite_edge_index + + +def test_construct_bipartite_edge_index(): + edge_index = get_random_edge_index(4, 6, num_edges=20) + + edge_index_dict = { + ('author', 'paper'): edge_index, + ('paper', 'author'): edge_index.flip([0]), + } + edge_attr_dict = { + ('author', 'paper'): ops.randn(edge_index.shape[1], 16), + ('paper', 'author'): ops.randn(edge_index.shape[1], 16) + } + + edge_index, edge_attr = construct_bipartite_edge_index( + edge_index_dict, + src_offset_dict={ + ('author', 'paper'): 0, + ('paper', 'author'): 4 + }, + dst_offset_dict={ + 'author': 0, + 'paper': 4 + }, + edge_attr_dict=edge_attr_dict, + ) + + assert edge_index.shape == (2, 40) + assert edge_index.min() >= 0 + assert edge_index[0].max() > 4 and edge_index[1].max() > 6 + assert edge_index.max() <= 10 + assert edge_attr.shape == (40, 16) + assert ops.equal(edge_attr[:20], edge_attr_dict['author', 'paper']).all() + assert ops.equal(edge_attr[20:], edge_attr_dict['paper', 'author']).all() diff --git a/tests/graph/utils/test_homophily.py b/tests/graph/utils/test_homophily.py new file mode 100644 index 000000000..77fae6675 --- /dev/null +++ b/tests/graph/utils/test_homophily.py @@ -0,0 +1,32 @@ +import pytest +import mindspore as ms +from mindscience.sharker import typing +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import homophily + + +def test_homophily(): + edge_index = ms.Tensor([[0, 1, 2, 3], [1, 2, 0, 4]]) + y = ms.Tensor([0, 0, 0, 0, 1]) + batch = ms.Tensor([0, 0, 0, 1, 1]) + row, col = edge_index + if typing.WITH_SPARSE: + adj = SparseTensor(row=row, col=col, sparse_sizes=(5, 5)) + + method = 'edge' + assert pytest.approx(homophily(edge_index, y, method=method)) == 0.75 + if typing.WITH_SPARSE: + assert pytest.approx(homophily(adj, y, method=method)) == 0.75 + assert homophily(edge_index, y, batch, method).tolist() == [1., 0.] + + method = 'node' + assert pytest.approx(homophily(edge_index, y, method=method)) == 0.6 + if typing.WITH_SPARSE: + assert pytest.approx(homophily(adj, y, method=method)) == 0.6 + assert homophily(edge_index, y, batch, method).tolist() == [1., 0.] + + method = 'edge_insensitive' + assert pytest.approx(homophily(edge_index, y, method=method)) == 0.1999999 + if typing.WITH_SPARSE: + assert pytest.approx(homophily(adj, y, method=method)) == 0.1999999 + assert homophily(edge_index, y, batch, method).tolist() == [0., 0.] diff --git a/tests/graph/utils/test_isolated.py b/tests/graph/utils/test_isolated.py new file mode 100644 index 000000000..b68fd7731 --- /dev/null +++ b/tests/graph/utils/test_isolated.py @@ -0,0 +1,45 @@ +import mindspore as ms +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import ( + contains_isolated_nodes, + remove_isolated_nodes, +) + + +def test_contains_isolated_nodes(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + assert not contains_isolated_nodes(edge_index) + assert contains_isolated_nodes(edge_index, num_nodes=3) + + if is_full_test(): + jit = ms.jit(contains_isolated_nodes) + assert not jit(edge_index) + assert jit(edge_index, num_nodes=3) + + edge_index = ms.Tensor([[0, 1, 2, 0], [1, 0, 2, 0]]) + assert contains_isolated_nodes(edge_index) + + +def test_remove_isolated_nodes(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + + out, _, mask = remove_isolated_nodes(edge_index) + assert out.tolist() == [[0, 1, 0], [1, 0, 0]] + assert mask.tolist() == [1, 1] + + if is_full_test(): + jit = ms.jit(remove_isolated_nodes) + out, _, mask = jit(edge_index) + assert out.tolist() == [[0, 1, 0], [1, 0, 0]] + assert mask.tolist() == [1, 1] + + out, _, mask = remove_isolated_nodes(edge_index, num_nodes=3) + assert out.tolist() == [[0, 1, 0], [1, 0, 0]] + assert mask.tolist() == [1, 1, 0] + + edge_index = ms.Tensor([[0, 2, 1, 0, 2], [2, 0, 1, 0, 2]]) + edge_attr = ms.Tensor([1, 2, 3, 4, 5]) + out1, out2, mask = remove_isolated_nodes(edge_index, edge_attr) + assert out1.tolist() == [[0, 1, 0, 1], [1, 0, 0, 1]] + assert out2.tolist() == [1, 2, 4, 5] + assert mask.tolist() == [1, 0, 1] diff --git a/tests/graph/utils/test_laplacian.py b/tests/graph/utils/test_laplacian.py new file mode 100644 index 000000000..9be74c3de --- /dev/null +++ b/tests/graph/utils/test_laplacian.py @@ -0,0 +1,28 @@ +import mindspore as ms +from mindspore import mint +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import get_laplacian + + +def test_get_laplacian(): + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]], dtype=ms.int64) + edge_weight = ms.Tensor([1, 2, 2, 4]).float() + + lap = get_laplacian(edge_index, edge_weight) + assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], [1, 0, 2, 1, 0, 1, 2]] + assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4] + + if is_full_test(): + jit = ms.jit(get_laplacian) + lap = jit(edge_index, edge_weight) + assert lap[0].tolist() == [[0, 1, 1, 2, 0, 1, 2], + [1, 0, 2, 1, 0, 1, 2]] + assert lap[1].tolist() == [-1, -2, -2, -4, 1, 4, 4] + + lap_sym = get_laplacian(edge_index, edge_weight, normalization='sym') + assert lap_sym[0].tolist() == lap[0].tolist() + assert mint.isclose(lap_sym[1], ms.Tensor([-0.5, -1., -0.5, -1., 1., 1., 1.]), rtol=1e-04, atol=1e-04).all() + + lap_rw = get_laplacian(edge_index, edge_weight, normalization='rw') + assert lap_rw[0].tolist() == lap[0].tolist() + assert lap_rw[1].tolist() == [-1, -0.5, -0.5, -1, 1, 1, 1] diff --git a/tests/graph/utils/test_lexsort.py b/tests/graph/utils/test_lexsort.py new file mode 100644 index 000000000..c1ba62a7d --- /dev/null +++ b/tests/graph/utils/test_lexsort.py @@ -0,0 +1,11 @@ +import numpy as np +from mindspore import Tensor, ops + +from mindscience.sharker.utils import lexsort + + +def test_lexsort(): + keys = [ops.randn(100) for _ in range(3)] + + expected = np.lexsort([key.numpy() for key in keys]) + assert ops.equal(lexsort(keys), Tensor.from_numpy(expected)).all() diff --git a/tests/graph/utils/test_loop.py b/tests/graph/utils/test_loop.py new file mode 100644 index 000000000..ef6293386 --- /dev/null +++ b/tests/graph/utils/test_loop.py @@ -0,0 +1,187 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils import ( + add_remaining_self_loops, + add_self_loops, + contains_self_loops, + get_self_loop_attr, + remove_self_loops, + segregate_self_loops +) + + +def test_contains_self_loops(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + assert contains_self_loops(edge_index) + + edge_index = ms.Tensor([[0, 1, 1], [1, 0, 2]]) + assert not contains_self_loops(edge_index) + + +def test_remove_self_loops(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_attr = ms.Tensor([[1, 2], [3, 4], [5, 6]]) + + expected = [[0, 1], [1, 0]] + + out = remove_self_loops(edge_index) + assert out[0].tolist() == expected + assert out[1] is None + + out = remove_self_loops(edge_index, edge_attr) + assert out[0].tolist() == expected + assert out[1].tolist() == [[1, 2], [3, 4]] + + out = remove_self_loops(edge_index, edge_attr) + assert out[0].tolist() == expected + assert out[0].shape == (2, 2) + assert out[1].tolist() == [[1, 2], [3, 4]] + + +def test_segregate_self_loops(): + edge_index = ms.Tensor([[0, 0, 1], [0, 1, 0]]) + + out = segregate_self_loops(edge_index) + assert out[0].tolist() == [[0, 1], [1, 0]] + assert out[1] is None + assert out[2].tolist() == [[0], [0]] + assert out[3] is None + + edge_attr = ms.Tensor([1, 2, 3]) + out = segregate_self_loops(edge_index, edge_attr) + assert out[0].tolist() == [[0, 1], [1, 0]] + assert out[1].tolist() == [2, 3] + assert out[2].tolist() == [[0], [0]] + assert out[3].tolist() == [1] + + out = segregate_self_loops(edge_index, edge_attr) + assert out[0].tolist() == [[0, 1], [1, 0]] + assert out[0].shape == (2, 2) + assert out[1].tolist() == [2, 3] + assert out[2].tolist() == [[0], [0]] + assert out[2].shape == (2, 1) + assert out[3].tolist() == [1] + + +def test_add_self_loops(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_weight = ms.Tensor([0.5, 0.5, 0.5]) + edge_attr = mint.eye(3) + + + expected = [[0, 1, 0, 0, 1], [1, 0, 0, 0, 1]] + assert add_self_loops(edge_index)[0].tolist() == expected + + out = add_self_loops(edge_index, edge_weight) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 1., 1.] + + out = add_self_loops(edge_index, edge_weight, fill_value=5) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 5.0, 5.0] + + out = add_self_loops(edge_index, edge_weight, fill_value=ms.Tensor(2.)) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 2., 2.] + + out = add_self_loops(edge_index, edge_weight, fill_value='add') + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 1, 0.5] + + # Tests with `edge_attr`: + out = add_self_loops(edge_index, edge_attr) + assert out[0].tolist() == expected + assert out[1].tolist() == [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.], + [1., 1., 1.], [1., 1., 1.]] + + out = add_self_loops(edge_index, edge_attr, + fill_value=ms.Tensor([0., 1., 0.])) + assert out[0].tolist() == expected + assert out[1].tolist() == [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.], + [0., 1., 0.], [0., 1., 0.]] + + out = add_self_loops(edge_index, edge_attr, fill_value='add') + assert out[0].tolist() == expected + assert out[1].tolist() == [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.], + [0., 1., 1.], [1., 0., 0.]] + + +def test_add_self_loops_bipartite(): + edge_index = ms.Tensor([[0, 1, 2, 3], [0, 0, 1, 1]]) + + edge_index, _ = add_self_loops(edge_index, num_nodes=(4, 2)) + assert edge_index.tolist() == [[0, 1, 2, 3, 0, 1], [0, 0, 1, 1, 0, 1]] + +def test_add_remaining_self_loops(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_weight = ms.Tensor([0.5, 0.5, 0.5]) + edge_attr = mint.eye(3) + + expected = [[0, 1, 0, 1], [1, 0, 0, 1]] + + out = add_remaining_self_loops(edge_index, edge_weight) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 1] + + out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 5.0] + + out = add_remaining_self_loops(edge_index, edge_weight, + fill_value=ms.Tensor(2.)) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 2.0] + + out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add') + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 0.5] + + # Test with `edge_attr`: + out = add_remaining_self_loops(edge_index, edge_attr, + fill_value=ms.Tensor([0., 1., 0.])) + assert out[0].tolist() == expected + assert out[1].tolist() == [[1., 0., 0.], [0., 1., 0.], [0., 0., 1.], + [0., 1., 0.]] + + +def test_add_remaining_self_loops_without_initial_loops(): + edge_index = ms.Tensor([[0, 1], [1, 0]]) + edge_weight = ms.Tensor([0.5, 0.5]) + + expected = [[0, 1, 0, 1], [1, 0, 0, 1]] + + out = add_remaining_self_loops(edge_index, edge_weight) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 1, 1] + + out = add_remaining_self_loops(edge_index, edge_weight, fill_value=5) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 5.0, 5.0] + + out = add_remaining_self_loops(edge_index, edge_weight, + fill_value=ms.Tensor(2.0)) + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 2.0, 2.0] + + # Test string `fill_value`: + out = add_remaining_self_loops(edge_index, edge_weight, fill_value='add') + assert out[0].tolist() == expected + assert out[1].tolist() == [0.5, 0.5, 0.5, 0.5] + + +def test_get_self_loop_attr(): + edge_index = ms.Tensor([[0, 1, 0], [1, 0, 0]]) + edge_weight = ms.Tensor([0.2, 0.3, 0.5]) + + full_loop_weight = get_self_loop_attr(edge_index, edge_weight) + assert full_loop_weight.tolist() == [0.5, 0.0] + + full_loop_weight = get_self_loop_attr(edge_index, edge_weight, num_nodes=4) + assert full_loop_weight.tolist() == [0.5, 0.0, 0.0, 0.0] + + full_loop_weight = get_self_loop_attr(edge_index) + assert full_loop_weight.tolist() == [1.0, 0.0] + + edge_attr = ms.Tensor([[1.0, 0.0], [0.0, 1.0], [0.5, 1.0]]) + full_loop_attr = get_self_loop_attr(edge_index, edge_attr) + assert full_loop_attr.tolist() == [[0.5, 1.0], [0.0, 0.0]] diff --git a/tests/graph/utils/test_map.py b/tests/graph/utils/test_map.py new file mode 100644 index 000000000..863459771 --- /dev/null +++ b/tests/graph/utils/test_map.py @@ -0,0 +1,74 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import withPackage +from mindscience.sharker.utils.map import map_index + + +@withPackage('pandas') +@pytest.mark.parametrize('max_index', [3, 100_000_000]) +def test_map_index(max_index): + src = ms.Tensor([2, 0, 1, 0, max_index]) + index = ms.Tensor([max_index, 2, 0, 1]) + + out, mask = map_index(src, index, inclusive=True) + assert mask is None + assert out.tolist() == [1, 2, 3, 2, 0] + + +@withPackage('pandas') +@pytest.mark.parametrize('max_index', [3, 100_000_000]) +def test_map_index_na(max_index): + src = ms.Tensor([2, 0, 1, 0, max_index]) + index = ms.Tensor([max_index, 2, 0]) + + out, mask = map_index(src, index, inclusive=False) + assert out.tolist() == [1, 2, 2, 0] + assert mask.tolist() == [True, True, False, True, True] + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + args = parser.parse_args() + + src = ops.randint(0, 100_000_000, (100_000, )) + index = src.unique() + + def trivial_map(src, index, max_index, inclusive): + if max_index is None: + max_index = max(src.max(), index.max()) + + if inclusive: + assoc = src.new_empty(max_index + 1) + else: + assoc = ops.full((max_index + 1, ), -1, dtype=src.dtype) + assoc[index] = mint.arange(index.numel()) + out = assoc[src] + + if inclusive: + return out, None + else: + mask = out != -1 + return out[mask], mask + + print('Inclusive:') + benchmark( + funcs=[trivial_map, map_index], + func_names=['trivial', 'map_index'], + args=(src, index, None, True), + num_steps=100, + num_warmups=50, + ) + + print('Exclusive:') + benchmark( + funcs=[trivial_map, map_index], + func_names=['trivial', 'map_index'], + args=(src, index[:50_000], None, False), + num_steps=100, + num_warmups=50, + ) diff --git a/tests/graph/utils/test_mask.py b/tests/graph/utils/test_mask.py new file mode 100644 index 000000000..c3f5ee935 --- /dev/null +++ b/tests/graph/utils/test_mask.py @@ -0,0 +1,28 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.utils import index_to_mask, mask_select, mask_to_index + + +def test_mask_select(): + src = ops.randn(6, 8) + mask = ms.Tensor([False, True, False, True, False, True]) + + out = mask_select(src, 0, mask) + assert out.shape == (3, 8) + assert ops.equal(src[ms.Tensor([1, 3, 5])], out).all() + +def test_index_to_mask(): + index = ms.Tensor([1, 3, 5]) + + mask = index_to_mask(index) + assert mask.tolist() == [False, True, False, True, False, True] + + mask = index_to_mask(index, size=7) + assert mask.tolist() == [False, True, False, True, False, True, False] + + +def test_mask_to_index(): + mask = ms.Tensor([False, True, False, True, False, True]) + + index = mask_to_index(mask) + assert index.tolist() == [1, 3, 5] diff --git a/tests/graph/utils/test_mesh_laplacian.py b/tests/graph/utils/test_mesh_laplacian.py new file mode 100644 index 000000000..ed770d57b --- /dev/null +++ b/tests/graph/utils/test_mesh_laplacian.py @@ -0,0 +1,101 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils import get_mesh_laplacian + + +def test_get_mesh_laplacian_of_cube(): + pos = ms.Tensor([ + [1.0, 1.0, 1.0], + [1.0, -1.0, 1.0], + [-1.0, -1.0, 1.0], + [-1.0, 1.0, 1.0], + [1.0, 1.0, -1.0], + [1.0, -1.0, -1.0], + [-1.0, -1.0, -1.0], + [-1.0, 1.0, -1.0], + ]) + + face = ms.Tensor([ + [0, 1, 2], + [0, 3, 2], + [4, 5, 1], + [4, 0, 1], + [7, 6, 5], + [7, 4, 5], + [3, 2, 6], + [3, 7, 6], + [4, 0, 3], + [4, 7, 3], + [1, 5, 6], + [1, 2, 6], + ]) + + edge_index, edge_weight = get_mesh_laplacian(pos, face.t(), + normalization='rw') + + assert edge_index.tolist() == [ + [ + 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 4, 4, 4, 4, + 4, 5, 5, 5, 5, 6, 6, 6, 6, 6, 7, 7, 7, 7, 0, 1, 2, 3, 4, 5, 6, 7 + ], + [ + 1, 2, 3, 4, 0, 2, 4, 5, 6, 0, 1, 3, 6, 0, 2, 4, 6, 7, 0, 1, 3, 5, + 7, 1, 4, 6, 7, 1, 2, 3, 5, 7, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7 + ], + ] + + assert mint.isclose( + edge_weight, + ms.Tensor([ + 0.375, 0.0, 0.375, 0.375, 0.3, 0.3, 0.0, 0.3, 0.0, 0.0, 0.375, + 0.375, 0.375, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, 0.0, 0.0, 0.3, 0.3, + 0.375, 0.375, 0.375, 0.0, 0.0, 0.3, 0.0, 0.3, 0.3, 0.375, 0.375, + 0.0, 0.375, -1.125, -0.9, -1.125, -0.9, -0.9, -1.125, -0.9, -1.125 + ]), rtol=1e-04, atol=1e-04).all() + + +def test_get_mesh_laplacian_of_irregular_triangular_prism(): + pos = ms.Tensor([ + [0.0, 0.0, 0.0], + [4.0, 0.0, 0.0], + [0.0, 0.0, -3.0], + [1.0, 5.0, -1.0], + [3.0, 5.0, -1.0], + [2.0, 5.0, -2.0], + ]) + + face = ms.Tensor([ + [0, 1, 2], + [3, 4, 5], + [0, 1, 4], + [0, 3, 4], + [1, 2, 5], + [1, 4, 5], + [2, 0, 3], + [2, 5, 3], + ]) + + edge_index, edge_weight = get_mesh_laplacian(pos, face.t(), + normalization='rw') + + assert edge_index.tolist() == [ + [ + 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, + 5, 5, 0, 1, 2, 3, 4, 5 + ], + [ + 1, 2, 3, 4, 0, 2, 4, 5, 0, 1, 3, 5, 0, 2, 4, 5, 0, 1, 3, 5, 1, 2, + 3, 4, 0, 1, 2, 3, 4, 5 + ], + ] + + assert mint.isclose( + edge_weight, + ms.Tensor([ + 0.09730332, 0.15039921, 0.05081503, 0.00000000, 0.08726977, + 0.03521059, 0.05363689, 0.00723919, 0.14497279, 0.03784235, + 0.01629947, 0.03438699, 0.08362866, 0.02782887, 0.24252312, + 0.40727590, 0.00000000, 0.08728313, 0.21507657, 0.38582093, + 0.01117009, 0.04936920, 0.34247482, 0.36583540, -0.29851755, + -0.18335645, -0.23350160, -0.76125660, -0.68818060, -0.76884955 + ]), rtol=1e-04, atol=1e-04).all() diff --git a/tests/graph/utils/test_negative_sampling.py b/tests/graph/utils/test_negative_sampling.py new file mode 100644 index 000000000..7a60018c6 --- /dev/null +++ b/tests/graph/utils/test_negative_sampling.py @@ -0,0 +1,169 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils import ( + batched_negative_sampling, + contains_self_loops, + is_undirected, + negative_sampling, + structured_negative_sampling, + structured_negative_sampling_feasible, + to_undirected, +) +from mindscience.sharker.utils.negative_sampling import ( + edge_index_to_vector, + vector_to_edge_index, +) + + +def is_negative(edge_index, neg_edge_index, size, bipartite): + adj = mint.zeros(size, dtype=ms.bool_) + neg_adj = mint.zeros(size, dtype=ms.bool_) + + adj[edge_index[0], edge_index[1]] = True + neg_adj[neg_edge_index[0], neg_edge_index[1]] = True + + if not bipartite: + arange = mint.arange(size[0]) + assert neg_adj[arange, arange].sum() == 0 + + return mint.logical_and(adj, neg_adj).sum() == 0 + + +def test_edge_index_to_vector_and_vice_versa(): + # Create a fully-connected graph: + N = 10 + row = mint.arange(N).view(-1, 1).tile((1, N)).view(-1) + col = mint.arange(N).view(1, -1).tile((N, 1)).view(-1) + edge_index = mint.stack(([row, col]), dim=0) + + idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=True) + assert population == N * N + assert idx.tolist() == list(range(population)) + edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=True) + assert is_undirected(edge_index2) + assert edge_index.tolist() == edge_index2.tolist() + + idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False) + assert population == N * N - N + assert idx.tolist() == list(range(population)) + mask = edge_index[0] != edge_index[1] # Remove self-loops. + edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False) + assert is_undirected(edge_index2) + assert edge_index[:, mask].tolist() == edge_index2.tolist() + + idx, population = edge_index_to_vector(edge_index, (N, N), bipartite=False, + force_undirected=True) + assert population == (N * (N + 1)) / 2 - N + assert idx.tolist() == list(range(population)) + mask = edge_index[0] != edge_index[1] # Remove self-loops. + edge_index2 = vector_to_edge_index(idx, (N, N), bipartite=False, + force_undirected=True) + assert is_undirected(edge_index2) + assert edge_index[:, mask].tolist() == to_undirected(edge_index2).tolist() + + +def test_negative_sampling(): + edge_index = ms.Tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) + + neg_edge_index = negative_sampling(edge_index) + assert neg_edge_index.shape[1] == edge_index.shape[1] + assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) + + neg_edge_index = negative_sampling(edge_index, method='dense') + assert neg_edge_index.shape[1] == edge_index.shape[1] + assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) + + neg_edge_index = negative_sampling(edge_index, num_neg_samples=2) + assert neg_edge_index.shape[1] == 2 + assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) + + edge_index = to_undirected(edge_index) + neg_edge_index = negative_sampling(edge_index, force_undirected=True) + assert neg_edge_index.shape[1] == edge_index.shape[1] - 1 + assert is_undirected(neg_edge_index) + assert is_negative(edge_index, neg_edge_index, (4, 4), bipartite=False) + + +def test_bipartite_negative_sampling(): + edge_index = ms.Tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) + + neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4)) + assert neg_edge_index.shape[1] == edge_index.shape[1] + assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) + + neg_edge_index = negative_sampling(edge_index, num_nodes=(3, 4), + num_neg_samples=2) + assert neg_edge_index.shape[1] == 2 + assert is_negative(edge_index, neg_edge_index, (3, 4), bipartite=True) + + +def test_batched_negative_sampling(): + edge_index = ms.Tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) + edge_index = mint.cat(([edge_index, edge_index + 4]), dim=1) + batch = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1]) + + neg_edge_index = batched_negative_sampling(edge_index, batch) + assert neg_edge_index.shape[1] <= edge_index.shape[1] + + adj = mint.zeros([8, 8], dtype=ms.bool_) + adj[edge_index[0], edge_index[1]] = True + neg_adj = mint.zeros([8, 8], dtype=ms.bool_) + neg_adj[neg_edge_index[0], neg_edge_index[1]] = True + + assert mint.logical_and(adj, neg_adj).sum() == 0 + assert mint.logical_or(adj, neg_adj).sum() == edge_index.shape[1] + neg_edge_index.shape[1] + + assert neg_adj[:4, 4:].sum() == 0 + assert neg_adj[4:, :4].sum() == 0 + + +def test_bipartite_batched_negative_sampling(): + edge_index1 = ms.Tensor([[0, 0, 1, 1], [0, 1, 2, 3]]) + edge_index2 = edge_index1 + ms.Tensor([[2], [4]]) + edge_index3 = edge_index2 + ms.Tensor([[2], [4]]) + edge_index = mint.cat(([edge_index1, edge_index2, edge_index3]), dim=1) + src_batch = ms.Tensor([0, 0, 1, 1, 2, 2]) + dst_batch = ms.Tensor([0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2]) + + neg_edge_index = batched_negative_sampling(edge_index, + (src_batch, dst_batch)) + assert neg_edge_index.shape[1] <= edge_index.shape[1] + + adj = mint.zeros([6, 12], dtype=ms.bool_) + adj[edge_index[0], edge_index[1]] = True + neg_adj = mint.zeros([6, 12], dtype=ms.bool_) + neg_adj[neg_edge_index[0], neg_edge_index[1]] = True + + assert mint.logical_and(adj, neg_adj).sum() == 0 + assert mint.logical_or(adj, neg_adj).sum() == edge_index.shape[1] + neg_edge_index.shape[1] + + +def test_structured_negative_sampling(): + edge_index = ms.Tensor([[0, 0, 1, 2], [0, 1, 2, 3]]) + + i, j, k = structured_negative_sampling(edge_index) + assert i.shape[0] == edge_index.shape[1] + assert j.shape[0] == edge_index.shape[1] + assert k.shape[0] == edge_index.shape[1] + + adj = mint.zeros([4, 4], dtype=ms.bool_) + adj[i, j] = 1 + + neg_adj = mint.zeros([4, 4], dtype=ms.bool_) + neg_adj[i, k] = 1 + assert mint.logical_and(adj, neg_adj).sum() == 0 + + # Test with no self-loops: + edge_index = ms.Tensor([[0, 0, 1, 1, 2], [1, 2, 0, 2, 1]]) + i, j, k = structured_negative_sampling(edge_index, num_nodes=4, + contains_neg_self_loops=False) + neg_edge_index = ops.vstack([i, k]) + assert not contains_self_loops(neg_edge_index) + + +def test_structured_negative_sampling_feasible(): + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 2], + [1, 2, 0, 2, 0, 1, 1]]) + assert not structured_negative_sampling_feasible(edge_index, 3, False) + assert structured_negative_sampling_feasible(edge_index, 3, True) + assert structured_negative_sampling_feasible(edge_index, 4, False) diff --git a/tests/graph/utils/test_noise_scheduler.py b/tests/graph/utils/test_noise_scheduler.py new file mode 100644 index 000000000..23702dec0 --- /dev/null +++ b/tests/graph/utils/test_noise_scheduler.py @@ -0,0 +1,34 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils.noise_scheduler import ( + get_diffusion_beta_schedule, + get_smld_sigma_schedule, +) + + +def test_get_smld_sigma_schedule(): + expected = ms.Tensor([ + 1., 0.59948425, 0.35938137, 0.21544347, 0.12915497, 0.07742637, + 0.04641589, 0.02782559, 0.01668101, 0.01 + ]) + out = get_smld_sigma_schedule( + sigma_min=0.01, + sigma_max=1.0, + num_scales=10, + ) + assert mint.isclose(out, expected, rtol=1e-04, atol=1e-04, equal_nan=True).all() + + +@pytest.mark.parametrize( + 'schedule_type', + ['linear', 'quadratic', 'constant', 'sigmoid'], +) +def test_get_diffusion_beta_schedule(schedule_type): + out = get_diffusion_beta_schedule( + schedule_type, + beta_start=0.1, + beta_end=0.2, + num_diffusion_timesteps=10, + ) + assert out.shape == (10, ) diff --git a/tests/graph/utils/test_normalized_cut.py b/tests/graph/utils/test_normalized_cut.py new file mode 100644 index 000000000..2ac16d842 --- /dev/null +++ b/tests/graph/utils/test_normalized_cut.py @@ -0,0 +1,20 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import normalized_cut + + +def test_normalized_cut(): + row = ms.Tensor([0, 1, 1, 1, 2, 2, 3, 3, 4, 4]) + col = ms.Tensor([1, 0, 2, 3, 1, 4, 1, 4, 2, 3]) + edge_attr = ms.Tensor( + [3.0, 3.0, 6.0, 3.0, 6.0, 1.0, 3.0, 2.0, 1.0, 2.0]) + expected = ms.Tensor([4.0, 4.0, 5.0, 2.5, 5.0, 1.0, 2.5, 2.0, 1.0, 2.0]) + + out = normalized_cut(mint.stack(([row, col]), dim=0), edge_attr) + assert mint.isclose(out, expected, equal_nan=True).all() + + if is_full_test(): + jit = ms.jit(normalized_cut) + out = jit(mint.stack(([row, col]), dim=0), edge_attr) + assert mint.isclose(out, expected).all() diff --git a/tests/graph/utils/test_num_nodes.py b/tests/graph/utils/test_num_nodes.py new file mode 100644 index 000000000..628a98fda --- /dev/null +++ b/tests/graph/utils/test_num_nodes.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindscience.sharker.utils.num_nodes import ( + maybe_num_nodes, + maybe_num_nodes_dict, +) + + +def test_maybe_num_nodes(): + edge_index = ms.Tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]) + + assert maybe_num_nodes(edge_index, 4) == 4 + assert maybe_num_nodes(edge_index) == 3 + +def test_maybe_num_nodes_dict(): + edge_index_dict = { + '1': ms.Tensor([[0, 0, 1, 1, 2, 2, 2], [1, 2, 0, 2, 0, 1, 1]]), + '2': ms.Tensor([[0, 0, 1, 3], [1, 2, 0, 4]]) + } + num_nodes_dict = {'2': 6} + + assert maybe_num_nodes_dict(edge_index_dict) == {'1': 3, '2': 5} + assert maybe_num_nodes_dict(edge_index_dict, num_nodes_dict) == { + '1': 3, + '2': 6, + } diff --git a/tests/graph/utils/test_ppr.py b/tests/graph/utils/test_ppr.py new file mode 100644 index 000000000..98e2e0372 --- /dev/null +++ b/tests/graph/utils/test_ppr.py @@ -0,0 +1,27 @@ +import pytest +import mindspore as ms +from mindscience.sharker.datasets import KarateClub +from mindscience.sharker.testing import withPackage +from mindscience.sharker.utils import get_ppr + + +@withPackage('numba') +@pytest.mark.parametrize('target', [None, ms.Tensor([0, 4, 5, 6])]) +def test_get_ppr(target): + data = KarateClub()[0] + + edge_index, edge_weight = get_ppr( + data.edge_index, + alpha=0.1, + eps=1e-5, + target=target, + ) + + assert edge_index.shape[0] == 2 + assert edge_index.shape[1] == edge_weight.numel() + + min_row = 0 if target is None else target.min() + max_row = data.num_nodes - 1 if target is None else target.max() + assert edge_index[0].min() == min_row and edge_index[0].max() == max_row + assert edge_index[1].min() >= 0 and edge_index[1].max() < data.num_nodes + assert edge_weight.min() >= 0.0 and edge_weight.max() <= 1.0 diff --git a/tests/graph/utils/test_random.py b/tests/graph/utils/test_random.py new file mode 100644 index 000000000..00eaf0826 --- /dev/null +++ b/tests/graph/utils/test_random.py @@ -0,0 +1,22 @@ +import mindspore as ms +from mindscience.sharker import seed_everything +from mindscience.sharker.utils import ( + barabasi_albert_graph, + erdos_renyi_graph +) + + +def test_erdos_renyi_graph(): + seed_everything(1023) + edge_index = erdos_renyi_graph(5, 0.2, directed=False) + assert edge_index.tolist() == [[3, 4], [4, 3]] + + seed_everything(1023) + edge_index = erdos_renyi_graph(5, 0.5, directed=True) + assert edge_index.tolist() == [[0, 1, 1, 1, 2, 2, 3, 3, 3, 4, 4], [2, 2, 3, 4, 1, 4, 0, 1, 2, 0, 2]] + + +def test_barabasi_albert_graph(): + seed_everything(12345) + edge_index = barabasi_albert_graph(num_nodes=8, num_edges=3) + assert edge_index.shape == (2, 30) diff --git a/tests/graph/utils/test_repeat.py b/tests/graph/utils/test_repeat.py new file mode 100644 index 000000000..b0887d716 --- /dev/null +++ b/tests/graph/utils/test_repeat.py @@ -0,0 +1,9 @@ +from mindscience.sharker.utils.repeat import repeat + + +def test_repeat(): + assert repeat(None, length=4) is None + assert repeat(4, length=4) == [4, 4, 4, 4] + assert repeat([2, 3, 4], length=4) == [2, 3, 4, 4] + assert repeat([1, 2, 3, 4], length=4) == [1, 2, 3, 4] + assert repeat([1, 2, 3, 4, 5], length=4) == [1, 2, 3, 4] diff --git a/tests/graph/utils/test_scatter.py b/tests/graph/utils/test_scatter.py new file mode 100644 index 000000000..2182c1677 --- /dev/null +++ b/tests/graph/utils/test_scatter.py @@ -0,0 +1,126 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker import seed_everything +from mindscience.sharker.utils import group_argsort, group_cat, scatter +from mindscience.sharker.utils import ptr2index + + +def test_scatter_validate(): + src = ops.randn(100, 32) + index = ops.randint(0, 10, (100, ), dtype=ms.int64) + + with pytest.raises(ValueError, match="must lay between 0 and 1"): + scatter(src, index, dim=2) + + with pytest.raises(ValueError, match="invalid `reduce` argument 'std'"): + scatter(src, index, reduce='std') + + +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'mul', 'min', 'amin', 'max', 'amax']) +def test_scatter(reduce): + seed_everything(1) + + + src = ops.randn(20, 16) + index = ms.Tensor([1 ,1 ,1, 1, 1, 2, 2 ,2 ,2 ,2 ,3, 3, 3, 3, 3 ,4, 4 ,4, 4 ,4]) + + out1 = scatter(src, index, dim=0, reduce=reduce) + out2 = scatter(src.T, index, dim=1, reduce=reduce).T + + if reduce == 'mul': + expected = mint.prod(src.view(4, 5, -1), 1) + elif reduce == 'add': + expected = mint.sum(src.view(4, 5, -1), 1) + elif reduce == 'amax': + expected = mint.argmax(src.view(4, 5, -1), 1) + expected[1] += 5 + expected[2] += 10 + expected[3] += 15 + elif reduce == 'amin': + expected = ops.argmin(src.view(4, 5, -1), 1) + expected[1] += 5 + expected[2] += 10 + expected[3] += 15 + else: + expected = getattr(ops, reduce)(src.view(4, 5, -1), 1) + expected = expected[0] if isinstance(expected, tuple) else expected + + assert out1.shape == (5, 16) + assert (out1[:1] == (20 if reduce in ['amin', 'amax'] else 0)).all() + assert mint.isclose(out1[1:], expected, atol=1e-3).all() + assert mint.isclose(out1, out2, atol=1e-3).all() + + # jit = ms.jit(scatter) + # out3 = jit(src, index, dim=0, reduce=reduce) + # assert out3.shape == (8, 8) + # assert ops.isclose(out1, out3, atol=1e-6).all() + + src = mint.randn(2, 4, 8) + index = mint.randint(0, 8, (4, )) + out1 = scatter(src, index, dim=1, reduce=reduce) + assert out1.shape[0] == 2 and out1.shape[2] == 8 + + +@pytest.mark.parametrize('reduce', ['sum', 'add', 'mean', 'min', 'max']) +def test_scatter_gradient(reduce): + src = ops.randn([8, 100, 8]) + index = ops.randint(0, 8, (100, )) + grad_fn = ms.value_and_grad(scatter, grad_position=0, weights=None, has_aux=False) + value, grad = grad_fn(src, index, dim=1, reduce=reduce) + assert value is not None + assert grad is not None + + +def test_scatter_any(): + src = ops.randn(6, 4) + index = ms.Tensor([0, 0, 1, 1, 2, 2]) + + out = scatter(src, index, dim=0, reduce='any') + + for i in range(3): + for j in range(4): + assert float(out[i, j]) in src[2 * i:2 * i + 2, j].tolist() + + +@pytest.mark.parametrize('num_groups', [4]) +@pytest.mark.parametrize('descending', [False, True]) +def test_group_argsort(num_groups, descending): + src = ops.randn(20) + index = ops.randint(0, num_groups, (20, )) + + out = group_argsort(src, index, 0, num_groups, descending=descending) + + expected = mint.zeros_like(index) + for i in range(num_groups): + mask = index == i + tmp = src[mask].argsort(descending=descending).long() + perm = mint.zeros_like(tmp) + perm[tmp] = mint.arange(tmp.numel()) + expected[mask] = perm + + assert ops.equal(out, expected).all() + + # # Not suppoert empty Tensor at the moment + # empty_tensor = ms.Tensor([]) + # out = group_argsort(empty_tensor, empty_tensor) + # assert out.numel() == 0 + + +def test_group_cat(): + x1 = ops.randn(4, 4) + x2 = ops.randn(2, 4) + index1 = ms.Tensor([0, 0, 1, 2]) + index2 = ms.Tensor([0, 2]) + + expected = mint.cat(([x1[:2], x2[:1], x1[2:4], x2[1:]]), dim=0) + + out, index = group_cat( + [x1, x2], + [index1, index2], + axis=0, + return_index=True, + ) + assert ops.equal(out, expected).all() + assert index.tolist() == [0, 0, 0, 1, 2, 2] + diff --git a/tests/graph/utils/test_segment.py b/tests/graph/utils/test_segment.py new file mode 100644 index 000000000..23f133b1d --- /dev/null +++ b/tests/graph/utils/test_segment.py @@ -0,0 +1,30 @@ +import pytest +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils._segment import segment + + +@pytest.mark.parametrize('reduce', ['sum', 'mean', 'mul', 'min', 'max', 'amax', 'amin']) +def test_segment(reduce): + src = ops.randn(20, 16) + ptr = ms.Tensor([0, 0, 5, 10, 15, 20]) + out = segment(src, ptr, dim=0, reduce=reduce) + out1 = segment(src.T, ptr, dim=1, reduce=reduce).T + + if reduce == 'mul': + expected = mint.prod(src.view(4, 5, -1), 1) + elif reduce == 'amax': + expected = mint.argmax(src.view(4, 5, -1), 1) + elif reduce == 'amin': + expected = ops.argmin(src.view(4, 5, -1), 1) + else: + expected = getattr(ops, reduce)(src.view(4, 5, -1), 1) + expected = expected[0] if isinstance(expected, tuple) else expected + + assert mint.isclose(out[:1], mint.zeros([1, 16], dtype=out.dtype)).all() + assert mint.isclose(out[1:], expected).all() + assert mint.isclose(out, out1, rtol=1e-04, atol=1e-4).all() + + # jit = ms.jit(segment) + # out1 = jit(src, ptr, reduce=reduce) + # assert ops.isclose(out, out1).all() diff --git a/tests/graph/utils/test_select.py b/tests/graph/utils/test_select.py new file mode 100644 index 000000000..d364680ec --- /dev/null +++ b/tests/graph/utils/test_select.py @@ -0,0 +1,23 @@ +import mindspore as ms +from mindspore import ops +from mindscience.sharker.utils import narrow, select + + +def test_select(): + src = ops.randn(5, 3) + index = ms.Tensor([0, 2, 4]) + mask = ms.Tensor([True, False, True, False, True]) + + out = select(src, index, axis=0) + assert ops.equal(out, src[index]).all() + assert ops.equal(out, select(src, mask, axis=0)).all() + assert ops.equal(out, ms.Tensor(select(src.tolist(), index, axis=0))).all() + assert ops.equal(out, ms.Tensor(select(src.tolist(), mask, axis=0))).all() + + +def test_narrow(): + src = ops.randn(5, 3) + + out = narrow(src, axis=0, start=2, length=2) + assert ops.equal(out, src[2:4]).all() + assert ops.equal(out, ms.Tensor(narrow(src.tolist(), 0, 2, 2))).all() diff --git a/tests/graph/utils/test_softmax.py b/tests/graph/utils/test_softmax.py new file mode 100644 index 000000000..26bbb08fb --- /dev/null +++ b/tests/graph/utils/test_softmax.py @@ -0,0 +1,82 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.profile import benchmark +from mindscience.sharker.utils import softmax + + +def test_softmax(): + src = ms.Tensor([1., 1., 1., 1.]) + index = ms.Tensor([0, 0, 1, 2]) + ptr = ms.Tensor([0, 2, 3, 4]) + + out = softmax(src, index) + assert out.tolist() == [0.5, 0.5, 1, 1] + assert softmax(src, ptr=ptr).tolist() == out.tolist() + + src = src.view(-1, 1) + out = softmax(src, index) + assert out.tolist() == [[0.5], [0.5], [1], [1]] + assert softmax(src, ptr=ptr).tolist() == out.tolist() + + +def test_softmax_grad(): + src_sparse = ops.rand(4, 8) + index = ms.Tensor([0, 0, 1, 1]) + src_dense = src_sparse.copy().view(2, 2, src_sparse.shape[-1]) + + grad_sparse = ops.value_and_grad(lambda src, index: softmax(src, index).mean()) + out_sparse, grad_sparse = grad_sparse(src_sparse, index) + grad_dense = ops.value_and_grad(lambda src: mint.nn.functional.softmax(src, dim=0).mean()) + out_dense, grad_dense = grad_dense(src_dense) + + assert mint.isclose(out_sparse, out_dense.view_as(out_sparse), rtol=1e-04, atol=1e-4).all() + assert mint.isclose(grad_sparse, grad_dense.view_as(src_sparse), rtol=1e-04, atol=1e-4).all() + + +def test_softmax_dim(): + index = ms.Tensor([0, 0, 0, 0]) + ptr = ms.Tensor([0, 4]) + + src = ops.randn(4) + assert mint.isclose(softmax(src, index, axis=0), mint.nn.functional.softmax(src, dim=0), rtol=1e-04, atol=1e-4).all() + assert mint.isclose(softmax(src, ptr=ptr, axis=0), mint.nn.functional.softmax(src, dim=0), rtol=1e-04, atol=1e-4).all() + + src = ops.randn(4, 16) + assert mint.isclose(softmax(src, index, axis=0), mint.nn.functional.softmax(src, dim=0), rtol=1e-04, atol=1e-4).all() + assert mint.isclose(softmax(src, ptr=ptr, axis=0), mint.nn.functional.softmax(src, dim=0), rtol=1e-04, atol=1e-4).all() + + src = ops.randn(4, 4) + assert mint.isclose(softmax(src, index, axis=-1), mint.nn.functional.softmax(src, dim=-1), rtol=1e-04, atol=1e-4).all() + assert mint.isclose(softmax(src, ptr=ptr, axis=-1), mint.nn.functional.softmax(src, dim=-1), rtol=1e-04, atol=1e-4).all() + + src = ops.randn(4, 4, 16) + assert mint.isclose(softmax(src, index, axis=1), mint.nn.functional.softmax(src, dim=1), rtol=1e-04, atol=1e-4).all() + assert mint.isclose(softmax(src, ptr=ptr, axis=1), mint.nn.functional.softmax(src, dim=1), rtol=1e-04, atol=1e-4).all() + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--backward', action='store_true') + args = parser.parse_args() + + num_nodes, num_edges = 10_000, 200_000 + x = ops.randn(num_edges, 64) + index = ops.randint(0, num_nodes, (num_edges, )) + + compiled_softmax = ms.jit(softmax) + + def dense_softmax(x, index): + x = x.view(num_nodes, -1, x.shape[-1]) + return x.softmax(axis=-1) + + benchmark( + funcs=[dense_softmax, softmax, compiled_softmax], + func_names=['Dense Softmax', 'Vanilla', 'Compiled'], + args=(x, index), + num_steps=500, + num_warmups=100, + backward=args.backward, + ) diff --git a/tests/graph/utils/test_sort_edge_index.py b/tests/graph/utils/test_sort_edge_index.py new file mode 100644 index 000000000..1e3ea1609 --- /dev/null +++ b/tests/graph/utils/test_sort_edge_index.py @@ -0,0 +1,72 @@ +from typing import List, Optional, Tuple +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, nn +from mindscience.sharker import typing +from mindscience.sharker.utils import sort_edge_index + + +def test_sort_edge_index(): + edge_index = ms.Tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) + edge_attr = ms.Tensor([[1], [2], [3], [4]]) + + out = sort_edge_index(edge_index) + assert out.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + + out = sort_edge_index((edge_index[0], edge_index[1])) + assert isinstance(out, tuple) + assert out[0].tolist() == [0, 1, 1, 2] + assert out[1].tolist() == [1, 0, 2, 1] + + out = sort_edge_index(edge_index, None) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1] is None + + out = sort_edge_index(edge_index, edge_attr) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1].tolist() == [[4], [3], [2], [1]] + + out = sort_edge_index(edge_index, [edge_attr, edge_attr.view(-1)]) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1][0].tolist() == [[4], [3], [2], [1]] + assert out[1][1].tolist() == [4, 3, 2, 1] + + +def test_sort_edge_index_jit(): + @ms.jit + def wrapper1(edge_index: Tensor) -> Tensor: + return sort_edge_index(edge_index) + + @ms.jit + def wrapper2( + edge_index: Tensor, + edge_attr: Optional[Tensor], + ) -> Tuple[Tensor, Optional[Tensor]]: + return sort_edge_index(edge_index, edge_attr) + + @ms.jit + def wrapper3( + edge_index: Tensor, + edge_attr: List[Tensor], + ) -> Tuple[Tensor, List[Tensor]]: + return sort_edge_index(edge_index, edge_attr) + + edge_index = ms.Tensor([[2, 1, 1, 0], [1, 2, 0, 1]]) + edge_attr = ms.Tensor([[1], [2], [3], [4]]) + + out = wrapper1(edge_index) + assert out.shape == edge_index.shape + + out = wrapper2(edge_index, None) + assert out[0].shape == edge_index.shape + assert out[1] is None + + out = wrapper2(edge_index, edge_attr) + assert out[0].shape == edge_index.shape + assert out[1].shape == edge_attr.shape + + out = wrapper3(edge_index, [edge_attr, edge_attr.view(-1)]) + assert out[0].shape == edge_index.shape + assert len(out[1]) == 2 + assert out[1][0].shape == edge_attr.shape + assert out[1][1].shape == edge_attr.view(-1).shape diff --git a/tests/graph/utils/test_sparse.py b/tests/graph/utils/test_sparse.py new file mode 100644 index 000000000..0061979ff --- /dev/null +++ b/tests/graph/utils/test_sparse.py @@ -0,0 +1,202 @@ +import pytest +import mindspore as ms +from mindspore import ops +from mindscience.sharker.profile import benchmark +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import ( + dense_to_sparse, + is_sparse_tensor, + to_edge_index, + to_coo, + to_csr, + to_sparse, +) +from mindscience.sharker.utils.sparse import cat +from mindscience.sharker.sparse import Layout + + +def test_dense_to_sparse(): + adj = ms.Tensor([ + [3.0, 1.0], + [2.0, 0.0], + ]) + edge_index, edge_attr = dense_to_sparse(adj) + assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]] + assert edge_attr.tolist() == [3, 1, 2] + + if is_full_test(): + jit = ms.jit(dense_to_sparse) + edge_index, edge_attr = jit(adj) + assert edge_index.tolist() == [[0, 0, 1], [0, 1, 0]] + assert edge_attr.tolist() == [3, 1, 2] + + adj = ms.Tensor([[ + [3.0, 1.0], + [2.0, 0.0], + ], [ + [0.0, 1.0], + [0.0, 2.0], + ]]) + edge_index, edge_attr = dense_to_sparse(adj) + assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]] + assert edge_attr.tolist() == [3, 1, 2, 1, 2] + + if is_full_test(): + jit = ms.jit(dense_to_sparse) + edge_index, edge_attr = jit(adj) + assert edge_index.tolist() == [[0, 0, 1, 2, 3], [0, 1, 0, 3, 3]] + assert edge_attr.tolist() == [3, 1, 2, 1, 2] + + adj = ms.Tensor([ + [ + [3.0, 1.0, 0.0], + [2.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + [ + [0.0, 1.0, 0.0], + [0.0, 2.0, 3.0], + [0.0, 5.0, 0.0], + ], + ]) + mask = ms.Tensor([[True, True, False], [True, True, True]]) + + edge_index, edge_attr = dense_to_sparse(adj, mask) + + assert edge_index.tolist() == [[0, 0, 1, 2, 3, 3, 4], + [0, 1, 0, 3, 3, 4, 3]] + assert edge_attr.tolist() == [3, 1, 2, 1, 2, 3, 5] + + +def test_dense_to_sparse_bipartite(): + edge_index, edge_attr = dense_to_sparse(ops.rand(2, 10, 5)) + assert edge_index[0].max() == 19 + assert edge_index[1].max() == 9 + + +def test_is_sparse(): + x = ops.randn([5, 5]) + + assert not is_sparse_tensor(x) + assert is_sparse_tensor(x.to_coo()) + assert is_sparse_tensor(x.to_csr()) + + +def test_to_coo(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2], + ]) + edge_attr = ops.randn(edge_index.shape[1], 8) + + adj = to_coo(edge_index, is_coalesced=False) + # assert adj.is_coalesced() + assert adj.shape == (4, 4) + assert ops.isclose(adj.indices.T, edge_index).all() + + adj = to_coo(edge_index, is_coalesced=True) + # assert adj.is_coalesced() + assert adj.shape == (4, 4) + assert ops.isclose(adj.indices.T, edge_index).all() + + adj = to_coo(edge_index, shape=6) + assert adj.shape == (6, 6) + assert ops.isclose(adj.indices.T, edge_index).all() + + # adj = to_coo(edge_index, edge_attr) + # assert adj.shape == (4, 4, 8) + # assert ops.isclose(adj.indices.T, edge_index).all() + # assert ops.isclose(adj.values, edge_attr).all() + + if is_full_test(): + jit = ms.jit(to_coo) + adj = jit(edge_index, edge_attr) + assert adj.shape == (4, 4, 8) + assert ops.isclose(adj.indices.T, edge_index).all() + assert ops.isclose(adj.values, edge_attr).all() + + +def test_to_csr(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3], + [1, 0, 2, 1, 3, 2], + ]).int() + + adj = to_csr(edge_index) + assert adj.shape == (4, 4) + assert ops.isclose(adj.to_dense().to_coo().indices.T, edge_index).all() + + edge_weight = ops.randn(edge_index.shape[1]) + adj = to_csr(edge_index, edge_weight) + assert adj.shape == (4, 4) + coo = adj.to_dense().to_coo() # .coalesce() + assert ops.isclose(coo.indices.T, edge_index).all() + assert ops.isclose(coo.values, edge_weight).all() + + +def test_to_edge_index(): + adj = ms.Tensor([ + [0., 1., 0., 0.], + [1., 0., 1., 0.], + [0., 1., 0., 1.], + [0., 0., 1., 0.], + ]).to_csr() + + edge_index, edge_attr = to_edge_index(adj) + assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] + assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.] + + # if is_full_test(): + # jit = ms.jit(to_edge_index) + # edge_index, edge_attr = jit(adj) + # assert edge_index.tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] + # assert edge_attr.tolist() == [1., 1., 1., 1., 1., 1.] + + +@pytest.mark.parametrize( + 'layout', + [Layout.COO, Layout.CSR], +) +@pytest.mark.parametrize('axis', [0, 1, (0, 1)]) +def test_cat(layout, axis): + edge_index = ms.Tensor([[0, 1, 1, 2], [1, 0, 2, 1]]) + edge_weight = ops.rand(4) + adj = to_sparse(edge_index, edge_weight, layout=layout) + + out = cat([adj, adj], axis=axis) + edge_index, edge_weight = to_edge_index(out.to_dense().to_csr()) + + if axis == 0: + assert out.shape == (6, 3) + assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5] + assert edge_index[1].tolist() == [1, 0, 2, 1, 1, 0, 2, 1] + elif axis == 1: + assert out.shape == (3, 6) + assert edge_index[0].tolist() == [0, 0, 1, 1, 1, 1, 2, 2] + assert edge_index[1].tolist() == [1, 4, 0, 2, 3, 5, 1, 4] + else: + assert out.shape == (6, 6) + assert edge_index[0].tolist() == [0, 1, 1, 2, 3, 4, 4, 5] + assert edge_index[1].tolist() == [1, 0, 2, 1, 4, 3, 5, 4] + + +if __name__ == '__main__': + import argparse + + parser = argparse.ArgumentParser() + device = ms.get_context('device_target') + args = parser.parse_args() + + num_nodes, num_edges = 10_000, 200_000 + edge_index = ops.randint(0, num_nodes, size=(2, num_edges)) + + benchmark( + funcs=[ + SparseTensor.from_edge_index, to_coo, to_csr, + ], + func_names=['SparseTensor', 'To COO', 'To CSR'], + args=(edge_index, None, (num_nodes, num_nodes)), + num_steps=50 if args.device == 'cpu' else 500, + num_warmups=10 if args.device == 'cpu' else 100, + ) diff --git a/tests/graph/utils/test_spmm.py b/tests/graph/utils/test_spmm.py new file mode 100644 index 000000000..f4d5f7dc0 --- /dev/null +++ b/tests/graph/utils/test_spmm.py @@ -0,0 +1,144 @@ +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker import typing +from mindscience.sharker import EdgeIndex +from mindscience.sharker.profile import benchmark +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import spmm, to_coo +from mindscience.sharker.sparse import Layout + + +@pytest.mark.parametrize('reduce', ['sum', 'mean']) +def test_spmm_basic(reduce): + src = ops.randn(5, 4) + other = ops.randn(4, 8) + + out1 = (src @ other) / (src.shape[1] if reduce == 'mean' else 1) + out2 = spmm(src.to_csr(), other, reduce=reduce) + assert out1.shape == (5, 8) + assert ops.isclose(out1, out2, atol=1e-6).all() + if typing.WITH_SPARSE: + out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) + assert ops.isclose(out2, out3, atol=1e-6).all() + + # Test `mean` reduction with isolated nodes: + src[0] = 0. + out1 = (src @ other) / (4. if reduce == 'mean' else 1.) + out2 = spmm(src.to_csr(), other, reduce=reduce) + assert out1.shape == (5, 8) + assert ops.isclose(out1, out2, atol=1e-6).all() + if typing.WITH_SPARSE: + out3 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) + assert ops.isclose(out2, out3, atol=1e-6).all() + + +@pytest.mark.parametrize('reduce', ['min', 'max']) +def test_spmm_reduce(reduce): + src = ops.randn(5, 4) + other = ops.randn(4, 8) + + out1 = spmm(src.to_csr(), other, reduce) + assert out1.shape == (5, 8) + if typing.WITH_SPARSE: + out2 = spmm(SparseTensor.from_dense(src), other, reduce=reduce) + assert ops.isclose(out1, out2).all() + + +@pytest.mark.parametrize( + 'layout', [Layout.COO, Layout.CSR]) # , Layout.CSC +@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max']) +def test_spmm_layout(layout, reduce): + src = ops.randn(5, 4) + if layout == Layout.COO: + src = src.to_coo() + elif layout == Layout.CSR: + src = src.to_csr() + # else: + # assert layout == Layout.CSC + # src = src.to_sparse_csc() + other = ops.randn(4, 8) + + # if src.is_cuda and reduce in {'min', 'max'}: + # with pytest.raises(NotImplementedError, match="not yet supported"): + # spmm(src, other, reduce=reduce) + # elif layout != Layout.CSR: + # with pytest.warns(UserWarning, match="Converting sparse tensor"): + # spmm(src, other, reduce=reduce) + # else: + spmm(src, other, reduce=reduce) + + +# @pytest.mark.parametrize('reduce', ['sum', 'mean']) +# def test_spmm_jit(reduce): +# @ms.jit +# def jit_torch_sparse(src: SparseTensor, other: Tensor, +# reduce: str) -> Tensor: +# return spmm(src, other, reduce=reduce) + +# @ms.jit +# def jit_torch(src: Tensor, other: Tensor, reduce: str) -> Tensor: +# return spmm(src, other, reduce=reduce) + +# src = ops.randn(5, 4) +# other = ops.randn(4, 8) + +# out1 = src @ other +# out2 = jit_torch(src.to_csr(), other, reduce) +# assert out1.shape == (5, 8) +# if reduce == 'sum': +# assert ops.isclose(out1, out2, atol=1e-6).all() +# if typing.WITH_SPARSE: +# out3 = jit_torch_sparse(SparseTensor.from_dense(src), other, reduce) +# assert ops.isclose(out2, out3, atol=1e-6).all() + + +@pytest.mark.parametrize('reduce', ['sum', 'mean', 'min', 'max']) +def test_spmm_edge_index(reduce): + src = EdgeIndex( + [[0, 1, 1, 2], [1, 0, 2, 1]], + sparse_shape=(4, 3), + sort_order='row' + ) + other = ops.rand(3, 4) + out = spmm(src, other, reduce=reduce) + assert out.shape == (4, 4) + + out2 = spmm(src.to_coo(), other, reduce=reduce) + assert ops.isclose(out, out2).all() + + +# if __name__ == '__main__': +# import argparse + +# warnings.filterwarnings('ignore', ".*Sparse CSR tensor support.*") +# warnings.filterwarnings('ignore', ".*Converting sparse tensor to CSR.*") + +# parser = argparse.ArgumentParser() +# parser.add_argument('--backward', action='store_true') +# args = parser.parse_args() +# device = ms.get_context('device_target') +# num_nodes, num_edges = 10_000, 200_000 +# x = ops.randn(num_nodes, 64) +# edge_index = ops.randint(0, num_nodes, (2, num_edges)) + +# reductions = ['sum', 'mean'] +# if not x.is_cuda: +# reductions.extend(['min', 'max']) +# layouts = [Layout.COO, Layout.CSR, Layout.CSC] + +# for reduce, layout in itertools.product(reductions, layouts): +# print(f'Aggregator: {reduce}, Layout: {layout}') + +# adj = to_coo(edge_index, shape=num_nodes) +# adj = adj.to_sparse(layout=layout) + +# benchmark( +# funcs=[spmm], +# func_names=['spmm'], +# args=(adj, x, reduce), +# num_steps=50 if device == 'CPU' else 500, +# num_warmups=10 if device == 'CPU' else 100, +# backward=args.backward, +# ) diff --git a/tests/graph/utils/test_subgraph.py b/tests/graph/utils/test_subgraph.py new file mode 100644 index 000000000..631642483 --- /dev/null +++ b/tests/graph/utils/test_subgraph.py @@ -0,0 +1,122 @@ +import mindspore as ms +from mindspore import nn, ops, mint +from mindscience.sharker.nn.conv import GCNConv +from mindscience.sharker.testing import withPackage +from mindscience.sharker.utils import ( + bipartite_subgraph, + get_num_hops, + index_to_mask, + k_hop_subgraph, + subgraph, +) + + +def test_get_num_hops(): + class GNN(nn.Cell): + def __init__(self): + super().__init__() + self.conv1 = GCNConv(3, 16, normalize=False) + self.conv2 = GCNConv(16, 16, normalize=False) + self.lin = nn.Dense(16, 2) + + def construct(self, x, edge_index): + x = ops.relu(self.conv1(x, edge_index)) + x = self.conv2(x, edge_index) + return self.lin(x) + net = GNN() + assert get_num_hops(net) == 2 + + +def test_subgraph(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6], + [1, 0, 2, 1, 3, 2, 4, 3, 5, 4, 6, 5], + ]) + edge_attr = ms.Tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]) + + idx = ms.Tensor([3, 4, 5]) + mask = index_to_mask(idx, 7) + indices = idx.tolist() + + for subset in [idx, mask, indices]: + out = subgraph(subset, edge_index, edge_attr, return_edge_mask=True) + assert out[0].tolist() == [[3, 4, 4, 5], [4, 3, 5, 4]] + assert out[1].tolist() == [7.0, 8.0, 9.0, 10.0] + assert out[2].tolist() == [0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0] + + out = subgraph(subset, edge_index, edge_attr, relabel_nodes=True) + assert out[0].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + assert out[1].tolist() == [7, 8, 9, 10] + + +@withPackage('pandas') +def test_subgraph_large_index(): + subset = ms.Tensor([50_000_000]) + edge_index = ms.Tensor([[50_000_000], [50_000_000]]) + edge_index, _ = subgraph(subset, edge_index, relabel_nodes=True) + assert edge_index.tolist() == [[0], [0]] + + +def test_bipartite_subgraph(): + edge_index = ms.Tensor([[0, 5, 2, 3, 3, 4, 4, 3, 5, 5, 6], + [0, 0, 3, 2, 0, 0, 2, 1, 2, 3, 1]]) + edge_attr = ms.Tensor( + [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0]) + idx = (ms.Tensor([2, 3, 5]), ms.Tensor([2, 3])) + mask = (index_to_mask(idx[0], 7), index_to_mask(idx[1], 4)) + indices = (idx[0].tolist(), idx[1].tolist()) + mixed = (mask[0], idx[1]) + + for subset in [idx, mask, indices, mixed]: + out = bipartite_subgraph(subset, edge_index, edge_attr, + return_edge_mask=True) + assert out[0].tolist() == [[2, 3, 5, 5], [3, 2, 2, 3]] + assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0] + assert out[2].int().tolist() == [0, 0, 1, 1, 0, 0, 0, 0, 1, 1, 0] + + out = bipartite_subgraph(subset, edge_index, edge_attr, + relabel_nodes=True) + assert out[0].tolist() == [[0, 1, 2, 2], [1, 0, 0, 1]] + assert out[1].tolist() == [3.0, 4.0, 9.0, 10.0] + + +@withPackage('pandas') +def test_bipartite_subgraph_large_index(): + subset = ms.Tensor([50_000_000]) + edge_index = ms.Tensor([[50_000_000], [50_000_000]]) + + edge_index, _ = bipartite_subgraph( + (subset, subset), + edge_index, + relabel_nodes=True, + ) + assert edge_index.tolist() == [[0], [0]] + + +def test_k_hop_subgraph(): + edge_index = ms.Tensor([ + [0, 1, 2, 3, 4, 5], + [2, 2, 4, 4, 6, 6], + ]) + + subset, edge_index, mapping, edge_mask = k_hop_subgraph( + 6, 2, edge_index, relabel_nodes=True) + assert subset.tolist() == [2, 3, 4, 5, 6] + assert edge_index.tolist() == [[0, 1, 2, 3], [2, 2, 4, 4]] + assert mapping.tolist() == [4] + assert edge_mask.tolist() == [False, False, True, True, True, True] + + edge_index = ms.Tensor([ + [1, 2, 4, 5], + [0, 1, 5, 6], + ]) + + subset, edge_index, mapping, edge_mask = k_hop_subgraph([0, 6], 2, + edge_index, + relabel_nodes=True) + + assert subset.tolist() == [0, 1, 2, 4, 5, 6] + assert edge_index.tolist() == [[1, 2, 3, 4], [0, 1, 4, 5]] + assert mapping.tolist() == [0, 5] + assert edge_mask.tolist() == [True, True, True, True] diff --git a/tests/graph/utils/test_to_dense_adj.py b/tests/graph/utils/test_to_dense_adj.py new file mode 100644 index 000000000..5cd65f08d --- /dev/null +++ b/tests/graph/utils/test_to_dense_adj.py @@ -0,0 +1,91 @@ +import mindspore as ms +from mindscience.sharker.testing import is_full_test +from mindscience.sharker.utils import to_dense_adj +from mindspore.ops import operations as P + + +def test_to_dense_adj(): + edge_index = ms.Tensor([ + [0, 0, 1, 2, 3, 4], + [0, 1, 0, 3, 4, 2], + ]) + batch = ms.Tensor([0, 0, 1, 1, 1]) + + adj = to_dense_adj(edge_index, batch) + assert adj.shape == (2, 3, 3) + assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] + assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] + + if is_full_test(): + jit = ms.jit(to_dense_adj) + adj = jit(edge_index, batch) + assert adj.shape == (2, 3, 3) + assert adj[0].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] + assert adj[1].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] + + adj = to_dense_adj(edge_index, batch, max_num_nodes=2) + assert adj.shape == (2, 2, 2) + assert adj[0].tolist() == [[1, 1], [1, 0]] + assert adj[1].tolist() == [[0, 1], [0, 0]] + + adj = to_dense_adj(edge_index, batch, max_num_nodes=5) + assert adj.shape == (2, 5, 5) + assert adj[0][:3, :3].tolist() == [[1, 1, 0], [1, 0, 0], [0, 0, 0]] + assert adj[1][:3, :3].tolist() == [[0, 1, 0], [0, 0, 1], [1, 0, 0]] + + edge_attr = ms.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0]) + adj = to_dense_adj(edge_index, batch, edge_attr) + assert adj.shape == (2, 3, 3) + assert adj[0].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]] + assert adj[1].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]] + + adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5) + assert adj.shape == (2, 5, 5) + assert adj[0][:3, :3].tolist() == [[1, 2, 0], [3, 0, 0], [0, 0, 0]] + assert adj[1][:3, :3].tolist() == [[0, 4, 0], [0, 0, 5], [6, 0, 0]] + + edge_attr = edge_attr.view(-1, 1) + adj = to_dense_adj(edge_index, batch, edge_attr) + assert adj.shape == (2, 3, 3, 1) + + edge_attr = edge_attr.view(-1, 1) + adj = to_dense_adj(edge_index, batch, edge_attr, max_num_nodes=5) + assert adj.shape == (2, 5, 5, 1) + + adj = to_dense_adj(edge_index) + assert adj.shape == (1, 5, 5) + assert adj[0].nonzero().t().tolist() == edge_index.tolist() + + adj = to_dense_adj(edge_index, max_num_nodes=10) + assert adj.shape == (1, 10, 10) + assert adj[0].nonzero().t().tolist() == edge_index.tolist() + + adj = to_dense_adj(edge_index, batch, batch_size=4) + assert adj.shape == (4, 3, 3) + + +def test_to_dense_adj_with_duplicate_entries(): + edge_index = ms.Tensor([ + [0, 0, 0, 1, 2, 3, 3, 4], + [0, 0, 1, 0, 3, 4, 4, 2], + ]) + batch = ms.Tensor([0, 0, 1, 1, 1]) + + adj = to_dense_adj(edge_index, batch) + assert adj.shape == (2, 3, 3) + assert adj[0].tolist() == [[2, 1, 0], [1, 0, 0], [0, 0, 0]] + assert adj[1].tolist() == [[0, 1, 0], [0, 0, 2], [1, 0, 0]] + + edge_attr = ms.Tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]) + adj = to_dense_adj(edge_index, batch, edge_attr) + assert adj.shape == (2, 3, 3) + assert adj[0].tolist() == [ + [3.0, 3.0, 0.0], + [4.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ] + assert adj[1].tolist() == [ + [0.0, 5.0, 0.0], + [0.0, 0.0, 13.0], + [8.0, 0.0, 0.0], + ] diff --git a/tests/graph/utils/test_to_dense_batch.py b/tests/graph/utils/test_to_dense_batch.py new file mode 100644 index 000000000..bfb562f30 --- /dev/null +++ b/tests/graph/utils/test_to_dense_batch.py @@ -0,0 +1,90 @@ +from typing import Tuple +import pytest +import mindspore as ms +from mindspore import Tensor, ops + +from mindscience.sharker.experimental import set_experimental_mode +from mindscience.sharker.testing import onlyFullTest +from mindscience.sharker.utils import to_dense_batch + + +@pytest.mark.parametrize('fill', [70.0, ms.Tensor(49.0)]) +def test_to_dense_batch(fill): + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + batch = ms.Tensor([0, 0, 1, 2, 2, 2]) + + item = fill.item() if isinstance(fill, Tensor) else fill + expected = ms.Tensor([ + [[1.0, 2.0], [3.0, 4.0], [item, item]], + [[5.0, 6.0], [item, item], [item, item]], + [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], + ]) + + out, mask = to_dense_batch(x, batch, fill_value=fill) + assert out.shape == (3, 3, 2) + assert ops.equal(out, expected).all() + assert mask.tolist() == [[1, 1, 0], [1, 0, 0], [1, 1, 1]] + + out, mask = to_dense_batch(x, batch, max_num_nodes=2, fill_value=fill) + assert out.shape == (3, 2, 2) + assert ops.equal(out, expected[:, :2]).all() + assert mask.tolist() == [[1, 1], [1, 0], [1, 1]] + + out, mask = to_dense_batch(x, batch, max_num_nodes=5, fill_value=fill) + assert out.shape == (3, 5, 2) + assert ops.equal(out[:, :3], expected).all() + assert mask.tolist() == [[1, 1, 0, 0, 0], [1, 0, 0, 0, 0], [1, 1, 1, 0, 0]] + + out, mask = to_dense_batch(x, fill_value=fill) + assert out.shape == (1, 6, 2) + assert ops.equal(out[0], x).all() + assert mask.tolist() == [[1, 1, 1, 1, 1, 1]] + + out, mask = to_dense_batch(x, max_num_nodes=2, fill_value=fill) + assert out.shape == (1, 2, 2) + assert ops.equal(out[0], x[:2]).all() + assert mask.tolist() == [[1, 1]] + + out, mask = to_dense_batch(x, max_num_nodes=10, fill_value=fill) + assert out.shape == (1, 10, 2) + assert ops.equal(out[0, :6], x).all() + assert mask.tolist() == [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]] + + out, mask = to_dense_batch(x, batch, batch_size=4, fill_value=fill) + assert out.shape == (4, 3, 2) + + +def test_to_dense_batch_disable_dynamic_shapes(): + x = ms.Tensor([ + [1.0, 2.0], + [3.0, 4.0], + [5.0, 6.0], + [7.0, 8.0], + [9.0, 10.0], + [11.0, 12.0], + ]) + batch = ms.Tensor([0, 0, 1, 2, 2, 2]) + + with set_experimental_mode(True, 'disable_dynamic_shapes'): + with pytest.raises(ValueError, match="'batch_size' needs to be set"): + out, mask = to_dense_batch(x, batch, max_num_nodes=6) + with pytest.raises(ValueError, match="'max_num_nodes' needs to be set"): + out, mask = to_dense_batch(x, batch, batch_size=4) + with pytest.raises(ValueError, match="'batch_size' needs to be set"): + out, mask = to_dense_batch(x) + + out, mask = to_dense_batch(x, batch_size=1, max_num_nodes=6) + assert out.shape == (1, 6, 2) + assert mask.shape == (1, 6) + + out, mask = to_dense_batch(x, batch, batch_size=3, max_num_nodes=10) + assert out.shape == (3, 10, 2) + assert mask.shape == (3, 10) + diff --git a/tests/graph/utils/test_tree_decomposition.py b/tests/graph/utils/test_tree_decomposition.py new file mode 100644 index 000000000..815dc9ff2 --- /dev/null +++ b/tests/graph/utils/test_tree_decomposition.py @@ -0,0 +1,16 @@ +import pytest +import mindspore as ms + +from mindscience.sharker.testing import withPackage +from mindscience.sharker.utils import tree_decomposition + + +@withPackage('rdkit') +@pytest.mark.parametrize('smiles', [ + r'F/C=C/F', + r'C/C(=C\C(=O)c1ccc(C)o1)Nc1ccc2c(c1)OCO2', +]) +def test_tree_decomposition(smiles): + from rdkit import Chem + mol = Chem.MolFromSmiles(smiles) + tree_decomposition(mol) # TODO Test output diff --git a/tests/graph/utils/test_trim_to_layer.py b/tests/graph/utils/test_trim_to_layer.py new file mode 100644 index 000000000..64ee7029d --- /dev/null +++ b/tests/graph/utils/test_trim_to_layer.py @@ -0,0 +1,154 @@ +from typing import List, Optional + +import mindspore as ms +from mindspore import Tensor, ops, nn, mint + +from mindscience.sharker import typing +from mindscience.sharker.nn.conv import GraphConv +from mindscience.sharker.typing import SparseTensor +from mindscience.sharker.utils import trim_to_layer + + +def test_trim_to_layer_basic(): + x0 = mint.arange(4) + edge_index0 = ms.Tensor([[1, 2, 3], [0, 1, 2]]) + edge_weight0 = mint.arange(3) + + num_sampled_nodes_per_hop = [1, 1, 1] + num_sampled_edges_per_hop = [1, 1, 1] + + x1, edge_index1, edge_weight1 = trim_to_layer( + layer=0, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x0, + edge_index=edge_index0, + edge_attr=edge_weight0, + ) + assert ops.equal(x1, mint.arange(4)).all() + assert edge_index1.tolist() == [[1, 2, 3], [0, 1, 2]] + assert ops.equal(edge_weight1, mint.arange(3)).all() + + if typing.WITH_SPARSE: + adj0 = SparseTensor.from_edge_index(edge_index0, edge_weight0, (4, 4)) + x1, adj_t1, _ = trim_to_layer( + layer=0, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x0, + edge_index=adj0.t(), + edge_attr=edge_weight0, + ) + adj1 = adj_t1.t() + assert adj1.sizes() == [4, 4] + + row, col, value = adj1.coo() + assert ops.equal(x1, mint.arange(4)).all() + assert row.tolist() == [1, 2, 3] + assert col.tolist() == [0, 1, 2] + assert ops.equal(value, mint.arange(3)).all() + + x2, edge_index2, edge_weight2 = trim_to_layer( + layer=1, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x1, + edge_index=edge_index1, + edge_attr=edge_weight1, + ) + assert ops.equal(x2, mint.arange(3)).all() + assert edge_index2.tolist() == [[1, 2], [0, 1]] + assert ops.equal(edge_weight2, mint.arange(2)).all() + + if typing.WITH_SPARSE: + adj1 = SparseTensor.from_edge_index(edge_index1, edge_weight1, (4, 4)) + x2, adj_t2, _ = trim_to_layer( + layer=1, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x1, + edge_index=adj1.t(), + ) + adj2 = adj_t2.t() + assert adj2.sizes() == [3, 3] + + row, col, value = adj2.coo() + assert ops.equal(x2, mint.arange(3)).all() + assert row.tolist() == [1, 2] + assert col.tolist() == [0, 1] + assert ops.equal(value, mint.arange(2)).all() + + x3, edge_index3, edge_weight3 = trim_to_layer( + layer=2, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x2, + edge_index=edge_index2, + edge_attr=edge_weight2, + ) + assert ops.equal(x3, mint.arange(2)).all() + assert edge_index3.tolist() == [[1], [0]] + assert ops.equal(edge_weight3, mint.arange(1)).all() + + if typing.WITH_SPARSE: + adj2 = SparseTensor.from_edge_index(edge_index2, edge_weight2, (3, 3)) + x3, adj_t3, _ = trim_to_layer( + layer=2, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x2, + edge_index=adj2.t(), + ) + adj3 = adj_t3.t() + assert adj3.sizes() == [2, 2] + + row, col, value = adj3.coo() + assert ops.equal(x3, mint.arange(2)).all() + assert row.tolist() == [1] + assert col.tolist() == [0] + assert ops.equal(value, mint.arange(1)).all() + + +def test_trim_to_layer_hetero(): + x = {'v': mint.arange(4)} + edge_index = {('v', 'to', 'v'): ms.Tensor([[1, 2, 3], [0, 1, 2]])} + edge_weight = {('v', 'to', 'v'): mint.arange(3)} + + num_sampled_nodes_per_hop = {'v': [1, 1, 1, 1]} + num_sampled_edges_per_hop = {('v', 'to', 'v'): [1, 1, 1]} + + x, edge_index, edge_weight = trim_to_layer( + layer=1, + num_sampled_nodes_per_hop=num_sampled_nodes_per_hop, + num_sampled_edges_per_hop=num_sampled_edges_per_hop, + x=x, + edge_index=edge_index, + edge_attr=edge_weight, + ) + assert ops.equal(x['v'], mint.arange(3)).all() + assert edge_index['v', 'to', 'v'].tolist() == [[1, 2], [0, 1]] + assert ops.equal(edge_weight['v', 'to', 'v'], mint.arange(2)).all() + + +class GNN(nn.Cell): + def __init__(self, num_layers: int): + super().__init__() + + self.convs = nn.CellList( + GraphConv(16, 16) for _ in range(num_layers)) + + def construct( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: Tensor, + num_sampled_nodes: Optional[List[int]] = None, + num_sampled_edges: Optional[List[int]] = None, + ) -> Tensor: + for i, conv in enumerate(self.convs): + if num_sampled_nodes is not None: + x, edge_index, edge_weight = trim_to_layer( + i, num_sampled_nodes, num_sampled_edges, x, edge_index, + edge_weight) + x = conv(x, edge_index, edge_weight) + return x diff --git a/tests/graph/utils/test_unbatch.py b/tests/graph/utils/test_unbatch.py new file mode 100644 index 000000000..2da9576af --- /dev/null +++ b/tests/graph/utils/test_unbatch.py @@ -0,0 +1,25 @@ +import mindspore as ms +from mindspore import ops, mint +from mindscience.sharker.utils import unbatch, unbatch_edge_index + + +def test_unbatch(): + src = mint.arange(10) + batch = ms.Tensor([0, 0, 0, 1, 1, 2, 2, 3, 4, 4]) + + out = unbatch(src, batch) + assert len(out) == 5 + for i in range(len(out)): + assert ops.equal(out[i], src[batch == i]).all() + + +def test_unbatch_edge_index(): + edge_index = ms.Tensor([ + [0, 1, 1, 2, 2, 3, 4, 5, 5, 6], + [1, 0, 2, 1, 3, 2, 5, 4, 6, 5], + ]) + batch = ms.Tensor([0, 0, 0, 0, 1, 1, 1]) + + edge_indices = unbatch_edge_index(edge_index, batch) + assert edge_indices[0].tolist() == [[0, 1, 1, 2, 2, 3], [1, 0, 2, 1, 3, 2]] + assert edge_indices[1].tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] diff --git a/tests/graph/utils/test_undirected.py b/tests/graph/utils/test_undirected.py new file mode 100644 index 000000000..67350f1d4 --- /dev/null +++ b/tests/graph/utils/test_undirected.py @@ -0,0 +1,40 @@ +import mindspore as ms +from mindspore import Tensor, ops, mint + +from mindscience.sharker.utils import is_undirected, to_undirected + + +def test_is_undirected(): + row = ms.Tensor([0, 1, 0]) + col = ms.Tensor([1, 0, 0]) + sym_weight = ms.Tensor([0, 0, 1]) + asym_weight = ms.Tensor([0, 1, 1]) + + assert is_undirected(mint.stack(([row, col]), dim=0)) + assert is_undirected(mint.stack(([row, col]), dim=0), sym_weight) + assert not is_undirected(mint.stack(([row, col]), dim=0), asym_weight) + + row = ms.Tensor([0, 1, 1]) + col = ms.Tensor([1, 0, 2]) + + assert not is_undirected(mint.stack(([row, col]), dim=0)) + + # @ms.jit + # def jit(edge_index: Tensor) -> bool: + # return is_undirected(edge_index) + + # assert not jit(ops.stack(([row, col]), axis=0)) + + +def test_to_undirected(): + row = ms.Tensor([0, 1, 1]) + col = ms.Tensor([1, 0, 2]) + + edge_index = to_undirected(mint.stack(([row, col]), dim=0)) + assert edge_index.tolist() == [[0, 1, 1, 2], [1, 0, 2, 1]] + + @ms.jit + def jit(edge_index: Tensor) -> Tensor: + return to_undirected(edge_index) + + assert ops.equal(jit(mint.stack(([row, col]), dim=0)), edge_index).all() -- Gitee