diff --git a/MindSPONGE/requirements.txt b/MindSPONGE/requirements.txt
index 9b940d1b58dbe91fe925d1b48212b4d50e90064b..174bca955d4bcb333c63557895df9fb49ee37a6c 100644
--- a/MindSPONGE/requirements.txt
+++ b/MindSPONGE/requirements.txt
@@ -1,5 +1,6 @@
numpy >= 1.17.0,<=1.23.4
scipy >= 1.7.0
+biopandas == 0.4.1
biopython == 1.81
pyyaml >= 5.4.1
dataclasses >= 0.6
@@ -9,8 +10,11 @@ absl-py >= 1.1.0
biotite == 0.40.0
descriptastorus == 2.6.1
pyparsing >= 3.0.7
+POT == 0.9.3
+tokenizers
+joblib
rdkit
bio
scikit-learn
mindformers >= 1.2.0
-sentencepiece >= 0.2.0
\ No newline at end of file
+sentencepiece >= 0.2.0
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py
index 3a01e80770f0372da696b2f99c7f1da816e56252..806c476bc316640d0dd3bae5620b4f7a728bc304 100644
--- a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py
+++ b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py
@@ -37,3 +37,5 @@ from .ufold import UFold, UFoldDataSet, ufold_configuration
from .rasp import RASP, RASPDataSet, rasp_configuration
from .prot_t5 import ProtT5, ProtT5TrainDataSet, prott5pretrain_configuration
from .prot_t5 import ProtT5DownstreamTasks, ProtT5TaskDataSet, prott5downtask_configuration
+from .equidock import EquiDock, EquiDockDataSet, equidock_configuration
+from .progen import ProGen, ProGenDataSet, progen_configuration
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c604d06e3821c6826fb93c7cfffe8858dde27ac
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""equidock"""
+
+from .equidock import EquiDock
+from .equidock_dataset import EquiDockDataSet
+from .equidock_configuration import equidock_configuration
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f951efb5e8797ba834f67d2415bc54d8ca140c3
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py
@@ -0,0 +1,383 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""equidock"""
+
+import os
+import random
+from datetime import datetime as dt
+
+import numpy as np
+from biopandas.pdb import PandasPdb
+import mindspore as ms
+from mindspore import nn, Tensor, ops, save_checkpoint
+from mindspore.experimental import optim
+
+from .train_utils import create_dir, prepare_graphs, graph_to_tensor, \
+ pretty_print_stats, get_rot_mat, compute_sq_dist_mat, compute_ot_emd, \
+ compute_body_intersection_loss
+from .nn_arch import MeterUnboundBound, RigidBodyDockingNet, log, FLAGS
+from .equidock_dataset import EquiDockDataSet
+from ..model import Model
+
+
+class EquiDock(Model):
+ """
+ EquiDock class
+ """
+ name = "EquiDock"
+
+ def __init__(self, config):
+
+ self.config = config
+ self.mixed_precision = False
+ self.network = RigidBodyDockingNet(self.config)
+ self.log(self.network)
+ self.checkpoint_url = "Local_Checkpoint_Used"
+ self.checkpoint_path = self.config.ckpt_dir
+
+ if self.config.is_train:
+ self.config.cache_path = './cache/' + self.config.data + '_' + \
+ self.config.graph_nodes + '_maxneighbor_' + str(self.config.graph_max_neighbor) + \
+ '_cutoff_' + str(self.config.graph_cutoff) + '_pocketCut_' + \
+ str(self.config.pocket_cutoff) + '/'
+ self.config.cache_path = os.path.join(self.config.cache_path, 'cv_' + str(self.config.split))
+ banner = 'EQUIDOCK'
+ self.config.checkpoint_dir = './checkpts/' + banner
+ self.config.log_files_dir = './stdouterr/'
+
+ ### Create log files only when in train mode
+ create_dir(self.config.log_files_dir)
+ create_dir(self.config.checkpoint_dir)
+
+ log_file_name = os.path.join(self.config.log_files_dir, banner + ".txt")
+ with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as fout:
+ fout.write('[' + str(dt.now()) + '] START\n')
+ self.config.checkpoint_filename = os.path.join(
+ self.config.checkpoint_dir,
+ self.config.data + '_model_best.pth'
+ )
+
+ self.loss_fn_coors = nn.MSELoss(reduction='mean')
+ self.optimizer = optim.Adam(
+ self.network.trainable_params(),
+ lr=self.config.lr,
+ weight_decay=self.config.w_decay
+ )
+ self.scheduler = optim.lr_scheduler.LambdaLR(
+ self.optimizer,
+ lr_lambda=[lambda epoch: min(1., ((epoch + 1) / self.config.warmup) ** 3)]
+ )
+
+ self.best_epoch = 0
+ self.best_val_rmsd_median = float('inf')
+ self.corr_val_rmsd_mean = float('inf')
+
+ dataset_util = EquiDockDataSet(self.config)
+ self.train_data_batched, self.train_loader, self.val_data_batched, self.val_loader, \
+ self.test_data_batched, self.test_loader = dataset_util.set_training_data_src("Train Mode")
+
+ super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path,
+ network=self.network, mixed_precision=self.mixed_precision)
+
+
+ def from_pretrained(self, ckpt_path=None):
+ "from_pretrained"
+ self.get_checkpoint_path(ckpt_path)
+ if not ckpt_path:
+ param_dict = ms.load_checkpoint(self.checkpoint_path)
+ else:
+ param_dict = ms.load_checkpoint(ckpt_path)
+ param_not_load, _ = ms.load_param_into_net(self.network, param_dict)
+ print(f'param not load: {param_not_load}')
+
+
+ def predict(self, data, **kwargs):
+ time_list = []
+ input_dir = self.config.input_dir
+ ground_truth_dir = self.config.ground_truth_dir
+ output_dir = self.config.output_dir
+ file_type = '.pdb'
+ l_b_str = '_l_b'
+
+ for file in data:
+ if not file.endswith(l_b_str + file_type):
+ continue
+ ll = len(l_b_str + file_type)
+ ligand_filename = os.path.join(input_dir, file[:-ll] + l_b_str + file_type)
+ receptor_filename = os.path.join(ground_truth_dir, file[:-ll] + '_r_b' + '_COMPLEX' + file_type)
+ out_filename = file[:-ll] + l_b_str + '_' + "EQUIDOCK" + file_type
+
+ self.log(' inference on file = ', ligand_filename)
+
+ start = dt.now()
+
+ ppdb_ligand = PandasPdb().read_pdb(ligand_filename)
+
+ ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, _\
+ = prepare_graphs(self.config, ppdb_ligand, ligand_filename, receptor_filename)
+
+ if self.config.input_edge_feats_dim < 0:
+ self.config.input_edge_feats_dim = ligand_graph.edata['he'].shape[1]
+
+ ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, \
+ input_tensor_tuple = graph_to_tensor(ligand_graph, receptor_graph)
+
+ _, _, _, all_rotation_list, all_translation_list = self.network(
+ ligand_graph_node_tensor,
+ receptor_graph_node_tensor,
+ unbatch_list,
+ input_tensor_tuple,
+ )
+
+ rotation = all_rotation_list[0].asnumpy()
+ translation = all_translation_list[0].asnumpy()
+
+ unbound_ligand_new_pos = (rotation @ unbound_ligand_all_atoms_pre_pos.T).T + translation
+
+ euler_angles_finetune = ops.zeros([3])
+ translation_finetune = ops.zeros([3])
+ ligand_th = (get_rot_mat(euler_angles_finetune) @ Tensor(unbound_ligand_new_pos).T).T + translation_finetune
+
+ ppdb_ligand.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ligand_th.asnumpy() # unbound_ligand_new_pos
+ unbound_ligand_save_filename = os.path.join(output_dir, out_filename)
+ ppdb_ligand.to_pdb(path=unbound_ligand_save_filename, records=['ATOM'], gz=False)
+
+ end = dt.now()
+ time_list.append((end - start).total_seconds())
+
+ time_array = np.array(time_list)
+ self.log("Mean runtime:", np.mean(time_array), "std runtime:", np.std(time_array))
+ self.log('Time list = ', time_list)
+
+
+ def run_a_train_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args):
+ train_complex_rmsd_mean, train_complex_rmsd_median, \
+ train_ligand_rmsd_mean, train_ligand_rmsd_median, \
+ train_receptor_rmsd_mean, train_receptor_rmsd_median, \
+ train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, \
+ train_avg_loss_ot, train_avg_loss_intersection = \
+ self.run_a_generic_epoch('train', run_epoch_tuple, data_batched, data_loader)
+
+ pretty_print_stats('TRAIN', epoch, args.num_epochs,
+ train_complex_rmsd_mean, train_complex_rmsd_median,
+ train_ligand_rmsd_mean, train_ligand_rmsd_median,
+ train_receptor_rmsd_mean, train_receptor_rmsd_median,
+ train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors,
+ train_avg_loss_ot, train_avg_loss_intersection,
+ self.log)
+
+
+ def run_an_eval_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args):
+ """
+ run_an_eval_epoch
+ """
+ val_complex_rmsd_mean, val_complex_rmsd_median, \
+ val_ligand_rmsd_mean, val_ligand_rmsd_median, \
+ val_receptor_rmsd_mean, val_receptor_rmsd_median, \
+ val_avg_loss, val_avg_loss_ligand_coors, \
+ val_avg_loss_receptor_coors, \
+ val_avg_loss_ot, val_avg_loss_intersection = \
+ self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader)
+
+ pretty_print_stats('VALIDATION', epoch, args.num_epochs,
+ val_complex_rmsd_mean, val_complex_rmsd_median,
+ val_ligand_rmsd_mean, val_ligand_rmsd_median,
+ val_receptor_rmsd_mean, val_receptor_rmsd_median,
+ val_avg_loss, val_avg_loss_ligand_coors, val_avg_loss_receptor_coors,
+ val_avg_loss_ot, val_avg_loss_intersection,
+ self.log)
+
+ return val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss
+
+
+ def run_a_generic_epoch(self, ep_type, run_epoch_tuple, data_batched, data_loader):
+ """
+ run_a_generic_epoch
+ """
+ args, self.network, loss_fn_coors, optimizer = run_epoch_tuple
+
+ meter = MeterUnboundBound()
+
+ avg_loss, total_loss, num_batches = 0., 0., 0
+
+ total_loss_ligand_coors = 0.
+ avg_loss_ligand_coors = 0.
+
+ total_loss_receptor_coors = 0.
+ avg_loss_receptor_coors = 0.
+
+ total_loss_ot = 0.
+ avg_loss_ot = 0.
+
+ total_loss_intersection = 0.
+ avg_loss_intersection = 0.
+
+ def forward(batch_id):
+
+ ######## RUN MODEL ##############
+ ligand_graph_node_tensor = Tensor(data_loader[batch_id][6], ms.float32)
+ receptor_graph_node_tensor = Tensor(data_loader[batch_id][7], ms.float32)
+ unbatch_list_tensor = Tensor(data_loader[batch_id][8], ms.int32)
+ input_tensor_tuple = (
+ Tensor(data_loader[batch_id][4], ms.int32), # ligand_graph_num_nodes
+ Tensor(data_loader[batch_id][5], ms.int32), # receptor_graph_num_nodes
+ Tensor(data_loader[batch_id][2], ms.float32), # ligand_graph_edge_tensor
+ Tensor(data_loader[batch_id][3], ms.float32), # receptor_graph_edge_tensor
+ Tensor(data_loader[batch_id][0], ms.int32), # ll_connection_tensor
+ Tensor(data_loader[batch_id][1], ms.int32), # rr_connection_tensor
+ )
+
+ model_ligand_coors_deform_list, \
+ model_keypts_ligand_list, model_keypts_receptor_list, \
+ _, _, = self.network(
+ ligand_graph_node_tensor,
+ receptor_graph_node_tensor,
+ unbatch_list_tensor,
+ input_tensor_tuple,
+ )
+ ################################
+
+ batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss = Tensor(0.0), \
+ Tensor(0.0), Tensor(0.0), Tensor(0.0)
+
+ for i, _ in enumerate(model_ligand_coors_deform_list):
+ ## Compute average MSE loss (which is 3 times smaller than average squared RMSD)
+
+ batch_ligand_coors_loss = batch_ligand_coors_loss + loss_fn_coors(model_ligand_coors_deform_list[i],
+ Tensor(data_batched[8][batch_id][i],
+ ms.float32))
+ # Compute the OT loss for the binding pocket:
+ ligand_pocket_coors = Tensor(data_batched[10][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes
+ receptor_pocket_coors = Tensor(data_batched[11][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes
+
+ ligand_keypts_coors = model_keypts_ligand_list[i] # (K, 3), K = num keypoints
+ receptor_keypts_coors = model_keypts_receptor_list[i] # (K, 3), K = num keypoints
+
+ ## (N, K) cost matrix
+ cost_mat_ligand = compute_sq_dist_mat(ligand_pocket_coors, ligand_keypts_coors)
+ cost_mat_receptor = compute_sq_dist_mat(receptor_pocket_coors, receptor_keypts_coors)
+
+ ot_dist, _ = compute_ot_emd(cost_mat_ligand + cost_mat_receptor)
+ batch_ot_loss = batch_ot_loss + ot_dist
+
+ batch_intersection_loss = batch_intersection_loss + compute_body_intersection_loss(
+ model_ligand_coors_deform_list[i], data_batched[9][batch_id][i],
+ args.intersection_sigma, args.intersection_surface_ct)
+
+ ### Add new stats to the meter
+ if ep_type != 'train' or random.random() < 0.1:
+ meter.update_rmsd(model_ligand_coors_deform_list[i],
+ data_batched[9][batch_id][i],
+ data_batched[8][batch_id][i],
+ data_batched[9][batch_id][i])
+
+ batch_ligand_coors_loss = batch_ligand_coors_loss / float(len(model_ligand_coors_deform_list))
+ batch_receptor_coors_loss = batch_receptor_coors_loss / float(len(model_ligand_coors_deform_list))
+ batch_ot_loss = batch_ot_loss / float(len(model_ligand_coors_deform_list))
+ batch_intersection_loss = batch_intersection_loss / float(len(model_ligand_coors_deform_list))
+
+ loss_coors = batch_ligand_coors_loss + batch_receptor_coors_loss
+
+ loss = loss_coors + args.pocket_ot_loss_weight * batch_ot_loss \
+ + args.intersection_loss_weight * batch_intersection_loss
+
+ return loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss
+
+ backward = ms.value_and_grad(forward, None, optimizer.parameters)
+
+ # Iterate over all batches of the epoch
+ for step in range(len(data_batched[0])):
+ num_batches += 1
+
+ (loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss,
+ batch_intersection_loss), grads = backward(step)
+
+ if ep_type == 'train':
+ grads = ms.ops.clip_by_norm(grads, max_norm=args.clip, norm_type=2)
+ optimizer(grads)
+
+ total_loss += loss.asnumpy()
+ total_loss_ligand_coors += batch_ligand_coors_loss
+ total_loss_receptor_coors += batch_receptor_coors_loss
+ total_loss_ot += batch_ot_loss
+ total_loss_intersection += batch_intersection_loss
+
+ if num_batches != 0:
+ avg_loss = total_loss / num_batches
+ avg_loss_ligand_coors = total_loss_ligand_coors / num_batches
+ avg_loss_receptor_coors = total_loss_receptor_coors / num_batches
+ avg_loss_ot = total_loss_ot / num_batches
+ avg_loss_intersection = total_loss_intersection / num_batches
+
+ ligand_rmsd_mean, receptor_rmsd_mean, complex_rmsd_mean = meter.summarize(reduction_rmsd='mean')
+ ligand_rmsd_median, receptor_rmsd_median, complex_rmsd_median = meter.summarize(reduction_rmsd='median')
+
+ return complex_rmsd_mean, complex_rmsd_median, \
+ ligand_rmsd_mean, ligand_rmsd_median, \
+ receptor_rmsd_mean, receptor_rmsd_median, \
+ avg_loss.item(), avg_loss_ligand_coors.item(), \
+ avg_loss_receptor_coors.item(), \
+ avg_loss_ot.item(), avg_loss_intersection.item()
+
+
+ def train_step(self, data):
+
+ self.log('+' * 100)
+ epoch = data
+ epoch_start = dt.now()
+
+ run_epoch_tuple = (self.config, self.network, self.loss_fn_coors, self.optimizer)
+ self.run_a_train_epoch(run_epoch_tuple, self.train_data_batched, self.train_loader, epoch, self.config)
+ val_complex_rmsd_mean, val_complex_rmsd_median, _ = \
+ self.run_an_eval_epoch(run_epoch_tuple, self.val_data_batched, self.val_loader, epoch, self.config)
+
+ self.scheduler.step()
+
+ if val_complex_rmsd_median < self.best_val_rmsd_median * 0.98: # We do this to avoid "pure luck"
+ # Best validation so far
+ self.best_val_rmsd_median = val_complex_rmsd_median
+ self.corr_val_rmsd_mean = val_complex_rmsd_mean
+ self.best_epoch = epoch
+ save_checkpoint(self.optimizer, self.config.checkpoint_dir + "/best_model.ckpt")
+
+ log('[BEST SO FAR] ', self.config.data,
+ '|| At best epoch {} we have: best_val_rmsd_median {:.4f}, '
+ '|| Current val rmsd_median {:.4f} '
+ '|| Train time: {}\n'.
+ format(self.best_epoch, self.best_val_rmsd_median,
+ val_complex_rmsd_median,
+ dt.now() - epoch_start))
+
+ return '\n'
+
+
+ def log(self, *pargs):
+ print('[' + str(dt.now()) + '] ', *pargs)
+
+
+ def forward(self, data):
+ return None
+
+
+ def backward(self, data):
+ return None
+
+
+ def _jit_forward(self, data):
+ return None
+
+
+ def _pynative_forward(self, data):
+ return None
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..5695757f89cfad9c13948ebf2ca477844545542a
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py
@@ -0,0 +1,24 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""equidock_configuration"""
+
+equidock_configuration = {
+ "predict_dips":
+ "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml",
+ "predict_db5":
+ "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml",
+ "train_db5":
+ "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml",
+}
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..9726490e0edaadae491b35b28c2266061b56c215
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py
@@ -0,0 +1,286 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""equidock_data"""
+
+import os
+import numpy as np
+from biopandas.pdb import PandasPdb
+from joblib import Parallel, delayed, cpu_count
+from scipy.spatial.transform import Rotation
+
+from .nn_arch import preprocess_unbound_bound, protein_to_graph_unbound_bound
+from .train_utils import log
+
+
+def one_hot_encoding(x, allowable_set, encode_unknown=False):
+ if encode_unknown and (allowable_set[-1] is not None):
+ allowable_set.append(None)
+
+ if encode_unknown and (x not in allowable_set):
+ x = None
+
+ return list(map(lambda s: x == s, allowable_set))
+
+
+def uniformrotation_translation(translation_interval):
+ rotation = Rotation.random(num=1)
+ rotation_matrix = rotation.as_matrix().squeeze()
+
+ t = np.random.randn(1, 3)
+ t = t / np.sqrt(np.sum(t * t))
+ length = np.random.uniform(low=0, high=translation_interval)
+ t = t * length
+ return rotation_matrix.astype(np.float32), t.astype(np.float32)
+
+
+def get_residues_db5(pdb_filename):
+ df = PandasPdb().read_pdb(pdb_filename).df['ATOM']
+ df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname',
+ 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True)
+ residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order !
+ return residues
+
+
+def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, **kwargs):
+ if n_jobs is None:
+ n_jobs = cpu_count() - 1
+ results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)(
+ delayed(pickleable_fn)(d, **kwargs) for i, d in enumerate(data)
+ )
+
+ return results
+
+
+class UnboundBoundData():
+ """
+ UnboundBoundData Class
+ """
+ def __init__(self, args, reload_mode='train', raw_data_path=None, split_files_path=None):
+
+ os.makedirs(args.processed_dataset_path + reload_mode, exist_ok=True)
+
+ if args.data == 'db5':
+ onlyfiles = [f for f in os.listdir(raw_data_path) if os.path.isfile(os.path.join(raw_data_path, f))]
+ code_set = {file.split('_')[0] for file in onlyfiles}
+ split_code_set = set()
+ with open(os.path.join(split_files_path, reload_mode + '.txt'), 'r') as f:
+ for line in f.readlines():
+ split_code_set.add(line.rstrip())
+
+ code_set = code_set & split_code_set
+ code_list = list(code_set)
+
+ bound_ligand_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_l_b.pdb'))
+ for code in code_list]
+ bound_receptor_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_r_b.pdb'))
+ for code in code_list]
+
+ input_residues_lists = [(bound_ligand_residues_list[i], bound_receptor_residues_list[i])
+ for i in range(len(bound_ligand_residues_list))]
+
+ log('Start preprocess_unbound_bound')
+ preprocess_result = pmap_multi(preprocess_unbound_bound,
+ input_residues_lists,
+ n_jobs=args.n_jobs,
+ graph_nodes=args.graph_nodes,
+ pos_cutoff=args.pocket_cutoff,
+ inference=False)
+ log('Done preprocess_unbound_bound\n\n')
+
+ unbound_predic_ligand_list, unbound_predic_receptor_list = [], []
+ bound_ligand_repres_nodes_loc_array_list, bound_receptor_repres_nodes_loc_array_list = [], []
+ pocket_coors_list = []
+ for result in preprocess_result:
+ unbound_predic_ligand, unbound_predic_receptor, \
+ bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, pocket_coors = result
+ if pocket_coors is not None:
+ unbound_predic_ligand_list.append(unbound_predic_ligand)
+ unbound_predic_receptor_list.append(unbound_predic_receptor)
+ bound_ligand_repres_nodes_loc_array_list.append(bound_ligand_repres_nodes_loc_array)
+ bound_receptor_repres_nodes_loc_array_list.append(bound_receptor_repres_nodes_loc_array)
+ pocket_coors_list.append(pocket_coors)
+
+ protein_to_graph_input = [(unbound_predic_ligand_list[i],
+ unbound_predic_receptor_list[i],
+ bound_ligand_repres_nodes_loc_array_list[i],
+ bound_receptor_repres_nodes_loc_array_list[i]) for i in
+ range(len(unbound_predic_ligand_list))]
+ log('Start protein_to_graph_unbound_bound')
+
+ both_proteins_to_graph_pair_list = pmap_multi(protein_to_graph_unbound_bound,
+ protein_to_graph_input,
+ n_jobs=args.n_jobs,
+ cutoff=args.graph_cutoff,
+ max_neighbor=args.graph_max_neighbor,
+ one_hot=False,
+ residue_loc_is_alphac=args.graph_residue_loc_is_alphaC
+ )
+
+ log('Done protein_to_graph_unbound_bound')
+
+ self.save_processed_data(
+ args,
+ reload_mode,
+ both_proteins_to_graph_pair_list,
+ bound_ligand_repres_nodes_loc_array_list,
+ bound_receptor_repres_nodes_loc_array_list,
+ pocket_coors_list,
+ )
+
+ def single_graph_data(self, ligand_graph, receptor_graph, ligand_new_loc):
+ """
+ Process single graph data in the list
+ """
+ for k in ligand_graph.ndata.keys():
+ ligand_graph.ndata[k] = ligand_graph.ndata[k].asnumpy()
+ for k in receptor_graph.ndata.keys():
+ receptor_graph.ndata[k] = receptor_graph.ndata[k].asnumpy()
+ for k in ligand_graph.edata.keys():
+ ligand_graph.edata[k] = ligand_graph.edata[k].asnumpy()
+ for k in receptor_graph.edata.keys():
+ receptor_graph.edata[k] = receptor_graph.edata[k].asnumpy()
+ ligand_graph.ndata['new_x'] = ligand_new_loc.astype(np.float32)
+
+ ##### Create a batch of a single heterograph
+ ligand_graph_node_tensor = np.concatenate(
+ (ligand_graph.ndata["res_feat"],
+ ligand_graph.ndata["x"],
+ ligand_graph.ndata["new_x"],
+ ligand_graph.ndata["mu_r_norm"]),
+ axis=1
+ )
+
+ receptor_graph_node_tensor = np.concatenate(
+ (receptor_graph.ndata["res_feat"],
+ receptor_graph.ndata["x"],
+ receptor_graph.ndata["mu_r_norm"]),
+ axis=1
+ )
+
+ ligand_graph_num_nodes = [ligand_graph.ndata["x"].shape[0]]
+ receptor_graph_num_nodes = [receptor_graph.ndata["x"].shape[0]]
+ ligand_graph_edge_tensor = ligand_graph.edata['he']
+ receptor_graph_edge_tensor = receptor_graph.edata['he']
+
+ ll_connection_tensor = np.stack(
+ (ligand_graph.src_list, ligand_graph.dst_list))
+
+ rr_connection_tensor = np.stack(
+ (receptor_graph.src_list, receptor_graph.dst_list))
+ ##### Create a batch of a single heterograph
+
+ return ll_connection_tensor, rr_connection_tensor, ligand_graph_edge_tensor, \
+ receptor_graph_edge_tensor, ligand_graph_num_nodes, receptor_graph_num_nodes, \
+ ligand_graph_node_tensor, receptor_graph_node_tensor
+
+ def save_npz_files(self, args, saved_lists, reload_mode):
+ """
+ save data into npz file format
+ """
+ ll_connection_tensor_list, rr_connection_tensor_list, ligand_graph_edge_tensor_list, \
+ receptor_graph_edge_tensor_list, ligand_graph_num_nodes_list, receptor_graph_num_nodes_list, \
+ ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, bound_ligand_repres_nodes_loc_array_list_new, \
+ bound_receptor_repres_nodes_loc_array_list_new, pocket_coors_ligand_list_new, \
+ pocket_coors_receptor_list_new = saved_lists
+
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ll_connection_tensor_list.npz"),
+ *ll_connection_tensor_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "rr_connection_tensor_list.npz"),
+ *rr_connection_tensor_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_edge_tensor_list.npz"),
+ *ligand_graph_edge_tensor_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_edge_tensor_list.npz"),
+ *receptor_graph_edge_tensor_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_num_nodes_list.npz"),
+ *ligand_graph_num_nodes_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_num_nodes_list.npz"),
+ *receptor_graph_num_nodes_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode,
+ "bound_ligand_repres_nodes_loc_array_list_new.npz"),
+ *bound_ligand_repres_nodes_loc_array_list_new)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode,
+ "bound_receptor_repres_nodes_loc_array_list_new.npz"),
+ *bound_receptor_repres_nodes_loc_array_list_new)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_ligand_list_new.npz"),
+ *pocket_coors_ligand_list_new)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_receptor_list_new.npz"),
+ *pocket_coors_receptor_list_new)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_node_tensor_list.npz"),
+ *ligand_graph_node_tensor_list)
+ np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_node_tensor_list.npz"),
+ *receptor_graph_node_tensor_list)
+
+ def save_processed_data(
+ self,
+ args,
+ reload_mode,
+ both_proteins_to_graph_pair_list,
+ bound_ligand_repres_nodes_loc_array_list,
+ bound_receptor_repres_nodes_loc_array_list,
+ pocket_coors_list,
+ ):
+ """
+ save_processed_data
+ """
+ ligand_graph_list, receptor_graph_list = [], []
+ for result in both_proteins_to_graph_pair_list:
+ ligand_graph, receptor_graph = result
+ ligand_graph_list.append(ligand_graph)
+ receptor_graph_list.append(receptor_graph)
+
+ ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, ligand_graph_num_nodes_list = [], [], []
+ receptor_graph_num_nodes_list, ligand_graph_edge_tensor_list, receptor_graph_edge_tensor_list = [], [], []
+ ll_connection_tensor_list, rr_connection_tensor_list, bound_ligand_repres_nodes_loc_array_list_new = [], [], []
+ pocket_coors_ligand_list_new, pocket_coors_receptor_list_new = [], []
+ bound_receptor_repres_nodes_loc_array_list_new = []
+
+ for i, _ in enumerate(ligand_graph_list):
+ ligand_graph, receptor_graph = ligand_graph_list[i], receptor_graph_list[i]
+ bound_ligand_repres_nodes_loc_array = bound_ligand_repres_nodes_loc_array_list[i]
+ bound_receptor_repres_nodes_loc_array = bound_receptor_repres_nodes_loc_array_list[i]
+ pocket_coors_ligand, pocket_coors_receptor = pocket_coors_list[i], pocket_coors_list[i]
+
+ # Randomly rotate and translate the ligand.
+ rot_t, rot_b = uniformrotation_translation(translation_interval=args.translation_interval)
+ ligand_original_loc = ligand_graph.ndata['x']
+ mean_to_remove = ligand_original_loc.mean(axis=0, keep_dims=True)
+ pocket_coors_ligand = (rot_t @ (pocket_coors_ligand - mean_to_remove).T).T + rot_b
+ ligand_new_loc = (rot_t @ (ligand_original_loc - mean_to_remove).T).T + rot_b
+
+ ll_connection_tensor, rr_connection_tensor, ligand_graph_edge_tensor, receptor_graph_edge_tensor, \
+ ligand_graph_num_nodes, receptor_graph_num_nodes, ligand_graph_node_tensor, \
+ receptor_graph_node_tensor = self.single_graph_data(ligand_graph, receptor_graph, ligand_new_loc)
+
+ ll_connection_tensor_list.append(ll_connection_tensor)
+ rr_connection_tensor_list.append(rr_connection_tensor)
+ ligand_graph_edge_tensor_list.append(ligand_graph_edge_tensor)
+ receptor_graph_edge_tensor_list.append(receptor_graph_edge_tensor)
+ ligand_graph_num_nodes_list.append(np.array(ligand_graph_num_nodes).astype(np.int32))
+ receptor_graph_num_nodes_list.append(np.array(receptor_graph_num_nodes).astype(np.int32))
+ ligand_graph_node_tensor_list.append(ligand_graph_node_tensor)
+ receptor_graph_node_tensor_list.append(receptor_graph_node_tensor)
+ bound_ligand_repres_nodes_loc_array_list_new.append(bound_ligand_repres_nodes_loc_array.astype(np.float32))
+ bound_receptor_repres_nodes_loc_array_list_new.append(
+ bound_receptor_repres_nodes_loc_array.astype(np.float32))
+ pocket_coors_ligand_list_new.append(pocket_coors_ligand.astype(np.float32))
+ pocket_coors_receptor_list_new.append(pocket_coors_receptor.astype(np.float32))
+
+ saved_lists = [ll_connection_tensor_list, rr_connection_tensor_list, ligand_graph_edge_tensor_list,
+ receptor_graph_edge_tensor_list, ligand_graph_num_nodes_list, receptor_graph_num_nodes_list,
+ ligand_graph_node_tensor_list, receptor_graph_node_tensor_list,
+ bound_ligand_repres_nodes_loc_array_list_new, bound_receptor_repres_nodes_loc_array_list_new,
+ pocket_coors_ligand_list_new, pocket_coors_receptor_list_new]
+
+ self.save_npz_files(args, saved_lists, reload_mode)
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa27ccd93f7efaaa665bdac64b12b513a4170984
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py
@@ -0,0 +1,229 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""""equidock_dataset"""
+
+#pylint: disable=W0221
+
+import os
+from datetime import datetime as dt
+
+import numpy as np
+
+from .equidock_data import UnboundBoundData
+from ...dataset import PSP
+
+class EquiDockDataSet(PSP):
+ "EquiDockDataSet"
+ def __init__(self, config):
+ self.config = config
+ self.train_data_batched = None
+ self.train_loader = None
+ self.val_data_batched = None
+ self.val_loader = None
+ self.test_data_batched = None
+ self.test_loader = None
+ super().__init__()
+
+ def process(self, data, **kwargs):
+ self.log(data, **kwargs)
+ input_dir = self.config.input_dir
+ files = [f for f in os.listdir(input_dir) \
+ if os.path.isfile(os.path.join(input_dir, f)) and f.endswith('.pdb')]
+ return files
+
+ def set_training_data_src(self, data_source, **kwargs):
+ """
+ set_training_data_src
+ """
+ self.log(data_source, **kwargs)
+
+ datasets = ['train', 'val', 'test']
+
+ if not os.path.exists(self.config.processed_dataset_path):
+ for _, dataset in enumerate(datasets):
+ UnboundBoundData(
+ self.config,
+ reload_mode=dataset,
+ raw_data_path=self.config.raw_data_path,
+ split_files_path=self.config.split_files_path,
+ )
+
+ self.train_data_batched, self.train_loader = self.create_dataloader(
+ self.config.train_dir,
+ shuffle=True,
+ )
+ self.val_data_batched, self.val_loader = self.create_dataloader(
+ self.config.val_dir,
+ shuffle=False,
+ )
+ self.test_data_batched, self.test_loader = self.create_dataloader(
+ self.config.test_dir,
+ shuffle=False,
+ )
+
+ return (
+ self.train_data_batched,
+ self.train_loader,
+ self.val_data_batched,
+ self.val_loader,
+ self.test_data_batched,
+ self.test_loader,
+ )
+
+ def create_iterator(self, num_epochs, **kwargs):
+ self.log(**kwargs)
+ if num_epochs == self.config.num_epochs:
+ return [_ for _ in range(num_epochs)]
+ return [_ for _ in range(self.config.num_epochs)]
+
+ def data_parse(self, idx):
+ return self.train_data_batched[idx]
+
+ def log(self, *pargs):
+ print('[' + str(dt.now()) + '] ', *pargs)
+
+ def load_data(self, files_dir):
+ """
+ load_data
+ """
+ npz_files_list = [
+ np.load(files_dir + "/ll_connection_tensor_list.npz"),
+ np.load(files_dir + "/rr_connection_tensor_list.npz"),
+ np.load(files_dir + "/ligand_graph_edge_tensor_list.npz"),
+ np.load(files_dir + "/receptor_graph_edge_tensor_list.npz"),
+ np.load(files_dir + "/ligand_graph_num_nodes_list.npz"),
+ np.load(files_dir + "/receptor_graph_num_nodes_list.npz"),
+ np.load(files_dir + "/ligand_graph_node_tensor_list.npz"),
+ np.load(files_dir + "/receptor_graph_node_tensor_list.npz"),
+ np.load(files_dir + "/bound_ligand_repres_nodes_loc_array_list_new.npz"),
+ np.load(files_dir + "/bound_receptor_repres_nodes_loc_array_list_new.npz"),
+ np.load(files_dir + "/pocket_coors_ligand_list_new.npz"),
+ np.load(files_dir + "/pocket_coors_receptor_list_new.npz"),
+ ]
+
+ extracted_files = []
+ for _, npz_file in enumerate(npz_files_list):
+ extracted_files.append([npz_file[id] for id in npz_file.files])
+
+ unbatch_list = []
+ for i, _ in enumerate(extracted_files[0]):
+ temp_length = []
+ temp_length.append(len(extracted_files[0][i][0]))
+ temp_length.append(len(extracted_files[1][i][0]))
+ temp_length.append(len(extracted_files[6][i]))
+ temp_length.append(len(extracted_files[7][i]))
+ temp_length.append(len(extracted_files[11][i]))
+ unbatch_list.append(np.array(temp_length))
+ extracted_files.append(unbatch_list)
+
+ index_shuffle = np.arange(len(npz_files_list[0]))
+
+ return extracted_files, index_shuffle
+
+
+ def shuffle_list(self, source_list, index_shuffle):
+ shuffled_list = []
+ for _, idx in enumerate(index_shuffle):
+ shuffled_list.append(source_list[idx])
+
+ return shuffled_list
+
+
+ def batch(self, bs, souce_list):
+ output_list = []
+ num = len(souce_list) // bs
+ for i in range(num):
+ output_list.append(souce_list[i * bs: (i + 1) * bs])
+ if num * bs < len(souce_list):
+ output_list.append(souce_list[num * bs:])
+
+ return output_list
+
+
+ def shuffle_batch_dataset(self, input_dataset, index_shuffle, batch_size, shuffle):
+ """
+ shuffle_batch_dataset
+ """
+ if shuffle:
+ shuffled_dataset = []
+ np.random.shuffle(index_shuffle)
+ for i, _ in enumerate(input_dataset):
+ shuffled_dataset.append(self.shuffle_list(input_dataset[i], index_shuffle))
+ else:
+ shuffled_dataset = input_dataset[:]
+
+ dataset_batched = []
+ for _, data in enumerate(shuffled_dataset):
+ dataset_batched.append(self.batch(batch_size, data))
+ unbatch_list = dataset_batched[-1]
+ for j, _ in enumerate(unbatch_list):
+ unbatch_list[j].insert(0, np.array([0, 0, 0, 0, 0]))
+ for i, _ in enumerate(unbatch_list[j]):
+ if i >= 1:
+ unbatch_list[j][i][0] += unbatch_list[j][i - 1][0]
+ unbatch_list[j][i][1] += unbatch_list[j][i - 1][1]
+ unbatch_list[j][i][2] += unbatch_list[j][i - 1][2]
+ unbatch_list[j][i][3] += unbatch_list[j][i - 1][3]
+ unbatch_list[j][i][4] += unbatch_list[j][i - 1][4]
+ dataset_batched[-1] = unbatch_list
+
+ return dataset_batched
+
+
+ def cat_properties(self, input_dataset_batched):
+ """
+ cat_properties
+ """
+ dataset_batch_cat = []
+ for i in range(len(input_dataset_batched[0])):
+ dataset_batch_cat.append(
+ [
+ np.concatenate(input_dataset_batched[0][i], axis=1),
+ np.concatenate(input_dataset_batched[1][i], axis=1),
+ np.concatenate(input_dataset_batched[2][i], axis=0),
+ np.concatenate(input_dataset_batched[3][i], axis=0),
+ np.concatenate(input_dataset_batched[4][i], axis=0),
+ np.concatenate(input_dataset_batched[5][i], axis=0),
+ np.concatenate(input_dataset_batched[6][i], axis=0),
+ np.concatenate(input_dataset_batched[7][i], axis=0),
+ input_dataset_batched[-1][i],
+ ]
+ )
+
+ return dataset_batch_cat
+
+
+ def create_dataloader(self, dataset_dir, shuffle):
+ """
+ create_dataloader
+ """
+ dataset, index_shuffle = self.load_data(dataset_dir)
+ dataset_batched = self.shuffle_batch_dataset(
+ input_dataset=dataset,
+ index_shuffle=index_shuffle,
+ batch_size=self.config.bs,
+ shuffle=shuffle,
+ )
+ dataloader = self.cat_properties(dataset_batched)
+
+ return dataset_batched, dataloader
+
+
+ def __getitem__(self, idx):
+ return self.data_parse(idx)
+
+
+ def __len__(self):
+ return len(self.train_data_batched)
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5d5f0a2cb5335109fe49250157dae150b4b5fa5
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py
@@ -0,0 +1,1326 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""nn_arch"""
+
+import os
+import math
+from datetime import datetime as dt
+
+import numpy as np
+from numpy import linalg as LA
+from biopandas.pdb import PandasPdb
+import scipy.spatial as spa
+from scipy.special import softmax
+from mindspore import nn, ops, Tensor
+
+
+ATOM_NAME = 'atom_name'
+FLAGS = os.O_RDWR | os.O_CREAT
+
+
+def log(*pargs):
+ banner = 'EQUIDOCK'
+ log_file_name = os.path.join('stdouterr/', banner + ".txt")
+ with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as w:
+ w.write('[' + str(dt.now()) + '] ')
+ w.write(" ".join(["{}".format(t) for t in pargs]))
+ w.write("\n")
+ print('[' + str(dt.now()) + '] ', *pargs)
+
+
+class MeterUnboundBound():
+ """
+ MeterUnboundBound
+ """
+ def __init__(self):
+ self.complex_rmsd_list = []
+ self.ligand_rmsd_list = []
+ self.receptor_rmsd_list = []
+
+ def update_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true):
+ """
+ update_rmsd
+ """
+ ligand_coors_pred = ligand_coors_pred.asnumpy()
+ receptor_coors_pred = receptor_coors_pred
+
+ ligand_coors_true = ligand_coors_true
+ receptor_coors_true = receptor_coors_true
+
+ ligand_rmsd = np.sqrt(np.mean(np.sum((ligand_coors_pred - ligand_coors_true) ** 2, axis=1)))
+ receptor_rmsd = np.sqrt(np.mean(np.sum((receptor_coors_pred - receptor_coors_true) ** 2, axis=1)))
+
+ complex_coors_pred = np.concatenate((ligand_coors_pred, receptor_coors_pred), axis=0)
+ complex_coors_true = np.concatenate((ligand_coors_true, receptor_coors_true), axis=0)
+
+ r, b = rigid_transform_kabsch_3d(complex_coors_pred.T, complex_coors_true.T)
+ complex_coors_pred_aligned = ((r @ complex_coors_pred.T) + b).T
+
+ complex_rmsd = np.sqrt(np.mean(np.sum((complex_coors_pred_aligned - complex_coors_true) ** 2, axis=1)))
+
+ self.complex_rmsd_list.append(complex_rmsd)
+ self.ligand_rmsd_list.append(ligand_rmsd)
+ self.receptor_rmsd_list.append(receptor_rmsd)
+
+ return complex_rmsd
+
+ def summarize(self, reduction_rmsd='median'):
+ """
+ summarize
+ """
+ if reduction_rmsd == 'mean':
+ complex_rmsd_array = np.array(self.complex_rmsd_list)
+ complex_rmsd_summarized = np.mean(complex_rmsd_array)
+
+ ligand_rmsd_array = np.array(self.ligand_rmsd_list)
+ ligand_rmsd_summarized = np.mean(ligand_rmsd_array)
+
+ receptor_rmsd_array = np.array(self.receptor_rmsd_list)
+ receptor_rmsd_summarized = np.mean(receptor_rmsd_array)
+ elif reduction_rmsd == 'median':
+ complex_rmsd_array = np.array(self.complex_rmsd_list)
+ complex_rmsd_summarized = np.median(complex_rmsd_array)
+
+ ligand_rmsd_array = np.array(self.ligand_rmsd_list)
+ ligand_rmsd_summarized = np.median(ligand_rmsd_array)
+
+ receptor_rmsd_array = np.array(self.receptor_rmsd_list)
+ receptor_rmsd_summarized = np.median(receptor_rmsd_array)
+ else:
+ raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!")
+
+ return ligand_rmsd_summarized, receptor_rmsd_summarized, complex_rmsd_summarized
+
+ def summarize_with_std(self, reduction_rmsd='median'):
+ complex_rmsd_array = np.array(self.complex_rmsd_list)
+ if reduction_rmsd == 'mean':
+ complex_rmsd_summarized = np.mean(complex_rmsd_array)
+ elif reduction_rmsd == 'median':
+ complex_rmsd_summarized = np.median(complex_rmsd_array)
+ else:
+ raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!")
+
+ return complex_rmsd_summarized, np.std(complex_rmsd_array)
+
+
+def get_nodes_coors_numpy(filename, all_atoms=False):
+ df = PandasPdb().read_pdb(filename).df['ATOM']
+ if not all_atoms:
+ return Tensor(
+ df[df['atom_name'] == 'CA'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32))
+ return Tensor(df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32))
+
+
+def get_residues(pdb_filename):
+ df = PandasPdb().read_pdb(pdb_filename).df['ATOM']
+ df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname',
+ 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True)
+ residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order !
+ return residues
+
+
+def distance_list_featurizer(dist_list):
+ """
+ distance_list_featurizer
+ """
+ length_scale_list = [1.5 ** x for x in range(15)]
+ center_list = [0. for _ in range(15)]
+
+ num_edge = len(dist_list)
+ dist_list = np.array(dist_list)
+
+ transformed_dist = [np.exp(- ((dist_list - center) ** 2) / float(length_scale))
+ for length_scale, center in zip(length_scale_list, center_list)]
+
+ transformed_dist = np.array(transformed_dist).T
+ transformed_dist = transformed_dist.reshape((num_edge, -1))
+
+ processed_features = dict()
+ processed_features['he'] = Tensor(transformed_dist.astype(np.float32))
+ return processed_features
+
+
+def rigid_transform_kabsch_3d(a, b):
+ """
+ rigid_transform_kabsch_3d
+ """
+ if a.shape[1] != b.shape[1]:
+ raise ValueError("a.shape[1] not equals b.shape[1]")
+ num_rows, num_cols = a.shape
+ if num_rows != 3:
+ raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}")
+ num_rows, num_cols = b.shape
+ if num_rows != 3:
+ raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}")
+
+ # find mean column wise: 3 x 1
+ centroid_a = np.mean(a, axis=1, keepdims=True)
+ centroid_b = np.mean(b, axis=1, keepdims=True)
+
+ # subtract mean
+ am = a - centroid_a
+ bm = b - centroid_b
+
+ h = am @ bm.T
+
+ # find rotation
+ u, _, vt = np.linalg.svd(h)
+
+ r = vt.T @ u.T
+
+ # special reflection case
+ if np.linalg.det(r) < 0:
+ ss = np.diag([1., 1., -1.])
+ r = (vt.T @ ss) @ u.T
+ if math.fabs(np.linalg.det(r) - 1) >= 1e-5:
+ raise ValueError("The value should be smaller than 1e-5")
+
+ t = -r @ centroid_a + centroid_b
+ return r, t
+
+
+def one_hot_encoding(x, allowable_set, encode_unknown=False):
+ if encode_unknown and (allowable_set[-1] is not None):
+ allowable_set.append(None)
+
+ if encode_unknown and (x not in allowable_set):
+ x = None
+
+ return list(map(lambda s: x == s, allowable_set))
+
+
+def residue_list_featurizer_dips_one_hot(predic):
+ residue_list = [term[1]['resname'].iloc[0] for term in predic]
+ feature_list = [residue_type_one_hot_dips(residue) for residue in residue_list]
+ feature_list = np.stack(feature_list)
+ processed_features = dict()
+ processed_features['res_feat'] = Tensor(feature_list.astype(np.float32))
+ return processed_features
+
+
+def residue_list_featurizer_dips_not_one_hot(predic):
+ residue_list = [term[1]['resname'].iloc[0] for term in predic]
+ feature_list = [[residue_type_one_hot_dips_not_one_hot(residue)] for residue in residue_list]
+ feature_list = np.array(feature_list)
+ processed_features = dict()
+ processed_features['res_feat'] = Tensor(feature_list.astype(np.float32)) # (N_res, 1)
+ return processed_features
+
+
+def residue_type_one_hot_dips(residue):
+ """
+ residue_type_one_hot_dips
+ """
+ dit = {
+ 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E',
+ 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F',
+ 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V',
+ 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y',
+ 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D',
+ 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G',
+ }
+ allowable_set = [
+ 'Y', 'R', 'F', 'G', 'I', 'V', 'A', 'W', 'E', 'H', 'C', 'N', 'M', 'D', 'T', 'S', 'K', 'L', 'Q', 'P'
+ ]
+ res_name = residue
+ if res_name not in dit.keys():
+ res_name = None
+ else:
+ res_name = dit[res_name]
+ return one_hot_encoding(res_name, allowable_set, encode_unknown=True)
+
+
+def residue_type_one_hot_dips_not_one_hot(residue):
+ """
+ residue_type_one_hot_dips_not_one_hot
+ """
+ dit = {
+ 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E',
+ 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F',
+ 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V',
+ 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y',
+ 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D',
+ 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G'
+ }
+
+ rare_residues = {
+ 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y',
+ 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D',
+ 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G'
+ }
+
+ if residue in rare_residues.keys():
+ log('Some rare residue: ', residue)
+
+ indicator = {
+ 'Y': 0, 'R': 1, 'F': 2, 'G': 3, 'I': 4, 'V': 5,
+ 'A': 6, 'W': 7, 'E': 8, 'H': 9, 'C': 10, 'N': 11,
+ 'M': 12, 'D': 13, 'T': 14, 'S': 15, 'K': 16, 'L': 17, 'Q': 18, 'P': 19
+ }
+ res_name = residue
+ if res_name not in dit.keys():
+ return 20
+ res_name = dit[res_name]
+ return indicator.get(res_name)
+
+
+def preprocess_unbound_bound(input_residues_tuple, graph_nodes, pos_cutoff=8.0, inference=False):
+ """
+ preprocess_unbound_bound
+ """
+ #######################
+ def filter_residues(residues):
+ residues_filtered = []
+ for residue in residues:
+ df = residue[1]
+ natom = df[df[ATOM_NAME] == 'N']
+ alphacatom = df[df[ATOM_NAME] == 'CA']
+ catom = df[df[ATOM_NAME] == 'C']
+
+ if natom.shape[0] == 1 and alphacatom.shape[0] == 1 and catom.shape[0] == 1:
+ residues_filtered.append(residue)
+ return residues_filtered
+
+ ##########################
+ bound_ligand_residues, bound_receptor_residues = input_residues_tuple
+ bound_predic_ligand_filtered = filter_residues(bound_ligand_residues)
+ unbound_predic_ligand_filtered = bound_predic_ligand_filtered
+
+ bound_predic_receptor_filtered = filter_residues(bound_receptor_residues)
+ unbound_predic_receptor_filtered = bound_predic_receptor_filtered
+
+ bound_predic_ligand_clean_list = bound_predic_ligand_filtered
+ unbound_predic_ligand_clean_list = unbound_predic_ligand_filtered
+
+ bound_predic_receptor_clean_list = bound_predic_receptor_filtered
+ unbound_predic_receptor_clean_list = unbound_predic_receptor_filtered
+
+ ###################
+ def get_alphac_loc_array(bound_predic_clean_list):
+ bound_alphac_loc_clean_list = []
+ for residue in bound_predic_clean_list:
+ df = residue[1]
+ alphacatom = df[df[ATOM_NAME] == 'CA']
+ alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32)
+ bound_alphac_loc_clean_list.append(alphac_loc)
+ if len(bound_alphac_loc_clean_list) <= 1:
+ bound_alphac_loc_clean_list.append(np.zeros(3))
+ return np.stack(bound_alphac_loc_clean_list, axis=0) # (N_res,3)
+
+ ####################
+ if graph_nodes != 'residues':
+ raise TypeError("graph_nodes should be residues")
+ bound_receptor_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_receptor_clean_list)
+ bound_ligand_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_ligand_clean_list)
+
+ if not inference:
+
+ # Keep pairs of ligand and receptor residues/atoms that have pairwise distances < threshold
+ ligand_receptor_distance = spa.distance.cdist(bound_ligand_repres_nodes_loc_array,
+ bound_receptor_repres_nodes_loc_array)
+ positive_tuple = np.where(ligand_receptor_distance < pos_cutoff)
+ active_ligand = positive_tuple[0]
+ active_receptor = positive_tuple[1]
+ if active_ligand.size <= 3: # We need: active_ligand.size > 0 '
+ pocket_coors = None # Will be filtered out later
+ else:
+ ligand_pocket_coors = bound_ligand_repres_nodes_loc_array[active_ligand, :]
+ receptor_pocket_coors = bound_receptor_repres_nodes_loc_array[active_receptor, :]
+ if np.max(np.linalg.norm(ligand_pocket_coors - receptor_pocket_coors, axis=1)) > pos_cutoff:
+ raise ValueError("The value should <= pos_cutoff")
+ pocket_coors = 0.5 * (ligand_pocket_coors + receptor_pocket_coors)
+ log('Num pocket nodes = ', len(active_ligand), ' total nodes = ',
+ bound_ligand_repres_nodes_loc_array.shape[0], ' graph_nodes = ', graph_nodes)
+
+ return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \
+ bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, \
+ pocket_coors
+
+ return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \
+ bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, 0
+
+
+def protein_to_graph_unbound_bound(
+ unbound_bound_tuple,
+ cutoff=20,
+ max_neighbor=None,
+ one_hot=False,
+ residue_loc_is_alphac=True,
+):
+ return protein_to_graph_unbound_bound_residuesonly(
+ unbound_bound_tuple,
+ cutoff,
+ max_neighbor,
+ one_hot,
+ residue_loc_is_alphac
+ )
+
+
+def protein_to_graph_unbound_bound_residuesonly(
+ unbound_bound_tuple,
+ cutoff=20,
+ max_neighbor=None,
+ one_hot=False,
+ residue_loc_is_alphac=True
+):
+ """
+ protein_to_graph_unbound_bound_residuesonly
+ """
+ unbound_ligand_predic, unbound_receptor_predic, bound_ligand_repres_nodes_loc_clean_array, \
+ bound_receptor_repres_nodes_loc_clean_array = unbound_bound_tuple
+
+ ################## Extract 3D coordinates and n_i,u_i,v_i vectors of representative residues ################
+ def l_or_r_extract_3d_coord_and_n_u_v_vecs(l_or_r_predic):
+ l_or_r_all_atom_coords_in_residue_list = []
+ l_or_r_residue_representatives_loc_list = []
+ l_or_r_n_i_list = []
+ l_or_r_u_i_list = []
+ l_or_r_v_i_list = []
+
+ for residue in l_or_r_predic:
+ df = residue[1]
+ coord = df[['x', 'y', 'z']].to_numpy().astype(np.float32) # (N_atoms, 3)
+ l_or_r_all_atom_coords_in_residue_list.append(coord)
+
+ natom = df[df[ATOM_NAME] == 'N']
+ alphacatom = df[df[ATOM_NAME] == 'CA']
+ catom = df[df[ATOM_NAME] == 'C']
+
+ if natom.shape[0] != 1 or alphacatom.shape[0] != 1 or catom.shape[0] != 1:
+ raise ValueError("protein utils protein_to_graph_unbound_bound, no N/CA/C exists")
+
+ n_loc = natom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32)
+ alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32)
+ c_loc = catom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32)
+
+ u_i = (n_loc - alphac_loc) / LA.norm(n_loc - alphac_loc)
+ t_i = (c_loc - alphac_loc) / LA.norm(c_loc - alphac_loc)
+ n_i = np.cross(u_i, t_i) / LA.norm(np.cross(u_i, t_i))
+ v_i = np.cross(n_i, u_i)
+
+ l_or_r_n_i_list.append(n_i)
+ l_or_r_u_i_list.append(u_i)
+ l_or_r_v_i_list.append(v_i)
+
+ if residue_loc_is_alphac:
+ l_or_r_residue_representatives_loc_list.append(alphac_loc)
+ else:
+ heavy_df = df[df['element'] != 'H']
+ residue_loc = heavy_df[['x', 'y', 'z']].mean(axis=0).to_numpy().astype(
+ np.float32) # average of all atom coordinates
+ l_or_r_residue_representatives_loc_list.append(residue_loc)
+
+ l_or_r_residue_representatives_loc_feat = np.stack(l_or_r_residue_representatives_loc_list,
+ axis=0) # (N_res, 3)
+ l_or_r_n_i_feat = np.stack(l_or_r_n_i_list, axis=0)
+ l_or_r_u_i_feat = np.stack(l_or_r_u_i_list, axis=0)
+ l_or_r_v_i_feat = np.stack(l_or_r_v_i_list, axis=0)
+
+ l_or_r_num_residues = len(l_or_r_predic)
+ if l_or_r_num_residues <= 1:
+ raise ValueError(f"l_or_r contains only 1 residue!")
+ return l_or_r_all_atom_coords_in_residue_list, \
+ l_or_r_residue_representatives_loc_feat, \
+ l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, l_or_r_num_residues
+
+ (ligand_all_atom_coords_in_residue_list, # list of (N_atoms,3) arrays, for each residue
+ ligand_residue_representatives_loc_feat, # (N_res, 3)
+ ligand_n_i_feat, # (N_res, 3)
+ ligand_u_i_feat, # (N_res, 3)
+ ligand_v_i_feat, # (N_res, 3)
+ ligand_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_ligand_predic)
+
+ (receptor_all_atom_coords_in_residue_list,
+ receptor_residue_representatives_loc_feat,
+ receptor_n_i_feat,
+ receptor_u_i_feat,
+ receptor_v_i_feat,
+ receptor_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_receptor_predic)
+
+ ################# Align unbound and bound structures, if needed ################################
+ def l_or_r_align_unbound_and_bound(l_or_r_residue_representatives_loc_feat,
+ l_or_r_n_i_feat,
+ l_or_r_u_i_feat,
+ l_or_r_v_i_feat,
+ bound_l_or_r_alphac_loc_clean_array):
+
+ ret_r_l_or_r, ret_t_l_or_r = rigid_transform_kabsch_3d(l_or_r_residue_representatives_loc_feat.T,
+ bound_l_or_r_alphac_loc_clean_array.T)
+ l_or_r_residue_representatives_loc_feat = ((ret_r_l_or_r @ (l_or_r_residue_representatives_loc_feat).T) +
+ ret_t_l_or_r).T
+ l_or_r_n_i_feat = ((ret_r_l_or_r @ (l_or_r_n_i_feat).T)).T
+ l_or_r_u_i_feat = ((ret_r_l_or_r @ (l_or_r_u_i_feat).T)).T
+ l_or_r_v_i_feat = ((ret_r_l_or_r @ (l_or_r_v_i_feat).T)).T
+ return l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat
+
+ (ligand_residue_representatives_loc_feat,
+ ligand_n_i_feat,
+ ligand_u_i_feat,
+ ligand_v_i_feat) = l_or_r_align_unbound_and_bound(ligand_residue_representatives_loc_feat,
+ ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat,
+ bound_ligand_repres_nodes_loc_clean_array)
+ (receptor_residue_representatives_loc_feat,
+ receptor_n_i_feat,
+ receptor_u_i_feat,
+ receptor_v_i_feat) = l_or_r_align_unbound_and_bound(receptor_residue_representatives_loc_feat,
+ receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat,
+ bound_receptor_repres_nodes_loc_clean_array)
+
+ ################### Build the k-NN graph ##############################
+ def loop_edges(input_tuple):
+
+ l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, \
+ l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list = input_tuple
+
+ # Loop over all edges of the graph and build the various p_ij, q_ij, k_ij, t_ij pairs
+ l_or_r_edge_feat_ori_list = []
+ for i in range(len(l_or_r_dist_list)):
+ src = l_or_r_src_list[i]
+ dst = l_or_r_dst_list[i]
+
+ # place n_i, u_i, v_i as lines in a 3x3 basis matrix
+ basis_matrix = np.stack((l_or_r_n_i_feat[dst, :], l_or_r_u_i_feat[dst, :], l_or_r_v_i_feat[dst, :]), axis=0)
+ p_ij = np.matmul(basis_matrix, l_or_r_residue_representatives_loc_feat[src, :] -
+ l_or_r_residue_representatives_loc_feat[dst, :])
+ q_ij = np.matmul(basis_matrix, l_or_r_n_i_feat[src, :]) # shape (3,)
+ k_ij = np.matmul(basis_matrix, l_or_r_u_i_feat[src, :])
+ t_ij = np.matmul(basis_matrix, l_or_r_v_i_feat[src, :])
+ s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0) # shape (12,)
+ l_or_r_edge_feat_ori_list.append(s_ij)
+ l_or_r_edge_feat_ori_feat = np.stack(l_or_r_edge_feat_ori_list, axis=0) # shape (num_edges, 4, 3)
+ l_or_r_edge_feat_ori_feat = Tensor(l_or_r_edge_feat_ori_feat.astype(np.float32))
+ l_or_r_protein_graph.edata['he'] = Tensor(
+ np.concatenate((l_or_r_protein_graph.edata['he'].asnumpy(), l_or_r_edge_feat_ori_feat.asnumpy()), axis=1)
+ )
+
+ l_or_r_residue_representatives_loc_feat = Tensor(
+ l_or_r_residue_representatives_loc_feat.astype(np.float32))
+ l_or_r_protein_graph.ndata['x'] = l_or_r_residue_representatives_loc_feat
+ l_or_r_protein_graph.ndata['mu_r_norm'] = Tensor(
+ np.array(l_or_r_mean_norm_list).astype(np.float32))
+
+ return l_or_r_protein_graph
+
+
+ def compute_dig_knn_graph(input_tuple):
+
+ l_or_r_num_residues, l_or_r_all_atom_coords_in_residue_list, unbound_l_or_r_predic,\
+ l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat = input_tuple
+
+ l_or_r_distance = np.full((l_or_r_num_residues, l_or_r_num_residues), np.inf)
+
+ for i in range(l_or_r_num_residues - 1):
+ for j in range((i + 1), l_or_r_num_residues):
+ l_or_r_pairwise_dis = spa.distance.cdist(l_or_r_all_atom_coords_in_residue_list[i],
+ l_or_r_all_atom_coords_in_residue_list[j])
+ l_or_r_distance[i, j] = np.mean(l_or_r_pairwise_dis)
+ l_or_r_distance[j, i] = np.mean(l_or_r_pairwise_dis)
+
+ l_or_r_protein_graph = Graph(num_nodes=l_or_r_num_residues)
+
+ l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_mean_norm_list = [], [], [], []
+
+ for i in range(l_or_r_num_residues):
+ valid_src = list(np.where(l_or_r_distance[i, :] < cutoff)[0])
+ if len(valid_src) > max_neighbor:
+ valid_src = list(np.argsort(l_or_r_distance[i, :]))[0: max_neighbor]
+ valid_dst = [i] * len(valid_src)
+ l_or_r_dst_list.extend(valid_dst)
+ l_or_r_src_list.extend(valid_src)
+
+ valid_dist = list(l_or_r_distance[i, valid_src])
+ l_or_r_dist_list.extend(valid_dist)
+
+ valid_dist_np = l_or_r_distance[i, valid_src]
+ sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1))
+ weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num)
+ diff_vecs = l_or_r_residue_representatives_loc_feat[valid_dst, :] - \
+ l_or_r_residue_representatives_loc_feat[valid_src, :] # (neigh_num, 3)
+ mean_vec = weights.dot(diff_vecs) # (sigma_num, 3)
+ denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,)
+ mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,)
+ l_or_r_mean_norm_list.append(mean_vec_ratio_norm)
+
+ l_or_r_protein_graph.add_edges(l_or_r_src_list, l_or_r_dst_list)
+
+ if one_hot:
+ l_or_r_protein_graph.ndata = residue_list_featurizer_dips_one_hot(unbound_l_or_r_predic)
+ else:
+ l_or_r_protein_graph.ndata = residue_list_featurizer_dips_not_one_hot(unbound_l_or_r_predic)
+
+ l_or_r_protein_graph.edata = distance_list_featurizer(l_or_r_dist_list)
+
+ loop_edges_input = l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, \
+ l_or_r_v_i_feat, l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list
+
+ return loop_edges(loop_edges_input)
+
+
+ ligand_protein_graph_tuple = (
+ ligand_num_residues, ligand_all_atom_coords_in_residue_list, unbound_ligand_predic,
+ ligand_residue_representatives_loc_feat, ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat,
+ )
+ ligand_protein_graph = compute_dig_knn_graph(ligand_protein_graph_tuple)
+
+ receptor_protein_graph_tuple = (
+ receptor_num_residues, receptor_all_atom_coords_in_residue_list, unbound_receptor_predic,
+ receptor_residue_representatives_loc_feat, receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat,
+ )
+ receptor_protein_graph = compute_dig_knn_graph(receptor_protein_graph_tuple)
+
+ return ligand_protein_graph, receptor_protein_graph
+
+
+def unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, h_feats_ligand, coors_receptor, coors_ligand):
+ """
+ unbatch_hetero_graph
+ """
+ list_hetero_graph = []
+ for i, _ in enumerate(unbatch_list_tensor):
+ if i < len(unbatch_list_tensor) - 1:
+ hetero_graph = {
+ "receptor_hv_iegmn_out": \
+ h_feats_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :],
+ "ligand_hv_iegmn_out": \
+ h_feats_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :],
+ "receptor_x_iegmn_out": \
+ coors_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :],
+ "ligand_x_iegmn_out": \
+ coors_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :],
+ }
+ list_hetero_graph.append(hetero_graph)
+ return list_hetero_graph
+
+
+def get_non_lin(input_type, negative_slope):
+ if input_type == 'swish':
+ return nn.SiLU()
+ if input_type == 'lkyrelu':
+ return nn.LeakyReLU(alpha=negative_slope)
+ raise NotImplementedError
+
+
+def get_layer_norm(layer_norm_type, dim):
+ if layer_norm_type == 'BN':
+ return nn.BatchNorm1d([dim])
+ if layer_norm_type == 'LN':
+ return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1,
+ epsilon=1e-5)
+ return nn.Identity()
+
+
+def get_final_h_layer_norm(layer_norm_type, dim):
+ if layer_norm_type == 'BN':
+ return nn.BatchNorm1d(dim)
+ if layer_norm_type == 'LN':
+ return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, epsilon=1e-5)
+ if layer_norm_type == '0':
+ return nn.Identity()
+ raise NotImplementedError
+
+
+def apply_final_h_layer_norm(h, norm_layer):
+ return norm_layer(h)
+
+
+def compute_cross_attention(queries, keys, values, mask, cross_msgs):
+ """
+ compute_cross_attention
+ """
+ if not cross_msgs:
+ return 0
+ a = mask * ops.mm(queries, ops.transpose(keys, (1, 0))) - 1000. * (1. - mask)
+ a_x = ops.softmax(a) # i->j, NxM
+ attention_x = ops.mm(a_x, values) # (N,d)
+ return attention_x
+
+
+def get_mask(ligand_batch_num_nodes, receptor_batch_num_nodes):
+ """
+ get_mask
+ """
+ rows = sum(ligand_batch_num_nodes)
+ cols = sum(receptor_batch_num_nodes)
+ mask = ops.zeros((int(rows), int(cols)))
+ partial_l = 0
+ partial_r = 0
+ for l_n, r_n in zip(ligand_batch_num_nodes, receptor_batch_num_nodes):
+ mask[partial_l: partial_l + l_n, partial_r: partial_r + r_n] = 1
+ partial_l = partial_l + l_n
+ partial_r = partial_r + r_n
+ return mask
+
+
+class Graph:
+ """
+ Graph class for wrapping data
+ """
+ def __init__(
+ self,
+ num_nodes=0,
+ num_edges=0,
+ ):
+ self.num_nodes = num_nodes
+ self.num_edges = num_edges
+ self.ndata = {}
+ self.edata = {}
+ self.src_list = []
+ self.dst_list = []
+
+ def add_edges(self, src_list, dst_list):
+ if len(src_list) != len(dst_list):
+ raise NotImplementedError
+ self.src_list = src_list
+ self.dst_list = dst_list
+ self.num_edges = len(dst_list)
+
+
+class IEGMNLayer(nn.Cell):
+ """
+ IEGMNLayer class
+ """
+ def __init__(
+ self,
+ orig_h_feats_dim,
+ h_feats_dim,
+ out_feats_dim,
+ fine_tune,
+ args,
+ log_input=None):
+
+ super(IEGMNLayer, self).__init__()
+
+ dropout = args.dropout
+ nonlin = args.nonlin
+ layer_norm = args.layer_norm
+ leakyrelu_neg_slope = args.leakyrelu_neg_slope
+ self.input_edge_feats_dim = args.input_edge_feats_dim
+ self.layer_norm_coors = args.layer_norm_coors
+ self.cross_msgs = args.cross_msgs
+ self.final_h_layer_norm = args.final_h_layer_norm
+ self.use_dist_in_layers = args.use_dist_in_layers
+ self.skip_weight_h = args.skip_weight_h
+ self.x_connection_init = args.x_connection_init
+ self.debug = args.debug
+ self.fine_tune = fine_tune
+ self.log = log_input
+ self.h_feats_dim = h_feats_dim
+ self.out_feats_dim = out_feats_dim
+ self.all_sigmas_dist = [1.5 ** x for x in range(15)]
+ self.orig_h_feats_dim = orig_h_feats_dim
+
+ # EDGES
+ self.edge_mlp = self.edge_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope)
+
+ # NODES
+ self.node_norm = nn.Identity()
+
+ self.att_mlp_q = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim, has_bias=False),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ )
+
+ self.att_mlp_k = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim, has_bias=False),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ )
+
+ self.att_mlp_v = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim, has_bias=False),
+ )
+
+ self.node_mlp = self.node_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope)
+
+ self.final_h_layernorm_layer = get_final_h_layer_norm(self.final_h_layer_norm, self.out_feats_dim)
+
+ ## The scalar weight to be multiplied by (x_i - x_j)
+ self.coors_mlp = self.coors_mlp_func(dropout, nonlin, leakyrelu_neg_slope)
+
+ if self.fine_tune:
+ self.att_mlp_cross_coors_q = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim, has_bias=False),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ )
+ self.att_mlp_cross_coors_k = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim, has_bias=False),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ )
+ self.att_mlp_cross_coors_v = nn.SequentialCell(
+ nn.Dense(h_feats_dim, h_feats_dim),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ nn.Dense(h_feats_dim, 1),
+ )
+
+ def __repr__(self):
+ return "IEGMNLayer " + str(self.__dict__)
+
+ def edge_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope):
+ """
+ IEGMNLayer edge_mlp_func
+ """
+ edge_mlp = nn.SequentialCell(
+ nn.Dense((self.h_feats_dim * 2) + self.input_edge_feats_dim
+ + len(self.all_sigmas_dist), self.out_feats_dim),
+ nn.Dropout(dropout),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ get_layer_norm(layer_norm, self.out_feats_dim),
+ nn.Dense(self.out_feats_dim, self.out_feats_dim),
+ )
+ return edge_mlp
+
+ def node_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope):
+
+ node_mlp = nn.SequentialCell(
+ nn.Dense(self.orig_h_feats_dim + 2 * self.h_feats_dim + self.out_feats_dim, self.h_feats_dim),
+ nn.Dropout(dropout),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ get_layer_norm(layer_norm, self.h_feats_dim),
+ nn.Dense(self.h_feats_dim, self.out_feats_dim),
+ )
+ return node_mlp
+
+ def coors_mlp_func(self, dropout, nonlin, leakyrelu_neg_slope):
+ coors_mlp = nn.SequentialCell(
+ nn.Dense(self.out_feats_dim, self.out_feats_dim),
+ nn.Dropout(dropout),
+ get_non_lin(nonlin, leakyrelu_neg_slope),
+ get_layer_norm(self.layer_norm_coors, self.out_feats_dim),
+ nn.Dense(self.out_feats_dim, 1)
+ )
+ return coors_mlp
+
+ def x_rel_mag(self, nodes_ligand_x_now, ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor):
+ """
+ IEGMNLayer x_rel_mag
+ """
+ nodes_ligand_x_now_src = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[0])
+ nodes_ligand_x_now_dst = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[1])
+ ll_edge_x_rel = ops.sub(nodes_ligand_x_now_src, nodes_ligand_x_now_dst)
+
+ nodes_receptor_x_now_src = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[0])
+ nodes_receptor_x_now_dst = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[1])
+ rr_edge_x_rel = ops.sub(nodes_receptor_x_now_src, nodes_receptor_x_now_dst)
+
+ x_rel_mag_ligand = ll_edge_x_rel ** 2
+ x_rel_mag_ligand = ops.sum(x_rel_mag_ligand, dim=1, keepdim=True) # ||x_i - x_j||^2 : (N_res, 1)
+ x_rel_mag_ligand = ops.cat([ops.exp(-x_rel_mag_ligand / sigma) for sigma in self.all_sigmas_dist], axis=-1)
+
+ x_rel_mag_receptor = rr_edge_x_rel ** 2
+ x_rel_mag_receptor = ops.sum(x_rel_mag_receptor, dim=1, keepdim=True)
+ x_rel_mag_receptor = ops.cat([ops.exp(-x_rel_mag_receptor / sigma) for sigma in self.all_sigmas_dist], axis=-1)
+
+ return ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor
+
+ def cat_input_for_msg(self, nodes_feat, original_edge_feats, connection_tensor, x_rel_mag):
+ nodes_feat_src = ops.index_select(nodes_feat, 0, connection_tensor[0])
+ nodes_feat_dst = ops.index_select(nodes_feat, 0, connection_tensor[1])
+ nodes_cat_feat = ops.cat((nodes_feat_src, nodes_feat_dst), axis=1)
+ cat_input_for_msg = ops.cat((nodes_cat_feat, # [h_i h_j]
+ original_edge_feats,
+ x_rel_mag), axis=-1)
+ return cat_input_for_msg
+
+ def nodes_aggr_cross_msg(self, h_feats_ligand, h_feats_receptor, mask):
+ """
+ IEGMNLayer nodes_aggr_cross_msg
+ """
+ nodes_ligand_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_ligand),
+ self.att_mlp_k(h_feats_receptor),
+ self.att_mlp_v(h_feats_receptor),
+ mask,
+ self.cross_msgs)
+ nodes_receptor_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_receptor),
+ self.att_mlp_k(h_feats_ligand),
+ self.att_mlp_v(h_feats_ligand),
+ mask.transpose(1, 0),
+ self.cross_msgs)
+ return nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg
+
+ def cal_x_final(self, edges_x_moment, edges_msg, orig_coors, nodes_x_now):
+ nodes_x_update = ops.mean(ops.reshape(edges_x_moment, (-1, 10, 3)), axis=1)
+ nodes_aggr_msg = ops.mean(ops.reshape(edges_msg, (-1, 10, edges_msg.shape[-1])), axis=1)
+ x_final = self.x_connection_init * orig_coors + (1. - self.x_connection_init) * nodes_x_now + \
+ nodes_x_update
+ return nodes_aggr_msg, x_final
+
+ def fine_tune_final_lr(self, input_fine_tune_tuple):
+ """
+ IEGMNLayer fine_tune_final_lr
+ """
+ x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, h_feats_receptor, \
+ nodes_receptor_x_now, mask = input_fine_tune_tuple
+
+ x_final_ligand = x_final_ligand + self.att_mlp_cross_coors_v(h_feats_ligand) * \
+ (nodes_ligand_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_ligand),
+ self.att_mlp_cross_coors_k(h_feats_receptor),
+ nodes_receptor_x_now,
+ mask,
+ self.cross_msgs))
+ x_final_receptor = x_final_receptor + self.att_mlp_cross_coors_v(h_feats_receptor) * \
+ (nodes_receptor_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_receptor),
+ self.att_mlp_cross_coors_k(h_feats_ligand),
+ nodes_ligand_x_now,
+ mask.transpose(1, 0),
+ self.cross_msgs))
+ return x_final_ligand, x_final_receptor
+
+ def skip_connections(self, h_feats_ligand, h_feats_receptor, input_node_upd_ligand, input_node_upd_receptor):
+ """
+ IEGMNLayer skip_connections
+ """
+ if self.h_feats_dim == self.out_feats_dim:
+ node_upd_ligand = self.skip_weight_h * self.node_mlp(input_node_upd_ligand) + (
+ 1. - self.skip_weight_h) * h_feats_ligand
+ node_upd_receptor = self.skip_weight_h * self.node_mlp(input_node_upd_receptor) + (
+ 1. - self.skip_weight_h) * h_feats_receptor
+ else:
+ node_upd_ligand = self.node_mlp(input_node_upd_ligand)
+ node_upd_receptor = self.node_mlp(input_node_upd_receptor)
+
+ node_upd_ligand = apply_final_h_layer_norm(node_upd_ligand, self.final_h_layernorm_layer)
+ node_upd_receptor = apply_final_h_layer_norm(node_upd_receptor, self.final_h_layernorm_layer)
+
+ return node_upd_ligand, node_upd_receptor
+
+ def log_debug_info(self, debug_variable_tuple):
+ """
+ IEGMNLayer log_debug_info
+ """
+ nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, \
+ nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, \
+ x_final_ligand = debug_variable_tuple
+ self.log(ops.max(nodes_ligand_x_now)[0], 'x_now : x_i at layer entrance')
+ self.log(ops.max(nodes_ligand_feat)[0], 'data[feat] = h_i at layer entrance')
+ self.log(ops.max(ll_edge_x_rel)[0], 'x_rel : x_i - x_j')
+ self.log(ops.max(x_rel_mag_ligand, axis=0)[0],
+ 'x_rel_mag_ligand = [exp(-||x_i - x_j||^2 / sigma) for sigma = 1.5 ** x, x = [0, 15]]')
+ self.log(ops.max(edges_ll_msg)[0], 'data[msg] = m_{i->j} = phi^e(h_i, h_j, f_{i,j}, x_rel_mag_ligand)')
+ self.log(ops.max(nodes_ligand_aggr_cross_msg)[0], 'aggr_cross_msg(i) = sum_j a_{i,j} * h_j')
+ self.log(ops.max(edge_coef_ligand)[0], 'edge_coef_ligand : ' + r'\p' + 'hi^x(m_{i->j})')
+ self.log(ops.max(edges_ll_x_moment)[0], 'data[x_moment] = (x_i - x_j) * ' + r'\p' + 'hi^x(m_{i->j})')
+ self.log(ops.max(nodes_ligand_aggr_msg)[0], 'data[aggr_msg]: ' + r'\s' + 'um_j m_{i->j}')
+ self.log(ops.max(x_final_ligand)[0], 'x_i new = x_final_ligand : x_i + data[x_update]')
+
+ def construct(self, iegmn_layer_tuple, input_tensor_tuple):
+ """
+ IEGMNLayer construct
+ """
+ ligand_graph_num_nodes, receptor_graph_num_nodes, ll_connection_tensor, rr_connection_tensor = \
+ input_tensor_tuple[0], input_tensor_tuple[1], input_tensor_tuple[4], input_tensor_tuple[5]
+
+ coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand, \
+ orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features, \
+ original_edge_feats_receptor, orig_coors_receptor = iegmn_layer_tuple
+
+ nodes_ligand_x_now, nodes_receptor_x_now, nodes_ligand_feat, nodes_receptor_feat = \
+ coors_ligand, coors_receptor, h_feats_ligand, h_feats_receptor
+
+ ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor = self.x_rel_mag(nodes_ligand_x_now, \
+ ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor)
+
+ if not self.use_dist_in_layers:
+ x_rel_mag_ligand = 0
+ x_rel_mag_receptor = 0
+
+ cat_input_for_msg_ligand = self.cat_input_for_msg(
+ nodes_ligand_feat, original_edge_feats_ligand, ll_connection_tensor, x_rel_mag_ligand)
+ cat_input_for_msg_receptor = self.cat_input_for_msg(
+ nodes_receptor_feat, original_edge_feats_receptor, rr_connection_tensor, x_rel_mag_receptor)
+
+ edges_ll_msg = self.edge_mlp(cat_input_for_msg_ligand) # m_{i->j}
+ edges_rr_msg = self.edge_mlp(cat_input_for_msg_receptor)
+
+ mask = get_mask(ligand_graph_num_nodes, receptor_graph_num_nodes)
+
+ nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg = self.nodes_aggr_cross_msg(h_feats_ligand,
+ h_feats_receptor, mask)
+ edge_coef_ligand = self.coors_mlp(edges_ll_msg) # \phi^x(m_{i->j})
+ edges_ll_x_moment = ll_edge_x_rel * edge_coef_ligand # (x_i - x_j) * \phi^x(m_{i->j})
+
+ edge_coef_receptor = self.coors_mlp(edges_rr_msg)
+ edges_rr_x_moment = rr_edge_x_rel * edge_coef_receptor
+
+ nodes_ligand_aggr_msg, x_final_ligand = self.cal_x_final(edges_ll_x_moment, edges_ll_msg,
+ orig_coors_ligand, nodes_ligand_x_now)
+ nodes_receptor_aggr_msg, x_final_receptor = self.cal_x_final(edges_rr_x_moment, edges_rr_msg,
+ orig_coors_receptor, nodes_receptor_x_now)
+
+ if self.fine_tune:
+ input_fine_tune = (
+ x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now,
+ h_feats_receptor, nodes_receptor_x_now, mask
+ )
+ x_final_ligand, x_final_receptor = self.fine_tune_final_lr(input_fine_tune)
+
+ if self.debug:
+ debug_variable_tuple = (
+ nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg,
+ nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, x_final_ligand,
+ )
+ self.log_debug_info(debug_variable_tuple)
+
+ input_node_upd_ligand = ops.cat((self.node_norm(nodes_ligand_feat), nodes_ligand_aggr_msg,
+ nodes_ligand_aggr_cross_msg, original_ligand_node_features), axis=-1)
+
+ input_node_upd_receptor = ops.cat((self.node_norm(nodes_receptor_feat), nodes_receptor_aggr_msg,
+ nodes_receptor_aggr_cross_msg, original_receptor_node_features), axis=-1)
+
+ node_upd_ligand, node_upd_receptor = self.skip_connections(h_feats_ligand, h_feats_receptor,
+ input_node_upd_ligand, input_node_upd_receptor)
+
+ return x_final_ligand, node_upd_ligand, x_final_receptor, node_upd_receptor
+
+
+class IEGMN(nn.Cell):
+ """
+ IEGMN class
+ """
+ def __init__(self, args, n_lays, fine_tune, log_input=None):
+
+ super(IEGMN, self).__init__()
+
+ self.debug = args.debug
+ self.log = log_input
+ self.graph_nodes = args.graph_nodes
+ self.rot_model = args.rot_model
+ self.iegmn_lay_hid_dim = args.iegmn_lay_hid_dim
+ self.noise_decay_rate = args.noise_decay_rate
+ self.noise_initial = args.noise_initial
+ self.use_edge_features_in_gmn = args.use_edge_features_in_gmn
+ self.use_mean_node_features = args.use_mean_node_features
+
+ # 21 types of amino-acid types
+ self.residue_emb_layer = nn.Embedding(vocab_size=21, embedding_size=args.residue_emb_dim, use_one_hot=False)
+ if self.graph_nodes != 'residues':
+ raise NotImplementedError
+ input_node_feats_dim = args.residue_emb_dim # One residue type
+
+ if self.use_mean_node_features:
+ input_node_feats_dim += 5 # Additional features from mu_r_norm
+
+ self.iegmn_layers = self.iegmn_layers_func(args, input_node_feats_dim, n_lays, fine_tune)
+
+ if args.rot_model != 'kb_att':
+ raise ValueError("args rot_model should be kb_att")
+
+ # Attention layers
+ self.num_att_heads = args.num_att_heads
+ self.out_feats_dim = self.iegmn_lay_hid_dim
+
+ self.att_mlp_key_rot = nn.SequentialCell(
+ nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False),
+ )
+ self.att_mlp_query_rot = nn.SequentialCell(
+ nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False),
+ )
+
+ self.mlp_h_mean_rot = nn.SequentialCell(
+ nn.Dense(self.out_feats_dim, self.out_feats_dim),
+ nn.Dropout(args.dropout),
+ get_non_lin(args.nonlin, args.leakyrelu_neg_slope),
+ )
+
+ def __repr__(self):
+ return "IEGMN " + str(self.__dict__)
+
+ def iegmn_layers_func(self, args, input_node_feats_dim, n_lays, fine_tune):
+ """
+ IEGMN iegmn_layers_func
+ """
+
+ iegmn_layers = nn.CellList()
+
+ iegmn_layers.append(
+ IEGMNLayer(orig_h_feats_dim=input_node_feats_dim,
+ h_feats_dim=input_node_feats_dim,
+ out_feats_dim=self.iegmn_lay_hid_dim,
+ fine_tune=fine_tune,
+ args=args,
+ log_input=self.log))
+
+ if args.shared_layers:
+ interm_lay = IEGMNLayer(orig_h_feats_dim=input_node_feats_dim,
+ h_feats_dim=self.iegmn_lay_hid_dim,
+ out_feats_dim=self.iegmn_lay_hid_dim,
+ args=args,
+ fine_tune=fine_tune,
+ log_input=self.log)
+ for _ in range(1, n_lays):
+ iegmn_layers.append(interm_lay)
+
+ else:
+ for _ in range(1, n_lays):
+ iegmn_layers.append(
+ IEGMNLayer(orig_h_feats_dim=input_node_feats_dim,
+ h_feats_dim=self.iegmn_lay_hid_dim,
+ out_feats_dim=self.iegmn_lay_hid_dim,
+ args=args,
+ fine_tune=fine_tune,
+ log_input=self.log))
+
+ return iegmn_layers
+
+ def att_weights_rot(self, h_feats, h_feats_att_mean_rot, d, z_coors, all_y_att_rot_list):
+ """
+ IEGMN att_weights_rot
+ """
+ after_key_rot = self.att_mlp_key_rot(h_feats).view(-1, self.num_att_heads, d).transpose(1, 0, 2)
+ after_query_rot = self.att_mlp_query_rot(h_feats_att_mean_rot).view(1, self.num_att_heads, d).transpose(1, 2, 0)
+ att_weights_rot = ops.softmax(after_key_rot @ after_query_rot / math.sqrt(d), axis=1)
+ att_weights_rot = att_weights_rot.view(self.num_att_heads, -1)
+
+ y_att_rot = att_weights_rot @ z_coors # K_heads, 3
+ all_y_att_rot_list.append(y_att_rot)
+
+ return y_att_rot, all_y_att_rot_list
+
+ def ap_compute(self, list_hetero_graph):
+ """
+ IEGMN ap_compute
+ """
+ all_t_align_list, all_b_align_list, all_y_receptor_att_rot_list, all_y_ligand_att_rot_list = [], [], [], []
+
+ for _, hetero_graph in enumerate(list_hetero_graph):
+
+ # Get H vectors
+ h_receptor_feats = hetero_graph["receptor_hv_iegmn_out"] # (m, d)
+ h_receptor_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_receptor_feats), axis=0,
+ keep_dims=True) # (1, d)
+
+ h_ligand_feats = hetero_graph["ligand_hv_iegmn_out"] # (n, d)
+ h_ligand_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_ligand_feats), axis=0,
+ keep_dims=True) # (1, d)
+
+ d = h_ligand_feats.shape[1]
+
+ # Z coordinates
+ z_receptor_coors = hetero_graph["receptor_x_iegmn_out"]
+ z_ligand_coors = hetero_graph["ligand_x_iegmn_out"]
+
+ #### AP 1: compute two point clouds of K_heads points each, then do Kabsch #####
+ # Att weights to compute the receptor centroid.
+ # They query is the average_h_ligand. Keys are each h_receptor_j.
+
+ y_receptor_att_rot, all_y_receptor_att_rot_list = self.att_weights_rot(
+ h_receptor_feats, h_ligand_feats_att_mean_rot, d, z_receptor_coors, all_y_receptor_att_rot_list
+ )
+
+ y_ligand_att_rot, all_y_ligand_att_rot_list = self.att_weights_rot(
+ h_ligand_feats, h_receptor_feats_att_mean_rot, d, z_ligand_coors, all_y_ligand_att_rot_list
+ )
+
+ ## Apply Kabsch algorithm
+ y_receptor_att_rot_mean = y_receptor_att_rot.mean(axis=0, keep_dims=True) # (1,3)
+ y_ligand_att_rot_mean = y_ligand_att_rot.mean(axis=0, keep_dims=True) # (1,3)
+
+ a = (y_receptor_att_rot - y_receptor_att_rot_mean).transpose(1, 0) @ (
+ y_ligand_att_rot - y_ligand_att_rot_mean) # 3, 3
+
+ if ops.isnan(a).any():
+ raise ValueError("There is Nan in a")
+ u, s, vt = np.linalg.svd(a.asnumpy())
+ u, s, vt = Tensor(u), Tensor(s), Tensor(vt)
+
+ num_it = 0
+ while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) -
+ (s ** 2).view(3, 1) + ops.eye(3)))[0] < 1e-2:
+ if self.debug:
+ self.log('S inside loop ', num_it, ' is ', s, ' and A = ', a)
+
+ a = a + ops.rand(3, 3) * ops.eye(3)
+ u, s, vt = np.linalg.svd(a.asnumpy())
+ u, s, vt = Tensor(u), Tensor(s), Tensor(vt)
+ num_it += 1
+
+ if num_it > 10:
+ self.log('SVD consistently numerically unstable! Exitting ... ')
+ raise ValueError('SVD consistently numerically unstable!')
+
+ corr_mat = ops.diag(Tensor([1, 1, float(ops.sign(ops.det(a)))]))
+ t_align = (u @ corr_mat) @ vt
+
+ b_align = y_receptor_att_rot_mean - ops.t(t_align @ y_ligand_att_rot_mean.t()) # (1,3)
+
+ #################### end AP 1 #########################
+
+ if self.debug:
+ self.log('Y_receptor_att_ROT_mean', y_receptor_att_rot_mean)
+ self.log('Y_ligand_att_ROT_mean', y_ligand_att_rot_mean)
+
+ all_t_align_list.append(t_align)
+ all_b_align_list.append(b_align)
+
+ return [all_t_align_list, all_b_align_list, all_y_ligand_att_rot_list, all_y_receptor_att_rot_list]
+
+ def construct(
+ self,
+ ligand_graph_node_tensor,
+ receptor_graph_node_tensor,
+ unbatch_list_tensor,
+ input_tensor_tuple,
+ ):
+ """
+ IEGMN construct
+ """
+ ligand_graph_edge_tensor = input_tensor_tuple[2]
+ receptor_graph_edge_tensor = input_tensor_tuple[3]
+
+ orig_coors_ligand = ligand_graph_node_tensor[:, 4:7]
+ orig_coors_receptor = receptor_graph_node_tensor[:, 1:4]
+
+ coors_ligand = ligand_graph_node_tensor[:, 4:7]
+ coors_receptor = receptor_graph_node_tensor[:, 1:4]
+
+ ## Embed residue types with a lookup table.
+ h_feats_ligand = self.residue_emb_layer(
+ ligand_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim)
+ h_feats_receptor = self.residue_emb_layer(
+ receptor_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim)
+
+ if self.use_mean_node_features:
+ h_feats_ligand = ops.cat([h_feats_ligand,
+ ops.log(ligand_graph_node_tensor[:, 7:12])], axis=1)
+ h_feats_receptor = ops.cat([h_feats_receptor,
+ ops.log(receptor_graph_node_tensor[:, 4:9])], axis=1)
+
+ original_ligand_node_features = h_feats_ligand
+ original_receptor_node_features = h_feats_receptor
+
+ original_edge_feats_ligand = ligand_graph_edge_tensor * self.use_edge_features_in_gmn
+ original_edge_feats_receptor = receptor_graph_edge_tensor * self.use_edge_features_in_gmn
+
+ for i, layer in enumerate(self.iegmn_layers):
+ if self.debug:
+ self.log('layer ', i)
+
+ iegmn_layer_tuple = (
+ coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand,
+ orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features,
+ original_edge_feats_receptor, orig_coors_receptor
+ )
+
+ coors_ligand, h_feats_ligand, coors_receptor, h_feats_receptor = \
+ layer(iegmn_layer_tuple=iegmn_layer_tuple, input_tensor_tuple=input_tensor_tuple)
+
+ if self.debug:
+ self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand before layers ')
+ self.log(ops.max(h_feats_ligand)[0], ops.norm(h_feats_ligand),
+ 'h_feats_ligand before layers but after mu_r_norm')
+ self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand after MPNN')
+ self.log(ops.max(coors_ligand)[0], 'coors_ligand before after MPNN')
+
+ list_hetero_graph = unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor,
+ h_feats_ligand, coors_receptor, coors_ligand)
+
+ ap_res = self.ap_compute(list_hetero_graph)
+
+ return ap_res
+
+
+class RigidBodyDockingNet(nn.Cell):
+ """
+ RigidBodyDockingNet
+ """
+ def __init__(self, args, log_input=None):
+
+ super(RigidBodyDockingNet, self).__init__()
+
+ self.debug = args.debug
+ self.log = log_input
+
+ self.iegmn_original = IEGMN(args, n_lays=args.iegmn_n_lays, fine_tune=False, log_input=log_input)
+ if args.fine_tune:
+ self.iegmn_fine_tune = IEGMN(args, n_lays=2, fine_tune=True, log_input=log_input)
+ self.list_iegmns = [('original', self.iegmn_original), ('finetune', self.iegmn_fine_tune)]
+ else:
+ self.list_iegmns = [('finetune', self.iegmn_original)] # just original
+
+ def __repr__(self):
+ return "RigidBodyDockingNet " + str(self.__dict__)
+
+ ####### FORWARD for RigidBodyDockingNet
+ def construct(
+ self,
+ ligand_graph_node_tensor,
+ receptor_graph_node_tensor,
+ unbatch_list_tensor,
+ input_tensor_tuple,
+ ):
+ """
+ construct
+ """
+ last_outputs = None
+ all_ligand_coors_deform_list = []
+
+ for stage, iegmn in self.list_iegmns:
+ outputs = iegmn(
+ ligand_graph_node_tensor,
+ receptor_graph_node_tensor,
+ unbatch_list_tensor,
+ input_tensor_tuple,
+ )
+
+ if len(outputs) != 4:
+ raise ValueError("Number of outputs not correct")
+
+ if stage == 'finetune':
+ last_outputs = outputs
+
+ if stage == 'original':
+ new_list_hetero_graph = []
+
+ list_hetero_graph = []
+ for i, _ in enumerate(unbatch_list_tensor):
+ if i < len(unbatch_list_tensor) - 1:
+ list_hetero_graph.append(
+ [
+ ligand_graph_node_tensor[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :],
+ receptor_graph_node_tensor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :],
+ ]
+ )
+
+ for the_idx, hetero_graph in enumerate(list_hetero_graph):
+ orig_coors_ligand = hetero_graph[0][:, 4:7]
+
+ t_align = outputs[0][the_idx]
+ b_align = outputs[1][the_idx]
+ if b_align.shape[0] != 1 or b_align.shape[1] != 3:
+ raise ValueError("Shape not correct")
+
+ inner_coors_ligand = (t_align @ orig_coors_ligand.t()).t() + b_align # (n,3)
+
+ if stage == 'original':
+ hetero_graph[0][:, 4:7] = inner_coors_ligand
+ new_list_hetero_graph.append(hetero_graph)
+
+ if self.debug:
+ self.log('T_align', t_align)
+ self.log('T_align @ T_align.t() - eye(3)', t_align @ t_align.t() - ops.eye(3))
+ self.log('b_align', b_align)
+ self.log('\n ---> inner_coors_ligand mean - true ligand mean ',
+ inner_coors_ligand.mean(axis=0)[0] - ligand_graph_node_tensor[:, 1:4].mean(axis=0)[0],
+ '\n')
+
+ if stage == 'finetune':
+ all_ligand_coors_deform_list.append(inner_coors_ligand)
+
+ all_keypts_ligand_list = last_outputs[2]
+ all_keypts_receptor_list = last_outputs[3]
+ all_rotation_list = last_outputs[0]
+ all_translation_list = last_outputs[1]
+
+ return all_ligand_coors_deform_list, \
+ all_keypts_ligand_list, all_keypts_receptor_list, \
+ all_rotation_list, all_translation_list
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0d422497ddf209e0b1e64251f562034f0e14361
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py
@@ -0,0 +1,236 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""train_utils"""
+
+import os
+import math
+
+import ot
+import numpy as np
+import mindspore as ms
+from mindspore import ops, Tensor
+
+from .nn_arch import (
+ preprocess_unbound_bound,
+ protein_to_graph_unbound_bound,
+ get_residues,
+ log,
+)
+
+
+def create_dir(path):
+ if os.path.exists(path):
+ raise FileExistsError('Path already exists. Please delete and restart your job.')
+ os.makedirs(path, exist_ok=False)
+
+
+def graph_to_tensor(ligand_graph, receptor_graph):
+ """
+ graph_to_tensor
+ """
+ ligand_graph.ndata['new_x'] = ligand_graph.ndata['x']
+
+ # Create a batch of a single heterograph
+ ligand_graph_node_tensor = ops.cat(
+ (ligand_graph.ndata["res_feat"],
+ ligand_graph.ndata["x"],
+ ligand_graph.ndata["new_x"],
+ ligand_graph.ndata["mu_r_norm"]),
+ axis=1
+ )
+
+ receptor_graph_node_tensor = ops.cat(
+ (receptor_graph.ndata["res_feat"],
+ receptor_graph.ndata["x"],
+ receptor_graph.ndata["mu_r_norm"]),
+ axis=1
+ )
+
+ ligand_graph_num_nodes = Tensor([ligand_graph.num_nodes])
+ receptor_graph_num_nodes = Tensor([receptor_graph.num_nodes])
+ ligand_graph_edge_tensor = ligand_graph.edata["he"]
+ receptor_graph_edge_tensor = receptor_graph.edata["he"]
+
+ ll_connection_tensor = ops.stack(
+ (Tensor(ligand_graph.src_list), Tensor(ligand_graph.dst_list)))
+
+ rr_connection_tensor = ops.stack(
+ (Tensor(receptor_graph.src_list), Tensor(receptor_graph.dst_list)))
+
+ unbatch_list = [
+ [
+ len(ll_connection_tensor[0]),
+ len(rr_connection_tensor[0]),
+ len(ligand_graph_node_tensor),
+ len(receptor_graph_node_tensor),
+ 0,
+ ]
+ ]
+ unbatch_list.insert(0, np.array([0, 0, 0, 0, 0]))
+
+ unbatch_list = Tensor(unbatch_list, ms.int32)
+
+ input_tensor_tuple = (
+ ligand_graph_num_nodes,
+ receptor_graph_num_nodes,
+ ligand_graph_edge_tensor,
+ receptor_graph_edge_tensor,
+ ll_connection_tensor,
+ rr_connection_tensor,
+ )
+
+ return ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple
+
+
+def prepare_graphs(args, ppdb_ligand, ligand_filename, receptor_filename):
+ """
+ prepare_graphs
+ """
+ unbound_ligand_all_atoms_pre_pos = ppdb_ligand.df["ATOM"][
+ ['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)
+
+ ligand_input = get_residues(ligand_filename)
+ receptor_input = get_residues(receptor_filename)
+ input_tuple = (ligand_input, receptor_input)
+ unbound_predic_ligand, \
+ unbound_predic_receptor, \
+ bound_ligand_repres_nodes_loc_clean_array, \
+ bound_receptor_repres_nodes_loc_clean_array, _ = preprocess_unbound_bound(
+ input_residues_tuple=input_tuple,
+ graph_nodes=args.graph_nodes, pos_cutoff=args.pocket_cutoff, inference=True)
+
+ protein_to_graph_unbound_bound_input_tuple = (
+ unbound_predic_ligand, unbound_predic_receptor,
+ bound_ligand_repres_nodes_loc_clean_array, bound_receptor_repres_nodes_loc_clean_array,
+ )
+
+ ligand_graph, receptor_graph = protein_to_graph_unbound_bound(
+ protein_to_graph_unbound_bound_input_tuple,
+ cutoff=args.graph_cutoff,
+ max_neighbor=args.graph_max_neighbor,
+ one_hot=False,
+ residue_loc_is_alphac=args.graph_residue_loc_is_alphaC,
+ )
+
+ return ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array
+
+
+def get_rot_mat(euler_angles):
+ """
+ get_rot_mat
+ """
+ roll = euler_angles[0]
+ yaw = euler_angles[1]
+ pitch = euler_angles[2]
+
+ tensor_0 = Tensor(0.0)
+ tensor_1 = Tensor(1.0)
+ cos = ops.cos
+ sin = ops.sin
+
+ rx = ops.stack([
+ ops.stack([tensor_1, tensor_0, tensor_0]),
+ ops.stack([tensor_0, cos(roll), -sin(roll)]),
+ ops.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3, 3)
+
+ ry = ops.stack([
+ ops.stack([cos(pitch), tensor_0, sin(pitch)]),
+ ops.stack([tensor_0, tensor_1, tensor_0]),
+ ops.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3, 3)
+
+ rz = ops.stack([
+ ops.stack([cos(yaw), -sin(yaw), tensor_0]),
+ ops.stack([sin(yaw), cos(yaw), tensor_0]),
+ ops.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3)
+
+ r = ops.mm(rz, ry)
+ r = ops.mm(r, rx)
+
+ return r
+
+
+def compute_ot_emd(cost_mat):
+ cost_mat_detach = cost_mat.asnumpy()
+ a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0]
+ b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1]
+ ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000)
+ ot_mat_attached = Tensor(ot_mat, ms.float32)
+ ot_dist = ops.sum(ot_mat_attached * cost_mat)
+ return ot_dist, ot_mat_attached
+
+
+def g_fn(protein_coords, x, sigma):
+ # protein_coords: (n,3) , x: (m,3), output: (m,)
+ e = ops.exp(- ops.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n)
+
+ return - sigma * ops.log(1e-3 + ops.sum(e, dim=1))
+
+
+def compute_body_intersection_loss(
+ model_ligand_coors_deform,
+ bound_receptor_repres_nodes_loc_array,
+ sigma,
+ surface_ct,
+ ):
+ """
+ compute_body_intersection_loss
+ """
+ g_fn_out1 = g_fn(Tensor(bound_receptor_repres_nodes_loc_array), model_ligand_coors_deform, sigma)
+ g_fn_out2 = g_fn(model_ligand_coors_deform, Tensor(bound_receptor_repres_nodes_loc_array), sigma)
+ loss1 = ops.clamp(surface_ct - g_fn_out1, min=0)
+ loss2 = ops.clamp(surface_ct - g_fn_out2, min=0)
+ loss = ops.mean(loss1) + ops.mean(loss2)
+
+ return loss
+
+
+def compute_sq_dist_mat(x_1, x_2):
+ '''Computes the l2 squared cost matrix between two point cloud inputs.
+ Args:
+ X_1: [n, #features] point cloud, tensor
+ X_2: [m, #features] point cloud, tensor
+ Output:
+ [n, m] matrix of the l2 distance between point pairs
+ '''
+ n_1, _ = x_1.shape
+ n_2, _ = x_2.shape
+ x_1 = x_1.view(n_1, 1, -1)
+ x_2 = x_2.view(1, n_2, -1)
+ squared_dist = (x_1 - x_2) ** 2
+ cost_mat = ops.sum(squared_dist, dim=2)
+ return cost_mat
+
+
+def pretty_print_stats(*args):
+
+ split_type, epoch, total_num_epochs,\
+ complex_rmsd_mean, complex_rmsd_median,\
+ ligand_rmsd_mean, ligand_rmsd_median,\
+ receptor_rmsd_mean, receptor_rmsd_median,\
+ _, _, _, avg_loss_ot, avg_loss_intersection, _ = args
+
+ log('[{:s}] --> epoch {:d}/{:d} '
+ '|| mean/median complex rmsd {:.4f} / {:.4f} '
+ '|| mean/median ligand rmsd {:.4f} / {:.4f} '
+ '|| mean/median sqrt pocket OT loss {:.4f} '
+ '|| intersection loss {:.4f} '
+ '|| mean/median receptor rmsd {:.4f} / {:.4f} '.
+ format(split_type,
+ epoch, total_num_epochs,
+ complex_rmsd_mean, complex_rmsd_median,
+ ligand_rmsd_mean, ligand_rmsd_median,
+ math.sqrt(avg_loss_ot),
+ avg_loss_intersection,
+ receptor_rmsd_mean, receptor_rmsd_median))
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc1d55347858e35580e54fe2626596072aaca6c8
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py
@@ -0,0 +1,19 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""progen"""
+
+from .progen import ProGen
+from .progen_dataset import ProGenDataSet
+from .progen_configuration import progen_configuration
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0772554a6104e644fe8b593b9da474b04c67fc88
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""module"""
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd34b9e754de074c10ab83285db6861f4d3884e8
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py
@@ -0,0 +1,2295 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""configuration_utils"""
+
+import copy
+from copy import deepcopy
+import json
+import os
+import warnings
+import inspect
+from typing import Optional, List, Callable, Dict, Any, Tuple, Union
+from enum import Enum
+from collections import OrderedDict, UserDict
+from dataclasses import fields
+from dataclasses import dataclass
+
+import numpy as np
+import mindspore
+from mindspore import nn, ops, Tensor, jit_class
+
+from .logits_process import (
+ ForcedEOSTokenLogitsProcessor,
+ LogitsProcessorList,
+ MinLengthLogitsProcessor,
+ MinNewTokensLengthLogitsProcessor,
+ NoBadWordsLogitsProcessor,
+ TemperatureLogitsWarper,
+ TopKLogitsWarper,
+ TopPLogitsWarper,
+ TypicalLogitsWarper,
+)
+
+DEFAULT_DTYPE = mindspore.float32
+INIT_WEIGHTS_FLAG = True
+
+
+def _is_mindspore(x):
+
+ return isinstance(x, mindspore.Tensor)
+
+def is_mindspore_available():
+ return mindspore.get_context('device_target') == 'Ascend'
+
+def is_mindspore_tensor(x):
+ """
+ Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed.
+ """
+ return False if not is_mindspore_available() else _is_mindspore(x)
+
+def no_grad(func):
+ """no grad wrapper"""
+ def wrapper(*args, **kwargs):
+ _pynative_executor.set_enable_grad(False)
+ outputs = func(*args, **kwargs)
+ _pynative_executor.set_enable_grad(True)
+ return outputs
+ return wrapper
+
+
+class CellUtilMixin:
+ """
+ A few utilities to be used as a mixin.
+ """
+
+ def get_head_mask(
+ self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
+ ) -> Tensor:
+ """
+ Prepare the head mask if needed.
+ """
+ if head_mask is not None:
+ head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
+ if is_attention_chunked is True:
+ head_mask = head_mask.expand_dims(-1)
+ else:
+ head_mask = ()
+ for _ in range(num_hidden_layers):
+ head_mask += (None,)
+
+ return head_mask
+
+ def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
+ """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
+ if head_mask.ndim == 1:
+ head_mask = head_mask.expand_dims(0).expand_dims(0).expand_dims(-1).expand_dims(-1)
+ head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1)
+ elif head_mask.ndim == 2:
+ head_mask = head_mask.expand_dims(1).expand_dims(-1)\
+ .expand_dims(-1) # We can specify head_mask for each layer
+ assert head_mask.ndim == 5, f"head_mask.dim != 5, instead {head_mask.ndim}"
+ head_mask = head_mask.astype(dtype=self.dtype) # switch to float if need + fp16 compatibility
+ return head_mask
+
+ @property
+ def dtype(self) -> mindspore.TensorType:
+ """
+ `mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
+ """
+ return get_parameter_dtype(self)
+
+
+class ModelOutputMindnlp(OrderedDict):
+ """
+ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a
+ tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular
+ python dictionary.
+
+
+
+ You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple
+ before.
+
+
+ """
+
+ def __post_init__(self):
+ class_fields = fields(self)
+
+ # Safety and consistency checks
+ if not class_fields:
+ raise ValueError(f"{self.__class__.__name__} has no fields.")
+ if not all(field.default is None for field in class_fields[1:]):
+ raise ValueError(f"{self.__class__.__name__} should not have more than one required field.")
+
+ first_field = getattr(self, class_fields[0].name)
+ other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:])
+
+ if other_fields_are_none and not is_tensor(first_field):
+ if isinstance(first_field, dict):
+ iterator = first_field.items()
+ first_field_iterator = True
+ else:
+ try:
+ iterator = iter(first_field)
+ first_field_iterator = True
+ except TypeError:
+ first_field_iterator = False
+
+ # if we provided an iterator as first field and the iterator is a (key, value) iterator
+ # set the associated fields
+ if first_field_iterator:
+ for idx, element in enumerate(iterator):
+ if (
+ not isinstance(element, (list, tuple))
+ or not len(element) == 2
+ or not isinstance(element[0], str)
+ ):
+ if idx == 0:
+ # If we do not have an iterator of key/values, set it as attribute
+ self[class_fields[0].name] = first_field
+ else:
+ # If we have a mixed iterator, raise an error
+ raise ValueError(
+ f"Cannot set key/value for {element}. It needs to be a tuple (key, value)."
+ )
+ break
+ setattr(self, element[0], element[1])
+ if element[1] is not None:
+ self[element[0]] = element[1]
+ elif first_field is not None:
+ self[class_fields[0].name] = first_field
+ else:
+ for field in class_fields:
+ v = getattr(self, field.name)
+ if v is not None:
+ self[field.name] = v
+
+ def __delitem__(self, *args, **kwargs):
+ raise RuntimeError(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
+
+ def setdefault(self, *args, **kwargs):
+ raise RuntimeError(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
+
+ def pop(self, *args, **kwargs):
+ raise RuntimeError(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
+
+ def update(self, *args, **kwargs):
+ raise RuntimeError(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
+
+ def __getitem__(self, k):
+ if isinstance(k, str):
+ inner_dict = dict(self.items())
+ return inner_dict[k]
+ return self.to_tuple()[k]
+
+ def __setattr__(self, name, value):
+ if name in self.keys() and value is not None:
+ # Don't call self.__setitem__ to avoid recursion errors
+ super().__setitem__(name, value)
+ super().__setattr__(name, value)
+
+ def __setitem__(self, key, value):
+ # Will raise a KeyException if needed
+ super().__setitem__(key, value)
+ # Don't call self.__setattr__ to avoid recursion errors
+ super().__setattr__(key, value)
+
+ def to_tuple(self) -> Tuple[Any]:
+ """
+ Convert self to a tuple containing all the attributes/keys that are not `None`.
+ """
+ return tuple(v for _, v in self.items())
+
+
+@dataclass
+class GenerateDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Outputs of decoder-only generation models, when using non-beam methods.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ logits: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None
+
+
+@dataclass
+class GenerateEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Outputs of encoder-decoder generation models, when using non-beam methods.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ logits: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None
+
+
+@dataclass
+class SampleDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using sampling.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+@dataclass
+class SampleEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of
+ the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
+ attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+@dataclass
+class BeamSampleDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using beam sample.
+ """
+
+ sequences: mindspore.Tensor = None
+ sequences_scores: Optional[mindspore.Tensor] = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ beam_indices: Optional[mindspore.Tensor] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+@dataclass
+class BeamSampleEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention
+ weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
+ encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
+ """
+
+ sequences: mindspore.Tensor = None
+ sequences_scores: Optional[mindspore.Tensor] = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ beam_indices: Optional[mindspore.Tensor] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+@dataclass
+class BeamSearchDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using beam search.
+ """
+
+ sequences: mindspore.Tensor = None
+ sequences_scores: Optional[mindspore.Tensor] = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ beam_indices: Optional[mindspore.Tensor] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+@dataclass
+class BeamSearchEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights
+ of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states
+ attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
+ """
+
+ sequences: mindspore.Tensor = None
+ sequences_scores: Optional[mindspore.Tensor] = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ beam_indices: Optional[mindspore.Tensor] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+@dataclass
+class GreedySearchDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using greedy search.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+@dataclass
+class GreedySearchEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention
+ weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the
+ encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes)
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+@dataclass
+class ContrastiveSearchEncoderDecoderOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using contrastive search.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None
+ encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+@dataclass
+class ContrastiveSearchDecoderOnlyOutput(ModelOutputMindnlp):
+ """
+ Base class for outputs of decoder-only generation models using contrastive search.
+ """
+
+ sequences: mindspore.Tensor = None
+ scores: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+
+
+GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput]
+SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput]
+BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput]
+BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput]
+ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput]
+GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput]
+
+
+class StoppingCriteriaList(list):
+ def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
+ return any(criteria(input_ids, scores) for criteria in self)
+
+ @property
+ def max_length(self) -> Optional[int]:
+ for stopping_criterium in self:
+ if not isinstance(stopping_criterium, MaxLengthCriteria):
+ raise TypeError
+ return stopping_criterium.max_length
+
+
+class GenerationConfig:
+ """
+ Class that holds a configuration for a generation task.
+ """
+ def __init__(self, **kwargs):
+ # Parameters that control the length of the output
+ self.max_length = kwargs.pop("max_length", 20)
+ self.max_new_tokens = kwargs.pop("max_new_tokens", None)
+ self.min_length = kwargs.pop("min_length", 0)
+ self.min_new_tokens = kwargs.pop("min_new_tokens", None)
+ self.early_stopping = kwargs.pop("early_stopping", False)
+ self.use_cache = kwargs.pop("use_cache", True)
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
+
+ # # Parameters that define the output variables of `generate`
+ self.output_attentions = kwargs.pop("output_attentions", False)
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
+ self.output_scores = kwargs.pop("output_scores", False)
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
+
+ # Special tokens that can be used at generation time
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
+ self._from_model_config = kwargs.pop("_from_model_config", False)
+
+ # Additional attributes without default values
+ if not self._from_model_config:
+ # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
+ # model's default configuration file
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error("Can't set %s with value %s", key, value)
+ raise err
+
+ # Validate the values of the attributes
+ self.validate()
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig":
+ """
+ Instantiates a [`GenerationConfig`] from a Python dictionary of parameters.
+ """
+ # Those arguments may be passed along for our internal telemetry.
+ # We remove them so they don't appear in `return_unused_kwargs`.
+ kwargs.pop("_from_auto", None)
+ kwargs.pop("_from_pipeline", None)
+ # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update.
+ commit_hash_str = "_commit_hash"
+ if commit_hash_str in kwargs and commit_hash_str in config_dict:
+ kwargs[commit_hash_str] = config_dict[commit_hash_str]
+
+ config = cls(**config_dict)
+ unused_kwargs = config.update(**kwargs)
+
+ return config
+
+ @classmethod
+ def from_model_config(cls, model_config) -> "GenerationConfig":
+ """
+ Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy
+ [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`].
+ """
+ config_dict = model_config.to_dict()
+ config = cls.from_dict(config_dict, return_unused_kwargs=False)
+ config.set_from_model_config(True)
+
+ return config
+
+
+ def set_from_model_config(self, value: bool):
+ """set _from_model_config"""
+ if not isinstance(value, bool):
+ raise TypeError
+ self._from_model_config = value
+
+ def update(self, **kwargs):
+ """
+ Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes,
+ returning all the unused kwargs.
+ """
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(self, key):
+ setattr(self, key, value)
+ to_remove.append(key)
+
+ # remove all the attributes that were updated, without modifying the input dict
+ unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove}
+ return unused_kwargs
+
+ def validate(self):
+ """
+ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence
+ of parameterization that can be detected as incorrect from the configuration instance alone.
+
+ Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the
+ model, such as parameters related to the generation length.
+ """
+
+ # Validation of individual attributes
+ if self.early_stopping not in {True, False, "never"}:
+ raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.")
+
+
+class GenerationMixin:
+ """
+ class GenerationMixin
+ A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`].
+ """
+
+ @staticmethod
+ def _expand_inputs_for_generation(
+ expand_size: int = 1,
+ input_ids: Optional[mindspore.Tensor] = None,
+ **model_kwargs,
+ ) -> Tuple[mindspore.Tensor, Dict[str, Any]]:
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
+
+ def _expand_dict_for_generation(dict_to_expand):
+ for key in dict_to_expand:
+ if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], mindspore.Tensor):
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
+ return dict_to_expand
+
+ if input_ids is not None:
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
+
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
+
+ return input_ids, model_kwargs
+
+ @staticmethod
+ def prepare_inputs_for_generation(*args, **kwargs):
+ """
+ prepare_inputs_for_generation
+ """
+ raise NotImplementedError(
+ "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`."
+ )
+
+ @classmethod
+ def _maybe_initialize_input_ids_for_generation(
+ cls,
+ inputs: Optional[mindspore.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None,
+ ) -> mindspore.Tensor:
+ """Initializes input ids for generation, if necessary."""
+ if inputs is not None:
+ return inputs
+
+ if bos_token_id is None:
+ raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.")
+
+ # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with
+ # soft-prompting or in multimodal implementations built on top of decoder-only language models.
+ batch_size = 1
+ for value in model_kwargs.values():
+ if isinstance(value, mindspore.Tensor):
+ batch_size = value.shape[0]
+ break
+ return ops.ones((batch_size, 1), dtype=mindspore.int64) * bos_token_id
+
+ @classmethod
+ def _prepare_decoder_input_ids_for_generation(
+ cls,
+ model_input_name: str,
+ model_kwargs: Dict[str, mindspore.Tensor],
+ ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]:
+ """Prepares `decoder_input_ids` for generation with encoder-decoder models"""
+ # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming,
+ # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input.
+ input_ids_str = "input_ids_str"
+ if model_kwargs is not None and "decoder_input_ids" in model_kwargs:
+ decoder_input_ids = model_kwargs.pop("decoder_input_ids")
+ elif input_ids_str in model_kwargs and model_input_name != input_ids_str:
+ decoder_input_ids = model_kwargs.pop(input_ids_str)
+ else:
+ decoder_input_ids = None
+
+ return decoder_input_ids, model_kwargs
+
+ @classmethod
+ def _merge_criteria_processor_list(
+ cls,
+ default_list: Union[LogitsProcessorList, StoppingCriteriaList],
+ custom_list: Union[LogitsProcessorList, StoppingCriteriaList],
+ ) -> Union[LogitsProcessorList, StoppingCriteriaList]:
+ """"
+ merge the criteria processor list
+ """
+ if not custom_list:
+ return default_list
+ for default in default_list:
+ for custom in custom_list:
+ if type(custom) is type(default):
+ object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor"
+ raise ValueError(
+ f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to"
+ f" `.generate()`, but it has already been created with the values {default}. {default} has been"
+ " created by passing the corresponding arguments to generate or by the model's config default"
+ f" values. If you just want to change the default values of {object_type} consider passing"
+ f" them as arguments to `.generate()` instead of using a custom {object_type}."
+ )
+ default_list.extend(custom_list)
+ return default_list
+
+ def _get_logits_warper(
+ self,
+ generation_config: GenerationConfig,
+ ) -> LogitsProcessorList:
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances
+ used for multinomial sampling.
+ """
+
+ # instantiate warpers list
+ warpers = LogitsProcessorList()
+
+ # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a
+ # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1)
+ if generation_config.num_beams > 1:
+ if isinstance(generation_config.eos_token_id, list):
+ min_tokens_to_keep = len(generation_config.eos_token_id) + 1
+ else:
+ min_tokens_to_keep = 2
+ else:
+ min_tokens_to_keep = 1
+
+ # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files
+ # all samplers can be found in `generation_utils_samplers.py`
+ if generation_config.temperature is not None and generation_config.temperature != 1.0:
+ warpers.append(TemperatureLogitsWarper(generation_config.temperature))
+ if generation_config.top_k is not None and generation_config.top_k != 0:
+ warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.top_p is not None and generation_config.top_p < 1.0:
+ warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep))
+ if generation_config.typical_p is not None and generation_config.typical_p < 1.0:
+ warpers.append(
+ TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep)
+ )
+ return warpers
+
+ def generate(
+ self,
+ inputs: Optional[mindspore.Tensor] = None,
+ generation_config: Optional[GenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **kwargs,
+ ) -> Union[GreedySearchDecoderOnlyOutput, mindspore.Tensor]:
+ """
+ generate method
+ """
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ self._validate_model_class()
+ generation_config = copy.deepcopy(self.generation_config)
+ model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
+ generation_config.validate()
+ self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+
+ # 3. Define model inputs
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs)
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
+ model_kwargs["use_cache"] = True
+ else:
+ model_kwargs["use_cache"] = generation_config.use_cache
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ if not self.config.is_encoder_decoder:
+ input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids")
+ if streamer is not None:
+ streamer.put(input_ids)
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_length = input_ids.shape[-1]
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
+
+ # 7. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_length,
+ logits_processor=logits_processor)
+
+ # 8. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria)
+
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ streamer=streamer,
+ **model_kwargs,
+ )
+
+ def sample(
+ self,
+ input_ids: mindspore.Tensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ streamer: Optional["BaseStreamer"] = None,
+ **model_kwargs,
+ ) -> Union[SampleOutput, mindspore.Tensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+ """
+ # init values
+ synced_gpus = None
+ # init values
+ logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
+ stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList()
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length)
+ logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList()
+ pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id
+ eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ eos_token_id_tensor = mindspore.tensor(eos_token_id) if eos_token_id is not None else None
+ output_scores = output_scores if output_scores is not None else self.generation_config.output_scores
+ output_attentions = (
+ output_attentions if output_attentions is not None else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
+
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
+ if return_dict_in_generate and self.config.is_encoder_decoder:
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
+ encoder_hidden_states = (
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
+ )
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = ops.ones(input_ids.shape[0], dtype=mindspore.int64)
+
+ this_peer_finished = False # used by synced_gpus only
+ # auto-regressive generation
+ while True:
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation_new(input_ids, **model_kwargs)
+ # forward pass to get next token
+ outputs = self.construct(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ next_token_logits = outputs.logits[:, -1, :]
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # sample
+ probs = ops.softmax(next_token_scores, axis=-1)
+ next_tokens = ops.multinomial(probs, num_samples=1).squeeze(1).astype(mindspore.int64)
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
+ # update generated ids, model inputs, and length for next step
+ input_ids = ops.cat([input_ids, next_tokens[:, None]], axis=-1)
+ if streamer is not None:
+ streamer.put(next_tokens)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id_tensor is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ next_tokens.tile((eos_token_id_tensor.shape[0], 1)).ne(
+ eos_token_id_tensor.unsqueeze(1)).prod(axis=0)
+ )
+
+ # stop when each sentence is finished
+ if unfinished_sequences.max() == 0:
+ this_peer_finished = True
+
+ # stop if we exceed the maximum length
+ if stopping_criteria(input_ids, scores):
+ this_peer_finished = True
+
+ if this_peer_finished and not synced_gpus:
+ break
+ if streamer is not None:
+ streamer.end()
+ if return_dict_in_generate:
+ if self.config.is_encoder_decoder:
+ return SampleEncoderDecoderOutput(
+ sequences=input_ids,
+ scores=scores,
+ encoder_attentions=encoder_attentions,
+ encoder_hidden_states=encoder_hidden_states,
+ decoder_attentions=decoder_attentions,
+ cross_attentions=cross_attentions,
+ decoder_hidden_states=decoder_hidden_states,
+ )
+ return SampleDecoderOnlyOutput(
+ sequences=input_ids,
+ scores=scores,
+ attentions=decoder_attentions,
+ hidden_states=decoder_hidden_states,
+ )
+ return input_ids
+
+ def _prepare_model_inputs(
+ self,
+ inputs: Optional[mindspore.Tensor] = None,
+ bos_token_id: Optional[int] = None,
+ model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None,
+ ) -> Tuple[mindspore.Tensor, Optional[str], Dict[str, mindspore.Tensor]]:
+ """
+ This function extracts the model-specific `inputs` for generation.
+ """
+ # 1. retrieve all kwargs that are non-None or non-model input related.
+ # some encoder-decoder models have different names for model and encoder
+ if not self.config.is_encoder_decoder:
+ input_name = self.main_input_name
+
+ model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name}
+
+ # 2. check whether model_input_name is passed as kwarg
+ # if yes and `inputs` is None use kwarg inputs
+ inputs_kwarg = model_kwargs.pop(input_name, None)
+ if inputs_kwarg is not None and inputs is not None:
+ raise ValueError(
+ f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. "
+ f"Make sure to either pass {inputs} or {input_name}=..."
+ )
+ if inputs_kwarg is not None:
+ inputs = inputs_kwarg
+
+ # 3. In the presence of `inputs_embeds` for text models:
+ # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model
+ # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with
+ # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`)
+ inputs_embeds_str = "inputs_embeds"
+ input_ids_str = "input_ids"
+ if input_name == input_ids_str and inputs_embeds_str in model_kwargs:
+ if not self.config.is_encoder_decoder:
+ has_inputs_embeds_forwarding = inputs_embeds_str in set(
+ inspect.signature(self.prepare_inputs_for_generation).parameters.keys()
+ )
+ if not has_inputs_embeds_forwarding:
+ raise ValueError(
+ f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} "
+ "doesn't have its forwarding implemented. See the GPT2 implementation for an example "
+ "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!"
+ )
+ # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of
+ # the attention mask) can rely on the actual model input.
+ model_kwargs[input_ids_str] = self._maybe_initialize_input_ids_for_generation(
+ inputs, bos_token_id, model_kwargs=model_kwargs
+ )
+ else:
+ if inputs is not None:
+ raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.")
+ inputs, input_name = model_kwargs[inputs_embeds_str][0], model_kwargs[inputs_embeds_str][1]
+
+ # 4. if `inputs` is still None, try to create `input_ids` from BOS token
+ inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs)
+ return inputs, input_name, model_kwargs
+
+ def _extract_past_from_model_output(self, outputs: ModelOutputMindnlp, standardize_cache_format: bool = False):
+ """
+ extract past from model output
+ """
+ past_key_values = None
+ if "past_key_values" in outputs:
+ past_key_values = outputs.past_key_values
+ elif "mems" in outputs:
+ past_key_values = outputs.mems
+ elif "past_buckets_states" in outputs:
+ past_key_values = outputs.past_buckets_states
+
+ # Bloom fix: standardizes the cache format when requested
+ if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"):
+ batch_size = outputs.logits.shape[0]
+ past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size)
+ return past_key_values
+
+ def _update_model_kwargs_for_generation(
+ self,
+ outputs,
+ model_kwargs: Dict[str, Any],
+ is_encoder_decoder: bool = False,
+ standardize_cache_format: bool = False,
+ ) -> Dict[str, Any]:
+ """
+ update past_key_values and update token_type_ids with last value
+ """
+ model_kwargs["past_key_values"] = self._extract_past_from_model_output(
+ outputs, standardize_cache_format=standardize_cache_format
+ )
+ model_kwargs["is_encoder_decoder"] = is_encoder_decoder
+
+ token_type_str = "token_type_ids"
+ if token_type_str in model_kwargs:
+ token_type_ids = model_kwargs[token_type_str]
+ model_kwargs[token_type_str] = ops.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1)
+
+ return model_kwargs
+
+ def _get_logits_processor(
+ self,
+ generation_config: GenerationConfig,
+ input_ids_seq_length: int,
+ logits_processor: Optional[LogitsProcessorList],
+ ) -> LogitsProcessorList:
+ """
+ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`]
+ instances used to modify the scores of the language model head.
+ """
+ # instantiate processors list
+ processors = LogitsProcessorList()
+
+ if generation_config.bad_words_ids is not None:
+ processors.append(
+ NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id)
+ )
+ if (
+ generation_config.min_length is not None
+ and generation_config.eos_token_id is not None
+ and generation_config.min_length > 0
+ ):
+ processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id))
+ if (
+ generation_config.min_new_tokens is not None
+ and generation_config.eos_token_id is not None
+ and generation_config.min_new_tokens > 0
+ ):
+ processors.append(
+ MinNewTokensLengthLogitsProcessor(
+ input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id
+ )
+ )
+
+ if generation_config.forced_eos_token_id is not None:
+ processors.append(
+ ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id)
+ )
+
+ processors = self._merge_criteria_processor_list(processors, logits_processor)
+
+ return processors
+
+ def _get_stopping_criteria(
+ self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList]
+ ) -> StoppingCriteriaList:
+ criteria = StoppingCriteriaList()
+ if generation_config.max_length is not None:
+ criteria.append(MaxLengthCriteria(max_length=generation_config.max_length))
+ criteria = self._merge_criteria_processor_list(criteria, stopping_criteria)
+ return criteria
+
+ def _validate_model_class(self):
+ """
+ Confirms that the model class is compatible with generation. If not, raises an exception that points to the
+ right class to use.
+ """
+ if not self.can_generate():
+ raise NotImplementedError(
+ "TODO: You need to implement this function."
+ )
+
+ def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]):
+ """Validates model kwargs for generation. Generate argument typos will also be caught here."""
+ # Excludes arguments that are handled before calling any model function
+ if self.config.is_encoder_decoder:
+ for key in ["decoder_input_ids"]:
+ model_kwargs.pop(key, None)
+
+ def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
+ """Performs validation related to the resulting generated length"""
+
+ # 1. Max length warnings related to poor parameterization
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
+ # 20 is the default max_length of the generation config
+ warnings.warn(
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
+ "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
+ "generation.",
+ UserWarning,
+ )
+ if input_ids_length >= generation_config.max_length:
+ input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ warnings.warn(
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`.",
+ UserWarning,
+ )
+
+ # 2. Min length warnings due to unfeasible parameter combinations
+ min_length_error_suffix = (
+ " Generation will stop at the defined maximum length. You should decrease the minimum length and/or "
+ "increase the maximum length."
+ )
+ if has_default_max_length:
+ min_length_error_suffix += (
+ f" Note that `max_length` is set to {generation_config.max_length}, its default value."
+ )
+ if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length:
+ warnings.warn(
+ f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than"
+ f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
+ UserWarning,
+ )
+ if generation_config.min_new_tokens is not None:
+ min_length = generation_config.min_new_tokens + input_ids_length
+ if min_length > generation_config.max_length:
+ warnings.warn(
+ f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when "
+ f"added to the prompt length ({input_ids_length}), is larger than"
+ f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix,
+ UserWarning,
+ )
+
+
+class ExplicitEnum(str, Enum):
+ """
+ Enum with more explicit error message for missing values.
+ """
+
+ @classmethod
+ def _missing_(cls, value):
+ raise ValueError(
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
+ )
+
+
+class GenerationMode(ExplicitEnum):
+ """
+ Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method.
+ """
+
+ # Non-beam methods
+ CONTRASTIVE_SEARCH = "contrastive_search"
+ GREEDY_SEARCH = "greedy_search"
+ SAMPLE = "sample"
+ ASSISTED_GENERATION = "assisted_generation"
+ # Beam methods
+ BEAM_SEARCH = "beam_search"
+ BEAM_SAMPLE = "beam_sample"
+ CONSTRAINED_BEAM_SEARCH = "constrained_beam_search"
+ GROUP_BEAM_SEARCH = "group_beam_search"
+
+
+@jit_class
+class PretrainedConfig:
+ """
+ Abstract class for Pretrained models config.
+ """
+ is_composition = False
+ # Add for handle attribute_map
+ attribute_map: Dict[str, str] = {}
+
+ def __init__(self, **kwargs):
+ self.ms_dtype = kwargs.pop("ms_dtype", None)
+ if 'torch_dtype' in kwargs:
+ self.ms_dtype = kwargs.pop("torch_dtype", None)
+ self.return_dict = kwargs.pop("return_dict", True)
+ self.output_hidden_states = kwargs.pop("output_hidden_states", False)
+ self.output_attentions = kwargs.pop("output_attentions", False)
+
+ self.pruned_heads = kwargs.pop("pruned_heads", {})
+ self.tie_word_embeddings = kwargs.pop(
+ "tie_word_embeddings", True
+ ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models.
+
+ # Is decoder is used in encoder-decoder models to differentiate encoder from decoder
+ self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False)
+ self.is_decoder = kwargs.pop("is_decoder", False)
+ self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None)
+ self.add_cross_attention = kwargs.pop("add_cross_attention", False)
+ self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False)
+
+ # Parameters for sequence generation
+ self.max_length = kwargs.pop("max_length", 20)
+ self.min_length = kwargs.pop("min_length", 0)
+ self.do_sample = kwargs.pop("do_sample", False)
+ self.early_stopping = kwargs.pop("early_stopping", False)
+ self.num_beams = kwargs.pop("num_beams", 1)
+ self.num_beam_groups = kwargs.pop("num_beam_groups", 1)
+ self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0)
+ self.temperature = kwargs.pop("temperature", 1.0)
+ self.top_k = kwargs.pop("top_k", 50)
+ self.top_p = kwargs.pop("top_p", 1.0)
+ self.typical_p = kwargs.pop("typical_p", 1.0)
+ self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0)
+ self.length_penalty = kwargs.pop("length_penalty", 1.0)
+ self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0)
+ self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0)
+ self.bad_words_ids = kwargs.pop("bad_words_ids", None)
+ self.num_return_sequences = kwargs.pop("num_return_sequences", 1)
+ self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0)
+ self.output_scores = kwargs.pop("output_scores", False)
+ self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False)
+ self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None)
+ self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None)
+ self.remove_invalid_values = kwargs.pop("remove_invalid_values", False)
+ self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None)
+ self.suppress_tokens = kwargs.pop("suppress_tokens", None)
+ self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None)
+
+ # Fine-tuning task arguments
+ self.architectures = kwargs.pop("architectures", None)
+ self.finetuning_task = kwargs.pop("finetuning_task", None)
+ self.id2label = kwargs.pop("id2label", None)
+ self.label2id = kwargs.pop("label2id", None)
+ if self.label2id is not None and not isinstance(self.label2id, dict):
+ raise ValueError("Argument label2id should be a dictionary.")
+ if self.id2label is not None:
+ if not isinstance(self.id2label, dict):
+ raise ValueError("Argument id2label should be a dictionary.")
+ num_labels = kwargs.pop("num_labels", None)
+
+ if num_labels is not None and len(self.id2label) != num_labels:
+ logger.warning(
+ f"You passed along `num_labels={num_labels}` with an incompatible id to label map: "
+ f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}."
+ )
+ self.id2label = {int(key): value for key, value in self.id2label.items()}
+ # Keys are always strings in JSON so convert ids to int here.
+ else:
+ self.num_labels = kwargs.pop("num_labels", 2)
+
+ if self.ms_dtype is not None and isinstance(self.ms_dtype, str):
+ if is_mindspore_available():
+ self.ms_dtype = getattr(mindspore, self.ms_dtype)
+
+ # Tokenizer arguments TODO: eventually tokenizer and models should share the same config
+ self.tokenizer_class = kwargs.pop("tokenizer_class", None)
+ self.prefix = kwargs.pop("prefix", None)
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
+ self.sep_token_id = kwargs.pop("sep_token_id", None)
+
+ self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None)
+
+ # task specific arguments
+ self.task_specific_params = kwargs.pop("task_specific_params", None)
+
+ # regression / multi-label classification
+ self.problem_type = kwargs.pop("problem_type", None)
+ allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification")
+ if self.problem_type is not None and self.problem_type not in allowed_problem_types:
+ raise ValueError(
+ f"The config parameter `problem_type` was not understood: received {self.problem_type} "
+ "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid."
+ )
+
+ # Name or path to the pretrained checkpoint
+ self._name_or_path = str(kwargs.pop("name_or_path", ""))
+
+ # Additional attributes without default values
+ for key, value in kwargs.items():
+ try:
+ setattr(self, key, value)
+ except AttributeError as err:
+ logger.error(f"Can't set {key} with value {value} for {self}")
+ raise err
+
+ def __setattr__(self, key, value):
+ if key in super().__getattribute__("attribute_map"):
+ key = super().__getattribute__("attribute_map")[key]
+ super().__setattr__(key, value)
+
+ def __getattribute__(self, key):
+ if key != "attribute_map" and key in super().__getattribute__("attribute_map"):
+ key = super().__getattribute__("attribute_map")[key]
+ return super().__getattribute__(key)
+
+ @property
+ def name_or_path(self) -> str:
+ """get name_or_path"""
+ return getattr(self, "_name_or_path", None)
+
+ @name_or_path.setter
+ def name_or_path(self, value):
+ """set name_or_path"""
+ self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding)
+
+ @classmethod
+ def from_json(cls, file_path):
+ """load config from json."""
+ with open(file_path, "r", encoding="utf-8") as file:
+ text = file.read()
+ config_map = json.loads(text)
+ config = cls()
+ for key, value in config_map.items():
+ setattr(config, key, value)
+ return config
+
+ @classmethod
+ def from_json_file(cls, json_file):
+ """Constructs a `Config` from a json file of parameters."""
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ dict_obj = json.loads(text)
+ return cls(**dict_obj)
+
+ @classmethod
+ def load(cls, pretrained_model_name_or_path):
+ """load config."""
+ return cls.from_pretrained(pretrained_model_name_or_path)
+
+ @property
+ def use_return_dict(self) -> bool:
+ """
+ `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples.
+ """
+ # If torchscript is set, force `return_dict=False` to avoid jit errors
+ return self.return_dict
+
+ @classmethod
+ def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig":
+ """
+ Constructs a `Config` from a Python dictionary of parameters.
+ """
+ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False)
+
+ config = cls(**config_dict)
+
+ if hasattr(config, "pruned_heads"):
+ config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items())
+
+ # Update config with kwargs if needed
+ if "num_labels" in kwargs and "id2label" in kwargs:
+ num_labels = kwargs["num_labels"]
+ id2label = kwargs["id2label"] if kwargs["id2label"] is not None else []
+ if len(id2label) != num_labels:
+ raise ValueError(
+ f"You passed along `num_labels={num_labels }` with an incompatible id to label map: "
+ f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove "
+ "one of them.")
+
+ to_remove = []
+ for key, value in kwargs.items():
+ if hasattr(config, key):
+ setattr(config, key, value)
+ to_remove.append(key)
+ for key in to_remove:
+ kwargs.pop(key, None)
+
+ logger.info("Model config %s", str(config))
+ if return_unused_kwargs:
+ return config, kwargs
+ return config
+
+ @classmethod
+ def from_pretrained(
+ cls,
+ pretrained_model_name_or_path: Union[str, os.PathLike],
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
+ force_download: bool = False,
+ local_files_only: bool = False,
+ mirror: str = 'huggingface',
+ **kwargs,
+ ) -> "PretrainedConfig":
+ r"""
+ Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration.
+ """
+ kwargs["cache_dir"] = cache_dir
+ kwargs["force_download"] = force_download
+ kwargs["local_files_only"] = local_files_only
+ kwargs['mirror'] = mirror
+
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
+ if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type:
+ logger.warning(
+ f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
+ f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
+ )
+
+ return cls.from_dict(config_dict, **kwargs)
+
+ @classmethod
+ def get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a
+ [`PretrainedConfig`] using `from_dict`.
+ """
+ original_kwargs = copy.deepcopy(kwargs)
+ # Get config dict associated with the base config file
+ config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs)
+
+ # That config file may point us toward another config file to use.
+ if "configuration_files" in config_dict:
+ configuration_file = get_configuration_file(config_dict["configuration_files"])
+ config_dict, kwargs = cls._get_config_dict(
+ pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs
+ )
+
+ return config_dict, kwargs
+
+ @classmethod
+ def _get_config_dict(
+ cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
+ """
+ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used
+ for instantiating a Config using `from_dict`.
+
+ Parameters:
+ pretrained_model_name_or_path (:obj:`string`):
+ The identifier of the pre-trained checkpoint from which we want the dictionary of parameters.
+
+ Returns:
+ :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object.
+
+ """
+ cache_dir = kwargs.pop("cache_dir", None)
+ force_download = kwargs.pop("force_download", False)
+ resume_download = kwargs.pop("resume_download", False)
+ proxies = kwargs.pop("proxies", None)
+ local_files_only = kwargs.pop("local_files_only", False)
+ subfolder = kwargs.pop("subfolder", "")
+ token = kwargs.pop('token', None)
+ revision = kwargs.pop('revision', 'main')
+ mirror = kwargs.pop('mirror', 'huggingface')
+
+ pretrained_model_name_or_path = str(pretrained_model_name_or_path)
+
+ is_local = os.path.isdir(pretrained_model_name_or_path)
+ if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
+ # Special case when pretrained_model_name_or_path is a local file
+ resolved_config_file = pretrained_model_name_or_path
+ is_local = True
+
+ elif is_remote_url(pretrained_model_name_or_path):
+ configuration_file = pretrained_model_name_or_path
+ resolved_config_file = download_url(pretrained_model_name_or_path)
+
+ else:
+ configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME)
+
+ try:
+ # Load from local folder or from cache or download from model Hub and cache
+ resolved_config_file = cached_file(
+ pretrained_model_name_or_path,
+ configuration_file,
+ cache_dir=cache_dir,
+ force_download=force_download,
+ proxies=proxies,
+ resume_download=resume_download,
+ local_files_only=local_files_only,
+ revision=revision,
+ token=token,
+ subfolder=subfolder,
+ mirror=mirror
+ )
+ except EnvironmentError:
+ # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to
+ # the original exception.
+ raise
+ except Exception as exc:
+ # For any other exception, we throw a generic error.
+ raise EnvironmentError(
+ f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it"
+ ", make sure you don't have a local directory with the same"
+ f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory"
+ f" containing a {configuration_file} file"
+ ) from exc
+
+ try:
+ # Load config dict
+ config_dict = cls._dict_from_json_file(resolved_config_file)
+ except (json.JSONDecodeError, UnicodeDecodeError) as exc:
+ raise EnvironmentError(
+ f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file."
+ ) from exc
+
+ if is_local:
+ logger.info(f"loading configuration file {resolved_config_file}")
+ else:
+ logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}")
+
+ return config_dict, kwargs
+
+ @classmethod
+ def _dict_from_json_file(cls, json_file: str):
+ """_dict_from_json_file"""
+ with open(json_file, "r", encoding="utf-8") as reader:
+ text = reader.read()
+ return json.loads(text)
+
+ def dict_ms_dtype_to_str(self, d: Dict[str, Any]) -> None:
+ """
+ Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None,
+ converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"*
+ string, which can then be stored in the json format.
+ """
+ if d.get("ms_dtype", None) is not None and not isinstance(d["ms_dtype"], str):
+ d["ms_dtype"] = str(d["ms_dtype"]).lower()
+ for value in d.values():
+ if isinstance(value, dict):
+ self.dict_ms_dtype_to_str(value)
+
+ def to_dict(self) -> Dict[str, Any]:
+ """
+ Serializes this instance to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance.
+ """
+ output = copy.deepcopy(self.__dict__)
+ if hasattr(self.__class__, "model_type"):
+ output["model_type"] = self.__class__.model_type
+ if "_auto_class" in output:
+ del output["_auto_class"]
+ if "_commit_hash" in output:
+ del output["_commit_hash"]
+ if "_attn_implementation_internal" in output:
+ del output["_attn_implementation_internal"]
+
+ for key, value in output.items():
+ # Deal with nested configs like CLIP
+ if isinstance(value, PretrainedConfig):
+ value = value.to_dict()
+
+ output[key] = value
+
+ if hasattr(self, "quantization_config"):
+ output["quantization_config"] = (
+ self.quantization_config.to_dict()
+ if not isinstance(self.quantization_config, dict)
+ else self.quantization_config
+ )
+
+ # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
+ _ = output.pop("_pre_quantization_dtype", None)
+
+ self.dict_ms_dtype_to_str(output)
+
+ return output
+
+ def to_diff_dict(self) -> Dict[str, Any]:
+ """
+ Removes all attributes from config which correspond to the default config attributes for better readability and
+ serializes to a Python dictionary.
+
+ Returns:
+ `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance,
+ """
+ config_dict = self.to_dict()
+
+ # get the default config dict
+ default_config_dict = PretrainedConfig().to_dict()
+
+ # get class specific config dict
+ class_config_dict = self.__class__().to_dict() if not self.is_composition else {}
+
+ serializable_config_dict = {}
+
+ # only serialize values that differ from the default config
+ for key, value in config_dict.items():
+ if (
+ isinstance(getattr(self, key, None), PretrainedConfig)
+ and key in class_config_dict
+ and isinstance(class_config_dict[key], dict)
+ ):
+ # For nested configs we need to clean the diff recursively
+ diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None))
+ if "model_type" in value:
+ # Needs to be set even if it's not in the diff
+ diff["model_type"] = value["model_type"]
+ if diff:
+ serializable_config_dict[key] = diff
+ elif (
+ key not in default_config_dict
+ or value != default_config_dict[key]
+ or (key in class_config_dict and value != class_config_dict[key])
+ ):
+ serializable_config_dict[key] = value
+
+ return serializable_config_dict
+
+ def to_json_string(self, use_diff: bool = True) -> str:
+ """
+ Serializes this instance to a JSON string.
+ """
+ if use_diff is True:
+ config_dict = self.to_diff_dict()
+ else:
+ config_dict = self.to_dict()
+
+ return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
+
+ def to_file(self, save_path):
+ """Serializes this instance to a JSON file."""
+ output_dict = self.to_dict()
+ with open(os.path.join(save_path, 'config.json'), encoding='utf-8') as f:
+ json.dump(output_dict, f, sort_keys=True, indent=2)
+
+ def update(self, config_dict: Dict[str, Any]):
+ """
+ Updates attributes of this class with attributes from `config_dict`.
+ """
+ for key, value in config_dict.items():
+ setattr(self, key, value)
+
+ def save_pretrained(self, save_directory: Union[str, os.PathLike]):
+ """
+ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the
+ [`~PretrainedConfig.from_pretrained`] class method.
+ """
+
+ if os.path.isfile(save_directory):
+ raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
+
+ os.makedirs(save_directory, exist_ok=True)
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_config_file = os.path.join(save_directory, CONFIG_NAME)
+
+ self.to_json_file(output_config_file, use_diff=True)
+ logger.info(f"Configuration saved in {output_config_file}")
+
+ def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True):
+ """
+ Save this instance to a JSON file.
+ """
+ with open(json_file_path, "w", encoding="utf-8") as writer:
+ writer.write(self.to_json_string(use_diff=use_diff))
+
+ @property
+ def num_labels(self) -> int:
+ """
+ `int`: The number of labels for classification models.
+ """
+ return len(self.id2label)
+
+ @num_labels.setter
+ def num_labels(self, num_labels: int):
+ if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels:
+ self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)}
+ self.label2id = dict(zip(self.id2label.values(), self.id2label.keys()))
+
+ @property
+ def _attn_implementation(self):
+ """
+ This property is made private for now (as it cannot be changed
+ and a PreTrainedModel.use_attn_implementation method needs to be implemented.)
+ """
+ if hasattr(self, "_attn_implementation_internal"):
+ if not self._attn_implementation_internal:
+ # `config.attn_implementation` should never be None, for backward compatibility.
+ return "eager"
+ return self._attn_implementation_internal
+ return "eager"
+
+ @_attn_implementation.setter
+ def _attn_implementation(self, value):
+ self._attn_implementation_internal = value
+
+
+class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin):
+ """
+ Abstract class for Pretrained models
+ """
+ config_class = None
+ base_model_prefix = ""
+ main_input_name = "input_ids"
+
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
+ # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
+ _keys_to_ignore_on_load_missing = None
+ # a list of `re` patterns of `state_dict` keys that should be removed from the list of
+ # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
+ # warnings.
+ _keys_to_ignore_on_load_unexpected = None
+
+ _tied_weights_keys = None
+
+ _keep_in_fp32_modules = None
+
+ supports_recompute = False
+
+ def __init__(self, config):
+ super().__init__(config)
+ self._check_and_unset_acl()
+ # Save config in model
+ self.config = config
+ self.name_or_path = config.name_or_path
+ self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
+
+ def _check_and_unset_acl(self):
+ if "MS" in str(self.__class__.__name__) and \
+ 'MS_DEV_FORCE_ACL' in os.environ:
+ del os.environ['MS_DEV_FORCE_ACL']
+
+ def post_init(self):
+ """
+ A method executed at the end of each Transformer model initialization, to execute code that needs the model's
+ modules properly initialized (such as weight initialization).
+
+ """
+ self.init_weights()
+
+ @staticmethod
+ def prepare_inputs_for_generation(*args, **kwargs):
+ """
+ prepare_inputs_for_generation
+ """
+ return
+
+ @classmethod
+ def _from_config(cls, config, **kwargs):
+ """
+ All context managers that the model should be initialized under go here.
+ """
+ model = cls(config, **kwargs)
+
+ return model
+
+ @property
+ def base_model(self):
+ """
+ to get base_model
+ """
+ return getattr(self, self.base_model_prefix, self)
+
+ def get_input_embeddings(self) -> "nn.Cell":
+ """
+ Returns the model's input embeddings.
+
+ Returns:
+ :obj:`nn.Cell`: A mindspore cell mapping vocabulary to hidden states.
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ return base_model.get_input_embeddings()
+ raise NotImplementedError
+
+ def set_input_embeddings(self, new_embeddings: nn.Cell):
+ """
+ Set model's input embeddings.
+
+ Args:
+ value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states.
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ return base_model.set_input_embeddings(new_embeddings)
+ raise NotImplementedError
+
+ def get_output_embeddings(self):
+ """ Get model's output embeddings
+ Return None if the model doesn't have output embeddings
+ """
+ return None # Overwrite for models with output embeddings
+
+ def set_output_embeddings(self, new_embeddings: nn.Cell):
+ """
+ Set model's output embeddings.
+
+ Args:
+ value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states.
+ """
+ base_model = getattr(self, self.base_model_prefix, self)
+ if base_model is not self:
+ return base_model.set_output_embeddings(new_embeddings)
+ raise NotImplementedError
+
+ def resize_token_embeddings(
+ self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None
+ ) -> nn.Embedding:
+ """
+ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
+ Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
+ """
+ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
+ if new_num_tokens is None and pad_to_multiple_of is None:
+ return model_embeds
+ # Update base model and current model config
+ self.config.vocab_size = model_embeds.weight.shape[0]
+ self.vocab_size = model_embeds.weight.shape[0]
+ # Tie weights again if needed
+ self.tie_weights()
+
+ return model_embeds
+
+ def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None):
+ """
+ Update new_num_tokens with the actual size of new_embeddings.
+ If word embeddings are not tied, make sure that lm head is resized as well.
+ """
+ old_embeddings = self.get_input_embeddings()
+ new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of)
+
+ self.set_input_embeddings(new_embeddings)
+ self.get_input_embeddings().weight.data_sync(True)
+
+ if pad_to_multiple_of is not None:
+ new_num_tokens = new_embeddings.weight.shape[0]
+
+ if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings:
+ old_lm_head = self.get_output_embeddings() # pylint: disable=assignment-from-none
+ new_lm_head = self._get_resized_lm_head(
+ old_lm_head, new_num_tokens)
+ self.set_output_embeddings(new_lm_head)
+ self.get_output_embeddings().weight.data_sync(True)
+
+ return self.get_input_embeddings()
+
+ def resize_tokenizer_embeddings(self, new_num_tokens):
+ """
+ Obtain a new embedding layer or use the original one without updating it.
+ """
+ old_embeddings = self.get_input_embeddings()
+ new_embeddings = self._get_resized_embeddings(
+ old_embeddings, new_num_tokens)
+ self.set_input_embeddings(new_embeddings)
+ return self.get_input_embeddings()
+
+ def _get_resized_embeddings(
+ self,
+ old_embeddings: nn.Embedding,
+ new_num_tokens: Optional[int] = None,
+ pad_to_multiple_of: Optional[int] = None,
+ ) -> nn.Embedding:
+ """ Build a resized Embedding Module from a provided token Embedding Module.
+ Increasing the size will add newly initialized vectors at the end
+ Reducing the size will remove vectors from the end
+ """
+ if pad_to_multiple_of is not None:
+ if not isinstance(pad_to_multiple_of, int):
+ raise ValueError(
+ f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, "
+ "which is not and integer. Please make sure to pass an integer"
+ )
+ if new_num_tokens is None:
+ new_num_tokens = old_embeddings.weight.shape[0]
+ new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
+ else:
+ logger.info(
+ "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. "
+ f"This means that the new embedding dimension will be {new_num_tokens}. This might induce "
+ "some performance reduction as *Tensor Cores* will not be available. For more details about "
+ "this, or help on choosing the correct value for resizing, refer to this guide:"
+ " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
+ )
+
+ if new_num_tokens is None:
+ return old_embeddings
+
+ old_num_tokens, old_embedding_dim = old_embeddings.weight.shape
+ if old_num_tokens == new_num_tokens:
+ return old_embeddings
+
+ # Build new embeddings
+ new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim)
+
+ # initialize all new embeddings (in particular added tokens)
+ self._init_weights(new_embeddings)
+
+ # Copy word embeddings from the previous weights
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+ new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[
+ :num_tokens_to_copy, :]
+ return new_embeddings
+
+ def _get_resized_lm_head(
+ self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False
+ ) -> nn.Dense:
+ """
+ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
+ vectors at the end. Reducing the size will remove vectors from the end
+ """
+ if new_num_tokens is None:
+ return old_lm_head
+
+ old_num_tokens, old_lm_head_dim = (
+ old_lm_head.weight.shape if not transposed else old_lm_head.weight.T.shape
+ )
+
+ if old_num_tokens == new_num_tokens:
+ return old_lm_head
+
+ if not isinstance(old_lm_head, nn.Dense):
+ raise TypeError(
+ f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Dense}. You"
+ " should either use a different resize function or make sure that `old_lm_head` are an instance of"
+ f" {nn.Dense}."
+ )
+
+ # Build new lm head
+ new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
+ has_new_lm_head_bias = old_lm_head.bias is not None
+
+ # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
+ # because the shape of the new embedding layer is used across various modeling files
+ # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
+ # to errors when training.
+ new_lm_head = nn.Dense(
+ *new_lm_head_shape,
+ has_bias=has_new_lm_head_bias,
+ )
+
+ # initialize new lm head (in particular added tokens)
+ self._init_weights(new_lm_head)
+
+ num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
+
+ self._copy_lm_head_original_to_resized(
+ new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+ )
+
+ return new_lm_head
+
+ def _copy_lm_head_original_to_resized(
+ self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
+ ):
+ """
+ Copy old lm head weights to new lm head.
+ Copy bias weights to new lm head.
+ """
+ if not transposed:
+ new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
+ else:
+ new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
+
+ if has_new_lm_head_bias:
+ new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
+
+
+ @classmethod
+ def load(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
+ *args, **kwargs):
+ """
+ Load a pre-trained checkpoint from a pre-trained model file or url,
+ download and cache the pre-trained model file if model name in model list.
+
+ Params:
+ pretrained_model_name_or_path:
+ """
+ return cls.from_pretrained(pretrained_model_name_or_path, args, kwargs)
+
+ def save(self, save_dir):
+ """ Save a model and its configuration file to a directory, so that
+ it can be re-loaded using the `:func:`PreTrainedModel.from_pretrained`` class method.
+
+ Arguments:
+ save_dir: directory to which to save.
+ """
+ if os.path.isfile(save_dir):
+ logger.error(f"Provided path ({save_dir}) should be a directory, not a file")
+ return
+ os.makedirs(save_dir, exist_ok=True)
+
+ # Only save the model itself if we are using distributed training
+ model_to_save = self.cell if hasattr(self, "cell") else self
+
+ # Attach architecture to the config
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
+
+ # If we save using the predefined names, we can load using `from_pretrained`
+ output_model_file = os.path.join(save_dir, WEIGHTS_NAME)
+ save_checkpoint(model_to_save, output_model_file)
+
+ logger.info(f"Model weights saved in {output_model_file}")
+
+ @classmethod
+ def can_generate(cls) -> bool:
+ """
+ Returns whether this model can generate sequences with `.generate()`.
+
+ Returns:
+ `bool`: Whether this model can generate sequences with `.generate()`.
+ """
+ # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation.
+ # Alternativelly, the model can also have a custom `generate` function.
+ if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate):
+ return False
+ return True
+
+ def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
+ """
+ Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
+ """
+ if (attention_mask is not None) or (self.config.pad_token_id is None):
+ return
+
+ # Check only the first and last input IDs to reduce overhead.
+ if self.config.pad_token_id in input_ids[:, [-1, 0]]:
+ warn_string = (
+ "We strongly recommend passing in an `attention_mask` since your input_ids may be padded."
+ )
+
+ # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
+ # attention_mask or not. In this case, we should still show a warning because this is a rare case.
+ if (
+ (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
+ or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
+ or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
+ ):
+ warn_string += (
+ f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
+ f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
+ f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
+ )
+
+ logger.warning(warn_string)
+
+ def num_parameters(self, only_trainable=False):
+ """return parameters count"""
+ total = 0
+ param_set = set()
+ for param in self.get_parameters():
+ param_id = param.uuid
+ if param_id not in param_set and (only_trainable or param.requires_grad):
+ total += param.size
+ param_set.add(param_id)
+ return total
+
+ def trainable_params(self, recurse=True):
+ """
+ fix duplicated weights
+ """
+ return list(set(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse))))
+
+ def save_pretrained(
+ self,
+ save_directory: Union[str, os.PathLike],
+ is_main_process: bool = True,
+ state_dict: Optional[dict] = None,
+ save_function: Callable = mindspore.save_checkpoint,
+ max_shard_size: Union[int, str] = "5GB",
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ ):
+ """
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
+ [`~PreTrainedModel.from_pretrained`] class method.
+ """
+
+ if os.path.isfile(save_directory):
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
+ return
+
+ os.makedirs(save_directory, exist_ok=True)
+
+
+ # Only save the model itself if we are using distributed training
+ model_to_save = self
+
+ # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
+ # we currently don't use this setting automatically, but may start to use with v5
+ dtype = get_parameter_dtype(model_to_save)
+ model_to_save.config.ms_dtype = str(dtype).lower()
+
+ # Attach architecture to the config
+ model_to_save.config.architectures = [model_to_save.__class__.__name__]
+
+ # Save the config
+ if is_main_process:
+ model_to_save.config.save_pretrained(save_directory)
+ if self.can_generate():
+ model_to_save.generation_config.save_pretrained(save_directory)
+
+
+ # Save the model
+ if state_dict is None:
+ state_dict = model_to_save.parameters_dict()
+
+ # Shard the model if it is too big.
+ # weights_name = _add_variant(WEIGHTS_NAME, variant)
+ weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
+
+ shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name)
+
+ # Clean the folder from a previous save
+ for filename in os.listdir(save_directory):
+ full_filename = os.path.join(save_directory, filename)
+ # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
+ # in distributed settings to avoid race conditions.
+ weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
+
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
+ filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
+ reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
+
+ if (
+ filename.startswith(weights_no_suffix)
+ and os.path.isfile(full_filename)
+ and filename not in shards
+ and is_main_process
+ and reg.fullmatch(filename_no_suffix) is not None
+ ):
+ os.remove(full_filename)
+
+ # Save the model
+ for shard_file, shard in shards.items():
+ if safe_serialization:
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
+ # joyfulness), but for now this enough.
+ safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "np"})
+ else:
+ save_function(shard, os.path.join(save_directory, shard_file))
+
+ if index is None:
+ path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant))
+ logger.info(f"Model weights saved in {path_to_weights}")
+ else:
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
+ # Save the index as well
+ with open(save_index_file, "w", encoding="utf-8") as f:
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
+ f.write(content)
+ logger.info(
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
+ f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the "
+ f"index located at {save_index_file}."
+ )
+
+ def check_names(self):
+ pass
+
+
+class TensorType(ExplicitEnum):
+ """
+ Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for
+ tab-completion in an IDE.
+ """
+
+ MINDSPORE = "ms"
+ NUMPY = "np"
+
+
+class PaddingStrategy(ExplicitEnum):
+ """
+ Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an
+ IDE.
+ """
+
+ LONGEST = "longest"
+ MAX_LENGTH = "max_length"
+ DO_NOT_PAD = "do_not_pad"
+
+
+class StoppingCriteria():
+ """Abstract base class for all stopping criteria that can be applied during generation."""
+
+ @staticmethod
+ def __call__(input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
+ raise NotImplementedError("StoppingCriteria needs to be subclassed")
+
+
+class MaxLengthCriteria(StoppingCriteria):
+ """
+ This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep
+ in mind for decoder-only type of transformers, this will include the initial prompted tokens.
+ """
+
+ def __init__(self, max_length: int):
+ self.max_length = max_length
+
+ def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool:
+ return input_ids.shape[-1] >= self.max_length
+
+
+@dataclass
+class BaseModelOutput(ModelOutputMindnlp):
+ """
+ Base class for model's outputs, with potential hidden states and attentions.
+ """
+
+ last_hidden_state: mindspore.Tensor = None
+ hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[mindspore.Tensor]] = None
+
+
+@dataclass
+class BaseModelOutputWithPast(ModelOutputMindnlp):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+
+ last_hidden_state: mindspore.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[mindspore.Tensor]] = None
+
+
+@dataclass
+class BaseModelOutputWithPastAndCrossAttentions(ModelOutputMindnlp):
+ """
+ Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding).
+ """
+
+ last_hidden_state: mindspore.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[mindspore.Tensor]] = None
+ cross_attentions: Optional[Tuple[mindspore.Tensor]] = None
+
+
+@dataclass
+class CausalLMOutputWithPast(ModelOutputMindnlp):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+ """
+
+ loss: Optional[mindspore.Tensor] = None
+ logits: mindspore.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[mindspore.Tensor]] = None
+
+
+@dataclass
+class CausalLMOutputWithCrossAttentions(ModelOutputMindnlp):
+ """
+ Base class for causal language model (or autoregressive) outputs.
+ """
+
+ loss: Optional[mindspore.Tensor] = None
+ logits: mindspore.Tensor = None
+ past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None
+ hidden_states: Optional[Tuple[mindspore.Tensor]] = None
+ attentions: Optional[Tuple[mindspore.Tensor]] = None
+ cross_attentions: Optional[Tuple[mindspore.Tensor]] = None
+
+
+def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList:
+ """validate stopping criteria"""
+ stopping_max_length = stopping_criteria.max_length
+ new_stopping_criteria = deepcopy(stopping_criteria)
+ if stopping_max_length is not None and stopping_max_length != max_length:
+ warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning)
+ elif stopping_max_length is None:
+ new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length))
+ return new_stopping_criteria
+
+def _is_numpy(x):
+ return isinstance(x, np.ndarray)
+
+def is_numpy_array(x):
+ """
+ Tests if `x` is a numpy array or not.
+ """
+ return _is_numpy(x)
+
+def infer_framework_from_repr(x):
+ """
+ Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the
+ frameworks in a smart order, without the need to import the frameworks).
+ """
+ representation = str(type(x))
+ if representation.startswith("= input_rank:
+ raise ValueError(f"`dim` should be in range [{-input_rank}, {input_rank}), but got {input_rank, dim}")
+
+ sizes_mul = reduce(operator.mul, list(sizes))
+ if -1 not in sizes and sizes_mul != input_shape[dim_new]:
+ raise ValueError(f"unflatten: Provided `sizes` {sizes} don't multiply up to the"
+ f"size of dim {dim} ({input_shape[dim_new]}) in the input tensor")
+
+ out_shape = input_shape[:dim_new] + tuple(sizes) + input_shape[dim_new + 1:]
+ return out_shape
+
+
+# For all backend
+# For functional api
+# matmul
+origin_matmul = ops.matmul
+ops.matmulcus = fp16_patch_decorator(origin_matmul)
+# mm
+ops.mmcus = fp16_patch_decorator(origin_matmul)
+# addbmm
+origin_addbmm = ops.addbmm
+ops.addbmmcus = fp16_patch_decorator(origin_addbmm)
+# addmm
+origin_addmm = ops.addmm
+ops.addmmcus = fp16_patch_decorator(origin_addmm)
+# addmv
+origin_addmv = ops.addmv
+ops.addmvcus = fp16_patch_decorator(origin_addmv)
+# addr
+origin_addr = ops.addr
+ops.addrcus = fp16_patch_decorator(origin_addr)
+# baddbmm
+origin_baddbmm = ops.baddbmm
+ops.baddbmmcus = fp16_patch_decorator(origin_baddbmm)
+# bmm
+origin_bmm = ops.bmm
+ops.bmmcus = fp16_patch_decorator(origin_bmm)
+
+
+# dense
+def origin_dense(input_dense, weight, bias=None):
+ """patched dense"""
+ dense_ = _get_cache_prim(ops.Dense)()
+ return dense_(input_dense, weight, bias)
+
+ops.densecus = fp16_patch_decorator(origin_dense)
+
+def einsum_label_to_index(label):
+ if label == '.':
+ return 52
+ num_of_letters = ord('z') - ord('a') + 1
+ return (ord(label) - ord('A')) if (label.isupper()) else (num_of_letters + (ord(label) - ord('a')))
+
+def enumerate_lhs(op_labels, lhs, num_ops):
+ """
+ enumerate lhs in einsum function
+ """
+ curr_op = 0
+ found_ell = False
+ ell_skip = 0
+
+ for _, label in enumerate(lhs):
+ if label == ' ':
+ continue
+ if label == '.':
+ if ell_skip != 0:
+ ell_skip -= 1
+ continue
+ assert not found_ell, f"einsum(): found {curr_op} for operand for which an ellipsis was already found"
+
+ ell_skip = 2
+ op_labels[curr_op].append(ELLIPSIS)
+ found_ell = True
+ elif label == ',':
+ curr_op += 1
+ assert curr_op < num_ops, "einsum(): fewer operands were provided than specified in the equation"
+ found_ell = False
+ else:
+
+ op_labels[curr_op].append(einsum_label_to_index(label))
+
+ return op_labels, lhs, num_ops, curr_op, found_ell
+
+def enumerate_operands(op_labels, operands, label_count):
+ """
+ enumerate operands in einsum function
+ """
+ # Compute label frequency and number of dimensions covered by ellipsis
+ # We do this after parsing labels to make it more readable and simpler
+ # to compute the number of dimensions covered by ellipsis.
+ ell_num_dim = 0
+
+ for i, operand in enumerate(operands):
+ labels = op_labels[i]
+ ndims = operand.ndim
+ nlabels = len(labels)
+ has_ellipsis = False
+
+ for label in labels:
+ if label == ELLIPSIS:
+ nlabels -= 1
+ has_ellipsis = True
+ ell_num_dim = max(ell_num_dim, ndims - nlabels)
+ else:
+ label_count[label] += 1
+ if has_ellipsis:
+ assert nlabels <= ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \
+ f") is more than the number of dimensions ({ndims}) for operand {i}"
+ else:
+ assert nlabels == ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \
+ f") does not match the number of dimensions (" \
+ f"{ndims}) for operand {i} and no ellipsis was given"
+
+ return ell_num_dim, label_count
+
+def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels,
+ ell_num_dim, ell_index, label_perm_index):
+ """
+ unsqueeze missing dimension in einsum function
+ """
+ permuted_operands = []
+ for i, operand in enumerate(operands):
+ perm_shape = [-1] * perm_index
+ label_dim = [-1] * total_labels
+ operand = operands[i]
+ labels = op_labels[i]
+ original_sizes = operand.shape
+
+ j = 0
+ for label in labels:
+ if label == ELLIPSIS:
+ # Add missing dimensions covered by the ellipsis
+ num_missing_dim = ell_num_dim - \
+ (len(original_sizes) - len(labels) + 1)
+ for k in range(num_missing_dim):
+ operand = ops.unsqueeze(operand, j)
+ for k in range(ell_num_dim):
+ perm_shape[ell_index + k] = j
+ j += 1
+ elif label_dim[label] != -1:
+ dim = label_dim[label]
+ operand = ops.diagonal(operand, offset=0, dim1=dim, dim2=j)
+ operand = ops.moveaxis(operand, -1, dim)
+ else:
+ label_dim[label] = j
+ perm_shape[label_perm_index[label]] = j
+ j += 1
+
+ # Add dimensions for missing labels
+ for idx, index in enumerate(perm_shape):
+ if index == -1:
+ operand = ops.unsqueeze(operand, -1)
+ perm_shape[idx] = j
+ j += 1
+
+ operand = ops.transpose(operand, tuple(perm_shape))
+ permuted_operands.append(operand)
+
+ # Check if operands broadcast and keep track of last operand with
+ # dimension size != 1 for optimizing reductions
+ dim_last_op = [0] * perm_index
+ has_zero_size_dim = False
+ for dim in range(perm_index):
+ broadcast_size = permuted_operands[0].shape[dim]
+ for i in range(1, len(operands)):
+ dim_size = permuted_operands[i].shape[dim]
+ if broadcast_size != dim_size and broadcast_size != 1 and dim_size != 1:
+ raise RuntimeError("einsum(): operands do not broadcast with remapped shapes [original->remapped]")
+ if dim_size != 1:
+ broadcast_size = dim_size
+ dim_last_op[dim] = i
+ has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0)
+
+ return permuted_operands, has_zero_size_dim, dim_last_op
+
+def einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs):
+ """
+ intermediate operation in einsum function
+ """
+ # Current index in the permuted shape
+ perm_index = 0
+ # Start index of ellipsis dimensions in the permuted shape
+ ell_index = 0
+ found_ell = False
+
+ if arrow_pos == -1:
+ # Implicit output is ellipsis (...) + labels seen only once
+ perm_index = ell_num_dim
+ found_ell = True
+ for label, label_count in enumerate(label_count):
+ if label_count == 1:
+ label_perm_index[label] = perm_index
+ perm_index += 1
+ else:
+ rhs = equation[arrow_pos + 2:]
+ ell_skip = 0
+ for i, label in enumerate(rhs):
+ if label == ' ':
+ continue
+ if label == '.':
+ if ell_skip != 0:
+ ell_skip -= 1
+ continue
+ assert not found_ell, "einsum(): found \'.\' for output but an ellipsis (...) was already found"
+
+ ell_skip = 2
+ ell_index = perm_index
+ perm_index += ell_num_dim
+ found_ell = True
+ else:
+ assert str.isalpha(label), f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " \
+ f"in the equation string, subscripts must be in [a-zA-Z]"
+
+ index = einsum_label_to_index(label)
+ label_perm_index[index] = perm_index
+ perm_index += 1
+ return perm_index, found_ell, label_perm_index, ell_index
+
+def sum_result(num_ops, permuted_operands, out_size, perm_index, dim_last_op, result):
+ """
+ sum out or squeeze dimensions that are size 1 for all later operands
+ """
+ for i in range(1, num_ops):
+ operand = permuted_operands[i]
+ sum_dims = []
+
+ dim = out_size
+ for j in range(dim, perm_index):
+ if dim_last_op[j] < i:
+ operand = ops.squeeze(operand, dim)
+ dim -= 1
+ elif dim_last_op[j] == i:
+ if result.shape[dim] == 1:
+ operand = ops.sum(operand, dim)
+ result = ops.squeeze(result, dim)
+ dim -= 1
+ else:
+ sum_dims.append(dim)
+ dim += 1
+ if not sum_dims:
+ result = result.mul(operand)
+ elif len(sum_dims) == len(result.shape):
+ result = result.flatten().dot(operand.flatten())
+
+ return result
+
+def einsum(equation, *operands):
+ """
+ einsum method
+ """
+ assert operands, "einsum(): must provide at least one operand"
+ if isinstance(operands[0], tuple):
+ operands = operands[0]
+
+ arrow_pos = equation.find("->")
+ num_ops = len(operands)
+ op_labels = [[] for _ in range(num_ops)]
+ lhs = equation[0: arrow_pos]
+
+ op_labels, lhs, num_ops, curr_op, found_ell = enumerate_lhs(op_labels, lhs, num_ops)
+
+ assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation"
+ # Labels must be within [a-zA-Z].
+ total_labels = 52
+ label_count = [0] * total_labels
+ # The maximum number of dimensions covered by any ellipsis, needed when
+ # unsqueezing missing dimensions from operands to permute and broadcast
+ ell_num_dim, label_count = enumerate_operands(op_labels, operands, label_count)
+
+ # We want to align the dimensions of every input tensor to have
+ # shape out_dims + sum_dims. For this, we create a mapping of label
+ # to index into the permuted shape.
+ label_perm_index = [-1] * total_labels
+
+ perm_index, found_ell, label_perm_index, ell_index = \
+ einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs)
+
+ out_size = perm_index
+ if not found_ell:
+ ell_index = perm_index
+ perm_index += ell_num_dim
+
+ for label in range(total_labels):
+ if label_count[label] > 0 and label_perm_index[label] == -1:
+ label_perm_index[label] = perm_index
+ perm_index += 1
+
+ # Here we unsqueeze missing dimensions to make all operands have the same
+ # number of dimensions. We take diagonals for repeated labels within the
+ # same operand. Finally we permute the operands to align dimensions as
+ # per the perm_out_index we computed above.
+ permuted_operands, has_zero_size_dim, dim_last_op = \
+ unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels,
+ ell_num_dim, ell_index, label_perm_index)
+
+ # Compute result
+ result = permuted_operands[0]
+ if has_zero_size_dim:
+ out_shape = [-1] * out_size
+ for i in range(out_size):
+ out_shape[i] = permuted_operands[dim_last_op[i]].shape[i]
+ return ops.zeros(out_shape)
+
+ # Sum out or squeeze dimensions that are size 1 for all later operands
+ dim = out_size
+ for i in range(dim, perm_index):
+ if dim_last_op[i] == 0:
+ if result.shape[dim] == 1:
+ result = ops.squeeze(result, dim)
+ dim -= 1
+ else:
+ result = ops.sum(result, dim)
+ dim -= 1
+ dim += 1
+
+ result = sum_result(num_ops, permuted_operands,
+ out_size, perm_index, dim_last_op, result)
+
+ return result
+
+ops.einsum = einsum
+
+# conv1d
+ops.conv1dcus = fp16_patch_decorator(ops.conv1d)
+
+
+def _ones(*size, dtype=None):
+ if dtype is None:
+ dtype = mindspore.float32
+ if isinstance(size[0], tuple):
+ size = size[0]
+ ones_ = _get_cache_prim(ops.Ones)()
+ return ones_(size, dtype)
+
+ops.onescus = _ones
+
+
+def _zeros(*size, dtype=None):
+ if dtype is None:
+ dtype = mindspore.float32
+ if isinstance(size[0], tuple):
+ size = size[0]
+ zeros_ = _get_cache_prim(ops.Zeros)()
+ return zeros_(size, dtype)
+
+
+ops.zeroscus = _zeros
+
+# for Tensor
+# unfold
+def _get_unfold_indices(input_shape, dimension, size, step):
+ if dimension < 0:
+ dimension += len(input_shape)
+ indices = []
+ for i in range(0, input_shape[dimension] - size + 1, step):
+ indices.append(list(range(i, i + size)))
+
+ return indices, dimension
+
+
+def unfold(self, dimension, size, step):
+ """unfold"""
+ indices_new, dimension_new = _get_unfold_indices(self.shape, dimension, size, step)
+ indices = mindspore.Tensor(indices_new).astype(mindspore.int32)
+ output = ops.gather(self, indices, axis=dimension_new)
+ output = ops.moveaxis(output, dimension_new + 1, -1)
+ return output
+
+Tensor.unfold = unfold
+StubTensor.unfold = unfold
+
+
+# var_mean
+def var_mean(input_vm, axis=None, *, correction=1, keepdims=False):
+ """var_mean"""
+ axis = Validator.check_and_canonicalize_axes(axis, input.ndim)
+ x_mean = ops.mean(input_vm, axis, True)
+ x_sub = ops.sub(input_vm, x_mean)
+ x_pow = ops.pow(x_sub, 2)
+ x_sum = ops.sum(x_pow, axis, keepdims)
+ res_mean = ops.mean(input_vm, axis, keepdims)
+ nums = 1
+ if not axis:
+ nums = input_vm.size
+ else:
+ for ax in axis:
+ nums *= input_vm.shape[ax]
+ return ops.true_divide(x_sum, nums - correction), res_mean
+
+ops.var_mean = var_mean
+
+
+# std_mean
+def std_mean(input_sm, axis=None, *, correction=1, keepdims=False):
+ """std_mean"""
+ output = var_mean(input_sm, axis, correction=correction, keepdims=keepdims)
+ return ops.pow(output[0], 0.5), output[1]
+
+ops.std_mean = std_mean
+
+
+# masked_fill
+def masked_fill(inputs, mask, value):
+ """patched masked_fill"""
+ masked_value = ops.fill(inputs.dtype, inputs.shape, value)
+ return ops.select(mask, masked_value, inputs)
+
+
+def _masked_fill(self, mask, value):
+ return masked_fill(self, mask, value)
+
+
+ops.masked_fill = masked_fill
+Tensor.masked_fill = _masked_fill
+StubTensor.masked_fill = _masked_fill
+
+
+# ops.std
+def std(input_std, axis=None, ddof=0, keepdims=False):
+ """patched std"""
+ # Calculate mean
+ mean = ops.mean(input_std, axis=axis, keep_dims=keepdims)
+
+ # Squared differences from the mean
+ squared_diff = (input_std - mean)**2
+
+ # Sum along the specified dimension
+ if axis is not None:
+ sum_along_dim = ops.sum(squared_diff, dim=axis, keepdim=keepdims)
+ else:
+ sum_along_dim = squared_diff.sum()
+
+ # Calculate the correction factor
+ factor = 1.0 / (input_std.shape[axis] - ddof) if axis is not None else 1.0 / (input_std.size - ddof)
+
+ # Calculate the standard deviation
+ out = ops.sqrt(factor * sum_along_dim)
+
+ return out
+
+
+def _std(self, axis=None, ddof=0, keepdims=False):
+ return std(self, axis, ddof, keepdims)
+
+
+ops.std = std
+Tensor.std = _std
+StubTensor.std = _std
+
+
+# Tensor.__contains__
+def _contains(self, key):
+ eq_res = ops.equal(self, key)
+ res = ops.any(eq_res)
+ return bool(res)
+
+
+Tensor.__contains__ = _contains
+StubTensor.__contains__ = _contains
+
+
+def unflatten(self, dim, sizes):
+ """Tensor.unflatten"""
+ out_shape = _get_unflatten_size(self.shape, dim, sizes)
+ return self.reshape(out_shape)
+
+
+Tensor.unflatten = unflatten
+StubTensor.unflatten = unflatten
+
+
+def _as_strided(self, size, stride, storage_offset=None):
+ """
+ replace as_strided
+ """
+ if len(size) != len(stride):
+ raise RuntimeError("mismatch in length of strides and shape.")
+ index = np.arange(0, size[0] * stride[0], stride[0])
+ for i in range(1, len(size)):
+ tmp = np.arange(0, size[i] * stride[i], stride[i])
+ index = np.expand_dims(index, -1)
+ index = index + tmp
+ if storage_offset is not None:
+ index = index + storage_offset
+ if index.size == 0:
+ input_indices = mindspore.numpy.empty(index.shape, dtype=mstype.int32)
+ else:
+ input_indices = Tensor(index)
+ out = ops.gather(self.reshape(-1), input_indices, 0)
+ return out
+
+
+Tensor.as_strided = _as_strided
+StubTensor.as_strided = _as_strided
+
+
+def _nonzero(self, as_tuple=False):
+ if self.dtype == mstype.bool_:
+ self = self.astype(mstype.int64)
+ outs = ops.nonzero(self)
+ if as_tuple:
+ outs = ops.tensor_split(outs, self.ndim, -1)
+ outs = tuple(out.squeeze(-1) for out in outs)
+ return outs
+
+
+Tensor.nonzero = _nonzero
+StubTensor.nonzero = _nonzero
+
+
+def _expand(self, *size):
+ if len(size) == 1:
+ size = size[0]
+ return ops.broadcast_to(self, size)
+
+
+Tensor.expand = _expand
+StubTensor.expand = _expand
+
+
+mindspore.tensor = mindspore.Tensor
+ops.prod = bool_patch_decorator(ops.prod)
+
+
+def _prod(self, axis=None, keep_dims=False):
+ return ops.prod(self, axis, keep_dims)
+Tensor.prod = _prod
+StubTensor.prod = _prod
+
+
+def _eq(self, other):
+ """
+ replace __eq__
+ """
+ if not isinstance(other, (int, float, Tensor)):
+ return False
+ if isinstance(other, Tensor) and self.shape != other.shape:
+ return False
+ if id(self) == id(other):
+ return True
+ # bool type is not supported for `Equal` operator in backend.
+ if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_):
+ self = self.to(mstype.int32)
+ other = other.to(mstype.int32)
+ return ops.eq(self, other)
+
+
+Parameter.__eq__ = _eq
+
+old_repeat = Tensor.repeat
+
+
+def new_repeat_interleave(input_ri, repeats, axis=None):
+ """new repeat_interleave"""
+ if axis is None:
+ input_ri = input_ri.reshape(-1)
+ axis = 0
+ if isinstance(repeats, Tensor):
+ repeats = repeats.asnumpy().tolist()
+ output = old_repeat(input_ri, repeats, axis)
+ return output
+
+
+ops.repeat_interleave = bool_io_patch_decorator(new_repeat_interleave)
+
+
+def _repeat_interleave(self, repeats, dim):
+ return old_repeat(self, repeats, axis=dim)
+
+
+Tensor.repeat_interleave = _repeat_interleave
+StubTensor.repeat_interleave = _repeat_interleave
+
+
+def _repeat(self, *sizes):
+ return ops.tile(self, tuple(sizes))
+
+
+Tensor.repeat = _repeat
+StubTensor.repeat = _repeat
+
+
+if LESS_MS_2_2:
+ mindspore.bfloat16 = None
+
+ def eq(self, other):
+ """patched eq"""
+ return ops.equal(self, other)
+
+ Tensor.eq = eq
+ StubTensor.eq = eq
+
+ def _item(self):
+ return self.asnumpy().item()
+ Tensor.item = _item
+ StubTensor.item = _item
+
+ def _tolist(self):
+ return self.asnumpy().tolist()
+ Tensor.tolist = _tolist
+ StubTensor.tolist = _tolist
+
+
+# For Cells
+class DenseMindnlp(nn.Cell):
+ """patched Dense"""
+ def __init__(self,
+ in_channels,
+ out_channels,
+ has_bias=True,
+ dtype=mstype.float32):
+ """Initialize Dense."""
+ super().__init__()
+ self.in_channels = Validator.check_positive_int(
+ in_channels, "in_channels", self.cls_name)
+ self.out_channels = Validator.check_positive_int(
+ out_channels, "out_channels", self.cls_name)
+ self.has_bias = Validator.check_bool(
+ has_bias, "has_bias", self.cls_name)
+
+ self.weight = Parameter(initializer(
+ HeUniform(math.sqrt(5)), [out_channels, in_channels], dtype=dtype), name="weight")
+
+ self.bias = None
+ if self.has_bias:
+ fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape)
+ bound = 1 / math.sqrt(fan_in)
+ self.bias = Parameter(initializer(
+ Uniform(bound), [out_channels], dtype=dtype), name="bias")
+
+ def construct(self, x):
+ """
+ construct method of DenseMindnlp
+ """
+ if LESS_MS_2_2:
+ x_shape = x.shape
+ if len(x_shape) != 2:
+ x = x.reshape(-1, x.shape[-1])
+ x = ops.matmul(x, self.weight.T)
+ if self.has_bias:
+ x = ops.add(x, self.bias)
+ if len(x_shape) != 2:
+ out_shape = x_shape[:-1] + (x.shape[-1],)
+ x = x.reshape(out_shape)
+ return x
+ return ops.dense(x, self.weight, self.bias)
+
+ def extend_repr(self):
+ s = f'input_channels={self.in_channels}, output_channels={self.out_channels}'
+ if self.has_bias:
+ s += f', has_bias={self.has_bias}'
+ return s
+
+
+class EmbeddingMindnlp(nn.Cell):
+ """patched Embedding"""
+ def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=False, dtype=mstype.float32):
+ """Initialize Embedding."""
+ super().__init__()
+ self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name)
+ self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name)
+ Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name)
+ Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name)
+ self.use_one_hot = use_one_hot
+ self.dtype = dtype
+ self.padding_idx = padding_idx
+ self.weight = Parameter(initializer(Normal(1.0), [vocab_size, embedding_size]), name='weight')
+ if self.padding_idx and self.weight.init_flag:
+ self.weight[self.padding_idx] = 0
+
+ def construct(self, ids):
+ """
+ construct method of EmbeddingMindnlp
+ """
+ out_shape = ids.shape + (self.embedding_size,)
+ flat_ids = ids.reshape((-1,))
+
+ if self.use_one_hot:
+ one_hot_ids = ops.one_hot(flat_ids, self.vocab_size)
+ output_for_reshape = ops.matmul(one_hot_ids, self.weight)
+ else:
+ output_for_reshape = ops.gather(self.weight, flat_ids, 0)
+
+ output = output_for_reshape.reshape(out_shape)
+ return output
+
+ def extend_repr(self):
+ return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \
+ f'weight={self.weight}, dtype={self.dtype}, padding_idx={self.padding_idx}'
+
+
+class LayerNormMindnlp(nn.Cell):
+ r"""
+ Applies Layer Normalization over a mini-batch of inputs.
+ """
+
+ def __init__(self,
+ normalized_shape,
+ begin_norm_axis=-1,
+ begin_params_axis=-1,
+ gamma_init='ones',
+ beta_init='zeros',
+ epsilon=1e-5,
+ dtype=mstype.float32,
+ elementwise_affine=True
+ ):
+ """Initialize LayerNorm."""
+ super().__init__()
+ if isinstance(normalized_shape, int):
+ normalized_shape = [normalized_shape]
+ if not isinstance(normalized_shape, (tuple, list)):
+ raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], "
+ f"but got {normalized_shape} and the type is {type(normalized_shape)}.")
+ if not normalized_shape:
+ raise ValueError(
+ f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at "
+ f"least one element, but got normalized_shape = {normalized_shape}"
+ )
+ self.normalized_shape = normalized_shape
+ self.begin_norm_axis = begin_norm_axis
+ self.begin_params_axis = begin_params_axis
+ self.epsilon = epsilon
+ self.weight = Parameter(initializer(
+ gamma_init, normalized_shape, dtype=dtype), name="weight")
+ self.bias = Parameter(initializer(
+ beta_init, normalized_shape, dtype=dtype), name="bias")
+ self.layer_norm = ops.LayerNorm(begin_norm_axis=self.begin_norm_axis,
+ begin_params_axis=self.begin_params_axis,
+ epsilon=self.epsilon)
+ self.elementwise_affine = elementwise_affine
+
+ def construct(self, input_x):
+ """
+ construct method of LayerNormMindnlp
+ """
+ if self.elementwise_affine:
+ y, _, _ = self.layer_norm(input_x, self.weight.astype(input_x.dtype), self.bias.astype(input_x.dtype))
+ else:
+ y, _, _ = self.layer_norm(input_x, ops.ones(self.normalized_shape, input_x.dtype),
+ ops.zeros(self.normalized_shape, input_x.dtype),)
+ return y
+
+ def extend_repr(self):
+ return f'normalized_shape={self.normalized_shape}, begin_norm_axis={self.begin_norm_axis}, ' \
+ f'begin_params_axis={self.begin_params_axis}, weight={self.weight}, bias={self.bias}'
+
+
+class BatchNorm1dMindnlp(nn.Cell):
+ """Batch Normalization base class."""
+ def __init__(self,
+ num_features,
+ eps=1e-5,
+ momentum=0.9,
+ affine=True,
+ weight_init='ones',
+ bias_init='zeros',
+ moving_mean_init='zeros',
+ moving_var_init='ones',
+ use_batch_statistics=None,
+ dtype=mstype.float32):
+ """Initialize _BatchNorm."""
+ super().__init__()
+ if num_features < 1:
+ raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.")
+
+ if momentum < 0 or momentum > 1:
+ raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], "
+ f"but got {momentum}.")
+ self.use_batch_statistics = use_batch_statistics
+ if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool):
+ raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None,"
+ f" but got {use_batch_statistics}.")
+ self.num_features = num_features
+ self.eps = eps
+ self.moving_mean_init = moving_mean_init
+ self.moving_var_init = moving_var_init
+ self.running_mean = Parameter(initializer(
+ moving_mean_init, num_features, dtype=dtype), name="running_mean", requires_grad=False)
+ self.running_var = Parameter(initializer(
+ moving_var_init, num_features, dtype=dtype), name="running_var", requires_grad=False)
+ self.weight = Parameter(initializer(
+ weight_init, num_features, dtype=dtype), name="weight", requires_grad=affine)
+ self.bias = Parameter(initializer(
+ bias_init, num_features, dtype=dtype), name="bias", requires_grad=affine)
+
+ self.momentum = 1.0 - momentum
+
+ self.bn_train = ops.BatchNorm(is_training=True,
+ epsilon=self.eps,
+ momentum=self.momentum)
+
+ self.bn_infer = ops.BatchNorm(is_training=False, epsilon=self.eps)
+
+ def construct(self, x):
+ """
+ construct method of BatchNorm1dMindnlp
+ """
+ if self.use_batch_statistics is None:
+ if self.training:
+ return self.bn_train(x,
+ self.weight,
+ self.bias,
+ self.running_mean,
+ self.running_var)[0]
+ if not self.training:
+ return self.bn_infer(x,
+ self.weight,
+ self.bias,
+ self.running_mean,
+ self.running_var)[0]
+
+ if self.use_batch_statistics:
+ return self.bn_train(x,
+ self.weight,
+ self.bias,
+ self.running_mean,
+ self.running_var)[0]
+
+ return self.bn_infer(x,
+ self.weight,
+ self.bias,
+ self.running_mean,
+ self.running_var)[0]
+
+ def extend_repr(self):
+ return f'num_features={self.num_features}, eps={self.eps}, momentum={1.0 - self.momentum}, ' \
+ f'weight={self.weight}, bias={self.bias}, running_mean={self.running_mean}, ' \
+ f'running_var={self.running_var}'
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..21efd44a1f6e5c943067078d3f03d7cdad56ac7d
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py
@@ -0,0 +1,1037 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Logits process"""
+
+import math
+import inspect
+import logging
+from typing import List, Optional
+import mindspore
+from mindspore import ops
+import numpy as np
+
+
+class LogitsProcessor:
+ """Abstract base class for all logit processors that can be applied during generation."""
+
+ def __call__(self, input_ids, scores):
+ """Method for processing logits."""
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsWarper:
+ """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
+
+ def __call__(self, input_ids, scores):
+ """Method for warping logits."""
+ raise NotImplementedError(
+ f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
+ )
+
+
+class LogitsProcessorList(list):
+ """
+ This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a
+ `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each
+ [`LogitsProcessor`] or [`LogitsWarper`] to the inputs.
+ """
+
+ def __call__(self, input_ids, scores, **kwargs):
+ for processor in self:
+ function_args = inspect.signature(processor.__call__).parameters
+ if len(function_args) > 2:
+ if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
+ raise ValueError(
+ f"Make sure that all the required parameters: {list(function_args.keys())} for "
+ f"{processor.__class__} are passed to the logits processor."
+ )
+ scores = processor(input_ids, scores, **kwargs)
+ else:
+ scores = processor(input_ids, scores)
+ return scores
+
+
+class HammingDiversityLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for
+ [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence
+ Models](https://arxiv.org/pdf/1610.02424.pdf) for more details.
+ """
+
+ def __init__(self, diversity_penalty, num_beams, num_beam_groups):
+ if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0):
+ raise ValueError("`diversity_penalty` should be a float strictly larger than 0.")
+ self._diversity_penalty = diversity_penalty
+ if not isinstance(num_beams, int) or num_beams < 2:
+ raise ValueError("`num_beams` should be an integer strictly larger than 1.")
+ self._num_beams = num_beams
+ if not isinstance(num_beam_groups, int) or num_beam_groups < 2:
+ raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.")
+ if num_beam_groups > num_beams:
+ raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.")
+ self._num_sub_beams = num_beams // num_beam_groups
+
+ def __call__(self, input_ids, scores, current_tokens, beam_group_idx):
+ # hamming diversity: penalise using same token in current group which was used in previous groups at
+ # the same time step
+ batch_size = current_tokens.shape[0] // self._num_beams
+ group_start_idx = beam_group_idx * self._num_sub_beams
+ group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams)
+ group_size = group_end_idx - group_start_idx
+ vocab_size = scores.shape[-1]
+
+ if group_start_idx == 0:
+ return scores
+
+ for batch_idx in range(batch_size):
+ # predicted tokens of last time step of previous groups
+ previous_group_tokens = current_tokens[
+ batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx
+ ]
+ token_frequency = ops.bincount(previous_group_tokens, minlength=vocab_size)
+ scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency
+
+ return scores
+
+
+class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input.
+ """
+
+ def __init__(self, penalty, encoder_input_ids):
+ if not isinstance(penalty, float) or (penalty <= 0):
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+
+ self.penalty = 1 / penalty
+ self.encoder_input_ids = encoder_input_ids
+
+ def __call__(self, input_ids, scores):
+ score = ops.gather_elements(scores, 1, self.encoder_input_ids)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ score = ops.where(score < 0, score * self.penalty, score / self.penalty)
+
+ scores.scatter_(1, self.encoder_input_ids, score)
+ return scores
+
+
+class RepetitionPenaltyLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences.
+ """
+
+ def __init__(self, penalty):
+ if not isinstance(penalty, float) or (penalty <= 0):
+ raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
+
+ self.penalty = penalty
+
+ def __call__(self, input_ids, scores):
+ input_ids = ops.where(input_ids >= scores.shape[1], input_ids - scores.shape[1], input_ids)
+ score = ops.gather_elements(scores, 1, input_ids)
+
+ # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability
+ score = ops.where(score < 0, score * self.penalty, score / self.penalty)
+
+ scores = ops.tensor_scatter_elements(scores, input_ids, score, axis=1)
+ return scores
+
+
+def _get_ngrams(ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int):
+ generated_ngrams = [{} for _ in range(num_hypos)]
+ for idx in range(num_hypos):
+ gen_tokens = prev_input_ids[idx].asnumpy().tolist()
+ generated_ngram = generated_ngrams[idx]
+ for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
+ prev_ngram_tuple = tuple(ngram[:-1])
+ generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
+ return generated_ngrams
+
+
+def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
+ # Before decoding the next token, prevent decoding of ngrams that have already appeared
+ start_idx = cur_len + 1 - ngram_size
+ ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
+ return banned_ngrams.get(ngram_idx, [])
+
+
+def _calc_banned_ngram_tokens(ngram_size, prev_input_ids, num_hypos, cur_len):
+ """Copied from fairseq for no_repeat_ngram in beam_search"""
+ if cur_len + 1 < ngram_size:
+ # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
+ return [[] for _ in range(num_hypos)]
+
+ generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)
+
+ banned_tokens = [
+ _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
+ for hypo_idx in range(num_hypos)
+ ]
+ return banned_tokens
+
+
+class NoRepeatNGramLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces no repetition of n-grams.
+ """
+
+ def __init__(self, ngram_size):
+ if not isinstance(ngram_size, int) or ngram_size <= 0:
+ raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
+ self.ngram_size = ngram_size
+
+ def __call__(self, input_ids, scores):
+ num_batch_hypotheses = scores.shape[0]
+ cur_len = input_ids.shape[-1]
+ banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len)
+
+ for i, banned_tokens in enumerate(banned_batch_tokens):
+ scores[i, banned_tokens] = -float("inf")
+
+ return scores
+
+
+class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids.
+ """
+
+ def __init__(self, encoder_ngram_size, encoder_input_ids):
+ if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0:
+ raise ValueError(
+ f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}"
+ )
+ self.ngram_size = encoder_ngram_size
+ if len(encoder_input_ids.shape) == 1:
+ encoder_input_ids = encoder_input_ids.unsqueeze(0)
+ self.batch_size = encoder_input_ids.shape[0]
+ self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size)
+
+ def __call__(self, input_ids, scores):
+ # B x num_beams
+ num_hypos = scores.shape[0]
+ num_beams = num_hypos // self.batch_size
+ cur_len = input_ids.shape[-1]
+ banned_batch_tokens = [
+ _get_generated_ngrams(
+ self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len
+ )
+ for hypo_idx in range(num_hypos)
+ ]
+
+ for i, banned_tokens in enumerate(banned_batch_tokens):
+ scores[i, banned_tokens] = -float("inf")
+
+ return scores
+
+
+class NoBadWordsLogitsProcessor(LogitsProcessor):
+ """
+ [`LogitsProcessor`] that enforces that specified sequences will never be sampled.
+ """
+
+ def __init__(self, bad_words_ids, eos_token_id):
+ if not isinstance(bad_words_ids, List) or not bad_words_ids:
+ raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
+ if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
+ raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
+ if any(
+ any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
+ for bad_word_ids in bad_words_ids
+ ):
+ raise ValueError(
+ f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
+ )
+
+ if eos_token_id is None:
+ eos_token_id = []
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+
+ bad_words_ids = list(
+ filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids)
+ )
+ self.bad_words_id_length_1 = []
+ self.bad_words_id_length_greater_than_1 = []
+ for word in bad_words_ids:
+ if len(word) == 1:
+ self.bad_words_id_length_1.append(word[0])
+ else:
+ self.bad_words_id_length_greater_than_1.append(word)
+
+ self.static_bad_words_mask: Optional[mindspore.Tensor] = None
+
+ for banned_token_seq in self.bad_words_id_length_greater_than_1:
+ if not banned_token_seq:
+ raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
+
+ def __call__(self, input_ids, scores):
+ if self.static_bad_words_mask is None and self.bad_words_id_length_1 > 0:
+ self.static_bad_words_mask = self._calc_static_bad_word_mask(scores)
+
+ dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist())
+ scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens)
+
+ return scores
+
+ def _calc_static_bad_word_mask(self, scores):
+ static_bad_words_mask = ops.zeros(scores.shape[1])
+ static_bad_words_mask[self.bad_words_id_length_1] = 1
+ return static_bad_words_mask.unsqueeze(0).bool()
+
+ def _tokens_match(self, prev_tokens, tokens):
+ if not tokens:
+ # if bad word tokens is just one token always ban it
+ return True
+ if len(tokens) > len(prev_tokens):
+ # if bad word tokens are longer then prev input_ids they can't be equal
+ return False
+ return prev_tokens[-len(tokens):] == tokens
+
+ def _calc_banned_bad_words_ids(self, prev_input_ids):
+ """
+ calculate banned bad words ids
+ """
+ banned_tokens = []
+ for prev_input_ids_slice in prev_input_ids:
+ banned_tokens_slice = []
+ for banned_token_seq in self.bad_words_id_length_greater_than_1:
+ if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]):
+ banned_tokens_slice.append(banned_token_seq[-1])
+
+ banned_tokens.append(banned_tokens_slice)
+
+ return banned_tokens
+
+ def _set_scores_to_inf_for_banned_tokens(self, scores, banned_tokens):
+ """
+ Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a
+ list of list of banned tokens to ban in the format [[batch index, vocabulary position],...
+ """
+ banned_mask_list = []
+ for idx, batch_banned_tokens in enumerate(banned_tokens):
+ for token in batch_banned_tokens:
+ # Eliminates invalid bad word IDs that are over the vocabulary size.
+ if token <= scores.shape[1]:
+ banned_mask_list.append([idx, token])
+ else:
+ logging.error(
+ "An invalid bad word ID is defined: %d. This ID is not contained in the "
+ "vocabulary, and is therefore ignored.", token
+ )
+ if not banned_mask_list and self.static_bad_words_mask is None:
+ return scores
+
+ if banned_mask_list:
+ banned_mask = mindspore.Tensor(banned_mask_list)
+ indices = ops.ones(len(banned_mask))
+ # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates:
+ # [ 0 1 1 ]
+ # [ 0 0 0 ]
+ # [ 1 0 0 ]
+
+ banned_mask = (
+ mindspore.COOTensor(banned_mask,
+ indices, scores.shape)
+ .to_dense()
+ .bool()
+ )
+
+ if self.static_bad_words_mask is not None:
+ banned_mask = ops.bitwise_or(banned_mask, self.static_bad_words_mask)
+ else:
+ banned_mask = self.static_bad_words_mask
+
+ scores = scores.masked_fill(banned_mask, -float("inf"))
+ return scores
+
+
+class MinLengthLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
+ """
+
+ def __init__(self, min_length, eos_token_id):
+ if not isinstance(min_length, int) or min_length < 0:
+ raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
+
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id):
+ raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}")
+
+ self.min_length = min_length
+ self.eos_token_id = eos_token_id
+
+ def __call__(self, input_ids, scores):
+ vocab_tensor = ops.arange(scores.shape[-1])
+ eos_token_id = mindspore.tensor(self.eos_token_id)
+ eos_token_mask = vocab_tensor == eos_token_id
+ scores_processed = scores.copy()
+ if input_ids.shape[-1] < self.min_length:
+ scores_processed = ops.where(eos_token_mask, -math.inf, scores)
+ return scores_processed
+
+
+class MinNewTokensLengthLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0.
+ """
+
+ def __init__(self, prompt_length_to_skip, min_new_tokens, eos_token_id):
+ for arg_name, arg_value in [
+ ("prompt_length_to_skip", prompt_length_to_skip),
+ ("min_new_tokens", min_new_tokens),
+ ("eos_token_id", eos_token_id),
+ ]:
+ if not isinstance(arg_value, int) or arg_value < 0:
+ raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}")
+
+ self.prompt_length_to_skip = prompt_length_to_skip
+ self.min_new_tokens = min_new_tokens
+ self.eos_token_id = eos_token_id
+
+ def __call__(self, input_ids, scores):
+ new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip
+ if new_tokens_length < self.min_new_tokens:
+ scores[:, self.eos_token_id] = -float("inf")
+
+ return scores
+
+
+class PrefixConstrainedLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained
+ generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information.
+ """
+
+ def __init__(self, prefix_allowed_tokens_fn, num_beams):
+ self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn
+ self._num_beams = num_beams
+
+ def __call__(self, input_ids, scores):
+ mask = ops.full_like(scores, -math.inf)
+ for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])):
+ for beam_id, sent in enumerate(beam_sent):
+ mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0
+
+ return scores + mask
+
+
+class ForcedBOSTokenLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces the specified token as the first generated token.
+
+ Args:
+ bos_token_id (`int`):
+ The id of the token to force as the first generated token.
+ """
+
+ def __init__(self, bos_token_id):
+ self.bos_token_id = bos_token_id
+
+ def __call__(self, input_ids, scores):
+ cur_len = input_ids.shape[-1]
+ if cur_len == 1:
+ num_tokens = scores.shape[1]
+ scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf")
+ scores[:, self.bos_token_id] = 0
+ return scores
+
+
+class ForcedEOSTokenLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
+ """
+
+ def __init__(self, max_length, eos_token_id):
+ self.max_length = max_length
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ self.eos_token_id = eos_token_id
+
+ def __call__(self, input_ids, scores):
+ cur_len = input_ids.shape[-1]
+ if cur_len == self.max_length - 1:
+ num_tokens = scores.shape[1]
+ scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = \
+ float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min)
+ for i in self.eos_token_id:
+ scores[:, i] = 0
+ return scores
+
+
+class InfNanRemoveLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using
+ the logits processor should only be used if necessary since it can slow down the generation method. `max_length` is
+ reached.
+ """
+
+ def __call__(self, input_ids, scores):
+ # set all nan values to 0.0
+ scores[ops.isnan(scores)] = 0.0
+
+ # set all inf values to max possible value
+ scores[scores == float("inf")] = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).max)
+ return scores
+
+
+class ExponentialDecayLengthPenalty(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been
+ reached.
+ """
+
+ def __init__(self, exponential_decay_length_penalty, eos_token_id, input_ids_seq_length):
+ self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length
+ self.regulation_factor = exponential_decay_length_penalty[1]
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ self.eos_token_id = eos_token_id
+
+ def __call__(self, input_ids, scores):
+ cur_len = input_ids.shape[-1]
+ if cur_len > self.regulation_start:
+ for i in self.eos_token_id:
+ scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start)
+ return scores
+
+
+class SuppressTokensLogitsProcessor(LogitsProcessor):
+ r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
+ are not sampled."""
+
+ def __init__(self, suppress_tokens):
+ self.suppress_tokens = list(suppress_tokens)
+
+ def __call__(self, input_ids, scores):
+ scores[:, self.suppress_tokens] = -float("inf")
+ return scores
+
+
+class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor):
+ r"""
+ [`SuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
+ generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
+ sampled at the beginning of the generation.
+ """
+
+ def __init__(self, begin_suppress_tokens, begin_index):
+ self.begin_suppress_tokens = list(begin_suppress_tokens)
+ self.begin_index = begin_index
+
+ def __call__(self, input_ids, scores):
+ if input_ids.shape[1] == self.begin_index:
+ scores[:, self.begin_suppress_tokens] = -float("inf")
+
+ return scores
+
+
+class ForceTokensLogitsProcessor(LogitsProcessor):
+ r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
+ indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are
+ sampled at their corresponding index."""
+
+ def __init__(self, force_token_map):
+ self.force_token_map = dict(force_token_map)
+
+ def __call__(self, input_ids, scores):
+ generation_idx = input_ids.shape[-1]
+ current_token = self.force_token_map.get(generation_idx, None)
+ if current_token is not None:
+ scores[:, :] = -float("inf")
+ scores[:, current_token] = 0
+ return scores
+
+
+class LogitNormalization(LogitsProcessor, LogitsWarper):
+ r"""
+ [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize
+ the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in
+ this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that
+ the scores are normalized when comparing the hypotheses.
+ """
+
+ def __call__(self, input_ids, scores):
+ scores = ops.log_softmax(scores, axis=-1)
+ return scores
+
+
+class TemperatureLogitsWarper(LogitsWarper):
+ r"""
+ [`TemperatureLogitsWarper`] for temperature (exponential scaling output probability distribution).
+ Args:
+ temperature (:obj:`float`):
+ The value used to module the logits distribution.
+ """
+
+ def __init__(self, temperature):
+ if not isinstance(temperature, float) or not temperature > 0:
+ raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
+
+ self.temperature = temperature
+
+ def __call__(self, input_ids, scores):
+ scores = scores / self.temperature
+ return scores
+
+
+class TopPLogitsWarper(LogitsWarper):
+ """
+ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off.
+ """
+
+ def __init__(self, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ top_p = float(top_p)
+ if top_p < 0 or top_p > 1.0:
+ raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
+ if not isinstance(min_tokens_to_keep, int) or min_tokens_to_keep < 1:
+ raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
+
+ self.top_p = top_p
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids, scores):
+ if self.filter_value == -float("Inf"):
+ self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min)
+ # scores = ops.select(ops.isneginf(scores), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min), scores)
+ sorted_logits, sorted_indices = ops.sort(scores, descending=False)
+ cumulative_probs = ops.softmax(sorted_logits, axis=-1).cumsum(axis=-1)
+
+ # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
+ sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p)
+
+ # scatter sorted tensors to original indexing
+ sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0
+
+ if isinstance(sorted_indices_to_remove[0][0].item(), bool):
+ sorted_indices_to_remove = sorted_indices_to_remove.astype("int32")
+
+ indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
+
+ if isinstance(indices_to_remove[0][0].item(), int):
+ indices_to_remove = indices_to_remove.astype("bool")
+
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class TopKLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
+ """
+
+ def __init__(self, top_k, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ if not isinstance(top_k, int) or top_k <= 0:
+ raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
+
+ self.top_k = max(top_k, min_tokens_to_keep)
+ self.filter_value = filter_value
+
+ def __call__(self, input_ids, scores):
+ if self.filter_value == -float("Inf"):
+ self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min)
+ top_k = min(self.top_k, scores.shape[-1]) # Safety check
+ # Remove all tokens with a probability less than the last token of the top-k
+ indices_to_remove = scores < ops.topk(scores, top_k)[0][..., -1, None]
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+class TypicalLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language
+ Generation](https://arxiv.org/abs/2202.00666) for more information.
+ """
+
+ def __init__(self, mass=0.9, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ mass = float(mass)
+ if mass <= 0 or mass >= 1:
+ raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}")
+ if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
+ raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
+
+ self.filter_value = filter_value
+ self.mass = mass
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids, scores):
+ # calculate entropy
+ normalized = ops.log_softmax(scores, axis=-1)
+ p = ops.exp(normalized)
+ ent = -(normalized * p).nansum(-1, keepdim=True)
+
+ # shift and sort
+ shifted_scores = ops.abs((-normalized) - ent)
+ sorted_scores, sorted_indices = ops.sort(shifted_scores, descending=False)
+ sorted_logits = scores.gather(-1, sorted_indices)
+ cumulative_probs = sorted_logits.softmax(axis=-1).cumsum(axis=-1)
+
+ # Remove tokens with cumulative mass above the threshold
+ last_ind = (cumulative_probs < self.mass).axis(dim=1)
+ last_ind.clamp_(max=sorted_scores.shape[-1] - 1)
+ sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1))
+ sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0
+ indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove,
+ sorted_indices, sorted_indices_to_remove, axis=1)
+
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class EpsilonLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the
+ largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model
+ Desmoothing](https://arxiv.org/abs/2210.15191) for more information.
+ """
+
+ def __init__(self, epsilon, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ epsilon = float(epsilon)
+ if epsilon <= 0 or epsilon >= 1:
+ raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
+
+ min_tokens_to_keep = int(min_tokens_to_keep)
+ if min_tokens_to_keep < 1:
+ raise ValueError(
+ f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
+ )
+
+ self.epsilon = epsilon
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids, scores):
+ # Determine which indices to remove
+ probabilities = ops.softmax(scores, axis=-1)
+ indices_to_remove = probabilities < self.epsilon
+
+ # Keep the words with the 'min_tokens_to_keep'-highest probabilities
+ top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
+ indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None])
+
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+
+class EtaLogitsWarper(LogitsWarper):
+ r"""
+ [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic
+ cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of
+ the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest
+ min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long
+ samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation
+ Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample`
+ must be set to `True` for this `LogitsWarper` to work.
+ """
+
+ def __init__(self, epsilon, filter_value=-float("Inf"), min_tokens_to_keep=1):
+ epsilon = float(epsilon)
+ if epsilon <= 0 or epsilon >= 1:
+ raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}")
+
+ min_tokens_to_keep = int(min_tokens_to_keep)
+ if min_tokens_to_keep < 1:
+ raise ValueError(
+ f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}"
+ )
+
+ self.epsilon = mindspore.tensor(epsilon, mindspore.float32)
+ self.filter_value = filter_value
+ self.min_tokens_to_keep = min_tokens_to_keep
+
+ def __call__(self, input_ids, scores):
+ # Calculate the adaptive cutoff
+ probabilities = scores.softmax(dim=-1)
+ entropy = mindspore.nn.probability.distribution.Categorical().entropy(scores)
+ eta = ops.min(self.epsilon, ops.sqrt(self.epsilon) * ops.exp(-entropy))[..., None]
+ indices_to_remove = probabilities < eta
+
+ # Keep the words with the 'min_tokens_to_keep'-highest probabilities
+ top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check
+ indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None])
+
+ scores = scores.masked_fill(indices_to_remove, self.filter_value)
+ return scores
+
+class SequenceBiasLogitsProcessor(LogitsProcessor):
+ """
+ [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence
+ when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than
+ one token, consider using beam methods (to gracefully work around partially completed sequences that have a
+ negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier).
+ """
+
+ def __init__(self, sequence_bias):
+ self.sequence_bias = sequence_bias
+ self._validate_arguments()
+
+ # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size
+ # is inferred in the first usage, which inhibits initializing here)
+ self.length_1_bias = None
+ self.prepared_bias_variables = False
+
+ def __call__(self, input_ids, scores):
+ # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called.
+ if not self.prepared_bias_variables:
+ self._prepare_bias_variables(scores)
+
+ # 2 - prepares an empty bias to add
+ bias = ops.zeros_like(scores)
+
+ # 3 - include the bias from length = 1
+ bias += self.length_1_bias
+
+ # 4 - include the bias from length > 1, after determining which biased sequences may be completed.
+ for sequence_ids, sequence_bias in self.sequence_bias.items():
+ if len(sequence_ids) == 1: # the sequence is of length 1, already applied
+ continue
+ if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore
+ continue
+ prefix_length = len(sequence_ids) - 1
+ last_token = sequence_ids[-1]
+ matching_rows = ops.eq(
+ input_ids[:, -prefix_length:],
+ mindspore.tensor(sequence_ids[:-1], dtype=input_ids.dtype),
+ ).prod(dim=1)
+ bias[:, last_token] += ops.where(
+ matching_rows.bool(),
+ mindspore.tensor(sequence_bias),
+ mindspore.tensor(0.0),
+ )
+
+ # 5 - apply the bias to the scores
+ scores = scores + bias
+ return scores
+
+class AlternatingCodebooksLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel.
+ """
+
+ def __init__(self, input_start_len, semantic_vocab_size, codebook_size):
+ if not isinstance(input_start_len, int) or input_start_len < 0:
+ raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}")
+
+ self.input_start_len = input_start_len
+ self.semantic_vocab_size = semantic_vocab_size
+ self.codebook_size = codebook_size
+
+ def __call__(self, input_ids, scores):
+ curr_len = input_ids.shape[-1]
+
+ # even -> first codebook, odd -> second codebook
+ is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0
+
+ if is_first_codebook:
+ scores[:, : self.semantic_vocab_size] = -float("inf")
+ scores[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf")
+ else:
+ scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf")
+
+ return scores
+
+class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
+ r"""Logits processor for Classifier-Free Guidance (CFG). The processors
+ computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits,
+ parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with
+ the `unconditional_ids` branch.
+
+ See [the paper](https://arxiv.org/abs/2306.17806) for more information.
+ """
+
+ def __init__(self, guidance_scale, model, unconditional_ids, unconditional_attention_mask, use_cache):
+ self.guidance_scale = guidance_scale
+ self.model = model
+ self.unconditional_context = {
+ "input_ids": unconditional_ids,
+ "attention_mask": unconditional_attention_mask,
+ "use_cache": use_cache,
+ "past_key_values": None,
+ "first_pass": True,
+ }
+
+ def get_unconditional_logits(self, input_ids):
+ """get_unconditional_logits"""
+ if self.unconditional_context["first_pass"]:
+ if self.unconditional_context["input_ids"] is None:
+ self.unconditional_context["input_ids"] = input_ids[:, -1:]
+ if self.unconditional_context["attention_mask"] is None:
+ self.unconditional_context["attention_mask"] = ops.ones_like(
+ self.unconditional_context["input_ids"], dtype=mindspore.int64
+ )
+ input_ids = self.unconditional_context["input_ids"]
+ attention_mask = self.unconditional_context["attention_mask"]
+ self.unconditional_context["first_pass"] = False
+ else:
+ attention_mask = ops.cat(
+ [
+ self.unconditional_context["attention_mask"],
+ ops.ones_like(input_ids[:, -1:], dtype=mindspore.int64),
+ ],
+ axis=1,
+ )
+ if not self.unconditional_context["use_cache"]:
+ input_ids = ops.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], axis=1)
+ else:
+ input_ids = input_ids[:, -1:]
+ self.unconditional_context["input_ids"] = input_ids
+ self.unconditional_context["attention_mask"] = attention_mask
+
+ out = self.model(
+ input_ids,
+ attention_mask=attention_mask,
+ use_cache=self.unconditional_context["use_cache"],
+ past_key_values=self.unconditional_context["past_key_values"],
+ )
+ self.unconditional_context["past_key_values"] = out.get("past_key_values", None)
+
+ return out.logits
+
+ def __call__(self, input_ids, scores):
+ scores = ops.log_softmax(scores, axis=-1)
+ if self.guidance_scale == 1:
+ return scores
+
+ logits = self.get_unconditional_logits(input_ids)
+
+ unconditional_logits = ops.log_softmax(logits[:, -1], axis=-1)
+ out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
+ return out
+
+class WhisperTimeStampLogitsProcessor(LogitsProcessor):
+ r"""
+
+ [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input
+ tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure
+ that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is
+ done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted
+ probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those
+ non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other
+ potential tokens.
+
+
+ See [the paper](https://arxiv.org/abs/2212.04356) for more information.
+ """
+
+ def __init__(self, generate_config): # support for the kwargs
+ self.eos_token_id = generate_config.eos_token_id
+ self.no_timestamps_token_id = generate_config.no_timestamps_token_id
+ self.timestamp_begin = generate_config.no_timestamps_token_id + 1
+
+ self.begin_index = len(generate_config.forced_decoder_ids) + 2
+ if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id:
+ self.begin_index -= 1
+ self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index
+
+ def __call__(self, input_ids, scores):
+ # suppress <|notimestamps|> which is handled by without_timestamps
+ scores[:, self.no_timestamps_token_id] = -float("inf")
+
+ if input_ids.shape[1] == self.begin_index - 1:
+ scores[:, :] = -float("inf")
+ scores[:, self.timestamp_begin] = 0
+ return scores
+
+ # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly
+ for k in range(input_ids.shape[0]):
+ seq = list(input_ids[k, self.begin_index :].tolist())
+ last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin
+ penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin
+
+ if last_was_timestamp:
+ if penultimate_was_timestamp: # has to be non-timestamp
+ scores[k, self.timestamp_begin :] = -float("inf")
+ else: # cannot be normal text tokens
+ scores[k, : self.eos_token_id] = -float("inf")
+
+ # apply the `max_initial_timestamp` option
+ if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None:
+ last_allowed = self.timestamp_begin + self.max_initial_timestamp_index
+ scores[:, last_allowed + 1 :] = -float("inf")
+
+ # if sum of probability over timestamps is above any other token, sample timestamp
+ logprobs = ops.log_softmax(scores.float(), axis=-1)
+ for k in range(input_ids.shape[0]):
+ timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(axis=-1)
+ max_text_token_logprob = logprobs[k, : self.timestamp_begin].max()
+ if not ops.isnan(timestamp_logprob) and timestamp_logprob > max_text_token_logprob:
+ scores[k, : self.timestamp_begin] = -float("inf")
+
+ return scores
+
+class BarkEosPrioritizerLogitsProcessor(LogitsProcessor):
+ r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`.
+ """
+
+ def __init__(self, eos_token_id, min_eos_p):
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ self.eos_token_id = eos_token_id
+ if min_eos_p is not None and min_eos_p <= 0:
+ raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}")
+ self.min_eos_p = min_eos_p
+
+ def __call__(self, input_ids, scores):
+ if self.min_eos_p:
+ probs = ops.softmax(scores.float(), axis=-1)
+ # create scores full of -inf except for the eos_token_id
+ early_stop_scores = ops.ones_like(scores) * -float("inf")
+ early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id]
+
+ do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p
+ do_early_stop = ops.any(do_early_stop, axis=1, keep_dims=True)
+ scores = ops.where(do_early_stop, early_stop_scores, scores)
+
+ return scores
+
+
+class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor):
+ r"""
+ [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension,
+ where the first half correspond to the conditional logits (predicted from the input prompt) and the second half
+ correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a
+ weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`.
+
+ See [the paper](https://arxiv.org/abs/2306.05284) for more information.
+ """
+
+ def __init__(self, guidance_scale):
+ if guidance_scale > 1:
+ self.guidance_scale = guidance_scale
+ else:
+ raise ValueError(
+ "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale "
+ f"{guidance_scale}."
+ )
+
+ def __call__(self, input_ids, scores):
+ # simple check to make sure we have compatible batch sizes between our
+ # logits scores (cond + uncond) and input ids (cond only)
+ if scores.shape[0] != 2 * input_ids.shape[0]:
+ raise ValueError(
+ f"Logits should have twice the batch size of the input ids, the first half of batches "
+ "corresponding to the conditional inputs, and the second half of batches corresponding to "
+ f"the unconditional inputs. Got batch size {scores.shape[0]} for the logits and "
+ f" {input_ids.shape[0]} for the input ids."
+ )
+ unguided_bsz = scores.shape[0] // 2
+ cond_logits, uncond_logits = scores.split(unguided_bsz, axis=0)
+ scores_processed = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale
+ return scores_processed
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e19f36b20d40562f7e78ec2434c858480b21d80
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py
@@ -0,0 +1,718 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""progen"""
+
+import math
+import time
+from typing import Tuple
+from collections import OrderedDict
+
+import mindspore as ms
+from mindspore import Tensor, nn, ops, Parameter
+from mindspore.nn import CrossEntropyLoss
+
+from .module.injection import EmbeddingMindnlp, DenseMindnlp, LayerNormMindnlp
+from .module.configuration_utils import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ PreTrainedModelMindnlp,
+ PretrainedConfig,
+)
+
+
+class ClassInstantier(OrderedDict):
+ r"""
+ Class Instantier
+ """
+
+ def __getitem__(self, key):
+ content = super().__getitem__(key)
+ cls, kwargs = content if isinstance(content, tuple) else (content, {})
+ return cls(**kwargs)
+
+
+ACT2CLS = {
+ """
+ Excitation equation matrix
+ """
+ 'relu': nn.ReLU,
+ 'gelu': (nn.GELU, {"approximate": False}),
+ 'gelu_new': nn.GELU,
+ 'gelu_approximate': nn.GELU,
+ 'gelu_pytorch_tanh': nn.GELU,
+ "swish": nn.SiLU,
+ "gelu_10": nn.GELU,
+ "gelu_fast": nn.FastGelu,
+ "gelu_python": nn.GELU,
+ "linear": nn.ReLU,
+ "mish": nn.Mish,
+ "quick_gelu": nn.FastGelu,
+ "relu": nn.ReLU,
+ "relu6": nn.ReLU6,
+ "sigmoid": nn.Sigmoid,
+ "silu": nn.SiLU,
+ "tanh": nn.Tanh,
+}
+ACT2FN = ClassInstantier(ACT2CLS)
+
+def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
+ dim = x.shape[-1]
+ if seq_len is None:
+ seq_len = x.shape[seq_dim]
+ inv_freq = 1.0 / (10000 ** (ops.arange(0, dim, 2.0) / (dim * 1.0)))
+ sinusoid_inp = ops.einsum("i , j -> i j", ops.arange(seq_len), inv_freq).float()
+ return ops.sin(sinusoid_inp), ops.cos(sinusoid_inp)
+
+
+def rotate_every_two(x):
+ x1 = x[:, :, :, ::2]
+ x2 = x[:, :, :, 1::2]
+ x = ops.stack((-x2, x1), axis=-1)
+ x_flatten = x.flatten(start_dim=-2)
+ return x_flatten # in einsum notation: rearrange(x, '... d j -> ... (d j)')
+
+
+def apply_rotary_pos_emb(x, sincos, offset=0):
+ sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos)
+ # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
+ return (x * cos) + (rotate_every_two(x) * sin)
+
+
+class ProGenConfig(PretrainedConfig):
+ """
+ ProGenConfig class
+ """
+ model_type = "progen"
+
+ def __init__(
+ self, vocab_size, n_positions, n_ctx, n_embd, n_layer, n_head, rotary_dim,
+ n_inner, activation_function, resid_pdrop, embd_pdrop, attn_pdrop,
+ layer_norm_epsilon, initializer_range, scale_attn_weights, gradient_checkpointing,
+ use_cache, bos_token_id, eos_token_id, min_length, **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.vocab_size = vocab_size
+ self.n_ctx = n_ctx
+ self.n_positions = n_positions
+ self.n_embd = n_embd
+ self.n_layer = n_layer
+ self.n_head = n_head
+ self.n_inner = n_inner
+ self.rotary_dim = rotary_dim
+ self.activation_function = activation_function
+ self.resid_pdrop = resid_pdrop
+ self.embd_pdrop = embd_pdrop
+ self.attn_pdrop = attn_pdrop
+ self.layer_norm_epsilon = layer_norm_epsilon
+ self.initializer_range = initializer_range
+ self.gradient_checkpointing = gradient_checkpointing
+ self.scale_attn_weights = scale_attn_weights
+ self.use_cache = use_cache
+ self.bos_token_id = bos_token_id
+ self.eos_token_id = eos_token_id
+ self.min_length = min_length
+
+ @property
+ def max_position_embeddings(self):
+ return self.n_positions
+
+ @property
+ def hidden_size(self):
+ return self.n_embd
+
+ @property
+ def num_attention_heads(self):
+ return self.n_head
+
+ @property
+ def num_hidden_layers(self):
+ return self.n_layer
+
+
+class NewGELUActivation(nn.Cell):
+ """
+ Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
+ the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
+ """
+ def construct(self, inputs: Tensor) -> Tensor:
+ return 0.5 * inputs * (1.0 + ops.tanh(math.sqrt(2.0 / math.pi) * (inputs + 0.044715 * ops.pow(inputs, 3.0))))
+
+
+class ProGenAttention(nn.Cell):
+ """
+ ProGenAttention class
+ """
+ def __init__(self, config):
+ super().__init__()
+
+ max_positions = config.max_position_embeddings
+ temp_tri = ops.tril(
+ ops.ones((max_positions, max_positions), dtype=ms.int32)).view(
+ 1, 1, max_positions, max_positions
+ )
+ self.bias = Parameter(
+ Tensor(temp_tri, dtype=ms.bool_),
+ name="bias",
+ requires_grad=False,
+ )
+
+ self.masked_bias = Parameter(
+ Tensor(-1e9),
+ name="masked_bias",
+ requires_grad=False,
+ )
+
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
+
+ self.embed_dim = config.hidden_size
+ self.num_attention_heads = config.num_attention_heads
+ self.head_dim = self.embed_dim // self.num_attention_heads
+ if self.head_dim * self.num_attention_heads != self.embed_dim:
+ raise ValueError(
+ f"embed_dim must be divisible by num_attention_heads (got `embed_dim`:\
+ {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
+ )
+ self.scale_attn = ops.sqrt(Tensor(self.head_dim, dtype=ms.float32)).to(ms.float16)
+ self.qkv_proj = DenseMindnlp(self.embed_dim, self.embed_dim * 3, has_bias=False)
+
+ self.out_proj = DenseMindnlp(self.embed_dim, self.embed_dim, has_bias=False)
+ self.rotary_dim = None
+ if config.rotary_dim is not None:
+ self.rotary_dim = config.rotary_dim
+
+
+ def construct(
+ self,
+ hidden_states,
+ use_cache=False,
+ output_attentions=False,
+ add_input=None,
+ ):
+ """
+ construct method of ProGenAttention
+ """
+ layer_past, attention_mask, head_mask = add_input
+ qkv = self.qkv_proj(hidden_states)
+ mp_num = 8
+ qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1))
+
+ local_dim = self.head_dim * self.num_attention_heads // mp_num
+ query, value, key = ops.split(qkv_split, local_dim, axis=-1)
+ query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+ key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+
+ value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
+ value = value.permute(0, 2, 1, 3)
+
+ seq_len = key.shape[1]
+ offset = 0
+
+ if layer_past is not None:
+ offset = layer_past[0].shape[-2]
+ seq_len += offset
+
+ if self.rotary_dim is not None:
+ k_rot = key[:, :, :, : self.rotary_dim]
+ k_pass = key[:, :, :, self.rotary_dim :]
+
+ q_rot = query[:, :, :, : self.rotary_dim]
+ q_pass = query[:, :, :, self.rotary_dim :]
+
+ sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len)
+ k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset)
+ q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset)
+
+ key = ops.cat((k_rot, k_pass), axis=-1)
+ query = ops.cat((q_rot, q_pass), axis=-1)
+ else:
+ sincos = fixed_pos_embedding(key, 1, seq_len=seq_len)
+ key = apply_rotary_pos_emb(key, sincos, offset=offset)
+ query = apply_rotary_pos_emb(query, sincos, offset=offset)
+
+ key = key.permute(0, 2, 1, 3)
+ query = query.permute(0, 2, 1, 3)
+
+ if layer_past is not None:
+ past_key = layer_past[0]
+ past_value = layer_past[1]
+ key = ops.cat((past_key, key), axis=-2)
+ value = ops.cat((past_value, value), axis=-2)
+
+ if use_cache is True:
+ present = (key, value)
+ else:
+ present = None
+
+ # compute self-attention: V x Softmax(QK^T)
+ attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
+
+ attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
+
+ attn_output = self.out_proj(attn_output)
+ attn_output = self.resid_dropout(attn_output)
+
+ outputs = (attn_output, present)
+ if output_attentions:
+ outputs += (attn_weights,)
+
+ return outputs # a, present, (attentions)
+
+ def _split_heads(self, x, n_head, dim_head, mp_num):
+ reshaped = x.reshape(x.shape[: -1] + (n_head // mp_num, dim_head))
+ reshaped = reshaped.reshape(x.shape[: -2] + (-1,) + reshaped.shape[-1: ])
+ return reshaped
+
+ def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
+ """
+ Merges attn_head_size dim and num_attn_heads dim into n_ctx
+ """
+ if len(tensor.shape) == 5:
+ tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
+ elif len(tensor.shape) == 4:
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ else:
+ raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}")
+ new_shape = tensor.shape[:-2] + (num_attention_heads * attn_head_size,)
+ return tensor.view(new_shape)
+
+ def _attn(
+ self,
+ query,
+ key,
+ value,
+ attention_mask=None,
+ head_mask=None,
+ ):
+ """
+ _attn method of ProGenAttention
+ """
+ # compute causal mask from causal mask buffer
+ query_length, key_length = query.shape[-2], key.shape[-2]
+ causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]
+
+ # Keep the attention weights computation in fp32 to avoid overflow issues
+ query = Tensor(query, dtype=ms.float32)
+ key = Tensor(key, dtype=ms.float32)
+
+ new_key = key.transpose(0, 1, 3, 2)
+ attn_weights = ops.matmul(query, new_key)
+
+ attn_weights = attn_weights / self.scale_attn
+ attn_weights = ops.where(causal_mask, attn_weights, self.masked_bias)
+
+ if attention_mask is not None:
+ # Apply the attention mask
+ attn_weights = attn_weights + attention_mask
+
+ attn_weights = nn.Softmax(axis=-1)(attn_weights)
+ attn_weights = self.attn_dropout(attn_weights)
+
+ # Mask heads if we want to
+ if head_mask is not None:
+ attn_weights = attn_weights * head_mask
+
+ attn_output = ops.matmul(attn_weights, value)
+
+ return attn_output, attn_weights
+
+
+class ProGenMLP(nn.Cell):
+ """
+ ProGenMLP class
+ """
+ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
+ super().__init__()
+ embed_dim = config.n_embd # n_embd=4096
+
+ self.fc_in = DenseMindnlp(embed_dim, intermediate_size)
+ self.fc_out = DenseMindnlp(intermediate_size, embed_dim)
+
+ self.act = NewGELUActivation()
+ self.dropout = nn.Dropout(config.resid_pdrop) # config.resid_pdrop=0.0
+
+ def construct(self, hidden_states):
+ hidden_states = self.fc_in(hidden_states)
+ hidden_states = self.act(hidden_states)
+ hidden_states = self.fc_out(hidden_states)
+ hidden_states = self.dropout(hidden_states)
+ return hidden_states
+
+
+class ProGenBlock(nn.Cell):
+ """
+ ProGenBlock class
+ """
+ def __init__(self, config):
+ super().__init__()
+ if config.n_inner is not None and config.n_inner != "None":
+ inner_dim = config.n_inner
+ else:
+ inner_dim = 4 * config.n_embd
+ self.ln_1 = LayerNormMindnlp(config.n_embd, epsilon=config.layer_norm_epsilon)
+ self.attn = ProGenAttention(config)
+ self.mlp = ProGenMLP(inner_dim, config)
+
+ def construct(
+ self,
+ hidden_states,
+ use_cache=False,
+ output_attentions=False,
+ add_input=None,
+ ):
+ """
+ construct method of ProGenBlock
+ """
+ residual = hidden_states
+ hidden_states = self.ln_1(hidden_states)
+ attn_outputs = self.attn(
+ hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ add_input=add_input,
+ )
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
+ outputs = attn_outputs[1:]
+
+ feed_forward_hidden_states = self.mlp(hidden_states)
+ hidden_states = attn_output + feed_forward_hidden_states + residual
+
+ if use_cache:
+ outputs = (hidden_states,) + outputs
+ else:
+ outputs = (hidden_states,) + outputs[1:]
+
+ return outputs # hidden_states, present, (attentions)
+
+
+class ProGenPreTrainedModel(PreTrainedModelMindnlp):
+ """
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
+ models.
+ """
+ config_class = ProGenConfig
+ base_model_prefix = "transformer"
+ is_parallelizable = True
+
+ def _init_weights(self, module):
+ """Initialize the weights."""
+ if isinstance(module, (DenseMindnlp,)):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.bias is not None:
+ module.bias.data.zero_()
+ elif isinstance(module, EmbeddingMindnlp):
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
+ if module.padding_idx is not None:
+ module.weight.data[module.padding_idx].zero_()
+ elif isinstance(module, LayerNormMindnlp):
+ module.bias.data.zero_()
+ module.weight.data.fill_(1.0)
+ else:
+ return
+
+
+class ProGenModel(ProGenPreTrainedModel):
+ """
+ ProGenModel class
+ """
+ def __init__(self, config):
+ super().__init__(config)
+
+ self.embed_dim = config.n_embd
+ self.vocab_size = config.vocab_size
+ self.wte = EmbeddingMindnlp(config.vocab_size, self.embed_dim)
+ self.drop = nn.Dropout(config.embd_pdrop)
+ self.h = nn.CellList([ProGenBlock(config) for _ in range(config.n_layer)])
+ self.ln_f = LayerNormMindnlp(self.embed_dim, epsilon=config.layer_norm_epsilon)
+ self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
+
+
+ def get_input_embeddings(self):
+ return self.wte
+
+ def set_input_embeddings(self, new_embeddings):
+ self.wte = new_embeddings
+
+ def construct(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ """
+ construct method of progen model
+ """
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
+ if input_ids is not None:
+ input_shape = input_ids.shape
+ input_ids = input_ids.view(-1, input_shape[-1])
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ input_shape = inputs_embeds.shape[:-1]
+ batch_size = inputs_embeds.shape[0]
+ else:
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
+
+ # device = input_ids.device if input_ids is not None else inputs_embeds.device
+
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
+
+ if position_ids is not None:
+ position_ids = position_ids.view(-1, input_shape[-1])
+
+ if past_key_values is None:
+ past_length = 0
+ past_key_values = tuple([None] * len(self.h))
+ else:
+ past_length = past_key_values[0][0].shape[-2]
+
+ if position_ids is None:
+ position_ids = ops.arange(past_length, input_shape[-1] + past_length)
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
+
+ # Attention mask.
+ if attention_mask is not None:
+ attention_mask = attention_mask.view(batch_size, -1)
+ # We create a 3D attention mask from a 2D tensor mask.
+ # Sizes are [batch_size, 1, 1, to_seq_length]
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
+ # this attention mask is more simple than the triangular masking of causal attention
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
+ attention_mask = attention_mask[:, None, None, :]
+
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
+ # masked positions, this operation will create a tensor which is 0.0 for
+ # positions we want to attend and -10000.0 for masked positions.
+ # Since we are adding it to the raw scores before the softmax, this is
+ # effectively the same as removing these entirely.
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
+ attention_mask = (1.0 - attention_mask) * -10000.0
+
+ # Prepare head mask if needed
+ # 1.0 in head_mask indicate we keep the head
+ # attention_probs has shape bsz x num_attention_heads x N x N
+ # head_mask has shape n_layer x batch x num_attention_heads x N x N
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
+
+ if inputs_embeds is None:
+ inputs_embeds = self.wte(input_ids)
+
+ hidden_states = inputs_embeds
+
+ if token_type_ids is not None:
+ token_type_embeds = self.wte(token_type_ids)
+ hidden_states = hidden_states + token_type_embeds
+
+ hidden_states = self.drop(hidden_states)
+
+ output_shape = input_shape + (hidden_states.shape[-1],)
+
+ presents = () if use_cache else None
+ all_self_attentions = () if output_attentions else None
+ all_hidden_states = () if output_hidden_states else None
+
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
+
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if getattr(self.config, "gradient_checkpointing", False) and self.training:
+
+ if use_cache:
+ use_cache = False
+
+ else:
+ add_input = (layer_past, attention_mask, head_mask[i])
+ outputs = block(
+ hidden_states,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ add_input=add_input
+ )
+
+ hidden_states = outputs[0]
+ if use_cache is True:
+ presents = presents + (outputs[1],)
+
+ if output_attentions:
+ all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)
+
+ hidden_states = self.ln_f(hidden_states)
+
+ hidden_states = hidden_states.view(*output_shape)
+ # Add last hidden state
+ if output_hidden_states:
+ all_hidden_states = all_hidden_states + (hidden_states,)
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
+
+
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=presents,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attentions,
+ )
+
+
+class ProGenForCausalLM(ProGenPreTrainedModel):
+ """
+ ProGenForCausalLM class
+ """
+ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"]
+
+ def __init__(self, config):
+ super().__init__(config)
+ self.transformer = ProGenModel(config)
+ self.lm_head = DenseMindnlp(config.n_embd, config.vocab_size)
+ self.loss_fct = CrossEntropyLoss()
+
+ @staticmethod
+ def _reorder_cache(past: Tuple[Tuple[Tensor]], beam_idx: Tensor) -> Tuple[Tuple[Tensor]]:
+ """
+ This function is used to re-order the :obj:`past_key_values` cache if
+ :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
+ called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
+ """
+ return tuple(
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past)
+ for layer_past in past
+ )
+
+ def construct(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ ):
+ """
+ construct method of ProGenForCausalLM class
+ """
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ transformer_outputs = self.transformer(
+ input_ids,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ token_type_ids=token_type_ids,
+ position_ids=position_ids,
+ head_mask=head_mask,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+
+ lm_logits = self.lm_head(hidden_states).to(ms.float32)
+
+ loss = None
+ if labels is not None:
+ # Shift so that tokens < n predict n
+ shift_logits = lm_logits[..., :-1, :].contiguous()
+ shift_labels = labels[..., 1:].contiguous()
+ # Flatten the tokens
+ loss = self.loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1).astype("int32"))
+ loss = loss.to(hidden_states.dtype)
+
+ if not return_dict:
+ output = (lm_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+
+ def get_output_embeddings(self):
+ return None
+
+ def set_output_embeddings(self, new_embeddings):
+ return
+
+ def prepare_inputs_for_generation_new(self, input_ids, past_key_values=None, **kwargs):
+ """
+ new method of preparing inputs for generation
+ """
+ token_type_ids = kwargs.get("token_type_ids", None)
+ # only last token for inputs_ids if past is defined in kwargs
+ if past_key_values:
+ input_ids = input_ids[:, -1].unsqueeze(-1)
+ if token_type_ids is not None:
+ token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
+
+ attention_mask = kwargs.get("attention_mask", None)
+ position_ids = kwargs.get("position_ids", None)
+
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -1].unsqueeze(-1)
+ else:
+ position_ids = None
+ return {
+ "input_ids": input_ids,
+ "past_key_values": past_key_values,
+ "use_cache": kwargs.get("use_cache"),
+ "position_ids": position_ids,
+ "attention_mask": attention_mask,
+ "token_type_ids": token_type_ids,
+ }
+
+
+class PrintTime:
+ def __init__(self, desc):
+ self.desc = desc
+
+ def __enter__(self):
+ print(self.desc)
+ self.t = time.time()
+
+ def __exit__(self, input_type, value, traceback):
+ print(f'{self.desc} took {time.time()-self.t:.02f}s')
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc71fa63e8d7ceeb9dee3e3d4d67d209fb8c004e
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py
@@ -0,0 +1,283 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""progen"""
+
+import os
+import random
+
+import mindspore as ms
+from mindspore import ops, Tensor
+from tokenizers import Tokenizer
+
+from .nn_arch import ProGenForCausalLM, ProGenConfig, PrintTime
+from ..model import Model
+
+FIRST_TOKEN = 5
+LAST_TOKEN = 29
+BOS_TOKEN = 3
+EOS_TOKEN = 4
+ALPHABET = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N',
+ 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z']
+
+
+class ProGen(Model):
+ """
+ ProGen class
+ """
+ name = "ProGen"
+
+ def __init__(self, config):
+ self.args = config
+ self.config = ProGenConfig(
+ config.vocab_size, config.n_positions, config.n_ctx, config.n_embd,
+ config.n_layer, config.n_head, config.rotary_dim, config.n_inner,
+ config.activation_function, config.resid_pdrop, config.embd_pdrop,
+ config.attn_pdrop, config.layer_norm_epsilon, config.initializer_range,
+ config.scale_attn_weights, config.gradient_checkpointing, config.use_cache,
+ config.bos_token_id, config.eos_token_id, config.min_length,
+ )
+ self.mixed_precision = False
+ self.network = ProGenForCausalLM(self.config)
+ self.checkpoint_url = "Local_Checkpoint_Used"
+ self.checkpoint_path = config.ckpt_dir
+ with PrintTime('loading tokenizer'):
+ self.tokenizer = self.create_tokenizer_custom(file=self.args.tokenizer_file)
+ self.set_env()
+ self.set_seed(self.args.rng_seed, deterministic=self.args.rng_deterministic)
+
+ super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path,
+ network=self.network, mixed_precision=self.mixed_precision)
+
+ def set_env(self):
+ os.environ['TOKENIZERS_PARALLELISM'] = 'false'
+
+ def set_seed(self, seed, deterministic=True):
+ print("deterministic", deterministic)
+ random.seed(seed)
+ os.environ['PYTHONHASHSEED'] = str(seed)
+
+ def create_tokenizer_custom(self, file):
+ with open(file, 'r') as f:
+ return Tokenizer.from_str(f.read())
+
+ def cross_entropy(self, logits, target, reduction='mean'):
+ return ops.cross_entropy(input=logits, target=target, reduction=reduction)
+
+ def log_likelihood(self, logits, target, reduction='mean'):
+ return -self.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1), reduction=reduction)
+
+ def log_likelihood_custom_1(self, logits, target, reduction='mean'):
+ return -ops.nll_loss(inputs=ops.log_softmax(logits, axis=1), target=target, reduction=reduction)
+
+ def log_likelihood_custom_2(self, logits, target, reduction='mean'):
+ log_likelihood = 0.0
+ n = logits.shape[0]
+ for i in range(n):
+ log_likelihood += ops.log_softmax(logits, axis=1)[i, target[i]] / (1. if reduction == 'sum' else n)
+ return log_likelihood
+
+ def cal_cross_entropy(self, tokens):
+ target = ms.tensor(self.tokenizer.encode(tokens).ids)
+ logits = self.network(target, labels=target).logits
+ # shift
+ logits = logits[:-1, ...]
+ target = target[1:]
+ return self.cross_entropy(logits=logits, target=target).item()
+
+ def ll(self, tokens, f, reduction):
+ """
+ shift, remove terminals and remove unused logits
+ """
+ target = Tensor(self.tokenizer.encode(tokens).ids)
+ logits = self.network(target, labels=target).logits
+
+ # shift
+ logits = logits[:-1, ...]
+ target = target[1:]
+
+ # remove terminals
+ if target[-1] in [self.BOS_TOKEN, self.EOS_TOKEN]:
+ logits = logits[:-1, ...]
+ target = target[:-1]
+
+ # remove unused logits
+ logits = logits[:, self.FIRST_TOKEN:(self.LAST_TOKEN + 1)]
+ target = target - self.FIRST_TOKEN
+
+ return f(logits=logits, target=target, reduction=reduction).item()
+
+ def from_pretrained(self, ckpt_path=None):
+ "from_pretrained"
+ self.get_checkpoint_path(ckpt_path)
+ if not ckpt_path:
+ param_dict = ms.load_checkpoint(self.checkpoint_path)
+ else:
+ param_dict = ms.load_checkpoint(ckpt_path)
+ param_not_load, _ = ms.load_param_into_net(self.network, param_dict)
+ print(f'param not load: {param_not_load}')
+
+ def predict(self, data, **kwargs):
+ if self.args.sanity:
+
+ with PrintTime('sanity cross-entropy'):
+
+ x_uniref90bfd30 = self.args.x_uniref90bfd30
+ x_oas = self.args.x_oas
+ x_bfd90 = self.args.x_bfd90
+
+ checkpoint_x_ce = {
+ 'progen2-small': (x_uniref90bfd30, 2.4),
+ 'progen2-medium': (x_uniref90bfd30, 1.9),
+ 'progen2-base': (x_uniref90bfd30, 1.9),
+ 'progen2-large': (x_uniref90bfd30, 1.8),
+ 'progen2-xlarge': (x_uniref90bfd30, 1.0),
+ 'progen2-oas': (x_oas, 0.3),
+ 'progen2-BFD90': (x_bfd90, 1.3),
+ }
+
+ ce_eval = self.cal_cross_entropy(checkpoint_x_ce[self.args.model][0])
+ ce_target = checkpoint_x_ce[self.args.model][1]
+
+ print(ce_target, ce_eval, abs(ce_eval - ce_target))
+
+ with PrintTime('sanity log-likelihood'):
+
+ x_data = self.args.x_data
+
+ ll_0 = self.ll(x_data, f=self.log_likelihood, reduction='mean')
+ ll_1 = self.ll(x_data, f=self.log_likelihood_custom_1, reduction='mean')
+ ll_2 = self.ll(x_data, f=self.log_likelihood_custom_2, reduction='mean')
+
+ print(f'll_0={ll_0}')
+ print(f'll_1={ll_1}')
+ print(f'll_2={ll_2}')
+
+ with PrintTime('sanity model'):
+
+ x_data = self.args.x_data
+ x_random = '2' + ''.join([random.choice(ALPHABET) for _ in range(len(x_data)-2)]) + '1'
+ x_perturb = x_random[:64] + x_data[len(x_random[:64]):]
+
+ ll_x_data = self.ll(x_data, f=self.log_likelihood, reduction='mean')
+ ll_x_random = self.ll(x_random, f=self.log_likelihood, reduction='mean')
+ ll_x_perturb = self.ll(x_perturb, f=self.log_likelihood, reduction='mean')
+
+ print(f'll_x_data={ll_x_data}')
+ print(f'll_x_random={ll_x_random}')
+ print(f'll_x_perturb={ll_x_perturb}')
+
+ with PrintTime('log-likelihood (left-to-right, right-to-left)'):
+
+ reverse = lambda s: s[::-1]
+
+ ll_lr_sum = self.ll(tokens=data, f=self.log_likelihood, reduction='sum')
+ ll_rl_sum = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='sum')
+
+ ll_lr_mean = self.ll(tokens=data, f=self.log_likelihood, reduction='mean')
+ ll_rl_mean = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='mean')
+
+ ll_sum = .5 * (ll_lr_sum + ll_rl_sum)
+ ll_mean = .5 * (ll_lr_mean + ll_rl_mean)
+
+ print(f'll_sum={(ll_sum)}')
+ print(f'll_mean={ll_mean}')
+
+ def sample(self, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id):
+ """
+ sample method of progen model
+ """
+ input_ids = Tensor([tokenizer.encode(context).ids])
+ tokens_batch = model.generate(
+ input_ids,
+ do_sample=True,
+ temperature=temp,
+ max_length=max_length,
+ top_p=top_p,
+ num_return_sequences=num_return_sequences,
+ pad_token_id=pad_token_id,
+ )
+ as_lists = lambda batch: [batch[i, ...].asnumpy().tolist() for i in range(batch.shape[0])]
+ return tokenizer.decode_batch(as_lists(tokens_batch))
+
+ def truncate(self, input_sample, terminals):
+ """
+ truncate method
+ """
+ pos = []
+ for terminal in terminals:
+ find_pos = input_sample.find(terminal, 1)
+ if find_pos != -1:
+ pos.append(find_pos)
+ if pos:
+ return input_sample[:(min(pos) + 1)]
+ return input_sample
+
+ def generate(self):
+ """
+ generate method
+ """
+ if self.args.sanity:
+ with PrintTime('sanity cross-entropy'):
+
+ x_uniref90bfd30 = self.args.x_uniref90bfd30
+ x_oas = self.args.x_oas
+ x_bfd90 = self.args.x_bfd90
+
+ checkpoint_x_ce = {
+ 'progen2-small': (x_uniref90bfd30, 2.4),
+ 'progen2-medium': (x_uniref90bfd30, 1.9),
+ 'progen2-base': (x_uniref90bfd30, 1.9),
+ 'progen2-large': (x_uniref90bfd30, 1.8),
+ 'progen2-xlarge': (x_uniref90bfd30, 1.0),
+ 'progen2-oas': (x_oas, 0.3),
+ 'progen2-BFD90': (x_bfd90, 1.3),
+ }
+
+ ce_eval = self.cal_cross_entropy(checkpoint_x_ce.get(self.args.model)[0])
+ ce_target = checkpoint_x_ce.get(self.args.model)[1]
+
+ if abs(ce_eval - ce_target) >= 0.1:
+ raise ValueError("Difference should be within 0.1")
+
+ with PrintTime('sampling'):
+ completions = self.sample(model=self.network, tokenizer=self.tokenizer, context=self.args.context,
+ pad_token_id=self.tokenizer.encode('<|pad|>').ids[0],
+ num_return_sequences=self.args.num_samples, temp=self.args.t,
+ top_p=self.args.p, max_length=self.args.max_length)
+ truncations = [self.truncate(completion, terminals=['1', '2']) for completion in completions]
+
+ print(self.args.num_samples, "sequences are sampled in total.")
+ for (i, truncation) in enumerate(truncations):
+ print("The generated sequence with index", i, "is: ", truncation)
+
+
+ def train_step(self, data):
+ return None
+
+
+ def forward(self, data):
+ return None
+
+
+ def backward(self, data):
+ return None
+
+
+ def _jit_forward(self, data):
+ return None
+
+
+ def _pynative_forward(self, data):
+ return None
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..b08eca21e437252e6480b23a1415daba4e901a0a
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py
@@ -0,0 +1,20 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""progen_configuration"""
+
+progen_configuration = {
+ "small":
+ "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/ProGen/small.yaml",
+}
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8232e64ca2ea6888c70b5a54ac441c94a905ea6
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py
@@ -0,0 +1,46 @@
+# Copyright 2024 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""progen_dataset"""
+
+from .nn_arch import PrintTime
+from ...dataset import PSP
+
+
+class ProGenDataSet(PSP):
+ "EquiDockDataSet"
+ def __init__(self, config):
+ self.config = config
+ super().__init__()
+
+ def process(self, data, **kwargs):
+ return data
+
+ def set_training_data_src(self, data_source, **kwargs):
+ with PrintTime('set_training_data_src'):
+ print(data_source)
+ print(**kwargs)
+
+ def create_iterator(self, num_epochs, **kwargs):
+ return None
+
+ def data_parse(self, idx):
+ return None
+
+ #pylint: disable=arguments-differ
+ def __getitem__(self, idx):
+ pass
+
+ def __len__(self):
+ pass
diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json
new file mode 100644
index 0000000000000000000000000000000000000000..cd3d5a9d0791354e3d0a1d27ddc85c89acac022d
--- /dev/null
+++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json
@@ -0,0 +1,91 @@
+{
+ "version": "1.0",
+ "truncation": null,
+ "padding": null,
+ "added_tokens": [
+ {
+ "id": 0,
+ "special": true,
+ "content": "<|pad|>",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 1,
+ "special": true,
+ "content": "<|bos|>",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ },
+ {
+ "id": 2,
+ "special": true,
+ "content": "<|eos|>",
+ "single_word": false,
+ "lstrip": false,
+ "rstrip": false,
+ "normalized": false
+ }
+ ],
+ "normalizer": null,
+ "pre_tokenizer": {
+ "type": "ByteLevel",
+ "add_prefix_space": false,
+ "trim_offsets": true
+ },
+ "post_processor": {
+ "type": "ByteLevel",
+ "add_prefix_space": true,
+ "trim_offsets": true
+ },
+ "decoder": {
+ "type": "ByteLevel",
+ "add_prefix_space": true,
+ "trim_offsets": true
+ },
+ "model": {
+ "type": "BPE",
+ "dropout": null,
+ "unk_token": null,
+ "continuing_subword_prefix": null,
+ "end_of_word_suffix": null,
+ "fuse_unk": false,
+ "vocab": {
+ "<|pad|>": 0,
+ "<|bos|>": 1,
+ "<|eos|>": 2,
+ "1": 3,
+ "2": 4,
+ "A": 5,
+ "B": 6,
+ "C": 7,
+ "D": 8,
+ "E": 9,
+ "F": 10,
+ "G": 11,
+ "H": 12,
+ "I": 13,
+ "K": 14,
+ "L": 15,
+ "M": 16,
+ "N": 17,
+ "O": 18,
+ "P": 19,
+ "Q": 20,
+ "R": 21,
+ "S": 22,
+ "T": 23,
+ "U": 24,
+ "V": 25,
+ "W": 26,
+ "X": 27,
+ "Y": 28,
+ "Z": 29
+ },
+ "merges": []
+ }
+}
\ No newline at end of file
diff --git a/MindSPONGE/src/mindsponge/pipeline/pipeline.py b/MindSPONGE/src/mindsponge/pipeline/pipeline.py
index 60a41859b82b8bb4a88188d1174f6cfec8408c65..850a55fb693236b2ba636d6881839db1bc4f4baf 100644
--- a/MindSPONGE/src/mindsponge/pipeline/pipeline.py
+++ b/MindSPONGE/src/mindsponge/pipeline/pipeline.py
@@ -35,10 +35,11 @@ from .models import RASP, RASPDataSet, rasp_configuration
from .models import Multimer, MultimerDataSet, multimer_configuration
from .models import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configuration
from .models import UFold, UFoldDataSet, ufold_configuration
+from .models import EquiDock, EquiDockDataSet, equidock_configuration
+from .models import ProGen, ProGenDataSet, progen_configuration
from .models import ProtT5, ProtT5TrainDataSet, prott5pretrain_configuration
from .models import ProtT5DownstreamTasks, ProtT5TaskDataSet, prott5downtask_configuration
-
model_card = {
"ColabDesign": {"model": COLABDESIGN, "dataset": ColabDesignDataSet, "config": colabdesign_configuration},
"DeepDR": {"model": DeepDR, "dataset": DeepDRDataSet, "config": deepdr_configuration},
@@ -55,6 +56,8 @@ model_card = {
"Proteinmpnn": {"model": ProteinMpnn, "dataset": ProteinMpnnDataset, "config": proteinmpnn_configuration},
"RASP": {"model": RASP, "dataset": RASPDataSet, "config": rasp_configuration},
"UFold": {"model": UFold, "dataset": UFoldDataSet, "config": ufold_configuration},
+ "EquiDock": {"model": EquiDock, "dataset": EquiDockDataSet, "config": equidock_configuration},
+ "ProGen": {"model": ProGen, "dataset": ProGenDataSet, "config": progen_configuration},
"ProtT5": {"model": ProtT5, "dataset": ProtT5TrainDataSet, "config": prott5pretrain_configuration},
"ProtT5Downstream": {"model": ProtT5DownstreamTasks, "dataset": ProtT5TaskDataSet,
"config": prott5downtask_configuration}