From 6870a2b1fafac49a791e8df4f00957f636d1aa63 Mon Sep 17 00:00:00 2001 From: Yuheng Wang Date: Thu, 10 Jul 2025 14:34:48 +0800 Subject: [PATCH] add alphafold3 to MindSPONGE research --- .../research/AlphaFold3/CMakeLists.txt | 95 + .../applications/research/AlphaFold3/LICENSE | 437 +++ .../research/AlphaFold3/README.md | 227 ++ .../research/AlphaFold3/README_EN.md | 226 ++ .../research/AlphaFold3/requirements.txt | 6 + .../research/AlphaFold3/run_alphafold.py | 680 ++++ .../AlphaFold3/run_alphafold_data_test.py | 299 ++ .../AlphaFold3/run_alphafold_test_v2.py | 381 ++ .../research/AlphaFold3/set_path.sh | 37 + .../AlphaFold3/src/alphafold3/__init__.py | 0 .../AlphaFold3/src/alphafold3/build_data.py | 44 + .../src/alphafold3/common/base_config.py | 151 + .../src/alphafold3/common/folding_input.py | 1115 ++++++ .../src/alphafold3/common/resources.py | 77 + .../src/alphafold3/common/testing/data.py | 70 + .../src/alphafold3/constants/atom_types.py | 262 ++ .../constants/chemical_component_sets.py | 38 + .../constants/chemical_components.py | 188 + .../constants/converters/ccd_pickle_gen.py | 53 + .../converters/chemical_component_sets_gen.py | 81 + .../src/alphafold3/constants/mmcif_names.py | 218 ++ .../alphafold3/constants/periodic_table.py | 399 +++ .../src/alphafold3/constants/residue_names.py | 421 +++ .../src/alphafold3/constants/side_chains.py | 112 + .../research/AlphaFold3/src/alphafold3/cpp.cc | 48 + .../alphafold3/data/cpp/msa_profile_pybind.cc | 79 + .../alphafold3/data/cpp/msa_profile_pybind.h | 25 + .../src/alphafold3/data/featurisation.py | 90 + .../AlphaFold3/src/alphafold3/data/msa.py | 344 ++ .../src/alphafold3/data/msa_config.py | 168 + .../src/alphafold3/data/msa_features.py | 203 ++ .../src/alphafold3/data/msa_identifiers.py | 86 + .../src/alphafold3/data/msa_store.py | 67 + .../AlphaFold3/src/alphafold3/data/parsers.py | 180 + .../src/alphafold3/data/pipeline.py | 538 +++ .../src/alphafold3/data/structure_stores.py | 101 + .../src/alphafold3/data/template_realign.py | 169 + .../src/alphafold3/data/template_store.py | 47 + .../src/alphafold3/data/templates.py | 969 +++++ .../src/alphafold3/data/tools/hmmalign.py | 143 + .../src/alphafold3/data/tools/hmmbuild.py | 145 + .../src/alphafold3/data/tools/hmmsearch.py | 150 + .../src/alphafold3/data/tools/jackhmmer.py | 135 + .../src/alphafold3/data/tools/msa_tool.py | 31 + .../src/alphafold3/data/tools/nhmmer.py | 167 + .../src/alphafold3/data/tools/rdkit_utils.py | 520 +++ .../alphafold3/data/tools/subprocess_utils.py | 107 + .../model/atom_layout/atom_layout.py | 1193 +++++++ .../src/alphafold3/model/base_config.py | 153 + .../alphafold3/model/components/base_model.py | 52 + .../model/components/base_modules.py | 146 + .../alphafold3/model/components/mapping.py | 356 ++ .../src/alphafold3/model/components/utils.py | 63 + .../src/alphafold3/model/confidence_types.py | 306 ++ .../src/alphafold3/model/confidences.py | 664 ++++ .../AlphaFold3/src/alphafold3/model/data3.py | 127 + .../src/alphafold3/model/data_constants.py | 27 + .../model/diffusion/atom_cross_attention.py | 490 +++ .../model/diffusion/confidence_head.py | 293 ++ .../model/diffusion/diffusion_head.py | 326 ++ .../model/diffusion/diffusion_transformer.py | 496 +++ .../model/diffusion/distogram_head.py | 85 + .../model/diffusion/featurization.py | 214 ++ .../alphafold3/model/diffusion/load_ckpt.py | 577 +++ .../model/diffusion/load_ckpt.py.bak | 318 ++ .../src/alphafold3/model/diffusion/model.py | 759 ++++ .../src/alphafold3/model/diffusion/modules.py | 567 +++ .../model/diffusion/random/bias.npy | Bin 0 -> 1152 bytes .../model/diffusion/random/weight.npy | Bin 0 -> 1152 bytes .../model/diffusion/template_modules.py | 347 ++ .../alphafold3/model/diffusion/triangle.py | 258 ++ .../src/alphafold3/model/feat_batch.py | 180 + .../src/alphafold3/model/features.py | 2101 +++++++++++ .../src/alphafold3/model/load_batch.py | 22 + .../src/alphafold3/model/merging_features.py | 92 + .../src/alphafold3/model/mkdssp_pybind.cc | 63 + .../src/alphafold3/model/mkdssp_pybind.h | 26 + .../src/alphafold3/model/mmcif_metadata.py | 199 ++ .../src/alphafold3/model/model_config.py | 32 + .../src/alphafold3/model/msa_pairing.py | 316 ++ .../AlphaFold3/src/alphafold3/model/params.py | 218 ++ .../model/pipeline/inter_chain_bonds.py | 347 ++ .../src/alphafold3/model/pipeline/pipeline.py | 446 +++ .../model/pipeline/structure_cleaning.py | 371 ++ .../src/alphafold3/model/post_processing.py | 114 + .../model/protein_data_processing.py | 128 + .../src/alphafold3/model/scoring/alignment.py | 146 + .../model/scoring/covalent_bond_cleaning.py | 265 ++ .../src/alphafold3/model/scoring/scoring.py | 67 + .../src/alphafold3/parsers/cpp/cif_dict.pyi | 125 + .../alphafold3/parsers/cpp/cif_dict_lib.cc | 648 ++++ .../src/alphafold3/parsers/cpp/cif_dict_lib.h | 149 + .../alphafold3/parsers/cpp/cif_dict_pybind.cc | 652 ++++ .../alphafold3/parsers/cpp/cif_dict_pybind.h | 24 + .../alphafold3/parsers/cpp/fasta_iterator.pyi | 22 + .../parsers/cpp/fasta_iterator_lib.cc | 121 + .../parsers/cpp/fasta_iterator_lib.h | 94 + .../parsers/cpp/fasta_iterator_pybind.cc | 127 + .../parsers/cpp/fasta_iterator_pybind.h | 24 + .../alphafold3/parsers/cpp/msa_conversion.pyi | 26 + .../parsers/cpp/msa_conversion_pybind.cc | 162 + .../parsers/cpp/msa_conversion_pybind.h | 24 + .../src/alphafold3/structure/__init__.py | 46 + .../src/alphafold3/structure/bioassemblies.py | 333 ++ .../src/alphafold3/structure/bonds.py | 237 ++ .../structure/chemical_components.py | 286 ++ .../alphafold3/structure/cpp/aggregation.pyi | 13 + .../structure/cpp/aggregation_pybind.cc | 54 + .../structure/cpp/aggregation_pybind.h | 24 + .../alphafold3/structure/cpp/membership.pyi | 18 + .../structure/cpp/membership_pybind.cc | 82 + .../structure/cpp/membership_pybind.h | 24 + .../alphafold3/structure/cpp/mmcif_altlocs.cc | 249 ++ .../alphafold3/structure/cpp/mmcif_altlocs.h | 51 + .../structure/cpp/mmcif_atom_site.pyi | 23 + .../structure/cpp/mmcif_atom_site_pybind.cc | 83 + .../structure/cpp/mmcif_atom_site_pybind.h | 24 + .../alphafold3/structure/cpp/mmcif_layout.h | 146 + .../alphafold3/structure/cpp/mmcif_layout.pyi | 26 + .../structure/cpp/mmcif_layout_lib.cc | 213 ++ .../structure/cpp/mmcif_layout_pybind.cc | 49 + .../structure/cpp/mmcif_layout_pybind.h | 24 + .../structure/cpp/mmcif_struct_conn.h | 34 + .../structure/cpp/mmcif_struct_conn.pyi | 13 + .../structure/cpp/mmcif_struct_conn_lib.cc | 380 ++ .../structure/cpp/mmcif_struct_conn_pybind.cc | 68 + .../structure/cpp/mmcif_struct_conn_pybind.h | 24 + .../alphafold3/structure/cpp/mmcif_utils.pyi | 71 + .../structure/cpp/mmcif_utils_pybind.cc | 787 ++++ .../structure/cpp/mmcif_utils_pybind.h | 24 + .../alphafold3/structure/cpp/string_array.pyi | 50 + .../structure/cpp/string_array_pybind.cc | 329 ++ .../structure/cpp/string_array_pybind.h | 24 + .../src/alphafold3/structure/mmcif.py | 333 ++ .../src/alphafold3/structure/parsing.py | 1806 ++++++++++ .../src/alphafold3/structure/sterics.py | 142 + .../src/alphafold3/structure/structure.py | 3181 +++++++++++++++++ .../alphafold3/structure/structure_tables.py | 843 +++++ .../src/alphafold3/structure/table.py | 565 +++ .../src/alphafold3/structure/test_utils.py | 358 ++ .../alphafold3/utils/attention/attention.py | 78 + .../utils/attention/attention_base.py | 275 ++ .../attention/attention_call_arg_specs.py | 61 + .../utils/attention/ms_attention.py | 97 + .../src/alphafold3/utils/common/precision.py | 90 + .../gated_linear_unit/gated_linear_unit.py | 69 + .../gated_linear_unit_base.py | 84 + .../src/alphafold3/utils/geometry/__init__.py | 28 + .../utils/geometry/rigid_matrix_vector.py | 192 + .../utils/geometry/rotation_matrix.py | 257 ++ .../utils/geometry/struct_of_array.py | 291 ++ .../src/alphafold3/utils/geometry/utils.py | 149 + .../src/alphafold3/utils/geometry/vector.py | 258 ++ 153 files changed, 38243 insertions(+) create mode 100644 MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt create mode 100644 MindSPONGE/applications/research/AlphaFold3/LICENSE create mode 100644 MindSPONGE/applications/research/AlphaFold3/README.md create mode 100644 MindSPONGE/applications/research/AlphaFold3/README_EN.md create mode 100644 MindSPONGE/applications/research/AlphaFold3/requirements.txt create mode 100644 MindSPONGE/applications/research/AlphaFold3/run_alphafold.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/run_alphafold_data_test.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/set_path.sh create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/__init__.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py.bak create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py create mode 100644 MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py diff --git a/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt b/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt new file mode 100644 index 000000000..81162722e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/CMakeLists.txt @@ -0,0 +1,95 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +cmake_minimum_required(VERSION 3.28) +project( + "${SKBUILD_PROJECT_NAME}" + LANGUAGES CXX + VERSION "${SKBUILD_PROJECT_VERSION}") + +include(FetchContent) +set(CMAKE_CXX_STANDARD 20) +set(CMAKE_CXX_STANDARD_REQUIRED ON) +set(CMAKE_POSITION_INDEPENDENT_CODE TRUE) +set(ABSL_PROPAGATE_CXX_STD ON) + +# Remove support for scan deps, which is only useful when using C++ modules. +unset(CMAKE_CXX_SCANDEP_SOURCE) + +FetchContent_Declare( + abseil-cpp + GIT_REPOSITORY https://github.com/abseil/abseil-cpp + GIT_TAG d7aaad83b488fd62bd51c81ecf16cd938532cc0a # 20240116.2 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + pybind11 + GIT_REPOSITORY https://github.com/pybind/pybind11 + GIT_TAG 2e0815278cb899b20870a67ca8205996ef47e70f # v2.12.0 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + pybind11_abseil + GIT_REPOSITORY https://github.com/pybind/pybind11_abseil + GIT_TAG bddf30141f9fec8e577f515313caec45f559d319 # HEAD @ 2024-08-07 + EXCLUDE_FROM_ALL) + +FetchContent_Declare( + cifpp + GIT_REPOSITORY https://github.com/pdb-redo/libcifpp + GIT_TAG ac98531a2fc8daf21131faa0c3d73766efa46180 # v7.0.3 + # Don't `EXCLUDE_FROM_ALL` as necessary for build_data. +) + +FetchContent_Declare( + dssp + GIT_REPOSITORY https://github.com/PDB-REDO/dssp + GIT_TAG 57560472b4260dc41f457706bc45fc6ef0bc0f10 # v4.4.7 + EXCLUDE_FROM_ALL) + +FetchContent_MakeAvailable(pybind11 abseil-cpp pybind11_abseil cifpp dssp) + +find_package( + Python3 + COMPONENTS Interpreter Development NumPy + REQUIRED) + +include_directories(${PYTHON_INCLUDE_DIRS}) +include_directories(src/) + +file(GLOB_RECURSE cpp_srcs src/alphafold3/*.cc) +list(FILTER cpp_srcs EXCLUDE REGEX ".*\(_test\|_main\|_benchmark\).cc$") + +add_compile_definitions(NPY_NO_DEPRECATED_API=NPY_1_7_API_VERSION) + +pybind11_add_module(cpp ${cpp_srcs}) + +target_link_libraries( + cpp + PRIVATE absl::check + absl::flat_hash_map + absl::node_hash_map + absl::strings + absl::status + absl::statusor + absl::log + pybind11_abseil::absl_casters + Python3::NumPy + dssp::dssp + cifpp::cifpp) + +target_compile_definitions(cpp PRIVATE VERSION_INFO=${PROJECT_VERSION}) +install(TARGETS cpp LIBRARY DESTINATION alphafold3) +install( + FILES LICENSE + OUTPUT_TERMS_OF_USE.md + WEIGHTS_PROHIBITED_USE_POLICY.md + WEIGHTS_TERMS_OF_USE.md + DESTINATION alphafold3) diff --git a/MindSPONGE/applications/research/AlphaFold3/LICENSE b/MindSPONGE/applications/research/AlphaFold3/LICENSE new file mode 100644 index 000000000..bfef380bf --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/LICENSE @@ -0,0 +1,437 @@ +Attribution-NonCommercial-ShareAlike 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International +Public License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial-ShareAlike 4.0 International Public License +("Public License"). To the extent this Public License may be +interpreted as a contract, You are granted the Licensed Rights in +consideration of Your acceptance of these terms and conditions, and the +Licensor grants You such rights in consideration of benefits the +Licensor receives from making the Licensed Material available under +these terms and conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. BY-NC-SA Compatible License means a license listed at + creativecommons.org/compatiblelicenses, approved by Creative + Commons as essentially the equivalent of this Public License. + + d. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + + e. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + f. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + g. License Elements means the license attributes listed in the name + of a Creative Commons Public License. The License Elements of this + Public License are Attribution, NonCommercial, and ShareAlike. + + h. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + i. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + j. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + k. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + l. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + m. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + n. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. Additional offer from the Licensor -- Adapted Material. + Every recipient of Adapted Material from You + automatically receives an offer from the Licensor to + exercise the Licensed Rights in the Adapted Material + under the conditions of the Adapter's License You apply. + + c. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + b. ShareAlike. + + In addition to the conditions in Section 3(a), if You Share + Adapted Material You produce, the following conditions also apply. + + 1. The Adapter's License You apply must be a Creative Commons + license with the same License Elements, this version or + later, or a BY-NC-SA Compatible License. + + 2. You must include the text of, or the URI or hyperlink to, the + Adapter's License You apply. You may satisfy this condition + in any reasonable manner based on the medium, means, and + context in which You Share Adapted Material. + + 3. You may not offer or impose any additional or different terms + or conditions on, or apply any Effective Technological + Measures to, Adapted Material that restrict exercise of the + rights granted under the Adapter's License You apply. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material, + including for purposes of Section 3(b); and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/README.md b/MindSPONGE/applications/research/AlphaFold3/README.md new file mode 100644 index 000000000..ce10f7e71 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/README.md @@ -0,0 +1,227 @@ +# AlphaFold3-MindSpore + +[**Ascend MindSpore版 AlphaFold3实现**] 一个基于MindSpore深度学习框架的AlphaFold3推理网络结构实现。 + +> 📖 **语言版本**: [中文](README.md) | [English](README_EN.md) + +## 📑 目录 + +- [项目简介](#项目简介) +- [安装](#安装) +- [快速开始](#快速开始) +- [详细使用说明](#详细使用说明) +- [许可证](#许可证) +- [致谢](#致谢) +- [参考文献](#参考文献) + +## 项目简介 + +**项目背景**: +AlphaFold3是DeepMind在2024年发布的革命性生物分子结构预测模型,能够预测蛋白质、DNA、RNA等生物大分子的三维结构。本项目基于Ascend NPU和MindSpore框架,实现了AlphaFold3的推理功能。 + +AlphaFold3 的模型结构如下图所示: +![架构对比](https://wiki.huawei.com/vision-file-storage/api/file/download/upload-v2/WIKI202503276382286/20006468/ef66568e62bf49a7a1885bbc903a2579.png?appKey=56f69231-0ee9-11ed-8d72-fa163ecf9d11) + +- 🧬 **蛋白质结构预测**: 基于AlphaFold3算法的生物分子结构预测模型 +- 🚀 **MindSpore支持**: 基于Ascend和MindSpore对模型进行适配 +- 🔧 **推理优化**: 专门针对推理场景进行优化 + +### 软件要求 + +- Python >= 3.11 +- MindSpore >= 2.5.0 +- CANN = 8.0.0 +- cmake >= 3.28.1 + +## 安装 + +### 1. 克隆仓库 + +```bash +git clone https://gitee.com/mindspore/mindscience.git +cd mindsience/MindSPONGE/application/research/AlphaFold3 +``` + +### 2. 安装依赖 + +```bash +pip install -r requirements.txt +git clone https://gitee.com/mindspore/mindscience.git +#`{PATH}` 为当前目录 +export PYTHONPATH={PATH}/mindscience/MindSPONGE/src +export PYTHONPATH={PATH}/mindscience/MindChemistry +``` + +### 3. 安装软件包 + +http://eddylab.org/software/hmmer/ 在链接处下载安装包,如 hmmer-3.4.tar.gz,并放置在当前目录下 + +```bash +mkdir /path/to/hmmer_build /path/to/hmmer && \ +mv ./hmmer-3.4.tar.gz /path/to/hmmer_build && \ +cd /path/to/hmmer_build && tar -zxf hmmer-3.4.tar.gz && rm hmmer-3.4.tar.gz && \ +cd /path/to/hmmer_build/hmmer-3.4 && ./configure --prefix=/path/to/hmmer && \ +make -j8 && make install && \ +cd /path/to/hmmer_build/hmmer-3.4/easel && make install && \ +rm -rf /path/to/hmmer_build +export PATH=/hmmer/bin:$PATH +which jackhmmer +``` + +如果出现/path/to/hmmer/bin/jackhmmer则安装成功 + +### 4. 编译 + +```bash +cd {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3 +mkdir build +cd build +cmake .. +make +cp ./cpp.cpython-311-aarch64-linux-gnu.so ../src/alphafold +cd .. +``` + +编译得到{PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3/build/cpp.cpython-311-aarch64-linux-gnu.so +将其复制到{PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold目录下,完成编译。 + +### 5. 下载数据库 + +可以从DeppMind官网下载测试用小数据库[miniature_databases](https://github.com/google-deepmind/alphafold3/tree/main/src/alphafold3/test_data/miniature_databases)(影响推理结果,仅测试使用!) +下载后放置在统一文件夹中并修改文件名如下所示: + +```txt +miniature_databases + └─ mmcif_files + │ bfd-first_non_consensus_sequences.fasta + │ mgy_clusters_2022_05.fa + │ pdb_seqres_2022_09_28.fasta + │ uniprot_all_2021_04.fa + │ uniref90_2022_05.fa + │ nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta + │ rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta + │ rnacentral_active_seq_id_90_cov_80_linclust.fasta +``` +如果想要搜索完整的数据库,请从一下链接下载数据库,放置到同一文件夹中: + +- [mmcif](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_2022_09_28_mmcif_files.tar.zst) +- [BFD](https://storage.googleapis.com/alphafold-databases/v3.0/bfd-first_non_consensus_sequences.fasta.zst) +- [MGnify](https://storage.googleapis.com/alphafold-databases/v3.0/mgy_clusters_2022_05.fa.zst) +- [PDB seqres](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_seqres_2022_09_28.fasta.zst) +- [UniProt](https://storage.googleapis.com/alphafold-databases/v3.0/uniprot_all_2021_04.fa.zst) +- [uniref90](https://storage.googleapis.com/alphafold-databases/v3.0/uniref90_2022_05.fa.zst) +- [NT](https://storage.googleapis.com/alphafold-databases/v3.0/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RFam](https://storage.googleapis.com/alphafold-databases/v3.0/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RNACentral](https://storage.googleapis.com/alphafold-databases/v3.0/rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst) + +请确保磁盘中有足够空间: +| DataBase | Compressed Size | Uncompressed Size| +|--------------|---------------------|------------------| +| mmcif | 233G | 233G | +| BFD | 9.2G | 16.9G | +| MGnify | 64.5G | 119G | +| PDB seqres| 25.3M | 217M | +| UniProt | 45.3G | 101G | +| uniref90 | 30.9G | 66.8G | +| NT | 15.8G | 75.4G | +| RFam | 53.9M | 217M | +| RNACentral| 3.27G | 12.9G | +| total | 402G | 534G | + +解压下载的数据文件: + +```bash +cd /PATH/TO/YOUR/DATA_DIR +tar –use-compress-program=unzstd -xf pdb_2022_09_28_mmcif_files.tar.zst +zstd -d bfd-first_non_consensus_sequences.fasta.zst +zstd -d mgy_clusters_2022_05.fa.zst +zstd -d pdb_seqres_2022_09_28.fasta.zst +zstd -d uniprot_all_2021_04.fa.zst +zstd -d uniref90_2022_05.fa.zst +zstd -d nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst +``` + +使用以下命令运行模型 + +```bash +bash set_path.sh +python run_alphafold.py \ + --json_path=input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true\ + --db_dir=/PATH/TO/DB_DIR +``` + +如统一放置在/mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases可省略```--db_dir=/PATH/TO/DB_DIR``` + +## 快速开始 + +### 输入数据格式 + +示例输入JSON: + +```json +{ + "name": "5tgy", + "sequences": [ + { + "protein": { + "id": "A", + "sequence": "SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN" + } + } + ], + "modelSeeds": [1], + "dialect": "alphafold3", + "version": 1 +} +``` + +### 运行流程 + +使用一下命令运行模型: + +```bash +python run_alphafold.py \ + --json_path=input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true +``` + +- **JSON格式数据输入**: 包含蛋白质核酸等的序列信息 +- **CIF格式输出**: 5个标准的蛋白质结构文件,及置信度信息 + +当看到如下日志,表明推理正常结束: + +```txt +=======write output to /PATH/TO/OUTPUT/DIR/name_of_your_input========== +Done processing fold input name_of_your_input. +Done processing 1 fold inputs. +``` + +### 参数说明 + +- `--json_path`: 输入文件名称 +- `--output`: 输出文件路径 +- `--model_dir`: 模型文件路径 +- `--bucket`: 设定序列长度,默认为输入序列长度 +- `--run_data_pipeline`: 是否运行数据处理模块 +- `--run_inference`: 是否运行推理模块 +- `--db_dir`: 数据库存放路径 + +## 许可证 + +详情请参阅 [LICENSE](LICENSE) 文件。 + +## 致谢 + +- data,structure,common, constant等模块使用了[DeepMind](https://deepmind.com/)实现。 +- model,utils等模块基于[MindSpore](https://www.mindspore.cn/)实现 + +## 参考文献 + +- Abramson J, Adler J, Dunger J, et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3[J]. Nature, 2024, 630(8016): 493-500. \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/README_EN.md b/MindSPONGE/applications/research/AlphaFold3/README_EN.md new file mode 100644 index 000000000..f5af0051e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/README_EN.md @@ -0,0 +1,226 @@ +# AlphaFold3-MindSpore + +[**Ascend MindSpore Implementation of AlphaFold3**] A MindSpore-based deep learning framework implementation of AlphaFold3 inference network architecture. + +> 📖 **Language**: [中文](README.md) | [English](README_EN.md) + +## 📑 Table of Contents + +- [Project Overview](#project-overview) +- [Installation](#installation) +- [Quick Start](#quick-start) +- [License](#license) +- [Acknowledgments](#acknowledgments) +- [Reference](#reference) + +## Project Overview + +**Project Background**: +AlphaFold3 is a revolutionary biomolecular structure prediction model released by DeepMind in 2024, capable of predicting the three-dimensional structures of proteins, DNA, RNA, and other biological macromolecules. This project implements AlphaFold3's inference functionality based on Ascend NPU and MindSpore framework, providing important biological computing tools for China's autonomous and controllable AI computing ecosystem. + +Model Architecture is shown below: +![archtecture](https://wiki.huawei.com/vision-file-storage/api/file/download/upload-v2/WIKI202503276382286/20006468/ef66568e62bf49a7a1885bbc903a2579.png?appKey=56f69231-0ee9-11ed-8d72-fa163ecf9d11) + +- 🧬 **Protein Structure Prediction**: Biomolecular structure prediction model based on AlphaFold3 algorithm +- 🚀 **MindSpore Support**: Model adaptation based on Ascend and MindSpore +- 🔧 **Inference Optimization**: Specifically optimized for inference scenarios + +### Software Requirements + +- Python >= 3.11 +- MindSpore >= 2.5.0 +- CANN = 8.0.0 +- cmake >= 3.28.1 + +## Installation + +### 1. Clone Repository + +```bash +pip install -r requirements.txt +git clone https://gitee.com/mindspore/mindscience.git +#`{PATH}` is the current directory. +export PYTHONPATH={PATH}/mindscience/MindSPONGE/src +export PYTHONPATH={PATH}/mindscience/MindChemistry +``` + +### 2. Install Dependencies + +```bash +pip install -r requirements.txt +``` + +### 3. Installing the Software Package + +Download the installation package from the link http://eddylab.org/software/hmmer/ , such as hmmer-3.4.tar.gz, and place it in the current directory. + +```bash +mkdir /path/to/hmmer_build /path/to/hmmer && \ +mv ./hmmer-3.4.tar.gz /path/to/hmmer_build && \ +cd /path/to/hmmer_build && tar -zxf hmmer-3.4.tar.gz && rm hmmer-3.4.tar.gz && \ +cd /path/to/hmmer_build/hmmer-3.4 && ./configure --prefix=/path/to/hmmer && \ +make -j8 && make install && \ +cd /path/to/hmmer_build/hmmer-3.4/easel && make install && \ +rm -rf /path/to/hmmer_build +export PATH=/hmmer/bin:$PATH +which jackhmmer +``` + +If the file `/path/to/hmmer/bin/jackhmmer` appears, the installation is successful. + +### 4. Compile + +```bash +cd {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3 +mkdir build +cd build +cmake .. +make +cp ./cpp.cpython-311-aarch64-linux-gnu.so ../src/alphafold +cd .. +``` + +Compiling will get the file: {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3/build/cpp.cpython-311-aarch64-linux-gnu.so. And copy this file to {PATH}/mindscience/MindSPONGE/applications/research/AlphaFold3/src/alphafold. + +### 5. Donwload DataBase + +You can download a small test databse from DeepMind [miniature_databases](https://github.com/google-deepmind/alphafold3/tree/main/src/alphafold3/test_data/miniature_databases)(Only for test, have influence to inference result!) +Download and put all the files in the same direction, and rename the file like the example below: + +```txt +miniature_databases + └─ mmcif_files + │ bfd-first_non_consensus_sequences.fasta + │ mgy_clusters_2022_05.fa + │ pdb_seqres_2022_09_28.fasta + │ uniprot_all_2021_04.fa + │ uniref90_2022_05.fa + │ nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta + │ rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta + │ rnacentral_active_seq_id_90_cov_80_linclust.fasta +``` + +If you want to seearch the full database, download the following database, and put them in the same direction: + +- [mmcif](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_2022_09_28_mmcif_files.tar.zst) +- [BFD small](https://storage.googleapis.com/alphafold-databases/v3.0/bfd-first_non_consensus_sequences.fasta.zst) +- [MGnify](https://storage.googleapis.com/alphafold-databases/v3.0/mgy_clusters_2022_05.fa.zst) +- [PDB seqres](https://storage.googleapis.com/alphafold-databases/v3.0/pdb_seqres_2022_09_28.fasta.zst) +- [UniProt](https://storage.googleapis.com/alphafold-databases/v3.0/uniprot_all_2021_04.fa.zst) +- [uniref90](https://storage.googleapis.com/alphafold-databases/v3.0/uniref90_2022_05.fa.zst) +- [NT](https://storage.googleapis.com/alphafold-databases/v3.0/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RFam](https://storage.googleapis.com/alphafold-databases/v3.0/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst) +- [RNACentral](https://storage.googleapis.com/alphafold-databases/v3.0/rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst) + +Make sure having enough space on disk: + +| DataBase | Compressed Size | Uncompressed Size| +|--------------|---------------------|------------------| +| mmcif | 233G | 233G | +| BFD | 9.2G | 16.9G | +| MGnify | 64.5G | 119G | +| PDB seqres| 25.3M | 217M | +| UniProt | 45.3G | 101G | +| uniref90 | 30.9G | 66.8G | +| NT | 15.8G | 75.4G | +| RFam | 53.9M | 217M | +| RNACentral| 3.27G | 12.9G | +| total | 402G | 534G | + +Uncompressing the following database file: + +```bash +cd /PATH/TO/YOUR/DATA_DIR +tar –use-compress-program=unzstd -xf pdb_2022_09_28_mmcif_files.tar.zst +zstd -d bfd-first_non_consensus_sequences.fasta.zst +zstd -d mgy_clusters_2022_05.fa.zst +zstd -d pdb_seqres_2022_09_28.fasta.zst +zstd -d uniprot_all_2021_04.fa.zst +zstd -d uniref90_2022_05.fa.zst +zstd -d nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta.zst +zstd -d rnacentral_active_seq_id_90_cov_80_linclust.fasta.zst +``` + +Run the pipeline with the following command + +```bash +python run_alphafold.py \ + --json_path=input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true\ + --db_dir=/PATH/TO/DB_DIR +``` + +No need to set ```--db_dir=/PATH/TO/DB_DIR``` if all the database are put in /mindscience/MindSPONGE/applications/research/AlphaFold3/public_databases + +## Quick Start + +### Input Structure + +Example Input JSON: + +```json +{ + "name": "5tgy", + "sequences": [ + { + "protein": { + "id": "A", + "sequence": "SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN" + } + } + ], + "modelSeeds": [1], + "dialect": "alphafold3", + "version": 1 +} +``` + +### Running Pipeline + +AlphaFold3 can be run with the following command. + +```bash +bash set_path.sh +python run_alphafold.py \ + --json_path=input.json \ + --output_dir=output \ + --run_data_pipeline=true \ + --run_inference=true +``` + +- **JSON Input**: Contains sequence information of proteins, nucleic acids, etc. +- **CIF Output**: 5 Standard protein structure files and confidence info. + +When you see the following log,the inference finished correctly: + +```text +=======write output to /PATH/TO/OUTPUT/DIR/name_of_your_input========== +Done processing fold input name_of_your_input. +Done processing 1 fold inputs. +``` + +### Parameter Introduction + +- `--json_path`: Name of input json +- `--output`: Output direction +- `--model_dir`: Path to ckpt +- `--bucket`: Sequence length +- `--run_data_pipeline`: run data-pipeline or not +- `--run_inference`: run inference or not +- `--db_dir`: path to database + +## License + +See the [LICENSE](LICENSE) file for details. + +## Acknowledgments + +- The implementation of Modules including: data,structure,common, constant referes to [DeepMind](https://github.com/google-deepmind/alphafold3). +- The implementation of Modules including: model,utils are based on [MindScience](https://gitee.com/mindspore/mindscience/) + +## Reference + +- Abramson J, Adler J, Dunger J, et al. Accurate structure prediction of biomolecular interactions with AlphaFold 3[J]. Nature, 2024, 630(8016): 493-500. diff --git a/MindSPONGE/applications/research/AlphaFold3/requirements.txt b/MindSPONGE/applications/research/AlphaFold3/requirements.txt new file mode 100644 index 000000000..1c230c665 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/requirements.txt @@ -0,0 +1,6 @@ +mindSpore==2.5.0 +absl-py==2.1.0 +numpy==1.26.0 +rdkit==2024.3.5 +scipy==1.14.1 +tqdm==4.67.0 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py b/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py new file mode 100644 index 000000000..65f9376dd --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/run_alphafold.py @@ -0,0 +1,680 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Callable, Iterable, Sequence +import csv +import dataclasses +import datetime +import functools +import multiprocessing +import os +import pathlib +import shutil +import string +import textwrap +import time +import typing +from typing import Protocol, Self, TypeVar, overload + +from absl import app +from absl import flags +from alphafold3.common import base_config +from alphafold3.common import folding_input +from alphafold3.common import resources +from alphafold3.constants import chemical_components +import alphafold3.cpp +from alphafold3.data import featurisation +from alphafold3.data import pipeline +from alphafold3.utils.attention import attention +from alphafold3.model import features +from alphafold3.model.diffusion.load_ckpt import load_diffuser +# from alphafold3.model import params +from alphafold3.model import post_processing +from alphafold3.model.components import base_model +from alphafold3.model.components import utils +from alphafold3.model.diffusion import model as diffusion_model +from alphafold3.model.feat_batch import Batch +import mindspore as ms +import numpy as np + + +_HOME_DIR = pathlib.Path(os.environ.get('HOME')) +_DEFAULT_MODEL_DIR = _HOME_DIR / 'ckpt' +_DEFAULT_DB_DIR = _HOME_DIR / 'public_databases' + + +# Input and output paths. +_JSON_PATH = flags.DEFINE_string( + 'json_path', + None, + 'Path to the input JSON file.', +) +_INPUT_DIR = flags.DEFINE_string( + 'input_dir', + None, + 'Path to the directory containing input JSON files.', +) +_OUTPUT_DIR = flags.DEFINE_string( + 'output_dir', + None, + 'Path to a directory where the results will be saved.', +) +MODEL_DIR = flags.DEFINE_string( + 'model_dir', + _DEFAULT_MODEL_DIR.as_posix(), + 'Path to the model to use for inference.', +) + +# Control which stages to run. +_RUN_DATA_PIPELINE = flags.DEFINE_bool( + 'run_data_pipeline', + True, + 'Whether to run the data pipeline on the fold inputs.', +) +_RUN_INFERENCE = flags.DEFINE_bool( + 'run_inference', + True, + 'Whether to run inference on the fold inputs.', +) + +# Binary paths. +_JACKHMMER_BINARY_PATH = flags.DEFINE_string( + 'jackhmmer_binary_path', + shutil.which('jackhmmer'), + 'Path to the Jackhmmer binary.', +) +_NHMMER_BINARY_PATH = flags.DEFINE_string( + 'nhmmer_binary_path', + shutil.which('nhmmer'), + 'Path to the Nhmmer binary.', +) +_HMMALIGN_BINARY_PATH = flags.DEFINE_string( + 'hmmalign_binary_path', + shutil.which('hmmalign'), + 'Path to the Hmmalign binary.', +) +_HMMSEARCH_BINARY_PATH = flags.DEFINE_string( + 'hmmsearch_binary_path', + shutil.which('hmmsearch'), + 'Path to the Hmmsearch binary.', +) +_HMMBUILD_BINARY_PATH = flags.DEFINE_string( + 'hmmbuild_binary_path', + shutil.which('hmmbuild'), + 'Path to the Hmmbuild binary.', +) + +# Database paths. +DB_DIR = flags.DEFINE_multi_string( + 'db_dir', + (_DEFAULT_DB_DIR.as_posix(),), + 'Path to the directory containing the databases. Can be specified multiple' + ' times to search multiple directories in order.', +) + +_SMALL_BFD_DATABASE_PATH = flags.DEFINE_string( + 'small_bfd_database_path', + '${DB_DIR}/bfd-first_non_consensus_sequences.fasta', + 'Small BFD database path, used for protein MSA search.', +) +_MGNIFY_DATABASE_PATH = flags.DEFINE_string( + 'mgnify_database_path', + '${DB_DIR}/mgy_clusters_2022_05.fa', + 'Mgnify database path, used for protein MSA search.', +) +_UNIPROT_CLUSTER_ANNOT_DATABASE_PATH = flags.DEFINE_string( + 'uniprot_cluster_annot_database_path', + '${DB_DIR}/uniprot_all_2021_04.fa', + 'UniProt database path, used for protein paired MSA search.', +) +_UNIREF90_DATABASE_PATH = flags.DEFINE_string( + 'uniref90_database_path', + '${DB_DIR}/uniref90_2022_05.fa', + 'UniRef90 database path, used for MSA search. The MSA obtained by ' + 'searching it is used to construct the profile for template search.', +) +_NTRNA_DATABASE_PATH = flags.DEFINE_string( + 'ntrna_database_path', + '${DB_DIR}/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta', + 'NT-RNA database path, used for RNA MSA search.', +) +_RFAM_DATABASE_PATH = flags.DEFINE_string( + 'rfam_database_path', + '${DB_DIR}/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta', + 'Rfam database path, used for RNA MSA search.', +) +_RNA_CENTRAL_DATABASE_PATH = flags.DEFINE_string( + 'rna_central_database_path', + '${DB_DIR}/rnacentral_active_seq_id_90_cov_80_linclust.fasta', + 'RNAcentral database path, used for RNA MSA search.', +) +_PDB_DATABASE_PATH = flags.DEFINE_string( + 'pdb_database_path', + '${DB_DIR}/mmcif_files', + 'PDB database directory with mmCIF files path, used for template search.', +) +_SEQRES_DATABASE_PATH = flags.DEFINE_string( + 'seqres_database_path', + '${DB_DIR}/pdb_seqres_2022_09_28.fasta', + 'PDB sequence database path, used for template search.', +) + +# Number of CPUs to use for MSA tools. +_JACKHMMER_N_CPU = flags.DEFINE_integer( + 'jackhmmer_n_cpu', + min(multiprocessing.cpu_count(), 8), + 'Number of CPUs to use for Jackhmmer. Default to min(cpu_count, 8). Going' + ' beyond 8 CPUs provides very little additional speedup.', +) +_NHMMER_N_CPU = flags.DEFINE_integer( + 'nhmmer_n_cpu', + min(multiprocessing.cpu_count(), 8), + 'Number of CPUs to use for Nhmmer. Default to min(cpu_count, 8). Going' + ' beyond 8 CPUs provides very little additional speedup.', +) + +# Template search configuration. +_MAX_TEMPLATE_DATE = flags.DEFINE_string( + 'max_template_date', + '2021-09-30', # By default, use the date from the AlphaFold 3 paper. + 'Maximum template release date to consider. Format: YYYY-MM-DD. All ' + 'templates released after this date will be ignored.', +) + + +_BUCKETS = flags.DEFINE_list( + 'buckets', + # pyformat: disable + ['256', '512', '768', '1024', '1280', '1536', '2048', '2560', '3072', + '3584', '4096', '4608', '5120'], + # pyformat: enable + 'Strictly increasing order of token sizes for which to cache compilations.' + ' For any input with more tokens than the largest bucket size, a new bucket' + ' is created for exactly that number of tokens.', +) +_FLASH_ATTENTION_IMPLEMENTATION = flags.DEFINE_enum( + 'flash_attention_implementation', + default='ms', + enum_values=['ms'], + help=( + "Flash attention implementation to use. 'triton' and 'cudnn' uses a" + ' Triton and cuDNN flash attention implementation, respectively. The' + ' Triton kernel is fastest and has been tested more thoroughly. The' + " Triton and cuDNN kernels require Ampere GPUs or later. 'xla' uses an" + ' XLA attention implementation (no flash attention) and is portable' + ' across GPU devices.' + ), +) + + +class ConfigurableModel(Protocol): + """A model with a nested config class.""" + + class Config(base_config.BaseConfig): + ... + + def __call__(self, config: Config) -> Self: + ... + + @classmethod + def get_inference_result( + cls: Self, + batch: features.BatchDict, + result: base_model.ModelResult, + target_name: str = '', + ) -> Iterable[base_model.InferenceResult]: + ... + + +ModelT = TypeVar('ModelT', bound=ConfigurableModel) + +def make_model_config(): + print('not implemented make_model_config') + return 'ab' +def make_model_config( + *, + model_class: type[ModelT] = diffusion_model.Diffuser, + flash_attention_implementation: attention.Implementation = 'ms', +): + config = model_class.Config() + if hasattr(config, '_configglobal'): + config.global_config.flash_attention_implementation = ( + flash_attention_implementation + ) + return config + + +class ModelRunner: + """Helper class to run structure prediction stages.""" + + def __init__( + self, + model_class: ConfigurableModel, + config: base_config.BaseConfig, + model_dir: pathlib.Path, + ): + self._model_class = model_class + self._model_config = config + self._model_dir = model_dir + + @functools.cached_property + def model_params(self): + """Loads model parameters from the model directory.""" + # Load parameters from checkpoint file + # param_dict = ms.load_checkpoint(self._model_dir / "test.ckpt") + # return param_dict + + @functools.cached_property + def _model( + self + ) -> Callable[[np.ndarray, features.BatchDict], base_model.ModelResult]: + """Loads model parameters and returns a model forward pass.""" + assert isinstance(self._model_config, self._model_class.Config) + + def forward_fn(batch): + num_residues = batch.token_features.residue_index.shape[0] + model = self._model_class(self._model_config, 447, (256, 447), (num_residues, 256, 128), (256, 256, 128), (256, 384), + (256, 24, 3), 128, 4, dtype=ms.float32) + load_diffuser(model, self._model_dir, dtype=ms.float32) + res = model(batch, 42) + return res + + return forward_fn + + def run_inference( + self, featurised_example: features.BatchDict + ) -> base_model.ModelResult: + """Computes a forward pass of the model on a featurised example.""" + featurised_example = Batch.from_data_dict(featurised_example) + featurised_example.convert_to_tensor(ms.float32) + + result = self._model(featurised_example) + + + # Convert identifier to bytes + if '__identifier__' in result: + result['__identifier__'] = result['__identifier__'].tobytes() + return result + + def extract_structures( + self, + batch: features.BatchDict, + result: base_model.ModelResult, + target_name: str, + ) -> list[base_model.InferenceResult]: + """Generates structures from model outputs.""" + batch = Batch.from_data_dict(batch) + batch.convert_to_tensor(ms.float32) + return list( + self._model_class.get_inference_result( + batch=batch, result=result, target_name=target_name + ) + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ResultsForSeed: + """Stores the inference results (diffusion samples) for a single seed. + + Attributes: + seed: The seed used to generate the samples. + inference_results: The inference results, one per sample. + full_fold_input: The fold input that must also include the results of + running the data pipeline - MSA and templates. + """ + + seed: int + inference_results: Sequence[base_model.InferenceResult] + full_fold_input: folding_input.Input + + +def predict_structure( + fold_input: folding_input.Input, + model_runner: ModelRunner, + buckets: Sequence[int] | None = None, +) -> Sequence[ResultsForSeed]: + """Runs the full inference pipeline to predict structures for each seed.""" + + print(f'Featurising data for seeds {fold_input.rng_seeds}...') + featurisation_start_time = time.time() + ccd = chemical_components.cached_ccd(user_ccd=fold_input.user_ccd) + featurised_examples = featurisation.featurise_input( + fold_input=fold_input, buckets=buckets, ccd=ccd, verbose=True + ) + print( + f'Featurising data for seeds {fold_input.rng_seeds} took ' + f' {time.time() - featurisation_start_time:.2f} seconds.' + ) + all_inference_start_time = time.time() + all_inference_results = [] + for seed, example in zip(fold_input.rng_seeds, featurised_examples): + print(f'Running model inference for seed {seed}...') + inference_start_time = time.time() + result = model_runner.run_inference(example) + print( + f'Running model inference for seed {seed} took ' + f' {time.time() - inference_start_time:.2f} seconds.' + ) + print(f'Extracting output structures (one per sample) for seed {seed}...') + extract_structures = time.time() + inference_results = model_runner.extract_structures( + batch=example, result=result, target_name=fold_input.name + ) + print( + f'Extracting output structures (one per sample) for seed {seed} took ' + f' {time.time() - extract_structures:.2f} seconds.' + ) + all_inference_results.append( + ResultsForSeed( + seed=seed, + inference_results=inference_results, + full_fold_input=fold_input, + ) + ) + print( + 'Running model inference and extracting output structures for seed' + f' {seed} took {time.time() - inference_start_time:.2f} seconds.' + ) + print( + 'Running model inference and extracting output structures for seeds' + f' {fold_input.rng_seeds} took ' + f' {time.time() - all_inference_start_time:.2f} seconds.' + ) + return all_inference_results + + +def write_fold_input_json( + fold_input: folding_input.Input, + output_dir: os.PathLike[str] | str, +) -> None: + """Writes the input JSON to the output directory.""" + os.makedirs(output_dir, exist_ok=True) + with open( + os.path.join(output_dir, f'{fold_input.sanitised_name()}_data.json'), 'wt' + ) as f: + f.write(fold_input.to_json()) + + +def write_outputs( + all_inference_results: Sequence[ResultsForSeed], + output_dir: os.PathLike[str] | str, + job_name: str, +) -> None: + """Writes outputs to the specified output directory.""" + ranking_scores = [] + max_ranking_score = None + max_ranking_result = None + + + os.makedirs(output_dir, exist_ok=True) + for results_for_seed in all_inference_results: + seed = results_for_seed.seed + for sample_idx, result in enumerate(results_for_seed.inference_results): + sample_dir = os.path.join(output_dir, f'seed-{seed}_sample-{sample_idx}') + os.makedirs(sample_dir, exist_ok=True) + post_processing.write_output( + inference_result=result, output_dir=sample_dir + ) + ranking_score = float(result.metadata['ranking_score']) + ranking_scores.append((seed, sample_idx, ranking_score)) + if max_ranking_score is None or ranking_score > max_ranking_score: + max_ranking_score = ranking_score + max_ranking_result = result + + if max_ranking_result is not None: # True iff ranking_scores non-empty. + post_processing.write_output( + inference_result=max_ranking_result, + output_dir=output_dir, + # The output terms of use are the same for all seeds/samples. + # terms_of_use=output_terms, + terms_of_use=None, + name=job_name, + ) + # Save csv of ranking scores with seeds and sample indices, to allow easier + # comparison of ranking scores across different runs. + with open(os.path.join(output_dir, 'ranking_scores.csv'), 'wt') as f: + writer = csv.writer(f) + writer.writerow(['seed', 'sample', 'ranking_score']) + writer.writerows(ranking_scores) + + +@overload +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: None, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> folding_input.Input: + ... + + +@overload +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: ModelRunner, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> Sequence[ResultsForSeed]: + ... + + +def replace_db_dir(path_with_db_dir: str, db_dirs: Sequence[str]) -> str: + """Replaces the DB_DIR placeholder in a path with the given DB_DIR.""" + template = string.Template(path_with_db_dir) + if 'DB_DIR' in template.get_identifiers(): + for db_dir in db_dirs: + path = template.substitute(DB_DIR=db_dir) + if os.path.exists(path): + return path + raise FileNotFoundError( + f'{path_with_db_dir} with ${{DB_DIR}} not found in any of {db_dirs}.' + ) + if not os.path.exists(path_with_db_dir): + raise FileNotFoundError(f'{path_with_db_dir} does not exist.') + return path_with_db_dir + + +def process_fold_input( + fold_input: folding_input.Input, + data_pipeline_config: pipeline.DataPipelineConfig | None, + model_runner: ModelRunner | None, + output_dir: os.PathLike[str] | str, + buckets: Sequence[int] | None = None, +) -> folding_input.Input | Sequence[ResultsForSeed]: + """Runs data pipeline and/or inference on a single fold input. + + Args: + fold_input: Fold input to process. + data_pipeline_config: Data pipeline config to use. If None, skip the data + pipeline. + model_runner: Model runner to use. If None, skip inference. + output_dir: Output directory to write to. + buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation + of the model. If None, calculate the appropriate bucket size from the + number of tokens. If not None, must be a sequence of at least one integer, + in strictly increasing order. Will raise an error if the number of tokens + is more than the largest bucket size. + + Returns: + The processed fold input, or the inference results for each seed. + + Raises: + ValueError: If the fold input has no chains. + """ + print(f'Processing fold input {fold_input.name}') + + if not fold_input.chains: + raise ValueError('Fold input has no chains.') + + if os.path.exists(output_dir) and os.listdir(output_dir): + new_output_dir = ( + f'{output_dir}_{datetime.datetime.now().strftime("%Y%m%d_%H%M%S")}' + ) + print( + f'Output directory {output_dir} exists and non-empty, using instead ' + f' {new_output_dir}.' + ) + output_dir = new_output_dir + + if model_runner is not None: + # If we're running inference, check we can load the model parameters before + # (possibly) launching the data pipeline. + print('Checking we can load the model parameters...') + _ = model_runner.model_params + + if data_pipeline_config is None: + print('Skipping data pipeline...') + else: + print('Running data pipeline...') + fold_input = pipeline.DataPipeline(data_pipeline_config).process(fold_input) + + print(f'Output directory: {output_dir}') + print(f'Writing model input JSON to {output_dir}') + write_fold_input_json(fold_input, output_dir) + if model_runner is None: + print('Skipping inference...') + output = fold_input + else: + print( + f'Predicting 3D structure for {fold_input.name} for seed(s)' + f' {fold_input.rng_seeds}...' + ) + all_inference_results = predict_structure( + fold_input=fold_input, + model_runner=model_runner, + buckets=buckets, + ) + print( + f'Writing outputs for {fold_input.name} for seed(s)' + f' {fold_input.rng_seeds}...' + ) + write_outputs( + all_inference_results=all_inference_results, + output_dir=output_dir, + job_name=fold_input.sanitised_name(), + ) + output = all_inference_results + + print(f'Done processing fold input {fold_input.name}.') + return output + + +def main(_): + + if _JSON_PATH.value is None == _INPUT_DIR.value is None: + raise ValueError( + 'Exactly one of --json_path or --input_dir must be specified.' + ) + + if not _RUN_INFERENCE.value and not _RUN_DATA_PIPELINE.value: + raise ValueError( + 'At least one of --run_inference or --run_data_pipeline must be' + ' set to true.' + ) + + if _INPUT_DIR.value is not None: + fold_inputs = folding_input.load_fold_inputs_from_dir( + pathlib.Path(_INPUT_DIR.value) + ) + elif _JSON_PATH.value is not None: + fold_inputs = folding_input.load_fold_inputs_from_path( + pathlib.Path(_JSON_PATH.value) + ) + else: + raise AssertionError( + 'Exactly one of --json_path or --input_dir must be specified.' + ) + + # Make sure we can create the output directory before running anything. + try: + os.makedirs(_OUTPUT_DIR.value, exist_ok=True) + except OSError as e: + print(f'Failed to create output directory {_OUTPUT_DIR.value}: {e}') + raise + + + notice = textwrap.wrap( + 'Running AlphaFold 3. Please note that standard AlphaFold 3 model' + ' parameters are only available under terms of use provided at' + ' https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.' + ' If you do not agree to these terms and are using AlphaFold 3 derived' + ' model parameters, cancel execution of AlphaFold 3 inference with' + ' CTRL-C, and do not use the model parameters.', + break_long_words=False, + break_on_hyphens=False, + width=80, + ) + print('\n'.join(notice)) + + if _RUN_DATA_PIPELINE.value: + expand_path = lambda x: replace_db_dir(x, DB_DIR.value) + max_template_date = datetime.date.fromisoformat(_MAX_TEMPLATE_DATE.value) + data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH.value, + nhmmer_binary_path=_NHMMER_BINARY_PATH.value, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH.value, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH.value, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH.value, + small_bfd_database_path=expand_path(_SMALL_BFD_DATABASE_PATH.value), + mgnify_database_path=expand_path(_MGNIFY_DATABASE_PATH.value), + uniprot_cluster_annot_database_path=expand_path( + _UNIPROT_CLUSTER_ANNOT_DATABASE_PATH.value + ), + uniref90_database_path=expand_path(_UNIREF90_DATABASE_PATH.value), + ntrna_database_path=expand_path(_NTRNA_DATABASE_PATH.value), + rfam_database_path=expand_path(_RFAM_DATABASE_PATH.value), + rna_central_database_path=expand_path(_RNA_CENTRAL_DATABASE_PATH.value), + pdb_database_path=expand_path(_PDB_DATABASE_PATH.value), + seqres_database_path=expand_path(_SEQRES_DATABASE_PATH.value), + jackhmmer_n_cpu=_JACKHMMER_N_CPU.value, + nhmmer_n_cpu=_NHMMER_N_CPU.value, + max_template_date=max_template_date, + ) + else: + print('Skipping running the data pipeline.') + data_pipeline_config = None + + if _RUN_INFERENCE.value: + print('Building model from scratch...') + model_runner = ModelRunner( + model_class=diffusion_model.Diffuser, + config=make_model_config( + flash_attention_implementation=typing.cast( + attention.Implementation, _FLASH_ATTENTION_IMPLEMENTATION.value + ) + ), + model_dir=pathlib.Path(MODEL_DIR.value), + ) + else: + print('Skipping running model inference.') + model_runner = None + + print(f'Processing {len(fold_inputs)} fold inputs.') + for fold_input in fold_inputs: + process_fold_input( + fold_input=fold_input, + data_pipeline_config=data_pipeline_config, + model_runner=model_runner, + output_dir=os.path.join(_OUTPUT_DIR.value, fold_input.sanitised_name()), + buckets=tuple(int(bucket) for bucket in _BUCKETS.value), + ) + + print(f'Done processing {len(fold_inputs)} fold inputs.') + + +if __name__ == '__main__': + flags.mark_flags_as_required([ + 'output_dir', + ]) + app.run(main) \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/run_alphafold_data_test.py b/MindSPONGE/applications/research/AlphaFold3/run_alphafold_data_test.py new file mode 100644 index 000000000..1ace86666 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/run_alphafold_data_test.py @@ -0,0 +1,299 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Tests the AlphaFold 3 data pipeline.""" + +import contextlib +import datetime +import difflib +import functools +import hashlib +import json +import os +import pathlib +import pickle +import shutil +from typing import Any +from absl.testing import absltest +from absl.testing import parameterized +import numpy as np +import mindspore as ms +import run_alphafold +from alphafold3 import structure +from alphafold3.common import folding_input +from alphafold3.common import resources +from alphafold3.common.testing import data as testing_data +from alphafold3.constants import chemical_components +from alphafold3.data import featurisation +from alphafold3.data import pipeline +from alphafold3.model.atom_layout import atom_layout +from alphafold3.structure import test_utils + + +_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer') +_NHMMER_BINARY_PATH = shutil.which('nhmmer') +_HMMALIGN_BINARY_PATH = shutil.which('hmmalign') +_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch') +_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild') + + +@contextlib.contextmanager +def _output(name: str): + result_path = f'{absltest.TEST_TMPDIR.value}/{name}' + with open(result_path, "wb") as f: + yield result_path, f + + +@functools.singledispatch +def _hash_data(x: Any, /) -> str: + if x is None: + return '<>' + return _hash_data(json.dumps(x).encode('utf-8')) + + +@_hash_data.register +def _(x: bytes, /) -> str: + return hashlib.sha256(x).hexdigest() + + +@_hash_data.register +def _(x: ms.Tensor) -> str: + return _hash_data(x.asnumpy()) + + +@_hash_data.register +def _(x: np.ndarray) -> str: + if x.dtype == object: + return ';'.join(map(_hash_data, x.ravel().tolist())) + return _hash_data(x.tobytes()) + + +@_hash_data.register +def _(_: structure.Structure) -> str: + return '<>' + + +@_hash_data.register +def _(_: atom_layout.AtomLayout) -> str: + return '<>' + + +def _generate_diff(actual: str, expected: str) -> str: + return '\n'.join( + difflib.unified_diff( + expected.split('\n'), + actual.split('\n'), + fromfile='expected', + tofile='actual', + lineterm='', + ) + ) + + +def tree_map(func, dict_tree): + if isinstance(dict_tree, dict): + return {k: tree_map(func, v) for k, v in dict_tree.items()} + else: + if func == "asnumpy": + return dict_tree.asnumpy() + elif func == "float32": + return dict_tree.astype(ms.float32) + elif func == "bfloat16": + return dict_tree.astype(ms.bfloat16) + else: + return func(dict_tree) + + +class DataPipelineTest(test_utils.StructureTestCase): + """Test AlphaFold 3 inference.""" + + def setUp(self): + super().setUp() + small_bfd_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta' + ).path() + mgnify_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta' + ).path() + uniref90_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniref90__subsampled_1000.fasta' + ).path() + ntrna_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta' + ).path() + rfam_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta' + ).path() + rna_central_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta' + ).path() + pdb_database_path = testing_data.Data( + resources.ROOT / 'test_data/miniature_databases/pdb_mmcif' + ).path() + seqres_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta' + ).path() + + self._data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH, + nhmmer_binary_path=_NHMMER_BINARY_PATH, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH, + small_bfd_database_path=small_bfd_database_path, + mgnify_database_path=mgnify_database_path, + uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path, + uniref90_database_path=uniref90_database_path, + ntrna_database_path=ntrna_database_path, + rfam_database_path=rfam_database_path, + rna_central_database_path=rna_central_database_path, + pdb_database_path=pdb_database_path, + seqres_database_path=seqres_database_path, + max_template_date=datetime.date(2021, 9, 30), + ) + test_input = { + 'name': '5tgy', + 'modelSeeds': [1234], + 'sequences': [ + { + 'protein': { + 'id': 'A', + 'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREA\ + IDKGDKDSLEQLLEELEQALQKIRELAEKKN', + 'modifications': [], + 'unpairedMsa': None, + 'pairedMsa': None, + } + }, + {'ligand': {'id': 'B', 'ccdCodes': ['7BU']}}, + ], + 'dialect': folding_input.JSON_DIALECT, + 'version': folding_input.JSON_VERSION, + } + self._test_input_json = json.dumps(test_input) + + def compare_golden(self, result_path: str) -> None: + filename = os.path.split(result_path)[1] + golden_path = testing_data.Data( + resources.ROOT / f'test_data/{filename}' + ).path() + with open(golden_path, 'r') as golden_file: + golden_text = golden_file.read() + with open(result_path, 'r') as result_file: + result_text = result_file.read() + + diff = _generate_diff(result_text, golden_text) + + self.assertEqual(diff, "", f"Result differs from golden:\n{diff}") + + def test_config(self): + model_config = run_alphafold.make_model_config() + model_config_as_str = json.dumps( + model_config.as_dict(), sort_keys=True, indent=2 + ) + with _output('model_config.json') as (result_path, output): + output.write(model_config_as_str.encode('utf-8')) + self.compare_golden(result_path) + + def test_featurisation(self): + """Run featurisation and assert that the output is as expected.""" + fold_input = folding_input.Input.from_json(self._test_input_json) + data_pipeline = pipeline.DataPipeline(self._data_pipeline_config) + full_fold_input = data_pipeline.process(fold_input) + featurised_example = featurisation.featurise_input( + full_fold_input, + ccd=chemical_components.cached_ccd(), + buckets=None, + ) + + del featurised_example[0]['ref_pos'] + + with _output('featurised_example.pkl') as (_, output): + output.write(pickle.dumps(featurised_example)) + featurised_example = [tree_map(_hash_data, featurised_example[0])] + with _output('featurised_example.json') as (result_path, output): + output.write( + json.dumps(featurised_example, sort_keys=True, indent=2).encode( + 'utf-8' + ) + ) + self.compare_golden(result_path) + + def test_write_input_json(self): + fold_input = folding_input.Input.from_json(self._test_input_json) + output_dir = self.create_tempdir().full_path + run_alphafold.write_fold_input_json(fold_input, output_dir) + with open( + os.path.join( + output_dir, f'{fold_input.sanitised_name()}_data.json'), + 'rt', + ) as f: + actual_fold_input = folding_input.Input.from_json(f.read()) + + self.assertEqual(actual_fold_input, fold_input) + + def test_process_fold_input_runs_only_data_pipeline(self): + fold_input = folding_input.Input.from_json(self._test_input_json) + output_dir = self.create_tempdir().full_path + run_alphafold.process_fold_input( + fold_input=fold_input, + data_pipeline_config=self._data_pipeline_config, + model_runner=None, + output_dir=output_dir, + ) + with open( + os.path.join( + output_dir, f'{fold_input.sanitised_name()}_data.json'), + 'rt', + ) as f: + actual_fold_input = folding_input.Input.from_json(f.read()) + + featurisation.validate_fold_input(actual_fold_input) + + @parameterized.product(num_db_dirs=tuple(range(1, 3))) + def test_replace_db_dir(self, num_db_dirs: int) -> None: + """Test that the db_dir is replaced correctly.""" + db_dirs = [pathlib.Path(self.create_tempdir()) + for _ in range(num_db_dirs)] + db_dirs_posix = [db_dir.as_posix() for db_dir in db_dirs] + + for i, db_dir in enumerate(db_dirs): + for j in range(i + 1): + (db_dir / f'filename{j}.txt').write_text(f'hello world {i}') + + for i in range(num_db_dirs): + self.assertEqual( + pathlib.Path( + run_alphafold.replace_db_dir( + f'${{DB_DIR}}/filename{i}.txt', db_dirs_posix + ) + ).read_text(), + f'hello world {i}', + ) + with self.assertRaises(FileNotFoundError): + run_alphafold.replace_db_dir( + f'${{DB_DIR}}/filename{num_db_dirs}.txt', db_dirs_posix + ) + + +if __name__ == '__main__': + absltest.main() diff --git a/MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py b/MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py new file mode 100644 index 000000000..5698056b7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/run_alphafold_test_v2.py @@ -0,0 +1,381 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Tests end-to-end running of AlphaFold 3.""" + +import contextlib +import csv +import datetime +import difflib +import functools +import hashlib +import json +import os +import pathlib +import pickle +from typing import Any + +from absl import logging +from absl.testing import absltest +from absl.testing import parameterized +from alphafold3 import structure +from alphafold3.common import folding_input +from alphafold3.common import resources +from alphafold3.common.testing import data as testing_data +from alphafold3.data import pipeline +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.diffusion import model as diffusion_model +from alphafold3.model.scoring import alignment +from alphafold3.structure import test_utils +import mindspore as ms +import numpy as np + +import run_alphafold +import shutil + + +_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer') +_NHMMER_BINARY_PATH = shutil.which('nhmmer') +_HMMALIGN_BINARY_PATH = shutil.which('hmmalign') +_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch') +_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild') + + +@contextlib.contextmanager +def _output(name: str): + with open(result_path := f'{absltest.TEST_TMPDIR.value}/{name}', "wb") as f: + yield result_path, f + + +def _generate_diff(actual: str, expected: str) -> str: + return '\n'.join( + difflib.unified_diff( + expected.split('\n'), + actual.split('\n'), + fromfile='expected', + tofile='actual', + lineterm='', + ) + ) + + +@functools.singledispatch +def _hash_data(x: Any, /) -> str: + if x is None: + return '<>' + return _hash_data(json.dumps(x).encode('utf-8')) + + +@_hash_data.register +def _(x: bytes, /) -> str: + return hashlib.sha256(x).hexdigest() + + +@_hash_data.register +def _(x: ms.Tensor) -> str: + return _hash_data(x.asnumpy()) + + +@_hash_data.register +def _(x: np.ndarray) -> str: + if x.dtype == object: + return ';'.join(map(_hash_data, x.ravel().tolist())) + return _hash_data(x.tobytes()) + + +@_hash_data.register +def _(_: structure.Structure) -> str: + return '<>' + + +@_hash_data.register +def _(_: atom_layout.AtomLayout) -> str: + return '<>' + +def tree_map(func, dict_tree): + if isinstance(dict_tree, dict): + return {k: tree_map(func, v) for k, v in dict_tree.items()} + else: + if func == "asnumpy": + return dict_tree.asnumpy() + elif func == "float32": + return dict_tree.astype(ms.float32) + elif func == "bfloat16": + return dict_tree.astype(ms.bfloat16) + else: + return func(dict_tree) + +class InferenceTest(test_utils.StructureTestCase): + """Test AlphaFold 3 inference.""" + + def setUp(self): + super().setUp() + small_bfd_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta' + ).path() + mgnify_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta' + ).path() + uniref90_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniref90__subsampled_1000.fasta' + ).path() + ntrna_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta' + ).path() + rfam_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta' + ).path() + rna_central_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta' + ).path() + pdb_database_path = testing_data.Data( + resources.ROOT / 'test_data/miniature_databases/pdb_mmcif' + ).path() + seqres_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta' + ).path() + + self._data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH, + nhmmer_binary_path=_NHMMER_BINARY_PATH, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH, + small_bfd_database_path=small_bfd_database_path, + mgnify_database_path=mgnify_database_path, + uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path, + uniref90_database_path=uniref90_database_path, + ntrna_database_path=ntrna_database_path, + rfam_database_path=rfam_database_path, + rna_central_database_path=rna_central_database_path, + pdb_database_path=pdb_database_path, + seqres_database_path=seqres_database_path, + max_template_date=datetime.date(2021, 9, 30), + ) + test_input = { + 'name': '5tgy', + 'modelSeeds': [1234], + 'sequences': [ + { + 'protein': { + 'id': 'A', + 'sequence': 'SEFEKLRQTGDELVQAFQRLREIFDKGDDDSLEQVLEEIEELIQKHRQLFDNRQEAADTEAAKQGDQWVQLFQRFREAIDKGDKDSLEQLLEELEQALQKIRELAEKKN', + 'modifications': [], + 'unpairedMsa': None, + 'pairedMsa': None, + } + }, + {'ligand': {'id': 'B', 'ccdCodes': ['7BU']}}, + ], + 'dialect': folding_input.JSON_DIALECT, + 'version': folding_input.JSON_VERSION, + } + self._test_input_json = json.dumps(test_input) + + self._runner = run_alphafold.ModelRunner( + model_class=run_alphafold.diffusion_model.Diffuser, + config=run_alphafold.make_model_config(), + model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value), + ) + + def compare_golden(self, result_path: str) -> None: + filename = os.path.split(result_path)[1] + golden_path = testing_data.Data( + resources.ROOT / f'test_data/{filename}' + ).path() + with open(golden_path, 'r') as golden_file: + golden_text = golden_file.read() + with open(result_path, 'r') as result_file: + result_text = result_file.read() + + diff = _generate_diff(result_text, golden_text) + + self.assertEqual(diff, "", f"Result differs from golden:\n{diff}") + + def test_model_inference(self): + """Run model inference and assert that the output is as expected.""" + featurised_examples = pickle.loads( + (resources.ROOT / 'test_data' / 'featurised_example.pkl').read_bytes() + ) + + self.assertLen(featurised_examples, 1) + featurised_example = featurised_examples[0] + inference_result = self._runner.run_inference( + featurised_example + ) + inference_result = tree_map(_hash_data, inference_result) + self.assertIsNotNone(inference_result) + + def test_process_fold_input_runs_only_inference(self): + with self.assertRaisesRegex(ValueError, 'missing unpaired MSA.'): + run_alphafold.process_fold_input( + fold_input=folding_input.Input.from_json(self._test_input_json), + # No data pipeline config, so featursation will run first, and fail + # since the input is missing MSAs. + data_pipeline_config=None, + model_runner=self._runner, + output_dir=self.create_tempdir(cleanup=absltest.TempFileCleanup.OFF).full_path, + ) + + @parameterized.named_parameters( + { + 'testcase_name': 'default_bucket', + 'bucket': None, + 'exp_ranking_scores': [0.69, 0.69, 0.72, 0.75, 0.70], + }, + ) + def test_inference(self, bucket, exp_ranking_scores): + """Run AlphaFold 3 inference.""" + + ### Prepare inputs. + fold_input = folding_input.Input.from_json(self._test_input_json) + + output_dir = self.create_tempdir(cleanup=absltest.TempFileCleanup.OFF).full_path + actual = run_alphafold.process_fold_input( + fold_input, + self._data_pipeline_config, + run_alphafold.ModelRunner( + model_class=diffusion_model.Diffuser, + config=run_alphafold.make_model_config(), + model_dir=pathlib.Path(run_alphafold.MODEL_DIR.value), + ), + output_dir=output_dir, + buckets=None if bucket is None else [bucket], + ) + logging.info('finished get_inference_result') + expected_model_cif_filename = f'{fold_input.sanitised_name()}_model.cif' + expected_summary_confidences_filename = ( + f'{fold_input.sanitised_name()}_summary_confidences.json' + ) + expected_confidences_filename = ( + f'{fold_input.sanitised_name()}_confidences.json' + ) + expected_data_json_filename = f'{fold_input.sanitised_name()}_data.json' + + + self.assertSameElements( + os.listdir(output_dir), + [ + # Subdirectories, one for each sample. + 'seed-1234_sample-0', + 'seed-1234_sample-1', + 'seed-1234_sample-2', + 'seed-1234_sample-3', + 'seed-1234_sample-4', + # Top ranking result. + expected_confidences_filename, + expected_model_cif_filename, + expected_summary_confidences_filename, + # Ranking scores for all samples. + 'ranking_scores.csv', + # The input JSON defining the job. + expected_data_json_filename, + # The output terms of use. + # 'TERMS_OF_USE.md', + ], + ) + + with open(os.path.join(output_dir, expected_data_json_filename), 'rt') as f: + actual_input_json = json.load(f) + + self.assertEqual( + actual_input_json['sequences'][0]['protein']['sequence'], + fold_input.protein_chains[0].sequence, + ) + self.assertSequenceEqual( + actual_input_json['sequences'][1]['ligand']['ccdCodes'], + fold_input.ligands[0].ccd_ids, + ) + self.assertNotEmpty( + actual_input_json['sequences'][0]['protein']['unpairedMsa'] + ) + self.assertNotEmpty( + actual_input_json['sequences'][0]['protein']['pairedMsa'] + ) + self.assertIsNotNone( + actual_input_json['sequences'][0]['protein']['templates'] + ) + + with open(os.path.join(output_dir, 'ranking_scores.csv'), 'rt') as f: + actual_ranking_scores = list(csv.DictReader(f)) + + self.assertLen(actual_ranking_scores, 5) + self.assertEqual( + [int(s['seed']) for s in actual_ranking_scores], [1234] * 5 + ) + self.assertEqual( + [int(s['sample']) for s in actual_ranking_scores], [0, 1, 2, 3, 4] + ) + + bucket_label = 'default' if bucket is None else bucket + output_filename = f'run_alphafold_test_output_bucket_{bucket_label}.pkl' + + # Convert to dict to enable simple serialization. + actual_dict = [ + dict( + seed=actual_inf.seed, + inference_results=actual_inf.inference_results, + full_fold_input=actual_inf.full_fold_input, + ) + for actual_inf in actual + ] + with _output(output_filename) as (_, output): + output.write(pickle.dumps(actual_dict)) + + logging.info('Comparing inference results with expected values.') + + ### Assert that output is as expected. + expected_dict = pickle.loads( + ( + resources.ROOT + / 'test_data' + / 'alphafold_run_outputs' + / output_filename + ).read_bytes() + ) + expected = [ + run_alphafold.ResultsForSeed(**expected_inf) + for expected_inf in expected_dict + ] + for actual_inf, expected_inf in zip(actual, expected, strict=True): + for actual_inf, expected_inf in zip( + actual_inf.inference_results, + expected_inf.inference_results, + strict=True, + ): + + # Check RMSD is within tolerance. + # 5tgy is very stable, NMR samples were all within 3.0 RMSD. + actual_rmsd = alignment.rmsd_from_coords( + actual_inf.predicted_structure.coords, + expected_inf.predicted_structure.coords, + ) + self.assertLess(actual_rmsd, 3.0) + np.testing.assert_array_equal( + actual_inf.predicted_structure.atom_occupancy, + [1.0] * actual_inf.predicted_structure.num_atoms, + ) + + +if __name__ == '__main__': + absltest.main() \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/set_path.sh b/MindSPONGE/applications/research/AlphaFold3/set_path.sh new file mode 100644 index 000000000..9ca81c633 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/set_path.sh @@ -0,0 +1,37 @@ +#!/bin/bash + +# Get the script directory to make paths more reliable +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +# From AlphaFold3 directory, go up to the mindscience directory +MINDSCIENCE_PATH="$(cd "$SCRIPT_DIR/../../../.." && pwd)" + +# Check if the base directory exists +if [ ! -d "$MINDSCIENCE_PATH" ]; then + echo "Error: MindScience path not found: $MINDSCIENCE_PATH" + echo "Please run this script from the correct directory" + exit 1 +fi + +# Function to add to PYTHONPATH if directory exists +add_to_pythonpath() { + local dir_path="$1" + if [ -d "$dir_path" ]; then + export PYTHONPATH="$PYTHONPATH:$dir_path" + echo "Added to PYTHONPATH: $dir_path" + else + echo "Warning: Directory not found, skipping: $dir_path" + fi +} + +add_to_pythonpath "$MINDSCIENCE_PATH/MindSPONGE/src" +add_to_pythonpath "$MINDSCIENCE_PATH/MindChemistry" +add_to_pythonpath "$MINDSCIENCE_PATH/MindSPONGE/applications/research/AlphaFold3" + +# Add directories to PATH +export PATH=$PATH:/hmmer/bin + +# Display current PYTHONPATH +echo "Current PYTHONPATH:" +echo "$PYTHONPATH" | tr ':' '\n' | sed 's/^/ /' + +echo "Environment setup completed." diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py new file mode 100644 index 000000000..ae02eacbf --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/build_data.py @@ -0,0 +1,44 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Script for building intermediate data.""" + +from importlib import resources +import pathlib +import site + +import alphafold3.constants.converters +from alphafold3.constants.converters import ccd_pickle_gen +from alphafold3.constants.converters import chemical_component_sets_gen + + +def build_data(): + """Builds intermediate data.""" + for site_path in site.getsitepackages(): + path = pathlib.Path(site_path) / 'share/libcifpp/components.cif' + if path.exists(): + cif_path = path + break + else: + raise ValueError('Could not find components.cif') + + out_root = resources.files(alphafold3.constants.converters) + ccd_pickle_path = out_root.joinpath('ccd.pickle') + chemical_component_sets_pickle_path = out_root.joinpath( + 'chemical_component_sets.pickle' + ) + ccd_pickle_gen.main(['', str(cif_path), str(ccd_pickle_path)]) + chemical_component_sets_gen.main( + ['', str(chemical_component_sets_pickle_path)] + ) + +if __name__ == '__main__': + build_data() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py new file mode 100644 index 000000000..27f6eba12 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/base_config.py @@ -0,0 +1,151 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ +"""Config for the protein folding model and experiment.""" + +from collections.abc import Mapping +import copy +import dataclasses +import types +import typing +from typing import Any, ClassVar, TypeVar + + +_T = TypeVar('_T') +_ConfigT = TypeVar('_ConfigT', bound='BaseConfig') + + +def _strip_optional(t: type[Any]) -> type[Any]: + """Transforms type annotations of the form `T | None` to `T`.""" + if typing.get_origin(t) in (typing.Union, types.UnionType): + args = set(typing.get_args(t)) - {types.NoneType} + if len(args) == 1: + return args.pop() + return t + + +_NO_UPDATE = object() + + +class _Autocreate: + + def __init__(self, **defaults: Any): + self.defaults = defaults + + +def autocreate(**defaults: Any) -> Any: + """Marks a field as having a default factory derived from its type.""" + return _Autocreate(**defaults) + + +def _clone_field( + field: dataclasses.Field[_T], new_default: _T +) -> dataclasses.Field[_T]: + if new_default is _NO_UPDATE: + return copy.copy(field) + return dataclasses.field( + default=new_default, + init=True, + kw_only=True, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + ) + + +@typing.dataclass_transform() +class ConfigMeta(type): + """Metaclass that synthesizes a __post_init__ that coerces dicts to Config subclass instances.""" + + def __new__(mcs, name, bases, classdict): + cls = super().__new__(mcs, name, bases, classdict) + + def _coercable_fields(self) -> Mapping[str, tuple[ConfigMeta, Any]]: + type_hints = typing.get_type_hints(self.__class__) + fields = dataclasses.fields(self.__class__) + field_to_type_and_default = { + field.name: (_strip_optional( + type_hints[field.name]), field.default) + for field in fields + } + coercable_fields = { + f: t + for f, t in field_to_type_and_default.items() + if issubclass(type(t[0]), ConfigMeta) + } + return coercable_fields + + cls._coercable_fields = property(_coercable_fields) + + old_post_init = getattr(cls, '__post_init__', None) + + def _post_init(self) -> None: + # Use get_type_hints instead of Field.type to ensure that forward + # references are resolved. + for field_name, ( + field_type, + field_default, + ) in self._coercable_fields.items(): # pylint: disable=protected-access + field_value = getattr(self, field_name) + if field_value is None: + continue + try: + match field_value: + case _Autocreate(): + # Construct from field defaults. + setattr(self, field_name, field_type( + **field_value.defaults)) + case Mapping(): + # Field value is not yet a `Config` instance; Assume we can create + # one by splatting keys and values. + args = {} + # Apply default args first, if present. + if isinstance(field_default, _Autocreate): + args.update(field_default.defaults) + args.update(field_value) + setattr(self, field_name, field_type(**args)) + case _: + pass + except TypeError as e: + raise TypeError( + f'Failure while coercing field {field_name!r} of' + f' {self.__class__.__qualname__}' + ) from e + if old_post_init: + old_post_init(self) + + cls.__post_init__ = _post_init + + return dataclasses.dataclass(kw_only=True)(cls) + + +class BaseConfig(metaclass=ConfigMeta): + """Config base class. + + Subclassing Config automatically makes the subclass a kw_only dataclass with + a `__post_init__` that coerces Config-subclass field values from mappings to + instances of the right type. + """ + # Provided by dataclasses.make_dataclass + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + + # Overridden by metaclass + @property + def _coercable_fields(self) -> Mapping[str, tuple[type['BaseConfig'], Any]]: + return {} + + def as_dict(self) -> Mapping[str, Any]: + result = dataclasses.asdict(self) + for field_name in self._coercable_fields: + field_value = getattr(self, field_name, None) + if isinstance(field_value, BaseConfig): + result[field_name] = field_value.as_dict() + return result diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py new file mode 100644 index 000000000..5a39b6597 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/folding_input.py @@ -0,0 +1,1115 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Model input dataclass.""" + +from collections.abc import Collection, Mapping, Sequence +import dataclasses +import json +import logging +import pathlib +import random +import re +import string +from typing_extensions import Any, Final, Self, TypeAlias + +from alphafold3 import structure +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.structure import mmcif as mmcif_lib +import rdkit.Chem as rd_chem + + +BondAtomId: TypeAlias = tuple[str, int, str] + +JSON_DIALECT: Final[str] = 'alphafold3' +JSON_VERSION: Final[int] = 1 + +ALPHAFOLDSERVER_JSON_DIALECT: Final[str] = 'alphafoldserver' +ALPHAFOLDSERVER_JSON_VERSION: Final[int] = 1 + + +def _validate_keys(actual: Collection[str], expected: Collection[str]): + """Validates that the JSON doesn't contain any extra unwanted keys.""" + if bad_keys := set(actual) - set(expected): + raise ValueError( + f'Unexpected JSON keys in: {", ".join(sorted(bad_keys))}') + + +class Template: + """Structural template input.""" + + __slots__ = ('_mmcif', '_query_to_template') + + def __init__(self, mmcif: str, query_to_template_map: Mapping[int, int]): + """Initializes the template. + + Args: + mmcif: The structural template in mmCIF format. The mmCIF should have only + one protein chain. + query_to_template_map: A mapping from query residue index to template + residue index. + """ + self._mmcif = mmcif + # Needed to make the Template class hashable. + self._query_to_template = tuple(query_to_template_map.items()) + + @property + def query_to_template_map(self) -> Mapping[int, int]: + return dict(self._query_to_template) + + @property + def mmcif(self) -> str: + return self._mmcif + + def __hash__(self) -> int: + return hash((self._mmcif, tuple(sorted(self._query_to_template)))) + + def __eq__(self, other: Self) -> bool: + mmcifs_equal = self._mmcif == other._mmcif + maps_equal = sorted(self._query_to_template) == sorted( + other._query_to_template + ) + return mmcifs_equal and maps_equal + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ProteinChain: + """Protein chain input. + + Attributes: + id: Unique protein chain identifier. + sequence: The amino acid sequence of the chain. + ptms: A list of tuples containing the post-translational modification type + and the (1-based) residue index where the modification is applied. + paired_msa: Paired A3M-formatted MSA for this chain. This MSA is not + deduplicated and will be used to compute paired features. If None, this + field is unset and must be filled in by the data pipeline before + featurisation. If set to an empty string, it will be treated as a custom + MSA with no sequences. + unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be + deduplicated and used to compute unpaired features. If None, this field is + unset and must be filled in by the data pipeline before featurisation. If + set to an empty string, it will be treated as a custom MSA with no + sequences. + templates: A list of structural templates for this chain. If None, this + field is unset and must be filled in by the data pipeline before + featurisation. The list can be empty or contain up to 20 templates. + """ + + id: str + sequence: str + ptms: Sequence[tuple[str, int]] + paired_msa: str | None = None + unpaired_msa: str | None = None + templates: Sequence[Template] | None = None + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'Protein must contain only letters, got "{self.sequence}"' + ) + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.ptms): + raise ValueError( + f'Invalid protein modification index: {self.ptms}') + + # Use hashable types for ptms and templates. + if self.ptms is not None: + object.__setattr__(self, 'ptms', tuple(self.ptms)) + if self.templates is not None: + object.__setattr__(self, 'templates', tuple(self.templates)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs ProteinChain from the AlphaFoldServer JSON dict.""" + _validate_keys( + json_dict.keys(), + {'sequence', 'glycans', 'modifications', 'count'}, + ) + sequence = json_dict['sequence'] + + if 'glycans' in json_dict: + raise ValueError( + f'Specifying glycans in the `{ALPHAFOLDSERVER_JSON_DIALECT}` format' + ' is not currently supported.' + ) + + ptms = [ + (mod['ptmType'].removeprefix('CCD_'), mod['ptmPosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, ptms=ptms) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs ProteinChain from the AlphaFold JSON dict.""" + json_dict = json_dict['protein'] + _validate_keys( + json_dict.keys(), + { + 'id', + 'sequence', + 'modifications', + 'unpairedMsa', + 'pairedMsa', + 'templates', + }, + ) + + sequence = json_dict['sequence'] + ptms = [ + (mod['ptmType'], mod['ptmPosition']) + for mod in json_dict.get('modifications', []) + ] + + unpaired_msa = json_dict.get('unpairedMsa', None) + paired_msa = json_dict.get('pairedMsa', None) + + raw_templates = json_dict.get('templates', None) + + if raw_templates is None: + templates = None + else: + templates = [ + Template( + mmcif=template['mmcif'], + query_to_template_map=dict( + zip(template['queryIndices'], + template['templateIndices']) + ), + ) + for template in raw_templates + ] + + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + ptms=ptms, + paired_msa=paired_msa, + unpaired_msa=unpaired_msa, + templates=templates, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts ProteinChain to an AlphaFold JSON dict.""" + if self.templates is None: + templates = None + else: + templates = [ + { + 'mmcif': template.mmcif, + 'queryIndices': list(template.query_to_template_map.keys()), + 'templateIndices': ( + list(template.query_to_template_map.values()) or None + ), + } + for template in self.templates + ] + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'ptmType': ptm[0], 'ptmPosition': ptm[1]} for ptm in self.ptms + ], + 'unpairedMsa': self.unpaired_msa, + 'pairedMsa': self.paired_msa, + 'templates': templates, + } + return {'protein': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + ccd_coded_seq = [ + residue_names.PROTEIN_COMMON_ONE_TO_THREE.get( + res, residue_names.UNK) + for res in self.sequence + ] + for ptm_code, ptm_index in self.ptms: + ccd_coded_seq[ptm_index - 1] = ptm_code + return ccd_coded_seq + + def fill_missing_fields(self) -> Self: + """Fill missing MSA and template fields with default values.""" + return dataclasses.replace( + self, + unpaired_msa=self.unpaired_msa or '', + paired_msa=self.paired_msa or '', + templates=self.templates or [], + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class RnaChain: + """RNA chain input. + + Attributes: + id: Unique RNA chain identifier. + sequence: The RNA sequence of the chain. + modifications: A list of tuples containing the modification type and the + (1-based) residue index where the modification is applied. + unpaired_msa: Unpaired A3M-formatted MSA for this chain. This will be + deduplicated and used to compute unpaired features. If None, this field is + unset and must be filled in by the data pipeline before featurisation. If + set to an empty string, it will be treated as a custom MSA with no + sequences. + """ + + id: str + sequence: str + modifications: Sequence[tuple[str, int]] + unpaired_msa: str | None = None + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'RNA must contain only letters, got "{self.sequence}"') + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.modifications): + raise ValueError( + f'Invalid RNA modification index: {self.modifications}') + + # Use hashable types for modifications. + object.__setattr__(self, 'modifications', tuple(self.modifications)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs RnaChain from the AlphaFoldServer JSON dict.""" + _validate_keys(json_dict.keys(), { + 'sequence', 'modifications', 'count'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'].removeprefix('CCD_'), mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, modifications=modifications) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs RnaChain from the AlphaFold JSON dict.""" + json_dict = json_dict['rna'] + _validate_keys( + json_dict.keys(), {'id', 'sequence', + 'unpairedMsa', 'modifications'} + ) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'], mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + unpaired_msa = json_dict.get('unpairedMsa', None) + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + modifications=modifications, + unpaired_msa=unpaired_msa, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts RnaChain to an AlphaFold JSON dict.""" + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'modificationType': mod[0], 'basePosition': mod[1]} + for mod in self.modifications + ], + 'unpairedMsa': self.unpaired_msa, + } + return {'rna': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + mapping = { + r: r for r in residue_names.RNA_TYPES} # Same 1-letter and CCD. + ccd_coded_seq = [ + mapping.get(res, residue_names.UNK_RNA) for res in self.sequence + ] + for ccd_code, modification_index in self.modifications: + ccd_coded_seq[modification_index - 1] = ccd_code + return ccd_coded_seq + + def fill_missing_fields(self) -> Self: + """Fill missing MSA fields with default values.""" + return dataclasses.replace(self, unpaired_msa=self.unpaired_msa or '') + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class DnaChain: + """Single strand DNA chain input. + + Attributes: + id: Unique DNA chain identifier. + sequence: The DNA sequence of the chain. + modifications: A list of tuples containing the modification type and the + (1-based) residue index where the modification is applied. + """ + + id: str + sequence: str + modifications: Sequence[tuple[str, int]] + + def __post_init__(self): + if not all(res.isalpha() for res in self.sequence): + raise ValueError( + f'DNA must contain only letters, got "{self.sequence}"') + if any(not 0 < mod[1] <= len(self.sequence) for mod in self.modifications): + raise ValueError( + f'Invalid DNA modification index: {self.modifications}') + + # Use hashable types for modifications. + object.__setattr__(self, 'modifications', tuple(self.modifications)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs DnaChain from the AlphaFoldServer JSON dict.""" + _validate_keys(json_dict.keys(), { + 'sequence', 'modifications', 'count'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'].removeprefix('CCD_'), mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls(id=seq_id, sequence=sequence, modifications=modifications) + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs DnaChain from the AlphaFold JSON dict.""" + json_dict = json_dict['dna'] + _validate_keys(json_dict.keys(), {'id', 'sequence', 'modifications'}) + sequence = json_dict['sequence'] + modifications = [ + (mod['modificationType'], mod['basePosition']) + for mod in json_dict.get('modifications', []) + ] + return cls( + id=seq_id or json_dict['id'], + sequence=sequence, + modifications=modifications, + ) + + def to_dict(self) -> Mapping[str, Mapping[str, Any]]: + """Converts DnaChain to an AlphaFold JSON dict.""" + contents = { + 'id': self.id, + 'sequence': self.sequence, + 'modifications': [ + {'modificationType': mod[0], 'basePosition': mod[1]} + for mod in self.modifications + ], + } + return {'dna': contents} + + def to_ccd_sequence(self) -> Sequence[str]: + """Converts to a sequence of CCD codes.""" + ccd_coded_seq = [ + residue_names.DNA_COMMON_ONE_TO_TWO.get(res, residue_names.UNK_DNA) + for res in self.sequence + ] + for ccd_code, modification_index in self.modifications: + ccd_coded_seq[modification_index - 1] = ccd_code + return ccd_coded_seq + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Ligand: + """Ligand input. + + Attributes: + id: Unique ligand "chain" identifier. + ccd_ids: The Chemical Component Dictionary or user-defined CCD IDs of the + chemical components of the ligand. Typically, this is just a single ID, + but some ligands are composed of multiple components. If that is the case, + a bond linking these components should be added to the bonded_atom_pairs + Input field. + smiles: The SMILES representation of the ligand. + """ + + id: str + ccd_ids: Sequence[str] | None = None + smiles: str | None = None + + def __post_init__(self): + if (self.ccd_ids is None) == (self.smiles is None): + raise ValueError('Ligand must have one of CCD ID or SMILES set.') + + if self.smiles is not None: + mol = rd_chem.MolFromSmiles(self.smiles) + if not mol: + raise ValueError( + f'Unable to make RDKit Mol from SMILES: {self.smiles}') + + # Use hashable types for ccd_ids. + if self.ccd_ids is not None: + object.__setattr__(self, 'ccd_ids', tuple(self.ccd_ids)) + + @classmethod + def from_alphafoldserver_dict( + cls, json_dict: Mapping[str, Any], seq_id: str + ) -> Self: + """Constructs Ligand from the AlphaFoldServer JSON dict.""" + # Ligand can be specified either as a ligand, or ion (special-case). + _validate_keys(json_dict.keys(), {'ligand', 'ion', 'count'}) + if 'ligand' in json_dict: + return cls(id=seq_id, ccd_ids=[json_dict['ligand'].removeprefix('CCD_')]) + elif 'ion' in json_dict: + return cls(id=seq_id, ccd_ids=[json_dict['ion']]) + else: + raise ValueError(f'Unknown ligand type: {json_dict}') + + @classmethod + def from_dict( + cls, json_dict: Mapping[str, Any], seq_id: str | None = None + ) -> Self: + """Constructs Ligand from the AlphaFold JSON dict.""" + json_dict = json_dict['ligand'] + _validate_keys(json_dict.keys(), {'id', 'ccdCodes', 'smiles'}) + if json_dict.get('ccdCodes') and json_dict.get('smiles'): + raise ValueError( + 'Ligand cannot have both CCD code and SMILES set at the same time, ' + f'got CCD: {json_dict["ccdCodes"]} and SMILES: {json_dict["smiles"]}' + ) + + if 'ccdCodes' in json_dict: + return cls(id=seq_id or json_dict['id'], ccd_ids=json_dict['ccdCodes']) + elif 'smiles' in json_dict: + return cls(id=seq_id or json_dict['id'], smiles=json_dict['smiles']) + else: + raise ValueError(f'Unknown ligand type: {json_dict}') + + def to_dict(self) -> Mapping[str, Any]: + """Converts Ligand to an AlphaFold JSON dict.""" + contents = {'id': self.id} + if self.ccd_ids is not None: + contents['ccdCodes'] = self.ccd_ids + if self.smiles is not None: + contents['smiles'] = self.smiles + return {'ligand': contents} + + +def _sample_rng_seed() -> int: + """Sample a random seed for AlphaFoldServer job.""" + # See https://alphafoldserver.com/faq#what-are-seeds-and-how-are-they-set. + return random.randint(0, 2**32 - 1) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Input: + """AlphaFold input. + + Attributes: + name: The name of the target. + chains: Protein chains, RNA chains, DNA chains, or ligands. + protein_chains: Protein chains. + rna_chains: RNA chains. + dna_chains: Single strand DNA chains. + ligands: Ligand (including ion) inputs. + rng_seeds: Random number generator seeds, one for each model execution. + bonded_atom_pairs: A list of tuples of atoms that are bonded to each other. + Each atom is defined by a tuple of (chain_id, res_id, atom_name). Chain + IDs must be set if there are any bonded atoms. Residue IDs are 1-indexed. + Atoms in ligands defined by SMILES can't be bonded since SMILES doesn't + define unique atom names. + user_ccd: Optional user-defined chemical component dictionary in the CIF + format. This can be used to provide additional CCD entries that are not + present in the default CCD and thus define arbitrary new ligands. This is + more expressive than SMILES since it allows to name all atoms within the + ligand which in turn makes it possible to define bonds using those atoms. + """ + + name: str + chains: Sequence[ProteinChain | RnaChain | DnaChain | Ligand] + rng_seeds: Sequence[int] + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]] | None = None + user_ccd: str | None = None + + def __post_init__(self): + if not self.rng_seeds: + raise ValueError('Input must have at least one RNG seed.') + + if not self.name.strip() or not self.sanitised_name(): + raise ValueError( + 'Input name must be non-empty and contain at least one valid' + ' character (letters, numbers, dots, dashes, underscores).' + ) + + chain_ids = [c.id for c in self.chains] + if any(not c.id.isalpha() or c.id.islower() for c in self.chains): + raise ValueError( + f'IDs must be upper case letters, got: {chain_ids}') + if len(set(chain_ids)) != len(chain_ids): + raise ValueError( + 'Input JSON contains sequences with duplicate IDs.') + + # Use hashable types for chains, rng_seeds, and bonded_atom_pairs. + object.__setattr__(self, 'chains', tuple(self.chains)) + object.__setattr__(self, 'rng_seeds', tuple(self.rng_seeds)) + if self.bonded_atom_pairs is not None: + object.__setattr__( + self, 'bonded_atom_pairs', tuple(self.bonded_atom_pairs) + ) + + @property + def protein_chains(self) -> Sequence[ProteinChain]: + return [chain for chain in self.chains if isinstance(chain, ProteinChain)] + + @property + def rna_chains(self) -> Sequence[RnaChain]: + return [chain for chain in self.chains if isinstance(chain, RnaChain)] + + @property + def dna_chains(self) -> Sequence[DnaChain]: + return [chain for chain in self.chains if isinstance(chain, DnaChain)] + + @property + def ligands(self) -> Sequence[Ligand]: + return [chain for chain in self.chains if isinstance(chain, Ligand)] + + @classmethod + def from_alphafoldserver_fold_job(cls, fold_job: Mapping[str, Any]) -> Self: + """Constructs Input from an AlphaFoldServer fold job.""" + + # Validate the fold job has the correct format. + _validate_keys( + fold_job.keys(), + {'name', 'modelSeeds', 'sequences', 'dialect', 'version'}, + ) + if 'dialect' not in fold_job and 'version' not in fold_job: + dialect = ALPHAFOLDSERVER_JSON_DIALECT + version = ALPHAFOLDSERVER_JSON_VERSION + elif 'dialect' in fold_job and 'version' in fold_job: + dialect = fold_job['dialect'] + version = fold_job['version'] + else: + raise ValueError( + 'AlphaFold Server input JSON must either contain both `dialect` and' + ' `version` fields, or neither. If neither is specified, it is' + f' assumed that `dialect="{ALPHAFOLDSERVER_JSON_DIALECT}"` and' + f' `version="{ALPHAFOLDSERVER_JSON_VERSION}"`.' + ) + + if dialect != ALPHAFOLDSERVER_JSON_DIALECT: + raise ValueError( + f'AlphaFold Server input JSON has unsupported dialect: {dialect}, ' + f'expected {ALPHAFOLDSERVER_JSON_DIALECT}.' + ) + + # For now, there is only one AlphaFold Server JSON version. + if version != ALPHAFOLDSERVER_JSON_VERSION: + raise ValueError( + f'AlphaFold Server input JSON has unsupported version: {version}, ' + f'expected {ALPHAFOLDSERVER_JSON_VERSION}.' + ) + + # Parse the chains. + chains = [] + for sequence in fold_job['sequences']: + if 'proteinChain' in sequence: + for _ in range(sequence['proteinChain'].get('count', 1)): + chains.append( + ProteinChain.from_alphafoldserver_dict( + sequence['proteinChain'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'rnaSequence' in sequence: + for _ in range(sequence['rnaSequence'].get('count', 1)): + chains.append( + RnaChain.from_alphafoldserver_dict( + sequence['rnaSequence'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'dnaSequence' in sequence: + for _ in range(sequence['dnaSequence'].get('count', 1)): + chains.append( + DnaChain.from_alphafoldserver_dict( + sequence['dnaSequence'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'ion' in sequence: + for _ in range(sequence['ion'].get('count', 1)): + chains.append( + Ligand.from_alphafoldserver_dict( + sequence['ion'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + elif 'ligand' in sequence: + for _ in range(sequence['ligand'].get('count', 1)): + chains.append( + Ligand.from_alphafoldserver_dict( + sequence['ligand'], + seq_id=mmcif_lib.int_id_to_str_id(len(chains) + 1), + ) + ) + else: + raise ValueError(f'Unknown sequence type: {sequence}') + + if 'modelSeeds' in fold_job and fold_job['modelSeeds']: + rng_seeds = [int(seed) for seed in fold_job['modelSeeds']] + else: + rng_seeds = [_sample_rng_seed()] + + return cls(name=fold_job['name'], chains=chains, rng_seeds=rng_seeds) + + @classmethod + def from_json(cls, json_str: str) -> Self: + """Loads the input from the AlphaFold JSON string.""" + raw_json = json.loads(json_str) + + _validate_keys( + raw_json.keys(), + { + 'dialect', + 'version', + 'name', + 'modelSeeds', + 'sequences', + 'bondedAtomPairs', + 'userCCD', + }, + ) + + if 'dialect' not in raw_json or 'version' not in raw_json: + raise ValueError( + 'AlphaFold 3 input JSON must contain `dialect` and `version` fields.' + ) + + if raw_json['dialect'] != JSON_DIALECT: + raise ValueError( + 'AlphaFold 3 input JSON has unsupported dialect:' + f' {raw_json["dialect"]}, expected {JSON_DIALECT}.' + ) + + # For now, there is only one AlphaFold 3 JSON version. + if raw_json['version'] != JSON_VERSION: + raise ValueError( + 'AlphaFold 3 input JSON has unsupported version:' + f' {raw_json["version"]}, expected {JSON_VERSION}.' + ) + + if 'sequences' not in raw_json: + raise ValueError( + 'AlphaFold 3 input JSON does not contain any sequences.') + + if 'modelSeeds' not in raw_json or not raw_json['modelSeeds']: + raise ValueError( + 'AlphaFold 3 input JSON must specify at least one rng seed in' + ' `modelSeeds`.' + ) + + sequences = raw_json['sequences'] + + # Make sure sequence IDs are all set. + raw_sequence_ids = [next(iter(s.values())).get('id') + for s in sequences] + if all(raw_sequence_ids): + sequence_ids = [] + for sequence_id in raw_sequence_ids: + if isinstance(sequence_id, list): + sequence_ids.append(sequence_id) + else: + sequence_ids.append([sequence_id]) + else: + raise ValueError( + 'AlphaFold 3 input JSON contains sequences with unset IDs.' + ) + + flat_seq_ids = [] + for seq_ids in sequence_ids: + flat_seq_ids.extend(seq_ids) + + chains = [] + for seq_ids, sequence in zip(sequence_ids, sequences, strict=True): + if len(sequence) != 1: + raise ValueError(f'Chain {seq_ids} has more than 1 sequence.') + for seq_id in seq_ids: + if 'protein' in sequence: + chains.append(ProteinChain.from_dict( + sequence, seq_id=seq_id)) + elif 'rna' in sequence: + chains.append(RnaChain.from_dict(sequence, seq_id=seq_id)) + elif 'dna' in sequence: + chains.append(DnaChain.from_dict(sequence, seq_id=seq_id)) + elif 'ligand' in sequence: + chains.append(Ligand.from_dict(sequence, seq_id=seq_id)) + else: + raise ValueError(f'Unknown sequence type: {sequence}') + + ligands = [chain for chain in chains if isinstance(chain, Ligand)] + bonded_atom_pairs = None + if bonds := raw_json.get('bondedAtomPairs'): + bonded_atom_pairs = [] + for bond in bonds: + if len(bond) != 2: + raise ValueError( + f'Bond {bond} must have 2 atoms, got {len(bond)}.') + bond_beg, bond_end = bond + if ( + len(bond_beg) != 3 + or not isinstance(bond_beg[0], str) + or not isinstance(bond_beg[1], int) + or not isinstance(bond_beg[2], str) + ): + raise ValueError( + f'Atom {bond_beg} in bond {bond} must have 3 components: ' + '(chain_id: str, res_id: int, atom_name: str).' + ) + if ( + len(bond_end) != 3 + or not isinstance(bond_end[0], str) + or not isinstance(bond_end[1], int) + or not isinstance(bond_end[2], str) + ): + raise ValueError( + f'Atom {bond_end} in bond {bond} must have 3 components: ' + '(chain_id: str, res_id: int, atom_name: str).' + ) + if bond_beg[0] not in flat_seq_ids or bond_end[0] not in flat_seq_ids: + raise ValueError(f'Invalid chain ID(s) in bond {bond}') + if bond_beg[1] <= 0 or bond_end[1] <= 0: + raise ValueError(f'Invalid residue ID(s) in bond {bond}') + smiles_ligand_ids = set( + l.id for l in ligands if l.smiles is not None) + if bond_beg[0] in smiles_ligand_ids: + raise ValueError( + f'Bond {bond} involves an unsupported SMILES ligand {bond_beg[0]}' + ) + if bond_end[0] in smiles_ligand_ids: + raise ValueError( + f'Bond {bond} involves an unsupported SMILES ligand {bond_end[0]}' + ) + bonded_atom_pairs.append((tuple(bond_beg), tuple(bond_end))) + + return cls( + name=raw_json['name'], + chains=chains, + rng_seeds=[int(seed) for seed in raw_json['modelSeeds']], + bonded_atom_pairs=bonded_atom_pairs, + user_ccd=raw_json.get('userCCD'), + ) + + @classmethod + def from_mmcif(cls, mmcif_str: str, ccd: chemical_components.Ccd) -> Self: + """Loads the input from an mmCIF string. + + WARNING: Since rng seeds are not stored in mmCIFs, an rng seed is sampled + in the returned `Input`. + + Args: + mmcif_str: The mmCIF string. + ccd: The chemical components dictionary. + + Returns: + The input in an Input format. + """ + + struc = structure.from_mmcif( + mmcif_str, + include_water=False, + fix_mse_residues=True, + fix_unknown_dna=True, + include_bonds=True, + include_other=False, + ) + + # Create default bioassembly, expanding structures implied by stoichiometry. + struc = struc.generate_bioassembly(None) + + sequences = struc.chain_single_letter_sequence( + include_missing_residues=True + ) + + chains = [] + for chain_id, chain_type in zip( + struc.group_by_chain.chain_id, struc.group_by_chain.chain_type + ): + sequence = sequences[chain_id] + + if chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + residues = list(struc.chain_res_name_sequence()[chain_id]) + if all(ccd.get(res) is not None for res in residues): + chains.append(Ligand(id=chain_id, ccd_ids=residues)) + elif len(residues) == 1: + comp_name = residues[0] + comps = struc.chemical_components_data + if comps is None: + raise ValueError( + 'Missing mmCIF chemical components data - this is required for ' + f'a non-CCD ligand {comp_name} defined using SMILES string.' + ) + chains.append( + Ligand(id=chain_id, + smiles=comps.chem_comp[comp_name].pdbx_smiles) + ) + else: + raise ValueError( + 'Multi-component ligand must be defined using CCD IDs, defining' + ' using SMILES is supported only for single-component ligands. ' + f'Got {residues}' + ) + else: + residues = struc.chain_res_name_sequence()[chain_id] + fixed = struc.chain_res_name_sequence( + fix_non_standard_polymer_res=True + )[chain_id] + modifications = [ + (orig, i + 1) + for i, (orig, fixed) in enumerate(zip(residues, fixed, strict=True)) + if orig != fixed + ] + + if chain_type == mmcif_names.PROTEIN_CHAIN: + chains.append( + ProteinChain(id=chain_id, sequence=sequence, + ptms=modifications) + ) + elif chain_type == mmcif_names.RNA_CHAIN: + chains.append( + RnaChain( + id=chain_id, sequence=sequence, modifications=modifications + ) + ) + elif chain_type == mmcif_names.DNA_CHAIN: + chains.append( + DnaChain( + id=chain_id, sequence=sequence, modifications=modifications + ) + ) + + bonded_atom_pairs = [] + chain_ids = set(c.id for c in chains) + for atom_a, atom_b, _ in struc.iter_bonds(): + if atom_a['chain_id'] in chain_ids and atom_b['chain_id'] in chain_ids: + beg = (atom_a['chain_id'], int( + atom_a['res_id']), atom_a['atom_name']) + end = (atom_b['chain_id'], int( + atom_b['res_id']), atom_b['atom_name']) + bonded_atom_pairs.append((beg, end)) + + return cls( + name=struc.name, + chains=chains, + # mmCIFs don't store rng seeds, so we need to sample one here. + rng_seeds=[_sample_rng_seed()], + bonded_atom_pairs=bonded_atom_pairs or None, + ) + + def to_structure(self, ccd: chemical_components.Ccd) -> structure.Structure: + """Converts Input to a Structure. + + WARNING: This method does not preserve the rng seeds. + + Args: + ccd: The chemical components dictionary. + + Returns: + The input in a structure.Structure format. + """ + ids: list[str] = [] + sequences: list[str] = [] + poly_types: list[str] = [] + formats: list[structure.SequenceFormat] = [] + + for chain in self.chains: + ids.append(chain.id) + match chain: + case ProteinChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.PROTEIN_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case RnaChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.RNA_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case DnaChain(): + sequences.append( + '(' + ')('.join(chain.to_ccd_sequence()) + ')') + poly_types.append(mmcif_names.DNA_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + case Ligand(): + if chain.ccd_ids is not None: + sequences.append('(' + ')('.join(chain.ccd_ids) + ')') + if len(chain.ccd_ids) == 1: + poly_types.append(mmcif_names.NON_POLYMER_CHAIN) + else: + poly_types.append(mmcif_names.BRANCHED_CHAIN) + formats.append(structure.SequenceFormat.CCD_CODES) + elif chain.smiles is not None: + # Convert to `:` format that is expected + # by structure.from_sequences_and_bonds. + sequences.append(f'LIG_{chain.id}:{chain.smiles}') + poly_types.append(mmcif_names.NON_POLYMER_CHAIN) + formats.append(structure.SequenceFormat.LIGAND_SMILES) + else: + raise ValueError( + 'Ligand must have one of CCD ID or SMILES set.') + + # Remap bond chain IDs from chain IDs to chain indices and convert to + # 0-based residue indexing. + bonded_atom_pairs = [] + chain_indices = {cid: i for i, cid in enumerate(ids)} + if self.bonded_atom_pairs is not None: + for bond_beg, bond_end in self.bonded_atom_pairs: + bonded_atom_pairs.append(( + (chain_indices[bond_beg[0]], bond_beg[1] - 1, bond_beg[2]), + (chain_indices[bond_end[0]], bond_end[1] - 1, bond_end[2]), + )) + + struc = structure.from_sequences_and_bonds( + sequences=sequences, + chain_types=poly_types, + sequence_formats=formats, + bonded_atom_pairs=bonded_atom_pairs, + ccd=ccd, + name=self.sanitised_name(), + bond_type=mmcif_names.COVALENT_BOND, + release_date=None, + ) + # Rename chain IDs to the original ones. + return struc.rename_chain_ids(dict(zip(struc.chains, ids, strict=True))) + + def to_json(self) -> str: + """Converts Input to an AlphaFold JSON.""" + alphafold_json = json.dumps( + { + 'dialect': JSON_DIALECT, + 'version': JSON_VERSION, + 'name': self.name, + 'sequences': [chain.to_dict() for chain in self.chains], + 'modelSeeds': self.rng_seeds, + 'bondedAtomPairs': self.bonded_atom_pairs, + 'userCCD': self.user_ccd, + }, + indent=2, + ) + # Remove newlines from the query/template indices arrays. We match the + # queryIndices/templatesIndices with a non-capturing group. We then match + # the entire region between the square brackets by looking for lines + # containing only whitespace, number, or a comma. + return re.sub( + r'("(?:queryIndices|templateIndices)": \[)([\s\n\d,]+)(\],?)', + lambda mtch: mtch[1] + + re.sub(r'\n\s+', ' ', mtch[2].strip()) + mtch[3], + alphafold_json, + ) + + def fill_missing_fields(self) -> Self: + """Fill missing MSA and template fields with default values.""" + with_missing_fields = [ + c.fill_missing_fields() + if isinstance(c, (ProteinChain, RnaChain)) + else c + for c in self.chains + ] + return dataclasses.replace(self, chains=with_missing_fields) + + def sanitised_name(self) -> str: + """Returns sanitised version of the name that can be used as a filename.""" + lower_spaceless_name = self.name.lower().replace(' ', '_') + allowed_chars = set(string.ascii_lowercase + string.digits + '_-.') + return ''.join(l for l in lower_spaceless_name if l in allowed_chars) + + +def check_unique_sanitised_names(fold_inputs: Sequence[Input]) -> None: + """Checks that the names of the fold inputs are unique.""" + names = [fi.sanitised_name() for fi in fold_inputs] + if len(set(names)) != len(names): + raise ValueError( + f'Fold inputs must have unique sanitised names, got {names}.' + ) + + +def load_fold_inputs_from_path(json_path: pathlib.Path) -> Sequence[Input]: + """Loads multiple fold inputs from a JSON string.""" + with open(json_path, 'r') as f: + json_str = f.read() + + # Parse the JSON string, so we can detect its format. + raw_json = json.loads(json_str) + + fold_inputs = [] + if isinstance(raw_json, list): + # AlphaFold Server JSON. + logging.info( + 'Detected %s is an AlphaFold Server JSON since the top-level is a' + ' list.', + json_path, + ) + + logging.info('Loading %d fold jobs from %s', len(raw_json), json_path) + for fold_job_idx, fold_job in enumerate(raw_json): + try: + fold_inputs.append( + Input.from_alphafoldserver_fold_job(fold_job)) + except ValueError as e: + raise ValueError( + f'Failed to load fold job {fold_job_idx} from {json_path}. The JSON' + f' at {json_path} was detected to be an AlphaFold Server JSON since' + ' the top-level is a list.' + ) from e + else: + logging.info( + 'Detected %s is an AlphaFold 3 JSON since the top-level is not a list.', + json_path, + ) + # AlphaFold 3 JSON. + try: + fold_inputs.append(Input.from_json(json_str)) + except ValueError as e: + raise ValueError( + f'Failed to load fold input from {json_path}. The JSON at' + f' {json_path} was detected to be an AlphaFold 3 JSON since the' + ' top-level is not a list.' + ) from e + + check_unique_sanitised_names(fold_inputs) + + return fold_inputs + + +def load_fold_inputs_from_dir(input_dir: pathlib.Path) -> Sequence[Input]: + """Loads multiple fold inputs from all JSON files in a given input_dir. + + Args: + input_dir: The directory containing the JSON files. + + Returns: + The fold inputs from all JSON files in the input directory. + + Raises: + ValueError: If the fold inputs have non-unique sanitised names. + """ + fold_inputs = [] + for file_path in input_dir.glob('*.json'): + if not file_path.is_file(): + continue + + fold_inputs.extend(load_fold_inputs_from_path(file_path)) + + check_unique_sanitised_names(fold_inputs) + + return fold_inputs diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py new file mode 100644 index 000000000..0a1f09ba6 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/resources.py @@ -0,0 +1,77 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Load external resources, such as external tools or data resources.""" + +from collections.abc import Iterator +import os +import pathlib +import typing +from typing import BinaryIO, Final, Literal, TextIO + +from importlib import resources +import alphafold3.common + + +_DATA_ROOT: Final[pathlib.Path] = ( + resources.files(alphafold3.common).joinpath('..').resolve() +) +ROOT = _DATA_ROOT + + +def filename(name: str | os.PathLike[str]) -> str: + """Returns the absolute path to an external resource. + + Note that this calls resources.GetResourceFilename under the hood and hence + causes par file unpacking, which might be unfriendly on diskless machines. + + + Args: + name: the name of the resource corresponding to its path relative to the + root of the repository. + """ + return (_DATA_ROOT / name).as_posix() + + +@typing.overload +def open_resource( + name: str | os.PathLike[str], mode: Literal['r', 'rt'] = 'rt' +) -> TextIO: + ... + + +@typing.overload +def open_resource( + name: str | os.PathLike[str], mode: Literal['rb'] +) -> BinaryIO: + ... + + +def open_resource( + name: str | os.PathLike[str], mode: str = 'rb' +) -> TextIO | BinaryIO: + """Returns an open file object for the named resource. + + Args: + name: the name of the resource corresponding to its path relative to the + root of the repository. + mode: the mode to use when opening the file. + """ + return (_DATA_ROOT / name).open(mode) + + +def get_resource_dir(path: str | os.PathLike[str]) -> os.PathLike[str]: + return _DATA_ROOT / path + + +def walk(path: str) -> Iterator[tuple[str, list[str], list[str]]]: + """Walks the directory tree of resources similar to os.walk.""" + return os.walk((_DATA_ROOT / path).as_posix()) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py new file mode 100644 index 000000000..ec935f137 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/common/testing/data.py @@ -0,0 +1,70 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Module that provides an abstraction for accessing test data.""" + +import os +import pathlib +from typing import Literal, overload + +from absl.testing import absltest + + +class Data: + """Provides an abstraction for accessing test data.""" + + def __init__(self, data_dir: os.PathLike[str] | str): + """Initiailizes data wrapper, providing users with high level data access. + + Args: + data_dir: Directory containing test data. + """ + self._data_dir = pathlib.Path(data_dir) + + def path(self, data_name: str | os.PathLike[str] | None = None) -> str: + """Returns the path to a given test data. + + Args: + data_name: the name of the test data file relative to data_dir. If not + set, this will return the absolute path to the data directory. + """ + data_dir_path = ( + pathlib.Path(absltest.get_default_test_srcdir()) / self._data_dir + ) + + if data_name: + return str(data_dir_path / data_name) + + return str(data_dir_path) + + @overload + def load( + self, data_name: str | os.PathLike[str], mode: Literal['rt'] = 'rt' + ) -> str: + ... + + @overload + def load( + self, data_name: str | os.PathLike[str], mode: Literal['rb'] = 'rb' + ) -> bytes: + ... + + def load( + self, data_name: str | os.PathLike[str], mode: str = 'rt' + ) -> str | bytes: + """Returns the contents of a given test data. + + Args: + data_name: the name of the test data file relative to data_dir. + mode: the mode in which to read the data file. Defaults to text ('rt'). + """ + with open(self.path(data_name), mode=mode) as f: + return f.read() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py new file mode 100644 index 000000000..8630278a1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/atom_types.py @@ -0,0 +1,262 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""List of atom types with reverse look-up.""" + +from collections.abc import Mapping, Sequence, Set +import itertools +import sys +from typing import Final +from alphafold3.constants import residue_names + +# Note: +# `sys.intern` places the values in the Python internal db for fast lookup. + +# 37 common residue atoms. +N = sys.intern('N') +CA = sys.intern('CA') +C = sys.intern('C') +CB = sys.intern('CB') +O = sys.intern('O') +CG = sys.intern('CG') +CG1 = sys.intern('CG1') +CG2 = sys.intern('CG2') +OG = sys.intern('OG') +OG1 = sys.intern('OG1') +SG = sys.intern('SG') +CD = sys.intern('CD') +CD1 = sys.intern('CD1') +CD2 = sys.intern('CD2') +ND1 = sys.intern('ND1') +ND2 = sys.intern('ND2') +OD1 = sys.intern('OD1') +OD2 = sys.intern('OD2') +SD = sys.intern('SD') +CE = sys.intern('CE') +CE1 = sys.intern('CE1') +CE2 = sys.intern('CE2') +CE3 = sys.intern('CE3') +NE = sys.intern('NE') +NE1 = sys.intern('NE1') +NE2 = sys.intern('NE2') +OE1 = sys.intern('OE1') +OE2 = sys.intern('OE2') +CH2 = sys.intern('CH2') +NH1 = sys.intern('NH1') +NH2 = sys.intern('NH2') +OH = sys.intern('OH') +CZ = sys.intern('CZ') +CZ2 = sys.intern('CZ2') +CZ3 = sys.intern('CZ3') +NZ = sys.intern('NZ') +OXT = sys.intern('OXT') + +# 29 common nucleic acid atoms. +C1PRIME = sys.intern("C1'") +C2 = sys.intern('C2') +C2PRIME = sys.intern("C2'") +C3PRIME = sys.intern("C3'") +C4 = sys.intern('C4') +C4PRIME = sys.intern("C4'") +C5 = sys.intern('C5') +C5PRIME = sys.intern("C5'") +C6 = sys.intern('C6') +C7 = sys.intern('C7') +C8 = sys.intern('C8') +N1 = sys.intern('N1') +N2 = sys.intern('N2') +N3 = sys.intern('N3') +N4 = sys.intern('N4') +N6 = sys.intern('N6') +N7 = sys.intern('N7') +N9 = sys.intern('N9') +O2 = sys.intern('O2') +O2PRIME = sys.intern("O2'") +O3PRIME = sys.intern("O3'") +O4 = sys.intern('O4') +O4PRIME = sys.intern("O4'") +O5PRIME = sys.intern("O5'") +O6 = sys.intern('O6') +OP1 = sys.intern('OP1') +OP2 = sys.intern('OP2') +OP3 = sys.intern('OP3') +P = sys.intern('P') + +# A list of atoms (excluding hydrogen) for each AA type. PDB naming convention. +RESIDUE_ATOMS: Mapping[str, tuple[str, ...]] = { + residue_names.ALA: (C, CA, CB, N, O), + residue_names.ARG: (C, CA, CB, CG, CD, CZ, N, NE, O, NH1, NH2), + residue_names.ASN: (C, CA, CB, CG, N, ND2, O, OD1), + residue_names.ASP: (C, CA, CB, CG, N, O, OD1, OD2), + residue_names.CYS: (C, CA, CB, N, O, SG), + residue_names.GLN: (C, CA, CB, CG, CD, N, NE2, O, OE1), + residue_names.GLU: (C, CA, CB, CG, CD, N, O, OE1, OE2), + residue_names.GLY: (C, CA, N, O), + residue_names.HIS: (C, CA, CB, CG, CD2, CE1, N, ND1, NE2, O), + residue_names.ILE: (C, CA, CB, CG1, CG2, CD1, N, O), + residue_names.LEU: (C, CA, CB, CG, CD1, CD2, N, O), + residue_names.LYS: (C, CA, CB, CG, CD, CE, N, NZ, O), + residue_names.MET: (C, CA, CB, CG, CE, N, O, SD), + residue_names.PHE: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O), + residue_names.PRO: (C, CA, CB, CG, CD, N, O), + residue_names.SER: (C, CA, CB, N, O, OG), + residue_names.THR: (C, CA, CB, CG2, N, O, OG1), + residue_names.TRP: + (C, CA, CB, CG, CD1, CD2, CE2, CE3, CZ2, CZ3, CH2, N, NE1, O), + residue_names.TYR: (C, CA, CB, CG, CD1, CD2, CE1, CE2, CZ, N, O, OH), + residue_names.VAL: (C, CA, CB, CG1, CG2, N, O), +} # pyformat: disable + +# Used to identify backbone for alignment and distance calculation for sterics. +PROTEIN_BACKBONE_ATOMS: tuple[str, ...] = (N, CA, C) + +# Naming swaps for ambiguous atom names. Due to symmetries in the amino acids +# the naming of atoms is ambiguous in 4 of the 20 amino acids. (The LDDT paper +# lists 7 amino acids as ambiguous, but the naming ambiguities in LEU, VAL and +# ARG can be resolved by using the 3D constellations of the 'ambiguous' atoms +# and their neighbours) +AMBIGUOUS_ATOM_NAMES: Mapping[str, Mapping[str, str]] = { + residue_names.ASP: {OD1: OD2}, + residue_names.GLU: {OE1: OE2}, + residue_names.PHE: {CD1: CD2, CE1: CE2}, + residue_names.TYR: {CD1: CD2, CE1: CE2}, +} + +# Used when we need to store atom data in a format that requires fixed atom data +# size for every protein residue (e.g. a numpy array). +ATOM37: tuple[str, ...] = ( + N, CA, C, CB, O, CG, CG1, CG2, OG, OG1, SG, CD, CD1, CD2, ND1, ND2, OD1, + OD2, SD, CE, CE1, CE2, CE3, NE, NE1, NE2, OE1, OE2, CH2, NH1, NH2, OH, CZ, + CZ2, CZ3, NZ, OXT) # pyformat: disable +ATOM37_ORDER: Mapping[str, int] = {name: i for i, name in enumerate(ATOM37)} +ATOM37_NUM: Final[int] = len(ATOM37) # := 37. + +# Used when we need to store protein atom data in a format that requires fixed +# atom data size for any residue but takes less space than ATOM37 by having 14 +# fields, which is sufficient for storing atoms of all protein residues (e.g. a +# numpy array). +ATOM14: Mapping[str, tuple[str, ...]] = { + residue_names.ALA: (N, CA, C, O, CB), + residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2), + residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2), + residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2), + residue_names.CYS: (N, CA, C, O, CB, SG), + residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2), + residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2), + residue_names.GLY: (N, CA, C, O), + residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2), + residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1), + residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2), + residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ), + residue_names.MET: (N, CA, C, O, CB, CG, SD, CE), + residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ), + residue_names.PRO: (N, CA, C, O, CB, CG, CD), + residue_names.SER: (N, CA, C, O, CB, OG), + residue_names.THR: (N, CA, C, O, CB, OG1, CG2), + residue_names.TRP: + (N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2), + residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH), + residue_names.VAL: (N, CA, C, O, CB, CG1, CG2), + residue_names.UNK: (), +} # pyformat: disable + +# A compact atom encoding with 14 columns, padded with '' in empty slots. +ATOM14_PADDED: Mapping[str, Sequence[str]] = { + k: [v for _, v in itertools.zip_longest(range(14), values, fillvalue='')] + for k, values in ATOM14.items() +} + +ATOM14_ORDER: Mapping[str, Mapping[str, int]] = { + k: {name: i for i, name in enumerate(v)} for k, v in ATOM14.items() +} +ATOM14_NUM: Final[int] = max(len(v) for v in ATOM14.values()) + +# Used when we need to store protein and nucleic atom library. +DENSE_ATOM: Mapping[str, tuple[str, ...]] = { + # Protein. + residue_names.ALA: (N, CA, C, O, CB), + residue_names.ARG: (N, CA, C, O, CB, CG, CD, NE, CZ, NH1, NH2), + residue_names.ASN: (N, CA, C, O, CB, CG, OD1, ND2), + residue_names.ASP: (N, CA, C, O, CB, CG, OD1, OD2), + residue_names.CYS: (N, CA, C, O, CB, SG), + residue_names.GLN: (N, CA, C, O, CB, CG, CD, OE1, NE2), + residue_names.GLU: (N, CA, C, O, CB, CG, CD, OE1, OE2), + residue_names.GLY: (N, CA, C, O), + residue_names.HIS: (N, CA, C, O, CB, CG, ND1, CD2, CE1, NE2), + residue_names.ILE: (N, CA, C, O, CB, CG1, CG2, CD1), + residue_names.LEU: (N, CA, C, O, CB, CG, CD1, CD2), + residue_names.LYS: (N, CA, C, O, CB, CG, CD, CE, NZ), + residue_names.MET: (N, CA, C, O, CB, CG, SD, CE), + residue_names.PHE: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ), + residue_names.PRO: (N, CA, C, O, CB, CG, CD), + residue_names.SER: (N, CA, C, O, CB, OG), + residue_names.THR: (N, CA, C, O, CB, OG1, CG2), + residue_names.TRP: + (N, CA, C, O, CB, CG, CD1, CD2, NE1, CE2, CE3, CZ2, CZ3, CH2), + residue_names.TYR: (N, CA, C, O, CB, CG, CD1, CD2, CE1, CE2, CZ, OH), + residue_names.VAL: (N, CA, C, O, CB, CG1, CG2), + residue_names.UNK: (), + # RNA. + residue_names.A: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4), + residue_names.C: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6), + residue_names.G: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4), + residue_names.U: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, O2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C6), + residue_names.UNK_RNA: (), + # DNA. + residue_names.DA: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N9, C8, N7, C5, C6, N6, N1, C2, N3, C4), + residue_names.DC: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N1, C2, O2, N3, C4, N4, C5, C6), + residue_names.DG: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N9, C8, N7, C5, C6, O6, N1, C2, N2, N3, C4), + residue_names.DT: + (OP3, P, OP1, OP2, O5PRIME, C5PRIME, C4PRIME, O4PRIME, C3PRIME, O3PRIME, + C2PRIME, C1PRIME, N1, C2, O2, N3, C4, O4, C5, C7, C6), + # Unknown nucleic. + residue_names.UNK_DNA: (), +} # pyformat: disable + +DENSE_ATOM_ORDER: Mapping[str, Mapping[str, int]] = { + k: {name: i for i, name in enumerate(v)} for k, v in DENSE_ATOM.items() +} +DENSE_ATOM_NUM: Final[int] = max(len(v) for v in DENSE_ATOM.values()) + +# Used when we need to store atom data in a format that requires fixed atom data +# size for every nucleic molecule (e.g. a numpy array). +ATOM29: tuple[str, ...] = ( + "C1'", 'C2', "C2'", "C3'", 'C4', "C4'", 'C5', "C5'", 'C6', 'C7', 'C8', 'N1', + 'N2', 'N3', 'N4', 'N6', 'N7', 'N9', 'OP3', 'O2', "O2'", "O3'", 'O4', "O4'", + "O5'", 'O6', 'OP1', 'OP2', 'P') # pyformat: disable +ATOM29_ORDER: Mapping[str, int] = { + atom_type: i for i, atom_type in enumerate(ATOM29) +} +ATOM29_NUM: Final[int] = len(ATOM29) # := 29 + +# Hydrogens that exist depending on the protonation state of the residue. +# Extracted from third_party/py/openmm/app/data/hydrogens.xml +PROTONATION_HYDROGENS: Mapping[str, Set[str]] = { + 'ASP': {'HD2'}, + 'CYS': {'HG'}, + 'GLU': {'HE2'}, + 'HIS': {'HD1', 'HE2'}, + 'LYS': {'HZ3'}, +} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py new file mode 100644 index 000000000..eaf7b5db4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_component_sets.py @@ -0,0 +1,38 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Sets of chemical components.""" + +import pickle +from typing import Final + +from alphafold3.common import resources + + +_CCD_SETS_CCD_PICKLE_FILE = resources.filename( + resources.ROOT / 'constants/converters/chemical_component_sets.pickle' +) + +_CCD_SET = pickle.load(open(_CCD_SETS_CCD_PICKLE_FILE, 'rb')) + +# Glycan (or 'Saccharide') ligands. +# _chem_comp.type containing 'saccharide' and 'linking' (when lower-case). +GLYCAN_LINKING_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_linking'] + +# _chem_comp.type containing 'saccharide' and not 'linking' (when lower-case). +GLYCAN_OTHER_LIGANDS: Final[frozenset[str]] = _CCD_SET['glycans_other'] + +# Each of these molecules appears in over 1k PDB structures, are used to +# facilitate crystallization conditions, but do not have biological relevance. +COMMON_CRYSTALLIZATION_AIDS: Final[frozenset[str]] = frozenset({ + 'SO4', 'GOL', 'EDO', 'PO4', 'ACT', 'PEG', 'DMS', 'TRS', 'PGE', 'PG4', 'FMT', + 'EPE', 'MPD', 'MES', 'CD', 'IOD', +}) # pyformat: disable diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py new file mode 100644 index 000000000..b4df2bc8e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/chemical_components.py @@ -0,0 +1,188 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Chemical Components found in PDB (CCD) constants.""" + +from collections.abc import ItemsView, Iterator, KeysView, Mapping, Sequence, ValuesView +import dataclasses +import functools +import os +import pickle + +from alphafold3.common import resources +from alphafold3.cpp import cif_dict + + +_CCD_PICKLE_FILE = resources.filename( + resources.ROOT / 'constants/converters/ccd.pickle' +) + + +class Ccd(Mapping[str, Mapping[str, Sequence[str]]]): + """Chemical Components found in PDB (CCD) constants. + + See https://academic.oup.com/bioinformatics/article/31/8/1274/212200 for CCD + CIF format documentation. + + Wraps the dict to prevent accidental mutation. + """ + + __slots__ = ('_dict', '_ccd_pickle_path') + + def __init__( + self, + ccd_pickle_path: os.PathLike[str] | None = None, + user_ccd: str | None = None, + ): + """Initialises the chemical components dictionary. + + Args: + ccd_pickle_path: Path to the CCD pickle file. If None, uses the default + CCD pickle file included in the source code. + user_ccd: A string containing the user-provided CCD. This has to conform + to the same format as the CCD, see https://www.wwpdb.org/data/ccd. If + provided, takes precedence over the CCD for the the same key. This can + be used to override specific entries in the CCD if desired. + """ + self._ccd_pickle_path = ccd_pickle_path or _CCD_PICKLE_FILE + with open(self._ccd_pickle_path, 'rb') as f: + self._dict = pickle.loads(f.read()) + + if user_ccd is not None: + if not user_ccd: + raise ValueError('User CCD cannot be an empty string.') + user_ccd_cifs = { + key: {k: tuple(v) for k, v in value.items()} + for key, value in cif_dict.parse_multi_data_cif(user_ccd).items() + } + self._dict.update(user_ccd_cifs) + + def __getitem__(self, key: str) -> Mapping[str, Sequence[str]]: + return self._dict[key] + + def __contains__(self, key: str) -> bool: + return key in self._dict + + def __iter__(self) -> Iterator[str]: + return self._dict.__iter__() + + def __len__(self) -> int: + return len(self._dict) + + def __hash__(self) -> int: + return id(self) # Ok since this is immutable. + + def get( + self, key: str, default: None | Mapping[str, Sequence[str]] = None + ) -> Mapping[str, Sequence[str]] | None: + return self._dict.get(key, default) + + def items(self) -> ItemsView[str, Mapping[str, Sequence[str]]]: + return self._dict.items() + + def values(self) -> ValuesView[Mapping[str, Sequence[str]]]: + return self._dict.values() + + def keys(self) -> KeysView[str]: + return self._dict.keys() + + +@functools.cache +def cached_ccd(user_ccd: str | None = None) -> Ccd: + return Ccd(user_ccd=user_ccd) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ComponentInfo: + name: str + type: str + pdbx_synonyms: str + formula: str + formula_weight: str + mon_nstd_parent_comp_id: str + mon_nstd_flag: str + pdbx_smiles: str + + +def mmcif_to_info(mmcif: Mapping[str, Sequence[str]]) -> ComponentInfo: + """Converts CCD mmCIFs to component info. Missing fields are left empty.""" + names = mmcif['_chem_comp.name'] + types = mmcif['_chem_comp.type'] + mon_nstd_parent_comp_ids = mmcif['_chem_comp.mon_nstd_parent_comp_id'] + pdbx_synonyms = mmcif['_chem_comp.pdbx_synonyms'] + formulas = mmcif['_chem_comp.formula'] + formula_weights = mmcif['_chem_comp.formula_weight'] + + def front_or_empty(values: Sequence[str]) -> str: + return values[0] if values else '' + + type_ = front_or_empty(types) + mon_nstd_parent_comp_id = front_or_empty(mon_nstd_parent_comp_ids) + if type_.lower() == 'non-polymer': + # Unset for non-polymers, e.g. water or ions. + mon_nstd_flag = '.' + elif mon_nstd_parent_comp_id == '?': + # A standard component - it doesn't have a standard parent, e.g. MET. + mon_nstd_flag = 'y' + else: + # A non-standard component, e.g. MSE. + mon_nstd_flag = 'n' + + pdbx_smiles = '' + descriptor_types = mmcif['_pdbx_chem_comp_descriptor.type'] + descriptors = mmcif['_pdbx_chem_comp_descriptor.descriptor'] + for descriptor_type, descriptor in zip(descriptor_types, descriptors): + if descriptor_type == 'SMILES_CANONICAL': + pdbx_smiles = descriptor + break + elif not pdbx_smiles and descriptor_type == 'SMILES': + pdbx_smiles = descriptor + + return ComponentInfo( + name=front_or_empty(names), + type=type_, + pdbx_synonyms=front_or_empty(pdbx_synonyms), + formula=front_or_empty(formulas), + formula_weight=front_or_empty(formula_weights), + mon_nstd_parent_comp_id=mon_nstd_parent_comp_id, + mon_nstd_flag=mon_nstd_flag, + pdbx_smiles=pdbx_smiles, + ) + + +@functools.lru_cache(maxsize=128) +def component_name_to_info(ccd: Ccd, res_name: str) -> ComponentInfo | None: + component = ccd.get(res_name) + if component is None: + return None + return mmcif_to_info(component) + + +def type_symbol(ccd: Ccd, res_name: str, atom_name: str) -> str: + """Returns the element type for the given component name and atom name. + + Args: + ccd: The chemical components dictionary. + res_name: The component name, e.g. ARG. + atom_name: The atom name, e.g. CB, OXT, or NH1. + + Returns: + Element type, e.g. C for (ARG, CB), O for (ARG, OXT), N for (ARG, NH1). + """ + res = ccd.get(res_name) + if res is None: + return '?' + try: + return res['_chem_comp_atom.type_symbol'][ + res['_chem_comp_atom.atom_id'].index(atom_name) + ] + except (ValueError, IndexError, KeyError): + return '?' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py new file mode 100644 index 000000000..e793f216b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/ccd_pickle_gen.py @@ -0,0 +1,53 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Reads Chemical Components gz file and generates a CCD pickle file.""" + +from collections.abc import Sequence +import gzip +import pickle +import sys + +from alphafold3.cpp import cif_dict +import tqdm + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 3: + raise ValueError( + 'Must specify input_file components.cif and output_file') + + _, input_file, output_file = argv + + print(f'Parsing {input_file}', flush=True) + if input_file.endswith('.gz'): + opener = gzip.open + else: + opener = open + + with opener(input_file, 'rb') as f: + whole_file = f.read() + result = { + key: {k: tuple(v) for k, v in value.items()} + for key, value in tqdm.tqdm( + cif_dict.parse_multi_data_cif(whole_file).items() + ) + } + assert len(result) == whole_file.count(b'data_') + + print(f'Writing {output_file}', flush=True) + with open(output_file, 'wb') as f: + pickle.dump(result, f, protocol=pickle.HIGHEST_PROTOCOL) + print('Done', flush=True) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py new file mode 100644 index 000000000..d66611e69 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/converters/chemical_component_sets_gen.py @@ -0,0 +1,81 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Script for updating chemical_component_sets.py.""" + +from collections.abc import Mapping, Sequence +import pathlib +import pickle +import re +import sys + +from alphafold3.common import resources +import tqdm + + +_CCD_PICKLE_FILE = resources.filename( + 'constants/converters/ccd.pickle' +) + + +def find_ions_and_glycans_in_ccd( + ccd: Mapping[str, Mapping[str, Sequence[str]]], +) -> dict[str, frozenset[str]]: + """Finds glycans and ions in all version of CCD.""" + glycans_linking = [] + glycans_other = [] + ions = [] + for name, comp in tqdm.tqdm(ccd.items()): + if name == 'UNX': + continue # Skip "unknown atom or ion". + comp_type = comp['_chem_comp.type'][0].lower() + # Glycans have the type 'saccharide'. + if re.findall(r'\bsaccharide\b', comp_type): + # Separate out linking glycans from others. + if 'linking' in comp_type: + glycans_linking.append(name) + else: + glycans_other.append(name) + + # Ions have the word 'ion' in their name. + comp_name = comp['_chem_comp.name'][0].lower() + if re.findall(r'\bion\b', comp_name): + ions.append(name) + result = dict( + glycans_linking=frozenset(glycans_linking), + glycans_other=frozenset(glycans_other), + ions=frozenset(ions), + ) + + return result + + +def main(argv: Sequence[str]) -> None: + if len(argv) != 2: + raise ValueError( + 'Directory to write to must be specified as a command-line arguments.' + ) + + print(f'Loading {_CCD_PICKLE_FILE}', flush=True) + with open(_CCD_PICKLE_FILE, 'rb') as f: + ccd: Mapping[str, Mapping[str, Sequence[str]]] = pickle.load(f) + output_path = pathlib.Path(argv[1]) + output_path.parent.mkdir(exist_ok=True) + print('Finding ions and glycans', flush=True) + result = find_ions_and_glycans_in_ccd(ccd) + print(f'writing to {output_path}', flush=True) + with output_path.open('wb') as f: + pickle.dump(result, f) + print('Done', flush=True) + + +if __name__ == '__main__': + main(sys.argv) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py new file mode 100644 index 000000000..15eabf2f9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/mmcif_names.py @@ -0,0 +1,218 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Names of things in mmCIF format. + +See https://www.iucr.org/__data/iucr/cifdic_html/2/cif_mm.dic/index.html +""" + +from collections.abc import Mapping, Sequence, Set +from typing import Final + +from alphafold3.constants import atom_types +from alphafold3.constants import residue_names + + +# The following are all possible values for the "_entity.type". +# https://mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_entity.type.html +BRANCHED_CHAIN: Final[str] = 'branched' +MACROLIDE_CHAIN: Final[str] = 'macrolide' +NON_POLYMER_CHAIN: Final[str] = 'non-polymer' +POLYMER_CHAIN: Final[str] = 'polymer' +WATER: Final[str] = 'water' + +CYCLIC_PSEUDO_PEPTIDE_CHAIN: Final[str] = 'cyclic-pseudo-peptide' +DNA_CHAIN: Final[str] = 'polydeoxyribonucleotide' +DNA_RNA_HYBRID_CHAIN: Final[str] = ( + 'polydeoxyribonucleotide/polyribonucleotide hybrid' +) +OTHER_CHAIN: Final[str] = 'other' +PEPTIDE_NUCLEIC_ACID_CHAIN: Final[str] = 'peptide nucleic acid' +POLYPEPTIDE_D_CHAIN: Final[str] = 'polypeptide(D)' +PROTEIN_CHAIN: Final[str] = 'polypeptide(L)' +RNA_CHAIN: Final[str] = 'polyribonucleotide' + +# Most common _entity_poly.types. +STANDARD_POLYMER_CHAIN_TYPES: Final[Set[str]] = { + PROTEIN_CHAIN, + DNA_CHAIN, + RNA_CHAIN, +} + +# Possible values for _entity.type other than polymer and water. +LIGAND_CHAIN_TYPES: Final[Set[str]] = { + BRANCHED_CHAIN, + MACROLIDE_CHAIN, + NON_POLYMER_CHAIN, +} + +# Possible values for _entity.type other than polymer. +NON_POLYMER_CHAIN_TYPES: Final[Set[str]] = { + *LIGAND_CHAIN_TYPES, + WATER, +} + +# Peptide possible values for _entity_poly.type. +PEPTIDE_CHAIN_TYPES: Final[Set[str]] = { + CYCLIC_PSEUDO_PEPTIDE_CHAIN, + POLYPEPTIDE_D_CHAIN, + PROTEIN_CHAIN, + PEPTIDE_NUCLEIC_ACID_CHAIN, +} + + +# Nucleic-acid possible values for _entity_poly.type. +NUCLEIC_ACID_CHAIN_TYPES: Final[Set[str]] = { + RNA_CHAIN, + DNA_CHAIN, + DNA_RNA_HYBRID_CHAIN, +} + +# All possible values for _entity_poly.type. +POLYMER_CHAIN_TYPES: Final[Set[str]] = { + *NUCLEIC_ACID_CHAIN_TYPES, + *PEPTIDE_CHAIN_TYPES, + OTHER_CHAIN, +} + + +TERMINAL_OXYGENS: Final[Mapping[str, str]] = { + PROTEIN_CHAIN: 'OXT', + DNA_CHAIN: 'OP3', + RNA_CHAIN: 'OP3', +} + + +# For each chain type, which atom should be used to represent each residue. +RESIDUE_REPRESENTATIVE_ATOMS: Final[Mapping[str, str]] = { + PROTEIN_CHAIN: atom_types.CA, + DNA_CHAIN: atom_types.C1PRIME, + RNA_CHAIN: atom_types.C1PRIME, +} + +# Methods involving crystallization. See the documentation at +# mmcif.wwpdb.org/dictionaries/mmcif_pdbx_v50.dic/Items/_exptl.method.html +# for the full list of experimental methods. +CRYSTALLIZATION_METHODS: Final[Set[str]] = { + 'X-RAY DIFFRACTION', + 'NEUTRON DIFFRACTION', + 'ELECTRON CRYSTALLOGRAPHY', + 'POWDER CRYSTALLOGRAPHY', + 'FIBER DIFFRACTION', +} + +# Possible bond types. +COVALENT_BOND: Final[str] = 'covale' +HYDROGEN_BOND: Final[str] = 'hydrog' +METAL_COORDINATION: Final[str] = 'metalc' +DISULFIDE_BRIDGE: Final[str] = 'disulf' + + +def is_standard_polymer_type(chain_type: str) -> bool: + """Returns if chain type is a protein, DNA or RNA chain type. + + Args: + chain_type: The type of the chain. + + Returns: + A bool for if the chain_type matches protein, DNA, or RNA. + """ + return chain_type in STANDARD_POLYMER_CHAIN_TYPES + + +def guess_polymer_type(chain_residues: Sequence[str]) -> str: + """Guess the polymer type (protein/rna/dna/other) based on the residues. + + The polymer type is guessed by first checking for any of the standard + protein residues. If one is present then the chain is considered to be a + polypeptide. Otherwise we decide by counting residue types and deciding by + majority voting (e.g. mostly DNA residues -> DNA). If there is a tie between + the counts, the ordering is rna > dna > other. + + Note that we count MSE and UNK as protein residues. + + Args: + chain_residues: A sequence of full residue name (1-letter for DNA, 2-letters + for RNA, 3 for protein). The _atom_site.label_comp_id column in mmCIF. + + Returns: + The most probable chain type as set in the _entity_poly mmCIF table: + protein - polypeptide(L), rna - polyribonucleotide, + dna - polydeoxyribonucleotide or other. + """ + residue_types = { + **{r: RNA_CHAIN for r in residue_names.RNA_TYPES}, + **{r: DNA_CHAIN for r in residue_names.DNA_TYPES}, + **{r: PROTEIN_CHAIN for r in residue_names.PROTEIN_TYPES_WITH_UNKNOWN}, + residue_names.MSE: PROTEIN_CHAIN, + } + + counts = {PROTEIN_CHAIN: 0, RNA_CHAIN: 0, DNA_CHAIN: 0, OTHER_CHAIN: 0} + for residue in chain_residues: + residue_type = residue_types.get(residue, OTHER_CHAIN) + # If we ever see a protein residue we'll consider this a polypeptide(L). + if residue_type == PROTEIN_CHAIN: + return residue_type + counts[residue_type] += 1 + + # Make sure protein > rna > dna > other if there is a tie. + tie_braker = {PROTEIN_CHAIN: 3, RNA_CHAIN: 2, DNA_CHAIN: 1, OTHER_CHAIN: 0} + + def order_fn(item): + name, count = item + return count, tie_braker[name] + + most_probable_type = max(counts.items(), key=order_fn)[0] + return most_probable_type + + +def fix_non_standard_polymer_res(*, res_name: str, chain_type: str) -> str: + """Returns the res_name of the closest standard protein/RNA/DNA residue. + + Optimized for the case where a single residue needs to be converted. + + If res_name is already a standard type, it is returned unaltered. + If a match cannot be found, returns 'UNK' for protein chains and 'N' for + RNA/DNA chains. + + Args: + res_name: A residue_name (monomer code from the CCD). + chain_type: The type of the chain, must be PROTEIN_CHAIN, RNA_CHAIN or + DNA_CHAIN. + + Returns: + An element from PROTEIN_TYPES_WITH_UNKNOWN | RNA_TYPES | DNA_TYPES | {'N'}. + + Raises: + ValueError: If chain_type not in PEPTIDE_CHAIN_TYPES or + {OTHER_CHAIN, RNA_CHAIN, DNA_CHAIN, DNA_RNA_HYBRID_CHAIN}. + """ + # Map to one letter code, then back to common res_names. + one_letter_code = residue_names.letters_three_to_one(res_name, default='X') + + if chain_type in PEPTIDE_CHAIN_TYPES or chain_type == OTHER_CHAIN: + return residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(one_letter_code, 'UNK') + elif chain_type == RNA_CHAIN: + # RNA's CCD monomer code is single-letter. + return ( + one_letter_code if one_letter_code in residue_names.RNA_TYPES else 'N' + ) + elif chain_type == DNA_CHAIN: + return residue_names.DNA_COMMON_ONE_TO_TWO.get(one_letter_code, 'N') + elif chain_type == DNA_RNA_HYBRID_CHAIN: + return ( + res_name + if res_name in residue_names.NUCLEIC_TYPES_WITH_UNKNOWN + else 'N' + ) + else: + raise ValueError( + f'Expected a protein/DNA/RNA chain but got {chain_type}') diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py new file mode 100644 index 000000000..7385245ff --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/periodic_table.py @@ -0,0 +1,399 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Periodic table of elements.""" + +from collections.abc import Mapping, Sequence +import dataclasses +from typing import Final + +import numpy as np + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Element: + name: str + number: int + symbol: str + weight: float + + +# Weights taken from rdkit/Code/GraphMol/atomic_data.cpp for compatibility. +# pylint: disable=invalid-name + +# X is an unknown element that can be present in the CCD, +# https://www.rcsb.org/ligand/UNX. +X: Final[Element] = Element(name='Unknown', number=0, symbol='X', weight=0.0) +H: Final[Element] = Element( + name='Hydrogen', number=1, symbol='H', weight=1.008) +He: Final[Element] = Element( + name='Helium', number=2, symbol='He', weight=4.003) +Li: Final[Element] = Element( + name='Lithium', number=3, symbol='Li', weight=6.941 +) +Be: Final[Element] = Element( + name='Beryllium', number=4, symbol='Be', weight=9.012 +) +B: Final[Element] = Element(name='Boron', number=5, symbol='B', weight=10.812) +C: Final[Element] = Element(name='Carbon', number=6, symbol='C', weight=12.011) +N: Final[Element] = Element( + name='Nitrogen', number=7, symbol='N', weight=14.007 +) +O: Final[Element] = Element(name='Oxygen', number=8, symbol='O', weight=15.999) +F: Final[Element] = Element( + name='Fluorine', number=9, symbol='F', weight=18.998 +) +Ne: Final[Element] = Element(name='Neon', number=10, symbol='Ne', weight=20.18) +Na: Final[Element] = Element( + name='Sodium', number=11, symbol='Na', weight=22.99 +) +Mg: Final[Element] = Element( + name='Magnesium', number=12, symbol='Mg', weight=24.305 +) +Al: Final[Element] = Element( + name='Aluminium', number=13, symbol='Al', weight=26.982 +) +Si: Final[Element] = Element( + name='Silicon', number=14, symbol='Si', weight=28.086 +) +P: Final[Element] = Element( + name='Phosphorus', number=15, symbol='P', weight=30.974 +) +S: Final[Element] = Element( + name='Sulfur', number=16, symbol='S', weight=32.067) +Cl: Final[Element] = Element( + name='Chlorine', number=17, symbol='Cl', weight=35.453 +) +Ar: Final[Element] = Element( + name='Argon', number=18, symbol='Ar', weight=39.948 +) +K: Final[Element] = Element( + name='Potassium', number=19, symbol='K', weight=39.098 +) +Ca: Final[Element] = Element( + name='Calcium', number=20, symbol='Ca', weight=40.078 +) +Sc: Final[Element] = Element( + name='Scandium', number=21, symbol='Sc', weight=44.956 +) +Ti: Final[Element] = Element( + name='Titanium', number=22, symbol='Ti', weight=47.867 +) +V: Final[Element] = Element( + name='Vanadium', number=23, symbol='V', weight=50.942 +) +Cr: Final[Element] = Element( + name='Chromium', number=24, symbol='Cr', weight=51.996 +) +Mn: Final[Element] = Element( + name='Manganese', number=25, symbol='Mn', weight=54.938 +) +Fe: Final[Element] = Element( + name='Iron', number=26, symbol='Fe', weight=55.845) +Co: Final[Element] = Element( + name='Cobalt', number=27, symbol='Co', weight=58.933 +) +Ni: Final[Element] = Element( + name='Nickel', number=28, symbol='Ni', weight=58.693 +) +Cu: Final[Element] = Element( + name='Copper', number=29, symbol='Cu', weight=63.546 +) +Zn: Final[Element] = Element(name='Zinc', number=30, symbol='Zn', weight=65.39) +Ga: Final[Element] = Element( + name='Gallium', number=31, symbol='Ga', weight=69.723 +) +Ge: Final[Element] = Element( + name='Germanium', number=32, symbol='Ge', weight=72.61 +) +As: Final[Element] = Element( + name='Arsenic', number=33, symbol='As', weight=74.922 +) +Se: Final[Element] = Element( + name='Selenium', number=34, symbol='Se', weight=78.96 +) +Br: Final[Element] = Element( + name='Bromine', number=35, symbol='Br', weight=79.904 +) +Kr: Final[Element] = Element( + name='Krypton', number=36, symbol='Kr', weight=83.8 +) +Rb: Final[Element] = Element( + name='Rubidium', number=37, symbol='Rb', weight=85.468 +) +Sr: Final[Element] = Element( + name='Strontium', number=38, symbol='Sr', weight=87.62 +) +Y: Final[Element] = Element( + name='Yttrium', number=39, symbol='Y', weight=88.906 +) +Zr: Final[Element] = Element( + name='Zirconium', number=40, symbol='Zr', weight=91.224 +) +Nb: Final[Element] = Element( + name='Niobium', number=41, symbol='Nb', weight=92.906 +) +Mo: Final[Element] = Element( + name='Molybdenum', number=42, symbol='Mo', weight=95.94 +) +Tc: Final[Element] = Element( + name='Technetium', number=43, symbol='Tc', weight=98 +) +Ru: Final[Element] = Element( + name='Ruthenium', number=44, symbol='Ru', weight=101.07 +) +Rh: Final[Element] = Element( + name='Rhodium', number=45, symbol='Rh', weight=102.906 +) +Pd: Final[Element] = Element( + name='Palladium', number=46, symbol='Pd', weight=106.42 +) +Ag: Final[Element] = Element( + name='Silver', number=47, symbol='Ag', weight=107.868 +) +Cd: Final[Element] = Element( + name='Cadmium', number=48, symbol='Cd', weight=112.412 +) +In: Final[Element] = Element( + name='Indium', number=49, symbol='In', weight=114.818 +) +Sn: Final[Element] = Element( + name='Tin', number=50, symbol='Sn', weight=118.711) +Sb: Final[Element] = Element( + name='Antimony', number=51, symbol='Sb', weight=121.76 +) +Te: Final[Element] = Element( + name='Tellurium', number=52, symbol='Te', weight=127.6 +) +I: Final[Element] = Element( + name='Iodine', number=53, symbol='I', weight=126.904 +) +Xe: Final[Element] = Element( + name='Xenon', number=54, symbol='Xe', weight=131.29 +) +Cs: Final[Element] = Element( + name='Caesium', number=55, symbol='Cs', weight=132.905 +) +Ba: Final[Element] = Element( + name='Barium', number=56, symbol='Ba', weight=137.328 +) +La: Final[Element] = Element( + name='Lanthanum', number=57, symbol='La', weight=138.906 +) +Ce: Final[Element] = Element( + name='Cerium', number=58, symbol='Ce', weight=140.116 +) +Pr: Final[Element] = Element( + name='Praseodymium', number=59, symbol='Pr', weight=140.908 +) +Nd: Final[Element] = Element( + name='Neodymium', number=60, symbol='Nd', weight=144.24 +) +Pm: Final[Element] = Element( + name='Promethium', number=61, symbol='Pm', weight=145 +) +Sm: Final[Element] = Element( + name='Samarium', number=62, symbol='Sm', weight=150.36 +) +Eu: Final[Element] = Element( + name='Europium', number=63, symbol='Eu', weight=151.964 +) +Gd: Final[Element] = Element( + name='Gadolinium', number=64, symbol='Gd', weight=157.25 +) +Tb: Final[Element] = Element( + name='Terbium', number=65, symbol='Tb', weight=158.925 +) +Dy: Final[Element] = Element( + name='Dysprosium', number=66, symbol='Dy', weight=162.5 +) +Ho: Final[Element] = Element( + name='Holmium', number=67, symbol='Ho', weight=164.93 +) +Er: Final[Element] = Element( + name='Erbium', number=68, symbol='Er', weight=167.26 +) +Tm: Final[Element] = Element( + name='Thulium', number=69, symbol='Tm', weight=168.934 +) +Yb: Final[Element] = Element( + name='Ytterbium', number=70, symbol='Yb', weight=173.04 +) +Lu: Final[Element] = Element( + name='Lutetium', number=71, symbol='Lu', weight=174.967 +) +Hf: Final[Element] = Element( + name='Hafnium', number=72, symbol='Hf', weight=178.49 +) +Ta: Final[Element] = Element( + name='Tantalum', number=73, symbol='Ta', weight=180.948 +) +W: Final[Element] = Element( + name='Tungsten', number=74, symbol='W', weight=183.84 +) +Re: Final[Element] = Element( + name='Rhenium', number=75, symbol='Re', weight=186.207 +) +Os: Final[Element] = Element( + name='Osmium', number=76, symbol='Os', weight=190.23 +) +Ir: Final[Element] = Element( + name='Iridium', number=77, symbol='Ir', weight=192.217 +) +Pt: Final[Element] = Element( + name='Platinum', number=78, symbol='Pt', weight=195.078 +) +Au: Final[Element] = Element( + name='Gold', number=79, symbol='Au', weight=196.967 +) +Hg: Final[Element] = Element( + name='Mercury', number=80, symbol='Hg', weight=200.59 +) +Tl: Final[Element] = Element( + name='Thallium', number=81, symbol='Tl', weight=204.383 +) +Pb: Final[Element] = Element(name='Lead', number=82, symbol='Pb', weight=207.2) +Bi: Final[Element] = Element( + name='Bismuth', number=83, symbol='Bi', weight=208.98 +) +Po: Final[Element] = Element( + name='Polonium', number=84, symbol='Po', weight=209 +) +At: Final[Element] = Element( + name='Astatine', number=85, symbol='At', weight=210 +) +Rn: Final[Element] = Element(name='Radon', number=86, symbol='Rn', weight=222) +Fr: Final[Element] = Element( + name='Francium', number=87, symbol='Fr', weight=223 +) +Ra: Final[Element] = Element(name='Radium', number=88, symbol='Ra', weight=226) +Ac: Final[Element] = Element( + name='Actinium', number=89, symbol='Ac', weight=227 +) +Th: Final[Element] = Element( + name='Thorium', number=90, symbol='Th', weight=232.038 +) +Pa: Final[Element] = Element( + name='Protactinium', number=91, symbol='Pa', weight=231.036 +) +U: Final[Element] = Element( + name='Uranium', number=92, symbol='U', weight=238.029 +) +Np: Final[Element] = Element( + name='Neptunium', number=93, symbol='Np', weight=237 +) +Pu: Final[Element] = Element( + name='Plutonium', number=94, symbol='Pu', weight=244 +) +Am: Final[Element] = Element( + name='Americium', number=95, symbol='Am', weight=243 +) +Cm: Final[Element] = Element(name='Curium', number=96, symbol='Cm', weight=247) +Bk: Final[Element] = Element( + name='Berkelium', number=97, symbol='Bk', weight=247 +) +Cf: Final[Element] = Element( + name='Californium', number=98, symbol='Cf', weight=251 +) +Es: Final[Element] = Element( + name='Einsteinium', number=99, symbol='Es', weight=252 +) +Fm: Final[Element] = Element( + name='Fermium', number=100, symbol='Fm', weight=257 +) +Md: Final[Element] = Element( + name='Mendelevium', number=101, symbol='Md', weight=258 +) +No: Final[Element] = Element( + name='Nobelium', number=102, symbol='No', weight=259 +) +Lr: Final[Element] = Element( + name='Lawrencium', number=103, symbol='Lr', weight=262 +) +Rf: Final[Element] = Element( + name='Rutherfordium', number=104, symbol='Rf', weight=267 +) +Db: Final[Element] = Element( + name='Dubnium', number=105, symbol='Db', weight=268 +) +Sg: Final[Element] = Element( + name='Seaborgium', number=106, symbol='Sg', weight=269 +) +Bh: Final[Element] = Element( + name='Bohrium', number=107, symbol='Bh', weight=270 +) +Hs: Final[Element] = Element( + name='Hassium', number=108, symbol='Hs', weight=269 +) +Mt: Final[Element] = Element( + name='Meitnerium', number=109, symbol='Mt', weight=278 +) +Ds: Final[Element] = Element( + name='Darmstadtium', number=110, symbol='Ds', weight=281 +) +Rg: Final[Element] = Element( + name='Roentgenium', number=111, symbol='Rg', weight=281 +) +Cn: Final[Element] = Element( + name='Copernicium', number=112, symbol='Cn', weight=285 +) +Nh: Final[Element] = Element( + name='Nihonium', number=113, symbol='Nh', weight=284 +) +Fl: Final[Element] = Element( + name='Flerovium', number=114, symbol='Fl', weight=289 +) +Mc: Final[Element] = Element( + name='Moscovium', number=115, symbol='Mc', weight=288 +) +Lv: Final[Element] = Element( + name='Livermorium', number=116, symbol='Lv', weight=293 +) +Ts: Final[Element] = Element( + name='Tennessine', number=117, symbol='Ts', weight=292 +) +Og: Final[Element] = Element( + name='Oganesson', number=118, symbol='Og', weight=294 +) +# pylint: enable=invalid-name + +# fmt: off +# Lanthanides +_L: Final[Sequence[Element]] = ( + La, Ce, Pr, Nd, Pm, Sm, Eu, Gd, Tb, Dy, Ho, Er, Tm, Yb, Lu) +# Actinides +_A: Final[Sequence[Element]] = ( + Ac, Th, Pa, U, Np, Pu, Am, Cm, Bk, Cf, Es, Fm, Md, No, Lr) + +# pylint: disable=bad-whitespace +PERIODIC_TABLE: Final[Sequence[Element]] = ( + X, # Unknown + H, He, + Li, Be, B, C, N, O, F, Ne, + Na, Mg, Al, Si, P, S, Cl, Ar, + K, Ca, Sc, Ti, V, Cr, Mn, Fe, Co, Ni, Cu, Zn, Ga, Ge, As, Se, Br, Kr, + Rb, Sr, Y, Zr, Nb, Mo, Tc, Ru, Rh, Pd, Ag, Cd, In, Sn, Sb, Te, I, Xe, + Cs, Ba, *_L, Hf, Ta, W, Re, Os, Ir, Pt, Au, Hg, Tl, Pb, Bi, Po, At, Rn, + Fr, Ra, *_A, Rf, Db, Sg, Bh, Hs, Mt, Ds, Rg, Cn, Nh, Fl, Mc, Lv, Ts, Og +) +# pylint: enable=bad-whitespace +# fmt: on +ATOMIC_SYMBOL: Mapping[int, str] = {e.number: e.symbol for e in PERIODIC_TABLE} +ATOMIC_NUMBER = {e.symbol: e.number for e in PERIODIC_TABLE} +# Add Deuterium as previous table contained it. +ATOMIC_NUMBER['D'] = 1 + +ATOMIC_NUMBER: Mapping[str, int] = ATOMIC_NUMBER +ATOMIC_WEIGHT: np.ndarray = np.zeros(len(PERIODIC_TABLE), dtype=np.float64) + +for e in PERIODIC_TABLE: + ATOMIC_WEIGHT[e.number] = e.weight +ATOMIC_WEIGHT.setflags(write=False) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py new file mode 100644 index 000000000..40d42587c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/residue_names.py @@ -0,0 +1,421 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Constants associated with residue names.""" + +from collections.abc import Mapping +import functools +import sys + +# pyformat: disable +# common_typos_disable +CCD_NAME_TO_ONE_LETTER: Mapping[str, str] = { + '00C': 'C', '01W': 'X', '02K': 'A', '03Y': 'C', '07O': 'C', '08P': 'C', + '0A0': 'D', '0A1': 'Y', '0A2': 'K', '0A8': 'C', '0AA': 'V', '0AB': 'V', + '0AC': 'G', '0AD': 'G', '0AF': 'W', '0AG': 'L', '0AH': 'S', '0AK': 'D', + '0AM': 'A', '0AP': 'C', '0AU': 'U', '0AV': 'A', '0AZ': 'P', '0BN': 'F', + '0C': 'C', '0CS': 'A', '0DC': 'C', '0DG': 'G', '0DT': 'T', '0FL': 'A', + '0G': 'G', '0NC': 'A', '0SP': 'A', '0U': 'U', '10C': 'C', '125': 'U', + '126': 'U', '127': 'U', '128': 'N', '12A': 'A', '143': 'C', '193': 'X', + '1AP': 'A', '1MA': 'A', '1MG': 'G', '1PA': 'F', '1PI': 'A', '1PR': 'N', + '1SC': 'C', '1TQ': 'W', '1TY': 'Y', '1X6': 'S', '200': 'F', '23F': 'F', + '23S': 'X', '26B': 'T', '2AD': 'X', '2AG': 'A', '2AO': 'X', '2AR': 'A', + '2AS': 'X', '2AT': 'T', '2AU': 'U', '2BD': 'I', '2BT': 'T', '2BU': 'A', + '2CO': 'C', '2DA': 'A', '2DF': 'N', '2DM': 'N', '2DO': 'X', '2DT': 'T', + '2EG': 'G', '2FE': 'N', '2FI': 'N', '2FM': 'M', '2GT': 'T', '2HF': 'H', + '2LU': 'L', '2MA': 'A', '2MG': 'G', '2ML': 'L', '2MR': 'R', '2MT': 'P', + '2MU': 'U', '2NT': 'T', '2OM': 'U', '2OT': 'T', '2PI': 'X', '2PR': 'G', + '2SA': 'N', '2SI': 'X', '2ST': 'T', '2TL': 'T', '2TY': 'Y', '2VA': 'V', + '2XA': 'C', '32S': 'X', '32T': 'X', '3AH': 'H', '3AR': 'X', '3CF': 'F', + '3DA': 'A', '3DR': 'N', '3GA': 'A', '3MD': 'D', '3ME': 'U', '3NF': 'Y', + '3QN': 'K', '3TY': 'X', '3XH': 'G', '4AC': 'N', '4BF': 'Y', '4CF': 'F', + '4CY': 'M', '4DP': 'W', '4FB': 'P', '4FW': 'W', '4HT': 'W', '4IN': 'W', + '4MF': 'N', '4MM': 'X', '4OC': 'C', '4PC': 'C', '4PD': 'C', '4PE': 'C', + '4PH': 'F', '4SC': 'C', '4SU': 'U', '4TA': 'N', '4U7': 'A', '56A': 'H', + '5AA': 'A', '5AB': 'A', '5AT': 'T', '5BU': 'U', '5CG': 'G', '5CM': 'C', + '5CS': 'C', '5FA': 'A', '5FC': 'C', '5FU': 'U', '5HP': 'E', '5HT': 'T', + '5HU': 'U', '5IC': 'C', '5IT': 'T', '5IU': 'U', '5MC': 'C', '5MD': 'N', + '5MU': 'U', '5NC': 'C', '5PC': 'C', '5PY': 'T', '5SE': 'U', '64T': 'T', + '6CL': 'K', '6CT': 'T', '6CW': 'W', '6HA': 'A', '6HC': 'C', '6HG': 'G', + '6HN': 'K', '6HT': 'T', '6IA': 'A', '6MA': 'A', '6MC': 'A', '6MI': 'N', + '6MT': 'A', '6MZ': 'N', '6OG': 'G', '70U': 'U', '7DA': 'A', '7GU': 'G', + '7JA': 'I', '7MG': 'G', '8AN': 'A', '8FG': 'G', '8MG': 'G', '8OG': 'G', + '9NE': 'E', '9NF': 'F', '9NR': 'R', '9NV': 'V', 'A': 'A', 'A1P': 'N', + 'A23': 'A', 'A2L': 'A', 'A2M': 'A', 'A34': 'A', 'A35': 'A', 'A38': 'A', + 'A39': 'A', 'A3A': 'A', 'A3P': 'A', 'A40': 'A', 'A43': 'A', 'A44': 'A', + 'A47': 'A', 'A5L': 'A', 'A5M': 'C', 'A5N': 'N', 'A5O': 'A', 'A66': 'X', + 'AA3': 'A', 'AA4': 'A', 'AAR': 'R', 'AB7': 'X', 'ABA': 'A', 'ABR': 'A', + 'ABS': 'A', 'ABT': 'N', 'ACB': 'D', 'ACL': 'R', 'AD2': 'A', 'ADD': 'X', + 'ADX': 'N', 'AEA': 'X', 'AEI': 'D', 'AET': 'A', 'AFA': 'N', 'AFF': 'N', + 'AFG': 'G', 'AGM': 'R', 'AGT': 'C', 'AHB': 'N', 'AHH': 'X', 'AHO': 'A', + 'AHP': 'A', 'AHS': 'X', 'AHT': 'X', 'AIB': 'A', 'AKL': 'D', 'AKZ': 'D', + 'ALA': 'A', 'ALC': 'A', 'ALM': 'A', 'ALN': 'A', 'ALO': 'T', 'ALQ': 'X', + 'ALS': 'A', 'ALT': 'A', 'ALV': 'A', 'ALY': 'K', 'AN8': 'A', 'AP7': 'A', + 'APE': 'X', 'APH': 'A', 'API': 'K', 'APK': 'K', 'APM': 'X', 'APP': 'X', + 'AR2': 'R', 'AR4': 'E', 'AR7': 'R', 'ARG': 'R', 'ARM': 'R', 'ARO': 'R', + 'ARV': 'X', 'AS': 'A', 'AS2': 'D', 'AS9': 'X', 'ASA': 'D', 'ASB': 'D', + 'ASI': 'D', 'ASK': 'D', 'ASL': 'D', 'ASM': 'X', 'ASN': 'N', 'ASP': 'D', + 'ASQ': 'D', 'ASU': 'N', 'ASX': 'B', 'ATD': 'T', 'ATL': 'T', 'ATM': 'T', + 'AVC': 'A', 'AVN': 'X', 'AYA': 'A', 'AZK': 'K', 'AZS': 'S', 'AZY': 'Y', + 'B1F': 'F', 'B1P': 'N', 'B2A': 'A', 'B2F': 'F', 'B2I': 'I', 'B2V': 'V', + 'B3A': 'A', 'B3D': 'D', 'B3E': 'E', 'B3K': 'K', 'B3L': 'X', 'B3M': 'X', + 'B3Q': 'X', 'B3S': 'S', 'B3T': 'X', 'B3U': 'H', 'B3X': 'N', 'B3Y': 'Y', + 'BB6': 'C', 'BB7': 'C', 'BB8': 'F', 'BB9': 'C', 'BBC': 'C', 'BCS': 'C', + 'BE2': 'X', 'BFD': 'D', 'BG1': 'S', 'BGM': 'G', 'BH2': 'D', 'BHD': 'D', + 'BIF': 'F', 'BIL': 'X', 'BIU': 'I', 'BJH': 'X', 'BLE': 'L', 'BLY': 'K', + 'BMP': 'N', 'BMT': 'T', 'BNN': 'F', 'BNO': 'X', 'BOE': 'T', 'BOR': 'R', + 'BPE': 'C', 'BRU': 'U', 'BSE': 'S', 'BT5': 'N', 'BTA': 'L', 'BTC': 'C', + 'BTR': 'W', 'BUC': 'C', 'BUG': 'V', 'BVP': 'U', 'BZG': 'N', 'C': 'C', + 'C1X': 'K', 'C25': 'C', 'C2L': 'C', 'C2S': 'C', 'C31': 'C', 'C32': 'C', + 'C34': 'C', 'C36': 'C', 'C37': 'C', 'C38': 'C', 'C3Y': 'C', 'C42': 'C', + 'C43': 'C', 'C45': 'C', 'C46': 'C', 'C49': 'C', 'C4R': 'C', 'C4S': 'C', + 'C5C': 'C', 'C66': 'X', 'C6C': 'C', 'CAF': 'C', 'CAL': 'X', 'CAR': 'C', + 'CAS': 'C', 'CAV': 'X', 'CAY': 'C', 'CB2': 'C', 'CBR': 'C', 'CBV': 'C', + 'CCC': 'C', 'CCL': 'K', 'CCS': 'C', 'CDE': 'X', 'CDV': 'X', 'CDW': 'C', + 'CEA': 'C', 'CFL': 'C', 'CG1': 'G', 'CGA': 'E', 'CGU': 'E', 'CH': 'C', + 'CHF': 'X', 'CHG': 'X', 'CHP': 'G', 'CHS': 'X', 'CIR': 'R', 'CLE': 'L', + 'CLG': 'K', 'CLH': 'K', 'CM0': 'N', 'CME': 'C', 'CMH': 'C', 'CML': 'C', + 'CMR': 'C', 'CMT': 'C', 'CNU': 'U', 'CP1': 'C', 'CPC': 'X', 'CPI': 'X', + 'CR5': 'G', 'CS0': 'C', 'CS1': 'C', 'CS3': 'C', 'CS4': 'C', 'CS8': 'N', + 'CSA': 'C', 'CSB': 'C', 'CSD': 'C', 'CSE': 'C', 'CSF': 'C', 'CSI': 'G', + 'CSJ': 'C', 'CSL': 'C', 'CSO': 'C', 'CSP': 'C', 'CSR': 'C', 'CSS': 'C', + 'CSU': 'C', 'CSW': 'C', 'CSX': 'C', 'CSZ': 'C', 'CTE': 'W', 'CTG': 'T', + 'CTH': 'T', 'CUC': 'X', 'CWR': 'S', 'CXM': 'M', 'CY0': 'C', 'CY1': 'C', + 'CY3': 'C', 'CY4': 'C', 'CYA': 'C', 'CYD': 'C', 'CYF': 'C', 'CYG': 'C', + 'CYJ': 'X', 'CYM': 'C', 'CYQ': 'C', 'CYR': 'C', 'CYS': 'C', 'CZ2': 'C', + 'CZZ': 'C', 'D11': 'T', 'D1P': 'N', 'D3': 'N', 'D33': 'N', 'D3P': 'G', + 'D3T': 'T', 'D4M': 'T', 'D4P': 'X', 'DA': 'A', 'DA2': 'X', 'DAB': 'A', + 'DAH': 'F', 'DAL': 'A', 'DAR': 'R', 'DAS': 'D', 'DBB': 'T', 'DBM': 'N', + 'DBS': 'S', 'DBU': 'T', 'DBY': 'Y', 'DBZ': 'A', 'DC': 'C', 'DC2': 'C', + 'DCG': 'G', 'DCI': 'X', 'DCL': 'X', 'DCT': 'C', 'DCY': 'C', 'DDE': 'H', + 'DDG': 'G', 'DDN': 'U', 'DDX': 'N', 'DFC': 'C', 'DFG': 'G', 'DFI': 'X', + 'DFO': 'X', 'DFT': 'N', 'DG': 'G', 'DGH': 'G', 'DGI': 'G', 'DGL': 'E', + 'DGN': 'Q', 'DHA': 'S', 'DHI': 'H', 'DHL': 'X', 'DHN': 'V', 'DHP': 'X', + 'DHU': 'U', 'DHV': 'V', 'DI': 'I', 'DIL': 'I', 'DIR': 'R', 'DIV': 'V', + 'DLE': 'L', 'DLS': 'K', 'DLY': 'K', 'DM0': 'K', 'DMH': 'N', 'DMK': 'D', + 'DMT': 'X', 'DN': 'N', 'DNE': 'L', 'DNG': 'L', 'DNL': 'K', 'DNM': 'L', + 'DNP': 'A', 'DNR': 'C', 'DNS': 'K', 'DOA': 'X', 'DOC': 'C', 'DOH': 'D', + 'DON': 'L', 'DPB': 'T', 'DPH': 'F', 'DPL': 'P', 'DPP': 'A', 'DPQ': 'Y', + 'DPR': 'P', 'DPY': 'N', 'DRM': 'U', 'DRP': 'N', 'DRT': 'T', 'DRZ': 'N', + 'DSE': 'S', 'DSG': 'N', 'DSN': 'S', 'DSP': 'D', 'DT': 'T', 'DTH': 'T', + 'DTR': 'W', 'DTY': 'Y', 'DU': 'U', 'DVA': 'V', 'DXD': 'N', 'DXN': 'N', + 'DYS': 'C', 'DZM': 'A', 'E': 'A', 'E1X': 'A', 'ECC': 'Q', 'EDA': 'A', + 'EFC': 'C', 'EHP': 'F', 'EIT': 'T', 'ENP': 'N', 'ESB': 'Y', 'ESC': 'M', + 'EXB': 'X', 'EXY': 'L', 'EY5': 'N', 'EYS': 'X', 'F2F': 'F', 'FA2': 'A', + 'FA5': 'N', 'FAG': 'N', 'FAI': 'N', 'FB5': 'A', 'FB6': 'A', 'FCL': 'F', + 'FFD': 'N', 'FGA': 'E', 'FGL': 'G', 'FGP': 'S', 'FHL': 'X', 'FHO': 'K', + 'FHU': 'U', 'FLA': 'A', 'FLE': 'L', 'FLT': 'Y', 'FME': 'M', 'FMG': 'G', + 'FMU': 'N', 'FOE': 'C', 'FOX': 'G', 'FP9': 'P', 'FPA': 'F', 'FRD': 'X', + 'FT6': 'W', 'FTR': 'W', 'FTY': 'Y', 'FVA': 'V', 'FZN': 'K', 'G': 'G', + 'G25': 'G', 'G2L': 'G', 'G2S': 'G', 'G31': 'G', 'G32': 'G', 'G33': 'G', + 'G36': 'G', 'G38': 'G', 'G42': 'G', 'G46': 'G', 'G47': 'G', 'G48': 'G', + 'G49': 'G', 'G4P': 'N', 'G7M': 'G', 'GAO': 'G', 'GAU': 'E', 'GCK': 'C', + 'GCM': 'X', 'GDP': 'G', 'GDR': 'G', 'GFL': 'G', 'GGL': 'E', 'GH3': 'G', + 'GHG': 'Q', 'GHP': 'G', 'GL3': 'G', 'GLH': 'Q', 'GLJ': 'E', 'GLK': 'E', + 'GLM': 'X', 'GLN': 'Q', 'GLQ': 'E', 'GLU': 'E', 'GLX': 'Z', 'GLY': 'G', + 'GLZ': 'G', 'GMA': 'E', 'GMS': 'G', 'GMU': 'U', 'GN7': 'G', 'GND': 'X', + 'GNE': 'N', 'GOM': 'G', 'GPL': 'K', 'GS': 'G', 'GSC': 'G', 'GSR': 'G', + 'GSS': 'G', 'GSU': 'E', 'GT9': 'C', 'GTP': 'G', 'GVL': 'X', 'H2U': 'U', + 'H5M': 'P', 'HAC': 'A', 'HAR': 'R', 'HBN': 'H', 'HCS': 'X', 'HDP': 'U', + 'HEU': 'U', 'HFA': 'X', 'HGL': 'X', 'HHI': 'H', 'HIA': 'H', 'HIC': 'H', + 'HIP': 'H', 'HIQ': 'H', 'HIS': 'H', 'HL2': 'L', 'HLU': 'L', 'HMR': 'R', + 'HOL': 'N', 'HPC': 'F', 'HPE': 'F', 'HPH': 'F', 'HPQ': 'F', 'HQA': 'A', + 'HRG': 'R', 'HRP': 'W', 'HS8': 'H', 'HS9': 'H', 'HSE': 'S', 'HSL': 'S', + 'HSO': 'H', 'HTI': 'C', 'HTN': 'N', 'HTR': 'W', 'HV5': 'A', 'HVA': 'V', + 'HY3': 'P', 'HYP': 'P', 'HZP': 'P', 'I': 'I', 'I2M': 'I', 'I58': 'K', + 'I5C': 'C', 'IAM': 'A', 'IAR': 'R', 'IAS': 'D', 'IC': 'C', 'IEL': 'K', + 'IG': 'G', 'IGL': 'G', 'IGU': 'G', 'IIL': 'I', 'ILE': 'I', 'ILG': 'E', + 'ILX': 'I', 'IMC': 'C', 'IML': 'I', 'IOY': 'F', 'IPG': 'G', 'IPN': 'N', + 'IRN': 'N', 'IT1': 'K', 'IU': 'U', 'IYR': 'Y', 'IYT': 'T', 'IZO': 'M', + 'JJJ': 'C', 'JJK': 'C', 'JJL': 'C', 'JW5': 'N', 'K1R': 'C', 'KAG': 'G', + 'KCX': 'K', 'KGC': 'K', 'KNB': 'A', 'KOR': 'M', 'KPI': 'K', 'KST': 'K', + 'KYQ': 'K', 'L2A': 'X', 'LA2': 'K', 'LAA': 'D', 'LAL': 'A', 'LBY': 'K', + 'LC': 'C', 'LCA': 'A', 'LCC': 'N', 'LCG': 'G', 'LCH': 'N', 'LCK': 'K', + 'LCX': 'K', 'LDH': 'K', 'LED': 'L', 'LEF': 'L', 'LEH': 'L', 'LEI': 'V', + 'LEM': 'L', 'LEN': 'L', 'LET': 'X', 'LEU': 'L', 'LEX': 'L', 'LG': 'G', + 'LGP': 'G', 'LHC': 'X', 'LHU': 'U', 'LKC': 'N', 'LLP': 'K', 'LLY': 'K', + 'LME': 'E', 'LMF': 'K', 'LMQ': 'Q', 'LMS': 'N', 'LP6': 'K', 'LPD': 'P', + 'LPG': 'G', 'LPL': 'X', 'LPS': 'S', 'LSO': 'X', 'LTA': 'X', 'LTR': 'W', + 'LVG': 'G', 'LVN': 'V', 'LYF': 'K', 'LYK': 'K', 'LYM': 'K', 'LYN': 'K', + 'LYR': 'K', 'LYS': 'K', 'LYX': 'K', 'LYZ': 'K', 'M0H': 'C', 'M1G': 'G', + 'M2G': 'G', 'M2L': 'K', 'M2S': 'M', 'M30': 'G', 'M3L': 'K', 'M5M': 'C', + 'MA': 'A', 'MA6': 'A', 'MA7': 'A', 'MAA': 'A', 'MAD': 'A', 'MAI': 'R', + 'MBQ': 'Y', 'MBZ': 'N', 'MC1': 'S', 'MCG': 'X', 'MCL': 'K', 'MCS': 'C', + 'MCY': 'C', 'MD3': 'C', 'MD6': 'G', 'MDH': 'X', 'MDR': 'N', 'MEA': 'F', + 'MED': 'M', 'MEG': 'E', 'MEN': 'N', 'MEP': 'U', 'MEQ': 'Q', 'MET': 'M', + 'MEU': 'G', 'MF3': 'X', 'MG1': 'G', 'MGG': 'R', 'MGN': 'Q', 'MGQ': 'A', + 'MGV': 'G', 'MGY': 'G', 'MHL': 'L', 'MHO': 'M', 'MHS': 'H', 'MIA': 'A', + 'MIS': 'S', 'MK8': 'L', 'ML3': 'K', 'MLE': 'L', 'MLL': 'L', 'MLY': 'K', + 'MLZ': 'K', 'MME': 'M', 'MMO': 'R', 'MMT': 'T', 'MND': 'N', 'MNL': 'L', + 'MNU': 'U', 'MNV': 'V', 'MOD': 'X', 'MP8': 'P', 'MPH': 'X', 'MPJ': 'X', + 'MPQ': 'G', 'MRG': 'G', 'MSA': 'G', 'MSE': 'M', 'MSL': 'M', 'MSO': 'M', + 'MSP': 'X', 'MT2': 'M', 'MTR': 'T', 'MTU': 'A', 'MTY': 'Y', 'MVA': 'V', + 'N': 'N', 'N10': 'S', 'N2C': 'X', 'N5I': 'N', 'N5M': 'C', 'N6G': 'G', + 'N7P': 'P', 'NA8': 'A', 'NAL': 'A', 'NAM': 'A', 'NB8': 'N', 'NBQ': 'Y', + 'NC1': 'S', 'NCB': 'A', 'NCX': 'N', 'NCY': 'X', 'NDF': 'F', 'NDN': 'U', + 'NEM': 'H', 'NEP': 'H', 'NF2': 'N', 'NFA': 'F', 'NHL': 'E', 'NIT': 'X', + 'NIY': 'Y', 'NLE': 'L', 'NLN': 'L', 'NLO': 'L', 'NLP': 'L', 'NLQ': 'Q', + 'NMC': 'G', 'NMM': 'R', 'NMS': 'T', 'NMT': 'T', 'NNH': 'R', 'NP3': 'N', + 'NPH': 'C', 'NPI': 'A', 'NSK': 'X', 'NTY': 'Y', 'NVA': 'V', 'NYM': 'N', + 'NYS': 'C', 'NZH': 'H', 'O12': 'X', 'O2C': 'N', 'O2G': 'G', 'OAD': 'N', + 'OAS': 'S', 'OBF': 'X', 'OBS': 'X', 'OCS': 'C', 'OCY': 'C', 'ODP': 'N', + 'OHI': 'H', 'OHS': 'D', 'OIC': 'X', 'OIP': 'I', 'OLE': 'X', 'OLT': 'T', + 'OLZ': 'S', 'OMC': 'C', 'OMG': 'G', 'OMT': 'M', 'OMU': 'U', 'ONE': 'U', + 'ONH': 'A', 'ONL': 'X', 'OPR': 'R', 'ORN': 'A', 'ORQ': 'R', 'OSE': 'S', + 'OTB': 'X', 'OTH': 'T', 'OTY': 'Y', 'OXX': 'D', 'P': 'G', 'P1L': 'C', + 'P1P': 'N', 'P2T': 'T', 'P2U': 'U', 'P2Y': 'P', 'P5P': 'A', 'PAQ': 'Y', + 'PAS': 'D', 'PAT': 'W', 'PAU': 'A', 'PBB': 'C', 'PBF': 'F', 'PBT': 'N', + 'PCA': 'E', 'PCC': 'P', 'PCE': 'X', 'PCS': 'F', 'PDL': 'X', 'PDU': 'U', + 'PEC': 'C', 'PF5': 'F', 'PFF': 'F', 'PFX': 'X', 'PG1': 'S', 'PG7': 'G', + 'PG9': 'G', 'PGL': 'X', 'PGN': 'G', 'PGP': 'G', 'PGY': 'G', 'PHA': 'F', + 'PHD': 'D', 'PHE': 'F', 'PHI': 'F', 'PHL': 'F', 'PHM': 'F', 'PIV': 'X', + 'PLE': 'L', 'PM3': 'F', 'PMT': 'C', 'POM': 'P', 'PPN': 'F', 'PPU': 'A', + 'PPW': 'G', 'PQ1': 'N', 'PR3': 'C', 'PR5': 'A', 'PR9': 'P', 'PRN': 'A', + 'PRO': 'P', 'PRS': 'P', 'PSA': 'F', 'PSH': 'H', 'PST': 'T', 'PSU': 'U', + 'PSW': 'C', 'PTA': 'X', 'PTH': 'Y', 'PTM': 'Y', 'PTR': 'Y', 'PU': 'A', + 'PUY': 'N', 'PVH': 'H', 'PVL': 'X', 'PYA': 'A', 'PYO': 'U', 'PYX': 'C', + 'PYY': 'N', 'QMM': 'Q', 'QPA': 'C', 'QPH': 'F', 'QUO': 'G', 'R': 'A', + 'R1A': 'C', 'R4K': 'W', 'RE0': 'W', 'RE3': 'W', 'RIA': 'A', 'RMP': 'A', + 'RON': 'X', 'RT': 'T', 'RTP': 'N', 'S1H': 'S', 'S2C': 'C', 'S2D': 'A', + 'S2M': 'T', 'S2P': 'A', 'S4A': 'A', 'S4C': 'C', 'S4G': 'G', 'S4U': 'U', + 'S6G': 'G', 'SAC': 'S', 'SAH': 'C', 'SAR': 'G', 'SBL': 'S', 'SC': 'C', + 'SCH': 'C', 'SCS': 'C', 'SCY': 'C', 'SD2': 'X', 'SDG': 'G', 'SDP': 'S', + 'SEB': 'S', 'SEC': 'A', 'SEG': 'A', 'SEL': 'S', 'SEM': 'S', 'SEN': 'S', + 'SEP': 'S', 'SER': 'S', 'SET': 'S', 'SGB': 'S', 'SHC': 'C', 'SHP': 'G', + 'SHR': 'K', 'SIB': 'C', 'SLA': 'P', 'SLR': 'P', 'SLZ': 'K', 'SMC': 'C', + 'SME': 'M', 'SMF': 'F', 'SMP': 'A', 'SMT': 'T', 'SNC': 'C', 'SNN': 'N', + 'SOC': 'C', 'SOS': 'N', 'SOY': 'S', 'SPT': 'T', 'SRA': 'A', 'SSU': 'U', + 'STY': 'Y', 'SUB': 'X', 'SUN': 'S', 'SUR': 'U', 'SVA': 'S', 'SVV': 'S', + 'SVW': 'S', 'SVX': 'S', 'SVY': 'S', 'SVZ': 'X', 'SYS': 'C', 'T': 'T', + 'T11': 'F', 'T23': 'T', 'T2S': 'T', 'T2T': 'N', 'T31': 'U', 'T32': 'T', + 'T36': 'T', 'T37': 'T', 'T38': 'T', 'T39': 'T', 'T3P': 'T', 'T41': 'T', + 'T48': 'T', 'T49': 'T', 'T4S': 'T', 'T5O': 'U', 'T5S': 'T', 'T66': 'X', + 'T6A': 'A', 'TA3': 'T', 'TA4': 'X', 'TAF': 'T', 'TAL': 'N', 'TAV': 'D', + 'TBG': 'V', 'TBM': 'T', 'TC1': 'C', 'TCP': 'T', 'TCQ': 'Y', 'TCR': 'W', + 'TCY': 'A', 'TDD': 'L', 'TDY': 'T', 'TFE': 'T', 'TFO': 'A', 'TFQ': 'F', + 'TFT': 'T', 'TGP': 'G', 'TH6': 'T', 'THC': 'T', 'THO': 'X', 'THR': 'T', + 'THX': 'N', 'THZ': 'R', 'TIH': 'A', 'TLB': 'N', 'TLC': 'T', 'TLN': 'U', + 'TMB': 'T', 'TMD': 'T', 'TNB': 'C', 'TNR': 'S', 'TOX': 'W', 'TP1': 'T', + 'TPC': 'C', 'TPG': 'G', 'TPH': 'X', 'TPL': 'W', 'TPO': 'T', 'TPQ': 'Y', + 'TQI': 'W', 'TQQ': 'W', 'TRF': 'W', 'TRG': 'K', 'TRN': 'W', 'TRO': 'W', + 'TRP': 'W', 'TRQ': 'W', 'TRW': 'W', 'TRX': 'W', 'TS': 'N', 'TST': 'X', + 'TT': 'N', 'TTD': 'T', 'TTI': 'U', 'TTM': 'T', 'TTQ': 'W', 'TTS': 'Y', + 'TY1': 'Y', 'TY2': 'Y', 'TY3': 'Y', 'TY5': 'Y', 'TYB': 'Y', 'TYI': 'Y', + 'TYJ': 'Y', 'TYN': 'Y', 'TYO': 'Y', 'TYQ': 'Y', 'TYR': 'Y', 'TYS': 'Y', + 'TYT': 'Y', 'TYU': 'N', 'TYW': 'Y', 'TYX': 'X', 'TYY': 'Y', 'TZB': 'X', + 'TZO': 'X', 'U': 'U', 'U25': 'U', 'U2L': 'U', 'U2N': 'U', 'U2P': 'U', + 'U31': 'U', 'U33': 'U', 'U34': 'U', 'U36': 'U', 'U37': 'U', 'U8U': 'U', + 'UAR': 'U', 'UCL': 'U', 'UD5': 'U', 'UDP': 'N', 'UFP': 'N', 'UFR': 'U', + 'UFT': 'U', 'UMA': 'A', 'UMP': 'U', 'UMS': 'U', 'UN1': 'X', 'UN2': 'X', + 'UNK': 'X', 'UR3': 'U', 'URD': 'U', 'US1': 'U', 'US2': 'U', 'US3': 'T', + 'US5': 'U', 'USM': 'U', 'VAD': 'V', 'VAF': 'V', 'VAL': 'V', 'VB1': 'K', + 'VDL': 'X', 'VLL': 'X', 'VLM': 'X', 'VMS': 'X', 'VOL': 'X', 'X': 'G', + 'X2W': 'E', 'X4A': 'N', 'XAD': 'A', 'XAE': 'N', 'XAL': 'A', 'XAR': 'N', + 'XCL': 'C', 'XCN': 'C', 'XCP': 'X', 'XCR': 'C', 'XCS': 'N', 'XCT': 'C', + 'XCY': 'C', 'XGA': 'N', 'XGL': 'G', 'XGR': 'G', 'XGU': 'G', 'XPR': 'P', + 'XSN': 'N', 'XTH': 'T', 'XTL': 'T', 'XTR': 'T', 'XTS': 'G', 'XTY': 'N', + 'XUA': 'A', 'XUG': 'G', 'XX1': 'K', 'Y': 'A', 'YCM': 'C', 'YG': 'G', + 'YOF': 'Y', 'YRR': 'N', 'YYG': 'G', 'Z': 'C', 'Z01': 'A', 'ZAD': 'A', + 'ZAL': 'A', 'ZBC': 'C', 'ZBU': 'U', 'ZCL': 'F', 'ZCY': 'C', 'ZDU': 'U', + 'ZFB': 'X', 'ZGU': 'G', 'ZHP': 'N', 'ZTH': 'T', 'ZU0': 'T', 'ZZJ': 'A', +} +# common_typos_enable +# pyformat: enable + + +@functools.lru_cache(maxsize=64) +def letters_three_to_one(restype: str, *, default: str) -> str: + """Returns single letter name if one exists otherwise returns default.""" + return CCD_NAME_TO_ONE_LETTER.get(restype, default) + + +ALA = sys.intern('ALA') +ARG = sys.intern('ARG') +ASN = sys.intern('ASN') +ASP = sys.intern('ASP') +CYS = sys.intern('CYS') +GLN = sys.intern('GLN') +GLU = sys.intern('GLU') +GLY = sys.intern('GLY') +HIS = sys.intern('HIS') +ILE = sys.intern('ILE') +LEU = sys.intern('LEU') +LYS = sys.intern('LYS') +MET = sys.intern('MET') +PHE = sys.intern('PHE') +PRO = sys.intern('PRO') +SER = sys.intern('SER') +THR = sys.intern('THR') +TRP = sys.intern('TRP') +TYR = sys.intern('TYR') +VAL = sys.intern('VAL') +UNK = sys.intern('UNK') +GAP = sys.intern('-') + +# Unknown ligand. +UNL = sys.intern('UNL') + +# Non-standard version of MET (with Se instead of S), but often appears in PDB. +MSE = sys.intern('MSE') + +# 20 standard protein amino acids (no unknown). +PROTEIN_TYPES: tuple[str, ...] = ( + ALA, ARG, ASN, ASP, CYS, GLN, GLU, GLY, HIS, ILE, LEU, LYS, MET, PHE, PRO, + SER, THR, TRP, TYR, VAL, +) # pyformat: disable + +# 20 standard protein amino acids plus the unknown (UNK) amino acid. +PROTEIN_TYPES_WITH_UNKNOWN: tuple[str, ...] = PROTEIN_TYPES + (UNK,) + +# This is the standard residue order when coding AA type as a number. +# Reproduce it by taking 3-letter AA codes and sorting them alphabetically. +# For legacy reasons this only refers to protein residues. + +PROTEIN_TYPES_ONE_LETTER: tuple[str, ...] = ( + 'A', 'R', 'N', 'D', 'C', 'Q', 'E', 'G', 'H', 'I', 'L', 'K', 'M', 'F', 'P', + 'S', 'T', 'W', 'Y', 'V', +) # pyformat: disable + +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = ( + PROTEIN_TYPES_ONE_LETTER + ('X',) +) +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN + (GAP,) +) + +PROTEIN_TYPES_ONE_LETTER_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER) +} +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN) +} + +PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP) +} + + +PROTEIN_COMMON_ONE_TO_THREE: Mapping[str, str] = { + 'A': ALA, + 'R': ARG, + 'N': ASN, + 'D': ASP, + 'C': CYS, + 'Q': GLN, + 'E': GLU, + 'G': GLY, + 'H': HIS, + 'I': ILE, + 'L': LEU, + 'K': LYS, + 'M': MET, + 'F': PHE, + 'P': PRO, + 'S': SER, + 'T': THR, + 'W': TRP, + 'Y': TYR, + 'V': VAL, +} + +PROTEIN_COMMON_THREE_TO_ONE: Mapping[str, str] = { + v: k for k, v in PROTEIN_COMMON_ONE_TO_THREE.items() +} + +A = sys.intern('A') +G = sys.intern('G') +C = sys.intern('C') +U = sys.intern('U') +T = sys.intern('T') + +DA = sys.intern('DA') +DG = sys.intern('DG') +DC = sys.intern('DC') +DT = sys.intern('DT') + +UNK_NUCLEIC_ONE_LETTER = sys.intern('N') # Unknown nucleic acid single letter. +UNK_RNA = sys.intern('N') # Unknown RNA. +UNK_DNA = sys.intern('DN') # Unknown DNA residue (differs from N). + +RNA_TYPES: tuple[str, ...] = (A, G, C, U) +DNA_TYPES: tuple[str, ...] = (DA, DG, DC, DT) + +NUCLEIC_TYPES: tuple[str, ...] = RNA_TYPES + DNA_TYPES +# Without UNK DNA. +NUCLEIC_TYPES_WITH_UNKNOWN: tuple[str, ...] = NUCLEIC_TYPES + ( + UNK_NUCLEIC_ONE_LETTER, +) +NUCLEIC_TYPES_WITH_2_UNKS: tuple[str, ...] = NUCLEIC_TYPES + ( + UNK_RNA, + UNK_DNA, +) + +RNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = RNA_TYPES + (UNK_RNA,) +RNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(RNA_TYPES_ONE_LETTER_WITH_UNKNOWN) +} + +DNA_TYPES_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES + (UNK_DNA,) +DNA_TYPES_ONE_LETTER: tuple[str, ...] = (A, G, C, T) +DNA_TYPES_ONE_LETTER_WITH_UNKNOWN: tuple[str, ...] = DNA_TYPES_ONE_LETTER + ( + UNK_NUCLEIC_ONE_LETTER, +) +DNA_TYPES_ONE_LETTER_WITH_UNKNOWN_TO_INT: Mapping[str, int] = { + r: i for i, r in enumerate(DNA_TYPES_ONE_LETTER_WITH_UNKNOWN) +} +DNA_COMMON_ONE_TO_TWO: Mapping[str, str] = { + 'A': 'DA', + 'G': 'DG', + 'C': 'DC', + 'T': 'DT', +} + +STANDARD_POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES + NUCLEIC_TYPES +POLYMER_TYPES: tuple[str, ...] = PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES +POLYMER_TYPES_WITH_UNKNOWN: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + NUCLEIC_TYPES_WITH_UNKNOWN +) +POLYMER_TYPES_WITH_GAP: tuple[str, ...] = PROTEIN_TYPES + \ + (GAP,) + NUCLEIC_TYPES +POLYMER_TYPES_WITH_UNKNOWN_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_UNKNOWN +) +POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP: tuple[str, ...] = ( + PROTEIN_TYPES_WITH_UNKNOWN + (GAP,) + NUCLEIC_TYPES_WITH_2_UNKS +) + +POLYMER_TYPES_ORDER = {restype: i for i, restype in enumerate(POLYMER_TYPES)} + +POLYMER_TYPES_ORDER_WITH_UNKNOWN = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN) +} + +POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_UNKNOWN_AND_GAP) +} + +POLYMER_TYPES_ORDER_WITH_ALL_UNKS_AND_GAP = { + restype: i for i, restype in enumerate(POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP) +} + +POLYMER_TYPES_NUM = len(POLYMER_TYPES) # := 29. +POLYMER_TYPES_NUM_WITH_UNKNOWN = len(POLYMER_TYPES_WITH_UNKNOWN) # := 30. +POLYMER_TYPES_NUM_WITH_GAP = len(POLYMER_TYPES_WITH_GAP) # := 29. +POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP = len( + POLYMER_TYPES_WITH_UNKNOWN_AND_GAP +) # := 31. +POLYMER_TYPES_NUM_ORDER_WITH_ALL_UNKS_AND_GAP = len( + POLYMER_TYPES_WITH_ALL_UNKS_AND_GAP +) # := 32. + +WATER_TYPES: tuple[str, ...] = ('HOH', 'DOD') + +UNKNOWN_TYPES: tuple[str, ...] = (UNK, UNK_RNA, UNK_DNA, UNL) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py new file mode 100644 index 000000000..0e8cd1297 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/constants/side_chains.py @@ -0,0 +1,112 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Constants associated with side chains.""" + +from collections.abc import Mapping, Sequence +import itertools + +# Format: The list for each AA type contains chi1, chi2, chi3, chi4 in +# this order (or a relevant subset from chi1 onwards). ALA and GLY don't have +# chi angles so their chi angle lists are empty. +CHI_ANGLES_ATOMS: Mapping[str, Sequence[tuple[str, ...]]] = { + 'ALA': [], + # Chi5 in arginine is always 0 +- 5 degrees, so ignore it. + 'ARG': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'NE'), + ('CG', 'CD', 'NE', 'CZ'), + ], + 'ASN': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')], + 'ASP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'OD1')], + 'CYS': [('N', 'CA', 'CB', 'SG')], + 'GLN': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'OE1'), + ], + 'GLU': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'OE1'), + ], + 'GLY': [], + 'HIS': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'ND1')], + 'ILE': [('N', 'CA', 'CB', 'CG1'), ('CA', 'CB', 'CG1', 'CD1')], + 'LEU': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'LYS': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'CD'), + ('CB', 'CG', 'CD', 'CE'), + ('CG', 'CD', 'CE', 'NZ'), + ], + 'MET': [ + ('N', 'CA', 'CB', 'CG'), + ('CA', 'CB', 'CG', 'SD'), + ('CB', 'CG', 'SD', 'CE'), + ], + 'PHE': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'PRO': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD')], + 'SER': [('N', 'CA', 'CB', 'OG')], + 'THR': [('N', 'CA', 'CB', 'OG1')], + 'TRP': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'TYR': [('N', 'CA', 'CB', 'CG'), ('CA', 'CB', 'CG', 'CD1')], + 'VAL': [('N', 'CA', 'CB', 'CG1')], +} + +CHI_GROUPS_FOR_ATOM = {} +for res_name, chi_angle_atoms_for_res in CHI_ANGLES_ATOMS.items(): + for chi_group_i, chi_group in enumerate(chi_angle_atoms_for_res): + for atom_i, atom in enumerate(chi_group): + CHI_GROUPS_FOR_ATOM.setdefault((res_name, atom), []).append( + (chi_group_i, atom_i) + ) + +# Mapping from (residue_name, atom_name) pairs to the atom's chi group index +# and atom index within that group. +CHI_GROUPS_FOR_ATOM: Mapping[tuple[str, str], Sequence[tuple[int, int]]] = ( + CHI_GROUPS_FOR_ATOM +) + +MAX_NUM_CHI_ANGLES: int = 4 +ATOMS_PER_CHI_ANGLE: int = 4 + +# A list of atoms for each AA type that are involved in chi angle calculations. +CHI_ATOM_SETS: Mapping[str, set[str]] = { + residue_name: set(itertools.chain(*atoms)) + for residue_name, atoms in CHI_ANGLES_ATOMS.items() +} + +# If chi angles given in fixed-length array, this matrix determines how to mask +# them for each AA type. The order is as per restype_order (see below). +CHI_ANGLES_MASK: Sequence[Sequence[float]] = ( + (0.0, 0.0, 0.0, 0.0), # ALA + (1.0, 1.0, 1.0, 1.0), # ARG + (1.0, 1.0, 0.0, 0.0), # ASN + (1.0, 1.0, 0.0, 0.0), # ASP + (1.0, 0.0, 0.0, 0.0), # CYS + (1.0, 1.0, 1.0, 0.0), # GLN + (1.0, 1.0, 1.0, 0.0), # GLU + (0.0, 0.0, 0.0, 0.0), # GLY + (1.0, 1.0, 0.0, 0.0), # HIS + (1.0, 1.0, 0.0, 0.0), # ILE + (1.0, 1.0, 0.0, 0.0), # LEU + (1.0, 1.0, 1.0, 1.0), # LYS + (1.0, 1.0, 1.0, 0.0), # MET + (1.0, 1.0, 0.0, 0.0), # PHE + (1.0, 1.0, 0.0, 0.0), # PRO + (1.0, 0.0, 0.0, 0.0), # SER + (1.0, 0.0, 0.0, 0.0), # THR + (1.0, 1.0, 0.0, 0.0), # TRP + (1.0, 1.0, 0.0, 0.0), # TYR + (1.0, 0.0, 0.0, 0.0), # VAL +) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc new file mode 100644 index 000000000..b2286b5c3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/cpp.cc @@ -0,0 +1,48 @@ +/* +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + */ + +#include "alphafold3/data/cpp/msa_profile_pybind.h" +#include "alphafold3/model/mkdssp_pybind.h" +#include "alphafold3/parsers/cpp/cif_dict_pybind.h" +#include "alphafold3/parsers/cpp/fasta_iterator_pybind.h" +#include "alphafold3/parsers/cpp/msa_conversion_pybind.h" +#include "alphafold3/structure/cpp/aggregation_pybind.h" +#include "alphafold3/structure/cpp/membership_pybind.h" +#include "alphafold3/structure/cpp/mmcif_atom_site_pybind.h" +#include "alphafold3/structure/cpp/mmcif_layout_pybind.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn_pybind.h" +#include "alphafold3/structure/cpp/mmcif_utils_pybind.h" +#include "alphafold3/structure/cpp/string_array_pybind.h" +#include "pybind11/pybind11.h" + +namespace alphafold3 { +namespace { + +// Include all modules as submodules to simplify building. +PYBIND11_MODULE(cpp, m) { + RegisterModuleCifDict(m.def_submodule("cif_dict")); + RegisterModuleFastaIterator(m.def_submodule("fasta_iterator")); + RegisterModuleMsaConversion(m.def_submodule("msa_conversion")); + RegisterModuleMmcifLayout(m.def_submodule("mmcif_layout")); + RegisterModuleMmcifStructConn(m.def_submodule("mmcif_struct_conn")); + RegisterModuleMembership(m.def_submodule("membership")); + RegisterModuleMmcifUtils(m.def_submodule("mmcif_utils")); + RegisterModuleAggregation(m.def_submodule("aggregation")); + RegisterModuleStringArray(m.def_submodule("string_array")); + RegisterModuleMmcifAtomSite(m.def_submodule("mmcif_atom_site")); + RegisterModuleMkdssp(m.def_submodule("mkdssp")); + RegisterModuleMsaProfile(m.def_submodule("msa_profile")); +} + +} // namespace +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc new file mode 100644 index 000000000..83b86f4e2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.cc @@ -0,0 +1,79 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/strings/str_cat.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" + +namespace { + +namespace py = pybind11; + +py::array_t ComputeMsaProfile( + const py::array_t& msa, int num_residue_types) { + if (msa.size() == 0) { + throw py::value_error("The MSA must be non-empty."); + } + if (msa.ndim() != 2) { + throw py::value_error(absl::StrCat("The MSA must be rectangular, got ", + msa.ndim(), "-dimensional MSA array.")); + } + const int msa_depth = msa.shape()[0]; + const int sequence_length = msa.shape()[1]; + + py::array_t profile({sequence_length, num_residue_types}); + std::fill(profile.mutable_data(), profile.mutable_data() + profile.size(), + 0.0f); + auto profile_unchecked = profile.mutable_unchecked<2>(); + + const double normalized_count = 1.0 / msa_depth; + const int* msa_it = msa.data(); + for (int row_index = 0; row_index < msa_depth; ++row_index) { + for (int column_index = 0; column_index < sequence_length; ++column_index) { + const int residue_code = *(msa_it++); + if (residue_code < 0 || residue_code >= num_residue_types) { + throw py::value_error( + absl::StrCat("All residue codes must be positive and smaller than " + "num_residue_types ", + num_residue_types, ", got ", residue_code)); + } + profile_unchecked(column_index, residue_code) += normalized_count; + } + } + return profile; +} + +constexpr char kComputeMsaProfileDoc[] = R"( +Computes MSA profile for the given encoded MSA. + +Args: + msa: A Numpy array of shape (num_msa, num_res) with the integer coded MSA. + num_residue_types: Integer that determines the number of unique residue types. + This will determine the shape of the output profile. + +Returns: + A float Numpy array of shape (num_res, num_residue_types) with residue + frequency (residue type count normalized by MSA depth) for every column of the + MSA. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMsaProfile(pybind11::module m) { + m.def("compute_msa_profile", &ComputeMsaProfile, py::arg("msa"), + py::arg("num_residue_types"), py::doc(kComputeMsaProfileDoc + 1)); +} + +} // namespace alphafold3 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h new file mode 100644 index 000000000..1145d331b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/cpp/msa_profile_pybind.h @@ -0,0 +1,25 @@ +/* +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMsaProfile(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_DATA_PYTHON_MSA_PROFILE_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py new file mode 100644 index 000000000..e0e63af55 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/featurisation.py @@ -0,0 +1,90 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""AlphaFold 3 featurisation pipeline.""" + +from collections.abc import Sequence +import datetime +import time + +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.model import features +from alphafold3.model.pipeline import pipeline +import numpy as np + + +def validate_fold_input(fold_input: folding_input.Input): + """Validates the fold input contains MSA and templates for featurisation.""" + for i, chain in enumerate(fold_input.protein_chains): + if chain.unpaired_msa is None: + raise ValueError(f'Protein chain {i + 1} is missing unpaired MSA.') + if chain.paired_msa is None: + raise ValueError(f'Protein chain {i + 1} is missing paired MSA.') + if chain.templates is None: + raise ValueError(f'Protein chain {i + 1} is missing Templates.') + for i, chain in enumerate(fold_input.rna_chains): + if chain.unpaired_msa is None: + raise ValueError(f'RNA chain {i + 1} is missing unpaired MSA.') + + +def featurise_input( + fold_input: folding_input.Input, + ccd: chemical_components.Ccd, + buckets: Sequence[int] | None, + max_template_date: datetime.date | None = None, + verbose: bool = False, +) -> Sequence[features.BatchDict]: + """Featurise the folding input. + + Args: + fold_input: The input to featurise. + ccd: The chemical components dictionary. + buckets: Bucket sizes to pad the data to, to avoid excessive re-compilation + of the model. If None, calculate the appropriate bucket size from the + number of tokens. If not None, must be a sequence of at least one integer, + in strictly increasing order. Will raise an error if the number of tokens + is more than the largest bucket size. + max_template_date: Optional max template date to prevent data leakage in + validation. + verbose: Whether to print progress messages. + + Returns: + A featurised batch for each rng_seed in the input. + """ + validate_fold_input(fold_input) + + # Set up data pipeline for single use. + data_pipeline = pipeline.WholePdbPipeline( + config=pipeline.WholePdbPipeline.Config( + buckets=buckets, max_template_date=max_template_date + ), + ) + + batches = [] + for rng_seed in fold_input.rng_seeds: + featurisation_start_time = time.time() + if verbose: + print(f'Featurising {fold_input.name} with rng_seed {rng_seed}.') + batch = data_pipeline.process_item( + fold_input=fold_input, + ccd=ccd, + random_state=np.random.RandomState(rng_seed), + random_seed=rng_seed, + ) + if verbose: + print( + f'Featurising {fold_input.name} with rng_seed {rng_seed} ' + f'took {time.time() - featurisation_start_time:.2f} seconds.' + ) + batches.append(batch) + + return batches diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py new file mode 100644 index 000000000..b77c93e44 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa.py @@ -0,0 +1,344 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions for getting MSA and calculating alignment features.""" + +from collections.abc import MutableMapping, Sequence +import string +from typing import Self + +from absl import logging +from alphafold3.constants import mmcif_names +from alphafold3.data import msa_config +from alphafold3.data import msa_features +from alphafold3.data import parsers +from alphafold3.data.tools import jackhmmer +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import nhmmer +import numpy as np + + +class Error(Exception): + """Error indicatating a problem with MSA Search.""" + + +def _featurize(seq: str, chain_poly_type: str) -> str | list[int]: + if mmcif_names.is_standard_polymer_type(chain_poly_type): + featurized_seqs, _ = msa_features.extract_msa_features( + msa_sequences=[seq], chain_poly_type=chain_poly_type + ) + return featurized_seqs[0].tolist() + # For anything else simply require an identical match. + return seq + + +def sequences_are_feature_equivalent( + sequence1: str, + sequence2: str, + chain_poly_type: str, +) -> bool: + feat1 = _featurize(sequence1, chain_poly_type) + feat2 = _featurize(sequence2, chain_poly_type) + return feat1 == feat2 + + +class Msa: + """Multiple Sequence Alignment container with methods for manipulating it.""" + + def __init__( + self, + query_sequence: str, + chain_poly_type: str, + sequences: Sequence[str], + descriptions: Sequence[str], + deduplicate: bool = True, + ): + """Raw constructor, prefer using the from_{a3m,multiple_msas} class methods. + + The first sequence must be equal (in featurised form) to the query sequence. + If sequences/descriptions are empty, they will be initialised to the query. + + Args: + query_sequence: The sequence that was used to search for MSA. + chain_poly_type: Polymer type of the query sequence, see mmcif_names. + sequences: The sequences returned by the MSA search tool. + descriptions: Metadata for the sequences returned by the MSA search tool. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + """ + if len(sequences) != len(descriptions): + raise ValueError('The number of sequences and descriptions must match.') + + self.query_sequence = query_sequence + self.chain_poly_type = chain_poly_type + + if not deduplicate: + self.sequences = sequences + self.descriptions = descriptions + else: + self.sequences = [] + self.descriptions = [] + # A replacement table that removes all lowercase characters. + deletion_table = str.maketrans('', '', string.ascii_lowercase) + unique_sequences = set() + for seq, desc in zip(sequences, descriptions, strict=True): + # Using string.translate is faster than re.sub('[a-z]+', ''). + sequence_no_deletions = seq.translate(deletion_table) + if sequence_no_deletions not in unique_sequences: + unique_sequences.add(sequence_no_deletions) + self.sequences.append(seq) + self.descriptions.append(desc) + + # Make sure the MSA always has at least the query. + self.sequences = self.sequences or [query_sequence] + self.descriptions = self.descriptions or ['Original query'] + + # Check if the 1st MSA sequence matches the query sequence. Since it may be + # mutated by the search tool (jackhmmer) check using the featurized version. + if not sequences_are_feature_equivalent( + self.sequences[0], query_sequence, chain_poly_type + ): + raise ValueError( + f'First MSA sequence {self.sequences[0]} is not the {query_sequence=}' + ) + + @classmethod + def from_multiple_msas( + cls, msas: Sequence[Self], deduplicate: bool = True + ) -> Self: + """Initializes the MSA from multiple MSAs. + + Args: + msas: A sequence of Msa objects representing individual MSAs produced by + different tools/dbs. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + An Msa object created by merging multiple MSAs. + """ + if not msas: + raise ValueError('At least one MSA must be provided.') + + query_sequence = msas[0].query_sequence + chain_poly_type = msas[0].chain_poly_type + sequences = [] + descriptions = [] + + for msa in msas: + if msa.query_sequence != query_sequence: + raise ValueError( + f'Query sequences must match: {[m.query_sequence for m in msas]}' + ) + if msa.chain_poly_type != chain_poly_type: + raise ValueError( + f'Chain poly types must match: {[m.chain_poly_type for m in msas]}' + ) + sequences.extend(msa.sequences) + descriptions.extend(msa.descriptions) + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=sequences, + descriptions=descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_multiple_a3ms( + cls, a3ms: Sequence[str], chain_poly_type: str, deduplicate: bool = True + ) -> Self: + """Initializes the MSA from multiple A3M strings. + + Args: + a3ms: A sequence of A3M strings representing individual MSAs produced by + different tools/dbs. + chain_poly_type: Polymer type of the query sequence, see mmcif_names. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + An Msa object created by merging multiple A3Ms. + """ + if not a3ms: + raise ValueError('At least one A3M must be provided.') + + query_sequence = None + all_sequences = [] + all_descriptions = [] + + for a3m in a3ms: + sequences, descriptions = parsers.parse_fasta(a3m) + if query_sequence is None: + query_sequence = sequences[0] + + if sequences[0] != query_sequence: + raise ValueError( + f'Query sequences must match: {sequences[0]=} != {query_sequence=}' + ) + all_sequences.extend(sequences) + all_descriptions.extend(descriptions) + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=all_sequences, + descriptions=all_descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_a3m( + cls, + query_sequence: str, + chain_poly_type: str, + a3m: str, + max_depth: int | None = None, + deduplicate: bool = True, + ) -> Self: + """Parses the single A3M and builds the Msa object.""" + sequences, descriptions = parsers.parse_fasta(a3m) + + if max_depth is not None and 0 < max_depth < len(sequences): + logging.info( + 'MSA cropped from depth of %d to %d for %s.', + len(sequences), + max_depth, + query_sequence, + ) + sequences = sequences[:max_depth] + descriptions = descriptions[:max_depth] + + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=sequences, + descriptions=descriptions, + deduplicate=deduplicate, + ) + + @classmethod + def from_empty(cls, query_sequence: str, chain_poly_type: str) -> Self: + """Creates an empty Msa containing just the query sequence.""" + return cls( + query_sequence=query_sequence, + chain_poly_type=chain_poly_type, + sequences=[], + descriptions=[], + deduplicate=False, + ) + + @property + def depth(self) -> int: + return len(self.sequences) + + def __repr__(self) -> str: + return f'Msa({self.depth} sequences, {self.chain_poly_type})' + + def to_a3m(self) -> str: + """Returns the MSA in the A3M format.""" + a3m_lines = [] + for desc, seq in zip(self.descriptions, self.sequences, strict=True): + a3m_lines.append(f'>{desc}') + a3m_lines.append(seq) + return '\n'.join(a3m_lines) + '\n' + + def featurize(self) -> MutableMapping[str, np.ndarray]: + """Featurises the MSA and returns a map of feature names to features. + + Returns: + A dictionary mapping feature names to values. + + Raises: + msa.Error: + * If the sequences in the MSA don't have the same length after deletions + (lower case letters) are removed. + * If the MSA contains an unknown amino acid code. + * If there are no sequences after aligning. + """ + try: + msa, deletion_matrix = msa_features.extract_msa_features( + msa_sequences=self.sequences, chain_poly_type=self.chain_poly_type + ) + except ValueError as e: + raise Error(f'Error extracting MSA or deletion features: {e}') from e + + if msa.shape == (0, 0): + raise Error(f'Empty MSA feature for {self}') + + species_ids = msa_features.extract_species_ids(self.descriptions) + + return { + 'msa_species_identifiers': np.array(species_ids, dtype=object), + 'num_alignments': np.array(self.depth, dtype=np.int32), + 'msa': msa, + 'deletion_matrix_int': deletion_matrix, + } + + +def get_msa_tool( + msa_tool_config: msa_config.JackhmmerConfig | msa_config.NhmmerConfig, +) -> msa_tool.MsaTool: + """Returns the requested MSA tool.""" + + match msa_tool_config: + case msa_config.JackhmmerConfig(): + return jackhmmer.Jackhmmer( + binary_path=msa_tool_config.binary_path, + database_path=msa_tool_config.database_config.path, + n_cpu=msa_tool_config.n_cpu, + n_iter=msa_tool_config.n_iter, + e_value=msa_tool_config.e_value, + z_value=msa_tool_config.z_value, + max_sequences=msa_tool_config.max_sequences, + ) + case msa_config.NhmmerConfig(): + return nhmmer.Nhmmer( + binary_path=msa_tool_config.binary_path, + hmmalign_binary_path=msa_tool_config.hmmalign_binary_path, + hmmbuild_binary_path=msa_tool_config.hmmbuild_binary_path, + database_path=msa_tool_config.database_config.path, + n_cpu=msa_tool_config.n_cpu, + e_value=msa_tool_config.e_value, + max_sequences=msa_tool_config.max_sequences, + alphabet=msa_tool_config.alphabet, + ) + case _: + raise ValueError(f'Unknown MSA tool: {msa_tool_config}.') + + +def get_msa( + target_sequence: str, + run_config: msa_config.RunConfig, + chain_poly_type: str, + deduplicate: bool = False, +) -> Msa: + """Computes the MSA for a given query sequence. + + Args: + target_sequence: The target amino-acid sequence. + run_config: MSA run configuration. + chain_poly_type: The type of chain for which to get an MSA. + deduplicate: If True, the MSA sequences will be deduplicated in the input + order. Lowercase letters (insertions) are ignored when deduplicating. + + Returns: + Aligned MSA sequences. + """ + + return Msa.from_a3m( + query_sequence=target_sequence, + chain_poly_type=chain_poly_type, + a3m=get_msa_tool(run_config.config).query(target_sequence).a3m, + max_depth=run_config.crop_size, + deduplicate=deduplicate, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py new file mode 100644 index 000000000..c195e1c3d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_config.py @@ -0,0 +1,168 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Genetic search config settings for data pipelines.""" + +import dataclasses +import datetime +from typing import Self +from alphafold3.constants import mmcif_names + + +def _validate_chain_poly_type(chain_poly_type: str) -> None: + if chain_poly_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES: + raise ValueError( + 'chain_poly_type must be one of' + f' {mmcif_names.STANDARD_POLYMER_CHAIN_TYPES}: {chain_poly_type}' + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class DatabaseConfig: + """Configuration for a database.""" + + name: str + path: str + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class JackhmmerConfig: + """Configuration for a jackhmmer run. + + Attributes: + binary_path: Path to the binary of the msa tool. + database_config: Database configuration. + n_cpu: An integer with the number of CPUs to use. + n_iter: An integer with the number of database search iterations. + e_value: e-value for the database lookup. + z_value: The Z-value representing the number of comparisons done (i.e + correct database size) for E-value calculation. + max_sequences: Max sequences to return in MSA. + """ + + binary_path: str + database_config: DatabaseConfig + n_cpu: int + n_iter: int + e_value: float + z_value: float | int | None + max_sequences: int + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class NhmmerConfig: + """Configuration for a nhmmer run. + + Attributes: + binary_path: Path to the binary of the msa tool. + hmmalign_binary_path: Path to the hmmalign binary. + hmmbuild_binary_path: Path to the hmmbuild binary. + database_config: Database configuration. + n_cpu: An integer with the number of CPUs to use. + e_value: e-value for the database lookup. + max_sequences: Max sequences to return in MSA. + alphabet: The alphabet when building a profile with hmmbuild. + """ + + binary_path: str + hmmalign_binary_path: str + hmmbuild_binary_path: str + database_config: DatabaseConfig + n_cpu: int + e_value: float + max_sequences: int + alphabet: str | None + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class RunConfig: + """Configuration for an MSA run. + + Attributes: + config: MSA tool config. + chain_poly_type: The chain type for which the tools will be run. + crop_size: The maximum number of sequences to keep in the MSA. If None, all + sequences are kept. Note that the query is included in the MSA, so it + doesn't make sense to set this to less than 2. + """ + + config: JackhmmerConfig | NhmmerConfig + chain_poly_type: str + crop_size: int | None + + def __post_init__(self): + if self.crop_size is not None and self.crop_size < 2: + raise ValueError(f'crop_size must be None or >= 2: {self.crop_size}') + + _validate_chain_poly_type(self.chain_poly_type) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class HmmsearchConfig: + """Configuration for a hmmsearch.""" + + hmmsearch_binary_path: str + hmmbuild_binary_path: str + + e_value: float + inc_e: float + dom_e: float + incdom_e: float + alphabet: str = 'amino' + filter_f1: float | None = None + filter_f2: float | None = None + filter_f3: float | None = None + filter_max: bool = False + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplateToolConfig: + """Configuration for a template tool.""" + + database_path: str + chain_poly_type: str + hmmsearch_config: HmmsearchConfig + max_a3m_query_sequences: int | None = 300 + + def __post_init__(self): + _validate_chain_poly_type(self.chain_poly_type) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplateFilterConfig: + """Configuration for a template filter.""" + + max_subsequence_ratio: float | None + min_align_ratio: float | None + min_hit_length: int | None + deduplicate_sequences: bool + max_hits: int | None + max_template_date: datetime.date + + @classmethod + def no_op_filter(cls) -> Self: + """Returns a config for filter that keeps everything.""" + return cls( + max_subsequence_ratio=None, + min_align_ratio=None, + min_hit_length=None, + deduplicate_sequences=False, + max_hits=None, + max_template_date=datetime.date(3000, 1, 1), # Very far in the future. + ) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class TemplatesConfig: + """Configuration for the template search pipeline.""" + + template_tool_config: TemplateToolConfig + filter_config: TemplateFilterConfig diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py new file mode 100644 index 000000000..b3eaca976 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_features.py @@ -0,0 +1,203 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for computing MSA features.""" + +from collections.abc import Sequence +import re +from alphafold3.constants import mmcif_names +import numpy as np + +_PROTEIN_TO_ID = { + 'A': 0, + 'B': 3, # Same as D. + 'C': 4, + 'D': 3, + 'E': 6, + 'F': 13, + 'G': 7, + 'H': 8, + 'I': 9, + 'J': 20, # Same as unknown (X). + 'K': 11, + 'L': 10, + 'M': 12, + 'N': 2, + 'O': 20, # Same as unknown (X). + 'P': 14, + 'Q': 5, + 'R': 1, + 'S': 15, + 'T': 16, + 'U': 4, # Same as C. + 'V': 19, + 'W': 17, + 'X': 20, + 'Y': 18, + 'Z': 6, # Same as E. + '-': 21, +} + +_RNA_TO_ID = { + # Map non-standard residues to UNK_NUCLEIC (N) -> 30 + **{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)}, + # Continue the RNA indices from where Protein indices left off. + '-': 21, + 'A': 22, + 'G': 23, + 'C': 24, + 'U': 25, +} + +_DNA_TO_ID = { + # Map non-standard residues to UNK_NUCLEIC (N) -> 30 + **{chr(i): 30 for i in range(ord('A'), ord('Z') + 1)}, + # Continue the DNA indices from where DNA indices left off. + '-': 21, + 'A': 26, + 'G': 27, + 'C': 28, + 'T': 29, +} + + +def extract_msa_features( + msa_sequences: Sequence[str], chain_poly_type: str +) -> tuple[np.ndarray, np.ndarray]: + """Extracts MSA features. + + Example: + The input raw MSA is: `[["AAAAAA"], ["Ai-CiDiiiEFa"]]` + The output MSA will be: `[["AAAAAA"], ["A-CDEF"]]` + The deletions will be: `[[0, 0, 0, 0, 0, 0], [0, 1, 0, 1, 3, 0]]` + + Args: + msa_sequences: A list of strings, each string with one MSA sequence. Each + string must have the same, constant number of non-lowercase (matching) + residues. + chain_poly_type: Either 'polypeptide(L)' (protein), 'polyribonucleotide' + (RNA), or 'polydeoxyribonucleotide' (DNA). Use the appropriate string + constant from mmcif_names.py. + + Returns: + A tuple with: + * MSA array of shape (num_seq, num_res) that contains only the uppercase + characters or gaps (-) from the original MSA. + * Deletions array of shape (num_seq, num_res) that contains the number + of deletions (lowercase letters in the MSA) to the left from each + non-deleted residue (uppercase letters in the MSA). + + Raises: + ValueError if any of the preconditions are not met. + """ + + # Select the appropriate character map based on the chain type. + if chain_poly_type == mmcif_names.RNA_CHAIN: + char_map = _RNA_TO_ID + elif chain_poly_type == mmcif_names.DNA_CHAIN: + char_map = _DNA_TO_ID + elif chain_poly_type == mmcif_names.PROTEIN_CHAIN: + char_map = _PROTEIN_TO_ID + else: + raise ValueError(f'{chain_poly_type=} invalid.') + + # Handle empty MSA. + if not msa_sequences: + empty_msa = np.array([], dtype=np.int32).reshape((0, 0)) + empty_deletions = np.array([], dtype=np.int32).reshape((0, 0)) + return empty_msa, empty_deletions + + # Get the number of rows and columns in the MSA. + num_rows = len(msa_sequences) + num_cols = sum(1 for c in msa_sequences[0] if c in char_map) + + # Initialize the output arrays. + msa_arr = np.zeros((num_rows, num_cols), dtype=np.int32) + deletions_arr = np.zeros((num_rows, num_cols), dtype=np.int32) + + # Populate the output arrays. + for problem_row, msa_sequence in enumerate(msa_sequences): + deletion_count = 0 + upper_count = 0 + problem_col = 0 + problems = [] + for current in msa_sequence: + msa_id = char_map.get(current, -1) + if msa_id == -1: + if not current.islower(): + problems.append(f'({problem_row}, {problem_col}):{current}') + deletion_count += 1 + else: + # Check the access is safe before writing to the array. + # We don't need to check problem_row since it's guaranteed to be within + # the array bounds, while upper_count is incremented in the loop. + if upper_count < deletions_arr.shape[1]: + deletions_arr[problem_row, upper_count] = deletion_count + msa_arr[problem_row, upper_count] = msa_id + deletion_count = 0 + upper_count += 1 + problem_col += 1 + if problems: + raise ValueError( + f"Unknown residues in MSA: {', '.join(problems)}. " + f'target_sequence: {msa_sequences[0]}' + ) + if upper_count != num_cols: + raise ValueError( + 'Invalid shape all strings must have the same number ' + 'of non-lowercase characters; First string has ' + f"{num_cols} non-lowercase characters but '{msa_sequence}' has " + f'{upper_count}. target_sequence: {msa_sequences[0]}' + ) + + return msa_arr, deletions_arr + + +# UniProtKB SwissProt/TrEMBL dbs have the following description format: +# `db|UniqueIdentifier|EntryName`, e.g. `sp|P0C2L1|A3X1_LOXLA` or +# `tr|A0A146SKV9|A0A146SKV9_FUNHE`. +_UNIPROT_ENTRY_NAME_REGEX = re.compile( + # UniProtKB TrEMBL or SwissProt database. + r'(?:tr|sp)\|' + # A primary accession number of the UniProtKB entry. + r'(?:[A-Z0-9]{6,10})' + # Occasionally there is an isoform suffix (e.g. _1 or _10) which we ignore. + r'(?:_\d+)?\|' + # TrEMBL: Same as AccessionId (6-10 characters). + # SwissProt: A mnemonic protein identification code (1-5 characters). + r'(?:[A-Z0-9]{1,10}_)' + # A mnemonic species identification code. + r'(?P[A-Z0-9]{1,5})' +) + + +def extract_species_ids(msa_descriptions: Sequence[str]) -> Sequence[str]: + """Extracts species ID from MSA UniProtKB sequence identifiers. + + Args: + msa_descriptions: The descriptions (the FASTA/A3M comment line) for each of + the sequences. + + Returns: + Extracted UniProtKB species IDs if there is a regex match for each + description line, blank if the regex doesn't match. + """ + species_ids = [] + for msa_description in msa_descriptions: + msa_description = msa_description.strip() + match = _UNIPROT_ENTRY_NAME_REGEX.match(msa_description) + if match: + species_ids.append(match.group('SpeciesId')) + else: + # Handle cases where the regex doesn't match + # (e.g., append None or raise an error depending on your needs) + species_ids.append('') + return species_ids diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py new file mode 100644 index 000000000..f2936e4f3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_identifiers.py @@ -0,0 +1,86 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for extracting identifiers from MSA sequence descriptions.""" + +import dataclasses +import re + + +# Sequences coming from UniProtKB database come in the +# `db|UniqueIdentifier|EntryName` format, e.g. `tr|A0A146SKV9|A0A146SKV9_FUNHE` +# or `sp|P0C2L1|A3X1_LOXLA` (for TREMBL/Swiss-Prot respectively). +_UNIPROT_PATTERN = re.compile( + r""" + ^ + # UniProtKB/TrEMBL or UniProtKB/Swiss-Prot + (?:tr|sp) + \| + # A primary accession number of the UniProtKB entry. + (?P[A-Za-z0-9]{6,10}) + # Occasionally there is a _0 or _1 isoform suffix, which we ignore. + (?:_\d)? + \| + # TREMBL repeats the accession ID here. Swiss-Prot has a mnemonic + # protein ID code. + (?:[A-Za-z0-9]+) + _ + # A mnemonic species identification code. + (?P([A-Za-z0-9]){1,5}) + # Small BFD uses a final value after an underscore, which we ignore. + (?:_\d+)? + $ + """, + re.VERBOSE, +) + + +@dataclasses.dataclass(frozen=True) +class Identifiers: + species_id: str = '' + + +def _parse_sequence_identifier(msa_sequence_identifier: str) -> Identifiers: + """Gets species from an msa sequence identifier. + + The sequence identifier has the format specified by + _UNIPROT_TREMBL_ENTRY_NAME_PATTERN or _UNIPROT_SWISSPROT_ENTRY_NAME_PATTERN. + An example of a sequence identifier: `tr|A0A146SKV9|A0A146SKV9_FUNHE` + + Args: + msa_sequence_identifier: a sequence identifier. + + Returns: + An `Identifiers` instance with species_id. These + can be empty in the case where no identifier was found. + """ + matches = re.search(_UNIPROT_PATTERN, msa_sequence_identifier.strip()) + if matches: + return Identifiers(species_id=matches.group('SpeciesIdentifier')) + return Identifiers() + + +def _extract_sequence_identifier(description: str) -> str | None: + """Extracts sequence identifier from description. Returns None if no match.""" + split_description = description.split() + if split_description: + return split_description[0].partition('/')[0] + else: + return None + + +def get_identifiers(description: str) -> Identifiers: + """Computes extra MSA features from the description.""" + sequence_identifier = _extract_sequence_identifier(description) + if sequence_identifier is None: + return Identifiers() + else: + return _parse_sequence_identifier(sequence_identifier) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py new file mode 100644 index 000000000..69cd6a54c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/msa_store.py @@ -0,0 +1,67 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Interface and implementations for fetching MSA data.""" + +from collections.abc import Sequence +from typing_extensions import Protocol, TypeAlias + +from alphafold3.data import msa +from alphafold3.data import msa_config + + +MsaErrors: TypeAlias = Sequence[tuple[msa_config.RunConfig, str]] + + +class MsaProvider(Protocol): + """Interface for providing Multiple Sequence Alignments.""" + + def __call__( + self, + query_sequence: str, + chain_polymer_type: str, + ) -> tuple[msa.Msa, MsaErrors]: + """Retrieve MSA for the given polymer query_sequence. + + Args: + query_sequence: The residue sequence of the polymer to search for. + chain_polymer_type: The polymer type of the query_sequence. This must + match the chain_polymer_type of the provider. + + Returns: + A tuple containing the MSA and MsaErrors. MsaErrors is a Sequence + containing a tuple for each msa_query that failed. Each tuple contains + the failing query and the associated error message. + """ + + +class EmptyMsaProvider: + """MSA provider that returns just the query sequence, useful for testing.""" + + def __init__(self, chain_polymer_type: str): + self._chain_polymer_type = chain_polymer_type + + def __call__( + self, query_sequence: str, chain_polymer_type: str + ) -> tuple[msa.Msa, MsaErrors]: + """Returns an MSA containing just the query sequence, never errors.""" + if chain_polymer_type != self._chain_polymer_type: + raise ValueError( + f'EmptyMsaProvider of type {self._chain_polymer_type} called with ' + f'sequence of {chain_polymer_type=}, {query_sequence=}.' + ) + return ( + msa.Msa.from_empty( + query_sequence=query_sequence, + chain_poly_type=self._chain_polymer_type, + ), + (), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py new file mode 100644 index 000000000..c12982783 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/parsers.py @@ -0,0 +1,180 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""Functions for parsing various file formats.""" + +from collections.abc import Iterable, Sequence +from typing import IO, TypeAlias + +from alphafold3.cpp import fasta_iterator +from alphafold3.cpp import msa_conversion + + +DeletionMatrix: TypeAlias = Sequence[Sequence[int]] + + +def lazy_parse_fasta_string(fasta_string: str) -> Iterable[tuple[str, str]]: + """Lazily parses a FASTA/A3M string and yields (sequence, description) tuples. + + This implementation is more memory friendly than `fasta_sequence` while + offering comparable performance. The underlying implementation is in C++ and + is therefore faster than a pure Python implementation. + + Use this method when parsing FASTA files where you already have the FASTA + string, but need to control how far you iterate through its sequences. + + Arguments: + fasta_string: A string with the contents of FASTA/A3M file. + + Returns: + Iterator of (sequence, description). In the description, the leading ">" is + stripped. + + Raises: + ValueError if the FASTA/A3M file is invalid, e.g. empty. + """ + + # The lifetime of the FastaStringIterator is tied to the lifetime of + # fasta_string - fasta_string must be kept while the iterator is in use. + return fasta_iterator.FastaStringIterator(fasta_string) + + +def parse_fasta(fasta_string: str) -> tuple[Sequence[str], Sequence[str]]: + """Parses FASTA string and returns list of strings with amino-acid sequences. + + Arguments: + fasta_string: The string contents of a FASTA file. + + Returns: + A tuple of two lists: + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. + """ + return fasta_iterator.parse_fasta_include_descriptions(fasta_string) + + +def convert_a3m_to_stockholm(a3m: str, max_seqs: int | None = None) -> str: + """Converts MSA in the A3M format to the Stockholm format.""" + sequences, descriptions = parse_fasta(a3m) + if max_seqs is not None: + sequences = sequences[:max_seqs] + descriptions = descriptions[:max_seqs] + + stockholm = ['# STOCKHOLM 1.0', ''] + + # Add the Stockholm header with the sequence metadata. + names = [] + for i, description in enumerate(descriptions): + name, _, rest = description.partition(' ') + # Ensure that the names are unique - stockholm format requires that + # the sequence names are unique. + name = f'{name}_{i}' + names.append(name) + # Avoid zero-length description due to historic hmmbuild parsing bug. + desc = rest.strip() or '' + stockholm.append(f'#=GS {name.strip()} DE {desc}') + stockholm.append('') + + # Convert insertions in a sequence into gaps in all other sequences that don't + # have an insertion in that column as well. + sequences = msa_conversion.convert_a3m_to_stockholm(sequences) + + # Add the MSA data. + max_name_width = max(len(name) for name in names) + for name, sequence in zip(names, sequences, strict=True): + # Align the names to the left and pad with spaces to the maximum length. + stockholm.append(f'{name:<{max_name_width}s} {sequence}') + + # Add the reference annotation for the query (the first sequence). + ref_annotation = ''.join('.' if c == '-' else 'x' for c in sequences[0]) + stockholm.append(f'{"#=GC RF":<{max_name_width}s} {ref_annotation}') + stockholm.append('//') + + return '\n'.join(stockholm) + + +def convert_stockholm_to_a3m( + stockholm: IO[str], + max_sequences: int | None = None, + remove_first_row_gaps: bool = True, + linewidth: int | None = None, +) -> str: + """Converts MSA in Stockholm format to the A3M format.""" + descriptions = {} + sequences = {} + reached_max_sequences = False + + if linewidth is not None and linewidth <= 0: + raise ValueError('linewidth must be > 0 or None') + + for line in stockholm: + reached_max_sequences = max_sequences and len(sequences) >= max_sequences + line = line.strip() + # Ignore blank lines, markup and end symbols - remainder are alignment + # sequence parts. + if not line or line.startswith(('#', '//')): + continue + seqname, aligned_seq = line.split(maxsplit=1) + if seqname not in sequences: + if reached_max_sequences: + continue + sequences[seqname] = '' + sequences[seqname] += aligned_seq + + stockholm.seek(0) + for line in stockholm: + line = line.strip() + if line[:4] == '#=GS': + # Description row - example format is: + # #=GS UniRef90_Q9H5Z4/4-78 DE [subseq from] cDNA: FLJ22755 ... + columns = line.split(maxsplit=3) + seqname, feature = columns[1:3] + value = columns[3] if len(columns) == 4 else '' + if feature != 'DE': + continue + if reached_max_sequences and seqname not in sequences: + continue + descriptions[seqname] = value + if len(descriptions) == len(sequences): + break + + assert len(descriptions) <= len(sequences) + + # Convert sto format to a3m line by line + a3m_sequences = {} + # query_sequence is assumed to be the first sequence + query_sequence = next(iter(sequences.values())) + for seqname, sto_sequence in sequences.items(): + if remove_first_row_gaps: + a3m_sequences[seqname] = msa_conversion.align_sequence_to_gapless_query( + sequence=sto_sequence, query_sequence=query_sequence + ).replace('.', '') + else: + a3m_sequences[seqname] = sto_sequence.replace('.', '') + + fasta_chunks = [] + + for seqname, seq in a3m_sequences.items(): + fasta_chunks.append(f'>{seqname} {descriptions.get(seqname, "")}') + + if linewidth: + fasta_chunks.extend( + seq[i : linewidth + i] for i in range(0, len(seq), linewidth) + ) + else: + fasta_chunks.append(seq) + + return '\n'.join(fasta_chunks) + '\n' # Include terminating newline. diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py new file mode 100644 index 000000000..9fc2e8c6f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/pipeline.py @@ -0,0 +1,538 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions for running the MSA and template tools for the AlphaFold model.""" + +from concurrent import futures +import dataclasses +import datetime +import functools +import logging +import time + +from alphafold3.common import folding_input +from alphafold3.constants import mmcif_names +from alphafold3.data import msa +from alphafold3.data import msa_config +from alphafold3.data import structure_stores +from alphafold3.data import templates as templates_lib + + +# Cache to avoid re-running template search for the same sequence in homomers. +@functools.cache +def _get_protein_templates( + sequence: str, + input_msa_a3m: str, + run_template_search: bool, + templates_config: msa_config.TemplatesConfig, + pdb_database_path: str, +) -> templates_lib.Templates: + """Searches for templates for a single protein chain.""" + if run_template_search: + templates_start_time = time.time() + logging.info('Getting protein templates for sequence %s', sequence) + protein_templates = templates_lib.Templates.from_seq_and_a3m( + query_sequence=sequence, + msa_a3m=input_msa_a3m, + max_template_date=templates_config.filter_config.max_template_date, + database_path=templates_config.template_tool_config.database_path, + hmmsearch_config=templates_config.template_tool_config.hmmsearch_config, + max_a3m_query_sequences=None, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + structure_store=structure_stores.StructureStore(pdb_database_path), + filter_config=templates_config.filter_config, + ) + logging.info( + 'Getting protein templates took %.2f seconds for sequence %s', + time.time() - templates_start_time, + sequence, + ) + else: + logging.info('Skipping template search for sequence %s', sequence) + protein_templates = templates_lib.Templates( + query_sequence=sequence, + hits=[], + max_template_date=templates_config.filter_config.max_template_date, + structure_store=structure_stores.StructureStore(pdb_database_path), + ) + return protein_templates + + +# Cache to avoid re-running the MSA tools for the same sequence in homomers. +@functools.cache +def _get_protein_msa_and_templates( + sequence: str, + run_template_search: bool, + uniref90_msa_config: msa_config.RunConfig, + mgnify_msa_config: msa_config.RunConfig, + small_bfd_msa_config: msa_config.RunConfig, + uniprot_msa_config: msa_config.RunConfig, + templates_config: msa_config.TemplatesConfig, + pdb_database_path: str, +) -> tuple[msa.Msa, msa.Msa, templates_lib.Templates]: + """Processes a single protein chain.""" + logging.info('Getting protein MSAs for sequence %s', sequence) + msa_start_time = time.time() + # Run various MSA tools in parallel. Use a ThreadPoolExecutor because + # they're not blocked by the GIL, as they're sub-shelled out. + with futures.ThreadPoolExecutor(max_workers=4) as executor: + uniref90_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=uniref90_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + mgnify_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=mgnify_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + small_bfd_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=small_bfd_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + uniprot_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=uniprot_msa_config, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + uniref90_msa = uniref90_msa_future.result() + mgnify_msa = mgnify_msa_future.result() + small_bfd_msa = small_bfd_msa_future.result() + uniprot_msa = uniprot_msa_future.result() + logging.info( + 'Getting protein MSAs took %.2f seconds for sequence %s', + time.time() - msa_start_time, + sequence, + ) + + logging.info('Deduplicating MSAs for sequence %s', sequence) + msa_dedupe_start_time = time.time() + with futures.ThreadPoolExecutor() as executor: + unpaired_protein_msa_future = executor.submit( + msa.Msa.from_multiple_msas, + msas=[uniref90_msa, small_bfd_msa, mgnify_msa], + deduplicate=True, + ) + paired_protein_msa_future = executor.submit( + msa.Msa.from_multiple_msas, msas=[uniprot_msa], deduplicate=False + ) + unpaired_protein_msa = unpaired_protein_msa_future.result() + paired_protein_msa = paired_protein_msa_future.result() + logging.info( + 'Deduplicating MSAs took %.2f seconds for sequence %s', + time.time() - msa_dedupe_start_time, + sequence, + ) + + protein_templates = _get_protein_templates( + sequence=sequence, + input_msa_a3m=unpaired_protein_msa.to_a3m(), + run_template_search=run_template_search, + templates_config=templates_config, + pdb_database_path=pdb_database_path, + ) + + return unpaired_protein_msa, paired_protein_msa, protein_templates + + +# Cache to avoid re-running the Nhmmer for the same sequence in homomers. +@functools.cache +def _get_rna_msa( + sequence: str, + nt_rna_msa_config: msa_config.NhmmerConfig, + rfam_msa_config: msa_config.NhmmerConfig, + rnacentral_msa_config: msa_config.NhmmerConfig, +) -> msa.Msa: + """Processes a single RNA chain.""" + logging.info('Getting RNA MSAs for sequence %s', sequence) + rna_msa_start_time = time.time() + # Run various MSA tools in parallel. Use a ThreadPoolExecutor because + # they're not blocked by the GIL, as they're sub-shelled out. + with futures.ThreadPoolExecutor() as executor: + nt_rna_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=nt_rna_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + rfam_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=rfam_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + rnacentral_msa_future = executor.submit( + msa.get_msa, + target_sequence=sequence, + run_config=rnacentral_msa_config, + chain_poly_type=mmcif_names.RNA_CHAIN, + ) + nt_rna_msa = nt_rna_msa_future.result() + rfam_msa = rfam_msa_future.result() + rnacentral_msa = rnacentral_msa_future.result() + logging.info( + 'Getting RNA MSAs took %.2f seconds for sequence %s', + time.time() - rna_msa_start_time, + sequence, + ) + + return msa.Msa.from_multiple_msas( + msas=[rfam_msa, rnacentral_msa, nt_rna_msa], + deduplicate=True, + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class DataPipelineConfig: + """The configuration for the data pipeline. + + Attributes: + jackhmmer_binary_path: Jackhmmer binary path, used for protein MSA search. + nhmmer_binary_path: Nhmmer binary path, used for RNA MSA search. + hmmalign_binary_path: Hmmalign binary path, used to align hits to the query + profile. + hmmsearch_binary_path: Hmmsearch binary path, used for template search. + hmmbuild_binary_path: Hmmbuild binary path, used to build HMM profile from + raw MSA in template search. + small_bfd_database_path: Small BFD database path, used for protein MSA + search. + mgnify_database_path: Mgnify database path, used for protein MSA search. + uniprot_cluster_annot_database_path: Uniprot database path, used for protein + paired MSA search. + uniref90_database_path: UniRef90 database path, used for MSA search, and the + MSA obtained by searching it is used to construct the profile for template + search. + ntrna_database_path: NT-RNA database path, used for RNA MSA search. + rfam_database_path: Rfam database path, used for RNA MSA search. + rna_central_database_path: RNAcentral database path, used for RNA MSA + search. + seqres_database_path: PDB sequence database path, used for template search. + pdb_database_path: PDB database directory with mmCIF files path, used for + template search. + jackhmmer_n_cpu: Number of CPUs to use for Jackhmmer. + nhmmer_n_cpu: Number of CPUs to use for Nhmmer. + max_template_date: The latest date of templates to use. + """ + + # Binary paths. + jackhmmer_binary_path: str + nhmmer_binary_path: str + hmmalign_binary_path: str + hmmsearch_binary_path: str + hmmbuild_binary_path: str + + # Jackhmmer databases. + small_bfd_database_path: str + mgnify_database_path: str + uniprot_cluster_annot_database_path: str + uniref90_database_path: str + # Nhmmer databases. + ntrna_database_path: str + rfam_database_path: str + rna_central_database_path: str + # Template search databases. + seqres_database_path: str + pdb_database_path: str + + # Optional configuration for MSA tools. + jackhmmer_n_cpu: int = 8 + nhmmer_n_cpu: int = 8 + + max_template_date: datetime.date + + +class DataPipeline: + """Runs the alignment tools and assembles the input features.""" + + def __init__(self, data_pipeline_config: DataPipelineConfig): + """Initializes the data pipeline with default configurations.""" + self._uniref90_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='uniref90', + path=data_pipeline_config.uniref90_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._mgnify_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='mgnify', + path=data_pipeline_config.mgnify_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=5_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._small_bfd_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='small_bfd', + path=data_pipeline_config.small_bfd_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + # Set z_value=138_515_945 to match the z_value used in the paper. + # In practice, this has minimal impact on predicted structures. + z_value=None, + max_sequences=5_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._uniprot_msa_config = msa_config.RunConfig( + config=msa_config.JackhmmerConfig( + binary_path=data_pipeline_config.jackhmmer_binary_path, + database_config=msa_config.DatabaseConfig( + name='uniprot_cluster_annot', + path=data_pipeline_config.uniprot_cluster_annot_database_path, + ), + n_cpu=data_pipeline_config.jackhmmer_n_cpu, + n_iter=1, + e_value=1e-4, + z_value=None, + max_sequences=50_000, + ), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + crop_size=None, + ) + self._nt_rna_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='nt_rna', + path=data_pipeline_config.ntrna_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + self._rfam_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='rfam_rna', + path=data_pipeline_config.rfam_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + self._rnacentral_msa_config = msa_config.RunConfig( + config=msa_config.NhmmerConfig( + binary_path=data_pipeline_config.nhmmer_binary_path, + hmmalign_binary_path=data_pipeline_config.hmmalign_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + database_config=msa_config.DatabaseConfig( + name='rna_central_rna', + path=data_pipeline_config.rna_central_database_path, + ), + n_cpu=data_pipeline_config.nhmmer_n_cpu, + e_value=1e-3, + alphabet='rna', + max_sequences=10_000, + ), + chain_poly_type=mmcif_names.RNA_CHAIN, + crop_size=None, + ) + + self._templates_config = msa_config.TemplatesConfig( + template_tool_config=msa_config.TemplateToolConfig( + database_path=data_pipeline_config.seqres_database_path, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + hmmsearch_config=msa_config.HmmsearchConfig( + hmmsearch_binary_path=data_pipeline_config.hmmsearch_binary_path, + hmmbuild_binary_path=data_pipeline_config.hmmbuild_binary_path, + filter_f1=0.1, + filter_f2=0.1, + filter_f3=0.1, + e_value=100, + inc_e=100, + dom_e=100, + incdom_e=100, + alphabet='amino', + ), + ), + filter_config=msa_config.TemplateFilterConfig( + max_subsequence_ratio=0.95, + min_align_ratio=0.1, + min_hit_length=10, + deduplicate_sequences=True, + max_hits=4, + max_template_date=data_pipeline_config.max_template_date, + ), + ) + self._pdb_database_path = data_pipeline_config.pdb_database_path + + def process_protein_chain( + self, chain: folding_input.ProteinChain + ) -> folding_input.ProteinChain: + """Processes a single protein chain.""" + has_unpaired_msa = chain.unpaired_msa is not None + has_paired_msa = chain.paired_msa is not None + has_templates = chain.templates is not None + + if not has_unpaired_msa and not has_paired_msa and not chain.templates: + # MSA None - search. Templates either [] - don't search, or None - search. + unpaired_msa, paired_msa, template_hits = _get_protein_msa_and_templates( + sequence=chain.sequence, + run_template_search=not has_templates, # Skip template search if []. + uniref90_msa_config=self._uniref90_msa_config, + mgnify_msa_config=self._mgnify_msa_config, + small_bfd_msa_config=self._small_bfd_msa_config, + uniprot_msa_config=self._uniprot_msa_config, + templates_config=self._templates_config, + pdb_database_path=self._pdb_database_path, + ) + unpaired_msa = unpaired_msa.to_a3m() + paired_msa = paired_msa.to_a3m() + templates = [ + folding_input.Template( + mmcif=struc.to_mmcif(), + query_to_template_map=hit.query_to_hit_mapping, + ) + for hit, struc in template_hits.get_hits_with_structures() + ] + elif has_unpaired_msa and has_paired_msa and not has_templates: + # Has MSA, but doesn't have templates. Search for templates only. + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + paired_msa = chain.paired_msa or empty_msa + template_hits = _get_protein_templates( + sequence=chain.sequence, + input_msa_a3m=unpaired_msa, + run_template_search=True, + templates_config=self._templates_config, + pdb_database_path=self._pdb_database_path, + ) + templates = [ + folding_input.Template( + mmcif=struc.to_mmcif(), + query_to_template_map=hit.query_to_hit_mapping, + ) + for hit, struc in template_hits.get_hits_with_structures() + ] + else: + # Has MSA and templates, don't search for anything. + if not has_unpaired_msa or not has_paired_msa or not has_templates: + raise ValueError( + f'Protein chain {chain.id} has unpaired MSA, paired MSA, or' + ' templates set only partially. If you want to run the pipeline' + ' with custom MSA/templates, you need to set all of them. You can' + ' set MSA to empty string and templates to empty list to signify' + ' that they should not be used and searched for.' + ) + logging.info( + 'Skipping MSA and template search for protein chain %s because it ' + 'already has MSAs and templates.', + chain.id, + ) + if not chain.unpaired_msa: + logging.info('Using empty unpaired MSA for protein chain %s', chain.id) + if not chain.paired_msa: + logging.info('Using empty paired MSA for protein chain %s', chain.id) + if not chain.templates: + logging.info('Using no templates for protein chain %s', chain.id) + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + paired_msa = chain.paired_msa or empty_msa + templates = chain.templates + + return dataclasses.replace( + chain, + unpaired_msa=unpaired_msa, + paired_msa=paired_msa, + templates=templates, + ) + + def process_rna_chain( + self, chain: folding_input.RnaChain + ) -> folding_input.RnaChain: + """Processes a single RNA chain.""" + if chain.unpaired_msa is not None: + # Don't run MSA tools if the chain already has an MSA. + logging.info( + 'Skipping MSA search for RNA chain %s because it already has MSA.', + chain.id, + ) + if not chain.unpaired_msa: + logging.info('Using empty unpaired MSA for RNA chain %s', chain.id) + empty_msa = msa.Msa.from_empty( + query_sequence=chain.sequence, chain_poly_type=mmcif_names.RNA_CHAIN + ).to_a3m() + unpaired_msa = chain.unpaired_msa or empty_msa + else: + unpaired_msa = _get_rna_msa( + sequence=chain.sequence, + nt_rna_msa_config=self._nt_rna_msa_config, + rfam_msa_config=self._rfam_msa_config, + rnacentral_msa_config=self._rnacentral_msa_config, + ).to_a3m() + return dataclasses.replace(chain, unpaired_msa=unpaired_msa) + + def process(self, fold_input: folding_input.Input) -> folding_input.Input: + """Runs MSA and template tools and returns a new Input with the results.""" + processed_chains = [] + for chain in fold_input.chains: + print(f'Processing chain {chain.id}') + process_chain_start_time = time.time() + match chain: + case folding_input.ProteinChain(): + processed_chains.append(self.process_protein_chain(chain)) + case folding_input.RnaChain(): + processed_chains.append(self.process_rna_chain(chain)) + case _: + processed_chains.append(chain) + print( + f'Processing chain {chain.id} took' + f' {time.time() - process_chain_start_time:.2f} seconds', + ) + + return dataclasses.replace(fold_input, chains=processed_chains) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py new file mode 100644 index 000000000..b3e17ba52 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/structure_stores.py @@ -0,0 +1,101 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library for loading structure data from various sources.""" + +from collections.abc import Mapping, Sequence +import functools +import os +import pathlib +import tarfile + + +class NotFoundError(KeyError): + """Raised when the structure store doesn't contain the requested target.""" + + +class StructureStore: + """Handles the retrieval of mmCIF files from a filesystem.""" + + def __init__( + self, + structures: str | os.PathLike[str] | Mapping[str, str], + ): + """Initialises the instance. + + Args: + structures: Path of the directory where the mmCIF files are or a Mapping + from target name to mmCIF string. + """ + if isinstance(structures, Mapping): + self._structure_mapping = structures + self._structure_path = None + self._structure_tar = None + else: + self._structure_mapping = None + path_str = os.fspath(structures) + if path_str.endswith('.tar'): + self._structure_tar = tarfile.open(path_str, 'r') + self._structure_path = None + else: + self._structure_path = pathlib.Path(structures) + self._structure_tar = None + + @functools.cached_property + def _tar_members(self) -> Mapping[str, tarfile.TarInfo]: + assert self._structure_tar is not None + return { + path.stem: tarinfo + for tarinfo in self._structure_tar.getmembers() + if tarinfo.isfile() + and (path := pathlib.Path(tarinfo.path.lower())).suffix == '.cif' + } + + def get_mmcif_str(self, target_name: str) -> str: + """Returns an mmCIF for a given `target_name`. + + Args: + target_name: Name specifying the target mmCIF. + + Raises: + NotFoundError: If the target is not found. + """ + if self._structure_mapping is not None: + try: + return self._structure_mapping[target_name] + except KeyError as e: + raise NotFoundError(f'{target_name=} not found') from e + + if self._structure_tar is not None: + try: + member = self._tar_members[target_name] + if struct_file := self._structure_tar.extractfile(member): + return struct_file.read().decode() + else: + raise NotFoundError(f'{target_name=} not found') + except KeyError: + raise NotFoundError(f'{target_name=} not found') from None + + filepath = self._structure_path / f'{target_name}.cif' + try: + return filepath.read_text() + except FileNotFoundError as e: + raise NotFoundError(f'{target_name=} not found at {filepath=}') from e + + def target_names(self) -> Sequence[str]: + """Returns all targets in the store.""" + if self._structure_mapping is not None: + return [*self._structure_mapping.keys()] + elif self._structure_tar is not None: + return sorted(self._tar_members.keys()) + elif self._structure_path is not None: + return sorted([path.stem for path in self._structure_path.glob('*.cif')]) + return () diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py new file mode 100644 index 000000000..453a00ee9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_realign.py @@ -0,0 +1,169 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Realign sequences found in PDB seqres to the actual CIF sequences.""" + +from collections.abc import Mapping + + +class AlignmentError(Exception): + """Failed alignment between the hit sequence and the actual mmCIF sequence.""" + + +def realign_hit_to_structure( + *, + hit_sequence: str, + hit_start_index: int, + hit_end_index: int, + full_length: int, + structure_sequence: str, + query_to_hit_mapping: Mapping[int, int], +) -> Mapping[int, int]: + """Realigns the hit sequence to the Structure sequence. + + For example, for the given input: + query_sequence : ABCDEFGHIJKL + hit_sequence : ---DEFGHIJK- + struc_sequence : XDEFGHKL + the mapping is {3: 0, 4: 1, 5: 2, 6: 3, 7: 4, 8: 5, 9: 6, 10: 7}. However, the + actual Structure sequence has an extra X at the start as well as no IJ. So the + alignment from the query to the Structure sequence will be: + hit_sequence : ---DEFGHIJK- + struc_aligned : --XDEFGH--KL + and the new mapping will therefore be: {3: 1, 4: 2, 5: 3, 6: 4, 7: 5, 10: 6}. + + Args: + hit_sequence: The PDB seqres hit sequence obtained from Hmmsearch, but + without any gaps. This is not the full PDB seqres template sequence but + rather just its subsequence from hit_start_index to hit_end_index. + hit_start_index: The start index of the hit sequence in the full PDB seqres + template sequence (inclusive). + hit_end_index: The end index of the hit sequence in the full PDB seqres + template sequence (exclusive). + full_length: The length of the full PDB seqres template sequence. + structure_sequence: The actual sequence extracted from the Structure + corresponding to this template. In vast majority of cases this is the same + as the PDB seqres sequence, but this function handles the cases when not. + query_to_hit_mapping: The mapping from the query sequence to the + hit_sequence. + + Raises: + AlignmentError: if the alignment between the sequence returned by Hmmsearch + differs from the actual sequence found in the mmCIF and can't be aligned + using the simple alignment algorithm. + + Returns: + A mapping from the query sequence to the actual Structure sequence. + """ + max_num_gaps = full_length - len(structure_sequence) + if max_num_gaps < 0: + raise AlignmentError( + f'The Structure sequence ({len(structure_sequence)}) ' + f'must be shorter than the PDB seqres sequence ({full_length}):\n' + f'Structure sequence : {structure_sequence}\n' + f'PDB seqres sequence: {hit_sequence}' + ) + + if len(hit_sequence) != hit_end_index - hit_start_index: + raise AlignmentError( + f'The difference of {hit_end_index=} and {hit_start_index=} does not ' + f'equal to the length of the {hit_sequence}: {len(hit_sequence)}' + ) + + best_score = -1 + best_start = 0 + best_query_to_hit_mapping = query_to_hit_mapping + max_num_gaps_before_subseq = min(hit_start_index, max_num_gaps) + # It is possible the gaps needed to align the PDB seqres subsequence and + # the Structure subsequence need to be inserted before the match region. + # Try and pick the alignment with the best number of aligned residues. + for num_gaps_before_subseq in range(0, max_num_gaps_before_subseq + 1): + start = hit_start_index - num_gaps_before_subseq + end = hit_end_index - num_gaps_before_subseq + structure_subseq = structure_sequence[start:end] + + new_query_to_hit_mapping, score = _remap_to_struc_seq( + hit_seq=hit_sequence, + struc_seq=structure_subseq, + max_num_gaps=max_num_gaps - num_gaps_before_subseq, + mapping=query_to_hit_mapping, + ) + if score >= best_score: + # Use >= to prefer matches with larger number of gaps before. + best_score = score + best_start = start + best_query_to_hit_mapping = new_query_to_hit_mapping + + return {q: h + best_start for q, h in best_query_to_hit_mapping.items()} + + +def _remap_to_struc_seq( + *, + hit_seq: str, + struc_seq: str, + max_num_gaps: int, + mapping: Mapping[int, int], +) -> tuple[Mapping[int, int], int]: + """Remaps the query -> hit mapping to match the actual Structure sequence. + + Args: + hit_seq: The hit sequence - a subsequence of the PDB seqres sequence without + any Hmmsearch modifications like inserted gaps or lowercased residues. + struc_seq: The actual sequence obtained from the corresponding Structure. + max_num_gaps: The maximum number of gaps that can be inserted in the + Structure sequence. In practice, this is the length difference between the + PDB seqres sequence and the actual Structure sequence. + mapping: The mapping from the query residues to the hit residues. This will + be remapped to point to the actual Structure sequence using a simple + realignment algorithm. + + Returns: + A tuple of (mapping, score): + * Mapping from the query to the actual Structure sequence. + * Score which is the number of matching aligned residues. + + Raises: + ValueError if the structure sequence isn't shorter than the seqres sequence. + ValueError if the alignment fails. + """ + hit_seq_idx = 0 + struc_seq_idx = 0 + hit_to_struc_seq_mapping = {} + score = 0 + + # This while loop is guaranteed to terminate since we increase both + # struc_seq_idx and hit_seq_idx by at least 1 in each iteration. + remaining_num_gaps = max_num_gaps + while hit_seq_idx < len(hit_seq) and struc_seq_idx < len(struc_seq): + if hit_seq[hit_seq_idx] != struc_seq[struc_seq_idx]: + # Explore which alignment aligns the next residue (if present). + best_shift = 0 + for shift in range(0, remaining_num_gaps + 1): + next_hit_res = hit_seq[hit_seq_idx + shift : hit_seq_idx + shift + 1] + next_struc_res = struc_seq[struc_seq_idx : struc_seq_idx + 1] + if next_hit_res == next_struc_res: + best_shift = shift + break + hit_seq_idx += best_shift + remaining_num_gaps -= best_shift + + hit_to_struc_seq_mapping[hit_seq_idx] = struc_seq_idx + score += hit_seq[hit_seq_idx] == struc_seq[struc_seq_idx] + hit_seq_idx += 1 + struc_seq_idx += 1 + + fixed_mapping = {} + for query_idx, original_hit_idx in mapping.items(): + fixed_hit_idx = hit_to_struc_seq_mapping.get(original_hit_idx) + if fixed_hit_idx is not None: + fixed_mapping[query_idx] = fixed_hit_idx + + return fixed_mapping, score diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py new file mode 100644 index 000000000..080fca74f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/template_store.py @@ -0,0 +1,47 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Interface and implementations for fetching templates data.""" + +from collections.abc import Mapping +import datetime +from typing import Any, Protocol, TypeAlias + + +TemplateFeatures: TypeAlias = Mapping[str, Any] + + +class TemplateFeatureProvider(Protocol): + """Interface for providing Template Features.""" + + def __call__( + self, + sequence: str, + release_date: datetime.date | None, + include_ligand_features: bool = True, + ) -> TemplateFeatures: + """Retrieve template features for the given sequence and release_date. + + Args: + sequence: The residue sequence of the query. + release_date: The release_date of the template query, this is used to + filter templates for training, ensuring that they do not leak structure + information from the future. + include_ligand_features: Whether to include ligand features. + + Returns: + Template features: A mapping of template feature labels to features, which + may be numpy arrays, bytes objects, or for the special case of label + `ligand_features`, a nested feature map of labels to numpy arrays. + + Raises: + TemplateRetrievalError if the template features were not found. + """ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py new file mode 100644 index 000000000..f9695106a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/templates.py @@ -0,0 +1,969 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""API for retrieving and manipulating template search results.""" + +from collections.abc import Iterable, Iterator, Mapping, Sequence +import dataclasses +import datetime +import functools +import os +import re +from typing import Any, Final, Self, TypeAlias + +from absl import logging +from alphafold3 import structure +from alphafold3.common import resources +from alphafold3.constants import atom_types +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.data import msa_config +from alphafold3.data import parsers +from alphafold3.data import structure_stores +from alphafold3.data import template_realign +from alphafold3.data.tools import hmmsearch +from alphafold3.structure import mmcif +import numpy as np + + +_POLYMER_FEATURES: Final[Mapping[str, np.float64 | np.int32 | object]] = { + 'template_aatype': np.int32, + 'template_all_atom_masks': np.float64, + 'template_all_atom_positions': np.float64, + 'template_domain_names': object, + 'template_release_date': object, + 'template_sequence': object, +} + +_LIGAND_FEATURES: Final[Mapping[str, Any]] = { + 'ligand_features': Mapping[str, Any] +} + + +TemplateFeatures: TypeAlias = Mapping[ + str, np.ndarray | bytes | Mapping[str, np.ndarray | bytes] +] +_REQUIRED_METADATA_COLUMNS: Final[Sequence[str]] = ( + 'seq_release_date', + 'seq_unresolved_res_num', + 'seq_author_chain_id', + 'seq_sequence', +) + + +@dataclasses.dataclass(frozen=True, kw_only=True, slots=True) +class _Polymer: + """Container for alphabet specific (dna, rna, protein) atom information.""" + + min_atoms: int + num_atom_types: int + atom_order: Mapping[str, int] + + +_POLYMERS = { + mmcif_names.PROTEIN_CHAIN: _Polymer( + min_atoms=5, + num_atom_types=atom_types.ATOM37_NUM, + atom_order=atom_types.ATOM37_ORDER, + ), + mmcif_names.DNA_CHAIN: _Polymer( + min_atoms=21, + num_atom_types=atom_types.ATOM29_NUM, + atom_order=atom_types.ATOM29_ORDER, + ), + mmcif_names.RNA_CHAIN: _Polymer( + min_atoms=20, + num_atom_types=atom_types.ATOM29_NUM, + atom_order=atom_types.ATOM29_ORDER, + ), +} + + +def _encode_restype( + chain_poly_type: str, + sequence: str, +) -> Sequence[int]: + """Encodes a sequence of residue names as a sequence of ints. + + Args: + chain_poly_type: Polymer chain type to determine sequence encoding. + sequence: Polymer residues. Protein encoded by single letters. RNA and DNA + encoded by multi-letter CCD codes. + + Returns: + A sequence of integers encoding amino acid types for the given chain type. + """ + if chain_poly_type == mmcif_names.PROTEIN_CHAIN: + return [ + residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP_TO_INT[ + _STANDARDIZED_AA.get(res, res) + ] + for res in sequence + ] + + unk_nucleic = residue_names.UNK_NUCLEIC_ONE_LETTER + unk_nucleic_idx = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[ + unk_nucleic + ] + if chain_poly_type == mmcif_names.RNA_CHAIN: + return [ + residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get( + res, unk_nucleic_idx + ) + for res in sequence + ] + elif chain_poly_type == mmcif_names.DNA_CHAIN: + # Map UNK DNA to the generic nucleic UNK (N), which happens to also be the + # same as the RNA UNK. + return [ + residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP.get( + residue_names.DNA_COMMON_ONE_TO_TWO.get(res, unk_nucleic), + unk_nucleic_idx, + ) + for res in sequence + ] + + raise NotImplementedError(f'"{chain_poly_type}" unsupported.') + + +_DAYS_BEFORE_QUERY_DATE: Final[int] = 60 +_HIT_DESCRIPTION_REGEX = re.compile( + r'(?P[a-z0-9]{4,})_(?P\w+)/(?P\d+)-(?P\d+) ' + r'.* length:(?P\d+)\b.*' +) + +_STANDARDIZED_AA = {'B': 'D', 'J': 'X', 'O': 'X', 'U': 'C', 'Z': 'E'} + + +class Error(Exception): + """Base class for exceptions.""" + + +class HitDateError(Error): + """An error indicating that invalid release date was detected.""" + + +class InvalidTemplateError(Error): + """An error indicating that template is invalid.""" + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Hit: + """Template hit metrics derived from the MSA for filtering and featurising. + + Attributes: + pdb_id: The PDB ID of the hit. + auth_chain_id: The author chain ID of the hit. + hmmsearch_sequence: Hit sequence as given in hmmsearch a3m output. + structure_sequence: Hit sequence as given in PDB structure. + unresolved_res_indices: Indices of unresolved residues in the structure + sequence. 0-based. + query_sequence: The query nucleotide/amino acid sequence. + start_index: The start index of the sequence relative to the full PDB seqres + sequence. Inclusive and uses 0-based indexing. + end_index: The end index of the sequence relative to the full PDB seqres + sequence. Exclusive and uses 0-based indexing. + full_length: Length of the full PDB seqres sequence. This can be different + from the length from the actual sequence we get from the mmCIF and we use + this to detect whether we need to realign or not. + release_date: The release date of the PDB corresponding to this hit. + chain_poly_type: The polymer type of the selected hit structure. + """ + + pdb_id: str + auth_chain_id: str + hmmsearch_sequence: str + structure_sequence: str + unresolved_res_indices: Sequence[int] | None + query_sequence: str + start_index: int + end_index: int + full_length: int + release_date: datetime.date + chain_poly_type: str + + @functools.cached_property + def query_to_hit_mapping(self) -> Mapping[int, int]: + """0-based query index to hit index mapping.""" + query_to_hit_mapping = {} + hit_index = 0 + query_index = 0 + for residue in self.hmmsearch_sequence: + # Gap inserted in the template + if residue == '-': + query_index += 1 + # Deleted residue in the template (would be a gap in the query). + elif residue.islower(): + hit_index += 1 + # Normal aligned residue, in both query and template. Add to mapping. + elif residue.isupper(): + query_to_hit_mapping[query_index] = hit_index + query_index += 1 + hit_index += 1 + + structure_subseq = self.structure_sequence[ + self.start_index : self.end_index + ] + if self.matching_sequence != structure_subseq: + # The seqres sequence doesn't match the structure sequence. Two cases: + # 1. The sequences have the same length. The sequences are different + # because our 3->1 residue code mapping is different from the one PDB + # uses. We don't do anything in this case as both sequences have the + # same length, so the original query to hit mapping stays valid. + # 2. The sequences don't have the same length, the one in structure is + # shorter. In this case we change the mapping to match the actual + # structure sequence using a simple realignment algorithm. + # This procedure was validated on all PDB seqres (2023_01_12) sequences + # and handles all cases that can happen. + if self.full_length != len(self.structure_sequence): + return template_realign.realign_hit_to_structure( + hit_sequence=self.matching_sequence, + hit_start_index=self.start_index, + hit_end_index=self.end_index, + full_length=self.full_length, + structure_sequence=self.structure_sequence, + query_to_hit_mapping=query_to_hit_mapping, + ) + + # Hmmsearch returns a subsequence and so far indices have been relative to + # the subsequence. Add an offset to index relative to the full structure + # sequence. + return {q: h + self.start_index for q, h in query_to_hit_mapping.items()} + + @property + def matching_sequence(self) -> str: + """Returns the matching hit sequence including insertions. + + Make deleted residues uppercase and remove gaps ("-"). + """ + return self.hmmsearch_sequence.upper().replace('-', '') + + @functools.cached_property + def output_templates_sequence(self) -> str: + """Returns the final template sequence.""" + result_seq = ['-'] * len(self.query_sequence) + for query_index, template_index in self.query_to_hit_mapping.items(): + result_seq[query_index] = self.structure_sequence[template_index] + return ''.join(result_seq) + + @property + def length_ratio(self) -> float: + """Ratio of the length of the hit sequence to the query.""" + return len(self.matching_sequence) / len(self.query_sequence) + + @property + def align_ratio(self) -> float: + """Ratio of the number of aligned residues to the query length.""" + return len(self.query_to_hit_mapping) / len(self.query_sequence) + + @functools.cached_property + def is_valid(self) -> bool: + """Whether hit can be used as a template.""" + if self.unresolved_res_indices is None: + return False + + return bool( + set(self.query_to_hit_mapping.values()) + - set(self.unresolved_res_indices) + ) + + @property + def full_name(self) -> str: + """A full name of the hit.""" + return f'{self.pdb_id}_{self.auth_chain_id}' + + def __post_init__(self): + if not self.pdb_id.islower() and not self.pdb_id.isdigit(): + raise ValueError(f'pdb_id must be lowercase {self.pdb_id}') + + if not (0 <= self.start_index <= self.end_index): + raise ValueError( + 'Start must be non-negative and less than or equal to end index. ' + f'Range: {self.start_index}-{self.end_index}' + ) + + if len(self.matching_sequence) != (self.end_index - self.start_index): + raise ValueError( + 'Sequence length must be equal to end_index - start_index. ' + f'{len(self.matching_sequence)} != {self.end_index} - ' + f'{self.start_index}' + ) + + if self.full_length < 0: + raise ValueError(f'Full length must be non-negative: {self.full_length}') + + def keep( + self, + *, + release_date_cutoff: datetime.date | None, + max_subsequence_ratio: float | None, + min_hit_length: int | None, + min_align_ratio: float | None, + ) -> bool: + """Returns whether the hit should be kept. + + In addition to filtering on all of the provided parameters, this method also + excludes hits with unresolved residues. + + Args: + release_date_cutoff: Maximum release date of the template. + max_subsequence_ratio: If set, excludes hits which are an exact + subsequence of the query sequence, and longer than this ratio. Useful to + avoid ground truth leakage. + min_hit_length: If set, excludes hits which have fewer residues than this. + min_align_ratio: If set, excludes hits where the number of residues + aligned to the query is less than this proportion of the template + length. + """ + # Exclude hits which are too recent. + if ( + release_date_cutoff is not None + and self.release_date > release_date_cutoff + ): + return False + + # Exclude hits which are large duplicates of the query_sequence. + if ( + max_subsequence_ratio is not None + and self.length_ratio > max_subsequence_ratio + ): + if self.matching_sequence in self.query_sequence: + return False + + # Exclude hits which are too short. + if ( + min_hit_length is not None + and len(self.matching_sequence) < min_hit_length + ): + return False + + # Exclude hits with unresolved residues. + if not self.is_valid: + return False + + # Exclude hits with too few alignments. + try: + if min_align_ratio is not None and self.align_ratio <= min_align_ratio: + return False + except template_realign.AlignmentError as e: + logging.warning('Failed to align %s: %s', self, str(e)) + return False + + return True + + +def _filter_hits( + hits: Iterable[Hit], + release_date_cutoff: datetime.date, + max_subsequence_ratio: float | None, + min_align_ratio: float | None, + min_hit_length: int | None, + deduplicate_sequences: bool, + max_hits: int | None, +) -> Sequence[Hit]: + """Filters hits based on the filter config.""" + filtered_hits = [] + seen_before = set() + for hit in hits: + if not hit.keep( + max_subsequence_ratio=max_subsequence_ratio, + min_align_ratio=min_align_ratio, + min_hit_length=min_hit_length, + release_date_cutoff=release_date_cutoff, + ): + continue + + # Remove duplicate templates, keeping the first. + if deduplicate_sequences: + if hit.output_templates_sequence in seen_before: + continue + seen_before.add(hit.output_templates_sequence) + + filtered_hits.append(hit) + if max_hits and len(filtered_hits) == max_hits: + break + + return filtered_hits + + +@dataclasses.dataclass(init=False) +class Templates: + """A container for templates that were found for the given query sequence. + + The structure_store is constructed from the config by default. Callers can + optionally supply a structure_store to the constructor to avoid the cost of + construction and metadata loading. + """ + + def __init__( + self, + *, + query_sequence: str, + hits: Sequence[Hit], + max_template_date: datetime.date, + structure_store: structure_stores.StructureStore, + query_release_date: datetime.date | None = None, + ): + self._query_sequence = query_sequence + self._hits = tuple(hits) + self._max_template_date = max_template_date + self._query_release_date = query_release_date + self._hit_structures = {} + self._structure_store = structure_store + + if any(h.query_sequence != self._query_sequence for h in self.hits): + raise ValueError('All hits must match the query sequence.') + + if self._hits: + chain_poly_type = self._hits[0].chain_poly_type + if any(h.chain_poly_type != chain_poly_type for h in self.hits): + raise ValueError('All hits must have the same chain_poly_type.') + + @classmethod + def from_seq_and_a3m( + cls, + *, + query_sequence: str, + msa_a3m: str, + max_template_date: datetime.date, + database_path: os.PathLike[str] | str, + hmmsearch_config: msa_config.HmmsearchConfig, + max_a3m_query_sequences: int | None, + structure_store: structure_stores.StructureStore, + filter_config: msa_config.TemplateFilterConfig | None = None, + query_release_date: datetime.date | None = None, + chain_poly_type: str = mmcif_names.PROTEIN_CHAIN, + ) -> Self: + """Creates templates from a run of hmmsearch tool against a custom a3m. + + Args: + query_sequence: The polymer sequence of the target query. + msa_a3m: An a3m of related polymers aligned to the query sequence, this is + used to create an HMM for the hmmsearch run. + max_template_date: This is used to filter templates for training, ensuring + that they do not leak ground truth information used in testing sets. + database_path: A path to the sequence database to search for templates. + hmmsearch_config: Config with Hmmsearch settings. + max_a3m_query_sequences: The maximum number of input MSA sequences to use + to construct the profile which is then used to search for templates. + structure_store: Structure store to fetch template structures from. + filter_config: Optional config that controls which and how many hits to + keep. More performant than constructing and then filtering. If not + provided, no filtering is done. + query_release_date: The release_date of the template query, this is used + to filter templates for training, ensuring that they do not leak + structure information from the future. + chain_poly_type: The polymer type of the templates. + + Returns: + Templates object containing a list of Hits initialised from the + structure_store metadata and a3m alignments. + """ + hmmsearch_a3m = run_hmmsearch_with_a3m( + database_path=database_path, + hmmsearch_config=hmmsearch_config, + max_a3m_query_sequences=max_a3m_query_sequences, + a3m=msa_a3m, + ) + return cls.from_hmmsearch_a3m( + query_sequence=query_sequence, + a3m=hmmsearch_a3m, + max_template_date=max_template_date, + query_release_date=query_release_date, + chain_poly_type=chain_poly_type, + structure_store=structure_store, + filter_config=filter_config, + ) + + @classmethod + def from_hmmsearch_a3m( + cls, + *, + query_sequence: str, + a3m: str, + max_template_date: datetime.date, + structure_store: structure_stores.StructureStore, + filter_config: msa_config.TemplateFilterConfig | None = None, + query_release_date: datetime.date | None = None, + chain_poly_type: str = mmcif_names.PROTEIN_CHAIN, + ) -> Self: + """Creates Templates from a Hmmsearch A3M. + + Args: + query_sequence: The polymer sequence of the target query. + a3m: Results of Hmmsearch in A3M format. This provides a list of potential + template alignments and pdb codes. + max_template_date: This is used to filter templates for training, ensuring + that they do not leak ground truth information used in testing sets. + structure_store: Structure store to fetch template structures from. + filter_config: Optional config that controls which and how many hits to + keep. More performant than constructing and then filtering. If not + provided, no filtering is done. + query_release_date: The release_date of the template query, this is used + to filter templates for training, ensuring that they do not leak + structure information from the future. + chain_poly_type: The polymer type of the templates. + + Returns: + Templates object containing a list of Hits initialised from the + structure_store metadata and a3m alignments. + """ + + def hit_generator(a3m: str): + for hit_seq, hit_desc in parsers.lazy_parse_fasta_string(a3m): + pdb_id, auth_chain_id, start, end, full_length = _parse_hit_description( + hit_desc + ) + + release_date, sequence, unresolved_res_ids = _parse_hit_metadata( + structure_store, pdb_id, auth_chain_id + ) + if unresolved_res_ids is None: + continue + + # seq_unresolved_res_num are 1-based, setting to 0-based indices. + unresolved_indices = [i - 1 for i in unresolved_res_ids] + + yield Hit( + pdb_id=pdb_id, + auth_chain_id=auth_chain_id, + hmmsearch_sequence=hit_seq, + structure_sequence=sequence, + query_sequence=query_sequence, + unresolved_res_indices=unresolved_indices, + start_index=start - 1, # Raw value is residue number, not index. + end_index=end, + full_length=full_length, + release_date=datetime.date.fromisoformat(release_date), + chain_poly_type=chain_poly_type, + ) + + if filter_config is None: + hits = tuple(hit_generator(a3m)) + else: + hits = _filter_hits( + hit_generator(a3m), + release_date_cutoff=filter_config.max_template_date, + max_subsequence_ratio=filter_config.max_subsequence_ratio, + min_align_ratio=filter_config.min_align_ratio, + min_hit_length=filter_config.min_hit_length, + deduplicate_sequences=filter_config.deduplicate_sequences, + max_hits=filter_config.max_hits, + ) + + return Templates( + query_sequence=query_sequence, + query_release_date=query_release_date, + hits=hits, + max_template_date=max_template_date, + structure_store=structure_store, + ) + + @property + def query_sequence(self) -> str: + return self._query_sequence + + @property + def hits(self) -> tuple[Hit, ...]: + return self._hits + + @property + def query_release_date(self) -> datetime.date | None: + return self._query_release_date + + @property + def num_hits(self) -> int: + return len(self._hits) + + @functools.cached_property + def release_date_cutoff(self) -> datetime.date: + if self.query_release_date is None: + return self._max_template_date + return min( + self._max_template_date, + self.query_release_date + - datetime.timedelta(days=_DAYS_BEFORE_QUERY_DATE), + ) + + def __repr__(self) -> str: + return f'Templates({self.num_hits} hits)' + + def filter( + self, + *, + max_subsequence_ratio: float | None, + min_align_ratio: float | None, + min_hit_length: int | None, + deduplicate_sequences: bool, + max_hits: int | None, + ) -> Self: + """Returns a new Templates object with only the hits that pass all filters. + + This also filters on query_release_date and max_template_date. + + Args: + max_subsequence_ratio: If set, excludes hits which are an exact + subsequence of the query sequence, and longer than this ratio. Useful to + avoid ground truth leakage. + min_align_ratio: If set, excludes hits where the number of residues + aligned to the query is less than this proportion of the template + length. + min_hit_length: If set, excludes hits which have fewer residues than this. + deduplicate_sequences: Whether to exclude duplicate template sequences, + keeping only the first. This can be useful in increasing the diversity + of hits especially in the case of homomer hits. + max_hits: If set, excludes any hits which exceed this count. + """ + filtered_hits = _filter_hits( + hits=self._hits, + release_date_cutoff=self.release_date_cutoff, + max_subsequence_ratio=max_subsequence_ratio, + min_align_ratio=min_align_ratio, + min_hit_length=min_hit_length, + deduplicate_sequences=deduplicate_sequences, + max_hits=max_hits, + ) + return Templates( + query_sequence=self.query_sequence, + query_release_date=self.query_release_date, + hits=filtered_hits, + max_template_date=self._max_template_date, + structure_store=self._structure_store, + ) + + def get_hits_with_structures( + self, + ) -> Sequence[tuple[Hit, structure.Structure]]: + """Returns hits + Structures, Structures filtered to the hit's chain.""" + results = [] + structures = {struc.name.lower(): struc for struc in self.structures} + for hit in self.hits: + if not hit.is_valid: + raise InvalidTemplateError( + 'Hits must be filtered before calling get_hits_with_structures.' + ) + struc = structures[hit.pdb_id] + label_chain_id = struc.polymer_auth_asym_id_to_label_asym_id().get( + hit.auth_chain_id + ) + results.append((hit, struc.filter(chain_id=label_chain_id))) + return results + + def featurize( + self, + include_ligand_features: bool = True, + ) -> TemplateFeatures: + """Featurises the templates and returns a map of feature names to features. + + NB: If you don't do any prefiltering, this method might be slow to run + as it has to fetch many CIFs and featurize them all. + + Args: + include_ligand_features: Whether to compute ligand features. + + Returns: + Template features: A mapping of template feature labels to features, which + may be numpy arrays, bytes objects, or for the special case of label + `ligand_features` (if `include_ligand_features` is True), a nested + feature map of labels to numpy arrays. + + Raises: + InvalidTemplateError: If hits haven't been filtered before featurization. + """ + hits_by_pdb_id = {} + for idx, hit in enumerate(self.hits): + if not hit.is_valid: + raise InvalidTemplateError( + f'Hits must be filtered before featurizing, got unprocessed {hit=}' + ) + hits_by_pdb_id.setdefault(hit.pdb_id, []).append((idx, hit)) + + unsorted_features = [] + for struc in self.structures: + pdb_id = str(struc.name).lower() + for idx, hit in hits_by_pdb_id[pdb_id]: + try: + label_chain_id = struc.polymer_auth_asym_id_to_label_asym_id()[ + hit.auth_chain_id + ] + hit_features = { + **get_polymer_features( + chain=struc.filter(chain_id=label_chain_id), + chain_poly_type=hit.chain_poly_type, + query_sequence_length=len(hit.query_sequence), + query_to_hit_mapping=hit.query_to_hit_mapping, + ), + } + if include_ligand_features: + hit_features['ligand_features'] = _get_ligand_features(struc) + unsorted_features.append((idx, hit_features)) + except Error as e: + raise type(e)(f'Failed to featurise {hit=}') from e + + sorted_features = sorted(unsorted_features, key=lambda x: x[0]) + sorted_features = [feat for _, feat in sorted_features] + return package_template_features( + hit_features=sorted_features, + include_ligand_features=include_ligand_features, + ) + + @property + def structures(self) -> Iterator[structure.Structure]: + """Yields template structures for each unique PDB ID among hits. + + If there are multiple hits in the same Structure, the Structure will be + included only once by this method. + + Yields: + A Structure object for each unique PDB ID among hits. + + Raises: + HitDateError: If template's release date exceeds max cutoff date. + """ + + for hit in self.hits: + if hit.release_date > self.release_date_cutoff: # pylint: disable=comparison-with-callable + raise HitDateError( + f'Invalid release date for hit {hit.pdb_id=}, when release date ' + f'cutoff is {self.release_date_cutoff}.' + ) + + # Get the set of pdbs to load. In particular, remove duplicate PDB IDs. + targets_to_load = tuple({hit.pdb_id for hit in self.hits}) + + for target_name in targets_to_load: + yield structure.from_mmcif( + mmcif_string=self._structure_store.get_mmcif_str(target_name), + fix_mse_residues=True, + fix_arginines=True, + include_water=False, + include_bonds=False, + include_other=True, # For non-standard polymer chains. + ) + + +def _parse_hit_description(description: str) -> tuple[str, str, int, int, int]: + """Parses the hmmsearch A3M sequence description line.""" + # Example lines (protein, nucleic, no description): + # >4pqx_A/2-217 [subseq from] mol:protein length:217 Free text + # >4pqx_A/2-217 [subseq from] mol:na length:217 Free text + # >5g3r_A/1-55 [subseq from] mol:protein length:352 + if match := re.fullmatch(_HIT_DESCRIPTION_REGEX, description): + return ( + match['pdb_id'], + match['chain_id'], + int(match['start']), + int(match['end']), + int(match['length']), + ) + else: + raise ValueError(f'Could not parse description "{description}"') + + +def _parse_hit_metadata( + structure_store: structure_stores.StructureStore, + pdb_id: str, + auth_chain_id: str, +) -> tuple[Any, str | None, Sequence[int] | None]: + """Parse hit metadata by parsing mmCIF from structure store.""" + try: + cif = mmcif.from_string(structure_store.get_mmcif_str(pdb_id)) + except structure_stores.NotFoundError: + logging.warning('Failed to get mmCIF for %s.', pdb_id) + return None, None, None + release_date = mmcif.get_release_date(cif) + + try: + struc = structure.from_parsed_mmcif( + cif, + model_id=structure.ModelID.ALL, + include_water=True, + include_other=True, + include_bonds=False, + ) + except ValueError: + struc = structure.from_parsed_mmcif( + cif, + model_id=structure.ModelID.FIRST, + include_water=True, + include_other=True, + include_bonds=False, + ) + + sequence = struc.polymer_author_chain_single_letter_sequence( + include_missing_residues=True, + protein=True, + dna=True, + rna=True, + other=True, + )[auth_chain_id] + + unresolved_res_ids = struc.filter( + chain_auth_asym_id=auth_chain_id + ).unresolved_residues.id + + return release_date, sequence, unresolved_res_ids + + +def get_polymer_features( + *, + chain: structure.Structure, + chain_poly_type: str, + query_sequence_length: int, + query_to_hit_mapping: Mapping[int, int], +) -> Mapping[str, Any]: + """Returns features for this polymer chain. + + Args: + chain: Structure object representing the template. Must be already filtered + to a single chain. + chain_poly_type: The chain polymer type (protein, DNA, RNA). + query_sequence_length: The length of the query sequence. + query_to_hit_mapping: 0-based query index to hit index mapping. + + Returns: + A dictionary with polymer features for template_chain_id in the struc. + + Raises: + ValueError: If the input structure contains more than just a single chain. + """ + if len(chain.polymer_auth_asym_id_to_label_asym_id()) != 1: + raise ValueError('The structure must be filtered to a single chain.') + + if chain.name is None: + raise ValueError('The structure must have a name.') + + if chain.release_date is None: + raise ValueError('The structure must have a release date.') + + auth_chain_id, label_chain_id = next( + iter(chain.polymer_auth_asym_id_to_label_asym_id().items()) + ) + chain_sequence = chain.chain_single_letter_sequence()[label_chain_id] + + polymer = _POLYMERS[chain_poly_type] + positions, positions_mask = chain.to_res_arrays( + include_missing_residues=True, atom_order=polymer.atom_order + ) + template_all_atom_positions = np.zeros( + (query_sequence_length, polymer.num_atom_types, 3), dtype=np.float64 + ) + template_all_atom_masks = np.zeros( + (query_sequence_length, polymer.num_atom_types), dtype=np.int64 + ) + + template_sequence = ['-'] * query_sequence_length + for query_index, template_index in query_to_hit_mapping.items(): + template_all_atom_positions[query_index] = positions[template_index] + template_all_atom_masks[query_index] = positions_mask[template_index] + template_sequence[query_index] = chain_sequence[template_index] + + template_sequence = ''.join(template_sequence) + template_aatype = _encode_restype(chain_poly_type, template_sequence) + template_name = f'{chain.name.lower()}_{auth_chain_id}' + release_date = chain.release_date.strftime('%Y-%m-%d') + return { + 'template_all_atom_positions': template_all_atom_positions, + 'template_all_atom_masks': template_all_atom_masks, + 'template_sequence': template_sequence.encode(), + 'template_aatype': np.array(template_aatype, dtype=np.int32), + 'template_domain_names': np.array(template_name.encode(), dtype=object), + 'template_release_date': np.array(release_date.encode(), dtype=object), + } + + +def _get_ligand_features( + struc: structure.Structure, +) -> Mapping[str, Mapping[str, np.ndarray | bytes]]: + """Returns features for the ligands in this structure.""" + ligand_struc = struc.filter_to_entity_type(ligand=True) + assert ligand_struc.coords is not None + assert ligand_struc.atom_name is not None + assert ligand_struc.atom_occupancy is not None + + ligand_features = {} + for ligand_chain_id in ligand_struc.chains: + idxs = np.where(ligand_struc.chain_id == ligand_chain_id)[0] + if idxs.shape[0]: + ligand_features[ligand_chain_id] = { + 'ligand_atom_positions': ligand_struc.coords[idxs, :].astype( + np.float32 + ), + 'ligand_atom_names': ligand_struc.atom_name[idxs].astype(object), + 'ligand_atom_occupancies': ligand_struc.atom_occupancy[idxs].astype( + np.float32 + ), + 'ccd_id': ligand_struc.res_name[idxs][0].encode(), + } + return ligand_features + + +def package_template_features( + *, + hit_features: Sequence[Mapping[str, Any]], + include_ligand_features: bool, +) -> Mapping[str, Any]: + """Stacks polymer features, adds empty and keeps ligand features unstacked.""" + + features_to_include = set(_POLYMER_FEATURES) + if include_ligand_features: + features_to_include.update(_LIGAND_FEATURES) + + features = { + feat: [single_hit_features[feat] for single_hit_features in hit_features] + for feat in features_to_include + } + + stacked_features = {} + for k, v in features.items(): + if k in _POLYMER_FEATURES: + v = np.stack(v, axis=0) if v else np.array([], dtype=_POLYMER_FEATURES[k]) + stacked_features[k] = v + + return stacked_features + + +def _resolve_path(path: os.PathLike[str] | str) -> str: + """Resolves path for data dep paths, stringifies otherwise.""" + # Data dependency paths: db baked into the binary. + resolved_path = resources.filename(path) + if os.path.exists(resolved_path): + return resolved_path + else: + # Other paths, e.g. local. + return str(path) + + +def run_hmmsearch_with_a3m( + *, + database_path: os.PathLike[str] | str, + hmmsearch_config: msa_config.HmmsearchConfig, + max_a3m_query_sequences: int | None, + a3m: str | None, +) -> str: + """Runs Hmmsearch to get a3m string of hits.""" + searcher = hmmsearch.Hmmsearch( + binary_path=hmmsearch_config.hmmsearch_binary_path, + hmmbuild_binary_path=hmmsearch_config.hmmbuild_binary_path, + database_path=_resolve_path(database_path), + e_value=hmmsearch_config.e_value, + inc_e=hmmsearch_config.inc_e, + dom_e=hmmsearch_config.dom_e, + incdom_e=hmmsearch_config.incdom_e, + alphabet=hmmsearch_config.alphabet, + filter_f1=hmmsearch_config.filter_f1, + filter_f2=hmmsearch_config.filter_f2, + filter_f3=hmmsearch_config.filter_f3, + filter_max=hmmsearch_config.filter_max, + ) + # STO enables us to annotate query non-gap columns as reference columns. + sto = parsers.convert_a3m_to_stockholm(a3m, max_a3m_query_sequences) + return searcher.query_with_sto(sto, model_construction='hand') diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py new file mode 100644 index 000000000..a7020c010 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmalign.py @@ -0,0 +1,143 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmalign from the HMMER Suite.""" + +from collections.abc import Mapping, Sequence +import os +import tempfile + +from alphafold3.data import parsers +from alphafold3.data.tools import subprocess_utils + + +def _to_a3m(sequences: Sequence[str], name_prefix: str = 'sequence') -> str: + a3m = '' + for i, sequence in enumerate(sequences, 1): + a3m += f'> {name_prefix} {i}\n{sequence}\n' + return a3m + + +class Hmmalign: + """Python wrapper of the hmmalign binary.""" + + def __init__(self, binary_path: str): + """Initializes the Python hmmalign wrapper. + + Args: + binary_path: Path to the hmmalign binary. + + Raises: + RuntimeError: If hmmalign binary not found within the path. + """ + self.binary_path = binary_path + + subprocess_utils.check_binary_exists(path=self.binary_path, name='hmmalign') + + def align_sequences( + self, + sequences: Sequence[str], + profile: str, + extra_flags: Mapping[str, str] | None = None, + ) -> str: + """Aligns sequence list to the profile and returns the alignment in A3M.""" + return self.align( + a3m_str=_to_a3m(sequences, name_prefix='query'), + profile=profile, + extra_flags=extra_flags, + ) + + def align( + self, + a3m_str: str, + profile: str, + extra_flags: Mapping[str, str] | None = None, + ) -> str: + """Aligns sequences in A3M to the profile and returns the alignment in A3M. + + Args: + a3m_str: A list of sequence strings. + profile: A hmm file with the hmm profile to align the sequences to. + extra_flags: Dictionary with extra flags, flag_name: flag_value, that are + added to hmmalign. + + Returns: + An A3M string with the aligned sequences. + + Raises: + RuntimeError: If hmmalign fails. + """ + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_profile = os.path.join(query_tmp_dir, 'profile.hmm') + input_sequences = os.path.join(query_tmp_dir, 'sequences.a3m') + output_a3m_path = os.path.join(query_tmp_dir, 'output.a3m') + + with open(input_profile, 'w') as f: + f.write(profile) + + with open(input_sequences, 'w') as f: + f.write(a3m_str) + + cmd = [ + self.binary_path, + *('-o', output_a3m_path), + *('--outformat', 'A2M'), # A2M is A3M in the HMMER suite. + ] + if extra_flags: + for flag_name, flag_value in extra_flags.items(): + cmd.extend([flag_name, flag_value]) + cmd.extend([input_profile, input_sequences]) + + subprocess_utils.run( + cmd=cmd, + cmd_name='hmmalign', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_a3m_path, encoding='utf-8') as f: + a3m = f.read() + + return a3m + + def align_sequences_to_profile(self, profile: str, sequences_a3m: str) -> str: + """Aligns the sequences to profile and returns the alignment in A3M string. + + Uses hmmalign to align the sequences to the profile, then ouputs the + sequence contatenated at the beginning of the sequences in the A3M format. + As the sequences are represented by an alignment with possible gaps ('-') + and insertions (lowercase characters), the method first removes the gaps, + then uppercases the insertions to prepare the sequences for realignment. + Sequences with gaps cannot be aligned, as '-'s are not a valid symbol to + align; lowercase characters must be uppercased to preserve the original + sequences before realignment. + + Args: + profile: The Hmmbuild profile to align the sequences to. + sequences_a3m: Sequences in A3M format to align to the profile. + + Returns: + An A3M string with the aligned sequences. + + Raises: + RuntimeError: If hmmalign fails. + """ + deletion_table = str.maketrans('', '', '-') + sequences_no_gaps_a3m = [] + for seq, desc in parsers.lazy_parse_fasta_string(sequences_a3m): + sequences_no_gaps_a3m.append(f'>{desc}') + sequences_no_gaps_a3m.append(seq.translate(deletion_table)) + sequences_no_gaps_a3m = '\n'.join(sequences_no_gaps_a3m) + + aligned_sequences = self.align(sequences_no_gaps_a3m, profile) + + return aligned_sequences diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py new file mode 100644 index 000000000..d08747c38 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmbuild.py @@ -0,0 +1,145 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmbuild - construct HMM profiles from MSA.""" + +import os +import re +import tempfile +from typing import Literal + +from alphafold3.data import parsers +from alphafold3.data.tools import subprocess_utils + + +class Hmmbuild(object): + """Python wrapper of the hmmbuild binary.""" + + def __init__( + self, + *, + binary_path: str, + singlemx: bool = False, + alphabet: str | None = None, + ): + """Initializes the Python hmmbuild wrapper. + + Args: + binary_path: The path to the hmmbuild executable. + singlemx: Whether to use --singlemx flag. If True, it forces HMMBuild to + just use a common substitution score matrix. + alphabet: The alphabet to assert when building a profile. Useful when + hmmbuild cannot guess the alphabet. If None, no alphabet is asserted. + + Raises: + RuntimeError: If hmmbuild binary not found within the path. + """ + self.binary_path = binary_path + self.singlemx = singlemx + self.alphabet = alphabet + + subprocess_utils.check_binary_exists(path=self.binary_path, name='hmmbuild') + + def build_profile_from_sto(self, sto: str, model_construction='fast') -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + sto: A string with the aligned sequences in the Stockholm format. + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + return self._build_profile( + sto, informat='stockholm', model_construction=model_construction + ) + + def build_profile_from_a3m(self, a3m: str) -> str: + """Builds a HHM for the aligned sequences given as an A3M string. + + Args: + a3m: A string with the aligned sequences in the A3M format. + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + """ + lines = [] + for sequence, description in parsers.lazy_parse_fasta_string(a3m): + sequence = re.sub('[a-z]+', '', sequence) # Remove inserted residues. + lines.append(f'>{description}\n{sequence}\n') + msa = ''.join(lines) + return self._build_profile(msa, informat='afa') + + def _build_profile( + self, + msa: str, + informat: Literal['afa', 'stockholm'], + model_construction: str = 'fast', + ) -> str: + """Builds a HMM for the aligned sequences given as an MSA string. + + Args: + msa: A string with the aligned sequences, in A3M or STO format. + informat: One of 'afa' (aligned FASTA) or 'sto' (Stockholm). + model_construction: Whether to use reference annotation in the msa to + determine consensus columns ('hand') or default ('fast'). + + Returns: + A string with the profile in the HMM format. + + Raises: + RuntimeError: If hmmbuild fails. + ValueError: If unspecified arguments are provided. + """ + if model_construction not in {'hand', 'fast'}: + raise ValueError(f'Bad {model_construction=}. Only hand or fast allowed.') + + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_msa_path = os.path.join(query_tmp_dir, 'query.msa') + output_hmm_path = os.path.join(query_tmp_dir, 'output.hmm') + + with open(input_msa_path, 'w') as f: + f.write(msa) + + # Specify the format as we don't specify the input file extension. See + # https://github.com/EddyRivasLab/hmmer/issues/321 for more details. + cmd_flags = ['--informat', informat] + # If adding flags, we have to do so before the output and input: + if model_construction == 'hand': + cmd_flags.append(f'--{model_construction}') + if self.singlemx: + cmd_flags.append('--singlemx') + if self.alphabet: + cmd_flags.append(f'--{self.alphabet}') + + cmd_flags.extend([output_hmm_path, input_msa_path]) + + cmd = [self.binary_path, *cmd_flags] + + subprocess_utils.run( + cmd=cmd, + cmd_name='Hmmbuild', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_hmm_path) as f: + hmm = f.read() + + return hmm diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py new file mode 100644 index 000000000..e425b768c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/hmmsearch.py @@ -0,0 +1,150 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""A Python wrapper for hmmsearch - search profile against a sequence db.""" + +import os +import tempfile + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import hmmbuild +from alphafold3.data.tools import subprocess_utils + + +class Hmmsearch(object): + """Python wrapper of the hmmsearch binary.""" + + def __init__( + self, + *, + binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + alphabet: str = 'amino', + filter_f1: float | None = None, + filter_f2: float | None = None, + filter_f3: float | None = None, + e_value: float | None = None, + inc_e: float | None = None, + dom_e: float | None = None, + incdom_e: float | None = None, + filter_max: bool = False, + ): + """Initializes the Python hmmsearch wrapper. + + Args: + binary_path: The path to the hmmsearch executable. + hmmbuild_binary_path: The path to the hmmbuild executable. Used to build + an hmm from an input a3m. + database_path: The path to the hmmsearch database (FASTA format). + alphabet: Chain type e.g. amino, rna, dna. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + e_value: E-value criteria for inclusion in tblout. + inc_e: E-value criteria for inclusion in MSA/next round. + dom_e: Domain e-value criteria for inclusion in tblout. + incdom_e: Domain e-value criteria for inclusion of domains in MSA/next + round. + filter_max: Remove all filters, will ignore all filter_f* settings. + + Raises: + RuntimeError: If hmmsearch binary not found within the path. + """ + self.binary_path = binary_path + self.hmmbuild_runner = hmmbuild.Hmmbuild( + alphabet=alphabet, binary_path=hmmbuild_binary_path + ) + self.database_path = database_path + flags = [] + if filter_max: + flags.append('--max') + else: + if filter_f1 is not None: + flags.extend(('--F1', filter_f1)) + if filter_f2 is not None: + flags.extend(('--F2', filter_f2)) + if filter_f3 is not None: + flags.extend(('--F3', filter_f3)) + + if e_value is not None: + flags.extend(('-E', e_value)) + if inc_e is not None: + flags.extend(('--incE', inc_e)) + if dom_e is not None: + flags.extend(('--domE', dom_e)) + if incdom_e is not None: + flags.extend(('--incdomE', incdom_e)) + + self.flags = tuple(map(str, flags)) + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='hmmsearch' + ) + + if not os.path.exists(self.database_path): + logging.error('Could not find hmmsearch database %s', database_path) + raise ValueError(f'Could not find hmmsearch database {database_path}') + + def query_with_hmm(self, hmm: str) -> str: + """Queries the database using hmmsearch using a given hmm.""" + with tempfile.TemporaryDirectory() as query_tmp_dir: + hmm_input_path = os.path.join(query_tmp_dir, 'query.hmm') + sto_out_path = os.path.join(query_tmp_dir, 'output.sto') + with open(hmm_input_path, 'w') as f: + f.write(hmm) + + cmd = [ + self.binary_path, + '--noali', # Don't include the alignment in stdout. + *('--cpu', '8'), + ] + # If adding flags, we have to do so before the output and input: + if self.flags: + cmd.extend(self.flags) + cmd.extend([ + *('-A', sto_out_path), + hmm_input_path, + self.database_path, + ]) + + subprocess_utils.run( + cmd=cmd, + cmd_name='Hmmsearch', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(sto_out_path) as f: + a3m_out = parsers.convert_stockholm_to_a3m( + f, remove_first_row_gaps=False, linewidth=60 + ) + + return a3m_out + + def query_with_a3m(self, a3m_in: str) -> str: + """Query the database using hmmsearch using a given a3m.""" + + # Only the "fast" model construction makes sense with A3M, as it doesn't + # have any way to annotate reference columns. + hmm = self.hmmbuild_runner.build_profile_from_a3m(a3m_in) + return self.query_with_hmm(hmm) + + def query_with_sto( + self, msa_sto: str, model_construction: str = 'fast' + ) -> str: + """Queries the database using hmmsearch using a given stockholm msa.""" + hmm = self.hmmbuild_runner.build_profile_from_sto( + msa_sto, model_construction=model_construction + ) + return self.query_with_hmm(hmm) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py new file mode 100644 index 000000000..ecfa6e88d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/jackhmmer.py @@ -0,0 +1,135 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library to run Jackhmmer from Python.""" + +import os +import tempfile + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import subprocess_utils + + +class Jackhmmer(msa_tool.MsaTool): + """Python wrapper of the Jackhmmer binary.""" + + def __init__( + self, + *, + binary_path: str, + database_path: str, + n_cpu: int = 8, + n_iter: int = 3, + e_value: float | None = 1e-3, + z_value: float | int | None = None, + max_sequences: int = 5000, + filter_f1: float = 5e-4, + filter_f2: float = 5e-5, + filter_f3: float = 5e-7, + ): + """Initializes the Python Jackhmmer wrapper. + + Args: + binary_path: The path to the jackhmmer executable. + database_path: The path to the jackhmmer database (FASTA format). + n_cpu: The number of CPUs to give Jackhmmer. + n_iter: The number of Jackhmmer iterations. + e_value: The E-value, see Jackhmmer docs for more details. + z_value: The Z-value representing the number of comparisons done (i.e + correct database size) for E-value calculation. + max_sequences: Maximum number of sequences to return in the MSA. + filter_f1: MSV and biased composition pre-filter, set to >1.0 to turn off. + filter_f2: Viterbi pre-filter, set to >1.0 to turn off. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + + Raises: + RuntimeError: If Jackhmmer binary not found within the path. + """ + self.binary_path = binary_path + self.database_path = database_path + + subprocess_utils.check_binary_exists( + path=self.binary_path, name='Jackhmmer' + ) + + if not os.path.exists(self.database_path): + raise ValueError(f'Could not find Jackhmmer database {database_path}') + + self.n_cpu = n_cpu + self.n_iter = n_iter + self.e_value = e_value + self.z_value = z_value + self.max_sequences = max_sequences + self.filter_f1 = filter_f1 + self.filter_f2 = filter_f2 + self.filter_f3 = filter_f3 + + def query(self, target_sequence: str) -> msa_tool.MsaToolResult: + """Queries the database using Jackhmmer.""" + logging.info('Query sequence: %s', target_sequence) + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_fasta_path = os.path.join(query_tmp_dir, 'query.fasta') + subprocess_utils.create_query_fasta_file( + sequence=target_sequence, path=input_fasta_path + ) + + output_sto_path = os.path.join(query_tmp_dir, 'output.sto') + + # The F1/F2/F3 are the expected proportion to pass each of the filtering + # stages (which get progressively more expensive), reducing these + # speeds up the pipeline at the expensive of sensitivity. They are + # currently set very low to make querying Mgnify run in a reasonable + # amount of time. + cmd_flags = [ + *('-o', '/dev/null'), # Don't pollute stdout with Jackhmmer output. + *('-A', output_sto_path), + '--noali', + *('--F1', str(self.filter_f1)), + *('--F2', str(self.filter_f2)), + *('--F3', str(self.filter_f3)), + *('--cpu', str(self.n_cpu)), + *('-N', str(self.n_iter)), + ] + + # Report only sequences with E-values <= x in per-sequence output. + if self.e_value is not None: + cmd_flags.extend(['-E', str(self.e_value)]) + + # Use the same value as the reporting e-value (`-E` flag). + cmd_flags.extend(['--incE', str(self.e_value)]) + + if self.z_value is not None: + cmd_flags.extend(['-Z', str(self.z_value)]) + + cmd = ( + [self.binary_path] + + cmd_flags + + [input_fasta_path, self.database_path] + ) + + subprocess_utils.run( + cmd=cmd, + cmd_name='Jackhmmer', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + with open(output_sto_path) as f: + a3m = parsers.convert_stockholm_to_a3m( + f, max_sequences=self.max_sequences + ) + + return msa_tool.MsaToolResult( + target_sequence=target_sequence, a3m=a3m, e_value=self.e_value + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py new file mode 100644 index 000000000..0fc329c6a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/msa_tool.py @@ -0,0 +1,31 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Defines protocol for MSA tools.""" + +import dataclasses +from typing import Protocol + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class MsaToolResult: + """The result of a MSA tool query.""" + + target_sequence: str + e_value: float + a3m: str + + +class MsaTool(Protocol): + """Interface for MSA tools.""" + + def query(self, target_sequence: str) -> MsaToolResult: + """Runs the MSA tool on the target sequence.""" diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py new file mode 100644 index 000000000..9fce2a67b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/nhmmer.py @@ -0,0 +1,167 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Library to run Nhmmer from Python.""" + +import os +import pathlib +import tempfile +from typing import Final + +from absl import logging +from alphafold3.data import parsers +from alphafold3.data.tools import hmmalign +from alphafold3.data.tools import hmmbuild +from alphafold3.data.tools import msa_tool +from alphafold3.data.tools import subprocess_utils + +_SHORT_SEQUENCE_CUTOFF: Final[int] = 50 + + +class Nhmmer(msa_tool.MsaTool): + """Python wrapper of the Nhmmer binary.""" + + def __init__( + self, + binary_path: str, + hmmalign_binary_path: str, + hmmbuild_binary_path: str, + database_path: str, + n_cpu: int = 8, + e_value: float = 1e-3, + max_sequences: int = 5000, + filter_f3: float = 1e-5, + alphabet: str | None = None, + strand: str | None = None, + ): + """Initializes the Python Nhmmer wrapper. + + Args: + binary_path: Path to the Nhmmer binary. + hmmalign_binary_path: Path to the Hmmalign binary. + hmmbuild_binary_path: Path to the Hmmbuild binary. + database_path: MSA database path to search against. This can be either a + FASTA (slow) or HMMERDB produced from the FASTA using the makehmmerdb + binary. The HMMERDB is ~10x faster but experimental. + n_cpu: The number of CPUs to give Nhmmer. + e_value: The E-value, see Nhmmer docs for more details. Will be + overwritten if bit_score is set. + max_sequences: Maximum number of sequences to return in the MSA. + filter_f3: Forward pre-filter, set to >1.0 to turn off. + alphabet: The alphabet to assert when building a profile with hmmbuild. + This must be 'rna', 'dna', or None. + strand: "watson" searches query sequence, "crick" searches + reverse-compliment and default is None which means searching for both. + + Raises: + RuntimeError: If Nhmmer binary not found within the path. + """ + self._binary_path = binary_path + self._hmmalign_binary_path = hmmalign_binary_path + self._hmmbuild_binary_path = hmmbuild_binary_path + self._db_path = database_path + + subprocess_utils.check_binary_exists(path=self._binary_path, name='Nhmmer') + + if strand and strand not in {'watson', 'crick'}: + raise ValueError(f'Invalid {strand=}. only "watson" or "crick" supported') + + if alphabet and alphabet not in {'rna', 'dna'}: + raise ValueError(f'Invalid {alphabet=}, only "rna" or "dna" supported') + + self._e_value = e_value + self._n_cpu = n_cpu + self._max_sequences = max_sequences + self._filter_f3 = filter_f3 + self._alphabet = alphabet + self._strand = strand + + def query(self, target_sequence: str) -> msa_tool.MsaToolResult: + """Query the database using Nhmmer.""" + logging.info('Query sequence: %s', target_sequence) + + with tempfile.TemporaryDirectory() as query_tmp_dir: + input_a3m_path = os.path.join(query_tmp_dir, 'query.a3m') + output_sto_path = os.path.join(query_tmp_dir, 'output.sto') + pathlib.Path(output_sto_path).touch() + subprocess_utils.create_query_fasta_file( + sequence=target_sequence, path=input_a3m_path + ) + + cmd_flags = [ + *('-o', '/dev/null'), # Don't pollute stdout with nhmmer output. + '--noali', # Don't include the alignment in stdout. + *('--cpu', str(self._n_cpu)), + ] + + cmd_flags.extend(['-E', str(self._e_value)]) + + if self._alphabet: + cmd_flags.extend([f'--{self._alphabet}']) + + if self._strand is not None: + cmd_flags.extend([f'--{self._strand}']) + + cmd_flags.extend(['-A', output_sto_path]) + # As recommend by RNAcentral for short sequences. + if ( + self._alphabet == 'rna' + and len(target_sequence) < _SHORT_SEQUENCE_CUTOFF + ): + cmd_flags.extend(['--F3', str(0.02)]) + else: + cmd_flags.extend(['--F3', str(self._filter_f3)]) + + # The input A3M and the db are the last two arguments. + cmd_flags.extend((input_a3m_path, self._db_path)) + + cmd = [self._binary_path, *cmd_flags] + + subprocess_utils.run( + cmd=cmd, + cmd_name='Nhmmer', + log_stdout=False, + log_stderr=True, + log_on_process_error=True, + ) + + if os.path.getsize(output_sto_path) > 0: + with open(output_sto_path) as f: + a3m_out = parsers.convert_stockholm_to_a3m( + f, max_sequences=self._max_sequences - 1 # Query not included. + ) + # Nhmmer hits are generally shorter than the query sequence. To get MSA + # of width equal to the query sequence, align hits to the query profile. + logging.info('Aligning output a3m of size %d bytes', len(a3m_out)) + + aligner = hmmalign.Hmmalign(self._hmmalign_binary_path) + target_sequence_fasta = f'>query\n{target_sequence}\n' + profile_builder = hmmbuild.Hmmbuild( + binary_path=self._hmmbuild_binary_path, alphabet=self._alphabet + ) + profile = profile_builder.build_profile_from_a3m(target_sequence_fasta) + a3m_out = aligner.align_sequences_to_profile( + profile=profile, sequences_a3m=a3m_out + ) + a3m_out = ''.join([target_sequence_fasta, a3m_out]) + + # Parse the output a3m to remove line breaks. + a3m = '\n'.join( + [f'>{n}\n{s}' for s, n in parsers.lazy_parse_fasta_string(a3m_out)] + ) + else: + # Nhmmer returns an empty file if there are no hits. + # In this case return only the query sequence. + a3m = f'>query\n{target_sequence}' + + return msa_tool.MsaToolResult( + target_sequence=target_sequence, e_value=self._e_value, a3m=a3m + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py new file mode 100644 index 000000000..be505a3ea --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/rdkit_utils.py @@ -0,0 +1,520 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Tools for calculating features for ligands.""" + +import collections +from collections.abc import Mapping, Sequence + +from absl import logging +from alphafold3.cpp import cif_dict +import numpy as np +import rdkit.Chem as rd_chem + + +_RDKIT_MMCIF_TO_BOND_TYPE: Mapping[str, rd_chem.BondType] = { + 'SING': rd_chem.BondType.SINGLE, + 'DOUB': rd_chem.BondType.DOUBLE, + 'TRIP': rd_chem.BondType.TRIPLE, +} + +_RDKIT_BOND_TYPE_TO_MMCIF: Mapping[rd_chem.BondType, str] = { + v: k for k, v in _RDKIT_MMCIF_TO_BOND_TYPE.items() +} + +_RDKIT_BOND_STEREO_TO_MMCIF: Mapping[rd_chem.BondStereo, str] = { + rd_chem.BondStereo.STEREONONE: 'N', + rd_chem.BondStereo.STEREOE: 'E', + rd_chem.BondStereo.STEREOZ: 'Z', + rd_chem.BondStereo.STEREOCIS: 'Z', + rd_chem.BondStereo.STEREOTRANS: 'E', +} + + +class MolFromMmcifError(Exception): + """Raised when conversion from mmCIF to RDKit Mol fails.""" + + +class UnsupportedMolBondError(Exception): + """Raised when we try to handle unsupported RDKit bonds.""" + + +def _populate_atoms_in_mol( + mol: rd_chem.Mol, + atom_names: Sequence[str], + atom_types: Sequence[str], + atom_charges: Sequence[int], + implicit_hydrogens: bool, + ligand_name: str, + atom_leaving_flags: Sequence[str], +): + """Populate the atoms of a Mol given atom features. + + Args: + mol: Mol object. + atom_names: Names of the atoms. + atom_types: Types of the atoms. + atom_charges: Charges of the atoms. + implicit_hydrogens: Whether to mark the atoms to allow implicit Hs. + ligand_name: Name of the ligand which the atoms are in. + atom_leaving_flags: Whether the atom is possibly a leaving atom. Values from + the CCD column `_chem_comp_atom.pdbx_leaving_atom_flag`. The expected + values are 'Y' (yes), 'N' (no), '?' (unknown/unset, interpreted as no). + + Raises: + ValueError: If atom type is invalid. + """ + # Map atom names to the position they will take in the rdkit molecule. + atom_name_to_idx = {name: i for i, name in enumerate(atom_names)} + + for atom_name, atom_type, atom_charge, atom_leaving_flag in zip( + atom_names, atom_types, atom_charges, atom_leaving_flags, strict=True + ): + try: + if atom_type == 'X': + atom_type = '*' + atom = rd_chem.Atom(atom_type) + except RuntimeError as e: + raise ValueError(f'Failed to use atom type: {str(e)}') from e + + if not implicit_hydrogens: + atom.SetNoImplicit(True) + + atom.SetProp('atom_name', atom_name) + atom.SetProp('atom_leaving_flag', atom_leaving_flag) + atom.SetFormalCharge(atom_charge) + residue_info = rd_chem.AtomPDBResidueInfo() + residue_info.SetName(_format_atom_name(atom_name, atom_type)) + residue_info.SetIsHeteroAtom(True) + residue_info.SetResidueName(ligand_name) + residue_info.SetResidueNumber(1) + atom.SetPDBResidueInfo(residue_info) + atom_index = mol.AddAtom(atom) + assert atom_index == atom_name_to_idx[atom_name] + + +def _populate_bonds_in_mol( + mol: rd_chem.Mol, + atom_names: Sequence[str], + bond_begins: Sequence[str], + bond_ends: Sequence[str], + bond_orders: Sequence[str], + bond_is_aromatics: Sequence[bool], +): + """Populate the bonds of a Mol given bond features. + + Args: + mol: Mol object. + atom_names: Names of atoms in the molecule. + bond_begins: Names of atoms at the beginning of the bond. + bond_ends: Names of atoms at the end of the bond. + bond_orders: What order the bonds are. + bond_is_aromatics: Whether the bonds are aromatic. + """ + atom_name_to_idx = {name: i for i, name in enumerate(atom_names)} + for begin, end, bond_type, is_aromatic in zip( + bond_begins, bond_ends, bond_orders, bond_is_aromatics, strict=True + ): + begin_name, end_name = atom_name_to_idx[begin], atom_name_to_idx[end] + bond_idx = mol.AddBond(begin_name, end_name, bond_type) + mol.GetBondWithIdx(bond_idx - 1).SetIsAromatic(is_aromatic) + + +def _sanitize_mol(mol, sort_alphabetically, remove_hydrogens) -> rd_chem.Mol: + # https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.SanitizeMol + # Kekulize, check valencies, set aromaticity, conjugation and hybridization. + # This can repair e.g. incorrect aromatic flags. + rd_chem.SanitizeMol(mol) + if sort_alphabetically: + mol = sort_atoms_by_name(mol) + if remove_hydrogens: + mol = rd_chem.RemoveHs(mol) + return mol + + +def _add_conformer_to_mol(mol, conformer, force_parse) -> rd_chem.Mol: + # Create conformer and use it to assign stereochemistry. + if conformer is not None: + try: + mol.AddConformer(conformer) + rd_chem.AssignStereochemistryFrom3D(mol) + except ValueError as e: + logging.warning('Failed to parse conformer: %s', e) + if not force_parse: + raise + + +def mol_from_ccd_cif( + mol_cif: cif_dict.CifDict, + *, + force_parse: bool = False, + sort_alphabetically: bool = True, + remove_hydrogens: bool = True, + implicit_hydrogens: bool = False, +) -> rd_chem.Mol: + """Creates an rdkit Mol object from a CCD mmcif data block. + + The atoms are renumbered so that their names are in alphabetical order and + these names are placed on the atoms under property 'atom_name'. + Only hydrogens which are not required to define the molecule are removed. + For example, hydrogens that define stereochemistry around a double bond are + retained. + See this link for more details. + https://www.rdkit.org/docs/source/rdkit.Chem.rdmolops.html#rdkit.Chem.rdmolops.RemoveHs + + Args: + mol_cif: An mmcif object representing a molecule. + force_parse: If True, assumes missing aromatic flags are false, substitutes + deuterium for hydrogen, assumes missing charges are 0 and ignores missing + conformer / stereochemistry information. + sort_alphabetically: True: sort atom alphabetically; False: keep CCD order + remove_hydrogens: if True, remove non-important hydrogens + implicit_hydrogens: Sets a marker on the atom that allows implicit Hs. + + Returns: + An rdkit molecule, with the atoms sorted by name. + + Raises: + MolToMmcifError: If conversion from mmcif to rdkit Mol fails. More detailed + error is available as this error's cause. + """ + # Read data fields. + try: + atom_names, atom_types, atom_charges, atom_leaving_flags = parse_atom_data( + mol_cif, force_parse + ) + bond_begins, bond_ends, bond_orders, bond_is_aromatics = parse_bond_data( + mol_cif, force_parse + ) + lig_name = mol_cif['_chem_comp.id'][0].rjust(3) + except (KeyError, ValueError) as e: + raise MolFromMmcifError from e + + # Build Rdkit molecule. + mol = rd_chem.RWMol() + + # Per atom features. + try: + _populate_atoms_in_mol( + mol=mol, + atom_names=atom_names, + atom_types=atom_types, + atom_charges=atom_charges, + implicit_hydrogens=implicit_hydrogens, + ligand_name=lig_name, + atom_leaving_flags=atom_leaving_flags, + ) + except (ValueError, RuntimeError) as e: + raise MolFromMmcifError from e + + _populate_bonds_in_mol( + mol, atom_names, bond_begins, bond_ends, bond_orders, bond_is_aromatics + ) + + try: + conformer = _parse_ideal_conformer(mol_cif) + except (KeyError, ValueError) as e: + logging.warning('Failed to parse ideal conformer: %s', e) + if not force_parse: + raise MolFromMmcifError from e + conformer = None + + mol.UpdatePropertyCache(strict=False) + + try: + _add_conformer_to_mol(mol, conformer, force_parse) + mol = _sanitize_mol(mol, sort_alphabetically, remove_hydrogens) + except ( + ValueError, + rd_chem.KekulizeException, + rd_chem.AtomValenceException, + ) as e: + raise MolFromMmcifError from e + + return mol + + +def mol_to_ccd_cif( + mol: rd_chem.Mol, + component_id: str, + pdbx_smiles: str | None = None, + include_hydrogens: bool = True, +) -> cif_dict.CifDict: + """Creates a CCD-like mmcif data block from an rdkit Mol object. + + Only a subset of associated mmcif fields is populated, but that is + sufficient for further usage, e.g. in featurization code. + + Atom names can be specified via `atom_name` property. For atoms with + unspecified value of that property, the name is assigned based on element type + and the order in the Mol object. + + If the Mol object has associated conformers, atom positions from the first of + them will be populated in the resulting mmcif file. + + Args: + mol: An rdkit molecule. + component_id: Name of the molecule to use in the resulting mmcif. That is + equivalent to CCD code. + pdbx_smiles: If specified, the value will be used to populate + `_chem_comp.pdbx_smiles`. + include_hydrogens: Whether to include atom and bond data involving + hydrogens. + + Returns: + An mmcif data block corresponding for the given rdkit molecule. + + Raises: + UnsupportedMolBond: When a molecule contains a bond that can't be + represented with mmcif. + """ + mol = rd_chem.Mol(mol) + if include_hydrogens: + mol = rd_chem.AddHs(mol) + rd_chem.Kekulize(mol) + + if mol.GetNumConformers() > 0: + ideal_conformer = mol.GetConformer(0).GetPositions() + ideal_conformer = np.vectorize(lambda x: f'{x:.3f}')(ideal_conformer) + else: + # No data will be populated in the resulting mmcif if the molecule doesn't + # have any conformers attached to it. + ideal_conformer = None + + mol_cif = collections.defaultdict(list) + mol_cif['data_'] = [component_id] + mol_cif['_chem_comp.id'] = [component_id] + if pdbx_smiles: + mol_cif['_chem_comp.pdbx_smiles'] = [pdbx_smiles] + + mol = assign_atom_names_from_graph(mol, keep_existing_names=True) + + for atom_idx, atom in enumerate(mol.GetAtoms()): + element = atom.GetSymbol() + if not include_hydrogens and element in ('H', 'D'): + continue + + mol_cif['_chem_comp_atom.comp_id'].append(component_id) + mol_cif['_chem_comp_atom.atom_id'].append(atom.GetProp('atom_name')) + mol_cif['_chem_comp_atom.type_symbol'].append(atom.GetSymbol().upper()) + mol_cif['_chem_comp_atom.charge'].append(str(atom.GetFormalCharge())) + if ideal_conformer is not None: + coords = ideal_conformer[atom_idx] + mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'].append(coords[0]) + mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'].append(coords[1]) + mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'].append(coords[2]) + + for bond in mol.GetBonds(): + atom1 = bond.GetBeginAtom() + atom2 = bond.GetEndAtom() + if not include_hydrogens and ( + atom1.GetSymbol() in ('H', 'D') or atom2.GetSymbol() in ('H', 'D') + ): + continue + mol_cif['_chem_comp_bond.comp_id'].append(component_id) + mol_cif['_chem_comp_bond.atom_id_1'].append( + bond.GetBeginAtom().GetProp('atom_name') + ) + mol_cif['_chem_comp_bond.atom_id_2'].append( + bond.GetEndAtom().GetProp('atom_name') + ) + try: + bond_type = bond.GetBondType() + # Older versions of RDKit did not have a DATIVE bond type. Convert it to + # SINGLE to match the AF3 training setup. + if bond_type == rd_chem.BondType.DATIVE: + bond_type = rd_chem.BondType.SINGLE + mol_cif['_chem_comp_bond.value_order'].append( + _RDKIT_BOND_TYPE_TO_MMCIF[bond_type] + ) + mol_cif['_chem_comp_bond.pdbx_stereo_config'].append( + _RDKIT_BOND_STEREO_TO_MMCIF[bond.GetStereo()] + ) + except KeyError as e: + raise UnsupportedMolBondError from e + mol_cif['_chem_comp_bond.pdbx_aromatic_flag'].append( + 'Y' if bond.GetIsAromatic() else 'N' + ) + + return cif_dict.CifDict(mol_cif) + + +def _format_atom_name(atom_name: str, atom_type: str) -> str: + """Formats an atom name to fit in the four characters specified in PDB. + + See for example the following note on atom name formatting in PDB files: + https://www.cgl.ucsf.edu/chimera/docs/UsersGuide/tutorials/pdbintro.html#note1 + + Args: + atom_name: The unformatted atom name. + atom_type: The atom element symbol. + + Returns: + formatted_atom_name: The formatted 4-character atom name. + """ + atom_name = atom_name.strip() + atom_type = atom_type.strip().upper() + if len(atom_name) == 1: + return atom_name.rjust(2).ljust(4) + elif len(atom_name) == 2: + if atom_name == atom_type: + return atom_name.ljust(4) + return atom_name.center(4) + elif len(atom_name) == 3: + if atom_name[:2] == atom_type: + return atom_name.ljust(4) + return atom_name.rjust(4) + elif len(atom_name) == 4: + return atom_name + else: + raise ValueError( + f'Atom name `{atom_name}` has more than four characters ' + 'or is an empty string.' + ) + + +def parse_atom_data( + mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool +) -> tuple[Sequence[str], Sequence[str], Sequence[int], Sequence[str]]: + """Parses atoms. If force_parse is True, fix deuterium and missing charge.""" + atom_types = [t.capitalize() for t in mol_cif['_chem_comp_atom.type_symbol']] + atom_names = mol_cif['_chem_comp_atom.atom_id'] + atom_charges = mol_cif['_chem_comp_atom.charge'] + atom_leaving_flags = ['?'] * len(atom_names) + if '_chem_comp_atom.pdbx_leaving_atom_flag' in mol_cif: + atom_leaving_flags = mol_cif['_chem_comp_atom.pdbx_leaving_atom_flag'] + + if force_parse: + # Replace missing charges with 0. + atom_charges = [charge if charge != '?' else '0' for charge in atom_charges] + # Deuterium for hydrogen. + atom_types = [type_ if type_ != 'D' else 'H' for type_ in atom_types] + + atom_charges = [int(atom_charge) for atom_charge in atom_charges] + return atom_names, atom_types, atom_charges, atom_leaving_flags + + +def parse_bond_data( + mol_cif: cif_dict.CifDict | Mapping[str, Sequence[str]], force_parse: bool +) -> tuple[ + Sequence[str], Sequence[str], Sequence[rd_chem.BondType], Sequence[bool] +]: + """Parses bond data. If force_parse is True, ignore missing aromatic flags.""" + # The bond table isn't present if there are no bonds. Use [] in that case. + begin_atoms = mol_cif.get('_chem_comp_bond.atom_id_1', []) + end_atoms = mol_cif.get('_chem_comp_bond.atom_id_2', []) + orders = mol_cif.get('_chem_comp_bond.value_order', []) + bond_types = [_RDKIT_MMCIF_TO_BOND_TYPE[order] for order in orders] + + try: + aromatic_flags = mol_cif.get('_chem_comp_bond.pdbx_aromatic_flag', []) + is_aromatic = [{'Y': True, 'N': False}[flag] for flag in aromatic_flags] + except KeyError: + if force_parse: + # Set them all to not aromatic. + is_aromatic = [False for _ in begin_atoms] + else: + raise + + return begin_atoms, end_atoms, bond_types, is_aromatic + + +def _parse_ideal_conformer(mol_cif: cif_dict.CifDict) -> rd_chem.Conformer: + """Builds a conformer containing the ideal coordinates from the CCD. + + Args: + mol_cif: An mmcif object representing a molecule. + + Returns: + An rdkit conformer filled with the ideal positions from the mmcif. + + Raises: + ValueError: if the positions can't be interpreted. + """ + atom_x = [ + float(x) for x in mol_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'] + ] + atom_y = [ + float(y) for y in mol_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'] + ] + atom_z = [ + float(z) for z in mol_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'] + ] + atom_positions = zip(atom_x, atom_y, atom_z, strict=True) + + conformer = rd_chem.Conformer(len(atom_x)) + for atom_index, atom_position in enumerate(atom_positions): + conformer.SetAtomPosition(atom_index, atom_position) + + return conformer + + +def sort_atoms_by_name(mol: rd_chem.Mol) -> rd_chem.Mol: + """Sorts the atoms in the molecule by their names.""" + atom_names = { + atom.GetProp('atom_name'): atom.GetIdx() for atom in mol.GetAtoms() + } + + # Sort the name, int tuples by the names. + sorted_atom_names = sorted(atom_names.items()) + + # Zip these tuples back together to the sorted indices. + _, new_order = zip(*sorted_atom_names, strict=True) + + # Reorder the molecule. + # new_order is effectively an argsort of the names. + return rd_chem.RenumberAtoms(mol, new_order) + + +def assign_atom_names_from_graph( + mol: rd_chem.Mol, + keep_existing_names: bool = False, +) -> rd_chem.Mol: + """Assigns atom names from the molecular graph. + + The atom name is stored as an atom property 'atom_name', accessible + with atom.GetProp('atom_name'). If the property is already specified, and + keep_existing_names is True we keep the original name. + + We traverse the graph in the order of the rdkit atom index and give each atom + a name equal to '{ELEMENT_TYPE}_{INDEX}'. E.g. C5 is the name for the fifth + unnamed carbon encountered. + + NOTE: A new mol is returned, the original is not changed in place. + + Args: + mol: + keep_existing_names: If True, atoms that already have the atom_name property + will keep their assigned names. + + Returns: + A new mol, with potentially new 'atom_name' properties. + """ + mol = rd_chem.Mol(mol) + + specified_atom_names = { + atom.GetProp('atom_name') + for atom in mol.GetAtoms() + if atom.HasProp('atom_name') and keep_existing_names + } + + element_counts = collections.Counter() + for atom in mol.GetAtoms(): + if not atom.HasProp('atom_name') or not keep_existing_names: + element = atom.GetSymbol() + while True: + element_counts[element] += 1 + new_name = f'{element}{element_counts[element]}' + if new_name not in specified_atom_names: + break + atom.SetProp('atom_name', new_name) + + return mol diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py new file mode 100644 index 000000000..e1a34e9d6 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/data/tools/subprocess_utils.py @@ -0,0 +1,107 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Helper functions for launching external tools.""" + +from collections.abc import Sequence +import os +import subprocess +import time +from typing import Any + +from absl import logging + + +def create_query_fasta_file(sequence: str, path: str, linewidth: int = 80): + """Creates a fasta file with the sequence with line width limit.""" + with open(path, 'w') as f: + f.write('>query\n') + + i = 0 + while i < len(sequence): + f.write(f'{sequence[i:(i + linewidth)]}\n') + i += linewidth + + +def check_binary_exists(path: str, name: str) -> None: + """Checks if a binary exists on the given path and raises otherwise.""" + if not os.path.exists(path): + raise RuntimeError(f'{name} binary not found at {path}') + + +def run( + cmd: Sequence[str], + cmd_name: str, + log_on_process_error: bool = False, + log_stderr: bool = False, + log_stdout: bool = False, + max_out_streams_len: int | None = 500_000, + **run_kwargs, +) -> subprocess.CompletedProcess[Any]: + """Launches a subprocess, times it, and checks for errors. + + Args: + cmd: Command to launch. + cmd_name: Human-readable command name to be used in logs. + log_on_process_error: Whether to use `logging.error` to log the process' + stderr on failure. + log_stderr: Whether to log the stderr of the command. + log_stdout: Whether to log the stdout of the command. + max_out_streams_len: Max length of prefix of stdout and stderr included in + the exception message. Set to `None` to disable truncation. + **run_kwargs: Any other kwargs for `subprocess.run`. + + Returns: + The completed process object. + + Raises: + RuntimeError: if the process completes with a non-zero return code. + """ + + logging.info('Launching subprocess "%s"', ' '.join(cmd)) + + start_time = time.time() + try: + completed_process = subprocess.run( + cmd, + check=True, + stderr=subprocess.PIPE, + stdout=subprocess.PIPE, + text=True, + **run_kwargs, + ) + except subprocess.CalledProcessError as e: + if log_on_process_error: + # Logs have a 15k character limit, so log the error line by line. + logging.error('%s failed. %s stderr begin:', cmd_name, cmd_name) + for error_line in e.stderr.splitlines(): + if stripped_error_line := error_line.strip(): + logging.error(stripped_error_line) + logging.error('%s stderr end.', cmd_name) + + error_msg = ( + f'{cmd_name} failed' + f'\nstdout:\n{e.stdout[:max_out_streams_len]}\n' + f'\nstderr:\n{e.stderr[:max_out_streams_len]}' + ) + raise RuntimeError(error_msg) from e + end_time = time.time() + + logging.info('Finished %s in %.3f seconds', cmd_name, end_time - start_time) + stdout, stderr = completed_process.stdout, completed_process.stderr + + if log_stdout and stdout: + logging.info('%s stdout:\n%s', cmd_name, stdout) + + if log_stderr and stderr: + logging.info('%s stderr:\n%s', cmd_name, stderr) + + return completed_process diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py new file mode 100644 index 000000000..6afdda52b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/atom_layout/atom_layout.py @@ -0,0 +1,1193 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Helper functions for different atom layouts and conversion between them.""" + +import collections +from collections.abc import Mapping, Sequence +import math +import dataclasses +import types +from typing import Any, TypeAlias + +import numpy as np +import mindspore as ms +from mindspore import ops +from rdkit import Chem + +from alphafold3 import structure +from alphafold3.constants import atom_types +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.structure import chemical_components as struc_chem_comps + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +NumpyIndex: TypeAlias = Any + + +def _assign_atom_names_from_graph(mol: Chem.Mol) -> Chem.Mol: + """Assigns atom names from the molecular graph. + + The atom name is stored as an atom property 'atom_name', accessible with + atom.GetProp('atom_name'). If the property is already specified, we keep the + original name. + + We traverse the graph in the order of the rdkit atom index and give each atom + a name equal to '{ELEMENT_TYPE}_{INDEX}'. E.g. C5 is the name for the fifth + unnamed carbon encountered. + + NOTE: A new mol is returned, the original is not changed in place. + + Args: + mol: RDKit molecule. + + Returns: + A new mol, with potentially new 'atom_name' properties. + """ + mol = Chem.Mol(mol) + + specified_atom_names = { + a.GetProp('atom_name') for a in mol.GetAtoms() if a.HasProp('atom_name') + } + + element_counts = collections.Counter() + for atom in mol.GetAtoms(): + if not atom.HasProp('atom_name'): + element = atom.GetSymbol() + while True: + element_counts[element] += 1 + new_name = f'{element}{element_counts[element]}' + if new_name not in specified_atom_names: + break + atom.SetProp('atom_name', new_name) + + return mol + + +@dataclasses.dataclass(frozen=True) +class AtomLayout: + """Atom layout in a fixed shape (usually 1-dim or 2-dim). + + Examples for atom layouts are atom37, atom14, and similar. + All members are np.ndarrays with the same shape, e.g. + - [num_atoms] + - [num_residues, max_atoms_per_residue] + - [num_fragments, max_fragments_per_residue] + All string arrays should have dtype=object to avoid pitfalls with Numpy's + fixed-size strings + + Attributes: + atom_name: np.ndarray of str: atom names (e.g. 'CA', 'NE2'), padding + elements have an empty string (''), None or any other value, that maps to + False for .astype(bool). mmCIF field: _atom_site.label_atom_id. + res_id: np.ndarray of int: residue index (usually starting from 1) padding + elements can have an arbitrary value. mmCIF field: + _atom_site.label_seq_id. + chain_id: np.ndarray of str: chain names (e.g. 'A', 'B') padding elements + can have an arbitrary value. mmCIF field: _atom_site.label_seq_id. + atom_element: np.ndarray of str: atom elements (e.g. 'C', 'N', 'O'), padding + elements have an empty string (''), None or any other value, that maps to + False for .astype(bool). mmCIF field: _atom_site.type_symbol. + res_name: np.ndarray of str: residue names (e.g. 'ARG', 'TRP') padding + elements can have an arbitrary value. mmCIF field: + _atom_site.label_comp_id. + chain_type: np.ndarray of str: chain types (e.g. 'polypeptide(L)'). padding + elements can have an arbitrary value. mmCIF field: _entity_poly.type OR + _entity.type (for non-polymers). + shape: shape of the layout (just returns atom_name.shape) + """ + + atom_name: np.ndarray + res_id: np.ndarray + chain_id: np.ndarray + atom_element: np.ndarray | None = None + res_name: np.ndarray | None = None + chain_type: np.ndarray | None = None + + def __post_init__(self): + """Assert all arrays have the same shape.""" + attribute_names = ( + 'atom_name', + 'atom_element', + 'res_name', + 'res_id', + 'chain_id', + 'chain_type', + ) + _assert_all_arrays_have_same_shape( + obj=self, + expected_shape=self.atom_name.shape, + attribute_names=attribute_names, + ) + # atom_name must have dtype object, such that we can convert it to bool to + # obtain the mask + if self.atom_name.dtype != object: + raise ValueError( + 'atom_name must have dtype object, such that it can ' + 'be converted converted to bool to obtain the mask' + ) + + def __getitem__(self, key: NumpyIndex) -> 'AtomLayout': + return AtomLayout( + atom_name=self.atom_name[key], + res_id=self.res_id[key], + chain_id=self.chain_id[key], + atom_element=( + self.atom_element[key] if self.atom_element is not None else None + ), + res_name=(self.res_name[key] + if self.res_name is not None else None), + chain_type=( + self.chain_type[key] if self.chain_type is not None else None + ), + ) + + def __eq__(self, other: 'AtomLayout') -> bool: + if not np.array_equal(self.atom_name, other.atom_name): + return False + + mask = self.atom_name.astype(bool) + # Check essential fields. + for field in ('res_id', 'chain_id'): + my_arr = getattr(self, field) + other_arr = getattr(other, field) + if not np.array_equal(my_arr[mask], other_arr[mask]): + return False + + # Check optional fields. + for field in ('atom_element', 'res_name', 'chain_type'): + my_arr = getattr(self, field) + other_arr = getattr(other, field) + if ( + my_arr is not None + and other_arr is not None + and not np.array_equal(my_arr[mask], other_arr[mask]) + ): + return False + + return True + + def copy_and_pad_to(self, shape: tuple[int, ...]) -> 'AtomLayout': + """Copies and pads the layout to the requested shape. + + Args: + shape: new shape for the atom layout + + Returns: + a copy of the atom layout padded to the requested shape + + Raises: + ValueError: incompatible shapes. + """ + if len(shape) != len(self.atom_name.shape): + raise ValueError( + f'Incompatible shape {shape}. Current layout has shape {self.shape}.' + ) + if any(new < old for old, new in zip(self.atom_name.shape, shape)): + raise ValueError( + "Can't pad to a smaller shape. Current layout has shape " + f'{self.shape} and you requested shape {shape}.' + ) + pad_width = [ + (0, new - old) for old, new in zip(self.atom_name.shape, shape) + ] + pad_val = np.array('', dtype=object) + return AtomLayout( + atom_name=np.pad(self.atom_name, pad_width, + constant_values=pad_val), + res_id=np.pad(self.res_id, pad_width, constant_values=0), + chain_id=np.pad(self.chain_id, pad_width, constant_values=pad_val), + atom_element=( + np.pad(self.atom_element, pad_width, constant_values=pad_val) + if self.atom_element is not None + else None + ), + res_name=( + np.pad(self.res_name, pad_width, constant_values=pad_val) + if self.res_name is not None + else None + ), + chain_type=( + np.pad(self.chain_type, pad_width, constant_values=pad_val) + if self.chain_type is not None + else None + ), + ) + + def to_array(self) -> np.ndarray: + """Stacks the fields to a numpy array with shape (6, ). + + Creates a pure numpy array of type `object` by stacking the 6 fields of the + AtomLayout, i.e. (atom_name, atom_element, res_name, res_id, chain_id, + chain_type). This method together with from_array() provides an easy way to + apply pure numpy methods like np.concatenate() to `AtomLayout`s. + + Returns: + np.ndarray of object with shape (6, ), e.g. + array([['N', 'CA', 'C', ..., 'CB', 'CG', 'CD'], + ['N', 'C', 'C', ..., 'C', 'C', 'C'], + ['LEU', 'LEU', 'LEU', ..., 'PRO', 'PRO', 'PRO'], + [1, 1, 1, ..., 403, 403, 403], + ['A', 'A', 'A', ..., 'D', 'D', 'D'], + ['polypeptide(L)', 'polypeptide(L)', ..., 'polypeptide(L)']], + dtype=object) + """ + if ( + self.atom_element is None + or self.res_name is None + or self.chain_type is None + ): + raise ValueError('All optional fields need to be present.') + + return np.stack(dataclasses.astuple(self), axis=0) + + @classmethod + def from_array(cls, arr: np.ndarray) -> 'AtomLayout': + """Creates an AtomLayout object from a numpy array with shape (6, ...). + + see also to_array() + Args: + arr: np.ndarray of object with shape (6, ) + + Returns: + AtomLayout object with shape () + """ + if arr.shape[0] != 6: + raise ValueError( + 'Given array must have shape (6, ...) to match the 6 fields of ' + 'AtomLayout (atom_name, atom_element, res_name, res_id, chain_id, ' + f'chain_type). Your array has {arr.shape=}' + ) + return cls(*arr) + + @property + def shape(self) -> tuple[int, ...]: + return self.atom_name.shape + + +@dataclasses.dataclass(frozen=True) +class Residues: + """List of residues with meta data. + + Attributes: + res_name: np.ndarray of str [num_res], e.g. 'ARG', 'TRP' + res_id: np.ndarray of int [num_res] + chain_id: np.ndarray of str [num_res], e.g. 'A', 'B' + chain_type: np.ndarray of str [num_res], e.g. 'polypeptide(L)' + is_start_terminus: np.ndarray of bool [num_res] + is_end_terminus: np.ndarray of bool [num_res] + deprotonation: (optional) np.ndarray of set() [num_res], e.g. {'HD1', 'HE2'} + smiles_string: (optional) np.ndarray of str [num_res], e.g. 'Cc1ccccc1' + shape: shape of the layout (just returns res_name.shape) + """ + + res_name: np.ndarray + res_id: np.ndarray + chain_id: np.ndarray + chain_type: np.ndarray + is_start_terminus: np.ndarray + is_end_terminus: np.ndarray + deprotonation: np.ndarray | None = None + smiles_string: np.ndarray | None = None + + def __post_init__(self): + """Assert all arrays are 1D have the same shape.""" + attribute_names = ( + 'res_name', + 'res_id', + 'chain_id', + 'chain_type', + 'is_start_terminus', + 'is_end_terminus', + 'deprotonation', + 'smiles_string', + ) + _assert_all_arrays_have_same_shape( + obj=self, + expected_shape=(self.res_name.shape[0],), + attribute_names=attribute_names, + ) + + def __getitem__(self, key: NumpyIndex) -> 'Residues': + return Residues( + res_name=self.res_name[key], + res_id=self.res_id[key], + chain_id=self.chain_id[key], + chain_type=self.chain_type[key], + is_start_terminus=self.is_start_terminus[key], + is_end_terminus=self.is_end_terminus[key], + deprotonation=( + self.deprotonation[key] if self.deprotonation is not None else None + ), + smiles_string=( + self.smiles_string[key] if self.smiles_string is not None else None + ), + ) + + def __eq__(self, other: 'Residues') -> bool: + return all( + np.array_equal(getattr(self, field.name), + getattr(other, field.name)) + for field in dataclasses.fields(self) + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.res_name.shape + + +@dataclasses.dataclass # (frozen=True) +class GatherInfo: + """Gather indices to translate from one atom layout to another. + + All members are np or jnp ndarray (usually 1-dim or 2-dim) with the same + shape, e.g. + - [num_atoms] + - [num_residues, max_atoms_per_residue] + - [num_fragments, max_fragments_per_residue] + + Attributes: + gather_idxs: np or jnp ndarray of int: gather indices into a flattened array + gather_mask: np or jnp ndarray of bool: mask for resulting array + input_shape: np or jnp ndarray of int: the shape of the unflattened input + array + shape: output shape. Just returns gather_idxs.shape + """ + + gather_idxs: ms.Tensor + gather_mask: ms.Tensor + input_shape: ms.Tensor + + def __post_init__(self): + if self.gather_mask.shape != self.gather_idxs.shape: + raise ValueError( + 'All arrays must have the same shape. Got\n' + f'gather_idxs.shape = {self.gather_idxs.shape}\n' + f'gather_mask.shape = {self.gather_mask.shape}\n' + ) + + def __getitem__(self, key: NumpyIndex) -> 'GatherInfo': + return GatherInfo( + gather_idxs=self.gather_idxs[key], + gather_mask=self.gather_mask[key], + input_shape=self.input_shape, + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.gather_idxs.shape + + def as_np_or_jnp(self, xnp: types.ModuleType) -> 'GatherInfo': + return GatherInfo( + gather_idxs=xnp.array(self.gather_idxs), + gather_mask=xnp.array(self.gather_mask), + input_shape=xnp.array(self.input_shape), + ) + + def as_dict( + self, + key_prefix: str | None = None, + ) -> dict[str, xnp_ndarray]: + prefix = f'{key_prefix}:' if key_prefix else '' + return { + prefix + 'gather_idxs': self.gather_idxs, + prefix + 'gather_mask': self.gather_mask, + prefix + 'input_shape': self.input_shape, + } + + @classmethod + def from_dict( + cls, + d: Mapping[str, xnp_ndarray], + key_prefix: str | None = None, + ) -> 'GatherInfo': + """Creates GatherInfo from a given dictionary.""" + prefix = f'{key_prefix}:' if key_prefix else '' + return cls( + gather_idxs=d[prefix + 'gather_idxs'], + gather_mask=d[prefix + 'gather_mask'], + input_shape=d[prefix + 'input_shape'], + ) + + +def fill_in_optional_fields( + minimal_atom_layout: AtomLayout, + reference_atoms: AtomLayout, +) -> AtomLayout: + """Fill in the optional fields (atom_element, res_name, chain_type). + + Extracts the optional fields (atom_element, res_name, chain_type) from a + flat reference layout and fills them into the fields from this layout. + + Args: + minimal_atom_layout: An AtomLayout that only contains the essential fields + (atom_name, res_id, chain_id). + reference_atoms: A flat layout that contains all fields for all atoms. + + Returns: + An AtomLayout that contains all fields. + + Raises: + ValueError: Reference atoms layout is not flat. + ValueError: Missing atoms in reference. + """ + if len(reference_atoms.shape) > 1: + raise ValueError('Only flat layouts are supported as reference.') + ref_to_self = compute_gather_idxs( + source_layout=reference_atoms, target_layout=minimal_atom_layout + ) + atom_mask = minimal_atom_layout.atom_name.astype(bool) + missing_atoms_mask = atom_mask & ~ref_to_self.gather_mask + if np.any(missing_atoms_mask): + raise ValueError( + f'{np.sum(missing_atoms_mask)} missing atoms in reference: ' + f'{minimal_atom_layout[missing_atoms_mask]}' + ) + + def _convert_str_array(gather: GatherInfo, arr: np.ndarray): + output = arr[gather.gather_idxs] + output[~gather.gather_mask] = '' + return output + + return dataclasses.replace( + minimal_atom_layout, + atom_element=_convert_str_array( + ref_to_self, reference_atoms.atom_element + ), + res_name=_convert_str_array(ref_to_self, reference_atoms.res_name), + chain_type=_convert_str_array(ref_to_self, reference_atoms.chain_type), + ) + + +def guess_deprotonation(residues: Residues) -> Residues: + """Convenience function to create a plausible deprotonation field. + + Assumes a pH of 7 and always prefers HE2 over HD1 for HIS. + Args: + residues: a Residues object without a depronotation field + + Returns: + a Residues object with a depronotation field + """ + num_residues = residues.res_name.shape[0] + deprotonation = np.empty(num_residues, dtype=object) + deprotonation_at_ph7 = { + 'ASP': 'HD2', + 'GLU': 'HE2', + 'HIS': 'HD1', + } + for idx, res_name in enumerate(residues.res_name): + deprotonation[idx] = set() + if res_name in deprotonation_at_ph7: + deprotonation[idx].add(deprotonation_at_ph7[res_name]) + if residues.is_end_terminus[idx]: + deprotonation[idx].add('HXT') + + return dataclasses.replace(residues, deprotonation=deprotonation) + + +def atom_layout_from_structure( + struct: structure.Structure, + *, + fix_non_standard_polymer_res: bool = False, +) -> AtomLayout: + """Extract AtomLayout from a Structure.""" + + if not fix_non_standard_polymer_res: + return AtomLayout( + atom_name=np.array(struct.atom_name, dtype=object), + atom_element=np.array(struct.atom_element, dtype=object), + res_name=np.array(struct.res_name, dtype=object), + res_id=np.array(struct.res_id, dtype=int), + chain_id=np.array(struct.chain_id, dtype=object), + chain_type=np.array(struct.chain_type, dtype=object), + ) + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + for atom in struct.iter_atoms(): + target_atom_names.append(atom['atom_name']) + target_atom_elements.append(atom['atom_element']) + target_res_ids.append(atom['res_id']) + target_chain_ids.append(atom['chain_id']) + target_chain_types.append(atom['chain_type']) + if mmcif_names.is_standard_polymer_type(atom['chain_type']): + fixed_res_name = mmcif_names.fix_non_standard_polymer_res( + res_name=atom['res_name'], chain_type=atom['chain_type'] + ) + target_res_names.append(fixed_res_name) + else: + target_res_names.append(atom['res_name']) + + return AtomLayout( + atom_name=np.array(target_atom_names, dtype=object), + atom_element=np.array(target_atom_elements, dtype=object), + res_name=np.array(target_res_names, dtype=object), + res_id=np.array(target_res_ids, dtype=int), + chain_id=np.array(target_chain_ids, dtype=object), + chain_type=np.array(target_chain_types, dtype=object), + ) + + +def residues_from_structure( + struct: structure.Structure, + *, + include_missing_residues: bool = True, + fix_non_standard_polymer_res: bool = False, +) -> Residues: + """Create a Residues object from a Structure object.""" + + def _get_smiles(res_name): + """Get SMILES string from chemical components.""" + smiles = None + if ( + struct.chemical_components_data is not None + and struct.chemical_components_data.chem_comp is not None + and struct.chemical_components_data.chem_comp.get(res_name) + ): + smiles = struct.chemical_components_data.chem_comp[res_name].pdbx_smiles + return smiles + + res_names_per_chain = struct.chain_res_name_sequence( + include_missing_residues=include_missing_residues, + fix_non_standard_polymer_res=fix_non_standard_polymer_res, + ) + res_name = [] + res_id = [] + chain_id = [] + chain_type = [] + smiles = [] + is_start_terminus = [] + for c in struct.iter_chains(): + if include_missing_residues: + this_res_ids = [ + id for (_, id) in struct.all_residues[c['chain_id']]] + else: + this_res_ids = [ + r['res_id'] + for r in struct.iter_residues() + if r['chain_id'] == c['chain_id'] + ] + fixed_res_names = res_names_per_chain[c['chain_id']] + assert len(this_res_ids) == len( + fixed_res_names + ), f'{len(this_res_ids)} != {len(fixed_res_names)}' + this_start_res_id = min(min(this_res_ids), 1) + this_is_start_terminus = [r == this_start_res_id for r in this_res_ids] + smiles.extend([_get_smiles(res_name) for res_name in fixed_res_names]) + num_res = len(fixed_res_names) + res_name.extend(fixed_res_names) + res_id.extend(this_res_ids) + chain_id.extend([c['chain_id']] * num_res) + chain_type.extend([c['chain_type']] * num_res) + is_start_terminus.extend(this_is_start_terminus) + res_name = np.array(res_name, dtype=object) + res_id = np.array(res_id, dtype=int) + chain_id = np.array(chain_id, dtype=object) + chain_type = np.array(chain_type, dtype=object) + smiles = np.array(smiles, dtype=object) + is_start_terminus = np.array(is_start_terminus, dtype=bool) + + res_uid_to_idx = { + uid: idx for idx, uid in enumerate(zip(chain_id, res_id, strict=True)) + } + + # Start terminus indicates whether residue index is 1 and chain is polymer. + is_polymer = np.isin(chain_type, tuple(mmcif_names.POLYMER_CHAIN_TYPES)) + is_start_terminus = is_start_terminus & is_polymer + + # Start also indicates whether amino acid is attached to H2 or proline to H. + start_terminus_atom_index = np.nonzero( + (struct.chain_type == mmcif_names.PROTEIN_CHAIN) + & ( + (struct.atom_name == 'H2') + | ((struct.atom_name == 'H') & (struct.res_name == 'PRO')) + ) + )[0] + + # Translate atom idx to residue idx to assign start terminus. + for atom_idx in start_terminus_atom_index: + res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx]) + res_idx = res_uid_to_idx[res_uid] + is_start_terminus[res_idx] = True + + # Infer end terminus: Check for OXT, or in case of + # include_missing_residues==True for the last residue of the chain. + num_all_residues = res_name.shape[0] + is_end_terminus = np.zeros(num_all_residues, dtype=bool) + end_term_atom_idxs = np.nonzero(struct.atom_name == 'OXT')[0] + for atom_idx in end_term_atom_idxs: + res_uid = (struct.chain_id[atom_idx], struct.res_id[atom_idx]) + res_idx = res_uid_to_idx[res_uid] + is_end_terminus[res_idx] = True + + if include_missing_residues: + for idx in range(num_all_residues - 1): + if is_polymer[idx] and chain_id[idx] != chain_id[idx + 1]: + is_end_terminus[idx] = True + if (num_all_residues > 0) and is_polymer[-1]: + is_end_terminus[-1] = True + + # Infer (de-)protonation: Only if hydrogens are given. + num_hydrogens = np.sum( + (struct.atom_element == 'H') & (struct.chain_type == 'polypeptide(L)') + ) + if num_hydrogens > 0: + deprotonation = np.empty(num_all_residues, dtype=object) + all_atom_uids = set( + zip(struct.chain_id, struct.res_id, struct.atom_name, strict=True) + ) + for idx in range(num_all_residues): + deprotonation[idx] = set() + check_hydrogens = set() + if is_end_terminus[idx]: + check_hydrogens.add('HXT') + if res_name[idx] in atom_types.PROTONATION_HYDROGENS: + check_hydrogens.update( + atom_types.PROTONATION_HYDROGENS[res_name[idx]]) + for hydrogen in check_hydrogens: + if (chain_id[idx], res_id[idx], hydrogen) not in all_atom_uids: + deprotonation[idx].add(hydrogen) + else: + deprotonation = None + + return Residues( + res_name=res_name, + res_id=res_id, + chain_id=chain_id, + chain_type=chain_type, + is_start_terminus=is_start_terminus.astype(bool), + is_end_terminus=is_end_terminus, + deprotonation=deprotonation, + smiles_string=smiles, + ) + + +def get_link_drop_atoms( + res_name: str, + chain_type: str, + *, + is_start_terminus: bool, + is_end_terminus: bool, + bonded_atoms: set[str], + drop_ligand_leaving_atoms: bool = False, +) -> set[str]: + """Returns set of atoms that are dropped when this res_name gets linked. + + Args: + res_name: residue name, e.g. 'ARG' + chain_type: chain_type, e.g. 'polypeptide(L)' + is_start_terminus: whether the residue is the n-terminus + is_end_terminus: whether the residue is the c-terminus + bonded_atoms: Names of atoms coming off this residue. + drop_ligand_leaving_atoms: Flag to switch on/off leaving atoms for ligands. + + Returns: + Set of atoms that are dropped when this amino acid gets linked. + """ + drop_atoms = set() + if chain_type == mmcif_names.PROTEIN_CHAIN: + if res_name == 'PRO': + if not is_start_terminus: + drop_atoms.update({'H', 'H2', 'H3'}) + if not is_end_terminus: + drop_atoms.update({'OXT', 'HXT'}) + else: + if not is_start_terminus: + drop_atoms.update({'H2', 'H3'}) + if not is_end_terminus: + drop_atoms.update({'OXT', 'HXT'}) + elif chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + if not is_start_terminus: + drop_atoms.update({'OP3'}) + elif ( + drop_ligand_leaving_atoms and chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + if res_name in { + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }: + if 'O1' not in bonded_atoms: + drop_atoms.update({'O1'}) + return drop_atoms + + +def get_bonded_atoms( + polymer_ligand_bonds: AtomLayout, + ligand_ligand_bonds: AtomLayout, + res_id: int, + chain_id: str, +) -> set[str]: + """Finds the res_name on the opposite end of the bond, if a bond exists. + + Args: + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + res_id: residue id in question. + chain_id: chain id of residue in question. + + Returns: + res_name of bonded atom. + """ + bonded_atoms = set() + if polymer_ligand_bonds: + # Filter before searching to speed this up. + bond_idx = np.logical_and( + polymer_ligand_bonds.res_id == res_id, + polymer_ligand_bonds.chain_id == chain_id, + ).any(axis=1) + relevant_polymer_bonds = polymer_ligand_bonds[bond_idx] + for atom_names, res_ids, chain_ids in zip( + relevant_polymer_bonds.atom_name, + relevant_polymer_bonds.res_id, + relevant_polymer_bonds.chain_id, + ): + if (res_ids[0], chain_ids[0]) == (res_id, chain_id): + bonded_atoms.add(atom_names[0]) + elif (res_ids[1], chain_ids[1]) == (res_id, chain_id): + bonded_atoms.add(atom_names[1]) + if ligand_ligand_bonds: + bond_idx = np.logical_and( + ligand_ligand_bonds.res_id == res_id, + ligand_ligand_bonds.chain_id == chain_id, + ).any(axis=1) + relevant_ligand_bonds = ligand_ligand_bonds[bond_idx] + for atom_names, res_ids, chain_ids in zip( + relevant_ligand_bonds.atom_name, + relevant_ligand_bonds.res_id, + relevant_ligand_bonds.chain_id, + ): + if (res_ids[0], chain_ids[0]) == (res_id, chain_id): + bonded_atoms.add(atom_names[0]) + elif (res_ids[1], chain_ids[1]) == (res_id, chain_id): + bonded_atoms.add(atom_names[1]) + return bonded_atoms + + +def make_flat_atom_layout( + residues: Residues, + ccd: chemical_components.Ccd, + polymer_ligand_bonds: AtomLayout | None = None, + ligand_ligand_bonds: AtomLayout | None = None, + *, + with_hydrogens: bool = False, + skip_unk_residues: bool = True, + drop_ligand_leaving_atoms: bool = False, +) -> AtomLayout: + """Make a flat atom layout for given residues. + + Create a flat layout from a `Residues` object. The required atoms for each + amino acid type are taken from the CCD, hydrogens and oxygens are dropped to + make the linked residues. Terminal OXT's and protonation state for the + hydrogens come from the `Residues` object. + + Args: + residues: a `Residues` object. + ccd: The chemical components dictionary. + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + with_hydrogens: whether to create hydrogens + skip_unk_residues: whether to skip 'UNK' resides -- default is True to be + compatible with the rest of AlphaFold that does not predict atoms for + unknown residues + drop_ligand_leaving_atoms: Flag to switch on/ off leaving atoms for ligands. + + Returns: + an `AtomLayout` object + """ + num_res = residues.res_name.shape[0] + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + for idx in range(num_res): + # skip 'UNK' residues if requested + if ( + skip_unk_residues + and residues.res_name[idx] in residue_names.UNKNOWN_TYPES + ): + continue + + # Get the atoms for this residue type from CCD. + if ccd.get(residues.res_name[idx]): + res_atoms = struc_chem_comps.get_all_atoms_in_entry( + ccd=ccd, res_name=residues.res_name[idx] + ) + atom_names_elements = list( + zip( + res_atoms['_chem_comp_atom.atom_id'], + res_atoms['_chem_comp_atom.type_symbol'], + strict=True, + ) + ) + elif residues.smiles_string[idx]: + # Get atoms from RDKit via SMILES. + mol = Chem.MolFromSmiles(residues.smiles_string[idx]) + mol = _assign_atom_names_from_graph(mol) + atom_names_elements = [ + (a.GetProp('atom_name'), a.GetSymbol()) for a in mol.GetAtoms() + ] + else: + raise ValueError( + f'{residues.res_name[idx]} not found in CCD and no SMILES string' + ) + + # Remove hydrogens if requested. + if not with_hydrogens: + atom_names_elements = [ + (n, e) for n, e in atom_names_elements if (e != 'H' and e != 'D') + ] + bonded_atoms = get_bonded_atoms( + polymer_ligand_bonds, + ligand_ligand_bonds, + residues.res_id[idx], + residues.chain_id[idx], + ) + # Connect the amino-acids, i.e. remove OXT, HXT and H2. + drop_atoms = get_link_drop_atoms( + res_name=residues.res_name[idx], + chain_type=residues.chain_type[idx], + is_start_terminus=residues.is_start_terminus[idx], + is_end_terminus=residues.is_end_terminus[idx], + bonded_atoms=bonded_atoms, + drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + ) + + # If deprotonation info is available, remove the specific atoms. + if residues.deprotonation is not None: + drop_atoms.update(residues.deprotonation[idx]) + + atom_names_elements = [ + (n, e) for n, e in atom_names_elements if n not in drop_atoms + ] + + # Append the found atoms to the target lists. + target_atom_names.extend([n for n, _ in atom_names_elements]) + target_atom_elements.extend([e for _, e in atom_names_elements]) + num_atoms = len(atom_names_elements) + target_res_names.extend([residues.res_name[idx]] * num_atoms) + target_res_ids.extend([residues.res_id[idx]] * num_atoms) + target_chain_ids.extend([residues.chain_id[idx]] * num_atoms) + target_chain_types.extend([residues.chain_type[idx]] * num_atoms) + + return AtomLayout( + atom_name=np.array(target_atom_names, dtype=object), + atom_element=np.array(target_atom_elements, dtype=object), + res_name=np.array(target_res_names, dtype=object), + res_id=np.array(target_res_ids, dtype=int), + chain_id=np.array(target_chain_ids, dtype=object), + chain_type=np.array(target_chain_types, dtype=object), + ) + + +def compute_gather_idxs( + *, + source_layout: AtomLayout, + target_layout: AtomLayout, + fill_value: int = 0, +) -> GatherInfo: + """Produce gather indices and mask to convert from source layout to target.""" + source_uid_to_idx = { + uid: idx + for idx, uid in enumerate( + zip( + source_layout.chain_id.ravel(), + source_layout.res_id.ravel(), + source_layout.atom_name.ravel(), + strict=True, + ) + ) + } + gather_idxs = [] + gather_mask = [] + for uid in zip( + target_layout.chain_id.ravel(), + target_layout.res_id.ravel(), + target_layout.atom_name.ravel(), + strict=True, + ): + if uid in source_uid_to_idx: + gather_idxs.append(source_uid_to_idx[uid]) + gather_mask.append(True) + else: + gather_idxs.append(fill_value) + gather_mask.append(False) + target_shape = target_layout.atom_name.shape + return GatherInfo( + gather_idxs=np.array(gather_idxs, dtype=int).reshape(target_shape), + gather_mask=np.array(gather_mask, dtype=bool).reshape(target_shape), + input_shape=np.array(source_layout.atom_name.shape), + ) + + +def convert( + gather_info: GatherInfo, + arr: xnp_ndarray, + *, + layout_axes: tuple[int, ...] = (0,), +) -> xnp_ndarray: + """Convert an array from one atom layout to another.""" + # Translate negative indices to the corresponding positives. + layout_axes = tuple(i if i >= 0 else i + arr.ndim for i in layout_axes) + + # Ensure that layout_axes are continuous. + layout_axes_begin = layout_axes[0] + layout_axes_end = layout_axes[-1] + 1 + + if layout_axes != tuple(range(layout_axes_begin, layout_axes_end)): + raise ValueError(f'layout_axes must be continuous. Got {layout_axes}.') + layout_shape = arr.shape[layout_axes_begin:layout_axes_end] + + # Ensure that the layout shape is compatible + # with the gather_info. I.e. the first axis size must be equal or greater + # than the gather_info.input_shape, and all subsequent axes sizes must match. + if (len(layout_shape) != gather_info.input_shape.size) or ( + isinstance(gather_info.input_shape, np.ndarray) + and ( + (layout_shape[0] < gather_info.input_shape[0]) + or (np.any(layout_shape[1:] != gather_info.input_shape[1:])) + ) + ): + raise ValueError( + 'Input array layout axes are incompatible. You specified layout ' + f'axes {layout_axes} with an input array of shape {arr.shape}, but ' + f'the gather info expects shape {gather_info.input_shape}. ' + 'Your first axis size must be equal or greater than the ' + 'gather_info.input_shape, and all subsequent axes sizes must ' + 'match.' + ) + + # Compute the shape of the input array with flattened layout. + batch_shape = arr.shape[:layout_axes_begin] + features_shape = arr.shape[layout_axes_end:] + arr_flattened_shape = batch_shape + \ + (int(np.prod(layout_shape)),) + features_shape + + # Flatten input array and perform the gather. + arr_flattened = arr.reshape(arr_flattened_shape) + if layout_axes_begin == 0: + out_arr = arr_flattened[gather_info.gather_idxs, ...] + elif layout_axes_begin == 1: + out_arr = arr_flattened[:, gather_info.gather_idxs, ...] + elif layout_axes_begin == 2: + out_arr = arr_flattened[:, :, gather_info.gather_idxs, ...] + elif layout_axes_begin == 3: + out_arr = arr_flattened[:, :, :, gather_info.gather_idxs, ...] + elif layout_axes_begin == 4: + out_arr = arr_flattened[:, :, :, :, gather_info.gather_idxs, ...] + else: + raise ValueError( + 'Only 4 batch axes supported. If you need more, the code ' + 'is easy to extend.' + ) + + # Broadcast the mask and apply it. + broadcasted_mask_shape = ( + (1,) * len(batch_shape) + + gather_info.gather_mask.shape + + (1,) * len(features_shape) + ) + out_arr *= gather_info.gather_mask.reshape(broadcasted_mask_shape) + return out_arr + + +def convert_ms( + gather_info: GatherInfo, + arr: ms.Tensor, + *, + layout_axes: tuple[int, ...] = (0,), +) -> ms.Tensor: + """Convert an array from one atom layout to another.""" + # Translate negative indices to the corresponding positives. + layout_axes = tuple(i if i >= 0 else i + arr.ndim for i in layout_axes) + + # Ensure that layout_axes are continuous. + layout_axes_begin = layout_axes[0] + layout_axes_end = layout_axes[-1] + 1 + + if layout_axes != tuple(range(layout_axes_begin, layout_axes_end)): + raise ValueError(f'layout_axes must be continuous. Got {layout_axes}.') + layout_shape = arr.shape[layout_axes_begin:layout_axes_end] + + # Ensure that the layout shape is compatible + # with the gather_info. I.e. the first axis size must be equal or greater + # than the gather_info.input_shape, and all subsequent axes sizes must match. + # if (len(layout_shape) != gather_info.input_shape.size) or ( + # isinstance(gather_info.input_shape, np.ndarray) + # and ( + # (layout_shape[0] < gather_info.input_shape[0]) + # or (np.any(layout_shape[1:] != gather_info.input_shape[1:])) + # ) + # ): + # raise ValueError( + # 'Input array layout axes are incompatible. You specified layout ' + # f'axes {layout_axes} with an input array of shape {arr.shape}, but ' + # f'the gather info expects shape {gather_info.input_shape}. ' + # 'Your first axis size must be equal or greater than the ' + # 'gather_info.input_shape, and all subsequent axes sizes must ' + # 'match.' + # ) + + # Compute the shape of the input array with flattened layout. + batch_shape = arr.shape[:layout_axes_begin] + features_shape = arr.shape[layout_axes_end:] + arr_flattened_shape = batch_shape + \ + (int(math.prod(layout_shape)),) + features_shape + + # Flatten input array and perform the gather. + arr_flattened = arr.reshape(arr_flattened_shape) + out_arr = ops.gather(arr_flattened, gather_info.gather_idxs, axis=layout_axes_begin) + + # Broadcast the mask and apply it. + broadcasted_mask_shape = ( + (1,) * len(batch_shape) + + gather_info.gather_mask.shape + + (1,) * len(features_shape) + ) + out_arr *= ms.Tensor(gather_info.gather_mask.reshape(broadcasted_mask_shape)) + return out_arr.astype(ms.float32) + + +def make_structure( + flat_layout: AtomLayout, + atom_coords: np.ndarray, + name: str, + *, + atom_b_factors: np.ndarray | None = None, + all_physical_residues: Residues | None = None, +) -> structure.Structure: + """Returns a Structure from a flat layout and atom coordinates. + + The provided flat_layout must be 1-dim and must not contain any padding + elements. The flat_layout.atom_name must conform to the OpenMM/CCD standard + and must not contain deuterium. + + Args: + flat_layout: flat 1-dim AtomLayout without pading elements + atom_coords: np.ndarray of float, shape (num_atoms, 3) + name: str: the name (usually PDB id), e.g. '1uao' + atom_b_factors: np.ndarray of float, shape (num_atoms,) or None. If None, + they will be set to all zeros. + all_physical_residues: a Residues object that contains all physically + existing residues, i.e. also those residues that have no resolved atoms. + This is common in experimental structures, but also appears in predicted + structures for 'UNK' or other non-standard residue types, where the model + does not predict coordinates. This will be used to create the + `all_residues` field of the structure object. + """ + + if flat_layout.atom_name.ndim != 1 or not np.all( + flat_layout.atom_name.astype(bool) + ): + raise ValueError( + 'flat_layout must be 1-dim and must not contain anypadding element' + ) + if ( + flat_layout.atom_element is None + or flat_layout.res_name is None + or flat_layout.chain_type is None + ): + raise ValueError('All optional fields must be present.') + + if atom_b_factors is None: + atom_b_factors = np.zeros(atom_coords.shape[:-1]) + + if all_physical_residues is not None: + # Create the all_residues field from a Residues object + # (unfortunately there is no central place to keep the chain_types in + # the structure class, so we drop it here) + all_residues = collections.defaultdict(list) + for chain_id, res_id, res_name in zip( + all_physical_residues.chain_id, + all_physical_residues.res_id, + all_physical_residues.res_name, + strict=True, + ): + all_residues[chain_id].append((res_name, res_id)) + else: + # Create the all_residues field from the flat_layout + all_residues = collections.defaultdict(list) + if flat_layout.chain_id.shape[0] > 0: + all_residues[flat_layout.chain_id[0]].append( + (flat_layout.res_name[0], flat_layout.res_id[0]) + ) + for i in range(1, flat_layout.shape[0]): + if ( + flat_layout.chain_id[i] != flat_layout.chain_id[i - 1] + or flat_layout.res_name[i] != flat_layout.res_name[i - 1] + or flat_layout.res_id[i] != flat_layout.res_id[i - 1] + ): + all_residues[flat_layout.chain_id[i]].append( + (flat_layout.res_name[i], flat_layout.res_id[i]) + ) + + return structure.from_atom_arrays( + name=name, + all_residues=dict(all_residues), + chain_id=flat_layout.chain_id, + chain_type=flat_layout.chain_type, + res_id=flat_layout.res_id.astype(np.int32), + res_name=flat_layout.res_name, + atom_name=flat_layout.atom_name, + atom_element=flat_layout.atom_element, + atom_x=atom_coords[..., 0], + atom_y=atom_coords[..., 1], + atom_z=atom_coords[..., 2], + atom_b_factor=atom_b_factors, + ) + + +def _assert_all_arrays_have_same_shape( + *, + obj: AtomLayout | Residues | GatherInfo, + expected_shape: tuple[int, ...], + attribute_names: Sequence[str], +) -> None: + """Checks that given attributes of the object have the expected shape.""" + attribute_shapes_description = [] + all_shapes_are_valid = True + + for attribute_name in attribute_names: + attribute = getattr(obj, attribute_name) + + if attribute is None: + attribute_shape = None + else: + attribute_shape = attribute.shape + + if attribute_shape is not None and expected_shape != attribute_shape: + all_shapes_are_valid = False + + attribute_shape_name = attribute_name + '.shape' + attribute_shapes_description.append( + f'{attribute_shape_name:25} = {attribute_shape}' + ) + + if not all_shapes_are_valid: + raise ValueError( + f'All arrays must have the same shape ({expected_shape=}). Got\n' + + '\n'.join(attribute_shapes_description) + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py new file mode 100644 index 000000000..0d3a08b62 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/base_config.py @@ -0,0 +1,153 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Config for the protein folding model and experiment.""" + +from collections.abc import Mapping +import copy +import dataclasses +import types +import typing +from typing import Any, ClassVar, TypeVar + + +_T = TypeVar('_T') +_ConfigT = TypeVar('_ConfigT', bound='BaseConfig') + + +def _strip_optional(t: type[Any]) -> type[Any]: + """Transforms type annotations of the form `T | None` to `T`.""" + if typing.get_origin(t) in (typing.Union, types.UnionType): + args = set(typing.get_args(t)) - {types.NoneType} + if len(args) == 1: + return args.pop() + return t + + +_NO_UPDATE = object() + + +class _Autocreate: + + def __init__(self, **defaults: Any): + self.defaults = defaults + + +def autocreate(**defaults: Any) -> Any: + """Marks a field as having a default factory derived from its type.""" + return _Autocreate(**defaults) + + +def _clone_field( + field: dataclasses.Field[_T], new_default: _T +) -> dataclasses.Field[_T]: + if new_default is _NO_UPDATE: + return copy.copy(field) + return dataclasses.field( + default=new_default, + init=True, + kw_only=True, + repr=field.repr, + hash=field.hash, + compare=field.compare, + metadata=field.metadata, + ) + + +@typing.dataclass_transform() +class ConfigMeta(type): + """Metaclass that synthesizes a __post_init__ that coerces dicts to Config subclass instances.""" + + def __new__(mcs, name, bases, classdict): + cls = super().__new__(mcs, name, bases, classdict) + + def _coercable_fields(self) -> Mapping[str, tuple[ConfigMeta, Any]]: + type_hints = typing.get_type_hints(self.__class__) + fields = dataclasses.fields(self.__class__) + field_to_type_and_default = { + field.name: (_strip_optional( + type_hints[field.name]), field.default) + for field in fields + } + coercable_fields = { + f: t + for f, t in field_to_type_and_default.items() + if issubclass(type(t[0]), ConfigMeta) + } + return coercable_fields + + cls._coercable_fields = property(_coercable_fields) + + old_post_init = getattr(cls, '__post_init__', None) + + def _post_init(self) -> None: + # Use get_type_hints instead of Field.type to ensure that forward + # references are resolved. + for field_name, ( + field_type, + field_default, + ) in self._coercable_fields.items(): # pylint: disable=protected-access + field_value = getattr(self, field_name) + if field_value is None: + continue + try: + match field_value: + case _Autocreate(): + # Construct from field defaults. + setattr(self, field_name, field_type( + **field_value.defaults)) + case Mapping(): + # Field value is not yet a `Config` instance; Assume we can create + # one by splatting keys and values. + args = {} + # Apply default args first, if present. + if isinstance(field_default, _Autocreate): + args.update(field_default.defaults) + args.update(field_value) + setattr(self, field_name, field_type(**args)) + case _: + pass + except TypeError as e: + raise TypeError( + f'Failure while coercing field {field_name!r} of' + f' {self.__class__.__qualname__}' + ) from e + if old_post_init: + old_post_init(self) + + cls.__post_init__ = _post_init + + return dataclasses.dataclass(kw_only=True)(cls) + + +class BaseConfig(metaclass=ConfigMeta): + """Config base class. + + Subclassing Config automatically makes the subclass a kw_only dataclass with + a `__post_init__` that coerces Config-subclass field values from mappings to + instances of the right type. + """ + # Provided by dataclasses.make_dataclass + __dataclass_fields__: ClassVar[dict[str, dataclasses.Field[Any]]] + + # Overridden by metaclass + @property + def _coercable_fields(self) -> Mapping[str, tuple[type['BaseConfig'], Any]]: + return {} + + def as_dict(self) -> Mapping[str, Any]: + result = dataclasses.asdict(self) + for field_name in self._coercable_fields: + field_value = getattr(self, field_name, None) + if isinstance(field_value, BaseConfig): + result[field_name] = field_value.as_dict() + return result diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py new file mode 100644 index 000000000..429fc9e03 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_model.py @@ -0,0 +1,52 @@ +"""Defines interface of a BaseModel.""" + +from collections.abc import Callable, Mapping +import dataclasses +from typing import Any, TypeAlias +from alphafold3 import structure +from alphafold3.model import features +import numpy as np +import mindspore as ms + +ModelResult: TypeAlias = Mapping[str, Any] +ScalarNumberOrArray: TypeAlias = Mapping[str, float | int | np.ndarray] + +# Eval result will contain scalars (e.g. metrics or losses), selected from the +# forward pass outputs or computed in the online evaluation; np.ndarrays or +# jax.Arrays generated from the forward pass outputs (e.g. distogram expected +# distances) or batch inputs; protein structures (predicted and ground-truth). +EvalResultValue: TypeAlias = ( + float | int | np.ndarray | ms.Tensor | structure.Structure +) +# Eval result may be None for some metrics if they are not computable. +EvalResults: TypeAlias = Mapping[str, EvalResultValue | None] +# Interface metrics are all floats or None. +InterfaceMetrics: TypeAlias = Mapping[str, float | None] +# Interface results are a mapping from interface name to mappings from score +# type to metric value. +InterfaceResults: TypeAlias = Mapping[str, Mapping[str, InterfaceMetrics]] +# Eval output consists of full eval results and a dict of interface metrics. +EvalOutput: TypeAlias = tuple[EvalResults, InterfaceResults] + +# Signature for `apply` method of hk.transform_with_state called on a BaseModel. +# ForwardFn: TypeAlias = Callable[ +# [hk.Params, hk.State, jax.Array, features.BatchDict], +# tuple[ModelResult, hk.State], +# ] + + +@dataclasses.dataclass(frozen=True) +class InferenceResult: + """Postprocessed model result.""" + + # Predicted protein structure. + predicted_structure: structure.Structure = dataclasses.field() + # Useful numerical data (scalars or arrays) to be saved at inference time. + numerical_data: ScalarNumberOrArray = dataclasses.field( + default_factory=dict) + # Smaller numerical data (usually scalar) to be saved as inference metadata. + metadata: ScalarNumberOrArray = dataclasses.field(default_factory=dict) + # Additional dict for debugging, e.g. raw outputs of a model forward pass. + debug_outputs: ModelResult | None = dataclasses.field(default_factory=dict) + # Model identifier. + model_id: bytes = b'' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py new file mode 100644 index 000000000..090ab43d3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/base_modules.py @@ -0,0 +1,146 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Common modules.""" + +from collections.abc import Sequence +import contextlib +import numbers +from typing import TypeAlias + +import numpy as np +import mindspore as ms +from mindspore import nn, mint +from mindspore.common import initializer +from mindchemistry.e3.utils import Ncon + +# Useful for mocking in tests. +DEFAULT_PRECISION = None + +# Constant from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.) +TRUNCATED_NORMAL_STDDEV_FACTOR = np.asarray( + 0.87962566103423978, dtype=np.float32 +) + + +class LayerNorm(nn.Cell): + """LayerNorm module. + + Equivalent to ms.nn.LayerNorm. In most cases, it can be replaced by ms.nn.LayerNorm. + Here, gamma is scale, beta is shift or offset + Args: + normalized_shape (tuple | list): The shape of Tensor which need to LayerNorm. + name (str): Name of this layer. + begin_norm_axis(int): From which axis norm begin + begin_params_axis(int): From which axis params begin + gamma_init('str'): Initializer of gamma + beta_init('str'): Initializer of beta + epsilon(float): epsilon value + dtype(ms.type): Type of output + create_beta(bool): whether to create a trainable beta parameter + create_gamma(bool): whether to create a trainable gamma parameter + Inputs: + - **x** (Tensor) - Tensor of any shape + Outputs: + The shape of tensor is the same as x. + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, normalized_shape, name=None, begin_norm_axis=-1, + begin_params_axis=-1, gamma_init='ones', + beta_init='zeros', epsilon=1e-5, dtype=ms.float32, + create_beta=True, create_gamma=True): + super().__init__() + if not create_beta: + beta_init = 'zeros' + if not create_gamma: + gamma_init = 'ones' + self.layernorm = nn.LayerNorm(normalized_shape[begin_norm_axis:], begin_norm_axis=begin_norm_axis, + begin_params_axis=begin_params_axis, gamma_init=gamma_init, + beta_init=beta_init, epsilon=epsilon, dtype=dtype) + if create_beta is False: + self.layernorm.beta.requires_grad = False + if create_gamma is False: + self.layernorm.gamma.requires_grad = False + self.dtype = dtype + + def construct(self, x): + out = self.layernorm(x.astype(ms.float32)).astype(x.dtype) + return out + + +class CustomDense(nn.Cell): + """ + Custom Linear Module. It can be apply to a high dimension Tensor, and can be used on more than 1D Matmul. + In Alphafold, they use Einsum to replace Matmul, here we use Ncon to replace Matmul. if in_shape and out_shape + are both int, this layer is equivalence to nn.Dense. + Args: + in_shape (Union(int, List, Tuple)): input shape, that need to be multiplied. + out_shape (Union(int, List, Tuple)): output shape, that need to be multiplied. + Inputs: + - **x** (Tensor) + Outputs: + + Supported Platforms: + ``Ascend`` + """ + + def __init__(self, in_shape, out_shape, weight_init="ones", use_bias=False, bias_init="zero", ndim=None, dtype=ms.float32): + # "zeros" change to ones for test + super().__init__() + if isinstance(in_shape, int): + in_shape = (in_shape,) + if isinstance(out_shape, int): + out_shape = (out_shape,) + self.num_output_dims = len(out_shape) + self.num_input_dims = len(in_shape) + if ndim is None: + ndim = len(in_shape) + 1 + if weight_init in ["relu", "linear"]: + self.weight = custom_initializer( + weight_init, in_shape + out_shape, dtype=dtype) + else: + self.weight = ms.Parameter(initializer.initializer( + weight_init, in_shape + out_shape, dtype=dtype)) + self.use_bias = use_bias + if self.use_bias == True: + self.bias = ms.Parameter( + initializer.initializer(bias_init, out_shape, dtype=dtype)) + ncon_list1 = [-i-1 for i in range(ndim - self.num_input_dims)] + [ + i+1 for i in range(len(in_shape))] + ncon_list2 = (ncon_list1[ndim - self.num_input_dims:]) + \ + [-i-ndim+self.num_input_dims-1 for i in range(len(out_shape))] + self.ncon = Ncon([ncon_list1, ncon_list2]) + + in_letters = 'abcde'[: self.num_input_dims] + out_letters = 'hijkl'[: self.num_output_dims] + self.equation = f'...{in_letters}, {in_letters}{out_letters}->...{out_letters}' + + def construct(self, x): + if self.use_bias: + output = self.ncon([x, self.weight]) + self.bias + else: + output = self.ncon([x, self.weight]) + return output + + +def custom_initializer(initializer_name, input_shape, dtype=ms.float32): + noise_scale = ms.Tensor(1.0) + for channel_dim in input_shape: + noise_scale /= channel_dim + if initializer_name == 'relu': + noise_scale *= 2 + stddev = ms.ops.sqrt(noise_scale) + stddev = stddev / ms.Tensor(TRUNCATED_NORMAL_STDDEV_FACTOR) + param = ms.Parameter(initializer.initializer( + initializer.TruncatedNormal(stddev, 0), input_shape, dtype)) + return param + diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py new file mode 100644 index 000000000..2b18d6122 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/mapping.py @@ -0,0 +1,356 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Specialized mapping functions.""" + +from collections.abc import Callable, Sequence +import functools +from typing import Any +import mindspore as ms + + +Pytree = Any +PytreeJaxArray = Any + +partial = functools.partial +PROXY = object() + + +def _maybe_slice(array, i, slice_size, axis): + "modified to mindspore" + if axis is PROXY: + return array + else: + start = [0]*array.ndim + start[axis] = i + size = list(array.shape) + size[axis] = slice_size + return ms.ops.slice(array, start, size) + + +def _maybe_get_size(array, axis): + "modified to mindspore" + if axis == PROXY: + return -1 + else: + return array.shape[axis] + + +def tree_flatten(tree): + if isinstance(tree, (list, tuple)): + flat, structure = [], [] + for item in tree: + sub_flat, sub_struct = tree_flatten(item) + flat.extend(sub_flat) + structure.append(sub_struct) + return flat, structure + elif isinstance(tree, dict): + flat, structure = [], {} + for key, value in tree.items(): + sub_flat, sub_struct = tree_flatten(value) + flat.extend(sub_flat) + structure[key] = sub_struct + return flat, structure + else: + return [tree], None + + +def tree_unflatten(flat, structure): + if isinstance(structure, list): + result, idx = [], 0 + for sub_struct in structure: + sub_tree, idx = tree_unflatten(flat[idx:], sub_struct) + result.append(sub_tree) + return result, idx + elif isinstance(structure, dict): + result, idx = {}, 0 + for key, sub_struct in structure.items(): + sub_tree, idx = tree_unflatten(flat[idx:], sub_struct) + result[key] = sub_tree + return result, idx + else: + return flat[0], 1 + + +def _expand_axes(axes, values, name="sharded_apply"): + values_tree_def = tree_flatten(values)[1] + # flat_axes = tree_flatten(axes)[0] + flat_axes = [PROXY if axes is None else axes for _ in values_tree_def] + expanded_axes, _ = tree_unflatten(flat_axes, values_tree_def) + return expanded_axes + + +def tree_map(fn, *trees): + "Mindspore do not have the same function like Jax.tree.map, so try to write a mindspore vesion." + tree_types = set([type(tree) for tree in trees]) + # if len(tree_types) != 1: + # raise ValueError("All input trees must have the same structure") + tree_type = tree_types.pop() + if tree_type in (list,): + return tree_type(tree_map(fn, *subtrees) for subtrees in zip(*trees)) + elif tree_type is dict: + keys = trees[0].keys() + if not all(tree.keys() == keys for tree in trees): + raise ValueError("All input dictionaries must have the same keys") + return {key: tree_map(fn, *(tree[key] for tree in trees)) for key in keys} + else: + return fn(*trees) + + +def tree_leaves(tree): + "same as tree_map" + if isinstance(tree, (list, tuple)): + leaves = [] + for item in tree: + leaves.extend(tree_leaves(item)) + return leaves + elif isinstance(tree, dict): + leaves = [] + for key in tree: + leaves.extend(tree_leaves(tree[key])) + return leaves + else: + return [tree] + + +def eval_shape(fun, *args, **kwargs): + fake_inputs = [ms.ops.zeros(arg.shape, dtype=arg.dtype) if isinstance( + arg, ms.Tensor) else arg for arg in args] + output = fun(*fake_inputs, **kwargs) + return output + + +def sharded_apply( + fun: Callable[..., PytreeJaxArray], + shard_size: int | None = 1, + in_axes: int | Pytree = 0, + out_axes: int | Pytree = 0, + new_out_axes: bool = False, +) -> Callable[..., PytreeJaxArray]: + """Sharded apply. + + Applies `fun` over shards to axes, in a way similar to vmap, + but does so in shards of `shard_size`. Shards are stacked after. + This allows a smooth trade-off between + memory usage (as in a plain map) vs higher throughput (as in a vmap). + + Args: + fun: Function to apply smap transform to. + shard_size: Integer denoting shard size. + in_axes: Either integer or pytree describing which axis to map over for each + input to `fun`, None denotes broadcasting. + out_axes: Integer or pytree denoting to what axis in the output the mapped + over axis maps. + new_out_axes: Whether to stack outputs on new axes. This assumes that the + output sizes for each shard (including the possible remainder shard) are + the same. + + Returns: + Function with smap applied. + """ + docstr = ( + "Mapped version of {fun}. Takes similar arguments to {fun} " + "but with additional array axes over which {fun} is mapped." + ) + if new_out_axes: + raise NotImplementedError("New output axes not yet implemented.") + + # shard size None denotes no sharding + if shard_size is None: + return fun + + def mapped_fn(*args, **kwargs): + # Expand in axes and determine loop range. + in_axes_ = _expand_axes(ms.Tensor(in_axes), args) + + in_sizes = tree_map(_maybe_get_size, list(args), in_axes_) + in_size = max(tree_leaves(in_sizes)) + + num_extra_shards = (in_size - 1) // shard_size + + # Fix if necessary. + last_shard_size = in_size % shard_size + last_shard_size = shard_size if last_shard_size == 0 else last_shard_size + + def apply_fun_to_slice(slice_start, slice_size, args, in_axes_): + input_slice = tree_map( + lambda array, axis: _maybe_slice( + array, slice_start, slice_size, axis + ), + args, + in_axes_, + ) + return fun(input_slice, **kwargs) + + remainder_shape_dtype = eval_shape( + lambda array, axis: apply_fun_to_slice( + 0, last_shard_size, array, axis), + args, in_axes_ + ) + + out_shapes = tree_map(lambda x: x.shape, remainder_shape_dtype) + out_dtypes = tree_map(lambda x: x.dtype, remainder_shape_dtype) + out_axes_ = _expand_axes(out_axes, out_shapes) + + if num_extra_shards > 0: + regular_shard_shape_dtype = eval_shape( + lambda array, axis: apply_fun_to_slice( + 0, shard_size, array, axis), + args, in_axes_ + ) + shard_shapes = tree_map( + lambda x: x.shape, regular_shard_shape_dtype) + + def make_output_shape(axis, shard_shape, remainder_shape): + axis = axis if isinstance(axis, int) else int(axis[0]) + shard_shape = tuple(shard_shape) + remainder_shape = tuple(remainder_shape) + return ms.ops.stack( + shard_shape[:axis] + + (shard_shape[axis] * num_extra_shards + + remainder_shape[axis],) + + shard_shape[axis + 1:] + ) + + out_shapes = tree_map( + make_output_shape, out_axes_[0], ms.Tensor( + shard_shapes), ms.Tensor(out_shapes) + ) + + # Calls dynamic Update slice with different argument order. + # This is here since tree_map only works with positional arguments. + def dynamic_update_slice_in_dim(array, slice_size, axis, i): + start = [0]*array.ndim + start[axis] = int(i) + size = list(array.shape) + size[axis] = slice_size.shape[axis] + # return ms.ops.slice(array, start, size) + end = [x + y for x, y in zip(start, size)] + array[start[0]: end[0]] = slice_size + return array + + def compute_shard(outputs, slice_start, slice_size): + slice_out = (lambda array, axis: apply_fun_to_slice( + int(slice_start), shard_size, array, axis))(args, in_axes_) + update_slice = partial(dynamic_update_slice_in_dim, i=slice_start) + # slice_out = (slice_out,) if not isinstance(slice, (int, float)) else [int(x) for x in slice_out] + return tree_map(update_slice, outputs, slice_out, out_axes_[0]) + + def scan_iteration(outputs, i): + new_outputs = compute_shard(outputs, i, shard_size) + return new_outputs + + slice_starts = ms.ops.arange(0, in_size - shard_size + 1, shard_size) + + def allocate_buffer(dtype, shape): + return ms.ops.zeros(shape, dtype=dtype) + + outputs = tree_map(allocate_buffer, out_dtypes, out_shapes) + + if slice_starts.shape[0] > 0: + for slice_start in slice_starts: + outputs = scan_iteration(outputs, slice_start) + # scan_op = ms.ops.Scan() + # outputs, _ = scan_op(scan_iteration, outputs, slice_starts) + + if last_shard_size != shard_size: + remainder_start = in_size - last_shard_size + outputs = compute_shard(outputs, remainder_start, last_shard_size) + + return outputs + + return mapped_fn + + +def sharded_map(fun, shard_size=1, in_axes=0, out_axes=0): + vmapped_fun = ms.vmap(fun, int(in_axes), int(out_axes)) + return sharded_apply(vmapped_fun, shard_size, in_axes, out_axes) + + +def reshape_partitioned_inputs(batched_args, partitioned_dim, subbatch_size): + subbatch_args = [] + """Reshapes so subbatching doesn't happen on the partitioned dim.""" + for arg in batched_args: + shape = arg.shape + new_shape = ( + shape[:partitioned_dim] + + (subbatch_size, shape[partitioned_dim] // subbatch_size) + + shape[partitioned_dim + 1:] + ) + subbatched_args.append(arg.reshape(new_shape)) + return subbatched_args + + +def reshape_partitioned_output(output, output_subbatch_dim): + """Reshapes outputs as if reshape_partitioned_inputs were never applied.""" + out_shape = ( + output.shape[: output_subbatch_dim - 1] + + (-1,) + + output.shape[output_subbatch_dim + 1:] + ) + return output.reshape(out_shape) + + +def inference_subbatch(module, subbatch_size, batched_args, + nonbatched_args, input_subbatch_dim=0, output_subbatch_dim=None, + input_subbatch_dim_is_partitioned=False): + """Run through subbatches (like batch apply but with split and concat).""" + assert len(batched_args) > 0 + if output_subbatch_dim is None: + output_subbatch_dim = input_subbatch_dim + if input_subbatch_dim_is_partitioned: + # Subbatching along the partitioned axis would induce an all-gather that + # undoes the partitioning. So instead we reshape such that + # [..., partitioned_input_size, ...] becomes [..., subbatch_size, + # partitioned_input_size // subbatch_size, ...] and then actually subbatch + # along the partitioned_input_size // subbatch_size axis in slices of + # size 1. Partitioning is then preserved on the partitioned axis, except + # that dimension is now of size subbatch_size instead of + # partitioned_input_size. Note that the module itself still sees inputs of + # size [..., subbatch_size, ...], just as it would if this reshaping were + # not applied. + batched_args = reshape_partitioned_inputs( + batched_args, input_subbatch_dim, subbatch_size + ) + input_subbatch_dim += 1 + output_subbatch_dim += 1 + subbatch_size = 1 + + def run_module(*batched_args): + if input_subbatch_dim_is_partitioned: + # Squeeze off the singleton dimension (otherwise the module would see + # [..., subbatch_size, 1, ...]). + batched_args = [b.squeeze(axis=input_subbatch_dim) + for b in batched_args] + args = list(batched_args)[0] + list(nonbatched_args) + res = module(*args) + if input_subbatch_dim_is_partitioned: + # Add back in the singleton dimension so the outputs are stacked on the + # axis we are actually subbatching over (i.e stacked back to + # [..., subbatch_size, partitioned_input_size // subbatch_size, ...]), + # rather than on the partitioned axis, which would again induce an + # all-gather that breaks partitioning. + res = ms.ops.expand_dims(res, axis=output_subbatch_dim) + return res + sharded_module = sharded_apply( + run_module, + shard_size=subbatch_size, + in_axes=input_subbatch_dim, + out_axes=output_subbatch_dim, + ) + output = sharded_module(*batched_args) + if input_subbatch_dim_is_partitioned: + # The is of the same shape as the inputs [..., subbatch_size, + # partitioned_input_size // subbatch_size, ...]. Reshape to + # [..., partitioned_input_size, ...] as if the reshaping due to partitioning + # had never been applied. + output = reshape_partitioned_output(output, output_subbatch_dim) + + return output diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py new file mode 100644 index 000000000..e2b0815c8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/components/utils.py @@ -0,0 +1,63 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections import abc +import contextlib +import numbers + +import numpy as np +import mindspore as ms + +VALID_DTYPES = [np.float32, np.float64, np.int8, np.int32, np.int32, bool] + + +def remove_invalidly_typed_feats(batch): + """Remove features of types we don't want to send to the TPU e.g. strings.""" + return { + k: v + for k, v in batch.items() + if hasattr(v, 'dtype') and v.dtype in VALID_DTYPES + } + + +def mask_mean(mask, value, axis=None, keepdims=False, eps=1e-10): + """Masked mean.""" + + mask_shape = mask.shape + value_shape = value.shape + + assert len(mask_shape) == len( + value_shape + ), 'Shapes are not compatible, shapes: {}, {}'.format(mask_shape, value_shape) + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(mask_shape))) + assert isinstance( + axis, abc.Iterable + ), 'axis needs to be either an iterable, integer or "None"' + + broadcast_factor = 1.0 + for axis_ in axis: + value_size = value_shape[axis_] + mask_size = mask_shape[axis_] + if mask_size == 1: + broadcast_factor *= value_size + else: + error = f'Shapes are not compatible, shapes: {mask_shape}, {value_shape}' + assert mask_size == value_size, error + + return ms.ops.sum(mask * value, keepdim=keepdims, dim=axis) / ( + ms.ops.maximum( + ms.ops.sum(mask, keepdim=keepdims, dim=axis) * + broadcast_factor, eps + ) + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py new file mode 100644 index 000000000..a956b3c13 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidence_types.py @@ -0,0 +1,306 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Confidence categories for predictions.""" + +import dataclasses +import enum +import json +from typing import Any, Self + +from absl import logging +import numpy as np +import mindspore as ms +from alphafold3.model.components import base_model +from alphafold3.model.components.mapping import tree_map + + +class StructureConfidenceFullEncoder(json.JSONEncoder): + """JSON encoder for serializing confidence types.""" + + def __init__(self, **kwargs): + super().__init__(**(kwargs | dict(separators=(',', ':')))) + + def encode(self, o: 'StructureConfidenceFull'): + # Cast to np.float64 before rounding, since casting to Python float will + # cast to a 64 bit float, potentially undoing np.float32 rounding. + atom_plddts = np.round( + np.clip(np.asarray(o.atom_plddts, dtype=np.float64), 0.0, 99.99), 2 + ).astype(float) + contact_probs = np.round( + np.clip(np.asarray(o.contact_probs, dtype=np.float64), 0.0, 1.0), 2 + ).astype(float) + pae = np.round( + np.clip(np.asarray(o.pae, dtype=np.float64), 0.0, 99.9), 1 + ).astype(float) + return """\ +{ + "atom_chain_ids": %s, + "atom_plddts": %s, + "contact_probs": %s, + "pae": %s, + "token_chain_ids": %s, + "token_res_ids": %s +}""" % ( + super().encode(o.atom_chain_ids), + super().encode(list(atom_plddts)).replace('NaN', 'null'), + super().encode([list(x) for x in contact_probs]).replace('NaN', 'null'), + super().encode([list(x) for x in pae]).replace('NaN', 'null'), + super().encode(o.token_chain_ids), + super().encode(o.token_res_ids), + ) + + +def _dump_json(data: Any, indent: int | None = None) -> str: + """Dumps a json string with JSON compatible NaN representation.""" + json_str = json.dumps( + data, + sort_keys=True, + indent=indent, + separators=(',', ': '), + ) + return json_str.replace('NaN', 'null') + + +@enum.unique +class ConfidenceCategory(enum.Enum): + """Confidence categories for AlphaFold predictions.""" + + HIGH = 0 + MEDIUM = 1 + LOW = 2 + DISORDERED = 3 + + @classmethod + def from_char(cls, char: str) -> Self: + match char: + case 'H': + return cls.HIGH + case 'M': + return cls.MEDIUM + case 'L': + return cls.LOW + case 'D': + return cls.DISORDERED + case _: + raise ValueError( + f'Unknown character. Expected one of H, M, L or D; got: {char}' + ) + + def to_char(self) -> str: + match self: + case self.HIGH: + return 'H' + case self.MEDIUM: + return 'M' + case self.LOW: + return 'L' + case self.DISORDERED: + return 'D' + + @classmethod + def from_confidence_score(cls, confidence: float) -> Self: + if 90 <= confidence <= 100: + return cls.HIGH + if 70 <= confidence < 90: + return cls.MEDIUM + if 50 <= confidence < 70: + return cls.LOW + if 0 <= confidence < 50: + return cls.DISORDERED + raise ValueError(f'Confidence score out of range [0, 100]: {confidence}') + + +@dataclasses.dataclass() +class AtomConfidence: + """Dataclass for 1D per-atom confidences from AlphaFold.""" + + chain_id: list[str] + atom_number: list[int] + confidence: list[float] + confidence_category: list[ConfidenceCategory] + + def __post_init__(self): + num_res = len(self.atom_number) + if not all( + len(v) == num_res + for v in [self.chain_id, self.confidence, self.confidence_category] + ): + raise ValueError('All confidence fields must have the same length.') + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Instantiates an AtomConfidence from a structure. + + Args: + inference_result: Inference result from AlphaFold. + + Returns: + Scores in AtomConfidence dataclass. + """ + struc = inference_result.predicted_structure + as_dict = { + 'chain_id': [], + 'atom_number': [], + 'confidence': [], + 'confidence_category': [], + } + for atom_number, atom in enumerate(struc.iter_atoms()): + this_confidence = float(struc.atom_b_factor[atom_number]) + as_dict['chain_id'].append(atom['chain_id']) + as_dict['atom_number'].append(atom_number) + as_dict['confidence'].append(round(this_confidence, 2)) + as_dict['confidence_category'].append( + ConfidenceCategory.from_confidence_score(this_confidence) + ) + return cls(**as_dict) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Instantiates a AtomConfidence from a json string.""" + input_dict = json.loads(json_string) + input_dict['confidence_category'] = [ + ConfidenceCategory.from_char(k) + for k in input_dict['confidence_category'] + ] + return cls(**input_dict) + + def to_json(self) -> str: + output = dataclasses.asdict(self) + output['confidence_category'] = [ + k.to_char() for k in output['confidence_category'] + ] + output['atom_number'] = [int(k) for k in output['atom_number']] + return _dump_json(output) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureConfidenceSummary: + """Dataclass for the summary of structure scores from AlphaFold. + + Attributes: + ptm: Predicted TM global score. + iptm: Interface predicted TM global score. + ranking_score: Ranking score extracted from CIF metadata. + fraction_disordered: Fraction disordered, measured with RASA. + has_clash: Has significant clashing. + chain_pair_pae_min: [num_chains, num_chains] Minimum cross chain PAE. + chain_pair_iptm: [num_chains, num_chains] Chain pair ipTM. + chain_ptm: [num_chains] Chain pTM. + chain_iptm: [num_chains] Mean cross chain ipTM for a chain. + """ + + ptm: float + iptm: float + ranking_score: float + fraction_disordered: float + has_clash: float + chain_pair_pae_min: np.ndarray + chain_pair_iptm: np.ndarray + chain_ptm: np.ndarray + chain_iptm: np.ndarray + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Returns a new instance based on a given inference result.""" + return cls( + ptm=float(inference_result.metadata['ptm']), + iptm=float(inference_result.metadata['iptm']), + ranking_score=float(inference_result.metadata['ranking_score']), + fraction_disordered=float( + inference_result.metadata['fraction_disordered'] + ), + has_clash=float(inference_result.metadata['has_clash']), + chain_pair_pae_min=inference_result.metadata['chain_pair_pae_min'], + chain_pair_iptm=inference_result.metadata['chain_pair_iptm'], + chain_ptm=inference_result.metadata['iptm_ichain'], + chain_iptm=inference_result.metadata['iptm_xchain'], + ) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Returns a new instance from a given json string.""" + return cls(**json.loads(json_string)) + + def to_json(self) -> str: + def convert(data): + if isinstance(data, np.ndarray): + # Cast to np.float64 before rounding, since casting to Python float will + # cast to a 64 bit float, potentially undoing np.float32 rounding. + rounded_data = np.round(data.astype(np.float64), decimals=2).tolist() + else: + rounded_data = np.round(data, decimals=2) + return rounded_data + + return _dump_json(tree_map(convert, dataclasses.asdict(self)), indent=1) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureConfidenceFull: + """Dataclass for full structure data from AlphaFold.""" + + pae: np.ndarray + token_chain_ids: list[str] + token_res_ids: list[int] + atom_plddts: list[float] + atom_chain_ids: list[str] + contact_probs: np.ndarray # [num_tokens, num_tokens] + + @classmethod + def from_inference_result( + cls, inference_result: base_model.InferenceResult + ) -> Self: + """Returns a new instance based on a given inference result.""" + + pae = inference_result.numerical_data['full_pae'] + if isinstance(pae, ms.Tensor): + pae = pae.asnumpy() + if not isinstance(pae, np.ndarray): + logging.info('%s', type(pae)) + raise TypeError('pae should be a numpy array.') + + contact_probs = inference_result.numerical_data['contact_probs'] + if isinstance(contact_probs, ms.Tensor): + contact_probs = contact_probs.asnumpy() + if not isinstance(contact_probs, np.ndarray): + logging.info('%s', type(contact_probs)) + raise TypeError('contact_probs should be a numpy array.') + + struc = inference_result.predicted_structure + chain_ids = struc.chain_id.tolist() + atom_plddts = struc.atom_b_factor.tolist() + token_chain_ids = [ + str(token_id) + for token_id in inference_result.metadata['token_chain_ids'] + ] + token_res_ids = [ + int(token_id) for token_id in inference_result.metadata['token_res_ids'] + ] + return cls( + pae=pae, + token_chain_ids=token_chain_ids, + token_res_ids=token_res_ids, + atom_plddts=atom_plddts, + atom_chain_ids=chain_ids, + contact_probs=contact_probs, + ) + + @classmethod + def from_json(cls, json_string: str) -> Self: + """Returns a new instance from a given json string.""" + return cls(**json.loads(json_string)) + + def to_json(self) -> str: + """Converts StructureConfidenceFull to json string.""" + return json.dumps(self, cls=StructureConfidenceFullEncoder) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py new file mode 100644 index 000000000..02bf0478e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/confidences.py @@ -0,0 +1,664 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for extracting and processing confidences from model outputs.""" +import warnings +import numpy as np +from absl import logging +from alphafold3 import structure +from alphafold3.constants import residue_names +from alphafold3.cpp import mkdssp +from scipy import spatial + + +# From Sander & Rost 1994 https://doi.org/10.1002/prot.340200303 +MAX_ACCESSIBLE_SURFACE_AREA = { + 'ALA': 106.0, + 'ARG': 248.0, + 'ASN': 157.0, + 'ASP': 163.0, + 'CYS': 135.0, + 'GLN': 198.0, + 'GLU': 194.0, + 'GLY': 84.0, + 'HIS': 184.0, + 'ILE': 169.0, + 'LEU': 164.0, + 'LYS': 205.0, + 'MET': 188.0, + 'PHE': 197.0, + 'PRO': 136.0, + 'SER': 130.0, + 'THR': 142.0, + 'TRP': 227.0, + 'TYR': 222.0, + 'VAL': 142.0, +} + +# Weights for ranking confidence. +_IPTM_WEIGHT = 0.8 +_FRACTION_DISORDERED_WEIGHT = 0.5 +_CLASH_PENALIZATION_WEIGHT = 100.0 + + +def windowed_solvent_accessible_area(cif: str, window: int = 25) -> np.ndarray: + """Implementation of AlphaFold_RSA. + + AlphaFold_RSA defined in + https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767/. + + Args: + cif: raw cif string. + window: The window over which to average accessible surface area + + Returns: + An array of size num_res that predicts disorder by using windowed solvent + accessible surface area. + """ + result = mkdssp.get_dssp(cif, calculate_surface_accessibility=True) + parse_row = False + rasa = [] + for row in result.splitlines(): + if parse_row: + aa = row[13:14] + if aa == '!': + continue + aa3 = residue_names.PROTEIN_COMMON_ONE_TO_THREE.get(aa, 'ALA') + max_acc = MAX_ACCESSIBLE_SURFACE_AREA[aa3] + acc = int(row[34:38]) + norm_acc = acc / max_acc + if norm_acc > 1.0: + norm_acc = 1.0 + rasa.append(norm_acc) + if row.startswith(' # RESIDUE'): + parse_row = True + + half_w = (window - 1) // 2 + pad_rasa = np.pad(rasa, (half_w, half_w), 'reflect') + rasa = np.convolve(pad_rasa, np.ones(window), 'valid') / window + return rasa + + +def fraction_disordered( + struc: structure.Structure, rasa_disorder_cutoff: float = 0.581 +) -> float: + """Compute fraction of protein residues that are disordered. + + Args: + struc: A structure to compute rASA metrics on. + rasa_disorder_cutoff: The threshold at which residues are considered + disordered. Default value taken from + https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9601767/. + + Returns: + The fraction of protein residues that are disordered + (rasa > rasa_disorder_cutoff). + """ + struc = struc.filter_to_entity_type(protein=True) + rasa = [] + seq_rasa = {} + for chain_id, chain_seq in struc.chain_single_letter_sequence().items(): + if chain_seq in seq_rasa: + # We assume that identical sequences have approximately similar rasa + # values to speed up the computation. + rasa.extend(seq_rasa[chain_seq]) + continue + chain_struc = struc.filter(chain_id=chain_id) + try: + rasa_per_residue = windowed_solvent_accessible_area( + chain_struc.to_mmcif() + ) + seq_rasa[chain_seq] = rasa_per_residue + rasa.extend(rasa_per_residue) + except (ValueError, RuntimeError): + logging.warning('%s: rasa calculation failed', struc.name) + + if not rasa: + return 0.0 + return np.mean(np.array(rasa) > rasa_disorder_cutoff) + + +def has_clash( + struc: structure.Structure, + cutoff_radius: float = 1.1, + min_clashes_for_overlap: int = 100, + min_fraction_for_overlap: float = 0.5, +) -> bool: + """Determine whether the structure has at least one clashing chain. + + A clashing chain is defined as having greater than 100 polymer atoms within + 1.1A of another polymer atom, or having more than 50% of the chain with + clashing atoms. + + Args: + struc: A structure to get clash metrics for. + cutoff_radius: atom distances under this threshold are considered a clash. + min_clashes_for_overlap: The minimum number of atom-atom clashes for a chain + to be considered overlapping. + min_fraction_for_overlap: The minimum fraction of atoms within a chain that + are clashing for the chain to be considered overlapping. + + Returns: + True if the structure has at least one clashing chain. + """ + struc = struc.filter_to_entity_type(protein=True, rna=True, dna=True) + if not struc.chains: + return False + coords = struc.coords + coord_kdtree = spatial.cKDTree(coords) + clashes_per_atom = coord_kdtree.query_ball_point( + coords, p=2.0, r=cutoff_radius + ) + per_atom_has_clash = np.zeros(len(coords), dtype=np.int32) + for atom_idx, clashing_indices in enumerate(clashes_per_atom): + for clashing_idx in clashing_indices: + if np.abs(struc.res_id[atom_idx] - struc.res_id[clashing_idx]) > 1 or ( + struc.chain_id[atom_idx] != struc.chain_id[clashing_idx] + ): + per_atom_has_clash[atom_idx] = True + break + for chain_id in struc.chains: + mask = struc.chain_id == chain_id + num_atoms = np.sum(mask) + if num_atoms == 0: + continue + num_clashes = np.sum(per_atom_has_clash * mask) + frac_clashes = num_clashes / num_atoms + if ( + num_clashes > min_clashes_for_overlap + or frac_clashes > min_fraction_for_overlap + ): + return True + return False + + +def get_ranking_score( + ptm: float, iptm: float, fraction_disordered_: float, has_clash_: bool +) -> float: + # ipTM is NaN for single chain structures. Use pTM for such cases. + if np.isnan(iptm): + ptm_iptm_average = ptm + else: + ptm_iptm_average = _IPTM_WEIGHT * iptm + (1.0 - _IPTM_WEIGHT) * ptm + return ( + ptm_iptm_average + + _FRACTION_DISORDERED_WEIGHT * fraction_disordered_ + - _CLASH_PENALIZATION_WEIGHT * has_clash_ + ) + + +def rank_metric( + full_pde: np.ndarray, contact_probs: np.ndarray +) -> np.ndarray: + """Compute the metric that will be used to rank predictions, higher is better. + + Args: + full_pde: A [num_samples, num_tokens,num_tokens] matrix of predicted + distance errors between pairs of tokens. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + A scalar that can be used to rank (higher is better). + """ + if not isinstance(full_pde, type(contact_probs)): + raise ValueError( + 'full_pde and contact_probs must be of the same type.') + + if isinstance(full_pde, np.ndarray): + sum_fn = np.sum + else: + raise ValueError('full_pde must be a numpy array or a jax array.') + # It was found that taking the contact_map weighted average was better than + # just the predicted distance error on its own. + return -sum_fn(full_pde * contact_probs[None, :, :], axis=(-2, -1)) / ( + sum_fn(contact_probs) + 1e-6 + ) + + +def weighted_mean(mask, value, axis): + return np.mean(mask * value, axis=axis) / (1e-8 + np.mean(mask, axis=axis)) + + +def pde_single( + num_tokens: int, + asym_ids: np.ndarray, + full_pde: np.ndarray, + contact_probs: np.ndarray, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute 1D PDE summaries. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted + distance errors. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + A tuple (ichain, xchain, full_chain) where: + `ichain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PDE matrix over all + its within-chain interactions, weighted by `contact_probs`. + `xchain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PDE matrix over all + its cross-chain interactions, weighted by `contact_probs`. + `full_chain` is a [num_samples, num_tokens] matrix where the + value assigned to each token is an average of it PDE against all tokens, + weighted by `contact_probs`. + """ + + full_pde = full_pde[:, :num_tokens, :num_tokens] + contact_probs = contact_probs[:num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pde.shape[0] + + asym_ids = asym_ids[None] + contact_probs = contact_probs[None] + + ichain = np.zeros((num_samples, num_chains)) + xchain = np.zeros((num_samples, num_chains)) + + for idx, asym_id in enumerate(unique_asym_ids): + my_asym_id = asym_ids == asym_id + imask = my_asym_id[:, :, None] * my_asym_id[:, None, :] + xmask = my_asym_id[:, :, None] * ~my_asym_id[:, None, :] + imask = imask * contact_probs + xmask = xmask * contact_probs + ichain[:, idx] = weighted_mean( + mask=imask, value=full_pde, axis=(-2, -1)) + xchain[:, idx] = weighted_mean( + mask=xmask, value=full_pde, axis=(-2, -1)) + + full_chain = weighted_mean(mask=contact_probs, value=full_pde, axis=(-1,)) + + return ichain, xchain, full_chain + + +def chain_pair_pde( + num_tokens: int, asym_ids: np.ndarray, full_pde: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Compute predicted distance errors for all pairs of chains. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pde: A [num_samples, num_tokens, num_tokens] matrix of predicted + distance errors. + + Returns: + chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average + per chain-pair predicted distance error. + chain_pair_pred_err_min - a [num_chains, num_chains] matrix with min + per chain-pair predicted distance error. + """ + full_pde = full_pde[:, :num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pde.shape[0] + chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains)) + chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains)) + + for idx1, asym_id_1 in enumerate(unique_asym_ids): + subset = full_pde[:, asym_ids == asym_id_1, :] + for idx2, asym_id_2 in enumerate(unique_asym_ids): + subsubset = subset[:, :, asym_ids == asym_id_2] + chain_pair_pred_err_mean[:, idx1, idx2] = np.mean( + subsubset, axis=(1, 2)) + chain_pair_pred_err_min[:, idx1, idx2] = np.min( + subsubset, axis=(1, 2)) + return chain_pair_pred_err_mean, chain_pair_pred_err_min + + +def weighted_nanmean( + value: np.ndarray, mask: np.ndarray, axis: int +) -> np.ndarray: + """Nan-mean with weighting -- empty slices return NaN.""" + assert mask.shape == value.shape + assert not np.isnan(mask).all() + + nan_idxs = np.where(np.isnan(value)) + # Need to NaN the mask to get the correct denominator weighting. + mask_with_nan = mask.copy() + mask_with_nan[nan_idxs] = np.nan + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings(action='ignore', message='Mean of empty slice') + return np.nanmean(value * mask_with_nan, axis=axis) / np.nanmean( + mask_with_nan, axis=axis + ) + + +def chain_pair_pae( + *, + num_tokens: int, + asym_ids: np.ndarray, + full_pae: np.ndarray, + mask: np.ndarray | None = None, + contact_probs: np.ndarray | None = None, +) -> tuple[np.ndarray, np.ndarray, np.ndarray]: + """Compute predicted errors for all pairs of chains. + + Args: + num_tokens: The number of tokens (not including padding). + asym_ids: The asym_ids (array of shape num_tokens). + full_pae: A [num_samples, num_tokens, num_tokens] matrix of predicted + errors. + mask: A [num_tokens, num_tokens] mask matrix. + contact_probs: A [num_tokens, num_tokens] matrix consisting of the + probability of contact (<8A) that is returned from the distogram head. + + Returns: + chain_pair_pred_err_mean - a [num_chains, num_chains] matrix with average + per chain-pair predicted error. + """ + if mask is None: + mask = np.ones(shape=full_pae.shape[1:], dtype=bool) + if contact_probs is None: + contact_probs = np.ones(shape=full_pae.shape[1:], dtype=float) + assert mask.shape == full_pae.shape[1:] + + full_pae = full_pae[:, :num_tokens, :num_tokens] + mask = mask[:num_tokens, :num_tokens] + asym_ids = asym_ids[:num_tokens] + contact_probs = contact_probs[:num_tokens, :num_tokens] + unique_asym_ids = np.unique(asym_ids) + num_chains = len(unique_asym_ids) + num_samples = full_pae.shape[0] + chain_pair_pred_err_mean = np.zeros((num_samples, num_chains, num_chains)) + chain_pair_pred_err_min = np.zeros((num_samples, num_chains, num_chains)) + + for idx1, asym_id_1 in enumerate(unique_asym_ids): + subset = full_pae[:, asym_ids == asym_id_1, :] + subset_mask = mask[asym_ids == asym_id_1, :] + subset_contact_probs = contact_probs[asym_ids == asym_id_1, :] + for idx2, asym_id_2 in enumerate(unique_asym_ids): + subsubset = subset[:, :, asym_ids == asym_id_2] + subsubset_mask = subset_mask[:, asym_ids == asym_id_2] + subsubset_contact_probs = subset_contact_probs[:, + asym_ids == asym_id_2] + (flat_mask_idxs,) = np.where(subsubset_mask.flatten() > 0) + flat_subsubset = subsubset.reshape([num_samples, -1]) + flat_contact_probs = subsubset_contact_probs.flatten() + # A ligand chain will have no valid frames if it contains fewer than + # three non-colinear atoms (e.g. a sodium ion). + if not flat_mask_idxs.size: + chain_pair_pred_err_mean[:, idx1, idx2] = np.nan + chain_pair_pred_err_min[:, idx1, idx2] = np.nan + else: + chain_pair_pred_err_min[:, idx1, idx2] = np.min( + flat_subsubset[:, flat_mask_idxs], axis=1 + ) + chain_pair_pred_err_mean[:, idx1, idx2] = weighted_mean( + mask=flat_contact_probs[flat_mask_idxs], + value=flat_subsubset[:, flat_mask_idxs], + axis=-1, + ) + return chain_pair_pred_err_mean, chain_pair_pred_err_min, unique_asym_ids + + +def reduce_chain_pair( + *, + chain_pair_met: np.ndarray, + num_chain_tokens: np.ndarray, + agg_over_col: bool, + agg_type: str, + weight_method: str, +) -> tuple[np.ndarray, np.ndarray]: + """Compute 1D summaries from a chain-pair summary. + + Args: + chain_pair_met: A [num_samples, num_chains, num_chains] aggregate matrix. + num_chain_tokens: A [num_chains] array of number of tokens for each chain. + Used for 'per_token' weighting. + agg_over_col: Whether to aggregate the PAE over rows (i.e. average error + when aligned to me) or columns (i.e. my average error when aligned to all + others.) + agg_type: The type of aggregation to use, 'mean' or 'min'. + weight_method: The method to use for weighting the PAE, 'per_token' or + 'per_chain'. + + Returns: + A tuple (ichain, xchain) where: + `ichain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PAE matrix over all + its within-chain interactions, weighted by `contact_probs`. + `xchain` is a [num_samples, num_chains] matrix where the + value assigned to each chain is an average of the full PAE matrix over all + its cross-chain interactions, weighted by `contact_probs`. + """ + num_samples, num_chains, _ = chain_pair_met.shape + + ichain = chain_pair_met.diagonal(axis1=-2, axis2=-1) + + if weight_method == 'per_chain': + chain_weight = np.ones((num_chains,), dtype=float) + elif weight_method == 'per_token': + chain_weight = num_chain_tokens + else: + raise ValueError(f'Unknown weight method: {weight_method}') + + if agg_over_col: + agg_axis = -1 + else: + agg_axis = -2 + + if agg_type == 'mean': + weight = np.ones((num_samples, num_chains, num_chains), dtype=float) + weight -= np.eye(num_chains, dtype=float) + weight *= chain_weight[None] * chain_weight[:, None] + xchain = weighted_nanmean(chain_pair_met, mask=weight, axis=agg_axis) + elif agg_type == 'min': + is_self = np.eye(num_chains) + with warnings.catch_warnings(): + # Min over empty slice is ok and should return a NaN. + warnings.filterwarnings( + 'ignore', message='All-NaN slice encountered') + xchain = np.nanmin(chain_pair_met + 1e8 * is_self, axis=agg_axis) + else: + raise ValueError(f'Unknown aggregation method: {agg_type}') + + return ichain, xchain + + +def pae_metrics( + num_tokens: int, + asym_ids: np.ndarray, + full_pae: np.ndarray, + mask: np.ndarray, + contact_probs: np.ndarray, + tm_adjusted_pae: np.ndarray, +): + """PAE aggregate metrics.""" + assert mask.shape == full_pae.shape[1:] + assert contact_probs.shape == full_pae.shape[1:] + + chain_pair_contact_weighted, _, unique_asym_ids = chain_pair_pae( + num_tokens=num_tokens, + asym_ids=asym_ids, + full_pae=full_pae, + mask=mask, + contact_probs=contact_probs, + ) + + ret = {} + ret['chain_pair_pae_mean'], ret['chain_pair_pae_min'], _ = chain_pair_pae( + num_tokens=num_tokens, + asym_ids=asym_ids, + full_pae=full_pae, + mask=mask, + ) + chain_pair_iptm = np.stack( + [ + chain_pairwise_predicted_tm_scores( + tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens], + asym_id=asym_ids[:num_tokens], + pair_mask=mask[:num_tokens, :num_tokens], + ) + for sample_tm_adjusted_pae in tm_adjusted_pae + ], + axis=0, + ) + + num_chain_tokens = np.array( + [sum(asym_ids == asym_id) for asym_id in unique_asym_ids] + ) + + def reduce_chain_pair_fn(chain_pair: np.ndarray): + def inner(agg_over_col): + ichain_pae, xchain_pae = reduce_chain_pair( + num_chain_tokens=num_chain_tokens, + chain_pair_met=chain_pair, + agg_over_col=agg_over_col, + agg_type='mean', + weight_method='per_chain', + ) + return ichain_pae, xchain_pae + + ichain, xchain_row_agg = inner(False) + _, xchain_col_agg = inner(True) + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings( + action='ignore', message='Mean of empty slice') + xchain = np.nanmean( + np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0 + ) + return ichain, xchain + + pae_ichain, pae_xchain = reduce_chain_pair_fn(chain_pair_contact_weighted) + iptm_ichain, iptm_xchain = reduce_chain_pair_fn(chain_pair_iptm) + + ret.update({ + 'chain_pair_iptm': chain_pair_iptm, + 'iptm_ichain': iptm_ichain, + 'iptm_xchain': iptm_xchain, + 'pae_ichain': pae_ichain, + 'pae_xchain': pae_xchain, + }) + + return ret + + +def get_iptm_xchain(chain_pair_iptm: np.ndarray) -> np.ndarray: + """Cross chain aggregate ipTM.""" + num_samples, num_chains, _ = chain_pair_iptm.shape + weight = np.ones((num_samples, num_chains, num_chains), dtype=float) + weight -= np.eye(num_chains, dtype=float) + xchain_row_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-2) + xchain_col_agg = weighted_nanmean(chain_pair_iptm, mask=weight, axis=-1) + with warnings.catch_warnings(): + # Mean of empty slice is ok and should return a NaN. + warnings.filterwarnings(action='ignore', message='Mean of empty slice') + iptm_xchain = np.nanmean( + np.stack([xchain_row_agg, xchain_col_agg], axis=0), axis=0 + ) + return iptm_xchain + + +def predicted_tm_score( + tm_adjusted_pae: np.ndarray, + pair_mask: np.ndarray, + asym_id: np.ndarray, + interface: bool = False, +) -> float: + """Computes predicted TM alignment or predicted interface TM alignment score. + + Args: + tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore + values. + pair_mask: A [num_res, num_res] mask. The TM score will only aggregate over + masked-on entries. + asym_id: [num_res] asymmetric unit ID (the chain ID). Only needed for ipTM + calculation, i.e. when interface=True. + interface: If True, the interface predicted TM score is computed. If False, + the predicted TM score without any residue pair restrictions is computed. + + Returns: + score: pTM or ipTM score. + """ + num_tokens, _ = tm_adjusted_pae.shape + if tm_adjusted_pae.shape != (num_tokens, num_tokens): + raise ValueError( + f'Bad tm_adjusted_pae shape, expected ({num_tokens, num_tokens}), got ' + f'{tm_adjusted_pae.shape}.' + ) + + if pair_mask.shape != (num_tokens, num_tokens): + raise ValueError( + f'Bad pair_mask shape, expected ({num_tokens, num_tokens}), got ' + f'{pair_mask.shape}.' + ) + if pair_mask.dtype != bool: + raise TypeError( + f'Bad pair mask type, expected bool, got {pair_mask.dtype}') + if asym_id.shape[0] != num_tokens: + raise ValueError( + f'Bad asym_id shape, expected ({num_tokens},), got {asym_id.shape}.' + ) + + # Create pair mask. + if interface: + pair_mask = pair_mask * (asym_id[:, None] != asym_id[None, :]) + + # Ions and other ligands with colinear atoms have ill-defined frames. + if pair_mask.sum() == 0: + return np.nan + + normed_residue_mask = pair_mask / ( + 1e-8 + np.sum(pair_mask, axis=-1, keepdims=True) + ) + per_alignment = np.sum(tm_adjusted_pae * normed_residue_mask, axis=-1) + return per_alignment.max() + + +def chain_pairwise_predicted_tm_scores( + tm_adjusted_pae: np.ndarray, + pair_mask: np.ndarray, + asym_id: np.ndarray, +) -> np.ndarray: + """Compute predicted TM (pTM) between each pair of chains independently. + + Args: + tm_adjusted_pae: [num_res, num_res] Relevant tensor for computing TMScore + values. + pair_mask: A [num_res, num_res] mask specifying which frames are valid. + Invalid frames can be the result of chains with not enough atoms (e.g. + ions). + asym_id: [num_res] asymmetric unit ID (the chain ID). + + Returns: + A [num_chains, num_chains] matrix, where row i, column j indicates the + predicted TM-score for the interface between chain i and chain j. + """ + unique_chains = list(np.unique(asym_id)) + num_chains = len(unique_chains) + all_pairs_iptms = np.zeros((num_chains, num_chains)) + for i, chain_i in enumerate(unique_chains): + chain_i_mask = asym_id == chain_i + for j, chain_j in enumerate(unique_chains[i:]): + chain_j_mask = asym_id == chain_j + mask = chain_i_mask | chain_j_mask + (indices,) = np.where(mask) + is_interface = chain_i != chain_j + indices = np.ix_(indices, indices) + iptm = predicted_tm_score( + tm_adjusted_pae=tm_adjusted_pae[indices], + pair_mask=pair_mask[indices], + asym_id=asym_id[mask], + interface=is_interface, + ) + all_pairs_iptms[i, i + j] = iptm + all_pairs_iptms[i + j, i] = iptm + return all_pairs_iptms diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py new file mode 100644 index 000000000..0cd6961bd --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data3.py @@ -0,0 +1,127 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Protein features that are computed from parsed mmCIF objects.""" + +from collections.abc import Mapping, MutableMapping +import datetime +from typing import TypeAlias + +from alphafold3.constants import residue_names +from alphafold3.cpp import msa_profile +from alphafold3.model import protein_data_processing +import numpy as np + + +FeatureDict: TypeAlias = Mapping[str, np.ndarray] +MutableFeatureDict: TypeAlias = MutableMapping[str, np.ndarray] + + +def fix_features(msa_features: MutableFeatureDict) -> MutableFeatureDict: + """Renames the deletion_matrix feature.""" + msa_features['deletion_matrix'] = msa_features.pop('deletion_matrix_int') + return msa_features + + +def get_profile_features( + msa: np.ndarray, deletion_matrix: np.ndarray +) -> FeatureDict: + """Returns the MSA profile and deletion_mean features.""" + num_restypes = residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + profile = msa_profile.compute_msa_profile( + msa=msa, num_residue_types=num_restypes + ) + + return { + 'profile': profile.astype(np.float32), + 'deletion_mean': np.mean(deletion_matrix, axis=0), + } + + +def fix_template_features( + sequence: str, + template_features: FeatureDict, +) -> FeatureDict: + """Convert template features to AlphaFold 3 format. + + Args: + sequence: amino acid sequence of the protein. + template_features: Template features for the protein. + + Returns: + Updated template_features for the chain. + """ + num_res = len(sequence) + if not template_features['template_aatype'].shape[0]: + template_features = empty_template_features(num_res) + else: + template_release_timestamp = [ + _get_timestamp(x.decode('utf-8')) + for x in template_features['template_release_date'] + ] + + # Convert from atom37 to dense atom + dense_atom_indices = np.take( + protein_data_processing.PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37, + template_features['template_aatype'], + axis=0, + ) + + atom_mask = np.take_along_axis( + template_features['template_all_atom_masks'], dense_atom_indices, axis=2 + ) + atom_positions = np.take_along_axis( + template_features['template_all_atom_positions'], + dense_atom_indices[..., None], + axis=2, + ) + atom_positions *= atom_mask[..., None] + + template_features = { + 'template_aatype': template_features['template_aatype'], + 'template_atom_mask': atom_mask.astype(np.int32), + 'template_atom_positions': atom_positions.astype(np.float32), + 'template_domain_names': np.array( + template_features['template_domain_names'], dtype=object + ), + 'template_release_timestamp': np.array( + template_release_timestamp, dtype=np.float32 + ), + } + return template_features + + +def empty_template_features(num_res: int) -> FeatureDict: + """Creates a fully masked out template features to allow padding to work. + + Args: + num_res: The length of the target chain. + + Returns: + Empty template features for the chain. + """ + template_features = { + 'template_aatype': np.zeros(num_res, dtype=np.int32)[None, ...], + 'template_atom_mask': np.zeros( + (num_res, protein_data_processing.NUM_DENSE), dtype=np.int32 + )[None, ...], + 'template_atom_positions': np.zeros( + (num_res, protein_data_processing.NUM_DENSE, 3), dtype=np.float32 + )[None, ...], + 'template_domain_names': np.array([b''], dtype=object), + 'template_release_timestamp': np.array([0.0], dtype=np.float32), + } + return template_features + + +def _get_timestamp(date_str: str): + dt = datetime.datetime.fromisoformat(date_str) + dt = dt.replace(tzinfo=datetime.timezone.utc) + return dt.timestamp() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py new file mode 100644 index 000000000..eabdcfda9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/data_constants.py @@ -0,0 +1,27 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Constants shared across modules in the AlphaFold data pipeline.""" + +from alphafold3.constants import residue_names + +MSA_GAP_IDX = residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP.index( + '-' +) + +# Feature groups. +NUM_SEQ_NUM_RES_MSA_FEATURES = ('msa', 'msa_mask', 'deletion_matrix') +NUM_SEQ_MSA_FEATURES = ('msa_species_identifiers',) +TEMPLATE_FEATURES = ( + 'template_aatype', + 'template_atom_positions', + 'template_atom_mask', +) +MSA_PAD_VALUES = {'msa': MSA_GAP_IDX, 'msa_mask': 1, 'deletion_matrix': 0} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py new file mode 100644 index 000000000..efa941b68 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/atom_cross_attention.py @@ -0,0 +1,490 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from dataclasses import dataclass +import numpy as np +import mindspore as ms +from mindspore import nn, ops, Tensor + +from alphafold3.model import base_config +from alphafold3.model import feat_batch +from alphafold3.model import model_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import utils +from alphafold3.model.diffusion import diffusion_transformer + +@dataclass +class AtomCrossAttEncoderConfig(base_config.BaseConfig): + per_token_channels: int = 768 + per_atom_channels: int = 128 + atom_transformer: diffusion_transformer.CrossAttTransformer.Config = ( + base_config.autocreate(num_intermediate_factor=2, num_blocks=3) + ) + per_atom_pair_channels: int = 16 + + +class _PerAtomConditioning(nn.Cell): + """ + A class to compute per-atom and pairwise conditioning information for structural data. + + Args: + config: Configuration object containing model parameters. + + Inputs: + - **batch** (dict) - A dictionary containing structural information: + - **ref_structure.positions** (Tensor) - Tensor of atomic positions. + - **ref_structure.mask** (Tensor) - Tensor of masks indicating valid atoms. + - **ref_structure.element** (Tensor) - Tensor of atomic elements. + - **ref_structure.charge** (Tensor) - Tensor of atomic charges. + - **ref_structure.atom_name_chars** (Tensor) - Tensor of atomic name characters. + + Outputs: + - **act** (Tensor) - Per-atom conditioning information. + - **pair_act** (Tensor) - Pairwise conditioning information. + """ + + def __init__(self, config, dtype=ms.float32): + super().__init__() + self.c = config + self.linear1 = nn.Dense(3, self.c.per_atom_channels, has_bias=False) + self.linear2 = nn.Dense(1, self.c.per_atom_channels, has_bias=False) + self.linear3 = nn.Dense(128, self.c.per_atom_channels, has_bias=False) + self.linear4 = nn.Dense(1, self.c.per_atom_channels, has_bias=False) + self.linear5 = nn.Dense(256, self.c.per_atom_channels, has_bias=False) + self.linear_row_act = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False) + self.linear_col_act = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False) + self.linear_pair_act1 = nn.Dense( + 3, self.c.per_atom_pair_channels, has_bias=False) + self.linear_pair_act2 = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False) + + @ms.jit + def construct(self, batch): + # Compute per-atom single conditioning + # Shape (num_tokens, num_dense, channels) + act = self.linear1(batch.ref_structure.positions) + act += self.linear2(batch.ref_structure.mask[:, :, None]) + # Element is encoded as atomic number if the periodic table, so + # 128 should be fine. + act += self.linear3(ops.one_hot(batch.ref_structure.element, 128, + Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)).astype(act.dtype)) + act += self.linear4(ops.arcsinh(batch.ref_structure.charge) + [:, :, None]) + # Characters are encoded as ASCII code minus 32, so we need 64 classes, + # to encode all standard ASCII characters between 32 and 96. + atom_name_chars_1hot = ops.one_hot(batch.ref_structure.atom_name_chars, 64, + Tensor(1.0, ms.float32), Tensor(0.0, ms.float32)).astype(act.dtype) + num_token, num_dense, _ = act.shape + act += self.linear5(atom_name_chars_1hot.reshape(num_token, num_dense, -1)) + act *= batch.ref_structure.mask[:, :, None] + + # Compute pair conditioning + # shape (num_tokens, num_dense, num_dense, channels) + # Embed single features + row_act = self.linear_row_act(ms.ops.relu(act)) + col_act = self.linear_col_act(ms.ops.relu(act)) + pair_act = row_act[:, :, None, :] + col_act[:, None, :, :] + + # Embed pairwise offsets + pair_act += self.linear_pair_act1(batch.ref_structure.positions[:, :, None, :] + - batch.ref_structure.positions[:, None, :, :]) + # Embed pairwise inverse squared distances + sq_dists = ops.sum(ops.square(batch.ref_structure.positions[:, :, None, :] + - batch.ref_structure.positions[:, None, :, :]), dim=-1) + pair_act += self.linear_pair_act2(1.0 / (1 + sq_dists[:, :, :, None])) + return act, pair_act + +@dataclass +class AtomCrossAttEncoderOutput: + def __init__( + self, + token_act, # (num_tokens, ch) + skip_connection, # (num_subsets, num_queries, ch) + queries_mask, # (num_subsets, num_queries) + queries_single_cond, # (num_subsets, num_queries, ch) + keys_mask, # (num_subsets, num_keys) + keys_single_cond, # (num_subsets, num_keys, ch) + pair_cond, # (num_subsets, num_queries, num_keys, ch) + ): + self.token_act = token_act + self.skip_connection = skip_connection + self.queries_mask = queries_mask + self.queries_single_cond = queries_single_cond + self.keys_mask = keys_mask + self.keys_single_cond = keys_single_cond + self.pair_cond = pair_cond + + +class AtomCrossAttEncoder(nn.Cell): + """Cross-attention on flat atom subsets and mapping to per-token features. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object with initialization settings. + name (str): Name of the module. + cond_channels (int): Number of conditioning channels. Default: ``384``. + with_cond (bool): Whether to include conditioning layers. Default: ``True``. + + Inputs: + - **token_atoms_act** (ms.Tensor): Tensor representing token atom activations. + - **trunk_single_cond** (ms.Tensor): Tensor representing single token conditioning. + - **trunk_pair_cond** (ms.Tensor): Tensor representing pair token conditioning. + - **batch** (feat_batch.Batch) : Batch of input data. + + Outputs: + - **token_act** (ms.Tensor): Activations for tokens after processing. + - **skip_connection** (ms.Tensor): Skip connection tensor for token queries. + - **queries_mask** (ms.Tensor): Mask for token queries. + - **queries_single_cond** (ms.Tensor): Single conditioning for token queries. + - **keys_mask** (ms.Tensor): Mask for token keys. + - **keys_single_cond** (ms.Tensor): Single conditioning for token keys. + - **pair_cond** (ms.Tensor): Pair conditioning tensor. + """ + + def __init__(self, config, global_config, name, cond_channels=384, with_cond=True, dtype=ms.float32): + super().__init__() + self.c = config + self.with_cond = with_cond + self.dtype = dtype + in_channels = 1 + self._per_atom_conditioning = _PerAtomConditioning(config, dtype=dtype) + if self.with_cond: + self._embed_trunk_single_cond = nn.Dense( + cond_channels, self.c.per_atom_channels, weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._lnorm_trunk_single_cond = bm.LayerNorm( + (cond_channels,), create_beta=False, gamma_init="ones", dtype=dtype) + + self._atom_positions_to_features = nn.Dense( + 3, self.c.per_atom_channels, has_bias=False, dtype=dtype) + + self._embed_trunk_pair_cond = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._lnorm_trunk_pair_cond = bm.LayerNorm( + (self.c.per_atom_channels,), create_beta=False, gamma_init="ones", dtype=dtype) + + self._single_to_pair_cond_row = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._single_to_pair_cond_col = nn.Dense( + self.c.per_atom_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + + self._embed_pair_offsets = nn.Dense( + 3, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + # self._embed_pair_offsets = bm.CustomDense(3, self.c.per_atom_pair_channels, use_bias=False, ndim=4, dtype=dtype) + self._embed_pair_distances = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._embed_pair_offsets_valid = nn.Dense( + 1, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + + self._pair_mlp_1 = nn.Dense( + self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._pair_mlp_2 = nn.Dense( + self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, has_bias=False, dtype=dtype) + self._pair_mlp_3 = nn.Dense(self.c.per_atom_pair_channels, self.c.per_atom_pair_channels, + weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self.relu = nn.ReLU() + self._project_atom_features_for_aggr = nn.Dense( + self.c.per_atom_channels, self.c.per_token_channels, has_bias=False, dtype=dtype) + + self._atom_transformer_encoder = diffusion_transformer.CrossAttTransformer( + self.c.atom_transformer, global_config, in_shape=[ + self.c.per_atom_channels, self.c.per_atom_pair_channels], dtype=dtype + ) + + def construct( + self, + token_atoms_act, # (num_tokens, max_atoms_per_token, 3) + trunk_single_cond, # (num_tokens, ch) + trunk_pair_cond, # (num_tokens, num_tokens, ch) + batch, # : feat_batch.Batch, + ): + # Compute single conditioning from atom meta data and convert to queries + # layout. + # (num_subsets, num_queries, channels) + token_atoms_single_cond, _ = self._per_atom_conditioning( + batch) # (num_res, max_atoms_per_token, 128) + # (num_tokens, max_atoms_per_token) + token_atoms_mask = batch.predicted_structure_info.atom_mask + queries_single_cond = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_single_cond, + layout_axes=(-3, -2), + ) # (num_subsets, num_queries, ch) + queries_mask = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_mask, + layout_axes=(-2, -1), + ) # (num_subsets, num_queries) + + # If provided, broadcast single conditioning from trunk to all queries + if trunk_single_cond is not None: + # (num_tokens, ch) -> (num_tokens, ch) -> (num_tokens, per_atom_channels:128) + trunk_single_cond = self._embed_trunk_single_cond( + self._lnorm_trunk_single_cond( + trunk_single_cond) + ) + # (num_subsets, num_queries, ch) + queries_single_cond += atom_layout.convert_ms( + batch.atom_cross_att.tokens_to_queries, + trunk_single_cond, + layout_axes=(-2,), + ) + + if token_atoms_act is None: + # if no token_atoms_act is given (e.g. begin of evoformer), we use the + # static conditioning only + # (num_subsets, num_queries, ch) + queries_act = queries_single_cond + else: + # Convert token_atoms_act to queries layout and map to per_atom_channels + # (num_subsets, num_queries, channels) + queries_act = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atoms_act, + layout_axes=(-3, -2), + ) + queries_act = self._atom_positions_to_features( + queries_act) + queries_act *= queries_mask[..., None] + queries_act += queries_single_cond + + # Gather the keys from the queries. + keys_single_cond = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_single_cond, layout_axes=( + -3, -2), + ) + keys_mask = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_mask, layout_axes=( + -2, -1) + ) + + # Embed single features into the pair conditioning. + # shape (num_subsets, num_queries, num_keys, ch) + row_act = self._single_to_pair_cond_row( + self.relu(queries_single_cond)) + pair_cond_keys_input = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, queries_single_cond, layout_axes=( + -3, -2), + ) + col_act = self._single_to_pair_cond_col( + self.relu(pair_cond_keys_input)) + pair_act = row_act[:, :, None, :] + col_act[:, None, :, :] + + if trunk_pair_cond is not None: + # If provided, broadcast the pair conditioning for the trunk (evoformer + # pairs) to the atom pair activations. This should boost ligands, but also + # help for cross attention within proteins, because we always have atoms + # from multiple residues in a subset. + # Map trunk pair conditioning to per_atom_pair_channels + # (num_tokens, num_tokens, per_atom_pair_channels) + trunk_pair_cond = self._embed_trunk_pair_cond( + self._lnorm_trunk_pair_cond( + trunk_pair_cond) + ) + + # Create the GatherInfo into a flattened trunk_pair_cond from the + # queries and keys gather infos. + num_tokens = trunk_pair_cond.shape[0] + # (num_subsets, num_queries) + tokens_to_queries = batch.atom_cross_att.tokens_to_queries + # (num_subsets, num_keys) + tokens_to_keys = batch.atom_cross_att.tokens_to_keys + # (num_subsets, num_queries, num_keys) + + # Gather the conditioning and add it to the atom-pair activations. + gather_idxs = Tensor(num_tokens * tokens_to_queries.gather_idxs[:, :, None] + + tokens_to_keys.gather_idxs[:, None, :]) + gather_mask = Tensor(np.array(tokens_to_queries.gather_mask[:, :, None], dtype=bool) & + np.array(tokens_to_keys.gather_mask[:, None, :], dtype=bool)) + input_shape = Tensor((num_tokens, num_tokens)) + trunk_pair_to_atom_pair = atom_layout.GatherInfo(gather_idxs=gather_idxs, + gather_mask=gather_mask, + input_shape=input_shape) + + + pair_act += atom_layout.convert_ms( + trunk_pair_to_atom_pair, trunk_pair_cond, layout_axes=(-3, -2) + ) + + # Embed pairwise offsets + queries_ref_pos = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + batch.ref_structure.positions, + layout_axes=(-3, -2), + ) + queries_ref_space_uid = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + batch.ref_structure.ref_space_uid, + layout_axes=(-2, -1), + ) + keys_ref_pos = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, + queries_ref_pos, + layout_axes=(-3, -2), + ) + keys_ref_space_uid = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_keys, + batch.ref_structure.ref_space_uid, + layout_axes=(-2, -1), + ) + + offsets_valid = ( + queries_ref_space_uid[:, :, None] == keys_ref_space_uid[:, None, :] + ) + offsets = queries_ref_pos[:, :, None, :] - keys_ref_pos[:, None, :, :] + pair_act += (self._embed_pair_offsets(offsets) + * offsets_valid[:, :, :, None]) + + # Embed pairwise inverse squared distances + sq_dists = ops.sum(ops.square(offsets), dim=-1) + pair_act += ( + self._embed_pair_distances(1.0 / (1 + sq_dists[:, :, :, None])) + * offsets_valid[:, :, :, None] + ) + + # Embed offsets valid mask + pair_act += self._embed_pair_offsets_valid( + offsets_valid[:, :, :, None].astype(ms.float32)) + + # Run a small MLP on the pair acitvations + pair_act2 = self._pair_mlp_1(self.relu(pair_act)) + pair_act2 = self._pair_mlp_2(self.relu(pair_act2)) + pair_act += self._pair_mlp_3(self.relu(pair_act2)) + + # Run the atom cross attention transformer. + # (num_subsets, num_queries, ch) + queries_act = self._atom_transformer_encoder( + queries_act=queries_act, + queries_mask=queries_mask, + queries_to_keys=batch.atom_cross_att.queries_to_keys, + keys_mask=keys_mask, + queries_single_cond=queries_single_cond, + keys_single_cond=keys_single_cond, + pair_cond=pair_act, + ) + queries_act *= queries_mask[..., None] + skip_connection = queries_act + + # convert back to token-atom layout and aggregate to tokens + queries_act = self._project_atom_features_for_aggr(queries_act) + token_atoms_act = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_token_atoms, + queries_act, + layout_axes=(-3, -2), + ) + token_act = utils.mask_mean( + token_atoms_mask[..., None], self.relu(token_atoms_act), axis=-2 + ) + + return AtomCrossAttEncoderOutput( + token_act=token_act, + # (num_subsets, num_queries, ch) + skip_connection=skip_connection, + # (num_subsets, num_queries) + queries_mask=queries_mask, + # (num_subsets, num_queries, ch) + queries_single_cond=queries_single_cond, + # (num_subsets, num_keys) + keys_mask=keys_mask, + # (num_subsets, num_keys, ch) + keys_single_cond=keys_single_cond, + # (num_subsets, num_queries, num_keys, ch) + pair_cond=pair_act, + ) + +@dataclass +class AtomCrossAttDecoderConfig(base_config.BaseConfig): + per_token_channels: int = 768 + per_atom_channels: int = 128 + per_atom_pair_channels: int = 16 + atom_transformer: diffusion_transformer.CrossAttTransformer.Config = ( + base_config.autocreate(num_intermediate_factor=2, num_blocks=3) + ) + + +class AtomCrossAttDecoder(nn.Cell): + """Mapping to per-atom features and self-attention on subsets. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object with additional parameters. + name (str): Name of the decoder. Default: ``None``. + + Inputs: + - **token_act** (Tensor) - Tensor representing token activations. + - **enc** (AtomCrossAttEncoderOutput) - Output from the encoder containing necessary features and masks. + - **batch** (feat_batch.Batch) - Batch containing atom cross attention features. + + Outputs: + - **position_update** (Tensor) - Tensor representing the updated positions after processing. + """ + + def __init__(self, config, global_config, name, dtype=ms.float32): + super().__init__() + self.c = config + self._project_token_features_for_broadcast = nn.Dense( + self.c.per_token_channels, self.c.per_atom_channels, has_bias=False, dtype=dtype) + self._atom_features_layer_norm = bm.LayerNorm( + (self.c.per_atom_channels,), create_beta=False, gamma_init="ones", dtype=dtype) + self._atom_features_to_position_update = nn.Dense( + self.c.per_atom_channels, 3, weight_init=global_config.final_init, has_bias=False, dtype=dtype) + self._atom_transformer_decoder = diffusion_transformer.CrossAttTransformer( + self.c.atom_transformer, global_config, in_shape=[ + self.c.per_atom_channels, self.c.per_atom_pair_channels], dtype=dtype + ) + + # @ms.jit + def construct( + self, + token_act, # (num_tokens, ch) + enc, + batch, + ): # (num_tokens, max_atoms_per_token, 3) + # map per-token act down to per_atom channels + token_act = self._project_token_features_for_broadcast(token_act) + # Broadcast to token-atoms layout and convert to queries layout. + num_token, max_atoms_per_token = ( + batch.atom_cross_att.queries_to_token_atoms.shape + ) + token_atom_act = ops.broadcast_to( + token_act[:, None, :], + (num_token, max_atoms_per_token, self.c.per_atom_channels), + ) + queries_act = atom_layout.convert_ms( + batch.atom_cross_att.token_atoms_to_queries, + token_atom_act, + layout_axes=(-3, -2), + ) + queries_act += enc.skip_connection + queries_act *= enc.queries_mask[..., None] + + # Run the atom cross attention transformer. + queries_act = self._atom_transformer_decoder( + queries_act=queries_act, + queries_mask=enc.queries_mask, + queries_to_keys=batch.atom_cross_att.queries_to_keys, + keys_mask=enc.keys_mask, + queries_single_cond=enc.queries_single_cond, + keys_single_cond=enc.keys_single_cond, + pair_cond=enc.pair_cond, + ) + + queries_act *= enc.queries_mask[..., None] + queries_position_update = self._atom_features_to_position_update( + self._atom_features_layer_norm(queries_act) + ) + position_update = atom_layout.convert_ms( + batch.atom_cross_att.queries_to_token_atoms, + queries_position_update, + layout_axes=(-3, -2), + ) + return position_update diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py new file mode 100644 index 000000000..df6375571 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/confidence_head.py @@ -0,0 +1,293 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +"""Confidence Head.""" +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops, Tensor +from alphafold3.model import base_config +from alphafold3.model import model_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import utils +from alphafold3.model.diffusion import modules +from alphafold3.model.diffusion import template_modules +import numpy as np + +def _safe_norm(x, keepdims, axis, eps=1e-8): + return ops.sqrt(eps + ops.sum(ops.square(x), dim=axis, keepdims=keepdims)) + + +class ConfidenceHead(nn.Cell): + """Head to predict the distance errors in a prediction. + + Args: + config (ConfidenceHead.Config): Configuration for the ConfidenceHead module. + global_config (base_config.BaseConfig): Global configuration for the model. + pair_shape (tuple): Shape of the pair features. + single_shape (tuple): Shape of the single features. + atom_shape (tuple): Shape of the atom features. + feat_in_channel (int): Number of input channels for feature projections. + out_channel (int): Number of output channels for feature projections. + + Inputs: + - **dense_atom_positions** (Tensor): [N_res, N_atom, 3] array of atom positions. + - **embeddings** (dict): Dictionary containing pair, single, and target features. + - **seq_mask** (Tensor): Sequence mask indicating valid residues. + - **token_atoms_to_pseudo_beta** (Tensor): Pseudo beta information for atom tokens. + - **asym_id** (Tensor): Asym ID token features. + + Outputs: + - **predicted_lddt** (Tensor): Predicted LDDT scores for each residue. + - **predicted_experimentally_resolved** (Tensor): Predicted experimental resolution scores. + - **full_pde** (Tensor): Full predicted distance errors. + - **average_pde** (Tensor): Average predicted distance errors. + - **pae_outputs** (dict): Additional outputs from PAE (Predicted Alignment Error) calculations. + """ + @dataclass + class PAEConfig(base_config.BaseConfig): + max_error_bin: float = 31.0 + num_bins: int = 64 + + @dataclass + class Config(base_config.BaseConfig): + """Configuration for ConfidenceHead.""" + + pairformer: modules.PairFormerIteration.Config = base_config.autocreate( + single_attention=base_config.autocreate(), + single_transition=base_config.autocreate(), + num_layer=4, + ) + max_error_bin: float = 31.0 + num_plddt_bins: int = 50 + num_bins: int = 64 + no_embedding_prob: float = 0.2 + pae: 'ConfidenceHead.PAEConfig' = base_config.autocreate() + dgram_features: template_modules.DistogramFeaturesConfig = ( + base_config.autocreate() + ) + + def __init__(self, config, global_config, pair_shape, single_shape, atom_shape, feat_in_channel, out_channel, dtype=ms.float32): + super().__init__() + self.dtype = dtype + self.config = config + self.global_config = global_config + in_channel = pair_shape[-1] + self.left_target_feat_project = nn.Dense( + feat_in_channel, out_channel, has_bias=False, dtype=dtype) + self.right_target_feat_project = nn.Dense( + feat_in_channel, out_channel, has_bias=False, dtype=dtype) + self.distogram_feat_project = nn.Dense( + template_modules.DistogramFeaturesConfig.num_bins, out_channel, has_bias=False, dtype=dtype) + self.pairformer_block = ms.nn.CellList( + [ + modules.PairFormerIteration( + self.config.pairformer, global_config, pair_shape, single_shape, with_single=True, dtype=dtype + ) + for _ in range(self.config.pairformer.num_layer) + ] + ) + self.left_half_distance_logits = nn.Dense( + pair_shape[-1], self.config.num_bins, has_bias=False, dtype=ms.float32) + self.logits_ln = bm.LayerNorm(pair_shape, dtype=ms.float32) + self.pae_logits = nn.Dense( + pair_shape[-1], self.config.pae.num_bins, has_bias=False, dtype=ms.float32) + self.pae_logits_ln = bm.LayerNorm(pair_shape, dtype=ms.float32) + self.plddt_logits = bm.CustomDense( + single_shape[-1], (atom_shape[-2], self.config.num_plddt_bins), ndim=2, dtype=ms.float32) + self.plddt_logits_ln = bm.LayerNorm(single_shape, dtype=ms.float32) + self.experimentally_resolved_logits = bm.CustomDense( + single_shape[-1], (atom_shape[-2], 2), ndim=2, dtype=ms.float32) + self.experimentally_resolved_ln = bm.LayerNorm(single_shape, dtype=ms.float32) + + def _embed_features(self, dense_atom_positions, token_atoms_to_pseude_beta, + pair_mask, pair_act, target_feat): + out = self.left_target_feat_project(target_feat) + out2 = self.right_target_feat_project(target_feat)[:, None] + out = out + out2 + positions = atom_layout.convert_ms( + token_atoms_to_pseude_beta, + dense_atom_positions, + layout_axes=(-3, -2), + ) + dgram = template_modules.dgram_from_positions( + positions, self.config.dgram_features, dtype=ms.float32 + ) + dgram *= pair_mask[..., None] + out += self.distogram_feat_project(dgram.astype(pair_act.dtype)) + return out + + def construct(self, dense_atom_positions, embeddings, seq_mask, + token_atoms_to_pseudo_beta, asym_id): + seq_mask_cast = seq_mask.astype(self.dtype) + pair_mask = seq_mask_cast[:, None] * seq_mask_cast[None, :].astype(self.dtype) + pair_act = embeddings['pair'].astype(self.dtype) + single_act = embeddings['single'].astype(self.dtype) + target_feat = embeddings['target_feat'].astype(self.dtype) + num_residues = seq_mask.shape[0] + num_pair_channels = pair_act.shape[2] + pair_act += self._embed_features( + dense_atom_positions, + token_atoms_to_pseudo_beta, + pair_mask, + pair_act, + target_feat, + ) + + for i in range(self.config.pairformer.num_layer): + pair_act, single_act = self.pairformer_block[i]( + pair_act, pair_mask, single_act, seq_mask) + pair_act = pair_act.astype(ms.float32) + assert pair_act.shape == ( + num_residues, num_residues, num_pair_channels) + + # Produce logits to predict a distogram of pairwise distance errors + # between the input prediction and the ground truth. + # Shape (num_res, num_res, num_bins) + left_distance_logits = self.left_half_distance_logits( + self.logits_ln(pair_act)) + right_distance_logits = left_distance_logits + distance_logits = left_distance_logits + ops.swapaxes( # Symmetrize. + right_distance_logits, -2, -3 + ) + # Shape (num_bins,) + distance_breaks = ops.linspace( + 0.0, self.config.max_error_bin, self.config.num_bins - 1 + ) + + step = distance_breaks[1] - distance_breaks[0] + + # Add half-step to get the center + bin_centers = distance_breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = ops.concat( + [bin_centers, bin_centers[-1:] + step], axis=0 + ) + + distance_probs = ops.softmax(distance_logits, axis=-1) + + pred_distance_error = ( + ops.sum(distance_probs * bin_centers, dim=-1) * pair_mask + ) + average_pred_distance_error = ops.sum( + pred_distance_error, dim=[-2, -1] + ) / ops.sum(pair_mask, dim=[-2, -1]) + + # Predicted aligned error + pae_outputs = {} + # Shape (num_res, num_res, num_bins) + pae_logits = self.pae_logits(self.pae_logits_ln(pair_act)) + # Shape (num_bins,) + pae_breaks = ops.linspace( + 0.0, self.config.pae.max_error_bin, self.config.pae.num_bins - 1 + ) + step = pae_breaks[1] - pae_breaks[0] + # Add half-step to get the center + bin_centers = pae_breaks + step / 2 + # Add a catch-all bin at the end. + bin_centers = ops.concat( + [bin_centers, bin_centers[-1:] + step], axis=0 + ) + pae_probs = ops.softmax(pae_logits, axis=-1) + + seq_mask_bool = seq_mask.astype(bool) + pair_mask_bool = seq_mask_bool[:, None] * seq_mask_bool[None, :] + pae = ops.sum(pae_probs * bin_centers, dim=-1) * pair_mask_bool + pae_outputs.update({ + 'full_pae': pae, + }) + + # The pTM is computed outside of bfloat16 context. + tmscore_adjusted_pae_global, tmscore_adjusted_pae_interface = ( + self._get_tmscore_adjusted_pae( + asym_id=asym_id, + seq_mask=seq_mask, + pair_mask=pair_mask_bool, + bin_centers=bin_centers, + pae_probs=pae_probs, + ) + ) + pae_outputs.update({ + 'tmscore_adjusted_pae_global': tmscore_adjusted_pae_global, + 'tmscore_adjusted_pae_interface': tmscore_adjusted_pae_interface, + }) + single_act = single_act.astype('float32') + + # pLDDT + # Shape (num_res, num_atom, num_bins) + plddt_logits = self.plddt_logits(self.plddt_logits_ln(single_act)) + + bin_width = 1.0 / self.config.num_plddt_bins + bin_centers = ops.arange(0.5 * bin_width, 1.0, bin_width) + predicted_lddt = ops.sum( + ops.softmax(plddt_logits, axis=-1) * bin_centers, dim=-1 + ) + predicted_lddt = predicted_lddt * 100.0 + + # Experimentally resolved + # Shape (num_res, num_atom, 2) + experimentally_resolved_logits = self.experimentally_resolved_logits( + self.experimentally_resolved_ln(single_act) + ) + + predicted_experimentally_resolved = ops.softmax( + experimentally_resolved_logits, axis=-1 + )[..., 1] + + return { + 'predicted_lddt': predicted_lddt, + 'predicted_experimentally_resolved': predicted_experimentally_resolved, + 'full_pde': pred_distance_error, + 'average_pde': average_pred_distance_error, + **pae_outputs, + } + + def _get_tmscore_adjusted_pae( + self, asym_id, seq_mask, pair_mask, bin_centers, pae_probs, + ): + def get_tmscore_adjusted_pae(num_interface_tokens, bin_centers, pae_probs): + # Clip to avoid negative/undefined d0. + clipped_num_res = ops.maximum(num_interface_tokens, 19) + + # Compute d_0(num_res) as defined by TM-score, eqn. (5) in + # http://zhanglab.ccmb.med.umich.edu/papers/2004_3.pdf + # Yang & Skolnick "Scoring function for automated + # assessment of protein structure template quality" 2004. + d0 = 1.24 * (clipped_num_res - 15) ** (1.0 / 3) - 1.8 + + # Make compatible with [num_tokens, num_tokens, num_bins] + d0 = d0[:, :, None] + bin_centers = bin_centers[None, None, :] + + # TM-Score term for every bin. + tm_per_bin = 1.0 / (1 + ops.square(bin_centers) / ops.square(d0)) + # E_distances tm(distance). + predicted_tm_term = ops.sum(pae_probs * tm_per_bin, dim=-1) + return predicted_tm_term + + # Interface version + x = asym_id[None, :] == asym_id[:, None] + num_chain_tokens = ops.sum(x * pair_mask, dim=-1) + num_interface_tokens = num_chain_tokens[None, + :] + num_chain_tokens[:, None] + # Don't double-count within a single chain + num_interface_tokens -= x * (num_interface_tokens // 2) + num_interface_tokens = num_interface_tokens * pair_mask + + num_global_tokens = ops.full( + size=pair_mask.shape, fill_value=seq_mask.sum() + ).astype(ms.int32) + + global_apae = get_tmscore_adjusted_pae( + num_global_tokens, bin_centers, pae_probs + ) + interface_apae = get_tmscore_adjusted_pae( + num_interface_tokens, bin_centers, pae_probs + ) + return global_apae, interface_apae diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py new file mode 100644 index 000000000..c3efe67de --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_head.py @@ -0,0 +1,326 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Diffusion Head.""" + +from dataclasses import dataclass +from collections.abc import Callable +import math +import numpy as np +import mindspore as ms +from mindspore import mint +from mindspore import nn, ops, Tensor +from mindchemistry.e3.utils import Ncon +from alphafold3.constants import residue_names +from alphafold3.model import base_config +from alphafold3.model import feat_batch +from alphafold3.model import model_config +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import utils +from alphafold3.model.diffusion import atom_cross_attention +from alphafold3.model.diffusion import diffusion_transformer +from alphafold3.model.diffusion import featurization + + +# Carefully measured by averaging multimer training set. +SIGMA_DATA = 16.0 + + +def fourier_embeddings(x, weight, bias, dim): + return mint.cos(2 * math.pi * (x[..., None] * weight + bias)) + +def random_rotation(key): + # Create a random rotation (Gram-Schmidt orthogonalization of two + # random normal vectors) + np.random.seed(key) + v0, v1 = ms.Tensor(np.random.normal(0, 1, (2, 3)), dtype=ms.float32) + e0 = v0 / mint.maximum(1e-10, mint.norm(v0)) + v1 = v1 - e0 * mint.matmul(v1, e0) + e1 = v1 / mint.maximum(1e-10, mint.norm(v1)) + e2 = mint.cross(e0, e1) + return mint.stack([e0, e1, e2]) + +def random_augmentation(rng_key, positions, mask): + """Apply random rigid augmentation. + Args: + rng_key: random key + positions: atom positions of shape (, 3) + mask: per-atom mask of shape (,) + Returns: + Transformed positions with the same shape as input positions. + """ + center = utils.mask_mean( + mask.unsqueeze(-1), positions, axis=(-2, -3), keepdims=True, eps=1e-6 + ).astype(ms.float32) + rot = random_rotation(rng_key) + np.random.seed(rng_key) + translation = ms.Tensor(np.random.normal(0, 1, (3,)), dtype=ms.float32) + + augmented_positions = ( + mint.einsum( + '...i,ij->...j', + (positions - center).astype(ms.float32), + rot, + ) + + translation + ) + return augmented_positions * mask[..., None] + +def noise_schedule(t, smin=0.0004, smax=160.0, p=7): + return ( + SIGMA_DATA + * (smax ** (1 / p) + t * (smin ** (1 / p) - smax ** (1 / p))) ** p + ) + +@dataclass +class ConditioningConfig(base_config.BaseConfig): + pair_channel: int + seq_channel: int + prob: float + +@dataclass +class SampleConfig(base_config.BaseConfig): + steps: int + gamma_0: float = 0.8 + gamma_min: float = 1.0 + noise_scale: float = 1.003 + step_scale: float = 1.5 + num_samples: int = 1 + +class DiffusionHead(nn.Cell): + """Denoising Diffusion Head. + + Args: + config (Config): Configuration object containing parameters for the diffusion head. + global_config (GlobalConfig): Global configuration object containing shared parameters. + in_shape (tuple): Input shape for the module. + max_relative_chain (int): Maximum number of relative chains for positional encoding. Default: ``2``. + max_relative_idx (int): Maximum relative index for positional encoding. Default: ``32``. + + Inputs: + - **positions_noisy** (Tensor) - Noisy atomic positions tensor. + - **noise_level** (Tensor) - Tensor representing the noise level. + - **batch** (Batch) - Batch of input data containing token features and structure information. + - **embeddings** (dict) - Dictionary of embeddings for single and pair features. + - **use_conditioning** (bool) - Flag to enable or disable conditioning. + + Outputs: + - **position_update** (Tensor) - Refined atomic positions tensor. + """ + + class Config( + atom_cross_attention.AtomCrossAttEncoderConfig, + atom_cross_attention.AtomCrossAttDecoderConfig, + ): + """Configuration for DiffusionHead.""" + eval_batch_size: int = 5 + eval_batch_dim_shard_size: int = 5 + conditioning: ConditioningConfig = base_config.autocreate( + prob=0.8, pair_channel=128, seq_channel=384 + ) + eval: SampleConfig = base_config.autocreate( + num_samples=5, + steps=200, + ) + transformer: diffusion_transformer.Transformer.Config = ( + base_config.autocreate() + ) + + def __init__(self, config, global_config, in_shape, max_relative_chain=2, max_relative_idx=32, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.dtype = dtype + in_channel = in_shape[-1] + self.max_relative_chain = max_relative_chain + self.max_relative_idx = max_relative_idx + + # _conditioning modules + in_channel_pair = in_channel + 4 * self.max_relative_idx + 4 + 2 * self.max_relative_chain + 2 + 1 + self.pair_cond_initial_norm = bm.LayerNorm( + in_shape[:-1] + (in_channel_pair,), + create_beta=False, gamma_init="ones", + name='pair_cond_initial_norm', dtype=dtype) + self.pair_cond_initial_projection = nn.Dense(in_channel_pair, self.config.conditioning.pair_channel, has_bias=False, dtype=ms.float32) + self.transition_block1 = diffusion_transformer.TransitionBlock( + global_config, in_channel, 2, with_single_cond=False, name=f'pair_transition_1', dtype=dtype) + self.transition_block2 = diffusion_transformer.TransitionBlock( + global_config, in_channel, 2, with_single_cond=False, name=f'pair_transition_2', dtype=dtype) + in_channel_single = self.config.conditioning.seq_channel * 2 \ + + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP * 2 + 1 + self.single_cond_initial_norm = bm.LayerNorm( + in_shape[:-1] + (in_channel_single,), + create_beta=False, gamma_init="ones", + name='single_cond_initial_norm', dtype=dtype) + self.single_cond_initial_projection = nn.Dense(in_channel_single, self.config.conditioning.seq_channel, has_bias=False, dtype=dtype) + self.num_noise_embedding = 256 + self.layer_norm_noise = bm.LayerNorm( + in_shape[:-1]+(self.num_noise_embedding,), + create_beta=False, gamma_init="ones", + name='noise_embedding_initial_norm', dtype=dtype) + self.linear_noise = nn.Dense(self.num_noise_embedding, + self.config.conditioning.seq_channel, has_bias=False, dtype=dtype) + self.single_transition1 = diffusion_transformer.TransitionBlock( + global_config, self.config.conditioning.seq_channel, 2, ndim=2, with_single_cond=False, name=f'single_transition_1', dtype=dtype) + self.single_transition2 = diffusion_transformer.TransitionBlock( + global_config, self.config.conditioning.seq_channel, 2, ndim=2, with_single_cond=False, name=f'single_transition_2', dtype=dtype) + + # modules + self.layer_norm_act = bm.LayerNorm( + (in_channel,)+(self.config.conditioning.seq_channel,), + create_beta=False, gamma_init="ones", + name='single_cond_embedding_norm', dtype=dtype) + self.linear_act = nn.Dense(self.config.conditioning.seq_channel, + self.config.per_token_channels, has_bias=False, dtype=dtype) + self.layer_norm_out = bm.LayerNorm( + in_shape[:-1]+(self.config.per_token_channels,), + create_beta=False, gamma_init="ones", + name='output_norm', dtype=dtype) + self.atom_cross_att_encoder = atom_cross_attention.AtomCrossAttEncoder( + self.config, self.global_config, "", dtype=dtype + ) + self.transformer = diffusion_transformer.Transformer( + self.config.transformer, self.global_config, in_shape[:-1] + (self.config.conditioning.seq_channel * 2,), + in_shape, using_pair_act=True, dtype=dtype + ) + self.atom_cross_att_decoder = atom_cross_attention.AtomCrossAttDecoder( + self.config, self.global_config, '', dtype=dtype + ) + + @ms.jit + def _conditioning(self, batch, embeddings, noise_level, use_conditioning, weight, bias): + single_embedding = use_conditioning * embeddings['single'] + pair_embedding = use_conditioning * embeddings['pair'] + rel_features = featurization.create_relative_encoding( + batch.token_features, max_relative_idx=self.max_relative_idx, max_relative_chain=self.max_relative_chain + ).astype(pair_embedding.dtype) + features_2d = mint.concat([pair_embedding, rel_features], dim=-1) + pair_cond = self.pair_cond_initial_projection(self.pair_cond_initial_norm(features_2d.astype(ms.float32))).astype(pair_embedding.dtype) #(256,256,267) -> (256,256,128) + pair_cond += self.transition_block1(pair_cond) + pair_cond += self.transition_block2(pair_cond) + + target_feat = embeddings['target_feat'] + features_1d = mint.concat([single_embedding, target_feat.astype(single_embedding.dtype)], dim=-1) + single_cond = self.single_cond_initial_norm(features_1d) + single_cond = self.single_cond_initial_projection(single_cond) #(256,831) -> (256,384) + noise_embedding = fourier_embeddings( + (1 / 4) * mint.log(noise_level / SIGMA_DATA).astype(self.dtype), weight, bias, dim=self.num_noise_embedding + ) + single_cond += self.linear_noise(self.layer_norm_noise(noise_embedding)) #(1,256) -> (1,384) + single_cond += self.single_transition1(single_cond) + single_cond += self.single_transition2(single_cond) + + return single_cond, pair_cond + + def construct(self, positions_noisy, noise_level, batch, embeddings, use_conditioning, weight, bias): + trunk_single_cond, trunk_pair_cond = self._conditioning( + batch=batch, + embeddings=embeddings, + noise_level=noise_level, + use_conditioning=use_conditioning, + weight=weight, + bias=bias + ) + + # Extract features + sequence_mask = batch.token_features.mask + atom_mask = batch.predicted_structure_info.atom_mask + # Position features + act = positions_noisy * atom_mask[..., None] + act = act / mint.sqrt(noise_level**2 + SIGMA_DATA**2) + enc = self.atom_cross_att_encoder(act, embeddings["single"], trunk_pair_cond.astype(ms.float32), batch) + + act = enc.token_act + act += self.linear_act(self.layer_norm_act(trunk_single_cond)) + act = act.astype(ms.float32) + trunk_single_cond = trunk_single_cond.astype(ms.float32) + trunk_pair_cond = trunk_pair_cond.astype(ms.float32) + sequence_mask = sequence_mask.astype(ms.float32) + act = self.transformer(act, trunk_single_cond, sequence_mask, trunk_pair_cond) + act = self.layer_norm_out(act) + position_update = self.atom_cross_att_decoder(act, enc, batch) + skip_scaling = SIGMA_DATA**2 / (noise_level**2 + SIGMA_DATA**2) + out_scaling = ( + noise_level * SIGMA_DATA / mint.sqrt(noise_level**2 + SIGMA_DATA**2) + ) + return ( + skip_scaling * positions_noisy + out_scaling * position_update + ) * atom_mask[..., None] + +def sample(denoising_step, batch, key, config, init_positions=None): + """Sample using denoiser on batch. + + Args: + denoising_step: the denoising function. + batch: the batch + key: random key + config: config for the sampling process (e.g. number of denoising steps, + etc.) + + Returns: + a dict + { + 'atom_positions': ms.Tensor # shape (, 3) + 'mask': ms.Tensor # shape (,) + } + where the are + (num_samples, num_tokens, max_atoms_per_token) + """ + + mask = batch.predicted_structure_info.atom_mask + # get weight and bias from Jax, this two values cannot be randomly generated + weight = ms.Tensor(np.load(f"./src/alphafold3/model/diffusion/random/weight.npy"), dtype=ms.float32) + bias = ms.Tensor(np.load(f"./src/alphafold3/model/diffusion/random/bias.npy"), dtype=ms.float32) + + def apply_denoising_step(carry, noise_level): + key, positions, noise_level_prev = carry + + positions = random_augmentation( + rng_key=key, positions=positions, mask=mask, + ) + gamma = config.gamma_0 * (noise_level > config.gamma_min) + t_hat = noise_level_prev * (1 + gamma) + + noise_scale = config.noise_scale * mint.sqrt(t_hat**2 - noise_level_prev**2) + np.random.seed(key) + noise = noise_scale * ms.Tensor(np.random.normal(0, 1, positions.shape), dtype=ms.float32) + positions_noisy = positions + noise + + positions_denoised = denoising_step(positions_noisy, t_hat, weight=weight, bias=bias) + grad = (positions_noisy - positions_denoised) / t_hat + + d_t = noise_level - t_hat + positions_out = positions_noisy + config.step_scale * d_t * grad + + return (key, positions_out, noise_level), positions_out + + num_samples = config.num_samples + + noise_levels = noise_schedule(mint.linspace(0, 1, config.steps + 1)) + + noise_key, key = key, key + 1 + np.random.seed(noise_key) + if init_positions is None: + init_positions = ms.Tensor(np.random.normal(0, 1, (num_samples,) + mask.shape + (3,)), dtype=ms.float32) + init_positions *= noise_levels[0] + init = (ms.Tensor([key + i for i in range(num_samples)]).reshape((-1,1)), + init_positions, + mint.tile(noise_levels[None, 0], (num_samples,)).reshape((-1,1))) + count = 0 + for noise_level in noise_levels[1:]: + for i in range(num_samples): + temp, _ = apply_denoising_step((count * 10 + i, init[1][i], init[2][i]), noise_level) + init[0][i], init[1][i], init[2][i] = temp + count += 1 + _, positions_out, _ = init + + final_dense_atom_mask = mint.tile(mask[None], (num_samples, 1, 1)) + + return {'atom_positions': positions_out, 'mask': final_dense_atom_mask} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py new file mode 100644 index 000000000..7433dd8fe --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/diffusion_transformer.py @@ -0,0 +1,496 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Diffusion transformer model.""" + +from dataclasses import dataclass +from alphafold3.model import base_config +from alphafold3.utils.gated_linear_unit import gated_linear_unit +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_modules as bm + +from mindspore import mint +import mindspore as ms +from mindspore import nn +from mindchemistry.e3.utils import Ncon + + +class AdaptiveLayernorm(nn.Cell): + """ + If single condition is None, this layer is the same as layernorm. + If single condition is given, the layer is modified from Scalable Diffusion Models with Transformers + https://arxiv.org/abs/2212.09748 + + Args: + num_channels (int): Number of channels in the input tensor. + single_channel (int, optional): Number of channels in the single condition tensor. Required if `with_single_cond` is True. Default: ``None``. + ndim (int, optional): Number of dimensions for the dense layers. Default: ``3``. + with_single_cond (bool, optional): Whether to include the single condition adaptation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to be normalized. + - **single_cond** (Tensor, optional) - Optional single condition tensor used to adapt the normalization parameters. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - The normalized output tensor. + """ + + def __init__(self, num_channels, single_channel=None, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.with_single_cond = with_single_cond + if self.with_single_cond: + self.layernorm = bm.LayerNorm([num_channels], name='layer_norm', + create_gamma=False, create_beta=False, + gamma_init='ones', beta_init='zeros', dtype=ms.float32) + self.single_cond_layer_norm = bm.LayerNorm([single_channel], name='single_cond_layer_norm', + create_beta=False, gamma_init='ones', beta_init='zeros', + dtype=ms.float32) + self.single_cond_scale = bm.CustomDense(single_channel, num_channels, weight_init='zeros', + use_bias=True, bias_init='ones', ndim=ndim, dtype=dtype) + self.single_cond_bias = bm.CustomDense( + single_channel, num_channels, weight_init='zeros', ndim=ndim, dtype=dtype) + else: + self.layernorm = bm.LayerNorm([num_channels], dtype=ms.float32) + + def construct(self, x, single_cond=None): + if not self.with_single_cond: + x = self.layernorm(x) + else: + x = self.layernorm(x) + single_cond = self.single_cond_layer_norm(single_cond) + single_scale = self.single_cond_scale(single_cond) + single_bias = self.single_cond_bias(single_cond) + x = mint.add(mint.mul(mint.sigmoid(single_scale.astype(ms.float32)).astype(x.dtype), x), single_bias) + return x + + +class AdaptiveZeroInit(nn.Cell): + """ + An adaptive initialization layer that combines two conditional linear transformations. + + Args: + global_config: Configuration object containing initialization settings. + in_channels (int): Number of input channels. + out_channels (int): Number of output channels. + single_channels (int, optional): Number of single conditional channels. Default: ``None``. + ndim (int, optional): Number of dimensions for the dense layer input. Default: ``3``. + with_single_cond (bool, optional): Whether to use single conditional transformation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to the layer. + - **single_cond** (Tensor, optional) - Single conditional tensor. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - Output tensor after applying the adaptive initialization. + """ + + def __init__(self, global_config, in_channels, out_channels, single_channels=None, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.with_single_cond = with_single_cond + # linear, change to ones for test + self.cond_linear1 = bm.CustomDense( + in_channels, out_channels, weight_init='ones', ndim=ndim, dtype=dtype) + if self.with_single_cond: + if single_channels is None: + single_channels = in_channels + self.cond_linear2 = bm.CustomDense(single_channels, out_channels, weight_init='zeros', + use_bias=True, bias_init='ones', ndim=ndim, dtype=dtype) # zeros, change to ones for test + self.cond_linear2.bias = ms.Parameter(self.cond_linear2.bias * (-2)) + + def construct(self, x, single_cond=None): + if not self.with_single_cond: + output = self.cond_linear1(x) + else: + output = self.cond_linear1(x) + cond = self.cond_linear2(single_cond) + output = mint.mul(mint.sigmoid(cond.astype(ms.float32)).astype(cond.dtype), output) + return output + + +class TransitionBlock(nn.Cell): + """ + A neural network layer that combines adaptive layer normalization, a gated linear unit (GLU), and adaptive zero initialization to process input data with optional conditional inputs. + + Args: + global_config: Configuration object containing initialization settings. + in_channels (int): Number of input channels. + num_intermediate_factor (int): Factor to determine the number of intermediate channels. + single_channels (int, optional): Number of single conditional channels. Default: ``None``. + ndim (int, optional): Number of dimensions for input tensor. Default: ``3``. + with_single_cond (bool, optional): Whether to use single conditional processing. Default: ``True``. + use_glu_kernel (bool, optional): Whether to use GLU. Default: ``True``. + name (str, optional): Name of the layer. Default: ``''``. + + Inputs: + - **x** (Tensor) - Input tensor to the layer. + - **single_cond** (Tensor, optional) - Single conditional tensor. Required if `with_single_cond` is True. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the TransitionBlock. + """ + + def __init__(self, global_config, in_channels, num_intermediate_factor, single_channels=None, ndim=3, with_single_cond=True, use_glu_kernel=True, name='', dtype=ms.float32): + super().__init__() + self.num_intermediate = num_intermediate_factor * in_channels + if single_channels is None: + single_channels = in_channels + self.adaptive_layernorm = AdaptiveLayernorm( + in_channels, single_channels, ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + self.use_glu_kernel = use_glu_kernel + if self.use_glu_kernel: + self.weights = bm.custom_initializer( + 'relu', [in_channels, self.num_intermediate * 2], dtype=dtype) + self.weights = ms.Parameter(ms.Tensor(self.weights).reshape( + in_channels, 2, self.num_intermediate)) + else: + # relu, change to ones for test + self.linear = bm.CustomDense( + in_channels, self.num_intermediate * 2, weight_init='ones', ndim=3, dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit( + global_config, self.num_intermediate, in_channels, single_channels, ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + + def construct(self, x, single_cond=None): + x = self.adaptive_layernorm(x, single_cond) + if self.use_glu_kernel: + c = gated_linear_unit.gated_linear_unit( + x=x, weight=self.weights.astype(x.dtype), + implementation=None, activation=mint.nn.functional.silu, precision=None + ).astype(x.dtype) + else: + x = self.linear(x) + x0, x1 = ms.ops.split(x, int(x.shape[-1]/2), axis=-1) + c = ms.ops.silu(x0) * x1 + output = self.adaptive_zero_init(c, single_cond) + return output + +@dataclass +class SelfAttentionConfig(base_config.BaseConfig): + num_head: int = 16 + key_dim: int | None = None + value_dim: int | None = None + + +class SelfAttention(nn.Cell): + """ + A self-attention mechanism implementation with adaptive layer normalization and adaptive zero initialization. + + This class implements the self-attention mechanism commonly used in transformer models. It includes adaptive layer normalization for input processing and adaptive zero initialization for the final output. The mechanism computes attention scores using query, key, and value transformations, applies masking, and optionally incorporates pair-wise logits. + + Args: + config: Configuration object containing parameters such as key dimension, value dimension, and number of attention heads. + global_config: Global configuration object for additional settings. + num_channels (int): Number of channels in the input tensor. + in_shape (tuple): Shape of the input tensor. + ndim (int, optional): Number of dimensions for the dense layers. Default: ``3``. + with_single_cond (bool, optional): Whether to include single condition adaptation. Default: ``True``. + + Inputs: + - **x** (Tensor) - Input tensor to the self-attention layer. + - **mask** (Tensor) - Attention mask to apply. + - **single_cond** (Tensor, optional) - Single condition tensor for adaptation. + - **pair_logits** (Tensor, optional) - Additional logits to incorporate into attention scores. + + Outputs: + - **output** (Tensor) - The output tensor after self-attention and adaptive zero initialization. + + Notes: + - The class uses adaptive layer normalization and adaptive zero initialization for processing inputs and outputs. + - The attention mechanism supports optional single condition adaptation and pair-wise logits. + """ + + def __init__(self, config, global_config, num_channels, in_shape, ndim=3, with_single_cond=True, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.adaptive_layernorm = AdaptiveLayernorm(num_channels, int( + num_channels//2), ndim=ndim, with_single_cond=with_single_cond, dtype=dtype) + key_dim = self.config.key_dim if self.config.key_dim is not None else num_channels + value_dim = self.config.value_dim if self.config.value_dim is not None else num_channels + num_head = self.config.num_head + assert key_dim % num_head == 0, f'{key_dim=} % {num_head=} != 0' + assert value_dim % num_head == 0, f'{value_dim=} % {num_head=} != 0' + key_dim = key_dim // num_head + self.key_dim = key_dim + value_dim = value_dim // num_head + qk_shape = (num_head, key_dim) + v_shape = (num_head, value_dim) + self.q_linear = bm.CustomDense(num_channels, qk_shape, use_bias=True, dtype=dtype) + self.k_linear = bm.CustomDense(num_channels, qk_shape, use_bias=False, dtype=dtype) + self.v_linear = bm.CustomDense(num_channels, v_shape, use_bias=False, dtype=dtype) + # weight_init="zeros", change to ones for test + self.linear = bm.CustomDense( + num_channels, num_head * value_dim, weight_init='zeros', dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit(global_config, num_channels, num_channels, int( + num_channels//2), 2, with_single_cond=with_single_cond, dtype=dtype) + self.ncon1 = Ncon([[-2, -1, 1], [-3, -1, 1]]) + self.ncon2 = Ncon([[-2, -1, 2], [2, -2, -3]]) + + def construct(self, x, mask, single_cond, pair_logits): + bias = (1e9 * (mask - 1.0))[..., None, None, :].astype(x.dtype) + x = self.adaptive_layernorm(x, single_cond) + q = self.q_linear(x) + k = self.k_linear(x) + logits = mint.einsum('...qhc,...khc->...hqk', q * self.key_dim ** (-0.5), k) + bias + if pair_logits is not None: + logits += pair_logits # (num_heads, seq_len, seq_len) + weights = mint.softmax(logits, dim=-1) + weights = weights.astype(q.dtype) + v = self.v_linear(x) + weighted_avg = mint.einsum('...hqk,...khc->...qhc', weights, v) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[:-2] + (-1,)) + gate_logits = self.linear(x) + weighted_avg *= mint.sigmoid(gate_logits.astype(ms.float32)).astype(gate_logits.dtype) + output = self.adaptive_zero_init(weighted_avg, single_cond) + return output + + +class Transformer(nn.Cell): + @dataclass + class Config(base_config.BaseConfig): + attention: SelfAttentionConfig = base_config.autocreate() + num_blocks: int = 24 + block_remat: bool = False + super_block_size: int = 4 + num_intermediate_factor: int = 2 + + def __init__(self, config, global_config, in_shape, pair_shape, using_pair_act=False, name="transformer", dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.using_pair_act = using_pair_act + self.act = [] + if using_pair_act: + self.pair_layernorm = bm.LayerNorm(pair_shape, create_beta=False, dtype=ms.float32) + else: + self.pair_layernorm = None + assert self.config.num_blocks % self.config.super_block_size == 0 + self.num_super_blocks = self.config.num_blocks // self.config.super_block_size + self.super_blocks = ms.nn.CellList( + [ + SuperBlock( + config, global_config, self.config.num_blocks, + using_pair_act, in_shape, pair_shape, name, dtype=dtype + ) + for _ in range(self.num_super_blocks) + ] + ) + + @ms.jit + def construct(self, act, single_cond, mask, pair_cond=None): + if pair_cond is None: + pair_act = None + else: + pair_act = self.pair_layernorm(pair_cond) + for i in range(self.num_super_blocks): + act = self.super_blocks[i](act, mask, single_cond, pair_act) + return act + + +class Block(nn.Cell): + def __init__(self, config, global_config, in_shape, dtype=ms.float32): + super().__init__() + self.self_attention = SelfAttention( + config.attention, global_config, in_shape[-1], in_shape, ndim=2, dtype=dtype) + self.transition_block = TransitionBlock(global_config, in_shape[-1], + config.num_intermediate_factor, int(in_shape[-1]//2), ndim=2, dtype=dtype) + + def construct(self, act, mask, single_cond, pair_logits): + act += self.self_attention(act, mask, single_cond, pair_logits) + act += self.transition_block(act, single_cond) + return act + + +class SuperBlock(nn.Cell): + def __init__(self, config, global_config, num_blocks, using_pair_act, in_shape, pair_shape=None, name='', dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_blocks = num_blocks + self.using_pair_act = using_pair_act + self.blocks = ms.nn.CellList( + [ + Block( + config, global_config, in_shape, dtype=dtype + ) + for _ in range(self.config.super_block_size) + ] + ) + if self.using_pair_act: + self.pair_linear = bm.CustomDense( + pair_shape[-1], (self.config.super_block_size, self.config.attention.num_head), ndim=3, dtype=dtype) + else: + self.pair_linear = None + + def construct(self, act, mask, single_cond, pair_act): + if pair_act is None: + pair_logits = None + else: + pair_logits = self.pair_linear(pair_act).transpose([2, 3, 0, 1]) + for j in range(self.config.super_block_size): + act = self.blocks[j](act, mask, single_cond, pair_logits[j]) + return act + +@dataclass +class CrossAttentionConfig(base_config.BaseConfig): + num_head: int = 4 + key_dim: int = 128 + value_dim: int = 128 + + +class CrossAttention(nn.Cell): + """ + A CrossAttention class implementing multi-head cross-attention mechanism for processing sequential data. + + Args: + config (Config): Configuration object containing attention settings. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Input dimension for the attention mechanism. + + Inputs: + - **x_q** (Tensor) - Query tensor. + - **x_k** (Tensor) - Key tensor. + - **mask_q** (Tensor) - Query mask tensor. + - **mask_k** (Tensor) - Key mask tensor. + - **pair_logits** (Tensor, optional) - Optional pair logits tensor. Default: ``None``. + - **single_cond_q** (Tensor) - Single condition tensor for queries. + - **single_cond_k** (Tensor) - Single condition tensor for keys. + + Outputs: + - **output** (Tensor) - Output tensor after cross-attention processing. + """ + + def __init__(self, config, global_config, in_channel, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.adaptive_layernorm_q = AdaptiveLayernorm(in_channel, in_channel, dtype=dtype) + self.adaptive_layernorm_k = AdaptiveLayernorm(in_channel, in_channel, dtype=dtype) + assert config.key_dim % config.num_head == 0 + assert config.value_dim % config.num_head == 0 + self.key_dim = config.key_dim // config.num_head + self.value_dim = config.value_dim // config.num_head + self.linear_q = bm.CustomDense( + in_channel, (self.config.num_head, self.key_dim), use_bias=True, ndim=3, dtype=dtype) + self.linear_k = bm.CustomDense( + in_channel, (self.config.num_head, self.key_dim), use_bias=False, ndim=3, dtype=dtype) + self.linear_v = bm.CustomDense( + in_channel, (self.config.num_head, self.value_dim), use_bias=False, ndim=3, dtype=dtype) + self.ncon1 = Ncon([[-1, -3, -2, 1], [-1, -4, -2, 1]]) + self.ncon2 = Ncon([[-1, -3, -2, 1], [-1, 1, -3, -4]]) + self.gating_query = bm.CustomDense( + in_channel, self.config.num_head * self.value_dim, use_bias=False, + weight_init='zeros', bias_init='ones', ndim=3, dtype=dtype) + self.adaptive_zero_init = AdaptiveZeroInit( + global_config, in_channel, in_channel, in_channel, dtype=dtype) + + def construct(self, x_q, x_k, mask_q, mask_k, pair_logits, single_cond_q, single_cond_k): + """Multihead self-attention.""" + assert len(mask_q.shape) == len(x_q.shape) - \ + 1, f'{mask_q.shape}, {x_q.shape}' + assert len(mask_k.shape) == len(x_k.shape) - \ + 1, f'{mask_k.shape}, {x_k.shape}' + # bias: ... x heads (1) x query x key + bias = ( + 1e9 + * (mask_q - 1.0)[..., None, :, None] + * (mask_k - 1.0)[..., None, None, :] + ) + x_q = self.adaptive_layernorm_q(x_q, single_cond_q) + x_k = self.adaptive_layernorm_k(x_k, single_cond_k) + q = self.linear_q(x_q) + k = self.linear_k(x_k) + logits = mint.einsum('...qhc,...khc->...hqk', q * self.key_dim ** (-0.5), k) + bias + if pair_logits is not None: + logits += pair_logits + weights = ms.ops.softmax(logits, axis=-1) + v = self.linear_v(x_k) + weighted_avg = mint.einsum('...hqk,...khc->...qhc', weights, v) + weighted_avg = ms.ops.reshape( + weighted_avg, weighted_avg.shape[:-2] + (-1,)) + + gate_logits = self.gating_query(x_q) + weighted_avg *= ms.ops.sigmoid(gate_logits.astype(ms.float32)).astype(gate_logits.dtype) + + output = self.adaptive_zero_init(weighted_avg, single_cond_q,) + return output + + +class CrossAttTransformer(nn.Cell): + """ + A CrossAttTransformer class implementing a transformer that applies cross attention between two sets of subsets. + + Args: + config (Config): Configuration object containing settings for the transformer. + global_config (GlobalConfig): Global configuration object. + in_shape (tuple): Input shape for the transformer. + + Inputs: + - **queries_act** (Tensor) - Query activations tensor. + - **queries_mask** (Tensor) - Mask tensor for queries. + - **queries_to_keys** (Tensor) - Tensor mapping queries to keys. + - **keys_mask** (Tensor) - Mask tensor for keys. + - **queries_single_cond** (Tensor) - Single condition tensor for queries. + - **keys_single_cond** (Tensor) - Single condition tensor for keys. + - **pair_cond** (Tensor) - Pair condition tensor. + + Outputs: + - **queries_act** (Tensor) - Processed query activations tensor after cross attention. + """ + @dataclass + class Config(base_config.BaseConfig): + num_intermediate_factor: int + num_blocks: int + attention: CrossAttentionConfig = base_config.autocreate() + + def __init__(self, config, global_config, in_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.pair_input_layer_norm = bm.LayerNorm(in_shape, create_beta=False, dtype=ms.float32) + self.pair_logits_projection = bm.CustomDense( + in_shape[-1], (self.config.num_blocks, self.config.attention.num_head), ndim=4, dtype=dtype) + self.block = ms.nn.CellList( + [ + CrossAttTransformerBlock( + config, global_config, in_shape[-2], dtype=dtype + ) + for _ in range(self.config.num_blocks) + ] + ) + + def construct(self, queries_act, queries_mask, queries_to_keys, + keys_mask, queries_single_cond, keys_single_cond, + pair_cond): + pair_act = self.pair_input_layer_norm(pair_cond) + pair_logits = self.pair_logits_projection(pair_act) + pair_logits = ms.ops.transpose(pair_logits, (3, 0, 4, 1, 2)) + for i in range(self.config.num_blocks): + queries_act = self.block[i](queries_act, queries_mask, queries_to_keys, keys_mask, pair_logits[i], + queries_single_cond, keys_single_cond) + return queries_act + + +class CrossAttTransformerBlock(nn.Cell): + def __init__(self, config, global_config, in_channel, dtype=ms.float32): + super().__init__() + self.cross_attention = CrossAttention( + config.attention, global_config, in_channel, dtype=dtype) + self.transition = TransitionBlock( + global_config, in_channel, config.num_intermediate_factor, dtype=dtype) + + def construct(self, queries_act, queries_mask, queries_to_keys, keys_mask, pair_logits, + queries_single_cond, keys_single_cond): + keys_act = atom_layout.convert_ms( + queries_to_keys, queries_act, layout_axes=(-3, -2) + ) + queries_act += self.cross_attention(queries_act, keys_act, queries_mask, keys_mask, + pair_logits, queries_single_cond, keys_single_cond) + queries_act += self.transition(queries_act, queries_single_cond) + return queries_act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py new file mode 100644 index 000000000..f8e439600 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/distogram_head.py @@ -0,0 +1,85 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Distogram head.""" + +from typing import Final +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops, Tensor +from alphafold3.model import base_config +from alphafold3.model.components import base_modules as bm +from mindchemistry.e3.utils import Ncon + + +_CONTACT_THRESHOLD: Final[float] = 8.0 +_CONTACT_EPSILON: Final[float] = 1e-3 + + +class DistogramHead(nn.Cell): + """ + A DistogramHead class that computes a distogram from pair embeddings, predicting distances between residues. + + Args: + config (Config): Configuration object containing parameters for the distogram head. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Number of input channels for the linear layer. + + Inputs: + - **batch** (dict) - Dictionary containing batch features. + - **embeddings** (dict) - Dictionary containing pair embeddings. + + Outputs: + - **bin_edges** (Tensor) - Tensor of bin edges for distance predictions. + - **contact_probs** (Tensor) - Tensor of contact probabilities. + + Notes: + - The distogram head computes distance probabilities using a linear transformation and softmax. + - The Ncon class is used for tensor contraction operations. + """ + @dataclass + class Config(base_config.BaseConfig): + first_break: float = 2.3125 + last_break: float = 21.6875 + num_bins: int = 64 + + def __init__( + self, config, global_config, in_channel, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + self.linear = bm.CustomDense( + in_channel, self.config.num_bins, weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.ncon = Ncon([[-1, -2, 1], [1]]) + + def construct(self, batch, embeddings): + pair_act = embeddings["pair"] + seq_mask = batch.token_features.mask.astype(ms.bool_) + pair_mask = seq_mask[:, None] * seq_mask[None, :] + left_half_logits = self.linear(pair_act) + right_half_logits = left_half_logits + logits = left_half_logits + ms.ops.swapaxes(right_half_logits, -2, -3) + probs = ms.ops.softmax(logits, axis=-1) + breaks = ms.ops.linspace( + self.config.first_break, + self.config.last_break, + self.config.num_bins - 1, + ) + bin_tops = ms.ops.concat( + (breaks, (breaks[-1] + breaks[-1] - breaks[-2]).reshape(-1))) + threshold = _CONTACT_THRESHOLD + _CONTACT_EPSILON + is_contact_bin = 1.0 * (bin_tops <= threshold) + contact_probs = self.ncon([probs.astype(ms.float32), is_contact_bin.astype(ms.float32)]) + contact_probs = pair_mask * contact_probs + return { + 'bin_edges': breaks, + 'contact_probs': contact_probs, + } diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py new file mode 100644 index 000000000..e69300d2e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/featurization.py @@ -0,0 +1,214 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Model-side of the input features processing.""" +import math +import functools +import numpy as np +import mindspore as ms +from mindspore import ops +from alphafold3.constants import residue_names +from alphafold3.model import feat_batch +from alphafold3.model import features +from alphafold3.model.components import utils + + +def _grid_keys(key, shape): + """Generate a grid of rng keys that is consistent with different padding. + + Generate random keys such that the keys will be identical, regardless of + how much padding is added to any dimension. + + Args: + key: A PRNG key. + shape: The shape of the output array of keys that will be generated. + + Returns: + An array of shape `shape` consisting of random keys. + """ + if not shape: + return key + + def partial_bitwise_xor(other): + return ms.ops.bitwise_xor(key, other) + + def _partial_grid_keys(key): + return _grid_keys(key, shape[1:]) + new_keys = ms.vmap(partial_bitwise_xor)( + ms.ops.arange(shape[0]) + ) + return ms.vmap(_partial_grid_keys)(new_keys) + + +def _padding_consistent_rng(f): + def inner(key, shape, **kwargs): + keys = _grid_keys(key, shape) + out = keys.flatten() + count = 0 + for key in keys.flatten(): + out[count] = (f((), key)) + count += 1 + return out.reshape(keys.shape) + return inner + + +def gumbel_sample(shape, seed): + uniform_samples = ms.Tensor(np.random.uniform(0.0, 1.0, shape)) + gumbel_samples = -ops.log(-ops.log(uniform_samples)) + return gumbel_samples + + +def gumbel_argsort_sample_idx(key, logits): + gumbel = _padding_consistent_rng(gumbel_sample) + z = gumbel(key, logits.shape) + perm = ms.ops.argsort(logits + z, axis=-1, descending=False) + return perm[::-1] + + +def create_msa_feat(msa): + msa_1hot = ms.ops.one_hot(msa.rows.astype( + ms.int64), residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 1) + deletion_matrix = msa.deletion_matrix + has_deletion = ms.ops.clip(deletion_matrix, 0.0, 1.0)[..., None] + deletion_value = (ms.ops.arctan(deletion_matrix / 3.0) + * (2.0 / math.pi))[..., None] + msa_feat = [msa_1hot.astype(deletion_value.dtype), has_deletion, deletion_value] + return ms.ops.concat(msa_feat, axis=-1) + + +def truncate_msa_batch(msa, num_msa): + indices = ms.ops.arange(num_msa) + return msa.index_msa_rows(indices) + + +def create_target_feat(batch, append_per_atom_features, dtype=ms.float32): + token_features = batch.token_features + target_features = [] + target_features.append(ms.ops.one_hot( + token_features.aatype.astype(ms.int64), + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP).astype(dtype)) + target_features.append(batch.msa.profile) + target_features.append(batch.msa.deletion_mean[..., None]) + + if append_per_atom_features: + ref_mask = batch.ref_structure.mask + element_feat = ms.ops.one_hot(batch.ref_structure.element, 128) + element_feat = utils.mask_mean( + mask=ref_mask[..., None], value=element_feat, axis=-2, eps=1e-6) + target_features.append(element_feat) + pos_feat = batch.ref_structure.positions + pos_feat = pos_feat.reshape([pos_feat.shape[0], -1]) + target_features.append(pos_feat) + target_features.append(ref_mask) + return ms.ops.concat(target_features, axis=-1) + + +def create_relative_encoding( + seq_features, + max_relative_idx, + max_relative_chain, +): + """Add relative position encodings.""" + rel_feats = [] + token_index = seq_features.token_index + residue_index = seq_features.residue_index + asym_id = seq_features.asym_id + entity_id = seq_features.entity_id + sym_id = seq_features.sym_id + + left_asym_id = asym_id[:, None] + right_asym_id = asym_id[None, :] + + left_residue_index = residue_index[:, None] + right_residue_index = residue_index[None, :] + + left_token_index = token_index[:, None] + right_token_index = token_index[None, :] + + left_entity_id = entity_id[:, None] + right_entity_id = entity_id[None, :] + left_sym_id = sym_id[:, None] + right_sym_id = sym_id[None, :] + + # Embed relative positions using a one-hot embedding of distance along chain + offset = left_residue_index - right_residue_index + clipped_offset = ms.ops.clip( + offset + max_relative_idx, min=0, max=2 * max_relative_idx + ) + asym_id_same = left_asym_id == right_asym_id + final_offset = ms.ops.where( + asym_id_same, + clipped_offset, + (2 * max_relative_idx + 1) * ms.ops.ones_like(clipped_offset), + ) + rel_pos = ms.ops.one_hot(final_offset.astype( + ms.int64), 2 * max_relative_idx + 2) + rel_feats.append(rel_pos) + + # Embed relative token index as a one-hot embedding of distance along residue + token_offset = left_token_index - right_token_index + clipped_token_offset = ms.ops.clip( + token_offset + max_relative_idx, min=0, max=2 * max_relative_idx + ) + residue_same = ms.ops.logical_and((left_asym_id == right_asym_id), ( + left_residue_index == right_residue_index + )) + final_token_offset = ms.ops.where( + residue_same, + clipped_token_offset, + (2 * max_relative_idx + 1) * ms.ops.ones_like(clipped_token_offset), + ) + rel_token = ms.ops.one_hot(final_token_offset.astype( + ms.int64), 2 * max_relative_idx + 2) + rel_feats.append(rel_token) + + # Embed same entity ID + entity_id_same = left_entity_id == right_entity_id + rel_feats.append(entity_id_same.astype(rel_pos.dtype)[..., None]) + + # Embed relative chain ID inside each symmetry class + rel_sym_id = left_sym_id - right_sym_id + + max_rel_chain = max_relative_chain + + clipped_rel_chain = ms.ops.clip( + rel_sym_id + max_rel_chain, min=0, max=2 * max_rel_chain + ) + + final_rel_chain = ms.ops.where( + entity_id_same, + clipped_rel_chain, + (2 * max_rel_chain + 1) * ms.ops.ones_like(clipped_rel_chain), + ) + rel_chain = ms.ops.one_hot(final_rel_chain.astype( + ms.int64), 2 * max_relative_chain + 2) + + rel_feats.append(rel_chain) + + return ms.ops.concat(rel_feats, axis=-1) + + +def shuffle_msa(key, msa): + """Shuffle MSA randomly, return batch with shuffled MSA. + + Args: + key: rng key for random number generation. + msa: MSA object to sample msa from. + + Returns: + Protein with sampled msa. + """ + key, sample_key = key, key + 1 + # Sample uniformly among sequences with at least one non-masked position. + logits = (ms.ops.clip(ms.ops.sum(msa.mask, dim=-1), 0.0, 1.0) - 1.0) * 1e6 + index_order = gumbel_argsort_sample_idx(sample_key, logits) + return msa.index_msa_rows(index_order), sample_key diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py new file mode 100644 index 000000000..42d775578 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py @@ -0,0 +1,577 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import os +import pathlib +import numpy as np +import mindspore as ms +from mindspore import load_checkpoint +from alphafold3.model.params import get_model_af3_params + + +def np_slice(arr, i, j, dtype=ms.bfloat16): + if i is not None and j is not None: + return ms.Parameter(ms.Tensor(arr[i, j], dtype)) + elif i is not None and j is None: + return ms.Parameter(ms.Tensor(arr[i], dtype)) + elif i is None and j is not None: + return ms.Parameter(ms.Tensor(arr[j], dtype)) + else: + return ms.Parameter(ms.Tensor(arr, dtype)) + + +def load_adaptive_layernorm(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(path + "single_cond_layer_norm"): + adaptive_layernorm.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[path + 'layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.layernorm.layernorm.beta.set_data( + np_slice(ckpt[path + 'layer_norm']['offset'], i, j, dtype=ms.float32)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + 'single_cond_layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.single_cond_scale.weight.set_data( + np_slice(ckpt[path + 'single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data( + np_slice(ckpt[path + 'single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data( + np_slice(ckpt[path + 'single_cond_bias']['weights'], i, j, dtype=dtype)) + + +def load_adaptive_zero_init(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data( + np_slice(ckpt[path + 'transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(path + "adaptive_zero_cond"): + adaptive_zero_init.cond_linear2.weight.set_data( + np_slice(ckpt[path + 'adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data( + np_slice(ckpt[path + 'adaptive_zero_cond']['bias'], i, j, dtype=dtype)) + + +def load_transition(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm( + transition_block.adaptive_layernorm, path + 'ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data( + np_slice(ckpt[path + 'ffw_transition1']['weights'], i, j, dtype=dtype).reshape( + (transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init( + transition_block.adaptive_zero_init, path + 'ffw_', ckpt, i, j, dtype=dtype) + + +def load_self_attention(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm( + self_attention.adaptive_layernorm, path, ckpt, i, j) + self_attention.q_linear.weight.set_data( + np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data( + np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data( + np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data( + np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data( + np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init( + self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_transformer(transformer, path, ckpt, dtype=ms.bfloat16): + for i in range(6): + for j in range(4): + transformer_path = (path + + '/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer') + load_self_attention(transformer.super_blocks[i].blocks[j].self_attention, + transformer_path, ckpt, i, j, dtype=dtype) + load_transition(transformer.super_blocks[i].blocks[j].transition_block, + transformer_path, ckpt, i, j, dtype=dtype) + if transformer.using_pair_act == True: + pair_projection_path = path + '/__layer_stack_with_per_layer/pair_logits_projection' + transformer.super_blocks[i].pair_linear.weight.set_data( + np_slice(ckpt[pair_projection_path]['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act == True: + pair_norm_path = path + '/pair_input_layer_norm' + transformer.pair_layernorm.layernorm.gamma.set_data( + np_slice(ckpt[pair_norm_path]['scale'].T, dtype=ms.float32)) + + +def load_transition_block(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + transition_block.glu_weight.set_data( + np_slice(ckpt[path + '/transition1']['weights'], i, j, dtype=dtype).reshape( + (-1, 2, transition_block.num_intermediate))) + transition_block.out_linear.weight.set_data( + np_slice(ckpt[path + '/transition2']['weights'], i, j, dtype=dtype)) + transition_block.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/input_layer_norm']['scale'], i, j, dtype=ms.float32)) + transition_block.layernorm.layernorm.beta.set_data( + np_slice(ckpt[path + '/input_layer_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_grid_self_attention(grid_self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + grid_self_attention.q_projection.weight.set_data( + np_slice(ckpt[path + '/q_projection']['weights'], i, j, dtype=dtype).transpose(2, 0, 1)) + grid_self_attention.k_projection.weight.set_data( + np_slice(ckpt[path + '/k_projection']['weights'], i, j, dtype=dtype).transpose(2, 0, 1)) + grid_self_attention.v_projection.weight.set_data( + np_slice(ckpt[path + '/v_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.gating_query.weight.set_data( + np_slice(ckpt[path + '/gating_query']['weights'], i, j, dtype=dtype).T) + grid_self_attention.output_projection.weight.set_data( + np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.pair_bias_projection.weight.set_data( + np_slice(ckpt[path + '/pair_bias_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.act_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/act_norm']['scale'], i, j, dtype=ms.float32)) + grid_self_attention.act_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/act_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_outer_product_mean(outer_product_mean, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + outer_product_mean.outer_product_mean.o_biases.set_data( + np_slice(ckpt[path]['output_b'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.linear_output_weight.set_data( + np_slice(ckpt[path]['output_w'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.left_projection_weight.set_data( + np_slice(ckpt[path + '/left_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.right_projection_weight.set_data( + np_slice(ckpt[path + '/right_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.layer_norm_input_gamma.set_data( + np_slice(ckpt[path + '/layer_norm_input']['scale'], i, j, dtype=ms.float32)) + outer_product_mean.outer_product_mean.layer_norm_input_beta.set_data( + np_slice(ckpt[path + '/layer_norm_input']['offset'], i, j, dtype=ms.float32)) + + +def load_msa_attention(msa_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + msa_attention.actnorm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/act_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.actnorm.layernorm.beta.set_data( + np_slice(ckpt[path + '/act_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/pair_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.beta.set_data( + np_slice(ckpt[path + '/pair_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pair_logits.weight.set_data( + np_slice(ckpt[path + '/pair_logits']['weights'], i, j, dtype=dtype)) + msa_attention.v_projection.weight.set_data( + np_slice(ckpt[path + '/v_projection']['weights'], i, j, dtype=dtype)) + msa_attention.gating_query.weight.set_data( + np_slice(ckpt[path + '/gating_query']['weights'], i, j, dtype=dtype)) + msa_attention.output_projection.weight.set_data( + np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + + +def load_triangle_multiplication(triangle_multiplication, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + triangle_multiplication.triangle_multi.gate.weight.set_data( + np_slice(ckpt[path + '/gate']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.projection.weight.set_data( + np_slice(ckpt[path + '/projection']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.weight_glu = ms.ops.stack( + [triangle_multiplication.triangle_multi.gate.weight, + triangle_multiplication.triangle_multi.projection.weight], axis=1) + triangle_multiplication.triangle_multi.output_projection.weight.set_data( + np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.gating_linear.weight.set_data( + np_slice(ckpt[path + '/gating_linear']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.gamma.set_data( + np_slice(ckpt[path + '/left_norm_input']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.beta.set_data( + np_slice(ckpt[path + '/left_norm_input']['offset'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/center_norm']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/center_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_pair_former(pair_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_grid_self_attention(pair_former.grid_self_attention1, path + '/pair_attention1', + ckpt, i, j, dtype=dtype) + load_grid_self_attention(pair_former.grid_self_attention2, path + '/pair_attention2', + ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication1, + path + '/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication2, + path + '/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.transition_block, path + '/pair_transition', + ckpt, i, j, dtype=dtype) + if pair_former.with_single: + pair_former.single_pair_logits_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/single_pair_logits_norm']['scale'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/single_pair_logits_norm']['offset'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_projection.weight.set_data( + np_slice(ckpt[path + '/single_pair_logits_projection']['weights'], i, j, dtype=dtype)) + load_self_attention(pair_former.single_attention, path + '/single_attention_', + ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.single_transition, path + '/single_transition', + ckpt, i, j, dtype=dtype) + + +def load_evo_former(evo_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_outer_product_mean(evo_former.outer_product_mean, path + '/outer_product_mean', + ckpt, i, j, dtype=dtype) + load_msa_attention(evo_former.msa_attention, path + '/msa_attention1', + ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.msa_transition, path + '/msa_transition', + ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication1, + path + '/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication2, + path + '/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention1, path + '/pair_attention1', + ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention2, path + '/pair_attention2', + ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.transition_block, path + '/pair_transition', + ckpt, i, j, dtype=dtype) + + +def load_single_template_embedding(single_template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + num_layer = single_template_embedding.config.template_stack.num_layer + for ii in range(num_layer): + template_path = path + '/__layer_stack_no_per_layer/template_embedding_iteration' + load_pair_former(single_template_embedding.template_stack[ii], template_path, + ckpt, ii, dtype=dtype) + for jj in range(9): + template_pair_path = f'{path}/template_pair_embedding_{jj}' + single_template_embedding.template_pair_embedding[jj].weight.set_data( + np_slice(ckpt[template_pair_path]['weights'], None, None, dtype=dtype)) + single_template_embedding.output_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/output_layer_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.output_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/output_layer_norm']['offset'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/query_embedding_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/query_embedding_norm']['offset'], i, j, dtype=ms.float32)) + + +def load_template_embedding(template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + template_embedding.output_linear.weight.set_data( + np_slice(ckpt[path + '/output_linear']['weights'], i, j, dtype=dtype)) + load_single_template_embedding(template_embedding.template_embedder, + path + '/single_template_embedding', ckpt, i, j, dtype=dtype) + + +def load_distogram_head(distogram_head, path, ckpt, i=None, j=None, dtype=ms.float32): + distogram_head.linear.weight.set_data( + np_slice(ckpt[path + '/half_logits']['weights'], i, j, dtype=dtype)) + + +def load_evoformer(evoformer, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + relative_encoding_path = path + '/~_relative_encoding/position_activations' + evoformer.position_activations.weight.set_data( + np_slice(ckpt[relative_encoding_path]['weights'], i, j, dtype=dtype)) + evoformer.left_single.weight.set_data( + np_slice(ckpt[path + '/left_single']['weights'], i, j, dtype=dtype)) + evoformer.right_single.weight.set_data( + np_slice(ckpt[path + '/right_single']['weights'], i, j, dtype=dtype)) + evoformer.bond_embedding.weight.set_data( + np_slice(ckpt[path + '/bond_embedding']['weights'], i, j, dtype=dtype)) + evoformer.msa_activations.weight.set_data( + np_slice(ckpt[path + '/msa_activations']['weights'], i, j, dtype=dtype)) + evoformer.extra_msa_target_feat.weight.set_data( + np_slice(ckpt[path + '/extra_msa_target_feat']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding.weight.set_data( + np_slice(ckpt[path + '/prev_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/prev_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_embedding_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/prev_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + evoformer.single_activations.weight.set_data( + np_slice(ckpt[path + '/single_activations']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding.weight.set_data( + np_slice(ckpt[path + '/prev_single_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/prev_single_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_single_embedding_layer_norm.layernorm.beta.set_data( + np_slice(ckpt[path + '/prev_single_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + load_template_embedding(evoformer.template_module, path + '/template_embedding', + ckpt, i, j, dtype=dtype) + for ii in range(evoformer.config.pairformer.num_layer): + pairformer_path = path+'/__layer_stack_no_per_layer_1/trunk_pairformer' + load_pair_former( + evoformer.pairformer_stack[ii], pairformer_path, ckpt, ii, dtype=dtype) + for jj in range(evoformer.config.msa_stack.num_layer): + msa_stack_path = path+'/__layer_stack_no_per_layer/msa_stack' + load_evo_former( + evoformer.evoformer_stack[jj], msa_stack_path, ckpt, jj, dtype=dtype) + + +def load_adaptive_layernorm_ms(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(path + "single_cond_layer_norm"): + adaptive_layernorm.layernorm.layernorm.gamma.set_data( + np_slice(ckpt[path + 'layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.layernorm.layernorm.beta.set_data( + np_slice(ckpt[path + 'layer_norm']['offset'], i, j, dtype=dtype)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + 'single_cond_layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.weight.set_data( + np_slice(ckpt[path + 'single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data( + np_slice(ckpt[path + 'single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data( + np_slice(ckpt[path + 'single_cond_bias']['weights'], i, j, dtype=dtype)) + + +def load_adaptive_zero_init_ms(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data( + np_slice(ckpt[path + 'transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(path + 'adaptive_zero_cond'): + adaptive_zero_init.cond_linear2.weight.set_data( + np_slice(ckpt[path + 'adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data( + np_slice(ckpt[path + 'adaptive_zero_cond']['bias'], i, j, dtype=dtype)) + + +def load_transition_ms(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + transition_block.adaptive_layernorm, path + 'ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data( + np_slice(ckpt[path + 'ffw_transition1']['weights'], i, j, dtype=dtype).reshape( + (transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init_ms( + transition_block.adaptive_zero_init, path + 'ffw_', ckpt, i, j, dtype=dtype) + + +def load_self_attention_ms(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + self_attention.adaptive_layernorm, path, ckpt, i, j, dtype=dtype) + self_attention.q_linear.weight.set_data( + np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data( + np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data( + np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data( + np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data( + np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms( + self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_transformer_ms(transformer, path, ckpt, dtype=ms.float16): + for i in range(6): + for j in range(4): + transformer_path = (path + + f'/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer') + load_self_attention_ms(transformer.super_blocks[i].blocks[j].self_attention, + transformer_path, ckpt, i, j, dtype=dtype) + load_transition_ms(transformer.super_blocks[i].blocks[j].transition_block, + transformer_path, ckpt, i, j, dtype=dtype) + if transformer.using_pair_act == True: + pair_projection_path = path + f'/__layer_stack_with_per_layer/pair_logits_projection' + transformer.super_blocks[i].pair_linear.weight.set_data( + np_slice(ckpt[pair_projection_path]['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act == True: + pair_norm_path = path + '/pair_input_layer_norm' + transformer.pair_layernorm.layernorm.gamma.set_data( + np_slice(ckpt[pair_norm_path]['scale'].T, None, None, dtype=ms.float32)) + + +def load_cross_attention(cross_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms( + cross_attention.adaptive_layernorm_q, path + 'q', ckpt, i, j, dtype=dtype) + load_adaptive_layernorm_ms( + cross_attention.adaptive_layernorm_k, path + 'k', ckpt, i, j, dtype=dtype) + cross_attention.linear_q.weight.set_data( + np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_q.bias.set_data( + np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + cross_attention.linear_k.weight.set_data( + np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_v.weight.set_data( + np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + cross_attention.gating_query.weight.set_data( + np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms( + cross_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + + +def load_cross_att_transformer_block(cross_att_transformer_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_cross_attention( + cross_att_transformer_block.cross_attention, path, ckpt, i, dtype=dtype) + load_transition_ms(cross_att_transformer_block.transition, + path, ckpt, i, dtype=dtype) + + +def load_cross_attention_transformer(cross_attention_transformer, path, ckpt, last_name, i, j, dtype=ms.bfloat16): + cross_attention_transformer.pair_input_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/pair_input_layer_norm']['scale'], i, j, dtype=dtype)) + cross_attention_transformer.pair_logits_projection.weight.set_data( + np_slice(ckpt[path + '/pair_logits_projection']['weights'], i, j, dtype=dtype)) + for ii in range(cross_attention_transformer.config.num_blocks): + block_path = path + f'/__layer_stack_with_per_layer/{last_name}' + load_cross_att_transformer_block(cross_attention_transformer.block[ii], block_path, + ckpt, ii, dtype=dtype) + + +def load_per_atom_conditioning(_per_atom_conditioning, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + _per_atom_conditioning.linear1.weight.set_data( + np_slice(ckpt[path + '_embed_ref_pos']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear2.weight.set_data( + np_slice(ckpt[path + '_embed_ref_mask']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear3.weight.set_data( + np_slice(ckpt[path + '_embed_ref_element']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear4.weight.set_data( + np_slice(ckpt[path + '_embed_ref_charge']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear5.weight.set_data( + np_slice(ckpt[path + '_embed_ref_atom_name']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_row_act.weight.set_data( + np_slice(ckpt[path + '_single_to_pair_cond_row']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_col_act.weight.set_data( + np_slice(ckpt[path + '_single_to_pair_cond_col']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_pair_act1.weight.set_data( + np_slice(ckpt[path + '_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_pair_act2.weight.set_data( + np_slice(ckpt[path + '_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + + +def load_atom_cross_encoder(atom_cross_att_encoder, path, ckpt, last_name, i=None, j=None, dtype=ms.bfloat16): + load_per_atom_conditioning( + atom_cross_att_encoder._per_atom_conditioning, path, ckpt, dtype=dtype) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_trunk_single_cond.weight.set_data( + np_slice(ckpt[path + '_embed_trunk_single_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_single_cond.layernorm.gamma.set_data( + np_slice(ckpt[path + '_lnorm_trunk_single_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._atom_positions_to_features.weight.set_data( + np_slice(ckpt[path + '_atom_positions_to_features']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_trunk_pair_cond.weight.set_data( + np_slice(ckpt[path + '_embed_trunk_pair_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_pair_cond.layernorm.gamma.set_data( + np_slice(ckpt[path + '_lnorm_trunk_pair_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._single_to_pair_cond_row.weight.set_data( + np_slice(ckpt[path + '_single_to_pair_cond_row_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._single_to_pair_cond_col.weight.set_data( + np_slice(ckpt[path + '_single_to_pair_cond_col_1']['weights'].T, i, j, dtype=dtype)) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data( + np_slice(ckpt[path + '_embed_pair_offsets_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data( + np_slice(ckpt[path + '_embed_pair_distances_1']['weights'].T, i, j, dtype=dtype)) + else: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data( + np_slice(ckpt[path + '_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data( + np_slice(ckpt[path + '_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_offsets_valid.weight.set_data( + np_slice(ckpt[path + '_embed_pair_offsets_valid']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_1.weight.set_data( + np_slice(ckpt[path + '_pair_mlp_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_2.weight.set_data( + np_slice(ckpt[path + '_pair_mlp_2']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_3.weight.set_data( + np_slice(ckpt[path + '_pair_mlp_3']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._project_atom_features_for_aggr.weight.set_data( + np_slice(ckpt[path + '_project_atom_features_for_aggr']['weights'].T, i, j, dtype=dtype)) + load_cross_attention_transformer(atom_cross_att_encoder._atom_transformer_encoder, + path + '_atom_transformer_encoder', ckpt, + f"{last_name}_atom_transformer_encoder", i, j, dtype=dtype) + + +def load_atom_cross_decoder(atom_cross_att_decoder, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + atom_cross_att_decoder._project_token_features_for_broadcast.weight.set_data( + np_slice(ckpt[path + '_project_token_features_for_broadcast']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_decoder._atom_features_layer_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '_atom_features_layer_norm']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_decoder._atom_features_to_position_update.weight.set_data( + np_slice(ckpt[path + '_atom_features_to_position_update']['weights'].T, i, j, dtype=dtype)) + load_cross_attention_transformer(atom_cross_att_decoder._atom_transformer_decoder, + path + '_atom_transformer_decoder', ckpt, + last_name='diffusion_atom_transformer_decoder', i=i, j=j, dtype=dtype) + + +def load_diffusion_head(diffusion_head, path, ckpt, i=None, j=None, dtype=ms.float32): + diffusion_head.pair_cond_initial_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/pair_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.pair_cond_initial_projection.weight.set_data( + np_slice(ckpt[path + '/pair_cond_initial_projection']['weights'].T, i, j, dtype=ms.float32)) + load_transition_ms(diffusion_head.transition_block1, + path + '/pair_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.transition_block2, + path + '/pair_transition_1', ckpt, dtype=dtype) + diffusion_head.single_cond_initial_norm.layernorm.gamma.set_data( + np_slice(ckpt[path + '/single_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.single_cond_initial_projection.weight.set_data( + np_slice(ckpt[path + '/single_cond_initial_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_noise.layernorm.gamma.set_data( + np_slice(ckpt[path + '/noise_embedding_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_noise.weight.set_data( + np_slice(ckpt[path + '/noise_embedding_initial_projection']['weights'].T, i, j, dtype=dtype)) + load_transition_ms(diffusion_head.single_transition1, + path + '/single_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.single_transition2, + path + '/single_transition_1', ckpt, dtype=dtype) + diffusion_head.layer_norm_act.layernorm.gamma.set_data( + np_slice(ckpt[path + '/single_cond_embedding_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_act.weight.set_data( + np_slice(ckpt[path + '/single_cond_embedding_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_out.layernorm.gamma.set_data( + np_slice(ckpt[path + '/output_norm']['scale'], i, j, dtype=ms.float32)) + load_atom_cross_encoder(diffusion_head.atom_cross_att_encoder, path + '/diffusion', ckpt, + last_name="diffusion", dtype=dtype) + load_transformer_ms(diffusion_head.transformer, path + + '/transformer', ckpt, dtype=dtype) + load_atom_cross_decoder( + diffusion_head.atom_cross_att_decoder, path + '/diffusion', ckpt, dtype=dtype) + + +def load_confidence_head(confidence_head, path, ckpt, i=None, j=None, dtype=ms.float32): + confidence_head.left_target_feat_project.weight.set_data( + np_slice(ckpt[path + '/~_embed_features/left_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.right_target_feat_project.weight.set_data( + np_slice(ckpt[path + '/~_embed_features/right_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.distogram_feat_project.weight.set_data( + np_slice(ckpt[path + '/~_embed_features/distogram_feat_project']['weights'].T, i, j, dtype=dtype)) + for ii in range(confidence_head.config.pairformer.num_layer): + confidence_pairformer_path = path + \ + f'/__layer_stack_no_per_layer/confidence_pairformer' + load_pair_former(confidence_head.pairformer_block[ii], confidence_pairformer_path, + ckpt, ii, dtype=dtype) + confidence_head.left_half_distance_logits.weight.set_data( + np_slice(ckpt[path + '/left_half_distance_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[path + '/logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.beta.set_data( + np_slice(ckpt[path + '/logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.pae_logits.weight.set_data( + np_slice(ckpt[path + '/pae_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[path + '/pae_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.beta.set_data( + np_slice(ckpt[path + '/pae_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits.weight.set_data( + np_slice(ckpt[path + '/plddt_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.gamma.set_data( + np_slice(ckpt[path + '/plddt_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.beta.set_data( + np_slice(ckpt[path + '/plddt_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_logits.weight.set_data( + np_slice(ckpt[path + '/experimentally_resolved_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.gamma.set_data( + np_slice(ckpt[path + '/experimentally_resolved_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.beta.set_data( + np_slice(ckpt[path + '/experimentally_resolved_ln']['offset'], i, j, dtype=ms.float32)) + + +def load_diffuser(diffuser, ckpt_dir, dtype=ms.bfloat16): + path = 'diffuser' + ckpt = get_model_af3_params(pathlib.Path(ckpt_dir)) + load_evoformer(diffuser.embedding_module, path + + '/evoformer', ckpt, dtype=dtype) + load_distogram_head(diffuser.distogram_head, path + + '/distogram_head', ckpt, dtype=ms.float32) + load_atom_cross_encoder(diffuser.create_target_feat_embedding.atom_cross_att_encoder, + path + '/evoformer_conditioning', ckpt, + last_name='evoformer_conditioning', dtype=ms.float32) + load_diffusion_head(diffuser.diffusion_module, path + + '/~/diffusion_head', ckpt, dtype=ms.float32) + load_confidence_head(diffuser.confidence_head, path + + '/confidence_head', ckpt, dtype=ms.float32) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py.bak b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py.bak new file mode 100644 index 000000000..c90e12392 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/load_ckpt.py.bak @@ -0,0 +1,318 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import os +import pathlib +import numpy as np +import mindspore as ms +from mindspore import load_checkpoint +from alphafold3.model.params import get_model_af3_params + +def np_slice(arr, i, j, dtype=ms.bfloat16): + if i is not None and j is not None: + return ms.Parameter(ms.Tensor(arr[i, j], dtype)) + elif i is not None and j is None: + return ms.Parameter(ms.Tensor(arr[i], dtype)) + elif i is None and j is not None: + return ms.Parameter(ms.Tensor(arr[j], dtype)) + else: + return ms.Parameter(ms.Tensor(arr, dtype)) + +def load_adaptive_layernorm(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(path + "single_cond_layer_norm"): + adaptive_layernorm.layernorm.layernorm.gamma.set_data(np_slice(ckpt[path + 'layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.layernorm.layernorm.beta.set_data(np_slice(ckpt[path + 'layer_norm']['offset'], i, j, dtype=ms.float32)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + 'single_cond_layer_norm']['scale'], i, j, dtype=ms.float32)) + adaptive_layernorm.single_cond_scale.weight.set_data(np_slice(ckpt[path + 'single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data(np_slice(ckpt[path + 'single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data(np_slice(ckpt[path + 'single_cond_bias']['weights'], i, j, dtype=dtype)) +def load_adaptive_zero_init(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data(np_slice(ckpt[path + 'transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(path + "adaptive_zero_cond"): + adaptive_zero_init.cond_linear2.weight.set_data(np_slice(ckpt[path + 'adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data(np_slice(ckpt[path + 'adaptive_zero_cond']['bias'], i, j, dtype=dtype)) +def load_transition(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm(transition_block.adaptive_layernorm, path + 'ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data(np_slice(ckpt[path + 'ffw_transition1']['weights'], i, j, dtype=dtype).reshape((transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init(transition_block.adaptive_zero_init, path + 'ffw_', ckpt, i, j, dtype=dtype) + +def load_self_attention(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm(self_attention.adaptive_layernorm, path, ckpt, i, j) + self_attention.q_linear.weight.set_data(np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data(np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data(np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data(np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data(np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init(self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + +def load_transformer(transformer, path, ckpt, dtype=ms.bfloat16): + for i in range(6): + for j in range(4): + load_self_attention(transformer.super_blocks[i].blocks[j].self_attention, path + '/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer', ckpt, i, j, dtype=dtype) + load_transition(transformer.super_blocks[i].blocks[j].transition_block, path + '/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer', ckpt, i, j, dtype=dtype) + if transformer.using_pair_act==True: + transformer.super_blocks[i].pair_linear.weight.set_data(np_slice(ckpt[path + '/__layer_stack_with_per_layer/pair_logits_projection']['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act==True: + transformer.pair_layernorm.layernorm.gamma.set_data(np_slice(ckpt[path + '/pair_input_layer_norm']['scale'].T, dtype=ms.float32)) + +def load_transition_block(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + transition_block.glu_weight.set_data(np_slice(ckpt[path + '/transition1']['weights'], i, j, dtype=dtype).reshape((-1, 2, transition_block.num_intermediate))) + transition_block.out_linear.weight.set_data(np_slice(ckpt[path + '/transition2']['weights'], i, j, dtype=dtype)) + transition_block.layernorm.layernorm.gamma.set_data(np_slice(ckpt[path + '/input_layer_norm']['scale'], i, j, dtype=ms.float32)) + transition_block.layernorm.layernorm.beta.set_data(np_slice(ckpt[path + '/input_layer_norm']['offset'], i, j, dtype=ms.float32)) + +def load_grid_self_attention(grid_self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + grid_self_attention.q_projection.weight.set_data(np_slice(ckpt[path + '/q_projection']['weights'], i, j, dtype=dtype).transpose(2,0,1)) + grid_self_attention.k_projection.weight.set_data(np_slice(ckpt[path + '/k_projection']['weights'], i, j, dtype=dtype).transpose(2,0,1)) + grid_self_attention.v_projection.weight.set_data(np_slice(ckpt[path + '/v_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.gating_query.weight.set_data(np_slice(ckpt[path + '/gating_query']['weights'], i, j, dtype=dtype).T) + grid_self_attention.output_projection.weight.set_data(np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.pair_bias_projection.weight.set_data(np_slice(ckpt[path + '/pair_bias_projection']['weights'], i, j, dtype=dtype)) + grid_self_attention.act_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/act_norm']['scale'], i, j, dtype=ms.float32)) + grid_self_attention.act_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/act_norm']['offset'], i, j, dtype=ms.float32)) + +def load_outer_product_mean(outer_product_mean, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + outer_product_mean.outer_product_mean.o_biases.set_data(np_slice(ckpt[path]['output_b'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.linear_output_weight.set_data(np_slice(ckpt[path]['output_w'], i, j, dtype=dtype)) + outer_product_mean.outer_product_mean.left_projection_weight.set_data(np_slice(ckpt[path + '/left_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.right_projection_weight.set_data(np_slice(ckpt[path + '/right_projection']['weights'], i, j, dtype=dtype).T) + outer_product_mean.outer_product_mean.layer_norm_input_gamma.set_data(np_slice(ckpt[path + '/layer_norm_input']['scale'], i, j, dtype=ms.float32)) + outer_product_mean.outer_product_mean.layer_norm_input_beta.set_data(np_slice(ckpt[path + '/layer_norm_input']['offset'], i, j, dtype=ms.float32)) + +def load_msa_attention(msa_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + msa_attention.actnorm.layernorm.gamma.set_data(np_slice(ckpt[path + '/act_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.actnorm.layernorm.beta.set_data(np_slice(ckpt[path + '/act_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.gamma.set_data(np_slice(ckpt[path + '/pair_norm']['scale'], i, j, dtype=ms.float32)) + msa_attention.pairnorm.layernorm.beta.set_data(np_slice(ckpt[path + '/pair_norm']['offset'], i, j, dtype=ms.float32)) + msa_attention.pair_logits.weight.set_data(np_slice(ckpt[path + '/pair_logits']['weights'], i, j, dtype=dtype)) + msa_attention.v_projection.weight.set_data(np_slice(ckpt[path + '/v_projection']['weights'], i, j, dtype=dtype)) + msa_attention.gating_query.weight.set_data(np_slice(ckpt[path + '/gating_query']['weights'], i, j, dtype=dtype)) + msa_attention.output_projection.weight.set_data(np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + +def load_triangle_multiplication(triangle_multiplication, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + triangle_multiplication.triangle_multi.gate.weight.set_data(np_slice(ckpt[path + '/gate']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.projection.weight.set_data(np_slice(ckpt[path + '/projection']['weights'], i, j, dtype=dtype).T) + triangle_multiplication.triangle_multi.weight_glu = ms.ops.stack([triangle_multiplication.triangle_multi.gate.weight, triangle_multiplication.triangle_multi.projection.weight], axis=1) + triangle_multiplication.triangle_multi.output_projection.weight.set_data(np_slice(ckpt[path + '/output_projection']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.gating_linear.weight.set_data(np_slice(ckpt[path + '/gating_linear']['weights'], i, j, dtype=dtype)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.gamma.set_data(np_slice(ckpt[path + '/left_norm_input']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.left_norm_input.layernorm.beta.set_data(np_slice(ckpt[path + '/left_norm_input']['offset'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/center_norm']['scale'], i, j, dtype=ms.float32)) + triangle_multiplication.triangle_multi.center_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/center_norm']['offset'], i, j, dtype=ms.float32)) + +def load_pair_former(pair_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_grid_self_attention(pair_former.grid_self_attention1, path + '/pair_attention1', ckpt, i, j, dtype=dtype) + load_grid_self_attention(pair_former.grid_self_attention2, path + '/pair_attention2', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication1, path + '/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(pair_former.triangle_multiplication2, path + '/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.transition_block, path + '/pair_transition', ckpt, i, j, dtype=dtype) + if pair_former.with_single: + pair_former.single_pair_logits_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/single_pair_logits_norm']['scale'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/single_pair_logits_norm']['offset'], i, j, dtype=ms.float32)) + pair_former.single_pair_logits_projection.weight.set_data(np_slice(ckpt[path + '/single_pair_logits_projection']['weights'], i, j, dtype=dtype)) + load_self_attention(pair_former.single_attention, path + '/single_attention_', ckpt, i, j, dtype=dtype) + load_transition_block(pair_former.single_transition, path + '/single_transition', ckpt, i, j, dtype=dtype) + + +def load_evo_former(evo_former, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_outer_product_mean(evo_former.outer_product_mean, path + '/outer_product_mean', ckpt, i, j, dtype=dtype) + load_msa_attention(evo_former.msa_attention, path + '/msa_attention1', ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.msa_transition, path + '/msa_transition', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication1, path + '/triangle_multiplication_outgoing', ckpt, i, j, dtype=dtype) + load_triangle_multiplication(evo_former.triangle_multiplication2, path + '/triangle_multiplication_incoming', ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention1, path + '/pair_attention1', ckpt, i, j, dtype=dtype) + load_grid_self_attention(evo_former.pair_attention2, path + '/pair_attention2', ckpt, i, j, dtype=dtype) + load_transition_block(evo_former.transition_block, path + '/pair_transition', ckpt, i, j, dtype=dtype) + +def load_single_template_embedding(single_template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + num_layer = single_template_embedding.config.template_stack.num_layer + for ii in range(num_layer): + load_pair_former(single_template_embedding.template_stack[ii], path + '/__layer_stack_no_per_layer/template_embedding_iteration', ckpt, ii, dtype=dtype) + for jj in range(9): + single_template_embedding.template_pair_embedding[jj].weight.set_data(np_slice(ckpt[f'{path}/template_pair_embedding_{jj}']['weights'], None, None, dtype=dtype)) + single_template_embedding.output_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/output_layer_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.output_layer_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/output_layer_norm']['offset'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/query_embedding_norm']['scale'], i, j, dtype=ms.float32)) + single_template_embedding.query_embedding_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/query_embedding_norm']['offset'], i, j, dtype=ms.float32)) + +def load_template_embedding(template_embedding, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + template_embedding.output_linear.weight.set_data(np_slice(ckpt[path + '/output_linear']['weights'], i, j, dtype=dtype)) + load_single_template_embedding(template_embedding.template_embedder, path + '/single_template_embedding', ckpt, i, j, dtype=dtype) + +def load_distogram_head(distogram_head, path, ckpt, i=None, j=None, dtype=ms.float32): + distogram_head.linear.weight.set_data(np_slice(ckpt[path + '/half_logits']['weights'], i, j, dtype=dtype)) + +def load_evoformer(evoformer, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + evoformer.position_activations.weight.set_data(np_slice(ckpt[path + '/~_relative_encoding/position_activations']['weights'], i, j, dtype=dtype)) + evoformer.left_single.weight.set_data(np_slice(ckpt[path + '/left_single']['weights'], i, j, dtype=dtype)) + evoformer.right_single.weight.set_data(np_slice(ckpt[path + '/right_single']['weights'], i, j, dtype=dtype)) + evoformer.bond_embedding.weight.set_data(np_slice(ckpt[path + '/bond_embedding']['weights'], i, j, dtype=dtype)) + evoformer.msa_activations.weight.set_data(np_slice(ckpt[path + '/msa_activations']['weights'], i, j, dtype=dtype)) + evoformer.extra_msa_target_feat.weight.set_data(np_slice(ckpt[path + '/extra_msa_target_feat']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding.weight.set_data(np_slice(ckpt[path + '/prev_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_embedding_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/prev_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_embedding_layer_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/prev_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + evoformer.single_activations.weight.set_data(np_slice(ckpt[path + '/single_activations']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding.weight.set_data(np_slice(ckpt[path + '/prev_single_embedding']['weights'], i, j, dtype=dtype)) + evoformer.prev_single_embedding_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/prev_single_embedding_layer_norm']['scale'], i, j, dtype=ms.float32)) + evoformer.prev_single_embedding_layer_norm.layernorm.beta.set_data(np_slice(ckpt[path + '/prev_single_embedding_layer_norm']['offset'], i, j, dtype=ms.float32)) + load_template_embedding(evoformer.template_module, path + '/template_embedding', ckpt, i, j, dtype=dtype) + for ii in range(evoformer.config.pairformer.num_layer): + load_pair_former(evoformer.pairformer_stack[ii], path+'/__layer_stack_no_per_layer_1/trunk_pairformer', ckpt, ii, dtype=dtype) + for jj in range(evoformer.config.msa_stack.num_layer): + load_evo_former(evoformer.evoformer_stack[jj], path+'/__layer_stack_no_per_layer/msa_stack', ckpt, jj, dtype=dtype) + +def load_adaptive_layernorm_ms(adaptive_layernorm, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + if not ckpt.get(path + "single_cond_layer_norm"): + adaptive_layernorm.layernorm.layernorm.gamma.set_data(np_slice(ckpt[path + 'layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.layernorm.layernorm.beta.set_data(np_slice(ckpt[path + 'layer_norm']['offset'], i, j, dtype=dtype)) + else: + adaptive_layernorm.single_cond_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + 'single_cond_layer_norm']['scale'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.weight.set_data(np_slice(ckpt[path + 'single_cond_scale']['weights'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_scale.bias.set_data(np_slice(ckpt[path + 'single_cond_scale']['bias'], i, j, dtype=dtype)) + adaptive_layernorm.single_cond_bias.weight.set_data(np_slice(ckpt[path + 'single_cond_bias']['weights'], i, j, dtype=dtype)) + +def load_adaptive_zero_init_ms(adaptive_zero_init, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + adaptive_zero_init.cond_linear1.weight.set_data(np_slice(ckpt[path + 'transition2']['weights'], i, j, dtype=dtype)) + if ckpt.get(path + 'adaptive_zero_cond'): + adaptive_zero_init.cond_linear2.weight.set_data(np_slice(ckpt[path + 'adaptive_zero_cond']['weights'], i, j, dtype=dtype)) + adaptive_zero_init.cond_linear2.bias.set_data(np_slice(ckpt[path + 'adaptive_zero_cond']['bias'], i, j, dtype=dtype)) + +def load_transition_ms(transition_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms(transition_block.adaptive_layernorm, path + 'ffw_', ckpt, i, j, dtype=dtype) + transition_block.weights.set_data(np_slice(ckpt[path + 'ffw_transition1']['weights'], i, j, dtype=dtype).reshape((transition_block.weights.shape[0], 2, transition_block.num_intermediate))) + load_adaptive_zero_init_ms(transition_block.adaptive_zero_init, path + 'ffw_', ckpt, i, j, dtype=dtype) + +def load_self_attention_ms(self_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms(self_attention.adaptive_layernorm, path, ckpt, i, j, dtype=dtype) + self_attention.q_linear.weight.set_data(np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + self_attention.q_linear.bias.set_data(np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + self_attention.k_linear.weight.set_data(np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + self_attention.v_linear.weight.set_data(np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + self_attention.linear.weight.set_data(np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms(self_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + +def load_transformer_ms(transformer, path, ckpt, dtype=ms.float16): + for i in range(6): + for j in range(4): + load_self_attention_ms(transformer.super_blocks[i].blocks[j].self_attention, path + f'/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer', ckpt, i, j, dtype=dtype) + load_transition_ms(transformer.super_blocks[i].blocks[j].transition_block, path + f'/__layer_stack_with_per_layer/__layer_stack_with_per_layer/transformer', ckpt, i, j, dtype=dtype) + if transformer.using_pair_act==True: + transformer.super_blocks[i].pair_linear.weight.set_data(np_slice(ckpt[path + f'/__layer_stack_with_per_layer/pair_logits_projection']['weights'], i, None, dtype=dtype)) + if transformer.using_pair_act==True: + transformer.pair_layernorm.layernorm.gamma.set_data(np_slice(ckpt[path + '/pair_input_layer_norm']['scale'].T, None, None, dtype=ms.float32)) + +def load_cross_attention(cross_attention, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_adaptive_layernorm_ms(cross_attention.adaptive_layernorm_q, path + 'q', ckpt, i, j, dtype=dtype) + load_adaptive_layernorm_ms(cross_attention.adaptive_layernorm_k, path + 'k', ckpt, i, j, dtype=dtype) + cross_attention.linear_q.weight.set_data(np_slice(ckpt[path + 'q_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_q.bias.set_data(np_slice(ckpt[path + 'q_projection']['bias'], i, j, dtype=dtype)) + cross_attention.linear_k.weight.set_data(np_slice(ckpt[path + 'k_projection']['weights'], i, j, dtype=dtype)) + cross_attention.linear_v.weight.set_data(np_slice(ckpt[path + 'v_projection']['weights'], i, j, dtype=dtype)) + cross_attention.gating_query.weight.set_data(np_slice(ckpt[path + 'gating_query']['weights'], i, j, dtype=dtype)) + load_adaptive_zero_init_ms(cross_attention.adaptive_zero_init, path, ckpt, i, j, dtype=dtype) + +def load_cross_att_transformer_block(cross_att_transformer_block, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + load_cross_attention(cross_att_transformer_block.cross_attention, path, ckpt, i, dtype=dtype) + load_transition_ms(cross_att_transformer_block.transition, path, ckpt, i, dtype=dtype) + +def load_cross_attention_transformer(cross_attention_transformer, path, ckpt, last_name, i, j, dtype=ms.bfloat16): + cross_attention_transformer.pair_input_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/pair_input_layer_norm']['scale'], i, j, dtype=dtype)) + cross_attention_transformer.pair_logits_projection.weight.set_data(np_slice(ckpt[path + '/pair_logits_projection']['weights'], i, j, dtype=dtype)) + for ii in range(cross_attention_transformer.config.num_blocks): + load_cross_att_transformer_block(cross_attention_transformer.block[ii], path + f'/__layer_stack_with_per_layer/{last_name}', ckpt, ii, dtype=dtype) + +def load_per_atom_conditioning(_per_atom_conditioning, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + _per_atom_conditioning.linear1.weight.set_data(np_slice(ckpt[path + '_embed_ref_pos']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear2.weight.set_data(np_slice(ckpt[path + '_embed_ref_mask']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear3.weight.set_data(np_slice(ckpt[path + '_embed_ref_element']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear4.weight.set_data(np_slice(ckpt[path + '_embed_ref_charge']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear5.weight.set_data(np_slice(ckpt[path + '_embed_ref_atom_name']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_row_act.weight.set_data(np_slice(ckpt[path + '_single_to_pair_cond_row']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_col_act.weight.set_data(np_slice(ckpt[path + '_single_to_pair_cond_col']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_pair_act1.weight.set_data(np_slice(ckpt[path + '_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + _per_atom_conditioning.linear_pair_act2.weight.set_data(np_slice(ckpt[path + '_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + +def load_atom_cross_encoder(atom_cross_att_encoder, path, ckpt, last_name, i=None, j=None, dtype=ms.bfloat16): + load_per_atom_conditioning(atom_cross_att_encoder._per_atom_conditioning, path, ckpt, dtype=dtype) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_trunk_single_cond.weight.set_data(np_slice(ckpt[path + '_embed_trunk_single_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_single_cond.layernorm.gamma.set_data(np_slice(ckpt[path + '_lnorm_trunk_single_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._atom_positions_to_features.weight.set_data(np_slice(ckpt[path + '_atom_positions_to_features']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_trunk_pair_cond.weight.set_data(np_slice(ckpt[path + '_embed_trunk_pair_cond']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._lnorm_trunk_pair_cond.layernorm.gamma.set_data(np_slice(ckpt[path + '_lnorm_trunk_pair_cond']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_encoder._single_to_pair_cond_row.weight.set_data(np_slice(ckpt[path + '_single_to_pair_cond_row_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._single_to_pair_cond_col.weight.set_data(np_slice(ckpt[path + '_single_to_pair_cond_col_1']['weights'].T, i, j, dtype=dtype)) + if atom_cross_att_encoder.with_cond: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data(np_slice(ckpt[path + '_embed_pair_offsets_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data(np_slice(ckpt[path + '_embed_pair_distances_1']['weights'].T, i, j, dtype=dtype)) + else: + atom_cross_att_encoder._embed_pair_offsets.weight.set_data(np_slice(ckpt[path + '_embed_pair_offsets']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_distances.weight.set_data(np_slice(ckpt[path + '_embed_pair_distances']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._embed_pair_offsets_valid.weight.set_data(np_slice(ckpt[path + '_embed_pair_offsets_valid']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_1.weight.set_data(np_slice(ckpt[path + '_pair_mlp_1']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_2.weight.set_data(np_slice(ckpt[path + '_pair_mlp_2']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._pair_mlp_3.weight.set_data(np_slice(ckpt[path + '_pair_mlp_3']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_encoder._project_atom_features_for_aggr.weight.set_data(np_slice(ckpt[path + '_project_atom_features_for_aggr']['weights'].T, i, j, dtype=dtype) ) + load_cross_attention_transformer(atom_cross_att_encoder._atom_transformer_encoder, path + '_atom_transformer_encoder', ckpt, f"{last_name}_atom_transformer_encoder", i, j, dtype=dtype) + +def load_atom_cross_decoder(atom_cross_att_decoder, path, ckpt, i=None, j=None, dtype=ms.bfloat16): + atom_cross_att_decoder._project_token_features_for_broadcast.weight.set_data(np_slice(ckpt[path + '_project_token_features_for_broadcast']['weights'].T, i, j, dtype=dtype)) + atom_cross_att_decoder._atom_features_layer_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '_atom_features_layer_norm']['scale'], i, j, dtype=ms.float32)) + atom_cross_att_decoder._atom_features_to_position_update.weight.set_data(np_slice(ckpt[path + '_atom_features_to_position_update']['weights'].T, i, j, dtype=dtype)) + load_cross_attention_transformer(atom_cross_att_decoder._atom_transformer_decoder, path + '_atom_transformer_decoder', ckpt, last_name='diffusion_atom_transformer_decoder', i=i, j=j, dtype=dtype) + +def load_diffusion_head(diffusion_head, path, ckpt, i=None, j=None, dtype=ms.float32): + diffusion_head.pair_cond_initial_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/pair_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.pair_cond_initial_projection.weight.set_data(np_slice(ckpt[path + '/pair_cond_initial_projection']['weights'].T, i, j, dtype=ms.float32)) + load_transition_ms(diffusion_head.transition_block1, path + '/pair_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.transition_block2, path + '/pair_transition_1', ckpt, dtype=dtype) + diffusion_head.single_cond_initial_norm.layernorm.gamma.set_data(np_slice(ckpt[path + '/single_cond_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.single_cond_initial_projection.weight.set_data(np_slice(ckpt[path + '/single_cond_initial_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_noise.layernorm.gamma.set_data(np_slice(ckpt[path + '/noise_embedding_initial_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_noise.weight.set_data(np_slice(ckpt[path + '/noise_embedding_initial_projection']['weights'].T, i, j, dtype=dtype)) + load_transition_ms(diffusion_head.single_transition1, path + '/single_transition_0', ckpt, dtype=dtype) + load_transition_ms(diffusion_head.single_transition2, path + '/single_transition_1', ckpt, dtype=dtype) + diffusion_head.layer_norm_act.layernorm.gamma.set_data(np_slice(ckpt[path + '/single_cond_embedding_norm']['scale'], i, j, dtype=ms.float32)) + diffusion_head.linear_act.weight.set_data(np_slice(ckpt[path + '/single_cond_embedding_projection']['weights'].T, i, j, dtype=dtype)) + diffusion_head.layer_norm_out.layernorm.gamma.set_data(np_slice(ckpt[path + '/output_norm']['scale'], i, j, dtype=ms.float32)) + load_atom_cross_encoder(diffusion_head.atom_cross_att_encoder, path + '/diffusion', ckpt, last_name="diffusion", dtype=dtype) + load_transformer_ms(diffusion_head.transformer, path + '/transformer', ckpt, dtype=dtype) + load_atom_cross_decoder(diffusion_head.atom_cross_att_decoder, path + '/diffusion', ckpt, dtype=dtype) + +def load_confidence_head(confidence_head, path, ckpt, i=None, j=None, dtype=ms.float32): + confidence_head.left_target_feat_project.weight.set_data(np_slice(ckpt[path + '/~_embed_features/left_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.right_target_feat_project.weight.set_data(np_slice(ckpt[path + '/~_embed_features/right_target_feat_project']['weights'].T, i, j, dtype=dtype)) + confidence_head.distogram_feat_project.weight.set_data(np_slice(ckpt[path + '/~_embed_features/distogram_feat_project']['weights'].T, i, j, dtype=dtype)) + for ii in range(confidence_head.config.pairformer.num_layer): + load_pair_former(confidence_head.pairformer_block[ii], path + f'/__layer_stack_no_per_layer/confidence_pairformer', ckpt, ii, dtype=dtype) + confidence_head.left_half_distance_logits.weight.set_data(np_slice(ckpt[path + '/left_half_distance_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.gamma.set_data(np_slice(ckpt[path + '/logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.logits_ln.layernorm.beta.set_data(np_slice(ckpt[path + '/logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.pae_logits.weight.set_data(np_slice(ckpt[path + '/pae_logits']['weights'].T, i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.gamma.set_data(np_slice(ckpt[path + '/pae_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.pae_logits_ln.layernorm.beta.set_data(np_slice(ckpt[path + '/pae_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits.weight.set_data(np_slice(ckpt[path + '/plddt_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.gamma.set_data(np_slice(ckpt[path + '/plddt_logits_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.plddt_logits_ln.layernorm.beta.set_data(np_slice(ckpt[path + '/plddt_logits_ln']['offset'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_logits.weight.set_data(np_slice(ckpt[path + '/experimentally_resolved_logits']['weights'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.gamma.set_data(np_slice(ckpt[path + '/experimentally_resolved_ln']['scale'], i, j, dtype=ms.float32)) + confidence_head.experimentally_resolved_ln.layernorm.beta.set_data(np_slice(ckpt[path + '/experimentally_resolved_ln']['offset'], i, j, dtype=ms.float32)) + + +def load_diffuser(diffuser, ckpt_dir, dtype=ms.bfloat16): + path = 'diffuser' + ckpt = get_model_af3_params(pathlib.Path(ckpt_dir)) + load_evoformer(diffuser.embedding_module, path + '/evoformer', ckpt, dtype=dtype) + load_distogram_head(diffuser.distogram_head, path + '/distogram_head', ckpt, dtype=ms.float32) + load_atom_cross_encoder(diffuser.create_target_feat_embedding.atom_cross_att_encoder, path + '/evoformer_conditioning', ckpt, last_name='evoformer_conditioning', dtype=ms.float32) + load_diffusion_head(diffuser.diffusion_module, path + '/~/diffusion_head', ckpt, dtype=ms.float32) + load_confidence_head(diffuser.confidence_head, path + '/confidence_head', ckpt, dtype=ms.float32) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py new file mode 100644 index 000000000..d4d60d3c5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/model.py @@ -0,0 +1,759 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Iterable +from dataclasses import dataclass +import random +import concurrent +import functools +from absl import logging +import numpy as np +import mindspore as ms +from mindspore import ops, nn +from alphafold3 import structure +from alphafold3.constants import residue_names +from alphafold3.model import base_config +from alphafold3.model import confidences +from alphafold3.model import feat_batch +from alphafold3.model import features +from alphafold3.model import model_config +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.components import base_model +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import mapping +from alphafold3.model.components import utils +from alphafold3.model.diffusion import atom_cross_attention +from alphafold3.model.diffusion import confidence_head +from alphafold3.model.diffusion import diffusion_head +from alphafold3.model.diffusion import distogram_head +from alphafold3.model.diffusion import featurization +from alphafold3.model.diffusion import modules +from alphafold3.model.diffusion import template_modules +from alphafold3.structure import mmcif + + +def get_predicted_structure(result, batch): + """Creates the predicted structure and ion preditions. + + Args: + result: model output in a model specific layout + batch: model input batch + + Returns: + Predicted structure. + """ + model_output_coords = result['diffusion_samples']['atom_positions'] + + # Rearrange model output coordinates to the flat output layout. + model_output_to_flat = atom_layout.compute_gather_idxs( + source_layout=batch.convert_model_output.token_atoms_layout, + target_layout=batch.convert_model_output.flat_output_layout, + ) + pred_flat_atom_coords = atom_layout.convert( + gather_info=model_output_to_flat, + arr=model_output_coords.asnumpy(), + layout_axes=(-3, -2), + ) + + predicted_lddt = result.get('predicted_lddt') + + if predicted_lddt is not None: + pred_flat_b_factors = atom_layout.convert( + gather_info=model_output_to_flat, + arr=predicted_lddt.asnumpy(), + layout_axes=(-2, -1), + ) + else: + # Handle models which don't have predicted_lddt outputs. + pred_flat_b_factors = np.zeros(pred_flat_atom_coords.shape[:-1]) + + (missing_atoms_indices,) = np.nonzero( + model_output_to_flat.gather_mask == 0) + if missing_atoms_indices.shape[0] > 0: + missing_atoms_flat_layout = batch.convert_model_output.flat_output_layout[ + missing_atoms_indices + ] + missing_atoms_uids = list( + zip( + missing_atoms_flat_layout.chain_id, + missing_atoms_flat_layout.res_id, + missing_atoms_flat_layout.res_name, + missing_atoms_flat_layout.atom_name, + ) + ) + logging.warning( + 'Target %s: warning: %s atoms were not predicted by the ' + 'model, setting their coordinates to (0, 0, 0). ' + 'Missing atoms: %s', + batch.convert_model_output.empty_output_struc.name, + missing_atoms_indices.shape[0], + missing_atoms_uids, + ) + + # Put them into a structure + pred_struc = batch.convert_model_output.empty_output_struc + pred_struc = pred_struc.copy_and_update_atoms( + atom_x=pred_flat_atom_coords[..., 0], + atom_y=pred_flat_atom_coords[..., 1], + atom_z=pred_flat_atom_coords[..., 2], + atom_b_factor=pred_flat_b_factors, + # Always 1.0. + atom_occupancy=np.ones(pred_flat_atom_coords.shape[:-1]), + ) + # Set manually/differently when adding metadata. + pred_struc = pred_struc.copy_and_update_globals(release_date=None) + return pred_struc + + +class CreateTargetFeatEmbedding(nn.Cell): + """ + A class that creates target feature embeddings by combining raw features with cross-attention encoded features. + + Args: + config (Config): Configuration object containing parameters for the target feature embedding. + global_config (GlobalConfig): Global configuration object. + + Inputs: + - **batch** (dict) - Dictionary containing batch features. + + Outputs: + - **target_feat** (Tensor) - Tensor of target feature embeddings. + """ + + def __init__(self, config, global_config, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.dtype = dtype + self.atom_cross_att_encoder = atom_cross_attention.AtomCrossAttEncoder( + self.config.per_atom_conditioning, self.global_config, '', with_cond=False, dtype=dtype + ) + + def construct(self, batch): + target_feat = featurization.create_target_feat( + batch, + append_per_atom_features=False, + dtype=ms.float32 + ).astype(self.dtype) + enc = self.atom_cross_att_encoder( + token_atoms_act=None, + trunk_single_cond=None, + trunk_pair_cond=None, + batch=batch, + ) + target_feat = ops.concat( + [target_feat, enc.token_act.astype(self.dtype)], axis=-1) + return target_feat + + +def _compute_ptm(result, num_tokens, asym_id, pae_single_mask, interface): + """Computes the pTM metrics from PAE.""" + return np.stack( + [ + confidences.predicted_tm_score( + tm_adjusted_pae=tm_adjusted_pae[:num_tokens, :num_tokens].asnumpy( + ), + asym_id=asym_id.asnumpy(), + pair_mask=pae_single_mask[:num_tokens, :num_tokens], + interface=interface, + ) + for tm_adjusted_pae in result['tmscore_adjusted_pae_global'] + ], + axis=0, + ) + + +def _compute_chain_pair_iptm( + num_tokens, + asym_ids, + mask, + tm_adjusted_pae): + """Computes the chain pair ipTM metrics from PAE.""" + return np.stack( + [ + confidences.chain_pairwise_predicted_tm_scores( + tm_adjusted_pae=sample_tm_adjusted_pae[:num_tokens], + asym_id=asym_ids[:num_tokens], + pair_mask=mask[:num_tokens, :num_tokens], + ) + for sample_tm_adjusted_pae in tm_adjusted_pae + ], + axis=0, + ) + + +class Diffuser(nn.Cell): + """ + Diffuser class for processing and generating diffusion samples, confidence scores, and distanceograms. + + Args: + config (Diffuser.Config): Configuration object containing parameters for the diffuser. + in_channel (int): Number of input channels. + feat_shape (tuple): Shape of the feature tensor. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + single_shape (tuple): Shape of the single tensor. + atom_shape (tuple): Shape of the atom tensor. + out_channel (int): Number of output channels. + num_templates (int): Number of templates. + + Inputs: + - **batch** (dict): Dictionary containing batch data. + - **key** (int): Random key generator. + + Outputs: + - **result** (dict): Dictionary containing diffusion samples, distanceogram, and confidence outputs. + """ + @dataclass + class HeadsConfig(base_config.BaseConfig): + diffusion: diffusion_head.DiffusionHead.Config = base_config.autocreate() + confidence: confidence_head.ConfidenceHead.Config = base_config.autocreate() + distogram: distogram_head.DistogramHead.Config = base_config.autocreate() + + @dataclass + class Config(base_config.BaseConfig): + evoformer: 'Evoformer.Config' = base_config.autocreate() + global_config: model_config.GlobalConfig = base_config.autocreate() + heads: 'Diffuser.HeadsConfig' = base_config.autocreate() + num_recycles: int = 10 # 10, change to 0 for test + return_embeddings: bool = False + + def __init__(self, config, in_channel, feat_shape, act_shape, pair_shape, single_shape, atom_shape, + out_channel, num_templates, dtype=ms.float32, name="model"): + super().__init__(auto_prefix=True) + self.config = config + self.global_config = config.global_config + self.dtype = dtype + self.diffusion_module = diffusion_head.DiffusionHead( + self.config.heads.diffusion, self.global_config, pair_shape, dtype=ms.float32 + ) + self.embedding_module = Evoformer(self.config.evoformer, self.global_config, + feat_shape, act_shape, pair_shape, single_shape, num_templates, dtype=dtype) + self.create_target_feat_embedding = CreateTargetFeatEmbedding( + self.embedding_module.config, self.global_config, dtype=ms.float32) + self.confidence_head = confidence_head.ConfidenceHead( + self.config.heads.confidence, self.global_config, + pair_shape, single_shape, atom_shape, feat_shape[-1], out_channel, dtype=dtype + ) + self.distogram_head = distogram_head.DistogramHead( + self.config.heads.distogram, self.global_config, pair_shape[-1], dtype=ms.float32 + ) + + def _sample_diffusion(self, batch, embeddings, sample_config, key, init_positions=None): + denoising_step = functools.partial( + self.diffusion_module, + batch=batch, + embeddings=embeddings, + use_conditioning=True, + ) + sample = diffusion_head.sample( + denoising_step=denoising_step, + batch=batch, + key=key+1, + config=sample_config, + init_positions=init_positions, + ) + return sample + + def construct(self, batch, key): + if key is None: + # generate a random number + key = int(np.random.randint(100)) + # batch = feat_batch.Batch.from_data_dict(batch) + target_feat = self.create_target_feat_embedding( + batch).astype(self.dtype) + + def recycle_body(prev, key): + key, subkey = random.randint(0, 1e6), key + embeddings = self.embedding_module( + batch=batch, + prev=prev, + target_feat=target_feat, + key=subkey, + ) + embeddings['pair'] = embeddings['pair'].astype(ms.float32) + embeddings['single'] = embeddings['single'].astype(ms.float32) + return embeddings, key + + num_res = batch.num_res + embeddings = { + 'pair': ops.zeros( + [num_res, num_res, self.config.evoformer.pair_channel], + dtype=ms.float32, + ), + 'single': ops.zeros( + [num_res, self.config.evoformer.seq_channel], dtype=ms.float32 + ), + 'target_feat': target_feat, + } + num_iter = self.config.num_recycles + 1 + for _ in range(num_iter): + embeddings, _ = recycle_body(embeddings, key) + + samples = self._sample_diffusion( + batch, + embeddings, + sample_config=self.config.heads.diffusion.eval, + key=key + ) + confidence_output = [] + for i in range(samples['atom_positions'].shape[0]): + confidence_output.append(self.confidence_head( + dense_atom_positions=samples['atom_positions'][i], + embeddings=embeddings, + seq_mask=batch.token_features.mask, + token_atoms_to_pseudo_beta=batch.pseudo_beta_info.token_atoms_to_pseudo_beta, + asym_id=batch.token_features.asym_id, + )) + for key in confidence_output[0].keys(): + confidence_output[0][key] = ops.stack( + [value[key] for value in confidence_output]) + confidence_output = confidence_output[0] + distogram = self.distogram_head(batch, embeddings) + output = { + 'diffusion_samples': samples, + 'distogram': distogram, + **confidence_output, + } + if self.config.return_embeddings: + output['single_embeddings'] = embeddings['single'] + output['pair_embeddings'] = embeddings['pair'] + return output + + @classmethod + def get_inference_result(cls, batch, result, target_name,): + """Get the predicted structure, scalars, and arrays for inference. + + This function also computes any inference-time quantities, which are not a + part of the forward-pass, e.g. additional confidence scores. Note that this + function is not serialized, so it should be slim if possible. + + Args: + batch: data batch used for model inference, incl. TPU invalid types. + result: output dict from the model's forward pass. + target_name: target name to be saved within structure. + + Yields: + inference_result: dataclass object that contains a predicted structure, + important inference-time scalars and arrays, as well as a slightly trimmed + dictionary of raw model result from the forward pass (for debugging). + """ + del target_name + # Retrieve structure and construct a predicted structure. + pred_structure = get_predicted_structure(result=result, batch=batch) + num_tokens = batch.token_features.seq_length.item() + pae_single_mask = np.tile( + batch.frames.mask[:, None], + [1, batch.frames.mask.shape[0]], + ) + ptm = _compute_ptm( + result=result, + num_tokens=num_tokens, + asym_id=batch.token_features.asym_id[:num_tokens], + pae_single_mask=pae_single_mask, + interface=False, + ) + iptm = _compute_ptm( + result=result, + num_tokens=num_tokens, + asym_id=batch.token_features.asym_id[:num_tokens], + pae_single_mask=pae_single_mask, + interface=True, + ) + ptm_iptm_average = 0.8 * iptm + 0.2 * ptm + + asym_ids = batch.token_features.asym_id[:num_tokens].asnumpy() + chain_ids = [mmcif.int_id_to_str_id(asym_id) for asym_id in asym_ids] + res_ids = batch.token_features.residue_index[:num_tokens] + + if len(np.unique(asym_ids)) > 1: + # There is more than one chain, hence interface pTM (i.e. ipTM) defined, + # so use it. + ranking_confidence = ptm_iptm_average + else: + # There is only one chain, hence ipTM=NaN, so use just pTM. + ranking_confidence = ptm + + contact_probs = result['distogram']['contact_probs'].astype(ms.float32) + # Compute PAE related summaries. + _, chain_pair_pae_min, _ = confidences.chain_pair_pae( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pae=result['full_pae'].asnumpy(), + mask=pae_single_mask, + ) + chain_pair_pde_mean, chain_pair_pde_min = confidences.chain_pair_pde( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pde=result['full_pde'].asnumpy(), + ) + intra_chain_single_pde, cross_chain_single_pde, _ = confidences.pde_single( + num_tokens, + batch.token_features.asym_id.asnumpy(), + result['full_pde'].asnumpy(), + contact_probs.asnumpy(), + ) + pae_metrics = confidences.pae_metrics( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + full_pae=result['full_pae'].asnumpy(), + mask=pae_single_mask, + contact_probs=contact_probs.asnumpy(), + tm_adjusted_pae=result['tmscore_adjusted_pae_interface'].asnumpy(), + ) + ranking_confidence_pae = confidences.rank_metric( + result['full_pae'].asnumpy(), + contact_probs.asnumpy() * batch.frames.mask[:, None].astype(float), + ) + chain_pair_iptm = _compute_chain_pair_iptm( + num_tokens=num_tokens, + asym_ids=batch.token_features.asym_id.asnumpy(), + mask=pae_single_mask, + tm_adjusted_pae=result['tmscore_adjusted_pae_interface'].asnumpy(), + ) + # iptm_ichain is a vector of per-chain ptm values. iptm_ichain[0], + # for example, is just the zeroth diagonal entry of the chain pair iptm + # matrix: + # [[x, , ], + # [ , , ], + # [ , , ]]] + iptm_ichain = chain_pair_iptm.diagonal(axis1=-2, axis2=-1) + # iptm_xchain is a vector of cross-chain interactions for each chain. + # iptm_xchain[0], for example, is an average of chain 0's interactions with + # other chains: + # [[ ,x,x], + # [x, , ], + # [x, , ]]] + iptm_xchain = confidences.get_iptm_xchain(chain_pair_iptm) + + predicted_distance_errors = result['average_pde'] + + # Computing solvent accessible area with dssp can be slow for large + # structures with lots of chains, so we parallelize the call. + pred_structures = pred_structure.unstack() + num_workers = len(pred_structures) + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_workers + ) as executor: + has_clash = list(executor.map( + confidences.has_clash, pred_structures)) + fraction_disordered = list( + executor.map(confidences.fraction_disordered, pred_structures) + ) + for idx, pred_structure in enumerate(pred_structures): + ranking_score = confidences.get_ranking_score( + ptm=ptm[idx], + iptm=iptm[idx], + fraction_disordered_=fraction_disordered[idx], + has_clash_=has_clash[idx], + ) + print(f"####### result {idx} ######") + print(f"####### ranking_score {ranking_score} ######") + print(f"####### predicted_tm_score {ptm[idx]} ######") + print(f"####### interface_predicted_tm_score {iptm[idx]} ######") + yield base_model.InferenceResult( + predicted_structure=pred_structure, + numerical_data={ + 'full_pde': result['full_pde'][idx, :num_tokens, :num_tokens], + 'full_pae': result['full_pae'][idx, :num_tokens, :num_tokens], + 'contact_probs': contact_probs[:num_tokens, :num_tokens], + }, + metadata={ + 'predicted_distance_error': predicted_distance_errors[idx], + 'ranking_score': ranking_score, + 'fraction_disordered': fraction_disordered[idx], + 'has_clash': has_clash[idx], + 'predicted_tm_score': ptm[idx], + 'interface_predicted_tm_score': iptm[idx], + 'chain_pair_pde_mean': chain_pair_pde_mean[idx], + 'chain_pair_pde_min': chain_pair_pde_min[idx], + 'chain_pair_pae_min': chain_pair_pae_min[idx], + 'ptm': ptm[idx], + 'iptm': iptm[idx], + 'ptm_iptm_average': ptm_iptm_average[idx], + 'intra_chain_single_pde': intra_chain_single_pde[idx], + 'cross_chain_single_pde': cross_chain_single_pde[idx], + 'pae_ichain': pae_metrics['pae_ichain'][idx], + 'pae_xchain': pae_metrics['pae_xchain'][idx], + 'ranking_confidence': ranking_confidence[idx], + 'ranking_confidence_pae': ranking_confidence_pae[idx], + 'chain_pair_iptm': chain_pair_iptm[idx], + 'iptm_ichain': iptm_ichain[idx], + 'iptm_xchain': iptm_xchain[idx], + 'token_chain_ids': chain_ids, + 'token_res_ids': res_ids, + }, + ) + + +class Evoformer(nn.Cell): + """ + Evoformer class for generating 'single' and 'pair' embeddings in protein structure prediction. + + Args: + config (Evoformer.Config): Configuration object defining the parameters for the Evoformer module. + global_config (base_config.BaseConfig): Global configuration object containing general settings. + feat_shape (tuple): Shape of the feature tensor. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + single_shape (tuple): Shape of the single tensor. + num_templates (int): Number of templates used in the model. + + Inputs: + - **batch** (dict): Dictionary containing batch data including token features, MSA, and other relevant information. + - **prev** (dict): Dictionary containing previous embeddings for 'single' and 'pair' activations. + - **target_feat** (Tensor): Target feature tensor used for generating embeddings. + - **key** (int): Random key for reproducibility. + + Outputs: + - **output** (dict): Dictionary containing the generated embeddings: + - **single** (Tensor): Single residue embeddings. + - **pair** (Tensor): Pairwise residue embeddings. + - **target_feat** (Tensor): Target feature tensor. + + Notes: + - The class processes input data through multiple modules including position encoding, bond embedding, template embedding, MSA processing, and Pairformer iterations. + - The `construct` method iteratively processes the input data to generate rich embeddings for downstream tasks in protein structure prediction. + """ + @dataclass + # pytype: disable=invalid-function-definition + class PairformerConfig(modules.PairFormerIteration.Config): + block_remat: bool = False + remat_block_size: int = 8 + + @dataclass + class Config(base_config.BaseConfig): + """Configuration for Evoformer.""" + + max_relative_chain: int = 2 + msa_channel: int = 64 + seq_channel: int = 384 + max_relative_idx: int = 32 + num_msa: int = 1024 + pair_channel: int = 128 + pairformer: 'Evoformer.PairformerConfig' = base_config.autocreate( + single_transition=base_config.autocreate(), + single_attention=base_config.autocreate(), + num_layer=48, + ) + per_atom_conditioning: atom_cross_attention.AtomCrossAttEncoderConfig = ( + base_config.autocreate( + per_token_channels=384, + per_atom_channels=128, + atom_transformer=base_config.autocreate( + num_intermediate_factor=2, + num_blocks=3, + ), + per_atom_pair_channels=16, + ) + ) + template: template_modules.TemplateEmbedding.Config = ( + base_config.autocreate() + ) + msa_stack: modules.EvoformerIteration.Config = base_config.autocreate() + + def __init__(self, config, global_config, feat_shape, act_shape, pair_shape, single_shape, num_templates, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + in_channel = feat_shape[-1] + position_activations_in = 4 * self.config.max_relative_idx + \ + 4 + 2 * self.config.max_relative_chain + 2 + 1 + self.position_activations = bm.CustomDense( + position_activations_in, self.config.pair_channel, ndim=3, dtype=dtype) + self.left_single = bm.CustomDense( + in_channel, self.config.pair_channel, ndim=2, dtype=dtype) + self.right_single = bm.CustomDense( + in_channel, self.config.pair_channel, ndim=2, dtype=dtype) + self.bond_embedding = bm.CustomDense( + 1, self.config.pair_channel, ndim=3, dtype=dtype) + self.template_module = template_modules.TemplateEmbedding( + self.config.template, self.global_config, num_templates, act_shape, dtype=dtype + ) + self.msa_activations = bm.CustomDense( + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP + 3, self.config.msa_channel, ndim=3, dtype=dtype) + self.extra_msa_target_feat = bm.CustomDense( + in_channel, self.config.msa_channel, ndim=2, dtype=dtype) + evofromer_act_shape = (self.config.num_msa, + act_shape[1], self.config.msa_channel) + self.evoformer_stack = nn.CellList( + [ + modules.EvoformerIteration( + self.config.msa_stack, self.global_config, evofromer_act_shape, pair_shape, dtype=dtype + ) for _ in range(self.config.msa_stack.num_layer) + ] + ) + self.prev_embedding = bm.CustomDense( + pair_shape[-1], pair_shape[-1], ndim=3, dtype=dtype) + self.prev_embedding_layer_norm = bm.LayerNorm( + pair_shape, dtype=ms.float32) + self.single_activations = bm.CustomDense( + in_channel, self.config.seq_channel, ndim=2, dtype=dtype) + self.prev_single_embedding = bm.CustomDense( + self.config.seq_channel, self.config.seq_channel, ndim=2, dtype=dtype) + self.prev_single_embedding_layer_norm = bm.LayerNorm(act_shape[:-1] + + (self.config.seq_channel,), dtype=ms.float32) + self.pairformer_stack = nn.CellList( + [ + modules.PairFormerIteration( + self.config.pairformer, self.global_config, pair_shape, single_shape, with_single=True, dtype=dtype + ) for _ in range(self.config.pairformer.num_layer) + ] + ) + + def _relative_encoding(self, batch, pair_activations): + rel_feat = featurization.create_relative_encoding( + batch.token_features, + self.config.max_relative_idx, + self.config.max_relative_chain, + ) + rel_feat = rel_feat.astype(pair_activations.dtype) + pair_activations += self.position_activations(rel_feat) + return pair_activations + + def _seq_pair_embedding(self, token_features, target_feat): + left_single = self.left_single(target_feat)[:, None] + right_single = self.right_single(target_feat)[None] + dtype = left_single.dtype + pair_activations = left_single + right_single + num_residues = pair_activations.shape[0] + mask = token_features.mask + pair_mask = (mask[:, None] * mask[None, :]).astype(dtype) + assert pair_mask.shape == (num_residues, num_residues) + return pair_activations, pair_mask + + def _embed_bonds(self, batch, pair_activations): + """Embeds bond features and merges into pair activations.""" + # Construct contact matrix. + num_tokens = batch.token_features.token_index.shape[0] + contact_matrix = ops.zeros((num_tokens, num_tokens)) + + tokens_to_polymer_ligand_bonds = ( + batch.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds + ) + gather_idxs_polymer_ligand = tokens_to_polymer_ligand_bonds.gather_idxs + gather_mask_polymer_ligand = ( + tokens_to_polymer_ligand_bonds.gather_mask.prod(dim=1).astype( + gather_idxs_polymer_ligand.dtype + )[:, None] + ) + # If valid mask then it will be all 1's, so idxs should be unchanged. + gather_idxs_polymer_ligand = ( + gather_idxs_polymer_ligand * gather_mask_polymer_ligand + ) + tokens_to_ligand_ligand_bonds = ( + batch.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds + ) + gather_idxs_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_idxs + gather_mask_ligand_ligand = tokens_to_ligand_ligand_bonds.gather_mask.prod( + dim=1 + ).astype(gather_idxs_ligand_ligand.dtype)[:, None] + gather_idxs_ligand_ligand = ( + gather_idxs_ligand_ligand * gather_mask_ligand_ligand + ) + gather_idxs = ops.concat( + [gather_idxs_polymer_ligand, gather_idxs_ligand_ligand] + ) + contact_matrix[gather_idxs[:, 0], gather_idxs[:, 1]] = 1.0 + contact_matrix[0, 0] = 0.0 + + bonds_act = self.bond_embedding( + contact_matrix[:, :, None].astype(pair_activations.dtype) + ) + return pair_activations + bonds_act + + def _embed_template_pair(self, batch, pair_activations, pair_mask, key): + """Embeds Templates and merges into pair activations.""" + dtype = pair_activations.dtype + key, subkey = key, key + 1 + + templates = batch.templates + asym_id = batch.token_features.asym_id + # Construct a mask such that only intra-chain template features are + # computed, since all templates are for each chain individually. + multichain_mask = (asym_id[:, None] == asym_id[None, :]).astype(dtype) + template_fn = functools.partial(self.template_module, key=subkey) + template_act = template_fn( + query_embedding=pair_activations, + templates=templates, + multichain_mask_2d=multichain_mask, + padding_mask_2d=pair_mask, + ) + return pair_activations + template_act, key + + def _embed_process_msa(self, msa_batch, pair_activations, pair_mask, key, target_feat): + """Processes MSA and returns updated pair activations.""" + dtype = pair_activations.dtype + # msa_batch, key = featurization.shuffle_msa(key, msa_batch) # no random in test + msa_batch = featurization.truncate_msa_batch( + msa_batch, self.config.num_msa) + msa_feat = featurization.create_msa_feat(msa_batch).astype(dtype) + + msa_activations = self.msa_activations(msa_feat) + msa_activations += self.extra_msa_target_feat(target_feat)[None] + msa_mask = msa_batch.mask.astype(dtype) + # Evoformer MSA stack. + evoformer_input = {'msa': msa_activations, 'pair': pair_activations} + mask = {'msa': msa_mask, 'pair': pair_mask} + for i in range(self.config.msa_stack.num_layer): + evoformer_input = self.evoformer_stack[i](evoformer_input, mask) + + return evoformer_input['pair'], key + + def construct(self, batch, prev, target_feat, key): + num_residues = target_feat.shape[0] + + dtype = (ms.bfloat16 if self.global_config.bfloat16 == + 'all' else ms.float32) + pair_activations, pair_mask = self._seq_pair_embedding( + batch.token_features, target_feat + ) + pair_activations += self.prev_embedding( + self.prev_embedding_layer_norm( + prev['pair'] + ).astype(pair_activations.dtype) + ) + pair_activations = self._relative_encoding(batch, pair_activations) + pair_activations = self._embed_bonds( + batch=batch, pair_activations=pair_activations + ) + pair_activations, key = self._embed_template_pair( + batch=batch, + pair_activations=pair_activations, + pair_mask=pair_mask, + key=key, + ) + pair_activations, key = self._embed_process_msa( + msa_batch=batch.msa, + pair_activations=pair_activations, + pair_mask=pair_mask, + key=key, + target_feat=target_feat, + ) + del key # Unused after this point. + single_activations = self.single_activations(target_feat) + + single_activations += self.prev_single_embedding( + self.prev_single_embedding_layer_norm( + prev['single'].astype(single_activations.dtype) + ) + ) + for i in range(self.config.pairformer.num_layer): + pair_activations, single_activations = self.pairformer_stack[i]( + pair_activations, pair_mask, single_act=single_activations, + seq_mask=batch.token_features.mask.astype(dtype) + ) + output = { + 'single': single_activations, + 'pair': pair_activations, + 'target_feat': target_feat, + } + + return output diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py new file mode 100644 index 000000000..aab05b1a3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/modules.py @@ -0,0 +1,567 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""modules for the Diffuser model.""" + +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal + +import mindspore as ms +from mindspore import nn, ops, Tensor, mint +from mindchemistry.e3.utils import Ncon +from alphafold3.model import base_config +from alphafold3.utils.attention import attention +from alphafold3.utils.gated_linear_unit.gated_linear_unit import gated_linear_unit +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import mapping +from alphafold3.model.diffusion import diffusion_transformer +from alphafold3.model.diffusion.triangle import TriangleMultiplication as Triangle +from alphafold3.model.diffusion.triangle import OuterProductMean as ProductMean + + +def get_shard_size(num_residues, shard_spec): + shard_size = shard_spec[0][-1] + for num_residues_upper_bound, num_residues_shard_size in shard_spec: + shard_size = num_residues_shard_size + if ( + num_residues_upper_bound is None + or num_residues <= num_residues_upper_bound + ): + break + return shard_size + + +class TransitionBlock(nn.Cell): + """ + A transition block for transformer networks, implementing either a GLU-based or linear-based transformation. + + Args: + config (Config): Configuration object containing parameters for the transition block. + global_config (GlobalConfig): Global configuration object. + normalized_shape (tuple): Shape of the input tensor for normalization. + ndim (int): Number of dimensions of the input tensor. Default: ``3``. + + Inputs: + - **act** (Tensor) - Input activation tensor to be processed. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the transition block. + """ + @dataclass + class Config(base_config.BaseConfig): + num_intermediate_factor: int = 4 + use_glu_kernel: bool = True + + def __init__( + self, config, global_config, normalized_shape, ndim=3, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + num_channels = normalized_shape[-1] + self.num_intermediate = int( + num_channels * self.config.num_intermediate_factor) + self.layernorm = bm.LayerNorm( + normalized_shape, name='input_layer_norm', dtype=ms.float32) + if self.config.use_glu_kernel: + self.glu_weight = bm.custom_initializer( + 'relu', (num_channels, 2 * self.num_intermediate), dtype=dtype) + self.glu_weight = ms.Parameter(Tensor(self.glu_weight).reshape( + num_channels, 2, self.num_intermediate)) + else: + self.linear = bm.CustomDense(num_channels, self.num_intermediate * 2, + weight_init='ones', ndim=ndim, dtype=dtype) + self.linear.weight = bm.custom_initializer( + 'ones', self.linear.weight.shape, dtype=dtype) # relu change to relu for test + self.out_linear = bm.CustomDense(self.num_intermediate, num_channels, + weight_init=self.global_config.final_init, ndim=ndim, dtype=dtype) + + def construct(self, act, broadcast_dim=0): + act = self.layernorm(act) + if self.config.use_glu_kernel: + c = gated_linear_unit( + x=act, + weight=self.glu_weight, + implementation=None, + activation=mint.nn.functional.silu, + precision=None + ) + else: + act = self.linear(act) + a, b = mint.split(act, act.shape[-1]//2, axis=-1) + c = mint.nn.functional.silu(a) * b + return self.out_linear(c) + + +class MSAAttention(nn.Cell): + """ + Multi-Head Self-Attention (MSA) attention mechanism for processing sequence and pair data. + + Args: + config (Config): Configuration object containing parameters for the attention mechanism. + global_config (GlobalConfig): Global configuration object. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor to prevent attention weights from focusing on invalid positions. + - **pair_act** (Tensor) - Pair activation tensor. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the attention mechanism. + """ + @dataclass + class Config(base_config.BaseConfig): + num_head: int = 8 + + def __init__(self, config, global_config, act_shape, pair_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.actnorm = bm.LayerNorm(act_shape, dtype=ms.float32) + self.pairnorm = bm.LayerNorm(pair_shape, dtype=ms.float32) + num_channel = act_shape[-1] + value_dim = num_channel // self.config.num_head + self.pair_logits = bm.CustomDense(pair_shape[-1], self.config.num_head, use_bias=False, + weight_init='ones', ndim=3, dtype=dtype) # None, change to ones for test + self.v_projection = bm.CustomDense(num_channel, (self.config.num_head, value_dim), + use_bias=False, ndim=len(act_shape), dtype=dtype) + ncon_list1 = [-3, -2, 1] + ncon_list2 = [-1, 1, -3, -4] + self.ncon = Ncon([ncon_list1, ncon_list2]) + self.gating_query = bm.CustomDense( + num_channel, self.config.num_head * value_dim, weight_init='zeros', use_bias=False, ndim=3, dtype=dtype) + self.output_projection = bm.CustomDense(self.config.num_head * value_dim, num_channel, + weight_init=self.global_config.final_init, use_bias=False, ndim=3, dtype=dtype) + + def construct(self, act, mask, pair_act): + act = self.actnorm(act) + pair_act = self.pairnorm(pair_act) + logits = self.pair_logits(pair_act).transpose([2, 0, 1]) + logits += 1e9 * (mint.max(mask, dim=0)[0] - 1.0) + weights = mint.softmax(logits, dim=-1) + v = self.v_projection(act) + v_avg = self.ncon([weights, v]) + v_avg = v_avg.reshape(v_avg.shape[:-2]+(-1,)) + gate_value = self.gating_query(act) + v_avg *= mint.sigmoid(gate_value.astype(ms.float32) + ).astype(gate_value.dtype) + out = self.output_projection(v_avg) + return out + + +class GridSelfAttention(nn.Cell): + """ + Self-attention mechanism that operates either per-sequence or per-residue. + + Args: + config (Config): Configuration object containing parameters for the attention mechanism. + global_config (GlobalConfig): Global configuration object. + transpose (bool): Whether to transpose the activation tensor during processing. + normalized_shape (tuple): Shape of the input tensor for normalization. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **pair_mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **output** (Tensor) - Output tensor after processing through the self-attention mechanism. + """ + @dataclass + class Config(base_config.BaseConfig): + num_head: int = 4 + + def __init__( + self, config, global_config, transpose, normalized_shape, dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + self.transpose = transpose + num_channels = normalized_shape[-1] + in_shape = normalized_shape[-1] + assert num_channels % self.config.num_head == 0 + qkv_dim = max(num_channels // self.config.num_head, 16) + qkv_shape = (self.config.num_head, qkv_dim) + self.q_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.k_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.v_projection = bm.CustomDense( + in_shape, qkv_shape, use_bias=False, ndim=3, dtype=dtype) + self.gating_query = bm.CustomDense( + num_channels, self.config.num_head * qkv_dim, weight_init='zeros', use_bias=False, ndim=3, dtype=dtype) + self.output_projection = bm.CustomDense( + self.config.num_head * qkv_dim, num_channels, weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.act_norm = bm.LayerNorm(normalized_shape, dtype=ms.float32) + self.pair_bias_projection = bm.CustomDense( + num_channels, self.config.num_head, use_bias=False, weight_init='ones', ndim=3, dtype=dtype) # linear, change to ones for test + num_residues = normalized_shape[0] + self.chunk_size = get_shard_size( + num_residues, self.global_config.pair_attention_chunk_size + ) + + def _attention(self, act, mask, bias): + q = self.q_projection(act) + k = self.k_projection(act) + v = self.v_projection(act) + bias = ops.expand_dims(bias, 0) + weighted_avg = attention.dot_product_attention( + q, + k, + v, + mask=mask, + bias=bias, + logits_dtype=ms.float32, + precision=None, + implementation=self.global_config.flash_attention_implementation, + ) + weighted_avg = weighted_avg.reshape(weighted_avg.shape[:-2] + (-1,)) + gate_value = self.gating_query(act) + weighted_avg *= mint.sigmoid(gate_value.astype(ms.float32) + ).astype(gate_value.dtype) + return self.output_projection(weighted_avg) + + def construct(self, act, pair_mask): + """Builds a module. + + Arguments: + act: [num_seq, num_res, channels] activations tensor + pair_mask: [num_seq, num_res] mask of non-padded regions in the tensor. + Only used in inducing points attention currently. + + Returns: + Result of the self-attention operation. + """ + assert len(act.shape) == 3 + assert len(pair_mask.shape) == 2 + pair_mask = mint.swapaxes(pair_mask, -1, -2) + act = self.act_norm(act) + + non_batched_bias = self.pair_bias_projection(act) + non_batched_bias = non_batched_bias.transpose(2, 0, 1) + if self.transpose: + act = mint.swapaxes(act, -2, -3) + pair_mask = pair_mask[:, None, None, :].astype(ms.bool_) + act = self._attention(act, pair_mask, non_batched_bias) + if self.transpose: + act = mint.swapaxes(act, -2, -3) + return act + + +class TriangleMultiplication(nn.Cell): + """ + Implements triangle multiplication for tensor operations. + + Args: + config (Config): Configuration object specifying the equation and whether to use a GLU kernel. + global_config (GlobalConfig): Global configuration object. + in_channel (int): Number of input channels. + normalized_shape (tuple): Shape of the input tensor for normalization. + batch_size (int, optional): Batch size for processing. Default: ``None``. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **out** (Tensor) - Output tensor after triangle multiplication. + """ + @dataclass + class Config(base_config.BaseConfig): + equation: Literal['ikc,jkc->ijc', 'kjc,kic->ijc'] + use_glu_kernel: bool = True + + def __init__(self, config, global_config, in_channel, normalized_shape, batch_size=None, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.triangle_multi = Triangle( + self.config, + self.global_config, + num_intermediate_channel=in_channel, + equation=self.config.equation, + normalized_shape=normalized_shape, + batch_size=batch_size, + dtype=dtype) + + def construct(self, act, mask): + out = self.triangle_multi(act, mask) + return out + + +class OuterProductMean(nn.Cell): + """ + Implements the OuterProductMean operation for tensor computations. + + Args: + config (Config): Configuration object containing parameters for the operation. + global_config (GlobalConfig): Global configuration object. + num_output_channel (int): Number of output channels. + in_channel (int): Number of input channels. + + Inputs: + - **act** (Tensor) - Input activation tensor. + - **mask** (Tensor) - Mask tensor indicating valid regions in the input. + + Outputs: + - **out** (Tensor) - Output tensor after applying the outer product mean operation. + """ + @dataclass + class Config(base_config.BaseConfig): + chunk_size: int = 128 + num_outer_channel: int = 32 + + def __init__(self, config, global_config, num_output_channel, in_channel, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_output_channel = num_output_channel + self.outer_product_mean = ProductMean(self.config.num_outer_channel, + # self.config.chunk_size, + in_channel, + self.num_output_channel, + dtype=dtype) + + def construct(self, act, mask): + mask_norm = ops.expand_dims(mint.matmul(mask.T, mask), -1) + out = self.outer_product_mean(act, mask, mask_norm) + return out + + +class PairFormerIteration(nn.Cell): + """ + Single Iteration of PairFormer, which processes pairwise and single activations in a single iteration. + + Args: + config (PairFormerIteration.Config): Configuration for the PairFormerIteration module. + global_config: Global configuration for the model. + normalized_shape (tuple): Shape of the input tensor for normalization. + single_shape (tuple | None): Shape of the single activation tensor. Default: ``None``. + with_single (bool): Whether to include single activation processing. Default: ``False``. + + Inputs: + - **act** (Tensor) - Pairwise activations tensor. + - **pair_mask** (Tensor) - Padding mask for pairwise activations. + - **single_act** (Tensor | None) - Single activations tensor, optional. + - **seq_mask** (Tensor | None) - Sequence mask, optional. + + Outputs: + - **act** (Tensor) - Processed pairwise activations tensor. + - **single_act** (Tensor) - Processed single activations tensor (if `with_single` is True). + """ + @dataclass + class Config(base_config.BaseConfig): + """Config for PairFormerIteration.""" + num_layer: int = 1 + pair_attention: GridSelfAttention.Config = base_config.autocreate() + pair_transition: TransitionBlock.Config = base_config.autocreate() + single_attention: diffusion_transformer.SelfAttentionConfig | None = base_config.autocreate() # None + single_transition: TransitionBlock.Config | None = base_config.autocreate() # None + triangle_multiplication_incoming: TriangleMultiplication.Config = ( + base_config.autocreate(equation='kjc,kic->ijc') + ) + triangle_multiplication_outgoing: TriangleMultiplication.Config = ( + base_config.autocreate(equation='ikc,jkc->ijc') + ) + shard_transition_blocks: bool = True + + def __init__(self, config, global_config, normalized_shape, single_shape=None, with_single=False, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.with_single = with_single + num_channel = normalized_shape[-1] + self.triangle_multiplication1 = TriangleMultiplication( + self.config.triangle_multiplication_outgoing, + self.global_config, + num_channel, + normalized_shape, + dtype=dtype + ) + self.triangle_multiplication2 = TriangleMultiplication( + self.config.triangle_multiplication_incoming, + self.global_config, + num_channel, + normalized_shape, + dtype=dtype + ) + self.grid_self_attention1 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + False, + normalized_shape, + dtype=dtype + ) + self.grid_self_attention2 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + True, + normalized_shape, + dtype=dtype + ) + self.transition_block = TransitionBlock( + self.config.pair_transition, self.global_config, normalized_shape, dtype=dtype + ) + num_residues = normalized_shape[0] + if self.config.shard_transition_blocks: + self.transition_block = mapping.sharded_apply( + self.transition_block, + get_shard_size( + num_residues, self.global_config.pair_transition_shard_spec + ) + ) + if self.with_single: + assert self.config.single_attention is not None + self.single_pair_logits_projection = bm.CustomDense( + num_channel, self.config.single_attention.num_head, ndim=3, dtype=dtype + ) + self.single_pair_logits_norm = bm.LayerNorm( + normalized_shape, dtype=ms.float32) + self.single_attention = diffusion_transformer.SelfAttention( + self.config.single_attention, self.global_config, single_shape[-1], normalized_shape, with_single_cond=False, dtype=dtype + ) + self.single_transition = TransitionBlock( + self.config.single_transition, + self.global_config, + single_shape, + 2, + dtype=dtype + ) + + def construct(self, act, pair_mask, single_act=None, seq_mask=None): + num_residues = act.shape[0] + act += self.triangle_multiplication1(act, pair_mask) + act += self.triangle_multiplication2(act, pair_mask) + act += self.grid_self_attention1(act, pair_mask) + act += self.grid_self_attention2(act, pair_mask) + act += self.transition_block(act) + if self.with_single: + norm_act = self.single_pair_logits_norm(act) + pair_logits = self.single_pair_logits_projection(norm_act) + pair_logits = pair_logits.transpose((2, 0, 1)) + single_act += self.single_attention( + single_act, seq_mask, None, pair_logits + ) + single_act += self.single_transition(single_act, + broadcast_dim=None) + return act, single_act + else: + return act + + +class EvoformerIteration(nn.Cell): + """ + EvoformerIteration is a single iteration of the Evoformer main stack, which processes + activations and masks through a series of attention and transformation layers to + update the MSA (Multiple Sequence Alignment) and pair representations. + + Args: + config (EvoformerIteration.Config): Configuration for the EvoformerIteration. + global_config (base_config.BaseConfig): Global configuration for the model. + act_shape (tuple): Shape of the activation tensor. + pair_shape (tuple): Shape of the pair tensor. + + Inputs: + - **activations** (dict): A dictionary containing the MSA and pair activations. + - **masks** (dict): A dictionary containing the MSA and pair masks. + + Outputs: + - **activations** (dict): A dictionary containing the updated MSA and pair activations. + """ + @dataclass + class Config(base_config.BaseConfig): + """Configuration for EvoformerIteration.""" + + num_layer: int = 4 + msa_attention: MSAAttention.Config = base_config.autocreate() + outer_product_mean: OuterProductMean.Config = base_config.autocreate() + msa_transition: TransitionBlock.Config = base_config.autocreate() + pair_attention: GridSelfAttention.Config = base_config.autocreate() + pair_transition: TransitionBlock.Config = base_config.autocreate() + triangle_multiplication_incoming: TriangleMultiplication.Config = ( + base_config.autocreate(equation='kjc,kic->ijc') + ) + triangle_multiplication_outgoing: TriangleMultiplication.Config = ( + base_config.autocreate(equation='ikc,jkc->ijc') + ) + shard_transition_blocks: bool = False + + def __init__(self, config, global_config, act_shape, pair_shape, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + num_channel = pair_shape[-1] + self.outer_product_mean = OuterProductMean( + config=self.config.outer_product_mean, + global_config=self.global_config, + num_output_channel=num_channel, + in_channel=act_shape[-1], + dtype=dtype + ) + self.msa_attention = MSAAttention(self.config.msa_attention, + self.global_config, act_shape, pair_shape, dtype=dtype) + self.msa_transition = TransitionBlock( + self.config.msa_transition, self.global_config, act_shape, dtype=dtype + ) + self.triangle_multiplication1 = TriangleMultiplication( + self.config.triangle_multiplication_outgoing, + self.global_config, + num_channel, + pair_shape, + dtype=dtype + ) + self.triangle_multiplication2 = TriangleMultiplication( + self.config.triangle_multiplication_incoming, + self.global_config, + num_channel, + pair_shape, + dtype=dtype + ) + self.pair_attention1 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + False, + pair_shape, + dtype=dtype + ) + self.pair_attention2 = GridSelfAttention( + self.config.pair_attention, + self.global_config, + True, + pair_shape, + dtype=dtype + ) + self.transition_block = TransitionBlock( + self.config.msa_transition, self.global_config, pair_shape, dtype=dtype + ) + num_residues = act_shape[0] + if self.config.shard_transition_blocks: + self.transition_block = mapping.sharded_apply( + self.transition_block, + get_shard_size( + num_residues, self.global_config.pair_transition_shard_spec + ) + ) + + def construct(self, activations, masks): + msa_act, pair_act = activations["msa"], activations["pair"] + msa_mask, pair_mask = masks['msa'], masks['pair'] + pair_act += self.outer_product_mean(msa_act, msa_mask) + msa_act += self.msa_attention(msa_act, msa_mask, pair_act) + msa_act += self.msa_transition(msa_act) + pair_act += self.triangle_multiplication1(pair_act, pair_mask) + pair_act += self.triangle_multiplication2(pair_act, pair_mask) + pair_act += self.pair_attention1(pair_act, pair_mask) + pair_act += self.pair_attention2(pair_act, pair_mask) + pair_act += self.transition_block(pair_act) + return {"msa": msa_act, "pair": pair_act} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/bias.npy new file mode 100644 index 0000000000000000000000000000000000000000..c7cd7468f857d0a2849491f4a3de779471901d38 GIT binary patch literal 1152 zcmbV}>rc~X0EK@ph{O#9Vi-Y?IuSZ_3Y8&B-*XxS>~L0e;w`n5(xL(Ztd&)*S}zk+ z)H>Zn0!9~vsu?nGt4d5W6T{UJMTePjRX0Rn${Z?UbYuU*o^R**bdu+wCSzXuvJq0T zv@yzPTAOW-nk9=;=EOx!kwxVcnl0w6g3Lm*(e$5B&B|YE8un{fWfhr*w_NdQ{FINU z$TrFTH>0Kc@b=1cuZ3T`}_&Y9a1iXZ{<<$vxYvE^SS$A9M8Y4g{$o^ z#DqK-%JeQIlvJVN@ON-dpTO&V3I^~01Tt(E_hO&H@S%kjHR0@dJ%oPOUeS=cjQmdiz^5x5q%xq1a&Ce??Ve>@r zr}NQV)LYDbHzqN(IE!W70W5m6O$7JGQrZ3p?wS_Jzn>4~%~O!M)(Pv%N0|C1lG3_l zWt}F0vQe|)wfz=^!je*rgS z#?gCjFms+1(Dw8n*jBg0TWh5AhcB=yuLW7v8lhZQGd7@w=7|B6U7n@#jY^oq;=VH~{Bfp4$obJ0vS{j92I+?P> zUbxcY>5|8y`)~&BO*fUHEj=iR9m7)pD%k2qGWmWzoP{X}?b`*F;cFrF*+ts!*_1E7 z4n=to%X1ZQpUM&sA6u!eTEnN#TeM!t|U(TeJq&)pJ)Q{D)Z#g8=b?qX_ K*TUf28vY9@O2t3` literal 0 HcmV?d00001 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/random/weight.npy new file mode 100644 index 0000000000000000000000000000000000000000..c595d2a5f23945d87a8589c43cb6879ee3a2de48 GIT binary patch literal 1152 zcmbV||3B1s9KchThevwM*9Se!a*G_>^>`H0$mjjKMOW$wY3h+p7k9o!xf50yN;nzS zs8p_eyVV1EOxnub>m^?sCbigDdziH5OS>#n8;jY$u;*{j=a0{`{=57F_82TWyy%D^ zLZ%2$6nOFlUUD~qkS~xYBq|f7@nH#x5wicfSQ@90Ep$bc^q_3Pom|$t3!R1hWBmWk zvh<;GHmJzrd`*gU+g(+7q}Tx*Z}sA-RJo`());~vH2C7Rp-2+1f%=FGP#920=E3tY zY5tJvm#Dy5(Wq@Q*$)MqU$Kl%H|At~k!h9wV6>SJ-$W*%Nyu?B;EAAbbUVh5o#*zM zz6W<#*;1*}g?vgzaJPJf^pRSx*(L#JRM)_oZ%?6q<|ru#52KZDE;WDj8AZiKvEG@R zutwSrRi!h!t8#z#wcSc^3Q1x$wQ*o^Wda7q!@-I=naMyiJ^1Y$+r(=Dqlp^y-jmE8 zjxHtnuZPgO>;UXYv!^WrJ8BgBLQ%aRU7Q?6o0A&2KAEG-3rvU7?p2hd5wOYSjU>#< z2ai^jHu~o?O!cWJiT#`*W%(=G{=gMxXOl_d@e-e@`{>%`38siMqXwCVRc+Lx=8sJ9 z{%$v%HoX96Rb%uyC2?0}D=0tiHUvKY3Js$57~i{`#F8EY4_9{OO&goeE}#srVhm3) zq_-=tQtJj4MAIC05kDqO$tPo-iFY7hFQ#sWv(k^Ba$p=_)YOm}3^PajH1acVWn zHidy#`HzqudJAS9J&3id>HVleuJe*R4Csn=9g(q=WF~BKRu1F3>IRmv!;-!{kqQ|z zrI>bTKZU86HvP5jCG%FdYM;gCu)XF!coPkqu(O=NvX-s~7>Y_Yjxg-`F`Mw2!1#Oi zta0cHuAOis(@KGWlFQ!Y zH&f4=NjzS+l3Qi9I34-0G=cjVoQyhgK_N{=8kpX*1q|68>9|H#Ubc zK?h)2K{=Ht?4o_#W8}s3vm%=;%A2}FMo(C?BsVxu-FeC{hGDa%zqn saH}A{n!9YsX)Rb=KLY!Eb6oW*YhA2WKJ4M0#z%?M7@p}0F2+Lo7aO`FEC2ui literal 0 HcmV?d00001 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py new file mode 100644 index 000000000..f4287f597 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/template_modules.py @@ -0,0 +1,347 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""template modules""" +from dataclasses import dataclass +import mindspore as ms +from mindspore import nn, ops, Tensor, mint + +from alphafold3.model import base_config +from alphafold3.constants import residue_names +from alphafold3.utils import geometry +from alphafold3.model import features +from alphafold3.model import model_config +from alphafold3.model import protein_data_processing +from alphafold3.model.components import base_modules as bm +from alphafold3.model.components import mapping +from alphafold3.model.diffusion import modules +from alphafold3.model.scoring import scoring + + +@dataclass +class DistogramFeaturesConfig(base_config.BaseConfig): + # The left edge of the first bin. + min_bin: float = 3.25 + # The left edge of the final bin. The final bin catches everything larger than + # `max_bin`. + max_bin: float = 50.75 + # The number of bins in the distogram. + num_bins: int = 39 + + +def dgram_from_positions(positions, config, dtype=ms.float32): + """Compute distogram from amino acid positions. + + Args: + positions: (num_res, 3) Position coordinates. + config: Distogram bin configuration. + + Returns: + Distogram with the specified number of bins. + """ + lower_breaks = mint.linspace( + config.min_bin, config.max_bin, config.num_bins) + lower_breaks = mint.square(lower_breaks) + upper_breaks = mint.concat( + [lower_breaks[1:], Tensor([1e8], dtype=ms.float32)], dim=-1) + dist2 = mint.sum(mint.square(ops.expand_dims(positions, axis=-2) + - ops.expand_dims(positions, axis=-3)), dim=-1, keepdim=True) + dgram = (dist2 > lower_breaks).astype(ms.float32) * \ + (dist2 < upper_breaks).astype(ms.float32) + return dgram + + +def slice_index(x, idx): + return ms.ops.gather_d(x, 1, idx.reshape(-1, 1)).squeeze() + + +def make_backbone_rigid(positions, mask, group_indices,): + """Make backbone Rigid3Array and mask. + + Args: + positions: (num_res, num_atoms) of atom positions as Vec3Array. + mask: (num_res, num_atoms) for atom mask. + group_indices: (num_res, num_group, 3) for atom indices forming groups. + + Returns: + tuple of backbone Rigid3Array and mask (num_res,). + """ + backbone_indices = group_indices[:, 0] + + # main backbone frames differ in sidechain frame convention. + # for sidechain it's (C, CA, N), for backbone it's (N, CA, C) + # Hence using c, b, a, each of shape (num_res,). + c, b, a = [backbone_indices[..., i] for i in range(3)] + + rigid_mask = slice_index(mask, a) * \ + slice_index(mask, b) * slice_index(mask, c) + frame_positions = [] + for indices in [a, b, c]: + frame_positions.append(geometry.vector.tree_map( + lambda x, idx=indices: slice_index(x, idx), positions + )) + rotation = geometry.Rot3Array.from_two_vectors( + frame_positions[2] - frame_positions[1], + frame_positions[0] - frame_positions[1], + ) + rigid = geometry.Rigid3Array(rotation, frame_positions[1]) + return rigid, rigid_mask + + +class TemplateEmbedding(nn.Cell): + """ + Embed a set of templates. + + Args: + config (TemplateEmbedding.Config): Configuration for the template embedding. + global_config (base_config.BaseConfig): Global configuration for the model. + num_templates (int): Number of templates to process. + normalized_shape (tuple): Shape of the normalized input tensor. + num_atoms (int): Number of atoms per residue. Default: ``24``. + + Inputs: + - **query_embedding** (Tensor) - Query tensor of shape [num_res, num_res, num_channel]. + - **templates** (Templates) - Object containing template data. + - **padding_mask_2d** (Tensor) - Pair mask for attention operations of shape [num_res, num_res]. + - **multichain_mask_2d** (Tensor) - Pair mask for multichain operations of shape [num_res, num_res]. + - **key** (int) - Random key generator. + + Outputs: + - **embedding** (Tensor) - Output embedding tensor of shape [num_res, num_res, num_channels]. + """ + @dataclass + class Config(base_config.BaseConfig): + num_channels: int = 64 + template_stack: modules.PairFormerIteration.Config = base_config.autocreate( + num_layer=2, + pair_transition=base_config.autocreate(num_intermediate_factor=2), + ) + dgram_features: DistogramFeaturesConfig = base_config.autocreate() + + def __init__(self, config, global_config, num_templates, normalized_shape, num_atoms=24, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_residues = normalized_shape[0] + self.num_templates = num_templates + self.query_num_channels = normalized_shape[2] + self.num_atoms = num_atoms + self.template_embedder = SingleTemplateEmbedding( + self.config, self.global_config, normalized_shape, dtype=dtype) + self.output_linear = bm.CustomDense( + self.config.num_channels, self.query_num_channels, ndim=3, dtype=dtype) + self.output_linear.weight = bm.custom_initializer( + 'relu', (self.config.num_channels, self.query_num_channels), dtype=dtype) + + def construct(self, query_embedding, templates, padding_mask_2d, + multichain_mask_2d, key): + """Generate an embedding for a set of templates. + + Args: + query_embedding: [num_res, num_res, num_channel] a query tensor that will + be used to attend over the templates to remove the num_templates + dimension. + templates: A 'Templates' object. + padding_mask_2d: [num_res, num_res] Pair mask for attention operations. + multichain_mask_2d: [num_res, num_res] Pair mask for multichain. + key: random key generator. + + Returns: + An embedding of size [num_res, num_res, num_channels] + """ + assert query_embedding.shape == ( + self.num_residues, + self.num_residues, + self.query_num_channels, + ) + assert templates.aatype.shape == ( + self.num_templates, self.num_residues) + assert templates.atom_positions.shape == ( + self.num_templates, + self.num_residues, + self.num_atoms, + 3, + ) + assert templates.atom_mask.shape == ( + self.num_templates, self.num_residues, self.num_atoms) + assert padding_mask_2d.shape == (self.num_residues, self.num_residues) + subkeys = mint.arange(key, key + self.num_templates, 1) + summed_template_embeddings = mint.zeros( + (self.num_residues, self.num_residues, + self.config.num_channels), dtype=query_embedding.dtype + ) + + def scan_fn(carry, x): + templates, key = x + embedding = self.template_embedder( + query_embedding, + templates, + padding_mask_2d, + multichain_mask_2d, + key, + ) + return carry + embedding + for i in range(len(subkeys)): + summed_template_embeddings = scan_fn( + summed_template_embeddings, (templates[i], subkeys[i])) + embedding = summed_template_embeddings / (1e-7 + self.num_templates) + embedding = mint.nn.functional.relu(embedding) + embedding = self.output_linear(embedding) + assert embedding.shape == ( + self.num_residues, self.num_residues, self.query_num_channels) + return embedding + + +class SingleTemplateEmbedding(nn.Cell): + """ + Embed a single template. + + Args: + config: Configuration object containing model parameters. + global_config: Global configuration object. + normalized_shape (tuple): Shape for normalization layers. + + Inputs: + - **query_embedding** (Tensor) - Query embedding tensor of shape (num_res, num_res, num_channels). + - **templates** (Templates object) - Object containing single template data. + - **padding_mask_2d** (Tensor) - Padding mask tensor. + - **multichain_mask_2d** (Tensor) - Mask indicating intra-chain residue pairs. + - **key** (random.KeyArray) - Random key generator. + + Outputs: + - **output** (Tensor) - Template embedding tensor of shape (num_res, num_res, num_channels). + """ + + def __init__( + self, + config, + global_config, + normalized_shape, + dtype=ms.float32 + ): + super().__init__() + self.config = config + self.global_config = global_config + num_channels = self.config.num_channels + self.query_embedding_norm = bm.LayerNorm( + normalized_shape, dtype=ms.float32) + + # to be deterimined the shape of input, output and number of layers + num_layers = 9 + in_shape_list = [39, (), 31, 31, (), (), (), (), 128] + ndim_list = [3, 2, 3, 3, 2, 2, 2, 2, 3] + self.template_pair_embedding = ms.nn.CellList( + [ + bm.CustomDense( + in_shape_list[i], num_channels, weight_init="relu", ndim=ndim_list[i], dtype=dtype + ) + for i in range(num_layers) + ] + ) + self.template_stack = ms.nn.CellList( + [ + modules.PairFormerIteration( + self.config.template_stack, self.global_config, normalized_shape[:-1] + ( + num_channels,), dtype=dtype + ) + for _ in range(self.config.template_stack.num_layer) + ] + ) + self.output_layer_norm = bm.LayerNorm( + normalized_shape[:-1] + (num_channels,), dtype=ms.float32) + + def construct(self, query_embedding, templates, padding_mask_2d, multichain_mask_2d, key): + assert padding_mask_2d.dtype == query_embedding.dtype + dtype = query_embedding.dtype + act = self.construct_input( + query_embedding, templates, multichain_mask_2d) + if self.config.template_stack.num_layer: + for i in range(self.config.template_stack.num_layer): + act = self.template_stack[i](act, padding_mask_2d) + act = self.output_layer_norm(act) + return act + + def construct_input(self, query_embedding, templates, multichain_mask_2d): + # Compute distogram feature for the template. + dtype = multichain_mask_2d.dtype + aatype = templates.aatype + dense_atom_mask = templates.atom_mask + dense_atom_positions = templates.atom_positions + dense_atom_positions *= dense_atom_mask[..., None] + pseudo_beta_positions, pseudo_beta_mask = [ms.Tensor(x) for x in scoring.pseudo_beta_fn( + templates.aatype, dense_atom_positions, dense_atom_mask + )] + pseudo_beta_mask_2d = ( + pseudo_beta_mask[:, None] * pseudo_beta_mask[None, :] + ) + pseudo_beta_mask_2d *= multichain_mask_2d + dgram = dgram_from_positions( + pseudo_beta_positions, self.config.dgram_features + ) + dgram *= pseudo_beta_mask_2d[..., None] + dgram = dgram.astype(dtype) + pseudo_beta_mask_2d = pseudo_beta_mask_2d.astype(dtype) + to_concat = [(dgram, 1), (pseudo_beta_mask_2d, 0)] + aatype = mint.nn.functional.one_hot( + aatype.astype(ms.int64), + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP, + ).astype(dtype) + to_concat.append((aatype[None, :, :].astype(dtype), 1)) + to_concat.append((aatype[:, None, :].astype(dtype), 1)) + template_group_indices = mint.index_select( + ms.Tensor(protein_data_processing.RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX), + 0, + templates.aatype, + ) + rigid, backbone_mask = make_backbone_rigid( + geometry.Vec3Array.from_array(dense_atom_positions), + dense_atom_mask, + template_group_indices.astype(ms.int32), + ) # rigid (256,) backbone_mask (256,) + points = rigid.translation + x = rigid.translation.x.unsqueeze(-1) + y = rigid.translation.y.unsqueeze(-1) + z = rigid.translation.z.unsqueeze(-1) + xx = rigid.rotation.xx.unsqueeze(-1) + xy = rigid.rotation.xy.unsqueeze(-1) + xz = rigid.rotation.xz.unsqueeze(-1) + yx = rigid.rotation.yx.unsqueeze(-1) + yy = rigid.rotation.yy.unsqueeze(-1) + yz = rigid.rotation.yz.unsqueeze(-1) + zx = rigid.rotation.zx.unsqueeze(-1) + zy = rigid.rotation.zy.unsqueeze(-1) + zz = rigid.rotation.zz.unsqueeze(-1) + rigid = geometry.Rigid3Array(geometry.Rot3Array( + xx, xy, xz, yx, yy, yz, zx, zy, zz), geometry.Vec3Array(x, y, z)) + rigid_vec = rigid.inverse().apply_to_point(points) + + unit_vector = rigid_vec.normalized() + unit_vector = [unit_vector.x, unit_vector.y, unit_vector.z] + unit_vector = [x.astype(dtype) for x in unit_vector] + backbone_mask = backbone_mask.astype(dtype) + + backbone_mask_2d = backbone_mask[:, None] * backbone_mask[None, :] + backbone_mask_2d *= multichain_mask_2d + unit_vector = [x * backbone_mask_2d for x in unit_vector] + + # Note that the backbone_mask takes into account C, CA and N (unlike + # pseudo beta mask which just needs CB) so we add both masks as features. + to_concat.extend([(x, 0) for x in unit_vector]) + to_concat.append((backbone_mask_2d, 0)) + query_embedding = self.query_embedding_norm(query_embedding) + # Allow the template embedder to see the query embedding. Note this + # contains the position relative feature, so this is how the network knows + # which residues are next to each other. + to_concat.append((query_embedding, 1)) + + act = 0 + for i, (x, n_input_dims) in enumerate(to_concat): + act += self.template_pair_embedding[i](x) + return act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py new file mode 100644 index 000000000..016122105 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/diffusion/triangle.py @@ -0,0 +1,258 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Triangle""" +import numpy as np +import mindspore as ms +import mindspore.nn as nn +import mindspore.common.dtype as mstype +from mindspore import Parameter, mint +from mindspore.common.tensor import Tensor +import mindspore.ops as ops +from mindspore.ops import operations as P +from mindspore.common.initializer import initializer +from mindsponge.common.utils import _memory_reduce +from mindsponge.cell.basic import Attention +from mindsponge.cell.initializer import lecun_init +from mindsponge.cell.mask import MaskedLayerNorm +from mindchemistry.e3.utils import Ncon + +from alphafold3.utils.gated_linear_unit import gated_linear_unit +from alphafold3.model.components.base_modules import LayerNorm, CustomDense + + +class TriangleMultiplication(nn.Cell): + r""" + Triangle multiplication layer. for the detailed implementation process, refer to + `TriangleMultiplication `_. + + The information between the amino acid pair is integrated through the information of three edges ij, ik, jk, and + the result of the dot product between ik and jk is added to the edge of ij. + + Args: + num_intermediate_channel (float): The number of intermediate channel. + equation (str): The equation used in triangle multiplication layer. edge update forms + corresponding to 'incoming' and 'outgoing', + :math:`(ikc,jkc->ijc, kjc,kic->ijc)`. + layer_norm_dim (int): The last dimension length of the layer norm. + batch_size (int): The batch size of parameters in triangle multiplication. Default: ``None``. + + Inputs: + - **pair_act** (Tensor) - Tensor of pair_act. shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + - **pair_mask** (Tensor) - The mask for TriangleAttention matrix with shape. shape :math:`(N{res}, N{res})`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. + + Outputs: + Tensor, the float tensor of the pair_act of the layer with shape :math:`(N{res}, N{res}, layer\_norm\_dim)`. + + Supported Platforms: + ``Ascend`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import TriangleMultiplication + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> model = TriangleMultiplication(num_intermediate_channel=64, + ... equation="ikc,jkc->ijc", layer_norm_dim=64, batch_size=0) + >>> input_0 = Tensor(np.ones((256, 256, 64)), mstype.float32) + >>> input_1 = Tensor(np.ones((256, 256)), mstype.float32) + >>> out = model(input_0, input_1, index=0) + >>> print(out.shape) + (256, 256, 64) + """ + + def __init__(self, config, global_config, num_intermediate_channel, equation, normalized_shape, batch_size=None, dtype=ms.float32): + super().__init__() + self.config = config + self.global_config = global_config + self.num_intermediate_channel = num_intermediate_channel + self.left_norm_input = LayerNorm(normalized_shape, dtype=ms.float32) + self.center_norm = LayerNorm(normalized_shape, dtype=ms.float32) + self.projection = nn.Dense( + normalized_shape[-1], num_intermediate_channel * 2, has_bias=False, dtype=dtype) + self.gate = nn.Dense(normalized_shape[-1], num_intermediate_channel * 2, + weight_init=self.global_config.final_init, has_bias=False, dtype=dtype) + self.output_projection = CustomDense( + normalized_shape[-1], num_intermediate_channel, weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.gating_linear = CustomDense( + num_intermediate_channel, num_intermediate_channel, weight_init=self.global_config.final_init, ndim=3, dtype=dtype) + self.weight_glu = mint.stack( + [self.gate.weight.T, self.projection.weight.T], dim=1) + if self.config.equation == "ikc,jkc->ijc": + ncon_list = [[-1, -2, 1], [-1, -3, 1]] + elif self.config.equation == "kjc,kic->ijc": + ncon_list = [[-1, 1, -3], [-1, 1, -2]] + else: + raise ValueError("Not support this equation.") + self.ncon = Ncon(ncon_list) + + def construct(self, act, mask, use_glu=True): + r""" + Builds triangle multiplication module. + + Args: + act(Tensor): Pair activations. Data type is float. + mask(Tensor): Pair mask. Data type is float. + + Returns: + act(Tensor), the shape is same as act_shape[:-1]. + """ + self.weight_glu = mint.stack( + [self.gate.weight.T, self.projection.weight.T], dim=1) + + mask = mask[None, ...] + act = self.left_norm_input(act) + input_act = act + + if use_glu == True: + projection = gated_linear_unit.gated_linear_unit( + x=act, + weight=self.weight_glu, + activation=ms.mint.sigmoid, + implementation=None, + precision=None, + ) + projection = ops.transpose(projection, (2, 0, 1)) + projection *= mask + else: + projection = self.projection(act) + projection = ms.ops.transpose(projection, (2, 0, 1)) + projection *= mask + gate = self.gate(act) + gate = ms.ops.transpose(gate, (2, 0, 1)) + projection *= ms.mint.sigmoid(gate) + projection = projection.reshape( + self.num_intermediate_channel, 2, *projection.shape[1:]) + a, b = projection[:, 0], projection[:, 1] + act = self.ncon([a, b]) + act = self.center_norm(act.transpose((1, 2, 0))) + act = self.output_projection(act) + gate_out = self.gating_linear(input_act) + act *= mint.sigmoid(gate_out.astype(ms.float32)).astype(gate_out.dtype) + return act + + +class OuterProductMean(nn.Cell): + r""" + Computing the correlation of the input tensor along its second dimension, the computed correlation + could be used to update the correlation features(e.g. the Pair representation). + + .. math:: + OuterProductMean(\mathbf{act}) = Linear(flatten(mean(\mathbf{act}\otimes\mathbf{act}))) + + Args: + num_outer_channel (float): The last dimension size of intermediate layer in OuterProductMean. + act_dim (int): The last dimension size of the input act. + num_output_channel (int): The last dimension size of output. + batch_size(int): The batch size of parameters in OuterProductMean, + used in while control flow. Default: "None". + slice_num (int): The slice num used in OuterProductMean layer + when the memory is overflow. Default: 0. + + Inputs: + - **act** (Tensor) - The input tensor with shape :math:`(dim_1, dim_2, act\_dim)`. + - **mask** (Tensor) - The mask for OuterProductMean with shape :math:`(dim_1, dim_2)`. + - **mask_norm** (Tensor) - Squared L2-norm along the first dimension of **mask**, + pre-computed to avoid re-computing, its shape is :math:`(dim_2, dim_2, 1)`. + - **index** (Tensor) - The index of while loop, only used in case of while control + flow. Default: "None". + + Outputs: + Tensor, the float tensor of the output of OuterProductMean layer with + shape :math:`(dim_2, dim_2, num\_output\_channel)`. + + Supported Platforms: + ``Ascend`` ``GPU`` + + Examples: + >>> import numpy as np + >>> from mindsponge.cell import OuterProductMean + >>> from mindspore import dtype as mstype + >>> from mindspore import Tensor + >>> from mindspore.ops import operations as P + >>> model = OuterProductMean(num_outer_channel=32, act_dim=128, num_output_channel=256) + >>> act = Tensor(np.ones((32, 64, 128)), mstype.float32) + >>> mask = Tensor(np.ones((32, 64)), mstype.float32) + >>> mask_norm = P.ExpandDims()(P.MatMul(transpose_a=True)(mask, mask), -1) + >>> output= model(act, mask, mask_norm) + >>> print(output.shape) + (64, 64, 256) + """ + + def __init__(self, num_outer_channel, act_dim, num_output_channel, batch_size=None, slice_num=0, dtype=ms.float32): + super(OuterProductMean, self).__init__() + self.dtype = dtype + self.num_output_channel = num_output_channel + self.num_outer_channel = num_outer_channel + self.layer_norm_input = MaskedLayerNorm() + self.matmul_trans_b = P.MatMul(transpose_b=True) + self.matmul = P.MatMul() + self.batch_matmul_trans_b = P.BatchMatMul(transpose_b=True) + self.act_dim = act_dim + self.batch_size = batch_size + self.slice_num = slice_num + self.idx = Tensor(0, mstype.int32) + self._init_parameter() + + def construct(self, act, mask, mask_norm, index=None): + """Compute outer product mean.""" + mask = P.ExpandDims()(mask, -1) + act = self.layer_norm_input( + act, self.layer_norm_input_gamma, self.layer_norm_input_beta) + act_shape = P.Shape()(act) + if len(act_shape) != 2: + act = P.Reshape()(act, (-1, act_shape[-1])) + out_shape = act_shape[:-1] + (-1,) + left_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, self.left_projection_weight), self.left_projection_bias), out_shape) + right_act = mask * P.Reshape()( + P.BiasAdd()(self.matmul_trans_b(act, self.right_projection_weight), self.right_projection_bias), out_shape) + a, d, e = right_act.shape + batched_inputs = (left_act,) + nonbatched_inputs = (right_act, self.linear_output_weight, + self.o_biases, d, e) + act = _memory_reduce(self._compute, batched_inputs, + nonbatched_inputs, self.slice_num, 1) + epsilon = 1e-3 + act = P.RealDiv()(act, epsilon + mask_norm) + return act + + def _init_parameter(self): + '''init parameter''' + self.layer_norm_input_gamma = Parameter( + Tensor(np.ones((self.act_dim)), self.dtype)) + self.layer_norm_input_beta = Parameter( + Tensor(np.zeros((self.act_dim)), self.dtype)) + self.left_projection_weight = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim], self.dtype)) + self.left_projection_bias = Tensor( + np.zeros((self.num_outer_channel)), self.dtype) + self.right_projection_weight = Parameter( + initializer(lecun_init(self.act_dim), [self.num_outer_channel, self.act_dim], self.dtype)) + self.right_projection_bias = Tensor( + np.zeros((self.num_outer_channel)), self.dtype) + self.linear_output_weight = Parameter( + Tensor(np.zeros((self.num_outer_channel, self.num_outer_channel, self.num_output_channel)), + self.dtype)) + self.o_biases = Parameter( + Tensor(np.zeros((self.num_output_channel)), self.dtype)) + + def _compute(self, left_act, right_act, linear_output_weight, linear_output_bias, d, e): + '''compute outer product mean''' + + a, b, c = left_act.shape + left_act = left_act.transpose((0, 2, 1)) + act = Ncon([[1, -2, -4], [1, -1, -3]])([left_act, right_act]) + act = Ncon([[-1, 1, 2, -2], [1, 2, -3]] + )([act, linear_output_weight]) + linear_output_bias + act = P.Transpose()(act, (1, 0, 2)) + return act diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py new file mode 100644 index 000000000..516a0d1ea --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/feat_batch.py @@ -0,0 +1,180 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Batch dataclass.""" + +import dataclasses +import mindspore as ms +from typing import Self +from mindspore import Tensor +from alphafold3.model import features + + +@dataclasses.dataclass +class Batch: + """Dataclass containing batch.""" + + msa: features.MSA + templates: features.Templates + token_features: features.TokenFeatures + ref_structure: features.RefStructure + predicted_structure_info: features.PredictedStructureInfo + polymer_ligand_bond_info: features.PolymerLigandBondInfo + ligand_ligand_bond_info: features.LigandLigandBondInfo + pseudo_beta_info: features.PseudoBetaInfo + atom_cross_att: features.AtomCrossAtt + convert_model_output: features.ConvertModelOutput + frames: features.Frames + + @property + def num_res(self) -> int: + return self.token_features.aatype.shape[-1] + + @staticmethod + def gather_to_tensor(input_feat): + input_feat.gather_idxs = Tensor(input_feat.gather_idxs) + input_feat.gather_mask = Tensor(input_feat.gather_mask) + input_feat.input_shape = Tensor(input_feat.input_shape) + + @classmethod + def from_data_dict(cls, batch: features.BatchDict) -> Self: + """Construct batch object from dictionary.""" + return cls( + msa=features.MSA.from_data_dict(batch), + templates=features.Templates.from_data_dict(batch), + token_features=features.TokenFeatures.from_data_dict(batch), + ref_structure=features.RefStructure.from_data_dict(batch), + predicted_structure_info=features.PredictedStructureInfo.from_data_dict( + batch + ), + polymer_ligand_bond_info=features.PolymerLigandBondInfo.from_data_dict( + batch + ), + ligand_ligand_bond_info=features.LigandLigandBondInfo.from_data_dict( + batch + ), + pseudo_beta_info=features.PseudoBetaInfo.from_data_dict(batch), + atom_cross_att=features.AtomCrossAtt.from_data_dict(batch), + convert_model_output=features.ConvertModelOutput.from_data_dict( + batch), + frames=features.Frames.from_data_dict(batch), + ) + + def as_data_dict(self) -> features.BatchDict: + """Converts batch object to dictionary.""" + output = { + **self.msa.as_data_dict(), + **self.templates.as_data_dict(), + **self.token_features.as_data_dict(), + **self.ref_structure.as_data_dict(), + **self.predicted_structure_info.as_data_dict(), + **self.polymer_ligand_bond_info.as_data_dict(), + **self.ligand_ligand_bond_info.as_data_dict(), + **self.pseudo_beta_info.as_data_dict(), + **self.atom_cross_att.as_data_dict(), + **self.convert_model_output.as_data_dict(), + **self.frames.as_data_dict(), + } + return output + + def convert_to_tensor(self, dtype=ms.float32): + # msa: features.MSA + self.msa.rows = Tensor(self.msa.rows, dtype=ms.int32) + self.msa.mask = Tensor(self.msa.mask, dtype=ms.int32) + self.msa.deletion_matrix = Tensor( + self.msa.deletion_matrix, dtype=dtype) + self.msa.deletion_mean = Tensor(self.msa.deletion_mean, dtype=dtype) + self.msa.profile = Tensor(self.msa.profile, dtype=dtype) + self.msa.num_alignments = Tensor( + self.msa.num_alignments, dtype=ms.int32) + # templates: features.Templates + self.templates.aatype = Tensor(self.templates.aatype, dtype=ms.int32) + self.templates.atom_mask = Tensor( + self.templates.atom_mask, dtype=ms.int32) + self.templates.atom_positions = Tensor( + self.templates.atom_positions, dtype=dtype) + # token_features: features.TokenFeatures + self.token_features.mask = Tensor( + self.token_features.mask, dtype=ms.int32) + self.token_features.token_index = Tensor( + self.token_features.mask, dtype=ms.int32) + self.token_features.asym_id = Tensor( + self.token_features.asym_id, dtype=ms.int32) + self.token_features.aatype = Tensor( + self.token_features.aatype, dtype=ms.int32) + self.token_features.residue_index = Tensor( + self.token_features.residue_index, dtype=ms.int32) + self.token_features.entity_id = Tensor( + self.token_features.entity_id, dtype=ms.int32) + self.token_features.sym_id = Tensor( + self.token_features.sym_id, dtype=ms.int32) + # ref_structure: features.RefStructure + self.ref_structure.positions = Tensor( + self.ref_structure.positions, dtype=dtype) + self.ref_structure.mask = Tensor(self.ref_structure.mask, dtype=dtype) + self.ref_structure.element = Tensor( + self.ref_structure.element, dtype=ms.int32) + self.ref_structure.charge = Tensor( + self.ref_structure.charge, dtype=dtype) + self.ref_structure.atom_name_chars = Tensor( + self.ref_structure.atom_name_chars, dtype=ms.int32) + self.ref_structure.ref_space_uid = Tensor( + self.ref_structure.ref_space_uid, dtype=dtype) + + # predicted_structure_info: features.PredictedStructureInfo + self.predicted_structure_info.atom_mask = Tensor( + self.predicted_structure_info.atom_mask, dtype=dtype) + + # polymer_ligand_bond_info: features.PolymerLigandBondInfo + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_idxs = Tensor( + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_idxs, dtype=ms.int32 + ) + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_mask = Tensor( + self.polymer_ligand_bond_info.tokens_to_polymer_ligand_bonds.gather_mask, dtype=ms.int32 + ) + # ligand_ligand_bond_info: features.LigandLigandBondInfo + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_idxs = Tensor( + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_idxs, dtype=ms.int32 + ) + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_mask = Tensor( + self.ligand_ligand_bond_info.tokens_to_ligand_ligand_bonds.gather_mask, dtype=ms.int32 + ) + + self.gather_to_tensor(self.pseudo_beta_info.token_atoms_to_pseudo_beta) + self.gather_to_tensor(self.atom_cross_att.queries_to_keys) + self.gather_to_tensor(self.atom_cross_att.tokens_to_queries) + self.gather_to_tensor(self.atom_cross_att.tokens_to_keys) + self.gather_to_tensor(self.atom_cross_att.token_atoms_to_queries) + self.gather_to_tensor(self.atom_cross_att.queries_to_token_atoms) + + # frames: features.Frames + + def astype(self, dtype=ms.float32): + # change dtype of float + # msa: features.MSA + self.msa.deletion_matrix = self.msa.deletion_matrix.astype(dtype) + self.msa.deletion_mean = self.msa.deletion_mean.astype(dtype) + self.msa.profile = self.msa.profile.astype(dtype) + # templates: features.Templates + self.templates.atom_positions = self.templates.atom_positions.astype( + dtype) + # ref_structure: features.RefStructure + self.ref_structure.positions = self.ref_structure.positions.astype( + dtype) + self.ref_structure.mask = self.ref_structure.mask.astype(dtype) + self.ref_structure.charge = self.ref_structure.charge.astype(dtype) + self.ref_structure.ref_space_uid = self.ref_structure.ref_space_uid.astype( + dtype) + + # predicted_structure_info: features.PredictedStructureInfo + self.predicted_structure_info.atom_mask = self.predicted_structure_info.atom_mask.astype( + dtype) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py new file mode 100644 index 000000000..dffa93925 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/features.py @@ -0,0 +1,2101 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +"""Data-side of the input features processing.""" + +import dataclasses +import datetime +import itertools +import numpy as np +from typing_extensions import Any, Self, TypeAlias +from absl import logging +from alphafold3 import structure +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import periodic_table +from alphafold3.constants import residue_names +from alphafold3.data import msa as msa_module +from alphafold3.data import templates +from alphafold3.data.tools import rdkit_utils +from alphafold3.model import data3 +from alphafold3.model import data_constants +from alphafold3.model import merging_features +from alphafold3.model import msa_pairing +from alphafold3.model.atom_layout import atom_layout +from alphafold3.structure import chemical_components as struc_chem_comps +from rdkit import Chem +from rdkit.Chem import AllChem + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +BatchDict: TypeAlias = dict[str, xnp_ndarray] + +_STANDARD_RESIDUES = frozenset({ + *residue_names.PROTEIN_TYPES_WITH_UNKNOWN, + *residue_names.NUCLEIC_TYPES_WITH_2_UNKS, +}) + + +@dataclasses.dataclass +class PaddingShapes: + num_tokens: int + msa_size: int + num_chains: int + num_templates: int + num_atoms: int + + +def _pad_to( + arr: np.ndarray, shape: tuple[int | None, ...], **kwargs +) -> np.ndarray: + """Pads an array to a given shape. Wrapper around np.pad(). + + Args: + arr: numpy array to pad + shape: target shape, use None for axes that should stay the same + **kwargs: additional args for np.pad, e.g. constant_values=-1 + + Returns: + the padded array + + Raises: + ValueError if arr and shape have a different number of axes. + """ + if arr.ndim != len(shape): + raise ValueError( + f'arr and shape have different number of axes. {arr.shape=}, {shape=}' + ) + + num_pad = [] + for axis, width in enumerate(shape): + if width is None: + num_pad.append((0, 0)) + else: + if width >= arr.shape[axis]: + num_pad.append((0, width - arr.shape[axis])) + else: + raise ValueError( + f'Can not pad to a smaller shape. {arr.shape=}, {shape=}' + ) + padded_arr = np.pad(arr, pad_width=num_pad, **kwargs) + return padded_arr + + +def _unwrap(obj): + """Unwrap an object from a zero-dim np.ndarray.""" + if isinstance(obj, np.ndarray) and obj.ndim == 0: + return obj.item() + else: + return obj + + +@dataclasses.dataclass +class Chains: + chain_id: np.ndarray + asym_id: np.ndarray + entity_id: np.ndarray + sym_id: np.ndarray + + +def _compute_asym_entity_and_sym_id( + all_tokens: atom_layout.AtomLayout, +) -> Chains: + """Compute asym_id, entity_id and sym_id. + + Args: + all_tokens: atom layout containing a representative atom for each token. + + Returns: + A Chains object + """ + + # Find identical sequences and assign entity_id and sym_id to every chain. + seq_to_entity_id_sym_id = {} + seen_chain_ids = set() + chain_ids = [] + asym_ids = [] + entity_ids = [] + sym_ids = [] + for chain_id in all_tokens.chain_id: + if chain_id not in seen_chain_ids: + asym_id = len(seen_chain_ids) + 1 + seen_chain_ids.add(chain_id) + seq = ','.join( + all_tokens.res_name[all_tokens.chain_id == chain_id]) + if seq not in seq_to_entity_id_sym_id: + entity_id = len(seq_to_entity_id_sym_id) + 1 + sym_id = 1 + else: + entity_id, sym_id = seq_to_entity_id_sym_id[seq] + sym_id += 1 + seq_to_entity_id_sym_id[seq] = (entity_id, sym_id) + + chain_ids.append(chain_id) + asym_ids.append(asym_id) + entity_ids.append(entity_id) + sym_ids.append(sym_id) + + return Chains( + chain_id=np.array(chain_ids), + asym_id=np.array(asym_ids), + entity_id=np.array(entity_ids), + sym_id=np.array(sym_ids), + ) + + +def tokenizer( + flat_output_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + max_atoms_per_token: int, + flatten_non_standard_residues: bool, + logging_name: str, +) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout, np.ndarray]: + """Maps a flat atom layout to tokens for evoformer. + + Creates the evoformer tokens as one token per polymer residue and one token + per ligand atom. The tokens are represented as AtomLayouts all_tokens + (1 representative atom per token) atoms per residue, and + all_token_atoms_layout (num_tokens, max_atoms_per_token). The atoms in a + residue token use the layout of the corresponding CCD entry + + Args: + flat_output_layout: flat AtomLayout containing all atoms that the model + wants to predict. + ccd: The chemical components dictionary. + max_atoms_per_token: number of slots per token. + flatten_non_standard_residues: whether to flatten non-standard residues, + i.e. whether to use one token per atom for non-standard residues. + logging_name: logging name for debugging (usually the mmcif_id). + + Returns: + A tuple (all_tokens, all_tokens_atoms_layout) with + all_tokens: AtomLayout shape (num_tokens,) containing one representative + atom per token. + all_token_atoms_layout: AtomLayout with shape + (num_tokens, max_atoms_per_token) containing all atoms per token. + standard_token_idxs: The token index that each token would have if not + flattening non standard resiudes. + """ + # Select the representative atom for each token. + token_idxs = [] + single_atom_token = [] + standard_token_idxs = [] + current_standard_token_id = 0 + # Iterate over residues, and provide a group_iter over the atoms of each + # residue. + for key, group_iter in itertools.groupby( + zip( + flat_output_layout.chain_type, + flat_output_layout.chain_id, + flat_output_layout.res_id, + flat_output_layout.res_name, + flat_output_layout.atom_name, + np.arange(flat_output_layout.shape[0]), + ), + key=lambda x: x[:3], + ): + + # Get chain type and chain id of this residue + chain_type, chain_id, _ = key + + # Get names and global idxs for all atoms of this residue + _, _, _, res_names, atom_names, idxs = zip(*group_iter) + + # As of March 2023, all OTHER CHAINs in pdb are artificial nucleics. + is_nucleic_backbone = ( + chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ) + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + res_name = res_names[0] + if ( + flatten_non_standard_residues + and res_name not in residue_names.PROTEIN_TYPES_WITH_UNKNOWN + and res_name != residue_names.MSE + ): + # For non-standard protein residues take all atoms. + # NOTE: This may get very large if we include hydrogens. + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend( + [current_standard_token_id] * len(idxs)) + else: + # For standard protein residues take 'CA' if it exists, else first atom. + if 'CA' in atom_names: + token_idxs.append(idxs[atom_names.index('CA')]) + else: + token_idxs.append(idxs[0]) + single_atom_token += [False] + standard_token_idxs.append(current_standard_token_id) + current_standard_token_id += 1 + elif is_nucleic_backbone: + res_name = res_names[0] + if ( + flatten_non_standard_residues + and res_name not in residue_names.NUCLEIC_TYPES_WITH_2_UNKS + ): + # For non-standard nucleic residues take all atoms. + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend( + [current_standard_token_id] * len(idxs)) + else: + # For standard nucleic residues take C1' if it exists, else first atom. + if "C1'" in atom_names: + token_idxs.append(idxs[atom_names.index("C1'")]) + else: + token_idxs.append(idxs[0]) + single_atom_token += [False] + standard_token_idxs.append(current_standard_token_id) + current_standard_token_id += 1 + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + # For non-polymers take all atoms + token_idxs.extend(idxs) + single_atom_token += [True] * len(idxs) + standard_token_idxs.extend([current_standard_token_id] * len(idxs)) + current_standard_token_id += len(idxs) + else: + # Chain type that we don't handle yet. + logging.warning( + '%s: ignoring chain %s with chain type %s.', + logging_name, + chain_id, + chain_type, + ) + + assert len(token_idxs) == len(single_atom_token) + assert len(token_idxs) == len(standard_token_idxs) + standard_token_idxs = np.array(standard_token_idxs, dtype=np.int32) + + # Create the list of all tokens, represented as a flat AtomLayout with 1 + # representative atom per token. + all_tokens = flat_output_layout[token_idxs] + + # Create the 2D atoms_per_token layout + num_tokens = all_tokens.shape[0] + + # Target lists. + target_atom_names = [] + target_atom_elements = [] + target_res_ids = [] + target_res_names = [] + target_chain_ids = [] + target_chain_types = [] + + # uids of all atoms in the flat layout, to check whether the dense atoms + # exist -- This is necessary for terminal atoms (e.g. 'OP3' or 'OXT') + all_atoms_uids = set( + zip( + flat_output_layout.chain_id, + flat_output_layout.res_id, + flat_output_layout.atom_name, + ) + ) + + for idx, single_atom in enumerate(single_atom_token): + if not single_atom: + # Standard protein and nucleic residues have many atoms per token + chain_id = all_tokens.chain_id[idx] + res_id = all_tokens.res_id[idx] + res_name = all_tokens.res_name[idx] + atom_names = [] + atom_elements = [] + + res_atoms = struc_chem_comps.get_all_atoms_in_entry( + ccd=ccd, res_name=res_name + ) + atom_names_elements = list( + zip( + res_atoms['_chem_comp_atom.atom_id'], + res_atoms['_chem_comp_atom.type_symbol'], + strict=True, + ) + ) + + for atom_name, atom_element in atom_names_elements: + # Remove hydrogens if they are not in flat layout. + if atom_element in ['H', 'D'] and ( + (chain_id, res_id, atom_name) not in all_atoms_uids + ): + continue + elif (chain_id, res_id, atom_name) in all_atoms_uids: + atom_names.append(atom_name) + atom_elements.append(atom_element) + # Leave spaces for OXT etc. + else: + atom_names.append('') + atom_elements.append('') + + if len(atom_names) > max_atoms_per_token: + logging.warning( + 'Atom list for chain %s ' + 'residue %s %s is too long and will be truncated: ' + '%s to the max atoms limit %s. Dropped atoms: %s', + chain_id, + res_id, + res_name, + len(atom_names), + max_atoms_per_token, + list( + zip( + atom_names[max_atoms_per_token:], + atom_elements[max_atoms_per_token:], + strict=True, + ) + ), + ) + atom_names = atom_names[:max_atoms_per_token] + atom_elements = atom_elements[:max_atoms_per_token] + + num_pad = max_atoms_per_token - len(atom_names) + atom_names.extend([''] * num_pad) + atom_elements.extend([''] * num_pad) + + else: + # ligands have only 1 atom per token + padding = [''] * (max_atoms_per_token - 1) + atom_names = [all_tokens.atom_name[idx]] + padding + atom_elements = [all_tokens.atom_element[idx]] + padding + + # Append the atoms to the target lists. + target_atom_names.append(atom_names) + target_atom_elements.append(atom_elements) + target_res_names.append( + [all_tokens.res_name[idx]] * max_atoms_per_token) + target_res_ids.append([all_tokens.res_id[idx]] * max_atoms_per_token) + target_chain_ids.append( + [all_tokens.chain_id[idx]] * max_atoms_per_token) + target_chain_types.append( + [all_tokens.chain_type[idx]] * max_atoms_per_token + ) + + # Make sure to get the right shape also for 0 tokens + trg_shape = (num_tokens, max_atoms_per_token) + all_token_atoms_layout = atom_layout.AtomLayout( + atom_name=np.array(target_atom_names, dtype=object).reshape(trg_shape), + atom_element=np.array(target_atom_elements, dtype=object).reshape( + trg_shape + ), + res_name=np.array(target_res_names, dtype=object).reshape(trg_shape), + res_id=np.array(target_res_ids, dtype=int).reshape(trg_shape), + chain_id=np.array(target_chain_ids, dtype=object).reshape(trg_shape), + chain_type=np.array(target_chain_types, + dtype=object).reshape(trg_shape), + ) + + return all_tokens, all_token_atoms_layout, standard_token_idxs + + +@dataclasses.dataclass +class MSA: + """Dataclass containing MSA.""" + + rows: xnp_ndarray + mask: xnp_ndarray + deletion_matrix: xnp_ndarray + # Occurrence of each residue type along the sequence, averaged over MSA rows. + profile: xnp_ndarray + # Occurrence of deletions along the sequence, averaged over MSA rows. + deletion_mean: xnp_ndarray + # Number of MSA alignments. + num_alignments: xnp_ndarray + + @classmethod + def compute_features( + cls, + *, + all_tokens: atom_layout.AtomLayout, + standard_token_idxs: np.ndarray, + padding_shapes: PaddingShapes, + fold_input: folding_input.Input, + logging_name: str, + max_paired_sequence_per_species: int, + ) -> Self: + """Compute the msa features.""" + seen_entities = {} + + substruct = atom_layout.make_structure( + flat_layout=all_tokens, + atom_coords=np.zeros(all_tokens.shape + (3,)), + name=logging_name, + ) + prot = substruct.filter_to_entity_type(protein=True) + num_unique_chains = len( + set(prot.chain_single_letter_sequence().values())) + need_msa_pairing = num_unique_chains > 1 + + np_chains_list = [] + input_chains_by_id = {chain.id: chain for chain in fold_input.chains} + nonempty_chain_ids = set(all_tokens.chain_id) + for asym_id, chain_info in enumerate(substruct.iter_chains(), start=1): + b_chain_id = chain_info['chain_id'] + chain_type = chain_info['chain_type'] + chain = input_chains_by_id[b_chain_id] + + # Generalised "sequence" for ligands (can't trust residue name) + chain_tokens = all_tokens[all_tokens.chain_id == b_chain_id] + assert chain_tokens.res_name is not None + three_letter_sequence = ','.join(chain_tokens.res_name.tolist()) + chain_num_tokens = len(chain_tokens.atom_name) + if chain_type in mmcif_names.POLYMER_CHAIN_TYPES: + sequence = substruct.chain_single_letter_sequence()[b_chain_id] + if chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + # Only allow nucleic residue types for nucleic chains (can have some + # protein residues in e.g. tRNA, but that causes MSA search failures). + # Replace non nucleic residue types by UNK_NUCLEIC. + nucleic_types_one_letter = ( + residue_names.DNA_TYPES_ONE_LETTER + + residue_names.RNA_TYPES_ONE_LETTER_WITH_UNKNOWN + ) + sequence = ''.join([ + base + if base in nucleic_types_one_letter + else residue_names.UNK_NUCLEIC_ONE_LETTER + for base in sequence + ]) + else: + sequence = 'X' * chain_num_tokens + + skip_chain = ( + chain_type not in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + or len(sequence) <= 4 + or b_chain_id not in nonempty_chain_ids + ) + if three_letter_sequence in seen_entities: + entity_id = seen_entities[three_letter_sequence] + else: + entity_id = len(seen_entities) + 1 + + if chain_type in mmcif_names.STANDARD_POLYMER_CHAIN_TYPES: + unpaired_a3m = '' + paired_a3m = '' + if not skip_chain: + if need_msa_pairing and isinstance(chain, folding_input.ProteinChain): + paired_a3m = chain.paired_msa + if isinstance( + chain, folding_input.RnaChain | folding_input.ProteinChain + ): + unpaired_a3m = chain.unpaired_msa + unpaired_msa = msa_module.Msa.from_a3m( + query_sequence=sequence, + chain_poly_type=chain_type, + a3m=unpaired_a3m, + deduplicate=True, + ) + + paired_msa = msa_module.Msa.from_a3m( + query_sequence=sequence, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + a3m=paired_a3m, + deduplicate=False, + ) + else: + unpaired_msa = msa_module.Msa.from_empty( + query_sequence='-' * len(sequence), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + paired_msa = msa_module.Msa.from_empty( + query_sequence='-' * len(sequence), + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + ) + + msa_features = unpaired_msa.featurize() + all_seqs_msa_features = paired_msa.featurize() + + msa_features = data3.fix_features(msa_features) + all_seqs_msa_features = data3.fix_features(all_seqs_msa_features) + + msa_features = msa_features | { + f'{k}_all_seq': v for k, v in all_seqs_msa_features.items() + } + feats = msa_features + feats['chain_id'] = b_chain_id + feats['asym_id'] = np.full(chain_num_tokens, asym_id) + feats['entity_id'] = entity_id + np_chains_list.append(feats) + + # Add profile features to each chain. + for chain in np_chains_list: + chain.update( + data3.get_profile_features( + chain['msa'], chain['deletion_matrix']) + ) + + # Allow 50% of the MSA to come from MSA pairing. + max_paired_sequences = padding_shapes.msa_size // 2 + if need_msa_pairing: + np_chains_list = list(map(dict, np_chains_list)) + np_chains_list = msa_pairing.create_paired_features( + np_chains_list, + max_paired_sequences=max_paired_sequences, + nonempty_chain_ids=nonempty_chain_ids, + max_hits_per_species=max_paired_sequence_per_species, + ) + np_chains_list = msa_pairing.deduplicate_unpaired_sequences( + np_chains_list + ) + + # Remove all gapped rows from all seqs. + nonempty_asym_ids = [] + for chain in np_chains_list: + if chain['chain_id'] in nonempty_chain_ids: + nonempty_asym_ids.append(chain['asym_id'][0]) + if 'msa_all_seq' in np_chains_list[0]: + np_chains_list = msa_pairing.remove_all_gapped_rows_from_all_seqs( + np_chains_list, asym_ids=nonempty_asym_ids + ) + + # Crop MSA rows. + cropped_chains_list = [] + for chain in np_chains_list: + unpaired_msa_size, paired_msa_size = ( + msa_pairing.choose_paired_unpaired_msa_crop_sizes( + unpaired_msa=chain['msa'], + paired_msa=chain.get('msa_all_seq'), + total_msa_crop_size=padding_shapes.msa_size, + max_paired_sequences=max_paired_sequences, + ) + ) + cropped_chain = { + 'asym_id': chain['asym_id'], + 'chain_id': chain['chain_id'], + 'profile': chain['profile'], + 'deletion_mean': chain['deletion_mean'], + } + for feat in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + if feat in chain: + cropped_chain[feat] = chain[feat][:unpaired_msa_size] + if feat + '_all_seq' in chain: + cropped_chain[feat + '_all_seq'] = chain[feat + '_all_seq'][ + :paired_msa_size + ] + cropped_chains_list.append(cropped_chain) + + # Merge Chains. + # Make sure the chain order is unaltered before slicing with tokens. + curr_chain_order = [chain['chain_id'] for chain in cropped_chains_list] + orig_chain_order = [chain['chain_id'] + for chain in substruct.iter_chains()] + assert curr_chain_order == orig_chain_order + np_example = { + 'asym_id': np.concatenate( + [c['asym_id'] for c in cropped_chains_list], axis=0 + ), + } + for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + for feat in [feature, feature + '_all_seq']: + if feat in cropped_chains_list[0]: + np_example[feat] = merging_features.merge_msa_features( + feat, cropped_chains_list + ) + for feature in ['profile', 'deletion_mean']: + feature_list = [c[feature] for c in cropped_chains_list] + np_example[feature] = np.concatenate(feature_list, axis=0) + + # Crop MSA rows to maximum size given by chains participating in the crop. + max_allowed_unpaired = max([ + len(chain['msa']) + for chain in cropped_chains_list + if chain['asym_id'][0] in nonempty_asym_ids + ]) + np_example['msa'] = np_example['msa'][:max_allowed_unpaired] + if 'msa_all_seq' in np_example: + max_allowed_paired = max([ + len(chain['msa_all_seq']) + for chain in cropped_chains_list + if chain['asym_id'][0] in nonempty_asym_ids + ]) + np_example['msa_all_seq'] = np_example['msa_all_seq'][:max_allowed_paired] + + np_example = merging_features.merge_paired_and_unpaired_msa(np_example) + + # Crop MSA residues. Need to use the standard token indices, since msa does + # not expand non-standard residues. This means that for expanded residues, + # we get repeated msa columns. + new_cropping_idxs = standard_token_idxs + for feature in data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES: + if feature in np_example: + np_example[feature] = np_example[feature][:, + new_cropping_idxs].copy() + for feature in ['profile', 'deletion_mean']: + np_example[feature] = np_example[feature][new_cropping_idxs] + + # Make MSA mask. + np_example['msa_mask'] = np.ones_like( + np_example['msa'], dtype=np.float32) + + # Count MSA size before padding. + num_alignments = np_example['msa'].shape[0] + + # Pad: + msa_size, num_tokens = padding_shapes.msa_size, padding_shapes.num_tokens + + def safe_cast_int8(x): + return np.clip(x, np.iinfo(np.int8).min, np.iinfo(np.int8).max).astype( + np.int8 + ) + + return MSA( + rows=_pad_to(safe_cast_int8( + np_example['msa']), (msa_size, num_tokens)), + mask=_pad_to( + np_example['msa_mask'].astype(bool), (msa_size, num_tokens) + ), + # deletion_matrix may be out of int8 range, but we mostly care about + # small values since we arctan it in the model. + deletion_matrix=_pad_to( + safe_cast_int8(np_example['deletion_matrix']), + (msa_size, num_tokens), + ), + profile=_pad_to(np_example['profile'], (num_tokens, None)), + deletion_mean=_pad_to(np_example['deletion_mean'], (num_tokens,)), + num_alignments=np.array(num_alignments, dtype=np.int32), + ) + + def index_msa_rows(self, indices: xnp_ndarray) -> Self: + assert indices.ndim == 1 + + return MSA( + rows=self.rows[indices, :], + mask=self.mask[indices, :], + deletion_matrix=self.deletion_matrix[indices, :], + profile=self.profile, + deletion_mean=self.deletion_mean, + num_alignments=self.num_alignments, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + output = cls( + rows=batch['msa'], + mask=batch['msa_mask'], + deletion_matrix=batch['deletion_matrix'], + profile=batch['profile'], + deletion_mean=batch['deletion_mean'], + num_alignments=batch['num_alignments'], + ) + return output + + def as_data_dict(self) -> BatchDict: + return { + 'msa': self.rows, + 'msa_mask': self.mask, + 'deletion_matrix': self.deletion_matrix, + 'profile': self.profile, + 'deletion_mean': self.deletion_mean, + 'num_alignments': self.num_alignments, + } + + +@dataclasses.dataclass +class Templates: + """Dataclass containing templates.""" + + # aatype of templates, int32 w shape [num_templates, num_res] + aatype: xnp_ndarray + # atom positions of templates, float32 w shape [num_templates, num_res, 24, 3] + atom_positions: xnp_ndarray + # atom mask of templates, bool w shape [num_templates, num_res, 24] + atom_mask: xnp_ndarray + def __getitem__(self, idx): + return Templates(self.aatype[idx], self.atom_positions[idx], self.atom_mask[idx]) + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + standard_token_idxs: np.ndarray, + padding_shapes: PaddingShapes, + fold_input: folding_input.Input, + max_templates: int, + logging_name: str, + ) -> Self: + """Compute the template features.""" + + seen_entities = {} + polymer_entity_features = {True: {}, False: {}} + + substruct = atom_layout.make_structure( + flat_layout=all_tokens, + atom_coords=np.zeros(all_tokens.shape + (3,)), + name=logging_name, + ) + np_chains_list = [] + + input_chains_by_id = {chain.id: chain for chain in fold_input.chains} + + nonempty_chain_ids = set(all_tokens.chain_id) + for chain_info in substruct.iter_chains(): + chain_id = chain_info['chain_id'] + chain_type = chain_info['chain_type'] + chain = input_chains_by_id[chain_id] + + # Generalised "sequence" for ligands (can't trust residue name) + chain_tokens = all_tokens[all_tokens.chain_id == chain_id] + assert chain_tokens.res_name is not None + three_letter_sequence = ','.join(chain_tokens.res_name.tolist()) + chain_num_tokens = len(chain_tokens.atom_name) + + # Don't compute features for chains not included in the crop, or ligands. + skip_chain = ( + chain_type != mmcif_names.PROTEIN_CHAIN + or chain_num_tokens <= 4 # not cache filled + or chain_id not in nonempty_chain_ids + ) + + if three_letter_sequence in seen_entities: + entity_id = seen_entities[three_letter_sequence] + else: + entity_id = len(seen_entities) + 1 + + if entity_id not in polymer_entity_features[skip_chain]: + if skip_chain: + template_features = data3.empty_template_features( + chain_num_tokens) + else: + assert isinstance(chain, folding_input.ProteinChain) + + sorted_features = [] + for template in chain.templates: + struc = structure.from_mmcif( + template.mmcif, + fix_mse_residues=True, + fix_arginines=True, + include_bonds=False, + include_water=False, + # For non-standard polymer chains. + include_other=True, + ) + hit_features = templates.get_polymer_features( + chain=struc, + chain_poly_type=mmcif_names.PROTEIN_CHAIN, + query_sequence_length=len(chain.sequence), + query_to_hit_mapping=dict( + template.query_to_template_map), + ) + sorted_features.append(hit_features) + + template_features = templates.package_template_features( + hit_features=sorted_features, + include_ligand_features=False, + ) + + template_features = data3.fix_template_features( + sequence=chain.sequence, + template_features=template_features, + ) + + template_features = _reduce_template_features( + template_features, max_templates + ) + polymer_entity_features[skip_chain][entity_id] = template_features + + seen_entities[three_letter_sequence] = entity_id + feats = polymer_entity_features[skip_chain][entity_id].copy() + feats['chain_id'] = chain_id + np_chains_list.append(feats) + + # We pad the num_templates dimension before merging, so that different + # chains can be concatenated on the num_res dimension. Masking will be + # applied so that each chains templates can't see each other. + for chain in np_chains_list: + chain['template_aatype'] = _pad_to( + chain['template_aatype'], (max_templates, None) + ) + chain['template_atom_positions'] = _pad_to( + chain['template_atom_positions'], ( + max_templates, None, None, None) + ) + chain['template_atom_mask'] = _pad_to( + chain['template_atom_mask'], (max_templates, None, None) + ) + + # Merge on token dimension. + np_example = { + ft: np.concatenate([c[ft] for c in np_chains_list], axis=1) + for ft in np_chains_list[0] + if ft in data_constants.TEMPLATE_FEATURES + } + + # Crop template data. Need to use the standard token indices, since msa does + # not expand non-standard residues. This means that for expanded residues, + # we get repeated template information. + for feature_name, v in np_example.items(): + np_example[feature_name] = v[:max_templates, + standard_token_idxs, ...] + + # Pad along the token dimension. + templates_features = Templates( + aatype=_pad_to( + np_example['template_aatype'], (None, + padding_shapes.num_tokens) + ), + atom_positions=_pad_to( + np_example['template_atom_positions'], + (None, padding_shapes.num_tokens, None, None), + ), + atom_mask=_pad_to( + np_example['template_atom_mask'].astype(bool), + (None, padding_shapes.num_tokens, None), + ), + ) + return templates_features + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + """Make Template from batch dictionary.""" + return cls( + aatype=batch['template_aatype'], + atom_positions=batch['template_atom_positions'], + atom_mask=batch['template_atom_mask'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'template_aatype': self.aatype, + 'template_atom_positions': self.atom_positions, + 'template_atom_mask': self.atom_mask, + } + + +def _reduce_template_features( + template_features: data3.FeatureDict, + max_templates: int, +) -> data3.FeatureDict: + """Reduces template features to max num templates and defined feature set.""" + num_templates = template_features['template_aatype'].shape[0] + template_keep_mask = np.arange(num_templates) < max_templates + template_fields = data_constants.TEMPLATE_FEATURES + ( + 'template_release_timestamp', + ) + template_features = { + k: v[template_keep_mask] + for k, v in template_features.items() + if k in template_fields + } + return template_features + + +@dataclasses.dataclass +class TokenFeatures: + """Dataclass containing features for tokens.""" + + residue_index: xnp_ndarray + token_index: xnp_ndarray + aatype: xnp_ndarray + mask: xnp_ndarray + seq_length: xnp_ndarray + + # Chain symmetry identifiers + # for an A3B2 stoichiometry the meaning of these features is as follows: + # asym_id: 1 2 3 4 5 + # entity_id: 1 1 1 2 2 + # sym_id: 1 2 3 1 2 + asym_id: xnp_ndarray + entity_id: xnp_ndarray + sym_id: xnp_ndarray + + # token type features + is_protein: xnp_ndarray + is_rna: xnp_ndarray + is_dna: xnp_ndarray + is_ligand: xnp_ndarray + is_nonstandard_polymer_chain: xnp_ndarray + is_water: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + ) -> Self: + """Compute the per-token features.""" + + residue_index = all_tokens.res_id.astype(np.int32) + + token_index = np.arange( + 1, len(all_tokens.atom_name) + 1).astype(np.int32) + + aatype = [] + for res_name, chain_type in zip(all_tokens.res_name, all_tokens.chain_type): + if chain_type in mmcif_names.POLYMER_CHAIN_TYPES: + res_name = mmcif_names.fix_non_standard_polymer_res( + res_name=res_name, chain_type=chain_type + ) + if ( + chain_type == mmcif_names.DNA_CHAIN + and res_name == residue_names.UNK_DNA + ): + res_name = residue_names.UNK_NUCLEIC_ONE_LETTER + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + res_name = residue_names.UNK + else: + raise ValueError( + f'Chain type {chain_type} not polymer or ligand.') + aa = residue_names.POLYMER_TYPES_ORDER_WITH_UNKNOWN_AND_GAP[res_name] + aatype.append(aa) + aatype = np.array(aatype, dtype=np.int32) + + mask = np.ones(all_tokens.shape[0], dtype=bool) + chains = _compute_asym_entity_and_sym_id(all_tokens) + m = dict(zip(chains.chain_id, chains.asym_id)) + asym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32) + + m = dict(zip(chains.chain_id, chains.entity_id)) + entity_id = np.array([m[c] + for c in all_tokens.chain_id], dtype=np.int32) + + m = dict(zip(chains.chain_id, chains.sym_id)) + sym_id = np.array([m[c] for c in all_tokens.chain_id], dtype=np.int32) + + seq_length = np.array(all_tokens.shape[0], dtype=np.int32) + + is_protein = all_tokens.chain_type == mmcif_names.PROTEIN_CHAIN + is_rna = all_tokens.chain_type == mmcif_names.RNA_CHAIN + is_dna = all_tokens.chain_type == mmcif_names.DNA_CHAIN + is_ligand = np.isin( + all_tokens.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES) + ) + standard_polymer_chain = list(mmcif_names.NON_POLYMER_CHAIN_TYPES) + list( + mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + ) + is_nonstandard_polymer_chain = np.isin( + all_tokens.chain_type, standard_polymer_chain, invert=True + ) + is_water = all_tokens.chain_type == mmcif_names.WATER + + return TokenFeatures( + residue_index=_pad_to(residue_index, (padding_shapes.num_tokens,)), + token_index=_pad_to(token_index, (padding_shapes.num_tokens,)), + aatype=_pad_to(aatype, (padding_shapes.num_tokens,)), + mask=_pad_to(mask, (padding_shapes.num_tokens,)), + asym_id=_pad_to(asym_id, (padding_shapes.num_tokens,)), + entity_id=_pad_to(entity_id, (padding_shapes.num_tokens,)), + sym_id=_pad_to(sym_id, (padding_shapes.num_tokens,)), + seq_length=seq_length, + is_protein=_pad_to(is_protein, (padding_shapes.num_tokens,)), + is_rna=_pad_to(is_rna, (padding_shapes.num_tokens,)), + is_dna=_pad_to(is_dna, (padding_shapes.num_tokens,)), + is_ligand=_pad_to(is_ligand, (padding_shapes.num_tokens,)), + is_nonstandard_polymer_chain=_pad_to( + is_nonstandard_polymer_chain, (padding_shapes.num_tokens,) + ), + is_water=_pad_to(is_water, (padding_shapes.num_tokens,)), + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + residue_index=batch['residue_index'], + token_index=batch['token_index'], + aatype=batch['aatype'], + mask=batch['seq_mask'], + entity_id=batch['entity_id'], + asym_id=batch['asym_id'], + sym_id=batch['sym_id'], + seq_length=batch['seq_length'], + is_protein=batch['is_protein'], + is_rna=batch['is_rna'], + is_dna=batch['is_dna'], + is_ligand=batch['is_ligand'], + is_nonstandard_polymer_chain=batch['is_nonstandard_polymer_chain'], + is_water=batch['is_water'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'residue_index': self.residue_index, + 'token_index': self.token_index, + 'aatype': self.aatype, + 'seq_mask': self.mask, + 'entity_id': self.entity_id, + 'asym_id': self.asym_id, + 'sym_id': self.sym_id, + 'seq_length': self.seq_length, + 'is_protein': self.is_protein, + 'is_rna': self.is_rna, + 'is_dna': self.is_dna, + 'is_ligand': self.is_ligand, + 'is_nonstandard_polymer_chain': self.is_nonstandard_polymer_chain, + 'is_water': self.is_water, + } + + +@dataclasses.dataclass +class PredictedStructureInfo: + """Contains information necessary to work with predicted structure.""" + + atom_mask: xnp_ndarray + residue_center_index: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + ) -> Self: + """Compute the PredictedStructureInfo features. + + Args: + all_tokens: flat AtomLayout with 1 representative atom per token, shape + (num_tokens,) + all_token_atoms_layout: AtomLayout for all atoms per token, shape + (num_tokens, max_atoms_per_token) + padding_shapes: padding shapes. + + Returns: + A PredictedStructureInfo object. + """ + atom_mask = _pad_to( + all_token_atoms_layout.atom_name.astype(bool), + (padding_shapes.num_tokens, None), + ) + residue_center_index = np.zeros( + padding_shapes.num_tokens, dtype=np.int32) + for idx in range(all_tokens.shape[0]): + repr_atom = all_tokens.atom_name[idx] + atoms = list(all_token_atoms_layout.atom_name[idx, :]) + if repr_atom in atoms: + residue_center_index[idx] = atoms.index(repr_atom) + else: + # Representative atoms can be missing if cropping the number of atoms + # per residue. + logging.warning( + 'The representative atom in all_tokens (%s) is not in ' + 'all_token_atoms_layout (%s)', + all_tokens[idx: idx + 1], + all_token_atoms_layout[idx, :], + ) + residue_center_index[idx] = 0 + return cls(atom_mask=atom_mask, residue_center_index=residue_center_index) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + atom_mask=batch['pred_dense_atom_mask'], + residue_center_index=batch['residue_center_index'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'pred_dense_atom_mask': self.atom_mask, + 'residue_center_index': self.residue_center_index, + } + + +@dataclasses.dataclass +class PolymerLigandBondInfo: + """Contains information about polymer-ligand bonds.""" + + tokens_to_polymer_ligand_bonds: atom_layout.GatherInfo + # Gather indices to convert from cropped dense atom layout to bonds layout + # (num_tokens, 2) + token_atoms_to_bonds: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + bond_layout: atom_layout.AtomLayout | None, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes the InterChainBondInfo features. + + Args: + all_tokens: AtomLayout for tokens; shape (num_tokens,). + all_token_atoms_layout: Atom Layout for all atoms (num_tokens, + max_atoms_per_token) + bond_layout: Bond layout for polymer-ligand bonds. + padding_shapes: Padding shapes. + + Returns: + A PolymerLigandBondInfo object. + """ + + if bond_layout is not None: + # Must convert to list before calling np.isin, will not work raw. + peptide_types = list(mmcif_names.PEPTIDE_CHAIN_TYPES) + nucleic_types = list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES) + [ + mmcif_names.OTHER_CHAIN + ] + # These atom renames are so that we can use the atom layout code with + # all_tokens, which only has a single atom per token. + atom_names = bond_layout.atom_name.copy() + atom_names[np.isin(bond_layout.chain_type, peptide_types)] = 'CA' + atom_names[np.isin(bond_layout.chain_type, nucleic_types)] = "C1'" + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=atom_names, + res_id=bond_layout.res_id, + chain_id=bond_layout.chain_id, + chain_type=bond_layout.chain_type, + ) + # Remove bonds that are not in the crop. + cropped_tokens_to_bonds = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + bond_is_in_crop = np.all( + cropped_tokens_to_bonds.gather_mask, axis=1 + ).astype(bool) + adjusted_bond_layout = adjusted_bond_layout[bond_is_in_crop, :] + else: + # Create layout with correct shape when bond_layout is None. + s = (0, 2) + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=np.array([], dtype=object).reshape(s), + res_id=np.array([], dtype=int).reshape(s), + chain_id=np.array([], dtype=object).reshape(s), + ) + adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens, 2) + ) + tokens_to_polymer_ligand_bonds = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + + # Stuff for computing the bond loss. + if bond_layout is not None: + # Pad to num_tokens (hoping that there are never more bonds than tokens). + padded_bond_layout = bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens, 2) + ) + token_atoms_to_bonds = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=padded_bond_layout + ) + else: + token_atoms_to_bonds = atom_layout.GatherInfo( + gather_idxs=np.zeros( + (padding_shapes.num_tokens, 2), dtype=int), + gather_mask=np.zeros( + (padding_shapes.num_tokens, 2), dtype=bool), + input_shape=np.array(( + padding_shapes.num_tokens, + all_token_atoms_layout.shape[1], + )), + ) + + return cls( + tokens_to_polymer_ligand_bonds=tokens_to_polymer_ligand_bonds, + token_atoms_to_bonds=token_atoms_to_bonds, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + tokens_to_polymer_ligand_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_polymer_ligand_bonds' + ), + token_atoms_to_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_polymer_ligand_bonds' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.tokens_to_polymer_ligand_bonds.as_dict( + key_prefix='tokens_to_polymer_ligand_bonds' + ), + **self.token_atoms_to_bonds.as_dict( + key_prefix='token_atoms_to_polymer_ligand_bonds' + ), + } + + +@dataclasses.dataclass +class LigandLigandBondInfo: + """Contains information about the location of ligand-ligand bonds.""" + + tokens_to_ligand_ligand_bonds: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + bond_layout: atom_layout.AtomLayout | None, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes the InterChainBondInfo features. + + Args: + all_tokens: AtomLayout for tokens; shape (num_tokens,). + bond_layout: Bond layout for ligand-ligand bonds. + padding_shapes: Padding shapes. + + Returns: + A LigandLigandBondInfo object. + """ + + if bond_layout is not None: + # Discard any bonds that do not join to an existing atom. + keep_mask = [] + all_atom_ids = { + uid + for uid in zip( + all_tokens.chain_id, + all_tokens.res_id, + all_tokens.atom_name, + strict=True, + ) + } + for chain_id, res_id, atom_name in zip( + bond_layout.chain_id, + bond_layout.res_id, + bond_layout.atom_name, + strict=True, + ): + atom_a = (chain_id[0], res_id[0], atom_name[0]) + atom_b = (chain_id[1], res_id[1], atom_name[1]) + if atom_a in all_atom_ids and atom_b in all_atom_ids: + keep_mask.append(True) + else: + keep_mask.append(False) + keep_mask = np.array(keep_mask).astype(bool) + bond_layout = bond_layout[keep_mask] + # Remove any bonds to Hydrogen atoms. + bond_layout = bond_layout[ + ~np.char.startswith(bond_layout.atom_name.astype(str), 'H').any( + axis=1 + ) + ] + atom_names = bond_layout.atom_name + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=atom_names, + res_id=bond_layout.res_id, + chain_id=bond_layout.chain_id, + chain_type=bond_layout.chain_type, + ) + else: + # Create layout with correct shape when bond_layout is None. + s = (0, 2) + adjusted_bond_layout = atom_layout.AtomLayout( + atom_name=np.array([], dtype=object).reshape(s), + res_id=np.array([], dtype=int).reshape(s), + chain_id=np.array([], dtype=object).reshape(s), + ) + # 10 x num_tokens as max_inter_bonds_ratio + max_intra_bonds_ration = 2.061. + adjusted_bond_layout = adjusted_bond_layout.copy_and_pad_to( + (padding_shapes.num_tokens * 10, 2) + ) + gather_idx = atom_layout.compute_gather_idxs( + source_layout=all_tokens, target_layout=adjusted_bond_layout + ) + return cls(tokens_to_ligand_ligand_bonds=gather_idx) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + tokens_to_ligand_ligand_bonds=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_ligand_ligand_bonds' + ) + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.tokens_to_ligand_ligand_bonds.as_dict( + key_prefix='tokens_to_ligand_ligand_bonds' + ) + } + + +@dataclasses.dataclass +class PseudoBetaInfo: + """Contains information for extracting pseudo-beta and equivalent atoms.""" + + token_atoms_to_pseudo_beta: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + padding_shapes: PaddingShapes, + logging_name: str, + ) -> Self: + """Compute the PseudoBetaInfo features. + + Args: + all_token_atoms_layout: AtomLayout for all atoms per token, shape + (num_tokens, max_atoms_per_token) + ccd: The chemical components dictionary. + padding_shapes: padding shapes. + logging_name: logging name for debugging (usually the mmcif_id) + + Returns: + A PseudoBetaInfo object. + """ + token_idxs = [] + atom_idxs = [] + for token_idx in range(all_token_atoms_layout.shape[0]): + chain_type = all_token_atoms_layout.chain_type[token_idx, 0] + atom_names = list(all_token_atoms_layout.atom_name[token_idx, :]) + atom_idx = None + is_nucleic_backbone = ( + chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ) + if chain_type == mmcif_names.PROTEIN_CHAIN: + # Protein chains + if 'CB' in atom_names: + atom_idx = atom_names.index('CB') + elif 'CA' in atom_names: + atom_idx = atom_names.index('CA') + elif is_nucleic_backbone: + # RNA / DNA chains + res_name = all_token_atoms_layout.res_name[token_idx, 0] + cifdict = ccd.get(res_name) + if cifdict: + parent = cifdict['_chem_comp.mon_nstd_parent_comp_id'][0] + if parent != '?': + res_name = parent + if res_name in {'A', 'G', 'DA', 'DG'}: + if 'C4' in atom_names: + atom_idx = atom_names.index('C4') + else: + if 'C2' in atom_names: + atom_idx = atom_names.index('C2') + elif chain_type in mmcif_names.NON_POLYMER_CHAIN_TYPES: + # Ligands: there is only one atom per token + atom_idx = 0 + else: + logging.warning( + '%s: Unknown chain type for token %i. (%s)', + logging_name, + token_idx, + all_token_atoms_layout[token_idx: token_idx + 1], + ) + atom_idx = 0 + if atom_idx is None: + (valid_atom_idxs,) = np.nonzero( + all_token_atoms_layout.atom_name[token_idx, :] + ) + if valid_atom_idxs.shape[0] > 0: + atom_idx = valid_atom_idxs[0] + else: + atom_idx = 0 + logging.warning( + '%s token %i (%s), does not contain a pseudo-beta atom.' + 'Using first valid atom (%s) instead.', + logging_name, + token_idx, + all_token_atoms_layout[token_idx: token_idx + 1], + all_token_atoms_layout.atom_name[token_idx, atom_idx], + ) + + token_idxs.append(token_idx) + atom_idxs.append(atom_idx) + + pseudo_beta_layout = all_token_atoms_layout[token_idxs, atom_idxs] + pseudo_beta_layout = pseudo_beta_layout.copy_and_pad_to(( + padding_shapes.num_tokens, + )) + token_atoms_to_pseudo_beta = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=pseudo_beta_layout + ) + + return cls( + token_atoms_to_pseudo_beta=token_atoms_to_pseudo_beta, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + token_atoms_to_pseudo_beta=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_pseudo_beta' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.token_atoms_to_pseudo_beta.as_dict( + key_prefix='token_atoms_to_pseudo_beta' + ), + } + + +_DEFAULT_BLANK_REF = { + 'positions': np.zeros(3), + 'mask': 0, + 'element': 0, + 'charge': 0, + 'atom_name_chars': np.zeros(4), +} + + +def random_rotation(random_state: np.random.RandomState) -> np.ndarray: + # Create a random rotation (Gram-Schmidt orthogonalization of two + # random normal vectors) + v0, v1 = random_state.normal(size=(2, 3)) + e0 = v0 / np.maximum(1e-10, np.linalg.norm(v0)) + v1 = v1 - e0 * np.dot(v1, e0) + e1 = v1 / np.maximum(1e-10, np.linalg.norm(v1)) + e2 = np.cross(e0, e1) + return np.stack([e0, e1, e2]) + + +def random_augmentation( + positions: np.ndarray, + random_state: np.random.RandomState, +) -> np.ndarray: + """Center then apply random translation and rotation.""" + + center = np.mean(positions, axis=0) + rot = random_rotation(random_state) + positions_target = np.einsum('ij,kj->ki', rot, positions - center) + + translation = random_state.normal(size=(3,)) + positions_target = positions_target + translation + return positions_target + + +def get_reference( + res_name: str, + chemical_components_data: struc_chem_comps.ChemicalComponentsData, + ccd: chemical_components.Ccd, + random_state: np.random.RandomState, + ref_max_modified_date: datetime.date, + intra_ligand_ptm_bonds: bool, +) -> tuple[dict[str, Any], Any, Any]: + """Reference structure for residue from CCD or SMILES. + + Args: + res_name: ccd code of the residue. + chemical_components_data: ChemicalComponentsData for making ref structure. + ccd: The chemical components dictionary. + random_state: Numpy RandomState + ref_max_modified_date: date beyond which reference structures must not be + modefied. + intra_ligand_ptm_bonds: Whether to return intra ligand/ ptm bonds. + + Returns: + Mapping from atom names to features, from_atoms, dest_atoms. + """ + ccd_cif = ccd.get(res_name) + non_ccd_with_smiles = False + if not ccd_cif: + # If res name is non-CCD try to get SMILES from chem comp dict. + has_smiles = ( + chemical_components_data.chem_comp + and res_name in chemical_components_data.chem_comp + and chemical_components_data.chem_comp[res_name].pdbx_smiles + ) + if has_smiles: + non_ccd_with_smiles = True + else: + # If no SMILES or CCD, return empty dictionary. + return dict(), None, None + + pos = [] + elements = [] + charges = [] + atom_names = [] + + mol_from_smiles = None # useless init to make pylint happy + if non_ccd_with_smiles: + smiles_string = chemical_components_data.chem_comp[res_name].pdbx_smiles + mol_from_smiles = Chem.MolFromSmiles(smiles_string) + if mol_from_smiles is None: + logging.warning( + 'Fail to construct RDKit Mol from the SMILES string: %s', + smiles_string, + ) + return dict(), None, None + # Note this does not contain ideal coordinates, just bonds. + ccd_cif = rdkit_utils.mol_to_ccd_cif( + mol_from_smiles, component_id='fake_cif' + ) + + # RDKit for non-CCD structure and if ref should be a random RDKit conformer. + try: + if non_ccd_with_smiles: + m = mol_from_smiles + m = Chem.AddHs(m) + m = rdkit_utils.assign_atom_names_from_graph( + m, keep_existing_names=True) + logging.info( + 'Success constructing SMILES reference structure for: %s', res_name + ) + else: + m = rdkit_utils.mol_from_ccd_cif(ccd_cif, remove_hydrogens=False) + # Stochastic conformer search method. + # V3 is the latest and supports macrocycles . + params = AllChem.ETKDGv3() + params.randomSeed = int(random_state.randint(1, 1 << 31)) + AllChem.EmbedMolecule(m, params) + conformer = m.GetConformer() + for i, atom in enumerate(m.GetAtoms()): + elements.append(atom.GetAtomicNum()) + charges.append(atom.GetFormalCharge()) + name = atom.GetProp('atom_name') + atom_names.append(name) + coords = conformer.GetAtomPosition(i) + pos.append([coords.x, coords.y, coords.z]) + pos = np.array(pos, dtype=np.float32) + except (rdkit_utils.MolFromMmcifError, ValueError): + logging.warning( + 'Failed to construct RDKit reference structure for: %s', res_name + ) + + if not atom_names: + # Get CCD ideal coordinates if RDKit fails. + atom_names = ccd_cif['_chem_comp_atom.atom_id'] + # If mol_from_smiles then it won't have ideal coordinates by default. + if '_chem_comp_atom.pdbx_model_Cartn_x_ideal' in ccd_cif: + atom_x = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_x_ideal'] + atom_y = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_y_ideal'] + atom_z = ccd_cif['_chem_comp_atom.pdbx_model_Cartn_z_ideal'] + else: + atom_x = np.array(['?'] * len(atom_names)) + atom_y = np.array(['?'] * len(atom_names)) + atom_z = np.array(['?'] * len(atom_names)) + type_symbols = ccd_cif['_chem_comp_atom.type_symbol'] + charges = ccd_cif['_chem_comp_atom.charge'] + elements = [ + periodic_table.ATOMIC_NUMBER.get(elem_type.capitalize(), 0) + for elem_type in type_symbols + ] + pos = np.array([[x, y, z] for x, y, z in zip(atom_x, atom_y, atom_z)]) + # Unknown reference coordinates are specified by '?' in chem comp dict. + # Replace unknown reference coords with 0. + if '?' in pos and '_chem_comp.pdbx_modified_date' in ccd_cif: + # Use reference coordinates if modifed date is before cutoff. + modified_dates = [ + datetime.date.fromisoformat(date) + for date in ccd_cif['_chem_comp.pdbx_modified_date'] + ] + max_modified_date = max(modified_dates) + if max_modified_date < ref_max_modified_date: + atom_x = ccd_cif['_chem_comp_atom.model_Cartn_x'] + atom_y = ccd_cif['_chem_comp_atom.model_Cartn_y'] + atom_z = ccd_cif['_chem_comp_atom.model_Cartn_z'] + pos = np.array([[x, y, z] + for x, y, z in zip(atom_x, atom_y, atom_z)]) + if '?' in pos: + if np.all(pos == '?'): + logging.warning('All ref positions unknown for: %s', res_name) + else: + logging.warning('Some ref positions unknown for: %s', res_name) + pos[pos == '?'] = 0 + pos = np.array(pos, dtype=np.float32) + + pos = random_augmentation(pos, random_state) + + if intra_ligand_ptm_bonds: + assert ccd_cif is not None, 'CCD CIF is None' + from_atom = ccd_cif.get('_chem_comp_bond.atom_id_1', None) + dest_atom = ccd_cif.get('_chem_comp_bond.atom_id_2', None) + else: + from_atom = None + dest_atom = None + + features = {} + for atom_name in atom_names: + features[atom_name] = {} + idx = atom_names.index(atom_name) + charge = 0 if charges[idx] == '?' else int(charges[idx]) + atom_name_chars = np.array([ord(c) - 32 for c in atom_name], dtype=int) + atom_name_chars = _pad_to(atom_name_chars, (4,)) + features[atom_name]['positions'] = pos[idx] + features[atom_name]['mask'] = 1 + features[atom_name]['element'] = elements[idx] + features[atom_name]['charge'] = charge + features[atom_name]['atom_name_chars'] = atom_name_chars + return features, from_atom, dest_atom + + +@dataclasses.dataclass +class RefStructure: + """Contains ref structure information.""" + + # Array with positions, float32, shape [num_res, max_atoms_per_token, 3] + positions: xnp_ndarray + # Array with masks, bool, shape [num_res, max_atoms_per_token] + mask: xnp_ndarray + # Array with elements, int32, shape [num_res, max_atoms_per_token] + element: xnp_ndarray + # Array with charges, float32, shape [num_res, max_atoms_per_token] + charge: xnp_ndarray + # Array with atom name characters, int32, [num_res, max_atoms_per_token, 4] + atom_name_chars: xnp_ndarray + # Array with reference space uids, int32, [num_res, max_atoms_per_token] + ref_space_uid: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + ccd: chemical_components.Ccd, + padding_shapes: PaddingShapes, + chemical_components_data: struc_chem_comps.ChemicalComponentsData, + random_state: np.random.RandomState, + ref_max_modified_date: datetime.date, + intra_ligand_ptm_bonds: bool, + ligand_ligand_bonds: atom_layout.AtomLayout | None = None, + ) -> tuple[Self, Any]: + """Reference structure information for each residue.""" + + # Get features per atom + padded_shape = (padding_shapes.num_tokens, + all_token_atoms_layout.shape[1]) + result = { + 'positions': np.zeros((*padded_shape, 3), 'float32'), + 'mask': np.zeros(padded_shape, 'bool'), + 'element': np.zeros(padded_shape, 'int32'), + 'charge': np.zeros(padded_shape, 'float32'), + 'atom_name_chars': np.zeros((*padded_shape, 4), 'int32'), + 'ref_space_uid': np.zeros((*padded_shape,), 'int32'), + } + + atom_names_all = [] + chain_ids_all = [] + res_ids_all = [] + + # Cache reference conformations for each residue. + conformations = {} + ref_space_uids = {} + for idx in np.ndindex(all_token_atoms_layout.shape): + chain_id = all_token_atoms_layout.chain_id[idx] + res_id = all_token_atoms_layout.res_id[idx] + res_name = all_token_atoms_layout.res_name[idx] + is_non_standard = res_name not in _STANDARD_RESIDUES + atom_name = all_token_atoms_layout.atom_name[idx] + if not atom_name: + ref = _DEFAULT_BLANK_REF + else: + if (chain_id, res_id) not in conformations: + conf, from_atom, dest_atom = get_reference( + res_name=res_name, + chemical_components_data=chemical_components_data, + ccd=ccd, + random_state=random_state, + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=intra_ligand_ptm_bonds, + ) + conformations[(chain_id, res_id)] = conf + + if ( + is_non_standard + and (from_atom is not None) + and (dest_atom is not None) + ): + # Add intra-ligand bond graph + atom_names_ligand = np.stack( + [from_atom, dest_atom], axis=1, dtype=object + ) + atom_names_all.append(atom_names_ligand) + res_ids_all.append( + np.full_like(atom_names_ligand, res_id, dtype=int) + ) + chain_ids_all.append( + np.full_like(atom_names_ligand, + chain_id, dtype=object) + ) + + conformation = conformations.get( + (chain_id, res_id), {atom_name: _DEFAULT_BLANK_REF} + ) + if atom_name not in conformation: + logging.warning( + 'Missing atom "%s" for CCD "%s"', + atom_name, + all_token_atoms_layout.res_name[idx], + ) + ref = conformation.get(atom_name, _DEFAULT_BLANK_REF) + for k in ref: + result[k][idx] = ref[k] + + # Assign a unique reference space id to each component, to determine which + # reference positions live in the same reference space. + space_str_id = ( + all_token_atoms_layout.chain_id[idx], + all_token_atoms_layout.res_id[idx], + ) + if space_str_id not in ref_space_uids: + ref_space_uids[space_str_id] = len(ref_space_uids) + result['ref_space_uid'][idx] = ref_space_uids[space_str_id] + + if atom_names_all: + atom_names_all = np.concatenate(atom_names_all, axis=0) + res_ids_all = np.concatenate(res_ids_all, axis=0) + chain_ids_all = np.concatenate(chain_ids_all, axis=0) + if ligand_ligand_bonds is not None: + adjusted_ligand_ligand_bonds = atom_layout.AtomLayout( + atom_name=np.concatenate( + [ligand_ligand_bonds.atom_name, atom_names_all], axis=0 + ), + chain_id=np.concatenate( + [ligand_ligand_bonds.chain_id, chain_ids_all], axis=0 + ), + res_id=np.concatenate( + [ligand_ligand_bonds.res_id, res_ids_all], axis=0 + ), + ) + else: + adjusted_ligand_ligand_bonds = atom_layout.AtomLayout( + atom_name=atom_names_all, + chain_id=chain_ids_all, + res_id=res_ids_all, + ) + else: + adjusted_ligand_ligand_bonds = ligand_ligand_bonds + + return cls(**result), adjusted_ligand_ligand_bonds + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + positions=batch['ref_pos'], + mask=batch['ref_mask'], + element=batch['ref_element'], + charge=batch['ref_charge'], + atom_name_chars=batch['ref_atom_name_chars'], + ref_space_uid=batch['ref_space_uid'], + ) + + def as_data_dict(self) -> BatchDict: + return { + 'ref_pos': self.positions, + 'ref_mask': self.mask, + 'ref_element': self.element, + 'ref_charge': self.charge, + 'ref_atom_name_chars': self.atom_name_chars, + 'ref_space_uid': self.ref_space_uid, + } + + +@dataclasses.dataclass +class ConvertModelOutput: + """Contains atom layout info.""" + + cleaned_struc: structure.Structure + token_atoms_layout: atom_layout.AtomLayout + flat_output_layout: atom_layout.AtomLayout + empty_output_struc: structure.Structure + polymer_ligand_bonds: atom_layout.AtomLayout + ligand_ligand_bonds: atom_layout.AtomLayout + + @classmethod + def compute_features( + cls, + all_token_atoms_layout: atom_layout.AtomLayout, + padding_shapes: PaddingShapes, + cleaned_struc: structure.Structure, + flat_output_layout: atom_layout.AtomLayout, + empty_output_struc: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout, + ligand_ligand_bonds: atom_layout.AtomLayout, + ) -> Self: + """Pads the all_token_atoms_layout and stores other data.""" + # Crop and pad the all_token_atoms_layout. + token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (padding_shapes.num_tokens, all_token_atoms_layout.shape[1]) + ) + + return cls( + cleaned_struc=cleaned_struc, + token_atoms_layout=token_atoms_layout, + flat_output_layout=flat_output_layout, + empty_output_struc=empty_output_struc, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + """Construct atom layout object from dictionary.""" + + return cls( + cleaned_struc=_unwrap(batch.get('cleaned_struc', None)), + token_atoms_layout=_unwrap(batch.get('token_atoms_layout', None)), + flat_output_layout=_unwrap(batch.get('flat_output_layout', None)), + empty_output_struc=_unwrap(batch.get('empty_output_struc', None)), + polymer_ligand_bonds=_unwrap( + batch.get('polymer_ligand_bonds', None)), + ligand_ligand_bonds=_unwrap( + batch.get('ligand_ligand_bonds', None)), + ) + + def as_data_dict(self) -> BatchDict: + return { + 'cleaned_struc': np.array(self.cleaned_struc, object), + 'token_atoms_layout': np.array(self.token_atoms_layout, object), + 'flat_output_layout': np.array(self.flat_output_layout, object), + 'empty_output_struc': np.array(self.empty_output_struc, object), + 'polymer_ligand_bonds': np.array(self.polymer_ligand_bonds, object), + 'ligand_ligand_bonds': np.array(self.ligand_ligand_bonds, object), + } + + +@dataclasses.dataclass +class AtomCrossAtt: + """Operate on flat atoms.""" + + token_atoms_to_queries: atom_layout.GatherInfo + tokens_to_queries: atom_layout.GatherInfo + tokens_to_keys: atom_layout.GatherInfo + queries_to_keys: atom_layout.GatherInfo + queries_to_token_atoms: atom_layout.GatherInfo + + @classmethod + def compute_features( + cls, + # (num_tokens, num_dense) + all_token_atoms_layout: atom_layout.AtomLayout, + queries_subset_size: int, + keys_subset_size: int, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes gather indices and meta data to work with a flat atom list.""" + + token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (padding_shapes.num_tokens, all_token_atoms_layout.shape[1]) + ) + token_atoms_mask = token_atoms_layout.atom_name.astype(bool) + flat_layout = token_atoms_layout[token_atoms_mask] + num_atoms = flat_layout.shape[0] + + padded_flat_layout = flat_layout.copy_and_pad_to(( + padding_shapes.num_atoms, + )) + + # Create the layout for queries + num_subsets = padding_shapes.num_atoms // queries_subset_size + lay_arr = padded_flat_layout.to_array() + queries_layout = atom_layout.AtomLayout.from_array( + lay_arr.reshape((6, num_subsets, queries_subset_size)) + ) + + # Create the layout for the keys (the key subsets are centered around the + # query subsets) + # Create initial gather indices (contain out-of-bound indices) + subset_centers = np.arange( + queries_subset_size / 2, padding_shapes.num_atoms, queries_subset_size + ) + flat_to_key_gathers = ( + subset_centers[:, None] + + np.arange(-keys_subset_size / 2, keys_subset_size / 2)[None, :] + ) + flat_to_key_gathers = flat_to_key_gathers.astype(int) + # Shift subsets with out-of-bound indices, such that they are fully within + # the bounds. + for row in range(flat_to_key_gathers.shape[0]): + if flat_to_key_gathers[row, 0] < 0: + flat_to_key_gathers[row, :] -= flat_to_key_gathers[row, 0] + elif flat_to_key_gathers[row, -1] > num_atoms - 1: + overflow = flat_to_key_gathers[row, -1] - (num_atoms - 1) + flat_to_key_gathers[row, :] -= overflow + # Create the keys layout. + keys_layout = padded_flat_layout[flat_to_key_gathers] + + # Create gather indices for conversion between token atoms layout, + # queries layout and keys layout. + token_atoms_to_queries = atom_layout.compute_gather_idxs( + source_layout=token_atoms_layout, target_layout=queries_layout + ) + + token_atoms_to_keys = atom_layout.compute_gather_idxs( + source_layout=token_atoms_layout, target_layout=keys_layout + ) + + queries_to_keys = atom_layout.compute_gather_idxs( + source_layout=queries_layout, target_layout=keys_layout + ) + + queries_to_token_atoms = atom_layout.compute_gather_idxs( + source_layout=queries_layout, target_layout=token_atoms_layout + ) + + # Create gather indices for conversion of tokens layout to + # queries and keys layout + token_idxs = np.arange(padding_shapes.num_tokens).astype(np.int64) + token_idxs = np.broadcast_to( + token_idxs[:, None], token_atoms_layout.shape) + tokens_to_queries = atom_layout.GatherInfo( + gather_idxs=atom_layout.convert( + token_atoms_to_queries, token_idxs, layout_axes=(0, 1) + ), + gather_mask=atom_layout.convert( + token_atoms_to_queries, token_atoms_mask, layout_axes=(0, 1) + ), + input_shape=np.array((padding_shapes.num_tokens,)), + ) + + tokens_to_keys = atom_layout.GatherInfo( + gather_idxs=atom_layout.convert( + token_atoms_to_keys, token_idxs, layout_axes=(0, 1) + ), + gather_mask=atom_layout.convert( + token_atoms_to_keys, token_atoms_mask, layout_axes=(0, 1) + ), + input_shape=np.array((padding_shapes.num_tokens,)), + ) + + return cls( + token_atoms_to_queries=token_atoms_to_queries, + tokens_to_queries=tokens_to_queries, + tokens_to_keys=tokens_to_keys, + queries_to_keys=queries_to_keys, + queries_to_token_atoms=queries_to_token_atoms, + ) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls( + token_atoms_to_queries=atom_layout.GatherInfo.from_dict( + batch, key_prefix='token_atoms_to_queries' + ), + tokens_to_queries=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_queries' + ), + tokens_to_keys=atom_layout.GatherInfo.from_dict( + batch, key_prefix='tokens_to_keys' + ), + queries_to_keys=atom_layout.GatherInfo.from_dict( + batch, key_prefix='queries_to_keys' + ), + queries_to_token_atoms=atom_layout.GatherInfo.from_dict( + batch, key_prefix='queries_to_token_atoms' + ), + ) + + def as_data_dict(self) -> BatchDict: + return { + **self.token_atoms_to_queries.as_dict( + key_prefix='token_atoms_to_queries' + ), + **self.tokens_to_queries.as_dict(key_prefix='tokens_to_queries'), + **self.tokens_to_keys.as_dict(key_prefix='tokens_to_keys'), + **self.queries_to_keys.as_dict(key_prefix='queries_to_keys'), + **self.queries_to_token_atoms.as_dict( + key_prefix='queries_to_token_atoms' + ), + } + + +@dataclasses.dataclass +class Frames: + """Features for backbone frames.""" + + mask: xnp_ndarray + + @classmethod + def compute_features( + cls, + all_tokens: atom_layout.AtomLayout, + all_token_atoms_layout: atom_layout.AtomLayout, + ref_structure: RefStructure, + padding_shapes: PaddingShapes, + ) -> Self: + """Computes features for backbone frames.""" + num_tokens = padding_shapes.num_tokens + all_token_atoms_layout = all_token_atoms_layout.copy_and_pad_to( + (num_tokens, all_token_atoms_layout.shape[1]) + ) + + all_token_atoms_to_all_tokens = atom_layout.compute_gather_idxs( + source_layout=all_token_atoms_layout, target_layout=all_tokens + ) + ref_coordinates = atom_layout.convert( + all_token_atoms_to_all_tokens, + ref_structure.positions.astype(np.float32), + layout_axes=(0, 1), + ) + ref_mask = atom_layout.convert( + all_token_atoms_to_all_tokens, + ref_structure.mask.astype(bool), + layout_axes=(0, 1), + ) + ref_mask = ref_mask & all_token_atoms_to_all_tokens.gather_mask.astype( + bool) + + all_frame_mask = [] + + # Iterate over tokens + for idx, args in enumerate( + zip(all_tokens.chain_type, all_tokens.chain_id, all_tokens.res_id) + ): + + chain_type, chain_id, res_id = args + + if chain_type in list(mmcif_names.PEPTIDE_CHAIN_TYPES): + frame_mask = True + elif chain_type in list(mmcif_names.NUCLEIC_ACID_CHAIN_TYPES): + frame_mask = True + elif chain_type in list(mmcif_names.NON_POLYMER_CHAIN_TYPES): + # For ligands, build frames from closest atoms from the same molecule. + (local_token_idxs,) = np.where( + (all_tokens.chain_type == chain_type) + & (all_tokens.chain_id == chain_id) + & (all_tokens.res_id == res_id) + ) + + if len(local_token_idxs) < 3: + frame_mask = False + + else: + # [local_tokens] + local_dist = np.linalg.norm( + ref_coordinates[idx] - ref_coordinates[local_token_idxs], axis=-1 + ) + local_mask = ref_mask[local_token_idxs] + cost = local_dist + 1e8 * ~local_mask + cost = cost + 1e8 * (idx == local_token_idxs) + # [local_tokens] + closest_idxs = np.argsort(cost, axis=0) + + # The closest indices index an array of local tokens. Convert this + # to indices of the full (num_tokens,) array. + global_closest_idxs = local_token_idxs[closest_idxs] + + # Construct frame by placing the current token at the origin and two + # nearest atoms on either side. + global_frame_idxs = np.array( + (global_closest_idxs[0], idx, global_closest_idxs[1]) + ) + + # Check that the frame atoms are not colinear. + a, b, c = ref_coordinates[global_frame_idxs] + vec1 = a - b + vec2 = c - b + # Reference coordinates can be all zeros, in which case we have + # to explicitly set colinearity. + if np.isclose(np.linalg.norm(vec1, axis=-1), 0) or np.isclose( + np.linalg.norm(vec2, axis=-1), 0 + ): + is_colinear = True + logging.info( + 'Found identical coordinates: Assigning as colinear.') + else: + vec1 = vec1 / np.linalg.norm(vec1, axis=-1) + vec2 = vec2 / np.linalg.norm(vec2, axis=-1) + cos_angle = np.einsum('...k,...k->...', vec1, vec2) + # <25 degree deviation is considered colinear. + is_colinear = 1 - np.abs(cos_angle) < 0.0937 + + frame_mask = not is_colinear + else: + # No frame for other chain types. + frame_mask = False + + all_frame_mask.append(frame_mask) + + all_frame_mask = np.array(all_frame_mask, dtype=bool) + + mask = _pad_to(all_frame_mask, (padding_shapes.num_tokens,)) + + return cls(mask=mask) + + @classmethod + def from_data_dict(cls, batch: BatchDict) -> Self: + return cls(mask=batch['frames_mask']) + + def as_data_dict(self) -> BatchDict: + return {'frames_mask': self.mask} diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py new file mode 100644 index 000000000..41cf2bf48 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/load_batch.py @@ -0,0 +1,22 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +"""load data 'batch' used in test""" +import pickle +import mindspore as ms +from alphafold3.model.feat_batch import Batch + + +def load_batch(dtype=ms.float32): + """Load batch data for test""" + with open('/data/zmmVol2/AF3/test/unit_tests/model/diffusion/example_np.pkl', 'rb') as f: + data = pickle.load(f) + batch = Batch.from_data_dict(data) + batch.convert_to_tensor(dtype=dtype) + return batch diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py new file mode 100644 index 000000000..3c1fab899 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/merging_features.py @@ -0,0 +1,92 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Methods for merging existing features to create a new example. + +Covers: +- Merging features across chains. +- Merging the paired and unpaired parts of the MSA. +""" + +from typing import TypeAlias + +from alphafold3.model import data_constants +import numpy as np + +NUM_SEQ_NUM_RES_MSA_FEATURES = data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES +NUM_SEQ_MSA_FEATURES = data_constants.NUM_SEQ_MSA_FEATURES +MSA_PAD_VALUES = data_constants.MSA_PAD_VALUES + + +xnp_ndarray: TypeAlias = np.ndarray # pylint: disable=invalid-name +BatchDict: TypeAlias = dict[str, xnp_ndarray] + + +def _pad_features_to_max(feat_name: str, chains: list[BatchDict], axis: int): + """Pad a set of features to the maximum size amongst all chains. + + Args: + feat_name: The feature name to pad. + chains: A list of chains with associated features. + axis: Which axis to pad to the max. + + Returns: + A list of features, all with the same size on the given axis. + """ + max_num_seq = np.max([chain[feat_name].shape[axis] for chain in chains]) + + padded_feats = [] + for chain in chains: + feat = chain[feat_name] + + padding = np.zeros_like(feat.shape) # pytype: disable=attribute-error + # pytype: disable=attribute-error + padding[axis] = max_num_seq - feat.shape[axis] + padding = [(0, p) for p in padding] + padded_feats.append( + np.pad( + feat, + padding, + mode='constant', + constant_values=MSA_PAD_VALUES[feat_name], + ) + ) + return padded_feats + + +def merge_msa_features(feat_name: str, chains: list[BatchDict]) -> np.ndarray: + """Merges MSA features with shape (NUM_SEQ, NUM_RES) across chains.""" + expected_dtype = chains[0][feat_name].dtype + if '_all_seq' in feat_name: + return np.concatenate( + [c.get(feat_name, np.array([], expected_dtype)) for c in chains], axis=1 + ) + else: + # Since each MSA can be of different lengths, we first need to pad them + # all to the size of the largest MSA before concatenating. + padded_feats = _pad_features_to_max(feat_name, chains, axis=0) + return np.concatenate(padded_feats, axis=1) + + +def merge_paired_and_unpaired_msa(example: BatchDict) -> BatchDict: + """Concatenates the paired (all_seq) MSA features with the unpaired ones.""" + new_example = dict(example) + + for feature_name in NUM_SEQ_NUM_RES_MSA_FEATURES + NUM_SEQ_MSA_FEATURES: + if feature_name in example and feature_name + '_all_seq' in example: + feat = example[feature_name] + feat_all_seq = example[feature_name + '_all_seq'] + merged_feat = np.concatenate([feat_all_seq, feat], axis=0) + new_example[feature_name] = merged_feat + + new_example['num_alignments'] = np.array( + new_example['msa'].shape[0], dtype=np.int32 + ) + return new_example diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc new file mode 100644 index 000000000..663e7f303 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.cc @@ -0,0 +1,63 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/model/mkdssp_pybind.h" + +#include + +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" + +namespace alphafold3 { +namespace py = pybind11; + +void RegisterModuleMkdssp(pybind11::module m) { + py::module site = py::module::import("site"); + py::list paths = py::cast(site.attr("getsitepackages")()); + // Find the first path that contains the libcifpp components.cif file. + bool found = false; + for (const auto& py_path : paths) { + auto path_str = + std::filesystem::path(py::cast(py_path)) / + "share/libcifpp/components.cif"; + if (std::filesystem::exists(path_str)) { + setenv("LIBCIFPP_DATA_DIR", path_str.parent_path().c_str(), 0); + found = true; + break; + } + } + if (!found) { + throw py::type_error("Could not find the libcifpp components.cif file."); + } + m.def( + "get_dssp", + [](absl::string_view mmcif, int model_no, + int min_poly_proline_stretch_length, + bool calculate_surface_accessibility) { + cif::file cif_file(mmcif.data(), mmcif.size()); + dssp result(cif_file.front(), model_no, min_poly_proline_stretch_length, + calculate_surface_accessibility); + std::stringstream sstream; + result.write_legacy_output(sstream); + return sstream.str(); + }, + py::arg("mmcif"), py::arg("model_no") = 1, + py::arg("min_poly_proline_stretch_length") = 3, + py::arg("calculate_surface_accessibility") = false, + py::doc("Gets secondary structure from an mmCIF file.")); +} + +} // namespace alphafold3 \ No newline at end of file diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h new file mode 100644 index 000000000..a1e4832b8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mkdssp_pybind.h @@ -0,0 +1,26 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ + + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMkdssp(pybind11::module m); + +} + + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_MODEL_MKDSSP_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py new file mode 100644 index 000000000..f28abda79 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/mmcif_metadata.py @@ -0,0 +1,199 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Adds mmCIF metadata (to be ModelCIF-conformant) and author and legal info.""" + +from typing import Final + +from alphafold3.structure import mmcif +import numpy as np + +_LICENSE_URL: Final[str] = ( + 'https://github.com/google-deepmind/alphafold3/blob/main/OUTPUT_TERMS_OF_USE.md' +) + +_LICENSE: Final[str] = f"""\ +Non-commercial use only, by using this file you agree to the terms of use found +at {_LICENSE_URL}. +To request access to the AlphaFold 3 model parameters, follow the process set +out at https://github.com/google-deepmind/alphafold3. You may only use these if +received directly from Google. Use is subject to terms of use available at +https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md. +""" + +_DISCLAIMER: Final[str] = """\ +AlphaFold 3 and its output are not intended for, have not been validated for, +and are not approved for clinical use. They are provided "as-is" without any +warranty of any kind, whether expressed or implied. No warranty is given that +use shall not infringe the rights of any third party. +""" + +_MMCIF_PAPER_AUTHORS: Final[tuple[str, ...]] = ( + 'Google DeepMind', + 'Isomorphic Labs', +) + +# Authors of the mmCIF - we set them to be equal to the authors of the paper. +_MMCIF_AUTHORS: Final[tuple[str, ...]] = _MMCIF_PAPER_AUTHORS + + +def add_metadata_to_mmcif( + old_cif: mmcif.Mmcif, model_id: bytes +) -> mmcif.Mmcif: + """Adds metadata to a mmCIF to make it ModelCIF-conformant.""" + cif = {} + + # ModelCIF conformation dictionary. + cif['_audit_conform.dict_name'] = ['mmcif_ma.dic'] +# cif['_audit_conform.dict_version'] = ['1.4.5'] + cif['_audit_conform.dict_location'] = [ + 'https://raw.githubusercontent.com/ihmwg/ModelCIF/master/dist/mmcif_ma.dic' + ] + + cif['_pdbx_data_usage.id'] = ['1', '2'] + cif['_pdbx_data_usage.type'] = ['license', 'disclaimer'] + cif['_pdbx_data_usage.details'] = [_LICENSE, _DISCLAIMER] + cif['_pdbx_data_usage.url'] = [_LICENSE_URL, '?'] + + # Structure author details. + cif['_audit_author.name'] = [] + cif['_audit_author.pdbx_ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_AUTHORS, start=1): + cif['_audit_author.name'].append(author_name) + cif['_audit_author.pdbx_ordinal'].append(str(author_index)) + + # Paper author details. + cif['_citation_author.citation_id'] = [] + cif['_citation_author.name'] = [] + cif['_citation_author.ordinal'] = [] + for author_index, author_name in enumerate(_MMCIF_PAPER_AUTHORS, start=1): + cif['_citation_author.citation_id'].append('primary') + cif['_citation_author.name'].append(author_name) + cif['_citation_author.ordinal'].append(str(author_index)) + + # Paper citation details. + cif['_citation.id'] = ['primary'] + cif['_citation.title'] = [ + 'Accurate structure prediction of biomolecular interactions with' + ' AlphaFold 3' + ] + cif['_citation.journal_full'] = ['Nature'] + cif['_citation.journal_volume'] = ['630'] + cif['_citation.page_first'] = ['493'] + cif['_citation.page_last'] = ['500'] + cif['_citation.year'] = ['2024'] + cif['_citation.journal_id_ASTM'] = ['NATUAS'] + cif['_citation.country'] = ['UK'] + cif['_citation.journal_id_ISSN'] = ['0028-0836'] + cif['_citation.journal_id_CSD'] = ['0006'] + cif['_citation.book_publisher'] = ['?'] + cif['_citation.pdbx_database_id_PubMed'] = ['38718835'] + cif['_citation.pdbx_database_id_DOI'] = ['10.1038/s41586-024-07487-w'] + + # Type of data in the dataset including data used in the model generation. + cif['_ma_data.id'] = ['1'] + cif['_ma_data.name'] = ['Model'] + cif['_ma_data.content_type'] = ['model coordinates'] + + # Description of number of instances for each entity. + cif['_ma_target_entity_instance.asym_id'] = old_cif['_struct_asym.id'] + cif['_ma_target_entity_instance.entity_id'] = old_cif[ + '_struct_asym.entity_id' + ] + cif['_ma_target_entity_instance.details'] = ['.'] * len( + cif['_ma_target_entity_instance.entity_id'] + ) + + # Details about the target entities. + cif['_ma_target_entity.entity_id'] = cif[ + '_ma_target_entity_instance.entity_id' + ] + cif['_ma_target_entity.data_id'] = ['1'] * len( + cif['_ma_target_entity.entity_id'] + ) + cif['_ma_target_entity.origin'] = ['.'] * len( + cif['_ma_target_entity.entity_id'] + ) + + # Details of the models being deposited. + cif['_ma_model_list.ordinal_id'] = ['1'] + cif['_ma_model_list.model_id'] = ['1'] + cif['_ma_model_list.model_group_id'] = ['1'] + cif['_ma_model_list.model_name'] = ['Top ranked model'] + + cif['_ma_model_list.model_group_name'] = [ + f'AlphaFold-beta-20231127' + ] + cif['_ma_model_list.data_id'] = ['1'] + cif['_ma_model_list.model_type'] = ['Ab initio model'] + + # Software used. + cif['_software.pdbx_ordinal'] = ['1'] + cif['_software.name'] = ['AlphaFold'] +# cif['_software.version'] = [ +# f'AlphaFold-beta-20231127 ({model_id.decode("ascii")})' +# ] + cif['_software.type'] = ['package'] + cif['_software.description'] = ['Structure prediction'] + cif['_software.classification'] = ['other'] + cif['_software.date'] = ['?'] + + # Collection of software into groups. + cif['_ma_software_group.ordinal_id'] = ['1'] + cif['_ma_software_group.group_id'] = ['1'] + cif['_ma_software_group.software_id'] = ['1'] + + # Method description to conform with ModelCIF. + cif['_ma_protocol_step.ordinal_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.protocol_id'] = ['1', '1', '1'] + cif['_ma_protocol_step.step_id'] = ['1', '2', '3'] + cif['_ma_protocol_step.method_type'] = [ + 'coevolution MSA', + 'template search', + 'modeling', + ] + + # Details of the metrics use to assess model confidence. + cif['_ma_qa_metric.id'] = ['1', '2'] + cif['_ma_qa_metric.name'] = ['pLDDT', 'pLDDT'] + # Accepted values are distance, energy, normalised score, other, zscore. + cif['_ma_qa_metric.type'] = ['pLDDT', 'pLDDT'] + cif['_ma_qa_metric.mode'] = ['global', 'local'] + cif['_ma_qa_metric.software_group_id'] = ['1', '1'] + + # Global model confidence metric value. + cif['_ma_qa_metric_global.ordinal_id'] = ['1'] + cif['_ma_qa_metric_global.model_id'] = ['1'] + cif['_ma_qa_metric_global.metric_id'] = ['1'] + global_plddt = np.mean( + [float(v) for v in old_cif['_atom_site.B_iso_or_equiv']] + ) + cif['_ma_qa_metric_global.metric_value'] = [f'{global_plddt:.2f}'] + + cif['_atom_type.symbol'] = sorted(set(old_cif['_atom_site.type_symbol'])) + + return old_cif.copy_and_update(cif) + + +def add_legal_comment(cif: str) -> str: + """Adds legal comment at the top of the mmCIF.""" + # fmt: off + # pylint: disable=line-too-long + comment = ( + '# By using this file you agree to the legally binding terms of use found at\n' + f'# {_LICENSE_URL}.\n' + '# To request access to the AlphaFold 3 model parameters, follow the process set\n' + '# out at https://github.com/google-deepmind/alphafold3. You may only use these if\n' + '# received directly from Google. Use is subject to terms of use available at\n' + '# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md.' + ) + # pylint: enable=line-too-long + # fmt: on + return f'{comment}\n{cif}' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py new file mode 100644 index 000000000..83cf9ce75 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/model_config.py @@ -0,0 +1,32 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Config for the protein folding model and experiment.""" + +from collections.abc import Sequence +from typing import Literal, TypeAlias + +from alphafold3.model import base_config +from alphafold3.utils.attention import attention + + +_Shape2DType: TypeAlias = tuple[int | None, int | None] + + +class GlobalConfig(base_config.BaseConfig): + bfloat16: Literal['all', 'none', 'intermediate'] = 'none' + final_init: Literal['zeros', 'linear'] = 'zeros' + pair_attention_chunk_size: Sequence[_Shape2DType] = ( + (1536, 128), (None, 32)) + pair_transition_shard_spec: Sequence[_Shape2DType] = ( + (2048, None), + (None, 1024), + ) + flash_attention_implementation: attention.Implementation = 'ms' diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py new file mode 100644 index 000000000..720dc2e24 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/msa_pairing.py @@ -0,0 +1,316 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for producing "paired" and "unpaired" MSA features for each chain. + +The paired MSA: +- Is made from the result of the all_seqs MSA query. +- Is ordered such that you can concatenate features across chains and related + sequences will end up on the same row. Related here means "from the same + species". Gaps are added to facilitate this whenever a sequence has no + suitable pair. + +The unpaired MSA: +- Is made from the results of the remaining MSA queries. +- Has no special ordering properties. +- Is deduplicated such that it doesn't contain any sequences in the paired MSA. +""" + +from typing import Mapping, MutableMapping, Sequence +from alphafold3.model import data_constants +import numpy as np + + +def _align_species( + all_species: Sequence[bytes], + chains_species_to_rows: Sequence[Mapping[bytes, np.ndarray]], + min_hits_per_species: Mapping[bytes, int], +) -> np.ndarray: + """Aligns MSA row indices based on species. + + Within a species, MSAs are aligned based on their original order (the first + sequence for a species in the first chain's MSA is aligned to the first + sequence for the same species in the second chain's MSA). + + Args: + all_species: A list of all unique species identifiers. + chains_species_to_rows: A dictionary for each chain, that maps species to + the set of MSA row indices from that species in that chain. + min_hits_per_species: A mapping from species id, to the minimum MSA size + across chains for that species (ignoring chains with zero hits). + + Returns: + A matrix of size [num_msa_rows, num_chains], where the i,j element is an + index into the jth chains MSA. Each row consists of sequences from each + chain for the same species (or -1 if that chain has no sequences for that + species). + """ + # Each species block is of size [num_seqs x num_chains] and consists of + # indices into the respective MSAs that have been aligned and are all for the + # same species. + species_blocks = [] + for species in all_species: + chain_row_indices = [] + for species_to_rows in chains_species_to_rows: + min_msa_size = min_hits_per_species[species] + if species not in species_to_rows: + # If a given chain has no hits for a species then we pad it with -1's, + # later on these values are used to make sure each feature is padded + # with its appropriate pad value. + row_indices = np.full( + min_msa_size, fill_value=-1, dtype=np.int32) + else: + # We crop down to the smallest MSA for a given species across chains. + row_indices = species_to_rows[species][:min_msa_size] + chain_row_indices.append(row_indices) + species_block = np.stack(chain_row_indices, axis=1) + species_blocks.append(species_block) + aligned_matrix = np.concatenate(species_blocks, axis=0) + return aligned_matrix + + +def create_paired_features( + chains: Sequence[MutableMapping[str, np.ndarray]], + max_paired_sequences: int, + nonempty_chain_ids: set[str], + max_hits_per_species: int, +) -> Sequence[MutableMapping[str, np.ndarray]]: + """Creates per-chain MSA features where the MSAs have been aligned. + + Args: + chains: A list of feature dicts, one for each chain. + max_paired_sequences: No more than this many paired sequences will be + returned from this function. + nonempty_chain_ids: A set of chain ids (str) that are included in the crop + there is no reason to process chains not in this list. + max_hits_per_species: No more than this number of sequences will be returned + for a given species. + + Returns: + An updated feature dictionary for each chain, where the {}_all_seq features + have been aligned so that the nth row in chain 1 is aligned to the nth row + in chain 2's features. + """ + # The number of chains that the given species appears in - we rank hits + # across more chains higher. + species_num_chains = {} + + # For each chain we keep a mapping from species to the row indices in the + # original MSA for that chain. + chains_species_to_rows = [] + + # Keep track of the minimum number of hits across chains for a given species. + min_hits_per_species = {} + + for chain in chains: + species_ids = chain['msa_species_identifiers_all_seq'] + + # The query gets an empty species_id, so no pairing happens for this row. + if ( + species_ids.size == 0 + or (species_ids.size == 1 and not species_ids[0]) + or chain['chain_id'] not in nonempty_chain_ids + ): + chains_species_to_rows.append({}) + continue + + # For each species keep track of which row indices in the original MSA are + # from this species. + row_indices = np.arange(len(species_ids)) + # The grouping np.split code requires that the input is already clustered + # by species id. + sort_idxs = species_ids.argsort() + species_ids = species_ids[sort_idxs] + row_indices = row_indices[sort_idxs] + + species, unique_row_indices = np.unique(species_ids, return_index=True) + grouped_row_indices = np.split(row_indices, unique_row_indices[1:]) + species_to_rows = dict(zip(species, grouped_row_indices, strict=True)) + chains_species_to_rows.append(species_to_rows) + + for s in species: + species_num_chains[s] = species_num_chains.get(s, 0) + 1 + + for species, row_indices in species_to_rows.items(): + min_hits_per_species[species] = min( + min_hits_per_species.get(species, max_hits_per_species), + len(row_indices), + ) + + # Construct a mapping from the number of chains a species appears in to + # the list of species with that count. + num_chains_to_species = {} + for species, num_chains in species_num_chains.items(): + if not species or num_chains <= 1: + continue + if num_chains not in num_chains_to_species: + num_chains_to_species[num_chains] = [] + num_chains_to_species[num_chains].append(species) + + num_rows_seen = 0 + # We always keep the first row as it is the query sequence. + all_rows = [np.array([[0] * len(chains)], dtype=np.int32)] + + # We prioritize species that have hits across more chains. + for num_chains in sorted(num_chains_to_species, reverse=True): + all_species = num_chains_to_species[num_chains] + + # Align all the per-chain row indices by species, so every paired row is + # for a single species. + rows = _align_species( + all_species, chains_species_to_rows, min_hits_per_species + ) + # Sort rows by the product of the original indices in the respective chain + # MSAS, so as to rank hits that appear higher in the original MSAs higher. + rank_metric = np.abs(np.prod(rows.astype(np.float32), axis=1)) + sorted_rows = rows[np.argsort(rank_metric), :] + all_rows.append(sorted_rows) + num_rows_seen += rows.shape[0] + if num_rows_seen >= max_paired_sequences: + break + + all_rows = np.concatenate(all_rows, axis=0) + all_rows = all_rows[:max_paired_sequences, :] + + # Now we just have to select the relevant rows from the original msa and + # deletion matrix features + paired_chains = [] + for chain_idx, chain in enumerate(chains): + out_chain = {k: v for k, v in chain.items() if 'all_seq' not in k} + selected_row_indices = all_rows[:, chain_idx] + for feat_name in {'msa', 'deletion_matrix'}: + all_seq_name = f'{feat_name}_all_seq' + feat_value = chain[all_seq_name] + + # The selected row indices are padded to be the same shape for each chain, + # they are padded with -1's, so we add a single row onto the feature with + # the appropriate pad value. This has the effect that we correctly pad + # each feature since all padded indices will select this padding row. + pad_value = data_constants.MSA_PAD_VALUES[feat_name] + feat_value = np.concatenate([ + feat_value, + np.full((1, feat_value.shape[1]), pad_value, feat_value.dtype), + ]) + + feat_value = feat_value[selected_row_indices, :] + out_chain[all_seq_name] = feat_value + out_chain['num_alignments_all_seq'] = np.array( + out_chain['msa_all_seq'].shape[0] + ) + paired_chains.append(out_chain) + return paired_chains + + +def deduplicate_unpaired_sequences( + np_chains: Sequence[MutableMapping[str, np.ndarray]], +) -> Sequence[MutableMapping[str, np.ndarray]]: + """Deduplicates unpaired sequences based on paired sequences.""" + + feature_names = np_chains[0].keys() + msa_features = ( + data_constants.NUM_SEQ_MSA_FEATURES + + data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES + ) + + for chain in np_chains: + sequence_set = set( + hash(s.data.tobytes()) for s in chain['msa_all_seq'].astype(np.int8) + ) + keep_rows = [] + # Go through unpaired MSA seqs and remove any rows that correspond to the + # sequences that are already present in the paired MSA. + for row_num, seq in enumerate(chain['msa'].astype(np.int8)): + if hash(seq.data.tobytes()) not in sequence_set: + keep_rows.append(row_num) + for feature_name in feature_names: + if feature_name in msa_features: + chain[feature_name] = chain[feature_name][keep_rows] + chain['num_alignments'] = np.array( + chain['msa'].shape[0], dtype=np.int32) + return np_chains + + +def choose_paired_unpaired_msa_crop_sizes( + unpaired_msa: np.ndarray, + paired_msa: np.ndarray | None, + total_msa_crop_size: int, + max_paired_sequences: int, +) -> tuple[int, int | None]: + """Returns the sizes of the MSA crop and MSA_all_seq crop. + + NOTE: Unpaired + paired MSA sizes can exceed total_msa_size when + there are lots of gapped rows. Through the pairing logic another chain(s) + will have fewer than total_msa_size. + + Args: + unpaired_msa: The unpaired MSA array (not all_seq). + paired_msa: The paired MSA array (all_seq). + total_msa_crop_size: The maximum total number of sequences to crop to. + max_paired_sequences: The maximum number of sequences that can come from + MSA pairing. + + Returns: + A tuple of: + The size of the reduced MSA crop (not all_seq features). + The size of the unreduced MSA crop (for all_seq features) or None, if + paired_msa is None. + """ + if paired_msa is not None: + paired_crop_size = np.minimum( + paired_msa.shape[0], max_paired_sequences) + + # We reduce the number of un-paired sequences, by the number of times a + # sequence from this chains MSA is included in the paired MSA. This keeps + # the MSA size for each chain roughly constant. + cropped_all_seq_msa = paired_msa[:max_paired_sequences] + num_non_gapped_pairs = cropped_all_seq_msa.shape[0] + + assert num_non_gapped_pairs <= max_paired_sequences + unpaired_crop_size = np.minimum( + unpaired_msa.shape[0], total_msa_crop_size - num_non_gapped_pairs + ) + assert unpaired_crop_size >= 0 + else: + unpaired_crop_size = np.minimum( + unpaired_msa.shape[0], total_msa_crop_size) + paired_crop_size = None + return unpaired_crop_size, paired_crop_size + + +def remove_all_gapped_rows_from_all_seqs( + chains_list: Sequence[dict[str, np.ndarray]], asym_ids: Sequence[float] +) -> Sequence[dict[str, np.ndarray]]: + """Removes all gapped rows from all_seq feat based on selected asym_ids.""" + + merged_msa_all_seq = np.concatenate( + [ + chain['msa_all_seq'] + for chain in chains_list + if chain['asym_id'][0] in asym_ids + ], + axis=1, + ) + + non_gapped_keep_rows = np.any( + merged_msa_all_seq != data_constants.MSA_GAP_IDX, axis=1 + ) + for chain in chains_list: + for feat_name in list(chains_list)[0]: + if '_all_seq' in feat_name: + feat_name_split = feat_name.split('_all_seq')[0] + if feat_name_split in ( + data_constants.NUM_SEQ_NUM_RES_MSA_FEATURES + + data_constants.NUM_SEQ_MSA_FEATURES + ): + # For consistency we do this for all chains even though the + # gapped rows are based on a selected set asym_ids. + chain[feat_name] = chain[feat_name][non_gapped_keep_rows] + chain['num_alignments_all_seq'] = np.sum(non_gapped_keep_rows) + return chains_list diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py new file mode 100644 index 000000000..3c1d22df6 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/params.py @@ -0,0 +1,218 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Model param loading.""" + +import bisect +import collections +from collections.abc import Iterator +import contextlib +import io +import os +import pathlib +import re +import struct +import sys +from typing import IO +import numpy as np + + +class RecordError(Exception): + """Error reading a record.""" + + +def encode_record(scope: str, name: str, arr: np.ndarray) -> bytes: + """Encodes a single haiku param as bytes, preserving non-numpy dtypes.""" + scope = scope.encode('utf-8') + name = name.encode('utf-8') + shape = arr.shape + dtype = str(arr.dtype).encode('utf-8') + arr = np.ascontiguousarray(arr) + if sys.byteorder == 'big': + arr = arr.byteswap() + arr_buffer = arr.tobytes('C') + header = struct.pack( + '<5i', len(scope), len(name), len(dtype), len(shape), len(arr_buffer) + ) + return header + b''.join( + (scope, name, dtype, struct.pack(f'{len(shape)}i', *shape), arr_buffer) + ) + + +def _read_record(stream: IO[bytes]) -> tuple[str, str, np.ndarray] | None: + """Reads a record encoded by `_encode_record` from a byte stream.""" + header_size = struct.calcsize('<5i') + header = stream.read(header_size) + if not header: + return None + if len(header) < header_size: + raise RecordError( + f'Incomplete header: {len(header)=} < {header_size=}') + (scope_len, name_len, dtype_len, shape_len, arr_buffer_len) = struct.unpack( + '<5i', header + ) + fmt = f'<{scope_len}s{name_len}s{dtype_len}s{shape_len}i' + payload_size = struct.calcsize(fmt) + arr_buffer_len + payload = stream.read(payload_size) + if len(payload) < payload_size: + raise RecordError( + f'Incomplete payload: {len(payload)=} < {payload_size=}') + scope, name, dtype, *shape = struct.unpack_from(fmt, payload) + scope = scope.decode('utf-8') + name = name.decode('utf-8') + dtype = dtype.decode('utf-8') + if dtype == 'bfloat16': + buffer = payload[-arr_buffer_len:] + if sys.byteorder == 'big': + buffer = buffer[::-1] + arr_uint16 = np.frombuffer(buffer, dtype=np.uint16) + arr_bf16 = arr_uint16.view('bfloat16') + arr = arr_bf16.astype(np.float32) + else: + arr = np.frombuffer(payload[-arr_buffer_len:], dtype=dtype) + if sys.byteorder == 'big': + arr = arr.byteswap() + arr = np.reshape(arr, shape) + if sys.byteorder == 'big': + arr = arr.byteswap() + return scope, name, arr + + +def read_records(stream: IO[bytes]) -> Iterator[tuple[str, str, np.ndarray]]: + """Fully reads the contents of a byte stream.""" + while record := _read_record(stream): + yield record + + +class _MultiFileIO(io.RawIOBase): + """A file-like object that presents a concatenated view of multiple files.""" + + def __init__(self, files: list[pathlib.Path]): + self._files = files + self._stack = contextlib.ExitStack() + self._handles = [ + self._stack.enter_context(file.open('rb')) for file in files + ] + self._sizes = [] + for handle in self._handles: + handle.seek(0, os.SEEK_END) + self._sizes.append(handle.tell()) + self._length = sum(self._sizes) + self._offsets = [0] + for s in self._sizes[:-1]: + self._offsets.append(self._offsets[-1] + s) + self._abspos = 0 + self._relpos = (0, 0) + + def _abs_to_rel(self, pos: int) -> tuple[int, int]: + idx = bisect.bisect_right(self._offsets, pos) - 1 + return idx, pos - self._offsets[idx] + + def close(self): + self._stack.close() + + def closed(self) -> bool: + return all(handle.closed for handle in self._handles) + + def fileno(self) -> int: + return -1 + + def readable(self) -> bool: + return True + + def tell(self) -> int: + return self._abspos + + def seek(self, pos: int, whence: int = os.SEEK_SET, /): + match whence: + case os.SEEK_SET: + pass + case os.SEEK_CUR: + pos += self._abspos + case os.SEEK_END: + pos = self._length - pos + case _: + raise ValueError(f'Invalid whence: {whence}') + self._abspos = pos + self._relpos = self._abs_to_rel(pos) + + def readinto(self, b: bytearray | memoryview) -> int: + result = 0 + mem = memoryview(b) + while mem: + self._handles[self._relpos[0]].seek(self._relpos[1]) + count = self._handles[self._relpos[0]].readinto(mem) + result += count + self._abspos += count + self._relpos = self._abs_to_rel(self._abspos) + mem = mem[count:] + if self._abspos == self._length: + break + return result + + +@contextlib.contextmanager +def open_for_reading(model_files: list[pathlib.Path], is_compressed: bool): + with contextlib.closing(_MultiFileIO(model_files)) as f: + yield f + + +def _match_model( + paths: list[pathlib.Path], pattern: re.Pattern[str] +) -> dict[str, list[pathlib.Path]]: + """Match files in a directory with a pattern, and group by model name.""" + models = collections.defaultdict(list) + for path in paths: + match = pattern.fullmatch(path.name) + if match: + models[match.group('model_name')].append(path) + return {k: sorted(v) for k, v in models.items()} + + +def select_model_files( + model_dir: pathlib.Path, model_name: str | None = None +) -> tuple[list[pathlib.Path], bool]: + """Select the model files from a model directory.""" + files = [file for file in model_dir.iterdir() if file.is_file()] + + for pattern, is_compressed in ( + (r'(?P.*)\.[0-9]+\.bin\.zst$', True), + (r'(?P.*)\.bin\.zst\.[0-9]+$', True), + (r'(?P.*)\.[0-9]+\.bin$', False), + (r'(?P.*)\.bin]\.[0-9]+$', False), + (r'(?P.*)\.bin\.zst$', True), + (r'(?P.*)\.bin$', False), + ): + models = _match_model(files, re.compile(pattern)) + if model_name is not None: + if model_name in models: + return models[model_name], is_compressed + else: + if models: + if len(models) > 1: + raise RuntimeError( + f'Multiple models matched in {model_dir}') + _, model_files = models.popitem() + return model_files, is_compressed + raise FileNotFoundError(f'No models matched in {model_dir}') + + +def get_model_af3_params(model_dir: pathlib.Path): + """Get the Haiku parameters from a model name.""" + params: dict[str, dict[str, np.array]] = {} + model_files, is_compressed = select_model_files(model_dir) + with open_for_reading(model_files, is_compressed) as stream: + for scope, name, arr in read_records(stream): + params.setdefault(scope, {})[name] = np.array(arr) + if not params: + raise FileNotFoundError(f'Model missing from "{model_dir}"') + return params diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py new file mode 100644 index 000000000..08953b2cb --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/inter_chain_bonds.py @@ -0,0 +1,347 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Functions for handling inter-chain bonds.""" + +from collections.abc import Collection +import functools +from typing import Final, NamedTuple +from alphafold3 import structure +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout +import numpy as np + + +BOND_THRESHOLD_GLYCANS_ANGSTROM: Final[float] = 1.7 +# See https://pubs.acs.org/doi/10.1021/ja010331r for P-P atom bond distances. +BOND_THRESHOLD_ALL_ANGSTROM: Final[float] = 2.4 + + +class BondAtomArrays(NamedTuple): + chain_id: np.ndarray + chain_type: np.ndarray + res_id: np.ndarray + res_name: np.ndarray + atom_name: np.ndarray + coords: np.ndarray + + +def _get_bond_atom_arrays( + struc: structure.Structure, bond_atom_indices: np.ndarray +) -> BondAtomArrays: + return BondAtomArrays( + chain_id=struc.chain_id[bond_atom_indices], + chain_type=struc.chain_type[bond_atom_indices], + res_id=struc.res_id[bond_atom_indices], + res_name=struc.res_name[bond_atom_indices], + atom_name=struc.atom_name[bond_atom_indices], + coords=struc.coords[..., bond_atom_indices, :], + ) + + +@functools.lru_cache(maxsize=1) +def get_polymer_ligand_and_ligand_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool, +) -> tuple[atom_layout.AtomLayout, atom_layout.AtomLayout]: + """Return polymer-ligand & ligand-ligand inter-residue bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom.. + + Returns: + polymer_ligand, ligand_ligand_bonds: Each object is an AtomLayout object + [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + all_bonds = get_bond_layout( + bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM + if only_glycan_ligands + else BOND_THRESHOLD_ALL_ANGSTROM, + struct=struct, + allowed_chain_types1=list({ + *mmcif_names.LIGAND_CHAIN_TYPES, + *mmcif_names.POLYMER_CHAIN_TYPES, + }), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + ligand_ligand_bonds_mask = np.isin( + all_bonds.chain_type, list(mmcif_names.LIGAND_CHAIN_TYPES) + ) + polymer_ligand_bonds_mask = np.isin( + all_bonds.chain_type, list(mmcif_names.POLYMER_CHAIN_TYPES) + ) + polymer_ligand_bonds_mask = np.logical_and( + ligand_ligand_bonds_mask.any(axis=1), + polymer_ligand_bonds_mask.any(axis=1), + ) + ligand_ligand_bonds = all_bonds[ligand_ligand_bonds_mask.all(axis=1)] + polymer_ligand_bonds = all_bonds[polymer_ligand_bonds_mask] + return polymer_ligand_bonds, ligand_ligand_bonds + + +def _remove_multi_bonds( + bond_layout: atom_layout.AtomLayout, +) -> atom_layout.AtomLayout: + """Remove instances greedily.""" + uids = {} + keep_indx = [] + for chain_id, res_id, atom_name in zip( + bond_layout.chain_id, + bond_layout.res_id, + bond_layout.atom_name, + strict=True, + ): + key1 = (chain_id[0], res_id[0], atom_name[0]) + key2 = (chain_id[1], res_id[1], atom_name[1]) + keep_indx.append(bool(key1 not in uids) and bool(key2 not in uids)) + if key1 not in uids: + uids[key1] = None + if key2 not in uids: + uids[key2] = None + return bond_layout[np.array(keep_indx, dtype=bool)] + + +@functools.lru_cache(maxsize=1) +def get_ligand_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool = False, +) -> atom_layout.AtomLayout: + """Return ligand-ligand inter-residue bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + return get_bond_layout( + bond_threshold=BOND_THRESHOLD_GLYCANS_ANGSTROM + if only_glycan_ligands + else BOND_THRESHOLD_ALL_ANGSTROM, + struct=struct, + allowed_chain_types1=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + + +@functools.lru_cache(maxsize=1) +def get_polymer_ligand_bonds( + struct: structure.Structure, + only_glycan_ligands: bool, + allow_multiple_bonds_per_atom: bool = False, + bond_threshold: float | None = None, +) -> atom_layout.AtomLayout: + """Return polymer-ligand interchain bonds. + + Args: + struct: Structure object to extract bonds from. + only_glycan_ligands: Whether to only include glycans in ligand category. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + bond_threshold: Euclidean distance of max allowed bond. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if only_glycan_ligands: + allowed_res_names = list({ + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + }) + else: + allowed_res_names = None + if bond_threshold is None: + if only_glycan_ligands: + bond_threshold = BOND_THRESHOLD_GLYCANS_ANGSTROM + else: + bond_threshold = BOND_THRESHOLD_ALL_ANGSTROM + return get_bond_layout( + bond_threshold=bond_threshold, + struct=struct, + allowed_chain_types1=list(mmcif_names.POLYMER_CHAIN_TYPES), + allowed_chain_types2=list(mmcif_names.LIGAND_CHAIN_TYPES), + allowed_res_names=allowed_res_names, + allow_multiple_bonds_per_atom=allow_multiple_bonds_per_atom, + ) + + +def get_bond_layout( + bond_threshold: float = BOND_THRESHOLD_ALL_ANGSTROM, + *, + struct: structure.Structure, + allowed_chain_types1: Collection[str], + allowed_chain_types2: Collection[str], + include_bond_types: Collection[str] = ('covale',), + allowed_res_names: Collection[str] | None = None, + allow_multiple_bonds_per_atom: bool, +) -> atom_layout.AtomLayout: + """Get bond_layout for all bonds between two sets of chain types. + + There is a mask (all_mask) that runs through this script, and each bond pair + needs to maintain a True across all conditions in order to be preserved at the + end, otherwise the bond pair has invalidated a condition with a False and is + removed entirely. Note, we remove oxygen atom bonds as they are an edge case + that causes issues with scoring, due to multiple waters bonding with single + residues. + + Args: + bond_threshold: Maximum bond distance in Angstrom. + struct: Structure object to extract bonds from. + allowed_chain_types1: One end of the bonds must be an atom with one of these + chain types. + allowed_chain_types2: The other end of the bond must be an atom with one of + these chain types. + include_bond_types: Only include bonds with specified type e.g. hydrog, + metalc, covale, disulf. + allowed_res_names: Further restricts from chain_types. Either end of the + bonds must be an atom part of these res_names. If none all will be + accepted after chain and bond type filtering. + allow_multiple_bonds_per_atom: If not allowed, we greedily choose the first + bond seen per atom and discard the remaining on each atom. + + Returns: + bond_layout: AtomLayout object [num_bonds, 2] for the bond-defining atoms. + """ + if not struct.bonds: + return atom_layout.AtomLayout( + atom_name=np.empty((0, 2), dtype=object), + res_id=np.empty((0, 2), dtype=int), + res_name=np.empty((0, 2), dtype=object), + chain_id=np.empty((0, 2), dtype=object), + chain_type=np.empty((0, 2), dtype=object), + atom_element=np.empty((0, 2), dtype=object), + ) + from_atom_idxs, dest_atom_idxs = struct.bonds.get_atom_indices( + struct.atom_key + ) + from_atoms = _get_bond_atom_arrays(struct, from_atom_idxs) + dest_atoms = _get_bond_atom_arrays(struct, dest_atom_idxs) + # Chain type + chain_mask = np.logical_or( + np.logical_and( + np.isin( + from_atoms.chain_type, + allowed_chain_types1, + ), + np.isin( + dest_atoms.chain_type, + allowed_chain_types2, + ), + ), + np.logical_and( + np.isin( + from_atoms.chain_type, + allowed_chain_types2, + ), + np.isin( + dest_atoms.chain_type, + allowed_chain_types1, + ), + ), + ) + if allowed_res_names: + # Res type + res_mask = np.logical_or( + np.isin(from_atoms.res_name, allowed_res_names), + np.isin(dest_atoms.res_name, allowed_res_names), + ) + # All mask + all_mask = np.logical_and(chain_mask, res_mask) + else: + all_mask = chain_mask + # Bond type mask + type_mask = np.isin(struct.bonds.type, list(include_bond_types)) + np.logical_and(all_mask, type_mask, out=all_mask) + # Bond length check. Work in square length to avoid taking many square roots. + bond_length_squared = np.square(from_atoms.coords - dest_atoms.coords).sum( + axis=1 + ) + bond_threshold_squared = bond_threshold * bond_threshold + np.logical_and( + all_mask, bond_length_squared < bond_threshold_squared, out=all_mask + ) + # Inter-chain and inter-residue bonds for ligands + ligand_types = list(mmcif_names.LIGAND_CHAIN_TYPES) + is_ligand = np.logical_or( + np.isin( + from_atoms.chain_type, + ligand_types, + ), + np.isin( + dest_atoms.chain_type, + ligand_types, + ), + ) + res_id_differs = from_atoms.res_id != dest_atoms.res_id + chain_id_differs = from_atoms.chain_id != dest_atoms.chain_id + is_inter_res = np.logical_or(res_id_differs, chain_id_differs) + is_inter_ligand_res = np.logical_and(is_inter_res, is_ligand) + is_inter_chain_not_ligand = np.logical_and(chain_id_differs, ~is_ligand) + # If ligand then inter-res & inter-chain bonds, otherwise inter-chain only. + combined_allowed_bonds = np.logical_or( + is_inter_chain_not_ligand, is_inter_ligand_res + ) + np.logical_and(all_mask, combined_allowed_bonds, out=all_mask) + bond_layout = atom_layout.AtomLayout( + atom_name=np.stack( + [ + from_atoms.atom_name[all_mask], + dest_atoms.atom_name[all_mask], + ], + axis=1, + dtype=object, + ), + res_id=np.stack( + [from_atoms.res_id[all_mask], dest_atoms.res_id[all_mask]], + axis=1, + dtype=int, + ), + chain_id=np.stack( + [ + from_atoms.chain_id[all_mask], + dest_atoms.chain_id[all_mask], + ], + axis=1, + dtype=object, + ), + ) + if not allow_multiple_bonds_per_atom: + bond_layout = _remove_multi_bonds(bond_layout) + return atom_layout.fill_in_optional_fields( + bond_layout, + reference_atoms=atom_layout.atom_layout_from_structure(struct), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py new file mode 100644 index 000000000..0da3d6c41 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/pipeline.py @@ -0,0 +1,446 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""The main featurizer.""" + +import bisect +from collections.abc import Sequence +import datetime +import itertools + +from absl import logging +from alphafold3.common import base_config +from alphafold3.common import folding_input +from alphafold3.constants import chemical_components +from alphafold3.model import feat_batch +from alphafold3.model import features +from alphafold3.model.pipeline import inter_chain_bonds +from alphafold3.model.pipeline import structure_cleaning +from alphafold3.structure import chemical_components as struc_chem_comps +import numpy as np +from alphafold3.common.folding_input import Template + + +_DETERMINISTIC_FRAMES_RANDOM_SEED = 12312837 + + +def calculate_bucket_size( + num_tokens: int, buckets: Sequence[int] | None +) -> int: + """Calculates the bucket size to pad the data to.""" + if buckets is None: + return num_tokens + + if not buckets: + raise ValueError('Buckets must be non-empty.') + + if not all(prev < curr for prev, curr in itertools.pairwise(buckets)): + raise ValueError( + f'Buckets must be in strictly increasing order. Got {buckets=}.' + ) + + bucket_idx = bisect.bisect_left(buckets, num_tokens) + + if bucket_idx == len(buckets): + logging.warning( + 'Creating a new bucket of size %d since the input has more tokens than' + ' the largest bucket size %d. This may trigger a re-compilation of the' + ' model. Consider additional large bucket sizes to avoid excessive' + ' re-compilation.', + num_tokens, + buckets[-1], + ) + return num_tokens + + return buckets[bucket_idx] + + +class NanDataError(Exception): + """Raised if the data pipeline produces data containing nans.""" + + +class TotalNumResOutOfRangeError(Exception): + """Raised if total number of residues for all chains outside allowed range.""" + + +class MmcifNumChainsError(Exception): + """Raised if the mmcif file contains too many / too few chains.""" + + +class WholePdbPipeline: + """Processes an entire mmcif entity and merges the content.""" + + class Config(base_config.BaseConfig): + """Configuration object for `WholePdbPipeline`. + + Properties: + max_atoms_per_token: number of atom slots in one token (was called + num_dense, and semi-hardcoded to 24 before) + pad_num_chains: Size to pad NUM_CHAINS feature dimensions to, only for + protein chains. + buckets: Bucket sizes to pad the data to, to avoid excessive + re-compilation of the model. If None, calculate the appropriate bucket + size from the number of tokens. If not None, must be a sequence of at + least one integer, in strictly increasing order. Will raise an error if + the number of tokens is more than the largest bucket size. + max_total_residues: Any mmCIF with more total residues will be rejected. + If none, then no limit is applied. + min_total_residues: Any mmCIF with less total residues will be rejected. + msa_crop_size: Maximum size of MSA to take across all chains. + max_template_date: Optional max template date to prevent data leakage in + validation. + max_templates: The maximum number of templates to send through the network + set to 0 to switch off templates. + filter_clashes: If true then will remove clashing chains. + filter_crystal_aids: If true ligands in the cryal aid list are removed. + max_paired_sequence_per_species: The maximum number of sequences per + species that will be used for MSA pairing. + drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. + intra_ligand_ptm_bonds: Whether to embed intra ligand covalent bond graph. + average_num_atoms_per_token: Target average number of atoms per token to + compute the padding size for flat atoms. + atom_cross_att_queries_subset_size: queries subset size in atom cross + attention + atom_cross_att_keys_subset_size: keys subset size in atom cross attention + flatten_non_standard_residues: Whether to expand non-standard polymer + residues into flat-atom format. + remove_nonsymmetric_bonds: Whether to remove nonsymmetric bonds from + symmetric polymer chains. + deterministic_frames: Whether to use fixed-seed reference positions to + construct deterministic frames. + """ + + max_atoms_per_token: int = 24 + pad_num_chains: int = 1000 + buckets: list[int] | None = None + max_total_residues: int | None = None + min_total_residues: int | None = None + msa_crop_size: int = 16384 + max_template_date: datetime.date | None = None + max_templates: int = 4 + filter_clashes: bool = False + filter_crystal_aids: bool = False + max_paired_sequence_per_species: int = 600 + drop_ligand_leaving_atoms: bool = True + intra_ligand_ptm_bonds: bool = True + average_num_atoms_per_token: int = 24 + atom_cross_att_queries_subset_size: int = 32 + atom_cross_att_keys_subset_size: int = 128 + flatten_non_standard_residues: bool = True + remove_nonsymmetric_bonds: bool = False + deterministic_frames: bool = True + + def __init__( + self, + *, + config: Config, + ): + """Init WholePdb. + + Args: + config: Pipeline configuration. + """ + self._config = config + + def process_item( + self, + fold_input: folding_input.Input, + random_state: np.random.RandomState, + ccd: chemical_components.Ccd, + random_seed: int | None = None, + ) -> features.BatchDict: + """Takes requests from in_queue, adds (key, serialized ex) to out_queue.""" + if random_seed is None: + random_seed = random_state.randint(2**31) + + random_state = np.random.RandomState(seed=random_seed) + + logging_name = f'{fold_input.name}, random_seed={random_seed}' + logging.info('processing %s', logging_name) + struct = fold_input.to_structure(ccd=ccd) + + # Clean structure. + cleaned_struc, cleaning_metadata = structure_cleaning.clean_structure( + struct, + ccd=ccd, + drop_non_standard_atoms=True, + drop_missing_sequence=True, + filter_clashes=self._config.filter_clashes, + filter_crystal_aids=self._config.filter_crystal_aids, + filter_waters=True, + filter_hydrogens=True, + filter_leaving_atoms=self._config.drop_ligand_leaving_atoms, + only_glycan_ligands_for_leaving_atoms=True, + covalent_bonds_only=True, + remove_polymer_polymer_bonds=True, + remove_bad_bonds=True, + remove_nonsymmetric_bonds=self._config.remove_nonsymmetric_bonds, + ) + + num_clashing_chains_removed = cleaning_metadata[ + 'num_clashing_chains_removed' + ] + + if num_clashing_chains_removed: + logging.info( + 'Removed %d clashing chains from %s', + num_clashing_chains_removed, + logging_name, + ) + + # No chains after fixes + # if cleaned_struc.num_chains == 0: + # raise MmcifNumChainsError(f'{logging_name}: No chains in structure!') + + polymer_ligand_bonds, ligand_ligand_bonds = ( + inter_chain_bonds.get_polymer_ligand_and_ligand_ligand_bonds( + cleaned_struc, + only_glycan_ligands=False, + allow_multiple_bonds_per_atom=True, + ) + ) + + # If empty replace with None as this causes errors downstream. + if ligand_ligand_bonds and not ligand_ligand_bonds.atom_name.size: + ligand_ligand_bonds = None + if polymer_ligand_bonds and not polymer_ligand_bonds.atom_name.size: + polymer_ligand_bonds = None + + # Create the flat output AtomLayout + empty_output_struc, flat_output_layout = ( + structure_cleaning.create_empty_output_struc_and_layout( + struc=cleaned_struc, + ccd=ccd, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + drop_ligand_leaving_atoms=self._config.drop_ligand_leaving_atoms, + ) + ) + + # Select the tokens for Evoformer. + # Each token (e.g. a residue) is encoded as one representative atom. This + # is flexible enough to allow the 1-token-per-atom ligand representation + # in the future. + all_tokens, all_token_atoms_layout, standard_token_idxs = ( + features.tokenizer( + flat_output_layout, + ccd=ccd, + max_atoms_per_token=self._config.max_atoms_per_token, + flatten_non_standard_residues=self._config.flatten_non_standard_residues, + logging_name=logging_name, + ) + ) + total_tokens = len(all_tokens.atom_name) + if ( + self._config.max_total_residues + and total_tokens > self._config.max_total_residues + ): + raise TotalNumResOutOfRangeError( + 'Total Number of Residues > max_total_residues: ' + f'({total_tokens} > {self._config.max_total_residues})' + ) + + if ( + self._config.min_total_residues + and total_tokens < self._config.min_total_residues + ): + raise TotalNumResOutOfRangeError( + 'Total Number of Residues < min_total_residues: ' + f'({total_tokens} < {self._config.min_total_residues})' + ) + + logging.info( + 'Calculating bucket size for input with %d tokens.', total_tokens + ) + padded_token_length = calculate_bucket_size( + total_tokens, self._config.buckets + ) + logging.info( + 'Got bucket size %d for input with %d tokens, resulting in %d padded' + ' tokens.', + padded_token_length, + total_tokens, + padded_token_length - total_tokens, + ) + + # Padding shapes for all features. + num_atoms = padded_token_length * self._config.average_num_atoms_per_token + # Round up to next multiple of subset size. + num_atoms = int( + np.ceil(num_atoms / self._config.atom_cross_att_queries_subset_size) + * self._config.atom_cross_att_queries_subset_size + ) + padding_shapes = features.PaddingShapes( + num_tokens=padded_token_length, + msa_size=self._config.msa_crop_size, + num_chains=self._config.pad_num_chains, + num_templates=self._config.max_templates, + num_atoms=num_atoms, + ) + + # Create the atom layouts for flat atom cross attention + batch_atom_cross_att = features.AtomCrossAtt.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + queries_subset_size=self._config.atom_cross_att_queries_subset_size, + keys_subset_size=self._config.atom_cross_att_keys_subset_size, + padding_shapes=padding_shapes, + ) + + # Extract per-token features + batch_token_features = features.TokenFeatures.compute_features( + all_tokens=all_tokens, + padding_shapes=padding_shapes, + ) + + # Create reference structure features + chemical_components_data = struc_chem_comps.populate_missing_ccd_data( + ccd=ccd, + chemical_components_data=cleaned_struc.chemical_components_data, + populate_pdbx_smiles=True, + ) + + # Add smiles info to empty_output_struc. + empty_output_struc = empty_output_struc.copy_and_update_globals( + chemical_components_data=chemical_components_data + ) + # Create layouts and store structures for model output conversion. + batch_convert_model_output = features.ConvertModelOutput.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + padding_shapes=padding_shapes, + cleaned_struc=cleaned_struc, + flat_output_layout=flat_output_layout, + empty_output_struc=empty_output_struc, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + # Create the PredictedStructureInfo + batch_predicted_structure_info = ( + features.PredictedStructureInfo.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + padding_shapes=padding_shapes, + ) + ) + + # Create MSA features + batch_msa = features.MSA.compute_features( + all_tokens=all_tokens, + standard_token_idxs=standard_token_idxs, + padding_shapes=padding_shapes, + fold_input=fold_input, + logging_name=logging_name, + max_paired_sequence_per_species=self._config.max_paired_sequence_per_species, + ) + + # Create template features + batch_templates = features.Templates.compute_features( + all_tokens=all_tokens, + standard_token_idxs=standard_token_idxs, + padding_shapes=padding_shapes, + fold_input=fold_input, + max_templates=self._config.max_templates, + logging_name=logging_name, + ) + + ref_max_modified_date = self._config.max_template_date + batch_ref_structure, ligand_ligand_bonds = ( + features.RefStructure.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + chemical_components_data=chemical_components_data, + random_state=random_state, + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=self._config.intra_ligand_ptm_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + ) + deterministic_ref_structure = None + if self._config.deterministic_frames: + deterministic_ref_structure, _ = features.RefStructure.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + chemical_components_data=chemical_components_data, + random_state=( + np.random.RandomState(_DETERMINISTIC_FRAMES_RANDOM_SEED) + ), + ref_max_modified_date=ref_max_modified_date, + intra_ligand_ptm_bonds=self._config.intra_ligand_ptm_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + ) + + # Create ligand-polymer bond features. + polymer_ligand_bond_info = features.PolymerLigandBondInfo.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + bond_layout=polymer_ligand_bonds, + padding_shapes=padding_shapes, + ) + # Create ligand-ligand bond features. + ligand_ligand_bond_info = features.LigandLigandBondInfo.compute_features( + all_tokens, + ligand_ligand_bonds, + padding_shapes, + ) + + # Create the Pseudo-beta layout for distogram head and distance error head. + batch_pseudo_beta_info = features.PseudoBetaInfo.compute_features( + all_token_atoms_layout=all_token_atoms_layout, + ccd=ccd, + padding_shapes=padding_shapes, + logging_name=logging_name, + ) + + # Frame construction. + batch_frames = features.Frames.compute_features( + all_tokens=all_tokens, + all_token_atoms_layout=all_token_atoms_layout, + ref_structure=( + deterministic_ref_structure + if self._config.deterministic_frames + else batch_ref_structure + ), + padding_shapes=padding_shapes, + ) + + # Assemble the Batch object. + batch = feat_batch.Batch( + msa=batch_msa, + templates=batch_templates, + token_features=batch_token_features, + ref_structure=batch_ref_structure, + predicted_structure_info=batch_predicted_structure_info, + polymer_ligand_bond_info=polymer_ligand_bond_info, + ligand_ligand_bond_info=ligand_ligand_bond_info, + pseudo_beta_info=batch_pseudo_beta_info, + atom_cross_att=batch_atom_cross_att, + convert_model_output=batch_convert_model_output, + frames=batch_frames, + ) + + np_example = batch.as_data_dict() + if 'num_iter_recycling' in np_example: + del np_example['num_iter_recycling'] # that does not belong here + + for name, value in np_example.items(): + if ( + value.dtype.kind not in {'U', 'S'} + and value.dtype.name != 'object' + and np.isnan(np.sum(value)) + ): + raise NanDataError( + 'The output of the data pipeline contained nans. ' + f'nan feature: {name}, fold input name: {fold_input.name}, ' + f'random_seed {random_seed}' + ) + + return np_example diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py new file mode 100644 index 000000000..3f0321a9c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/pipeline/structure_cleaning.py @@ -0,0 +1,371 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Prepare PDB structure for training or inference.""" + +from typing import Any + +from absl import logging +from alphafold3 import structure +from alphafold3.constants import chemical_component_sets +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout +from alphafold3.model.pipeline import inter_chain_bonds +from alphafold3.model.scoring import covalent_bond_cleaning +from alphafold3.structure import sterics +import numpy as np + + +def _get_leaving_atom_mask( + struc: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout | None, + ligand_ligand_bonds: atom_layout.AtomLayout | None, + chain_id: str, + chain_type: str, + res_id: int, + res_name: str, +) -> np.ndarray: + """Updates a drop_leaving_atoms mask with new leaving atom locations.""" + bonded_atoms = atom_layout.get_bonded_atoms( + polymer_ligand_bonds, + ligand_ligand_bonds, + res_id, + chain_id, + ) + # Connect the amino-acids, i.e. remove OXT, HXT and H2. + drop_atoms = atom_layout.get_link_drop_atoms( + res_name=res_name, + chain_type=chain_type, + is_start_terminus=False, + is_end_terminus=False, + bonded_atoms=bonded_atoms, + drop_ligand_leaving_atoms=True, + ) + # Default mask where everything is false, which equates to being kept. + drop_atom_filter_atoms = struc.chain_id != struc.chain_id + for drop_atom in drop_atoms: + drop_atom_filter_atom = np.logical_and( + np.logical_and( + struc.atom_name == drop_atom, + struc.chain_id == chain_id, + ), + struc.res_id == res_id, + ) + drop_atom_filter_atoms = np.logical_or( + drop_atom_filter_atoms, drop_atom_filter_atom + ) + return drop_atom_filter_atoms + + +def clean_structure( + struc: structure.Structure, + ccd: chemical_components.Ccd, + *, + drop_missing_sequence: bool, + filter_clashes: bool, + drop_non_standard_atoms: bool, + filter_crystal_aids: bool, + filter_waters: bool, + filter_hydrogens: bool, + filter_leaving_atoms: bool, + only_glycan_ligands_for_leaving_atoms: bool, + covalent_bonds_only: bool, + remove_polymer_polymer_bonds: bool, + remove_bad_bonds: bool, + remove_nonsymmetric_bonds: bool, +) -> tuple[structure.Structure, dict[str, Any]]: + """Cleans structure. + + Args: + struc: Structure to clean. + ccd: The chemical components dictionary. + drop_missing_sequence: Whether to drop chains without specified sequences. + filter_clashes: Whether to drop clashing chains. + drop_non_standard_atoms: Whether to drop non CCD standard atoms. + filter_crystal_aids: Whether to drop ligands in the crystal aid set. + filter_waters: Whether to drop water chains. + filter_hydrogens: Whether to drop hyrdogen atoms. + filter_leaving_atoms: Whether to drop leaving atoms based on heuristics. + only_glycan_ligands_for_leaving_atoms: Whether to only include glycan + ligands when filtering leaving atoms. + covalent_bonds_only: Only include covalent bonds. + remove_polymer_polymer_bonds: Remove polymer-polymer bonds. + remove_bad_bonds: Whether to remove badly bonded ligands. + remove_nonsymmetric_bonds: Whether to remove nonsymmetric polymer-ligand + bonds from symmetric polymer chains. + + Returns: + Tuple of structure and metadata dict. The metadata dict has + information about what was cleaned from the original. + """ + + metadata = {} + # Crop crystallization aids. + if ( + filter_crystal_aids + and struc.structure_method in mmcif_names.CRYSTALLIZATION_METHODS + ): + struc = struc.filter_out( + res_name=chemical_component_sets.COMMON_CRYSTALLIZATION_AIDS + ) + + # Drop chains without specified sequences. + if drop_missing_sequence: + chains_with_unk_sequence = struc.find_chains_with_unknown_sequence() + num_with_unk_sequence = len(chains_with_unk_sequence) + if chains_with_unk_sequence: + struc = struc.filter_out(chain_id=chains_with_unk_sequence) + else: + num_with_unk_sequence = 0 + metadata['num_with_unk_sequence'] = num_with_unk_sequence + + # Remove intersecting chains. + if filter_clashes and struc.num_chains > 1: + clashing_chains = sterics.find_clashing_chains(struc) + if clashing_chains: + struc = struc.filter_out(chain_id=clashing_chains) + else: + clashing_chains = [] + metadata['num_clashing_chains_removed'] = len(clashing_chains) + metadata['chains_removed'] = clashing_chains + + # Drop non-standard atoms + if drop_non_standard_atoms: + struc = struc.drop_non_standard_atoms( + ccd=ccd, drop_unk=False, drop_non_ccd=False + ) + + # Sort chains in "reverse-spreadsheet" order. + struc = struc.with_sorted_chains + + if filter_hydrogens: + struc = struc.without_hydrogen() + + if filter_waters: + struc = struc.filter_out(chain_type=mmcif_names.WATER) + + if filter_leaving_atoms: + drop_leaving_atoms_all = struc.chain_id != struc.chain_id + polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds( + struc, + only_glycan_ligands=only_glycan_ligands_for_leaving_atoms, + ) + ligand_ligand_bonds = inter_chain_bonds.get_ligand_ligand_bonds( + struc, + only_glycan_ligands=only_glycan_ligands_for_leaving_atoms, + ) + all_glycans = { + *chemical_component_sets.GLYCAN_OTHER_LIGANDS, + *chemical_component_sets.GLYCAN_LINKING_LIGANDS, + } + # If only glycan ligands and no O1 atoms, we can do parallel drop. + if ( + only_glycan_ligands_for_leaving_atoms + and (not (ligand_ligand_bonds.atom_name == 'O1').any()) + and (not (polymer_ligand_bonds.atom_name == 'O1').any()) + ): + drop_leaving_atoms_all = np.logical_and( + np.isin(struc.atom_name, 'O1'), + np.isin(struc.res_name, list(all_glycans)), + ) + else: + substruct = struc.group_by_residue + glycan_mask = np.isin(substruct.res_name, list(all_glycans)) + substruct = substruct.filter(glycan_mask) + # We need to iterate over all glycan residues for this. + for res in substruct.iter_residues(): + # Only need to do drop leaving atoms for glycans depending on bonds. + if (res_name := res['res_name']) in all_glycans: + drop_atom_filter = _get_leaving_atom_mask( + struc=struc, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + chain_id=res['chain_id'], + chain_type=res['chain_type'], + res_id=res['res_id'], + res_name=res_name, + ) + drop_leaving_atoms_all = np.logical_or( + drop_leaving_atoms_all, drop_atom_filter + ) + + num_atoms_before = struc.num_atoms + struc = struc.filter_out(drop_leaving_atoms_all) + num_atoms_after = struc.num_atoms + + if num_atoms_before > num_atoms_after: + logging.error( + 'Dropped %s atoms from GT struc: chain_id %s res_id %s res_name %s', + num_atoms_before - num_atoms_after, + struc.chain_id, + struc.res_id, + struc.res_name, + ) + + # Can filter by bond type without having to iterate over bonds. + if struc.bonds and covalent_bonds_only: + is_covalent = np.isin(struc.bonds.type, ['covale']) + if sum(is_covalent) > 0: + new_bonds = struc.bonds[is_covalent] + else: + new_bonds = structure.Bonds.make_empty() + struc = struc.copy_and_update(bonds=new_bonds) + + # Other bond filters require iterating over individual bonds. + if struc.bonds and (remove_bad_bonds or remove_polymer_polymer_bonds): + include_bond = [] + num_pp_bonds = 0 + num_bad_bonds = 0 + for bond in struc.iter_bonds(): + dest_atom = bond.dest_atom + from_atom = bond.from_atom + if remove_polymer_polymer_bonds: + if ( + from_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES + and dest_atom['chain_type'] in mmcif_names.POLYMER_CHAIN_TYPES + ): + num_pp_bonds += 1 + include_bond.append(False) + continue + if remove_bad_bonds: + dest_coords = np.array( + [dest_atom['atom_x'], dest_atom['atom_y'], dest_atom['atom_z']] + ) + from_coords = np.array( + [from_atom['atom_x'], from_atom['atom_y'], from_atom['atom_z']] + ) + squared_dist = np.sum(np.square(dest_coords - from_coords)) + squared_threshold = 2.4 * 2.4 + if squared_dist > squared_threshold: + num_bad_bonds += 1 + include_bond.append(False) + continue + include_bond.append(True) + if sum(include_bond) < len(struc.bonds): + logging.info( + 'Reducing number of bonds for %s from %s to %s, of which %s are' + ' polymer-polymer bonds and %s are bad bonds.', + struc.name, + len(struc.bonds), + sum(include_bond), + num_pp_bonds, + num_bad_bonds, + ) + if sum(include_bond) > 0: + # Need to index bonds with bond keys or arrays of bools with same length + # as num bonds. In this case, we use array of bools (as elsewhere in the + # cleaning code). + new_bonds = struc.bonds[np.array(include_bond, dtype=bool)] + else: + new_bonds = structure.Bonds.make_empty() + struc = struc.copy_and_update(bonds=new_bonds) + + if struc.bonds and remove_nonsymmetric_bonds: + # Check for asymmetric polymer-ligand bonds and remove if these exist. + polymer_ligand_bonds = inter_chain_bonds.get_polymer_ligand_bonds( + struc, + only_glycan_ligands=False, + ) + if polymer_ligand_bonds: + if covalent_bond_cleaning.has_nonsymmetric_bonds_on_symmetric_polymer_chains( + struc, polymer_ligand_bonds + ): + from_atom_idxs, dest_atom_idxs = struc.bonds.get_atom_indices( + struc.atom_key + ) + poly_chain_types = list(mmcif_names.POLYMER_CHAIN_TYPES) + is_polymer_bond = np.logical_or( + np.isin( + struc.chain_type[from_atom_idxs], poly_chain_types), + np.isin( + struc.chain_type[dest_atom_idxs], poly_chain_types), + ) + struc = struc.copy_and_update( + bonds=struc.bonds[~is_polymer_bond]) + + return struc, metadata + + +def create_empty_output_struc_and_layout( + struc: structure.Structure, + ccd: chemical_components.Ccd, + *, + with_hydrogens: bool = False, + skip_unk: bool = False, + polymer_ligand_bonds: atom_layout.AtomLayout | None = None, + ligand_ligand_bonds: atom_layout.AtomLayout | None = None, + drop_ligand_leaving_atoms: bool = False, +) -> tuple[structure.Structure, atom_layout.AtomLayout]: + """Make zero-coordinate structure from all physical residues. + + Args: + struc: Structure object. + ccd: The chemical components dictionary. + with_hydrogens: Whether to keep hydrogen atoms in structure. + skip_unk: Whether to remove unknown residues from structure. + polymer_ligand_bonds: Bond information for polymer-ligand pairs. + ligand_ligand_bonds: Bond information for ligand-ligand pairs. + drop_ligand_leaving_atoms: Flag for handling leaving atoms for ligands. + + Returns: + Tuple of structure with all bonds, physical residues and coordinates set to + 0 and a flat atom layout of empty structure. + """ + bonded_atom_pairs = [] + if polymer_ligand_bonds: + for chain_ids, res_ids, atom_names in zip( + polymer_ligand_bonds.chain_id, + polymer_ligand_bonds.res_id, + polymer_ligand_bonds.atom_name, + strict=True, + ): + bonded_atom_pairs.append(( + (chain_ids[0], res_ids[0], atom_names[0]), + (chain_ids[1], res_ids[1], atom_names[1]), + )) + if ligand_ligand_bonds: + for chain_ids, res_ids, atom_names in zip( + ligand_ligand_bonds.chain_id, + ligand_ligand_bonds.res_id, + ligand_ligand_bonds.atom_name, + strict=True, + ): + bonded_atom_pairs.append(( + (chain_ids[0], res_ids[0], atom_names[0]), + (chain_ids[1], res_ids[1], atom_names[1]), + )) + residues = atom_layout.residues_from_structure( + struc, include_missing_residues=True + ) + + flat_output_layout = atom_layout.make_flat_atom_layout( + residues, + ccd=ccd, + with_hydrogens=with_hydrogens, + skip_unk_residues=skip_unk, + polymer_ligand_bonds=polymer_ligand_bonds, + ligand_ligand_bonds=ligand_ligand_bonds, + drop_ligand_leaving_atoms=drop_ligand_leaving_atoms, + ) + + empty_output_struc = atom_layout.make_structure( + flat_layout=flat_output_layout, + atom_coords=np.zeros((flat_output_layout.shape[0], 3)), + name=struc.name, + atom_b_factors=None, + all_physical_residues=residues, + ) + if bonded_atom_pairs: + empty_output_struc = empty_output_struc.add_bonds( + bonded_atom_pairs, bond_type=mmcif_names.COVALENT_BOND + ) + + return empty_output_struc, flat_output_layout diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py new file mode 100644 index 000000000..d200279e7 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/post_processing.py @@ -0,0 +1,114 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Post-processing utilities for AlphaFold inference results.""" + +import dataclasses +import datetime +import os + +# from alphafold3 import version +from alphafold3.model import confidence_types +from alphafold3.model import mmcif_metadata +from alphafold3.model.components import base_model +import numpy as np + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class ProcessedInferenceResult: + """Stores attributes of a processed inference result. + + Attributes: + cif: CIF file containing an inference result. + mean_confidence_1d: Mean 1D confidence calculated from confidence_1d. + ranking_score: Ranking score extracted from CIF metadata. + structure_confidence_summary_json: Content of JSON file with structure + confidences summary calculated from CIF file. + structure_full_data_json: Content of JSON file with structure full + confidences calculated from CIF file. + model_id: Identifier of the model that produced the inference result. + """ + + cif: bytes + mean_confidence_1d: float + ranking_score: float + structure_confidence_summary_json: bytes + structure_full_data_json: bytes + model_id: bytes + + +def post_process_inference_result( + inference_result: base_model.InferenceResult, +) -> ProcessedInferenceResult: + """Returns cif, confidence_1d_json, confidence_2d_json, mean_confidence_1d, and ranking confidence.""" + + # Add mmCIF metadata fields. + timestamp = datetime.datetime.now().isoformat(sep=' ', timespec='seconds') + cif_with_metadata = mmcif_metadata.add_metadata_to_mmcif( + old_cif=inference_result.predicted_structure.to_mmcif_dict(), + # version=f'{version.__version__} @ {timestamp}', + # version=None, + model_id=inference_result.model_id, + ) + cif = mmcif_metadata.add_legal_comment(cif_with_metadata.to_string()) + cif = cif.encode('utf-8') + confidence_1d = confidence_types.AtomConfidence.from_inference_result( + inference_result + ) + mean_confidence_1d = np.mean(confidence_1d.confidence) + structure_confidence_summary_json = ( + confidence_types.StructureConfidenceSummary.from_inference_result( + inference_result + ) + .to_json() + .encode('utf-8') + ) + structure_full_data_json = ( + confidence_types.StructureConfidenceFull.from_inference_result( + inference_result + ) + .to_json() + .encode('utf-8') + ) + return ProcessedInferenceResult( + cif=cif, + mean_confidence_1d=mean_confidence_1d, + ranking_score=float(inference_result.metadata['ranking_score']), + structure_confidence_summary_json=structure_confidence_summary_json, + structure_full_data_json=structure_full_data_json, + model_id=inference_result.model_id, + ) + + +def write_output( + inference_result: base_model.InferenceResult, + output_dir: os.PathLike[str] | str, + terms_of_use: str | None = None, + name: str | None = None, +) -> None: + """Writes processed inference result to a directory.""" + processed_result = post_process_inference_result(inference_result) + + prefix = f'{name}_' if name is not None else '' + + with open(os.path.join(output_dir, f'{prefix}model.cif'), 'wb') as f: + f.write(processed_result.cif) + + with open( + os.path.join(output_dir, f'{prefix}summary_confidences.json'), 'wb' + ) as f: + f.write(processed_result.structure_confidence_summary_json) + + with open(os.path.join(output_dir, f'{prefix}confidences.json'), 'wb') as f: + f.write(processed_result.structure_full_data_json) + + if terms_of_use is not None: + with open(os.path.join(output_dir, 'TERMS_OF_USE.md'), 'wt') as f: + f.write(terms_of_use) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py new file mode 100644 index 000000000..195db4c27 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/protein_data_processing.py @@ -0,0 +1,128 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Process Structure Data.""" + +from alphafold3.constants import atom_types +from alphafold3.constants import residue_names +from alphafold3.constants import side_chains +import numpy as np + + +NUM_DENSE = atom_types.DENSE_ATOM_NUM +NUM_AA = len(residue_names.PROTEIN_TYPES) +NUM_AA_WITH_UNK_AND_GAP = len( + residue_names.PROTEIN_TYPES_ONE_LETTER_WITH_UNKNOWN_AND_GAP +) +NUM_RESTYPES_WITH_UNK_AND_GAP = ( + residue_names.POLYMER_TYPES_NUM_WITH_UNKNOWN_AND_GAP +) + + +def _make_restype_rigidgroup_dense_atom_idx(): + """Create Mapping from rigid_groups to dense_atom indices.""" + # Create an array with the atom names. + # shape (num_restypes, num_rigidgroups, 3_atoms): + # (31, 8, 3) + base_atom_indices = np.zeros( + (NUM_RESTYPES_WITH_UNK_AND_GAP, 8, 3), dtype=np.int32 + ) + + # 4,5,6,7: 'chi1,2,3,4-group' + for restype, restype_letter in enumerate( + residue_names.PROTEIN_TYPES_ONE_LETTER + ): + resname = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter] + + dense_atom_names = atom_types.ATOM14[resname] + # 0: backbone frame + base_atom_indices[restype, 0, :] = [ + dense_atom_names.index(atom) for atom in ['C', 'CA', 'N'] + ] + + # 3: 'psi-group' + base_atom_indices[restype, 3, :] = [ + dense_atom_names.index(atom) for atom in ['CA', 'C', 'O'] + ] + for chi_idx in range(4): + if side_chains.CHI_ANGLES_MASK[restype][chi_idx]: + atom_names = side_chains.CHI_ANGLES_ATOMS[resname][chi_idx] + base_atom_indices[restype, chi_idx + 4, :] = [ + dense_atom_names.index(atom) for atom in atom_names[1:] + ] + dense_atom_names = atom_types.DENSE_ATOM['A'] + nucleic_rigid_atoms = [ + dense_atom_names.index(atom) for atom in ["C1'", "C3'", "C4'"] + ] + for nanum, _ in enumerate(residue_names.NUCLEIC_TYPES): + # 0: backbone frame only. + # we have aa + unk + gap, so we want to start after those + resnum = nanum + NUM_AA_WITH_UNK_AND_GAP + base_atom_indices[resnum, 0, :] = nucleic_rigid_atoms + + return base_atom_indices + + +RESTYPE_RIGIDGROUP_DENSE_ATOM_IDX = _make_restype_rigidgroup_dense_atom_idx() + + +def _make_restype_pseudobeta_idx(): + """Returns indices of residue's pseudo-beta.""" + restype_pseudobeta_index = np.zeros( + (NUM_RESTYPES_WITH_UNK_AND_GAP,), dtype=np.int32 + ) + for restype, restype_letter in enumerate( + residue_names.PROTEIN_TYPES_ONE_LETTER + ): + restype_name = residue_names.PROTEIN_COMMON_ONE_TO_THREE[restype_letter] + atom_names = list(atom_types.ATOM14[restype_name]) + if restype_name in {'GLY'}: + restype_pseudobeta_index[restype] = atom_names.index('CA') + else: + restype_pseudobeta_index[restype] = atom_names.index('CB') + for nanum, resname in enumerate(residue_names.NUCLEIC_TYPES): + atom_names = list(atom_types.DENSE_ATOM[resname]) + # 0: backbone frame only. + # we have aa + unk , so we want to start after those + restype = nanum + NUM_AA_WITH_UNK_AND_GAP + if resname in {'A', 'G', 'DA', 'DG'}: + restype_pseudobeta_index[restype] = atom_names.index('C4') + else: + restype_pseudobeta_index[restype] = atom_names.index('C2') + return restype_pseudobeta_index + + +RESTYPE_PSEUDOBETA_INDEX = _make_restype_pseudobeta_idx() + + +def _make_aatype_dense_atom_to_atom37(): + """Map from dense_atom to atom37 per residue type.""" + restype_dense_atom_to_atom37 = [ + ] # mapping (restype, dense_atom) --> atom37 + for rt in residue_names.PROTEIN_TYPES_ONE_LETTER: + atom_names = list( + atom_types.ATOM14_PADDED[residue_names.PROTEIN_COMMON_ONE_TO_THREE[rt]] + ) + atom_names.extend([''] * (NUM_DENSE - len(atom_names))) + restype_dense_atom_to_atom37.append( + [(atom_types.ATOM37_ORDER[name] if name else 0) + for name in atom_names] + ) + # Add dummy mapping for restype 'UNK', '-' (gap), and nucleics [but not DN]. + for _ in range(2 + len(residue_names.NUCLEIC_TYPES_WITH_UNKNOWN)): + restype_dense_atom_to_atom37.append([0] * NUM_DENSE) + + restype_dense_atom_to_atom37 = np.array( + restype_dense_atom_to_atom37, dtype=np.int32 + ) + return restype_dense_atom_to_atom37 + + +PROTEIN_AATYPE_DENSE_ATOM_TO_ATOM37 = _make_aatype_dense_atom_to_atom37() diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py new file mode 100644 index 000000000..4224a7b23 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/alignment.py @@ -0,0 +1,146 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Alignment based metrics.""" + +import numpy as np + + +def transform_ls( + x: np.ndarray, + b: np.ndarray, + *, + allow_reflection: bool = False, +) -> np.ndarray: + """Find the least squares best fit rotation between two sets of N points. + + Solve Ax = b for A. Where A is the transform rotating x^T into b^T. + + Args: + x: NxD numpy array of coordinates. Usually dimension D is 3. + b: NxD numpy array of coordinates. Usually dimension D is 3. + allow_reflection: Whether the returned transformation can reflect as well as + rotate. + + Returns: + Matrix A transforming x into b, i.e. s.t. Ax^T = b^T. + """ + assert x.shape[1] >= b.shape[1] + assert b.shape[0] == x.shape[0], '%d, %d' % (b.shape[0], x.shape[0]) + # First postmultiply by x.; + # Axx^t = b x^t + bxt = np.dot(b.transpose(), x) / b.shape[0] + + u, _, v = np.linalg.svd(bxt) + + r = np.dot(u, v) + if not allow_reflection: + flip = np.ones((v.shape[1], 1)) + flip[v.shape[1] - 1, 0] = np.sign(np.linalg.det(r)) + r = np.dot(u, v * flip) + + return r + + +def align( + *, + x: np.ndarray, + y: np.ndarray, + x_indices: np.ndarray, + y_indices: np.ndarray, +) -> np.ndarray: + """Align x to y considering only included_idxs. + + Args: + x: NxD np array of coordinates. + y: NxD np array of coordinates. + x_indices: An np array of indices for `x` that will be used in the + alignment. Must be of the same length as `y_included_idxs`. + y_indices: An np array of indices for `y` that will be used in the + alignment. Must be of the same length as `x_included_idxs`. + + Returns: + NxD np array of points obtained by applying a rigid transformation to x. + These points are aligned to y and the alignment is the optimal alignment + over the points in included_idxs. + + Raises: + ValueError: If the number of included indices is not the same for both + input arrays. + """ + if len(x_indices) != len(y_indices): + raise ValueError( + 'Number of included indices must be the same for both input arrays,' + f' but got for x: {len(x_indices)}, and for y: {len(y_indices)}.' + ) + + x_mean = np.mean(x[x_indices, :], axis=0) + y_mean = np.mean(y[y_indices, :], axis=0) + + centered_x = x - x_mean + centered_y = y - y_mean + t = transform_ls(centered_x[x_indices, :], centered_y[y_indices, :]) + transformed_x = np.dot(centered_x, t.transpose()) + y_mean + + return transformed_x + + +def deviations_from_coords( + decoy_coords: np.ndarray, + gt_coords: np.ndarray, + align_idxs: np.ndarray | None = None, + include_idxs: np.ndarray | None = None, +) -> np.ndarray: + """Returns the raw per-atom deviations used in RMSD computation.""" + if decoy_coords.shape != gt_coords.shape: + raise ValueError( + 'decoy_coords.shape and gt_coords.shape must match.Found: %s and %s.' + % (decoy_coords.shape, gt_coords.shape) + ) + # Include and align all residues unless specified otherwise. + if include_idxs is None: + include_idxs = np.arange(decoy_coords.shape[0]) + if align_idxs is None: + align_idxs = include_idxs + aligned_decoy_coords = align( + x=decoy_coords, + y=gt_coords, + x_indices=align_idxs, + y_indices=align_idxs, + ) + deviations = np.linalg.norm( + aligned_decoy_coords[include_idxs] - gt_coords[include_idxs], axis=1 + ) + return deviations + + +def rmsd_from_coords( + decoy_coords: np.ndarray | str, + gt_coords: np.ndarray | str, + align_idxs: np.ndarray | None = None, + include_idxs: np.ndarray | None = None, +) -> float: + """Computes the *aligned* RMSD of two Mx3 np arrays of coordinates. + + Args: + decoy_coords: [M, 3] np array of decoy atom coordinates. + gt_coords: [M, 3] np array of gt atom coordinates. + align_idxs: [M] np array of indices specifying coordinates to align on. + Defaults to None, in which case all the include_idx (see after) are used. + include_idxs: [M] np array of indices specifying coordinates to score. + Defaults to None, in which case all indices are used for scoring. + + Returns: + rmsd value of the aligned decoy and gt coordinates. + """ + deviations = deviations_from_coords( + decoy_coords, gt_coords, align_idxs, include_idxs + ) + return np.sqrt(np.mean(np.square(deviations))) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py new file mode 100644 index 000000000..376340110 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/covalent_bond_cleaning.py @@ -0,0 +1,265 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Some methods to compute metrics for PTMs.""" + +import collections +from collections.abc import Mapping +import dataclasses + +from alphafold3 import structure +from alphafold3.constants import mmcif_names +from alphafold3.model.atom_layout import atom_layout +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class ResIdMapping: + old_res_ids: np.ndarray + new_res_ids: np.ndarray + + +def _count_symmetric_chains(struc: structure.Structure) -> Mapping[str, int]: + """Returns a dict with each chain ID and count.""" + chain_res_name_sequence_from_chain_id = struc.chain_res_name_sequence( + include_missing_residues=True, fix_non_standard_polymer_res=False + ) + counts_for_chain_res_name_sequence = collections.Counter( + chain_res_name_sequence_from_chain_id.values() + ) + chain_symmetric_count = {} + for chain_id, chain_res_name in chain_res_name_sequence_from_chain_id.items(): + chain_symmetric_count[chain_id] = counts_for_chain_res_name_sequence[ + chain_res_name + ] + return chain_symmetric_count + + +def has_nonsymmetric_bonds_on_symmetric_polymer_chains( + struc: structure.Structure, polymer_ligand_bonds: atom_layout.AtomLayout +) -> bool: + """Returns true if nonsymmetric bonds found on polymer chains.""" + try: + _get_polymer_dim(polymer_ligand_bonds) + except ValueError: + return True + if _has_non_polymer_ligand_ptm_bonds(polymer_ligand_bonds): + return True + if _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): + return True + combined_struc, _ = _combine_polymer_ligand_ptm_chains( + struc, polymer_ligand_bonds + ) + struc = struc.filter(chain_type=mmcif_names.POLYMER_CHAIN_TYPES) + combined_struc = combined_struc.filter( + chain_type=mmcif_names.POLYMER_CHAIN_TYPES + ) + return _count_symmetric_chains(struc) != _count_symmetric_chains( + combined_struc + ) + + +def _has_non_polymer_ligand_ptm_bonds( + polymer_ligand_bonds: atom_layout.AtomLayout, +): + """Checks if all bonds are between a polymer chain and a ligand chain type.""" + for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: + if ( + start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + continue + elif ( + start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + continue + else: + return True + return False + + +def _combine_polymer_ligand_ptm_chains( + struc: structure.Structure, + polymer_ligand_bonds: atom_layout.AtomLayout, +) -> tuple[structure.Structure, dict[tuple[str, str], ResIdMapping]]: + """Combines the ptm polymer-ligand chains together. + + This will prevent them from being permuted away from each other when chains + are matched to the ground truth. This function also returns the res_id mapping + from the separate ligand res_ids to their res_ids in the combined + polymer-ligand chain; this information is needed to later separate the + combined polymer-ligand chain. + + Args: + struc: Structure to be modified. + polymer_ligand_bonds: AtomLayout with polymer-ligand bond info. + + Returns: + A tuple of a Structure with each ptm polymer-ligand chain relabelled as one + chain and a dict from bond chain pair to the res_id mapping. + """ + if not _has_only_single_bond_from_each_chain(polymer_ligand_bonds): + if _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): + # For structures where a polymer chain is connected to multiple ligands, + # we need to sort the multiple bonds from the same chain by res_id to + # ensure that the combined polymer-ligand chain will always be the same + # when you have repeated symmetric polymer-ligand chains. + polymer_ligand_bonds = ( + _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( + polymer_ligand_bonds + ) + ) + else: + raise ValueError( + 'Code cannot handle multiple bonds from one chain unless' + ' its several ligands bonded to a polymer.' + ) + res_id_mappings_for_bond_chain_pair = dict() + for (start_chain_id, end_chain_id), (start_chain_type, end_chain_type) in zip( + polymer_ligand_bonds.chain_id, polymer_ligand_bonds.chain_type + ): + poly_info, ligand_info = _get_polymer_and_ligand_chain_ids_and_types( + start_chain_id, end_chain_id, start_chain_type, end_chain_type + ) + polymer_chain_id, polymer_chain_type = poly_info + ligand_chain_id, _ = ligand_info + + # Join the ligand chain to the polymer chain. + ligand_res_ids = struc.filter(chain_id=ligand_chain_id).res_id + new_res_ids = ligand_res_ids + len(struc.all_residues[polymer_chain_id]) + res_id_mappings_for_bond_chain_pair[(polymer_chain_id, ligand_chain_id)] = ( + ResIdMapping(old_res_ids=ligand_res_ids, new_res_ids=new_res_ids) + ) + chain_groups = [] + chain_group_ids = [] + chain_group_types = [] + for chain_id, chain_type in zip( + struc.chains_table.id, struc.chains_table.type + ): + if chain_id == ligand_chain_id: + continue + elif chain_id == polymer_chain_id: + chain_groups.append([polymer_chain_id, ligand_chain_id]) + chain_group_ids.append(polymer_chain_id) + chain_group_types.append(polymer_chain_type) + else: + chain_groups.append([chain_id]) + chain_group_ids.append(chain_id) + chain_group_types.append(chain_type) + + struc = struc.merge_chains( + chain_groups=chain_groups, + chain_group_ids=chain_group_ids, + chain_group_types=chain_group_types, + ) + + return struc, res_id_mappings_for_bond_chain_pair + + +def _has_only_single_bond_from_each_chain( + polymer_ligand_bonds: atom_layout.AtomLayout, +) -> bool: + """Checks that there is at most one bond from each chain.""" + chain_ids = [] + for chains in polymer_ligand_bonds.chain_id: + chain_ids.extend(chains) + if len(chain_ids) != len(set(chain_ids)): + return False + return True + + +def _get_polymer_and_ligand_chain_ids_and_types( + start_chain_id: str, + end_chain_id: str, + start_chain_type: str, + end_chain_type: str, +) -> tuple[tuple[str, str], tuple[str, str]]: + """Finds polymer and ligand chain ids from chain types.""" + if ( + start_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + and end_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + ): + return (start_chain_id, start_chain_type), (end_chain_id, end_chain_type) + elif ( + start_chain_type in mmcif_names.LIGAND_CHAIN_TYPES + and end_chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + return (end_chain_id, end_chain_type), (start_chain_id, start_chain_type) + else: + raise ValueError( + 'This code only handles PTM-bonds from polymer chain to ligands.' + ) + + +def _get_polymer_dim(polymer_ligand_bonds: atom_layout.AtomLayout) -> int: + """Gets polymer dimension from the polymer-ligand bond layout.""" + start_chain_types = [] + end_chain_types = [] + for start_chain_type, end_chain_type in polymer_ligand_bonds.chain_type: + start_chain_types.append(start_chain_type) + end_chain_types.append(end_chain_type) + if set(start_chain_types).issubset( + set(mmcif_names.POLYMER_CHAIN_TYPES) + ) and set(end_chain_types).issubset(set(mmcif_names.LIGAND_CHAIN_TYPES)): + return 0 + elif set(start_chain_types).issubset(mmcif_names.LIGAND_CHAIN_TYPES) and set( + end_chain_types + ).issubset(set(mmcif_names.POLYMER_CHAIN_TYPES)): + return 1 + else: + raise ValueError( + 'Polymer and ligand dimensions are not consistent within the structure.' + ) + + +def _has_multiple_ligands_bonded_to_one_polymer(polymer_ligand_bonds): + """Checks if there are multiple ligands bonded to one polymer.""" + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + polymer_chain_ids = [ + chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id + ] + if len(polymer_chain_ids) != len(set(polymer_chain_ids)): + return True + return False + + +def _has_multiple_polymers_bonded_to_one_ligand(polymer_ligand_bonds): + """Checks if there are multiple polymer chains bonded to one ligand.""" + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + ligand_dim = 1 - polymer_dim + ligand_chain_ids = [ + chains[ligand_dim] for chains in polymer_ligand_bonds.chain_id + ] + if len(ligand_chain_ids) != len(set(ligand_chain_ids)): + return True + return False + + +def _sort_polymer_ligand_bonds_by_polymer_chain_and_res_id( + polymer_ligand_bonds, +): + """Sorts bonds by res_id (for when a polymer chain has multiple bonded ligands).""" + + polymer_dim = _get_polymer_dim(polymer_ligand_bonds) + + polymer_chain_ids = [ + chains[polymer_dim] for chains in polymer_ligand_bonds.chain_id + ] + polymer_res_ids = [res[polymer_dim] for res in polymer_ligand_bonds.res_id] + + polymer_chain_and_res_id = zip(polymer_chain_ids, polymer_res_ids) + sorted_indices = [ + idx + for idx, _ in sorted( + enumerate(polymer_chain_and_res_id), key=lambda x: x[1] + ) + ] + return polymer_ligand_bonds[sorted_indices] diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py new file mode 100644 index 000000000..5b3caeb54 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/model/scoring/scoring.py @@ -0,0 +1,67 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Library of scoring methods of the model outputs.""" + +from alphafold3.model import protein_data_processing +import numpy as np + + +Array = np.ndarray + + +def pseudo_beta_fn( + aatype: Array, + dense_atom_positions: Array, + dense_atom_masks: Array, + is_ligand: Array | None = None, + use_jax: bool | None = True, +) -> tuple[Array, Array] | Array: + """Create pseudo beta atom positions and optionally mask. + + Args: + aatype: [num_res] amino acid types. + dense_atom_positions: [num_res, NUM_DENSE, 3] vector of all atom positions. + dense_atom_masks: [num_res, NUM_DENSE] mask. + is_ligand: [num_res] flag if something is a ligand. + use_jax: whether to use jax for the computations. + + Returns: + Pseudo beta dense atom positions and the corresponding mask. + """ + + if is_ligand is None: + is_ligand = np.zeros_like(aatype) + + pseudobeta_index_polymer = np.take( + protein_data_processing.RESTYPE_PSEUDOBETA_INDEX, aatype, axis=0 + ).astype(np.int32) + + pseudobeta_index = np.where( + is_ligand, + np.zeros_like(pseudobeta_index_polymer), + pseudobeta_index_polymer, + ) + + if not isinstance(dense_atom_positions, Array): + dense_atom_positions = dense_atom_positions.asnumpy() + if not isinstance(dense_atom_masks, Array): + dense_atom_masks = dense_atom_masks.asnumpy() + pseudo_beta = np.take_along_axis( + dense_atom_positions, pseudobeta_index[..., None, None], axis=-2 + ) + pseudo_beta = np.squeeze(pseudo_beta, axis=-2) + + pseudo_beta_mask = np.take_along_axis( + dense_atom_masks, pseudobeta_index[..., None], axis=-1 + ).astype(np.float32) + pseudo_beta_mask = np.squeeze(pseudo_beta_mask, axis=-1) + + return pseudo_beta, pseudo_beta_mask diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi new file mode 100644 index 000000000..09d915c84 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict.pyi @@ -0,0 +1,125 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from typing import Any, ClassVar, Iterable, Iterator, TypeVar, overload + +import numpy as np + +_T = TypeVar('_T') + +class CifDict: + class ItemView: + def __iter__(self) -> Iterator[tuple[str, list[str]]]: ... + def __len__(self) -> int: ... + + class KeyView: + @overload + def __contains__(self, key: str) -> bool: ... + @overload + def __contains__(self, key: object) -> bool: ... + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + + class ValueView: + def __iter__(self) -> Iterator[list[str]]: ... + def __len__(self) -> int: ... + + def __init__(self, d: dict[str, Iterable[str]]) -> None: ... + def copy_and_update(self, d: dict[str, Iterable[str]]) -> CifDict: ... + def extract_loop_as_dict(self, prefix: str, index: str) -> dict: + """Extracts loop associated with a prefix from mmCIF data as a dict. + + For instance for an mmCIF with these fields: + '_a.ix': ['1', '2', '3'] + '_a.1': ['a.1.1', 'a.1.2', 'a.1.3'] + '_a.2': ['a.2.1', 'a.2.2', 'a.2.3'] + + this function called with prefix='_a.', index='_a.ix' extracts: + {'1': {'a.ix': '1', 'a.1': 'a.1.1', 'a.2': 'a.2.1'} + '2': {'a.ix': '2', 'a.1': 'a.1.2', 'a.2': 'a.2.2'} + '3': {'a.ix': '3', 'a.1': 'a.1.3', 'a.2': 'a.2.3'}} + + Args: + prefix: Prefix shared by each of the data items in the loop. The prefix + should include the trailing period. + index: Which item of loop data should serve as the key. + + Returns: + Dict of dicts; each dict represents 1 entry from an mmCIF loop, + indexed by the index column. + """ + + def extract_loop_as_list(self, prefix: str) -> list: + """Extracts loop associated with a prefix from mmCIF data as a list. + + Reference for loop_ in mmCIF: + http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + + For instance for an mmCIF with these fields: + '_a.1': ['a.1.1', 'a.1.2', 'a.1.3'] + '_a.2': ['a.2.1', 'a.2.2', 'a.2.3'] + + this function called with prefix='_a.' extracts: + [{'_a.1': 'a.1.1', '_a.2': 'a.2.1'} + {'_a.1': 'a.1.2', '_a.2': 'a.2.2'} + {'_a.1': 'a.1.3', '_a.2': 'a.2.3'}] + + Args: + prefix: Prefix shared by each of the data items in the loop. The prefix + should include the trailing period. + + Returns: + A list of dicts; each dict represents 1 entry from an mmCIF loop. + """ + + def get(self, key: str, default_value: _T = ...) -> list[str] | _T: ... + def get_array( + self, key: str, dtype: object = ..., gather: object = ... + ) -> np.ndarray: + """Returns values looked up in dict converted to a NumPy array. + + Args: + key: Key in dictionary. + dtype: Optional (default `object`) Specifies output dtype of array. One of + [object, np.{int,uint}{8,16,32,64} np.float{32,64}]. As with NumPy use + `object` to return a NumPy array of strings. + gather: Optional one of [slice, np.{int,uint}{32,64}] non-intermediate + version of get_array(key, dtype)[gather]. + + Returns: + A NumPy array of given dtype. An optimised equivalent to + np.array(cif[key]).astype(dtype). With support of '.' being treated + as np.nan if dtype is one of np.float{32,64}. + Identical strings will all reference the same object to save space. + + Raises: + KeyError - if key is not found. + TypeError - if dtype is not valid or supported. + ValueError - if string cannot convert to dtype. + """ + + def get_data_name(self) -> str: ... + def items(self) -> CifDict.ItemView: ... + def keys(self) -> CifDict.KeyView: ... + def to_string(self) -> str: ... + def value_length(self, key: str) -> int: ... + def values(self) -> CifDict.ValueView: ... + def __bool__(self) -> bool: ... + def __contains__(self, key: str) -> bool: ... + def __getitem__(self, key: str) -> list[str]: ... + def __getstate__(self) -> tuple: ... + def __iter__(self) -> Iterator[str]: ... + def __len__(self) -> int: ... + def __setstate__(self, state: tuple) -> None: ... + +def tokenize(cif_string: str) -> list[str]: ... +def split_line(line: str) -> list[str]: ... +def from_string(mmcif_string: str | bytes) -> CifDict: ... +def parse_multi_data_cif(cif_string: str | bytes) -> dict[str, CifDict]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc new file mode 100644 index 000000000..b0b0fdbc3 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.cc @@ -0,0 +1,648 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/btree_map.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/container/node_hash_map.h" +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +namespace alphafold3 { +namespace { + +bool IsQuote(const char symbol) { return symbol == '\'' || symbol == '"'; } +bool IsWhitespace(const char symbol) { return symbol == ' ' || symbol == '\t'; } + +// Splits line into tokens, returns whether successful. +bool SplitLineInline(absl::string_view line, + std::vector* tokens) { + // See https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax + for (int i = 0, line_length = line.length(); i < line_length;) { + // Skip whitespace (spaces or tabs). + while (IsWhitespace(line[i])) { + if (++i == line_length) { + break; + } + } + if (i == line_length) { + break; + } + + // Skip comments (from # until the end of the line). If # is a non-comment + // character, it must be inside a quoted token. + if (line[i] == '#') { + break; + } + + int start_index; + int end_index; + if (IsQuote(line[i])) { + // Token in single or double quotes. CIF v1.1 specification considers a + // quote to be an opening quote only if it is at the beginning of a token. + // So e.g. A' B has tokens A' and B. Also, ""A" is a token "A. + const char quote_char = line[i++]; + start_index = i; + + // Find matching quote. The double loop is not strictly necessary, but + // optimises a bit better. + while (true) { + while (i < line_length && line[i] != quote_char) { + ++i; + } + if (i == line_length) { + // Reached the end of the line while still being inside a token. + return false; + } + if (i + 1 == line_length || IsWhitespace(line[i + 1])) { + break; + } + ++i; + } + end_index = i++; + } else { + // Non-quoted token. Read until reaching whitespace. + start_index = i++; + while (i < line_length && !IsWhitespace(line[i])) { + ++i; + } + end_index = i; + } + + tokens->push_back(line.substr(start_index, end_index - start_index)); + } + + return true; +} + +using HeapStrings = std::vector>; + +// The majority of strings can be viewed on original cif_string. +// heap_strings store multi-line tokens that have internal white-space stripped. +absl::StatusOr> TokenizeInternal( + absl::string_view cif_string, HeapStrings* heap_strings) { + const std::vector lines = absl::StrSplit(cif_string, '\n'); + std::vector tokens; + // Heuristic: Most lines in an mmCIF are _atom_site lines with 21 tokens. + tokens.reserve(lines.size() * 21); + int line_num = 0; + while (line_num < lines.size()) { + auto line = lines[line_num]; + line_num++; + + if (line.empty() || line[0] == '#') { + // Skip empty lines or lines that contain only comments. + continue; + } else if (line[0] == ';') { + // Leading whitespace on each line must be preserved while trailing + // whitespace may be stripped. + std::vector multiline_tokens; + // Strip the leading ";". + multiline_tokens.push_back( + absl::StripTrailingAsciiWhitespace(line.substr(1))); + while (line_num < lines.size()) { + auto multiline = absl::StripTrailingAsciiWhitespace(lines[line_num]); + line_num++; + if (!multiline.empty() && multiline[0] == ';') { + break; + } + multiline_tokens.push_back(multiline); + } + heap_strings->push_back( + std::make_unique(absl::StrJoin(multiline_tokens, "\n"))); + tokens.emplace_back(*heap_strings->back()); + } else { + if (!SplitLineInline(line, &tokens)) { + return absl::InvalidArgumentError( + absl::StrCat("Line ended with quote open: ", line)); + } + } + } + return tokens; +} + +absl::string_view GetEscapeQuote(const absl::string_view value) { + // Empty values should not happen, but if so, they should be quoted. + if (value.empty()) { + return "\""; + } + + // Shortcut for the most common cases where no quoting needed. + if (std::all_of(value.begin(), value.end(), [](char c) { + return absl::ascii_isalnum(c) || c == '.' || c == '?' || c == '-'; + })) { + return ""; + } + + // The value must not start with one of these CIF keywords. + if (absl::StartsWithIgnoreCase(value, "data_") || + absl::StartsWithIgnoreCase(value, "loop_") || + absl::StartsWithIgnoreCase(value, "save_") || + absl::StartsWithIgnoreCase(value, "stop_") || + absl::StartsWithIgnoreCase(value, "global_")) { + return "\""; + } + + // The first character must not be a special character. + const char first = value.front(); + if (first == '_' || first == '#' || first == '$' || first == '[' || + first == ']' || first == ';') { + return "\""; + } + + // No quotes or whitespace allowed inside. + for (const char c : value) { + if (c == '"') { + return "'"; + } else if (c == '\'' || c == ' ' || c == '\t') { + return "\""; + } + } + return ""; +} + +int RecordIndex(absl::string_view record) { + if (record == "_entry") { + return 0; // _entry is always first. + } + if (record == "_atom_site") { + return 2; // _atom_site is always last. + } + return 1; // other records are between _entry and _atom_site. +} + +struct RecordOrder { + using is_transparent = void; // Enable heterogeneous lookup. + bool operator()(absl::string_view lhs, absl::string_view rhs) const { + std::size_t lhs_index = RecordIndex(lhs); + std::size_t rhs_index = RecordIndex(rhs); + return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs); + } +}; + +// Make sure the _atom_site loop columns are sorted in the PDB-standard way. +constexpr absl::string_view kAtomSiteSortOrder[] = { + "_atom_site.group_PDB", + "_atom_site.id", + "_atom_site.type_symbol", + "_atom_site.label_atom_id", + "_atom_site.label_alt_id", + "_atom_site.label_comp_id", + "_atom_site.label_asym_id", + "_atom_site.label_entity_id", + "_atom_site.label_seq_id", + "_atom_site.pdbx_PDB_ins_code", + "_atom_site.Cartn_x", + "_atom_site.Cartn_y", + "_atom_site.Cartn_z", + "_atom_site.occupancy", + "_atom_site.B_iso_or_equiv", + "_atom_site.pdbx_formal_charge", + "_atom_site.auth_seq_id", + "_atom_site.auth_comp_id", + "_atom_site.auth_asym_id", + "_atom_site.auth_atom_id", + "_atom_site.pdbx_PDB_model_num", +}; + +size_t AtomSiteIndex(absl::string_view atom_site) { + return std::distance(std::begin(kAtomSiteSortOrder), + absl::c_find(kAtomSiteSortOrder, atom_site)); +} + +struct AtomSiteOrder { + bool operator()(absl::string_view lhs, absl::string_view rhs) const { + auto lhs_index = AtomSiteIndex(lhs); + auto rhs_index = AtomSiteIndex(rhs); + return std::tie(lhs_index, lhs) < std::tie(rhs_index, rhs); + } +}; + +class Column { + public: + Column(absl::string_view key, const std::vector* values) + : key_(key), values_(values) { + int max_value_length = 0; + for (size_t i = 0; i < values->size(); ++i) { + absl::string_view value = (*values)[i]; + if (absl::StrContains(value, '\n')) { + values_with_newlines_.insert(i); + } else { + absl::string_view quote = GetEscapeQuote(value); + if (!quote.empty()) { + values_with_quotes_[i] = quote; + } + max_value_length = + std::max(max_value_length, value.size() + quote.size() * 2); + } + } + max_value_length_ = max_value_length; + } + + absl::string_view key() const { return key_; } + + const std::vector* values() const { return values_; } + + int max_value_length() const { return max_value_length_; } + + bool has_newlines(size_t index) const { + return values_with_newlines_.contains(index); + } + + absl::string_view quote(size_t index) const { + if (auto it = values_with_quotes_.find(index); + it != values_with_quotes_.end()) { + return it->second; + } + return ""; + } + + private: + absl::string_view key_; + const std::vector* values_; + int max_value_length_; + // Values with newlines or quotes are very rare in a typical CIF file. + absl::flat_hash_set values_with_newlines_; + absl::flat_hash_map values_with_quotes_; +}; + +struct GroupedKeys { + std::vector grouped_columns; + int max_key_length; + int value_size; +}; + +} // namespace + +absl::StatusOr CifDict::FromString(absl::string_view cif_string) { + CifDict::Dict cif; + + bool loop_flag = false; + absl::string_view key; + + HeapStrings heap_strings; + auto tokens = TokenizeInternal(cif_string, &heap_strings); + if (!tokens.ok()) { + return tokens.status(); + } + + if (tokens->empty()) { + return absl::InvalidArgumentError("The CIF file must not be empty."); + } + + // The first token should be data_XXX. Split into key = data, value = XXX. + absl::string_view first_token = tokens->front(); + if (!absl::ConsumePrefix(&first_token, "data_")) { + return absl::InvalidArgumentError( + "The CIF file does not start with the data_ field."); + } + cif["data_"].emplace_back(first_token); + + // Counters for CIF loop_ regions. + int loop_token_index = 0; + int num_loop_keys = 0; + // Loops have usually O(10) columns but could have up to O(10^6) rows. It is + // therefore wasteful to look up the cif vector where to add a loop value + // since that means doing `columns * rows` map lookups. If we save pointers to + // these loop column fields instead, we need only 1 cif lookup per column. + std::vector*> loop_column_values; + + // Skip the first element since we already processed it above. + for (auto token_itr = tokens->begin() + 1; token_itr != tokens->end(); + ++token_itr) { + auto token = *token_itr; + if (absl::EqualsIgnoreCase(token, "loop_")) { + // A new loop started, get rid of old loop's data. + loop_flag = true; + loop_column_values.clear(); + loop_token_index = 0; + num_loop_keys = 0; + continue; + } else if (loop_flag) { + // The second condition checks we are in the first column. Some mmCIF + // files (e.g. 4q9r) have values in later columns starting with an + // underscore and we don't want to read these as keys. + int token_column_index = + num_loop_keys == 0 ? 0 : loop_token_index % num_loop_keys; + if (token_column_index == 0 && !token.empty() && token[0] == '_') { + if (loop_token_index > 0) { + // We are out of the loop. + loop_flag = false; + } else { + // We are in the keys (column names) section of the loop. + auto& columns = cif[token]; + columns.clear(); + + // Heuristic: _atom_site is typically the largest table in an mmCIF + // with ~16 columns. Make sure we reserve enough space for its values. + if (absl::StartsWith(token, "_atom_site.")) { + columns.reserve(tokens->size() / 20); + } + + // Save the pointer to the loop column values. + loop_column_values.push_back(&columns); + num_loop_keys += 1; + continue; + } + } else { + // We are in the values section of the loop. We have a pointer to the + // loops' values, add the new token in there. + if (token_column_index >= loop_column_values.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Too many columns at: '", token, + "' at column index: ", token_column_index, + " expected at most: ", loop_column_values.size())); + } + loop_column_values[token_column_index]->emplace_back(token); + loop_token_index++; + continue; + } + } + if (key.empty()) { + key = token; + } else { + cif[key].emplace_back(token); + key = ""; + } + } + return CifDict(std::move(cif)); +} + +absl::StatusOr CifDict::ToString() const { + std::string output; + + absl::string_view data_name; + // Check that the data_ field exists. + if (auto name_it = (*dict_).find("data_"); + name_it == (*dict_).end() || name_it->second.empty()) { + return absl::InvalidArgumentError( + "The CIF must contain a valid name for this data block in the special " + "data_ field."); + } else { + data_name = name_it->second.front(); + } + + if (absl::c_any_of(data_name, + [](char i) { return absl::ascii_isspace(i); })) { + return absl::InvalidArgumentError(absl::StrFormat( + "The CIF data block name must not contain any whitespace characters, " + "got '%s'.", + data_name)); + } + absl::StrAppend(&output, "data_", data_name, "\n#\n"); + + // Group keys by their prefix. Use btree_map to iterate in alphabetical order, + // but with some keys being placed at the end (e.g. _atom_site). + absl::btree_map grouped_keys; + for (const auto& [key, values] : *dict_) { + if (key == "data_") { + continue; // Skip the special data_ key, we are already done with it. + } + const std::pair key_parts = + absl::StrSplit(key, absl::MaxSplits('.', 1)); + const absl::string_view key_prefix = key_parts.first; + auto [it, inserted] = grouped_keys.emplace(key_prefix, GroupedKeys{}); + GroupedKeys& grouped_key = it->second; + grouped_key.grouped_columns.push_back(Column(key, &values)); + if (inserted) { + grouped_key.max_key_length = key.length(); + grouped_key.value_size = values.size(); + } else { + grouped_key.max_key_length = + std::max(key.length(), grouped_key.max_key_length); + if (grouped_key.value_size != values.size()) { + return absl::InvalidArgumentError( + absl::StrFormat("Values for key %s have different length (%d) than " + "the other values with the same key prefix (%d).", + key, values.size(), grouped_key.value_size)); + } + } + } + + for (auto& [key_prefix, group_info] : grouped_keys) { + if (key_prefix == "_atom_site") { + // Make sure we sort the _atom_site loop in the standard way. + absl::c_sort(group_info.grouped_columns, + [](const Column& lhs, const Column& rhs) { + return AtomSiteOrder{}(lhs.key(), rhs.key()); + }); + } else { + // Make the key ordering within a key group deterministic. + absl::c_sort(group_info.grouped_columns, + [](const Column& lhs, const Column& rhs) { + return lhs.key() < rhs.key(); + }); + } + + // Force `_atom_site` field to always be a loop. This resolves issues with + // third party mmCIF parsers such as OpenBabel which always expect a loop + // even when there is only a single atom present. + if (group_info.value_size == 1 && key_prefix != "_atom_site") { + // Plain key-value pairs, output them as they are. + for (const Column& grouped_column : group_info.grouped_columns) { + int width = group_info.max_key_length + 1; + size_t start_pos = output.size(); + output.append(width, ' '); + auto out_it = output.begin() + start_pos; + absl::c_copy(grouped_column.key(), out_it); + // Append the value, handle multi-line/quoting. + absl::string_view value = grouped_column.values()->front(); + if (grouped_column.has_newlines(0)) { + absl::StrAppend(&output, "\n;", value, "\n;\n"); // Multi-line value. + } else { + const absl::string_view quote_char = grouped_column.quote(0); + absl::StrAppend(&output, quote_char, value, quote_char, "\n"); + } + } + } else { + // CIF loop. Output the column names, then the rows with data. + absl::StrAppend(&output, "loop_\n"); + for (Column& grouped_column : group_info.grouped_columns) { + absl::StrAppend(&output, grouped_column.key(), "\n"); + } + // Write the loop values, line by line. This is the most expensive part + // since this path is taken to write the entire atom site table which has + // about 20 columns, but thousands of rows. + for (int i = 0; i < group_info.value_size; i++) { + for (int column_index = 0; + column_index < group_info.grouped_columns.size(); ++column_index) { + const Column& grouped_column = + group_info.grouped_columns[column_index]; + const absl::string_view value = (*grouped_column.values())[i]; + if (grouped_column.has_newlines(i)) { + // Multi-line. This is very rarely taken path. + if (column_index == 0) { + // No extra newline before leading ;, already inserted. + absl::StrAppend(&output, ";", value, "\n;\n"); + } else if (column_index == group_info.grouped_columns.size() - 1) { + // No extra newline after trailing ;, will be inserted. + absl::StrAppend(&output, "\n;", value, "\n;"); + } else { + absl::StrAppend(&output, "\n;", value, "\n;\n"); + } + } else { + size_t start_pos = output.size(); + output.append(grouped_column.max_value_length() + 1, ' '); + auto out_it = output.begin() + start_pos; + absl::string_view quote = grouped_column.quote(i); + if (!quote.empty()) { + out_it = absl::c_copy(quote, out_it); + out_it = absl::c_copy(value, out_it); + absl::c_copy(quote, out_it); + } else { + absl::c_copy(value, out_it); + } + } + } + absl::StrAppend(&output, "\n"); + } + } + absl::StrAppend(&output, "#\n"); // Comment token after every key group. + } + return output; +} + +absl::StatusOr< + std::vector>> +CifDict::ExtractLoopAsList(absl::string_view prefix) const { + std::vector column_names; + std::vector> column_data; + + for (const auto& element : *dict_) { + if (absl::StartsWith(element.first, prefix)) { + column_names.emplace_back(element.first); + auto& cells = column_data.emplace_back(); + cells.insert(cells.begin(), element.second.begin(), element.second.end()); + } + } + // Make sure all columns have the same number of rows. + const std::size_t num_rows = column_data.empty() ? 0 : column_data[0].size(); + for (const auto& column : column_data) { + if (column.size() != num_rows) { + return absl::InvalidArgumentError(absl::StrCat( + GetDataName(), + ": Columns do not have the same number of rows for prefix: '", prefix, + "'. One possible reason could be not including the trailing dot, " + "e.g. '_atom_site.'.")); + } + } + + std::vector> result; + result.reserve(num_rows); + CHECK_EQ(column_names.size(), column_data.size()); + for (std::size_t row_index = 0; row_index < num_rows; ++row_index) { + auto& row_dict = result.emplace_back(); + row_dict.reserve(column_names.size()); + for (int col_index = 0; col_index < column_names.size(); ++col_index) { + row_dict[column_names[col_index]] = column_data[col_index][row_index]; + } + } + return result; +} + +absl::StatusOr>> +CifDict::ExtractLoopAsDict(absl::string_view prefix, + absl::string_view index) const { + if (!absl::StartsWith(index, prefix)) { + return absl::InvalidArgumentError( + absl::StrCat(GetDataName(), ": The loop index '", index, + "' must start with the loop prefix '", prefix, "'.")); + } + absl::flat_hash_map> + result; + auto loop_as_list = ExtractLoopAsList(prefix); + if (!loop_as_list.ok()) { + return loop_as_list.status(); + } + result.reserve(loop_as_list->size()); + for (auto& entry : *loop_as_list) { + if (const auto it = entry.find(index); it != entry.end()) { + result[it->second] = entry; + } else { + return absl::InvalidArgumentError(absl::StrCat( + GetDataName(), ": The index column '", index, + "' could not be found in the loop with prefix '", prefix, "'.")); + } + } + return result; +} + +absl::StatusOr> Tokenize( + absl::string_view cif_string) { + HeapStrings heap_strings; + auto tokens = TokenizeInternal(cif_string, &heap_strings); + if (!tokens.ok()) { + return tokens.status(); + } + return std::vector(tokens->begin(), tokens->end()); +} + +absl::StatusOr> SplitLine( + absl::string_view line) { + std::vector tokens; + if (!SplitLineInline(line, &tokens)) { + return absl::InvalidArgumentError( + absl::StrCat("Line ended with quote open: ", line)); + } + return tokens; +} + +absl::StatusOr> ParseMultiDataCifDict( + absl::string_view cif_string) { + absl::flat_hash_map mapping; + constexpr absl::string_view delimitor = "data_"; + // Check cif_string starts with correct offset. + if (!cif_string.empty() && !absl::StartsWith(cif_string, delimitor)) { + return absl::InvalidArgumentError( + "Invalid format. MultiDataCifDict must start with 'data_'"); + } + for (absl::string_view data_block : + absl::StrSplit(cif_string, delimitor, absl::SkipEmpty())) { + absl::string_view block_with_delimitor( + data_block.data() - delimitor.size(), + data_block.size() + delimitor.size()); + absl::StatusOr parsed_block = + CifDict::FromString(block_with_delimitor); + if (!parsed_block.ok()) { + return parsed_block.status(); + } + absl::string_view data_name = parsed_block->GetDataName(); + mapping[data_name] = *std::move(parsed_block); + } + + return mapping; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h new file mode 100644 index 000000000..5c16eaa87 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_lib.h @@ -0,0 +1,149 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +// A C++ implementation of a CIF parser. For the format specification see +// https://www.iucr.org/resources/cif/spec/version1.1/cifsyntax +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace alphafold3 { + +class CifDict { + public: + // Use absl::node_hash_map since it guarantees pointer stability. + using Dict = absl::node_hash_map>; + + CifDict() = default; + + explicit CifDict(Dict dict) + : dict_(std::make_shared(std::move(dict))) {} + + // Converts a CIF string into a dictionary mapping each CIF field to a list of + // values that field contains. + static absl::StatusOr FromString(absl::string_view cif_string); + + // Converts the CIF into into a string that is a valid CIF file. + absl::StatusOr ToString() const; + + // Extracts loop associated with a prefix from mmCIF data as a list. + // Reference for loop_ in mmCIF: + // http://mmcif.wwpdb.org/docs/tutorials/mechanics/pdbx-mmcif-syntax.html + // Args: + // prefix: Prefix shared by each of the data items in the loop. + // e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + // _entity_poly_seq.mon_id. Should include the trailing period. + // + // Returns a list of dicts; each dict represents 1 entry from an mmCIF loop. + // Lifetime of string_views tied to this. + absl::StatusOr< + std::vector>> + ExtractLoopAsList(absl::string_view prefix) const; + + // Extracts loop associated with a prefix from mmCIF data as a dictionary. + // Args: + // prefix: Prefix shared by each of the data items in the loop. + // e.g. '_entity_poly_seq.', where the data items are _entity_poly_seq.num, + // _entity_poly_seq.mon_id. Should include the trailing period. + // index: Which item of loop data should serve as the key. + // + // Returns a dict of dicts; each dict represents 1 entry from an mmCIF loop, + // indexed by the index column. + // Lifetime of string_views tied to this. + absl::StatusOr>> + ExtractLoopAsDict(absl::string_view prefix, absl::string_view index) const; + + // Returns value at key if present or an empty list. + absl::Span operator[](absl::string_view key) const { + auto it = dict_->find(key); + if (it != dict_->end()) { + return it->second; + } + return {}; + } + + // Returns boolean of whether dict contains key. + bool Contains(absl::string_view key) const { return dict_->contains(key); } + + // Returns number of values for the given key if present, 0 otherwise. + size_t ValueLength(absl::string_view key) const { + return (*this)[key].size(); + } + + // Returns the size of the underlying dictionary. + std::size_t Length() { return dict_->size(); } + + // Creates a copy of this CifDict object that will contain the original values + // but only if not updated by the given dictionary. + // E.g. if the CifDict = {a: [a1, a2], b: [b1]} and other = {a: [x], c: [z]}, + // you will get {a: [x], b: [b1], c: [z]}. + CifDict CopyAndUpdate(Dict other) const { + other.insert(dict_->begin(), dict_->end()); + return CifDict(std::move(other)); + } + + // Returns the value of the special CIF data_ field. + absl::string_view GetDataName() const { + // The data_ element has to be present by construction. + if (auto it = dict_->find("data_"); + it != dict_->end() && !it->second.empty()) { + return it->second.front(); + } else { + return ""; + } + } + + const std::shared_ptr& dict() const { return dict_; } + + private: + std::shared_ptr dict_; +}; + +// Tokenizes a CIF string into a list of string tokens. This is more involved +// than just a simple split on whitespace as CIF allows comments and quoting. +absl::StatusOr> Tokenize(absl::string_view cif_string); + +// Tokenizes a single line of a CIF string. +absl::StatusOr> SplitLine( + absl::string_view line); + +// Parses a CIF string with multiple data records and returns a mapping from +// record names to CifDict objects. For instance, the following CIF string: +// +// data_001 +// _foo bar +// +// data_002 +// _foo baz +// +// will be parsed as: +// {'001': CifDict({'_foo': ['bar']}), +// '002': CifDict({'_foo': ['baz']})} +absl::StatusOr> ParseMultiDataCifDict( + absl::string_view cif_string); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_LIB_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc new file mode 100644 index 000000000..130a8215a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.cc @@ -0,0 +1,652 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/base/no_destructor.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/numbers.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "pybind11/attr.h" +#include "pybind11/cast.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +template +bool GatherArray(size_t num_dims, npy_intp* shape_array, npy_intp* stride_array, + const char* data, absl::Span values, + ForEach&& for_each_cb) { + if (num_dims == 1) { + const npy_intp shape = shape_array[0]; + const npy_intp stride = stride_array[0]; + for (size_t i = 0; i < shape; ++i) { + Item index; + std::memcpy(&index, data + stride * i, sizeof(Item)); + if (index < 0 || index >= values.size()) { + PyErr_SetString(PyExc_IndexError, + absl::StrCat("index ", index, + " is out of bounds for column with size ", + values.size()) + .c_str()); + return false; + } + if (!for_each_cb(values[index])) { + return false; + } + } + } else if (num_dims == 0) { + Item index; + std::memcpy(&index, data, sizeof(Item)); + if (index < 0 || index >= values.size()) { + PyErr_SetString( + PyExc_IndexError, + absl::StrCat("index ", index, + " is out of bounds for column with size ", values.size()) + .c_str()); + return false; + } + if (!for_each_cb(values[index])) { + return false; + } + } else { + const npy_intp shape = shape_array[0]; + const npy_intp stride = stride_array[0]; + for (size_t i = 0; i < shape; ++i) { + if (!GatherArray(num_dims - 1, shape_array + 1, stride_array + 1, + data + stride * i, values, for_each_cb)) { + return false; + } + } + } + return true; +} + +template +bool Gather(PyObject* gather, absl::Span values, + Size&& size_cb, ForEach&& for_each_cb) { + if (gather == Py_None) { + npy_intp dim = static_cast(values.size()); + if (!size_cb(absl::MakeSpan(&dim, 1))) { + return false; + } + for (const std::string& v : values) { + if (!for_each_cb(v)) { + return false; + } + } + return true; + } + if (PySlice_Check(gather)) { + Py_ssize_t start, stop, step, slice_length; + if (PySlice_GetIndicesEx(gather, values.size(), &start, &stop, &step, + &slice_length) != 0) { + return false; + } + npy_intp dim = static_cast(slice_length); + if (!size_cb(absl::MakeSpan(&dim, 1))) { + return false; + } + for (size_t i = 0; i < slice_length; ++i) { + if (!for_each_cb(values[start + i * step])) { + return false; + } + } + return true; + } + if (PyArray_Check(gather)) { + PyArrayObject* gather_array = reinterpret_cast(gather); + auto shape = + absl::MakeSpan(PyArray_DIMS(gather_array), PyArray_NDIM(gather_array)); + switch (PyArray_TYPE(gather_array)) { + case NPY_INT16: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT16: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_INT32: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT32: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_INT64: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + case NPY_UINT64: + if (!size_cb(shape)) { + return false; + } + return GatherArray(shape.size(), shape.data(), + PyArray_STRIDES(gather_array), + PyArray_BYTES(gather_array), values, + std::forward(for_each_cb)); + default: + PyErr_SetString(PyExc_TypeError, "Unsupported NumPy array type."); + return false; + } + } + + PyErr_Format(PyExc_TypeError, "Invalid gather %R", gather); + return false; +} + +// Creates a NumPy array of objects of given strings. Reusing duplicates where +// possible. +PyObject* ConvertStrings(PyObject* gather, PyArray_Descr* type, + absl::Span values) { + absl::flat_hash_map existing; + + PyObject* ret = nullptr; + PyObject** dst; + if (Gather( + gather, values, + [&dst, &ret, type](absl::Span size) { + ret = PyArray_NewFromDescr( + /*subtype=*/&PyArray_Type, + /*type=*/type, + /*nd=*/size.size(), + /*dims=*/size.data(), + /*strides=*/nullptr, + /*data=*/nullptr, + /*flags=*/0, + /*obj=*/nullptr); + dst = static_cast( + PyArray_DATA(reinterpret_cast(ret))); + return true; + }, + [&dst, &existing](absl::string_view value) { + auto [it, inserted] = existing.emplace(value, nullptr); + if (inserted) { + it->second = + PyUnicode_FromStringAndSize(value.data(), value.size()); + PyUnicode_InternInPlace(&it->second); + } else { + Py_INCREF(it->second); + } + *dst++ = it->second; + return true; + })) { + return ret; + } else { + Py_XDECREF(ret); + return nullptr; + } +} + +// Creates NumPy array with given dtype given specified converter. +// `converter` shall have the following signature: +// bool converter(const std::string& value, T* result); +// It must return whether conversion is successful and store conversion in +// result. +template +inline PyObject* Convert(PyObject* gather, PyArray_Descr* type, + absl::Span values, C&& converter) { + py::object ret; + T* dst; + if (Gather( + gather, values, + [&dst, &ret, type](absl::Span size) { + // Construct uninitialised NumPy array of type T. + ret = py::reinterpret_steal(PyArray_NewFromDescr( + /*subtype=*/&PyArray_Type, + /*type=*/type, + /*nd=*/size.size(), + /*dims=*/size.data(), + /*strides=*/nullptr, + /*data=*/nullptr, + /*flags=*/0, + /*obj=*/nullptr)); + + dst = static_cast( + PyArray_DATA(reinterpret_cast(ret.ptr()))); + return true; + }, + [&dst, &converter](const std::string& value) { + if (!converter(value, dst++)) { + PyErr_SetString(PyExc_ValueError, value.c_str()); + return false; + } + return true; + })) { + return ret.release().ptr(); + } + return nullptr; +} + +PyObject* CifDictGetArray(const CifDict& self, absl::string_view key, + PyObject* dtype, PyObject* gather) { + import_array(); + PyArray_Descr* type = nullptr; + if (dtype == Py_None) { + type = PyArray_DescrFromType(NPY_OBJECT); + } else if (PyArray_DescrConverter(dtype, &type) == NPY_FAIL || !type) { + PyErr_Format(PyExc_TypeError, "Invalid dtype %R", dtype); + Py_XDECREF(type); + return nullptr; + } + auto entry = self.dict()->find(key); + if (entry == self.dict()->end()) { + Py_DECREF(type); + PyErr_SetObject(PyExc_KeyError, + PyUnicode_FromStringAndSize(key.data(), key.size())); + return nullptr; + } + + auto int_convert = [](absl::string_view str, auto* value) { + return absl::SimpleAtoi(str, value); + }; + + auto int_convert_bounded = [](absl::string_view str, auto* value) { + int64_t v; + if (absl::SimpleAtoi(str, &v)) { + using limits = + std::numeric_limits>; + if (limits::min() <= v && v <= limits::max()) { + *value = v; + return true; + } + } + return false; + }; + + absl::Span values = entry->second; + + switch (type->type_num) { + case NPY_DOUBLE: + return Convert( + gather, type, values, [](absl::string_view str, double* value) { + if (str == ".") { + *value = std::numeric_limits::quiet_NaN(); + return true; + } + return absl::SimpleAtod(str, value); + }); + case NPY_FLOAT: + return Convert( + gather, type, values, [](absl::string_view str, float* value) { + if (str == ".") { + *value = std::numeric_limits::quiet_NaN(); + return true; + } + return absl::SimpleAtof(str, value); + }); + case NPY_INT8: + return Convert(gather, type, values, int_convert_bounded); + case NPY_INT16: + return Convert(gather, type, values, int_convert_bounded); + case NPY_INT32: + return Convert(gather, type, values, int_convert); + case NPY_INT64: + return Convert(gather, type, values, int_convert); + case NPY_UINT8: + return Convert(gather, type, values, int_convert_bounded); + case NPY_UINT16: + return Convert(gather, type, values, int_convert_bounded); + case NPY_UINT32: + return Convert(gather, type, values, int_convert); + case NPY_UINT64: + return Convert(gather, type, values, int_convert); + case NPY_BOOL: + return Convert(gather, type, values, + [](absl::string_view str, bool* value) { + if (str == "n" || str == "no") { + *value = false; + return true; + } + if (str == "y" || str == "yes") { + *value = true; + return true; + } + return false; + }); + case NPY_OBJECT: + return ConvertStrings(gather, type, values); + default: { + PyErr_Format(PyExc_TypeError, "Unsupported dtype %R", dtype); + Py_XDECREF(type); + return nullptr; + } + } +} + +} // namespace + +void RegisterModuleCifDict(pybind11::module m) { + using Value = std::vector; + static absl::NoDestructor> empty_values; + + m.def( + "from_string", + [](absl::string_view s) { + absl::StatusOr dict = CifDict::FromString(s); + if (!dict.ok()) { + throw py::value_error(dict.status().ToString()); + } + return *dict; + }, + py::call_guard()); + + m.def( + "tokenize", + [](absl::string_view cif_string) { + absl::StatusOr> tokens = Tokenize(cif_string); + if (!tokens.ok()) { + throw py::value_error(tokens.status().ToString()); + } + return *std::move(tokens); + }, + py::arg("cif_string")); + + m.def("split_line", [](absl::string_view line) { + absl::StatusOr> tokens = SplitLine(line); + if (!tokens.ok()) { + throw py::value_error(tokens.status().ToString()); + } + return *std::move(tokens); + }); + + m.def( + "parse_multi_data_cif", + [](absl::string_view cif_string) { + auto result = ParseMultiDataCifDict(cif_string); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + py::dict dict; + for (auto& [key, value] : *result) { + dict[py::cast(key)] = py::cast(value); + } + return dict; + }, + py::arg("cif_string")); + + auto cif_dict = + py::class_(m, "CifDict") + .def(py::init<>([](py::dict dict) { + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + return CifDict(std::move(result)); + }), + "Initialise with a map") + .def("copy_and_update", + [](const CifDict& self, py::dict dict) { + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + { + py::gil_scoped_release gil_release; + return self.CopyAndUpdate(std::move(result)); + } + }) + .def( + "__str__", + [](const CifDict& self) { + absl::StatusOr result = self.ToString(); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + return *result; + }, + "Serialize to a string", py::call_guard()) + .def( + "to_string", + [](const CifDict& self) { + absl::StatusOr result = self.ToString(); + if (!result.ok()) { + throw py::value_error(result.status().ToString()); + } + return *result; + }, + "Serialize to a string", py::call_guard()) + .def("value_length", &CifDict::ValueLength, py::arg("key"), + "Num elements in value") + .def("__len__", + [](const CifDict& self) { return self.dict()->size(); }) + .def( + "__bool__", + [](const CifDict& self) { return !self.dict()->empty(); }, + "Check whether the map is nonempty") + .def( + "__contains__", + [](const CifDict& self, absl::string_view k) { + return self.dict()->find(k) != self.dict()->end(); + }, + py::arg("key"), py::call_guard()) + .def("get_data_name", &CifDict::GetDataName) + .def( + "get", + [](const CifDict& self, absl::string_view k, + py::object default_value) -> py::object { + auto it = self.dict()->find(k); + if (it == self.dict()->end()) return default_value; + py::list result(it->second.size()); + size_t index = 0; + for (const std::string& v : it->second) { + result[index++] = py::cast(v); + } + return result; + }, + py::arg("key"), py::arg("default_value") = py::none()) + .def( + "get_array", + [](const CifDict& self, absl::string_view key, py::handle dtype, + py::handle gather) -> py::object { + PyObject* obj = + CifDictGetArray(self, key, dtype.ptr(), gather.ptr()); + if (obj == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(obj); + }, + py::arg("key"), py::arg("dtype") = py::none(), + py::arg("gather") = py::none()) + .def( + "__getitem__", + [](const CifDict& self, absl::string_view k) -> const Value& { + auto it = self.dict()->find(k); + if (it == self.dict()->end()) { + throw py::key_error(std::string(k).c_str()); + } + return it->second; + }, + py::arg("key"), py::call_guard()) + .def( + "extract_loop_as_dict", + [](const CifDict& self, absl::string_view prefix, + absl::string_view index) { + absl::StatusOr>> + dict; + { + py::gil_scoped_release gil_release; + dict = self.ExtractLoopAsDict(prefix, index); + if (!dict.ok()) { + throw py::value_error(dict.status().ToString()); + } + } + py::dict key_value_dict; + for (const auto& [key, value] : *dict) { + py::dict value_dict; + for (const auto& [key2, value2] : value) { + value_dict[py::cast(key2)] = py::cast(value2); + } + key_value_dict[py::cast(key)] = std::move(value_dict); + } + return key_value_dict; + }, + py::arg("prefix"), py::arg("index")) + .def( + "extract_loop_as_list", + [](const CifDict& self, absl::string_view prefix) { + absl::StatusOr>> + list_dict; + { + py::gil_scoped_release gil_release; + list_dict = self.ExtractLoopAsList(prefix); + if (!list_dict.ok()) { + throw py::value_error(list_dict.status().ToString()); + } + } + py::list list_obj(list_dict->size()); + size_t index = 0; + for (const auto& value : *list_dict) { + py::dict value_dict; + for (const auto& [key, value] : value) { + value_dict[py::cast(key)] = py::cast(value); + } + list_obj[index++] = std::move(value_dict); + } + return list_obj; + }, + py::arg("prefix")) + .def(py::pickle( + [](const CifDict& self) { // __getstate__. + py::tuple result_tuple(1); + py::dict result; + for (const auto& [key, value] : *self.dict()) { + result[py::cast(key)] = py::cast(value); + } + result_tuple[0] = std::move(result); + return result_tuple; + }, + [](py::tuple t) { // __setstate__. + py::dict dict = t[0].cast(); + CifDict::Dict result; + for (const auto& [key, value] : dict) { + result.emplace(py::cast(key), + py::cast>(value)); + } + return CifDict(std::move(result)); + })); + + // Item, value, and key views + struct KeyView { + CifDict map; + }; + + struct ValueView { + CifDict map; + }; + struct ItemView { + CifDict map; + }; + + py::class_(cif_dict, "ItemView") + .def("__len__", [](const ItemView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const ItemView& v) { + return py::make_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + py::class_(cif_dict, "KeyView") + .def("__contains__", + [](const KeyView& v, absl::string_view k) { + return v.map.dict()->find(k) != v.map.dict()->end(); + }) + .def("__contains__", [](const KeyView&, py::handle) { return false; }) + .def("__len__", [](const KeyView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const KeyView& v) { + return py::make_key_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + py::class_(cif_dict, "ValueView") + .def("__len__", [](const ValueView& v) { return v.map.dict()->size(); }) + .def( + "__iter__", + [](const ValueView& v) { + return py::make_value_iterator(v.map.dict()->begin(), + v.map.dict()->end()); + }, + py::keep_alive<0, 1>()); + + cif_dict + .def( + "__iter__", + [](CifDict& self) { + return py::make_key_iterator(self.dict()->begin(), + self.dict()->end()); + }, + py::keep_alive<0, 1>()) + .def( + "keys", [](CifDict& self) { return KeyView{self}; }, + "Returns an iterable view of the map's keys.") + .def( + "values", [](CifDict& self) { return ValueView{self}; }, + "Returns an iterable view of the map's values.") + .def( + "items", [](CifDict& self) { return ItemView{self}; }, + "Returns an iterable view of the map's items."); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h new file mode 100644 index 000000000..ca4f94702 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/cif_dict_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleCifDict(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_CIF_DICT_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi new file mode 100644 index 000000000..d5da60ec8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator.pyi @@ -0,0 +1,22 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +class FastaFileIterator: + def __init__(self, fasta_path: str) -> None: ... + def __iter__(self) -> FastaFileIterator: ... + def __next__(self) -> tuple[str,str]: ... + +class FastaStringIterator: + def __init__(self, fasta_string: str | bytes) -> None: ... + def __iter__(self) -> FastaStringIterator: ... + def __next__(self) -> tuple[str,str]: ... + +def parse_fasta(fasta_string: str | bytes) -> list[str]: ... +def parse_fasta_include_descriptions(fasta_string: str | bytes) -> tuple[list[str],list[str]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc new file mode 100644 index 000000000..82cac9343 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.cc @@ -0,0 +1,121 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/parsers/cpp/fasta_iterator_lib.h" + +#include +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/ascii.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "absl/strings/strip.h" + +namespace alphafold3 { + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns a list of amino acid sequences only. +std::vector ParseFasta(absl::string_view fasta_string) { + std::vector sequences; + std::string* sequence = nullptr; + for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) { + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + sequence = &sequences.emplace_back(); + } else if (!line.empty() && sequence != nullptr) { + absl::StrAppend(sequence, line); + } + } + return sequences; +} + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns two lists: The first one with amino acid sequences, the second with +// the descriptions associated with each sequence. +std::pair, std::vector> +ParseFastaIncludeDescriptions(absl::string_view fasta_string) { + std::pair, std::vector> result; + auto& [sequences, descriptions] = result; + std::string* sequence = nullptr; + for (absl::string_view line_raw : absl::StrSplit(fasta_string, '\n')) { + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + descriptions.emplace_back(line); + sequence = &sequences.emplace_back(); + } else if (!line.empty() && sequence != nullptr) { + absl::StrAppend(sequence, line); + } + } + return result; +} + +absl::StatusOr> FastaFileIterator::Next() { + std::string line_str; + while (std::getline(reader_, line_str)) { + absl::string_view line = line_str; + line = absl::StripAsciiWhitespace(line); + if (absl::ConsumePrefix(&line, ">")) { + if (!description_.has_value()) { + description_ = line; + } else { + std::pair output(sequence_, *description_); + description_ = line; + sequence_ = ""; + return output; + } + } else if (description_.has_value()) { + absl::StrAppend(&sequence_, line); + } + } + has_next_ = false; + reader_.close(); + if (description_.has_value()) { + return std::pair(sequence_, *description_); + } else { + return absl::InvalidArgumentError( + absl::StrCat("Invalid FASTA file: ", filename_)); + } +} + +absl::StatusOr> +FastaStringIterator::Next() { + size_t consumed = 0; + for (absl::string_view line_raw : absl::StrSplit(fasta_string_, '\n')) { + consumed += line_raw.size() + 1; // +1 for the newline character. + absl::string_view line = absl::StripAsciiWhitespace(line_raw); + if (absl::ConsumePrefix(&line, ">")) { + if (!description_.has_value()) { + description_ = line; + } else { + std::pair output(sequence_, *description_); + description_ = line; + sequence_ = ""; + fasta_string_.remove_prefix(consumed); + return output; + } + } else if (description_.has_value()) { + absl::StrAppend(&sequence_, line); + } + } + has_next_ = false; + if (description_.has_value()) { + return std::pair(sequence_, *description_); + } else { + return absl::InvalidArgumentError("Invalid FASTA string"); + } +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h new file mode 100644 index 000000000..486d05f20 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_lib.h @@ -0,0 +1,94 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +// A C++ implementation of a FASTA parser. +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" + +namespace alphafold3 { + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns a list of amino acid sequences only. +std::vector ParseFasta(absl::string_view fasta_string); + +// Parse FASTA string and return list of strings with amino acid sequences. +// Returns two lists: The first one with amino acid sequences, the second with +// the descriptions associated with each sequence. +std::pair, std::vector> +ParseFastaIncludeDescriptions(absl::string_view fasta_string); + +// Lazy FASTA parser for memory efficient FASTA parsing from a path. +class FastaFileIterator { + public: + // Initialise FastaFileIterator with filename of fasta. If you initialize + // reader_ with an invalid path or empty file, it won't fail, only + // riegeli::ReadLine within the Next method will then return false. That will + // then trigger the "Invalid FASTA file" error. + explicit FastaFileIterator(absl::string_view fasta_path) + : filename_(fasta_path), + reader_(filename_, std::ios::in), + has_next_(true) {} + + // Returns whether there are more sequences. Returns true before first call to + // next even if the file is empty. + bool HasNext() const { return has_next_; } + + // Fetches the next (sequence, description) from the file. + absl::StatusOr> Next(); + + private: + // Use riegeli::FileReader instead of FileLineIterator for about 2x speedup. + std::string filename_; + std::fstream reader_; + std::optional description_; + std::string sequence_; + bool has_next_; +}; + +// Lazy FASTA parser for memory efficient FASTA parsing from a string. +class FastaStringIterator { + public: + // Initialise FastaStringIterator with a string_view of a FASTA. If you + // initialize it with an invalid FASTA string, it won't fail, the Next method + // will then return false. That will then trigger the "Invalid FASTA" error. + // WARNING: The object backing the fasta_string string_view must not be + // deleted while this Iterator is alive. + explicit FastaStringIterator(absl::string_view fasta_string) + : fasta_string_(fasta_string), has_next_(true) {} + + // Returns whether there are more sequences. Returns true before first call to + // next even if the string is empty. + bool HasNext() const { return has_next_; } + + // Fetches the next (sequence, description) from the string. + absl::StatusOr> Next(); + + private: + absl::string_view fasta_string_; + bool has_next_; + std::optional description_; + std::string sequence_; +}; + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_LIB_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc new file mode 100644 index 000000000..0b47933d4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.cc @@ -0,0 +1,127 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/fasta_iterator_lib.h" +#include "pybind11/attr.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { +namespace { + +namespace py = pybind11; + +template +T ValueOrThrowValueError(absl::StatusOr value) { + if (!value.ok()) throw py::value_error(value.status().ToString()); + return *std::move(value); +} + +constexpr char kFastaFileIteratorDoc[] = R"( +Lazy FASTA parser for memory efficient FASTA parsing from a path.)"; + +constexpr char kFastaStringIteratorDoc[] = R"( +Lazy FASTA parser for memory efficient FASTA parsing from a string. + +WARNING: The object backing the fasta_string string_view must not be +deleted while the FastaStringIterator is alive. E.g. this will break: + +``` +# Make sure the fasta_string is not interned. +fasta_string = '\n'.join(['>d\nS' for _ in range(10)]) +iterator = fasta_iterator.FastaStringIterator(fasta_string) +del fasta_string +iterator.next() # Heap use-after-free. +``` +)"; + +constexpr char kParseFastaDoc[] = R"( +Parses a FASTA string and returns a list of amino-acid sequences. + +Args: + fasta_string: The contents of a FASTA file. + +Returns: + List of sequences in the FASTA file. Descriptions are ignored. +)"; + +constexpr char kParseFastaIncludeDescriptionsDoc[] = R"( +Parses a FASTA string, returns amino-acid sequences with descriptions. + +Args: + fasta_string: The contents of a FASTA file. + +Returns: + A tuple with two lists (sequences, descriptions): + * A list of sequences. + * A list of sequence descriptions taken from the comment lines. In the + same order as the sequences. +)"; + +class PythonFastaStringIterator : public FastaStringIterator { + public: + explicit PythonFastaStringIterator(py::object fasta_string) + : FastaStringIterator(py::cast(fasta_string)), + fasta_string_(std::move(fasta_string)) {} + + private: + py::object fasta_string_; +}; + +} // namespace + +void RegisterModuleFastaIterator(pybind11::module m) { + py::class_(m, "FastaFileIterator", kFastaFileIteratorDoc) + .def(py::init(), py::arg("fasta_path")) + .def("__iter__", + [](FastaFileIterator& iterator) -> FastaFileIterator& { + return iterator; + }) + .def( + "__next__", + [](FastaFileIterator& iterator) { + if (iterator.HasNext()) { + return ValueOrThrowValueError(iterator.Next()); + } else { + throw py::stop_iteration(); + } + }, + py::call_guard()); + + py::class_(m, "FastaStringIterator", + kFastaStringIteratorDoc) + .def(py::init(), py::arg("fasta_string")) + .def("__iter__", + [](PythonFastaStringIterator& iterator) + -> PythonFastaStringIterator& { return iterator; }) + .def( + "__next__", + [](PythonFastaStringIterator& iterator) { + if (iterator.HasNext()) { + return ValueOrThrowValueError(iterator.Next()); + } else { + throw py::stop_iteration(); + } + }, + py::call_guard()); + + m.def("parse_fasta", &ParseFasta, py::arg("fasta_string"), + py::call_guard(), py::doc(kParseFastaDoc + 1)); + m.def("parse_fasta_include_descriptions", &ParseFastaIncludeDescriptions, + py::arg("fasta_string"), py::call_guard(), + py::doc(kParseFastaIncludeDescriptionsDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h new file mode 100644 index 000000000..091ea3fa2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/fasta_iterator_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleFastaIterator(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_FASTA_ITERATOR_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi new file mode 100644 index 000000000..3602032b9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion.pyi @@ -0,0 +1,26 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Type annotations for Python bindings for `msa_conversion`. + +The type annotations in this file were modified from the automatically generated +stubgen output. +""" + +from collections.abc import Iterable + + +def align_sequence_to_gapless_query( + sequence: str | bytes, + query_sequence: str | bytes, +) -> str: ... + + +def convert_a3m_to_stockholm(a3m_sequences: Iterable[str]) -> list[str]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc new file mode 100644 index 000000000..c192052f0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.cc @@ -0,0 +1,162 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include + +#include "absl/strings/ascii.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" + +namespace { + +namespace py = pybind11; + +std::vector ConvertA3MToStockholm( + std::vector a3m_sequences) { + std::vector stockholm_sequences(a3m_sequences.size()); + auto max_length_element = + std::max_element(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view lhs, absl::string_view rhs) { + return lhs.size() < rhs.size(); + }); + + for (auto& out : stockholm_sequences) { + out.reserve(max_length_element->size()); + } + + // While any sequence has remaining columns. + while (std::any_of(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view in) { return !in.empty(); })) { + if (std::any_of(a3m_sequences.begin(), a3m_sequences.end(), + [](absl::string_view in) { + return !in.empty() && absl::ascii_islower(in.front()); + })) { + // Insertion(s) found at column. + for (std::size_t i = 0; i < a3m_sequences.size(); ++i) { + absl::string_view& in = a3m_sequences[i]; + std::string& out = stockholm_sequences[i]; + if (!in.empty() && absl::ascii_islower(in.front())) { + // Consume insertion. + out.push_back(absl::ascii_toupper(in.front())); + in.remove_prefix(1); + } else { + // Row requires padding. + out.push_back('-'); + } + } + } else { + // No insertions found. + for (std::size_t i = 0; i < a3m_sequences.size(); ++i) { + absl::string_view& in = a3m_sequences[i]; + std::string& out = stockholm_sequences[i]; + if (!in.empty()) { + // Consume entire column. + out.push_back(in.front()); + in.remove_prefix(1); + } else { + // One alignment is shorter than the others. Should not happen with + // valid A3M input. + throw std::invalid_argument(absl::StrFormat( + "a3m rows have inconsistent lengths; row %d has no columns left " + "but not all rows are exhausted", + i)); + } + } + } + } + return stockholm_sequences; +} + +std::string AlignSequenceToGaplessQuery(absl::string_view sequence, + absl::string_view query_sequence) { + if (sequence.size() != query_sequence.size()) { + throw py::value_error( + absl::StrFormat("The sequence (%d) and the query sequence (%d) don't " + "have the same length.", + sequence.size(), query_sequence.size())); + } + std::string output; + for (std::size_t residue_index = 0, sequence_length = sequence.size(); + residue_index < sequence_length; ++residue_index) { + const char query_residue = query_sequence[residue_index]; + const char residue = sequence[residue_index]; + if (query_residue != '-') { + // No gap in the query, so the residue is aligned. + output += residue; + } else if (residue == '-') { + // Gap in both sequence and query, simply skip. + continue; + } else { + // Gap only in the query, so this must be an inserted residue. + output += absl::ascii_tolower(residue); + } + } + return output; +} + +constexpr char kConvertA3mToStockholm[] = R"( +Converts a list of sequences in a3m format to stockholm format sequences. + +As an example if the input is: +abCD +CgD +fCDa + +Then the output will be: +ABC-D- +--CGD- +F-C-DA + +Args: + a3m_sequences: A list of strings in a3m format. + +Returns + A list of strings converted to stockholm format. +)"; + +constexpr char kAlignSequenceToGaplessQuery[] = R"( +Aligns a sequence to a gapless query sequence. + +This is useful when converting Stockholm MSA to A3M MSA. Example: +Seq : AB--E +Query: A--DE +Output: Ab-E. + +Args: + sequence: A string containing to be aligned. + query_sequence: A string containing the reference sequence to align to. + +Returns + The input sequence with gaps dropped where both the `sequence` and + `query_sequence` have gaps, and sequence elements non-capitalized where the + `query_sequence` has a gap, but the `sequence` does not. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMsaConversion(pybind11::module m) { + m.def("convert_a3m_to_stockholm", &ConvertA3MToStockholm, + py::arg("a3m_sequences"), py::call_guard(), + py::doc(kConvertA3mToStockholm + 1)); + m.def("align_sequence_to_gapless_query", &AlignSequenceToGaplessQuery, + py::arg("sequence"), py::arg("query_sequence"), + py::call_guard(), + py::doc(kAlignSequenceToGaplessQuery + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h new file mode 100644 index 000000000..65f5fe99e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/parsers/cpp/msa_conversion_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMsaConversion(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_PARSERS_PYTHON_MSA_CONVERSION_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py new file mode 100644 index 000000000..17f44cd06 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/__init__.py @@ -0,0 +1,46 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Structure module initialization.""" + +# pylint: disable=g-importing-member +from alphafold3.structure.bioassemblies import BioassemblyData +from alphafold3.structure.bonds import Bonds +from alphafold3.structure.chemical_components import ChemCompEntry +from alphafold3.structure.chemical_components import ChemicalComponentsData +from alphafold3.structure.chemical_components import get_data_for_ccd_components +from alphafold3.structure.chemical_components import populate_missing_ccd_data +from alphafold3.structure.mmcif import BondParsingError +from alphafold3.structure.parsing import BondAtomId +from alphafold3.structure.parsing import from_atom_arrays +from alphafold3.structure.parsing import from_mmcif +from alphafold3.structure.parsing import from_parsed_mmcif +from alphafold3.structure.parsing import from_res_arrays +from alphafold3.structure.parsing import from_sequences_and_bonds +from alphafold3.structure.parsing import ModelID +from alphafold3.structure.parsing import SequenceFormat +from alphafold3.structure.structure import ARRAY_FIELDS +from alphafold3.structure.structure import AuthorNamingScheme +from alphafold3.structure.structure import Bond +from alphafold3.structure.structure import CascadeDelete +from alphafold3.structure.structure import concat +from alphafold3.structure.structure import enumerate_residues +from alphafold3.structure.structure import fix_non_standard_polymer_residues +from alphafold3.structure.structure import GLOBAL_FIELDS +from alphafold3.structure.structure import make_empty_structure +from alphafold3.structure.structure import MissingAtomError +from alphafold3.structure.structure import MissingAuthorResidueIdError +from alphafold3.structure.structure import multichain_residue_index +from alphafold3.structure.structure import stack +from alphafold3.structure.structure import Structure +from alphafold3.structure.structure_tables import Atoms +from alphafold3.structure.structure_tables import Chains +from alphafold3.structure.structure_tables import Residues diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py new file mode 100644 index 000000000..166fff5f2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bioassemblies.py @@ -0,0 +1,333 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for parsing and manipulating bioassembly data.""" + +from collections.abc import Mapping, Sequence +import copy +import dataclasses +from typing_extensions import Self + +from alphafold3.structure import mmcif +import numpy as np + + +@dataclasses.dataclass(frozen=True) +class Operation: + """A rigid transformation operation.""" + + trans: np.ndarray # shape: (3,) + rot: np.ndarray # shape: (3, 3) + + def apply_to_coords(self, coords: np.ndarray) -> np.ndarray: + """Applies the rotation followed by the translation to `coords`.""" + return np.dot(coords, self.rot.T) + self.trans[np.newaxis, :] + + +@dataclasses.dataclass(frozen=True) +class Transform: + """A rigid transformation composed of a sequence of `Operation`s.""" + + # The sequence of operations that form the transform. These will be applied + # right-to-left (last-to-first). + operations: Sequence[Operation] + + # The chain IDs that this transform should be applied to. These are + # label_asym_ids in the mmCIF spec. + chain_ids: Sequence[str] + + # A mapping from chain IDs (of chains that participate in this transform) + # to their new values in the bioassembly. + chain_id_rename_map: Mapping[str, str] + + def apply_to_coords(self, coords: np.ndarray) -> np.ndarray: + """Applies the `operations` in right-to-left order.""" + for operation in reversed(self.operations): + coords = operation.apply_to_coords(coords) + return coords + + +def _get_operation(oper_data: Mapping[str, str]) -> Operation: + """Parses an `Operation` from a mmCIF _pdbx_struct_oper_list row.""" + trans = np.zeros((3,), dtype=np.float32) + rot = np.zeros((3, 3), dtype=np.float32) + for i in range(3): + trans[i] = float(oper_data[f'_pdbx_struct_oper_list.vector[{i + 1}]']) + for i in range(3): + for j in range(3): + rot[i][j] = float( + oper_data[f'_pdbx_struct_oper_list.matrix[{i + 1}][{j + 1}]'] + ) + return Operation(trans=trans, rot=rot) + + +class MissingBioassemblyDataError(Exception): + """Raised when bioassembly data is missing from an mmCIF.""" + + +class BioassemblyData: + """Stores and processes bioassembly data from mmCIF tables.""" + + # Not all of these columns are required for internal operations, but all + # should be present whenever bioassemblies are defined in an mmCIF to stay + # consistent with external mmCIFs. + _REQUIRED_COLUMNS = ( + '_pdbx_struct_assembly.id', + '_pdbx_struct_assembly.details', + '_pdbx_struct_assembly.method_details', + '_pdbx_struct_assembly.oligomeric_details', + '_pdbx_struct_assembly.oligomeric_count', + '_pdbx_struct_assembly_gen.assembly_id', + '_pdbx_struct_assembly_gen.oper_expression', + '_pdbx_struct_assembly_gen.asym_id_list', + '_pdbx_struct_oper_list.id', + '_pdbx_struct_oper_list.type', + '_pdbx_struct_oper_list.name', + '_pdbx_struct_oper_list.symmetry_operation', + '_pdbx_struct_oper_list.matrix[1][1]', + '_pdbx_struct_oper_list.matrix[1][2]', + '_pdbx_struct_oper_list.matrix[1][3]', + '_pdbx_struct_oper_list.vector[1]', + '_pdbx_struct_oper_list.matrix[2][1]', + '_pdbx_struct_oper_list.matrix[2][2]', + '_pdbx_struct_oper_list.matrix[2][3]', + '_pdbx_struct_oper_list.vector[2]', + '_pdbx_struct_oper_list.matrix[3][1]', + '_pdbx_struct_oper_list.matrix[3][2]', + '_pdbx_struct_oper_list.matrix[3][3]', + '_pdbx_struct_oper_list.vector[3]', + ) + + def __init__( + self, + *, + pdbx_struct_assembly: Mapping[str, Mapping[str, str]], + pdbx_struct_assembly_gen: Mapping[str, Sequence[Mapping[str, str]]], + pdbx_struct_oper_list: Mapping[str, Mapping[str, str]], + assembly_ids: Sequence[str], + oper_ids: Sequence[str], + ): + for assembly_id in assembly_ids: + for table, table_name in ( + (pdbx_struct_assembly, '_pdbx_struct_assembly'), + (pdbx_struct_assembly_gen, '_pdbx_struct_assembly_gen'), + ): + if assembly_id not in table: + raise ValueError( + f'Assembly ID "{assembly_id}" missing from {table_name} ' + f'with keys: {table.keys()}' + ) + for oper_id in oper_ids: + if oper_id not in pdbx_struct_oper_list: + raise ValueError( + f'Oper ID "{oper_id}" missing from _pdbx_struct_oper_list ' + f'with keys: {pdbx_struct_oper_list.keys()}' + ) + + self._pdbx_struct_assembly = pdbx_struct_assembly + self._pdbx_struct_assembly_gen = pdbx_struct_assembly_gen + self._pdbx_struct_oper_list = pdbx_struct_oper_list + self._operations = { + oper_id: _get_operation(oper_data) + for oper_id, oper_data in self._pdbx_struct_oper_list.items() + } + self._assembly_ids = assembly_ids + self._oper_ids = oper_ids + + @classmethod + def from_mmcif(cls, cif: mmcif.Mmcif) -> Self: + """Constructs an instance of `BioassemblyData` from an `Mmcif` object.""" + for col in cls._REQUIRED_COLUMNS: + if col not in cif: + raise MissingBioassemblyDataError(col) + + pdbx_struct_assembly = cif.extract_loop_as_dict( + prefix='_pdbx_struct_assembly.', index='_pdbx_struct_assembly.id' + ) + pdbx_struct_oper_list = cif.extract_loop_as_dict( + prefix='_pdbx_struct_oper_list.', index='_pdbx_struct_oper_list.id' + ) + + # _pdbx_struct_assembly_gen is unlike the other two tables because it can + # have multiple rows share the same assembly ID. This can happen when an + # assembly is constructed by applying different sets of transforms to + # different sets of chain IDs. Each of these would have its own row. + # Here we group rows by their assembly_id. + pdbx_struct_assembly_gen = {} + for assembly_id, oper_expression, asym_id_list in zip( + cif['_pdbx_struct_assembly_gen.assembly_id'], + cif['_pdbx_struct_assembly_gen.oper_expression'], + cif['_pdbx_struct_assembly_gen.asym_id_list'], + ): + pdbx_struct_assembly_gen.setdefault(assembly_id, []).append({ + '_pdbx_struct_assembly_gen.assembly_id': assembly_id, + '_pdbx_struct_assembly_gen.oper_expression': oper_expression, + '_pdbx_struct_assembly_gen.asym_id_list': asym_id_list, + }) + + # We provide these separately to keep track of the original order that they + # appear in the mmCIF. + assembly_ids = cif['_pdbx_struct_assembly.id'] + oper_ids = cif['_pdbx_struct_oper_list.id'] + return cls( + pdbx_struct_assembly=pdbx_struct_assembly, + pdbx_struct_assembly_gen=pdbx_struct_assembly_gen, + pdbx_struct_oper_list=pdbx_struct_oper_list, + assembly_ids=assembly_ids, + oper_ids=oper_ids, + ) + + @property + def assembly_ids(self) -> Sequence[str]: + return self._assembly_ids + + def asym_id_by_assembly_chain_id(self, assembly_id: str) -> Mapping[str, str]: + asym_id_by_assembly_chain_id = {} + for transform in self.get_transforms(assembly_id): + for asym_id, assembly_chain_id in transform.chain_id_rename_map.items(): + asym_id_by_assembly_chain_id[assembly_chain_id] = asym_id + return asym_id_by_assembly_chain_id + + def assembly_chain_ids_by_asym_id( + self, assembly_id: str + ) -> Mapping[str, set[str]]: + assembly_chain_ids_by_asym_id = {} + for transform in self.get_transforms(assembly_id): + for asym_id, assembly_chain_id in transform.chain_id_rename_map.items(): + assembly_chain_ids_by_asym_id.setdefault(asym_id, set()).add( + assembly_chain_id + ) + return assembly_chain_ids_by_asym_id + + def get_default_assembly_id(self) -> str: + """Gets a default assembly ID.""" + # The first assembly is usually (though not always) the best choice. + # If we find a better heuristic for picking bioassemblies then this + # method should be updated. + return min(self._assembly_ids) + + def get_assembly_info(self, assembly_id: str) -> Mapping[str, str]: + return { + k.replace('_pdbx_struct_assembly.', ''): v + for k, v in self._pdbx_struct_assembly[assembly_id].items() + } + + def get_transforms(self, assembly_id: str) -> Sequence[Transform]: + """Returns the transforms required to generate the given assembly.""" + partial_transforms = [] + all_chain_ids = set() + for row in self._pdbx_struct_assembly_gen[assembly_id]: + oper_expression = row['_pdbx_struct_assembly_gen.oper_expression'] + parsed_oper_id_seqs = mmcif.parse_oper_expr(oper_expression) + label_asym_ids = row['_pdbx_struct_assembly_gen.asym_id_list'].split( + ',') + all_chain_ids |= set(label_asym_ids) + for parsed_oper_id_seq in parsed_oper_id_seqs: + partial_transforms.append((parsed_oper_id_seq, label_asym_ids)) + + # We start assigning new chain IDs by finding the largest chain ID in + # the original structure that is involved in this bioassembly, and then + # starting from the next one. + max_int_chain_id = max(mmcif.str_id_to_int_id(c) + for c in all_chain_ids) + next_int_chain_id = max_int_chain_id + 1 + + transforms = [] + has_been_renamed = set() + for parsed_oper_id_seq, label_asym_ids in partial_transforms: + chain_id_rename_map = {} + for label_asym_id in label_asym_ids: + if label_asym_id not in has_been_renamed: + # The first time we see a label_asym_id we don't need to rename it. + # This isn't strictly necessary since we don't provide any + # guarantees about chain naming after bioassembly extraction but + # can make it a bit easier to inspect and compare structures + # pre and post bioassembly extraction. + chain_id_rename_map[label_asym_id] = label_asym_id + has_been_renamed.add(label_asym_id) + else: + chain_id_rename_map[label_asym_id] = mmcif.int_id_to_str_id( + next_int_chain_id + ) + next_int_chain_id += 1 + transforms.append( + Transform( + operations=[ + self._operations[oper_id] for oper_id in parsed_oper_id_seq + ], + chain_ids=label_asym_ids, + chain_id_rename_map=chain_id_rename_map, + ) + ) + return transforms + + def to_mmcif_dict(self) -> Mapping[str, Sequence[str]]: + """Returns the bioassembly data as a dict suitable for `mmcif.Mmcif`.""" + mmcif_dict = {} + for assembly_id in self._assembly_ids: + for column, val in self._pdbx_struct_assembly[assembly_id].items(): + mmcif_dict.setdefault(column, []).append(val) + for row in self._pdbx_struct_assembly_gen[assembly_id]: + for column, val in row.items(): + mmcif_dict.setdefault(column, []).append(val) + for oper_id in self._oper_ids: + for column, val in self._pdbx_struct_oper_list[oper_id].items(): + mmcif_dict.setdefault(column, []).append(val) + return mmcif_dict + + def rename_label_asym_ids( + self, + mapping: Mapping[str, str], + present_chains: set[str], + ) -> Self: + """Returns a new BioassemblyData with renamed label_asym_ids. + + Args: + mapping: A mapping from original label_asym_ids to their new values. Any + label_asym_ids in this BioassemblyData that are not in this mapping will + remain unchanged. + present_chains: A set of label_asym_ids that are actually present in the + atom site list. All label_asym_ids that are in the BioassemblyData but + not in present_chains won't be included in the output BioassemblyData. + + Returns: + A new BioassemblyData with renamed label_asym_ids. + + Raises: + ValueError: If any two previously distinct chains do not have unique names + anymore after the rename. + """ + new_pdbx_struct_assembly_gen = copy.deepcopy( + self._pdbx_struct_assembly_gen) + for rows in new_pdbx_struct_assembly_gen.values(): + for row in rows: + old_asym_ids = row['_pdbx_struct_assembly_gen.asym_id_list'].split( + ',') + new_asym_ids = [ + mapping.get(label_asym_id, label_asym_id) + for label_asym_id in old_asym_ids + if label_asym_id in present_chains + ] + if len(set(old_asym_ids) & present_chains) != len(set(new_asym_ids)): + raise ValueError( + 'Can not rename chains, the new names are not unique: ' + f'{sorted(new_asym_ids)}.' + ) + row['_pdbx_struct_assembly_gen.asym_id_list'] = ','.join( + new_asym_ids) # pytype: disable=unsupported-operands + + return BioassemblyData( + pdbx_struct_assembly=copy.deepcopy(self._pdbx_struct_assembly), + pdbx_struct_assembly_gen=new_pdbx_struct_assembly_gen, + pdbx_struct_oper_list=copy.deepcopy(self._pdbx_struct_oper_list), + assembly_ids=copy.deepcopy(self._assembly_ids), + oper_ids=copy.deepcopy(self._oper_ids), + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py new file mode 100644 index 000000000..ce21863b1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/bonds.py @@ -0,0 +1,237 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Bond representation for structure module.""" + +import collections +from collections.abc import Mapping, Sequence +import dataclasses +import typing +from typing_extensions import Self + +from alphafold3.structure import table +import numpy as np + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Bonds(table.Table): + """Table of atomic bonds.""" + + # mmCIF column: _struct_conn.conn_type_id + # mmCIF desc: This data item is a pointer to _struct_conn_type.id in the + # STRUCT_CONN_TYPE category. + # E.g.: "covale", "disulf", "hydrog", "metalc". + type: np.ndarray + + # mmCIF column: _struct_conn.pdbx_role + # mmCIF desc: The chemical or structural role of the interaction. + # E.g.: "N-Glycosylation", "O-Glycosylation". + role: np.ndarray + + # mmCIF columns: _struct_conn.ptnr1_* + from_atom_key: np.ndarray + + # mmCIF columns: _struct_conn.ptnr2_* + dest_atom_key: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.empty((0,), dtype=np.int64), + from_atom_key=np.empty((0,), dtype=np.int64), + dest_atom_key=np.empty((0,), dtype=np.int64), + type=np.empty((0,), dtype=object), + role=np.empty((0,), dtype=object), + ) + + def get_atom_indices( + self, + atom_key: np.ndarray, + ) -> tuple[np.ndarray, np.ndarray]: + """Returns the indices of the from/dest atoms in the atom_key array.""" + from_atom_missing = ~np.isin(self.from_atom_key, atom_key) + dest_atom_missing = ~np.isin(self.dest_atom_key, atom_key) + if np.any(from_atom_missing): + raise ValueError( + f'No atoms for from_atom_key {self.from_atom_key[from_atom_missing]}' + ) + if np.any(dest_atom_missing): + raise ValueError( + f'No atoms for dest_atom_key {self.dest_atom_key[dest_atom_missing]}' + ) + sort_indices = np.argsort(atom_key) + from_indices_sorted = np.searchsorted( + atom_key, self.from_atom_key, sorter=sort_indices + ) + dest_indices_sorted = np.searchsorted( + atom_key, self.dest_atom_key, sorter=sort_indices + ) + from_indices = sort_indices[from_indices_sorted] + dest_indices = sort_indices[dest_indices_sorted] + return from_indices, dest_indices + + def restrict_to_atoms(self, atom_key: np.ndarray) -> Self: + if not self.size: # Early-out for empty table. + return self + from_atom_mask = np.isin(self.from_atom_key, atom_key) + dest_atom_mask = np.isin(self.dest_atom_key, atom_key) + mask = np.logical_and(from_atom_mask, dest_atom_mask) + return typing.cast(Bonds, self.filter(mask=mask)) + + def to_mmcif_dict_from_atom_arrays( + self, + atom_key: np.ndarray, + chain_id: np.ndarray, + res_id: np.ndarray, + res_name: np.ndarray, + atom_name: np.ndarray, + auth_asym_id: np.ndarray, + auth_seq_id: np.ndarray, + insertion_code: np.ndarray, + ) -> Mapping[str, Sequence[str] | np.ndarray]: + """Returns a dict suitable for building a CifDict, representing bonds. + + Args: + atom_key: A (num_atom,) integer array of atom_keys. + chain_id: A (num_atom,) array of label_asym_id strings. + res_id: A (num_atom,) array of label_seq_id strings. + res_name: A (num_atom,) array of label_comp_id strings. + atom_name: A (num_atom,) array of label_atom_id strings. + auth_asym_id: A (num_atom,) array of auth_asym_id strings. + auth_seq_id: A (num_atom,) array of auth_seq_id strings. + insertion_code: A (num_atom,) array of insertion code strings. + """ + mmcif_dict = collections.defaultdict(list) + ptnr1_indices, ptnr2_indices = self.get_atom_indices(atom_key) + + mmcif_dict['_struct_conn.ptnr1_label_asym_id'] = chain_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_asym_id'] = chain_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_comp_id'] = res_name[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_comp_id'] = res_name[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_seq_id'] = res_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_seq_id'] = res_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_label_atom_id'] = atom_name[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_label_atom_id'] = atom_name[ptnr2_indices] + + mmcif_dict['_struct_conn.ptnr1_auth_asym_id'] = auth_asym_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_auth_asym_id'] = auth_asym_id[ptnr2_indices] + mmcif_dict['_struct_conn.ptnr1_auth_seq_id'] = auth_seq_id[ptnr1_indices] + mmcif_dict['_struct_conn.ptnr2_auth_seq_id'] = auth_seq_id[ptnr2_indices] + mmcif_dict['_struct_conn.pdbx_ptnr1_PDB_ins_code'] = insertion_code[ + ptnr1_indices + ] + mmcif_dict['_struct_conn.pdbx_ptnr2_PDB_ins_code'] = insertion_code[ + ptnr2_indices + ] + + label_alt_id = ['?'] * self.size + mmcif_dict['_struct_conn.pdbx_ptnr1_label_alt_id'] = label_alt_id + mmcif_dict['_struct_conn.pdbx_ptnr2_label_alt_id'] = label_alt_id + + # We need to set this to make visualisation work in NGL/PyMOL. + mmcif_dict['_struct_conn.pdbx_value_order'] = ['?'] * self.size + + # We use a symmetry of 1_555 which is the no-op transformation. Other + # values are used when bonds involve atoms that only exist after expanding + # the bioassembly, but we don't support this kind of bond at the moment. + symmetry = ['1_555'] * self.size + mmcif_dict['_struct_conn.ptnr1_symmetry'] = symmetry + mmcif_dict['_struct_conn.ptnr2_symmetry'] = symmetry + bond_type_counter = collections.Counter() + for bond_row in self.iterrows(): + bond_type = bond_row['type'] + bond_type_counter[bond_type] += 1 + mmcif_dict['_struct_conn.id'].append( + f'{bond_type}{bond_type_counter[bond_type]}' + ) + mmcif_dict['_struct_conn.pdbx_role'].append(bond_row['role']) + mmcif_dict['_struct_conn.conn_type_id'].append(bond_type) + + bond_types = np.unique(self.type) + mmcif_dict['_struct_conn_type.id'] = bond_types + unknown = ['?'] * len(bond_types) + mmcif_dict['_struct_conn_type.criteria'] = unknown + mmcif_dict['_struct_conn_type.reference'] = unknown + + return dict(mmcif_dict) + + +def concat_with_atom_keys( + bonds_tables: Sequence[Bonds | None], + atom_key_arrays: Sequence[np.ndarray], +) -> tuple[Bonds | None, np.ndarray]: + """Concatenates bonds tables and atom keys simultaneously. + + Args: + bonds_tables: A sequence of `Bonds` instances to concatenate. If any are + None then these are skipped. + atom_key_arrays: A sequence of integer `atom_key` arrays, where the n-th + bonds_table referrs to the atoms in the n-th atom_key array. These must + all be non-None. + + Returns: + A pair of (bonds, atom_key) where atom_key is a unique atom_key array with + length equal to the sum of the input atom array sizes, and the bonds table + contains all the bonds from the individual bonds table inputs. + """ + if not bonds_tables or not atom_key_arrays: + if bonds_tables or atom_key_arrays: + raise ValueError( + 'bonds_tables and atom_keys must have same length but got' + f' {len(bonds_tables)=} and {len(atom_key_arrays)=}' + ) + return None, np.array([], dtype=np.int64) + max_key = -1 + atom_keys_to_concat = [] + types_to_concat = [] + roles_to_concat = [] + from_atom_keys_to_concat = [] + dest_atom_keys_to_concat = [] + for bonds, atom_key in zip(bonds_tables, atom_key_arrays, strict=True): + if not atom_key.size: + assert bonds is None or bonds.size == 0 + continue + # Should always be non-negative! + assert np.min(atom_key, initial=0) >= 0 + offset = max_key + 1 + offset_atom_key = atom_key + offset + atom_keys_to_concat.append(offset_atom_key) + max_key = np.max(offset_atom_key) + if bonds is not None: + types_to_concat.append(bonds.type) + roles_to_concat.append(bonds.role) + from_atom_keys_to_concat.append(bonds.from_atom_key + offset) + dest_atom_keys_to_concat.append(bonds.dest_atom_key + offset) + + if atom_keys_to_concat: + concatted_atom_keys = np.concatenate(atom_keys_to_concat, axis=0) + else: + concatted_atom_keys = np.array([], dtype=np.int64) + + if types_to_concat: + assert ( + len(types_to_concat) + == len(roles_to_concat) + == len(from_atom_keys_to_concat) + == len(dest_atom_keys_to_concat) + ) + num_bonds = sum(b.size for b in bonds_tables if b is not None) + concatted_bonds = Bonds( + key=np.arange(num_bonds, dtype=np.int64), + type=np.concatenate(types_to_concat, axis=0), + role=np.concatenate(roles_to_concat, axis=0), + from_atom_key=np.concatenate(from_atom_keys_to_concat, axis=0), + dest_atom_key=np.concatenate(dest_atom_keys_to_concat, axis=0), + ) + else: + concatted_bonds = None + + return concatted_bonds, concatted_atom_keys diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py new file mode 100644 index 000000000..e5a2af3e1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/chemical_components.py @@ -0,0 +1,286 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for manipulating chemical components data.""" + +from collections.abc import Iterable, Mapping, Sequence +import dataclasses +import functools +from typing_extensions import Self + +from alphafold3.constants import chemical_components +from alphafold3.constants import residue_names +from alphafold3.structure import mmcif +import rdkit.Chem as rd_chem + + +@dataclasses.dataclass(frozen=True) +class ChemCompEntry: + """Items of _chem_comp category. + + For the full list of items and their semantics see + http://mmcif.rcsb.org/dictionaries/mmcif_pdbx_v50.dic/Categories/chem_comp.html + """ + + type: str + name: str = '?' + pdbx_synonyms: str = '?' + formula: str = '?' + formula_weight: str = '?' + mon_nstd_flag: str = '?' + pdbx_smiles: str | None = None + + def __post_init__(self): + for field, value in vars(self).items(): + if not value and value is not None: + raise ValueError(f"{field} value can't be an empty string.") + + def extends(self, other: Self) -> bool: + """Checks whether this ChemCompEntry extends another one.""" + for field, value in vars(self).items(): + other_value = getattr(other, field) + if _value_is_missing(other_value): + continue + if value != other_value: + return False + return True + + @property + def rdkit_mol(self) -> rd_chem.Mol: + """Returns an RDKit Mol, created via RDKit from entry SMILES string.""" + if not self.pdbx_smiles: + raise ValueError( + 'Cannot construct RDKit Mol with empty pdbx_smiles') + return rd_chem.MolFromSmiles(self.pdbx_smiles) + + +_REQUIRED_MMCIF_COLUMNS = ('_chem_comp.id', '_chem_comp.type') + + +class MissingChemicalComponentsDataError(Exception): + """Raised when chemical components data is missing from an mmCIF.""" + + +@dataclasses.dataclass(frozen=True) +class ChemicalComponentsData: + """Extra information for chemical components occurring in mmCIF. + + Fields: + chem_comp: A mapping from _chem_comp.id to associated items in the + chem_comp category. + """ + + chem_comp: Mapping[str, ChemCompEntry] + + @classmethod + def from_mmcif( + cls, cif: mmcif.Mmcif, fix_mse: bool, fix_unknown_dna: bool + ) -> Self: + """Constructs an instance of ChemicalComponentsData from an Mmcif object.""" + for col in _REQUIRED_MMCIF_COLUMNS: + if col not in cif: + raise MissingChemicalComponentsDataError(col) + + id_ = cif['_chem_comp.id'] # Guaranteed to be present. + type_ = cif['_chem_comp.type'] # Guaranteed to be present. + name = cif.get('_chem_comp.name', ['?'] * len(id_)) + synonyms = cif.get('_chem_comp.pdbx_synonyms', ['?'] * len(id_)) + formula = cif.get('_chem_comp.formula', ['?'] * len(id_)) + weight = cif.get('_chem_comp.formula_weight', ['?'] * len(id_)) + mon_nstd_flag = cif.get('_chem_comp.mon_nstd_flag', ['?'] * len(id_)) + smiles = cif.get('_chem_comp.pdbx_smiles', ['?'] * len(id_)) + smiles = [None if s == '?' else s for s in smiles] + + chem_comp = { + component_name: ChemCompEntry(*entry) + for component_name, *entry in zip( + id_, type_, name, synonyms, formula, weight, mon_nstd_flag, smiles + ) + } + + if fix_mse and 'MSE' in chem_comp: + if 'MET' not in chem_comp: + chem_comp['MET'] = ChemCompEntry( + type='L-PEPTIDE LINKING', + name='METHIONINE', + pdbx_synonyms='?', + formula='C5 H11 N O2 S', + formula_weight='149.211', + mon_nstd_flag='y', + pdbx_smiles=None, + ) + + if fix_unknown_dna and 'N' in chem_comp: + # Do not delete 'N' as it may be needed for RNA in the system. + if 'DN' not in chem_comp: + chem_comp['DN'] = ChemCompEntry( + type='DNA LINKING', + name="UNKNOWN 2'-DEOXYNUCLEOTIDE", + pdbx_synonyms='?', + formula='C5 H11 O6 P', + formula_weight='198.111', + mon_nstd_flag='y', + pdbx_smiles=None, + ) + + return ChemicalComponentsData(chem_comp) + + def to_mmcif_dict(self) -> Mapping[str, Sequence[str]]: + """Returns chemical components data as a dict suitable for `mmcif.Mmcif`.""" + mmcif_dict = {} + + mmcif_fields = set() + for entry in self.chem_comp.values(): + for field, value in vars(entry).items(): + if value: + mmcif_fields.add(field) + chem_comp_ids = [] + for component_id in sorted(self.chem_comp): + entry = self.chem_comp[component_id] + chem_comp_ids.append(component_id) + for field in mmcif_fields: + mmcif_dict.setdefault(f'_chem_comp.{field}', []).append( + getattr(entry, field) or '?' + ) + if chem_comp_ids: + mmcif_dict['_chem_comp.id'] = chem_comp_ids + return mmcif_dict + + +def _value_is_missing(value: str) -> bool: + return not value or value in ('.', '?') + + +def get_data_for_ccd_components( + ccd: chemical_components.Ccd, + chemical_component_ids: Iterable[str], + populate_pdbx_smiles: bool = False, +) -> ChemicalComponentsData: + """Returns `ChemicalComponentsData` for chemical components known by PDB.""" + chem_comp = {} + for chemical_component_id in chemical_component_ids: + chem_data = chemical_components.component_name_to_info( + ccd=ccd, res_name=chemical_component_id + ) + if not chem_data: + continue + chem_comp[chemical_component_id] = ChemCompEntry( + type=chem_data.type, + name=chem_data.name, + pdbx_synonyms=chem_data.pdbx_synonyms, + formula=chem_data.formula, + formula_weight=chem_data.formula_weight, + mon_nstd_flag=chem_data.mon_nstd_flag, + pdbx_smiles=( + chem_data.pdbx_smiles or None if populate_pdbx_smiles else None + ), + ) + return ChemicalComponentsData(chem_comp=chem_comp) + + +def populate_missing_ccd_data( + ccd: chemical_components.Ccd, + chemical_components_data: ChemicalComponentsData, + chemical_component_ids: Iterable[str] | None = None, + populate_pdbx_smiles: bool = False, +) -> ChemicalComponentsData: + """Populates missing data for the chemical components from CCD. + + Args: + ccd: The chemical components database. + chemical_components_data: ChemicalComponentsData to populate missing values + for. This function doesn't modify the object, extended version is provided + as a return value. + chemical_component_ids: chemical components to populate missing values for. + If not specified, the function will consider all chemical components which + are already present in `chemical_components_data`. + populate_pdbx_smiles: whether to populate `pdbx_smiles` field using SMILES + descriptors from _pdbx_chem_comp_descriptor CCD table. If CCD provides + multiple SMILES strings, any of them could be used. + + Returns: + New instance of ChemicalComponentsData without missing values for CCD + entries. + """ + if chemical_component_ids is None: + chemical_component_ids = chemical_components_data.chem_comp.keys() + + ccd_data = get_data_for_ccd_components( + ccd, chemical_component_ids, populate_pdbx_smiles + ) + chem_comp = dict(chemical_components_data.chem_comp) + for component_id, ccd_entry in ccd_data.chem_comp.items(): + if component_id not in chem_comp: + chem_comp[component_id] = ccd_entry + else: + already_specified_fields = { + field: value + for field, value in vars(chem_comp[component_id]).items() + if not _value_is_missing(value) + } + chem_comp[component_id] = ChemCompEntry( + **{**vars(ccd_entry), **already_specified_fields} + ) + return ChemicalComponentsData(chem_comp=chem_comp) + + +def get_all_atoms_in_entry( + ccd: chemical_components.Ccd, res_name: str +) -> Mapping[str, Sequence[str]]: + """Get all possible atoms and bonds for this residue in a standard order. + + Args: + ccd: The chemical components dictionary. + res_name: Full CCD name. + + Returns: + A dictionary table of the atoms and bonds for this residue in this residue + type. + """ + # The CCD version of 'UNK' is weird. It has a CB and a CG atom. We just want + # the minimal amino-acid here which is GLY. + if res_name == 'UNK': + res_name = 'GLY' + ccd_data = ccd.get(res_name) + if not ccd_data: + raise ValueError(f'Unknown residue type {res_name}') + + keys = ( + '_chem_comp_atom.atom_id', + '_chem_comp_atom.type_symbol', + '_chem_comp_bond.atom_id_1', + '_chem_comp_bond.atom_id_2', + ) + + # Add terminal hydrogens for protonation of the N-terminal + if res_name == 'PRO': + res_atoms = {key: [*ccd_data.get(key, [])] for key in keys} + res_atoms['_chem_comp_atom.atom_id'].extend(['H2', 'H3']) + res_atoms['_chem_comp_atom.type_symbol'].extend(['H', 'H']) + res_atoms['_chem_comp_bond.atom_id_1'].extend(['N', 'N']) + res_atoms['_chem_comp_bond.atom_id_2'].extend(['H2', 'H3']) + elif res_name in residue_names.PROTEIN_TYPES_WITH_UNKNOWN: + res_atoms = {key: [*ccd_data.get(key, [])] for key in keys} + res_atoms['_chem_comp_atom.atom_id'].append('H3') + res_atoms['_chem_comp_atom.type_symbol'].append('H') + res_atoms['_chem_comp_bond.atom_id_1'].append('N') + res_atoms['_chem_comp_bond.atom_id_2'].append('H3') + else: + res_atoms = {key: ccd_data.get(key, []) for key in keys} + + return res_atoms + + +@functools.lru_cache(maxsize=128) +def get_res_atom_names(ccd: chemical_components.Ccd, res_name: str) -> set[str]: + """Gets the names of the atoms in a given CCD residue.""" + atoms = get_all_atoms_in_entry(ccd, res_name)['_chem_comp_atom.atom_id'] + return set(atoms) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi new file mode 100644 index 000000000..8f4a8b375 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation.pyi @@ -0,0 +1,13 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence + +def indices_grouped_by_value(values: Sequence[int]) -> dict[int, list[int]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc new file mode 100644 index 000000000..5ac46d62c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.cc @@ -0,0 +1,54 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/types/span.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +absl::flat_hash_map> IndicesGroupedByValue( + absl::Span values) { + absl::flat_hash_map> group_indices; + for (int64_t i = 0, e = values.size(); i < e; ++i) { + group_indices[values[i]].push_back(i); + } + return group_indices; +} + +constexpr char kIndicesGroupedByValue[] = R"( +Returns a map from value to a list of indices this value occupies. + +E.g. indices_grouped_by_value([1, 1, 2, 3, 3, 1, 1]) returns: +{1: [0, 1, 5, 6], 2: [2], 3: [3, 4]} + +Args: + values: a list of values to group. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleAggregation(py::module m) { + m.def("indices_grouped_by_value", &IndicesGroupedByValue, py::arg("values"), + py::doc(kIndicesGroupedByValue + 1), + py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h new file mode 100644 index 000000000..9547b9448 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/aggregation_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleAggregation(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_AGGREGATION_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi new file mode 100644 index 000000000..305f36600 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership.pyi @@ -0,0 +1,18 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import numpy + + +def isin( + array: numpy.ndarray[numpy.int64], + test_elements: set[int], + invert: bool = ..., +) -> numpy.ndarray[bool]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc new file mode 100644 index 000000000..2b3faf8a2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.cc @@ -0,0 +1,82 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +py::array_t IsIn(const py::array_t& array, + const absl::flat_hash_set& test_elements, + bool invert) { + const size_t num_elements = array.size(); + + py::array_t output(num_elements); + std::fill(output.mutable_data(), output.mutable_data() + output.size(), + invert); + + // Shortcut: The output will be trivially always false if test_elements empty. + if (test_elements.empty()) { + return output; + } + + for (size_t i = 0; i < num_elements; ++i) { + if (test_elements.contains(array.data()[i])) { + output.mutable_data()[i] = !invert; + } + } + if (array.ndim() > 1) { + auto shape = + std::vector(array.shape(), array.shape() + array.ndim()); + return output.reshape(shape); + } + return output; +} + +constexpr char kIsInDoc[] = R"( +Computes whether each element is in test_elements. + +Same use as np.isin, but much faster. If len(array) = n, len(test_elements) = m: +* This function has complexity O(n). +* np.isin with kind='sort' has complexity O(m*log(m) + n * log(m)). + +Args: + array: Input NumPy array with dtype=np.int64. + test_elements: The values against which to test each value of array. + invert: If True, the values in the returned array are inverted, as if + calculating `element not in test_elements`. Default is False. + `isin(a, b, invert=True)` is equivalent to but faster than `~isin(a, b)`. + +Returns + A boolean array of the same shape as the input array. Each value `val` is: + * `val in test_elements` if `invert=False`, + * `val not in test_elements` if `invert=True`. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleMembership(pybind11::module m) { + m.def("isin", &IsIn, py::arg("array"), py::arg("test_elements"), + py::kw_only(), py::arg("invert") = false, py::doc(kIsInDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h new file mode 100644 index 000000000..d224fb1f6 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/membership_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMembership(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MEMBERSHIP_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc new file mode 100644 index 000000000..cea9a1b1c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.cc @@ -0,0 +1,249 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/structure/cpp/mmcif_altlocs.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/log/log.h" +#include "absl/strings/numbers.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { +namespace { + +float OccupancyToFloat(absl::string_view occupancy) { + float result = 0.0f; + LOG_IF(ERROR, !absl::SimpleAtof(occupancy, &result)) + << "Invalid Occupancy: " << occupancy; + return result; +} + +// Deuterium is the same atom as Hydrogen so keep equivalent for grouping. +bool AtomEquiv(absl::string_view lhs, absl::string_view rhs) { + if (lhs == rhs) return true; + if (lhs.empty() != rhs.empty()) return false; + // Both lhs and rhs are guaranteed to be non-empty after this. + char first_lhs = lhs.front(); + char second_rhs = rhs.front(); + if ((first_lhs == 'H' && second_rhs == 'D') || + (first_lhs == 'D' && second_rhs == 'H')) { + lhs.remove_prefix(1); + rhs.remove_prefix(1); + return lhs == rhs; + } + return false; +} + +// Calls group_callback with that start index and count for each group of +// equivalent values in `values`, starting at `start` and ending at `count`. +// Example: +// GroupBy({"B", "B", "B", "C", "C"}, 0, 5, [](size_t start, size_t count) { +// absl::Printf("start=%d, count=%d\n", start, count); +// }); +// Would print: +// start=0, count=3 +// start=3, count=2 +template > +void GroupBy(absl::Span values, std::size_t start, + std::size_t count, GroupCallback&& group_callback, + IsEqual&& is_equal = std::equal_to{}) { + std::size_t span_start = start; + if (count > 0) { + for (std::size_t i = start + 1; i < start + count; ++i) { + if (!is_equal(values[i], values[span_start])) { + group_callback(span_start, i - span_start); + span_start = i; + } + } + group_callback(span_start, start + count - span_start); + } +} + +void ProcessAltLocGroupsWhole(std::size_t alt_loc_start, + std::size_t alt_loc_count, + absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + std::vector& in_out_keep_indices) { + std::pair best_split = {alt_loc_start, + alt_loc_count}; + std::vector alt_loc_groups; + float best_occupancy = -std::numeric_limits::infinity(); + char best_group = alt_ids[alt_loc_start].front(); + std::vector> occupancy_stats; + + // Group by residue type. + GroupBy(comp_ids, alt_loc_start, alt_loc_count, + [&](std::size_t start, std::size_t count) { + // This callback selects the best residue group and the best + // Alt-loc char within that group. + alt_loc_groups.clear(); + occupancy_stats.clear(); + // Calculate total occupancy for residue type. + for (std::size_t i = 0; i < count; ++i) { + char alt_loc_id = alt_ids[start + i].front(); + float occupancy = OccupancyToFloat(occupancies[start + i]); + if (auto loc = absl::c_find(alt_loc_groups, alt_loc_id); + loc == alt_loc_groups.end()) { + occupancy_stats.emplace_back(1, occupancy); + alt_loc_groups.push_back(alt_loc_id); + } else { + auto& stat = + occupancy_stats[std::distance(alt_loc_groups.begin(), loc)]; + ++stat.first; + stat.second += occupancy; + } + } + float total_occupancy = 0.0; + for (auto& stat : occupancy_stats) { + total_occupancy += stat.second / stat.first; + } + char group = *absl::c_min_element(alt_loc_groups); + // Compares occupancy of residue to best seen so far. + // Tie breaks alphabetic. + if (total_occupancy > best_occupancy || + (total_occupancy == best_occupancy && group < best_group)) { + // Selects the best sub group. + best_group = alt_loc_groups.front(); + float best_amount = occupancy_stats.front().second / + occupancy_stats.front().first; + for (std::size_t i = 1; i < occupancy_stats.size(); ++i) { + float amount = + occupancy_stats[i].second / occupancy_stats[i].first; + char group = alt_loc_groups[i]; + if (amount > best_amount || + (amount == best_amount && group < best_group)) { + best_amount = amount; + best_group = group; + } + } + best_occupancy = total_occupancy; + best_split = {start, count}; + } + }); + + // Now that the best residue type has been selected and the best alt-loc + // within that has been selected add indices of indices to keep to the keep + // list. + auto [split_start, split_count] = best_split; + GroupBy( + atom_ids, split_start, split_count, + [&in_out_keep_indices, &alt_ids, best_group](std::size_t start, + std::size_t count) { + // This makes sure we select an atom for each atom id even if it does + // not have our selected alt-loc char. + std::size_t best_index = start; + for (std::size_t i = 1; i < count; ++i) { + if (alt_ids[start + i].front() == best_group) { + best_index = start + i; + break; + } + } + in_out_keep_indices.push_back(best_index); + }, + AtomEquiv); +} + +// Finds the alt-loc group with the highest score and pushes the indices on to +// the back of in_out_keep_indices. +void ProcessAltLocGroupPartial( + std::size_t alt_loc_start, std::size_t alt_loc_count, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + std::vector& in_out_keep_indices) { + GroupBy( + atom_ids, alt_loc_start, alt_loc_count, + [&](std::size_t start, std::size_t count) { + if (count == 1) { + in_out_keep_indices.push_back(start); + } else { + float best_occ = OccupancyToFloat(occupancies[start]); + std::size_t best_index = start; + char best_group = alt_ids[start].front(); + for (std::size_t i = 0; i < count; ++i) { + float occ = OccupancyToFloat(occupancies[start + i]); + char group = alt_ids[start + i].front(); + if (occ > best_occ || (occ == best_occ && group < best_group)) { + best_group = group; + best_index = start + i; + best_occ = occ; + } + } + in_out_keep_indices.push_back(best_index); + } + }, + AtomEquiv); +} + +} // namespace + +// Resolves alt-locs returning the atom indices that will be left. +std::vector ResolveMmcifAltLocs( + const MmcifLayout& layout, absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + absl::Span chain_indices) { + std::vector keep_indices; + keep_indices.reserve(layout.num_atoms()); + std::size_t alt_loc_start = 0; + for (std::size_t chain_index : chain_indices) { + auto [residues_start, residues_end] = layout.residue_range(chain_index); + for (std::size_t residue = residues_start; residue < residues_end; + ++residue) { + std::size_t alt_loc_count = 0; + auto [atom_start, atom_end] = layout.atom_range(residue); + for (std::size_t i = atom_start; i < atom_end; ++i) { + char alt_loc_id = alt_ids[i].front(); + if (alt_loc_id == '.' || alt_loc_id == '?') { + if (alt_loc_count > 0) { + ProcessAltLocGroupPartial(alt_loc_start, alt_loc_count, atom_ids, + alt_ids, occupancies, keep_indices); + alt_loc_count = 0; + } + keep_indices.push_back(i); + } else { + if (alt_loc_count == 0) { + alt_loc_start = i; + } + ++alt_loc_count; + } + } + if (alt_loc_count > 0) { + if (atom_end - atom_start == alt_loc_count) { + ProcessAltLocGroupsWhole(alt_loc_start, alt_loc_count, comp_ids, + atom_ids, alt_ids, occupancies, + keep_indices); + } else { + ProcessAltLocGroupPartial(alt_loc_start, alt_loc_count, atom_ids, + alt_ids, occupancies, keep_indices); + } + } + } + } + + return keep_indices; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h new file mode 100644 index 000000000..fab57817c --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_altlocs.h @@ -0,0 +1,51 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ + +#include +#include +#include +#include + +#include "absl/types/span.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { + +// Returns the list of indices that should be kept after resolving alt-locs. +// 1) Partial Residue. Each cycle of alt-locs are resolved separately with the +// highest occupancy alt-loc. Tie-breaks are resolved alphabetically. See +// tests for examples. +// 2) Whole Residue. These are resolved in two passes. +// a) The residue with the highest occupancy is chosen. +// b) The locations for a given residue are resolved. +// All tie-breaks are resolved alphabetically. See tests for examples. +// +// Preconditions: layout and comp_ids, alt_ids, occupancies are all from same +// mmCIF file and chain_indices are monotonically increasing and less than +// layout.num_chains(). +// +// comp_ids from '_atom_site.label_comp_id'. +// alt_ids from '_atom_site.label_alt_id'. +// occupancies from '_atom_site.occupancy'. +std::vector ResolveMmcifAltLocs( + const MmcifLayout& layout, absl::Span comp_ids, + absl::Span atom_ids, + absl::Span alt_ids, + absl::Span occupancies, + absl::Span chain_indices); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ALTLOCS_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi new file mode 100644 index 000000000..5f0ba34b0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site.pyi @@ -0,0 +1,23 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Callable +from alphafold3.cpp import cif_dict + + +def get_internal_to_author_chain_id_map( + mmcif: cif_dict.CifDict +) -> dict[str,str]: ... + + +def get_or_infer_type_symbol( + mmcif: cif_dict.CifDict, + atom_id_to_type_symbol: Callable[[str, str], str], +) -> list[str]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc new file mode 100644 index 000000000..6037fe08b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.cc @@ -0,0 +1,83 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/log/check.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "pybind11_abseil/absl_casters.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +// If present, returns the _atom_site.type_symbol. If not, infers it using +// _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) +// and the CCD. +py::list GetOrInferTypeSymbol(const CifDict& mmcif, + const py::object& atom_id_to_type_symbol) { + const auto& type_symbol = mmcif["_atom_site.type_symbol"]; + const int num_atom = mmcif["_atom_site.id"].size(); + py::list patched_type_symbol(num_atom); + if (type_symbol.empty()) { + const auto& label_comp_id = mmcif["_atom_site.label_comp_id"]; + const auto& label_atom_id = mmcif["_atom_site.label_atom_id"]; + CHECK_EQ(label_comp_id.size(), num_atom); + CHECK_EQ(label_atom_id.size(), num_atom); + for (int i = 0; i < num_atom; i++) { + patched_type_symbol[i] = + atom_id_to_type_symbol(label_comp_id[i], label_atom_id[i]); + } + } else { + for (int i = 0; i < num_atom; i++) { + patched_type_symbol[i] = type_symbol[i]; + } + } + return patched_type_symbol; +} + +absl::flat_hash_map +GetInternalToAuthorChainIdMap(const CifDict& mmcif) { + const auto& label_asym_ids = mmcif["_atom_site.label_asym_id"]; + const auto& auth_asym_ids = mmcif["_atom_site.auth_asym_id"]; + CHECK_EQ(label_asym_ids.size(), auth_asym_ids.size()); + + absl::flat_hash_map mapping; + for (size_t i = 0, num_rows = label_asym_ids.size(); i < num_rows; ++i) { + // Use only the first internal_chain_id occurrence to generate the mapping. + // It should not matter as there should not be a case where a single + // internal chain ID would map to more than one author chain IDs (i.e. the + // mapping should be injective). Since we need this method to be fast, we + // choose not to check it. + mapping.emplace(label_asym_ids[i], auth_asym_ids[i]); + } + return mapping; +} + +} // namespace + +namespace py = pybind11; + +void RegisterModuleMmcifAtomSite(pybind11::module m) { + m.def("get_or_infer_type_symbol", &GetOrInferTypeSymbol, py::arg("mmcif"), + py::arg("atom_id_to_type_symbol")); + + m.def("get_internal_to_author_chain_id_map", &GetInternalToAuthorChainIdMap, + py::arg("mmcif"), py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h new file mode 100644 index 000000000..1f2104ecf --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_atom_site_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifAtomSite(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_ATOM_SITE_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h new file mode 100644 index 000000000..51c67c528 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.h @@ -0,0 +1,146 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ + +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +namespace alphafold3 { + +// Holds the layout of a parsed mmCIF file. +class MmcifLayout { + public: + MmcifLayout(std::vector chain_ends, + std::vector residues, std::size_t model_offset, + std::size_t num_models) + : chain_ends_(std::move(chain_ends)), + residue_ends_(std::move(residues)), + model_offset_(model_offset), + num_models_(num_models) {} + + // Reads a layout from a valid parsed mmCIF. If a valid model_id is provided + // the offsets will select that model from the mmCIF. + // If no model_id is specified, we calculate the layout of the first model + // only. Therefore it is a requirement that each model has identical atom + // layouts. An error is returned if the atom counts do not between models. + static absl::StatusOr Create(const CifDict& mmcif, + absl::string_view model_id = ""); + + std::string ToDebugString() const; + + // Returns the start index and one past the last residue index of a given + // chain. A chain_index of n refers to the n-th chain in the mmCIF. The + // returned residue indices are 0-based enumerations of residues in the + // _atom_site records, and therefore do not include missing residues. + std::pair residue_range( + std::size_t chain_index) const { + if (chain_index > 0) { + return {chain_ends_[chain_index - 1], chain_ends_[chain_index]}; + } else { + return {0, chain_ends_[0]}; + } + } + + // Returns the start index and one past the last index of a given residue. + // A residue_index of n refers to the n-th residue in the mmCIF, not + // including residues that are unresolved (i.e. only using _atom_site). + std::pair atom_range( + std::size_t residue_index) const { + if (residue_index > 0) { + return {residue_ends_[residue_index - 1], residue_ends_[residue_index]}; + } else { + return {model_offset_, residue_ends_[residue_index]}; + } + } + + // If model_id was provided during construction then this is 1, otherwise + // it is the number of models present in the mmCIF. + std::size_t num_models() const { return num_models_; } + // The number of atoms in the chosen model. + std::size_t num_atoms() const { + return residue_ends_.empty() ? 0 : residue_ends_.back() - model_offset_; + } + // The number of chains in the chosen model. + std::size_t num_chains() const { return chain_ends_.size(); } + // The number of residues in the chosen model, not counting unresolved + // residues. + std::size_t num_residues() const { return residue_ends_.size(); } + + // Returns the first atom index that is part of the specified chain. + // The chain is specified using chain_index, which is a 0-based + // enumeration of the chains in the _atom_site table. + std::size_t atom_site_from_chain_index(std::size_t chain_index) const { + if (chain_index == 0) { + return model_offset_; + } + return atom_site_from_residue_index(chain_ends_[chain_index - 1]); + } + + // Returns the first atom index that is part of the specified residue. + // The residue is specified using residue_index, which is a 0-based + // enumeration of the residues in the _atom_site table. + std::size_t atom_site_from_residue_index(std::size_t residues_index) const { + if (residues_index == 0) { + return model_offset_; + } + return residue_ends_[residues_index - 1]; + } + + // One past last residue index of each chain. The residue index does not + // include unresolved residues and is a simple 0-based enumeration of the + // residues in _atom_site table. + const std::vector& chains() const { return chain_ends_; } + + // Indices of the first atom of each chain. Note that this returns atom + // indices (like residue_starts()), not residue indices (like chains()). + std::vector chain_starts() const; + + // One past last atom index of each residue. + const std::vector& residues() const { return residue_ends_; } + + // Indices of the first atom of each residue. + std::vector residue_starts() const { + std::vector residue_starts; + if (!residue_ends_.empty()) { + residue_starts.reserve(residue_ends_.size()); + residue_starts.push_back(model_offset_); + residue_starts.insert(residue_starts.end(), residue_ends_.begin(), + residue_ends_.end() - 1); + } + return residue_starts; + } + + // The first atom index that is part of the specified model. + std::size_t model_offset() const { return model_offset_; } + + void Filter(absl::Span keep_indices); + + private: + std::vector chain_ends_; + std::vector residue_ends_; + std::size_t model_offset_; + std::size_t num_models_; +}; + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi new file mode 100644 index 000000000..add1b05ea --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout.pyi @@ -0,0 +1,26 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.cpp import cif_dict + +class MmcifLayout: + def atom_range(self, residue_index: int) -> tuple[int, int]: ... + def chain_starts(self) -> list[int]: ... + def chains(self) -> list[int]: ... + def model_offset(self) -> int: ... + def num_atoms(self) -> int: ... + def num_chains(self) -> int: ... + def num_models(self) -> int: ... + def num_residues(self) -> int: ... + def residue_range(self, chain_index: int) -> tuple[int, int]: ... + def residue_starts(self) -> list[int]: ... + def residues(self) -> list[int]: ... + +def from_mmcif(mmcif: cif_dict.CifDict, model_id: str = ...) -> MmcifLayout: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc new file mode 100644 index 000000000..91ad70c0b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_lib.cc @@ -0,0 +1,213 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" + +namespace alphafold3 { + +std::string MmcifLayout::ToDebugString() const { + return absl::StrFormat( + "MmcifLayout(models=%d, chains=%d, num_residues=%d, atoms=%d)", + num_models(), num_chains(), num_residues(), num_atoms()); +} + +// Changes layout to match keep_indices removing empty chains/residues. +void MmcifLayout::Filter(absl::Span keep_indices) { + if (num_chains() == 0) { + return; + } + // Update residue indices. + auto keep_it = absl::c_lower_bound(keep_indices, residue_ends_.front()); + for (auto& residue : residue_ends_) { + while (keep_it != keep_indices.end() && *keep_it < residue) { + ++keep_it; + } + residue = std::distance(keep_indices.begin(), keep_it); + } + // Unique residue_ends_ with updating chains. + auto first = residue_ends_.begin(); + auto tail = first; + std::size_t num_skipped = 0; + std::size_t current = 0; + for (std::size_t& chain_end : chain_ends_) { + for (auto e = residue_ends_.begin() + chain_end; first != e; ++first) { + std::size_t next = *first; + *tail = next; + if (current != next) { + current = next; + ++tail; + } else { + ++num_skipped; + } + } + chain_end -= num_skipped; + } + residue_ends_.erase(tail, residue_ends_.end()); + + current = 0; + chain_ends_.erase(std::remove_if(chain_ends_.begin(), chain_ends_.end(), + [¤t](std::size_t next) { + bool result = current == next; + current = next; + return result; + }), + chain_ends_.end()); + model_offset_ = 0; +} + +absl::StatusOr MmcifLayout::Create(const CifDict& mmcif, + absl::string_view model_id) { + auto model_ids = mmcif["_atom_site.pdbx_PDB_model_num"]; + auto chain_ids = mmcif["_atom_site.label_asym_id"]; // chain ID. + auto label_seq_ids = mmcif["_atom_site.label_seq_id"]; // residue ID. + auto auth_seq_ids = mmcif["_atom_site.auth_seq_id"]; // author residue ID. + auto insertion_codes = mmcif["_atom_site.pdbx_PDB_ins_code"]; + + if (model_ids.size() != chain_ids.size() || + model_ids.size() != label_seq_ids.size() || + (model_ids.size() != auth_seq_ids.size() && !auth_seq_ids.empty()) || + (model_ids.size() != insertion_codes.size() && + !insertion_codes.empty())) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid _atom_site table.", // + " len(_atom_site.pdbx_PDB_model_num): ", model_ids.size(), + " len(_atom_site.label_asym_id): ", chain_ids.size(), + " len(_atom_site.label_seq_id): ", label_seq_ids.size(), + " len(_atom_site.auth_seq_id): ", auth_seq_ids.size(), + " len(_atom_site.pdbx_PDB_ins_code): ", insertion_codes.size())); + } + std::size_t num_atoms = model_ids.size(); + if (num_atoms == 0) { + return MmcifLayout({}, {}, 0, 0); + } + std::size_t model_offset = 0; + std::size_t num_models; + std::size_t num_atoms_per_model; + if (model_id.empty()) { + absl::string_view first_model_id = model_ids.front(); + + // Binary search for where the first model ends. + num_atoms_per_model = std::distance( + model_ids.begin(), + absl::c_upper_bound(model_ids, first_model_id, std::not_equal_to<>{})); + if (num_atoms % num_atoms_per_model != 0) { + return absl::InvalidArgumentError(absl::StrCat( + "Each model must have the same number of atoms: (", num_atoms, " % ", + num_atoms_per_model, " == ", num_atoms % num_atoms_per_model, ").")); + } + num_models = num_atoms / num_atoms_per_model; + // Test boundary conditions for each model hold. + for (std::size_t i = 1; i < num_models; ++i) { + if ((model_ids[i * num_atoms_per_model] != + model_ids[(i + 1) * num_atoms_per_model - 1]) || + (model_ids[i * num_atoms_per_model - 1] == + model_ids[i * num_atoms_per_model])) { + return absl::InvalidArgumentError( + absl::StrCat("Each model must have the same number of atoms: (", + num_atoms, " % ", num_atoms_per_model, + " == ", num_atoms % num_atoms_per_model, ").")); + } + } + } else { + num_models = 1; + model_offset = + std::distance(model_ids.begin(), absl::c_find(model_ids, model_id)); + if (model_offset == model_ids.size()) { + return absl::InvalidArgumentError( + absl::StrCat("Unknown model_id: ", model_id)); + } + model_ids.remove_prefix(model_offset); + chain_ids.remove_prefix(model_offset); + label_seq_ids.remove_prefix(model_offset); + if (!auth_seq_ids.empty()) auth_seq_ids.remove_prefix(model_offset); + if (!insertion_codes.empty()) insertion_codes.remove_prefix(model_offset); + + num_atoms_per_model = std::distance( + model_ids.begin(), std::upper_bound(model_ids.begin(), model_ids.end(), + model_id, std::not_equal_to<>{})); + num_atoms = num_atoms_per_model; + } + std::vector residues; + std::vector chains; + absl::string_view chain_id = chain_ids.front(); + if (!auth_seq_ids.empty() && !insertion_codes.empty()) { + // If author residue IDs are present then these are preferred to + // label residue IDs because they work for multi-residue ligands (which + // are given constant "." label residue IDs). + // NB: Author residue IDs require both the auth_seq_id and the insertion + // code to be unique. + absl::string_view auth_seq_id = auth_seq_ids.front(); + absl::string_view insertion_code = insertion_codes.front(); + for (std::size_t i = 1; i < num_atoms_per_model; ++i) { + if (absl::string_view current_chain_id = chain_ids[i]; + current_chain_id != chain_id) { + residues.push_back(i + model_offset); + chains.push_back(residues.size()); + chain_id = current_chain_id; + auth_seq_id = auth_seq_ids[i]; + insertion_code = insertion_codes[i]; + } else if (absl::string_view current_seq_id = auth_seq_ids[i], + current_insertion_code = insertion_codes[i]; + insertion_code != current_insertion_code || + auth_seq_id != current_seq_id) { + residues.push_back(i + model_offset); + auth_seq_id = current_seq_id; + insertion_code = current_insertion_code; + } + } + } else { + absl::string_view label_seq_id = label_seq_ids.front(); + for (std::size_t i = 1; i < num_atoms_per_model; ++i) { + if (absl::string_view current_chain_id = chain_ids[i]; + current_chain_id != chain_id) { + residues.push_back(i + model_offset); + chains.push_back(residues.size()); + chain_id = current_chain_id; + label_seq_id = label_seq_ids[i]; + } else if (absl::string_view current_seq_id = label_seq_ids[i]; + label_seq_id != current_seq_id) { + residues.push_back(i + model_offset); + label_seq_id = current_seq_id; + } + } + } + residues.push_back(num_atoms_per_model + model_offset); + chains.push_back(residues.size()); + return MmcifLayout(std::move(chains), std::move(residues), model_offset, + num_models); +} + +std::vector MmcifLayout::chain_starts() const { + std::vector chain_starts; + chain_starts.reserve(chain_ends_.size()); + for (std::size_t index = 0; index < chain_ends_.size(); ++index) { + chain_starts.push_back(atom_site_from_chain_index(index)); + } + return chain_starts; +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc new file mode 100644 index 000000000..8eb69befc --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.cc @@ -0,0 +1,49 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include "alphafold3/structure/cpp/mmcif_layout.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { + +namespace py = pybind11; + +void RegisterModuleMmcifLayout(pybind11::module m) { + py::class_(m, "MmcifLayout") + .def("__str__", &MmcifLayout::ToDebugString) + .def("num_models", &MmcifLayout::num_models) + .def("num_chains", &MmcifLayout::num_chains) + .def("num_residues", &MmcifLayout::num_residues) + .def("num_atoms", &MmcifLayout::num_atoms) + .def("residue_range", &MmcifLayout::residue_range, py::arg("chain_index")) + .def("atom_range", &MmcifLayout::atom_range, py::arg("residue_index")) + .def("chains", &MmcifLayout::chains, + py::doc("Returns a list of indices one past the last residue of " + "each chain.")) + .def( + "chain_starts", &MmcifLayout::chain_starts, + py::doc("Returns a list of indices of the first atom of each chain.")) + .def("residues", &MmcifLayout::residues, + py::doc("Returns a list of indices one past the last atom of each " + "residue.")) + .def("residue_starts", &MmcifLayout::residue_starts, + py::doc( + "Returns a list of indices of the first atom of each residue.")) + .def("model_offset", &MmcifLayout::model_offset, + py::doc("Returns the first atom index that is part of the specified " + "model.")); + + m.def("from_mmcif", &MmcifLayout::Create, py::arg("mmcif"), + py::arg("model_id") = ""); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h new file mode 100644 index 000000000..c79b2dd50 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_layout_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifLayout(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_LAYOUT_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h new file mode 100644 index 000000000..821be658d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.h @@ -0,0 +1,34 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" + +namespace alphafold3 { + +// Returns a pair of atom indices for each row in the bonds table (aka +// _struct_conn). The indices are simple 0-based indexes into the columns of +// the _atom_site table in the input mmCIF, and do not necessarily correspond +// to the values in _atom_site.id, or any other column. +absl::StatusOr, std::vector>> +GetBondAtomIndices(const CifDict& mmcif, absl::string_view model_id); + +} // namespace alphafold3 + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi new file mode 100644 index 000000000..d293e666a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn.pyi @@ -0,0 +1,13 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.cpp import cif_dict + +def get_bond_atom_indices(mmcif_dict: cif_dict.CifDict, model_id: str) -> tuple[list[int],list[int]]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc new file mode 100644 index 000000000..afb930fab --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_lib.cc @@ -0,0 +1,380 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn.h" + +namespace alphafold3 { + +namespace { + +struct AtomId { + absl::string_view chain_id; + absl::string_view res_id_1; + absl::string_view res_id_2; + absl::string_view atom_name; + absl::string_view alt_id; + + friend bool operator==(const AtomId&, const AtomId&) = default; + template + friend H AbslHashValue(H h, const AtomId& m) { + return H::combine(std::move(h), m.chain_id, m.res_id_1, m.res_id_2, + m.atom_name, m.alt_id); + } +}; + +using StringArrayRef = absl::Span; +using BondIndexByAtom = absl::flat_hash_map>; +using BondAtomIndices = std::vector; + +// Returns whether each container is the same size. +template +bool AreSameSize(const C& c, const Cs&... cs) { + return ((c.size() == cs.size()) && ...); +} + +struct ColumnSpec { + absl::string_view chain_id_col; + absl::string_view res_id_1_col; + absl::string_view res_id_2_col; + absl::string_view atom_name_col; + std::optional alt_id_col; // Not used by OpenMM. +}; + +class AtomColumns { + public: + static absl::StatusOr Create(const CifDict& mmcif, + const ColumnSpec& column_spec) { + StringArrayRef chain_id = mmcif[column_spec.chain_id_col]; + StringArrayRef res_id_1 = mmcif[column_spec.res_id_1_col]; + StringArrayRef res_id_2 = mmcif[column_spec.res_id_2_col]; + StringArrayRef atom_name = mmcif[column_spec.atom_name_col]; + if (!AreSameSize(chain_id, res_id_1, res_id_2, atom_name)) { + return absl::InvalidArgumentError(absl::StrCat( + "Atom columns are not the same size. ", // + "len(", column_spec.chain_id_col, ")=", chain_id.size(), // + ", len(", column_spec.res_id_1_col, ")=", res_id_1.size(), // + ", len(", column_spec.res_id_2_col, ")=", res_id_2.size(), // + ", len(", column_spec.atom_name_col, ")=", atom_name.size(), // + ".")); + } + if (column_spec.alt_id_col.has_value()) { + StringArrayRef alt_id = mmcif[*column_spec.alt_id_col]; + if (!AreSameSize(alt_id, chain_id)) { + return absl::InvalidArgumentError(absl::StrCat( + "Atom columns are not the same size. ", // + "len(", column_spec.chain_id_col, ")=", chain_id.size(), // + ", len(", *column_spec.alt_id_col, ")=", alt_id.size(), // + ".")); + } + return AtomColumns(chain_id, res_id_1, res_id_2, atom_name, alt_id, + column_spec); + } else { + return AtomColumns(chain_id, res_id_1, res_id_2, atom_name, std::nullopt, + column_spec); + } + } + + inline std::size_t size() const { return size_; } + + absl::string_view GetNormalizedAltId(const std::size_t index) const { + constexpr absl::string_view kFullStop = "."; + if (alt_id_.has_value()) { + absl::string_view alt_id = (*alt_id_)[index]; + return alt_id == "?" ? kFullStop : alt_id; + } else { + return kFullStop; + } + } + + AtomId GetAtom(const std::size_t index) const { + return {.chain_id = chain_id_[index], + .res_id_1 = res_id_1_[index], + .res_id_2 = res_id_2_[index], + .atom_name = atom_name_[index], + .alt_id = GetNormalizedAltId(index)}; + } + + std::string GetAtomString(const std::size_t index) const { + std::string alt_id_col; + if (column_spec_.alt_id_col.has_value()) { + alt_id_col = *column_spec_.alt_id_col; + } else { + alt_id_col = "default label_alt_id"; + } + return absl::StrCat( + column_spec_.chain_id_col, "=", chain_id_[index], ", ", // + column_spec_.res_id_1_col, "=", res_id_1_[index], ", ", // + column_spec_.res_id_2_col, "=", res_id_2_[index], ", ", // + column_spec_.atom_name_col, "=", atom_name_[index], ", ", // + alt_id_col, "=", GetNormalizedAltId(index)); // + } + + private: + AtomColumns(StringArrayRef chain_id, StringArrayRef res_id_1, + StringArrayRef res_id_2, StringArrayRef atom_name, + std::optional alt_id, + const ColumnSpec& column_spec) + : chain_id_(chain_id), + res_id_1_(res_id_1), + res_id_2_(res_id_2), + atom_name_(atom_name), + alt_id_(alt_id), + column_spec_(column_spec), + size_(chain_id.size()) {} + StringArrayRef chain_id_; + StringArrayRef res_id_1_; + StringArrayRef res_id_2_; + StringArrayRef atom_name_; + std::optional alt_id_; + ColumnSpec column_spec_; + std::size_t size_; +}; + +// Adds the atom index to any rows in the bond table involving that atom. +absl::Status FillInBondsForAtom(const BondIndexByAtom& bond_index_by_atom, + const AtomId& atom, + const std::size_t atom_index, + BondAtomIndices& bond_atom_indices) { + if (auto bond_index_it = bond_index_by_atom.find(atom); + bond_index_it != bond_index_by_atom.end()) { + for (std::size_t bond_index : bond_index_it->second) { + if (bond_index < 0 || bond_index >= bond_atom_indices.size()) { + return absl::OutOfRangeError( + absl::StrCat("Bond index out of range: ", bond_index)); + } + bond_atom_indices[bond_index] = atom_index; + } + } + return absl::OkStatus(); +} + +// Checks that the CifDict has all of the columns in the column spec. +bool HasAllColumns(const CifDict& mmcif, const ColumnSpec& columns) { + return mmcif.Contains(columns.chain_id_col) && + mmcif.Contains(columns.res_id_1_col) && + mmcif.Contains(columns.res_id_2_col) && + mmcif.Contains(columns.atom_name_col) && + (!columns.alt_id_col.has_value() || + mmcif.Contains(*columns.alt_id_col)); +} + +// Fully specified ptnr1 atom. +constexpr ColumnSpec kStructConnPtnr1ColumnsFull{ + .chain_id_col = "_struct_conn.ptnr1_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr1_auth_seq_id", + .res_id_2_col = "_struct_conn.pdbx_ptnr1_PDB_ins_code", + .atom_name_col = "_struct_conn.ptnr1_label_atom_id", + .alt_id_col = "_struct_conn.pdbx_ptnr1_label_alt_id", +}; + +// Fully specified ptnr2 atom. +constexpr ColumnSpec kStructConnPtnr2ColumnsFull{ + .chain_id_col = "_struct_conn.ptnr2_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr2_auth_seq_id", + .res_id_2_col = "_struct_conn.pdbx_ptnr2_PDB_ins_code", + .atom_name_col = "_struct_conn.ptnr2_label_atom_id", + .alt_id_col = "_struct_conn.pdbx_ptnr2_label_alt_id", +}; + +// Columns used by OpenMM for ptnr1 atoms. +constexpr ColumnSpec kStructConnPtnr1OpenMM{ + .chain_id_col = "_struct_conn.ptnr1_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr1_label_seq_id", + .res_id_2_col = "_struct_conn.ptnr1_label_comp_id", + .atom_name_col = "_struct_conn.ptnr1_label_atom_id", + .alt_id_col = std::nullopt, +}; + +// Columns used by OpenMM for ptnr2 atoms. +constexpr ColumnSpec kStructConnPtnr2OpenMM{ + .chain_id_col = "_struct_conn.ptnr2_label_asym_id", + .res_id_1_col = "_struct_conn.ptnr2_label_seq_id", + .res_id_2_col = "_struct_conn.ptnr2_label_comp_id", + .atom_name_col = "_struct_conn.ptnr2_label_atom_id", + .alt_id_col = std::nullopt, +}; + +// Fully specified atom sites. +constexpr ColumnSpec kAtomSiteColumnsFull{ + .chain_id_col = "_atom_site.label_asym_id", + .res_id_1_col = "_atom_site.auth_seq_id", + .res_id_2_col = "_atom_site.pdbx_PDB_ins_code", + .atom_name_col = "_atom_site.label_atom_id", + .alt_id_col = "_atom_site.label_alt_id", +}; + +// Atom site columns used to match OpenMM _struct_conn tables. +constexpr ColumnSpec kAtomSiteColumnsOpenMM{ + .chain_id_col = "_atom_site.label_asym_id", + .res_id_1_col = "_atom_site.label_seq_id", + .res_id_2_col = "_atom_site.label_comp_id", + .atom_name_col = "_atom_site.label_atom_id", + .alt_id_col = "_atom_site.label_alt_id", +}; + +} // namespace + +absl::StatusOr> GetBondAtomIndices( + const CifDict& mmcif, absl::string_view model_id) { + ColumnSpec ptnr1_columns, ptnr2_columns, atom_site_columns; + + if (HasAllColumns(mmcif, kStructConnPtnr1ColumnsFull) && + HasAllColumns(mmcif, kStructConnPtnr2ColumnsFull)) { + ptnr1_columns = kStructConnPtnr1ColumnsFull; + ptnr2_columns = kStructConnPtnr2ColumnsFull; + atom_site_columns = kAtomSiteColumnsFull; + } else { + ptnr1_columns = kStructConnPtnr1OpenMM; + ptnr2_columns = kStructConnPtnr2OpenMM; + atom_site_columns = kAtomSiteColumnsOpenMM; + } + + absl::StatusOr ptnr1_atoms = + AtomColumns::Create(mmcif, ptnr1_columns); + if (!ptnr1_atoms.ok()) { + return ptnr1_atoms.status(); + } + absl::StatusOr ptnr2_atoms = + AtomColumns::Create(mmcif, ptnr2_columns); + if (!ptnr2_atoms.ok()) { + return ptnr2_atoms.status(); + } + StringArrayRef struct_conn_id = mmcif["_struct_conn.id"]; + if (!AreSameSize(struct_conn_id, *ptnr1_atoms, *ptnr2_atoms)) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid '_struct_conn.' loop. ", // + "len(id) = ", struct_conn_id.size(), ", ", // + "len(ptnr1_atoms) = ", ptnr1_atoms->size(), ", ", // + "len(ptnr2_atoms) = ", ptnr2_atoms->size(), "." // + )); + } + + absl::StatusOr atoms = + AtomColumns::Create(mmcif, atom_site_columns); + if (!atoms.ok()) { + return atoms.status(); + } + StringArrayRef atom_site_id = mmcif["_atom_site.id"]; + StringArrayRef atom_site_model_id = mmcif["_atom_site.pdbx_PDB_model_num"]; + if (!AreSameSize(atom_site_id, atom_site_model_id, *atoms)) { + return absl::InvalidArgumentError(absl::StrCat( + "Invalid '_atom_site.' loop. ", // + "len(id)= ", atom_site_id.size(), ", ", // + "len(pdbx_PDB_model_num)= ", atom_site_model_id.size(), ", ", // + "len(atoms)= ", atoms->size(), ".")); // + } + + // Build maps from atom ID tuples to the rows in _struct_conn where that + // atom appears (NB could be multiple). + const std::size_t struct_conn_size = struct_conn_id.size(); + BondIndexByAtom ptnr1_rows_by_atom(struct_conn_size); + BondIndexByAtom ptnr2_rows_by_atom(struct_conn_size); + for (std::size_t i = 0; i < struct_conn_size; ++i) { + ptnr1_rows_by_atom[ptnr1_atoms->GetAtom(i)].push_back(i); + ptnr2_rows_by_atom[ptnr2_atoms->GetAtom(i)].push_back(i); + } + + // Allocate two output arrays with one element per row in struct_conn, where + // each element will be the index of that atom in the atom_site table. + // Fill the arrays with atom_site_size, which is an invalid value, so that + // we can check at the end that each atom has been found. + const std::size_t atom_site_size = atom_site_id.size(); + BondAtomIndices ptnr1_atom_indices(struct_conn_size, atom_site_size); + BondAtomIndices ptnr2_atom_indices(struct_conn_size, atom_site_size); + + bool model_id_ecountered = false; + absl::flat_hash_set seen_alt_ids; + for (std::size_t atom_i = 0; atom_i < atom_site_size; ++atom_i) { + if (atom_site_model_id[atom_i] != model_id) { + if (!model_id_ecountered) { + continue; + } else { + // Models are contiguous so once we see a different model ID after + // encountering our model ID then we can exit early. + break; + } + } else { + model_id_ecountered = true; + } + AtomId atom = atoms->GetAtom(atom_i); + seen_alt_ids.insert(atom.alt_id); + + if (auto fill_in_bonds_status1 = FillInBondsForAtom( + ptnr1_rows_by_atom, atom, atom_i, ptnr1_atom_indices); + !fill_in_bonds_status1.ok()) { + return fill_in_bonds_status1; + } + if (auto fill_in_bonds_status2 = FillInBondsForAtom( + ptnr2_rows_by_atom, atom, atom_i, ptnr2_atom_indices); + !fill_in_bonds_status2.ok()) { + return fill_in_bonds_status2; + } + } + // The seen_alt_ids check is a workaround for a known PDB issue: some mmCIFs + // (2evw, 2g0v, 2g0x, 2g0z, 2g10, 2g11, 2g12, 2g14, 2grz, 2ntw as of 2024) + // have multiple models and they set different whole-chain altloc in each + // model. The bond table however doesn't distinguish between models, so there + // are bonds that are valid only for some models. E.g. 2grz has model 1 with + // chain A with altloc A, and model 2 with chain A with altloc B. The bonds + // table lists a bond for each of these. + + // Check that a ptnr1 atom was found for every bond. + if (auto row_it = absl::c_find(ptnr1_atom_indices, atom_site_size); + row_it != ptnr1_atom_indices.end()) { + if (seen_alt_ids.size() > 1 || seen_alt_ids.contains(".") || + seen_alt_ids.contains("?")) { + std::size_t i = std::distance(ptnr1_atom_indices.begin(), row_it); + return absl::InvalidArgumentError( + absl::StrCat("Error parsing \"", mmcif.GetDataName(), "\". ", + "Cannot find atom for bond ID ", struct_conn_id[i], ": ", + ptnr1_atoms->GetAtomString(i))); + } + } + + // Check that a ptnr2 atom was found for every bond. + if (auto row_it = absl::c_find(ptnr2_atom_indices, atom_site_size); + row_it != ptnr2_atom_indices.end()) { + if (seen_alt_ids.size() > 1 || seen_alt_ids.contains(".") || + seen_alt_ids.contains("?")) { + std::size_t i = std::distance(ptnr2_atom_indices.begin(), row_it); + return absl::InvalidArgumentError( + absl::StrCat("Error parsing \"", mmcif.GetDataName(), "\". ", + "Cannot find atom for bond ID ", struct_conn_id[i], ": ", + ptnr2_atoms->GetAtomString(i))); + } + } + + if (!model_id_ecountered) { + return absl::InvalidArgumentError(absl::StrCat( + "Error parsing \"", mmcif.GetDataName(), "\". model_id \"", model_id, + "\" not found in _atom_site.pdbx_PDB_model_num.")); + } + + return std::make_pair(std::move(ptnr1_atom_indices), + std::move(ptnr2_atom_indices)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc new file mode 100644 index 000000000..111715ab5 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.cc @@ -0,0 +1,68 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include "absl/strings/string_view.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_struct_conn.h" +#include "pybind11/gil.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" + +namespace alphafold3 { + +namespace py = pybind11; + +constexpr char kGetBondAtomIndices[] = R"( +Extracts the indices of the atoms that participate in bonds. + +This function has a workaround for a known PDB issue: some mmCIFs have +(2evw, 2g0v, 2g0x, 2g0z, 2g10, 2g11, 2g12, 2g14, 2grz, 2ntw as of 2024) +multiple models and they set different whole-chain altloc in each model. +The bond table however doesn't distinguish between models, so there are +bonds that are valid only for some models. E.g. 2grz has model 1 with +chain A with altloc A, and model 2 with chain A with altloc B. The bonds +table lists a bond for each of these. This case is rather rare (10 cases +in PDB as of 2024). For the offending bonds, the returned atom index is +set to the size of the atom_site table, i.e. it is an invalid index. + +Args: + mmcif: The mmCIF object to process. + model_id: The ID of the model that the returned atoms will belong to. This + should be a value in the mmCIF's _atom_site.pdbx_PDB_model_num column. + +Returns: + Two lists of atom indices, `from_atoms` and `to_atoms`, each one having + length num_bonds (as defined by _struct_conn, the bonds table). The bond + i, defined by the i'th row in _struct_conn, is a bond from atom at index + from_atoms[i], to the atom at index to_atoms[i]. The indices are simple + 0-based indexes into the columns of the _atom_site table in the input + mmCIF, and do not necessarily correspond to the values in _atom_site.id, + or any other column. +)"; + +void RegisterModuleMmcifStructConn(pybind11::module m) { + m.def( + "get_bond_atom_indices", + [](const CifDict& mmcif, absl::string_view model_id) { + auto result = GetBondAtomIndices(mmcif, model_id); + if (result.ok()) { + return *result; + } + throw py::value_error(std::string(result.status().message())); + }, + py::arg("mmcif_dict"), py::arg("model_id"), + py::doc(kGetBondAtomIndices + 1), + py::call_guard()); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h new file mode 100644 index 000000000..acdbf7b77 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_struct_conn_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifStructConn(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_STRUCT_CONN_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi new file mode 100644 index 000000000..aa2dc23e9 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils.pyi @@ -0,0 +1,71 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence + +import numpy as np + +from alphafold3.cpp import cif_dict +from alphafold3.structure.python import mmcif_layout + + +def filter( + mmcif: cif_dict.CifDict, + include_nucleotides: bool, + include_ligands: bool = ..., + include_water: bool = ..., + include_other: bool = ..., + model_id: str = ..., +) -> tuple[np.ndarray[int], mmcif_layout.MmcifLayout]: ... + + +def fix_residues( + layout: mmcif_layout.MmcifLayout, + comp_id: Sequence[str], + atom_id: Sequence[str], + atom_x: Sequence[float], + atom_y: Sequence[float], + atom_z: Sequence[float], + fix_arg: bool = ..., +) -> None: ... + + +def read_layout( + mmcif: cif_dict.CifDict, model_id: str = ... +) -> mmcif_layout.MmcifLayout: ... + + +def selected_ligand_residue_mask( + layout: mmcif_layout.MmcifLayout, + atom_site_label_asym_ids: list[str], + atom_site_label_seq_ids: list[str], + atom_site_auth_seq_ids: list[str], + atom_site_label_comp_ids: list[str], + atom_site_pdbx_pdb_ins_codes: list[str], + nonpoly_asym_ids: list[str], + nonpoly_auth_seq_ids: list[str], + nonpoly_pdb_ins_codes: list[str], + nonpoly_mon_ids: list[str], + branch_asym_ids: list[str], + branch_auth_seq_ids: list[str], + branch_pdb_ins_codes: list[str], + branch_mon_ids: list[str], +) -> tuple[list[bool], list[bool]]: ... + + +def selected_polymer_residue_mask( + layout: mmcif_layout.MmcifLayout, + atom_site_label_asym_ids: list[str], + atom_site_label_seq_ids: list[str], + atom_site_label_comp_ids: list[str], + poly_seq_asym_ids: list[str], + poly_seq_seq_ids: list[str], + poly_seq_mon_ids: list[str], +) -> list[bool]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc new file mode 100644 index 000000000..52bd039b2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.cc @@ -0,0 +1,787 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/memory/memory.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "alphafold3/parsers/cpp/cif_dict_lib.h" +#include "alphafold3/structure/cpp/mmcif_altlocs.h" +#include "alphafold3/structure/cpp/mmcif_layout.h" +#include "pybind11/cast.h" +#include "pybind11/gil.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11/stl.h" +#include "pybind11_abseil/absl_casters.h" + +namespace alphafold3 { +namespace { +namespace py = pybind11; + +struct PyObjectDeleter { + inline void operator()(PyObject* obj) const { Py_CLEAR(obj); } +}; + +using ScopedPyObject = std::unique_ptr; + +using StringArrayRef = absl::Span; +using Indexer = absl::flat_hash_map; + +// Returns the reverse look-up map of name to index. +Indexer MakeIndex(StringArrayRef col) { + Indexer index; + index.reserve(col.size()); + for (std::size_t i = 0; i < col.size(); ++i) { + index[col[i]] = i; + } + return index; +} + +// Returns whether each container is the same size. +template +bool AreSameSize(C c, const Cs&... cs) { + return ((c.size() == cs.size()) && ...); +} + +// Stores references to columns in `_atom_site` ensuring they all exist and +// are the same size. +struct AtomSiteLoop { + explicit AtomSiteLoop(const CifDict& cif_dict) + : id(cif_dict["_atom_site.id"]), + model_id(cif_dict["_atom_site.pdbx_PDB_model_num"]), + chain_id(cif_dict["_atom_site.label_asym_id"]), + seq_id(cif_dict["_atom_site.label_seq_id"]), + + comp_id(cif_dict["_atom_site.label_comp_id"]), + atom_id(cif_dict["_atom_site.label_atom_id"]), + + alt_id(cif_dict["_atom_site.label_alt_id"]), + occupancy(cif_dict["_atom_site.occupancy"]) + + { + if (!AreSameSize(id, model_id, chain_id, seq_id, comp_id, atom_id, alt_id, + occupancy)) { + throw py::value_error( + absl::StrCat("Invalid '_atom_site.' loop. ", // + "len(id)=", id.size(), ", ", // + "len(pdbx_PDB_model_num)=", model_id.size(), ", ", // + "len(label_asym_id)=", chain_id.size(), ", ", // + "len(label_seq_id)=", seq_id.size(), ", ", // + "len(label_comp_id)=", comp_id.size(), ", ", // + "len(atom_id)=", atom_id.size(), ", ", // + "len(label_alt_id)=", alt_id.size(), ", ", // + "len(occupancy)=", occupancy.size())); + } + } + StringArrayRef id; + StringArrayRef model_id; + StringArrayRef chain_id; + StringArrayRef seq_id; + StringArrayRef comp_id; + StringArrayRef atom_id; + StringArrayRef alt_id; + StringArrayRef occupancy; +}; + +// Stores references to columns in `_entity` ensuring they all exist and are the +// same size. +struct EntityLoop { + explicit EntityLoop(const CifDict& cif_dict) + : id(cif_dict["_entity.id"]), type(cif_dict["_entity.type"]) { + if (!AreSameSize(id, type)) { + throw py::value_error(absl::StrCat("Invalid '_entity.' loop. ", // + "len(id)=", id.size(), ", ", // + "len(type)=", type.size())); + } + } + StringArrayRef id; + StringArrayRef type; +}; + +// Stores references to columns in `_entity_poly` ensuring they all exist and +// are the same size. +struct EntityPolyLoop { + explicit EntityPolyLoop(const CifDict& cif_dict) + : entity_id(cif_dict["_entity_poly.entity_id"]), + type(cif_dict["_entity_poly.type"]) { + if (!AreSameSize(entity_id, type)) { + throw py::value_error(absl::StrCat("Invalid '_entity_poly.' loop. ", // + "len(entity_id)=", entity_id.size(), + ", ", // + "len(type)=", type.size())); + } + } + StringArrayRef entity_id; + StringArrayRef type; +}; + +// Returns a set of entity names removing ones not included by the flags +// specified. +absl::flat_hash_set SelectChains(const CifDict& mmcif, + bool include_nucleotides, + bool include_ligands, + bool include_water, + bool include_other) { + EntityLoop entity_loop(mmcif); + EntityPolyLoop entity_poly(mmcif); + absl::flat_hash_set permitted_polymers{"polypeptide(L)"}; + absl::flat_hash_set forbidden_polymers; + for (absl::string_view type : + {"polydeoxyribonucleotide", "polyribonucleotide", + "polydeoxyribonucleotide/polyribonucleotide hybrid"}) { + if (include_nucleotides) { + permitted_polymers.emplace(type); + } else { + forbidden_polymers.emplace(type); + } + } + + absl::flat_hash_set permitted_nonpoly_entity_types; + absl::flat_hash_set forbidden_nonpoly_entity_types; + for (absl::string_view type : {"non-polymer", "branched"}) { + if (include_ligands) { + permitted_nonpoly_entity_types.emplace(type); + } else { + forbidden_nonpoly_entity_types.emplace(type); + } + } + absl::string_view water_type = "water"; + if (include_water) { + permitted_nonpoly_entity_types.emplace(water_type); + } else { + forbidden_nonpoly_entity_types.emplace(water_type); + } + + StringArrayRef chain_ids = mmcif["_struct_asym.id"]; + StringArrayRef entity_ids = mmcif["_struct_asym.entity_id"]; + Indexer chain_index = MakeIndex(chain_ids); + Indexer entity_poly_index = MakeIndex(entity_poly.entity_id); + Indexer entity_id_to_index = MakeIndex(entity_loop.id); + + absl::flat_hash_set keep_chain_id; + for (std::size_t i = 0; i < chain_ids.size(); ++i) { + absl::string_view chain_id = chain_ids[i]; + absl::string_view entity_id = entity_ids[i]; + if (entity_id_to_index.empty() || + entity_loop.type[entity_id_to_index[entity_id]] == "polymer") { + if (auto it = entity_poly_index.find(entity_id); + it != entity_poly_index.end()) { + absl::string_view poly_type = entity_poly.type[it->second]; + if (include_other) { + if (!forbidden_polymers.contains(poly_type)) { + keep_chain_id.insert(chain_id); + } + } else { + if (permitted_polymers.contains(poly_type)) { + keep_chain_id.insert(chain_id); + } + } + } + } else { + absl::string_view entity_type = + entity_loop.type[entity_id_to_index[entity_id]]; + if (include_other) { + if (!forbidden_nonpoly_entity_types.contains(entity_type)) { + keep_chain_id.insert(chain_id); + continue; + } + } else { + if (permitted_nonpoly_entity_types.contains(entity_type)) { + keep_chain_id.insert(chain_id); + continue; + } + } + } + } + return keep_chain_id; +} + +class ProcessResidue { + public: + explicit ProcessResidue(const char* residue) + : residue_(PyUnicode_InternFromString(residue)) {} + bool IsResidue(PyObject* residue) { + return ArePyObjectsEqual(residue_.get(), residue); + } + + static bool ArePyObjectsEqual(PyObject* lhs, PyObject* rhs) { + switch (PyObject_RichCompareBool(lhs, rhs, Py_EQ)) { + case -1: + PyErr_Clear(); + return false; + case 0: + return false; + default: + return true; + } + } + + private: + ScopedPyObject residue_; +}; + +struct Position3 { + float x; + float y; + float z; +}; + +float DistanceSquared(Position3 v1, Position3 v2) { + float dx = v1.x - v2.x; + float dy = v1.y - v2.y; + float dz = v1.z - v2.z; + return dx * dx + dy * dy + dz * dz; +} + +class FixArginine : public ProcessResidue { + public: + FixArginine() + : ProcessResidue("ARG"), + cd_(PyUnicode_InternFromString("CD")), + nh1_(PyUnicode_InternFromString("NH1")), + nh2_(PyUnicode_InternFromString("NH2")), + hh11_(PyUnicode_InternFromString("HH11")), + hh21_(PyUnicode_InternFromString("HH21")), + hh12_(PyUnicode_InternFromString("HH12")), + hh22_(PyUnicode_InternFromString("HH22")) {} + void Fix(absl::Span atom_ids, absl::Span atom_x, + absl::Span atom_y, absl::Span atom_z) { + std::ptrdiff_t cd_index = -1; + std::ptrdiff_t nh1_index = -1; + std::ptrdiff_t nh2_index = -1; + std::ptrdiff_t hh11_index = -1; + std::ptrdiff_t hh21_index = -1; + std::ptrdiff_t hh12_index = -1; + std::ptrdiff_t hh22_index = -1; + for (std::ptrdiff_t index = 0; index < atom_ids.size(); ++index) { + PyObject* atom_id = atom_ids[index]; + if (cd_index == -1 && ArePyObjectsEqual(atom_id, cd_.get())) { + cd_index = index; + } else if (nh1_index == -1 && ArePyObjectsEqual(atom_id, nh1_.get())) { + nh1_index = index; + } else if (nh2_index == -1 && ArePyObjectsEqual(atom_id, nh2_.get())) { + nh2_index = index; + } else if (hh11_index == -1 && ArePyObjectsEqual(atom_id, hh11_.get())) { + hh11_index = index; + } else if (hh21_index == -1 && ArePyObjectsEqual(atom_id, hh21_.get())) { + hh21_index = index; + } else if (hh12_index == -1 && ArePyObjectsEqual(atom_id, hh12_.get())) { + hh12_index = index; + } else if (hh22_index == -1 && ArePyObjectsEqual(atom_id, hh22_.get())) { + hh22_index = index; + } + } + if (cd_index < 0 || nh1_index < 0 || nh2_index < 0) { + return; + } + Position3 cd_pos(atom_x[cd_index], atom_y[cd_index], atom_z[cd_index]); + Position3 nh1_pos(atom_x[nh1_index], atom_y[nh1_index], atom_z[nh1_index]); + Position3 nh2_pos(atom_x[nh2_index], atom_y[nh2_index], atom_z[nh2_index]); + if (DistanceSquared(nh1_pos, cd_pos) <= DistanceSquared(nh2_pos, cd_pos)) { + return; + } + std::swap(atom_ids[nh1_index], atom_ids[nh2_index]); + if (hh11_index >= 0 && hh21_index >= 0) { + std::swap(atom_ids[hh11_index], atom_ids[hh21_index]); + } else if (hh11_index >= 0) { + Py_DECREF(atom_ids[hh11_index]); + Py_INCREF(hh21_.get()); + atom_ids[hh11_index] = hh21_.get(); + } else if (hh21_index >= 0) { + Py_DECREF(atom_ids[hh21_index]); + Py_INCREF(hh11_.get()); + atom_ids[hh21_index] = hh11_.get(); + } + if (hh12_index >= 0 && hh22_index >= 0) { + std::swap(atom_ids[hh12_index], atom_ids[hh22_index]); + } else if (hh12_index >= 0) { + Py_DECREF(atom_ids[hh12_index]); + Py_INCREF(hh22_.get()); + atom_ids[hh12_index] = hh22_.get(); + } else if (hh22_index >= 0) { + Py_DECREF(atom_ids[hh22_index]); + Py_INCREF(hh21_.get()); + atom_ids[hh22_index] = hh21_.get(); + } + } + + private: + ScopedPyObject cd_; + ScopedPyObject nh1_; + ScopedPyObject nh2_; + ScopedPyObject hh11_; + ScopedPyObject hh21_; + ScopedPyObject hh12_; + ScopedPyObject hh22_; +}; + +// Returns the layout of the mmCIF `_atom_site` table. +inline MmcifLayout ReadMmcifLayout(const CifDict& mmcif, + absl::string_view model_id = "") { + py::gil_scoped_release release; + auto mmcif_layout = MmcifLayout::Create(mmcif, model_id); + if (mmcif_layout.ok()) { + return *mmcif_layout; + } + + throw py::value_error(std::string(mmcif_layout.status().message())); +} + +std::pair MmcifFilter( // + const CifDict& mmcif, // + bool include_nucleotides, // + bool include_ligands, // + bool include_water, // + bool include_other, // + absl::string_view model_id) { + if (_import_array() < 0) { + throw py::import_error("Failed to import NumPy."); + } + auto layout = ReadMmcifLayout(mmcif, model_id); + std::unique_ptr> keep_indices; + size_t new_num_atoms; + + { + py::gil_scoped_release release; + + AtomSiteLoop atom_site(mmcif); + + auto keep_chain_ids = + SelectChains(mmcif, include_nucleotides, include_ligands, include_water, + include_other); + + std::vector chain_indices; + chain_indices.reserve(keep_chain_ids.size()); + for (std::size_t i = 0; i < layout.num_chains(); ++i) { + if (keep_chain_ids.contains( + atom_site.chain_id[layout.atom_site_from_chain_index(i)])) { + chain_indices.push_back(i); + } + } + + keep_indices = + absl::WrapUnique(new std::vector(ResolveMmcifAltLocs( + layout, atom_site.comp_id, atom_site.atom_id, atom_site.alt_id, + atom_site.occupancy, chain_indices))); + new_num_atoms = keep_indices->size(); + + if (layout.num_models() > 1) { + keep_indices->reserve(layout.num_models() * new_num_atoms); + std::uint64_t* start = &(*keep_indices->begin()); + std::size_t num_atom = keep_indices->size(); + // Copy first model indices into all model indices offsetting each copy. + for (std::size_t i = 1; i < layout.num_models(); ++i) { + std::size_t offset = i * layout.num_atoms(); + std::transform(start, start + num_atom, + std::back_inserter(*keep_indices), + [offset](std::size_t v) { return v + offset; }); + } + } + } + + layout.Filter(*keep_indices); + + npy_intp shape[] = {static_cast(layout.num_models()), + static_cast(new_num_atoms)}; + PyObject* arr = + PyArray_SimpleNewFromData(2, shape, NPY_INT64, keep_indices->data()); + // Create a capsule to hold the memory of the buffer so NumPy knows how to + // delete it when done with it. + PyObject* capsule = PyCapsule_New( + keep_indices.release(), nullptr, +[](PyObject* capsule_cleanup) { + void* memory = PyCapsule_GetPointer(capsule_cleanup, nullptr); + delete static_cast*>(memory); + }); + PyArray_SetBaseObject(reinterpret_cast(arr), capsule); + + return std::make_pair(py::reinterpret_steal(arr), + std::move(layout)); +} + +void MmcifFixResidues( // + const MmcifLayout& layout, // + absl::Span comp_id, // + absl::Span atom_id, // + absl::Span atom_x, // + absl::Span atom_y, // + absl::Span atom_z, // + bool fix_arginine // +) { + std::optional arginine; + std::size_t num_atoms = layout.num_atoms(); + if (comp_id.size() != num_atoms || atom_id.size() != num_atoms || + atom_x.size() != num_atoms || atom_y.size() != num_atoms || + atom_z.size() != num_atoms) { + throw py::value_error( + absl::StrCat("Sizes must match. ", // + "num_atoms=", num_atoms, ", ", // + "len(comp_id)=", comp_id.size(), ", ", // + "len(atom_id)=", atom_id.size(), ", ", // + "len(atom_x)=", atom_x.size(), ", ", // + "len(atom_y)=", atom_y.size(), ", ", // + "len(atom_z)=", atom_z.size())); + } + + if (fix_arginine) { + arginine.emplace(); + } + if (!arginine.has_value()) { + return; + } + + for (std::size_t res_index = 0; res_index < layout.num_residues(); + ++res_index) { + auto [atom_start, atom_end] = layout.atom_range(res_index); + std::size_t atom_count = atom_end - atom_start; + PyObject* resname = comp_id[atom_start]; + if (arginine.has_value() && arginine->IsResidue(resname)) { + arginine->Fix(atom_id.subspan(atom_start, atom_count), + atom_x.subspan(atom_start, atom_count), + atom_y.subspan(atom_start, atom_count), + atom_z.subspan(atom_start, atom_count)); + } + } +} + +std::vector SelectedPolymerResidueMask( + const MmcifLayout& layout, + const std::vector& atom_site_label_asym_ids, // + const std::vector& atom_site_label_seq_ids, // + const std::vector& atom_site_label_comp_ids, // + const std::vector& poly_seq_asym_ids, // + const std::vector& poly_seq_seq_ids, // + const std::vector& poly_seq_mon_ids // +) { + absl::flat_hash_map, + absl::string_view> + selected; + selected.reserve(layout.num_residues()); + // layout.residues() is O(1) while layout.residue_starts() is O(num_res). + const std::vector& residue_starts = layout.residue_starts(); + for (int i = 0; i < layout.residues().size(); ++i) { + std::size_t res_start = residue_starts[i]; + std::size_t res_end = layout.residues()[i]; + if (res_start == res_end) { + continue; // Skip empty residues (containing no atoms). + } + + absl::string_view label_seq_id = atom_site_label_seq_ids[i]; + if (label_seq_id == ".") { + continue; // Skip non-polymers. + } + + absl::string_view label_asym_id = atom_site_label_asym_ids[i]; + absl::string_view label_comp_id = atom_site_label_comp_ids[i]; + selected[std::make_pair(label_asym_id, label_seq_id)] = label_comp_id; + } + + std::vector mask; + mask.reserve(poly_seq_mon_ids.size()); + for (int i = 0; i < poly_seq_mon_ids.size(); ++i) { + absl::string_view poly_seq_asym_id = poly_seq_asym_ids[i]; + absl::string_view poly_seq_seq_id = poly_seq_seq_ids[i]; + absl::string_view poly_seq_mon_id = poly_seq_mon_ids[i]; + + auto it = selected.find(std::make_pair(poly_seq_asym_id, poly_seq_seq_id)); + if (it != selected.end()) { + mask.push_back(it->second == poly_seq_mon_id); + } else { + mask.push_back(true); // Missing residues are not heterogeneous. + } + } + return mask; +} + +std::pair, std::vector> SelectedLigandResidueMask( + const MmcifLayout& layout, // + const std::vector& atom_site_label_asym_ids, // + const std::vector& atom_site_label_seq_ids, // + const std::vector& atom_site_auth_seq_ids, // + const std::vector& atom_site_label_comp_ids, // + const std::vector& atom_site_pdbx_pdb_ins_codes, // + const std::vector& nonpoly_asym_ids, // + const std::vector& nonpoly_auth_seq_ids, // + const std::vector& nonpoly_pdb_ins_codes, // + const std::vector& nonpoly_mon_ids, // + const std::vector& branch_asym_ids, // + const std::vector& branch_auth_seq_ids, // + const std::vector& branch_pdb_ins_codes, // + const std::vector& branch_mon_ids) { + absl::flat_hash_map< + std::tuple, + absl::string_view> + selected; + selected.reserve(layout.num_residues()); + // layout.residues() is O(1) while layout.residue_starts() is O(num_res). + const std::vector& residue_starts = layout.residue_starts(); + for (int i = 0; i < layout.residues().size(); ++i) { + std::size_t res_start = residue_starts[i]; + std::size_t res_end = layout.residues()[i]; + if (res_start == res_end) { + continue; // Skip empty residues (containing no atoms). + } + + absl::string_view label_seq_id = atom_site_label_seq_ids[i]; + if (label_seq_id != ".") { + continue; // Skip polymers. + } + + absl::string_view label_asym_id = atom_site_label_asym_ids[i]; + absl::string_view auth_seq_id = atom_site_auth_seq_ids[i]; + absl::string_view ins_code = atom_site_pdbx_pdb_ins_codes[i]; + ins_code = ins_code == "?" ? "." : ins_code; // Remap unknown to unset. + absl::string_view label_comp_id = atom_site_label_comp_ids[i]; + selected[std::make_tuple(label_asym_id, auth_seq_id, ins_code)] = + label_comp_id; + } + + std::vector nonpoly_mask; + nonpoly_mask.reserve(nonpoly_asym_ids.size()); + for (int i = 0; i < nonpoly_asym_ids.size(); ++i) { + absl::string_view nonpoly_asym_id = nonpoly_asym_ids[i]; + absl::string_view nonpoly_auth_seq_id = nonpoly_auth_seq_ids[i]; + absl::string_view nonpoly_ins_code = nonpoly_pdb_ins_codes[i]; + // Remap unknown to unset. + nonpoly_ins_code = nonpoly_ins_code == "?" ? "." : nonpoly_ins_code; + absl::string_view nonpoly_mon_id = nonpoly_mon_ids[i]; + + auto it = selected.find(std::make_tuple( + nonpoly_asym_id, nonpoly_auth_seq_id, nonpoly_ins_code)); + if (it != selected.end()) { + nonpoly_mask.push_back(it->second == nonpoly_mon_id); + } else { + nonpoly_mask.push_back(true); // Missing residues are not heterogeneous. + } + } + + std::vector branch_mask; + branch_mask.reserve(branch_asym_ids.size()); + for (int i = 0; i < branch_asym_ids.size(); ++i) { + absl::string_view branch_asym_id = branch_asym_ids[i]; + absl::string_view branch_auth_seq_id = branch_auth_seq_ids[i]; + + // Insertion codes in _pdbx_branch_scheme are not required and can be + // missing. Default to unset ('.') in such case. + absl::string_view branch_ins_code; + if (i < branch_pdb_ins_codes.size()) { + branch_ins_code = branch_pdb_ins_codes[i]; + // Remap unknown to unset. + branch_ins_code = branch_ins_code == "?" ? "." : branch_ins_code; + } else { + branch_ins_code = "."; + } + + absl::string_view branch_mon_id = branch_mon_ids[i]; + + auto it = selected.find( + std::make_tuple(branch_asym_id, branch_auth_seq_id, branch_ins_code)); + if (it != selected.end()) { + branch_mask.push_back(it->second == branch_mon_id); + } else { + branch_mask.push_back(true); // Missing residues are not heterogeneous. + } + } + + return std::make_pair(nonpoly_mask, branch_mask); +} + +constexpr char kReadMmcifLayout[] = R"( +Returns the layout of the cif_dict. + +Args: + mmcif: mmCIF to calculate the layout for. + model_id: If non-empty the layout of the given model is returned + otherwise the layout of all models are returned. +Raises: + ValueError: if the mmCIF is malformed or the number of atoms in each + model are inconsistent. +)"; + +constexpr char kMmcifFilter[] = R"( +Returns NumpyArray of selected rows in `_atom_site` and new layout. + +Args: + mmcif: mmCIF to filter. + include_nucleotides: Whether to include polymer entities of type: + "polypeptide(L)\", "polydeoxyribonucleotide", "polyribonucleotide". + Otherwise only "polypeptide(L)\". ("polypeptide(D)\" is never included.) + include_ligands: Whether to include non-polymer entities of type: + "non-polymer", "branched". + include_water: Whether to include entities of type water. + include_other: Whether to include other (non-standard) entity types + that are not covered by any of the above parameters. + model_id: If non-empty the model with given name is selected otherwise + all models are selected. + +Returns: + A tuple containing a numpy array with a shape (num_models, num_atoms) + with the atom_site indices selected and the new layout. + +Raises: + ValueError error if mmCIF dict does not have all required fields. +)"; + +constexpr char kMmcifFixResidues[] = R"( +Fixes residue columns in-place. + +Args: + layout: layout from filter command. + comp_id: '_atom_site.label_comp_id' of first model. + group: '_atom_site.group_PDB' of first model. + atom_id: '_atom_site.label_atom_id' of first model. + type_symbol: '_atom_site.type_symbol' of first model. + atom_x: '_atom_site.Cartn_x' of first model. + atom_y: '_atom_site.Cartn_y' of first model. + atom_z: '_atom_site.Cartn_z' of first model. + fix_mse: Whether to convert MSE residues into MET residues. + fix_arg: Whether to ensure the atoms in ARG are in the correct order. + fix_unknown_dna: Whether to convert DNA residues from N to DN. + dna_mask: Which atoms are from DNA chains. + +Raises: + ValueError: If shapes are invalid. +)"; + +constexpr char kSelectedPolymerResidueMask[] = R"( +Returns a _pdbx_poly_seq_scheme mask for selected hetero residues. + +Should be called after filtering the layout using mmcif_utils.filter. + +Args: + layout: Layout defining the _atom_site residue selection. + atom_site_label_asym_ids: Internal (label) chain ID, per selected residue. + atom_site_label_seq_ids: Internal (label) residue ID, per selected residue. + atom_site_label_comp_ids: Residue name, per selected residue. + poly_seq_asym_ids: Internal (label) chain ID, per residue. + poly_seq_seq_ids: Internal (label) residue ID, per residue. + poly_seq_mon_ids: Residue name, per residue. + +Returns: + A mask for the _pdbx_poly_seq_scheme table. If residues are selected + using this mask, they will have consistent heterogeneous residue + selection with the _atom_site table. +)"; + +constexpr char kSelectedLigandResidueMask[] = R"( +Returns masks for selected ligand hetero residues. + +Should be called after filtering the layout using mmcif_utils.filter. + +Args: + layout: Layout defining the _atom_site residue selection. + atom_site_label_asym_ids: Internal (label) chain ID, per selected residue. + atom_site_label_seq_ids: Internal (author) residue ID, per selected residue. + atom_site_auth_seq_ids: External (author) residue ID, per selected residue. + atom_site_label_comp_ids: Residue name, per selected residue. + atom_site_pdbx_pdb_ins_codes: Insertion code, per selected residue. + nonpoly_asym_ids: Internal (label) chain ID, per residue from + _pdbx_nonpoly_scheme. + nonpoly_auth_seq_ids: External (author) residue ID, per residue from + _pdbx_nonpoly_scheme. + nonpoly_pdb_ins_codes: Residue name, per residue from + _pdbx_nonpoly_scheme. + nonpoly_mon_ids: Insertion code, per residue from _pdbx_nonpoly_scheme. + branch_asym_ids: Internal (label) chain ID, per residue from + _pdbx_branch_scheme. + branch_auth_seq_ids: External (author) residue ID, per residue from + _pdbx_branch_scheme. + branch_pdb_ins_codes: Residue name, per residue from _pdbx_branch_scheme. + branch_mon_ids: Insertion code, per residue from _pdbx_branch_scheme. + +Returns: + A tuple with masks for _pdbx_nonpoly_scheme and _pdbx_branch_scheme. If + residues are selected using these masks, they will have consistent + heterogeneous residue selection with the _atom_site table. +)"; + +} // namespace + +void RegisterModuleMmcifUtils(pybind11::module m) { + m.def("read_layout", ReadMmcifLayout, + py::arg("mmcif"), // + py::arg("model_id") = "", // + py::doc(kReadMmcifLayout + 1) // + ); + + m.def("filter", MmcifFilter, // + py::arg("mmcif"), // + py::arg("include_nucleotides"), // + py::arg("include_ligands") = false, // + py::arg("include_water") = false, // + py::arg("include_other") = false, // + py::arg("model_id") = "", // + py::doc(kMmcifFilter + 1) // + ); + + m.def("fix_residues", MmcifFixResidues, + py::arg("layout"), // + py::arg("comp_id"), // + py::arg("atom_id"), // + py::arg("atom_x"), // + py::arg("atom_y"), // + py::arg("atom_z"), // + py::arg("fix_arg") = false, // + py::doc(kMmcifFixResidues + 1) // + ); + + m.def("selected_polymer_residue_mask", SelectedPolymerResidueMask, + py::arg("layout"), // + py::arg("atom_site_label_asym_ids"), // + py::arg("atom_site_label_seq_ids"), // + py::arg("atom_site_label_comp_ids"), // + py::arg("poly_seq_asym_ids"), // + py::arg("poly_seq_seq_ids"), // + py::arg("poly_seq_mon_ids"), // + py::call_guard(), // + py::doc(kSelectedPolymerResidueMask + 1) // + ); + + m.def("selected_ligand_residue_mask", SelectedLigandResidueMask, + py::arg("layout"), // + py::arg("atom_site_label_asym_ids"), // + py::arg("atom_site_label_seq_ids"), // + py::arg("atom_site_auth_seq_ids"), // + py::arg("atom_site_label_comp_ids"), // + py::arg("atom_site_pdbx_pdb_ins_codes"), // + py::arg("nonpoly_asym_ids"), // + py::arg("nonpoly_auth_seq_ids"), // + py::arg("nonpoly_pdb_ins_codes"), // + py::arg("nonpoly_mon_ids"), // + py::arg("branch_asym_ids"), // + py::arg("branch_auth_seq_ids"), // + py::arg("branch_pdb_ins_codes"), // + py::arg("branch_mon_ids"), // + py::call_guard(), // + py::doc(kSelectedLigandResidueMask + 1) // + ); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h new file mode 100644 index 000000000..7ba19420b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/mmcif_utils_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleMmcifUtils(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_MMCIF_UTILS_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi new file mode 100644 index 000000000..b4b76c27f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array.pyi @@ -0,0 +1,50 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from collections.abc import Sequence +from typing import Any, overload + +import numpy as np + + +def format_float_array( + values: Sequence[float], num_decimal_places: int +) -> list[str]: ... + + +def isin( + array: np.ndarray[object], + test_elements: set[str | bytes], + *, + invert: bool = ..., +) -> np.ndarray[bool]: ... + + +@overload +def remap( + array: np.ndarray[object], + mapping: dict[str, str], + default_value: str, + inplace: bool = ..., +) -> np.ndarray[object]: ... + + +@overload +def remap( + array: np.ndarray[object], + mapping: dict[str, str], + inplace: bool = ..., +) -> np.ndarray[object]: ... + + +def remap_multiple( + arrays: Sequence[np.ndarray[object]], + mapping: dict[tuple[Any], int], +) -> np.ndarray[int]: ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc new file mode 100644 index 000000000..29fac727a --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.cc @@ -0,0 +1,329 @@ +// Copyright 2024 DeepMind Technologies Limited +// +// AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +// this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +// +// To request access to the AlphaFold 3 model parameters, follow the process set +// out at https://github.com/google-deepmind/alphafold3. You may only use these +// if received directly from Google. Use is subject to terms of use available at +// https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "numpy/arrayobject.h" +#include "numpy/ndarrayobject.h" +#include "numpy/ndarraytypes.h" +#include "numpy/npy_common.h" +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "pybind11/cast.h" +#include "pybind11/numpy.h" +#include "pybind11/pybind11.h" +#include "pybind11/pytypes.h" +#include "pybind11_abseil/absl_casters.h" + +namespace { + +namespace py = pybind11; + +PyObject* RemapNumpyArrayObjects(PyObject* array, PyObject* mapping, + bool inplace, PyObject* default_value) { + import_array(); + if (!PyArray_Check(array)) { + PyErr_SetString(PyExc_TypeError, "'array' must be a np.ndarray."); + return nullptr; + } + if (!PyDict_Check(mapping)) { + PyErr_SetString(PyExc_TypeError, "'mapping' must be a Python dict."); + return nullptr; + } + + PyArrayObject* array_obj = reinterpret_cast(array); + if (PyArray_TYPE(array_obj) != NPY_OBJECT) { + PyErr_SetString(PyExc_TypeError, "`array` must be an array of objects."); + return nullptr; + } + + if (inplace) { + // We are returning original array so we need to increase the ref count. + Py_INCREF(array); + } else { + // We are returning a fresh copy. + array = PyArray_NewCopy(array_obj, NPY_CORDER); + if (array == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + return nullptr; + } + array_obj = reinterpret_cast(array); + } + + if (PyArray_SIZE(array_obj) == 0) { + return array; + } + + if (default_value == nullptr && PyDict_Size(mapping) == 0) { + return array; + } + + NpyIter* iter = NpyIter_New( + array_obj, NPY_ITER_READWRITE | NPY_ITER_EXTERNAL_LOOP | NPY_ITER_REFS_OK, + NPY_KEEPORDER, NPY_NO_CASTING, nullptr); + if (iter == nullptr) { + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + Py_XDECREF(array); + return nullptr; + } + + NpyIter_IterNextFunc* iter_next = NpyIter_GetIterNext(iter, nullptr); + if (iter_next == nullptr) { + NpyIter_Deallocate(iter); + Py_XDECREF(array); + PyErr_SetString(PyExc_MemoryError, "Out of memory!"); + return nullptr; + } + + // Iterating arrays taken from: + // https://numpy.org/doc/stable/reference/c-api/iterator.html + char** data_pointer = NpyIter_GetDataPtrArray(iter); + npy_intp* stride_pointer = NpyIter_GetInnerStrideArray(iter); + npy_intp* inner_size_pointer = NpyIter_GetInnerLoopSizePtr(iter); + do { + char* data = *data_pointer; + npy_intp stride = *stride_pointer; + npy_intp count = *inner_size_pointer; + for (size_t i = 0; i < count; ++i) { + PyObject* entry; + std::memcpy(&entry, data, sizeof(PyObject*)); + PyObject* result = PyDict_GetItem(mapping, entry); + if (result != nullptr) { + // Replace entry. + Py_INCREF(result); + Py_XDECREF(entry); + std::memcpy(data, &result, sizeof(PyObject*)); + } else if (default_value != nullptr) { + // Replace entry with a default value. + Py_INCREF(default_value); + Py_XDECREF(entry); + std::memcpy(data, &default_value, sizeof(PyObject*)); + } + data += stride; + } + } while (iter_next(iter)); + + NpyIter_Deallocate(iter); + return array; +} + +// Convert 1D Numpy float array to a list of strings where each string has fixed +// number of decimal points. This is faster than Python list comprehension. +std::vector FormatFloatArray(absl::Span values, + int num_decimal_places) { + std::vector output; + output.reserve(values.size()); + + absl::c_transform(values, std::back_inserter(output), + [num_decimal_places](float value) { + return absl::StrFormat("%.*f", num_decimal_places, value); + }); + return output; +} + +py::array_t IsIn( + const py::array_t& array, + const absl::flat_hash_set& test_elements, bool invert) { + const size_t num_elements = array.size(); + py::array_t output(num_elements); + std::fill(output.mutable_data(), output.mutable_data() + output.size(), + invert); + + // Shortcut: The output will be trivially always false if test_elements empty. + if (test_elements.empty()) { + return output; + } + + for (size_t i = 0; i < num_elements; ++i) { + // Compare the string values instead of comparing just object pointers. + py::handle handle = array.data()[i]; + if (!PyUnicode_Check(handle.ptr()) && !PyBytes_Check(handle.ptr())) { + continue; + } + if (test_elements.contains(py::cast(handle))) { + output.mutable_data()[i] = !invert; + } + } + if (array.ndim() > 1) { + auto shape = + std::vector(array.shape(), array.shape() + array.ndim()); + return output.reshape(shape); + } + return output; +} + +py::array RemapMultipleArrays( + const std::vector>& arrays, + const py::dict& mapping) { + size_t array_size = arrays[0].size(); + for (const auto& array : arrays) { + if (array.size() != array_size) { + throw py::value_error("All arrays must have the same length."); + } + } + + // Create a result buffer. + auto result = py::array_t(array_size); + absl::Span result_buffer(result.mutable_data(), array_size); + PyObject* entry = PyTuple_New(arrays.size()); + if (entry == nullptr) { + throw py::error_already_set(); + } + std::vector> array_spans; + array_spans.reserve(arrays.size()); + for (const auto& array : arrays) { + array_spans.emplace_back(array.data(), array.size()); + } + + // Iterate over arrays and look up elements in the `py_dict`. + bool fail = false; + for (size_t i = 0; i < array_size; ++i) { + for (size_t j = 0; j < array_spans.size(); ++j) { + PyTuple_SET_ITEM(entry, j, array_spans[j][i]); + } + PyObject* result = PyDict_GetItem(mapping.ptr(), entry); + if (result != nullptr) { + int64_t result_value = PyLong_AsLongLong(result); + if (result_value == -1 && PyErr_Occurred()) { + fail = true; + break; + } + if (result_value > std::numeric_limits::max() || + result_value < std::numeric_limits::lowest()) { + PyErr_SetString(PyExc_OverflowError, "Result value too large."); + fail = true; + break; + } + result_buffer[i] = result_value; + } else { + PyErr_Format(PyExc_KeyError, "%R", entry); + fail = true; + break; + } + } + + for (size_t j = 0; j < array_spans.size(); ++j) { + PyTuple_SET_ITEM(entry, j, nullptr); + } + Py_XDECREF(entry); + if (fail) { + throw py::error_already_set(); + } + return result; +} + +constexpr char kRemapNumpyArrayObjects[] = R"( +Replace objects in NumPy array of objects using mapping. + +Args: + array: NumPy array with dtype=object. + mapping: Dict mapping old values to new values. + inplace: Bool (default False) whether to replace values inplace or to + create a new array. + default_value: If given, what value to map to if the mapping is missing + for that particular item. If not given, such items are left unchanged. + +Returns + NumPy array of dtype object with values replaced according to mapping. + If inplace is True the original array is modified inplace otherwise a + new array is returned. +)"; + +constexpr char kFormatFloatArrayDoc[] = R"( +Converts float -> string array with given number of decimal places. +)"; + +constexpr char kIsInDoc[] = R"( +Computes whether each element is in test_elements. + +Same use as np.isin, but much faster. If len(array) = n, len(test_elements) = m: +* This function has complexity O(n). +* np.isin with arrays of objects has complexity O(m*log(m) + n * log(m)). + +Args: + array: Input NumPy array with dtype=object. + test_elements: The values against which to test each value of array. + invert: If True, the values in the returned array are inverted, as if + calculating `element not in test_elements`. Default is False. + `isin(a, b, invert=True)` is equivalent to but faster than `~isin(a, b)`. + +Returns + A boolean array of the same shape as the input array. Each value `val` is: + * `val in test_elements` if `invert=False`, + * `val not in test_elements` if `invert=True`. +)"; + +constexpr char kRemapMultipleDoc[] = R"( +Maps keys from multiple aligned arrays to a single array. + +Args: + arrays: Numpy arrays of the same length. The tuple of aligned entries is used + as key for the mapping. + mapping: Dict mapping from tuples to integer values. + +Returns + NumPy array of dtype `int` with values looked up in mapping according to the + tuple of aligned array entries as keys. +)"; + +} // namespace + +namespace alphafold3 { + +void RegisterModuleStringArray(pybind11::module m) { + m.def( + "remap", + [](py::object array, py::object mapping, bool inplace, + py::object default_value) -> py::object { + PyObject* result = RemapNumpyArrayObjects(array.ptr(), mapping.ptr(), + inplace, default_value.ptr()); + if (result == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(result); + }, + py::return_value_policy::take_ownership, py::arg("array"), + py::arg("mapping"), py::arg("inplace") = false, py::arg("default_value"), + py::doc(kRemapNumpyArrayObjects + 1)); + m.def( + "remap", + [](py::object array, py::object mapping, bool inplace) -> py::object { + PyObject* result = RemapNumpyArrayObjects(array.ptr(), mapping.ptr(), + inplace, nullptr); + if (result == nullptr) { + throw py::error_already_set(); + } + return py::reinterpret_steal(result); + }, + py::return_value_policy::take_ownership, py::arg("array"), + py::arg("mapping"), py::arg("inplace") = false, + py::doc(kRemapNumpyArrayObjects + 1)); + m.def("format_float_array", &FormatFloatArray, py::arg("values"), + py::arg("num_decimal_places"), py::doc(kFormatFloatArrayDoc + 1), + py::call_guard()); + m.def("isin", &IsIn, py::arg("array"), py::arg("test_elements"), + py::kw_only(), py::arg("invert") = false, py::doc(kIsInDoc + 1)); + m.def("remap_multiple", &RemapMultipleArrays, py::arg("arrays"), + py::arg("mapping"), py::doc(kRemapMultipleDoc + 1)); +} + +} // namespace alphafold3 diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h new file mode 100644 index 000000000..85790ddd8 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/cpp/string_array_pybind.h @@ -0,0 +1,24 @@ +/* + * Copyright 2024 DeepMind Technologies Limited + * + * AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of + * this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ + * + * To request access to the AlphaFold 3 model parameters, follow the process set + * out at https://github.com/google-deepmind/alphafold3. You may only use these + * if received directly from Google. Use is subject to terms of use available at + * https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + */ + +#ifndef ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ +#define ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ + +#include "pybind11/pybind11.h" + +namespace alphafold3 { + +void RegisterModuleStringArray(pybind11::module m); + +} + +#endif // ALPHAFOLD3_SRC_ALPHAFOLD3_STRUCTURE_PYTHON_STRING_ARRAY_PYBIND_H_ diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py new file mode 100644 index 000000000..d1b71c028 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/mmcif.py @@ -0,0 +1,333 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Low level mmCIF parsing operations and wrappers for nicer C++/Py errors. + +Note that the cif_dict.CifDict class has many useful methods to help with data +extraction which are not shown in this file. You can find them in cif_dict.clif +together with docstrings. The cif_dict.CifDict class behaves like an immutable +Python dictionary (some methods are not implemented though). +""" +from collections.abc import Callable, Mapping, Sequence +import functools +import itertools +import re +from typing import ParamSpec, TypeAlias, TypeVar + +from alphafold3.constants import chemical_components +from alphafold3.cpp import cif_dict +from alphafold3.cpp import mmcif_atom_site +from alphafold3.cpp import mmcif_struct_conn +from alphafold3.cpp import string_array +import numpy as np + +Mmcif = cif_dict.CifDict + + +_P = ParamSpec('_P') +_T = TypeVar('_T') +_WappedFn: TypeAlias = Callable[_P, _T] + + +@functools.lru_cache(maxsize=256) +def int_id_to_str_id(num: int) -> str: + """Encodes a number as a string, using reverse spreadsheet style naming. + + Args: + num: A positive integer. + + Returns: + A string that encodes the positive integer using reverse spreadsheet style, + naming e.g. 1 = A, 2 = B, ..., 27 = AA, 28 = BA, 29 = CA, ... This is the + usual way to encode chain IDs in mmCIF files. + """ + if num <= 0: + raise ValueError(f'Only positive integers allowed, got {num}.') + + num = num - 1 # 1-based indexing. + output = [] + while num >= 0: + output.append(chr(num % 26 + ord('A'))) + num = num // 26 - 1 + return ''.join(output) + + +@functools.lru_cache(maxsize=256) +def str_id_to_int_id(str_id: str) -> int: + """Encodes an mmCIF-style string chain ID as an integer. + + The integer IDs are one based so this function is the inverse of + int_id_to_str_id. + + Args: + str_id: A string chain ID consisting only of upper case letters A-Z. + + Returns: + An integer that can be used to order mmCIF chain IDs in the standard + (reverse spreadsheet style) ordering. + """ + if not re.match('^[A-Z]+$', str_id): + raise ValueError( + f'String ID must be upper case letters, got {str_id}.') + + offset = ord('A') - 1 + output = 0 + for i, c in enumerate(str_id): + output += (ord(c) - offset) * int(26**i) + return output + + +def from_string(mmcif_string: str | bytes) -> Mmcif: + return cif_dict.from_string(mmcif_string) + + +def parse_multi_data_cif(cif_string: str) -> dict[str, Mmcif]: + """Parses a CIF string with multiple data records. + + For instance, the CIF string: + + ``` + data_001 + _foo bar + # + data_002 + _foo baz + ``` + + is parsed as: + + ``` + {'001': Mmcif({'_foo': ['bar']}), '002': Mmcif({'_foo': ['baz']})} + ``` + + Args: + cif_string: The multi-data CIF string to be parsed. + + Returns: + A dictionary mapping record names to Mmcif objects with data. + """ + return cif_dict.parse_multi_data_cif(cif_string) + + +def tokenize(mmcif_string: str) -> list[str]: + return cif_dict.tokenize(mmcif_string) + + +def split_line(line: str) -> list[str]: + return cif_dict.split_line(line) + + +class BondParsingError(Exception): + """Exception raised by errors when getting bond atom indices.""" + + +def get_bond_atom_indices( + mmcif: Mmcif, + model_id: str = '1', +) -> tuple[Sequence[int], Sequence[int]]: + """Extracts the indices of the atoms that participate in bonds. + + Args: + mmcif: The mmCIF object to process. + model_id: The ID of the model that the returned atoms will belong to. This + should be a value in the mmCIF's _atom_site.pdbx_PDB_model_num column. + + Returns: + Two lists of atom indices, `from_atoms` and `to_atoms`, each one having + length num_bonds (as defined by _struct_conn, the bonds table). The bond + i, defined by the i'th row in _struct_conn, is a bond from atom at index + from_atoms[i], to the atom at index to_atoms[i]. The indices are simple + 0-based indexes into the columns of the _atom_site table in the input + mmCIF, and do not necessarily correspond to the values in _atom_site.id, + or any other column. + + Raises: + BondParsingError: If any of the required tables or columns are not present + in + the mmCIF, or if the _struct_conn table refers to atoms that cannot + be found in the _atom_site table. + """ + try: + return mmcif_struct_conn.get_bond_atom_indices(mmcif, model_id) + except ValueError as e: + raise BondParsingError(str(e)) from e + + +def get_or_infer_type_symbol( + mmcif: Mmcif, ccd: chemical_components.Ccd | None = None +) -> Sequence[str]: + """Returns the type symbol (element) for all of the atoms. + + Args: + mmcif: A parsed mmCIF file in the Mmcif format. + ccd: The chemical component dictionary. If not provided, defaults to the + cached CCD. + + If present, returns the _atom_site.type_symbol. If not, infers it using + _atom_site.label_comp_id (residue name), _atom_site.label_atom_id (atom name) + and the CCD. + """ + ccd = ccd or chemical_components.cached_ccd() + + def type_symbol_fn(res_name, atom_name): return chemical_components.type_symbol( + ccd, res_name, atom_name + ) + return mmcif_atom_site.get_or_infer_type_symbol(mmcif, type_symbol_fn) + + +def get_chain_type_by_entity_id(mmcif: Mmcif) -> Mapping[str, str]: + """Returns mapping from entity ID to its type or polymer type if available. + + If the entity is in the _entity_poly table, returns its polymer chain type. + If not, returns the type as specified in the _entity table. + + Args: + mmcif: CifDict holding the mmCIF. + """ + poly_entity_id = mmcif.get('_entity_poly.entity_id', []) + poly_type = mmcif.get('_entity_poly.type', []) + poly_type_by_entity_id = dict(zip(poly_entity_id, poly_type, strict=True)) + + chain_type_by_entity_id = {} + for entity_id, entity_type in zip( + mmcif.get('_entity.id', []), mmcif.get('_entity.type', []), strict=True + ): + chain_type = poly_type_by_entity_id.get(entity_id) or entity_type + chain_type_by_entity_id[entity_id] = chain_type + + return chain_type_by_entity_id + + +def get_internal_to_author_chain_id_map(mmcif: Mmcif) -> Mapping[str, str]: + """Returns a mapping from internal chain ID to the author chain ID. + + Note that this is not a bijection. One author chain ID can map to multiple + internal chain IDs. For example, a protein chain and a ligand bound to it will + share the same author chain ID, but they will each have a unique internal + chain ID). + + Args: + mmcif: CifDict holding the mmCIF. + """ + return mmcif_atom_site.get_internal_to_author_chain_id_map(mmcif) + + +def get_experimental_method(mmcif: Mmcif) -> str | None: + field = '_exptl.method' + return ','.join(mmcif[field]).lower() if field in mmcif else None + + +def get_release_date(mmcif: Mmcif) -> str | None: + """Returns the oldest revision date.""" + if '_pdbx_audit_revision_history.revision_date' not in mmcif: + return None + + # Release dates are ISO-8601, hence sort well. + return min(mmcif['_pdbx_audit_revision_history.revision_date']) + + +def get_resolution(mmcif: Mmcif) -> float | None: + """Returns the resolution of the structure. + + More than one resolution can be reported in an mmCIF. This function returns + the first one (in the order _refine.ls_d_res_high, + _em_3d_reconstruction.resolution, _reflns.d_resolution_high) that appears + in the mmCIF as is parseable as a float. + + Args: + mmcif: An `Mmcif` object. + + Returns: + The resolution as reported in the mmCIF. + """ + for res_key in ('_refine.ls_d_res_high', + '_em_3d_reconstruction.resolution', + '_reflns.d_resolution_high'): + if res_key in mmcif: + try: + raw_resolution = mmcif[res_key][0] + return float(raw_resolution) + except ValueError: + continue + return None + + +def parse_oper_expr(oper_expression: str) -> list[tuple[str, ...]]: + """Determines which transforms to apply based on an MMCIF oper_expression str. + + Args: + oper_expression: the field oper_expression from MMCIF format data. + Transform ids may be either numbers or single letters. Hyphens are used to + denote a numeric range of transforms to apply, and commas are used to + delimit a sequence of transforms. Where two sets of parentheses are + adjacent without a comma, the two sets of transforms should be combined as + a cartesian product, i.e. all possible pairs. + example 1,2,3 -> generate 3 copies of each chain by applying 1, 2 or 3. + example (1-3) -> generate 3 copies of each chain by applying 1, 2 or 3. + example (1-3)(4-6) -> generate 9 copies of each chain by applying one of + [(1,4), (1,5), (1,6), + (2,4), (2,5), (2,6), + (3,4), (3,5), (3,6)] + example (P) -> apply transform with id P. + + Raises: + ValueError: Failure to parse oper_expression. + + Returns: + A list with one element for each chain copy that should be generated. + Each element is a list of transform ids to apply. + """ + # Expand ranges, e.g. 1-4 -> 1,2,3,4. + def range_expander(match): + return ','.join( + [str(i) for i in range(int(match.group(1)), + int(match.group(2)) + 1)]) + + ranges_expanded = re.sub(r'\b(\d+)-(\d+)', range_expander, oper_expression) + + if re.fullmatch(r'(\w+,)*\w+', ranges_expanded): + # No brackets, just a single range, e.g. "1,2,3". + return [(t,) for t in ranges_expanded.split(',')] + elif re.fullmatch(r'\((\w+,)*\w+\)', ranges_expanded): + # Single range in brackets, e.g. "(1,2,3)". + return [(t,) for t in ranges_expanded[1:-1].split(',')] + elif re.fullmatch(r'\((\w+,)*\w+\)\((\w+,)*\w+\)', ranges_expanded): + # Cartesian product of two ranges, e.g. "(1,2,3)(4,5)". + part1, part2 = ranges_expanded[1:-1].split(')(') + return list(itertools.product(part1.split(','), part2.split(','))) + else: + raise ValueError( + f'Unsupported oper_expression format: {oper_expression}') + + +def format_float_array( + values: np.ndarray, num_decimal_places: int) -> Sequence[str]: + """Converts 1D array to a list of strings with the given number of decimals. + + This function is faster than converting via Python list comprehension, e.g.: + atoms_x = ['%.3f' % x for x in atoms_x] + + Args: + values: A numpy array with values to convert. This array is casted to + float32 before doing the conversion. + num_decimal_places: The number of decimal points to keep, including trailing + zeros. E.g. for 1.07 and num_decimal_places=1: 1.1, + num_decimal_places=2: 1.07, num_decimal_places=3: 1.070. + + Returns: + A list of formatted strings. + """ + if values.ndim != 1: + raise ValueError(f'The given array must be 1D, got {values.ndim}D') + + return string_array.format_float_array( + values=values.astype(np.float32), num_decimal_places=num_decimal_places + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py new file mode 100644 index 000000000..f466cf6a1 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/parsing.py @@ -0,0 +1,1806 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Module for parsing various data sources and producing Structures.""" + +from collections.abc import Collection, Mapping, MutableMapping, Sequence +import dataclasses +import datetime +import enum +import functools +import itertools +from typing import TypeAlias + +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import mmcif_utils +from alphafold3.cpp import string_array +from alphafold3.structure import bioassemblies +from alphafold3.structure import bonds +from alphafold3.structure import chemical_components as struc_chem_comps +from alphafold3.structure import mmcif +from alphafold3.structure import structure +from alphafold3.structure import structure_tables +import numpy as np + + +ChainIndex: TypeAlias = int +ResIndex: TypeAlias = int +AtomName: TypeAlias = str +BondAtomId: TypeAlias = tuple[ChainIndex, ResIndex, AtomName] + +_INSERTION_CODE_REMAP: Mapping[str, str] = {'.': '?'} + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class BondIndices: + from_indices: list[int] + dest_indices: list[int] + + +@enum.unique +class ModelID(enum.Enum): + """Values for specifying model IDs when parsing.""" + + FIRST = 1 # The first model in the file. + ALL = 2 # All models in the file. + + +@enum.unique +class SequenceFormat(enum.Enum): + """The possible formats for an input sequence.""" + + FASTA = 'fasta' # One-letter code used in FASTA. + # Multiple-letter chemical components dictionary ids. + CCD_CODES = 'ccd_codes' + LIGAND_SMILES = 'ligand_smiles' # SMILES string defining a molecule. + + +def _create_bond_lookup( + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]], +) -> Mapping[tuple[ChainIndex, ResIndex], Mapping[AtomName, BondIndices]]: + """Creates maps to help find bonds during a loop over residues.""" + bond_lookup = {} + for bond_i, (from_atom_id, dest_atom_id) in enumerate(bonded_atom_pairs): + from_chain_i, from_res_i, from_atom_name = from_atom_id + dest_chain_i, dest_res_i, dest_atom_name = dest_atom_id + bonds_by_from_atom_name = bond_lookup.setdefault( + (from_chain_i, from_res_i), {} + ) + bonds_by_dest_atom_name = bond_lookup.setdefault( + (dest_chain_i, dest_res_i), {} + ) + bonds_by_from_atom_name.setdefault( + from_atom_name, BondIndices(from_indices=[], dest_indices=[]) + ).from_indices.append(bond_i) + bonds_by_dest_atom_name.setdefault( + dest_atom_name, BondIndices(from_indices=[], dest_indices=[]) + ).dest_indices.append(bond_i) + return bond_lookup + + +def _get_atom_element( + ccd: chemical_components.Ccd, res_name: str, atom_name: str +) -> str: + return ( + chemical_components.type_symbol( + ccd=ccd, res_name=res_name, atom_name=atom_name + ) + or '?' + ) + + +def _get_representative_atom( + ccd: chemical_components.Ccd, + res_name: str, + chain_type: str, + sequence_format: SequenceFormat, +) -> tuple[str, str]: + match sequence_format: + case SequenceFormat.CCD_CODES: + atom_name = _get_first_non_leaving_atom(ccd=ccd, res_name=res_name) + atom_element = _get_atom_element( + ccd=ccd, res_name=res_name, atom_name=atom_name + ) + return atom_name, atom_element + case SequenceFormat.LIGAND_SMILES: + return '', '?' + case SequenceFormat.FASTA: + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + return 'CA', 'C' + if chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + return "C1'", 'C' + else: + raise ValueError(chain_type) + case _: + raise ValueError(sequence_format) + + +@functools.lru_cache(maxsize=128) +def _get_first_non_leaving_atom( + ccd: chemical_components.Ccd, res_name: str +) -> str: + """Returns first definitely non-leaving atom if exists, as a stand-in.""" + all_atoms = struc_chem_comps.get_all_atoms_in_entry(ccd, res_name=res_name)[ + '_chem_comp_atom.atom_id' + ] + representative_atom = all_atoms[0] + if representative_atom == 'O1' and len(all_atoms) > 1: + representative_atom = all_atoms[1] + return representative_atom + + +def _add_ligand_to_chem_comp( + chem_comp: MutableMapping[str, struc_chem_comps.ChemCompEntry], + ligand_id: str, + ligand_smiles: str, +): + """Adds a ligand to chemical components. Raises ValueError on mismatch.""" + new_entry = struc_chem_comps.ChemCompEntry( + type='non-polymer', pdbx_smiles=ligand_smiles + ) + + existing_entry = chem_comp.get(ligand_id) + if existing_entry is None: + chem_comp[ligand_id] = new_entry + elif existing_entry != new_entry: + raise ValueError( + f'Mismatching data for ligand {ligand_id}: ' + f'{new_entry} != {existing_entry}' + ) + + +def _get_first_model_id(cif: mmcif.Mmcif) -> str: + """Returns cheaply the first model ID from the mmCIF.""" + return cif.get_array( + '_atom_site.pdbx_PDB_model_num', dtype=object, gather=slice(1) + )[0] + + +def _get_str_model_id( + cif: mmcif.Mmcif, + model_id: ModelID | int, +) -> str: + """Converts a user-specified model_id argument into a string.""" + match model_id: + case int(): + str_model_id = str(model_id) + case enum.Enum(): + # We compare the enum's value attribute since regular enum comparison + # breaks when adhoc importing. + match model_id.value: + case ModelID.FIRST.value: + str_model_id = _get_first_model_id(cif) + case ModelID.ALL.value: + str_model_id = '' + case _: + raise ValueError( + f'Model ID {model_id} with value {model_id.value} not recognized.' + ) + case _: + raise ValueError( + f'Model ID {model_id} with type {type(model_id)} not recognized.' + ) + return str_model_id + + +def _parse_bonds( + cif: mmcif.Mmcif, + atom_key: np.ndarray, + model_id: str, +) -> bonds.Bonds: + """Returns the bonds table extracted from the mmCIF. + + Args: + cif: The raw mmCIF to extract the bond information from. + atom_key: A numpy array defining atom key for each atom in _atom_site. Note + that the atom key must be computed before resolving alt-locs since this + function operates on the raw mmCIF! + model_id: The ID of the model to get bonds for. + """ + if '_struct_conn.id' not in cif: + # This is the category key item for the _struct_conn table, therefore + # we use it to determine whether to parse bond info. + return bonds.Bonds.make_empty() + from_atom, dest_atom = mmcif.get_bond_atom_indices(cif, model_id) + from_atom = np.array(from_atom, dtype=np.int64) + dest_atom = np.array(dest_atom, dtype=np.int64) + num_bonds = from_atom.shape[0] + bond_key = np.arange(num_bonds, dtype=np.int64) + bond_type = cif.get_array('_struct_conn.conn_type_id', dtype=object) + if '_struct_conn.pdbx_role' in cif: # This column isn't always present. + bond_role = cif.get_array('_struct_conn.pdbx_role', dtype=object) + else: + bond_role = np.full((num_bonds,), '?', dtype=object) + + bonds_mask = np.ones((num_bonds,), dtype=bool) + # Symmetries other than 1_555 imply the atom is not part of the asymmetric + # unit, and therefore this is a bond that only exists in the expanded + # bioassembly. + # We do not currently support parsing these types of bonds. + if '_struct_conn.ptnr1_symmetry' in cif: + ptnr1_symmetry = cif.get_array( + '_struct_conn.ptnr1_symmetry', dtype=object) + np.logical_and(bonds_mask, ptnr1_symmetry == '1_555', out=bonds_mask) + if '_struct_conn.ptnr2_symmetry' in cif: + ptnr2_symmetry = cif.get_array( + '_struct_conn.ptnr2_symmetry', dtype=object) + np.logical_and(bonds_mask, ptnr2_symmetry == '1_555', out=bonds_mask) + # Remove bonds that involve atoms that are not part of the structure, + # e.g. waters if include_water=False. In a rare case this also removes invalid + # bonds that are indicated by a key that is set to _atom_site size. + np.logical_and(bonds_mask, np.isin(from_atom, atom_key), out=bonds_mask) + np.logical_and(bonds_mask, np.isin(dest_atom, atom_key), out=bonds_mask) + return bonds.Bonds( + key=bond_key[bonds_mask], + type=bond_type[bonds_mask], + role=bond_role[bonds_mask], + from_atom_key=from_atom[bonds_mask], + dest_atom_key=dest_atom[bonds_mask], + ) + + +@dataclasses.dataclass(frozen=True, slots=True) +class _MmcifHeader: + name: str + resolution: float | None + release_date: datetime.date | None + structure_method: str | None + bioassembly_data: bioassemblies.BioassemblyData | None + chemical_components_data: struc_chem_comps.ChemicalComponentsData | None + + +def _get_mmcif_header( + cif: mmcif.Mmcif, + fix_mse: bool, + fix_unknown_dna: bool, +) -> _MmcifHeader: + """Extract header fields from an mmCIF object.""" + name = cif.get_data_name() + resolution = mmcif.get_resolution(cif) + + release_date = mmcif.get_release_date(cif) + if release_date is not None: + release_date = datetime.date.fromisoformat(release_date) + + experiments = cif.get('_exptl.method') + structure_method = ','.join(experiments) if experiments else None + + try: + bioassembly_data = bioassemblies.BioassemblyData.from_mmcif(cif) + except bioassemblies.MissingBioassemblyDataError: + bioassembly_data = None + + try: + chemical_components_data = ( + struc_chem_comps.ChemicalComponentsData.from_mmcif( + cif, fix_mse=fix_mse, fix_unknown_dna=fix_unknown_dna + ) + ) + except struc_chem_comps.MissingChemicalComponentsDataError: + chemical_components_data = None + + return _MmcifHeader( + name=name, + resolution=resolution, + release_date=release_date, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chemical_components_data, + ) + + +def from_parsed_mmcif( + mmcif_object: mmcif.Mmcif, + *, + name: str | None = None, + fix_mse_residues: bool = False, + fix_arginines: bool = False, + fix_unknown_dna: bool = False, + include_water: bool = False, + include_other: bool = False, + include_bonds: bool = False, + model_id: int | ModelID = ModelID.FIRST, +) -> structure.Structure: + """Construct a Structure from a parsed mmCIF object. + + This function is called by `from_mmcif` but can be useful when an mmCIF has + already been parsed e.g. to extract extra information from the header before + then converting to Structure for further manipulation. + + Args: + mmcif_object: A parsed mmcif.Mmcif object. + name: Optional name for the structure. If not provided, the name will be + taken from the mmCIF data_ field. + fix_mse_residues: If True, selenium atom sites (SE) in selenomethionine + (MSE) residues will be changed to sulphur atom sites (SD). This is because + methionine (MET) residues are often replaced with MSE to aid X-Ray + crystallography. If False, the SE MSE atom sites won't be modified. + fix_arginines: If True, NH1 and NH2 in arginine will be swapped if needed so + that NH1 is always closer to CD than NH2. If False, no atom sites in + arginine will be touched. Note that HH11, HH12, HH21, HH22 are fixed too. + fix_unknown_dna: If True, residues with name N in DNA chains will have their + res_name replaced with DN. Atoms are not changed. + include_water: If True, water (HOH) molecules will be parsed. Water + molecules may be grouped into chains, where number of residues > 1. Water + molecules are usually grouped into chains but do not necessarily all share + the same chain ID. + include_other: If True, all other atoms that are not included by any of the + above parameters will be included. This covers e.g. "polypeptide(D)" and + "macrolide" entities, as well as all other non-standard types. + include_bonds: If True, bond information will be parsed from the mmCIF and + stored in the Structure. + model_id: Either the integer model ID to parse, or one of ModelID.FIRST to + parse the first model, or ModelID.ALL to parse all models. + + Returns: + A Structure representation of the mmCIF object. + """ + str_model_id = _get_str_model_id(cif=mmcif_object, model_id=model_id) + header = _get_mmcif_header( + mmcif_object, fix_mse=fix_mse_residues, fix_unknown_dna=fix_unknown_dna + ) + + chains, residues, atoms = get_tables( + cif=mmcif_object, + fix_mse_residues=fix_mse_residues, + fix_arginines=fix_arginines, + fix_unknown_dna=fix_unknown_dna, + include_water=include_water, + include_other=include_other, + model_id=str_model_id, + ) + + if include_bonds: + # NB: parsing the atom table before the bonds table allows for a more + # informative error message when dealing with bad multi-model mmCIFs. + # We also ensure that we always use a specific model ID, even when parsing + # all models. + if str_model_id == '': # pylint: disable=g-explicit-bool-comparison + bonds_model_id = _get_first_model_id(mmcif_object) + else: + bonds_model_id = str_model_id + + bonds_table = _parse_bonds( + mmcif_object, + atom_key=atoms.key, + model_id=bonds_model_id, + ) + else: + bonds_table = bonds.Bonds.make_empty() + + return structure.Structure( + name=name if name is not None else header.name, + resolution=header.resolution, + release_date=header.release_date, + structure_method=header.structure_method, + bioassembly_data=header.bioassembly_data, + chemical_components_data=header.chemical_components_data, + bonds=bonds_table, + chains=chains, + residues=residues, + atoms=atoms, + ) + + +def from_mmcif( + mmcif_string: str | bytes, + *, + name: str | None = None, + fix_mse_residues: bool = False, + fix_arginines: bool = False, + fix_unknown_dna: bool = False, + include_water: bool = False, + include_other: bool = False, + include_bonds: bool = False, + model_id: int | ModelID = ModelID.FIRST, +) -> structure.Structure: + """Construct a Structure from a mmCIF string. + + Args: + mmcif_string: The string contents of an mmCIF file. + name: Optional name for the structure. If not provided, the name will be + taken from the mmCIF data_ field. + fix_mse_residues: If True, selenium atom sites (SE) in selenomethionine + (MSE) residues will be changed to sulphur atom sites (SD). This is because + methionine (MET) residues are often replaced with MSE to aid X-Ray + crystallography. If False, the SE MSE atom sites won't be modified. + fix_arginines: If True, NH1 and NH2 in arginine will be swapped if needed so + that NH1 is always closer to CD than NH2. If False, no atom sites in + arginine will be touched. Note that HH11, HH12, HH21, HH22 are fixed too. + fix_unknown_dna: If True, residues with name N in DNA chains will have their + res_name replaced with DN. Atoms are not changed. + include_water: If True, water (HOH) molecules will be parsed. Water + molecules may be grouped into chains, where number of residues > 1. Water + molecules are usually grouped into chains but do not necessarily all share + the same chain ID. + include_other: If True, all other atoms that are not included by any of the + above parameters will be included. This covers e.g. "polypeptide(D)" and + "macrolide" entities, as well as all other non-standard types. + include_bonds: If True, bond information will be parsed from the mmCIF and + stored in the Structure. + model_id: Either the integer model ID to parse, or one of ModelID.FIRST to + parse the first model, or ModelID.ALL to parse all models. + + Returns: + A Structure representation of the mmCIF string. + """ + mmcif_object = mmcif.from_string(mmcif_string) + + return from_parsed_mmcif( + mmcif_object, + name=name, + fix_mse_residues=fix_mse_residues, + fix_arginines=fix_arginines, + fix_unknown_dna=fix_unknown_dna, + include_water=include_water, + include_other=include_other, + include_bonds=include_bonds, + model_id=model_id, + ) + + +def from_res_arrays(atom_mask: np.ndarray, **kwargs) -> structure.Structure: + """Returns Structure created from from arrays with a residue dimension. + + All unset fields are filled with defaults (e.g. 1.0 for occupancy) or + unset/unknown values (e.g. UNK for residue type, or '.' for atom element). + + Args: + atom_mask: A array with shape (num_res, num_atom). This is used to decide + which atoms in the atom dimension are present in a given residue. Present + atoms should have a nonzero value, e.g. 1.0 or True. + **kwargs: A mapping from field name to values. For all array-valued fields + these arrays must have a dimension of length num_res. Chain and residue + fields should have this as their only dimension and atom fields should be + shaped (num_res, num_atom). Coordinate fields may also have arbitrary + leading dimensions (they must be the same across all coordinate fields). + See structure.{CHAIN,RESIDUE,ATOM}_FIELDS for a list of allowed fields. + """ + num_res, num_atom = atom_mask.shape + included_indices = np.flatnonzero(atom_mask) + + array_fields = ( + structure.CHAIN_FIELDS.keys() + | structure.RESIDUE_FIELDS.keys() + | structure.ATOM_FIELDS.keys() + ) + initializer_kwargs = {} + fields = {} + for k, val in kwargs.items(): + if k not in array_fields: + # The kwarg key isn't an array field name. Such kwargs are forwarded as-is + # to the constructor. They are expected to be global fields (e.g. name). + # Other values will raise an error when the constructor is called. + if k in structure.TABLE_FIELDS: + raise ValueError(f'Table fields must not be set. Got {k}.') + initializer_kwargs[k] = val + continue + elif val is None: + raise ValueError(f'{k} must be non-None.') + + if not isinstance(val, np.ndarray): + raise TypeError( + f'Value for {k} must be a NumPy array. Got {type(val)}.') + if k in structure.CHAIN_FIELDS or k in structure.RESIDUE_FIELDS: + if val.shape != (num_res,): + raise ValueError( + f'{k} must have shape ({num_res=},). Got {val.shape=}.' + ) + # Do not reshape the chain/residue arrays, they have the shape we need. + fields[k] = val + else: + assert k in structure.ATOM_FIELDS + if val.shape[-2:] != (num_res, num_atom): + raise ValueError( + f'{k} must have final two dimensions of length ' + f'{(num_res, num_atom)=}. Got {val.shape=}.' + ) + leading_dims = val.shape[:-2] + flat_val = val.reshape(leading_dims + (-1,), order='C') + masked_val = flat_val[..., included_indices] + fields[k] = masked_val + + # Get chain IDs or assume this is a single-chain structure. + chain_id = kwargs.get('chain_id', np.array(['A'] * num_res, dtype=object)) + # Find chain starts in res-sized arrays, use these to make chain-sized arrays. + chain_start = np.concatenate( + ([0], np.where(chain_id[1:] != chain_id[:-1])[0] + 1) + ) + if len(set(chain_id)) != len(chain_start): + raise ValueError(f'Chain IDs must be contiguous, but got {chain_id}') + + chain_lengths = np.diff(chain_start, append=len(chain_id)) + chain_key = np.repeat(np.arange(len(chain_start)), chain_lengths) + + chain_entity_id = fields.get('chain_entity_id') + if chain_entity_id is not None: + entity_id = chain_entity_id[chain_entity_id] + else: + entity_id = np.array( + [str(mmcif.str_id_to_int_id(cid)) + for cid in chain_id[chain_start]], + dtype=object, + ) + chain_str_empty = np.full((num_res,), '.', dtype=object) + chains_table = structure_tables.Chains( + key=chain_key[chain_start], + id=chain_id[chain_start], + type=fields.get('chain_type', chain_str_empty)[chain_start], + auth_asym_id=fields.get('chain_auth_asym_id', chain_id)[chain_start], + entity_id=entity_id, + entity_desc=fields.get('chain_entity_desc', chain_str_empty)[ + chain_start], + ) + + # Since all arrays are residue-shaped, we can use them directly. + res_key = np.arange(num_res, dtype=np.int64) + res_id = fields.get('res_id', res_key + 1).astype(np.int32) + residues_table = structure_tables.Residues( + key=res_key, + chain_key=chain_key, + id=res_id, + name=fields.get('res_name', np.full(num_res, 'UNK', dtype=object)), + auth_seq_id=fields.get( + 'res_auth_seq_id', np.char.mod('%d', res_id).astype(object) + ), + insertion_code=fields.get( + 'res_insertion_code', np.full(num_res, '?', dtype=object) + ), + ) + + # The atom-sized arrays have already been masked and reshaped. + num_atoms_per_res = np.sum(atom_mask, axis=1, dtype=np.int32) + num_atoms_total = np.sum(num_atoms_per_res, dtype=np.int32) + # Structure is immutable, so use the same array multiple times to save RAM. + atom_str_empty = np.full(num_atoms_total, '.', dtype=object) + atom_float32_zeros = np.zeros(num_atoms_total, dtype=np.float32) + atom_float32_ones = np.ones(num_atoms_total, dtype=np.float32) + atoms_table = structure_tables.Atoms( + key=np.arange(num_atoms_total, dtype=np.int64), + chain_key=np.repeat(chain_key, num_atoms_per_res), + res_key=np.repeat(res_key, num_atoms_per_res), + name=fields.get('atom_name', atom_str_empty), + element=fields.get('atom_element', atom_str_empty), + x=fields.get('atom_x', atom_float32_zeros), + y=fields.get('atom_y', atom_float32_zeros), + z=fields.get('atom_z', atom_float32_zeros), + b_factor=fields.get('atom_b_factor', atom_float32_zeros), + occupancy=fields.get('atom_occupancy', atom_float32_ones), + ) + + return structure.Structure( + chains=chains_table, + residues=residues_table, + atoms=atoms_table, + bonds=structure_tables.Bonds.make_empty(), # Currently not set. + **initializer_kwargs, + ) + + +def expand_sequence( + sequence: str, chain_type: str, sequence_format: SequenceFormat +) -> Sequence[str]: + """Returns full residue names based on a sequence string. + + Args: + sequence: A string representing the sequence. + chain_type: The chain type of the sequence. + sequence_format: The format of the sequence argument. + """ + match sequence_format: + case SequenceFormat.FASTA: + if not all(c.isalpha() for c in sequence): + raise ValueError( + f'Sequence "{sequence}" has non-alphabetic characters') + match chain_type: + case mmcif_names.PROTEIN_CHAIN: + res_name_map = residue_names.PROTEIN_COMMON_ONE_TO_THREE + default_res_name = residue_names.UNK + case mmcif_names.RNA_CHAIN: + res_name_map = {r: r for r in residue_names.RNA_TYPES} + default_res_name = residue_names.UNK_RNA + case mmcif_names.DNA_CHAIN: + res_name_map = residue_names.DNA_COMMON_ONE_TO_TWO + default_res_name = residue_names.UNK_DNA + case _: + raise ValueError( + f'{chain_type=} not supported for FASTA format.') + return [ + res_name_map.get(one_letter_res, default_res_name) + for one_letter_res in sequence + ] + case SequenceFormat.CCD_CODES: + return sequence.strip('()').split(')(') + case SequenceFormat.LIGAND_SMILES: + ligand_id, _ = sequence.split(':', maxsplit=1) + return [ligand_id] + + +def from_sequences_and_bonds( + sequences: Sequence[str], + chain_types: Sequence[str], + sequence_formats: Sequence[SequenceFormat], + bonded_atom_pairs: Sequence[tuple[BondAtomId, BondAtomId]] | None, + ccd: chemical_components.Ccd, + name: str = 'from_sequences_and_bonds', + bond_type: str | None = None, + **constructor_args, +) -> structure.Structure: + """Returns a minimal structure for the input sequences and bonds. + + The returned structure will have at least one atom per residue. If the + residue has any bonded atoms, according to `bonded_atom_pairs`, then + all (and only) those atoms will be present for that residue. If the residue + is not involved in any bond then an arbitrary atom will be created. + + Args: + sequences: A sequence of strings, each one representing a single chain. + chain_types: The types of each chain, e.g. polypeptide(L). The n-th element + describes the n-th sequence in `sequences`. + sequence_formats: The format of each sequence. The n-th element describes + the n-th sequence in `sequences`. + bonded_atom_pairs: A sequence of bonded atom pairs. Each atom is described + as a tuple of (chain_index, res_index, atom_name), where the first two + values are 0-based indices. The chain_index is the index of the chain in + the `sequences` argument, and the res_index is the index of the residue in + that sequence. The atom_name is the name of the atom in the residue, e.g. + CA. If the atom is not found in the standard atoms for that residue + (according to the CCD) then an error is raised. + ccd: The chemical components dictionary. + name: A name for the returned structure. + bond_type: This type will be used for all bonds in the structure, where type + follows PDB scheme, e.g. unknown (?), hydrog, metalc, covale, disulf. + **constructor_args: These arguments are passed directly to the + structure.Structure constructor. + """ + chain_id = [] + chain_type = [] + chain_res_count = [] + res_id = [] + res_name = [] + res_atom_count = [] + atom_name = [] + atom_element = [] + chem_comp = {} + + num_bonds = len(bonded_atom_pairs or ()) + from_atom_key = np.full((num_bonds,), -1, dtype=np.int64) + dest_atom_key = np.full((num_bonds,), -1, dtype=np.int64) + + # Create map (chain_i, res_i) -> {atom_name -> (from_idxs dest_idxs)}. + # This allows quick lookup of whether a residue has any bonded atoms, and + # which bonds those atoms participate in. + bond_lookup = _create_bond_lookup(bonded_atom_pairs or ()) + + current_atom_key = 0 + for chain_i, (sequence, curr_chain_type, sequence_format) in enumerate( + zip(sequences, chain_types, sequence_formats, strict=True) + ): + current_chain_id = mmcif.int_id_to_str_id(chain_i + 1) + num_chain_residues = 0 + for res_i, full_res_name in enumerate( + expand_sequence(sequence, curr_chain_type, sequence_format) + ): + current_res_id = res_i + 1 + num_res_atoms = 0 + + # Look for bonded atoms in the bond lookup and if any are found, add + # their atom keys to the bond atom_key columns. + if bond_indices_by_atom_name := bond_lookup.get((chain_i, res_i)): + for bond_atom_name, bond_indices in bond_indices_by_atom_name.items(): + atom_name.append(bond_atom_name) + atom_element.append( + _get_atom_element( + ccd=ccd, res_name=full_res_name, atom_name=bond_atom_name + ) + ) + for from_bond_i in bond_indices.from_indices: + from_atom_key[from_bond_i] = current_atom_key + for dest_bond_i in bond_indices.dest_indices: + dest_atom_key[dest_bond_i] = current_atom_key + current_atom_key += 1 + num_res_atoms += 1 + else: + # If this residue has no bonded atoms then we need to add one atom + # like in from_sequences. + assert num_res_atoms == 0 + rep_atom_name, rep_atom_element = _get_representative_atom( + ccd=ccd, + res_name=full_res_name, + chain_type=curr_chain_type, + sequence_format=sequence_format, + ) + atom_name.append(rep_atom_name) + atom_element.append(rep_atom_element) + num_res_atoms += 1 + current_atom_key += 1 + + if sequence_format == SequenceFormat.LIGAND_SMILES: + # Sequence expect to be in the format :, + # which always corresponds to a single-residue chain. + ligand_id, ligand_smiles = sequence.split(':', maxsplit=1) + if ccd.get(ligand_id) is not None: + raise ValueError( + f'Ligand name {ligand_id} is in CCD - it is not supported to give' + ' ligands created from SMILES the same name as CCD components.' + ) + # We need to provide additional chemical components metadata for + # ligands specified via SMILES strings since they might not be in CCD. + _add_ligand_to_chem_comp(chem_comp, ligand_id, ligand_smiles) + + assert num_res_atoms >= 1 + res_atom_count.append(num_res_atoms) + num_chain_residues += 1 + res_id.append(current_res_id) + res_name.append(full_res_name) + + chain_id.append(current_chain_id) + chain_type.append(curr_chain_type) + chain_res_count.append(num_chain_residues) + + chem_comp_data = struc_chem_comps.ChemicalComponentsData(chem_comp) + chem_comp_data = struc_chem_comps.populate_missing_ccd_data( + ccd=ccd, + chemical_components_data=chem_comp_data, + chemical_component_ids=set(res_name), + ) + + if bonded_atom_pairs is not None: + unknown_bond_col = np.full((num_bonds,), '?', dtype=object) + if bond_type is None: + bond_type_col = unknown_bond_col + else: + bond_type_col = np.full((num_bonds,), bond_type, dtype=object) + bonds_table = bonds.Bonds( + key=np.arange(num_bonds, dtype=np.int64), + type=bond_type_col, + role=unknown_bond_col, + from_atom_key=from_atom_key, + dest_atom_key=dest_atom_key, + ) + else: + bonds_table = structure_tables.Bonds.make_empty() + + # 1 chain per sequence. + chain_key = np.arange(len(sequences), dtype=np.int64) + chain_id = np.array(chain_id, dtype=object) + chains_table = structure_tables.Chains( + key=chain_key, + id=chain_id, + type=np.array(chain_type, dtype=object), + auth_asym_id=chain_id, + entity_id=np.char.mod('%d', chain_key + 1).astype(object), + entity_desc=np.array(['.'] * len(chain_key), dtype=object), + ) + + res_key = np.arange(len(res_name), dtype=np.int64) + res_chain_key = np.repeat(chain_key, chain_res_count) + residues_table = structure_tables.Residues( + key=res_key, + chain_key=res_chain_key, + id=np.array(res_id, dtype=np.int32), + name=np.array(res_name, dtype=object), + auth_seq_id=np.char.mod('%d', res_id).astype(object), + insertion_code=np.full(len(res_name), '?', dtype=object), + ) + + num_atoms = current_atom_key + atom_float32_zeros = np.zeros(num_atoms, dtype=np.float32) + atoms_table = structure_tables.Atoms( + key=np.arange(num_atoms, dtype=np.int64), + chain_key=np.repeat(res_chain_key, res_atom_count), + res_key=np.repeat(res_key, res_atom_count), + name=np.array(atom_name, dtype=object), + element=np.array(atom_element, dtype=object), + x=atom_float32_zeros, + y=atom_float32_zeros, + z=atom_float32_zeros, + b_factor=atom_float32_zeros, + occupancy=np.ones(num_atoms, np.float32), + ) + + return structure.Structure( + name=name, + atoms=atoms_table, + residues=residues_table, + chains=chains_table, + bonds=bonds_table, + chemical_components_data=chem_comp_data, + **constructor_args, + ) + + +class _ChainResBuilder: + """Class for incrementally building chain and residue tables.""" + + def __init__( + self, + *, + chain_key_by_chain_id: Mapping[str, int], + entity_id_by_chain_id: Mapping[str, str], + chain_type_by_entity_id: Mapping[str, str], + entity_desc_by_entity_id: Mapping[str, str], + fix_mse_residues: bool, + fix_unknown_dna: bool, + ): + # Len: num_chains. + self.chain_key = [] + self.chain_id = [] + self.chain_type = [] + self.chain_auth_asym_id = [] + self.chain_entity_id = [] + self.chain_entity_desc = [] + + # Len: num_residues. + self.res_key = [] + self.res_chain_key = [] + self.res_id = [] + self.res_name = [] + self.res_auth_seq_id = [] + self.res_insertion_code = [] + + self.chain_key_by_chain_id = chain_key_by_chain_id + self.entity_id_by_chain_id = entity_id_by_chain_id + self.chain_type_by_entity_id = chain_type_by_entity_id + self.entity_desc_by_entity_id = entity_desc_by_entity_id + self.key_for_res: dict[tuple[str, str, str, str], int] = {} + + self._fix_mse_residues = fix_mse_residues + self._fix_unknown_dna = fix_unknown_dna + + def add_residues( + self, + *, + chain_ids: np.ndarray, + chain_auth_asym_ids: np.ndarray, + res_ids: np.ndarray, + res_names: np.ndarray, + res_auth_seq_ids: np.ndarray, + res_ins_codes: np.ndarray, + ): + """Adds a residue (and its chain) to the tables.""" + # Create chain table data. + if chain_ids.size == 0: + return + + chain_ids_with_prev = np.concatenate( + (([self.chain_id[-1] if self.chain_id else None], chain_ids)) + ) + chain_change_mask = chain_ids_with_prev[:-1] != chain_ids_with_prev[1:] + chain_change_ids = chain_ids[chain_change_mask] + chain_keys = string_array.remap( + chain_change_ids, self.chain_key_by_chain_id, inplace=False + ) + self.chain_key.extend(chain_keys) + self.chain_id.extend(chain_change_ids) + self.chain_auth_asym_id.extend(chain_auth_asym_ids[chain_change_mask]) + chain_entity_id = string_array.remap( + chain_change_ids, self.entity_id_by_chain_id, inplace=False + ) + self.chain_entity_id.extend(chain_entity_id) + chain_type = string_array.remap( + chain_entity_id, self.chain_type_by_entity_id, inplace=False + ) + self.chain_type.extend(chain_type) + chain_entity_desc = string_array.remap( + chain_entity_id, self.entity_desc_by_entity_id, inplace=False + ) + self.chain_entity_desc.extend(chain_entity_desc) + + # Create residue table data. + num_prev_res = len(self.res_id) + res_keys = np.arange(num_prev_res, num_prev_res + len(res_ids)) + res_iter = zip( + chain_ids, + res_auth_seq_ids, + res_names, + res_ins_codes, + strict=True, + ) + key_for_res_update = { + res_unique_id: res_key + for res_key, res_unique_id in enumerate(res_iter, num_prev_res) + } + self.key_for_res.update(key_for_res_update) + self.res_key.extend(res_keys) + self.res_chain_key.extend( + string_array.remap( + chain_ids, self.chain_key_by_chain_id, inplace=False) + ) + self.res_id.extend(res_ids) + self.res_name.extend(res_names) + self.res_auth_seq_id.extend(res_auth_seq_ids) + self.res_insertion_code.extend(res_ins_codes) + + def make_chains_table(self) -> structure_tables.Chains: + """Returns the Structure chains table.""" + chain_key = np.array(self.chain_key, dtype=np.int64) + if not np.all(chain_key[:-1] <= chain_key[1:]): + # If the order is inconsistent with the atoms table, sort so that it is. + order = np.argsort(self.chain_key, kind='stable') + return structure_tables.Chains( + key=chain_key[order], + id=np.array(self.chain_id, dtype=object)[order], + type=np.array(self.chain_type, dtype=object)[order], + auth_asym_id=np.array( + self.chain_auth_asym_id, dtype=object)[order], + entity_id=np.array(self.chain_entity_id, dtype=object)[order], + entity_desc=np.array( + self.chain_entity_desc, dtype=object)[order], + ) + return structure_tables.Chains( + key=chain_key, + id=np.array(self.chain_id, dtype=object), + type=np.array(self.chain_type, dtype=object), + auth_asym_id=np.array(self.chain_auth_asym_id, dtype=object), + entity_id=np.array(self.chain_entity_id, dtype=object), + entity_desc=np.array(self.chain_entity_desc, dtype=object), + ) + + def make_residues_table(self) -> structure_tables.Residues: + """Returns the Structure residues table.""" + res_name = np.array(self.res_name, dtype=object) + res_chain_key = np.array(self.res_chain_key, dtype=np.int64) + + if self._fix_mse_residues: + string_array.remap(res_name, mapping={'MSE': 'MET'}, inplace=True) + + if self._fix_unknown_dna: + # Remap residues from N -> DN in DNA chains only. + dna_chain_mask = ( + np.array(self.chain_type, dtype=object) == mmcif_names.DNA_CHAIN + ) + dna_chain_key = np.array(self.chain_key, dtype=object)[ + dna_chain_mask] + res_name[(res_name == 'N') & np.isin( + res_chain_key, dna_chain_key)] = 'DN' + + if not np.all(res_chain_key[:-1] <= res_chain_key[1:]): + # If the order is inconsistent with the atoms table, sort so that it is. + order = np.argsort(res_chain_key, kind='stable') + return structure_tables.Residues( + key=np.array(self.res_key, dtype=np.int64)[order], + chain_key=res_chain_key[order], + id=np.array(self.res_id, dtype=np.int32)[order], + name=res_name[order], + auth_seq_id=np.array(self.res_auth_seq_id, + dtype=object)[order], + insertion_code=np.array( + self.res_insertion_code, dtype=object)[order], + ) + return structure_tables.Residues( + key=np.array(self.res_key, dtype=np.int64), + chain_key=res_chain_key, + id=np.array(self.res_id, dtype=np.int32), + name=res_name, + auth_seq_id=np.array(self.res_auth_seq_id, dtype=object), + insertion_code=np.array(self.res_insertion_code, dtype=object), + ) + + +def _get_string_array_default(cif: mmcif.Mmcif, key: str, default: list[str]): + try: + return cif.get_array(key, dtype=object) + except KeyError: + return default + + +def _generate_required_tables_if_missing( + cif: mmcif.Mmcif, +) -> Mapping[str, Sequence[str]]: + """Generates all required tables and columns if missing.""" + update = {} + + atom_site_entities = _get_string_array_default( + cif, '_atom_site.label_entity_id', [] + ) + + # OpenMM produces files that don't have any of the tables and also have + # _atom_site.label_entity_id set to '?' for all atoms. We infer the entities + # based on the _atom_site.label_asym_id column. We start with cheaper O(1) + # checks to prevent running the expensive O(n) check on most files. + if ( + len(atom_site_entities) > 0 # pylint: disable=g-explicit-length-test + and '_entity.id' not in cif # Ignore if the _entity table exists. + and atom_site_entities[0] == '?' # Cheap check. + and set(atom_site_entities) == {'?'} # Expensive check. + ): + label_asym_ids = cif.get_array( + '_atom_site.label_asym_id', dtype=object) + atom_site_entities = [ + str(mmcif.str_id_to_int_id(cid)) for cid in label_asym_ids + ] + # Update _atom_site.label_entity_id to be consistent with the new tables. + update['_atom_site.label_entity_id'] = atom_site_entities + + # Check table existence by checking the presence of its primary key. + if '_struct_asym.id' not in cif: + # Infer the _struct_asym table using the _atom_site table. + asym_ids = _get_string_array_default( + cif, '_atom_site.label_asym_id', []) + + if len(atom_site_entities) == 0 or len(asym_ids) == 0: # pylint: disable=g-explicit-length-test + raise ValueError( + 'Could not parse an mmCIF with no _struct_asym table and also no ' + '_atom_site.label_entity_id or _atom_site.label_asym_id columns.' + ) + + # Deduplicate, but keep the order intact - dict.fromkeys maintains order. + entity_id_chain_id_pairs = list( + dict.fromkeys(zip(atom_site_entities, asym_ids, strict=True)) + ) + update['_struct_asym.entity_id'] = [ + e for e, _ in entity_id_chain_id_pairs] + update['_struct_asym.id'] = [c for _, c in entity_id_chain_id_pairs] + + if '_entity.id' not in cif: + # Infer the _entity_poly and _entity tables using the _atom_site table. + residues = _get_string_array_default( + cif, '_atom_site.label_comp_id', []) + group_pdb = _get_string_array_default(cif, '_atom_site.group_PDB', []) + if '_atom_site.label_entity_id' in cif: + entities = atom_site_entities + else: + # If _atom_site.label_entity_id not set, use the asym_id -> entity_id map. + asym_to_entity = dict( + zip( + cif['_struct_asym.id'], cif['_struct_asym.entity_id'], strict=True + ) + ) + entities = string_array.remap( + cif.get_array('_atom_site.label_asym_id', dtype=object), + mapping=asym_to_entity, + ) + + entity_ids = [] + entity_types = [] + entity_poly_entity_ids = [] + entity_poly_types = [] + entity_poly_table_missing = '_entity_poly.entity_id' not in cif + for entity_id, group in itertools.groupby( + zip(entities, residues, group_pdb, strict=True), key=lambda e: e[0] + ): + _, entity_residues, entity_group_pdb = zip(*group, strict=True) + entity_type = _guess_entity_type( + chain_residues=entity_residues, atom_types=entity_group_pdb + ) + entity_ids.append(entity_id) + entity_types.append(entity_type) + + if entity_poly_table_missing and entity_type == mmcif_names.POLYMER_CHAIN: + polymer_type = mmcif_names.guess_polymer_type(entity_residues) + entity_poly_entity_ids.append(entity_id) + entity_poly_types.append(polymer_type) + + update['_entity.id'] = entity_ids + update['_entity.type'] = entity_types + if entity_poly_table_missing: + update['_entity_poly.entity_id'] = entity_poly_entity_ids + update['_entity_poly.type'] = entity_poly_types + + if '_atom_site.type_symbol' not in cif: + update['_atom_site.type_symbol'] = mmcif.get_or_infer_type_symbol(cif) + + return update + + +def _maybe_add_missing_scheme_tables( + cif: mmcif.Mmcif, + res_starts: Sequence[int], + label_asym_ids: np.ndarray, + label_seq_ids: np.ndarray, + label_comp_ids: np.ndarray, + auth_seq_ids: np.ndarray, + pdb_ins_codes: np.ndarray, +) -> Mapping[str, Sequence[str]]: + """If missing, infers the scheme tables from the _atom_site table.""" + update = {} + + required_poly_seq_scheme_cols = ( + '_pdbx_poly_seq_scheme.asym_id', + '_pdbx_poly_seq_scheme.pdb_seq_num', + '_pdbx_poly_seq_scheme.pdb_ins_code', + '_pdbx_poly_seq_scheme.seq_id', + '_pdbx_poly_seq_scheme.mon_id', + '_pdbx_poly_seq_scheme.pdb_strand_id', + ) + if not all(col in cif for col in required_poly_seq_scheme_cols): + # Create a mask for atoms where each polymer residue start. + entity_id_by_chain_id = dict( + zip(cif['_struct_asym.id'], + cif['_struct_asym.entity_id'], strict=True) + ) + chain_type_by_entity_id = dict( + zip(cif['_entity.id'], cif['_entity.type'], strict=True) + ) + # Remap asym ID -> entity ID. + chain_type = string_array.remap( + label_asym_ids, mapping=entity_id_by_chain_id, inplace=False + ) + # Remap entity ID -> chain type. + string_array.remap( + chain_type, mapping=chain_type_by_entity_id, inplace=True + ) + res_mask = np.zeros_like(label_seq_ids, dtype=bool) + res_mask[res_starts] = True + res_mask &= chain_type == mmcif_names.POLYMER_CHAIN + + entity_poly_seq_cols = ( + '_entity_poly_seq.entity_id', + '_entity_poly_seq.num', + '_entity_poly_seq.mon_id', + ) + if all(col in cif for col in entity_poly_seq_cols): + # Use _entity_poly_seq if available. + poly_seq_num = cif.get_array('_entity_poly_seq.num', dtype=object) + poly_seq_mon_id = cif.get_array( + '_entity_poly_seq.mon_id', dtype=object) + poly_seq_entity_id = cif.get_array( + '_entity_poly_seq.entity_id', dtype=object + ) + label_seq_id_to_auth_seq_id = dict( + zip(label_seq_ids[res_mask], + auth_seq_ids[res_mask], strict=True) + ) + scheme_pdb_seq_num = string_array.remap( + poly_seq_num, mapping=label_seq_id_to_auth_seq_id, default_value='.' + ) + label_seq_id_to_ins_code = dict( + zip(label_seq_ids[res_mask], + pdb_ins_codes[res_mask], strict=True) + ) + scheme_pdb_ins_code = string_array.remap( + poly_seq_num, mapping=label_seq_id_to_ins_code, default_value='.' + ) + + # The _entity_poly_seq table is entity-based, while _pdbx_poly_seq_scheme + # is chain-based. A single entity could mean multiple chains (asym_ids), + # we therefore need to replicate each entity for all of the chains. + scheme_asym_id = [] + select = [] + indices = np.arange(len(poly_seq_entity_id), dtype=np.int32) + for asym_id, entity_id in zip( + cif['_struct_asym.id'], cif['_struct_asym.entity_id'], strict=True + ): + entity_mask = poly_seq_entity_id == entity_id + select.extend(indices[entity_mask]) + scheme_asym_id.extend([asym_id] * sum(entity_mask)) + + scheme_pdb_strand_id = string_array.remap( + np.array(scheme_asym_id, dtype=object), + mapping=mmcif.get_internal_to_author_chain_id_map(cif), + inplace=False, + ) + + update['_pdbx_poly_seq_scheme.asym_id'] = scheme_asym_id + update['_pdbx_poly_seq_scheme.pdb_strand_id'] = scheme_pdb_strand_id + update['_pdbx_poly_seq_scheme.pdb_seq_num'] = scheme_pdb_seq_num[select] + update['_pdbx_poly_seq_scheme.pdb_ins_code'] = scheme_pdb_ins_code[select] + update['_pdbx_poly_seq_scheme.seq_id'] = poly_seq_num[select] + update['_pdbx_poly_seq_scheme.mon_id'] = poly_seq_mon_id[select] + else: + # _entity_poly_seq not available, fallback to _atom_site. + res_asym_ids = label_asym_ids[res_mask] + res_strand_ids = string_array.remap( + array=res_asym_ids, + mapping=mmcif.get_internal_to_author_chain_id_map(cif), + inplace=False, + ) + update['_pdbx_poly_seq_scheme.asym_id'] = res_asym_ids + update['_pdbx_poly_seq_scheme.pdb_seq_num'] = auth_seq_ids[res_mask] + update['_pdbx_poly_seq_scheme.pdb_ins_code'] = pdb_ins_codes[res_mask] + update['_pdbx_poly_seq_scheme.seq_id'] = label_seq_ids[res_mask] + update['_pdbx_poly_seq_scheme.mon_id'] = label_comp_ids[res_mask] + update['_pdbx_poly_seq_scheme.pdb_strand_id'] = res_strand_ids + + required_nonpoly_scheme_cols = ( + '_pdbx_nonpoly_scheme.mon_id', + '_pdbx_nonpoly_scheme.asym_id', + '_pdbx_nonpoly_scheme.pdb_seq_num', + '_pdbx_nonpoly_scheme.pdb_ins_code', + ) + required_branch_scheme_cols = ( + '_pdbx_branch_scheme.mon_id', + '_pdbx_branch_scheme.asym_id', + '_pdbx_branch_scheme.pdb_seq_num', + ) + + # Generate _pdbx_nonpoly_scheme only if both tables are missing. + if not ( + all(col in cif for col in required_nonpoly_scheme_cols) + or all(col in cif for col in required_branch_scheme_cols) + ): + # To be strictly semantically correct, multi-residue ligands should be + # written in _pdbx_branch_scheme. However, Structure parsing handles + # correctly multi-residue ligands in _pdbx_nonpoly_scheme and the tables + # constructed here live only while parsing, hence this is unnecessary. + entity_id_by_chain_id = dict( + zip(cif['_struct_asym.id'], + cif['_struct_asym.entity_id'], strict=True) + ) + chain_type_by_entity_id = dict( + zip(cif['_entity.id'], cif['_entity.type'], strict=True) + ) + # Remap asym ID -> entity ID. + chain_type = string_array.remap( + label_asym_ids, mapping=entity_id_by_chain_id, inplace=False + ) + # Remap entity ID -> chain type. + string_array.remap( + chain_type, mapping=chain_type_by_entity_id, inplace=True + ) + res_mask = np.zeros_like(label_seq_ids, dtype=bool) + res_mask[res_starts] = True + res_mask &= chain_type != mmcif_names.POLYMER_CHAIN + + if not np.any(res_mask): + return update # Shortcut: no non-polymer residues. + + ins_codes = string_array.remap( + pdb_ins_codes[res_mask], mapping={'?': '.'}, inplace=False + ) + + update['_pdbx_nonpoly_scheme.asym_id'] = label_asym_ids[res_mask] + update['_pdbx_nonpoly_scheme.pdb_seq_num'] = auth_seq_ids[res_mask] + update['_pdbx_nonpoly_scheme.pdb_ins_code'] = ins_codes + update['_pdbx_nonpoly_scheme.mon_id'] = label_comp_ids[res_mask] + + return update + + +def _get_chain_key_by_chain_id( + resolved_chain_ids: np.ndarray, struct_asym_chain_ids: np.ndarray +) -> Mapping[str, int]: + """Returns chain key for each chain ID respecting resolved chain ordering.""" + # Check that all chain IDs found in the (potentially filtered) _atom_site + # table are present in the _struct_asym table. + unique_resolved_chain_ids = set(resolved_chain_ids) + if not unique_resolved_chain_ids.issubset(set(struct_asym_chain_ids)): + unique_resolved_chain_ids = sorted(unique_resolved_chain_ids) + unique_struct_asym_chain_ids = sorted(set(struct_asym_chain_ids)) + raise ValueError( + 'Bad mmCIF: chain IDs in _atom_site.label_asym_id ' + f'{unique_resolved_chain_ids} is not a subset of chain IDs in ' + f'_struct_asym.id {unique_struct_asym_chain_ids}.' + ) + + resolved_mask = string_array.isin( + struct_asym_chain_ids, unique_resolved_chain_ids + ) + # For all resolved chains, use the _atom_site order they appear in. E.g. + # resolved_chain_ids = [B A E D F] + # struct_asym_chain_ids = [A B C D E F] + # consistent_chain_order = [B A C E D F] + # chain_keys = [0 1 2 3 4 5] + consistent_chain_order = struct_asym_chain_ids.copy() + consistent_chain_order[resolved_mask] = resolved_chain_ids + return dict(zip(consistent_chain_order, range(len(struct_asym_chain_ids)))) + + +def get_tables( + cif: mmcif.Mmcif, + fix_mse_residues: bool, + fix_arginines: bool, + fix_unknown_dna: bool, + include_water: bool, + include_other: bool, + model_id: str, +) -> tuple[ + structure_tables.Chains, structure_tables.Residues, structure_tables.Atoms +]: + """Returns chain, residue, and atom tables from a parsed mmcif. + + Args: + cif: A parsed mmcif.Mmcif. + fix_mse_residues: See from_mmcif. + fix_arginines: See from_mmcif. + fix_unknown_dna: See from_mmcif. + include_water: See from_mmcif. + include_other: See from_mmcif. + model_id: A string defining which model ID to use. If set, only coordinates, + b-factors and occupancies for the given model are returned. If empty, + coordinates, b-factors and occupanciesall for models are returned with a + leading dimension of num_models. Note that the model_id argument in + from_mmcif is an integer and has slightly different use (see from_mmcif). + """ + # Add any missing tables and columns we require for parsing. + if cif_update := _generate_required_tables_if_missing(cif): + cif = cif.copy_and_update(cif_update) + + # Resolve alt-locs, selecting only a single option for each residue. Also + # computes the layout, which defines where chain and residue boundaries are. + atom_site_all_models, layout = mmcif_utils.filter( + cif, + include_nucleotides=True, + include_ligands=True, + include_water=include_water, + include_other=include_other, + model_id=model_id, + ) + atom_site_first_model = atom_site_all_models[0] + + # Get atom information from the _atom_site table. + def _first_model_string_array(col: str) -> np.ndarray: + return cif.get_array(col, dtype=object, gather=atom_site_first_model) + + def _requested_models_float_array(col: str) -> np.ndarray: + if not model_id: + # Return data for all models with a leading dimension of num_models. + return cif.get_array(col, dtype=np.float32, gather=atom_site_all_models) + else: + # Return data only for the single requested model. + return cif.get_array(col, dtype=np.float32, gather=atom_site_first_model) + + # These columns are the same for all models, fetch them just for the 1st one. + label_comp_ids = _first_model_string_array('_atom_site.label_comp_id') + label_asym_ids = _first_model_string_array('_atom_site.label_asym_id') + label_seq_ids = _first_model_string_array('_atom_site.label_seq_id') + label_atom_ids = _first_model_string_array('_atom_site.label_atom_id') + if '_atom_site.auth_seq_id' in cif: + auth_seq_ids = _first_model_string_array('_atom_site.auth_seq_id') + else: + # auth_seq_id unset, fallback to label_seq_id. + auth_seq_ids = label_seq_ids + type_symbols = _first_model_string_array('_atom_site.type_symbol') + pdbx_pdb_ins_codes = _first_model_string_array( + '_atom_site.pdbx_PDB_ins_code') + + # These columns are different for all models, fetch them as requested. + atom_x = _requested_models_float_array('_atom_site.Cartn_x') + atom_y = _requested_models_float_array('_atom_site.Cartn_y') + atom_z = _requested_models_float_array('_atom_site.Cartn_z') + atom_b_factor = _requested_models_float_array('_atom_site.B_iso_or_equiv') + atom_occupancy = _requested_models_float_array('_atom_site.occupancy') + + # Make sure the scheme (residue) tables exist in case they are not present. + if cif_update := _maybe_add_missing_scheme_tables( + cif, + res_starts=layout.residue_starts(), + label_asym_ids=label_asym_ids, + label_seq_ids=label_seq_ids, + label_comp_ids=label_comp_ids, + auth_seq_ids=auth_seq_ids, + pdb_ins_codes=pdbx_pdb_ins_codes, + ): + cif = cif.copy_and_update(cif_update) + + # Fix common issues found in mmCIF files, like swapped arginine NH atoms. + mmcif_utils.fix_residues( + layout, + comp_id=label_comp_ids, + atom_id=label_atom_ids, + atom_x=atom_x[0] if not model_id else atom_x, + atom_y=atom_y[0] if not model_id else atom_y, + atom_z=atom_z[0] if not model_id else atom_z, + fix_arg=fix_arginines, + ) + + # Get keys for chains in the order they appear in _atom_site while also + # dealing with empty chains. + resolved_chain_ids = label_asym_ids[layout.chain_starts()] + struct_asym_chain_ids = cif.get_array('_struct_asym.id', dtype=object) + + chain_key_by_chain_id = _get_chain_key_by_chain_id( + resolved_chain_ids=resolved_chain_ids, + struct_asym_chain_ids=struct_asym_chain_ids, + ) + entity_id_by_chain_id = dict( + zip(struct_asym_chain_ids, cif['_struct_asym.entity_id'], strict=True) + ) + entity_description = cif.get( + '_entity.pdbx_description', ['?'] * len(cif['_entity.id']) + ) + entity_desc_by_entity_id = dict( + zip(cif['_entity.id'], entity_description, strict=True) + ) + chain_type_by_entity_id = mmcif.get_chain_type_by_entity_id(cif) + auth_asym_id_by_chain_id = mmcif.get_internal_to_author_chain_id_map(cif) + + chain_res_builder = _ChainResBuilder( + chain_key_by_chain_id=chain_key_by_chain_id, + entity_id_by_chain_id=entity_id_by_chain_id, + chain_type_by_entity_id=chain_type_by_entity_id, + entity_desc_by_entity_id=entity_desc_by_entity_id, + fix_mse_residues=fix_mse_residues, + fix_unknown_dna=fix_unknown_dna, + ) + + # Collect data for polymer chain and residue tables. _pdbx_poly_seq_scheme is + # guaranteed to be present thanks to _maybe_add_missing_scheme_tables. + def _get_poly_seq_scheme_col(col: str) -> np.ndarray: + return cif.get_array(key=f'_pdbx_poly_seq_scheme.{col}', dtype=object) + + poly_seq_asym_ids = _get_poly_seq_scheme_col('asym_id') + poly_seq_pdb_seq_nums = _get_poly_seq_scheme_col('pdb_seq_num') + poly_seq_seq_ids = _get_poly_seq_scheme_col('seq_id') + poly_seq_mon_ids = _get_poly_seq_scheme_col('mon_id') + poly_seq_pdb_strand_ids = _get_poly_seq_scheme_col('pdb_strand_id') + poly_seq_pdb_ins_codes = _get_poly_seq_scheme_col('pdb_ins_code') + string_array.remap( + poly_seq_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + + # We resolved alt-locs earlier for the atoms table. In cases of heterogeneous + # residues (a residue with an alt-loc that is of different residue type), we + # need to also do the same resolution in the residues table. Compute a mask + # for the residues that were selected in the atoms table. + poly_seq_mask = mmcif_utils.selected_polymer_residue_mask( + layout=layout, + atom_site_label_asym_ids=label_asym_ids[layout.residue_starts()], + atom_site_label_seq_ids=label_seq_ids[layout.residue_starts()], + atom_site_label_comp_ids=label_comp_ids[layout.residue_starts()], + poly_seq_asym_ids=poly_seq_asym_ids, + poly_seq_seq_ids=poly_seq_seq_ids, + poly_seq_mon_ids=poly_seq_mon_ids, + ) + + if not include_other and poly_seq_mask: + # Mask filtered-out residues so that they are not treated as missing. + # Instead, we don't want them included in the chains/residues tables at all. + keep_mask = string_array.remap( + poly_seq_asym_ids, + mapping={cid: True for cid in resolved_chain_ids}, + default_value=False, + inplace=False, + ).astype(bool) + poly_seq_mask &= keep_mask + + chain_res_builder.add_residues( + chain_ids=poly_seq_asym_ids[poly_seq_mask], + chain_auth_asym_ids=poly_seq_pdb_strand_ids[poly_seq_mask], + res_ids=poly_seq_seq_ids[poly_seq_mask].astype(np.int32), + res_names=poly_seq_mon_ids[poly_seq_mask], + res_auth_seq_ids=poly_seq_pdb_seq_nums[poly_seq_mask], + res_ins_codes=poly_seq_pdb_ins_codes[poly_seq_mask], + ) + + # Collect data for ligand chain and residue tables. _pdbx_nonpoly_scheme + # could be empty/unset if there are only branched ligands. + def _get_nonpoly_scheme_col(col: str) -> np.ndarray: + key = f'_pdbx_nonpoly_scheme.{col}' + if f'_pdbx_nonpoly_scheme.{col}' in cif: + return cif.get_array(key=key, dtype=object) + else: + return np.array([], dtype=object) + + nonpoly_asym_ids = _get_nonpoly_scheme_col('asym_id') + nonpoly_auth_seq_ids = _get_nonpoly_scheme_col('pdb_seq_num') + nonpoly_pdb_ins_codes = _get_nonpoly_scheme_col('pdb_ins_code') + nonpoly_mon_ids = _get_nonpoly_scheme_col('mon_id') + nonpoly_auth_asym_id = string_array.remap( + nonpoly_asym_ids, mapping=auth_asym_id_by_chain_id, inplace=False + ) + + def _get_branch_scheme_col(col: str) -> np.ndarray: + key = f'_pdbx_branch_scheme.{col}' + if f'_pdbx_branch_scheme.{col}' in cif: + return cif.get_array(key=key, dtype=object) + else: + return np.array([], dtype=object) + + branch_asym_ids = _get_branch_scheme_col('asym_id') + branch_auth_seq_ids = _get_branch_scheme_col('pdb_seq_num') + branch_pdb_ins_codes = _get_branch_scheme_col('pdb_ins_code') + branch_mon_ids = _get_branch_scheme_col('mon_id') + branch_auth_asym_id = string_array.remap( + branch_asym_ids, mapping=auth_asym_id_by_chain_id, inplace=False + ) + + if branch_asym_ids.size > 0 and branch_pdb_ins_codes.size == 0: + branch_pdb_ins_codes = np.array( + ['.'] * branch_asym_ids.size, dtype=object) + + # Compute the heterogeneous residue masks as above, this time for ligands. + nonpoly_mask, branch_mask = mmcif_utils.selected_ligand_residue_mask( + layout=layout, + atom_site_label_asym_ids=label_asym_ids[layout.residue_starts()], + atom_site_label_seq_ids=label_seq_ids[layout.residue_starts()], + atom_site_auth_seq_ids=auth_seq_ids[layout.residue_starts()], + atom_site_label_comp_ids=label_comp_ids[layout.residue_starts()], + atom_site_pdbx_pdb_ins_codes=pdbx_pdb_ins_codes[layout.residue_starts( + )], + nonpoly_asym_ids=nonpoly_asym_ids, + nonpoly_auth_seq_ids=nonpoly_auth_seq_ids, + nonpoly_pdb_ins_codes=nonpoly_pdb_ins_codes, + nonpoly_mon_ids=nonpoly_mon_ids, + branch_asym_ids=branch_asym_ids, + branch_auth_seq_ids=branch_auth_seq_ids, + branch_pdb_ins_codes=branch_pdb_ins_codes, + branch_mon_ids=branch_mon_ids, + ) + + if not include_water: + if nonpoly_mask: + nonpoly_mask &= (nonpoly_mon_ids != 'HOH') & ( + nonpoly_mon_ids != 'DOD') + if branch_mask: + # Fix for bad mmCIFs that have water in the branch scheme table. + branch_mask &= (branch_mon_ids != 'HOH') & ( + branch_mon_ids != 'DOD') + + string_array.remap( + pdbx_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + string_array.remap( + nonpoly_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + string_array.remap( + branch_pdb_ins_codes, mapping=_INSERTION_CODE_REMAP, inplace=True + ) + + def _ligand_residue_ids(chain_ids: np.ndarray) -> np.ndarray: + """Computes internal residue ID for ligand residues that don't have it.""" + + # E.g. chain_ids=[A, A, A, B, C, C, D, D, D] -> [1, 2, 3, 1, 1, 2, 1, 2, 3]. + indices = np.arange(chain_ids.size, dtype=np.int32) + return (indices + 1) - np.maximum.accumulate( + indices * (chain_ids != np.roll(chain_ids, 1)) + ) + + branch_residue_ids = _ligand_residue_ids(branch_asym_ids[branch_mask]) + nonpoly_residue_ids = _ligand_residue_ids(nonpoly_asym_ids[nonpoly_mask]) + + chain_res_builder.add_residues( + chain_ids=branch_asym_ids[branch_mask], + chain_auth_asym_ids=branch_auth_asym_id[branch_mask], + res_ids=branch_residue_ids, + res_names=branch_mon_ids[branch_mask], + res_auth_seq_ids=branch_auth_seq_ids[branch_mask], + res_ins_codes=branch_pdb_ins_codes[branch_mask], + ) + + chain_res_builder.add_residues( + chain_ids=nonpoly_asym_ids[nonpoly_mask], + chain_auth_asym_ids=nonpoly_auth_asym_id[nonpoly_mask], + res_ids=nonpoly_residue_ids, + res_names=nonpoly_mon_ids[nonpoly_mask], + res_auth_seq_ids=nonpoly_auth_seq_ids[nonpoly_mask], + res_ins_codes=nonpoly_pdb_ins_codes[nonpoly_mask], + ) + + chains = chain_res_builder.make_chains_table() + residues = chain_res_builder.make_residues_table() + + # Construct foreign residue keys for the atoms table. + res_ends = np.array(layout.residues(), dtype=np.int32) + res_starts = np.array(layout.residue_starts(), dtype=np.int32) + res_lengths = res_ends - res_starts + + # Check just for HOH, DOD can be part e.g. of hydroxycysteine. + if include_water: + res_chain_types = chains.apply_array_to_column( + column_name='type', arr=residues.chain_key + ) + water_mask = res_chain_types != mmcif_names.WATER + if 'HOH' in set(residues.name[water_mask]): + raise ValueError( + 'Bad mmCIF file: non-water entity has water molecules.') + else: + # Include resolved and unresolved residues. + if 'HOH' in set(residues.name) | set(label_comp_ids[res_starts]): + raise ValueError( + 'Bad mmCIF file: non-water entity has water molecules.') + + atom_chain_key = string_array.remap( + label_asym_ids, mapping=chain_res_builder.chain_key_by_chain_id + ).astype(int) + + # If any of the residue lookups failed, the mmCIF is corrupted. + try: + atom_res_key_per_res = string_array.remap_multiple( + ( + label_asym_ids[res_starts], + auth_seq_ids[res_starts], + label_comp_ids[res_starts], + pdbx_pdb_ins_codes[res_starts], + ), + mapping=chain_res_builder.key_for_res, + ) + except KeyError as e: + raise ValueError( + 'Lookup for the following atom from the _atom_site table failed: ' + f'(atom_id, auth_seq_id, res_name, ins_code)={e}. This is ' + 'likely due to a known issue with some multi-model mmCIFs that only ' + 'match the first model in _atom_site table to the _pdbx_poly_scheme, ' + '_pdbx_nonpoly_scheme, or _pdbx_branch_scheme tables.' + ) from e + + # The residue ID will be shared for all atoms within that residue. + atom_res_key = np.repeat(atom_res_key_per_res, repeats=res_lengths) + + if fix_mse_residues: + met_residues_mask = (residues.name == 'MET')[atom_res_key] + unfixed_mse_selenium_mask = met_residues_mask & ( + label_atom_ids == 'SE') + label_atom_ids[unfixed_mse_selenium_mask] = 'SD' + type_symbols[unfixed_mse_selenium_mask] = 'S' + + atoms = structure_tables.Atoms( + key=atom_site_first_model, + chain_key=atom_chain_key, + res_key=atom_res_key, + name=label_atom_ids, + element=type_symbols, + x=atom_x, + y=atom_y, + z=atom_z, + b_factor=atom_b_factor, + occupancy=atom_occupancy, + ) + + return chains, residues, atoms + + +def from_atom_arrays( + *, + res_id: np.ndarray, + name: str = 'unset', + release_date: datetime.date | None = None, + resolution: float | None = None, + structure_method: str | None = None, + all_residues: Mapping[str, Sequence[tuple[str, int]]] | None = None, + bioassembly_data: bioassemblies.BioassemblyData | None = None, + chemical_components_data: ( + struc_chem_comps.ChemicalComponentsData | None + ) = None, + bond_table: structure_tables.Bonds | None = None, + chain_id: np.ndarray | None = None, + chain_type: np.ndarray | None = None, + res_name: np.ndarray | None = None, + atom_key: np.ndarray | None = None, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, +) -> structure.Structure: + """Returns a Structure constructed from atom array level data. + + All fields except name and, res_id are optional, all array fields consist of a + value for each atom in the structure - so residue and chain values should hold + the same value for each atom in the chain or residue. Fields which are not + defined are filled with default values. + + Validation is performed by the Structure constructor where possible - but + author_naming scheme and all_residues must be checked in this function. + + It is not possible to construct structures with chains that do not contain + any resolved residues using this function. If this is necessary, use the + structure.Structure constructor directly. + + Args: + res_id: Integer array of shape [num_atom]. The unique residue identifier for + each residue. mmCIF field - _atom_site.label_seq_id. + name: The name of the structure. E.g. a PDB ID. + release_date: The release date of the structure as a `datetime.date`. + resolution: The resolution of the structure in Angstroms. + structure_method: The method used to solve this structure's coordinates. + all_residues: An optional mapping from each chain ID (i.e. label_asym_id) to + a sequence of (label_comp_id, label_seq_id) tuples, one per residue. This + can contain residues that aren't present in the atom arrays. This is + common in experimental data where some residues are not resolved but are + known to be present. + bioassembly_data: An optional instance of bioassembly.BioassemblyData. If + present then a new Structure representing a specific bioassembly can be + extracted using `Structure.generate_bioassembly(assembly_id)`. + chemical_components_data: An optional instance of ChemicalComponentsData. + Its content will be used for providing metadata about chemical components + in this Structure instance. If not specified information will be retrieved + from the standard chemical component dictionary (CCD, for more details see + https://www.wwpdb.org/data/ccd). + bond_table: A table representing manually-specified bonds. This corresponds + to the _struct_conn table in an mmCIF. Atoms are identified by their key, + as specified by the atom_key column. If this table is provided then the + atom_key column must also be defined. + chain_id: String array of shape [num_atom] of unique chain identifiers. + mmCIF field - _atom_site.label_asym_id. + chain_type: String array of shape [num_atom]. The molecular type of the + current chain (e.g. polyribonucleotide). mmCIF field - _entity_poly.type + OR _entity.type (for non-polymers). + res_name: String array of shape [num_atom].. The name of each residue, + typically a 3 letter string for polypeptides or 1-2 letter strings for + polynucleotides. mmCIF field - _atom_site.label_comp_id. + atom_key: A unique sorted integer array, used only by the bonds table to + identify the atoms participating in each bond. If the bonds table is + specified then this column must be non-None. + atom_name: String array of shape [num_atom]. The name of each atom (e.g CA, + O2', etc.). mmCIF field - _atom_site.label_atom_id. + atom_element: String array of shape [num_atom]. The element type of each + atom (e.g. C, O, N, etc.). mmCIF field - _atom_site.type_symbol. + atom_x: Float array of shape [..., num_atom] of atom x coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_y: Float array of shape [..., num_atom] of atom y coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_z: Float array of shape [..., num_atom] of atom z coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_b_factor: Float array of shape [..., num_atom] or [num_atom] of atom + b-factors or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + atom_occupancy: Float array of shape [..., num_atom] or [num_atom] of atom + occupancies or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + """ + + atoms, residues, chains = structure_tables.tables_from_atom_arrays( + res_id=res_id, + all_residues=all_residues, + chain_id=chain_id, + chain_type=chain_type, + res_name=res_name, + atom_key=atom_key, + atom_name=atom_name, + atom_element=atom_element, + atom_x=atom_x, + atom_y=atom_y, + atom_z=atom_z, + atom_b_factor=atom_b_factor, + atom_occupancy=atom_occupancy, + ) + + return structure.Structure( + name=name, + release_date=release_date, + resolution=resolution, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chemical_components_data, + atoms=atoms, + chains=chains, + residues=residues, + bonds=bond_table or structure_tables.Bonds.make_empty(), + ) + + +def _guess_entity_type( + chain_residues: Collection[str], atom_types: Collection[str] +) -> str: + """Guess the entity type (polymer/non-polymer/water) based on residues/atoms. + + We treat both arguments as unordered collections since we care only whether + all elements satisfy come conditions. The chain_residues can be either + grouped by residue (length num_res), or it can be raw (length num_atoms). + Atom type is unique for each atom in a residue, so don't group atom_types. + + Args: + chain_residues: A sequence of full residue name (1-letter for DNA, 2-letters + for RNA, 3 for protein). The _atom_site.label_comp_id column in mmCIF. + atom_types: Atom type: ATOM or HETATM. The _atom_site.group_PDB column in + mmCIF. + + Returns: + One of polymer/non-polymer/water based on the following criteria: + * If all atoms are HETATMs and all residues are water -> water. + * If all atoms are HETATMs and not all residues are water -> non-polymer. + * Otherwise -> polymer. + """ + if not chain_residues or not atom_types: + raise ValueError( + f'chain_residues (len {len(chain_residues)}) and atom_types (len ' + f'{len(atom_types)}) must be both non-empty. Got: {chain_residues=} ' + f'and {atom_types=}' + ) + + if all(a == 'HETATM' for a in atom_types): + if all(c in residue_names.WATER_TYPES for c in chain_residues): + return mmcif_names.WATER + return mmcif_names.NON_POLYMER_CHAIN + return mmcif_names.POLYMER_CHAIN diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py new file mode 100644 index 000000000..12fbd7ae4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/sterics.py @@ -0,0 +1,142 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Functions relating to spatial locations of atoms within a structure.""" + +from collections.abc import Collection, Sequence + +from alphafold3 import structure +from alphafold3.structure import mmcif +import numpy as np +import scipy + + +def _make_atom_has_clash_mask( + kd_query_result: np.ndarray, + struc: structure.Structure, + ignore_chains: Collection[str], +) -> np.ndarray: + """Returns a boolean NumPy array representing whether each atom has a clash. + + Args: + kd_query_result: NumPy array containing N-atoms arrays, each array + containing indices to atoms that clash with the N'th atom. + struc: Structure over which clashes were detected. + ignore_chains: Collection of chains that should not be considered clashing. + A boolean NumPy array of length N atoms. + """ + atom_is_clashing = np.zeros((struc.num_atoms,), dtype=bool) + for atom_index, clashes in enumerate(kd_query_result): + chain_i = struc.chain_id[atom_index] + if chain_i in ignore_chains: + continue + islig_i = struc.is_ligand_mask[atom_index] + for clashing_atom_index in clashes: + chain_c = struc.chain_id[clashing_atom_index] + if chain_c in ignore_chains: + continue + islig_c = struc.is_ligand_mask[clashing_atom_index] + if ( + clashing_atom_index == atom_index + or chain_i == chain_c + or islig_i != islig_c + ): + # Ignore clashes within chain or between ligand and polymer. + continue + atom_is_clashing[atom_index] = True + return atom_is_clashing + + +def find_clashing_chains( + struc: structure.Structure, + clash_thresh_angstrom: float = 1.7, + clash_thresh_fraction: float = 0.3, +) -> Sequence[str]: + """Finds chains that clash with others. + + Clashes are defined by polymer backbone atoms and all ligand atoms. + Ligand-polymer clashes are not dropped. + + Will not find clashes if all coordinates are 0. Coordinates are all 0s if + the structure is generated from sequences only, as done for inference in + dendro for example. + + Args: + struc: The structure defining the chains and atom positions. + clash_thresh_angstrom: Below this distance, atoms are considered clashing. + clash_thresh_fraction: Chains with more than this fraction of their atoms + considered clashing will be dropped. This value should be in the range (0, + 1]. + + Returns: + A sequence of chain ids for chains that clash. + + Raises: + ValueError: If `clash_thresh_fraction` is not in range (0,1]. + """ + if not 0 < clash_thresh_fraction <= 1: + raise ValueError('clash_thresh_fraction must be in range (0,1]') + + struc_backbone = struc.filter_polymers_to_single_atom_per_res() + if struc_backbone.num_chains == 0: + return [] + + # If the coordinates are all 0, do not search for clashes. + if not np.any(struc_backbone.coords): + return [] + + coord_kdtree = scipy.spatial.cKDTree(struc_backbone.coords) + + # For each atom coordinate, find all atoms within the clash thresh radius. + clashing_per_atom = coord_kdtree.query_ball_point( + struc_backbone.coords, r=clash_thresh_angstrom + ) + chain_ids = struc_backbone.chains + if struc_backbone.atom_occupancy is not None: + chain_occupancy = np.array([ + np.mean(struc_backbone.atom_occupancy[start:end]) + for start, end in struc_backbone.iter_chain_ranges() + ]) + else: + chain_occupancy = None + + # Remove chains until no more significant clashing. + chains_to_remove = set() + for _ in range(len(chain_ids)): + # Calculate maximally clashing. + atom_has_clash = _make_atom_has_clash_mask( + clashing_per_atom, struc_backbone, chains_to_remove + ) + clashes_per_chain = np.array([ + atom_has_clash[start:end].mean() + for start, end in struc_backbone.iter_chain_ranges() + ]) + max_clash = np.max(clashes_per_chain) + if max_clash <= clash_thresh_fraction: + # None of the remaining chains exceed the clash fraction threshold, so + # we can exit. + break + + # Greedily remove worst with the lowest occupancy. + most_clashes = np.nonzero(clashes_per_chain == max_clash)[0] + if chain_occupancy is not None: + occupancy_clashing = chain_occupancy[most_clashes] + last_lowest_occupancy = ( + len(occupancy_clashing) - + np.argmin(occupancy_clashing[::-1]) - 1 + ) + worst_and_last = most_clashes[last_lowest_occupancy] + else: + worst_and_last = most_clashes[-1] + + chains_to_remove.add(chain_ids[worst_and_last]) + + return sorted(chains_to_remove, key=mmcif.str_id_to_int_id) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py new file mode 100644 index 000000000..192cf84e0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure.py @@ -0,0 +1,3181 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Structure class for representing and processing molecular structures.""" + +import collections +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence, Set +import dataclasses +import datetime +import enum +import functools +import itertools +import typing +from typing_extensions import Any, ClassVar, Final, Literal, NamedTuple, Self, TypeAlias, TypeVar + +from alphafold3.constants import atom_types +from alphafold3.constants import chemical_components +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import membership +from alphafold3.cpp import string_array +from alphafold3.structure import bioassemblies +from alphafold3.structure import chemical_components as struc_chem_comps +from alphafold3.structure import mmcif +from alphafold3.structure import structure_tables +from alphafold3.structure import table +import numpy as np + +# Controls the default number of decimal places for coordinates when writing to +# mmCIF. +_COORDS_DECIMAL_PLACES: Final[int] = 3 + + +@enum.unique +class CascadeDelete(enum.Enum): + NONE = 0 + FULL = 1 + CHAINS = 2 + + +# See www.python.org/dev/peps/pep-0484/#support-for-singleton-types-in-unions +class _UnsetSentinel(enum.Enum): + UNSET = object() + + +_UNSET = _UnsetSentinel.UNSET + + +class Bond(NamedTuple): + """Describes a bond between two atoms.""" + + from_atom: Mapping[str, str | int | float | np.ndarray] + dest_atom: Mapping[str, str | int | float | np.ndarray] + bond_info: Mapping[str, str | int] + + +class MissingAtomError(Exception): + """Error raised when an atom is missing during alignment.""" + + +class MissingAuthorResidueIdError(Exception): + """Raised when author naming data is missing for a residue. + + This can occur in certain edge cases where missing residue data is provided + without also providing author IDs for those missing residues. + """ + + +# AllResidues is a mapping from label_asym_id to a sequence of (label_comp_id, +# label_seq_id) pairs. These represent the full sequence including residues +# that might be missing (e.g. unresolved residues in X-ray data). +AllResidues: TypeAlias = Mapping[str, Sequence[tuple[str, int]]] +AuthorNamingScheme: TypeAlias = structure_tables.AuthorNamingScheme + + +# External residue ID given to missing residues that don't have an ID +# already provided. In mmCIFs this data is found in _pdbx_poly_seq_scheme. +MISSING_AUTH_SEQ_ID: Final[str] = '.' + + +# Maps from structure fields to column names in the relevant table. +CHAIN_FIELDS: Final[Mapping[str, str]] = { + 'chain_id': 'id', + 'chain_type': 'type', + 'chain_auth_asym_id': 'auth_asym_id', + 'chain_entity_id': 'entity_id', + 'chain_entity_desc': 'entity_desc', +} + + +RESIDUE_FIELDS: Final[Mapping[str, str]] = { + 'res_id': 'id', + 'res_name': 'name', + 'res_auth_seq_id': 'auth_seq_id', + 'res_insertion_code': 'insertion_code', +} + +ATOM_FIELDS: Final[Mapping[str, str]] = { + 'atom_name': 'name', + 'atom_element': 'element', + 'atom_x': 'x', + 'atom_y': 'y', + 'atom_z': 'z', + 'atom_b_factor': 'b_factor', + 'atom_occupancy': 'occupancy', + 'atom_key': 'key', +} + +# Fields in structure. +ARRAY_FIELDS = frozenset({ + 'atom_b_factor', + 'atom_element', + 'atom_key', + 'atom_name', + 'atom_occupancy', + 'atom_x', + 'atom_y', + 'atom_z', + 'chain_id', + 'chain_type', + 'res_id', + 'res_name', +}) + +GLOBAL_FIELDS = frozenset({ + 'name', + 'release_date', + 'resolution', + 'structure_method', + 'bioassembly_data', + 'chemical_components_data', +}) + +# Fields which can be updated in copy_and_update. +_UPDATEABLE_FIELDS: Final[Set[str]] = frozenset({ + 'all_residues', + 'atom_b_factor', + 'atom_element', + 'atom_key', + 'atom_name', + 'atom_occupancy', + 'atom_x', + 'atom_y', + 'atom_z', + 'bioassembly_data', + 'bonds', + 'chain_id', + 'chain_type', + 'chemical_components_data', + 'name', + 'release_date', + 'res_id', + 'res_name', + 'resolution', + 'structure_method', +}) + + +def fix_non_standard_polymer_residues( + res_names: np.ndarray, chain_type: str +) -> np.ndarray: + """Remaps residue names to the closest standard protein/RNA/DNA residue. + + If residue name is already a standard type, it is not altered. + If a match cannot be found, returns 'UNK' for protein chainresidues and 'N' + for RNA/DNA chain residue. + + Args: + res_names: A numpy array of string residue names (CCD monomer codes). E.g. + 'ARG' (protein), 'DT' (DNA), 'N' (RNA). + chain_type: The type of the chain, must be PROTEIN_CHAIN, RNA_CHAIN or + DNA_CHAIN. + + Returns: + An array remapped so that its elements are all from + PROTEIN_TYPES_WITH_UNKNOWN | RNA_TYPES | DNA_TYPES | {'N'}. + + Raises: + ValueError: If chain_type not in PEPTIDE_CHAIN_TYPES or + {OTHER_CHAIN, RNA_CHAIN, DNA_CHAIN, DNA_RNA_HYBRID_CHAIN}. + """ + # Map to one letter code, then back to common res_names. + one_letter_codes = string_array.remap( + res_names, mapping=residue_names.CCD_NAME_TO_ONE_LETTER, default_value='X' + ) + + if ( + chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES + or chain_type == mmcif_names.OTHER_CHAIN + ): + mapping = residue_names.PROTEIN_COMMON_ONE_TO_THREE + default_value = 'UNK' + elif chain_type == mmcif_names.RNA_CHAIN: + # RNA has single-letter CCD monomer codes. + mapping = {r: r for r in residue_names.RNA_TYPES} + default_value = 'N' + elif chain_type == mmcif_names.DNA_CHAIN: + mapping = residue_names.DNA_COMMON_ONE_TO_TWO + default_value = 'N' + elif chain_type == mmcif_names.DNA_RNA_HYBRID_CHAIN: + mapping = {r: r for r in residue_names.NUCLEIC_TYPES_WITH_UNKNOWN} + default_value = 'N' + else: + raise ValueError( + f'Expected a protein/DNA/RNA chain but got {chain_type}') + + return string_array.remap( + one_letter_codes, mapping=mapping, default_value=default_value + ) + + +def _get_change_indices(arr: np.ndarray) -> np.ndarray: + if arr.size == 0: + return np.array([], dtype=np.int32) + else: + changing_idxs = np.where(arr[1:] != arr[:-1])[0] + 1 + return np.concatenate(([0], changing_idxs), axis=0) + + +def _unpack_filter_predicates( + predicate_by_field_name: Mapping[str, table.FilterPredicate], +) -> tuple[ + Mapping[str, table.FilterPredicate], + Mapping[str, table.FilterPredicate], + Mapping[str, table.FilterPredicate], +]: + """Unpacks filter kwargs into predicates for each table.""" + chain_predicates = {} + res_predicates = {} + atom_predicates = {} + for k, pred in predicate_by_field_name.items(): + if col := CHAIN_FIELDS.get(k): + chain_predicates[col] = pred + elif col := RESIDUE_FIELDS.get(k): + res_predicates[col] = pred + elif col := ATOM_FIELDS.get(k): + atom_predicates[col] = pred + else: + raise ValueError(k) + return chain_predicates, res_predicates, atom_predicates + + +_T = TypeVar('_T') + + +SCALAR_FIELDS: Final[Collection[str]] = frozenset({ + 'name', + 'release_date', + 'resolution', + 'structure_method', + 'bioassembly_data', + 'chemical_components_data', +}) + + +TABLE_FIELDS: Final[Collection[str]] = frozenset( + {'chains', 'residues', 'atoms', 'bonds'} +) + + +V2_FIELDS: Final[Collection[str]] = frozenset({*SCALAR_FIELDS, *TABLE_FIELDS}) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class StructureTables: + chains: structure_tables.Chains + residues: structure_tables.Residues + atoms: structure_tables.Atoms + bonds: structure_tables.Bonds + + +class Structure(table.Database): + """Structure class for representing and processing molecular structures.""" + + tables: ClassVar[Collection[str]] = TABLE_FIELDS + + foreign_keys: ClassVar[Mapping[str, Collection[tuple[str, str]]]] = { + 'residues': (('chain_key', 'chains'),), + 'atoms': (('chain_key', 'chains'), ('res_key', 'residues')), + 'bonds': (('from_atom_key', 'atoms'), ('dest_atom_key', 'atoms')), + } + + def __init__( + self, + *, + name: str = 'unset', + release_date: datetime.date | None = None, + resolution: float | None = None, + structure_method: str | None = None, + bioassembly_data: bioassemblies.BioassemblyData | None = None, + chemical_components_data: ( + struc_chem_comps.ChemicalComponentsData | None + ) = None, + chains: structure_tables.Chains, + residues: structure_tables.Residues, + atoms: structure_tables.Atoms, + bonds: structure_tables.Bonds, + skip_validation: bool = False, + ): + # Version number is written to mmCIF and should be incremented when changes + # are made to mmCIF writing or internals that affect this. + # b/345221494 Rename this variable when structure_v1 compatibility code + # is removed. + self._VERSION = '2.0.0' # pylint: disable=invalid-name + self._name = name + self._release_date = release_date + self._resolution = resolution + self._structure_method = structure_method + self._bioassembly_data = bioassembly_data + self._chemical_components_data = chemical_components_data + + self._chains = chains + self._residues = residues + self._atoms = atoms + self._bonds = bonds + + if not skip_validation: + self._validate_table_foreign_keys() + self._validate_consistent_table_ordering() + + def _validate_table_foreign_keys(self): + """Validates that all foreign keys are present in the referred tables.""" + residue_keys = set(self._residues.key) + chain_keys = set(self._chains.key) + if np.any(membership.isin(self._atoms.res_key, residue_keys, invert=True)): + raise ValueError( + 'Atom residue keys not in the residues table: ' + f'{set(self._atoms.res_key).difference(self._residues.key)}' + ) + if np.any(membership.isin(self._atoms.chain_key, chain_keys, invert=True)): + raise ValueError( + 'Atom chain keys not in the chains table: ' + f'{set(self._atoms.chain_key).difference(self._chains.key)}' + ) + if np.any( + membership.isin(self._residues.chain_key, chain_keys, invert=True) + ): + raise ValueError( + 'Residue chain keys not in the chains table: ' + f'{set(self._residues.chain_key).difference(self._chains.key)}' + ) + + def _validate_consistent_table_ordering(self): + """Validates that all tables have the same ordering.""" + atom_chain_keys = self._atoms.chain_key[self.chain_boundaries] + atom_res_keys = self._atoms.res_key[self.res_boundaries] + + if not np.array_equal(self.present_chains.key, atom_chain_keys): + raise ValueError( + f'Atom table chain order\n{atom_chain_keys}\ndoes not match the ' + f'chain table order\n{self._chains.key}' + ) + if not np.array_equal(self.present_residues.key, atom_res_keys): + raise ValueError( + f'Atom table residue order\n{atom_res_keys}\ndoes not match the ' + f'present residue table order\n{self.present_residues.key}' + ) + + def get_table(self, table_name: str) -> table.Table: + match table_name: + case 'chains': + return self.chains_table + case 'residues': + return self.residues_table + case 'atoms': + return self.atoms_table + case 'bonds': + return self.bonds_table + case _: + raise ValueError(table_name) + + @property + def chains_table(self) -> structure_tables.Chains: + """Chains table.""" + return self._chains + + @property + def residues_table(self) -> structure_tables.Residues: + """Residues table.""" + return self._residues + + @property + def atoms_table(self) -> structure_tables.Atoms: + """Atoms table.""" + return self._atoms + + @property + def bonds_table(self) -> structure_tables.Bonds: + """Bonds table.""" + return self._bonds + + @property + def name(self) -> str: + return self._name + + @property + def release_date(self) -> datetime.date | None: + return self._release_date + + @property + def resolution(self) -> float | None: + return self._resolution + + @property + def structure_method(self) -> str | None: + return self._structure_method + + @property + def bioassembly_data(self) -> bioassemblies.BioassemblyData | None: + return self._bioassembly_data + + @property + def chemical_components_data( + self, + ) -> struc_chem_comps.ChemicalComponentsData | None: + return self._chemical_components_data + + @property + def bonds(self) -> structure_tables.Bonds: + return self._bonds + + @functools.cached_property + def author_naming_scheme(self) -> AuthorNamingScheme: + auth_asym_id = {} + entity_id = {} + entity_desc = {} + auth_seq_id = collections.defaultdict(dict) + insertion_code = collections.defaultdict(dict) + + for chain_i in range(self._chains.size): + chain_id = self._chains.id[chain_i] + auth_asym_id[chain_id] = self._chains.auth_asym_id[chain_i] + chain_entity_id = self._chains.entity_id[chain_i] + entity_id[chain_id] = chain_entity_id + entity_desc[chain_entity_id] = self._chains.entity_desc[chain_i] + + chain_index_by_key = self._chains.index_by_key + for res_i in range(self._residues.size): + chain_key = self._residues.chain_key[res_i] + chain_id = self._chains.id[chain_index_by_key[chain_key]] + res_id = self._residues.id[res_i] + res_auth_seq_id = self._residues.auth_seq_id[res_i] + if res_auth_seq_id == MISSING_AUTH_SEQ_ID: + continue + auth_seq_id[chain_id][res_id] = res_auth_seq_id + ins_code = self._residues.insertion_code[res_i] + # Compatibility with Structure v1 which used None to represent . or ?. + insertion_code[chain_id][res_id] = ( + ins_code if ins_code not in {'.', '?'} else None + ) + + return AuthorNamingScheme( + auth_asym_id=auth_asym_id, + entity_id=entity_id, + entity_desc=entity_desc, + auth_seq_id=dict(auth_seq_id), + insertion_code=dict(insertion_code), + ) + + @functools.cached_property + def all_residues(self) -> AllResidues: + chain_id_by_key = dict(zip(self._chains.key, self._chains.id)) + residue_chain_boundaries = _get_change_indices( + self._residues.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, count_unresolved=True + ) + return { + chain_id_by_key[self._residues.chain_key[start]]: list( + zip(self._residues.name[start:end], + self._residues.id[start:end]) + ) + for start, end in boundaries + } + + @functools.cached_property + def label_asym_id_to_entity_id(self) -> Mapping[str, str]: + return dict(zip(self._chains.id, self._chains.entity_id)) + + @functools.cached_property + def chain_entity_id(self) -> np.ndarray: + """Returns the entity ID for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'entity_id', self._atoms.chain_key + ) + + @functools.cached_property + def chain_entity_desc(self) -> np.ndarray: + """Returns the entity description for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'entity_desc', self._atoms.chain_key + ) + + @functools.cached_property + def chain_auth_asym_id(self) -> np.ndarray: + """Returns the chain auth asym ID for each atom in the structure.""" + return self.chains_table.apply_array_to_column( + 'auth_asym_id', self._atoms.chain_key + ) + + @functools.cached_property + def chain_id(self) -> np.ndarray: + chain_index_by_key = self._chains.index_by_key + return self._chains.id[chain_index_by_key[self._atoms.chain_key]] + + @functools.cached_property + def chain_type(self) -> np.ndarray: + chain_index_by_key = self._chains.index_by_key + return self._chains.type[chain_index_by_key[self._atoms.chain_key]] + + @functools.cached_property + def res_id(self) -> np.ndarray: + return self._residues['id', self._atoms.res_key] + + @functools.cached_property + def res_name(self) -> np.ndarray: + return self._residues['name', self._atoms.res_key] + + @functools.cached_property + def res_auth_seq_id(self) -> np.ndarray: + """Returns the residue auth seq ID for each atom in the structure.""" + return self.residues_table.apply_array_to_column( + 'auth_seq_id', self._atoms.res_key + ) + + @functools.cached_property + def res_insertion_code(self) -> np.ndarray: + """Returns the residue insertion code for each atom in the structure.""" + return self.residues_table.apply_array_to_column( + 'insertion_code', self._atoms.res_key + ) + + @property + def atom_key(self) -> np.ndarray: + return self._atoms.key + + @property + def atom_name(self) -> np.ndarray: + return self._atoms.name + + @property + def atom_element(self) -> np.ndarray: + return self._atoms.element + + @property + def atom_x(self) -> np.ndarray: + return self._atoms.x + + @property + def atom_y(self) -> np.ndarray: + return self._atoms.y + + @property + def atom_z(self) -> np.ndarray: + return self._atoms.z + + @property + def atom_b_factor(self) -> np.ndarray: + return self._atoms.b_factor + + @property + def atom_occupancy(self) -> np.ndarray: + return self._atoms.occupancy + + @functools.cached_property + def chain_boundaries(self) -> np.ndarray: + """The indices in the atom fields where each chain begins.""" + return _get_change_indices(self._atoms.chain_key) + + @functools.cached_property + def res_boundaries(self) -> np.ndarray: + """The indices in the atom fields where each residue begins.""" + return _get_change_indices(self._atoms.res_key) + + @functools.cached_property + def present_chains(self) -> structure_tables.Chains: + """Returns table of chains which have at least 1 resolved atom.""" + is_present_mask = np.isin(self._chains.key, self._atoms.chain_key) + return typing.cast(structure_tables.Chains, self._chains[is_present_mask]) + + @functools.cached_property + def present_residues(self) -> structure_tables.Residues: + """Returns table of residues which have at least 1 resolved atom.""" + is_present_mask = np.isin(self._residues.key, self._atoms.res_key) + return typing.cast( + structure_tables.Residues, self._residues[is_present_mask] + ) + + @functools.cached_property + def unresolved_residues(self) -> structure_tables.Residues: + """Returns table of residues which have at least 1 resolved atom.""" + is_unresolved_mask = np.isin( + self._residues.key, self._atoms.res_key, invert=True + ) + return typing.cast( + structure_tables.Residues, self._residues[is_unresolved_mask] + ) + + def __getitem__(self, field: str) -> Any: + """Gets raw field data using field name as a string.""" + if field in TABLE_FIELDS: + return self.get_table(field) + else: + return getattr(self, field) + + def __getstate__(self) -> dict[str, Any]: + """Pickle calls this on dump. + + Returns: + Members with cached properties removed. + """ + cached_props = { + k + for k, v in self.__class__.__dict__.items() + if isinstance(v, functools.cached_property) + } + return {k: v for k, v in self.__dict__.items() if k not in cached_props} + + def __repr__(self): + return ( + f'Structure({self._name}: {self.num_chains} chains, ' + f'{self.num_residues(count_unresolved=False)} residues, ' + f'{self.num_atoms} atoms)' + ) + + @property + def num_atoms(self) -> int: + return self._atoms.size + + def num_residues(self, *, count_unresolved: bool) -> int: + """Returns the number of residues in this Structure. + + Args: + count_unresolved: Whether to include unresolved (empty) residues. + + Returns: + Number of residues in the Structure. + """ + if count_unresolved: + return self._residues.size + else: + return self.present_residues.size + + @property + def num_chains(self) -> int: + return self._chains.size + + @property + def num_models(self) -> int: + """The number of models of this Structure.""" + return self._atoms.num_models + + def _atom_mask(self, entities: Set[str]) -> np.ndarray: + """Boolean label indicating if each atom is from entities or not.""" + mask = np.zeros(self.num_atoms, dtype=bool) + chain_index_by_key = self._chains.index_by_key + for start, end in self.iter_chain_ranges(): + chain_index = chain_index_by_key[self._atoms.chain_key[start]] + chain_type = self._chains.type[chain_index] + mask[start:end] = chain_type in entities + return mask + + @functools.cached_property + def is_protein_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from protein or not.""" + return self._atom_mask(entities={mmcif_names.PROTEIN_CHAIN}) + + @functools.cached_property + def is_dna_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from DNA or not.""" + return self._atom_mask(entities={mmcif_names.DNA_CHAIN}) + + @functools.cached_property + def is_rna_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from RNA or not.""" + return self._atom_mask(entities={mmcif_names.RNA_CHAIN}) + + @functools.cached_property + def is_nucleic_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is a nucleic acid or not.""" + return self._atom_mask(entities=mmcif_names.NUCLEIC_ACID_CHAIN_TYPES) + + @functools.cached_property + def is_ligand_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is a ligand or not.""" + return self._atom_mask(entities=mmcif_names.LIGAND_CHAIN_TYPES) + + @functools.cached_property + def is_water_mask(self) -> np.ndarray: + """Boolean label indicating if each atom is from water or not.""" + return self._atom_mask(entities={mmcif_names.WATER}) + + def iter_atoms(self) -> Iterator[Mapping[str, Any]]: + """Iterates over the atoms in the structure.""" + if self._atoms.size == 0: + return + + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=self._atoms.chain_key[0] + ) + current_chain_key = self._atoms.chain_key[0] + current_res = self._residues.get_row_by_key( + column_name_map=RESIDUE_FIELDS, key=self._atoms.res_key[0] + ) + current_res_key = self._atoms.res_key[0] + for atom_i in range(self._atoms.size): + atom_chain_key = self._atoms.chain_key[atom_i] + atom_res_key = self._atoms.res_key[atom_i] + + if atom_chain_key != current_chain_key: + chain_index = self._chains.index_by_key[atom_chain_key] + current_chain = { + 'chain_id': self._chains.id[chain_index], + 'chain_type': self._chains.type[chain_index], + 'chain_auth_asym_id': self._chains.auth_asym_id[chain_index], + 'chain_entity_id': self._chains.entity_id[chain_index], + 'chain_entity_desc': self._chains.entity_desc[chain_index], + } + current_chain_key = atom_chain_key + if atom_res_key != current_res_key: + res_index = self._residues.index_by_key[atom_res_key] + current_res = { + 'res_id': self._residues.id[res_index], + 'res_name': self._residues.name[res_index], + 'res_auth_seq_id': self._residues.auth_seq_id[res_index], + 'res_insertion_code': self._residues.insertion_code[res_index], + } + current_res_key = atom_res_key + + yield { + 'atom_name': self._atoms.name[atom_i], + 'atom_element': self._atoms.element[atom_i], + 'atom_x': self._atoms.x[..., atom_i], + 'atom_y': self._atoms.y[..., atom_i], + 'atom_z': self._atoms.z[..., atom_i], + 'atom_b_factor': self._atoms.b_factor[..., atom_i], + 'atom_occupancy': self._atoms.occupancy[..., atom_i], + 'atom_key': self._atoms.key[atom_i], + **current_res, + **current_chain, + } + + def iter_residues( + self, + include_unresolved: bool = False, + ) -> Iterator[Mapping[str, Any]]: + """Iterates over the residues in the structure.""" + res_table = self._residues if include_unresolved else self.present_residues + if res_table.size == 0: + return + + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=res_table.chain_key[0] + ) + current_chain_key = res_table.chain_key[0] + for res_i in range(res_table.size): + res_chain_key = res_table.chain_key[res_i] + + if res_chain_key != current_chain_key: + current_chain = self._chains.get_row_by_key( + column_name_map=CHAIN_FIELDS, key=res_table.chain_key[res_i] + ) + current_chain_key = res_chain_key + + row = { + 'res_id': res_table.id[res_i], + 'res_name': res_table.name[res_i], + 'res_auth_seq_id': res_table.auth_seq_id[res_i], + 'res_insertion_code': res_table.insertion_code[res_i], + } + yield row | current_chain + + def _iter_atom_ranges( + self, boundaries: Sequence[int] + ) -> Iterator[tuple[int, int]]: + """Iterator for (start, end) pairs from an array of start indices.""" + yield from itertools.pairwise(boundaries) + # Use explicit length test as boundaries can be a NumPy array. + if len(boundaries) > 0: # pylint: disable=g-explicit-length-test + yield boundaries[-1], self.num_atoms + + def _iter_residue_ranges( + self, + boundaries: Sequence[int], + *, + count_unresolved: bool, + ) -> Iterator[tuple[int, int]]: + """Iterator for (start, end) pairs from an array of start indices.""" + yield from itertools.pairwise(boundaries) + # Use explicit length test as boundaries can be a NumPy array. + if len(boundaries) > 0: # pylint: disable=g-explicit-length-test + yield boundaries[-1], self.num_residues(count_unresolved=count_unresolved) + + def iter_chain_ranges(self) -> Iterator[tuple[int, int]]: + """Iterates pairs of (chain_start, chain_end) indices. + + Yields: + Pairs of (start, end) indices for each chain, where end is not inclusive. + i.e. struc.chain_id[start:end] would be a constant array with length + equal to the number of atoms in the chain. + """ + yield from self._iter_atom_ranges(self.chain_boundaries) + + def iter_residue_ranges(self) -> Iterator[tuple[int, int]]: + """Iterates pairs of (residue_start, residue_end) indices. + + Yields: + Pairs of (start, end) indices for each residue, where end is not + inclusive. i.e. struc.res_id[start:end] would be a constant array with + length equal to the number of atoms in the residue. + """ + yield from self._iter_atom_ranges(self.res_boundaries) + + def iter_chains(self) -> Iterator[Mapping[str, Any]]: + """Iterates over the chains in the structure.""" + for chain_i in range(self.present_chains.size): + yield { + 'chain_id': self.present_chains.id[chain_i], + 'chain_type': self.present_chains.type[chain_i], + 'chain_auth_asym_id': self.present_chains.auth_asym_id[chain_i], + 'chain_entity_id': self.present_chains.entity_id[chain_i], + 'chain_entity_desc': self.present_chains.entity_desc[chain_i], + } + + def iter_bonds(self) -> Iterator[Bond]: + """Iterates over the atoms and bond information. + + Example usage: + + ``` + for from_atom, dest_atom, bond_info in struc.iter_bonds(): + print( + f'From atom: name={from_atom["atom_name"]}, ' + f'chain={from_atom["chain_id"]}, ...' + ) + # Same for dest_atom + print(f'Bond info: type={bond_info["type"]}, role={bond_info["role"]}') + ``` + + Yields: + A `Bond` NamedTuple for each bond in the bonds table. + These have fields `from_atom`, `dest_atom`, `bond_info` where each + is a dictionary. The first two have the same keys as the atom dicts + returned by self.iter_atoms() -- i.e. one key per non-None field. + The final dict has the same keys as self.bonds.iterrows() -- i.e. one + key per column in the bonds table. + """ + from_atom_iter = self._atoms.iterrows( + row_keys=self._bonds.from_atom_key, + column_name_map=ATOM_FIELDS, + chain_key=self._chains.with_column_names(CHAIN_FIELDS), + res_key=self._residues.with_column_names(RESIDUE_FIELDS), + ) + dest_atom_iter = self._atoms.iterrows( + row_keys=self._bonds.dest_atom_key, + column_name_map=ATOM_FIELDS, + chain_key=self._chains.with_column_names(CHAIN_FIELDS), + res_key=self._residues.with_column_names(RESIDUE_FIELDS), + ) + + for from_atom, dest_atom, bond_info in zip( + from_atom_iter, dest_atom_iter, self._bonds.iterrows(), strict=True + ): + yield Bond(from_atom=from_atom, dest_atom=dest_atom, bond_info=bond_info) + + def _apply_atom_index_array( + self, + index_arr: np.ndarray, + chain_boundaries: np.ndarray | None = None, + res_boundaries: np.ndarray | None = None, + skip_validation: bool = False, + ) -> Self: + """Applies index_arr to the atom table using NumPy-style array indexing. + + Args: + index_arr: A 1D NumPy array that will be used to index into the atoms + table. This can either be a boolean array to act as a mask, or an + integer array to perform a gather operation. + chain_boundaries: Unused in structure v2. + res_boundaries: Unused in structure v2. + skip_validation: Whether to skip the validation step that checks internal + consistency after applying atom index array. Do not set to True unless + you are certain the transform is safe, e.g. when the order of atoms is + guaranteed to not change. + + Returns: + A new Structure with an updated atoms table. + """ + del chain_boundaries, res_boundaries + + if index_arr.ndim != 1: + raise ValueError( + f'index_arr must be a 1D NumPy array, but has shape {index_arr.shape}' + ) + + if index_arr.dtype == bool and np.all(index_arr): + # Shortcut: The operation is a no-op, so just return itself. + return self + + atoms = structure_tables.Atoms( + **{col: self._atoms[col][..., index_arr] for col in self._atoms.columns} + ) + updated_tables = self._cascade_delete(atoms=atoms) + return self.copy_and_update( + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + skip_validation=skip_validation, + ) + + @property + def group_by_residue(self) -> Self: + """Returns a Structure with one atom per residue. + + e.g. restypes = struc.group_by_residue['res_id'] + + Returns: + A new Structure with one atom per residue such that per-atom arrays + such as res_name (i.e. Structure v1 fields) have one element per residue. + """ + # This use of _apply_atom_index_array is safe because the chain/residue/atom + # ordering won't change (essentially applying a residue start mask). + return self._apply_atom_index_array( + self.res_boundaries, skip_validation=True + ) + + @property + def group_by_chain(self) -> Self: + """Returns a Structure where all fields are per-chain. + + e.g. chains = struc.group_by_chain['chain_id'] + + Returns: + A new Structure with one atom per chain such that per-atom arrays + such as res_name (i.e. Structure v1 fields) have one element per chain. + """ + # This use of _apply_atom_index_array is safe because the chain/residue/atom + # ordering won't change (essentially applying a chain start mask). + return self._apply_atom_index_array( + self.chain_boundaries, skip_validation=True + ) + + @property + def with_sorted_chains(self) -> Self: + """Returns a new structure with the chains are in reverse spreadsheet style. + + This is the usual order to write chains in an mmCIF: + (A < B < ... < AA < BA < CA < ... < AB < BB < CB ...) + + NB: this method will fail if chains do not conform to this mmCIF naming + convention. + + Only to be used for third party metrics that rely on the chain order. + Elsewhere chains should be identified by name and code should be agnostic to + the order. + """ + sorted_chains = sorted(self.chains, key=mmcif.str_id_to_int_id) + return self.reorder_chains(new_order=sorted_chains) + + @functools.cached_property + def atom_ids(self) -> Sequence[tuple[str, str, None, str]]: + """Gets a list of atom ID tuples from Structure class arrays. + + Returns: + A list of tuples of (chain_id, res_id, insertion_code, atom_name) where + insertion code is always None. There is one element per atom, and the + list is ordered according to the order of atoms in the input arrays. + """ + # Convert to Numpy strings, then to Python strings (dtype=object). + res_ids = self.residues_table.id.astype(str).astype(object) + res_ids = res_ids[ + self.residues_table.index_by_key[self.atoms_table.res_key] + ] + ins_codes = [None] * self.num_atoms + return list( + zip(self.chain_id, res_ids, ins_codes, self.atom_name, strict=True) + ) + + def order_and_drop_atoms_to_match( + self, + other: 'Structure', + *, + allow_missing_atoms: bool = False, + ) -> Self: + """Returns a new structure with atoms ordered & dropped to match another's. + + This performs two operations simultaneously: + * Ordering the atoms in this structure to match the order in the other. + * Dropping atoms in this structure that do not appear in the other. + + Example: + Consider a prediction and ground truth with the following atoms, described + using tuples of `(chain_id, res_id, atom_name)`: + * `prediction: [(A, 1, CA), (A, 1, N), (A, 2, CA), (B, 1, CA)]` + * `ground_truth: [(B, 1, CA), (A, 1, N), (A, 1, CA)]` + Note how the ground truth is missing the `(A, 2, CA)` atom and also + has the atoms in a different order. This method returns a modified + prediction that has reordered atoms and without any atoms not in the ground + truth so that its atom list looks the same as the ground truth atom list. + This means `prediction.coords` and `ground_truth.coords` now have the + same shape and can be compared across the atom dimension. + + Note that matching residues with no atoms and matching chains with no + residues will also be kept. E.g. in the example above, if prediction and + ground truth both had an unresolved residue (A, 3), the output structure + will also have an unresolved residue (A, 3). + + Args: + other: Another `Structure`. This provides the reference ordering that is + used to sort this structure's atom arrays. + allow_missing_atoms: Whether to skip atoms present in `other` but not this + structure and return a structure containing a subset of the atoms in the + other structure. + + Returns: + A new `Structure`, based on this structure, which, if + `allow_missing_atoms` is False, contains exactly the same atoms as in + the `other` structure and which matches the `other` structure in terms + of the order of the atoms in the field arrays. Otherwise, if missing + atoms are allowed then the resulting structure contains a subset of + those atoms in the other structure. + + Raises: + MissingAtomError: If there are atoms present in the other structure that + cannot be found in this structure. + """ + atom_index_map = {atom_id: i for i, + atom_id in enumerate(self.atom_ids)} + try: + if allow_missing_atoms: + # Only include atoms that were found in the other structure. + atom_indices = [ + atom_index + for atom_id in other.atom_ids + if (atom_index := atom_index_map.get(atom_id)) is not None + ] + else: + atom_indices = [ + atom_index_map[atom_id] # Hard fail on missing. + for atom_id in other.atom_ids + ] + except KeyError as e: + if len(e.args[0]) == 4: + chain_id, res_id, ins_code, atom_name = e.args[0] + raise MissingAtomError( + f'No atom in this structure (name: {self._name}) matches atom in ' + f'other structure (name: {other.name}) with internal (label) chain ' + f'ID {chain_id}, residue ID {res_id}, insertion code {ins_code} ' + f'and atom name {atom_name}.' + ) from e + else: + raise + + def _iter_residues(struc: Self) -> Iterable[tuple[str, str]]: + yield from zip( + struc.chains_table['id', struc.residues_table.chain_key], + struc.residues_table.id, + strict=True, + ) + + chain_index_map = { + chain_id: i for i, chain_id in enumerate(self._chains.id) + } + chain_indices = [ + chain_index + for chain_id in other.chains_table.id + if (chain_index := chain_index_map.get(chain_id)) is not None + ] + residue_index_map = { + res_id: i for i, res_id in enumerate(_iter_residues(self)) + } + res_indices = [ + residue_index + for res_id in _iter_residues(other) + if (residue_index := residue_index_map.get(res_id)) is not None + ] + + # Reorder all tables. + chains = self._chains.apply_index( + np.array(chain_indices, dtype=np.int64)) + residues = self._residues.apply_index( + np.array(res_indices, dtype=np.int64)) + atoms = self._atoms.apply_index(np.array(atom_indices, dtype=np.int64)) + + # Get chain keys in the order they appear in the atoms table. + new_chain_boundaries = _get_change_indices(atoms.chain_key) + new_chain_key_order = atoms.chain_key[new_chain_boundaries] + if len(new_chain_key_order) != len(set(new_chain_key_order)): + raise ValueError( + f'Chain keys not contiguous after reordering: {new_chain_key_order}' + ) + + # Get residue keys in the order they appear in the atoms table. + new_res_boundaries = _get_change_indices(atoms.res_key) + new_res_key_order = atoms.res_key[new_res_boundaries] + if len(new_res_key_order) != len(set(new_res_key_order)): + raise ValueError( + f'Residue keys not contiguous after reordering: {new_res_key_order}' + ) + + # If any atoms were deleted, propagate that into the bonds table. + updated_tables = self._cascade_delete( + chains=chains, + residues=residues, + atoms=atoms, + ) + return self.copy_and_update( + chains=chains, + residues=residues, + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + ) + + def copy_and_update( + self, + *, + name: str | Literal[_UNSET] = _UNSET, + release_date: datetime.date | None | Literal[_UNSET] = _UNSET, + resolution: float | None | Literal[_UNSET] = _UNSET, + structure_method: str | None | Literal[_UNSET] = _UNSET, + bioassembly_data: ( + bioassemblies.BioassemblyData | None | Literal[_UNSET] + ) = _UNSET, + chemical_components_data: ( + struc_chem_comps.ChemicalComponentsData | None | Literal[_UNSET] + ) = _UNSET, + chains: structure_tables.Chains | None | Literal[_UNSET] = _UNSET, + residues: structure_tables.Residues | None | Literal[_UNSET] = _UNSET, + atoms: structure_tables.Atoms | None | Literal[_UNSET] = _UNSET, + bonds: structure_tables.Bonds | None | Literal[_UNSET] = _UNSET, + skip_validation: bool = False, + ) -> Self: + """Performs a shallow copy but with specified fields updated.""" + + def all_unset(fields): + return all(field == _UNSET for field in fields) + + if all_unset((chains, residues, atoms, bonds)): + if all_unset(( + name, + release_date, + resolution, + structure_method, + bioassembly_data, + chemical_components_data, + )): + raise ValueError( + 'Unnecessary call to copy_and_update with no changes. As Structure' + ' and its component tables are immutable, there is no need to copy' + ' it. Any subsequent operation that modifies structure will return' + ' a new object.' + ) + else: + raise ValueError( + 'When only changing global fields, prefer to use the specialised ' + 'copy_and_update_globals.' + ) + + def select(field, default): + return field if field != _UNSET else default + + return Structure( + name=select(name, self.name), + release_date=select(release_date, self.release_date), + resolution=select(resolution, self.resolution), + structure_method=select(structure_method, self.structure_method), + bioassembly_data=select(bioassembly_data, self.bioassembly_data), + chemical_components_data=select( + chemical_components_data, self.chemical_components_data + ), + chains=select(chains, self._chains), + residues=select(residues, self._residues), + atoms=select(atoms, self._atoms), + bonds=select(bonds, self._bonds), + skip_validation=skip_validation, + ) + + def _copy_and_update( + self, skip_validation: bool = False, **changes: Any + ) -> Self: + """Performs a shallow copy but with specified fields updated.""" + if not changes: + raise ValueError( + 'Unnecessary call to copy_and_update with no changes. As Structure ' + 'and its component tables are immutable, there is no need to copy ' + 'it. Any subsequent operation that modifies structure will return a ' + 'new object.' + ) + + if 'author_naming_scheme' in changes: + raise ValueError( + 'Updating using author_naming_scheme is not supported. Update ' + 'auth_asym_id, entity_id, entity_desc fields directly in the chains ' + 'table and auth_seq_id, insertion_code in the residues table.' + ) + + if all(k in GLOBAL_FIELDS for k in changes): + raise ValueError( + 'When only changing global fields, prefer to use the specialised ' + 'copy_and_update_globals.' + ) + + if all(k in V2_FIELDS for k in changes): + constructor_kwargs = {field: self[field] for field in V2_FIELDS} + constructor_kwargs.update(changes) + elif any(k in ('atoms', 'residues', 'chains') for k in changes): + raise ValueError( + 'Cannot specify atoms/chains/residues table changes with non-v2' + f' constructor params: {changes.keys()}' + ) + elif all(k in ATOM_FIELDS for k in changes): + if 'atom_key' not in changes: + raise ValueError( + 'When only changing atom fields, prefer to use the specialised ' + 'copy_and_update_atoms.' + ) + # Only atom fields are being updated, do that directly on the atoms table. + updated_atoms = self._atoms.copy_and_update( + **{ATOM_FIELDS[k]: v for k, v in changes.items()} + ) + constructor_kwargs = { + field: self[field] for field in V2_FIELDS if field != 'atoms' + } + constructor_kwargs['atoms'] = updated_atoms + else: + constructor_kwargs = {field: self[field] + for field in _UPDATEABLE_FIELDS} + constructor_kwargs.update(changes) + return Structure(skip_validation=skip_validation, **constructor_kwargs) + + def copy_and_update_coords(self, coords: np.ndarray) -> Self: + """Performs a shallow copy but with coordinates updated.""" + if coords.shape[-2:] != (self.num_atoms, 3): + raise ValueError( + f'{coords.shape=} does not have last dimensions ({self.num_atoms}, 3)' + ) + updated_atoms = self._atoms.copy_and_update_coords(coords) + return self.copy_and_update(atoms=updated_atoms, skip_validation=True) + + def copy_and_update_from_res_arrays(self, **changes: np.ndarray) -> Self: + """Like copy_and_update but changes are arrays of length num_residues. + + These changes are first scattered into arrays of length num_atoms such + that each value is repeated across the residue at that index, then they + are used as the new values of these fields. + + E.g. + * This structure's res_id: 1, 1, 1, 2, 3, 3 (3 res, 6 atoms) + * new atom_b_factor: 7, 8, 9 + * Returned structure's atom_b_factor: 7, 7, 7, 8, 9, 9 + + Args: + **changes: kwargs corresponding to atom array fields, e.g. atom_x or + atom_b_factor, but with length num_residues rather than num_atoms. Note + that changing atom_key this way is is not supported. + + Returns: + A new `Structure` with all fields other than those specified as kwargs + shallow copied from this structure. The values of the kwargs are + scattered across the atom arrays and then used to overwrite these + fields for the returned structure. + """ + # We create scatter indices by (1) starting from zeros, then (2) setting + # the position where each residue starts to 1 and then (3) doing a + # cumulative sum. Finally, since self.res_boundaries always starts with 0 + # the result of the cumulative sum will start from 1, so (4) we subtract + # 1 to get the final array of zero-based indices. + # Example, 6 atoms, 3 residues at indices 0, 2 and 5. + # (1) 0 0 0 0 0 0 + # (2) 1 0 1 0 0 1 + # (3) 1 1 2 2 2 3 + # (4) 0 0 1 1 1 2 + if not all(c in set(ATOM_FIELDS) - {'atom_key'} for c in changes): + raise ValueError( + 'Changes must only be to atom fields, got changes to' + f' {changes.keys()}' + ) + scatter_idxs = np.zeros((self.num_atoms,), dtype=int) + scatter_idxs[self.res_boundaries] = 1 + scatter_idxs = scatter_idxs.cumsum() - 1 + atom_array_changes = { + ATOM_FIELDS[field]: new_val[scatter_idxs] + for field, new_val in changes.items() + } + updated_atoms = self._atoms.copy_and_update(**atom_array_changes) + return self.copy_and_update(atoms=updated_atoms, skip_validation=True) + + def copy_and_update_globals( + self, + *, + name: str | Literal[_UNSET] = _UNSET, + release_date: datetime.date | Literal[_UNSET] | None = _UNSET, + resolution: float | Literal[_UNSET] | None = _UNSET, + structure_method: str | Literal[_UNSET] | None = _UNSET, + bioassembly_data: ( + bioassemblies.BioassemblyData | Literal[_UNSET] | None + ) = _UNSET, + chemical_components_data: ( + struc_chem_comps.ChemicalComponentsData | Literal[_UNSET] | None + ) = _UNSET, + ) -> Self: + """Returns a shallow copy with the global columns updated.""" + + def select(field, default): + return field if field != _UNSET else default + + name = select(name, self.name) + release_date = select(release_date, self.release_date) + resolution = select(resolution, self.resolution) + structure_method = select(structure_method, self.structure_method) + bioassembly_data = select(bioassembly_data, self.bioassembly_data) + chem_data = select(chemical_components_data, + self.chemical_components_data) + + return Structure( + name=name, + release_date=release_date, + resolution=resolution, + structure_method=structure_method, + bioassembly_data=bioassembly_data, + chemical_components_data=chem_data, + atoms=self._atoms, + residues=self._residues, + chains=self._chains, + bonds=self._bonds, + ) + + def copy_and_update_atoms( + self, + *, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, + ) -> Self: + """Returns a shallow copy with the atoms table updated.""" + new_atoms = structure_tables.Atoms( + key=self._atoms.key, + res_key=self._atoms.res_key, + chain_key=self._atoms.chain_key, + name=atom_name if atom_name is not None else self.atom_name, + element=atom_element if atom_element is not None else self.atom_element, + x=atom_x if atom_x is not None else self.atom_x, + y=atom_y if atom_y is not None else self.atom_y, + z=atom_z if atom_z is not None else self.atom_z, + b_factor=( + atom_b_factor if atom_b_factor is not None else self.atom_b_factor + ), + occupancy=( + atom_occupancy + if atom_occupancy is not None + else self.atom_occupancy + ), + ) + return self.copy_and_update(atoms=new_atoms) + + def _cascade_delete( + self, + *, + chains: structure_tables.Chains | None = None, + residues: structure_tables.Residues | None = None, + atoms: structure_tables.Atoms | None = None, + bonds: structure_tables.Bonds | None = None, + ) -> StructureTables: + """Performs a cascade delete operation on the structure's tables. + + Cascade delete ensures all the tables are consistent after any table fields + are being updated by cascading any deletions down the hierarchy of tables: + chains > residues > atoms > bonds. + + E.g.: if a row from residues table is removed then all the atoms in that + residue will also be removed from the atoms table. In turn this cascades + also to the bond table, by removing any bond row which involves any of those + removed atoms. However the chains table will not be modified, even if + that was the only residue in its chain, because the chains table is above + the residues table in the hierarchy. + + Args: + chains: An optional new chains table. + residues: An optional new residues table. + atoms: An optional new atoms table. + bonds: An optional new bonds table. + + Returns: + A StructureTables object with the updated tables. + """ + if chains_unchanged := chains is None: + chains = self._chains + if residues_unchanged := residues is None: + residues = self._residues + if atoms_unchanged := atoms is None: + atoms = self._atoms + if bonds is None: + bonds = self._bonds + + if not chains_unchanged: + residues_mask = membership.isin(residues.chain_key, set( + chains.key)) # pylint:disable=attribute-error + if not np.all(residues_mask): # Only apply if this is not a no-op. + residues = residues[residues_mask] + residues_unchanged = False + if not residues_unchanged: + atoms_mask = membership.isin(atoms.res_key, set( + residues.key)) # pylint:disable=attribute-error + if not np.all(atoms_mask): # Only apply if this is not a no-op. + atoms = atoms[atoms_mask] + atoms_unchanged = False + if not atoms_unchanged: + bonds = bonds.restrict_to_atoms(atoms.key) + return StructureTables( + chains=chains, residues=residues, atoms=atoms, bonds=bonds + ) + + def filter( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + invert: bool = False, + cascade_delete: CascadeDelete = CascadeDelete.CHAINS, + **predicate_by_field_name: table.FilterPredicate, + ) -> Self: + """Filters the structure by field values and returns a new structure. + + Predicates are specified as keyword arguments, with names following the + pattern: _, where table_name := (chain|res|atom). + For instance the auth_seq_id column in the residues table can be filtered + by passing `res_auth_seq_id=pred_value`. The full list of valid options + are defined in the `col_by_field_name` fields on the different Table + dataclasses. + + Predicate values can be either: + 1. A constant value, e.g. 'CA'. In this case then only rows that match + this value for the given field are retained. + 2. A (non-string) iterable e.g. ('A', 'B'). In this + case then rows are retained if they match any of the provided values for + the given field. + 3. A boolean function e.g. lambda b_fac: b_fac < 100.0. + In this case then only rows that evaluate to True are retained. By + default this function's parameter is expected to be an array, unless + apply_per_element=True. + + Example usage: + # Filter to backbone atoms in residues up to 100 in chain B. + filtered_struc = struc.filter( + chain_id='B', + atom_name=('N', 'CA', 'C'), + res_id=lambda res_id: res_id < 100) + + Example usage where predicate must be applied per-element: + # Filter to residues with IDs in either [1, 100) or [300, 400). + ranges = ((1, 100), (300, 400)) + filtered_struc = struc.filter( + res_id=lambda i: np.any([start <= i < end for start, end in ranges]), + apply_per_element=True) + + Example usage of providing a raw mask: + filtered_struc = struc.filter(struc.atom_b_factor < 10.0) + + Args: + mask: An optional boolean NumPy array with length equal to num_atoms. If + provided then this will be combined with the other predicates so that an + atom is included if it is masked-in *and* matches all the predicates. + apply_per_element: Whether apply predicates to each element individually, + or to pass the whole column array to the predicate. + invert: Whether to remove, rather than retain, the entities which match + the specified predicates. + cascade_delete: Whether to remove residues and chains which are left + unresolved in a cascade. filter operates on the atoms table, removing + atoms which match the predicate. If all atoms in a residue are removed, + the residue is "unresolved". The value of this argument then determines + whether such residues and their parent chains should be deleted. FULL + implies that all unresolved residues should be deleted, and any chains + which are left with no resolved residues should be deleted. CHAINS is + the default behaviour - only chains with no resolved residues, and their + child residues are deleted. Unresolved residues in partially resolved + chains remain. NONE implies that no unresolved residues or chains should + be deleted. + **predicate_by_field_name: A mapping from field name to a predicate. + Filtered columns must be 1D arrays. If multiple fields are provided as + keyword arguments then each predicate is applied and the results are + combined using a boolean AND operation, so an atom is only retained if + it passes all predicates. + + Returns: + A new structure representing a filtered version of the current structure. + + Raises: + ValueError: If mask is provided and is not a bool array with shape + (num_atoms,). + """ + chain_predicates, res_predicates, atom_predicates = ( + _unpack_filter_predicates(predicate_by_field_name) + ) + # Get boolean masks for each table. These are None if none of the filter + # parameters affect the table in question. + chain_mask = self._chains.make_filter_mask( + **chain_predicates, apply_per_element=apply_per_element + ) + res_mask = self._residues.make_filter_mask( + **res_predicates, apply_per_element=apply_per_element + ) + atom_mask = self._atoms.make_filter_mask( + mask, **atom_predicates, apply_per_element=apply_per_element + ) + if atom_mask is None: + atom_mask = np.ones((self._atoms.size,), dtype=bool) + + # Remove atoms that belong to filtered out chains. + if chain_mask is not None: + atom_chain_mask = membership.isin( + self._atoms.chain_key, set(self._chains.key[chain_mask]) + ) + np.logical_and(atom_mask, atom_chain_mask, out=atom_mask) + + # Remove atoms that belong to filtered out residues. + if res_mask is not None: + atom_res_mask = membership.isin( + self._atoms.res_key, set(self._residues.key[res_mask]) + ) + np.logical_and(atom_mask, atom_res_mask, out=atom_mask) + + final_atom_mask = ~atom_mask if invert else atom_mask + + if cascade_delete == CascadeDelete.NONE and np.all(final_atom_mask): + # Shortcut: The filter is a no-op, so just return itself. + return self + + filtered_atoms = typing.cast( + structure_tables.Atoms, self._atoms[final_atom_mask] + ) + + match cascade_delete: + case CascadeDelete.FULL: + nonempty_residues_mask = np.isin( + self._residues.key, filtered_atoms.res_key + ) + filtered_residues = self._residues[nonempty_residues_mask] + nonempty_chain_mask = np.isin( + self._chains.key, filtered_atoms.chain_key + ) + filtered_chains = self._chains[nonempty_chain_mask] + updated_tables = self._cascade_delete( + chains=filtered_chains, + residues=filtered_residues, + atoms=filtered_atoms, + ) + case CascadeDelete.CHAINS: + # To match v1 behavior we remove chains that have no atoms remaining, + # and we remove residues in those chains. + # NB we do not remove empty residues. + nonempty_chain_mask = membership.isin( + self._chains.key, set(filtered_atoms.chain_key) + ) + filtered_chains = self._chains[nonempty_chain_mask] + updated_tables = self._cascade_delete( + chains=filtered_chains, atoms=filtered_atoms + ) + case CascadeDelete.NONE: + updated_tables = self._cascade_delete(atoms=filtered_atoms) + case _: + raise ValueError( + f'Unknown cascade_delete behaviour: {cascade_delete}') + return self.copy_and_update( + chains=updated_tables.chains, + residues=updated_tables.residues, + atoms=updated_tables.atoms, + bonds=updated_tables.bonds, + skip_validation=True, + ) + + def filter_out(self, *args, **kwargs) -> Self: + """Returns a new structure with the specified elements removed.""" + return self.filter(*args, invert=True, **kwargs) + + def filter_to_entity_type( + self, + *, + protein: bool = False, + rna: bool = False, + dna: bool = False, + dna_rna_hybrid: bool = False, + ligand: bool = False, + water: bool = False, + ) -> Self: + """Filters the structure to only include the selected entity types. + + This convenience method abstracts away the specifics of mmCIF entity + type names which, especially for ligands, are non-trivial. + + Args: + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + dna_rna_hybrid: Whether to include DNA RNA hybrid chains. + ligand: Whether to include ligand (i.e. not polymer) chains. + water: Whether to include water chains. + + Returns: + The filtered structure. + """ + include_types = [] + if protein: + include_types.append(mmcif_names.PROTEIN_CHAIN) + if rna: + include_types.append(mmcif_names.RNA_CHAIN) + if dna: + include_types.append(mmcif_names.DNA_CHAIN) + if dna_rna_hybrid: + include_types.append(mmcif_names.DNA_RNA_HYBRID_CHAIN) + if ligand: + include_types.extend(mmcif_names.LIGAND_CHAIN_TYPES) + if water: + include_types.append(mmcif_names.WATER) + return self.filter(chain_type=include_types) + + def get_stoichiometry( + self, *, fix_non_standard_polymer_res: bool = False + ) -> Sequence[int]: + """Returns the structure's stoichiometry using chain_res_name_sequence. + + Note that everything is considered (protein, RNA, DNA, ligands) except for + water molecules. If you are interested only in a certain type of entities, + filter them out before calling this method. + + Args: + fix_non_standard_polymer_res: If True, maps non standard residues in + protein / RNA / DNA chains to standard residues (e.g. MSE -> MET) or UNK + / N if a match is not found. + + Returns: + A list of integers, one for each unique chain in the structure, + determining the number of that chain appearing in the structure. The + numbers are sorted highest to lowest. E.g. for an A3B2 protein this method + will return [3, 2]. + """ + filtered = self.filter_to_entity_type( + protein=True, + rna=True, + dna=True, + dna_rna_hybrid=True, + ligand=True, + water=False, + ) + seqs = filtered.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fix_non_standard_polymer_res, + ) + + unique_seq_counts = collections.Counter(seqs.values()) + return sorted(unique_seq_counts.values(), reverse=True) + + def without_hydrogen(self) -> Self: + """Returns the structure without hydrogen atoms.""" + return self.filter( + np.logical_and(self._atoms.element != 'H', + self._atoms.element != 'D') + ) + + def without_terminal_oxygens(self) -> Self: + """Returns the structure without terminal oxygen atoms.""" + terminal_oxygen_filter = np.zeros(self.num_atoms, dtype=bool) + for chain_type, atom_name in mmcif_names.TERMINAL_OXYGENS.items(): + chain_keys = self._chains.key[self._chains.type == chain_type] + chain_atom_filter = np.logical_and( + self._atoms.name == atom_name, + np.isin(self._atoms.chain_key, chain_keys), + ) + np.logical_or( + terminal_oxygen_filter, chain_atom_filter, out=terminal_oxygen_filter + ) + return self.filter_out(terminal_oxygen_filter) + + def reset_author_naming_scheme(self) -> Self: + """Remove author chain/residue ids, entity info and use internal ids.""" + new_chains = structure_tables.Chains( + key=self._chains.key, + id=self._chains.id, + type=self._chains.type, + auth_asym_id=self._chains.id, + entity_id=np.arange(1, self.num_chains + + 1).astype(str).astype(object), + entity_desc=np.full(self.num_chains, '.', dtype=object), + ) + new_residues = structure_tables.Residues( + key=self._residues.key, + chain_key=self._residues.chain_key, + id=self._residues.id, + name=self._residues.name, + auth_seq_id=self._residues.id.astype(str).astype(object), + insertion_code=np.full( + self.num_residues(count_unresolved=True), '?', dtype=object + ), + ) + return self.copy_and_update( + chains=new_chains, residues=new_residues, skip_validation=True + ) + + def filter_residues(self, res_mask: np.ndarray) -> Self: + """Filter resolved residues using a boolean mask.""" + required_shape = (self.num_residues(count_unresolved=False),) + if res_mask.shape != required_shape: + raise ValueError( + f'res_mask must have shape {required_shape}. Got: {res_mask.shape}.' + ) + if res_mask.dtype != bool: + raise ValueError( + f'res_mask must have dtype bool. Got: {res_mask.dtype}.') + + filtered_residues = self.present_residues.filter(res_mask) + atom_mask = np.isin(self._atoms.res_key, filtered_residues.key) + return self.filter(atom_mask) + + def filter_coords( + self, coord_predicate: Callable[[np.ndarray], bool] + ) -> Self: + """Filter a structure's atoms by a function of their coordinates. + + Args: + coord_predicate: A boolean function of coordinate vectors (shape (3,)). + + Returns: + A Structure filtered so that only atoms with coords passing the predicate + function are present. + + Raises: + ValueError: If the coords are not shaped (num_atom, 3). + """ + coords = self.coords + if coords.ndim != 2 or coords.shape[-1] != 3: + raise ValueError( + f'coords should have shape (num_atom, 3). Got {coords.shape}.' + ) + mask = np.vectorize(coord_predicate, signature='(n)->()')(coords) + # This use of _apply_atom_index_array is safe because a boolean mask is + # used, which means the chain/residue/atom ordering will stay unchanged. + return self._apply_atom_index_array(mask, skip_validation=True) + + def filter_polymers_to_single_atom_per_res( + self, + representative_atom_by_chain_type: Mapping[ + str, str + ] = mmcif_names.RESIDUE_REPRESENTATIVE_ATOMS, + ) -> Self: + """Filter to one representative atom per polymer residue, ligands unchanged. + + Args: + representative_atom_by_chain_type: Chain type str to atom name, only atoms + with this name will be kept for this chain type. Chains types from the + structure not found in this mapping will keep all their atoms. + + Returns: + A Structure filtered so that per chain types, only specified atoms are + present. + """ + polymer_chain_keys = self._chains.key[ + string_array.isin( + self._chains.type, set(representative_atom_by_chain_type) + ) + ] + polymer_atoms_mask = np.isin(self._atoms.chain_key, polymer_chain_keys) + + wanted_atom_by_chain_key = { + chain_key: representative_atom_by_chain_type.get(chain_type, None) + for chain_key, chain_type in zip(self._chains.key, self._chains.type) + } + wanted_atoms = string_array.remap( + self._atoms.chain_key.astype(object), mapping=wanted_atom_by_chain_key + ) + + representative_polymer_atoms_mask = polymer_atoms_mask & ( + wanted_atoms == self._atoms.name + ) + + return self.filter(representative_polymer_atoms_mask | ~polymer_atoms_mask) + + def drop_non_standard_protein_atoms(self, *, drop_oxt: bool = True) -> Self: + """Drops non-standard atom names from protein chains. + + Args: + drop_oxt: If True, also drop terminal oxygens (OXT). + + Returns: + A new Structure object where the protein chains have been filtered to + only contain atoms with names listed in `atom_types` + (including OXT unless `drop_oxt` is `True`). Non-protein chains are + unaltered. + """ + allowed_names = set(atom_types.ATOM37) + if drop_oxt: + allowed_names = {n for n in allowed_names if n != atom_types.OXT} + + return self.filter_out( + chain_type=mmcif_names.PROTEIN_CHAIN, + atom_name=lambda n: string_array.isin( + n, allowed_names, invert=True), + ) + + def drop_non_standard_atoms( + self, + *, + ccd: chemical_components.Ccd, + drop_unk: bool, + drop_non_ccd: bool, + drop_terminal_oxygens: bool = False, + ) -> Self: + """Drops atoms that are not in the CCD for the given residue type.""" + + # We don't remove any atoms in UNL, as it has no standard atoms. + def _keep(atom_index: int) -> bool: + atom_name = self._atoms.name[atom_index] + res_name = self._residues.name[ + self._residues.index_by_key[self._atoms.res_key[atom_index]] + ] + if drop_unk and res_name in residue_names.UNKNOWN_TYPES: + return False + else: + return ( + (not drop_non_ccd and not ccd.get(res_name)) + or atom_name in struc_chem_comps.get_res_atom_names(ccd, res_name) + or res_name == residue_names.UNL + ) + + standard_atom_mask = np.array( + [_keep(atom_i) for atom_i in range(self.num_atoms)], dtype=bool + ) + standard_atoms = self.filter(mask=standard_atom_mask) + if drop_terminal_oxygens: + standard_atoms = standard_atoms.without_terminal_oxygens() + return standard_atoms + + def find_chains_with_unknown_sequence(self) -> Sequence[str]: + """Returns a sequence of chain IDs that contain only unknown residues.""" + unknown_sequences = [] + for start, end in self.iter_chain_ranges(): + try: + unknown_id = residue_names.UNKNOWN_TYPES.index( + self.res_name[start]) + if start + 1 == end or np.all( + self.res_name[start + 1: end] + == residue_names.UNKNOWN_TYPES[unknown_id] + ): + unknown_sequences.append(self.chain_id[start]) + except ValueError: + pass + return unknown_sequences + + def add_bonds( + self, + bonded_atom_pairs: Sequence[ + tuple[tuple[str, int, str], tuple[str, int, str]], + ], + bond_type: str | None = None, + ) -> Self: + """Returns a structure with new bonds added. + + Args: + bonded_atom_pairs: A sequence of pairs of atoms, with one pair per bond. + Each element of the pair is a tuple of (chain_id, res_id, atom_name), + matching values from the respective fields of this structure. The first + element is the start atom, and the second atom is the end atom of the + bond. + bond_type: This type will be used for all bonds in the structure, where + type follows PDB scheme, e.g. unknown (?), hydrog, metalc, covale, + disulf. + + Returns: + A copy of this structure with the new bonds added. If this structure has + bonds already then the new bonds are concatenated onto the end of the + old bonds. NB: bonds are not deduplicated. + """ + atom_key_lookup: dict[tuple[str, str, None, str], int] = dict( + zip(self.atom_ids, self._atoms.key, strict=True) + ) + + # iter_atoms returns a 4-tuple (chain_id, res_id, ins_code, atom_name) but + # the insertion code is always None. It also uses string residue IDs. + def _to_internal_res_id( + bonded_atom_id: tuple[str, int, str], + ) -> tuple[str, str, None, str]: + return bonded_atom_id[0], str(bonded_atom_id[1]), None, bonded_atom_id[2] + + from_atom_key = [] + dest_atom_key = [] + for from_atom, dest_atom in bonded_atom_pairs: + from_atom_key.append( + atom_key_lookup[_to_internal_res_id(from_atom)]) + dest_atom_key.append( + atom_key_lookup[_to_internal_res_id(dest_atom)]) + num_bonds = len(bonded_atom_pairs) + bonds_key = np.arange(num_bonds, dtype=np.int64) + from_atom_key = np.array(from_atom_key, dtype=np.int64) + dest_atom_key = np.array(dest_atom_key, dtype=np.int64) + all_unk_col = np.array(['?'] * num_bonds, dtype=object) + if bond_type is None: + bond_type_col = all_unk_col + else: + bond_type_col = np.full((num_bonds,), bond_type, dtype=object) + + max_key = -1 if not self._bonds.size else np.max(self._bonds.key) + new_bonds = structure_tables.Bonds( + key=np.concatenate([self._bonds.key, bonds_key + max_key + 1]), + from_atom_key=np.concatenate( + [self._bonds.from_atom_key, from_atom_key] + ), + dest_atom_key=np.concatenate( + [self._bonds.dest_atom_key, dest_atom_key] + ), + type=np.concatenate([self._bonds.type, bond_type_col]), + role=np.concatenate([self._bonds.role, all_unk_col]), + ) + return self.copy_and_update(bonds=new_bonds) + + @property + def coords(self) -> np.ndarray: + """A [..., num_atom, 3] shaped array of atom coordinates.""" + return np.stack([self._atoms.x, self._atoms.y, self._atoms.z], axis=-1) + + def chain_single_letter_sequence( + self, include_missing_residues: bool = True + ) -> Mapping[str, str]: + """Returns a mapping from chain ID to a single letter residue sequence. + + Args: + include_missing_residues: Whether to include residues that have no atoms. + """ + res_table = ( + self._residues if include_missing_residues else self.present_residues + ) + residue_chain_boundaries = _get_change_indices(res_table.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, + count_unresolved=include_missing_residues, + ) + chain_keys = res_table.chain_key[residue_chain_boundaries] + chain_ids = self._chains.apply_array_to_column('id', chain_keys) + chain_types = self._chains.apply_array_to_column('type', chain_keys) + chain_seqs = {} + for idx, (start, end) in enumerate(boundaries): + chain_id = chain_ids[idx] + chain_type = chain_types[idx] + chain_res = res_table.name[start:end] + if chain_type in mmcif_names.PEPTIDE_CHAIN_TYPES: + unknown_default = 'X' + elif chain_type in mmcif_names.NUCLEIC_ACID_CHAIN_TYPES: + unknown_default = 'N' + else: + chain_seqs[chain_id] = 'X' * chain_res.size + continue + + chain_res = string_array.remap( + chain_res, + mapping=residue_names.CCD_NAME_TO_ONE_LETTER, + inplace=False, + default_value=unknown_default, + ) + chain_seqs[chain_id] = ''.join(chain_res.tolist()) + + return chain_seqs + + def polymer_auth_asym_id_to_label_asym_id( + self, + *, + protein: bool = True, + rna: bool = True, + dna: bool = True, + other: bool = True, + ) -> Mapping[str, str]: + """Mapping from author chain ID to internal chain ID, polymers only. + + This mapping is well defined only for polymers (protein, DNA, RNA), but not + for ligands or water. + + E.g. if a structure had the following internal chain IDs (label_asym_id): + A (protein), B (DNA), C (ligand bound to A), D (ligand bound to A), + E (ligand bound to B). + + Such structure would have this internal chain ID (label_asym_id) -> author + chain ID (auth_asym_id) mapping: + A -> A, B -> B, C -> A, D -> A, E -> B + + This is a bijection only for polymers (A, B), but not for ligands. + + Args: + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + other: Whether to include other polymer chains, e.g. RNA/DNA hybrid or + polypeptide(D). Note that include_other=True must be set in from_mmcif. + + Returns: + A mapping from author chain ID to the internal (label) chain ID for the + given polymer types in the Structure, ligands/water are ignored. + + Raises: + ValueError: If the mapping from internal chain IDs to author chain IDs is + not a bijection for polymer chains. + """ + allowed_types = set() + if protein: + allowed_types.add(mmcif_names.PROTEIN_CHAIN) + if rna: + allowed_types.add(mmcif_names.RNA_CHAIN) + if dna: + allowed_types.add(mmcif_names.DNA_CHAIN) + if other: + non_standard_chain_types = ( + mmcif_names.POLYMER_CHAIN_TYPES + - mmcif_names.STANDARD_POLYMER_CHAIN_TYPES + ) + allowed_types |= non_standard_chain_types + + auth_asym_id_to_label_asym_id = {} + for chain in self.iter_chains(): + if chain['chain_type'] not in allowed_types: + continue + label_asym_id = chain['chain_id'] + auth_asym_id = chain['chain_auth_asym_id'] + # The mapping from author chain id to label chain id is only one-to-one if + # we restrict our attention to polymers. But check nevertheless. + if auth_asym_id in auth_asym_id_to_label_asym_id: + raise ValueError( + f'Author chain ID "{auth_asym_id}" does not have a unique mapping ' + f'to internal chain ID "{label_asym_id}", it is already mapped to ' + f'"{auth_asym_id_to_label_asym_id[auth_asym_id]}".' + ) + auth_asym_id_to_label_asym_id[auth_asym_id] = label_asym_id + + return auth_asym_id_to_label_asym_id + + def polymer_author_chain_single_letter_sequence( + self, + *, + include_missing_residues: bool = True, + protein: bool = True, + rna: bool = True, + dna: bool = True, + other: bool = True, + ) -> Mapping[str, str]: + """Mapping from author chain ID to single letter aa sequence, polymers only. + + This mapping is well defined only for polymers (protein, DNA, RNA), but not + for ligands or water. + + Args: + include_missing_residues: If True then all residues will be returned for + each polymer chain present in the structure. This uses the all_residues + field and will include residues missing due to filtering operations as + well as e.g. unresolved residues specified in an mmCIF header. + protein: Whether to include protein (polypeptide(L)) chains. + rna: Whether to include RNA chains. + dna: Whether to include DNA chains. + other: Whether to include other polymer chains, e.g. RNA/DNA hybrid or + polypeptide(D). Note that include_other=True must be set in from_mmcif. + + Returns: + A mapping from (author) chain IDs to their single-letter sequences for all + polymers in the Structure, ligands/water are ignored. + + Raises: + ValueError: If the mapping from internal chain IDs to author chain IDs is + not a bijection for polymer chains. + """ + label_chain_id_to_seq = self.chain_single_letter_sequence( + include_missing_residues=include_missing_residues + ) + auth_to_label = self.polymer_auth_asym_id_to_label_asym_id( + protein=protein, rna=rna, dna=dna, other=other + ) + return { + auth: label_chain_id_to_seq[label] + for auth, label in auth_to_label.items() + } + + def chain_res_name_sequence( + self, + *, + include_missing_residues: bool = True, + fix_non_standard_polymer_res: bool = False, + ) -> Mapping[str, Sequence[str]]: + """A mapping from internal chain ID to a sequence of residue names. + + The residue names are the full residue names rather than single letter + codes. For instance, for proteins these are the 3 letter CCD codes. + + Args: + include_missing_residues: Whether to include residues with no atoms in the + returned sequences. + fix_non_standard_polymer_res: Whether to map non standard residues in + protein / RNA / DNA chains to standard residues (e.g. MSE -> MET) or UNK + / N if a match is not found. + + Returns: + A mapping from (internal) chain IDs to a sequence of residue names. + """ + res_table = ( + self._residues if include_missing_residues else self.present_residues + ) + residue_chain_boundaries = _get_change_indices(res_table.chain_key) + boundaries = self._iter_residue_ranges( + residue_chain_boundaries, count_unresolved=include_missing_residues + ) + chain_keys = res_table.chain_key[residue_chain_boundaries] + chain_ids = self._chains.apply_array_to_column('id', chain_keys) + chain_types = self._chains.apply_array_to_column('type', chain_keys) + chain_seqs = {} + for idx, (start, end) in enumerate(boundaries): + chain_id = chain_ids[idx] + chain_type = chain_types[idx] + chain_res = res_table.name[start:end] + if ( + fix_non_standard_polymer_res + and chain_type in mmcif_names.POLYMER_CHAIN_TYPES + ): + chain_seqs[chain_id] = tuple( + fix_non_standard_polymer_residues( + res_names=chain_res, chain_type=chain_type + ) + ) + else: + chain_seqs[chain_id] = tuple(chain_res) + + return chain_seqs + + def fix_non_standard_polymer_res( + self, + res_mapper: Callable[ + [np.ndarray, str], np.ndarray + ] = fix_non_standard_polymer_residues, + ) -> Self: + """Replaces non-standard polymer residues with standard alternatives or UNK. + + e.g. maps 'ACE' -> 'UNK', 'MSE' -> 'MET'. + + NB: Only fixes the residue names, but does not fix the atom names. + E.g., 'MSE' will be renamed to 'MET' but its 'SE' atom will not be renamed + to 'S'. Fixing MSE should be done during conversion from mmcif with the + `fix_mse_residues` flag. + + Args: + res_mapper: An optional function that accepts a numpy array of residue + names and chain_type, and returns an array with fixed res_names. This + defaults to fix_non_standard_polymer_residues. + + Returns: + A Structure containing only standard residue types (or 'UNK') in its + polymer chains. + """ + fixed_res_name = self._residues.name.copy() + chain_change_indices = _get_change_indices(self._residues.chain_key) + for start, end in self._iter_atom_ranges(chain_change_indices): + chain_key = self._residues.chain_key[start] + chain_type = self._chains.type[self._chains.index_by_key[chain_key]] + if chain_type not in mmcif_names.POLYMER_CHAIN_TYPES: + continue # We don't need to change anything for non-polymers. + fixed_res_name[start:end] = res_mapper( + fixed_res_name[start:end], chain_type + ) + fixed_residues = self._residues.copy_and_update(name=fixed_res_name) + return self.copy_and_update(residues=fixed_residues, skip_validation=True) + + @property + def slice_leading_dims(self) -> '_LeadingDimSlice': + """Used to create a new Structure by slicing into the leading dimensions. + + Example usage 1: + + ``` + final_state = multi_state_struc.slice_leading_dims[-1] + ``` + + Example usage 2: + + ``` + # Structure has leading batch and time dimensions. + # Get final 3 time frames from first two batch elements. + sliced_strucs = batched_trajectories.slice_leading_dims[:2, -3:] + ``` + """ + return _LeadingDimSlice(self) + + def unstack(self, axis: int = 0) -> Sequence[Self]: + """Unstacks a multi-model structure into a list of Structures. + + This method is the inverse of `stack`. + + Example usage: + ``` + strucs = multi_dim_struc.unstack(axis=0) + ``` + + Args: + axis: The axis to unstack over. The structures in the returned list won't + have this axis in their coordinate of b-factor fields. + + Returns: + A list of `Structure`s with length equal to the size of the specified + axis in the coorinate field arrays. + + Raises: + IndexError: If axis does not refer to one of the leading dimensions of + `self.atoms_table.size`. + """ + ndim = self._atoms.ndim + if not (-ndim <= axis < ndim): + raise IndexError( + f'{axis=} is out of range for atom coordinate fields with {ndim=}.' + ) + elif axis < 0: + axis += ndim + if axis == ndim - 1: + raise IndexError( + 'axis must refer to one of the leading dimensions, not the final ' + f'dimension. The atom fields have {ndim=} and {axis=} was specified.' + ) + unstacked = [] + leading_dim_slice = self.slice_leading_dims # Compute once here. + for i in range(self._atoms.shape[axis]): + slice_i = (slice(None),) * axis + (i,) + unstacked.append(leading_dim_slice[slice_i]) + return unstacked + + def split_by_chain(self) -> Sequence[Self]: + """Splits a Structure into single-chain Structures, one for each chain. + + The obtained structures can be merged back together into the original + structure using the `concat` function. + + Returns: + A list of `Structure`s, one for each chain. The order is the same as the + chain order in the original Structure. + """ + return [self.filter(chain_id=chain_id) for chain_id in self.chains] + + def transform_states_to_chains(self) -> Self: + """Transforms states to chains. + + A multi-state protein structure will be transformed to a multi-chain + single-state protein structure. Useful for visualising multiples states to + examine diversity. This structure's coordinate fields must have shape + `(num_states, num_atoms)`. + + Returns: + A new `Structure`, based on this structure, but with the multiple states + now represented as `num_states * num_chains` chains in a + single-state protein. + + Raises: + ValueError: If this structure's array fields don't have shape + `(num_states, num_atoms)`. + """ + if self._atoms.ndim != 2: + raise ValueError( + 'Coordinate field tensor must have 2 dimensions: ' + f'(num_states, num_atoms), got {self._atoms.ndim}.' + ) + return concat(self.unstack(axis=0)) + + def merge_chains( + self, + *, + chain_groups: Sequence[Sequence[str]], + chain_group_ids: Sequence[str] | None = None, + chain_group_types: Sequence[str] | None = None, + ) -> Self: + """Merges chains in each group into a single chain. + + If a Structure has chains A, B, C, D, E, and + `merge_chains([[A, C], [B, D], [E]])` is called, the new Structure will have + 3 chains A, B, C, the first being concatenation of A+C, the second B+D, the + third just the original chain E. + + Args: + chain_groups: Each group defines what chains should be merged into a + single chain. The output structure will therefore have len(chain_groups) + chains. Residue IDs are renumbered to preserve uniqueness within new + chains. Order of chain groups and within each group matters. + chain_group_ids: Optional sequence of new chain IDs for each group. If not + given, the new internal chain IDs (label_asym_id) are assigned in the + standard mmCIF order (i.e. A, B, ..., Z, AA, BA, CA, ...). Author chain + names (auth_asym_id) are set to be equal to the new internal chain IDs. + chain_group_types: Optional sequence of new chain types for each group. If + not given, only chains with the same type can be merged. + + Returns: + A new `Structure` with chains merged together into a single chain within + each chain group. + + Raises: + ValueError: If chain_group_ids or chain_group_types are given but don't + match the length of chain_groups. + ValueError: If the chain IDs in the flattened chain_groups don't match the + chain IDs in the Structure. + ValueError: If chains in any of the groups don't have the same chain type. + """ + if chain_group_ids and len(chain_group_ids) != len(chain_groups): + raise ValueError( + 'chain_group_ids must the same length as chain_groups: ' + f'{len(chain_group_ids)=} != {len(chain_groups)=}' + ) + if chain_group_types and len(chain_group_types) != len(chain_groups): + raise ValueError( + 'chain_group_types must the same length as chain_groups: ' + f'{len(chain_group_types)=} != {len(chain_groups)=}' + ) + flattened = sorted(itertools.chain.from_iterable(chain_groups)) + if flattened != sorted(self.chains): + raise ValueError( + 'IDs in chain groups do not match Structure chain IDs: ' + f'{chain_groups=}, chains={self.chains}' + ) + + new_chain_key_by_chain_id = {} + for new_chain_key, group_chain_ids in enumerate(chain_groups): + for chain_id in group_chain_ids: + new_chain_key_by_chain_id[chain_id] = new_chain_key + + chain_key_remap = {} + new_chain_type_by_chain_key = {} + for old_chain_key, old_chain_id, old_chain_type in zip( + self._chains.key, self._chains.id, self._chains.type + ): + new_chain_key = new_chain_key_by_chain_id[old_chain_id] + chain_key_remap[old_chain_key] = new_chain_key + + if new_chain_key not in new_chain_type_by_chain_key: + new_chain_type_by_chain_key[new_chain_key] = old_chain_type + elif not chain_group_types: + if new_chain_type_by_chain_key[new_chain_key] != old_chain_type: + bad_types = [ + f'{cid}: {self._chains.type[np.where(self._chains.id == cid)][0]}' + for cid in chain_groups[new_chain_key] + ] + raise ValueError( + 'Inconsistent chain types within group:\n' + + '\n'.join(bad_types) + ) + + new_chain_key = np.arange(len(chain_groups), dtype=np.int64) + if chain_group_ids: + new_chain_id = np.array(chain_group_ids, dtype=object) + else: + new_chain_id = np.array( + [mmcif.int_id_to_str_id(k) for k in new_chain_key + 1], dtype=object + ) + if chain_group_types: + new_chain_type = np.array(chain_group_types, dtype=object) + else: + new_chain_type = np.array( + [new_chain_type_by_chain_key[k] for k in new_chain_key], dtype=object + ) + new_chains = structure_tables.Chains( + key=new_chain_key, + id=new_chain_id, + type=new_chain_type, + auth_asym_id=new_chain_id, + entity_id=np.char.mod('%d', new_chain_key + 1).astype(object), + entity_desc=np.full(len(chain_groups), + fill_value='.', dtype=object), + ) + + # Remap chain keys and sort residues to match the chain table order. + new_residues = self._residues.copy_and_remap(chain_key=chain_key_remap) + new_residues = new_residues.apply_index( + np.argsort(new_residues.chain_key, kind='stable') + ) + # Renumber uniquely residues in each chain. + indices = np.arange(new_residues.chain_key.size, dtype=np.int32) + new_res_ids = (indices + 1) - np.maximum.accumulate( + indices * (new_residues.chain_key != + np.roll(new_residues.chain_key, 1)) + ) + new_residues = new_residues.copy_and_update(id=new_res_ids) + + # Remap chain keys and sort atoms to match the chain table order. + new_atoms = self._atoms.copy_and_remap(chain_key=chain_key_remap) + new_atoms = new_atoms.apply_index( + np.argsort(new_atoms.chain_key, kind='stable') + ) + + return self.copy_and_update( + chains=new_chains, + residues=new_residues, + atoms=new_atoms, + bonds=self._bonds, + ) + + def to_res_arrays( + self, + *, + include_missing_residues: bool, + atom_order: Mapping[str, int] = atom_types.ATOM37_ORDER, + ) -> tuple[np.ndarray, np.ndarray]: + """Returns an atom position and atom mask array with a num_res dimension. + + NB: All residues in the structure will appear in the residue + dimension but atoms will only have a True (1.0) mask value if + they are defined in `atom_order`. + + Args: + include_missing_residues: If True then the res arrays will include rows + for missing residues where all atoms will be masked out. Otherwise these + will simply be skipped. + atom_order: Atom order mapping atom names to their index in the atom + dimension of the returned arrays. Default is atom_order for proteins, + choose atom_types.ATOM29_ORDER for nucleics. + + Returns: + A pair of arrays: + * atom_positions: [num_res, atom_type_num, 3] float32 array of coords. + * atom_mask: [num_res, atom_type_num] float32 atom mask denoting + which atoms are present in this Structure. + """ + num_res = self.num_residues(count_unresolved=include_missing_residues) + atom_type_num = len(atom_order) + atom_positions = np.zeros( + (num_res, atom_type_num, 3), dtype=np.float32) + atom_mask = np.zeros((num_res, atom_type_num), dtype=np.float32) + + all_residues = None if not include_missing_residues else self.all_residues + for i, atom in enumerate_residues(self.iter_atoms(), all_residues): + atom_idx = atom_order.get(atom['atom_name']) + if atom_idx is not None: + atom_positions[i, atom_idx, 0] = atom['atom_x'] + atom_positions[i, atom_idx, 1] = atom['atom_y'] + atom_positions[i, atom_idx, 2] = atom['atom_z'] + atom_mask[i, atom_idx] = 1.0 + + return atom_positions, atom_mask + + def to_res_atom_lists( + self, *, include_missing_residues: bool + ) -> Sequence[Sequence[Mapping[str, Any]]]: + """Returns list of atom dictionaries grouped by residue. + + If this is a multi-model structure, each atom will store its fields + atom_x, atom_y, atom_z, and atom_b_factor as Numpy arrays of shape of the + leading dimension(s). If this is a single-mode structure, these fields will + just be scalars. + + Args: + include_missing_residues: If True, then the output list will contain an + empty list of atoms for missing residues. Otherwise missing residues + will simply be skipped. + + Returns: + A list of size `num_res`. Each element in the list represents atoms of one + residue. If a residue is present is present, the list will contain an atom + dictionary for every atom present in that residue. If a residue is missing + and `include_missing_residues=True`, the list for that missing residue + will be empty. + """ + num_res = self.num_residues(count_unresolved=include_missing_residues) + residue_atoms = [[] for _ in range(num_res)] + all_residues = None if not include_missing_residues else self.all_residues + + # We could yield directly in this loop but the code would be more complex. + # Let's optimise if memory usage is an issue. + for res_index, atom in enumerate_residues(self.iter_atoms(), all_residues): + residue_atoms[res_index].append(atom) + + return residue_atoms + + def reorder_chains(self, new_order: Sequence[str]) -> Self: + """Reorders tables so that the label_asym_ids are in the given order. + + This method changes the order of the chains, residues, and atoms tables so + that they are all consistent with each other. Moreover, it remaps chain keys + so that they stay monotonically increasing in chains/residues/atoms tables. + + Args: + new_order: The order in which the chain IDs (label_asym_id) should be. + This must be a permutation of the current chain IDs. + + Returns: + A structure with chains reorded. + """ + if len(new_order) != len(self.chains): + raise ValueError( + f'The new number of chains ({len(new_order)}) does not match the ' + f'current number of chains ({len(self.chains)}).' + ) + new_chain_set = set(new_order) + if len(new_chain_set) != len(new_order): + raise ValueError( + f'The new order {new_order} contains non-unique IDs.') + if new_chain_set.symmetric_difference(set(self.chains)): + raise ValueError( + f'New chain IDs {new_order} do not match the old {set(self.chains)}' + ) + + if self.chains == tuple(new_order): + # Shortcut: the new order is the same as the current one. + return self + + desired_chain_id_pos = {chain_id: i for i, + chain_id in enumerate(new_order)} + + current_chain_index_order = np.empty(self.num_chains, dtype=np.int64) + for index, old_chain_id in enumerate(self._chains.id): + current_chain_index_order[index] = desired_chain_id_pos[old_chain_id] + chain_reorder = np.argsort(current_chain_index_order, kind='stable') + chain_key_map = dict( + zip(self._chains.key[chain_reorder], range(self.num_chains)) + ) + chains = self._chains.apply_index(chain_reorder) + chains = chains.copy_and_remap(key=chain_key_map) + + # The stable sort keeps the original residue ordering within each chain. + residues = self._residues.copy_and_remap(chain_key=chain_key_map) + residue_reorder = np.argsort(residues.chain_key, kind='stable') + residues = residues.apply_index(residue_reorder) + + # The stable sort keeps the original atom ordering within each chain. + atoms = self._atoms.copy_and_remap(chain_key=chain_key_map) + atoms_reorder = np.argsort(atoms.chain_key, kind='stable') + atoms = atoms.apply_index(atoms_reorder) + + # Bonds unchanged - each references 2 atom keys, hence ordering not defined. + return self.copy_and_update(chains=chains, residues=residues, atoms=atoms) + + def rename_auth_asym_ids(self, new_id_by_old_id: Mapping[str, str]) -> Self: + """Returns a new structure with renamed auth_asym_ids. + + Args: + new_id_by_old_id: A mapping from original auth_asym_ids to their new + values. Any auth_asym_ids in this structure that are not in the mapping + will remain unchanged. + + Raises: + ValueError: If any two previously distinct polymer chains do not have + unique names anymore after the rename. + """ + mapped_chains = self._chains.copy_and_remap( + auth_asym_id=new_id_by_old_id) + mapped_polymer_ids = mapped_chains.filter( + type=mmcif_names.POLYMER_CHAIN_TYPES + ).auth_asym_id + if len(mapped_polymer_ids) != len(set(mapped_polymer_ids)): + raise ValueError( + 'The new polymer auth_asym_ids are not unique:' + f' {sorted(mapped_polymer_ids)}.' + ) + return self.copy_and_update(chains=mapped_chains, skip_validation=True) + + def rename_chain_ids(self, new_id_by_old_id: Mapping[str, str]) -> Self: + """Returns a new structure with renamed chain IDs (label_asym_ids). + + The chains' auth_asym_ids will be updated to be identical to the chain ID + since there isn't one unambiguous way to maintain the auth_asym_ids after + renaming the chain IDs (depending on whether you view the auth_asym_id as + more strongly associated with a given physical chain, or with a given + chain ID). + + The residues' auth_seq_id will be updated to be identical to the residue ID + since they are strongly tied to the original author chain naming and keeping + them would be misleading. + + Args: + new_id_by_old_id: A mapping from original chain ID to their new values. + Any chain IDs in this structure that are not in this mapping will remain + unchanged. + + Returns: + A new structure with renamed chains (and bioassembly data if it is + present). + + Raises: + ValueError: If any two previously distinct chains do not have unique names + anymore after the rename. + """ + new_chain_id = string_array.remap(self._chains.id, new_id_by_old_id) + if len(new_chain_id) != len(set(new_chain_id)): + raise ValueError( + f"New chain names aren't unique: {sorted(new_chain_id)}") + + # Map label_asym_ids in the bioassembly data. + if self._bioassembly_data is None: + new_bioassembly_data = None + else: + new_bioassembly_data = self._bioassembly_data.rename_label_asym_ids( + new_id_by_old_id, present_chains=set(self.present_chains.id) + ) + + # Set author residue IDs to be the string version of internal residue IDs. + new_residues = self._residues.copy_and_update( + auth_seq_id=self._residues.id.astype(str).astype(object) + ) + + new_chains = self._chains.copy_and_update( + id=new_chain_id, auth_asym_id=new_chain_id + ) + + return self.copy_and_update( + bioassembly_data=new_bioassembly_data, + chains=new_chains, + residues=new_residues, + skip_validation=True, + ) + + @functools.cached_property + def chains(self) -> tuple[str, ...]: + """Ordered internal chain IDs (label_asym_id) present in the Structure.""" + return tuple(self._chains.id) + + def rename_res_name( + self, + res_name_map: Mapping[str, str], + fail_if_not_found: bool = True, + ) -> Self: + """Returns a copy of this structure with residues renamed. + + Residue names in chemical components data will also be renamed. + + Args: + res_name_map: A mapping from old residue names to new residue names. Any + residues that are not in this mapping will be left unchanged. + fail_if_not_found: Whether to fail if keys in the res_name_map mapping are + not found in this structure's residues' `name` column. + + Raises: + ValueError: If `fail_if_not_found=True` and a residue name isn't found in + the residues table's `name` field. + """ + res_name_set = set(self._residues.name) + if fail_if_not_found: + for res_name in res_name_map: + if res_name not in res_name_set: + raise ValueError( + f'"{res_name}" not found in this structure.') + new_residues = self._residues.copy_and_remap(name=res_name_map) + + if self._chemical_components_data is not None: + chem_comp = { + res_name_map.get(res_name, res_name): data + for res_name, data in self._chemical_components_data.chem_comp.items() + } + new_chem_comp = struc_chem_comps.ChemicalComponentsData(chem_comp) + else: + new_chem_comp = None + + return self.copy_and_update( + residues=new_residues, + chemical_components_data=new_chem_comp, + skip_validation=True, + ) + + def rename_chains_to_match( + self, + other: 'Structure', + *, + fuzzy_match_non_standard_res: bool = True, + ) -> Self: + """Returns a new structure with renamed chains to match another's. + + Example: + This structure has chains: {'A': 'DEEP', 'B': 'MIND', 'C': 'MIND'} + Other structure has chains: {'X': 'DEEP', 'Z': 'MIND', 'Y': 'MIND'} + + After calling this method, you will get a structure that has chains named: + {'X': 'DEEP', 'Z': 'MIND', Y: 'MIND'} + + Args: + other: Another `Structure`. This provides the reference chain names that + is used to rename this structure's chains. + fuzzy_match_non_standard_res: If True, protein/RNA/DNA chains with the + same one letter sequence will be matched. e.g. "MET-MET-UNK1" will match + "MET-MSE-UNK2", since both will be mapped to "MMX". If False, we require + the full res_names to match. + + Returns: + A new `Structure`, based on this structure, which has chains renamed to + match the other structure. + """ + sequences = self.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fuzzy_match_non_standard_res, + ) + + other_sequences = other.chain_res_name_sequence( + include_missing_residues=True, + fix_non_standard_polymer_res=fuzzy_match_non_standard_res, + ) + + # Check that the sequences are the same. + sequence_counts = collections.Counter(sequences.values()) + other_sequence_counts = collections.Counter(other_sequences.values()) + if other_sequence_counts != sequence_counts: + raise ValueError( + 'The other structure does not have the same sequences\n' + f' other: {other_sequence_counts}\n self: {sequence_counts}' + ) + + new_decoy_id_by_old_id = {} + used_chain_ids = set() + # Sort self keys and take min over other to make matching deterministic. + # The matching is arbitrary but this helps debugging. + for self_chain_id, self_seq in sorted(sequences.items()): + # Find corresponding chains in the other structure. + other_chain_id = min( + k + for k, v in other_sequences.items() + if v == self_seq and k not in used_chain_ids + ) + + new_decoy_id_by_old_id[self_chain_id] = other_chain_id + used_chain_ids.add(other_chain_id) + + return self.rename_chain_ids(new_decoy_id_by_old_id) + + def _apply_bioassembly_transform( + self, transform: bioassemblies.Transform + ) -> Self: + """Applies a bioassembly transform to this structure.""" + base_struc = self.filter(chain_id=transform.chain_ids) + transformed_atoms = base_struc.atoms_table.copy_and_update_coords( + transform.apply_to_coords(base_struc.coords) + ) + transformed_chains = base_struc.chains_table.copy_and_remap( + id=transform.chain_id_rename_map + ) + # Set the transformed author chain ID to match the label chain ID. + transformed_chains = transformed_chains.copy_and_update( + auth_asym_id=transformed_chains.id + ) + return base_struc.copy_and_update( + chains=transformed_chains, + atoms=transformed_atoms, + skip_validation=True, + ) + + def generate_bioassembly(self, assembly_id: str | None = None) -> Self: + """Generates a biological assembly as a new `Structure`. + + When no assembly ID is provided this method produces a default assembly. + If this structure has no `bioassembly_data` then this returns itself + unchanged. Otherwise a default assembly ID is picked with + `BioassemblyData.get_default_assembly_id()`. + + Args: + assembly_id: The assembly ID to generate, or None to generate a default + bioassembly. + + Returns: + A new `Structure`, based on this one, representing the specified + bioassembly. Note that if the bioassembly contains copies of chains + in the original structure then they will be given new unique chain IDs. + + Raises: + ValueError: If this structure's `bioassembly_data` is `None` and + `assembly_id` is not `None`. + """ + if self._bioassembly_data is None: + if assembly_id is None: + return self + else: + raise ValueError( + f'Unset bioassembly_data, cannot generate assembly {assembly_id}' + ) + + if assembly_id is None: + assembly_id = self._bioassembly_data.get_default_assembly_id() + + transformed_strucs = [ + self._apply_bioassembly_transform(transform) + for transform in self._bioassembly_data.get_transforms(assembly_id) + ] + + # We don't need to assign unique chain IDs because the bioassembly + # transform takes care of remapping chain IDs to be unique. + concatenated = concat(transformed_strucs, + assign_unique_chain_ids=False) + + # Copy over all scalar fields (e.g. name, release date, etc.) other than + # bioassembly_data because it relates only to the pre-transformed structure. + return concatenated.copy_and_update_globals( + name=self.name, + release_date=self.release_date, + resolution=self.resolution, + structure_method=self.structure_method, + bioassembly_data=None, + chemical_components_data=self.chemical_components_data, + ) + + def _to_mmcif_header(self) -> Mapping[str, Sequence[str]]: + raw_mmcif = collections.defaultdict(list) + raw_mmcif['data_'] = [self._name] + raw_mmcif['_entry.id'] = [self._name] + + if self._release_date is not None: + date = [datetime.datetime.strftime(self._release_date, '%Y-%m-%d')] + raw_mmcif['_pdbx_audit_revision_history.revision_date'] = date + raw_mmcif['_pdbx_database_status.recvd_initial_deposition_date'] = date + + if self._resolution is not None: + raw_mmcif['_refine.ls_d_res_high'] = ['%.2f' % self._resolution] + + if self._structure_method is not None: + for method in self._structure_method.split(','): + raw_mmcif['_exptl.method'].append(method) + + if self._bioassembly_data is not None: + raw_mmcif.update(self._bioassembly_data.to_mmcif_dict()) + + # Populate chemical components data for all residues of this Structure. + if self._chemical_components_data: + raw_mmcif.update(self._chemical_components_data.to_mmcif_dict()) + + # Add _software table to store version number used to generate mmCIF. + # Only required data items are used (+ _software.version). + raw_mmcif['_software.pdbx_ordinal'] = ['1'] + raw_mmcif['_software.name'] = ['DeepMind Structure Class'] + raw_mmcif['_software.version'] = [self._VERSION] + raw_mmcif['_software.classification'] = ['other'] # Required. + + return raw_mmcif + + def to_mmcif_dict( + self, + *, + coords_decimal_places: int = _COORDS_DECIMAL_PLACES, + ) -> mmcif.Mmcif: + """Returns an Mmcif representing the structure.""" + header = self._to_mmcif_header() + sequence_tables = structure_tables.to_mmcif_sequence_and_entity_tables( + self._chains, self._residues, self._atoms.res_key + ) + atom_and_bond_tables = structure_tables.to_mmcif_atom_site_and_bonds_table( + chains=self._chains, + residues=self._residues, + atoms=self._atoms, + bonds=self._bonds, + coords_decimal_places=coords_decimal_places, + ) + return mmcif.Mmcif({**header, **sequence_tables, **atom_and_bond_tables}) + + def to_mmcif( + self, *, coords_decimal_places: int = _COORDS_DECIMAL_PLACES + ) -> str: + """Returns an mmCIF string representing the structure. + + Args: + coords_decimal_places: The number of decimal places to keep for atom + coordinates, including trailing zeros. + """ + return self.to_mmcif_dict( + coords_decimal_places=coords_decimal_places + ).to_string() + + +class _LeadingDimSlice: + """Helper class for slicing the leading dimensions of a `Structure`. + + Wraps a `Structure` instance and applies a slice operation to the coordinate + fields and other fields that may have leading dimensions (e.g. b_factor). + + Example usage: + t0_struc = multi_state_struc.slice_leading_dims[0] + """ + + def __init__(self, struc: Structure): + self._struc = struc + + def __getitem__(self, *args, **kwargs) -> Structure: + sliced_atom_cols = {} + for col_name in structure_tables.Atoms.multimodel_cols: + if (col := self._struc.atoms_table.get_column(col_name)).ndim > 1: + sliced_col = col.__getitem__(*args, **kwargs) + if ( + not sliced_col.shape + or sliced_col.shape[-1] != self._struc.num_atoms + ): + raise ValueError( + 'Coordinate slice cannot change final (atom) dimension.' + ) + sliced_atom_cols[col_name] = sliced_col + sliced_atoms = self._struc.atoms_table.copy_and_update( + **sliced_atom_cols) + return self._struc.copy_and_update(atoms=sliced_atoms, skip_validation=True) + + +def stack(strucs: Sequence[Structure], axis: int = 0) -> Structure: + """Stacks multiple structures into a single multi-model Structure. + + This function is the inverse of `Structure.unstack()`. + + NB: this function assumes that every structure in `strucs` is identical + other than the coordinates and b-factors. Under this assumption we can safely + copy all these identical fields from the first element of strucs w.l.o.g. + However this is not checked in full detail as full comparison is expensive. + Instead this only checks that the `atom_name` field is identical, and that + the coordinates have the same shape. + + Usage example: + ``` + multi_model_struc = structure.stack(strucs, axis=0) + ``` + + Args: + strucs: A sequence of structures, each with the same atoms, but they may + have different coordinates and b-factors. If any b-factors are not None + then they must have the same shape as each of the coordinate fields. + axis: The axis in the returned structure that represents the different + structures in `strucs` and will have size `len(strucs)`. This cannot be + the final dimension as this is reserved for `num_atoms`. + + Returns: + A `Structure` with the same atoms as the structures in `strucs` but with + all of their coordinates stacked into a new leading axis. + + Raises: + ValueError: If `strucs` is empty. + ValueError: If `strucs` do not all have the same `atom_name` field. + """ + if not strucs: + raise ValueError('Need at least one Structure to stack.') + struc_0, *other_strucs = strucs + for i, struc in enumerate(other_strucs, start=1): + # Check that every structure has the same atom name column. + # This check is intended to catch cases where the input structures might + # contain the same atoms, but in different orders. This won't catch every + # such case, e.g. if these are carbon-alpha-only structures, but should + # catch most cases. + if np.any(struc.atoms_table.name != struc_0.atoms_table.name): + raise ValueError( + f'strucs[0] and strucs[{i}] have mismatching atom name columns.' + ) + + stacked_atoms = struc_0.atoms_table.copy_and_update( + x=np.stack([s.atoms_table.x for s in strucs], axis=axis), + y=np.stack([s.atoms_table.y for s in strucs], axis=axis), + z=np.stack([s.atoms_table.z for s in strucs], axis=axis), + b_factor=np.stack([s.atoms_table.b_factor for s in strucs], axis=axis), + occupancy=np.stack( + [s.atoms_table.occupancy for s in strucs], axis=axis), + ) + return struc_0.copy_and_update(atoms=stacked_atoms, skip_validation=True) + + +def _assign_unique_chain_ids( + strucs: Iterable[Structure], +) -> Sequence[Structure]: + """Creates a sequence of `Structure` objects with unique chain IDs. + + Let e.g. [A, B] denote a structure of two chains A and B, then this function + performs the following kind of renaming operation: + + e.g.: [Z], [C], [B, C] -> [A], [B], [C, D] + + NB: This function uses Structure.rename_chain_ids which will define each + structure's chains.auth_asym_id to be identical to its chains.id columns. + + Args: + strucs: Structures whose chains ids are to be uniquified. + + Returns: + A sequence with the same number of elements as `strucs` but where each + element has had its chains renamed so that they aren't shared with any + other `Structure` in the sequence. + """ + # Start counting at 1 because mmcif.int_id_to_str_id expects integers >= 1. + chain_counter = 1 + strucs_with_new_chain_ids = [] + for struc in strucs: + rename_map = {} + for chain_id in struc.chains: + rename_map[chain_id] = mmcif.int_id_to_str_id(chain_counter) + chain_counter += 1 + renamed = struc.rename_chain_ids(rename_map) + strucs_with_new_chain_ids.append(renamed) + return strucs_with_new_chain_ids + + +def concat( + strucs: Sequence[Structure], + *, + name: str | None = None, + assign_unique_chain_ids: bool = True, +) -> Structure: + """Concatenates structures along the atom dimension. + + NB: By default this function will first assign unique chain IDs to all chains + in `strucs` so that the resulting structure does not contain duplicate chain + IDs. This will also fix entity IDs and author chain IDs. If this is disabled + via `assign_unique_chain_ids=False` the user must ensure that there are no + duplicate chains (label_asym_id). However, duplicate entity IDs and author + chain IDs are allowed as that might be the desired behavior. + + If `assign_unique_chain_ids=True`, note also that the chain_ids may be + overwritten even if they are already unique. + + Let e.g. [A, B] denote a structure of two chains A and B, then this function + performs the following kind of concatenation operation: + + assign_unique_chain_ids=True: + label chain IDS : [Z], [C], [B, C] -> [A, B, C, D] + author chain IDS: [U], [V], [V, C] -> [A, B, C, D] + entity IDs : [1], [1], [3, 3] -> [1, 2, 3, 4] + assign_unique_chain_ids=False: + label chain IDS : [D], [B], [C, A] -> [D, B, C, A] (inputs must be unique) + author chain IDS: [U], [V], [V, A] -> [U, V, V, A] + entity IDs : [1], [1], [3, 3] -> [1, 1, 3, 3] + + NB: This operation loses some information from the elements of `strucs`, + namely the `name`, `resolution`, `release_date` and `bioassembly_data` fields. + + Args: + strucs: The `Structure` instances to concatenate. These should all have the + same number and shape of leading dimensions (i.e. if any are multi-model + structures then they should all have the same number of models). + name: Optional name to give to the concatenated structure. If None, the name + will be concatenation of names of all concatenated structures. + assign_unique_chain_ids: Whether this function will first assign new unique + chain IDs, entity IDs and author chain IDs to every chain in `strucs`. If + `False` then users must ensure chain IDs are already unique, otherwise an + exception is raised. See `_assign_unique_chain_ids` for more information + on how this is performed. + + Returns: + A new concatenated `Structure` with all of the chains in `strucs` combined + into one new structure. The new structure will be named by joining the + names of `strucs` with underscores. + + Raises: + ValueError: If `strucs` is empty. + ValueError: If `assign_unique_chain_ids=False` and not all chains in + `strucs` have unique chain IDs. + """ + if not strucs: + raise ValueError('Need at least one Structure to concatenate.') + if assign_unique_chain_ids: + strucs = _assign_unique_chain_ids(strucs) + + chemical_components_data = {} + seen_label_chain_ids = set() + for i, struc in enumerate(strucs): + if not assign_unique_chain_ids: + if seen_cid := seen_label_chain_ids.intersection(struc.chains): + raise ValueError( + f'Chain IDs {seen_cid} from strucs[{i}] also exist in other' + ' members of strucs. All given structures must have unique chain' + ' IDs. Consider setting assign_unique_chain_ids=True.' + ) + seen_label_chain_ids.update(struc.chains) + + if struc.chemical_components_data is not None: + # pytype: disable=attribute-error # always-use-property-annotation + chemical_components_data.update( + struc.chemical_components_data.chem_comp) + + concatted_struc = table.concat_databases(strucs) + name = name if name is not None else '_'.join(s.name for s in strucs) + # Chain IDs (label and author) are fixed at this point, fix also entity IDs. + if assign_unique_chain_ids: + entity_id = np.char.mod('%d', np.arange( + 1, concatted_struc.num_chains + 1)) + chains = concatted_struc.chains_table.copy_and_update( + entity_id=entity_id) + else: + chains = concatted_struc.chains_table + return concatted_struc.copy_and_update( + name=name, + release_date=None, + resolution=None, + structure_method=None, + bioassembly_data=None, + chemical_components_data=( + struc_chem_comps.ChemicalComponentsData(chemical_components_data) + if chemical_components_data + else None + ), + chains=chains, + skip_validation=True, # Already validated by table.concat_databases. + ) + + +def multichain_residue_index( + struc: Structure, chain_offset: int = 9000, between_chain_buffer: int = 1000 +) -> np.ndarray: + """Compute a residue index array that is monotonic across all chains. + + Lots of metrics (lddt, l1_long, etc) require computing a + distance-along-chain between two residues. For multimers we want to ensure + that any residues on different chains have a high along-chain distance + (i.e. they should always count as long-range contacts for example). To + do this we add 10000 to the residue indices of each chain, and enforce that + the residue index is monotonically increasing across the whole complex. + + Note: This returns the same as struc.res_id for monomers. + + Args: + struc: The structure to make a multichain residue index for. + chain_offset: The start of each chain is offset by at least this amount. + This must be larger than the absolute range of standard residue IDs. + between_chain_buffer: The final residue in one chain will have at least this + much of a buffer before the first residue in the next chain. + + Returns: + A monotonically increasing residue index, with at least + `between_chain_buffer` residues in between each chain. + """ + if struc.num_atoms: + res_id_range = np.max(struc.res_id) - np.min(struc.res_id) + assert res_id_range < chain_offset + chain_id_int = struc.chain_id + monotonic_chain_id_int = np.concatenate( + ([0], np.cumsum(chain_id_int[1:] != chain_id_int[:-1])) + ) + return struc.res_id + monotonic_chain_id_int * ( + chain_offset + between_chain_buffer + ) + + +def make_empty_structure() -> Structure: + """Returns a new structure consisting of empty array fields.""" + return Structure( + chains=structure_tables.Chains.make_empty(), + residues=structure_tables.Residues.make_empty(), + atoms=structure_tables.Atoms.make_empty(), + bonds=structure_tables.Bonds.make_empty(), + ) + + +def enumerate_residues( + atom_iter: Iterable[Mapping[str, Any]], + all_residues: AllResidues | None = None, +) -> Iterator[tuple[int, Mapping[str, Any]]]: + """Provides a zero-indexed enumeration of residues in an atom iterable. + + Args: + atom_iter: An iterable of atom dicts as returned by Structure.iter_atoms(). + all_residues: (Optional) A structure's all_residues field. If present then + this will be used to count missing residues by adding appropriate gaps in + the residue enumeration. + + Yields: + (res_i, atom) pairs where atom is the unmodified atom dict and res_i is a + zero-based index for the residue that the atom belongs to. + """ + if all_residues is None: + prev_res = None + res_i = -1 + for atom in atom_iter: + res = (atom['chain_id'], atom['res_id']) + if res != prev_res: + prev_res = res + res_i += 1 + yield res_i, atom + else: + all_res_seq = [] # Sequence of (chain_id, res_id) for all chains. + prev_chain = None + res_i = 0 + for atom in atom_iter: + chain_id = atom['chain_id'] + if chain_id not in all_residues: + raise ValueError( + f'Atom {atom} does not belong to any residue in all_residues.' + ) + if chain_id != prev_chain: + prev_chain = chain_id + all_res_seq.extend( + (chain_id, res_id) for (_, res_id) in all_residues[chain_id] + ) + res = (chain_id, atom['res_id']) + while res_i < len(all_res_seq) and res != all_res_seq[res_i]: + res_i += 1 + if res_i == len(all_res_seq): + raise ValueError( + f'Atom {atom} does not belong to a residue in all_residues.' + ) + yield res_i, atom diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py new file mode 100644 index 000000000..690454c39 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/structure_tables.py @@ -0,0 +1,843 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Table implementations for the Structure class.""" + +import collections +from collections.abc import Mapping, Sequence +import dataclasses +import functools +import itertools +import typing +from typing_extensions import Any, ClassVar, Self + +from alphafold3.constants import mmcif_names +from alphafold3.constants import residue_names +from alphafold3.cpp import aggregation +from alphafold3.cpp import string_array +from alphafold3.structure import bonds as bonds_module +from alphafold3.structure import mmcif +from alphafold3.structure import table +import numpy as np + + +Bonds = bonds_module.Bonds + + +def _residue_name_to_record_name( + residue_name: np.ndarray, + polymer_mask: np.ndarray, +) -> np.ndarray: + """Returns record names (ATOM/HETATM) given residue names and polymer mask.""" + record_name = np.array(['HETATM'] * len(residue_name), dtype=object) + record_name[polymer_mask] = string_array.remap( + residue_name[polymer_mask], + mapping={r: 'ATOM' for r in residue_names.STANDARD_POLYMER_TYPES}, + default_value='HETATM', + ) + return record_name + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class AuthorNamingScheme: + """A mapping from internal values to author values in a mmCIF. + + Fields: + auth_asym_id: A mapping from label_asym_id to auth_asym_id. + auth_seq_id: A mapping from label_asym_id to a mapping from + label_seq_id to auth_seq_id. + insertion_code: A mapping from label_asym_id to a mapping from + label_seq_id to insertion codes. + entity_id: A mapping from label_asym_id to _entity.id. + entity_desc: A mapping from _entity.id to _entity.pdbx_description. + """ + + auth_asym_id: Mapping[str, str] + auth_seq_id: Mapping[str, Mapping[int, str]] + insertion_code: Mapping[str, Mapping[int, str | None]] + entity_id: Mapping[str, str] + entity_desc: Mapping[str, str] + + +def _default( + candidate_value: np.ndarray | None, default_value: Sequence[Any], dtype: Any +) -> np.ndarray: + if candidate_value is None: + return np.array(default_value, dtype=dtype) + return np.array(candidate_value, dtype=dtype) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Atoms(table.Table): + """Table of atoms in a Structure.""" + + chain_key: np.ndarray + res_key: np.ndarray + name: np.ndarray + element: np.ndarray + x: np.ndarray + y: np.ndarray + z: np.ndarray + b_factor: np.ndarray + occupancy: np.ndarray + multimodel_cols: ClassVar[tuple[str, ...]] = ( + 'x', + 'y', + 'z', + 'b_factor', + 'occupancy', + ) + + def __post_init__(self): + # Validates that the atom coordinates, b-factors and occupancies are finite. + for column_name in ('x', 'y', 'z', 'b_factor', 'occupancy'): + column = self.get_column(column_name) + if not np.isfinite(column).all(): + raise ValueError( + f'Column {column_name} must not contain NaN/inf values.' + ) + # super().__post_init__() can't be used as that causes the following error: + # TypeError: super(type, obj): obj must be an instance or subtype of type + super(Atoms, self).__post_init__() + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + chain_key=np.array([], dtype=np.int64), + res_key=np.array([], dtype=np.int64), + name=np.array([], dtype=object), + element=np.array([], dtype=object), + x=np.array([], dtype=np.float32), + y=np.array([], dtype=np.float32), + z=np.array([], dtype=np.float32), + b_factor=np.array([], dtype=np.float32), + occupancy=np.array([], dtype=np.float32), + ) + + @classmethod + def from_defaults( + cls, + *, + chain_key: np.ndarray, + res_key: np.ndarray, + key: np.ndarray | None = None, + name: np.ndarray | None = None, + element: np.ndarray | None = None, + x: np.ndarray | None = None, + y: np.ndarray | None = None, + z: np.ndarray | None = None, + b_factor: np.ndarray | None = None, + occupancy: np.ndarray | None = None, + ) -> Self: + """Create an Atoms table with minimal user inputs.""" + num_atoms = len(chain_key) + if not num_atoms: + return cls.make_empty() + return Atoms( + chain_key=chain_key, + res_key=res_key, + key=_default(key, np.arange(num_atoms), np.int64), + name=_default(name, ['?'] * num_atoms, object), + element=_default(element, ['?'] * num_atoms, object), + x=_default(x, [0.0] * num_atoms, np.float32), + y=_default(y, [0.0] * num_atoms, np.float32), + z=_default(z, [0.0] * num_atoms, np.float32), + b_factor=_default(b_factor, [0.0] * num_atoms, np.float32), + occupancy=_default(occupancy, [1.0] * num_atoms, np.float32), + ) + + def get_value_by_index( + self, column_name: str, index: int + ) -> table.TableEntry | np.ndarray: + if column_name in self.multimodel_cols: + return self.get_column(column_name)[..., index] + else: + return self.get_column(column_name)[index] + + def copy_and_update_coords(self, coords: np.ndarray) -> Self: + """Returns a copy with the x, y and z columns updated.""" + if coords.shape[-1] != 3: + raise ValueError( + f'Expecting 3-dimensional coordinates, got {coords.shape}' + ) + return typing.cast( + Atoms, + self.copy_and_update( + x=coords[..., 0], y=coords[..., 1], z=coords[..., 2] + ), + ) + + @property + def shape(self) -> tuple[int, ...]: + return self.x.shape + + @property + def ndim(self) -> int: + return len(self.shape) + + @functools.cached_property + def num_models(self) -> int: + """The number of models of this Structure.""" + leading_dims = self.shape[:-1] + match leading_dims: + case(): + return 1 + case(single_leading_dim_size,): + return single_leading_dim_size + case _: + raise ValueError( + 'num_models not defined for atom tables with more than one ' + 'leading dimension.' + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Residues(table.Table): + """Table of residues in a Structure.""" + + chain_key: np.ndarray + id: np.ndarray + name: np.ndarray + auth_seq_id: np.ndarray + insertion_code: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + chain_key=np.array([], dtype=np.int64), + id=np.array([], dtype=np.int32), + name=np.array([], dtype=object), + auth_seq_id=np.array([], dtype=object), + insertion_code=np.array([], dtype=object), + ) + + @classmethod + def from_defaults( + cls, + *, + id: np.ndarray, # pylint:disable=redefined-builtin + chain_key: np.ndarray, + key: np.ndarray | None = None, + name: np.ndarray | None = None, + auth_seq_id: np.ndarray | None = None, + insertion_code: np.ndarray | None = None, + ) -> Self: + """Create a Residues table with minimal user inputs.""" + num_res = len(id) + if not num_res: + return cls.make_empty() + return Residues( + key=_default(key, np.arange(num_res), np.int64), + id=id, + chain_key=chain_key, + name=_default(name, ['UNK'] * num_res, object), + auth_seq_id=_default(auth_seq_id, id.astype(str), object), + insertion_code=_default(insertion_code, ['?'] * num_res, object), + ) + + +@dataclasses.dataclass(frozen=True, slots=True, kw_only=True) +class Chains(table.Table): + """Table of chains in a Structure.""" + + id: np.ndarray + type: np.ndarray + auth_asym_id: np.ndarray + entity_id: np.ndarray + entity_desc: np.ndarray + + @classmethod + def make_empty(cls) -> Self: + return cls( + key=np.array([], dtype=np.int64), + id=np.array([], dtype=object), + type=np.array([], dtype=object), + auth_asym_id=np.array([], dtype=object), + entity_id=np.array([], dtype=object), + entity_desc=np.array([], dtype=object), + ) + + @classmethod + def from_defaults( + cls, + *, + id: np.ndarray, # pylint:disable=redefined-builtin + key: np.ndarray | None = None, + type: np.ndarray | None = None, # pylint:disable=redefined-builtin + auth_asym_id: np.ndarray | None = None, + entity_id: np.ndarray | None = None, + entity_desc: np.ndarray | None = None, + ) -> Self: + """Create a Chains table with minimal user inputs.""" + num_chains = len(id) + if not num_chains: + return cls.make_empty() + + return Chains( + key=_default(key, np.arange(num_chains), np.int64), + id=id, + type=_default(type, [mmcif_names.PROTEIN_CHAIN] + * num_chains, object), + auth_asym_id=_default(auth_asym_id, id, object), + entity_id=_default( + entity_id, np.arange(1, num_chains + 1).astype(str), object + ), + entity_desc=_default(entity_desc, ['.'] * num_chains, object), + ) + + +def to_mmcif_sequence_and_entity_tables( + chains: Chains, + residues: Residues, + atom_res_key: np.ndarray, +) -> Mapping[str, Sequence[str]]: + """Returns raw sequence and entity mmCIF tables.""" + raw_mmcif = collections.defaultdict(list) + chains_by_entity_id = {} + written_entity_poly_seq_ids = set() + present_res_keys = set(atom_res_key) + + # Performance optimisation: Find residue indices for each chain in advance, so + # that we don't have to do redunant masking work for each chain. + res_indices_for_chain = aggregation.indices_grouped_by_value( + residues.chain_key + ) + + for chain in chains.iterrows(): + # Add all chain information to the _struct_asym table. + chain_id = chain['id'] # Saves multiple dict lookups. + auth_asym_id = chain['auth_asym_id'] + entity_id = chain['entity_id'] + chains_by_entity_id.setdefault(entity_id, []).append(chain) + raw_mmcif['_struct_asym.id'].append(chain_id) + raw_mmcif['_struct_asym.entity_id'].append(entity_id) + + res_chain_indices = res_indices_for_chain[chain['key']] + chain_type = chain['type'] + is_polymer = chain_type in mmcif_names.POLYMER_CHAIN_TYPES + is_water = chain_type == mmcif_names.WATER + is_branched = len( + res_chain_indices) > 1 and not is_polymer and not is_water + write_entity_poly_seq = entity_id not in written_entity_poly_seq_ids + + # Iterate over the individual masked residue table columns, as that doesn't + # create a copy (only a view), while residues[res_chain_indices] does. + for res_key, res_name, res_id, pdb_seq_num, res_ins_code in zip( + residues.key[res_chain_indices], + residues.name[res_chain_indices], + residues.id[res_chain_indices], + residues.auth_seq_id[res_chain_indices], + residues.insertion_code[res_chain_indices], + strict=True, + ): + is_missing = res_key not in present_res_keys + str_res_id = str(res_id) + # While atom_site uses "?" for insertion codes, scheme tables use ".". + ins_code = (res_ins_code or '.').replace('?', '.') + auth_seq_num = '?' if is_missing else pdb_seq_num + + if is_polymer: + raw_mmcif['_pdbx_poly_seq_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_poly_seq_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_poly_seq_scheme.seq_id'].append(str_res_id) + raw_mmcif['_pdbx_poly_seq_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_poly_seq_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_strand_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_poly_seq_scheme.pdb_ins_code'].append( + ins_code) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_pdbx_poly_seq_scheme.hetero'].append('n') + if write_entity_poly_seq: + raw_mmcif['_entity_poly_seq.entity_id'].append(entity_id) + raw_mmcif['_entity_poly_seq.num'].append(str_res_id) + raw_mmcif['_entity_poly_seq.mon_id'].append(res_name) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_entity_poly_seq.hetero'].append('n') + written_entity_poly_seq_ids.add(entity_id) + elif is_branched: + raw_mmcif['_pdbx_branch_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_branch_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_branch_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_branch_scheme.num'].append(str_res_id) + raw_mmcif['_pdbx_branch_scheme.pdb_asym_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_branch_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_branch_scheme.auth_asym_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_branch_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_branch_scheme.pdb_ins_code'].append(ins_code) + # Structure doesn't support heterogeneous sequences. + raw_mmcif['_pdbx_branch_scheme.hetero'].append('n') + else: + raw_mmcif['_pdbx_nonpoly_scheme.asym_id'].append(chain_id) + raw_mmcif['_pdbx_nonpoly_scheme.entity_id'].append(entity_id) + raw_mmcif['_pdbx_nonpoly_scheme.mon_id'].append(res_name) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_seq_num'].append( + pdb_seq_num) + raw_mmcif['_pdbx_nonpoly_scheme.auth_seq_num'].append( + auth_seq_num) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_strand_id'].append( + auth_asym_id) + raw_mmcif['_pdbx_nonpoly_scheme.pdb_ins_code'].append(ins_code) + + # Add _entity and _entity_poly tables. + for entity_id, chains in chains_by_entity_id.items(): + # chains should always be a non-empty list because of how we constructed + # chains_by_entity_id. + assert chains + # All chains for a given entity should have the same type and sequence + # so we can pick the first one without losing information. + key_chain = chains[0] + raw_mmcif['_entity.id'].append(entity_id) + raw_mmcif['_entity.pdbx_description'].append(key_chain['entity_desc']) + entity_type = key_chain['type'] + if entity_type not in mmcif_names.POLYMER_CHAIN_TYPES: + raw_mmcif['_entity.type'].append(entity_type) + else: + raw_mmcif['_entity.type'].append('polymer') + raw_mmcif['_entity_poly.entity_id'].append(entity_id) + raw_mmcif['_entity_poly.type'].append(entity_type) + + # _entity_poly.pdbx_strand_id is a comma-separated list of + # auth_asym_ids that are part of the entity. + raw_mmcif['_entity_poly.pdbx_strand_id'].append( + ','.join(chain['auth_asym_id'] for chain in chains) + ) + return raw_mmcif + + +def to_mmcif_atom_site_and_bonds_table( + *, + chains: Chains, + residues: Residues, + atoms: Atoms, + bonds: Bonds, + coords_decimal_places: int, +) -> Mapping[str, Sequence[str]]: + """Returns raw _atom_site and _struct_conn mmCIF tables.""" + raw_mmcif = collections.defaultdict(list) + # Use [value] * num wherever possible since it is about 10x faster than list + # comprehension in such cases. Also use f-strings instead of str() - faster. + total_atoms = atoms.size * atoms.num_models + raw_mmcif['_atom_site.id'] = [f'{i}' for i in range(1, total_atoms + 1)] + raw_mmcif['_atom_site.label_alt_id'] = ['.'] * total_atoms + # Use format_float_array instead of list comprehension for performance. + raw_mmcif['_atom_site.Cartn_x'] = mmcif.format_float_array( + values=atoms.x.ravel(), num_decimal_places=coords_decimal_places + ) + raw_mmcif['_atom_site.Cartn_y'] = mmcif.format_float_array( + values=atoms.y.ravel(), num_decimal_places=coords_decimal_places + ) + raw_mmcif['_atom_site.Cartn_z'] = mmcif.format_float_array( + values=atoms.z.ravel(), num_decimal_places=coords_decimal_places + ) + + # atoms.b_factor or atoms.occupancy can be flat even when the coordinates have + # leading dimensions. In this case we tile it to match. + if atoms.b_factor.ndim == 1: + atom_b_factor = np.tile(atoms.b_factor, atoms.num_models) + else: + atom_b_factor = atoms.b_factor.ravel() + raw_mmcif['_atom_site.B_iso_or_equiv'] = mmcif.format_float_array( + values=atom_b_factor, num_decimal_places=2 + ) + + if atoms.occupancy.ndim == 1: + atom_occupancy = np.tile(atoms.occupancy, atoms.num_models) + else: + atom_occupancy = atoms.occupancy.ravel() + raw_mmcif['_atom_site.occupancy'] = mmcif.format_float_array( + values=atom_occupancy.ravel(), num_decimal_places=2 + ) + + label_atom_id = atoms.name + type_symbol = atoms.element + label_comp_id = residues.apply_array_to_column('name', atoms.res_key) + label_asym_id = chains.apply_array_to_column('id', atoms.chain_key) + label_entity_id = chains.apply_array_to_column( + 'entity_id', atoms.chain_key) + # Performance optimisation: Do the int->str conversion on num_residue-sized, + # array, then select instead of selecting and then converting. + label_seq_id = residues.id.astype('str').astype(object)[ + ..., residues.index_by_key[atoms.res_key] + ] + + # _atom_site.label_seq_id is '.' for non-polymers. + non_polymer_chain_mask = string_array.isin( + chains.type, mmcif_names.POLYMER_CHAIN_TYPES, invert=True + ) + non_polymer_chain_keys = chains.key[non_polymer_chain_mask] + non_polymer_atom_mask = np.isin(atoms.chain_key, non_polymer_chain_keys) + label_seq_id[non_polymer_atom_mask] = '.' + + auth_asym_id = chains.apply_array_to_column( + 'auth_asym_id', atoms.chain_key) + auth_seq_id = residues.apply_array_to_column('auth_seq_id', atoms.res_key) + pdbx_pdb_ins_code = residues.apply_array_to_column( + 'insertion_code', atoms.res_key + ) + string_array.remap(pdbx_pdb_ins_code, mapping={None: '?'}, inplace=True) + + group_pdb = _residue_name_to_record_name( + residue_name=label_comp_id, polymer_mask=~non_polymer_atom_mask + ) + + def tile_for_models(arr: np.ndarray) -> list[str]: + if atoms.num_models == 1: + # Memory optimisation: np.tile(arr, 1) does a copy. + return arr.tolist() + return np.tile(arr, atoms.num_models).tolist() + + raw_mmcif['_atom_site.group_PDB'] = tile_for_models(group_pdb) + raw_mmcif['_atom_site.label_atom_id'] = tile_for_models(label_atom_id) + raw_mmcif['_atom_site.type_symbol'] = tile_for_models(type_symbol) + raw_mmcif['_atom_site.label_comp_id'] = tile_for_models(label_comp_id) + raw_mmcif['_atom_site.label_asym_id'] = tile_for_models(label_asym_id) + raw_mmcif['_atom_site.label_entity_id'] = tile_for_models(label_entity_id) + raw_mmcif['_atom_site.label_seq_id'] = tile_for_models(label_seq_id) + raw_mmcif['_atom_site.auth_asym_id'] = tile_for_models(auth_asym_id) + raw_mmcif['_atom_site.auth_seq_id'] = tile_for_models(auth_seq_id) + raw_mmcif['_atom_site.pdbx_PDB_ins_code'] = tile_for_models( + pdbx_pdb_ins_code) + model_id = np.array( + [str(i + 1) for i in range(atoms.num_models)], dtype=object + ) + raw_mmcif['_atom_site.pdbx_PDB_model_num'] = np.repeat( + model_id, [atoms.size] * atoms.num_models + ).tolist() + + if bonds.key.size > 0: + raw_mmcif.update( + bonds.to_mmcif_dict_from_atom_arrays( + atom_key=atoms.key, + chain_id=label_asym_id, + res_id=label_seq_id, + res_name=label_comp_id, + atom_name=label_atom_id, + auth_asym_id=auth_asym_id, + auth_seq_id=auth_seq_id, + insertion_code=np.array(pdbx_pdb_ins_code), + ) + ) + return raw_mmcif + + +def _flatten_author_naming_scheme_table( + res_table: Mapping[str, Mapping[int, str]], + chain_ids: np.ndarray, + res_chain_ids: np.ndarray, + res_ids: np.ndarray, + default_if_missing: str, + table_name: str, +) -> np.ndarray: + """Flattens an author naming scheme table consistently with res_ids.""" + if not set(chain_ids).issubset(res_table): + raise ValueError( + f'Chain IDs in the chain_id array must be a subset of {table_name} in ' + 'author naming scheme:\n' + f'chain_ids: {sorted(chain_ids)}\n' + f'{table_name} keys: {sorted(res_table.keys())}' + ) + + chain_change_mask = res_chain_ids[1:] != res_chain_ids[:-1] + res_chain_boundaries = np.concatenate( + ([0], np.where(chain_change_mask)[0] + 1, [len(res_chain_ids)]) + ) + + flat_vals = np.empty(len(res_ids), dtype=object) + for chain_start, chain_end in itertools.pairwise(res_chain_boundaries): + chain_id = res_chain_ids[chain_start] + chain_res_ids = res_ids[chain_start:chain_end] + chain_mapping = res_table[chain_id] + flat_vals[chain_start:chain_end] = [ + chain_mapping.get(r, default_if_missing) for r in chain_res_ids + ] + + return flat_vals + + +def tables_from_atom_arrays( + *, + res_id: np.ndarray, + author_naming_scheme: AuthorNamingScheme | None = None, + all_residues: Mapping[str, Sequence[tuple[str, int]]] | None = None, + chain_id: np.ndarray | None = None, + chain_type: np.ndarray | None = None, + res_name: np.ndarray | None = None, + atom_key: np.ndarray | None = None, + atom_name: np.ndarray | None = None, + atom_element: np.ndarray | None = None, + atom_x: np.ndarray | None = None, + atom_y: np.ndarray | None = None, + atom_z: np.ndarray | None = None, + atom_b_factor: np.ndarray | None = None, + atom_occupancy: np.ndarray | None = None, +) -> tuple[Atoms, Residues, Chains]: + """Returns Structure tables constructed from atom array level data. + + All fields except name and, res_id are optional, all array fields consist of a + value for each atom in the structure - so residue and chain values should hold + the same value for each atom in the chain or residue. Fields which are not + defined are filled with default values. + + Validation is performed by the Structure constructor where possible - but + author_naming scheme and all_residues must be checked in this function. + + It is not possible to construct structures with chains that do not contain + any resolved residues using this function. If this is necessary, use the + structure.Structure constructor directly. + + Args: + res_id: Integer array of shape [num_atom]. The unique residue identifier for + each residue. mmCIF field - _atom_site.label_seq_id. + author_naming_scheme: An optional instance of AuthorNamingScheme to use when + converting this structure to mmCIF. + all_residues: An optional mapping from each chain ID (i.e. label_asym_id) to + a sequence of (label_comp_id, label_seq_id) tuples, one per residue. This + can contain residues that aren't present in the atom arrays. This is + common in experimental data where some residues are not resolved but are + known to be present. + chain_id: String array of shape [num_atom] of unique chain identifiers. + mmCIF field - _atom_site.label_asym_id. + chain_type: String array of shape [num_atom]. The molecular type of the + current chain (e.g. polyribonucleotide). mmCIF field - _entity_poly.type + OR _entity.type (for non-polymers). + res_name: String array of shape [num_atom].. The name of each residue, + typically a 3 letter string for polypeptides or 1-2 letter strings for + polynucleotides. mmCIF field - _atom_site.label_comp_id. + atom_key: A unique sorted integer array, used only by the bonds table to + identify the atoms participating in each bond. If the bonds table is + specified then this column must be non-None. + atom_name: String array of shape [num_atom]. The name of each atom (e.g CA, + O2', etc.). mmCIF field - _atom_site.label_atom_id. + atom_element: String array of shape [num_atom]. The element type of each + atom (e.g. C, O, N, etc.). mmCIF field - _atom_site.type_symbol. + atom_x: Float array of shape [..., num_atom] of atom x coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_y: Float array of shape [..., num_atom] of atom y coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_z: Float array of shape [..., num_atom] of atom z coordinates. May have + arbitrary leading dimensions, provided that these are consistent across + all coordinate fields. + atom_b_factor: Float array of shape [..., num_atom] or [num_atom] of atom + b-factors or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + atom_occupancy: Float array of shape [..., num_atom] or [num_atom] of atom + occupancies or equivalent. If there are no extra leading dimensions then + these values are assumed to apply to all coordinates for a given atom. If + there are leading dimensions then these must match those used by the + coordinate fields. + """ + num_atoms = len(res_id) + + for arr_name, array, dtype in ( + ('chain_id', chain_id, object), + ('chain_type', chain_type, object), + ('res_id', res_id, np.int32), + ('res_name', res_name, object), + ('atom_key', atom_key, np.int64), + ('atom_name', atom_name, object), + ('atom_element', atom_element, object), + ): + if array is not None and array.shape != (num_atoms,): + raise ValueError( + f'{arr_name} shape {array.shape} != ({num_atoms},)') + if array is not None and array.dtype != dtype: + raise ValueError(f'{arr_name} dtype {array.dtype} != {dtype}') + + for arr_name, array in ( + ('atom_x', atom_x), + ('atom_y', atom_y), + ('atom_z', atom_z), + ('atom_b_factor', atom_b_factor), + ('atom_occupancy', atom_occupancy), + ): + if array is not None and array.shape[-1] != num_atoms: + raise ValueError( + f'{arr_name} last dim {array.shape[-1]} != {num_atoms=}') + if ( + array is not None + and array.dtype != np.float32 + and array.dtype != np.float64 + ): + raise ValueError( + f'{arr_name} must be np.float32 or np.float64, got {array.dtype=}' + ) + + if all_residues is not None and (res_name is None or res_id is None): + raise ValueError( + 'If all_residues != None, res_name and res_id must not be None either.' + ) + + if num_atoms == 0: + return Atoms.make_empty(), Residues.make_empty(), Chains.make_empty() + + if chain_id is None: + chain_id = np.full(shape=num_atoms, fill_value='A', dtype=object) + if res_name is None: + res_name = np.full(shape=num_atoms, fill_value='UNK', dtype=object) + + chain_change_mask = chain_id[1:] != chain_id[:-1] + chain_start = np.concatenate(([0], np.where(chain_change_mask)[0] + 1)) + res_start = np.concatenate( + ([0], np.where((res_id[1:] != res_id[:-1]) | chain_change_mask)[0] + 1) + ) + + if len(set(chain_id)) != len(chain_start): + raise ValueError(f'Chain IDs must be contiguous, but got {chain_id}') + + # We do not support chains with unresolved residues-only in this function. + chain_ids = chain_id[chain_start] + if all_residues and set(all_residues.keys()) != set(chain_ids): + raise ValueError( + 'all_residues must contain the same set of chain IDs as the chain_id ' + f'array:\nall_residues keys: {sorted(all_residues.keys())}\n' + f'chain_ids: {sorted(chain_ids)}.' + ) + # Make sure all_residue ordering is consistent with chain_id. + if all_residues and np.any(list(all_residues.keys()) != chain_ids): + all_residues = {cid: all_residues[cid] for cid in chain_ids} + + # Create the chains table. + num_chains = len(chain_ids) + chain_keys = np.arange(num_chains, dtype=np.int64) + chain_key_by_chain_id = dict(zip(chain_ids, chain_keys, strict=True)) + + if chain_type is not None: + chain_types = chain_type[chain_start] + else: + chain_types = np.full( + num_chains, mmcif_names.PROTEIN_CHAIN, dtype=object) + + if author_naming_scheme is not None: + auth_asym_id = string_array.remap( + chain_ids, author_naming_scheme.auth_asym_id + ) + entity_id = string_array.remap( + chain_ids, author_naming_scheme.entity_id, default_value='.' + ) + entity_desc = string_array.remap( + entity_id, author_naming_scheme.entity_desc, default_value='.' + ) + else: + auth_asym_id = chain_ids + entity_id = (chain_keys + 1).astype(str).astype(object) + entity_desc = np.full(num_chains, '.', dtype=object) + + chains = Chains( + key=chain_keys, + id=chain_ids, + type=chain_types, + auth_asym_id=auth_asym_id, + entity_id=entity_id, + entity_desc=entity_desc, + ) + + # Create the residues table. + if all_residues is not None: + residue_order = [] + for cid, residues in all_residues.items(): + residue_order.extend((cid, rname, int(rid)) + for (rname, rid) in residues) + res_chain_ids, res_names, res_ids = zip(*residue_order) + res_chain_ids = np.array(res_chain_ids, dtype=object) + res_ids = np.array(res_ids, dtype=np.int32) + res_names = np.array(res_names, dtype=object) + else: + res_chain_ids = chain_id[res_start] + res_ids = res_id[res_start] + res_names = res_name[res_start] + residue_order = list(zip(res_chain_ids, res_names, res_ids)) + + if author_naming_scheme is not None and author_naming_scheme.auth_seq_id: + auth_seq_id = _flatten_author_naming_scheme_table( + author_naming_scheme.auth_seq_id, + chain_ids=chain_ids, + res_chain_ids=res_chain_ids, + res_ids=res_ids, + default_if_missing='.', + table_name='auth_seq_id', + ) + else: + auth_seq_id = res_ids.astype(str).astype(object) + + if author_naming_scheme is not None and author_naming_scheme.insertion_code: + insertion_code = _flatten_author_naming_scheme_table( + author_naming_scheme.insertion_code, + chain_ids=chain_ids, + res_chain_ids=res_chain_ids, + res_ids=res_ids, + default_if_missing='?', + table_name='insertion_code', + ) + # Make sure insertion code of None is mapped to '.'. + insertion_code = string_array.remap(insertion_code, {None: '?'}) + else: + insertion_code = np.full( + shape=len(res_ids), fill_value='?', dtype=object) + + res_key_by_res = {res: i for i, res in enumerate(residue_order)} + res_keys = np.arange(len(residue_order), dtype=np.int64) + res_chain_keys = string_array.remap( + res_chain_ids, chain_key_by_chain_id + ).astype(np.int64) + residues = Residues( + chain_key=res_chain_keys, + key=res_keys, + id=res_ids, + name=res_names, + auth_seq_id=auth_seq_id, + insertion_code=insertion_code, + ) + + if atom_key is None: + atom_key = np.arange(num_atoms, dtype=np.int64) + + atom_chain_keys = string_array.remap(chain_id, chain_key_by_chain_id).astype( + np.int64 + ) + + try: + atom_res_keys = [res_key_by_res[r] + for r in zip(chain_id, res_name, res_id)] + except KeyError as e: + missing_chain_id, missing_res_name, missing_res_id = e.args[0] + raise ValueError( + 'Inconsistent res_name, res_id and all_residues. Could not find ' + f'residue with chain_id={missing_chain_id}, ' + f'res_name={missing_res_name}, res_id={missing_res_id} in all_residues.' + ) from e + + atoms = Atoms( + key=atom_key, + chain_key=atom_chain_keys, + res_key=np.array(atom_res_keys, dtype=np.int64), + name=_default(atom_name, ['?'] * num_atoms, object), + element=_default(atom_element, ['?'] * num_atoms, object), + x=_default(atom_x, [0.0] * num_atoms, np.float32), + y=_default(atom_y, [0.0] * num_atoms, np.float32), + z=_default(atom_z, [0.0] * num_atoms, np.float32), + b_factor=_default(atom_b_factor, [0.0] * num_atoms, np.float32), + occupancy=_default(atom_occupancy, [1.0] * num_atoms, np.float32), + ) + return atoms, residues, chains diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py new file mode 100644 index 000000000..7cad4a27d --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/table.py @@ -0,0 +1,565 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Table module for atom/residue/chain tables in Structure. + +Tables are intended to be lightweight collections of columns, loosely based +on a pandas dataframe, for use in the Structure class. +""" + +import abc +from collections.abc import Callable, Collection, Iterable, Iterator, Mapping, Sequence +import dataclasses +import functools +import graphlib +import typing +from typing_extensions import Any, Protocol, Self, TypeAlias, TypeVar, overload + +from alphafold3.cpp import string_array +import numpy as np + + +TableEntry: TypeAlias = str | int | float | None +FilterPredicate: TypeAlias = ( + TableEntry + | Iterable[Any] # Workaround for b/326384670. Tighten once fixed. + | Callable[[Any], bool] # Workaround for b/326384670. Tighten once fixed. + | Callable[[np.ndarray], bool] +) + + +class RowLookup(Protocol): + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> Mapping[str, Any]: + ... + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class Table: + """Parent class for structure tables. + + A table is a collection of columns of equal length, where one column is the + key. The key uniquely identifies each row in the table. + + A table can refer to other tables by including a foreign key column, whose + values are key values from the other table's key column. These column can have + arbitrary names and are treated like any other integer-valued column. + + See the `Database` class in this module for utilities for handing sets of + tables that are related via foreign keys. + + NB: This does not correspond to an mmCIF table. + """ + + key: np.ndarray + + def __post_init__(self): + for col_name in self.columns: + if (col_len := self.get_column(col_name).shape[-1]) != self.size: + raise ValueError( + f'All columns should have length {self.size} but got "{col_name}"' + f' with length {col_len}.' + ) + # Make col immutable. + self.get_column(col_name).flags.writeable = False + if self.key.size and self.key.min() < 0: + raise ValueError( + 'Key values must be non-negative. Got negative values:' + f' {set(self.key[self.key < 0])}' + ) + self.key.flags.writeable = False # Make key immutable. + + def __getstate__(self) -> dict[str, Any]: + """Returns members with cached properties removed for pickling.""" + cached_props = { + k + for k, v in self.__class__.__dict__.items() + if isinstance(v, functools.cached_property) + } + return {k: v for k, v in self.__dict__.items() if k not in cached_props} + + @functools.cached_property + def index_by_key(self) -> np.ndarray: + """Mapping from key values to their index in the column arrays. + + i.e.: self.key[index_by_key[k]] == k + """ + if not self.key.size: + return np.array([], dtype=np.int64) + else: + index_by_key = np.zeros(np.max(self.key) + 1, dtype=np.int64) + index_by_key[self.key] = np.arange(self.size) + return index_by_key + + @functools.cached_property + def columns(self) -> tuple[str, ...]: + """The names of the columns in the table, including the key column.""" + return tuple(field.name for field in dataclasses.fields(self)) + + @functools.cached_property + def items(self) -> Mapping[str, np.ndarray]: + """Returns the mapping from column names to column values.""" + return {col: getattr(self, col) for col in self.columns} + + @functools.cached_property + def size(self) -> int: + """The number of rows in the table.""" + return self.key.shape[-1] + + def __len__(self) -> int: + return self.size + + def get_column(self, column_name: str) -> np.ndarray: + """Gets a column by name.""" + # Performance optimisation: use the cached columns, instead of getattr. + return self.items[column_name] + + def apply_array(self, arr: np.ndarray) -> Self: + """Returns a sliced table using a key (!= index) array or a boolean mask.""" + if arr.dtype == bool and np.all(arr): + return self # Shortcut: No-op, so just return. + + return self.copy_and_update(**{ + column_name: self.apply_array_to_column(column_name, arr) + for column_name in self.columns + }) + + def apply_index(self, index_arr: np.ndarray) -> Self: + """Returns a sliced table using an index (!= key) array.""" + if index_arr.dtype == bool: + raise ValueError('The index array must not be a boolean mask.') + + return self.copy_and_update( + **{col: self.get_column(col)[..., index_arr] for col in self.columns} + ) + + def apply_array_to_column( + self, + column_name: str, + arr: np.ndarray, + ) -> np.ndarray: + """Returns a sliced column array using a key array or a boolean mask.""" + if arr.dtype == bool: + return self.get_column(column_name)[..., arr] + else: + return self.get_column(column_name)[..., self.index_by_key[arr]] + + def get_value_by_index(self, column_name: str, index: int) -> Any: + return self.get_column(column_name)[index] + + def get_value_by_key( + self, + column_name: str, + key: int | np.integer, + ) -> TableEntry: + """Gets the value of a column at the row with specified key value.""" + return self.get_value_by_index(column_name, self.index_by_key[key]) + + @overload + def __getitem__(self, key: str) -> np.ndarray: + ... + + @overload + def __getitem__(self, key: np.ndarray) -> 'Table': + ... + + @overload + def __getitem__(self, key: tuple[str, int | np.integer]) -> TableEntry: + ... + + @overload + def __getitem__(self, key: tuple[str, np.ndarray]) -> np.ndarray: + ... + + def __getitem__(self, key): + match key: + case str(): + return self.get_column(key) + case np.ndarray() as key_arr_or_mask: + return self.apply_array(key_arr_or_mask) + case str() as col, int() | np.integer() as key_val: + return self.get_value_by_key(col, key_val) + case str() as col, np.ndarray() as key_arr_or_mask: + return self.apply_array_to_column(col, key_arr_or_mask) + case _: + if isinstance(key, tuple): + err_msg = f'{key}, type: tuple({[type(v) for v in key]})' + else: + err_msg = f'{key}, type: {type(key)}' + raise KeyError(err_msg) + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> dict[str, Any]: + """Gets the row with specified key value.""" + return self.get_row_by_index( + self.index_by_key[key], column_name_map=column_name_map + ) + + def get_row_by_index( + self, + index: int, + column_name_map: Mapping[str, str] | None = None, + ) -> dict[str, Any]: + """Gets the row at the specified index.""" + if column_name_map is not None: + return { + renamed_col: self.get_value_by_index(col, index) + for renamed_col, col in column_name_map.items() + } + else: + return {col: self.get_value_by_index(col, index) for col in self.columns} + + def iterrows( + self, + *, + row_keys: np.ndarray | None = None, + column_name_map: Mapping[str, str] | None = None, + **table_by_foreign_key_col: RowLookup, + ) -> Iterator[Mapping[str, Any]]: + """Yields rows from the table. + + Args: + row_keys: An optional array of keys of rows to yield. If None, all rows + will be yielded. + column_name_map: An optional mapping from desired keys in the row dicts to + the names of the columns they correspond to. + **table_by_foreign_key_col: An optional mapping from column names in this + table, which are expected to be columns of foreign keys, to the table + that the foreign keys point into. If provided, then the yielded rows + will include data from the foreign tables at the appropriate key. + """ + if row_keys is not None: + row_indices = self.index_by_key[row_keys] + else: + row_indices = range(self.size) + for i in row_indices: + row = self.get_row_by_index(i, column_name_map=column_name_map) + for key_col, table in table_by_foreign_key_col.items(): + foreign_key = self[key_col][i] + foreign_row = table.get_row_by_key(foreign_key) + row.update(foreign_row) + yield row + + def with_column_names( + self, column_name_map: Mapping[str, str] + ) -> 'RenamedTableView': + """Returns a view of this table with mapped column names.""" + return RenamedTableView(self, column_name_map=column_name_map) + + def make_filter_mask( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + **predicate_by_col: FilterPredicate, + ) -> np.ndarray | None: + """Returns a boolean array of rows to keep, or None if all can be kept. + + Args: + mask: See `Table.filter`. + apply_per_element: See `Table.filter`. + **predicate_by_col: See `Table.filter`. + + Returns: + Either a boolean NumPy array of length `(self.size,)` denoting which rows + should be kept according to the input mask and predicates, or None. None + implies there is no filtering required, and is used where possible + instead of an all-True array to save time and space. + """ + if mask is None: + if not predicate_by_col: + return None + else: + mask = np.ones((self.size,), dtype=bool) + else: + if mask.shape != (self.size,): + raise ValueError( + f'mask must have shape ({self.size},). Got: {mask.shape}.' + ) + if mask.dtype != bool: + raise ValueError( + f'mask must have dtype bool. Got: {mask.dtype}.') + + for col, predicate in predicate_by_col.items(): + if self[col].ndim > 1: + raise ValueError( + f'Cannot filter by column {col} with more than 1 dimension.' + ) + + callable_predicates = [] + if not callable(predicate): + if isinstance(predicate, Iterable) and not isinstance(predicate, str): + target_vals = predicate + else: + target_vals = [predicate] + for target_val in target_vals: + callable_predicates.append( + lambda x, target=target_val: x == target) + else: + callable_predicates.append(predicate) + + field_mask = np.zeros_like(mask) + for callable_predicate in callable_predicates: + if not apply_per_element: + callable_predicate = typing.cast( + Callable[[np.ndarray], bool], callable_predicate + ) + predicate_result = callable_predicate(self.get_column(col)) + else: + predicate_result = np.array( + [callable_predicate(elem) + for elem in self.get_column(col)] + ) + np.logical_or(field_mask, predicate_result, out=field_mask) + np.logical_and(mask, field_mask, out=mask) # Update in-place. + return mask + + def filter( + self, + mask: np.ndarray | None = None, + *, + apply_per_element: bool = False, + invert: bool = False, + **predicate_by_col: FilterPredicate, + ) -> Self: + """Filters the table using mask and/or predicates and returns a new table. + + Predicates can be either: + 1. A constant value, e.g. `'CA'`. In this case then only rows that match + this value for the given column are retained. + 2. A (non-string) iterable e.g. `('A', 'B')`. In this + case then rows are retained if they match any of the provided values for + the given column. + 3. A boolean function e.g. `lambda b_fac: b_fac < 100.0`. + In this case then only rows that evaluate to `True` are retained. By + default this function's parameter is expected to be an array, unless + `apply_per_element=True`. + + Args: + mask: An optional boolean NumPy array with length equal to the table size. + If provided then this will be combined with the other predicates so that + a row is included if it is masked-in *and* matches all the predicates. + apply_per_element: Whether apply predicates to each element in the column + individually, or to pass the whole column array to the predicate. + invert: If True then the returned table will contain exactly those rows + that would be removed if this was `False`. + **predicate_by_col: A mapping from column name to a predicate. Filtered + columns must be 1D arrays. If multiple columns are provided as keyword + arguments then each predicate is applied and the results are combined + using a boolean AND operation, so an atom is only retained if it passes + all predicates. + + Returns: + A new table with the desired rows retained (or filtered out if + `invert=True`). + + Raises: + ValueError: If mask is provided and is not a bool array with shape + `(num_atoms,)`. + """ + filter_mask = self.make_filter_mask( + mask, apply_per_element=apply_per_element, **predicate_by_col + ) + if filter_mask is None: + # No mask or predicate was specified, so we can return early. + if not invert: + return self + else: + return self[np.array((), dtype=np.int64)] + else: + return self[~filter_mask if invert else filter_mask] + + def _validate_keys_are_column_names(self, keys: Collection[str]) -> None: + """Raises an error if any of the keys are not column names.""" + if mismatches := set(keys) - set(self.columns): + raise ValueError(f'Invalid column names: {sorted(mismatches)}.') + + def copy_and_update(self, **new_column_by_column_name: np.ndarray) -> Self: + """Returns a copy of this table with the specified changes applied. + + Args: + **new_column_by_column_name: New values for the specified columns. + + Raises: + ValueError: If a specified column name is not a column in this table. + """ + self._validate_keys_are_column_names(new_column_by_column_name) + return dataclasses.replace(self, **new_column_by_column_name) + + def copy_and_remap( + self, **mapping_by_col: Mapping[TableEntry, TableEntry] + ) -> Self: + """Returns a copy of the table with the specified columns remapped. + + Args: + **mapping_by_col: Each kwarg key should be the name of one of this table's + columns, and each value should be a mapping. The values in the column + will be looked up in the mapping and replaced with the result if one is + found. + + Raises: + ValueError: If a specified column name is not a column in this table. + """ + self._validate_keys_are_column_names(mapping_by_col) + if not self.size: + return self + remapped_cols = {} + for column_name, mapping in mapping_by_col.items(): + col_arr = self.get_column(column_name) + if col_arr.dtype == object: + remapped = string_array.remap(col_arr, mapping) + else: + remapped = np.vectorize(lambda x: mapping.get(x, x))( + col_arr) # pylint: disable=cell-var-from-loop + remapped_cols[column_name] = remapped + return self.copy_and_update(**remapped_cols) + + +class RenamedTableView: + """View of a table with renamed column names.""" + + def __init__(self, table: Table, column_name_map: Mapping[str, str]): + self._table = table + self._column_name_map = column_name_map + + def get_row_by_key( + self, + key: int, + column_name_map: Mapping[str, str] | None = None, + ) -> Mapping[str, Any]: + del column_name_map + return self._table.get_row_by_key( + key, column_name_map=self._column_name_map + ) + + +_DatabaseT = TypeVar('_DatabaseT', bound='Database') + + +class Database(abc.ABC): + """Relational database base class.""" + + @property + @abc.abstractmethod + def tables(self) -> Collection[str]: + """The names of the tables in this database.""" + + @abc.abstractmethod + def get_table(self, table_name: str) -> Table: + """Gets the table with the given name.""" + + @property + @abc.abstractmethod + def foreign_keys(self) -> Mapping[str, Collection[tuple[str, str]]]: + """Describes the relationship between keys in the database. + + Returns: + A map from table names to pairs of `(column_name, foreign_table_name)` + where `column_name` is a column containing foreign keys in the table named + by the key, and the `foreign_table_name` is the name of the table that + those foreign keys refer to. + """ + + @abc.abstractmethod + def copy_and_update( + self: _DatabaseT, + **new_field_by_field_name: ..., + ) -> _DatabaseT: + """Returns a copy of this database with the specified changes applied.""" + + +def table_dependency_order(db: Database) -> Iterable[str]: + """Yields the names of the tables in the database in dependency order. + + This order guarantees that a table appears after all other tables that + it refers to using foreign keys. Specifically A < B implies that A contains + no column that refers to B.key as a foreign key. + + Args: + db: The database that defines the table names and foreign keys. + """ + connections: dict[str, set[str]] = {} + for table_name in db.tables: + connection_set = set() + for _, foreign_table in db.foreign_keys.get(table_name, ()): + connection_set.add(foreign_table) + connections[table_name] = connection_set + yield from graphlib.TopologicalSorter(connections).static_order() + + +def concat_databases(dbs: Sequence[_DatabaseT]) -> _DatabaseT: + """Concatenates the tables across a sequence of databases. + + Args: + dbs: A non-empty sequence of database instances of the same type. + + Returns: + A new database containing the concatenated tables from the input databases. + + Raises: + ValueError: If `dbs` is empty or `dbs` contains different Database + types. + """ + if not dbs: + raise ValueError('Need at least one value to concatenate.') + distinct_db_types = {type(db) for db in dbs} + if len(distinct_db_types) > 1: + raise ValueError( + f'All `dbs` must be of the same type, got: {distinct_db_types}' + ) + + first_db, *other_dbs = dbs + concatted_tables: dict[str, Table] = {} + key_offsets: dict[str, list[int]] = {} + for table_name in table_dependency_order(first_db): + first_table = first_db.get_table(table_name) + columns: dict[str, list[np.ndarray]] = { + column_name: [first_table.get_column(column_name)] + for column_name in first_table.columns + } + key_offsets[table_name] = [ + first_table.key.max() + 1 if first_table.size else 0 + ] + + for prev_index, db in enumerate(other_dbs): + table = db.get_table(table_name) + for col_name in table.columns: + columns[col_name].append(table.get_column(col_name)) + key_offset = key_offsets[table_name][prev_index] + offset_key = table.key + key_offset + columns['key'][-1] = offset_key + if table.size: + key_offsets[table_name].append(offset_key.max() + 1) + else: + key_offsets[table_name].append( + key_offsets[table_name][prev_index]) + for fkey_col_name, foreign_table_name in first_db.foreign_keys.get( + table_name, [] + ): + fkey_columns = columns[fkey_col_name] + fkey_columns[-1] = ( + fkey_columns[-1] + + key_offsets[foreign_table_name][prev_index] + ) + + concatted_columns = { + column_name: np.concatenate(values, axis=-1) + for column_name, values in columns.items() + } + concatted_tables[table_name] = (type(first_table))(**concatted_columns) + return first_db.copy_and_update(**concatted_tables) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py new file mode 100644 index 000000000..8cc6ec498 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/structure/test_utils.py @@ -0,0 +1,358 @@ +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md +# ============================================================================ + +"""Utilities for structure module testing.""" + +import dataclasses + +from absl.testing import parameterized +from alphafold3 import structure +from alphafold3.common.testing import data +import numpy as np + +import os +import contextlib +import datetime +import difflib +import functools +import hashlib +import shutil +import pathlib +from typing import Any +from absl.testing import absltest +import mindspore as ms +from alphafold3.common.testing import data as testing_data +from alphafold3.common import resources +from alphafold3.data import pipeline +from alphafold3.model.atom_layout import atom_layout + +_JACKHMMER_BINARY_PATH = shutil.which('jackhmmer') +_NHMMER_BINARY_PATH = shutil.which('nhmmer') +_HMMALIGN_BINARY_PATH = shutil.which('hmmalign') +_HMMSEARCH_BINARY_PATH = shutil.which('hmmsearch') +_HMMBUILD_BINARY_PATH = shutil.which('hmmbuild') + +@contextlib.contextmanager +def _output(name: str): + with open(result_path := f'{absltest.TEST_TMPDIR.value}/{name}', "wb") as f: + yield result_path, f + + +@functools.singledispatch +def _hash_data(x: Any, /) -> str: + if x is None: + return '<>' + return _hash_data(json.dumps(x).encode('utf-8')) + + +@_hash_data.register +def _(x: bytes, /) -> str: + return hashlib.sha256(x).hexdigest() + + +@_hash_data.register +def _(x: ms.Tensor) -> str: + return _hash_data(x.asnumpy()) + + +@_hash_data.register +def _(x: np.ndarray) -> str: + if x.dtype == object: + return ';'.join(map(_hash_data, x.ravel().tolist())) + return _hash_data(x.tobytes()) + + +@_hash_data.register +def _(_: structure.Structure) -> str: + return '<>' + + +@_hash_data.register +def _(_: atom_layout.AtomLayout) -> str: + return '<>' + + +def _generate_diff(actual: str, expected: str) -> str: + return '\n'.join( + difflib.unified_diff( + expected.split('\n'), + actual.split('\n'), + fromfile='expected', + tofile='actual', + lineterm='', + ) + ) + + +def tree_map(func, dict_tree): + if isinstance(dict_tree, dict): + return {k: tree_map(func, v) for k, v in dict_tree.items()} + else: + if func == "asnumpy": + return dict_tree.asnumpy() + elif func == "float32": + return dict_tree.astype(ms.float32) + elif func == "bfloat16": + return dict_tree.astype(ms.bfloat16) + else: + return func(dict_tree) + +class StructureTestCase(parameterized.TestCase): + """Testing utilities for working with structure.Structure.""" + + def set_path(self, use_full_database=False): + if use_full_database: + small_bfd_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/bfd-first_non_consensus_sequences.fasta' + ).path() + mgnify_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/mgy_clusters_2022_05.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/uniprot_all_2021_04.fa' + ).path() + uniref90_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/uniref90_2022_05.fa' + ).path() + ntrna_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq.fasta' + ).path() + rfam_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/rfam_14_9_clust_seq_id_90_cov_80_rep_seq.fasta' + ).path() + rna_central_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/rnacentral_active_seq_id_90_cov_80_linclust.fasta' + ).path() + pdb_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/mmcif_files' + ).path() + seqres_database_path = testing_data.Data( + '/data/zmmVol2/AF3/public_databases/pdb_seqres_2022_09_28.fasta' + ).path() + else: + small_bfd_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/bfd-first_non_consensus_sequences__subsampled_1000.fasta' + ).path() + mgnify_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/mgy_clusters__subsampled_1000.fa' + ).path() + uniprot_cluster_annot_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniprot_all__subsampled_1000.fasta' + ).path() + uniref90_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/uniref90__subsampled_1000.fasta' + ).path() + ntrna_database_path = testing_data.Data( + resources.ROOT + / ('test_data/miniature_databases/' + 'nt_rna_2023_02_23_clust_seq_id_90_cov_80_rep_seq__subsampled_1000.fasta') + ).path() + rfam_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rfam_14_4_clustered_rep_seq__subsampled_1000.fasta' + ).path() + rna_central_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/rnacentral_active_seq_id_90_cov_80_linclust__subsampled_1000.fasta' + ).path() + pdb_database_path = testing_data.Data( + resources.ROOT / 'test_data/miniature_databases/pdb_mmcif' + ).path() + seqres_database_path = testing_data.Data( + resources.ROOT + / 'test_data/miniature_databases/pdb_seqres_2022_09_28__subsampled_1000.fasta' + ).path() + + self._data_pipeline_config = pipeline.DataPipelineConfig( + jackhmmer_binary_path=_JACKHMMER_BINARY_PATH, + nhmmer_binary_path=_NHMMER_BINARY_PATH, + hmmalign_binary_path=_HMMALIGN_BINARY_PATH, + hmmsearch_binary_path=_HMMSEARCH_BINARY_PATH, + hmmbuild_binary_path=_HMMBUILD_BINARY_PATH, + small_bfd_database_path=small_bfd_database_path, + mgnify_database_path=mgnify_database_path, + uniprot_cluster_annot_database_path=uniprot_cluster_annot_database_path, + uniref90_database_path=uniref90_database_path, + ntrna_database_path=ntrna_database_path, + rfam_database_path=rfam_database_path, + rna_central_database_path=rna_central_database_path, + pdb_database_path=pdb_database_path, + seqres_database_path=seqres_database_path, + max_template_date=datetime.date(2021, 9, 30), + ) + self.data_path = "/data/zmmVol2/AF3/run_test/src/alphafold3/test_data" + + def compare_golden(self, result_path: str, golden_path) -> None: + filename = os.path.split(result_path)[1] + golden_path = pathlib.Path(golden_path) + with open(golden_path, 'r') as golden_file: + golden_text = golden_file.read() + with open(result_path, 'r') as result_file: + result_text = result_file.read() + + diff = _generate_diff(result_text, golden_text) + + self.assertEqual(diff, "", f"Result differs from golden:\n{diff}") + + def assertAuthorNamingSchemeEqual(self, ans1, ans2): # pylint: disable=invalid-name + """Walks naming scheme, making sure all elements are equal.""" + if ans1 is None or ans2 is None: + self.assertIsNone(ans1) + self.assertIsNone(ans2) + return + flat_ans1 = dict(tree.flatten_with_path(dataclasses.asdict(ans1))) + flat_ans2 = dict(tree.flatten_with_path(dataclasses.asdict(ans2))) + for k, v in flat_ans1.items(): + self.assertEqual(v, flat_ans2[k], msg=str(k)) + for k, v in flat_ans2.items(): + self.assertEqual(v, flat_ans1[k], msg=str(k)) + + def assertAllResiduesEqual(self, all_res1, all_res2): # pylint: disable=invalid-name + """Walks all residues, making sure alll elements are equal.""" + if all_res1 is None or all_res2 is None: + self.assertIsNone(all_res1) + self.assertIsNone(all_res2) + return + self.assertSameElements(all_res1.keys(), all_res2.keys()) + for chain_id, chain_res in all_res1.items(): + self.assertSequenceEqual( + chain_res, all_res2[chain_id], msg=chain_id) + + def assertBioassemblyDataEqual(self, data1, data2): # pylint: disable=invalid-name + if data1 is None or data2 is None: + self.assertIsNone(data1) + self.assertIsNone(data2) + return + self.assertDictEqual(data1.to_mmcif_dict(), data2.to_mmcif_dict()) + + def assertChemicalComponentsDataEqual( # pylint: disable=invalid-name + self, + data1, + data2, + allow_chem_comp_data_extension, + ): + """Checks whether two ChemicalComponentData objects are considered equal.""" + if data1 is None or data2 is None: + self.assertIsNone(data1) + self.assertIsNone(data2) + return + if (not allow_chem_comp_data_extension) or ( + data1.chem_comp.keys() ^ data2.chem_comp.keys() + ): + self.assertDictEqual(data1.chem_comp, data2.chem_comp) + else: + mismatching_values = [] + for component_id in data1.chem_comp: + found = data1.chem_comp[component_id] + expected = data2.chem_comp[component_id] + if not found.extends(expected): + mismatching_values.append((component_id, expected, found)) + + if mismatching_values: + mismatch_err_msgs = '\n'.join( + f'{component_id}: {expected} or its extension expected,' + f' but {found} found.' + for component_id, expected, found in mismatching_values + ) + self.fail( + f'Mismatching values for `_chem_comp` table: {mismatch_err_msgs}', + ) + + def assertBondsEqual(self, bonds1, bonds2, atom_key1, atom_key2): # pylint: disable=invalid-name + """Checks whether two Bonds objects are considered equal.""" + # An empty bonds table is functionally equivalent to an empty bonds table. + # NB: this can only ever be None in structure v1. + if bonds1 is None or not bonds1.size or bonds2 is None or not bonds2.size: + self.assertTrue(bonds1 is None or not bonds1.size, + msg=f'{bonds1=}') + self.assertTrue(bonds2 is None or not bonds2.size, + msg=f'{bonds2=}') + return + + ptnr1_indices1, ptnr2_indices1 = bonds1.get_atom_indices(atom_key1) + ptnr1_indices2, ptnr2_indices2 = bonds2.get_atom_indices(atom_key2) + np.testing.assert_array_equal(ptnr1_indices1, ptnr1_indices2) + np.testing.assert_array_equal(ptnr2_indices1, ptnr2_indices2) + np.testing.assert_array_equal(bonds1.type, bonds2.type) + np.testing.assert_array_equal(bonds1.role, bonds2.role) + + def assertStructuresEqual( # pylint: disable=invalid-name + self, + struc1, + struc2, + *, + ignore_fields=None, + allow_chem_comp_data_extension=False, + atol=0, + ): + """Checks whether two Structure objects could be considered equal. + + Args: + struc1: First Structure object. + struc2: Second Structure object. + ignore_fields: Fields not taken into account during comparison. + allow_chem_comp_data_extension: Whether to allow data of `_chem_comp` + table to differ if `struc2` is missing some fields, but `struc1` has + specific values for them. + atol: Absolute tolerance for floating point comparisons (in + np.testing.assert_allclose). + """ + for field in sorted(structure.GLOBAL_FIELDS): + if ignore_fields and field in ignore_fields: + continue + if field == 'author_naming_scheme': + self.assertAuthorNamingSchemeEqual( + struc1[field], struc2[field]) + elif field == 'all_residues': + self.assertAllResiduesEqual(struc1[field], struc2[field]) + elif field == 'bioassembly_data': + self.assertBioassemblyDataEqual(struc1[field], struc2[field]) + elif field == 'chemical_components_data': + self.assertChemicalComponentsDataEqual( + struc1[field], struc2[field], allow_chem_comp_data_extension + ) + elif field == 'bonds': + self.assertBondsEqual( + struc1.bonds, struc2.bonds, struc1.atom_key, struc2.atom_key + ) + else: + self.assertEqual(struc1[field], struc2[field], msg=field) + + # The chain order within a structure is arbitrary so in order to + # directly compare arrays we first align struc1 to struc2 and check that + # the number of atoms doesn't change. + num_atoms = struc1.num_atoms + self.assertEqual(struc2.num_atoms, num_atoms) + struc1 = struc1.order_and_drop_atoms_to_match(struc2) + self.assertEqual(struc1.num_atoms, num_atoms) + + for field in sorted(structure.ARRAY_FIELDS): + if field == 'atom_key': + # atom_key has no external meaning, so it doesn't matter whether it + # differs between two structures. + continue + if ignore_fields and field in ignore_fields: + continue + self.assertEqual(struc1[field] is None, + struc2[field] is None, msg=field) + + if np.issubdtype(struc1[field].dtype, np.inexact): + np.testing.assert_allclose( + struc1[field], struc2[field], err_msg=field, atol=atol + ) + else: + np.testing.assert_array_equal( + struc1[field], struc2[field], err_msg=field + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py new file mode 100644 index 000000000..3f4e41c75 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention.py @@ -0,0 +1,78 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import mindspore as ms +from typing import Literal, TypeAlias +import typing +import alphafold3.utils.attention.attention_base as base +import alphafold3.utils.attention.ms_attention as ms_attention + +Implementation: TypeAlias = Literal["ms"] + + +def dot_product_attention(query, key, value, *, bias, mask, implementation, + logits_dtype=None, precision=None): + """Performs scaled dot-product attention. + + Scaled dot-product attention from "Attention is all you need" + https://arxiv.org/abs/1706.03762. + + Computes self- or cross-attention. The following is computed: + softmax(qk_scale * query @ key^T + bias) @ value. + + Supports both multi-head and multi-query attention + (https://arxiv.org/abs/1911.02150). + + Arguments: + query: Query array of shape `[batch, seq_len_q, num_heads, head_dim]`. + key: Key array of shape `[batch, seq_len_kv, num_heads, head_dim]`. + `num_heads` can be 1 for multi-query attention. + value: Value array of shape `[batch, seq_len_kv, num_heads, head_dim]`. + `num_heads` can be 1 for multi-query attention. + bias: Optional bias array, broadcastable to shape `[batch, num_heads, + seq_len_q, seq_len_kv]`. + mask: Optional boolean mask, broadcastable to `[batch, num_heads, seq_len_q, + seq_len_kv]`. Attention weights are masked out if the corresponding mask + value is `False`. + implementation: if `None` (default), an implementation is automatically + chosen. 'ms' will use standard MS and work on any platform. + logits_dtype: Data type for attention logits (`query @ key^T`). If `None` is + passed (the default), the accumulator type from the `query @ key^T` dot + product will be used, which is FP32 for BF16/FP16/FP32 inputs. Note that + this default increases the memory usage for BF16/FP16 inputs when using + `implementation='ms'`. + precision: The precision for the dot products. Either a single or a tuple + of `DEFAULT` precision. + + Returns: + An array with the same shape as `query`. + """ + + if implementation is not None: + named_args = typing.get_args(Implementation) + if implementation not in named_args: + raise ValueError( + f"Unsupported named implementation. Must be one of {named_args}." + ) + + logits_dtype = base.AUTO if logits_dtype is None else logits_dtype + precision = "DEFAULT" if precision is None else precision + + args = (query, key, value) + kwargs = dict( + precision=precision, + logits_dtype=logits_dtype, + bias=bias, + mask=mask, + ) + + return ms_attention.MsDotProductAttention()(*args, **kwargs) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py new file mode 100644 index 000000000..e739817d4 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_base.py @@ -0,0 +1,275 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import abc +import enum +import math +import dataclasses +import functools +from dataclasses import dataclass, KW_ONLY +from typing import Any +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor +from alphafold3.utils.common import precision as precision_lib + + +class AUTO: # Used as a sentinel value. + pass + + +@dataclasses.dataclass(frozen=True) +class Mask: + """An attention mask. + + `k_start` (inclusive) and `k_end` (exclusive) define range of enabled + k-sequence values for each row of logits. + + For example, a local attention mask could be defined as follows: + ``` + seq_len_q = seq_len_k = 4 + window_size = 2 + k_start = Tensor(np.maximum(0, np.arange(seq_len_q) + 1 - window_size)) + mask = Mask(k_start=k_start, is_causal=True) + assert mask.as_array(seq_len_q, seq_len_k) == Tensor(np.array( + [[1, 0, 0, 0], + [1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, 1]], dtype=bool)) + ``` + """ + bool_mask: ms.Tensor | None = None + _: dataclasses.KW_ONLY + q_start: ms.Tensor | None = None + q_end: ms.Tensor | None = None + k_start: ms.Tensor | None = None + k_end: ms.Tensor | None = None + is_causal: bool = False + + def tree_flatten(self): + return ( + self.bool_mask, + self.q_start, + self.q_end, + self.k_start, + self.k_end, + ), (self.is_causal,) + + @classmethod + def tree_unflatten(cls, aux, children): + (is_causal,) = aux + bool_mask, q_start, q_end, k_start, k_end = children + return cls( + bool_mask, + q_start=q_start, + q_end=q_end, + k_start=k_start, + k_end=k_end, + is_causal=is_causal, + ) + + def as_array(self, q_len_or_indices, k_len_or_indices): + """Returns the mask as a boolean array.""" + q_indices = ops.arange(q_len_or_indices) if isinstance( + q_len_or_indices, int) else q_len_or_indices + q_indices = q_indices[..., None] + + k_indices = ops.arange(k_len_or_indices) if isinstance( + k_len_or_indices, int) else k_len_or_indices + k_indices = k_indices[..., None, :] + + mask = [] + if self.bool_mask is not None: + mask.append(self.bool_mask) + + if self.q_start is not None: + mask.append(q_indices >= self.q_start[..., None, :]) + + if self.q_end is not None: + mask.append(q_indices < self.q_end[..., None, :]) + + if self.k_start is not None: + mask.append(k_indices >= self.k_start[..., None]) + + if self.k_end is not None: + mask.append(k_indices < self.k_end[..., None]) + + if self.is_causal: + mask.append(q_indices >= k_indices) + + logical_and = functools.partial(functools.reduce, ops.logical_and) + + if mask: + return logical_and(mask) + else: + return None + + def take(self, *attrs): + """Returns a mask with attrs removed and the removed attrs.""" + default_mask = type(self)() + replacements = {attr: getattr(default_mask, attr) for attr in attrs} + values = (getattr(self, attr) for attr in attrs) + return dataclasses.replace(self, **replacements), *values + + def __and__(self, other): + """Returns the intersection of two masks.""" + if not isinstance(other, Mask): + other = Mask(other) + + def combine(op): + return lambda a, b: b if a is None else a if b is None else op(a, b) + + return Mask( + bool_mask=combine(ops.logical_and)( + self.bool_mask, other.bool_mask), + q_end=combine(ops.minimum)(self.q_end, other.q_end), + k_start=combine(ops.maximum)(self.k_start, other.k_start), + k_end=combine(ops.minimum)(self.k_end, other.k_end), + is_causal=self.is_causal or other.is_causal, + ) + + +CAUSAL_MASK = Mask(is_causal=True) + + +@enum.unique +class SoftmaxResidualMode(enum.Enum): + """The mode of storing softmax residuals for the backwards pass. + + The stable softmax calculation performs two reductions calculating: + - the maximum input value (`x_max`), + - the sum of exponentiated values (`denom`). + + We can store these values as residuals to avoid the need to recompute them + in the backwards pass. + + It is also possible to combine the two residuals into a single residual, + `res = x_max + log(denom)`, as `exp(x - res) === exp(x - x_max - log(denom)) + === exp(x - x_max) / denom`. Combining the residuals reduces the memory usage + of the residuals, but will reduce the accuracy of the backwards pass if + `abs(x_max) >> log(denom)`. + """ + + SEPARATE = "separate" + COMBINED = "combined" + + def conform(self, aux): + match self, aux: + case None, _: + return None + case SoftmaxResidualMode.SEPARATE, (_, _): + return aux + case SoftmaxResidualMode.SEPARATE, _: # pytype: disable=redundant-match # b/300135240 + raise ValueError("`aux` has been combined.") + case SoftmaxResidualMode.COMBINED, (x_max, denom): + return x_max + ops.log(denom) + case SoftmaxResidualMode.COMBINED, _: # pytype: disable=redundant-match # b/300135240 + return aux + + +class DotProductAttention(abc.ABC): + """Dot product attention function.""" + + def __call__(self, query, key, value, *, precision, logits_dtype, bias, mask, q_indices=None, k_indices=None): + """Performs scaled dot-product attention. + + Scaled dot-product attention from "Attention is all you need" + https://arxiv.org/abs/1706.03762. + + Computes self- or cross-attention. The following is computed: + softmax(qk_scale * query @ key^T + bias) @ value. + + Supports both multi-head and multi-query attention + (https://arxiv.org/abs/1911.02150). + + Arguments: + query: Query array of shape `[batch, seq_len_q, num_heads_q, head_dim]`. + It must be a multiple of num_heads_kv. + Here's an example of how q/kv heads are interleaved: + For 8 key/value heads and 4 query heads: + - key/value heads [0, 1] see query head 0 + - key/value heads [2, 3] see query head 1 + - key/value heads [4, 5] see query head 2 + key: Key array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. It + must be divisible by num_heads_q. + value: Value array of shape `[batch, seq_len_kv, num_heads_kv, head_dim]`. + precision: The precision for the dot products. Either a tuple `( + query_key_dot_precision, weights_value_dot_precision)` or a single + precision applied to both dot products. + logits_dtype: Data type for attention logits (`query @ key^T`). If `AUTO` + is passed (the default), the accumulator type from the `query @ key^T` + dot product will be used. + bias: Optional bias array, broadcastable to shape `[batch, num_heads, + seq_len_q, seq_len_kv]`. + mask: Optional boolean mask, broadcastable to `[batch, num_heads, + seq_len_q, seq_len_kv]`. Attention weights are masked out if the + corresponding mask value is `False`. + q_indices: Optional indices for each token in query sequence. + k_indices: Optional indices for each token in key/value sequence. + + Returns: + An array with the same shape as `query`. + """ + return self.fwd( + query, + key, + value, + precision=precision, + logits_dtype=logits_dtype, + bias=bias, + mask=mask, + q_indices=q_indices, + k_indices=k_indices, + ) + + def fwd(self, query, key, value, *, precision, logits_dtype, bias, mask, q_indices, k_indices): + """Performs attention.""" + if not isinstance(precision, tuple): + precision = (precision, precision) + + q_k_dot_precision, weights_v_dot_precision = precision + + if not isinstance(q_k_dot_precision, precision_lib.DotPrecision): + q_k_dot_precision = precision_lib.get_equivalent_dot_precision( + query.dtype, key.dtype, q_k_dot_precision + ) + + if not isinstance(weights_v_dot_precision, precision_lib.DotPrecision): + weights_v_dot_precision = precision_lib.get_equivalent_dot_precision( + value.dtype, value.dtype, weights_v_dot_precision + ) + + if logits_dtype is AUTO: + logits_dtype = q_k_dot_precision.accumulator_dtype + + if not isinstance(mask, Mask): + mask = Mask(mask) + + return self._fwd( + Tensor(query), + Tensor(key), + Tensor(value), + q_k_dot_precision=q_k_dot_precision, + logits_dtype=logits_dtype, + logits_scale=1 / math.sqrt(query.shape[-1]), + bias=bias, + mask=mask, + weights_v_dot_precision=weights_v_dot_precision, + q_indices=q_indices, + k_indices=k_indices, + ) + + @abc.abstractmethod + def _fwd(self, q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, bias, mask, + weights_v_dot_precision, q_indices, k_indices): + """Performs attention.""" + ... diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py new file mode 100644 index 000000000..6e8db1bde --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/attention_call_arg_specs.py @@ -0,0 +1,61 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Attention call argument specifications. + +Attention argument specifications used by users of the library. +They are the most important test cases, and also cases for optimize +performance of via autotuning. +""" + +from typing import Any + + +def _make_argspec( + *, + q_shape, + dtype, + k_shape=None, + v_shape=None, + bias_shape=None, + mask_shape=None, + **kwargs, +) -> dict[str, Any]: + """Make argspec from shapes and kwargs.""" + if k_shape is None: + k_shape = q_shape + if v_shape is None: + v_shape = k_shape + + return dict( + query=q_shape, + key=k_shape, + value=v_shape, + bias=bias_shape, + mask=mask_shape, + dtype=dtype, + **kwargs, + ) + + +# A subset of the full set of argument specifications. Useful for tap-tests and +# microbenchmarks. +CALL_ARG_SPECS = dict( + vanilla_f32=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='float32'), + vanilla_bf16=_make_argspec(q_shape=(8, 1024, 4, 128), dtype='bfloat16'), + alphafold=_make_argspec( + q_shape=(384, 384, 4, 32), + bias_shape=(1, 4, 384, 384), + mask_shape=(384, 1, 1, 384), + dtype='bfloat16', + ), +) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py new file mode 100644 index 000000000..3b5acda5e --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/attention/ms_attention.py @@ -0,0 +1,97 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +import dataclasses +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops +import alphafold3.utils.attention.attention_base as base + + +def _softmax(x): + """Computes softmax.""" + dtype = ms.float32 + x_max, _ = ops.max(x.astype(dtype), axis=-1, keepdims=True) + unnormalized = ops.exp(x - x_max) + denom = ops.sum(unnormalized, dim=-1, keepdim=True) + return (unnormalized / denom).astype(x.dtype) + + +def cal_logits(q, k, use_bf16=False): + # ...qhd,...khd->...hqk + dtype = q.dtype + if use_bf16: + q = q.astype(ms.bfloat16) + k = k.astype(ms.bfloat16) + q_trans = ops.transpose(q, (0, 2, 1, 3)) # ...qhd -> ...hqd + k_trans = ops.transpose(k, (0, 2, 3, 1)) # ...khd -> ...hdk + logits = ops.matmul(q_trans, k_trans) + if use_bf16: + logits = logits.astype(dtype) + return logits + + +def cal_out(weights, v, use_bf16=False): + # ...hqk,...khd->...qhd + dtype = v.dtype + if use_bf16: + weights = weights.astype(ms.bfloat16) + v = v.astype(ms.bfloat16) + v_trans = ops.transpose(v, (0, 2, 1, 3)) # ...khd -> ...hkd + out_temp = ops.matmul(weights, v_trans) # ...hqk,...hkd->...hqd + out = ops.transpose(out_temp, (0, 2, 1, 3)) + if use_bf16: + logits = out.astype(dtype) + return out + + +def _attend( + q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, + bias, mask, weights_v_dot_precision, q_indices, k_indices, +): + logits = cal_logits(q, k) + + logits *= logits_scale + + if bias is not None: + logits += bias + + if mask is not None: + q_len_or_indices = q.shape[-3] if q_indices is None else q_indices + k_len_or_indices = k.shape[-3] if k_indices is None else k_indices + mask = mask.as_array(q_len_or_indices, k_len_or_indices) + + if mask is not None: # TBD in ms + mask_value = -3.4028235e+37 # a small value close to min of bfloat16 + logits = ops.where(mask.bool(), logits, mask_value) + + weights = _softmax(logits) + + out = cal_out(weights, v) + + return out + + +@dataclasses.dataclass(frozen=True) +class MsDotProductAttention(base.DotProductAttention): + """MS dot product attention function.""" + + _: dataclasses.KW_ONLY + + def _fwd( + self, q, k, v, *, q_k_dot_precision, logits_dtype, logits_scale, + bias, mask, weights_v_dot_precision, q_indices, k_indices, + ): + + return _attend( + q, k, v, bias=bias, mask=mask, q_indices=q_indices, k_indices=k_indices, + q_k_dot_precision=q_k_dot_precision, logits_dtype=logits_dtype, logits_scale=logits_scale, + weights_v_dot_precision=weights_v_dot_precision, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py new file mode 100644 index 000000000..ee9da311b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/common/precision.py @@ -0,0 +1,90 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Precision classes and utilities.""" + +import enum +import mindspore as ms + + +@enum.unique +class DotPrecision(enum.Enum): + """Precision for `dot` operation. + + Naming scheme: {OPERAND_DTYPE}_{ACCUMULATOR_DTYPE}[_{NUM_PASSES}x] + """ + + BF16_F32 = "bf16_f32" + + # NPU only precisions. + F32_F32 = "f32_f32" # Full f32 precision (doesn't use TensorCores). + F16_F16 = "f16_f16" + F16_F32 = "f16_f32" + + @property + def operand_dtype(self) -> ms.dtype: + match self: + case DotPrecision.BF16_F32: + return ms.bfloat16 + case DotPrecision.F16_F16 | DotPrecision.F16_F32: + return ms.float16 + case _: + return ms.float32 + + @property + def accumulator_dtype(self) -> ms.dtype: + return ms.float16 if (self == DotPrecision.F16_F16) else ms.float32 + + +_MS_NPU_PRECISION_MAP = { + (ms.float16, "DEFAULT"): DotPrecision.F16_F32, + (ms.bfloat16, "DEFAULT"): DotPrecision.BF16_F32, + (ms.float32, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "HIGH"): DotPrecision.F32_F32, + (ms.float32, "HIGHEST"): DotPrecision.F32_F32, +} + +_MS_CPU_PRECISION_MAP = { + (ms.float16, "DEFAULT"): DotPrecision.F16_F32, + (ms.bfloat16, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "DEFAULT"): DotPrecision.F32_F32, + (ms.float32, "HIGH"): DotPrecision.F32_F32, + (ms.float32, "HIGHEST"): DotPrecision.F32_F32, +} + + +def _create_ms_precision_map(): + precision_map = {} + for (dtype, ms_precision), dot_precision in _MS_NPU_PRECISION_MAP.items(): + precision_map[("ascend", dtype, ms_precision)] = dot_precision + for (dtype, ms_precision), dot_precision in _MS_CPU_PRECISION_MAP.items(): + precision_map[("cpu", dtype, ms_precision)] = dot_precision + return precision_map + + +_MS_PRECISION_MAP = _create_ms_precision_map() + + +def get_equivalent_dot_precision( + a_dtype: ms.dtype, b_dtype: ms.dtype, ms_precision: str +) -> DotPrecision: + """Returns `DotPrecision` replicating default behaviour.""" + if a_dtype != b_dtype: + raise ValueError("Cannot infer precision if operand types differ.") + + backend = ms.context.get_context("device_target").lower() + if (ms_precision != "DEFAULT") and (a_dtype != ms.float32): + raise ValueError( + "`Precision` values other than `DEFAULT` only have an effect if" + " the operand type is `float32`." + ) + return _MS_PRECISION_MAP[(backend, a_dtype, ms_precision)] diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py new file mode 100644 index 000000000..deb0c0ef0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit.py @@ -0,0 +1,69 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Public API for gated linear unit functions.""" + +import typing +from typing import Literal, TypeAlias +import numpy as np +import mindspore as ms +from mindspore import Tensor +from alphafold3.utils.gated_linear_unit import gated_linear_unit_base + +Implementation: TypeAlias = Literal['ms'] + + +def gated_linear_unit(x, weight, *, activation, precision, implementation=None): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + implementation: if `None` (default), an implementation is automatically + chosen. 'ms' will use standard MS and work on any platform. + + Raises: + ValueError: if the arguments are invalid. + + Returns: + The output array. + """ + + if x.dtype != weight.dtype: + raise ValueError( + f'Input and weight must have the same dtype. {x.dtype} !=' + f' {weight.dtype}' + ) + + if implementation is not None: + named_args = typing.get_args(Implementation) + if implementation not in named_args: + raise ValueError( + f'Unsupported named implementation. Must be one of {named_args}.' + ) + + return gated_linear_unit_base.gated_linear_unit_ms( + x=x, + weight=weight, + activation=activation, + precision=precision, + ) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py new file mode 100644 index 000000000..d899c6714 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/gated_linear_unit/gated_linear_unit_base.py @@ -0,0 +1,84 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Common types for gated linear unit kernels.""" +import abc +import numpy as np +import mindspore as ms +from mindspore import Tensor, ops, mint + + +class GatedLinearUnit(abc.ABC): + """Gated linear unit.""" + + def __call__(self, x, weight, *, activation, precision, **kwargs): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + + Returns: + The output array. + """ + + return self._fwd( + x, weight, activation=activation, precision=precision, **kwargs + ) + + @abc.abstractmethod + def _fwd(self, x, weight, *, activation, precision): + """Gated linear unit.""" + ... + + +def gated_linear_unit_ms(x, weight, *, activation, precision=None): + """Applies a gated linear unit (https://arxiv.org/abs/1612.08083). + + Computes `activation(x @ weight[:, 0]) * x @ weight[:, 1]`. + + This is SwiGLU when `activation=swish`, GEGLU when + `activation=gelu`, REGLU when `activation=relu`, and GLU when + `activation=sigmoid` (https://arxiv.org/abs/2002.05202). + + Args: + x: the input array. + weight: the combined weight array. + activation: optional activation function. + precision: specifies the matrix multiplication precision. Either `None` + (default), which means the default precision for the backend, or an + enum of "DEFAULT/HIGH/...". + + Returns: + The output array. + """ + + weight_reshaped = mint.reshape( + weight, (-1, weight.shape[-2] * weight.shape[-1])) + # y = ops.dot(x.astype('float32'), weight_reshaped.astype('float32')) + y1 = mint.matmul(x, weight_reshaped) + y = y1.astype(ms.float32) + a, b = y.split(y.shape[-1] // 2, axis=-1) + out = mint.mul(a, b) if activation is None else mint.mul(activation(a), b) + out = out.astype(x.dtype) + + return out diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py new file mode 100644 index 000000000..910ccfe9f --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +from alphafold3.utils.geometry import rigid_matrix_vector +from alphafold3.utils.geometry import rotation_matrix +from alphafold3.utils.geometry import struct_of_array +from alphafold3.utils.geometry import vector + +Rot3Array = rotation_matrix.Rot3Array +Rigid3Array = rigid_matrix_vector.Rigid3Array + +StructOfArray = struct_of_array.StructOfArray + +Vec3Array = vector.Vec3Array +square_euclidean_distance = vector.square_euclidean_distance +euclidean_distance = vector.euclidean_distance +dihedral_angle = vector.dihedral_angle +dot = vector.dot +cross = vector.cross diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py new file mode 100644 index 000000000..5eab7b905 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rigid_matrix_vector.py @@ -0,0 +1,192 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Rigid3Array Transformations represented by a Matrix and a Vector.""" + +from typing import Any, Final, TypeAlias +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import Tensor +from mindspore.ops import operations as P + +from alphafold3.utils.geometry import rotation_matrix, struct_of_array, utils, vector + +Float: TypeAlias = float | Tensor + +VERSION: Final[str] = '0.1' + + +def _compute_covariance_matrix( + row_values: vector.Vec3Array, + col_values: vector.Vec3Array, + weights: Tensor, + epsilon=1e-6, +) -> Tensor: + """Compute covariance matrix.""" + weights = mnp.asarray(weights) + + weights = mnp.broadcast_to(weights, row_values.shape) + + normalized_weights = weights / \ + (mnp.sum(weights, axis=-1, keepdims=True) + epsilon) + + def weighted_average(x): return mnp.sum(normalized_weights * x, axis=-1) + + out = [ + mnp.stack( + ( + weighted_average(row_values.x * col_values.x), + weighted_average(row_values.x * col_values.y), + weighted_average(row_values.x * col_values.z), + ), + axis=-1, + ) + ] + + out.append( + mnp.stack( + ( + weighted_average(row_values.y * col_values.x), + weighted_average(row_values.y * col_values.y), + weighted_average(row_values.y * col_values.z), + ), + axis=-1, + ) + ) + + out.append( + mnp.stack( + ( + weighted_average(row_values.z * col_values.x), + weighted_average(row_values.z * col_values.y), + weighted_average(row_values.z * col_values.z), + ), + axis=-1, + ) + ) + + return mnp.stack(out, axis=-2) + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rigid3Array: + """Rigid Transformation, i.e. element of special euclidean group.""" + + rotation: rotation_matrix.Rot3Array + translation: vector.Vec3Array + + def __matmul__(self, other: 'Rigid3Array') -> 'Rigid3Array': + new_rotation = self.rotation @ other.rotation + new_translation = self.apply_to_point(other.translation) + return Rigid3Array(new_rotation, new_translation) + + def inverse(self) -> 'Rigid3Array': + """Return Rigid3Array corresponding to inverse transform.""" + inv_rotation = self.rotation.inverse() + inv_translation = inv_rotation.apply_to_point(-self.translation) + return Rigid3Array(inv_rotation, inv_translation) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply Rigid3Array transform to point.""" + return self.rotation.apply_to_point(point) + self.translation + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Apply inverse Rigid3Array transform to point.""" + new_point = point - self.translation + return self.rotation.apply_inverse_to_point(new_point) + + def compose_rotation(self, other_rotation: rotation_matrix.Rot3Array) -> 'Rigid3Array': + rot = self.rotation @ other_rotation + trans = P.BroadcastTo(rot.shape)(self.translation) + return Rigid3Array(rot, trans) + + @classmethod + def identity(cls, shape: Any, dtype: ms.dtype = ms.float32) -> 'Rigid3Array': + """Return identity Rigid3Array of given shape.""" + + return cls( + rotation_matrix.Rot3Array.identity(shape, dtype=dtype), + vector.Vec3Array.zeros(shape, dtype=dtype), + ) + + def scale_translation(self, factor: Float) -> 'Rigid3Array': + """Scale translation in Rigid3Array by 'factor'.""" + return Rigid3Array(self.rotation, self.translation * factor) + + def to_array(self): + rot_array = self.rotation.to_array() + vec_array = self.translation.to_array() + return mnp.concatenate([rot_array, vec_array[..., None]], axis=-1) + + @classmethod + def from_array(cls, array): + rot = rotation_matrix.Rot3Array.from_array(array[..., :3]) + vec = vector.Vec3Array.from_array(array[..., -1]) + return cls(rot, vec) + + @classmethod + def from_array4x4(cls, array: Tensor) -> 'Rigid3Array': + """Construct Rigid3Array from homogeneous 4x4 array.""" + if array.shape[-2:] != (4, 4): + raise ValueError(f'array.shape({array.shape}) must be [..., 4, 4]') + rotation = rotation_matrix.Rot3Array( + *(array[..., 0, 0], array[..., 0, 1], array[..., 0, 2]), + *(array[..., 1, 0], array[..., 1, 1], array[..., 1, 2]), + *(array[..., 2, 0], array[..., 2, 1], array[..., 2, 2]), + ) + translation = vector.Vec3Array( + array[..., 0, 3], array[..., 1, 3], array[..., 2, 3] + ) + return cls(rotation, translation) + + @classmethod + def from_point_alignment( + cls, + points_to: vector.Vec3Array, + points_from: vector.Vec3Array, + weights: Float | None = None, + epsilon: float = 1e-6, + ) -> 'Rigid3Array': + """Constructs Rigid3Array by finding transform aligning points.""" + if weights is None: + weights = 1.0 + + def compute_center(value): + return utils.weighted_mean(value=value, weights=weights, axis=-1) + + points_to_center = P.Map()(compute_center, points_to) + points_from_center = P.Map()(compute_center, points_from) + centered_points_to = points_to - points_to_center[..., None] + centered_points_from = points_from - points_from_center[..., None] + cov_mat = _compute_covariance_matrix( + centered_points_to, + centered_points_from, + weights=weights, + epsilon=epsilon, + ) + rots = rotation_matrix.Rot3Array.from_svd( + mnp.reshape(cov_mat, cov_mat.shape[:-2] + (9,)) + ) + + translations = points_to_center - \ + rots.apply_to_point(points_from_center) + + return cls(rots, translations) + + def __getstate__(self): + return (VERSION, (self.rotation, self.translation)) + + def __setstate__(self, state): + version, (rot, trans) = state + del version + object.__setattr__(self, 'rotation', rot) + object.__setattr__(self, 'translation', trans) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py new file mode 100644 index 000000000..481be9989 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/rotation_matrix.py @@ -0,0 +1,257 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Rot3Array Matrix Class.""" + +import dataclasses +from typing import Any, Final +from mindspore import ops, mint +import numpy as np +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import Tensor +from alphafold3.utils.geometry import struct_of_array, utils, vector + +COMPONENTS: Final[tuple[str, ...]] = ( + *('xx', 'xy', 'xz'), + *('yx', 'yy', 'yz'), + *('zx', 'zy', 'zz'), +) +VERSION: Final[str] = '0.1' + + +def make_matrix_svd_factors() -> Tensor: + """Generates factors for converting 3x3 matrix to symmetric 4x4 matrix.""" + factors = mnp.zeros((16, 9), dtype=ms.float32) + + indices = [(0, [0, 4, 8]), ([1, 4], 5), ([1, 4], 7), ([2, 8], 6), ([2, 8], 2), + ([3, 12], 1), ([3, 12], 3), (5, 0), (5, [4, 8]), + ([6, 9], 1), ([6, 9], 3), ([7, 13], 2), ([7, 13], 6), + (10, 4), (10, [0, 8]), ([11, 14], 5), ([11, 14], 7), (15, 8), (15, [0, 4])] + + values = [[1.0], [1.0, -1.0], [1.0, -1.0], [1.0, -1.0], [1.0, -1.0], + [1.0, -1.0], [1.0, 1.0], [1.0, -1.0], [-1.0, -1.0], + [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], + [1.0, -1.0], [1.0, -1.0], [1.0, 1.0], [1.0, 1.0]] + + for idx, val in zip(indices, values): + if isinstance(idx[1], list): + for i in idx[1]: + factors[idx[0], i] = val[i % len(val)] + else: + factors[idx[0], idx[1]] = val[0] + + return factors + + +def largest_evec(m): + _, eigvecs = np.linalg.eigh(m.asnumpy()) + + return Tensor(eigvecs[..., -1]) + + +MATRIX_SVD_QUAT_FACTORS = make_matrix_svd_factors() + + +@struct_of_array.StructOfArray(same_dtype=True) +class Rot3Array: + """Rot3Array Matrix in 3 dimensional Space implemented as struct of arrays.""" + + xx: Tensor = dataclasses.field(metadata={'dtype': ms.float32}) + xy: Tensor + xz: Tensor + yx: Tensor + yy: Tensor + yz: Tensor + zx: Tensor + zy: Tensor + zz: Tensor + + __array_ufunc__ = None + + def inverse(self): + """Returns inverse of Rot3Array.""" + return Rot3Array( + *(self.xx, self.yx, self.zx), + *(self.xy, self.yy, self.zy), + *(self.xz, self.yz, self.zz), + ) + + def apply_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies Rot3Array to point.""" + x = self.xx * point.x + self.xy * point.y + self.xz * point.z + y = self.yx * point.x + self.yy * point.y + self.yz * point.z + z = self.zx * point.x + self.zy * point.y + self.zz * point.z + return vector.Vec3Array(x, y, z) + + def apply_inverse_to_point(self, point: vector.Vec3Array) -> vector.Vec3Array: + """Applies inverse Rot3Array to point.""" + return self.inverse().apply_to_point(point) + + def __matmul__(self, other): + """Composes two Rot3Arrays.""" + c0 = self.apply_to_point( + vector.Vec3Array(other.xx, other.yx, other.zx)) + c1 = self.apply_to_point( + vector.Vec3Array(other.xy, other.yy, other.zy)) + c2 = self.apply_to_point( + vector.Vec3Array(other.xz, other.yz, other.zz)) + return Rot3Array(c0.x, c1.x, c2.x, c0.y, c1.y, c2.y, c0.z, c1.z, c2.z) + + @classmethod + def identity(cls, shape: Any, dtype: ms.dtype = ms.float32): + """Returns identity of given shape.""" + ones = mint.ones(shape, dtype=dtype) + zeros = mint.zeros(shape, dtype=dtype) + + temp = cls(ones, zeros, zeros, zeros, ones, zeros, zeros, zeros, ones) + return temp + + @classmethod + def from_two_vectors(cls, e0: vector.Vec3Array, e1: vector.Vec3Array): + """Construct Rot3Array from two Vectors. + + Rot3Array is constructed such that in the corresponding frame 'e0' lies on + the positive x-Axis and 'e1' lies in the xy plane with positive sign of y. + + Args: + e0: Vector + e1: Vector + + Returns: + Rot3Array + """ + # Normalize the unit vector for the x-axis, e0. + e0 = e0.normalized() + # Make e1 perpendicular to e0. + c = e1.dot(e0) + e1 = (e1 - e0 * c).normalized() + # Compute e2 as cross product of e0 and e1. + e2 = e0.cross(e1) + return cls(e0.x, e1.x, e2.x, e0.y, e1.y, e2.y, e0.z, e1.z, e2.z) + + @classmethod + def from_array(cls, array: Tensor): + """Construct Rot3Array Matrix from array of shape [..., 3, 3].""" + unstacked = utils.unstack(array, axis=-2) + unstacked = sum([utils.unstack(x, axis=-1) for x in unstacked], []) + return cls(*unstacked) + + def to_array(self) -> Tensor: + """Convert Rot3Array to array of shape [..., 3, 3].""" + return ops.stack( + [ + ops.stack([self.xx, self.xy, self.xz], axis=-1), + ops.stack([self.yx, self.yy, self.yz], axis=-1), + ops.stack([self.zx, self.zy, self.zz], axis=-1), + ], + axis=-2, + ) + + @classmethod + def from_quaternion( + cls, + w: Tensor, + x: Tensor, + y: Tensor, + z: Tensor, + normalize: bool = True, + epsilon: float = 1e-6, + ): + """Construct Rot3Array from components of quaternion.""" + if normalize: + inv_norm = ops.rsqrt(ops.maximum( + epsilon, w**2 + x**2 + y**2 + z**2)) + w *= inv_norm + x *= inv_norm + y *= inv_norm + z *= inv_norm + xx = 1 - 2 * (y**2 + z**2) + xy = 2 * (x * y - w * z) + xz = 2 * (x * z + w * y) + yx = 2 * (x * y + w * z) + yy = 1 - 2 * (x**2 + z**2) + yz = 2 * (y * z - w * x) + zx = 2 * (x * z - w * y) + zy = 2 * (y * z + w * x) + zz = 1 - 2 * (x**2 + y**2) + return cls(xx, xy, xz, yx, yy, yz, zx, zy, zz) + + @classmethod + def from_svd(cls, mat: Tensor, use_quat_formula: bool = True): + """Constructs Rot3Array from arbitrary array of shape [3 * 3] using SVD. + + The case when 'use_quat_formula' is False rephrases the problem of + projecting the matrix to a rotation matrix as a problem of finding the + largest eigenvector of a certain 4x4 matrix. This has the advantage of + having fewer numerical issues. + This approach follows: + https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.65.971&rep=rep1&type=pdf + In the other case we construct it via svd following + https://arxiv.org/pdf/2006.14616.pdf + In that case [∂L/∂M] is large if the two smallest singular values are close + to each other, or if they are close to 0. + + Args: + mat: Array of shape [..., 3 * 3] + use_quat_formula: Whether to construct matrix via 4x4 eigenvalue problem. + + Returns: + Rot3Array of shape [...] + """ + assert mat.shape[-1] == 9 + if use_quat_formula: + # TODO + symmetric_4by4 = ops.einsum( + 'ji, ...i -> ...j', + MATRIX_SVD_QUAT_FACTORS, + mat, + ) + symmetric_4by4 = ops.reshape( + symmetric_4by4, mat.shape[:-1] + (4, 4)) + largest_eigvec = largest_evec(symmetric_4by4) + return cls.from_quaternion( + *utils.unstack(largest_eigvec, axis=-1) + ).inverse() + + else: + mat = ops.reshape(mat, mat.shape[:-1] + (3, 3)) + u, _, v_t = np.linalg.svd(mat.asnumpy(), full_matrices=False) + u = Tensor(u) + v_t = Tensor(v_t) + det_uv_t = ops.det(ops.matmul(u, v_t)) + ones = ops.ones_like(det_uv_t) + diag_array = ops.stack([ones, ones, det_uv_t], axis=-1) + # This is equivalent to making diag_array into a diagonal array and matrix + # multiplying + diag_times_v_t = diag_array[..., None] * v_t + out = ops.matmul(u, diag_times_v_t) + return cls.from_array(out) + + @classmethod + def random_uniform(cls, key, shape, dtype=ms.float32): + """Samples uniform random Rot3Array according to Haar Measure.""" + # TODO + stdnormal = ops.StandardNormal(seed=key) + quat_array = stdnormal(shape + (4,)).astype(dtype) + # quat_array = ops.StandardNormal()(shape=(tuple(shape) + (4,)), seed=key) + quats = utils.unstack(quat_array) + return cls.from_quaternion(*quats) + + def __getstate__(self): + return (VERSION, [getattr(self, field) for field in COMPONENTS]) + + def __setstate__(self, state): + version, state = state + del version + for i, field in enumerate(COMPONENTS): + object.__setattr__(self, field, state[i]) diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py new file mode 100644 index 000000000..604b657b0 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/struct_of_array.py @@ -0,0 +1,291 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Class decorator to represent (nested) struct of arrays.""" + +import dataclasses +import numbers +import mindspore as ms +import mindspore.ops as ops +import mindspore.numpy as mnp +from mindspore import nn +from collections.abc import Iterable + + +def get_item(instance, key): + sliced = {} + for field in get_array_fields(instance): + num_trailing_dims = field.metadata.get('num_trailing_dims', 0) + this_key = key + if isinstance(key, tuple) and Ellipsis in this_key: + this_key += (slice(None),) * num_trailing_dims + + def apply_slice(x): + if isinstance(x, ms.Tensor): + return x[this_key] + elif isinstance(x, dict): + return {k: apply_slice(v) for k, v in x.items()} + elif isinstance(x, list): + return [apply_slice(item) for item in x] + else: + return x + + sliced[field.name] = apply_slice(getattr(instance, field.name)) + + return dataclasses.replace(instance, **sliced) + + +@property +def get_shape(instance): + """Returns Shape for given instance of dataclass.""" + first_field = dataclasses.fields(instance)[0] + num_trailing_dims = first_field.metadata.get('num_trailing_dims', None) + value = getattr(instance, first_field.name) + if num_trailing_dims: + return value.shape[:-num_trailing_dims] + else: + return value.shape + + +def get_len(instance): + """Returns length for given instance of dataclass.""" + shape = instance.shape + if shape: + return shape[0] + else: + # Match utils.numpy behavior. + raise TypeError('len() of unsized object') + + +@property +def get_dtype(instance): + """Returns Dtype for given instance of dataclass.""" + fields = dataclasses.fields(instance) + sets_dtype = [ + field.name for field in fields if field.metadata.get('sets_dtype', False) + ] + if sets_dtype: + assert len(sets_dtype) == 1, 'at most one field can set dtype' + field_value = getattr(instance, sets_dtype[0]) + elif instance.same_dtype: + field_value = getattr(instance, fields[0].name) + else: + raise AttributeError( + 'Trying to access Dtype on Struct of Array without' + 'either "same_dtype" or field setting dtype' + ) + + if hasattr(field_value, 'dtype'): + return field_value.dtype + else: + raise AttributeError(f'field_value {field_value} does not have dtype') + + +def replace(instance, **kwargs): + return dataclasses.replace(instance, **kwargs) + + +def post_init(instance): + """Validate instance has same shapes & dtypes.""" + array_fields = get_array_fields(instance) + arrays = list(get_array_fields(instance, return_values=True).values()) + first_field = array_fields[0] + try: + dtype = instance.dtype + except AttributeError: + dtype = None + if dtype is not None: + first_shape = instance.shape + for array, field in zip(arrays, array_fields, strict=True): + num_trailing_dims = field.metadata.get('num_trailing_dims', None) + if num_trailing_dims: + array_shape = array.shape + field_shape = array_shape[:-num_trailing_dims] + msg = ( + f'field {field} should have number of trailing dims' + ' {num_trailing_dims}' + ) + assert len(array_shape) == len( + first_shape) + num_trailing_dims, msg + else: + + field_shape = array.shape + + shape_msg = ( + f"Stripped Shape {field_shape} of field {field} doesn't " + f'match shape {first_shape} of field {first_field}' + ) + + assert field_shape == first_shape, shape_msg + + field_dtype = array.dtype + + allowed_metadata_dtypes = field.metadata.get('allowed_dtypes', []) + if allowed_metadata_dtypes: + msg = f'Dtype is {field_dtype} but must be in {allowed_metadata_dtypes}' + assert field_dtype in allowed_metadata_dtypes, msg + + if 'dtype' in field.metadata: + target_dtype = field.metadata['dtype'] + else: + target_dtype = dtype + + msg = f'Dtype is {field_dtype} but must be {target_dtype}' + assert field_dtype == target_dtype, msg + + +def flatten(instance): + """Flatten Struct Of Array instance.""" + array_likes = get_array_fields(instance, return_values=True).values() + flat_array_likes = [] + inner_treedefs = [] + num_arrays = [] + for array_like in array_likes: + flat_array_like, inner_treedef = tree_flatten(array_like) + inner_treedefs.append(inner_treedef) + flat_array_likes += flat_array_like + num_arrays.append(len(flat_array_like)) + metadata = get_metadata_fields(instance, return_values=True) + metadata = type(instance).metadata_cls(**metadata) + return flat_array_likes, (inner_treedefs, metadata, num_arrays) + + +def make_metadata_class(cls): + metadata_fields = get_fields( + cls, lambda x: x.metadata.get('is_metadata', False) + ) + metadata_cls = dataclasses.make_dataclass( + cls_name='Meta' + cls.__name__, + fields=[(field.name, field.type, field) for field in metadata_fields], + frozen=True, + eq=True, + ) + return metadata_cls + + +def get_fields(cls_or_instance, filterfn, return_values=False): + fields = dataclasses.fields(cls_or_instance) + fields = [field for field in fields if filterfn(field)] + if return_values: + return { + field.name: getattr(cls_or_instance, field.name) for field in fields + } + else: + return fields + + +def get_array_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: not x.metadata.get('is_metadata', False), + return_values=return_values, + ) + + +def get_metadata_fields(cls, return_values=False): + return get_fields( + cls, + lambda x: x.metadata.get('is_metadata', False), + return_values=return_values, + ) + + +def tree_flatten(pytree): + """Custom tree flattening function for MindSpore tensors.""" + if isinstance(pytree, ms.Tensor): + return [pytree], None + elif isinstance(pytree, dict): + keys, values = zip(*pytree.items()) + flat_values, treedefs = zip(*(tree_flatten(v) for v in values)) + return sum(flat_values, []), {'keys': keys, 'treedefs': treedefs} + elif isinstance(pytree, list): + flat_items, treedefs = zip(*(tree_flatten(item) for item in pytree)) + return sum(flat_items, []), {'treedefs': treedefs} + else: + return [], None + + +def tree_unflatten(treedef, leaves): + """Custom tree unflattening function for MindSpore tensors.""" + if treedef is None: + return leaves[0] + elif isinstance(treedef, dict): + if 'keys' in treedef: + keys = treedef['keys'] + treedefs = treedef['treedefs'] + items = [tree_unflatten(td, leaves[i:i+1]) + for i, td in enumerate(treedefs)] + return dict(zip(keys, items)) + else: + treedefs = treedef['treedefs'] + start = 0 + items = [] + for td in treedefs: + size = len(tree_flatten(tree_unflatten( + td, leaves[start:start+1]))[0]) + items.append(tree_unflatten(td, leaves[start:start+size])) + start += size + return items + else: + return [] + + +class StructOfArray: + """Class Decorator for Struct Of Arrays.""" + + def __init__(self, same_dtype=True): + self.same_dtype = same_dtype + + def __call__(self, cls): + cls.__array_ufunc__ = None + cls.replace = replace + cls.same_dtype = self.same_dtype + cls.dtype = get_dtype + cls.shape = get_shape + cls.__len__ = get_len + cls.__getitem__ = get_item + cls.__post_init__ = post_init + new_cls = dataclasses.dataclass(cls, frozen=True, eq=False) + # pytree claims to require metadata to be hashable, not sure why, + # But making derived dataclass that can just hold metadata + new_cls.metadata_cls = make_metadata_class(new_cls) + + def unflatten(cls, params): + aux, data = params + inner_treedefs, metadata, num_arrays = aux + array_fields = [field.name for field in get_array_fields(new_cls)] + value_dict = {} + array_start = 0 + for num_array, inner_treedef, array_field in zip( + num_arrays, inner_treedefs, array_fields, strict=True + ): + value_dict[array_field] = tree_unflatten( + inner_treedef, data[array_start: array_start + num_array] + ) + array_start += num_array + metadata_fields = get_metadata_fields(new_cls) + for field in metadata_fields: + value_dict[field.name] = getattr(metadata, field.name) + + return new_cls(**value_dict) + + # Override __flatten__ and __unflatten__ methods + new_cls.__flatten__ = flatten + new_cls.__unflatten__ = unflatten + + return new_cls + + # Override __flatten__ and __unflatten__ methods + new_cls.__flatten__ = flatten + new_cls.__unflatten__ = unflatten + + return new_cls diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py new file mode 100644 index 000000000..3bf99c25b --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/utils.py @@ -0,0 +1,149 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Utils for geometry library.""" + +from collections.abc import Iterable +import numbers + +import mindspore as ms +import mindspore.ops as ops +import mindspore.numpy as mnp + + +def safe_select(condition, true_fn, false_fn): + """Safe version of selection (i.e. `where`). + + This applies the double-where trick. + Like jnp.where, this function will still execute both branches and is + expected to be more lightweight than lax.cond. Other than NaN-semantics, + safe_select(condition, true_fn, false_fn) is equivalent to + + utils.tree.map(lambda x, y: jnp.where(condition, x, y), + true_fn(), + false_fn()), + + Compared to the naive implementation above, safe_select provides the + following guarantee: in either the forward or backward pass, a NaN produced + *during the execution of true_fn()* will not propagate to the rest of the + computation and similarly for false_fn. It is very important to note that + while true_fn and false_fn will typically close over other tensors (i.e. they + use values computed prior to the safe_select function), there is no NaN-safety + for the backward pass of closed over values. It is important than any NaN's + are produced within the branch functions and not before them. For example, + + safe_select(x < eps, lambda: 0., lambda: jnp.sqrt(x)) + + will not produce NaN on the backward pass even if x == 0. since sqrt happens + within the false_fn, but the very similar + + y = jnp.sqrt(x) + safe_select(x < eps, lambda: 0., lambda: y) + + will produce a NaN on the backward pass if x == 0 because the sqrt happens + prior to the false_fn. + + Args: + condition: Boolean array to use in where + true_fn: Zero-argument function to construct the values used in the True + condition. Tensors that this function closes over will be extracted + automatically to implement the double-where trick to suppress spurious NaN + propagation. + false_fn: False branch equivalent of true_fn + + Returns: + Resulting PyTree equivalent to tree_map line above. + """ + true_result = true_fn() + false_result = false_fn() + + # Apply the double-where trick + true_part = ops.select(condition, true_result, + ops.stop_gradient(true_result)) + false_part = ops.select( + condition, ops.stop_gradient(false_result), false_result) + + return ops.select(condition, true_part, false_part) + + +def unstack(value: ms.Tensor, axis: int = -1) -> list[ms.Tensor]: + if len(value.shape) == 3: + if axis == -1: + split_tensors = [value[:, :, i] for i in range(value.shape[axis])] + elif axis == -2: + split_tensors = [value[:, i, :] for i in range(value.shape[axis])] + else: + split_tensors = [value[i, :, :] for i in range(value.shape[axis])] + elif len(value.shape) == 2: + if axis == -1: + split_tensors = [value[:, i] for i in range(value.shape[axis])] + else: + split_tensors = [value[i, :] for i in range(value.shape[axis])] + return split_tensors + + +def angdiff(alpha: ms.Tensor, beta: ms.Tensor) -> ms.Tensor: + """Compute absolute difference between two angles.""" + d = alpha - beta + d = (d + mnp.pi) % (2 * mnp.pi) - mnp.pi + return d + + +def safe_arctan2( + x1: ms.Tensor, x2: ms.Tensor, eps: float = 1e-8 +) -> ms.Tensor: + """Safe version of arctan2 that avoids NaN gradients when x1=x2=0.""" + + return safe_select( + ops.abs(x1) + ops.abs(x2) < eps, + lambda: ops.zeros_like(ops.atan2(x1, x2)), + lambda: ops.atan2(x1, x2), + ) + + +def weighted_mean( + *, + weights: ms.Tensor, + value: ms.Tensor, + axis: int | Iterable[int] | None = None, + eps: float = 1e-10, +) -> ms.Tensor: + """Computes weighted mean in a safe way that avoids NaNs. + + This is equivalent to jnp.average for the case eps=0.0, but adds a small + constant to the denominator of the weighted average to avoid NaNs. + 'weights' should be broadcastable to the shape of value. + + Args: + weights: Weights to weight value by. + value: Values to average + axis: Axes to average over. + eps: Epsilon to add to the denominator. + + Returns: + Weighted average. + """ + + weights = ops.cast(weights, value.dtype) + weights = ops.broadcast_to(weights, value.shape) + + weights_shape = weights.shape + + if isinstance(axis, numbers.Integral): + axis = [axis] + elif axis is None: + axis = list(range(len(weights_shape))) + + numerator = ops.reduce_sum(weights * value, axis=tuple(axis)) + denominator = ops.reduce_sum(weights, axis=tuple(axis)) + eps + + return numerator / denominator diff --git a/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py new file mode 100644 index 000000000..0e6c672f2 --- /dev/null +++ b/MindSPONGE/applications/research/AlphaFold3/src/alphafold3/utils/geometry/vector.py @@ -0,0 +1,258 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Copyright 2024 DeepMind Technologies Limited +# +# AlphaFold 3 source code is licensed under CC BY-NC-SA 4.0. To view a copy of +# this license, visit https://creativecommons.org/licenses/by-nc-sa/4.0/ +# +# To request access to the AlphaFold 3 model parameters, follow the process set +# out at https://github.com/google-deepmind/alphafold3. You may only use these +# if received directly from Google. Use is subject to terms of use available at +# https://github.com/google-deepmind/alphafold3/blob/main/WEIGHTS_TERMS_OF_USE.md + +"""Vec3Array Class.""" + +import dataclasses +from typing import Final, TypeVar, Generic, TypeAlias + +import mindspore as ms +from mindspore import ops, mint +from mindspore import numpy as mnp +from alphafold3.utils.geometry import struct_of_array + +Self = TypeVar('Self', bound='Vec3Array') +Float = TypeAlias = float | ms.Tensor +VERSION: Final[str] = '0.1' + + +def tree_map(func, *trees): + """ + Recursively applies a function to each leaf of the input trees. + + Args: + func: A function to apply to each leaf. + *trees: One or more tree structures (nested lists/tuples/dicts). + + Returns: + A new tree with the same structure where `func` has been applied to each leaf. + """ + if isinstance(trees[0], Vec3Array): + return Vec3Array( + x=tree_map(func, *(t.x for t in trees)), + y=tree_map(func, *(t.y for t in trees)), + z=tree_map(func, *(t.z for t in trees)) + ) + elif isinstance(trees[0], dict): + return {key: tree_map(func, *(t[key] for t in trees)) for key in trees[0]} + elif isinstance(trees[0], (list, tuple)): + return type(trees[0])(tree_map(func, *args) for args in zip(*trees)) + else: + return func(*trees) + + +@struct_of_array.StructOfArray(same_dtype=True) +class Vec3Array: + """Vec3Array in 3 dimensional Space implemented as struct of arrays. + This is done in order to improve performance and precision. + """ + + x: ms.Tensor = dataclasses.field(metadata={'dtype': ms.float32}) + y: ms.Tensor + z: ms.Tensor + + def __post_init__(self): + if hasattr(self.x, 'dtype'): + if not self.x.dtype == self.y.dtype == self.z.dtype: + raise ValueError( + f'Type mismatch: {self.x.dtype}, {self.y.dtype}, {self.z.dtype}' + ) + if not self.x.shape == self.y.shape == self.z.shape: + raise ValueError( + f'Shape mismatch: {self.x.shape}, {self.y.shape}, {self.z.shape}' + ) + + @property + def shape(self): + """Return the shape of the Vec3Array.""" + return self.x.shape + + def __add__(self, other: Self) -> Self: + return tree_map(ops.add, self, other) + + def __sub__(self, other: Self) -> Self: + return tree_map(ops.sub, self, other) + + def __mul__(self, other: Float | ms.Tensor) -> Self: + if isinstance(other, float): + return tree_map(lambda x: ops.mul(x, other), self) + else: + x = ops.mul(self.x, other) + y = ops.mul(self.y, other) + z = ops.mul(self.z, other) + return Vec3Array(x, y, z) + + def __rmul__(self, other: Float | ms.Tensor) -> Self: + if isinstance(other, float): + return self * other + else: + x = ops.mul(self.x, other) + y = ops.mul(self.y, other) + z = ops.mul(self.z, other) + return Vec3Array(x, y, z) + + def __truediv__(self, other: Float) -> Self: + return tree_map(lambda x: ops.div(x, other), self) + + def __neg__(self) -> Self: + return tree_map(lambda x: -x, self) + + def __pos__(self) -> Self: + return tree_map(lambda x: x, self) + + def cross(self, other: Self) -> Self: + """Compute cross product between 'self' and 'other'.""" + new_x = ops.sub(ops.mul(self.y, other.z), ops.mul(self.z, other.y)) + new_y = ops.sub(ops.mul(self.z, other.x), ops.mul(self.x, other.z)) + new_z = ops.sub(ops.mul(self.x, other.y), ops.mul(self.y, other.x)) + return Vec3Array(new_x, new_y, new_z) + + def dot(self, other: Self) -> ms.Tensor: + """Compute dot product between 'self' and 'other'.""" + return ops.add(ops.add(ops.mul(self.x, other.x), ops.mul(self.y, other.y)), ops.mul(self.z, other.z)) + + def norm(self, epsilon: float = 1e-6) -> ms.Tensor: + """Compute Norm of Vec3Array, clipped to epsilon.""" + # To avoid NaN on the backward pass, we must use maximum before the sqrt + norm2 = self.dot(self) + if epsilon: + norm2 = ops.maximum(norm2, epsilon**2) + return ops.sqrt(norm2) + + def norm2(self) -> ms.Tensor: + return self.dot(self) + + def normalized(self, epsilon: float = 1e-6) -> Self: + """Return unit vector with optional clipping.""" + return self / self.norm(epsilon) + + @classmethod + def zeros(cls, shape, dtype=ms.float32): + """Return Vec3Array corresponding to zeros of given shape.""" + return cls( + mint.zeros(shape, dtype=dtype), + mint.zeros(shape, dtype=dtype), + mint.zeros(shape, dtype=dtype), + ) + + def to_array(self) -> ms.Tensor: + return ops.stack([self.x, self.y, self.z], axis=-1) + + @classmethod + def from_array(cls, array): + unstacked = ops.unstack(array, axis=-1) + return cls(unstacked[0], unstacked[1], unstacked[2]) + + def __getstate__(self): + return ( + VERSION, + [self.x.asnumpy(), self.y.asnumpy(), self.z.asnumpy()], + ) + + def __setstate__(self, state): + version, state = state + del version + for i, letter in enumerate('xyz'): + object.__setattr__(self, letter, ms.Tensor(state[i])) + + +def square_euclidean_distance( + vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 +) -> Float: + """Computes square of euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute distance to + vec2: Vec3Array to compute distance from, should be broadcast compatible + with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of square euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + difference = vec1 - vec2 + distance = difference.dot(difference) + if epsilon: + distance = ops.maximum(distance, epsilon) + return distance + + +def dot(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.dot(vector2) + + +def cross(vector1: Vec3Array, vector2: Vec3Array) -> Float: + return vector1.cross(vector2) + + +def norm(vector: Vec3Array, epsilon: float = 1e-6) -> Float: + return vector.norm(epsilon) + + +def normalized(vector: Vec3Array, epsilon: float = 1e-6) -> Vec3Array: + return vector.normalized(epsilon) + + +def euclidean_distance( + vec1: Vec3Array, vec2: Vec3Array, epsilon: float = 1e-6 +) -> Float: + """Computes euclidean distance between 'vec1' and 'vec2'. + + Args: + vec1: Vec3Array to compute euclidean distance to + vec2: Vec3Array to compute euclidean distance from, should be broadcast + compatible with 'vec1' + epsilon: distance is clipped from below to be at least epsilon + + Returns: + Array of euclidean distances; + shape will be result of broadcasting 'vec1' and 'vec2' + """ + distance_sq = square_euclidean_distance(vec1, vec2, epsilon**2) + distance = ops.sqrt(distance_sq) + return distance + + +def dihedral_angle( + a: Vec3Array, b: Vec3Array, c: Vec3Array, d: Vec3Array +) -> Float: + """Computes torsion angle for a quadruple of points. + + For points (a, b, c, d), this is the angle between the planes defined by + points (a, b, c) and (b, c, d). It is also known as the dihedral angle. + + Arguments: + a: A Vec3Array of coordinates. + b: A Vec3Array of coordinates. + c: A Vec3Array of coordinates. + d: A Vec3Array of coordinates. + + Returns: + A tensor of angles in radians: [-pi, pi]. + """ + v1 = a - b + v2 = b - c + v3 = d - c + + c1 = v1.cross(v2) + c2 = v3.cross(v2) + c3 = c2.cross(c1) + + v2_mag = v2.norm() + return ops.atan2(c3.dot(v2), v2_mag * c1.dot(c2)) + + +def random_gaussian_vector(shape, key=None, dtype=ms.float32) -> Vec3Array: + stdnormal = ops.StandardNormal(seed=key) + vec_array = stdnormal(shape + (3,)).astype(dtype) + return Vec3Array.from_array(vec_array) -- Gitee