From 543b0bd98c1ae9ef87a60e3f66c0b81355a6d62b Mon Sep 17 00:00:00 2001 From: yzhang Date: Wed, 27 Nov 2024 08:39:23 +0000 Subject: [PATCH 1/5] =?UTF-8?q?=E6=96=B0=E5=BB=BA=20progen?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- MindSPONGE/src/mindsponge/pipeline/models/progen/.keep | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/.keep diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/.keep b/MindSPONGE/src/mindsponge/pipeline/models/progen/.keep new file mode 100644 index 000000000..e69de29bb -- Gitee From 0b022d03a3ed88cbcc213cdfeef049a1bc67d42b Mon Sep 17 00:00:00 2001 From: yzhang Date: Wed, 27 Nov 2024 16:47:47 +0800 Subject: [PATCH 2/5] add equidock and progen --- MindSPONGE/requirements.txt | 6 +- .../pipeline/models/equidock/__init__.py | 19 + .../pipeline/models/equidock/equidock.py | 383 +++ .../models/equidock/equidock_configuration.py | 24 + .../pipeline/models/equidock/equidock_data.py | 286 ++ .../models/equidock/equidock_dataset.py | 229 ++ .../pipeline/models/equidock/nn_arch.py | 1326 ++++++++++ .../pipeline/models/equidock/train_utils.py | 236 ++ .../pipeline/models/progen/__init__.py | 19 + .../pipeline/models/progen/module/__init__.py | 15 + .../progen/module/configuration_utils.py | 2295 +++++++++++++++++ .../models/progen/module/injection.py | 979 +++++++ .../models/progen/module/logits_process.py | 1037 ++++++++ .../pipeline/models/progen/nn_arch.py | 718 ++++++ .../pipeline/models/progen/progen.py | 282 ++ .../models/progen/progen_configuration.py | 20 + .../pipeline/models/progen/progen_dataset.py | 46 + .../pipeline/models/progen/tokenizer.json | 91 + .../src/mindsponge/pipeline/pipeline.py | 5 +- 19 files changed, 8014 insertions(+), 2 deletions(-) create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json diff --git a/MindSPONGE/requirements.txt b/MindSPONGE/requirements.txt index 9b940d1b5..174bca955 100644 --- a/MindSPONGE/requirements.txt +++ b/MindSPONGE/requirements.txt @@ -1,5 +1,6 @@ numpy >= 1.17.0,<=1.23.4 scipy >= 1.7.0 +biopandas == 0.4.1 biopython == 1.81 pyyaml >= 5.4.1 dataclasses >= 0.6 @@ -9,8 +10,11 @@ absl-py >= 1.1.0 biotite == 0.40.0 descriptastorus == 2.6.1 pyparsing >= 3.0.7 +POT == 0.9.3 +tokenizers +joblib rdkit bio scikit-learn mindformers >= 1.2.0 -sentencepiece >= 0.2.0 \ No newline at end of file +sentencepiece >= 0.2.0 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py new file mode 100644 index 000000000..4c604d06e --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock""" + +from .equidock import EquiDock +from .equidock_dataset import EquiDockDataSet +from .equidock_configuration import equidock_configuration diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py new file mode 100644 index 000000000..7f951efb5 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py @@ -0,0 +1,383 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock""" + +import os +import random +from datetime import datetime as dt + +import numpy as np +from biopandas.pdb import PandasPdb +import mindspore as ms +from mindspore import nn, Tensor, ops, save_checkpoint +from mindspore.experimental import optim + +from .train_utils import create_dir, prepare_graphs, graph_to_tensor, \ + pretty_print_stats, get_rot_mat, compute_sq_dist_mat, compute_ot_emd, \ + compute_body_intersection_loss +from .nn_arch import MeterUnboundBound, RigidBodyDockingNet, log, FLAGS +from .equidock_dataset import EquiDockDataSet +from ..model import Model + + +class EquiDock(Model): + """ + EquiDock class + """ + name = "EquiDock" + + def __init__(self, config): + + self.config = config + self.mixed_precision = False + self.network = RigidBodyDockingNet(self.config) + self.log(self.network) + self.checkpoint_url = "Local_Checkpoint_Used" + self.checkpoint_path = self.config.ckpt_dir + + if self.config.is_train: + self.config.cache_path = './cache/' + self.config.data + '_' + \ + self.config.graph_nodes + '_maxneighbor_' + str(self.config.graph_max_neighbor) + \ + '_cutoff_' + str(self.config.graph_cutoff) + '_pocketCut_' + \ + str(self.config.pocket_cutoff) + '/' + self.config.cache_path = os.path.join(self.config.cache_path, 'cv_' + str(self.config.split)) + banner = 'EQUIDOCK' + self.config.checkpoint_dir = './checkpts/' + banner + self.config.log_files_dir = './stdouterr/' + + ### Create log files only when in train mode + create_dir(self.config.log_files_dir) + create_dir(self.config.checkpoint_dir) + + log_file_name = os.path.join(self.config.log_files_dir, banner + ".txt") + with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as fout: + fout.write('[' + str(dt.now()) + '] START\n') + self.config.checkpoint_filename = os.path.join( + self.config.checkpoint_dir, + self.config.data + '_model_best.pth' + ) + + self.loss_fn_coors = nn.MSELoss(reduction='mean') + self.optimizer = optim.Adam( + self.network.trainable_params(), + lr=self.config.lr, + weight_decay=self.config.w_decay + ) + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lr_lambda=[lambda epoch: min(1., ((epoch + 1) / self.config.warmup) ** 3)] + ) + + self.best_epoch = 0 + self.best_val_rmsd_median = float('inf') + self.corr_val_rmsd_mean = float('inf') + + dataset_util = EquiDockDataSet(self.config) + self.train_data_batched, self.train_loader, self.val_data_batched, self.val_loader, \ + self.test_data_batched, self.test_loader = dataset_util.set_training_data_src("Train Mode") + + super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, + network=self.network, mixed_precision=self.mixed_precision) + + + def from_pretrained(self, ckpt_path=None): + "from_pretrained" + self.get_checkpoint_path(ckpt_path) + if not ckpt_path: + param_dict = ms.load_checkpoint(self.checkpoint_path) + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(self.network, param_dict) + print(f'param not load: {param_not_load}') + + + def predict(self, data, **kwargs): + time_list = [] + input_dir = self.config.input_dir + ground_truth_dir = self.config.ground_truth_dir + output_dir = self.config.output_dir + file_type = '.pdb' + l_b_str = '_l_b' + + for file in data: + if not file.endswith(l_b_str + file_type): + continue + ll = len(l_b_str + file_type) + ligand_filename = os.path.join(input_dir, file[:-ll] + l_b_str + file_type) + receptor_filename = os.path.join(ground_truth_dir, file[:-ll] + '_r_b' + '_COMPLEX' + file_type) + out_filename = file[:-ll] + l_b_str + '_' + "EQUIDOCK" + file_type + + self.log(' inference on file = ', ligand_filename) + + start = dt.now() + + ppdb_ligand = PandasPdb().read_pdb(ligand_filename) + + ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, _\ + = prepare_graphs(self.config, ppdb_ligand, ligand_filename, receptor_filename) + + if self.config.input_edge_feats_dim < 0: + self.config.input_edge_feats_dim = ligand_graph.edata['he'].shape[1] + + ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, \ + input_tensor_tuple = graph_to_tensor(ligand_graph, receptor_graph) + + _, _, _, all_rotation_list, all_translation_list = self.network( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list, + input_tensor_tuple, + ) + + rotation = all_rotation_list[0].asnumpy() + translation = all_translation_list[0].asnumpy() + + unbound_ligand_new_pos = (rotation @ unbound_ligand_all_atoms_pre_pos.T).T + translation + + euler_angles_finetune = ops.zeros([3]) + translation_finetune = ops.zeros([3]) + ligand_th = (get_rot_mat(euler_angles_finetune) @ Tensor(unbound_ligand_new_pos).T).T + translation_finetune + + ppdb_ligand.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ligand_th.asnumpy() # unbound_ligand_new_pos + unbound_ligand_save_filename = os.path.join(output_dir, out_filename) + ppdb_ligand.to_pdb(path=unbound_ligand_save_filename, records=['ATOM'], gz=False) + + end = dt.now() + time_list.append((end - start).total_seconds()) + + time_array = np.array(time_list) + self.log("Mean runtime:", np.mean(time_array), "std runtime:", np.std(time_array)) + self.log('Time list = ', time_list) + + + def run_a_train_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + train_complex_rmsd_mean, train_complex_rmsd_median, \ + train_ligand_rmsd_mean, train_ligand_rmsd_median, \ + train_receptor_rmsd_mean, train_receptor_rmsd_median, \ + train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, \ + train_avg_loss_ot, train_avg_loss_intersection = \ + self.run_a_generic_epoch('train', run_epoch_tuple, data_batched, data_loader) + + pretty_print_stats('TRAIN', epoch, args.num_epochs, + train_complex_rmsd_mean, train_complex_rmsd_median, + train_ligand_rmsd_mean, train_ligand_rmsd_median, + train_receptor_rmsd_mean, train_receptor_rmsd_median, + train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, + train_avg_loss_ot, train_avg_loss_intersection, + self.log) + + + def run_an_eval_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + """ + run_an_eval_epoch + """ + val_complex_rmsd_mean, val_complex_rmsd_median, \ + val_ligand_rmsd_mean, val_ligand_rmsd_median, \ + val_receptor_rmsd_mean, val_receptor_rmsd_median, \ + val_avg_loss, val_avg_loss_ligand_coors, \ + val_avg_loss_receptor_coors, \ + val_avg_loss_ot, val_avg_loss_intersection = \ + self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader) + + pretty_print_stats('VALIDATION', epoch, args.num_epochs, + val_complex_rmsd_mean, val_complex_rmsd_median, + val_ligand_rmsd_mean, val_ligand_rmsd_median, + val_receptor_rmsd_mean, val_receptor_rmsd_median, + val_avg_loss, val_avg_loss_ligand_coors, val_avg_loss_receptor_coors, + val_avg_loss_ot, val_avg_loss_intersection, + self.log) + + return val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss + + + def run_a_generic_epoch(self, ep_type, run_epoch_tuple, data_batched, data_loader): + """ + run_a_generic_epoch + """ + args, self.network, loss_fn_coors, optimizer = run_epoch_tuple + + meter = MeterUnboundBound() + + avg_loss, total_loss, num_batches = 0., 0., 0 + + total_loss_ligand_coors = 0. + avg_loss_ligand_coors = 0. + + total_loss_receptor_coors = 0. + avg_loss_receptor_coors = 0. + + total_loss_ot = 0. + avg_loss_ot = 0. + + total_loss_intersection = 0. + avg_loss_intersection = 0. + + def forward(batch_id): + + ######## RUN MODEL ############## + ligand_graph_node_tensor = Tensor(data_loader[batch_id][6], ms.float32) + receptor_graph_node_tensor = Tensor(data_loader[batch_id][7], ms.float32) + unbatch_list_tensor = Tensor(data_loader[batch_id][8], ms.int32) + input_tensor_tuple = ( + Tensor(data_loader[batch_id][4], ms.int32), # ligand_graph_num_nodes + Tensor(data_loader[batch_id][5], ms.int32), # receptor_graph_num_nodes + Tensor(data_loader[batch_id][2], ms.float32), # ligand_graph_edge_tensor + Tensor(data_loader[batch_id][3], ms.float32), # receptor_graph_edge_tensor + Tensor(data_loader[batch_id][0], ms.int32), # ll_connection_tensor + Tensor(data_loader[batch_id][1], ms.int32), # rr_connection_tensor + ) + + model_ligand_coors_deform_list, \ + model_keypts_ligand_list, model_keypts_receptor_list, \ + _, _, = self.network( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ) + ################################ + + batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss = Tensor(0.0), \ + Tensor(0.0), Tensor(0.0), Tensor(0.0) + + for i, _ in enumerate(model_ligand_coors_deform_list): + ## Compute average MSE loss (which is 3 times smaller than average squared RMSD) + + batch_ligand_coors_loss = batch_ligand_coors_loss + loss_fn_coors(model_ligand_coors_deform_list[i], + Tensor(data_batched[8][batch_id][i], + ms.float32)) + # Compute the OT loss for the binding pocket: + ligand_pocket_coors = Tensor(data_batched[10][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes + receptor_pocket_coors = Tensor(data_batched[11][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes + + ligand_keypts_coors = model_keypts_ligand_list[i] # (K, 3), K = num keypoints + receptor_keypts_coors = model_keypts_receptor_list[i] # (K, 3), K = num keypoints + + ## (N, K) cost matrix + cost_mat_ligand = compute_sq_dist_mat(ligand_pocket_coors, ligand_keypts_coors) + cost_mat_receptor = compute_sq_dist_mat(receptor_pocket_coors, receptor_keypts_coors) + + ot_dist, _ = compute_ot_emd(cost_mat_ligand + cost_mat_receptor) + batch_ot_loss = batch_ot_loss + ot_dist + + batch_intersection_loss = batch_intersection_loss + compute_body_intersection_loss( + model_ligand_coors_deform_list[i], data_batched[9][batch_id][i], + args.intersection_sigma, args.intersection_surface_ct) + + ### Add new stats to the meter + if ep_type != 'train' or random.random() < 0.1: + meter.update_rmsd(model_ligand_coors_deform_list[i], + data_batched[9][batch_id][i], + data_batched[8][batch_id][i], + data_batched[9][batch_id][i]) + + batch_ligand_coors_loss = batch_ligand_coors_loss / float(len(model_ligand_coors_deform_list)) + batch_receptor_coors_loss = batch_receptor_coors_loss / float(len(model_ligand_coors_deform_list)) + batch_ot_loss = batch_ot_loss / float(len(model_ligand_coors_deform_list)) + batch_intersection_loss = batch_intersection_loss / float(len(model_ligand_coors_deform_list)) + + loss_coors = batch_ligand_coors_loss + batch_receptor_coors_loss + + loss = loss_coors + args.pocket_ot_loss_weight * batch_ot_loss \ + + args.intersection_loss_weight * batch_intersection_loss + + return loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss + + backward = ms.value_and_grad(forward, None, optimizer.parameters) + + # Iterate over all batches of the epoch + for step in range(len(data_batched[0])): + num_batches += 1 + + (loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, + batch_intersection_loss), grads = backward(step) + + if ep_type == 'train': + grads = ms.ops.clip_by_norm(grads, max_norm=args.clip, norm_type=2) + optimizer(grads) + + total_loss += loss.asnumpy() + total_loss_ligand_coors += batch_ligand_coors_loss + total_loss_receptor_coors += batch_receptor_coors_loss + total_loss_ot += batch_ot_loss + total_loss_intersection += batch_intersection_loss + + if num_batches != 0: + avg_loss = total_loss / num_batches + avg_loss_ligand_coors = total_loss_ligand_coors / num_batches + avg_loss_receptor_coors = total_loss_receptor_coors / num_batches + avg_loss_ot = total_loss_ot / num_batches + avg_loss_intersection = total_loss_intersection / num_batches + + ligand_rmsd_mean, receptor_rmsd_mean, complex_rmsd_mean = meter.summarize(reduction_rmsd='mean') + ligand_rmsd_median, receptor_rmsd_median, complex_rmsd_median = meter.summarize(reduction_rmsd='median') + + return complex_rmsd_mean, complex_rmsd_median, \ + ligand_rmsd_mean, ligand_rmsd_median, \ + receptor_rmsd_mean, receptor_rmsd_median, \ + avg_loss.item(), avg_loss_ligand_coors.item(), \ + avg_loss_receptor_coors.item(), \ + avg_loss_ot.item(), avg_loss_intersection.item() + + + def train_step(self, data): + + self.log('+' * 100) + epoch = data + epoch_start = dt.now() + + run_epoch_tuple = (self.config, self.network, self.loss_fn_coors, self.optimizer) + self.run_a_train_epoch(run_epoch_tuple, self.train_data_batched, self.train_loader, epoch, self.config) + val_complex_rmsd_mean, val_complex_rmsd_median, _ = \ + self.run_an_eval_epoch(run_epoch_tuple, self.val_data_batched, self.val_loader, epoch, self.config) + + self.scheduler.step() + + if val_complex_rmsd_median < self.best_val_rmsd_median * 0.98: # We do this to avoid "pure luck" + # Best validation so far + self.best_val_rmsd_median = val_complex_rmsd_median + self.corr_val_rmsd_mean = val_complex_rmsd_mean + self.best_epoch = epoch + save_checkpoint(self.optimizer, self.config.checkpoint_dir + "/best_model.ckpt") + + log('[BEST SO FAR] ', self.config.data, + '|| At best epoch {} we have: best_val_rmsd_median {:.4f}, ' + '|| Current val rmsd_median {:.4f} ' + '|| Train time: {}\n'. + format(self.best_epoch, self.best_val_rmsd_median, + val_complex_rmsd_median, + dt.now() - epoch_start)) + + return '\n' + + + def log(self, *pargs): + print('[' + str(dt.now()) + '] ', *pargs) + + + def forward(self, data): + return None + + + def backward(self, data): + return None + + + def _jit_forward(self, data): + return None + + + def _pynative_forward(self, data): + return None diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py new file mode 100644 index 000000000..5695757f8 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py @@ -0,0 +1,24 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock_configuration""" + +equidock_configuration = { + "predict_dips": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml", + "predict_db5": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml", + "train_db5": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml", +} diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py new file mode 100644 index 000000000..9726490e0 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py @@ -0,0 +1,286 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock_data""" + +import os +import numpy as np +from biopandas.pdb import PandasPdb +from joblib import Parallel, delayed, cpu_count +from scipy.spatial.transform import Rotation + +from .nn_arch import preprocess_unbound_bound, protein_to_graph_unbound_bound +from .train_utils import log + + +def one_hot_encoding(x, allowable_set, encode_unknown=False): + if encode_unknown and (allowable_set[-1] is not None): + allowable_set.append(None) + + if encode_unknown and (x not in allowable_set): + x = None + + return list(map(lambda s: x == s, allowable_set)) + + +def uniformrotation_translation(translation_interval): + rotation = Rotation.random(num=1) + rotation_matrix = rotation.as_matrix().squeeze() + + t = np.random.randn(1, 3) + t = t / np.sqrt(np.sum(t * t)) + length = np.random.uniform(low=0, high=translation_interval) + t = t * length + return rotation_matrix.astype(np.float32), t.astype(np.float32) + + +def get_residues_db5(pdb_filename): + df = PandasPdb().read_pdb(pdb_filename).df['ATOM'] + df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname', + 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True) + residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order ! + return residues + + +def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, **kwargs): + if n_jobs is None: + n_jobs = cpu_count() - 1 + results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( + delayed(pickleable_fn)(d, **kwargs) for i, d in enumerate(data) + ) + + return results + + +class UnboundBoundData(): + """ + UnboundBoundData Class + """ + def __init__(self, args, reload_mode='train', raw_data_path=None, split_files_path=None): + + os.makedirs(args.processed_dataset_path + reload_mode, exist_ok=True) + + if args.data == 'db5': + onlyfiles = [f for f in os.listdir(raw_data_path) if os.path.isfile(os.path.join(raw_data_path, f))] + code_set = {file.split('_')[0] for file in onlyfiles} + split_code_set = set() + with open(os.path.join(split_files_path, reload_mode + '.txt'), 'r') as f: + for line in f.readlines(): + split_code_set.add(line.rstrip()) + + code_set = code_set & split_code_set + code_list = list(code_set) + + bound_ligand_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_l_b.pdb')) + for code in code_list] + bound_receptor_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_r_b.pdb')) + for code in code_list] + + input_residues_lists = [(bound_ligand_residues_list[i], bound_receptor_residues_list[i]) + for i in range(len(bound_ligand_residues_list))] + + log('Start preprocess_unbound_bound') + preprocess_result = pmap_multi(preprocess_unbound_bound, + input_residues_lists, + n_jobs=args.n_jobs, + graph_nodes=args.graph_nodes, + pos_cutoff=args.pocket_cutoff, + inference=False) + log('Done preprocess_unbound_bound\n\n') + + unbound_predic_ligand_list, unbound_predic_receptor_list = [], [] + bound_ligand_repres_nodes_loc_array_list, bound_receptor_repres_nodes_loc_array_list = [], [] + pocket_coors_list = [] + for result in preprocess_result: + unbound_predic_ligand, unbound_predic_receptor, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, pocket_coors = result + if pocket_coors is not None: + unbound_predic_ligand_list.append(unbound_predic_ligand) + unbound_predic_receptor_list.append(unbound_predic_receptor) + bound_ligand_repres_nodes_loc_array_list.append(bound_ligand_repres_nodes_loc_array) + bound_receptor_repres_nodes_loc_array_list.append(bound_receptor_repres_nodes_loc_array) + pocket_coors_list.append(pocket_coors) + + protein_to_graph_input = [(unbound_predic_ligand_list[i], + unbound_predic_receptor_list[i], + bound_ligand_repres_nodes_loc_array_list[i], + bound_receptor_repres_nodes_loc_array_list[i]) for i in + range(len(unbound_predic_ligand_list))] + log('Start protein_to_graph_unbound_bound') + + both_proteins_to_graph_pair_list = pmap_multi(protein_to_graph_unbound_bound, + protein_to_graph_input, + n_jobs=args.n_jobs, + cutoff=args.graph_cutoff, + max_neighbor=args.graph_max_neighbor, + one_hot=False, + residue_loc_is_alphac=args.graph_residue_loc_is_alphaC + ) + + log('Done protein_to_graph_unbound_bound') + + self.save_processed_data( + args, + reload_mode, + both_proteins_to_graph_pair_list, + bound_ligand_repres_nodes_loc_array_list, + bound_receptor_repres_nodes_loc_array_list, + pocket_coors_list, + ) + + def single_graph_data(self, ligand_graph, receptor_graph, ligand_new_loc): + """ + Process single graph data in the list + """ + for k in ligand_graph.ndata.keys(): + ligand_graph.ndata[k] = ligand_graph.ndata[k].asnumpy() + for k in receptor_graph.ndata.keys(): + receptor_graph.ndata[k] = receptor_graph.ndata[k].asnumpy() + for k in ligand_graph.edata.keys(): + ligand_graph.edata[k] = ligand_graph.edata[k].asnumpy() + for k in receptor_graph.edata.keys(): + receptor_graph.edata[k] = receptor_graph.edata[k].asnumpy() + ligand_graph.ndata['new_x'] = ligand_new_loc.astype(np.float32) + + ##### Create a batch of a single heterograph + ligand_graph_node_tensor = np.concatenate( + (ligand_graph.ndata["res_feat"], + ligand_graph.ndata["x"], + ligand_graph.ndata["new_x"], + ligand_graph.ndata["mu_r_norm"]), + axis=1 + ) + + receptor_graph_node_tensor = np.concatenate( + (receptor_graph.ndata["res_feat"], + receptor_graph.ndata["x"], + receptor_graph.ndata["mu_r_norm"]), + axis=1 + ) + + ligand_graph_num_nodes = [ligand_graph.ndata["x"].shape[0]] + receptor_graph_num_nodes = [receptor_graph.ndata["x"].shape[0]] + ligand_graph_edge_tensor = ligand_graph.edata['he'] + receptor_graph_edge_tensor = receptor_graph.edata['he'] + + ll_connection_tensor = np.stack( + (ligand_graph.src_list, ligand_graph.dst_list)) + + rr_connection_tensor = np.stack( + (receptor_graph.src_list, receptor_graph.dst_list)) + ##### Create a batch of a single heterograph + + return ll_connection_tensor, rr_connection_tensor, ligand_graph_edge_tensor, \ + receptor_graph_edge_tensor, ligand_graph_num_nodes, receptor_graph_num_nodes, \ + ligand_graph_node_tensor, receptor_graph_node_tensor + + def save_npz_files(self, args, saved_lists, reload_mode): + """ + save data into npz file format + """ + ll_connection_tensor_list, rr_connection_tensor_list, ligand_graph_edge_tensor_list, \ + receptor_graph_edge_tensor_list, ligand_graph_num_nodes_list, receptor_graph_num_nodes_list, \ + ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, bound_ligand_repres_nodes_loc_array_list_new, \ + bound_receptor_repres_nodes_loc_array_list_new, pocket_coors_ligand_list_new, \ + pocket_coors_receptor_list_new = saved_lists + + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ll_connection_tensor_list.npz"), + *ll_connection_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "rr_connection_tensor_list.npz"), + *rr_connection_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_edge_tensor_list.npz"), + *ligand_graph_edge_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_edge_tensor_list.npz"), + *receptor_graph_edge_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_num_nodes_list.npz"), + *ligand_graph_num_nodes_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_num_nodes_list.npz"), + *receptor_graph_num_nodes_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, + "bound_ligand_repres_nodes_loc_array_list_new.npz"), + *bound_ligand_repres_nodes_loc_array_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, + "bound_receptor_repres_nodes_loc_array_list_new.npz"), + *bound_receptor_repres_nodes_loc_array_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_ligand_list_new.npz"), + *pocket_coors_ligand_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_receptor_list_new.npz"), + *pocket_coors_receptor_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_node_tensor_list.npz"), + *ligand_graph_node_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_node_tensor_list.npz"), + *receptor_graph_node_tensor_list) + + def save_processed_data( + self, + args, + reload_mode, + both_proteins_to_graph_pair_list, + bound_ligand_repres_nodes_loc_array_list, + bound_receptor_repres_nodes_loc_array_list, + pocket_coors_list, + ): + """ + save_processed_data + """ + ligand_graph_list, receptor_graph_list = [], [] + for result in both_proteins_to_graph_pair_list: + ligand_graph, receptor_graph = result + ligand_graph_list.append(ligand_graph) + receptor_graph_list.append(receptor_graph) + + ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, ligand_graph_num_nodes_list = [], [], [] + receptor_graph_num_nodes_list, ligand_graph_edge_tensor_list, receptor_graph_edge_tensor_list = [], [], [] + ll_connection_tensor_list, rr_connection_tensor_list, bound_ligand_repres_nodes_loc_array_list_new = [], [], [] + pocket_coors_ligand_list_new, pocket_coors_receptor_list_new = [], [] + bound_receptor_repres_nodes_loc_array_list_new = [] + + for i, _ in enumerate(ligand_graph_list): + ligand_graph, receptor_graph = ligand_graph_list[i], receptor_graph_list[i] + bound_ligand_repres_nodes_loc_array = bound_ligand_repres_nodes_loc_array_list[i] + bound_receptor_repres_nodes_loc_array = bound_receptor_repres_nodes_loc_array_list[i] + pocket_coors_ligand, pocket_coors_receptor = pocket_coors_list[i], pocket_coors_list[i] + + # Randomly rotate and translate the ligand. + rot_t, rot_b = uniformrotation_translation(translation_interval=args.translation_interval) + ligand_original_loc = ligand_graph.ndata['x'] + mean_to_remove = ligand_original_loc.mean(axis=0, keep_dims=True) + pocket_coors_ligand = (rot_t @ (pocket_coors_ligand - mean_to_remove).T).T + rot_b + ligand_new_loc = (rot_t @ (ligand_original_loc - mean_to_remove).T).T + rot_b + + ll_connection_tensor, rr_connection_tensor, ligand_graph_edge_tensor, receptor_graph_edge_tensor, \ + ligand_graph_num_nodes, receptor_graph_num_nodes, ligand_graph_node_tensor, \ + receptor_graph_node_tensor = self.single_graph_data(ligand_graph, receptor_graph, ligand_new_loc) + + ll_connection_tensor_list.append(ll_connection_tensor) + rr_connection_tensor_list.append(rr_connection_tensor) + ligand_graph_edge_tensor_list.append(ligand_graph_edge_tensor) + receptor_graph_edge_tensor_list.append(receptor_graph_edge_tensor) + ligand_graph_num_nodes_list.append(np.array(ligand_graph_num_nodes).astype(np.int32)) + receptor_graph_num_nodes_list.append(np.array(receptor_graph_num_nodes).astype(np.int32)) + ligand_graph_node_tensor_list.append(ligand_graph_node_tensor) + receptor_graph_node_tensor_list.append(receptor_graph_node_tensor) + bound_ligand_repres_nodes_loc_array_list_new.append(bound_ligand_repres_nodes_loc_array.astype(np.float32)) + bound_receptor_repres_nodes_loc_array_list_new.append( + bound_receptor_repres_nodes_loc_array.astype(np.float32)) + pocket_coors_ligand_list_new.append(pocket_coors_ligand.astype(np.float32)) + pocket_coors_receptor_list_new.append(pocket_coors_receptor.astype(np.float32)) + + saved_lists = [ll_connection_tensor_list, rr_connection_tensor_list, ligand_graph_edge_tensor_list, + receptor_graph_edge_tensor_list, ligand_graph_num_nodes_list, receptor_graph_num_nodes_list, + ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, + bound_ligand_repres_nodes_loc_array_list_new, bound_receptor_repres_nodes_loc_array_list_new, + pocket_coors_ligand_list_new, pocket_coors_receptor_list_new] + + self.save_npz_files(args, saved_lists, reload_mode) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py new file mode 100644 index 000000000..fa27ccd93 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py @@ -0,0 +1,229 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""""equidock_dataset""" + +#pylint: disable=W0221 + +import os +from datetime import datetime as dt + +import numpy as np + +from .equidock_data import UnboundBoundData +from ...dataset import PSP + +class EquiDockDataSet(PSP): + "EquiDockDataSet" + def __init__(self, config): + self.config = config + self.train_data_batched = None + self.train_loader = None + self.val_data_batched = None + self.val_loader = None + self.test_data_batched = None + self.test_loader = None + super().__init__() + + def process(self, data, **kwargs): + self.log(data, **kwargs) + input_dir = self.config.input_dir + files = [f for f in os.listdir(input_dir) \ + if os.path.isfile(os.path.join(input_dir, f)) and f.endswith('.pdb')] + return files + + def set_training_data_src(self, data_source, **kwargs): + """ + set_training_data_src + """ + self.log(data_source, **kwargs) + + datasets = ['train', 'val', 'test'] + + if not os.path.exists(self.config.processed_dataset_path): + for _, dataset in enumerate(datasets): + UnboundBoundData( + self.config, + reload_mode=dataset, + raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, + ) + + self.train_data_batched, self.train_loader = self.create_dataloader( + self.config.train_dir, + shuffle=True, + ) + self.val_data_batched, self.val_loader = self.create_dataloader( + self.config.val_dir, + shuffle=False, + ) + self.test_data_batched, self.test_loader = self.create_dataloader( + self.config.test_dir, + shuffle=False, + ) + + return ( + self.train_data_batched, + self.train_loader, + self.val_data_batched, + self.val_loader, + self.test_data_batched, + self.test_loader, + ) + + def create_iterator(self, num_epochs, **kwargs): + self.log(**kwargs) + if num_epochs == self.config.num_epochs: + return [_ for _ in range(num_epochs)] + return [_ for _ in range(self.config.num_epochs)] + + def data_parse(self, idx): + return self.train_data_batched[idx] + + def log(self, *pargs): + print('[' + str(dt.now()) + '] ', *pargs) + + def load_data(self, files_dir): + """ + load_data + """ + npz_files_list = [ + np.load(files_dir + "/ll_connection_tensor_list.npz"), + np.load(files_dir + "/rr_connection_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_edge_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_edge_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_num_nodes_list.npz"), + np.load(files_dir + "/receptor_graph_num_nodes_list.npz"), + np.load(files_dir + "/ligand_graph_node_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_node_tensor_list.npz"), + np.load(files_dir + "/bound_ligand_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/bound_receptor_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/pocket_coors_ligand_list_new.npz"), + np.load(files_dir + "/pocket_coors_receptor_list_new.npz"), + ] + + extracted_files = [] + for _, npz_file in enumerate(npz_files_list): + extracted_files.append([npz_file[id] for id in npz_file.files]) + + unbatch_list = [] + for i, _ in enumerate(extracted_files[0]): + temp_length = [] + temp_length.append(len(extracted_files[0][i][0])) + temp_length.append(len(extracted_files[1][i][0])) + temp_length.append(len(extracted_files[6][i])) + temp_length.append(len(extracted_files[7][i])) + temp_length.append(len(extracted_files[11][i])) + unbatch_list.append(np.array(temp_length)) + extracted_files.append(unbatch_list) + + index_shuffle = np.arange(len(npz_files_list[0])) + + return extracted_files, index_shuffle + + + def shuffle_list(self, source_list, index_shuffle): + shuffled_list = [] + for _, idx in enumerate(index_shuffle): + shuffled_list.append(source_list[idx]) + + return shuffled_list + + + def batch(self, bs, souce_list): + output_list = [] + num = len(souce_list) // bs + for i in range(num): + output_list.append(souce_list[i * bs: (i + 1) * bs]) + if num * bs < len(souce_list): + output_list.append(souce_list[num * bs:]) + + return output_list + + + def shuffle_batch_dataset(self, input_dataset, index_shuffle, batch_size, shuffle): + """ + shuffle_batch_dataset + """ + if shuffle: + shuffled_dataset = [] + np.random.shuffle(index_shuffle) + for i, _ in enumerate(input_dataset): + shuffled_dataset.append(self.shuffle_list(input_dataset[i], index_shuffle)) + else: + shuffled_dataset = input_dataset[:] + + dataset_batched = [] + for _, data in enumerate(shuffled_dataset): + dataset_batched.append(self.batch(batch_size, data)) + unbatch_list = dataset_batched[-1] + for j, _ in enumerate(unbatch_list): + unbatch_list[j].insert(0, np.array([0, 0, 0, 0, 0])) + for i, _ in enumerate(unbatch_list[j]): + if i >= 1: + unbatch_list[j][i][0] += unbatch_list[j][i - 1][0] + unbatch_list[j][i][1] += unbatch_list[j][i - 1][1] + unbatch_list[j][i][2] += unbatch_list[j][i - 1][2] + unbatch_list[j][i][3] += unbatch_list[j][i - 1][3] + unbatch_list[j][i][4] += unbatch_list[j][i - 1][4] + dataset_batched[-1] = unbatch_list + + return dataset_batched + + + def cat_properties(self, input_dataset_batched): + """ + cat_properties + """ + dataset_batch_cat = [] + for i in range(len(input_dataset_batched[0])): + dataset_batch_cat.append( + [ + np.concatenate(input_dataset_batched[0][i], axis=1), + np.concatenate(input_dataset_batched[1][i], axis=1), + np.concatenate(input_dataset_batched[2][i], axis=0), + np.concatenate(input_dataset_batched[3][i], axis=0), + np.concatenate(input_dataset_batched[4][i], axis=0), + np.concatenate(input_dataset_batched[5][i], axis=0), + np.concatenate(input_dataset_batched[6][i], axis=0), + np.concatenate(input_dataset_batched[7][i], axis=0), + input_dataset_batched[-1][i], + ] + ) + + return dataset_batch_cat + + + def create_dataloader(self, dataset_dir, shuffle): + """ + create_dataloader + """ + dataset, index_shuffle = self.load_data(dataset_dir) + dataset_batched = self.shuffle_batch_dataset( + input_dataset=dataset, + index_shuffle=index_shuffle, + batch_size=self.config.bs, + shuffle=shuffle, + ) + dataloader = self.cat_properties(dataset_batched) + + return dataset_batched, dataloader + + + def __getitem__(self, idx): + return self.data_parse(idx) + + + def __len__(self): + return len(self.train_data_batched) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py new file mode 100644 index 000000000..e5d5f0a2c --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py @@ -0,0 +1,1326 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""nn_arch""" + +import os +import math +from datetime import datetime as dt + +import numpy as np +from numpy import linalg as LA +from biopandas.pdb import PandasPdb +import scipy.spatial as spa +from scipy.special import softmax +from mindspore import nn, ops, Tensor + + +ATOM_NAME = 'atom_name' +FLAGS = os.O_RDWR | os.O_CREAT + + +def log(*pargs): + banner = 'EQUIDOCK' + log_file_name = os.path.join('stdouterr/', banner + ".txt") + with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as w: + w.write('[' + str(dt.now()) + '] ') + w.write(" ".join(["{}".format(t) for t in pargs])) + w.write("\n") + print('[' + str(dt.now()) + '] ', *pargs) + + +class MeterUnboundBound(): + """ + MeterUnboundBound + """ + def __init__(self): + self.complex_rmsd_list = [] + self.ligand_rmsd_list = [] + self.receptor_rmsd_list = [] + + def update_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true): + """ + update_rmsd + """ + ligand_coors_pred = ligand_coors_pred.asnumpy() + receptor_coors_pred = receptor_coors_pred + + ligand_coors_true = ligand_coors_true + receptor_coors_true = receptor_coors_true + + ligand_rmsd = np.sqrt(np.mean(np.sum((ligand_coors_pred - ligand_coors_true) ** 2, axis=1))) + receptor_rmsd = np.sqrt(np.mean(np.sum((receptor_coors_pred - receptor_coors_true) ** 2, axis=1))) + + complex_coors_pred = np.concatenate((ligand_coors_pred, receptor_coors_pred), axis=0) + complex_coors_true = np.concatenate((ligand_coors_true, receptor_coors_true), axis=0) + + r, b = rigid_transform_kabsch_3d(complex_coors_pred.T, complex_coors_true.T) + complex_coors_pred_aligned = ((r @ complex_coors_pred.T) + b).T + + complex_rmsd = np.sqrt(np.mean(np.sum((complex_coors_pred_aligned - complex_coors_true) ** 2, axis=1))) + + self.complex_rmsd_list.append(complex_rmsd) + self.ligand_rmsd_list.append(ligand_rmsd) + self.receptor_rmsd_list.append(receptor_rmsd) + + return complex_rmsd + + def summarize(self, reduction_rmsd='median'): + """ + summarize + """ + if reduction_rmsd == 'mean': + complex_rmsd_array = np.array(self.complex_rmsd_list) + complex_rmsd_summarized = np.mean(complex_rmsd_array) + + ligand_rmsd_array = np.array(self.ligand_rmsd_list) + ligand_rmsd_summarized = np.mean(ligand_rmsd_array) + + receptor_rmsd_array = np.array(self.receptor_rmsd_list) + receptor_rmsd_summarized = np.mean(receptor_rmsd_array) + elif reduction_rmsd == 'median': + complex_rmsd_array = np.array(self.complex_rmsd_list) + complex_rmsd_summarized = np.median(complex_rmsd_array) + + ligand_rmsd_array = np.array(self.ligand_rmsd_list) + ligand_rmsd_summarized = np.median(ligand_rmsd_array) + + receptor_rmsd_array = np.array(self.receptor_rmsd_list) + receptor_rmsd_summarized = np.median(receptor_rmsd_array) + else: + raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!") + + return ligand_rmsd_summarized, receptor_rmsd_summarized, complex_rmsd_summarized + + def summarize_with_std(self, reduction_rmsd='median'): + complex_rmsd_array = np.array(self.complex_rmsd_list) + if reduction_rmsd == 'mean': + complex_rmsd_summarized = np.mean(complex_rmsd_array) + elif reduction_rmsd == 'median': + complex_rmsd_summarized = np.median(complex_rmsd_array) + else: + raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!") + + return complex_rmsd_summarized, np.std(complex_rmsd_array) + + +def get_nodes_coors_numpy(filename, all_atoms=False): + df = PandasPdb().read_pdb(filename).df['ATOM'] + if not all_atoms: + return Tensor( + df[df['atom_name'] == 'CA'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)) + return Tensor(df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)) + + +def get_residues(pdb_filename): + df = PandasPdb().read_pdb(pdb_filename).df['ATOM'] + df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname', + 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True) + residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order ! + return residues + + +def distance_list_featurizer(dist_list): + """ + distance_list_featurizer + """ + length_scale_list = [1.5 ** x for x in range(15)] + center_list = [0. for _ in range(15)] + + num_edge = len(dist_list) + dist_list = np.array(dist_list) + + transformed_dist = [np.exp(- ((dist_list - center) ** 2) / float(length_scale)) + for length_scale, center in zip(length_scale_list, center_list)] + + transformed_dist = np.array(transformed_dist).T + transformed_dist = transformed_dist.reshape((num_edge, -1)) + + processed_features = dict() + processed_features['he'] = Tensor(transformed_dist.astype(np.float32)) + return processed_features + + +def rigid_transform_kabsch_3d(a, b): + """ + rigid_transform_kabsch_3d + """ + if a.shape[1] != b.shape[1]: + raise ValueError("a.shape[1] not equals b.shape[1]") + num_rows, num_cols = a.shape + if num_rows != 3: + raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") + num_rows, num_cols = b.shape + if num_rows != 3: + raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") + + # find mean column wise: 3 x 1 + centroid_a = np.mean(a, axis=1, keepdims=True) + centroid_b = np.mean(b, axis=1, keepdims=True) + + # subtract mean + am = a - centroid_a + bm = b - centroid_b + + h = am @ bm.T + + # find rotation + u, _, vt = np.linalg.svd(h) + + r = vt.T @ u.T + + # special reflection case + if np.linalg.det(r) < 0: + ss = np.diag([1., 1., -1.]) + r = (vt.T @ ss) @ u.T + if math.fabs(np.linalg.det(r) - 1) >= 1e-5: + raise ValueError("The value should be smaller than 1e-5") + + t = -r @ centroid_a + centroid_b + return r, t + + +def one_hot_encoding(x, allowable_set, encode_unknown=False): + if encode_unknown and (allowable_set[-1] is not None): + allowable_set.append(None) + + if encode_unknown and (x not in allowable_set): + x = None + + return list(map(lambda s: x == s, allowable_set)) + + +def residue_list_featurizer_dips_one_hot(predic): + residue_list = [term[1]['resname'].iloc[0] for term in predic] + feature_list = [residue_type_one_hot_dips(residue) for residue in residue_list] + feature_list = np.stack(feature_list) + processed_features = dict() + processed_features['res_feat'] = Tensor(feature_list.astype(np.float32)) + return processed_features + + +def residue_list_featurizer_dips_not_one_hot(predic): + residue_list = [term[1]['resname'].iloc[0] for term in predic] + feature_list = [[residue_type_one_hot_dips_not_one_hot(residue)] for residue in residue_list] + feature_list = np.array(feature_list) + processed_features = dict() + processed_features['res_feat'] = Tensor(feature_list.astype(np.float32)) # (N_res, 1) + return processed_features + + +def residue_type_one_hot_dips(residue): + """ + residue_type_one_hot_dips + """ + dit = { + 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', + 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', + 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V', + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G', + } + allowable_set = [ + 'Y', 'R', 'F', 'G', 'I', 'V', 'A', 'W', 'E', 'H', 'C', 'N', 'M', 'D', 'T', 'S', 'K', 'L', 'Q', 'P' + ] + res_name = residue + if res_name not in dit.keys(): + res_name = None + else: + res_name = dit[res_name] + return one_hot_encoding(res_name, allowable_set, encode_unknown=True) + + +def residue_type_one_hot_dips_not_one_hot(residue): + """ + residue_type_one_hot_dips_not_one_hot + """ + dit = { + 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', + 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', + 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V', + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G' + } + + rare_residues = { + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G' + } + + if residue in rare_residues.keys(): + log('Some rare residue: ', residue) + + indicator = { + 'Y': 0, 'R': 1, 'F': 2, 'G': 3, 'I': 4, 'V': 5, + 'A': 6, 'W': 7, 'E': 8, 'H': 9, 'C': 10, 'N': 11, + 'M': 12, 'D': 13, 'T': 14, 'S': 15, 'K': 16, 'L': 17, 'Q': 18, 'P': 19 + } + res_name = residue + if res_name not in dit.keys(): + return 20 + res_name = dit[res_name] + return indicator.get(res_name) + + +def preprocess_unbound_bound(input_residues_tuple, graph_nodes, pos_cutoff=8.0, inference=False): + """ + preprocess_unbound_bound + """ + ####################### + def filter_residues(residues): + residues_filtered = [] + for residue in residues: + df = residue[1] + natom = df[df[ATOM_NAME] == 'N'] + alphacatom = df[df[ATOM_NAME] == 'CA'] + catom = df[df[ATOM_NAME] == 'C'] + + if natom.shape[0] == 1 and alphacatom.shape[0] == 1 and catom.shape[0] == 1: + residues_filtered.append(residue) + return residues_filtered + + ########################## + bound_ligand_residues, bound_receptor_residues = input_residues_tuple + bound_predic_ligand_filtered = filter_residues(bound_ligand_residues) + unbound_predic_ligand_filtered = bound_predic_ligand_filtered + + bound_predic_receptor_filtered = filter_residues(bound_receptor_residues) + unbound_predic_receptor_filtered = bound_predic_receptor_filtered + + bound_predic_ligand_clean_list = bound_predic_ligand_filtered + unbound_predic_ligand_clean_list = unbound_predic_ligand_filtered + + bound_predic_receptor_clean_list = bound_predic_receptor_filtered + unbound_predic_receptor_clean_list = unbound_predic_receptor_filtered + + ################### + def get_alphac_loc_array(bound_predic_clean_list): + bound_alphac_loc_clean_list = [] + for residue in bound_predic_clean_list: + df = residue[1] + alphacatom = df[df[ATOM_NAME] == 'CA'] + alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + bound_alphac_loc_clean_list.append(alphac_loc) + if len(bound_alphac_loc_clean_list) <= 1: + bound_alphac_loc_clean_list.append(np.zeros(3)) + return np.stack(bound_alphac_loc_clean_list, axis=0) # (N_res,3) + + #################### + if graph_nodes != 'residues': + raise TypeError("graph_nodes should be residues") + bound_receptor_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_receptor_clean_list) + bound_ligand_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_ligand_clean_list) + + if not inference: + + # Keep pairs of ligand and receptor residues/atoms that have pairwise distances < threshold + ligand_receptor_distance = spa.distance.cdist(bound_ligand_repres_nodes_loc_array, + bound_receptor_repres_nodes_loc_array) + positive_tuple = np.where(ligand_receptor_distance < pos_cutoff) + active_ligand = positive_tuple[0] + active_receptor = positive_tuple[1] + if active_ligand.size <= 3: # We need: active_ligand.size > 0 ' + pocket_coors = None # Will be filtered out later + else: + ligand_pocket_coors = bound_ligand_repres_nodes_loc_array[active_ligand, :] + receptor_pocket_coors = bound_receptor_repres_nodes_loc_array[active_receptor, :] + if np.max(np.linalg.norm(ligand_pocket_coors - receptor_pocket_coors, axis=1)) > pos_cutoff: + raise ValueError("The value should <= pos_cutoff") + pocket_coors = 0.5 * (ligand_pocket_coors + receptor_pocket_coors) + log('Num pocket nodes = ', len(active_ligand), ' total nodes = ', + bound_ligand_repres_nodes_loc_array.shape[0], ' graph_nodes = ', graph_nodes) + + return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, \ + pocket_coors + + return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, 0 + + +def protein_to_graph_unbound_bound( + unbound_bound_tuple, + cutoff=20, + max_neighbor=None, + one_hot=False, + residue_loc_is_alphac=True, +): + return protein_to_graph_unbound_bound_residuesonly( + unbound_bound_tuple, + cutoff, + max_neighbor, + one_hot, + residue_loc_is_alphac + ) + + +def protein_to_graph_unbound_bound_residuesonly( + unbound_bound_tuple, + cutoff=20, + max_neighbor=None, + one_hot=False, + residue_loc_is_alphac=True +): + """ + protein_to_graph_unbound_bound_residuesonly + """ + unbound_ligand_predic, unbound_receptor_predic, bound_ligand_repres_nodes_loc_clean_array, \ + bound_receptor_repres_nodes_loc_clean_array = unbound_bound_tuple + + ################## Extract 3D coordinates and n_i,u_i,v_i vectors of representative residues ################ + def l_or_r_extract_3d_coord_and_n_u_v_vecs(l_or_r_predic): + l_or_r_all_atom_coords_in_residue_list = [] + l_or_r_residue_representatives_loc_list = [] + l_or_r_n_i_list = [] + l_or_r_u_i_list = [] + l_or_r_v_i_list = [] + + for residue in l_or_r_predic: + df = residue[1] + coord = df[['x', 'y', 'z']].to_numpy().astype(np.float32) # (N_atoms, 3) + l_or_r_all_atom_coords_in_residue_list.append(coord) + + natom = df[df[ATOM_NAME] == 'N'] + alphacatom = df[df[ATOM_NAME] == 'CA'] + catom = df[df[ATOM_NAME] == 'C'] + + if natom.shape[0] != 1 or alphacatom.shape[0] != 1 or catom.shape[0] != 1: + raise ValueError("protein utils protein_to_graph_unbound_bound, no N/CA/C exists") + + n_loc = natom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + c_loc = catom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + + u_i = (n_loc - alphac_loc) / LA.norm(n_loc - alphac_loc) + t_i = (c_loc - alphac_loc) / LA.norm(c_loc - alphac_loc) + n_i = np.cross(u_i, t_i) / LA.norm(np.cross(u_i, t_i)) + v_i = np.cross(n_i, u_i) + + l_or_r_n_i_list.append(n_i) + l_or_r_u_i_list.append(u_i) + l_or_r_v_i_list.append(v_i) + + if residue_loc_is_alphac: + l_or_r_residue_representatives_loc_list.append(alphac_loc) + else: + heavy_df = df[df['element'] != 'H'] + residue_loc = heavy_df[['x', 'y', 'z']].mean(axis=0).to_numpy().astype( + np.float32) # average of all atom coordinates + l_or_r_residue_representatives_loc_list.append(residue_loc) + + l_or_r_residue_representatives_loc_feat = np.stack(l_or_r_residue_representatives_loc_list, + axis=0) # (N_res, 3) + l_or_r_n_i_feat = np.stack(l_or_r_n_i_list, axis=0) + l_or_r_u_i_feat = np.stack(l_or_r_u_i_list, axis=0) + l_or_r_v_i_feat = np.stack(l_or_r_v_i_list, axis=0) + + l_or_r_num_residues = len(l_or_r_predic) + if l_or_r_num_residues <= 1: + raise ValueError(f"l_or_r contains only 1 residue!") + return l_or_r_all_atom_coords_in_residue_list, \ + l_or_r_residue_representatives_loc_feat, \ + l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, l_or_r_num_residues + + (ligand_all_atom_coords_in_residue_list, # list of (N_atoms,3) arrays, for each residue + ligand_residue_representatives_loc_feat, # (N_res, 3) + ligand_n_i_feat, # (N_res, 3) + ligand_u_i_feat, # (N_res, 3) + ligand_v_i_feat, # (N_res, 3) + ligand_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_ligand_predic) + + (receptor_all_atom_coords_in_residue_list, + receptor_residue_representatives_loc_feat, + receptor_n_i_feat, + receptor_u_i_feat, + receptor_v_i_feat, + receptor_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_receptor_predic) + + ################# Align unbound and bound structures, if needed ################################ + def l_or_r_align_unbound_and_bound(l_or_r_residue_representatives_loc_feat, + l_or_r_n_i_feat, + l_or_r_u_i_feat, + l_or_r_v_i_feat, + bound_l_or_r_alphac_loc_clean_array): + + ret_r_l_or_r, ret_t_l_or_r = rigid_transform_kabsch_3d(l_or_r_residue_representatives_loc_feat.T, + bound_l_or_r_alphac_loc_clean_array.T) + l_or_r_residue_representatives_loc_feat = ((ret_r_l_or_r @ (l_or_r_residue_representatives_loc_feat).T) + + ret_t_l_or_r).T + l_or_r_n_i_feat = ((ret_r_l_or_r @ (l_or_r_n_i_feat).T)).T + l_or_r_u_i_feat = ((ret_r_l_or_r @ (l_or_r_u_i_feat).T)).T + l_or_r_v_i_feat = ((ret_r_l_or_r @ (l_or_r_v_i_feat).T)).T + return l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat + + (ligand_residue_representatives_loc_feat, + ligand_n_i_feat, + ligand_u_i_feat, + ligand_v_i_feat) = l_or_r_align_unbound_and_bound(ligand_residue_representatives_loc_feat, + ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat, + bound_ligand_repres_nodes_loc_clean_array) + (receptor_residue_representatives_loc_feat, + receptor_n_i_feat, + receptor_u_i_feat, + receptor_v_i_feat) = l_or_r_align_unbound_and_bound(receptor_residue_representatives_loc_feat, + receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat, + bound_receptor_repres_nodes_loc_clean_array) + + ################### Build the k-NN graph ############################## + def loop_edges(input_tuple): + + l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, \ + l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list = input_tuple + + # Loop over all edges of the graph and build the various p_ij, q_ij, k_ij, t_ij pairs + l_or_r_edge_feat_ori_list = [] + for i in range(len(l_or_r_dist_list)): + src = l_or_r_src_list[i] + dst = l_or_r_dst_list[i] + + # place n_i, u_i, v_i as lines in a 3x3 basis matrix + basis_matrix = np.stack((l_or_r_n_i_feat[dst, :], l_or_r_u_i_feat[dst, :], l_or_r_v_i_feat[dst, :]), axis=0) + p_ij = np.matmul(basis_matrix, l_or_r_residue_representatives_loc_feat[src, :] - + l_or_r_residue_representatives_loc_feat[dst, :]) + q_ij = np.matmul(basis_matrix, l_or_r_n_i_feat[src, :]) # shape (3,) + k_ij = np.matmul(basis_matrix, l_or_r_u_i_feat[src, :]) + t_ij = np.matmul(basis_matrix, l_or_r_v_i_feat[src, :]) + s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0) # shape (12,) + l_or_r_edge_feat_ori_list.append(s_ij) + l_or_r_edge_feat_ori_feat = np.stack(l_or_r_edge_feat_ori_list, axis=0) # shape (num_edges, 4, 3) + l_or_r_edge_feat_ori_feat = Tensor(l_or_r_edge_feat_ori_feat.astype(np.float32)) + l_or_r_protein_graph.edata['he'] = Tensor( + np.concatenate((l_or_r_protein_graph.edata['he'].asnumpy(), l_or_r_edge_feat_ori_feat.asnumpy()), axis=1) + ) + + l_or_r_residue_representatives_loc_feat = Tensor( + l_or_r_residue_representatives_loc_feat.astype(np.float32)) + l_or_r_protein_graph.ndata['x'] = l_or_r_residue_representatives_loc_feat + l_or_r_protein_graph.ndata['mu_r_norm'] = Tensor( + np.array(l_or_r_mean_norm_list).astype(np.float32)) + + return l_or_r_protein_graph + + + def compute_dig_knn_graph(input_tuple): + + l_or_r_num_residues, l_or_r_all_atom_coords_in_residue_list, unbound_l_or_r_predic,\ + l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat = input_tuple + + l_or_r_distance = np.full((l_or_r_num_residues, l_or_r_num_residues), np.inf) + + for i in range(l_or_r_num_residues - 1): + for j in range((i + 1), l_or_r_num_residues): + l_or_r_pairwise_dis = spa.distance.cdist(l_or_r_all_atom_coords_in_residue_list[i], + l_or_r_all_atom_coords_in_residue_list[j]) + l_or_r_distance[i, j] = np.mean(l_or_r_pairwise_dis) + l_or_r_distance[j, i] = np.mean(l_or_r_pairwise_dis) + + l_or_r_protein_graph = Graph(num_nodes=l_or_r_num_residues) + + l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_mean_norm_list = [], [], [], [] + + for i in range(l_or_r_num_residues): + valid_src = list(np.where(l_or_r_distance[i, :] < cutoff)[0]) + if len(valid_src) > max_neighbor: + valid_src = list(np.argsort(l_or_r_distance[i, :]))[0: max_neighbor] + valid_dst = [i] * len(valid_src) + l_or_r_dst_list.extend(valid_dst) + l_or_r_src_list.extend(valid_src) + + valid_dist = list(l_or_r_distance[i, valid_src]) + l_or_r_dist_list.extend(valid_dist) + + valid_dist_np = l_or_r_distance[i, valid_src] + sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) + weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num) + diff_vecs = l_or_r_residue_representatives_loc_feat[valid_dst, :] - \ + l_or_r_residue_representatives_loc_feat[valid_src, :] # (neigh_num, 3) + mean_vec = weights.dot(diff_vecs) # (sigma_num, 3) + denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,) + mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,) + l_or_r_mean_norm_list.append(mean_vec_ratio_norm) + + l_or_r_protein_graph.add_edges(l_or_r_src_list, l_or_r_dst_list) + + if one_hot: + l_or_r_protein_graph.ndata = residue_list_featurizer_dips_one_hot(unbound_l_or_r_predic) + else: + l_or_r_protein_graph.ndata = residue_list_featurizer_dips_not_one_hot(unbound_l_or_r_predic) + + l_or_r_protein_graph.edata = distance_list_featurizer(l_or_r_dist_list) + + loop_edges_input = l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, \ + l_or_r_v_i_feat, l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list + + return loop_edges(loop_edges_input) + + + ligand_protein_graph_tuple = ( + ligand_num_residues, ligand_all_atom_coords_in_residue_list, unbound_ligand_predic, + ligand_residue_representatives_loc_feat, ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat, + ) + ligand_protein_graph = compute_dig_knn_graph(ligand_protein_graph_tuple) + + receptor_protein_graph_tuple = ( + receptor_num_residues, receptor_all_atom_coords_in_residue_list, unbound_receptor_predic, + receptor_residue_representatives_loc_feat, receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat, + ) + receptor_protein_graph = compute_dig_knn_graph(receptor_protein_graph_tuple) + + return ligand_protein_graph, receptor_protein_graph + + +def unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, h_feats_ligand, coors_receptor, coors_ligand): + """ + unbatch_hetero_graph + """ + list_hetero_graph = [] + for i, _ in enumerate(unbatch_list_tensor): + if i < len(unbatch_list_tensor) - 1: + hetero_graph = { + "receptor_hv_iegmn_out": \ + h_feats_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + "ligand_hv_iegmn_out": \ + h_feats_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + "receptor_x_iegmn_out": \ + coors_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + "ligand_x_iegmn_out": \ + coors_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + } + list_hetero_graph.append(hetero_graph) + return list_hetero_graph + + +def get_non_lin(input_type, negative_slope): + if input_type == 'swish': + return nn.SiLU() + if input_type == 'lkyrelu': + return nn.LeakyReLU(alpha=negative_slope) + raise NotImplementedError + + +def get_layer_norm(layer_norm_type, dim): + if layer_norm_type == 'BN': + return nn.BatchNorm1d([dim]) + if layer_norm_type == 'LN': + return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, + epsilon=1e-5) + return nn.Identity() + + +def get_final_h_layer_norm(layer_norm_type, dim): + if layer_norm_type == 'BN': + return nn.BatchNorm1d(dim) + if layer_norm_type == 'LN': + return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, epsilon=1e-5) + if layer_norm_type == '0': + return nn.Identity() + raise NotImplementedError + + +def apply_final_h_layer_norm(h, norm_layer): + return norm_layer(h) + + +def compute_cross_attention(queries, keys, values, mask, cross_msgs): + """ + compute_cross_attention + """ + if not cross_msgs: + return 0 + a = mask * ops.mm(queries, ops.transpose(keys, (1, 0))) - 1000. * (1. - mask) + a_x = ops.softmax(a) # i->j, NxM + attention_x = ops.mm(a_x, values) # (N,d) + return attention_x + + +def get_mask(ligand_batch_num_nodes, receptor_batch_num_nodes): + """ + get_mask + """ + rows = sum(ligand_batch_num_nodes) + cols = sum(receptor_batch_num_nodes) + mask = ops.zeros((int(rows), int(cols))) + partial_l = 0 + partial_r = 0 + for l_n, r_n in zip(ligand_batch_num_nodes, receptor_batch_num_nodes): + mask[partial_l: partial_l + l_n, partial_r: partial_r + r_n] = 1 + partial_l = partial_l + l_n + partial_r = partial_r + r_n + return mask + + +class Graph: + """ + Graph class for wrapping data + """ + def __init__( + self, + num_nodes=0, + num_edges=0, + ): + self.num_nodes = num_nodes + self.num_edges = num_edges + self.ndata = {} + self.edata = {} + self.src_list = [] + self.dst_list = [] + + def add_edges(self, src_list, dst_list): + if len(src_list) != len(dst_list): + raise NotImplementedError + self.src_list = src_list + self.dst_list = dst_list + self.num_edges = len(dst_list) + + +class IEGMNLayer(nn.Cell): + """ + IEGMNLayer class + """ + def __init__( + self, + orig_h_feats_dim, + h_feats_dim, + out_feats_dim, + fine_tune, + args, + log_input=None): + + super(IEGMNLayer, self).__init__() + + dropout = args.dropout + nonlin = args.nonlin + layer_norm = args.layer_norm + leakyrelu_neg_slope = args.leakyrelu_neg_slope + self.input_edge_feats_dim = args.input_edge_feats_dim + self.layer_norm_coors = args.layer_norm_coors + self.cross_msgs = args.cross_msgs + self.final_h_layer_norm = args.final_h_layer_norm + self.use_dist_in_layers = args.use_dist_in_layers + self.skip_weight_h = args.skip_weight_h + self.x_connection_init = args.x_connection_init + self.debug = args.debug + self.fine_tune = fine_tune + self.log = log_input + self.h_feats_dim = h_feats_dim + self.out_feats_dim = out_feats_dim + self.all_sigmas_dist = [1.5 ** x for x in range(15)] + self.orig_h_feats_dim = orig_h_feats_dim + + # EDGES + self.edge_mlp = self.edge_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope) + + # NODES + self.node_norm = nn.Identity() + + self.att_mlp_q = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + + self.att_mlp_k = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + + self.att_mlp_v = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + ) + + self.node_mlp = self.node_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope) + + self.final_h_layernorm_layer = get_final_h_layer_norm(self.final_h_layer_norm, self.out_feats_dim) + + ## The scalar weight to be multiplied by (x_i - x_j) + self.coors_mlp = self.coors_mlp_func(dropout, nonlin, leakyrelu_neg_slope) + + if self.fine_tune: + self.att_mlp_cross_coors_q = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + self.att_mlp_cross_coors_k = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + self.att_mlp_cross_coors_v = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim), + get_non_lin(nonlin, leakyrelu_neg_slope), + nn.Dense(h_feats_dim, 1), + ) + + def __repr__(self): + return "IEGMNLayer " + str(self.__dict__) + + def edge_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope): + """ + IEGMNLayer edge_mlp_func + """ + edge_mlp = nn.SequentialCell( + nn.Dense((self.h_feats_dim * 2) + self.input_edge_feats_dim + + len(self.all_sigmas_dist), self.out_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(layer_norm, self.out_feats_dim), + nn.Dense(self.out_feats_dim, self.out_feats_dim), + ) + return edge_mlp + + def node_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope): + + node_mlp = nn.SequentialCell( + nn.Dense(self.orig_h_feats_dim + 2 * self.h_feats_dim + self.out_feats_dim, self.h_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(layer_norm, self.h_feats_dim), + nn.Dense(self.h_feats_dim, self.out_feats_dim), + ) + return node_mlp + + def coors_mlp_func(self, dropout, nonlin, leakyrelu_neg_slope): + coors_mlp = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.out_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(self.layer_norm_coors, self.out_feats_dim), + nn.Dense(self.out_feats_dim, 1) + ) + return coors_mlp + + def x_rel_mag(self, nodes_ligand_x_now, ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor): + """ + IEGMNLayer x_rel_mag + """ + nodes_ligand_x_now_src = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[0]) + nodes_ligand_x_now_dst = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[1]) + ll_edge_x_rel = ops.sub(nodes_ligand_x_now_src, nodes_ligand_x_now_dst) + + nodes_receptor_x_now_src = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[0]) + nodes_receptor_x_now_dst = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[1]) + rr_edge_x_rel = ops.sub(nodes_receptor_x_now_src, nodes_receptor_x_now_dst) + + x_rel_mag_ligand = ll_edge_x_rel ** 2 + x_rel_mag_ligand = ops.sum(x_rel_mag_ligand, dim=1, keepdim=True) # ||x_i - x_j||^2 : (N_res, 1) + x_rel_mag_ligand = ops.cat([ops.exp(-x_rel_mag_ligand / sigma) for sigma in self.all_sigmas_dist], axis=-1) + + x_rel_mag_receptor = rr_edge_x_rel ** 2 + x_rel_mag_receptor = ops.sum(x_rel_mag_receptor, dim=1, keepdim=True) + x_rel_mag_receptor = ops.cat([ops.exp(-x_rel_mag_receptor / sigma) for sigma in self.all_sigmas_dist], axis=-1) + + return ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor + + def cat_input_for_msg(self, nodes_feat, original_edge_feats, connection_tensor, x_rel_mag): + nodes_feat_src = ops.index_select(nodes_feat, 0, connection_tensor[0]) + nodes_feat_dst = ops.index_select(nodes_feat, 0, connection_tensor[1]) + nodes_cat_feat = ops.cat((nodes_feat_src, nodes_feat_dst), axis=1) + cat_input_for_msg = ops.cat((nodes_cat_feat, # [h_i h_j] + original_edge_feats, + x_rel_mag), axis=-1) + return cat_input_for_msg + + def nodes_aggr_cross_msg(self, h_feats_ligand, h_feats_receptor, mask): + """ + IEGMNLayer nodes_aggr_cross_msg + """ + nodes_ligand_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_ligand), + self.att_mlp_k(h_feats_receptor), + self.att_mlp_v(h_feats_receptor), + mask, + self.cross_msgs) + nodes_receptor_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_receptor), + self.att_mlp_k(h_feats_ligand), + self.att_mlp_v(h_feats_ligand), + mask.transpose(1, 0), + self.cross_msgs) + return nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg + + def cal_x_final(self, edges_x_moment, edges_msg, orig_coors, nodes_x_now): + nodes_x_update = ops.mean(ops.reshape(edges_x_moment, (-1, 10, 3)), axis=1) + nodes_aggr_msg = ops.mean(ops.reshape(edges_msg, (-1, 10, edges_msg.shape[-1])), axis=1) + x_final = self.x_connection_init * orig_coors + (1. - self.x_connection_init) * nodes_x_now + \ + nodes_x_update + return nodes_aggr_msg, x_final + + def fine_tune_final_lr(self, input_fine_tune_tuple): + """ + IEGMNLayer fine_tune_final_lr + """ + x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, h_feats_receptor, \ + nodes_receptor_x_now, mask = input_fine_tune_tuple + + x_final_ligand = x_final_ligand + self.att_mlp_cross_coors_v(h_feats_ligand) * \ + (nodes_ligand_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_ligand), + self.att_mlp_cross_coors_k(h_feats_receptor), + nodes_receptor_x_now, + mask, + self.cross_msgs)) + x_final_receptor = x_final_receptor + self.att_mlp_cross_coors_v(h_feats_receptor) * \ + (nodes_receptor_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_receptor), + self.att_mlp_cross_coors_k(h_feats_ligand), + nodes_ligand_x_now, + mask.transpose(1, 0), + self.cross_msgs)) + return x_final_ligand, x_final_receptor + + def skip_connections(self, h_feats_ligand, h_feats_receptor, input_node_upd_ligand, input_node_upd_receptor): + """ + IEGMNLayer skip_connections + """ + if self.h_feats_dim == self.out_feats_dim: + node_upd_ligand = self.skip_weight_h * self.node_mlp(input_node_upd_ligand) + ( + 1. - self.skip_weight_h) * h_feats_ligand + node_upd_receptor = self.skip_weight_h * self.node_mlp(input_node_upd_receptor) + ( + 1. - self.skip_weight_h) * h_feats_receptor + else: + node_upd_ligand = self.node_mlp(input_node_upd_ligand) + node_upd_receptor = self.node_mlp(input_node_upd_receptor) + + node_upd_ligand = apply_final_h_layer_norm(node_upd_ligand, self.final_h_layernorm_layer) + node_upd_receptor = apply_final_h_layer_norm(node_upd_receptor, self.final_h_layernorm_layer) + + return node_upd_ligand, node_upd_receptor + + def log_debug_info(self, debug_variable_tuple): + """ + IEGMNLayer log_debug_info + """ + nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, \ + nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, \ + x_final_ligand = debug_variable_tuple + self.log(ops.max(nodes_ligand_x_now)[0], 'x_now : x_i at layer entrance') + self.log(ops.max(nodes_ligand_feat)[0], 'data[feat] = h_i at layer entrance') + self.log(ops.max(ll_edge_x_rel)[0], 'x_rel : x_i - x_j') + self.log(ops.max(x_rel_mag_ligand, axis=0)[0], + 'x_rel_mag_ligand = [exp(-||x_i - x_j||^2 / sigma) for sigma = 1.5 ** x, x = [0, 15]]') + self.log(ops.max(edges_ll_msg)[0], 'data[msg] = m_{i->j} = phi^e(h_i, h_j, f_{i,j}, x_rel_mag_ligand)') + self.log(ops.max(nodes_ligand_aggr_cross_msg)[0], 'aggr_cross_msg(i) = sum_j a_{i,j} * h_j') + self.log(ops.max(edge_coef_ligand)[0], 'edge_coef_ligand : ' + r'\p' + 'hi^x(m_{i->j})') + self.log(ops.max(edges_ll_x_moment)[0], 'data[x_moment] = (x_i - x_j) * ' + r'\p' + 'hi^x(m_{i->j})') + self.log(ops.max(nodes_ligand_aggr_msg)[0], 'data[aggr_msg]: ' + r'\s' + 'um_j m_{i->j}') + self.log(ops.max(x_final_ligand)[0], 'x_i new = x_final_ligand : x_i + data[x_update]') + + def construct(self, iegmn_layer_tuple, input_tensor_tuple): + """ + IEGMNLayer construct + """ + ligand_graph_num_nodes, receptor_graph_num_nodes, ll_connection_tensor, rr_connection_tensor = \ + input_tensor_tuple[0], input_tensor_tuple[1], input_tensor_tuple[4], input_tensor_tuple[5] + + coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand, \ + orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features, \ + original_edge_feats_receptor, orig_coors_receptor = iegmn_layer_tuple + + nodes_ligand_x_now, nodes_receptor_x_now, nodes_ligand_feat, nodes_receptor_feat = \ + coors_ligand, coors_receptor, h_feats_ligand, h_feats_receptor + + ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor = self.x_rel_mag(nodes_ligand_x_now, \ + ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor) + + if not self.use_dist_in_layers: + x_rel_mag_ligand = 0 + x_rel_mag_receptor = 0 + + cat_input_for_msg_ligand = self.cat_input_for_msg( + nodes_ligand_feat, original_edge_feats_ligand, ll_connection_tensor, x_rel_mag_ligand) + cat_input_for_msg_receptor = self.cat_input_for_msg( + nodes_receptor_feat, original_edge_feats_receptor, rr_connection_tensor, x_rel_mag_receptor) + + edges_ll_msg = self.edge_mlp(cat_input_for_msg_ligand) # m_{i->j} + edges_rr_msg = self.edge_mlp(cat_input_for_msg_receptor) + + mask = get_mask(ligand_graph_num_nodes, receptor_graph_num_nodes) + + nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg = self.nodes_aggr_cross_msg(h_feats_ligand, + h_feats_receptor, mask) + edge_coef_ligand = self.coors_mlp(edges_ll_msg) # \phi^x(m_{i->j}) + edges_ll_x_moment = ll_edge_x_rel * edge_coef_ligand # (x_i - x_j) * \phi^x(m_{i->j}) + + edge_coef_receptor = self.coors_mlp(edges_rr_msg) + edges_rr_x_moment = rr_edge_x_rel * edge_coef_receptor + + nodes_ligand_aggr_msg, x_final_ligand = self.cal_x_final(edges_ll_x_moment, edges_ll_msg, + orig_coors_ligand, nodes_ligand_x_now) + nodes_receptor_aggr_msg, x_final_receptor = self.cal_x_final(edges_rr_x_moment, edges_rr_msg, + orig_coors_receptor, nodes_receptor_x_now) + + if self.fine_tune: + input_fine_tune = ( + x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, + h_feats_receptor, nodes_receptor_x_now, mask + ) + x_final_ligand, x_final_receptor = self.fine_tune_final_lr(input_fine_tune) + + if self.debug: + debug_variable_tuple = ( + nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, + nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, x_final_ligand, + ) + self.log_debug_info(debug_variable_tuple) + + input_node_upd_ligand = ops.cat((self.node_norm(nodes_ligand_feat), nodes_ligand_aggr_msg, + nodes_ligand_aggr_cross_msg, original_ligand_node_features), axis=-1) + + input_node_upd_receptor = ops.cat((self.node_norm(nodes_receptor_feat), nodes_receptor_aggr_msg, + nodes_receptor_aggr_cross_msg, original_receptor_node_features), axis=-1) + + node_upd_ligand, node_upd_receptor = self.skip_connections(h_feats_ligand, h_feats_receptor, + input_node_upd_ligand, input_node_upd_receptor) + + return x_final_ligand, node_upd_ligand, x_final_receptor, node_upd_receptor + + +class IEGMN(nn.Cell): + """ + IEGMN class + """ + def __init__(self, args, n_lays, fine_tune, log_input=None): + + super(IEGMN, self).__init__() + + self.debug = args.debug + self.log = log_input + self.graph_nodes = args.graph_nodes + self.rot_model = args.rot_model + self.iegmn_lay_hid_dim = args.iegmn_lay_hid_dim + self.noise_decay_rate = args.noise_decay_rate + self.noise_initial = args.noise_initial + self.use_edge_features_in_gmn = args.use_edge_features_in_gmn + self.use_mean_node_features = args.use_mean_node_features + + # 21 types of amino-acid types + self.residue_emb_layer = nn.Embedding(vocab_size=21, embedding_size=args.residue_emb_dim, use_one_hot=False) + if self.graph_nodes != 'residues': + raise NotImplementedError + input_node_feats_dim = args.residue_emb_dim # One residue type + + if self.use_mean_node_features: + input_node_feats_dim += 5 # Additional features from mu_r_norm + + self.iegmn_layers = self.iegmn_layers_func(args, input_node_feats_dim, n_lays, fine_tune) + + if args.rot_model != 'kb_att': + raise ValueError("args rot_model should be kb_att") + + # Attention layers + self.num_att_heads = args.num_att_heads + self.out_feats_dim = self.iegmn_lay_hid_dim + + self.att_mlp_key_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False), + ) + self.att_mlp_query_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False), + ) + + self.mlp_h_mean_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.out_feats_dim), + nn.Dropout(args.dropout), + get_non_lin(args.nonlin, args.leakyrelu_neg_slope), + ) + + def __repr__(self): + return "IEGMN " + str(self.__dict__) + + def iegmn_layers_func(self, args, input_node_feats_dim, n_lays, fine_tune): + """ + IEGMN iegmn_layers_func + """ + + iegmn_layers = nn.CellList() + + iegmn_layers.append( + IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=input_node_feats_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + fine_tune=fine_tune, + args=args, + log_input=self.log)) + + if args.shared_layers: + interm_lay = IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log_input=self.log) + for _ in range(1, n_lays): + iegmn_layers.append(interm_lay) + + else: + for _ in range(1, n_lays): + iegmn_layers.append( + IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log_input=self.log)) + + return iegmn_layers + + def att_weights_rot(self, h_feats, h_feats_att_mean_rot, d, z_coors, all_y_att_rot_list): + """ + IEGMN att_weights_rot + """ + after_key_rot = self.att_mlp_key_rot(h_feats).view(-1, self.num_att_heads, d).transpose(1, 0, 2) + after_query_rot = self.att_mlp_query_rot(h_feats_att_mean_rot).view(1, self.num_att_heads, d).transpose(1, 2, 0) + att_weights_rot = ops.softmax(after_key_rot @ after_query_rot / math.sqrt(d), axis=1) + att_weights_rot = att_weights_rot.view(self.num_att_heads, -1) + + y_att_rot = att_weights_rot @ z_coors # K_heads, 3 + all_y_att_rot_list.append(y_att_rot) + + return y_att_rot, all_y_att_rot_list + + def ap_compute(self, list_hetero_graph): + """ + IEGMN ap_compute + """ + all_t_align_list, all_b_align_list, all_y_receptor_att_rot_list, all_y_ligand_att_rot_list = [], [], [], [] + + for _, hetero_graph in enumerate(list_hetero_graph): + + # Get H vectors + h_receptor_feats = hetero_graph["receptor_hv_iegmn_out"] # (m, d) + h_receptor_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_receptor_feats), axis=0, + keep_dims=True) # (1, d) + + h_ligand_feats = hetero_graph["ligand_hv_iegmn_out"] # (n, d) + h_ligand_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_ligand_feats), axis=0, + keep_dims=True) # (1, d) + + d = h_ligand_feats.shape[1] + + # Z coordinates + z_receptor_coors = hetero_graph["receptor_x_iegmn_out"] + z_ligand_coors = hetero_graph["ligand_x_iegmn_out"] + + #### AP 1: compute two point clouds of K_heads points each, then do Kabsch ##### + # Att weights to compute the receptor centroid. + # They query is the average_h_ligand. Keys are each h_receptor_j. + + y_receptor_att_rot, all_y_receptor_att_rot_list = self.att_weights_rot( + h_receptor_feats, h_ligand_feats_att_mean_rot, d, z_receptor_coors, all_y_receptor_att_rot_list + ) + + y_ligand_att_rot, all_y_ligand_att_rot_list = self.att_weights_rot( + h_ligand_feats, h_receptor_feats_att_mean_rot, d, z_ligand_coors, all_y_ligand_att_rot_list + ) + + ## Apply Kabsch algorithm + y_receptor_att_rot_mean = y_receptor_att_rot.mean(axis=0, keep_dims=True) # (1,3) + y_ligand_att_rot_mean = y_ligand_att_rot.mean(axis=0, keep_dims=True) # (1,3) + + a = (y_receptor_att_rot - y_receptor_att_rot_mean).transpose(1, 0) @ ( + y_ligand_att_rot - y_ligand_att_rot_mean) # 3, 3 + + if ops.isnan(a).any(): + raise ValueError("There is Nan in a") + u, s, vt = np.linalg.svd(a.asnumpy()) + u, s, vt = Tensor(u), Tensor(s), Tensor(vt) + + num_it = 0 + while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - + (s ** 2).view(3, 1) + ops.eye(3)))[0] < 1e-2: + if self.debug: + self.log('S inside loop ', num_it, ' is ', s, ' and A = ', a) + + a = a + ops.rand(3, 3) * ops.eye(3) + u, s, vt = np.linalg.svd(a.asnumpy()) + u, s, vt = Tensor(u), Tensor(s), Tensor(vt) + num_it += 1 + + if num_it > 10: + self.log('SVD consistently numerically unstable! Exitting ... ') + raise ValueError('SVD consistently numerically unstable!') + + corr_mat = ops.diag(Tensor([1, 1, float(ops.sign(ops.det(a)))])) + t_align = (u @ corr_mat) @ vt + + b_align = y_receptor_att_rot_mean - ops.t(t_align @ y_ligand_att_rot_mean.t()) # (1,3) + + #################### end AP 1 ######################### + + if self.debug: + self.log('Y_receptor_att_ROT_mean', y_receptor_att_rot_mean) + self.log('Y_ligand_att_ROT_mean', y_ligand_att_rot_mean) + + all_t_align_list.append(t_align) + all_b_align_list.append(b_align) + + return [all_t_align_list, all_b_align_list, all_y_ligand_att_rot_list, all_y_receptor_att_rot_list] + + def construct( + self, + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ): + """ + IEGMN construct + """ + ligand_graph_edge_tensor = input_tensor_tuple[2] + receptor_graph_edge_tensor = input_tensor_tuple[3] + + orig_coors_ligand = ligand_graph_node_tensor[:, 4:7] + orig_coors_receptor = receptor_graph_node_tensor[:, 1:4] + + coors_ligand = ligand_graph_node_tensor[:, 4:7] + coors_receptor = receptor_graph_node_tensor[:, 1:4] + + ## Embed residue types with a lookup table. + h_feats_ligand = self.residue_emb_layer( + ligand_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim) + h_feats_receptor = self.residue_emb_layer( + receptor_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim) + + if self.use_mean_node_features: + h_feats_ligand = ops.cat([h_feats_ligand, + ops.log(ligand_graph_node_tensor[:, 7:12])], axis=1) + h_feats_receptor = ops.cat([h_feats_receptor, + ops.log(receptor_graph_node_tensor[:, 4:9])], axis=1) + + original_ligand_node_features = h_feats_ligand + original_receptor_node_features = h_feats_receptor + + original_edge_feats_ligand = ligand_graph_edge_tensor * self.use_edge_features_in_gmn + original_edge_feats_receptor = receptor_graph_edge_tensor * self.use_edge_features_in_gmn + + for i, layer in enumerate(self.iegmn_layers): + if self.debug: + self.log('layer ', i) + + iegmn_layer_tuple = ( + coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand, + orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features, + original_edge_feats_receptor, orig_coors_receptor + ) + + coors_ligand, h_feats_ligand, coors_receptor, h_feats_receptor = \ + layer(iegmn_layer_tuple=iegmn_layer_tuple, input_tensor_tuple=input_tensor_tuple) + + if self.debug: + self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand before layers ') + self.log(ops.max(h_feats_ligand)[0], ops.norm(h_feats_ligand), + 'h_feats_ligand before layers but after mu_r_norm') + self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand after MPNN') + self.log(ops.max(coors_ligand)[0], 'coors_ligand before after MPNN') + + list_hetero_graph = unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, + h_feats_ligand, coors_receptor, coors_ligand) + + ap_res = self.ap_compute(list_hetero_graph) + + return ap_res + + +class RigidBodyDockingNet(nn.Cell): + """ + RigidBodyDockingNet + """ + def __init__(self, args, log_input=None): + + super(RigidBodyDockingNet, self).__init__() + + self.debug = args.debug + self.log = log_input + + self.iegmn_original = IEGMN(args, n_lays=args.iegmn_n_lays, fine_tune=False, log_input=log_input) + if args.fine_tune: + self.iegmn_fine_tune = IEGMN(args, n_lays=2, fine_tune=True, log_input=log_input) + self.list_iegmns = [('original', self.iegmn_original), ('finetune', self.iegmn_fine_tune)] + else: + self.list_iegmns = [('finetune', self.iegmn_original)] # just original + + def __repr__(self): + return "RigidBodyDockingNet " + str(self.__dict__) + + ####### FORWARD for RigidBodyDockingNet + def construct( + self, + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ): + """ + construct + """ + last_outputs = None + all_ligand_coors_deform_list = [] + + for stage, iegmn in self.list_iegmns: + outputs = iegmn( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ) + + if len(outputs) != 4: + raise ValueError("Number of outputs not correct") + + if stage == 'finetune': + last_outputs = outputs + + if stage == 'original': + new_list_hetero_graph = [] + + list_hetero_graph = [] + for i, _ in enumerate(unbatch_list_tensor): + if i < len(unbatch_list_tensor) - 1: + list_hetero_graph.append( + [ + ligand_graph_node_tensor[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + receptor_graph_node_tensor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + ] + ) + + for the_idx, hetero_graph in enumerate(list_hetero_graph): + orig_coors_ligand = hetero_graph[0][:, 4:7] + + t_align = outputs[0][the_idx] + b_align = outputs[1][the_idx] + if b_align.shape[0] != 1 or b_align.shape[1] != 3: + raise ValueError("Shape not correct") + + inner_coors_ligand = (t_align @ orig_coors_ligand.t()).t() + b_align # (n,3) + + if stage == 'original': + hetero_graph[0][:, 4:7] = inner_coors_ligand + new_list_hetero_graph.append(hetero_graph) + + if self.debug: + self.log('T_align', t_align) + self.log('T_align @ T_align.t() - eye(3)', t_align @ t_align.t() - ops.eye(3)) + self.log('b_align', b_align) + self.log('\n ---> inner_coors_ligand mean - true ligand mean ', + inner_coors_ligand.mean(axis=0)[0] - ligand_graph_node_tensor[:, 1:4].mean(axis=0)[0], + '\n') + + if stage == 'finetune': + all_ligand_coors_deform_list.append(inner_coors_ligand) + + all_keypts_ligand_list = last_outputs[2] + all_keypts_receptor_list = last_outputs[3] + all_rotation_list = last_outputs[0] + all_translation_list = last_outputs[1] + + return all_ligand_coors_deform_list, \ + all_keypts_ligand_list, all_keypts_receptor_list, \ + all_rotation_list, all_translation_list diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py new file mode 100644 index 000000000..b0d422497 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py @@ -0,0 +1,236 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train_utils""" + +import os +import math + +import ot +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor + +from .nn_arch import ( + preprocess_unbound_bound, + protein_to_graph_unbound_bound, + get_residues, + log, +) + + +def create_dir(path): + if os.path.exists(path): + raise FileExistsError('Path already exists. Please delete and restart your job.') + os.makedirs(path, exist_ok=False) + + +def graph_to_tensor(ligand_graph, receptor_graph): + """ + graph_to_tensor + """ + ligand_graph.ndata['new_x'] = ligand_graph.ndata['x'] + + # Create a batch of a single heterograph + ligand_graph_node_tensor = ops.cat( + (ligand_graph.ndata["res_feat"], + ligand_graph.ndata["x"], + ligand_graph.ndata["new_x"], + ligand_graph.ndata["mu_r_norm"]), + axis=1 + ) + + receptor_graph_node_tensor = ops.cat( + (receptor_graph.ndata["res_feat"], + receptor_graph.ndata["x"], + receptor_graph.ndata["mu_r_norm"]), + axis=1 + ) + + ligand_graph_num_nodes = Tensor([ligand_graph.num_nodes]) + receptor_graph_num_nodes = Tensor([receptor_graph.num_nodes]) + ligand_graph_edge_tensor = ligand_graph.edata["he"] + receptor_graph_edge_tensor = receptor_graph.edata["he"] + + ll_connection_tensor = ops.stack( + (Tensor(ligand_graph.src_list), Tensor(ligand_graph.dst_list))) + + rr_connection_tensor = ops.stack( + (Tensor(receptor_graph.src_list), Tensor(receptor_graph.dst_list))) + + unbatch_list = [ + [ + len(ll_connection_tensor[0]), + len(rr_connection_tensor[0]), + len(ligand_graph_node_tensor), + len(receptor_graph_node_tensor), + 0, + ] + ] + unbatch_list.insert(0, np.array([0, 0, 0, 0, 0])) + + unbatch_list = Tensor(unbatch_list, ms.int32) + + input_tensor_tuple = ( + ligand_graph_num_nodes, + receptor_graph_num_nodes, + ligand_graph_edge_tensor, + receptor_graph_edge_tensor, + ll_connection_tensor, + rr_connection_tensor, + ) + + return ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple + + +def prepare_graphs(args, ppdb_ligand, ligand_filename, receptor_filename): + """ + prepare_graphs + """ + unbound_ligand_all_atoms_pre_pos = ppdb_ligand.df["ATOM"][ + ['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32) + + ligand_input = get_residues(ligand_filename) + receptor_input = get_residues(receptor_filename) + input_tuple = (ligand_input, receptor_input) + unbound_predic_ligand, \ + unbound_predic_receptor, \ + bound_ligand_repres_nodes_loc_clean_array, \ + bound_receptor_repres_nodes_loc_clean_array, _ = preprocess_unbound_bound( + input_residues_tuple=input_tuple, + graph_nodes=args.graph_nodes, pos_cutoff=args.pocket_cutoff, inference=True) + + protein_to_graph_unbound_bound_input_tuple = ( + unbound_predic_ligand, unbound_predic_receptor, + bound_ligand_repres_nodes_loc_clean_array, bound_receptor_repres_nodes_loc_clean_array, + ) + + ligand_graph, receptor_graph = protein_to_graph_unbound_bound( + protein_to_graph_unbound_bound_input_tuple, + cutoff=args.graph_cutoff, + max_neighbor=args.graph_max_neighbor, + one_hot=False, + residue_loc_is_alphac=args.graph_residue_loc_is_alphaC, + ) + + return ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array + + +def get_rot_mat(euler_angles): + """ + get_rot_mat + """ + roll = euler_angles[0] + yaw = euler_angles[1] + pitch = euler_angles[2] + + tensor_0 = Tensor(0.0) + tensor_1 = Tensor(1.0) + cos = ops.cos + sin = ops.sin + + rx = ops.stack([ + ops.stack([tensor_1, tensor_0, tensor_0]), + ops.stack([tensor_0, cos(roll), -sin(roll)]), + ops.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3, 3) + + ry = ops.stack([ + ops.stack([cos(pitch), tensor_0, sin(pitch)]), + ops.stack([tensor_0, tensor_1, tensor_0]), + ops.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3, 3) + + rz = ops.stack([ + ops.stack([cos(yaw), -sin(yaw), tensor_0]), + ops.stack([sin(yaw), cos(yaw), tensor_0]), + ops.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) + + r = ops.mm(rz, ry) + r = ops.mm(r, rx) + + return r + + +def compute_ot_emd(cost_mat): + cost_mat_detach = cost_mat.asnumpy() + a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0] + b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1] + ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000) + ot_mat_attached = Tensor(ot_mat, ms.float32) + ot_dist = ops.sum(ot_mat_attached * cost_mat) + return ot_dist, ot_mat_attached + + +def g_fn(protein_coords, x, sigma): + # protein_coords: (n,3) , x: (m,3), output: (m,) + e = ops.exp(- ops.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n) + + return - sigma * ops.log(1e-3 + ops.sum(e, dim=1)) + + +def compute_body_intersection_loss( + model_ligand_coors_deform, + bound_receptor_repres_nodes_loc_array, + sigma, + surface_ct, + ): + """ + compute_body_intersection_loss + """ + g_fn_out1 = g_fn(Tensor(bound_receptor_repres_nodes_loc_array), model_ligand_coors_deform, sigma) + g_fn_out2 = g_fn(model_ligand_coors_deform, Tensor(bound_receptor_repres_nodes_loc_array), sigma) + loss1 = ops.clamp(surface_ct - g_fn_out1, min=0) + loss2 = ops.clamp(surface_ct - g_fn_out2, min=0) + loss = ops.mean(loss1) + ops.mean(loss2) + + return loss + + +def compute_sq_dist_mat(x_1, x_2): + '''Computes the l2 squared cost matrix between two point cloud inputs. + Args: + X_1: [n, #features] point cloud, tensor + X_2: [m, #features] point cloud, tensor + Output: + [n, m] matrix of the l2 distance between point pairs + ''' + n_1, _ = x_1.shape + n_2, _ = x_2.shape + x_1 = x_1.view(n_1, 1, -1) + x_2 = x_2.view(1, n_2, -1) + squared_dist = (x_1 - x_2) ** 2 + cost_mat = ops.sum(squared_dist, dim=2) + return cost_mat + + +def pretty_print_stats(*args): + + split_type, epoch, total_num_epochs,\ + complex_rmsd_mean, complex_rmsd_median,\ + ligand_rmsd_mean, ligand_rmsd_median,\ + receptor_rmsd_mean, receptor_rmsd_median,\ + _, _, _, avg_loss_ot, avg_loss_intersection, _ = args + + log('[{:s}] --> epoch {:d}/{:d} ' + '|| mean/median complex rmsd {:.4f} / {:.4f} ' + '|| mean/median ligand rmsd {:.4f} / {:.4f} ' + '|| mean/median sqrt pocket OT loss {:.4f} ' + '|| intersection loss {:.4f} ' + '|| mean/median receptor rmsd {:.4f} / {:.4f} '. + format(split_type, + epoch, total_num_epochs, + complex_rmsd_mean, complex_rmsd_median, + ligand_rmsd_mean, ligand_rmsd_median, + math.sqrt(avg_loss_ot), + avg_loss_intersection, + receptor_rmsd_mean, receptor_rmsd_median)) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py new file mode 100644 index 000000000..fc1d55347 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" + +from .progen import ProGen +from .progen_dataset import ProGenDataSet +from .progen_configuration import progen_configuration diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py new file mode 100644 index 000000000..0772554a6 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""module""" diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py new file mode 100644 index 000000000..dd34b9e75 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -0,0 +1,2295 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""configuration_utils""" + +import copy +from copy import deepcopy +import json +import os +import warnings +import inspect +from typing import Optional, List, Callable, Dict, Any, Tuple, Union +from enum import Enum +from collections import OrderedDict, UserDict +from dataclasses import fields +from dataclasses import dataclass + +import numpy as np +import mindspore +from mindspore import nn, ops, Tensor, jit_class + +from .logits_process import ( + ForcedEOSTokenLogitsProcessor, + LogitsProcessorList, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, +) + +DEFAULT_DTYPE = mindspore.float32 +INIT_WEIGHTS_FLAG = True + + +def _is_mindspore(x): + + return isinstance(x, mindspore.Tensor) + +def is_mindspore_available(): + return mindspore.get_context('device_target') == 'Ascend' + +def is_mindspore_tensor(x): + """ + Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. + """ + return False if not is_mindspore_available() else _is_mindspore(x) + +def no_grad(func): + """no grad wrapper""" + def wrapper(*args, **kwargs): + _pynative_executor.set_enable_grad(False) + outputs = func(*args, **kwargs) + _pynative_executor.set_enable_grad(True) + return outputs + return wrapper + + +class CellUtilMixin: + """ + A few utilities to be used as a mixin. + """ + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.expand_dims(-1) + else: + head_mask = () + for _ in range(num_hidden_layers): + head_mask += (None,) + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.ndim == 1: + head_mask = head_mask.expand_dims(0).expand_dims(0).expand_dims(-1).expand_dims(-1) + head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.ndim == 2: + head_mask = head_mask.expand_dims(1).expand_dims(-1)\ + .expand_dims(-1) # We can specify head_mask for each layer + assert head_mask.ndim == 5, f"head_mask.dim != 5, instead {head_mask.ndim}" + head_mask = head_mask.astype(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + @property + def dtype(self) -> mindspore.TensorType: + """ + `mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + +class ModelOutputMindnlp(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if not class_fields: + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for idx, element in enumerate(iterator): + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(v for _, v in self.items()) + + +@dataclass +class GenerateDecoderOnlyOutput(ModelOutputMindnlp): + """ + Outputs of decoder-only generation models, when using non-beam methods. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + logits: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None + + +@dataclass +class GenerateEncoderDecoderOutput(ModelOutputMindnlp): + """ + Outputs of encoder-decoder generation models, when using non-beam methods. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + logits: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None + + +@dataclass +class SampleDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using sampling. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class SampleEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of + the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class BeamSampleDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using beam sample. + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSampleEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using beam search. + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights + of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class GreedySearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using greedy search. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class GreedySearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class ContrastiveSearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using contrastive search. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class ContrastiveSearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using contrastive search. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] +GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] + + +class StoppingCriteriaList(list): + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + return any(criteria(input_ids, scores) for criteria in self) + + @property + def max_length(self) -> Optional[int]: + for stopping_criterium in self: + if not isinstance(stopping_criterium, MaxLengthCriteria): + raise TypeError + return stopping_criterium.max_length + + +class GenerationConfig: + """ + Class that holds a configuration for a generation task. + """ + def __init__(self, **kwargs): + # Parameters that control the length of the output + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + self.min_length = kwargs.pop("min_length", 0) + self.min_new_tokens = kwargs.pop("min_new_tokens", None) + self.early_stopping = kwargs.pop("early_stopping", False) + self.use_cache = kwargs.pop("use_cache", True) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) + + # # Parameters that define the output variables of `generate` + self.output_attentions = kwargs.pop("output_attentions", False) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_scores = kwargs.pop("output_scores", False) + self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) + + # Special tokens that can be used at generation time + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self._from_model_config = kwargs.pop("_from_model_config", False) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error("Can't set %s with value %s", key, value) + raise err + + # Validate the values of the attributes + self.validate() + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a Python dictionary of parameters. + """ + # Those arguments may be passed along for our internal telemetry. + # We remove them so they don't appear in `return_unused_kwargs`. + kwargs.pop("_from_auto", None) + kwargs.pop("_from_pipeline", None) + # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. + commit_hash_str = "_commit_hash" + if commit_hash_str in kwargs and commit_hash_str in config_dict: + kwargs[commit_hash_str] = config_dict[commit_hash_str] + + config = cls(**config_dict) + unused_kwargs = config.update(**kwargs) + + return config + + @classmethod + def from_model_config(cls, model_config) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy + [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`]. + """ + config_dict = model_config.to_dict() + config = cls.from_dict(config_dict, return_unused_kwargs=False) + config.set_from_model_config(True) + + return config + + + def set_from_model_config(self, value: bool): + """set _from_model_config""" + if not isinstance(value, bool): + raise TypeError + self._from_model_config = value + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + def validate(self): + """ + Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence + of parameterization that can be detected as incorrect from the configuration instance alone. + + Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the + model, such as parameters related to the generation length. + """ + + # Validation of individual attributes + if self.early_stopping not in {True, False, "never"}: + raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") + + +class GenerationMixin: + """ + class GenerationMixin + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + """ + + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + input_ids: Optional[mindspore.Tensor] = None, + **model_kwargs, + ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], mindspore.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + return input_ids, model_kwargs + + @staticmethod + def prepare_inputs_for_generation(*args, **kwargs): + """ + prepare_inputs_for_generation + """ + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + + @classmethod + def _maybe_initialize_input_ids_for_generation( + cls, + inputs: Optional[mindspore.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, + ) -> mindspore.Tensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, mindspore.Tensor): + batch_size = value.shape[0] + break + return ops.ones((batch_size, 1), dtype=mindspore.int64) * bos_token_id + + @classmethod + def _prepare_decoder_input_ids_for_generation( + cls, + model_input_name: str, + model_kwargs: Dict[str, mindspore.Tensor], + ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + input_ids_str = "input_ids_str" + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif input_ids_str in model_kwargs and model_input_name != input_ids_str: + decoder_input_ids = model_kwargs.pop(input_ids_str) + else: + decoder_input_ids = None + + return decoder_input_ids, model_kwargs + + @classmethod + def _merge_criteria_processor_list( + cls, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + """" + merge the criteria processor list + """ + if not custom_list: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `.generate()`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `.generate()` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + + def _get_logits_warper( + self, + generation_config: GenerationConfig, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # instantiate warpers list + warpers = LogitsProcessorList() + + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config.eos_token_id, list): + min_tokens_to_keep = len(generation_config.eos_token_id) + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + return warpers + + def generate( + self, + inputs: Optional[mindspore.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ) -> Union[GreedySearchDecoderOnlyOutput, mindspore.Tensor]: + """ + generate method + """ + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + generation_config = copy.deepcopy(self.generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs) + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if not self.config.is_encoder_decoder: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + if streamer is not None: + streamer.put(input_ids) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + logits_processor=logits_processor) + + # 8. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria) + + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + streamer=streamer, + **model_kwargs, + ) + + def sample( + self, + input_ids: mindspore.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ) -> Union[SampleOutput, mindspore.Tensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + """ + # init values + synced_gpus = None + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = mindspore.tensor(eos_token_id) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = ops.ones(input_ids.shape[0], dtype=mindspore.int64) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation_new(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self.construct( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + next_token_logits = outputs.logits[:, -1, :] + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = ops.softmax(next_token_scores, axis=-1) + next_tokens = ops.multinomial(probs, num_samples=1).squeeze(1).astype(mindspore.int64) + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + # update generated ids, model inputs, and length for next step + input_ids = ops.cat([input_ids, next_tokens[:, None]], axis=-1) + if streamer is not None: + streamer.put(next_tokens) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile((eos_token_id_tensor.shape[0], 1)).ne( + eos_token_id_tensor.unsqueeze(1)).prod(axis=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + if streamer is not None: + streamer.end() + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + return input_ids + + def _prepare_model_inputs( + self, + inputs: Optional[mindspore.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, + ) -> Tuple[mindspore.Tensor, Optional[str], Dict[str, mindspore.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if not self.config.is_encoder_decoder: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + f"Make sure to either pass {inputs} or {input_name}=..." + ) + if inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + inputs_embeds_str = "inputs_embeds" + input_ids_str = "input_ids" + if input_name == input_ids_str and inputs_embeds_str in model_kwargs: + if not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = inputs_embeds_str in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs[input_ids_str] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs[inputs_embeds_str][0], model_kwargs[inputs_embeds_str][1] + + # 4. if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, input_name, model_kwargs + + def _extract_past_from_model_output(self, outputs: ModelOutputMindnlp, standardize_cache_format: bool = False): + """ + extract past from model output + """ + past_key_values = None + if "past_key_values" in outputs: + past_key_values = outputs.past_key_values + elif "mems" in outputs: + past_key_values = outputs.mems + elif "past_buckets_states" in outputs: + past_key_values = outputs.past_buckets_states + + # Bloom fix: standardizes the cache format when requested + if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): + batch_size = outputs.logits.shape[0] + past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size) + return past_key_values + + def _update_model_kwargs_for_generation( + self, + outputs, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + """ + update past_key_values and update token_type_ids with last value + """ + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + model_kwargs["is_encoder_decoder"] = is_encoder_decoder + + token_type_str = "token_type_ids" + if token_type_str in model_kwargs: + token_type_ids = model_kwargs[token_type_str] + model_kwargs[token_type_str] = ops.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1) + + return model_kwargs + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: int, + logits_processor: Optional[LogitsProcessorList], + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + # instantiate processors list + processors = LogitsProcessorList() + + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 + ): + processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) + if ( + generation_config.min_new_tokens is not None + and generation_config.eos_token_id is not None + and generation_config.min_new_tokens > 0 + ): + processors.append( + MinNewTokensLengthLogitsProcessor( + input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id + ) + ) + + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) + + processors = self._merge_criteria_processor_list(processors, logits_processor) + + return processors + + def _get_stopping_criteria( + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if generation_config.max_length is not None: + criteria.append(MaxLengthCriteria(max_length=generation_config.max_length)) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + if not self.can_generate(): + raise NotImplementedError( + "TODO: You need to implement this function." + ) + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + warnings.warn( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.", + UserWarning, + ) + + # 2. Min length warnings due to unfeasible parameter combinations + min_length_error_suffix = ( + " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " + "increase the maximum length." + ) + if has_default_max_length: + min_length_error_suffix += ( + f" Note that `max_length` is set to {generation_config.max_length}, its default value." + ) + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + if generation_config.min_new_tokens is not None: + min_length = generation_config.min_new_tokens + input_ids_length + if min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " + f"added to the prompt length ({input_ids_length}), is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + + +class ExplicitEnum(str, Enum): + """ + Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class GenerationMode(ExplicitEnum): + """ + Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method. + """ + + # Non-beam methods + CONTRASTIVE_SEARCH = "contrastive_search" + GREEDY_SEARCH = "greedy_search" + SAMPLE = "sample" + ASSISTED_GENERATION = "assisted_generation" + # Beam methods + BEAM_SEARCH = "beam_search" + BEAM_SAMPLE = "beam_sample" + CONSTRAINED_BEAM_SEARCH = "constrained_beam_search" + GROUP_BEAM_SEARCH = "group_beam_search" + + +@jit_class +class PretrainedConfig: + """ + Abstract class for Pretrained models config. + """ + is_composition = False + # Add for handle attribute_map + attribute_map: Dict[str, str] = {} + + def __init__(self, **kwargs): + self.ms_dtype = kwargs.pop("ms_dtype", None) + if 'torch_dtype' in kwargs: + self.ms_dtype = kwargs.pop("torch_dtype", None) + self.return_dict = kwargs.pop("return_dict", True) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_attentions = kwargs.pop("output_attentions", False) + + self.pruned_heads = kwargs.pop("pruned_heads", {}) + self.tie_word_embeddings = kwargs.pop( + "tie_word_embeddings", True + ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder + self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) + self.is_decoder = kwargs.pop("is_decoder", False) + self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None) + self.add_cross_attention = kwargs.pop("add_cross_attention", False) + self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) + + # Parameters for sequence generation + self.max_length = kwargs.pop("max_length", 20) + self.min_length = kwargs.pop("min_length", 0) + self.do_sample = kwargs.pop("do_sample", False) + self.early_stopping = kwargs.pop("early_stopping", False) + self.num_beams = kwargs.pop("num_beams", 1) + self.num_beam_groups = kwargs.pop("num_beam_groups", 1) + self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) + self.temperature = kwargs.pop("temperature", 1.0) + self.top_k = kwargs.pop("top_k", 50) + self.top_p = kwargs.pop("top_p", 1.0) + self.typical_p = kwargs.pop("typical_p", 1.0) + self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + self.length_penalty = kwargs.pop("length_penalty", 1.0) + self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) + self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) + self.num_return_sequences = kwargs.pop("num_return_sequences", 1) + self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) + self.output_scores = kwargs.pop("output_scores", False) + self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) + self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) + self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) + self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) + self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) + self.suppress_tokens = kwargs.pop("suppress_tokens", None) + self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) + + # Fine-tuning task arguments + self.architectures = kwargs.pop("architectures", None) + self.finetuning_task = kwargs.pop("finetuning_task", None) + self.id2label = kwargs.pop("id2label", None) + self.label2id = kwargs.pop("label2id", None) + if self.label2id is not None and not isinstance(self.label2id, dict): + raise ValueError("Argument label2id should be a dictionary.") + if self.id2label is not None: + if not isinstance(self.id2label, dict): + raise ValueError("Argument id2label should be a dictionary.") + num_labels = kwargs.pop("num_labels", None) + + if num_labels is not None and len(self.id2label) != num_labels: + logger.warning( + f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " + f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." + ) + self.id2label = {int(key): value for key, value in self.id2label.items()} + # Keys are always strings in JSON so convert ids to int here. + else: + self.num_labels = kwargs.pop("num_labels", 2) + + if self.ms_dtype is not None and isinstance(self.ms_dtype, str): + if is_mindspore_available(): + self.ms_dtype = getattr(mindspore, self.ms_dtype) + + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.tokenizer_class = kwargs.pop("tokenizer_class", None) + self.prefix = kwargs.pop("prefix", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.sep_token_id = kwargs.pop("sep_token_id", None) + + self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + # task specific arguments + self.task_specific_params = kwargs.pop("task_specific_params", None) + + # regression / multi-label classification + self.problem_type = kwargs.pop("problem_type", None) + allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") + if self.problem_type is not None and self.problem_type not in allowed_problem_types: + raise ValueError( + f"The config parameter `problem_type` was not understood: received {self.problem_type} " + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." + ) + + # Name or path to the pretrained checkpoint + self._name_or_path = str(kwargs.pop("name_or_path", "")) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def __setattr__(self, key, value): + if key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + super().__setattr__(key, value) + + def __getattribute__(self, key): + if key != "attribute_map" and key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + return super().__getattribute__(key) + + @property + def name_or_path(self) -> str: + """get name_or_path""" + return getattr(self, "_name_or_path", None) + + @name_or_path.setter + def name_or_path(self, value): + """set name_or_path""" + self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + + @classmethod + def from_json(cls, file_path): + """load config from json.""" + with open(file_path, "r", encoding="utf-8") as file: + text = file.read() + config_map = json.loads(text) + config = cls() + for key, value in config_map.items(): + setattr(config, key, value) + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `Config` from a json file of parameters.""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + dict_obj = json.loads(text) + return cls(**dict_obj) + + @classmethod + def load(cls, pretrained_model_name_or_path): + """load config.""" + return cls.from_pretrained(pretrained_model_name_or_path) + + @property + def use_return_dict(self) -> bool: + """ + `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples. + """ + # If torchscript is set, force `return_dict=False` to avoid jit errors + return self.return_dict + + @classmethod + def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": + """ + Constructs a `Config` from a Python dictionary of parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + config = cls(**config_dict) + + if hasattr(config, "pruned_heads"): + config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) + + # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them.") + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info("Model config %s", str(config)) + if return_unused_kwargs: + return config, kwargs + return config + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + mirror: str = 'huggingface', + **kwargs, + ) -> "PretrainedConfig": + r""" + Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. + """ + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs['mirror'] = mirror + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + [`PretrainedConfig`] using `from_dict`. + """ + original_kwargs = copy.deepcopy(kwargs) + # Get config dict associated with the base config file + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + # That config file may point us toward another config file to use. + if "configuration_files" in config_dict: + configuration_file = get_configuration_file(config_dict["configuration_files"]) + config_dict, kwargs = cls._get_config_dict( + pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs + ) + + return config_dict, kwargs + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used + for instantiating a Config using `from_dict`. + + Parameters: + pretrained_model_name_or_path (:obj:`string`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + subfolder = kwargs.pop("subfolder", "") + token = kwargs.pop('token', None) + revision = kwargs.pop('revision', 'main') + mirror = kwargs.pop('mirror', 'huggingface') + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + # Special case when pretrained_model_name_or_path is a local file + resolved_config_file = pretrained_model_name_or_path + is_local = True + + elif is_remote_url(pretrained_model_name_or_path): + configuration_file = pretrained_model_name_or_path + resolved_config_file = download_url(pretrained_model_name_or_path) + + else: + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) + + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_config_file = cached_file( + pretrained_model_name_or_path, + configuration_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + revision=revision, + token=token, + subfolder=subfolder, + mirror=mirror + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception as exc: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it" + ", make sure you don't have a local directory with the same" + f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory" + f" containing a {configuration_file} file" + ) from exc + + try: + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + raise EnvironmentError( + f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + ) from exc + + if is_local: + logger.info(f"loading configuration file {resolved_config_file}") + else: + logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") + + return config_dict, kwargs + + @classmethod + def _dict_from_json_file(cls, json_file: str): + """_dict_from_json_file""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def dict_ms_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("ms_dtype", None) is not None and not isinstance(d["ms_dtype"], str): + d["ms_dtype"] = str(d["ms_dtype"]).lower() + for value in d.values(): + if isinstance(value, dict): + self.dict_ms_dtype_to_str(value) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] + + for key, value in output.items(): + # Deal with nested configs like CLIP + if isinstance(value, PretrainedConfig): + value = value.to_dict() + + output[key] = value + + if hasattr(self, "quantization_config"): + output["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + + self.dict_ms_dtype_to_str(output) + + return output + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) + ): + # For nested configs we need to clean the diff recursively + diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) + if "model_type" in value: + # Needs to be set even if it's not in the diff + diff["model_type"] = value["model_type"] + if diff: + serializable_config_dict[key] = diff + elif ( + key not in default_config_dict + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + return serializable_config_dict + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_file(self, save_path): + """Serializes this instance to a JSON file.""" + output_dict = self.to_dict() + with open(os.path.join(save_path, 'config.json'), encoding='utf-8') as f: + json.dump(output_dict, f, sort_keys=True, indent=2) + + def update(self, config_dict: Dict[str, Any]): + """ + Updates attributes of this class with attributes from `config_dict`. + """ + for key, value in config_dict.items(): + setattr(self, key, value) + + def save_pretrained(self, save_directory: Union[str, os.PathLike]): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + """ + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file, use_diff=True) + logger.info(f"Configuration saved in {output_config_file}") + + def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + """ + Save this instance to a JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + @property + def num_labels(self) -> int: + """ + `int`: The number of labels for classification models. + """ + return len(self.id2label) + + @num_labels.setter + def num_labels(self, num_labels: int): + if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels: + self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + + @property + def _attn_implementation(self): + """ + This property is made private for now (as it cannot be changed + and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + """ + if hasattr(self, "_attn_implementation_internal"): + if not self._attn_implementation_internal: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + return self._attn_implementation_internal + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + + +class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): + """ + Abstract class for Pretrained models + """ + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + + _tied_weights_keys = None + + _keep_in_fp32_modules = None + + supports_recompute = False + + def __init__(self, config): + super().__init__(config) + self._check_and_unset_acl() + # Save config in model + self.config = config + self.name_or_path = config.name_or_path + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + + def _check_and_unset_acl(self): + if "MS" in str(self.__class__.__name__) and \ + 'MS_DEV_FORCE_ACL' in os.environ: + del os.environ['MS_DEV_FORCE_ACL'] + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + + """ + self.init_weights() + + @staticmethod + def prepare_inputs_for_generation(*args, **kwargs): + """ + prepare_inputs_for_generation + """ + return + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + """ + model = cls(config, **kwargs) + + return model + + @property + def base_model(self): + """ + to get base_model + """ + return getattr(self, self.base_model_prefix, self) + + def get_input_embeddings(self) -> "nn.Cell": + """ + Returns the model's input embeddings. + + Returns: + :obj:`nn.Cell`: A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + raise NotImplementedError + + def set_input_embeddings(self, new_embeddings: nn.Cell): + """ + Set model's input embeddings. + + Args: + value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.set_input_embeddings(new_embeddings) + raise NotImplementedError + + def get_output_embeddings(self): + """ Get model's output embeddings + Return None if the model doesn't have output embeddings + """ + return None # Overwrite for models with output embeddings + + def set_output_embeddings(self, new_embeddings: nn.Cell): + """ + Set model's output embeddings. + + Args: + value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.set_output_embeddings(new_embeddings) + raise NotImplementedError + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + # Update base model and current model config + self.config.vocab_size = model_embeds.weight.shape[0] + self.vocab_size = model_embeds.weight.shape[0] + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + """ + Update new_num_tokens with the actual size of new_embeddings. + If word embeddings are not tied, make sure that lm head is resized as well. + """ + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + + self.set_input_embeddings(new_embeddings) + self.get_input_embeddings().weight.data_sync(True) + + if pad_to_multiple_of is not None: + new_num_tokens = new_embeddings.weight.shape[0] + + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() # pylint: disable=assignment-from-none + new_lm_head = self._get_resized_lm_head( + old_lm_head, new_num_tokens) + self.set_output_embeddings(new_lm_head) + self.get_output_embeddings().weight.data_sync(True) + + return self.get_input_embeddings() + + def resize_tokenizer_embeddings(self, new_num_tokens): + """ + Obtain a new embedding layer or use the original one without updating it. + """ + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings( + old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ Build a resized Embedding Module from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + """ + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, " + "which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. " + f"This means that the new embedding dimension will be {new_num_tokens}. This might induce " + "some performance reduction as *Tensor Cores* will not be available. For more details about " + "this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + old_num_tokens, old_embedding_dim = old_embeddings.weight.shape + if old_num_tokens == new_num_tokens: + return old_embeddings + + # Build new embeddings + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) + + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy word embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[ + :num_tokens_to_copy, :] + return new_embeddings + + def _get_resized_lm_head( + self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + ) -> nn.Dense: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + """ + if new_num_tokens is None: + return old_lm_head + + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.shape if not transposed else old_lm_head.weight.T.shape + ) + + if old_num_tokens == new_num_tokens: + return old_lm_head + + if not isinstance(old_lm_head, nn.Dense): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Dense}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Dense}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_lm_head = nn.Dense( + *new_lm_head_shape, + has_bias=has_new_lm_head_bias, + ) + + # initialize new lm head (in particular added tokens) + self._init_weights(new_lm_head) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + """ + Copy old lm head weights to new lm head. + Copy bias weights to new lm head. + """ + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + + @classmethod + def load(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, **kwargs): + """ + Load a pre-trained checkpoint from a pre-trained model file or url, + download and cache the pre-trained model file if model name in model list. + + Params: + pretrained_model_name_or_path: + """ + return cls.from_pretrained(pretrained_model_name_or_path, args, kwargs) + + def save(self, save_dir): + """ Save a model and its configuration file to a directory, so that + it can be re-loaded using the `:func:`PreTrainedModel.from_pretrained`` class method. + + Arguments: + save_dir: directory to which to save. + """ + if os.path.isfile(save_dir): + logger.error(f"Provided path ({save_dir}) should be a directory, not a file") + return + os.makedirs(save_dir, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = self.cell if hasattr(self, "cell") else self + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_dir, WEIGHTS_NAME) + save_checkpoint(model_to_save, output_model_file) + + logger.info(f"Model weights saved in {output_model_file}") + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternativelly, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning(warn_string) + + def num_parameters(self, only_trainable=False): + """return parameters count""" + total = 0 + param_set = set() + for param in self.get_parameters(): + param_id = param.uuid + if param_id not in param_set and (only_trainable or param.requires_grad): + total += param.size + param_set.add(param_id) + return total + + def trainable_params(self, recurse=True): + """ + fix duplicated weights + """ + return list(set(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = mindspore.save_checkpoint, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + """ + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + + # Only save the model itself if we are using distributed training + model_to_save = self + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.ms_dtype = str(dtype).lower() + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if is_main_process: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) + + + # Save the model + if state_dict is None: + state_dict = model_to_save.parameters_dict() + + # Shard the model if it is too big. + # weights_name = _add_variant(WEIGHTS_NAME, variant) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "np"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant)) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def check_names(self): + pass + + +class TensorType(ExplicitEnum): + """ + Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for + tab-completion in an IDE. + """ + + MINDSPORE = "ms" + NUMPY = "np" + + +class PaddingStrategy(ExplicitEnum): + """ + Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an + IDE. + """ + + LONGEST = "longest" + MAX_LENGTH = "max_length" + DO_NOT_PAD = "do_not_pad" + + +class StoppingCriteria(): + """Abstract base class for all stopping criteria that can be applied during generation.""" + + @staticmethod + def __call__(input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + raise NotImplementedError("StoppingCriteria needs to be subclassed") + + +class MaxLengthCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep + in mind for decoder-only type of transformers, this will include the initial prompted tokens. + """ + + def __init__(self, max_length: int): + self.max_length = max_length + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + return input_ids.shape[-1] >= self.max_length + + +@dataclass +class BaseModelOutput(ModelOutputMindnlp): + """ + Base class for model's outputs, with potential hidden states and attentions. + """ + + last_hidden_state: mindspore.Tensor = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutputMindnlp): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ + + last_hidden_state: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutputMindnlp): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ + + last_hidden_state: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + cross_attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutputMindnlp): + """ + Base class for causal language model (or autoregressive) outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutputMindnlp): + """ + Base class for causal language model (or autoregressive) outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + cross_attentions: Optional[Tuple[mindspore.Tensor]] = None + + +def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: + """validate stopping criteria""" + stopping_max_length = stopping_criteria.max_length + new_stopping_criteria = deepcopy(stopping_criteria) + if stopping_max_length is not None and stopping_max_length != max_length: + warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) + elif stopping_max_length is None: + new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) + return new_stopping_criteria + +def _is_numpy(x): + return isinstance(x, np.ndarray) + +def is_numpy_array(x): + """ + Tests if `x` is a numpy array or not. + """ + return _is_numpy(x) + +def infer_framework_from_repr(x): + """ + Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the + frameworks in a smart order, without the need to import the frameworks). + """ + representation = str(type(x)) + if representation.startswith("= input_rank: + raise ValueError(f"`dim` should be in range [{-input_rank}, {input_rank}), but got {input_rank, dim}") + + sizes_mul = reduce(operator.mul, list(sizes)) + if -1 not in sizes and sizes_mul != input_shape[dim_new]: + raise ValueError(f"unflatten: Provided `sizes` {sizes} don't multiply up to the" + f"size of dim {dim} ({input_shape[dim_new]}) in the input tensor") + + out_shape = input_shape[:dim_new] + tuple(sizes) + input_shape[dim_new + 1:] + return out_shape + + +# For all backend +# For functional api +# matmul +origin_matmul = ops.matmul +ops.matmulcus = fp16_patch_decorator(origin_matmul) +# mm +ops.mmcus = fp16_patch_decorator(origin_matmul) +# addbmm +origin_addbmm = ops.addbmm +ops.addbmmcus = fp16_patch_decorator(origin_addbmm) +# addmm +origin_addmm = ops.addmm +ops.addmmcus = fp16_patch_decorator(origin_addmm) +# addmv +origin_addmv = ops.addmv +ops.addmvcus = fp16_patch_decorator(origin_addmv) +# addr +origin_addr = ops.addr +ops.addrcus = fp16_patch_decorator(origin_addr) +# baddbmm +origin_baddbmm = ops.baddbmm +ops.baddbmmcus = fp16_patch_decorator(origin_baddbmm) +# bmm +origin_bmm = ops.bmm +ops.bmmcus = fp16_patch_decorator(origin_bmm) + + +# dense +def origin_dense(input_dense, weight, bias=None): + """patched dense""" + dense_ = _get_cache_prim(ops.Dense)() + return dense_(input_dense, weight, bias) + +ops.densecus = fp16_patch_decorator(origin_dense) + +def einsum_label_to_index(label): + if label == '.': + return 52 + num_of_letters = ord('z') - ord('a') + 1 + return (ord(label) - ord('A')) if (label.isupper()) else (num_of_letters + (ord(label) - ord('a'))) + +def enumerate_lhs(op_labels, lhs, num_ops): + """ + enumerate lhs in einsum function + """ + curr_op = 0 + found_ell = False + ell_skip = 0 + + for _, label in enumerate(lhs): + if label == ' ': + continue + if label == '.': + if ell_skip != 0: + ell_skip -= 1 + continue + assert not found_ell, f"einsum(): found {curr_op} for operand for which an ellipsis was already found" + + ell_skip = 2 + op_labels[curr_op].append(ELLIPSIS) + found_ell = True + elif label == ',': + curr_op += 1 + assert curr_op < num_ops, "einsum(): fewer operands were provided than specified in the equation" + found_ell = False + else: + + op_labels[curr_op].append(einsum_label_to_index(label)) + + return op_labels, lhs, num_ops, curr_op, found_ell + +def enumerate_operands(op_labels, operands, label_count): + """ + enumerate operands in einsum function + """ + # Compute label frequency and number of dimensions covered by ellipsis + # We do this after parsing labels to make it more readable and simpler + # to compute the number of dimensions covered by ellipsis. + ell_num_dim = 0 + + for i, operand in enumerate(operands): + labels = op_labels[i] + ndims = operand.ndim + nlabels = len(labels) + has_ellipsis = False + + for label in labels: + if label == ELLIPSIS: + nlabels -= 1 + has_ellipsis = True + ell_num_dim = max(ell_num_dim, ndims - nlabels) + else: + label_count[label] += 1 + if has_ellipsis: + assert nlabels <= ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \ + f") is more than the number of dimensions ({ndims}) for operand {i}" + else: + assert nlabels == ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \ + f") does not match the number of dimensions (" \ + f"{ndims}) for operand {i} and no ellipsis was given" + + return ell_num_dim, label_count + +def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, + ell_num_dim, ell_index, label_perm_index): + """ + unsqueeze missing dimension in einsum function + """ + permuted_operands = [] + for i, operand in enumerate(operands): + perm_shape = [-1] * perm_index + label_dim = [-1] * total_labels + operand = operands[i] + labels = op_labels[i] + original_sizes = operand.shape + + j = 0 + for label in labels: + if label == ELLIPSIS: + # Add missing dimensions covered by the ellipsis + num_missing_dim = ell_num_dim - \ + (len(original_sizes) - len(labels) + 1) + for k in range(num_missing_dim): + operand = ops.unsqueeze(operand, j) + for k in range(ell_num_dim): + perm_shape[ell_index + k] = j + j += 1 + elif label_dim[label] != -1: + dim = label_dim[label] + operand = ops.diagonal(operand, offset=0, dim1=dim, dim2=j) + operand = ops.moveaxis(operand, -1, dim) + else: + label_dim[label] = j + perm_shape[label_perm_index[label]] = j + j += 1 + + # Add dimensions for missing labels + for idx, index in enumerate(perm_shape): + if index == -1: + operand = ops.unsqueeze(operand, -1) + perm_shape[idx] = j + j += 1 + + operand = ops.transpose(operand, tuple(perm_shape)) + permuted_operands.append(operand) + + # Check if operands broadcast and keep track of last operand with + # dimension size != 1 for optimizing reductions + dim_last_op = [0] * perm_index + has_zero_size_dim = False + for dim in range(perm_index): + broadcast_size = permuted_operands[0].shape[dim] + for i in range(1, len(operands)): + dim_size = permuted_operands[i].shape[dim] + if broadcast_size != dim_size and broadcast_size != 1 and dim_size != 1: + raise RuntimeError("einsum(): operands do not broadcast with remapped shapes [original->remapped]") + if dim_size != 1: + broadcast_size = dim_size + dim_last_op[dim] = i + has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0) + + return permuted_operands, has_zero_size_dim, dim_last_op + +def einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs): + """ + intermediate operation in einsum function + """ + # Current index in the permuted shape + perm_index = 0 + # Start index of ellipsis dimensions in the permuted shape + ell_index = 0 + found_ell = False + + if arrow_pos == -1: + # Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim + found_ell = True + for label, label_count in enumerate(label_count): + if label_count == 1: + label_perm_index[label] = perm_index + perm_index += 1 + else: + rhs = equation[arrow_pos + 2:] + ell_skip = 0 + for i, label in enumerate(rhs): + if label == ' ': + continue + if label == '.': + if ell_skip != 0: + ell_skip -= 1 + continue + assert not found_ell, "einsum(): found \'.\' for output but an ellipsis (...) was already found" + + ell_skip = 2 + ell_index = perm_index + perm_index += ell_num_dim + found_ell = True + else: + assert str.isalpha(label), f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " \ + f"in the equation string, subscripts must be in [a-zA-Z]" + + index = einsum_label_to_index(label) + label_perm_index[index] = perm_index + perm_index += 1 + return perm_index, found_ell, label_perm_index, ell_index + +def sum_result(num_ops, permuted_operands, out_size, perm_index, dim_last_op, result): + """ + sum out or squeeze dimensions that are size 1 for all later operands + """ + for i in range(1, num_ops): + operand = permuted_operands[i] + sum_dims = [] + + dim = out_size + for j in range(dim, perm_index): + if dim_last_op[j] < i: + operand = ops.squeeze(operand, dim) + dim -= 1 + elif dim_last_op[j] == i: + if result.shape[dim] == 1: + operand = ops.sum(operand, dim) + result = ops.squeeze(result, dim) + dim -= 1 + else: + sum_dims.append(dim) + dim += 1 + if not sum_dims: + result = result.mul(operand) + elif len(sum_dims) == len(result.shape): + result = result.flatten().dot(operand.flatten()) + + return result + +def einsum(equation, *operands): + """ + einsum method + """ + assert operands, "einsum(): must provide at least one operand" + if isinstance(operands[0], tuple): + operands = operands[0] + + arrow_pos = equation.find("->") + num_ops = len(operands) + op_labels = [[] for _ in range(num_ops)] + lhs = equation[0: arrow_pos] + + op_labels, lhs, num_ops, curr_op, found_ell = enumerate_lhs(op_labels, lhs, num_ops) + + assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation" + # Labels must be within [a-zA-Z]. + total_labels = 52 + label_count = [0] * total_labels + # The maximum number of dimensions covered by any ellipsis, needed when + # unsqueezing missing dimensions from operands to permute and broadcast + ell_num_dim, label_count = enumerate_operands(op_labels, operands, label_count) + + # We want to align the dimensions of every input tensor to have + # shape out_dims + sum_dims. For this, we create a mapping of label + # to index into the permuted shape. + label_perm_index = [-1] * total_labels + + perm_index, found_ell, label_perm_index, ell_index = \ + einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs) + + out_size = perm_index + if not found_ell: + ell_index = perm_index + perm_index += ell_num_dim + + for label in range(total_labels): + if label_count[label] > 0 and label_perm_index[label] == -1: + label_perm_index[label] = perm_index + perm_index += 1 + + # Here we unsqueeze missing dimensions to make all operands have the same + # number of dimensions. We take diagonals for repeated labels within the + # same operand. Finally we permute the operands to align dimensions as + # per the perm_out_index we computed above. + permuted_operands, has_zero_size_dim, dim_last_op = \ + unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, + ell_num_dim, ell_index, label_perm_index) + + # Compute result + result = permuted_operands[0] + if has_zero_size_dim: + out_shape = [-1] * out_size + for i in range(out_size): + out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] + return ops.zeros(out_shape) + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for i in range(dim, perm_index): + if dim_last_op[i] == 0: + if result.shape[dim] == 1: + result = ops.squeeze(result, dim) + dim -= 1 + else: + result = ops.sum(result, dim) + dim -= 1 + dim += 1 + + result = sum_result(num_ops, permuted_operands, + out_size, perm_index, dim_last_op, result) + + return result + +ops.einsum = einsum + +# conv1d +ops.conv1dcus = fp16_patch_decorator(ops.conv1d) + + +def _ones(*size, dtype=None): + if dtype is None: + dtype = mindspore.float32 + if isinstance(size[0], tuple): + size = size[0] + ones_ = _get_cache_prim(ops.Ones)() + return ones_(size, dtype) + +ops.onescus = _ones + + +def _zeros(*size, dtype=None): + if dtype is None: + dtype = mindspore.float32 + if isinstance(size[0], tuple): + size = size[0] + zeros_ = _get_cache_prim(ops.Zeros)() + return zeros_(size, dtype) + + +ops.zeroscus = _zeros + +# for Tensor +# unfold +def _get_unfold_indices(input_shape, dimension, size, step): + if dimension < 0: + dimension += len(input_shape) + indices = [] + for i in range(0, input_shape[dimension] - size + 1, step): + indices.append(list(range(i, i + size))) + + return indices, dimension + + +def unfold(self, dimension, size, step): + """unfold""" + indices_new, dimension_new = _get_unfold_indices(self.shape, dimension, size, step) + indices = mindspore.Tensor(indices_new).astype(mindspore.int32) + output = ops.gather(self, indices, axis=dimension_new) + output = ops.moveaxis(output, dimension_new + 1, -1) + return output + +Tensor.unfold = unfold +StubTensor.unfold = unfold + + +# var_mean +def var_mean(input_vm, axis=None, *, correction=1, keepdims=False): + """var_mean""" + axis = Validator.check_and_canonicalize_axes(axis, input.ndim) + x_mean = ops.mean(input_vm, axis, True) + x_sub = ops.sub(input_vm, x_mean) + x_pow = ops.pow(x_sub, 2) + x_sum = ops.sum(x_pow, axis, keepdims) + res_mean = ops.mean(input_vm, axis, keepdims) + nums = 1 + if not axis: + nums = input_vm.size + else: + for ax in axis: + nums *= input_vm.shape[ax] + return ops.true_divide(x_sum, nums - correction), res_mean + +ops.var_mean = var_mean + + +# std_mean +def std_mean(input_sm, axis=None, *, correction=1, keepdims=False): + """std_mean""" + output = var_mean(input_sm, axis, correction=correction, keepdims=keepdims) + return ops.pow(output[0], 0.5), output[1] + +ops.std_mean = std_mean + + +# masked_fill +def masked_fill(inputs, mask, value): + """patched masked_fill""" + masked_value = ops.fill(inputs.dtype, inputs.shape, value) + return ops.select(mask, masked_value, inputs) + + +def _masked_fill(self, mask, value): + return masked_fill(self, mask, value) + + +ops.masked_fill = masked_fill +Tensor.masked_fill = _masked_fill +StubTensor.masked_fill = _masked_fill + + +# ops.std +def std(input_std, axis=None, ddof=0, keepdims=False): + """patched std""" + # Calculate mean + mean = ops.mean(input_std, axis=axis, keep_dims=keepdims) + + # Squared differences from the mean + squared_diff = (input_std - mean)**2 + + # Sum along the specified dimension + if axis is not None: + sum_along_dim = ops.sum(squared_diff, dim=axis, keepdim=keepdims) + else: + sum_along_dim = squared_diff.sum() + + # Calculate the correction factor + factor = 1.0 / (input_std.shape[axis] - ddof) if axis is not None else 1.0 / (input_std.size - ddof) + + # Calculate the standard deviation + out = ops.sqrt(factor * sum_along_dim) + + return out + + +def _std(self, axis=None, ddof=0, keepdims=False): + return std(self, axis, ddof, keepdims) + + +ops.std = std +Tensor.std = _std +StubTensor.std = _std + + +# Tensor.__contains__ +def _contains(self, key): + eq_res = ops.equal(self, key) + res = ops.any(eq_res) + return bool(res) + + +Tensor.__contains__ = _contains +StubTensor.__contains__ = _contains + + +def unflatten(self, dim, sizes): + """Tensor.unflatten""" + out_shape = _get_unflatten_size(self.shape, dim, sizes) + return self.reshape(out_shape) + + +Tensor.unflatten = unflatten +StubTensor.unflatten = unflatten + + +def _as_strided(self, size, stride, storage_offset=None): + """ + replace as_strided + """ + if len(size) != len(stride): + raise RuntimeError("mismatch in length of strides and shape.") + index = np.arange(0, size[0] * stride[0], stride[0]) + for i in range(1, len(size)): + tmp = np.arange(0, size[i] * stride[i], stride[i]) + index = np.expand_dims(index, -1) + index = index + tmp + if storage_offset is not None: + index = index + storage_offset + if index.size == 0: + input_indices = mindspore.numpy.empty(index.shape, dtype=mstype.int32) + else: + input_indices = Tensor(index) + out = ops.gather(self.reshape(-1), input_indices, 0) + return out + + +Tensor.as_strided = _as_strided +StubTensor.as_strided = _as_strided + + +def _nonzero(self, as_tuple=False): + if self.dtype == mstype.bool_: + self = self.astype(mstype.int64) + outs = ops.nonzero(self) + if as_tuple: + outs = ops.tensor_split(outs, self.ndim, -1) + outs = tuple(out.squeeze(-1) for out in outs) + return outs + + +Tensor.nonzero = _nonzero +StubTensor.nonzero = _nonzero + + +def _expand(self, *size): + if len(size) == 1: + size = size[0] + return ops.broadcast_to(self, size) + + +Tensor.expand = _expand +StubTensor.expand = _expand + + +mindspore.tensor = mindspore.Tensor +ops.prod = bool_patch_decorator(ops.prod) + + +def _prod(self, axis=None, keep_dims=False): + return ops.prod(self, axis, keep_dims) +Tensor.prod = _prod +StubTensor.prod = _prod + + +def _eq(self, other): + """ + replace __eq__ + """ + if not isinstance(other, (int, float, Tensor)): + return False + if isinstance(other, Tensor) and self.shape != other.shape: + return False + if id(self) == id(other): + return True + # bool type is not supported for `Equal` operator in backend. + if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): + self = self.to(mstype.int32) + other = other.to(mstype.int32) + return ops.eq(self, other) + + +Parameter.__eq__ = _eq + +old_repeat = Tensor.repeat + + +def new_repeat_interleave(input_ri, repeats, axis=None): + """new repeat_interleave""" + if axis is None: + input_ri = input_ri.reshape(-1) + axis = 0 + if isinstance(repeats, Tensor): + repeats = repeats.asnumpy().tolist() + output = old_repeat(input_ri, repeats, axis) + return output + + +ops.repeat_interleave = bool_io_patch_decorator(new_repeat_interleave) + + +def _repeat_interleave(self, repeats, dim): + return old_repeat(self, repeats, axis=dim) + + +Tensor.repeat_interleave = _repeat_interleave +StubTensor.repeat_interleave = _repeat_interleave + + +def _repeat(self, *sizes): + return ops.tile(self, tuple(sizes)) + + +Tensor.repeat = _repeat +StubTensor.repeat = _repeat + + +if LESS_MS_2_2: + mindspore.bfloat16 = None + + def eq(self, other): + """patched eq""" + return ops.equal(self, other) + + Tensor.eq = eq + StubTensor.eq = eq + + def _item(self): + return self.asnumpy().item() + Tensor.item = _item + StubTensor.item = _item + + def _tolist(self): + return self.asnumpy().tolist() + Tensor.tolist = _tolist + StubTensor.tolist = _tolist + + +# For Cells +class DenseMindnlp(nn.Cell): + """patched Dense""" + def __init__(self, + in_channels, + out_channels, + has_bias=True, + dtype=mstype.float32): + """Initialize Dense.""" + super().__init__() + self.in_channels = Validator.check_positive_int( + in_channels, "in_channels", self.cls_name) + self.out_channels = Validator.check_positive_int( + out_channels, "out_channels", self.cls_name) + self.has_bias = Validator.check_bool( + has_bias, "has_bias", self.cls_name) + + self.weight = Parameter(initializer( + HeUniform(math.sqrt(5)), [out_channels, in_channels], dtype=dtype), name="weight") + + self.bias = None + if self.has_bias: + fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) + bound = 1 / math.sqrt(fan_in) + self.bias = Parameter(initializer( + Uniform(bound), [out_channels], dtype=dtype), name="bias") + + def construct(self, x): + """ + construct method of DenseMindnlp + """ + if LESS_MS_2_2: + x_shape = x.shape + if len(x_shape) != 2: + x = x.reshape(-1, x.shape[-1]) + x = ops.matmul(x, self.weight.T) + if self.has_bias: + x = ops.add(x, self.bias) + if len(x_shape) != 2: + out_shape = x_shape[:-1] + (x.shape[-1],) + x = x.reshape(out_shape) + return x + return ops.dense(x, self.weight, self.bias) + + def extend_repr(self): + s = f'input_channels={self.in_channels}, output_channels={self.out_channels}' + if self.has_bias: + s += f', has_bias={self.has_bias}' + return s + + +class EmbeddingMindnlp(nn.Cell): + """patched Embedding""" + def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=False, dtype=mstype.float32): + """Initialize Embedding.""" + super().__init__() + self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) + self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) + Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) + Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) + self.use_one_hot = use_one_hot + self.dtype = dtype + self.padding_idx = padding_idx + self.weight = Parameter(initializer(Normal(1.0), [vocab_size, embedding_size]), name='weight') + if self.padding_idx and self.weight.init_flag: + self.weight[self.padding_idx] = 0 + + def construct(self, ids): + """ + construct method of EmbeddingMindnlp + """ + out_shape = ids.shape + (self.embedding_size,) + flat_ids = ids.reshape((-1,)) + + if self.use_one_hot: + one_hot_ids = ops.one_hot(flat_ids, self.vocab_size) + output_for_reshape = ops.matmul(one_hot_ids, self.weight) + else: + output_for_reshape = ops.gather(self.weight, flat_ids, 0) + + output = output_for_reshape.reshape(out_shape) + return output + + def extend_repr(self): + return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \ + f'weight={self.weight}, dtype={self.dtype}, padding_idx={self.padding_idx}' + + +class LayerNormMindnlp(nn.Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + """ + + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + epsilon=1e-5, + dtype=mstype.float32, + elementwise_affine=True + ): + """Initialize LayerNorm.""" + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], " + f"but got {normalized_shape} and the type is {type(normalized_shape)}.") + if not normalized_shape: + raise ValueError( + f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at " + f"least one element, but got normalized_shape = {normalized_shape}" + ) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.epsilon = epsilon + self.weight = Parameter(initializer( + gamma_init, normalized_shape, dtype=dtype), name="weight") + self.bias = Parameter(initializer( + beta_init, normalized_shape, dtype=dtype), name="bias") + self.layer_norm = ops.LayerNorm(begin_norm_axis=self.begin_norm_axis, + begin_params_axis=self.begin_params_axis, + epsilon=self.epsilon) + self.elementwise_affine = elementwise_affine + + def construct(self, input_x): + """ + construct method of LayerNormMindnlp + """ + if self.elementwise_affine: + y, _, _ = self.layer_norm(input_x, self.weight.astype(input_x.dtype), self.bias.astype(input_x.dtype)) + else: + y, _, _ = self.layer_norm(input_x, ops.ones(self.normalized_shape, input_x.dtype), + ops.zeros(self.normalized_shape, input_x.dtype),) + return y + + def extend_repr(self): + return f'normalized_shape={self.normalized_shape}, begin_norm_axis={self.begin_norm_axis}, ' \ + f'begin_params_axis={self.begin_params_axis}, weight={self.weight}, bias={self.bias}' + + +class BatchNorm1dMindnlp(nn.Cell): + """Batch Normalization base class.""" + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + weight_init='ones', + bias_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=None, + dtype=mstype.float32): + """Initialize _BatchNorm.""" + super().__init__() + if num_features < 1: + raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.") + + if momentum < 0 or momentum > 1: + raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], " + f"but got {momentum}.") + self.use_batch_statistics = use_batch_statistics + if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): + raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None," + f" but got {use_batch_statistics}.") + self.num_features = num_features + self.eps = eps + self.moving_mean_init = moving_mean_init + self.moving_var_init = moving_var_init + self.running_mean = Parameter(initializer( + moving_mean_init, num_features, dtype=dtype), name="running_mean", requires_grad=False) + self.running_var = Parameter(initializer( + moving_var_init, num_features, dtype=dtype), name="running_var", requires_grad=False) + self.weight = Parameter(initializer( + weight_init, num_features, dtype=dtype), name="weight", requires_grad=affine) + self.bias = Parameter(initializer( + bias_init, num_features, dtype=dtype), name="bias", requires_grad=affine) + + self.momentum = 1.0 - momentum + + self.bn_train = ops.BatchNorm(is_training=True, + epsilon=self.eps, + momentum=self.momentum) + + self.bn_infer = ops.BatchNorm(is_training=False, epsilon=self.eps) + + def construct(self, x): + """ + construct method of BatchNorm1dMindnlp + """ + if self.use_batch_statistics is None: + if self.training: + return self.bn_train(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + if not self.training: + return self.bn_infer(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + if self.use_batch_statistics: + return self.bn_train(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + return self.bn_infer(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + def extend_repr(self): + return f'num_features={self.num_features}, eps={self.eps}, momentum={1.0 - self.momentum}, ' \ + f'weight={self.weight}, bias={self.bias}, running_mean={self.running_mean}, ' \ + f'running_var={self.running_var}' diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py new file mode 100644 index 000000000..21efd44a1 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py @@ -0,0 +1,1037 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Logits process""" + +import math +import inspect +import logging +from typing import List, Optional +import mindspore +from mindspore import ops +import numpy as np + + +class LogitsProcessor: + """Abstract base class for all logit processors that can be applied during generation.""" + + def __call__(self, input_ids, scores): + """Method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsWarper: + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + def __call__(self, input_ids, scores): + """Method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a + `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each + [`LogitsProcessor`] or [`LogitsWarper`] to the inputs. + """ + + def __call__(self, input_ids, scores, **kwargs): + for processor in self: + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + if not all(arg in kwargs for arg in list(function_args.keys())[2:]): + raise ValueError( + f"Make sure that all the required parameters: {list(function_args.keys())} for " + f"{processor.__class__} are passed to the logits processor." + ) + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) + return scores + + +class HammingDiversityLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for + [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models](https://arxiv.org/pdf/1610.02424.pdf) for more details. + """ + + def __init__(self, diversity_penalty, num_beams, num_beam_groups): + if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): + raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") + self._diversity_penalty = diversity_penalty + if not isinstance(num_beams, int) or num_beams < 2: + raise ValueError("`num_beams` should be an integer strictly larger than 1.") + self._num_beams = num_beams + if not isinstance(num_beam_groups, int) or num_beam_groups < 2: + raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") + self._num_sub_beams = num_beams // num_beam_groups + + def __call__(self, input_ids, scores, current_tokens, beam_group_idx): + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step + batch_size = current_tokens.shape[0] // self._num_beams + group_start_idx = beam_group_idx * self._num_sub_beams + group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) + group_size = group_end_idx - group_start_idx + vocab_size = scores.shape[-1] + + if group_start_idx == 0: + return scores + + for batch_idx in range(batch_size): + # predicted tokens of last time step of previous groups + previous_group_tokens = current_tokens[ + batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx + ] + token_frequency = ops.bincount(previous_group_tokens, minlength=vocab_size) + scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency + + return scores + + +class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. + """ + + def __init__(self, penalty, encoder_input_ids): + if not isinstance(penalty, float) or (penalty <= 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = 1 / penalty + self.encoder_input_ids = encoder_input_ids + + def __call__(self, input_ids, scores): + score = ops.gather_elements(scores, 1, self.encoder_input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = ops.where(score < 0, score * self.penalty, score / self.penalty) + + scores.scatter_(1, self.encoder_input_ids, score) + return scores + + +class RepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. + """ + + def __init__(self, penalty): + if not isinstance(penalty, float) or (penalty <= 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + + def __call__(self, input_ids, scores): + input_ids = ops.where(input_ids >= scores.shape[1], input_ids - scores.shape[1], input_ids) + score = ops.gather_elements(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = ops.where(score < 0, score * self.penalty, score / self.penalty) + + scores = ops.tensor_scatter_elements(scores, input_ids, score, axis=1) + return scores + + +def _get_ngrams(ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int): + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].asnumpy().tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + return generated_ngrams + + +def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - ngram_size + ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) + return banned_ngrams.get(ngram_idx, []) + + +def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len): + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + + generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) + + banned_tokens = [ + _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) + for hypo_idx in range(num_hypos) + ] + return banned_tokens + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of n-grams. + """ + + def __init__(self, ngram_size): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids, scores): + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + +class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. + """ + + def __init__(self, encoder_ngram_size, encoder_input_ids): + if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: + raise ValueError( + f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" + ) + self.ngram_size = encoder_ngram_size + if len(encoder_input_ids.shape) == 1: + encoder_input_ids = encoder_input_ids.unsqueeze(0) + self.batch_size = encoder_input_ids.shape[0] + self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) + + def __call__(self, input_ids, scores): + # B x num_beams + num_hypos = scores.shape[0] + num_beams = num_hypos // self.batch_size + cur_len = input_ids.shape[-1] + banned_batch_tokens = [ + _get_generated_ngrams( + self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len + ) + for hypo_idx in range(num_hypos) + ] + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + +class NoBadWordsLogitsProcessor(LogitsProcessor): + """ + [`LogitsProcessor`] that enforces that specified sequences will never be sampled. + """ + + def __init__(self, bad_words_ids, eos_token_id): + if not isinstance(bad_words_ids, List) or not bad_words_ids: + raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") + if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): + raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) + for bad_word_ids in bad_words_ids + ): + raise ValueError( + f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." + ) + + if eos_token_id is None: + eos_token_id = [] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + bad_words_ids = list( + filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids) + ) + self.bad_words_id_length_1 = [] + self.bad_words_id_length_greater_than_1 = [] + for word in bad_words_ids: + if len(word) == 1: + self.bad_words_id_length_1.append(word[0]) + else: + self.bad_words_id_length_greater_than_1.append(word) + + self.static_bad_words_mask: Optional[mindspore.Tensor] = None + + for banned_token_seq in self.bad_words_id_length_greater_than_1: + if not banned_token_seq: + raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") + + def __call__(self, input_ids, scores): + if self.static_bad_words_mask is None and self.bad_words_id_length_1 > 0: + self.static_bad_words_mask = self._calc_static_bad_word_mask(scores) + + dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist()) + scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens) + + return scores + + def _calc_static_bad_word_mask(self, scores): + static_bad_words_mask = ops.zeros(scores.shape[1]) + static_bad_words_mask[self.bad_words_id_length_1] = 1 + return static_bad_words_mask.unsqueeze(0).bool() + + def _tokens_match(self, prev_tokens, tokens): + if not tokens: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_tokens): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + return prev_tokens[-len(tokens):] == tokens + + def _calc_banned_bad_words_ids(self, prev_input_ids): + """ + calculate banned bad words ids + """ + banned_tokens = [] + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + for banned_token_seq in self.bad_words_id_length_greater_than_1: + if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]): + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def _set_scores_to_inf_for_banned_tokens(self, scores, banned_tokens): + """ + Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a + list of list of banned tokens to ban in the format [[batch index, vocabulary position],... + """ + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + # Eliminates invalid bad word IDs that are over the vocabulary size. + if token <= scores.shape[1]: + banned_mask_list.append([idx, token]) + else: + logging.error( + "An invalid bad word ID is defined: %d. This ID is not contained in the " + "vocabulary, and is therefore ignored.", token + ) + if not banned_mask_list and self.static_bad_words_mask is None: + return scores + + if banned_mask_list: + banned_mask = mindspore.Tensor(banned_mask_list) + indices = ops.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = ( + mindspore.COOTensor(banned_mask, + indices, scores.shape) + .to_dense() + .bool() + ) + + if self.static_bad_words_mask is not None: + banned_mask = ops.bitwise_or(banned_mask, self.static_bad_words_mask) + else: + banned_mask = self.static_bad_words_mask + + scores = scores.masked_fill(banned_mask, -float("inf")) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. + """ + + def __init__(self, min_length, eos_token_id): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids, scores): + vocab_tensor = ops.arange(scores.shape[-1]) + eos_token_id = mindspore.tensor(self.eos_token_id) + eos_token_mask = vocab_tensor == eos_token_id + scores_processed = scores.copy() + if input_ids.shape[-1] < self.min_length: + scores_processed = ops.where(eos_token_mask, -math.inf, scores) + return scores_processed + + +class MinNewTokensLengthLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. + """ + + def __init__(self, prompt_length_to_skip, min_new_tokens, eos_token_id): + for arg_name, arg_value in [ + ("prompt_length_to_skip", prompt_length_to_skip), + ("min_new_tokens", min_new_tokens), + ("eos_token_id", eos_token_id), + ]: + if not isinstance(arg_value, int) or arg_value < 0: + raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") + + self.prompt_length_to_skip = prompt_length_to_skip + self.min_new_tokens = min_new_tokens + self.eos_token_id = eos_token_id + + def __call__(self, input_ids, scores): + new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip + if new_tokens_length < self.min_new_tokens: + scores[:, self.eos_token_id] = -float("inf") + + return scores + + +class PrefixConstrainedLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained + generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information. + """ + + def __init__(self, prefix_allowed_tokens_fn, num_beams): + self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self._num_beams = num_beams + + def __call__(self, input_ids, scores): + mask = ops.full_like(scores, -math.inf) + for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): + for beam_id, sent in enumerate(beam_sent): + mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 + + return scores + mask + + +class ForcedBOSTokenLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces the specified token as the first generated token. + + Args: + bos_token_id (`int`): + The id of the token to force as the first generated token. + """ + + def __init__(self, bos_token_id): + self.bos_token_id = bos_token_id + + def __call__(self, input_ids, scores): + cur_len = input_ids.shape[-1] + if cur_len == 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") + scores[:, self.bos_token_id] = 0 + return scores + + +class ForcedEOSTokenLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. + """ + + def __init__(self, max_length, eos_token_id): + self.max_length = max_length + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + + def __call__(self, input_ids, scores): + cur_len = input_ids.shape[-1] + if cur_len == self.max_length - 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = \ + float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + for i in self.eos_token_id: + scores[:, i] = 0 + return scores + + +class InfNanRemoveLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using + the logits processor should only be used if necessary since it can slow down the generation method. `max_length` is + reached. + """ + + def __call__(self, input_ids, scores): + # set all nan values to 0.0 + scores[ops.isnan(scores)] = 0.0 + + # set all inf values to max possible value + scores[scores == float("inf")] = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).max) + return scores + + +class ExponentialDecayLengthPenalty(LogitsProcessor): + r""" + [`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been + reached. + """ + + def __init__(self, exponential_decay_length_penalty, eos_token_id, input_ids_seq_length): + self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length + self.regulation_factor = exponential_decay_length_penalty[1] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + + def __call__(self, input_ids, scores): + cur_len = input_ids.shape[-1] + if cur_len > self.regulation_start: + for i in self.eos_token_id: + scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start) + return scores + + +class SuppressTokensLogitsProcessor(LogitsProcessor): + r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they + are not sampled.""" + + def __init__(self, suppress_tokens): + self.suppress_tokens = list(suppress_tokens) + + def __call__(self, input_ids, scores): + scores[:, self.suppress_tokens] = -float("inf") + return scores + + +class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): + r""" + [`SuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts + generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not + sampled at the beginning of the generation. + """ + + def __init__(self, begin_suppress_tokens, begin_index): + self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_index = begin_index + + def __call__(self, input_ids, scores): + if input_ids.shape[1] == self.begin_index: + scores[:, self.begin_suppress_tokens] = -float("inf") + + return scores + + +class ForceTokensLogitsProcessor(LogitsProcessor): + r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token + indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are + sampled at their corresponding index.""" + + def __init__(self, force_token_map): + self.force_token_map = dict(force_token_map) + + def __call__(self, input_ids, scores): + generation_idx = input_ids.shape[-1] + current_token = self.force_token_map.get(generation_idx, None) + if current_token is not None: + scores[:, :] = -float("inf") + scores[:, current_token] = 0 + return scores + + +class LogitNormalization(LogitsProcessor, LogitsWarper): + r""" + [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize + the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in + this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that + the scores are normalized when comparing the hypotheses. + """ + + def __call__(self, input_ids, scores): + scores = ops.log_softmax(scores, axis=-1) + return scores + + +class TemperatureLogitsWarper(LogitsWarper): + r""" + [`TemperatureLogitsWarper`] for temperature (exponential scaling output probability distribution). + Args: + temperature (:obj:`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature): + if not isinstance(temperature, float) or not temperature > 0: + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids, scores): + scores = scores / self.temperature + return scores + + +class TopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + """ + + def __init__(self, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1): + top_p = float(top_p) + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + if not isinstance(min_tokens_to_keep, int) or min_tokens_to_keep < 1: + raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids, scores): + if self.filter_value == -float("Inf"): + self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + # scores = ops.select(ops.isneginf(scores), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min), scores) + sorted_logits, sorted_indices = ops.sort(scores, descending=False) + cumulative_probs = ops.softmax(sorted_logits, axis=-1).cumsum(axis=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) + + # scatter sorted tensors to original indexing + sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 + + if isinstance(sorted_indices_to_remove[0][0].item(), bool): + sorted_indices_to_remove = sorted_indices_to_remove.astype("int32") + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + if isinstance(indices_to_remove[0][0].item(), int): + indices_to_remove = indices_to_remove.astype("bool") + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class TopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + """ + + def __init__(self, top_k, filter_value=-float("Inf"), min_tokens_to_keep=1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = max(top_k, min_tokens_to_keep) + self.filter_value = filter_value + + def __call__(self, input_ids, scores): + if self.filter_value == -float("Inf"): + self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + top_k = min(self.top_k, scores.shape[-1]) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < ops.topk(scores, top_k)[0][..., -1, None] + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + +class TypicalLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information. + """ + + def __init__(self, mass=0.9, filter_value=-float("Inf"), min_tokens_to_keep=1): + mass = float(mass) + if mass <= 0 or mass >= 1: + raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}") + if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): + raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") + + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids, scores): + # calculate entropy + normalized = ops.log_softmax(scores, axis=-1) + p = ops.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = ops.abs((-normalized) - ent) + sorted_scores, sorted_indices = ops.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(axis=-1).cumsum(axis=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).axis(dim=1) + last_ind.clamp_(max=sorted_scores.shape[-1] - 1) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, + sorted_indices, sorted_indices_to_remove, axis=1) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class EpsilonLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the + largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + """ + + def __init__(self, epsilon, filter_value=-float("Inf"), min_tokens_to_keep=1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = epsilon + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids, scores): + # Determine which indices to remove + probabilities = ops.softmax(scores, axis=-1) + indices_to_remove = probabilities < self.epsilon + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class EtaLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic + cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of + the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest + min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long + samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation + Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` + must be set to `True` for this `LogitsWarper` to work. + """ + + def __init__(self, epsilon, filter_value=-float("Inf"), min_tokens_to_keep=1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = mindspore.tensor(epsilon, mindspore.float32) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids, scores): + # Calculate the adaptive cutoff + probabilities = scores.softmax(dim=-1) + entropy = mindspore.nn.probability.distribution.Categorical().entropy(scores) + eta = ops.min(self.epsilon, ops.sqrt(self.epsilon) * ops.exp(-entropy))[..., None] + indices_to_remove = probabilities < eta + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + +class SequenceBiasLogitsProcessor(LogitsProcessor): + """ + [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence + when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than + one token, consider using beam methods (to gracefully work around partially completed sequences that have a + negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier). + """ + + def __init__(self, sequence_bias): + self.sequence_bias = sequence_bias + self._validate_arguments() + + # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size + # is inferred in the first usage, which inhibits initializing here) + self.length_1_bias = None + self.prepared_bias_variables = False + + def __call__(self, input_ids, scores): + # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. + if not self.prepared_bias_variables: + self._prepare_bias_variables(scores) + + # 2 - prepares an empty bias to add + bias = ops.zeros_like(scores) + + # 3 - include the bias from length = 1 + bias += self.length_1_bias + + # 4 - include the bias from length > 1, after determining which biased sequences may be completed. + for sequence_ids, sequence_bias in self.sequence_bias.items(): + if len(sequence_ids) == 1: # the sequence is of length 1, already applied + continue + if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore + continue + prefix_length = len(sequence_ids) - 1 + last_token = sequence_ids[-1] + matching_rows = ops.eq( + input_ids[:, -prefix_length:], + mindspore.tensor(sequence_ids[:-1], dtype=input_ids.dtype), + ).prod(dim=1) + bias[:, last_token] += ops.where( + matching_rows.bool(), + mindspore.tensor(sequence_bias), + mindspore.tensor(0.0), + ) + + # 5 - apply the bias to the scores + scores = scores + bias + return scores + +class AlternatingCodebooksLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel. + """ + + def __init__(self, input_start_len, semantic_vocab_size, codebook_size): + if not isinstance(input_start_len, int) or input_start_len < 0: + raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}") + + self.input_start_len = input_start_len + self.semantic_vocab_size = semantic_vocab_size + self.codebook_size = codebook_size + + def __call__(self, input_ids, scores): + curr_len = input_ids.shape[-1] + + # even -> first codebook, odd -> second codebook + is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0 + + if is_first_codebook: + scores[:, : self.semantic_vocab_size] = -float("inf") + scores[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf") + else: + scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf") + + return scores + +class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r"""Logits processor for Classifier-Free Guidance (CFG). The processors + computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits, + parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with + the `unconditional_ids` branch. + + See [the paper](https://arxiv.org/abs/2306.17806) for more information. + """ + + def __init__(self, guidance_scale, model, unconditional_ids, unconditional_attention_mask, use_cache): + self.guidance_scale = guidance_scale + self.model = model + self.unconditional_context = { + "input_ids": unconditional_ids, + "attention_mask": unconditional_attention_mask, + "use_cache": use_cache, + "past_key_values": None, + "first_pass": True, + } + + def get_unconditional_logits(self, input_ids): + """get_unconditional_logits""" + if self.unconditional_context["first_pass"]: + if self.unconditional_context["input_ids"] is None: + self.unconditional_context["input_ids"] = input_ids[:, -1:] + if self.unconditional_context["attention_mask"] is None: + self.unconditional_context["attention_mask"] = ops.ones_like( + self.unconditional_context["input_ids"], dtype=mindspore.int64 + ) + input_ids = self.unconditional_context["input_ids"] + attention_mask = self.unconditional_context["attention_mask"] + self.unconditional_context["first_pass"] = False + else: + attention_mask = ops.cat( + [ + self.unconditional_context["attention_mask"], + ops.ones_like(input_ids[:, -1:], dtype=mindspore.int64), + ], + axis=1, + ) + if not self.unconditional_context["use_cache"]: + input_ids = ops.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], axis=1) + else: + input_ids = input_ids[:, -1:] + self.unconditional_context["input_ids"] = input_ids + self.unconditional_context["attention_mask"] = attention_mask + + out = self.model( + input_ids, + attention_mask=attention_mask, + use_cache=self.unconditional_context["use_cache"], + past_key_values=self.unconditional_context["past_key_values"], + ) + self.unconditional_context["past_key_values"] = out.get("past_key_values", None) + + return out.logits + + def __call__(self, input_ids, scores): + scores = ops.log_softmax(scores, axis=-1) + if self.guidance_scale == 1: + return scores + + logits = self.get_unconditional_logits(input_ids) + + unconditional_logits = ops.log_softmax(logits[:, -1], axis=-1) + out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits + return out + +class WhisperTimeStampLogitsProcessor(LogitsProcessor): + r""" + + [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input + tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure + that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is + done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted + probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those + non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other + potential tokens. + + + See [the paper](https://arxiv.org/abs/2212.04356) for more information. + """ + + def __init__(self, generate_config): # support for the kwargs + self.eos_token_id = generate_config.eos_token_id + self.no_timestamps_token_id = generate_config.no_timestamps_token_id + self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + + self.begin_index = len(generate_config.forced_decoder_ids) + 2 + if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: + self.begin_index -= 1 + self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index + + def __call__(self, input_ids, scores): + # suppress <|notimestamps|> which is handled by without_timestamps + scores[:, self.no_timestamps_token_id] = -float("inf") + + if input_ids.shape[1] == self.begin_index - 1: + scores[:, :] = -float("inf") + scores[:, self.timestamp_begin] = 0 + return scores + + # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly + for k in range(input_ids.shape[0]): + seq = list(input_ids[k, self.begin_index :].tolist()) + last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + scores[k, self.timestamp_begin :] = -float("inf") + else: # cannot be normal text tokens + scores[k, : self.eos_token_id] = -float("inf") + + # apply the `max_initial_timestamp` option + if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None: + last_allowed = self.timestamp_begin + self.max_initial_timestamp_index + scores[:, last_allowed + 1 :] = -float("inf") + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = ops.log_softmax(scores.float(), axis=-1) + for k in range(input_ids.shape[0]): + timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(axis=-1) + max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() + if not ops.isnan(timestamp_logprob) and timestamp_logprob > max_text_token_logprob: + scores[k, : self.timestamp_begin] = -float("inf") + + return scores + +class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): + r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. + """ + + def __init__(self, eos_token_id, min_eos_p): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + if min_eos_p is not None and min_eos_p <= 0: + raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") + self.min_eos_p = min_eos_p + + def __call__(self, input_ids, scores): + if self.min_eos_p: + probs = ops.softmax(scores.float(), axis=-1) + # create scores full of -inf except for the eos_token_id + early_stop_scores = ops.ones_like(scores) * -float("inf") + early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] + + do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p + do_early_stop = ops.any(do_early_stop, axis=1, keep_dims=True) + scores = ops.where(do_early_stop, early_stop_scores, scores) + + return scores + + +class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, + where the first half correspond to the conditional logits (predicted from the input prompt) and the second half + correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a + weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. + + See [the paper](https://arxiv.org/abs/2306.05284) for more information. + """ + + def __init__(self, guidance_scale): + if guidance_scale > 1: + self.guidance_scale = guidance_scale + else: + raise ValueError( + "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale " + f"{guidance_scale}." + ) + + def __call__(self, input_ids, scores): + # simple check to make sure we have compatible batch sizes between our + # logits scores (cond + uncond) and input ids (cond only) + if scores.shape[0] != 2 * input_ids.shape[0]: + raise ValueError( + f"Logits should have twice the batch size of the input ids, the first half of batches " + "corresponding to the conditional inputs, and the second half of batches corresponding to " + f"the unconditional inputs. Got batch size {scores.shape[0]} for the logits and " + f" {input_ids.shape[0]} for the input ids." + ) + unguided_bsz = scores.shape[0] // 2 + cond_logits, uncond_logits = scores.split(unguided_bsz, axis=0) + scores_processed = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale + return scores_processed diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py new file mode 100644 index 000000000..2e19f36b2 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -0,0 +1,718 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" + +import math +import time +from typing import Tuple +from collections import OrderedDict + +import mindspore as ms +from mindspore import Tensor, nn, ops, Parameter +from mindspore.nn import CrossEntropyLoss + +from .module.injection import EmbeddingMindnlp, DenseMindnlp, LayerNormMindnlp +from .module.configuration_utils import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + PreTrainedModelMindnlp, + PretrainedConfig, +) + + +class ClassInstantier(OrderedDict): + r""" + Class Instantier + """ + + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + """ + Excitation equation matrix + """ + 'relu': nn.ReLU, + 'gelu': (nn.GELU, {"approximate": False}), + 'gelu_new': nn.GELU, + 'gelu_approximate': nn.GELU, + 'gelu_pytorch_tanh': nn.GELU, + "swish": nn.SiLU, + "gelu_10": nn.GELU, + "gelu_fast": nn.FastGelu, + "gelu_python": nn.GELU, + "linear": nn.ReLU, + "mish": nn.Mish, + "quick_gelu": nn.FastGelu, + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": nn.SiLU, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (ops.arange(0, dim, 2.0) / (dim * 1.0))) + sinusoid_inp = ops.einsum("i , j -> i j", ops.arange(seq_len), inv_freq).float() + return ops.sin(sinusoid_inp), ops.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = ops.stack((-x2, x1), axis=-1) + x_flatten = x.flatten(start_dim=-2) + return x_flatten # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class ProGenConfig(PretrainedConfig): + """ + ProGenConfig class + """ + model_type = "progen" + + def __init__( + self, vocab_size, n_positions, n_ctx, n_embd, n_layer, n_head, rotary_dim, + n_inner, activation_function, resid_pdrop, embd_pdrop, attn_pdrop, + layer_norm_epsilon, initializer_range, scale_attn_weights, gradient_checkpointing, + use_cache, bos_token_id, eos_token_id, min_length, **kwargs, + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.min_length = min_length + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer + + +class NewGELUActivation(nn.Cell): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + def construct(self, inputs: Tensor) -> Tensor: + return 0.5 * inputs * (1.0 + ops.tanh(math.sqrt(2.0 / math.pi) * (inputs + 0.044715 * ops.pow(inputs, 3.0)))) + + +class ProGenAttention(nn.Cell): + """ + ProGenAttention class + """ + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + temp_tri = ops.tril( + ops.ones((max_positions, max_positions), dtype=ms.int32)).view( + 1, 1, max_positions, max_positions + ) + self.bias = Parameter( + Tensor(temp_tri, dtype=ms.bool_), + name="bias", + requires_grad=False, + ) + + self.masked_bias = Parameter( + Tensor(-1e9), + name="masked_bias", + requires_grad=False, + ) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`:\ + {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = ops.sqrt(Tensor(self.head_dim, dtype=ms.float32)).to(ms.float16) + self.qkv_proj = DenseMindnlp(self.embed_dim, self.embed_dim * 3, has_bias=False) + + self.out_proj = DenseMindnlp(self.embed_dim, self.embed_dim, has_bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + + def construct( + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, + ): + """ + construct method of ProGenAttention + """ + layer_past, attention_mask, head_mask = add_input + qkv = self.qkv_proj(hidden_states) + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = ops.split(qkv_split, local_dim, axis=-1) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = ops.cat((k_rot, k_pass), axis=-1) + query = ops.cat((q_rot, q_pass), axis=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = ops.cat((past_key, key), axis=-2) + value = ops.cat((past_value, value), axis=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + def _split_heads(self, x, n_head, dim_head, mp_num): + reshaped = x.reshape(x.shape[: -1] + (n_head // mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[: -2] + (-1,) + reshaped.shape[-1: ]) + return reshaped + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.shape[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + """ + _attn method of ProGenAttention + """ + # compute causal mask from causal mask buffer + query_length, key_length = query.shape[-2], key.shape[-2] + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = Tensor(query, dtype=ms.float32) + key = Tensor(key, dtype=ms.float32) + + new_key = key.transpose(0, 1, 3, 2) + attn_weights = ops.matmul(query, new_key) + + attn_weights = attn_weights / self.scale_attn + attn_weights = ops.where(causal_mask, attn_weights, self.masked_bias) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(axis=-1)(attn_weights) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = ops.matmul(attn_weights, value) + + return attn_output, attn_weights + + +class ProGenMLP(nn.Cell): + """ + ProGenMLP class + """ + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd # n_embd=4096 + + self.fc_in = DenseMindnlp(embed_dim, intermediate_size) + self.fc_out = DenseMindnlp(intermediate_size, embed_dim) + + self.act = NewGELUActivation() + self.dropout = nn.Dropout(config.resid_pdrop) # config.resid_pdrop=0.0 + + def construct(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ProGenBlock(nn.Cell): + """ + ProGenBlock class + """ + def __init__(self, config): + super().__init__() + if config.n_inner is not None and config.n_inner != "None": + inner_dim = config.n_inner + else: + inner_dim = 4 * config.n_embd + self.ln_1 = LayerNormMindnlp(config.n_embd, epsilon=config.layer_norm_epsilon) + self.attn = ProGenAttention(config) + self.mlp = ProGenMLP(inner_dim, config) + + def construct( + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, + ): + """ + construct method of ProGenBlock + """ + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + add_input=add_input, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class ProGenPreTrainedModel(PreTrainedModelMindnlp): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + config_class = ProGenConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (DenseMindnlp,)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, EmbeddingMindnlp): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LayerNormMindnlp): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + else: + return + + +class ProGenModel(ProGenPreTrainedModel): + """ + ProGenModel class + """ + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = EmbeddingMindnlp(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.CellList([ProGenBlock(config) for _ in range(config.n_layer)]) + self.ln_f = LayerNormMindnlp(self.embed_dim, epsilon=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def construct( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + construct method of progen model + """ + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + if input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].shape[-2] + + if position_ids is None: + position_ids = ops.arange(past_length, input_shape[-1] + past_length) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.shape[-1],) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + use_cache = False + + else: + add_input = (layer_past, attention_mask, head_mask[i]) + outputs = block( + hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + add_input=add_input + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ProGenForCausalLM(ProGenPreTrainedModel): + """ + ProGenForCausalLM class + """ + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = ProGenModel(config) + self.lm_head = DenseMindnlp(config.n_embd, config.vocab_size) + self.loss_fct = CrossEntropyLoss() + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[Tensor]], beam_idx: Tensor) -> Tuple[Tuple[Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + def construct( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + """ + construct method of ProGenForCausalLM class + """ + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).to(ms.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss = self.loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1).astype("int32")) + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation_new(self, input_ids, past_key_values=None, **kwargs): + """ + new method of preparing inputs for generation + """ + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + +class PrintTime: + def __init__(self, desc): + self.desc = desc + + def __enter__(self): + print(self.desc) + self.t = time.time() + + def __exit__(self, input_type, value, traceback): + print(f'{self.desc} took {time.time()-self.t:.02f}s') diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py new file mode 100644 index 000000000..8f01cf3aa --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py @@ -0,0 +1,282 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" + +import os +import random + +import mindspore as ms +from mindspore import ops, Tensor +from tokenizers import Tokenizer + +from .nn_arch import ProGenForCausalLM, ProGenConfig, PrintTime +from ..model import Model + + +class ProGen(Model): + """ + ProGen class + """ + name = "ProGen" + + def __init__(self, config): + self.args = config + self.config = ProGenConfig( + config.vocab_size, config.n_positions, config.n_ctx, config.n_embd, + config.n_layer, config.n_head, config.rotary_dim, config.n_inner, + config.activation_function, config.resid_pdrop, config.embd_pdrop, + config.attn_pdrop, config.layer_norm_epsilon, config.initializer_range, + config.scale_attn_weights, config.gradient_checkpointing, config.use_cache, + config.bos_token_id, config.eos_token_id, config.min_length, + ) + self.mixed_precision = False + self.network = ProGenForCausalLM(self.config) + self.checkpoint_url = "Local_Checkpoint_Used" + self.FIRST_TOKEN = 5 + self.LAST_TOKEN = 29 + self.BOS_TOKEN = 3 + self.EOS_TOKEN = 4 + self.checkpoint_path = config.ckpt_dir + with PrintTime('loading tokenizer'): + self.tokenizer = self.create_tokenizer_custom(file=self.args.tokenizer_file) + self.set_env() + self.set_seed(self.args.rng_seed, deterministic=self.args.rng_deterministic) + + super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, + network=self.network, mixed_precision=self.mixed_precision) + + def set_env(self): + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + def set_seed(self, seed, deterministic=True): + print("deterministic", deterministic) + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + + def create_tokenizer_custom(self, file): + with open(file, 'r') as f: + return Tokenizer.from_str(f.read()) + + def cross_entropy(self, logits, target, reduction='mean'): + return ops.cross_entropy(input=logits, target=target, reduction=reduction) + + def log_likelihood(self, logits, target, reduction='mean'): + return -self.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1), reduction=reduction) + + def log_likelihood_custom_1(self, logits, target, reduction='mean'): + return -ops.nll_loss(inputs=ops.log_softmax(logits, axis=1), target=target, reduction=reduction) + + def log_likelihood_custom_2(self, logits, target, reduction='mean'): + log_likelihood = 0.0 + n = logits.shape[0] + for i in range(n): + log_likelihood += ops.log_softmax(logits, axis=1)[i, target[i]] / (1. if reduction == 'sum' else n) + return log_likelihood + + def cal_cross_entropy(self, tokens): + target = ms.tensor(self.tokenizer.encode(tokens).ids) + logits = self.network(target, labels=target).logits + # shift + logits = logits[:-1, ...] + target = target[1:] + return self.cross_entropy(logits=logits, target=target).item() + + def ll(self, tokens, f, reduction): + """ + shift, remove terminals and remove unused logits + """ + target = Tensor(self.tokenizer.encode(tokens).ids) + logits = self.network(target, labels=target).logits + + # shift + logits = logits[:-1, ...] + target = target[1:] + + # remove terminals + if target[-1] in [self.BOS_TOKEN, self.EOS_TOKEN]: + logits = logits[:-1, ...] + target = target[:-1] + + # remove unused logits + logits = logits[:, self.FIRST_TOKEN:(self.LAST_TOKEN + 1)] + target = target - self.FIRST_TOKEN + + return f(logits=logits, target=target, reduction=reduction).item() + + def from_pretrained(self, ckpt_path=None): + "from_pretrained" + self.get_checkpoint_path(ckpt_path) + if not ckpt_path: + param_dict = ms.load_checkpoint(self.checkpoint_path) + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(self.network, param_dict) + print(f'param not load: {param_not_load}') + + def predict(self, data, **kwargs): + if self.args.sanity: + + with PrintTime('sanity cross-entropy'): + + x_uniref90bfd30 = self.args.x_uniref90bfd30 + x_oas = self.args.x_oas + x_bfd90 = self.args.x_bfd90 + + checkpoint_x_ce = { + 'progen2-small': (x_uniref90bfd30, 2.4), + 'progen2-medium': (x_uniref90bfd30, 1.9), + 'progen2-base': (x_uniref90bfd30, 1.9), + 'progen2-large': (x_uniref90bfd30, 1.8), + 'progen2-xlarge': (x_uniref90bfd30, 1.0), + 'progen2-oas': (x_oas, 0.3), + 'progen2-BFD90': (x_bfd90, 1.3), + } + + ce_eval = self.cal_cross_entropy(checkpoint_x_ce[self.args.model][0]) + ce_target = checkpoint_x_ce[self.args.model][1] + + print(ce_target, ce_eval, abs(ce_eval - ce_target)) + + with PrintTime('sanity log-likelihood'): + + x_data = self.args.x_data + + ll_0 = self.ll(x_data, f=self.log_likelihood, reduction='mean') + ll_1 = self.ll(x_data, f=self.log_likelihood_custom_1, reduction='mean') + ll_2 = self.ll(x_data, f=self.log_likelihood_custom_2, reduction='mean') + + print(f'll_0={ll_0}') + print(f'll_1={ll_1}') + print(f'll_2={ll_2}') + + with PrintTime('sanity model'): + + ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', + 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + x_data = self.args.x_data + x_random = '2' + ''.join([random.choice(ALPHABET) for _ in range(len(x_data)-2)]) + '1' + x_perturb = x_random[:64] + x_data[len(x_random[:64]):] + + ll_x_data = self.ll(x_data, f=self.log_likelihood, reduction='mean') + ll_x_random = self.ll(x_random, f=self.log_likelihood, reduction='mean') + ll_x_perturb = self.ll(x_perturb, f=self.log_likelihood, reduction='mean') + + print(f'll_x_data={ll_x_data}') + print(f'll_x_random={ll_x_random}') + print(f'll_x_perturb={ll_x_perturb}') + + with PrintTime('log-likelihood (left-to-right, right-to-left)'): + + reverse = lambda s: s[::-1] + + ll_lr_sum = self.ll(tokens=data, f=self.log_likelihood, reduction='sum') + ll_rl_sum = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='sum') + + ll_lr_mean = self.ll(tokens=data, f=self.log_likelihood, reduction='mean') + ll_rl_mean = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='mean') + + ll_sum = .5 * (ll_lr_sum + ll_rl_sum) + ll_mean = .5 * (ll_lr_mean + ll_rl_mean) + + print(f'll_sum={(ll_sum)}') + print(f'll_mean={ll_mean}') + + def sample(self, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id): + """ + sample method of progen model + """ + input_ids = Tensor([tokenizer.encode(context).ids]) + tokens_batch = model.generate( + input_ids, + do_sample=True, + temperature=temp, + max_length=max_length, + top_p=top_p, + num_return_sequences=num_return_sequences, + pad_token_id=pad_token_id, + ) + as_lists = lambda batch: [batch[i, ...].asnumpy().tolist() for i in range(batch.shape[0])] + return tokenizer.decode_batch(as_lists(tokens_batch)) + + def truncate(self, input_sample, terminals): + """ + truncate method + """ + pos = [] + for terminal in terminals: + find_pos = input_sample.find(terminal, 1) + if find_pos != -1: + pos.append(find_pos) + if pos: + return input_sample[:(min(pos) + 1)] + return input_sample + + def generate(self): + """ + generate method + """ + if self.args.sanity: + with PrintTime('sanity cross-entropy'): + + x_uniref90bfd30 = self.args.x_uniref90bfd30 + x_oas = self.args.x_oas + x_bfd90 = self.args.x_bfd90 + + checkpoint_x_ce = { + 'progen2-small': (x_uniref90bfd30, 2.4), + 'progen2-medium': (x_uniref90bfd30, 1.9), + 'progen2-base': (x_uniref90bfd30, 1.9), + 'progen2-large': (x_uniref90bfd30, 1.8), + 'progen2-xlarge': (x_uniref90bfd30, 1.0), + 'progen2-oas': (x_oas, 0.3), + 'progen2-BFD90': (x_bfd90, 1.3), + } + + ce_eval = self.cal_cross_entropy(checkpoint_x_ce.get(self.args.model)[0]) + ce_target = checkpoint_x_ce.get(self.args.model)[1] + + if abs(ce_eval - ce_target) >= 0.1: + raise ValueError("Difference should be within 0.1") + + with PrintTime('sampling'): + completions = self.sample(model=self.network, tokenizer=self.tokenizer, context=self.args.context, + pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], + num_return_sequences=self.args.num_samples, temp=self.args.t, + top_p=self.args.p, max_length=self.args.max_length) + truncations = [self.truncate(completion, terminals=['1', '2']) for completion in completions] + + print(self.args.num_samples, "sequences are sampled in total.") + for (i, truncation) in enumerate(truncations): + print("The generated sequence with index", i, "is: ", truncation) + + + def train_step(self, data): + return None + + + def forward(self, data): + return None + + + def backward(self, data): + return None + + + def _jit_forward(self, data): + return None + + + def _pynative_forward(self, data): + return None diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py new file mode 100644 index 000000000..b08eca21e --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py @@ -0,0 +1,20 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen_configuration""" + +progen_configuration = { + "small": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/ProGen/small.yaml", +} diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py new file mode 100644 index 000000000..f8232e64c --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -0,0 +1,46 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen_dataset""" + +from .nn_arch import PrintTime +from ...dataset import PSP + + +class ProGenDataSet(PSP): + "EquiDockDataSet" + def __init__(self, config): + self.config = config + super().__init__() + + def process(self, data, **kwargs): + return data + + def set_training_data_src(self, data_source, **kwargs): + with PrintTime('set_training_data_src'): + print(data_source) + print(**kwargs) + + def create_iterator(self, num_epochs, **kwargs): + return None + + def data_parse(self, idx): + return None + + #pylint: disable=arguments-differ + def __getitem__(self, idx): + pass + + def __len__(self): + pass diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json new file mode 100644 index 000000000..cd3d5a9d0 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json @@ -0,0 +1,91 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "special": true, + "content": "<|pad|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 1, + "special": true, + "content": "<|bos|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 2, + "special": true, + "content": "<|eos|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": { + "<|pad|>": 0, + "<|bos|>": 1, + "<|eos|>": 2, + "1": 3, + "2": 4, + "A": 5, + "B": 6, + "C": 7, + "D": 8, + "E": 9, + "F": 10, + "G": 11, + "H": 12, + "I": 13, + "K": 14, + "L": 15, + "M": 16, + "N": 17, + "O": 18, + "P": 19, + "Q": 20, + "R": 21, + "S": 22, + "T": 23, + "U": 24, + "V": 25, + "W": 26, + "X": 27, + "Y": 28, + "Z": 29 + }, + "merges": [] + } +} \ No newline at end of file diff --git a/MindSPONGE/src/mindsponge/pipeline/pipeline.py b/MindSPONGE/src/mindsponge/pipeline/pipeline.py index 60a41859b..850a55fb6 100644 --- a/MindSPONGE/src/mindsponge/pipeline/pipeline.py +++ b/MindSPONGE/src/mindsponge/pipeline/pipeline.py @@ -35,10 +35,11 @@ from .models import RASP, RASPDataSet, rasp_configuration from .models import Multimer, MultimerDataSet, multimer_configuration from .models import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configuration from .models import UFold, UFoldDataSet, ufold_configuration +from .models import EquiDock, EquiDockDataSet, equidock_configuration +from .models import ProGen, ProGenDataSet, progen_configuration from .models import ProtT5, ProtT5TrainDataSet, prott5pretrain_configuration from .models import ProtT5DownstreamTasks, ProtT5TaskDataSet, prott5downtask_configuration - model_card = { "ColabDesign": {"model": COLABDESIGN, "dataset": ColabDesignDataSet, "config": colabdesign_configuration}, "DeepDR": {"model": DeepDR, "dataset": DeepDRDataSet, "config": deepdr_configuration}, @@ -55,6 +56,8 @@ model_card = { "Proteinmpnn": {"model": ProteinMpnn, "dataset": ProteinMpnnDataset, "config": proteinmpnn_configuration}, "RASP": {"model": RASP, "dataset": RASPDataSet, "config": rasp_configuration}, "UFold": {"model": UFold, "dataset": UFoldDataSet, "config": ufold_configuration}, + "EquiDock": {"model": EquiDock, "dataset": EquiDockDataSet, "config": equidock_configuration}, + "ProGen": {"model": ProGen, "dataset": ProGenDataSet, "config": progen_configuration}, "ProtT5": {"model": ProtT5, "dataset": ProtT5TrainDataSet, "config": prott5pretrain_configuration}, "ProtT5Downstream": {"model": ProtT5DownstreamTasks, "dataset": ProtT5TaskDataSet, "config": prott5downtask_configuration} -- Gitee From b0660638fbae70fbdc0a55e8dcfb5cf5f25c497a Mon Sep 17 00:00:00 2001 From: yzhang Date: Wed, 27 Nov 2024 16:57:59 +0800 Subject: [PATCH 3/5] add EquiDock and Progen --- .../src/mindsponge/pipeline/models/progen/.keep | 0 .../src/mindsponge/pipeline/models/progen/progen.py | 13 +++++++------ 2 files changed, 7 insertions(+), 6 deletions(-) delete mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/.keep diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/.keep b/MindSPONGE/src/mindsponge/pipeline/models/progen/.keep deleted file mode 100644 index e69de29bb..000000000 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py index 8f01cf3aa..dc71fa63e 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py @@ -24,6 +24,13 @@ from tokenizers import Tokenizer from .nn_arch import ProGenForCausalLM, ProGenConfig, PrintTime from ..model import Model +FIRST_TOKEN = 5 +LAST_TOKEN = 29 +BOS_TOKEN = 3 +EOS_TOKEN = 4 +ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', + 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + class ProGen(Model): """ @@ -44,10 +51,6 @@ class ProGen(Model): self.mixed_precision = False self.network = ProGenForCausalLM(self.config) self.checkpoint_url = "Local_Checkpoint_Used" - self.FIRST_TOKEN = 5 - self.LAST_TOKEN = 29 - self.BOS_TOKEN = 3 - self.EOS_TOKEN = 4 self.checkpoint_path = config.ckpt_dir with PrintTime('loading tokenizer'): self.tokenizer = self.create_tokenizer_custom(file=self.args.tokenizer_file) @@ -163,8 +166,6 @@ class ProGen(Model): with PrintTime('sanity model'): - ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', - 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] x_data = self.args.x_data x_random = '2' + ''.join([random.choice(ALPHABET) for _ in range(len(x_data)-2)]) + '1' x_perturb = x_random[:64] + x_data[len(x_random[:64]):] -- Gitee From 542391228800af689f6aa1e0fffef469df4e9c78 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Wed, 27 Nov 2024 20:49:21 +0800 Subject: [PATCH 4/5] add EquiDock and ProGen models --- MindSPONGE/src/mindsponge/pipeline/models/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py index 3a01e8077..8df8379a5 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py @@ -37,3 +37,6 @@ from .ufold import UFold, UFoldDataSet, ufold_configuration from .rasp import RASP, RASPDataSet, rasp_configuration from .prot_t5 import ProtT5, ProtT5TrainDataSet, prott5pretrain_configuration from .prot_t5 import ProtT5DownstreamTasks, ProtT5TaskDataSet, prott5downtask_configuration +from .equidock import EquiDock, EquiDockDataSet, equidock_configuration +from .progen import ProGen, ProGenDataSet, progen_configuration + -- Gitee From 4a108d359a4ab5dd6fee0a3295a15583cd238904 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Wed, 27 Nov 2024 20:56:13 +0800 Subject: [PATCH 5/5] add new model ProGen and EquiDock --- MindSPONGE/src/mindsponge/pipeline/models/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py index 8df8379a5..806c476bc 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py @@ -39,4 +39,3 @@ from .prot_t5 import ProtT5, ProtT5TrainDataSet, prott5pretrain_configuration from .prot_t5 import ProtT5DownstreamTasks, ProtT5TaskDataSet, prott5downtask_configuration from .equidock import EquiDock, EquiDockDataSet, equidock_configuration from .progen import ProGen, ProGenDataSet, progen_configuration - -- Gitee