From 14b61d03a238591e284e5e5954aa964bafe8c1d6 Mon Sep 17 00:00:00 2001 From: yuchengzhang Date: Mon, 5 Aug 2024 15:05:03 +0800 Subject: [PATCH 01/16] equidock model first commit --- .../mindsponge/pipeline/models/__init__.py | 1 + .../pipeline/models/equidock/__init__.py | 27 + .../pipeline/models/equidock/equidock.py | 398 +++++ .../models/equidock/equidock_configuration.py | 32 + .../pipeline/models/equidock/equidock_data.py | 251 +++ .../models/equidock/equidock_dataset.py | 38 + .../pipeline/models/equidock/nn_arch.py | 1580 +++++++++++++++++ .../src/mindsponge/pipeline/pipeline.py | 2 + 8 files changed, 2329 insertions(+) create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py diff --git a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py index 3e0858a40..8e8fff3c5 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py @@ -35,3 +35,4 @@ from .multimer import Multimer, MultimerDataSet, multimer_configuration from .proteinmpnn import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configuration from .ufold import UFold, UFoldDataSet, ufold_configuration from .rasp import RASP, RASPDataSet, rasp_configuration +from .equidock import EquiDock, EquiDockDataSet, equidock_configuration diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py new file mode 100644 index 000000000..261ae3475 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2023 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock""" + +from .equidock import EquiDock +from .equidock_dataset import EquiDockDataSet +from .equidock_configuration import equidock_configuration diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py new file mode 100644 index 000000000..9089dd096 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py @@ -0,0 +1,398 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import random +from datetime import datetime as dt + +import numpy as np +from biopandas.pdb import PandasPdb +import mindspore as ms +from mindspore import jit, context, nn, Tensor, ops, save_checkpoint +from mindspore.experimental import optim + +from .nn_arch import * +from .equidock_data import UnboundBoundData +from ..model import Model + + +class EquiDock(Model): + "EquiDock" + name = "EquiDock" + + def __init__(self, config): + + self.config = config + self.mixed_precision = False + self.network = RigidBodyDockingNet(self.config) + self.log(self.network) + self.checkpoint_url = "Local_Checkpoint_Used" + self.checkpoint_path = self.config.ckpt_dir + + if self.config.is_train: + self.config.cache_path = './cache/' + self.config.data + '_' + \ + self.config.graph_nodes + '_maxneighbor_' + str(self.config.graph_max_neighbor) + \ + '_cutoff_' + str(self.config.graph_cutoff) + '_pocketCut_' + \ + str(self.config.pocket_cutoff) + '/' + self.config.cache_path = os.path.join(self.config.cache_path, 'cv_' + str(self.config.split)) + banner = 'EQUIDOCK' + self.config.checkpoint_dir = './checkpts/' + banner + self.config.log_files_dir = './stdouterr/' + + ### Create log files only when in train mode + create_dir(self.config.log_files_dir) + create_dir(self.config.checkpoint_dir) + + log_file_name = os.path.join(self.config.log_files_dir, banner + ".txt") + with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as fout: + fout.write('[' + str(dt.now()) + '] START\n') + self.config.checkpoint_filename = os.path.join(self.config.checkpoint_dir, self.config.data + '_model_best.pth') + + self.loss_fn_coors = nn.MSELoss(reduction='mean') + self.optimizer = optim.Adam(self.network.trainable_params(), lr=self.config.lr, weight_decay=self.config.w_decay) + self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, + lr_lambda=[lambda epoch: min(1., ((epoch + 1) / self.config.warmup) ** 3)]) + + self.best_epoch = 0 + self.best_val_rmsd_median = float('inf') + self.corr_val_rmsd_mean = float('inf') + + if not os.path.exists(self.config.processed_dataset_path): + UnboundBoundData(self.config, reload_mode='val', raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) + UnboundBoundData(self.config, reload_mode='test', raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) + UnboundBoundData(self.config, reload_mode='train', raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) + + self.train_data_batched, self.train_loader = create_dataloader(self.config.train_dir, bs=self.config.bs, shuffle=True) + self.val_data_batched, self.val_loader = create_dataloader(self.config.val_dir, bs=self.config.bs, shuffle=False) + self.test_data_batched, self.test_loader = create_dataloader(self.config.test_dir, bs=self.config.bs, shuffle=False) + + super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, + network=self.network, mixed_precision=self.mixed_precision) + + + def from_pretrained(self, ckpt_path=None): + "from_pretrained" + self.get_checkpoint_path(ckpt_path) + if not ckpt_path: + param_dict = ms.load_checkpoint(self.checkpoint_path) + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(self.network, param_dict) + print(f'param not load: {param_not_load}') + + + def predict(self, data, **kwargs): + time_list = [] + input_dir = self.config.input_dir + ground_truth_dir = self.config.ground_truth_dir + output_dir = self.config.output_dir + file_type = '.pdb' + l_b_str = '_l_b' + + files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(file_type)] + + for file in files: + if not file.endswith(l_b_str + file_type): + continue + ll = len(l_b_str + file_type) + ligand_filename = os.path.join(input_dir, file[:-ll] + l_b_str + file_type) + receptor_filename = os.path.join(ground_truth_dir, file[:-ll] + '_r_b' + '_COMPLEX' + file_type) + gt_ligand_filename = os.path.join(ground_truth_dir, file[:-ll] + l_b_str + '_COMPLEX' + file_type) + out_filename = file[:-ll] + l_b_str + '_' + "EQUIDOCK" + file_type + + self.log(' inference on file = ', ligand_filename) + + start = dt.now() + + ppdb_ligand = PandasPdb().read_pdb(ligand_filename) + + ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array\ + = prepare_graphs(self.config, ppdb_ligand, ligand_filename, receptor_filename) + + if self.config.input_edge_feats_dim < 0: + self.config.input_edge_feats_dim = ligand_graph.edata['he'].shape[1] + + ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, \ + input_tensor_tuple = graph_to_tensor(ligand_graph, receptor_graph) + + model_ligand_coors_deform_list, model_keypts_ligand_list, model_keypts_receptor_list, \ + all_rotation_list, all_translation_list = self.network( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list, + input_tensor_tuple, + ) + + rotation = all_rotation_list[0].asnumpy() + translation = all_translation_list[0].asnumpy() + + new_residues = (rotation @ bound_ligand_repres_nodes_loc_clean_array.T).T + translation + + unbound_ligand_new_pos = (rotation @ unbound_ligand_all_atoms_pre_pos.T).T + translation + + euler_angles_finetune = ops.zeros([3]) + translation_finetune = ops.zeros([3]) + ligand_th = (get_rot_mat(euler_angles_finetune) @ Tensor(unbound_ligand_new_pos).T).T + translation_finetune + + ppdb_ligand.df['ATOM'][['x_coord', 'y_coord', 'z_coord']] = ligand_th.asnumpy() # unbound_ligand_new_pos + unbound_ligand_save_filename = os.path.join(output_dir, out_filename) + ppdb_ligand.to_pdb(path=unbound_ligand_save_filename, records=['ATOM'], gz=False) + + end = dt.now() + time_list.append((end - start).total_seconds()) + + time_array = np.array(time_list) + self.log("Mean runtime:", np.mean(time_array), "std runtime:", np.std(time_array)) + self.log('Time list = ', time_list) + + + def run_a_train_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + train_complex_rmsd_mean, train_complex_rmsd_median, \ + train_ligand_rmsd_mean, train_ligand_rmsd_median, \ + train_receptor_rmsd_mean, train_receptor_rmsd_median, \ + train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, \ + train_avg_loss_ot, train_avg_loss_intersection = \ + self.run_a_generic_epoch('train', run_epoch_tuple, data_batched, data_loader) + + pretty_print_stats('TRAIN', epoch, args.num_epochs, + train_complex_rmsd_mean, train_complex_rmsd_median, + train_ligand_rmsd_mean, train_ligand_rmsd_median, + train_receptor_rmsd_mean, train_receptor_rmsd_median, + train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, + train_avg_loss_ot, train_avg_loss_intersection, + self.log) + + + def run_an_eval_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + val_complex_rmsd_mean, val_complex_rmsd_median, \ + val_ligand_rmsd_mean, val_ligand_rmsd_median, \ + val_receptor_rmsd_mean, val_receptor_rmsd_median, \ + val_avg_loss, val_avg_loss_ligand_coors, \ + val_avg_loss_receptor_coors, \ + val_avg_loss_ot, val_avg_loss_intersection = \ + self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader) + + pretty_print_stats('VALIDATION', epoch, args.num_epochs, + val_complex_rmsd_mean, val_complex_rmsd_median, + val_ligand_rmsd_mean, val_ligand_rmsd_median, + val_receptor_rmsd_mean, val_receptor_rmsd_median, + val_avg_loss, val_avg_loss_ligand_coors, val_avg_loss_receptor_coors, + val_avg_loss_ot, val_avg_loss_intersection, + self.log) + return val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss + + + def run_a_test_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + test_complex_rmsd_mean, test_complex_rmsd_median, \ + test_ligand_rmsd_mean, test_ligand_rmsd_median, \ + test_receptor_rmsd_mean, test_receptor_rmsd_median, \ + test_avg_loss, test_avg_loss_ligand_coors, \ + test_avg_loss_receptor_coors, \ + test_avg_loss_ot, test_avg_loss_intersection = \ + self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader) + + pretty_print_stats('FINAL TEST for ' + args.data, -1, args.num_epochs, + test_complex_rmsd_mean, test_complex_rmsd_median, + test_ligand_rmsd_mean, test_ligand_rmsd_median, + test_receptor_rmsd_mean, test_receptor_rmsd_median, + test_avg_loss, test_avg_loss_ligand_coors, test_avg_loss_receptor_coors, + test_avg_loss_ot, test_avg_loss_intersection, self.log) + + + def run_a_generic_epoch(self, ep_type, run_epoch_tuple, data_batched, data_loader): + args, epoch, self.network, loss_fn_coors, optimizer = run_epoch_tuple + + meter = MeterUnboundBound() + + avg_loss, total_loss, num_batches = 0., 0., 0 + + total_loss_ligand_coors = 0. + avg_loss_ligand_coors = 0. + + total_loss_receptor_coors = 0. + avg_loss_receptor_coors = 0. + + total_loss_ot = 0. + avg_loss_ot = 0. + + total_loss_intersection = 0. + avg_loss_intersection = 0. + + def forward(batch_id): + + ######## RUN MODEL ############## + ligand_graph_node_tensor = Tensor(data_loader[batch_id][6], ms.float32) + receptor_graph_node_tensor = Tensor(data_loader[batch_id][7], ms.float32) + unbatch_list_tensor = Tensor(data_loader[batch_id][8], ms.int32) + input_tensor_tuple = ( + Tensor(data_loader[batch_id][4], ms.int32), # ligand_graph_num_nodes + Tensor(data_loader[batch_id][5], ms.int32), # receptor_graph_num_nodes + Tensor(data_loader[batch_id][2], ms.float32), # ligand_graph_edge_tensor + Tensor(data_loader[batch_id][3], ms.float32), # receptor_graph_edge_tensor + Tensor(data_loader[batch_id][0], ms.int32), # ll_connection_tensor + Tensor(data_loader[batch_id][1], ms.int32), # rr_connection_tensor + ) + + model_ligand_coors_deform_list, \ + model_keypts_ligand_list, model_keypts_receptor_list, \ + _, _, = self.network( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ) + ################################ + + batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss = Tensor(0.0), \ + Tensor(0.0), Tensor(0.0), Tensor(0.0) + + for i, _ in enumerate(model_ligand_coors_deform_list): + ## Compute average MSE loss (which is 3 times smaller than average squared RMSD) + + batch_ligand_coors_loss = batch_ligand_coors_loss + loss_fn_coors(model_ligand_coors_deform_list[i], + Tensor(data_batched[8][batch_id][i], + ms.float32)) + # Compute the OT loss for the binding pocket: + ligand_pocket_coors = Tensor(data_batched[10][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes + receptor_pocket_coors = Tensor(data_batched[11][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes + + ligand_keypts_coors = model_keypts_ligand_list[i] # (K, 3), K = num keypoints + receptor_keypts_coors = model_keypts_receptor_list[i] # (K, 3), K = num keypoints + + ## (N, K) cost matrix + cost_mat_ligand = compute_sq_dist_mat(ligand_pocket_coors, ligand_keypts_coors) + cost_mat_receptor = compute_sq_dist_mat(receptor_pocket_coors, receptor_keypts_coors) + + ot_dist, _ = compute_ot_emd(cost_mat_ligand + cost_mat_receptor) + batch_ot_loss = batch_ot_loss + ot_dist + + batch_intersection_loss = batch_intersection_loss + compute_body_intersection_loss( + model_ligand_coors_deform_list[i], data_batched[9][batch_id][i], + args.intersection_sigma, args.intersection_surface_ct) + + ### Add new stats to the meter + if ep_type != 'train' or random.random() < 0.1: + meter.update_rmsd(model_ligand_coors_deform_list[i], + data_batched[9][batch_id][i], + data_batched[8][batch_id][i], + data_batched[9][batch_id][i]) + + batch_ligand_coors_loss = batch_ligand_coors_loss / float(len(model_ligand_coors_deform_list)) + batch_receptor_coors_loss = batch_receptor_coors_loss / float(len(model_ligand_coors_deform_list)) + batch_ot_loss = batch_ot_loss / float(len(model_ligand_coors_deform_list)) + batch_intersection_loss = batch_intersection_loss / float(len(model_ligand_coors_deform_list)) + + loss_coors = batch_ligand_coors_loss + batch_receptor_coors_loss + + loss = loss_coors + args.pocket_ot_loss_weight * batch_ot_loss \ + + args.intersection_loss_weight * batch_intersection_loss + + return loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss + + backward = ms.value_and_grad(forward, None, optimizer.parameters) + + # Iterate over all batches of the epoch + for step in range(len(data_batched[0])): + num_batches += 1 + + (loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, + batch_intersection_loss), grads = backward(step) + + if ep_type == 'train': + grads = ms.ops.clip_by_norm(grads, max_norm=args.clip, norm_type=2) + optimizer(grads) + + total_loss += loss.asnumpy() + total_loss_ligand_coors += batch_ligand_coors_loss + total_loss_receptor_coors += batch_receptor_coors_loss + total_loss_ot += batch_ot_loss + total_loss_intersection += batch_intersection_loss + + if num_batches != 0: + avg_loss = total_loss / num_batches + avg_loss_ligand_coors = total_loss_ligand_coors / num_batches + avg_loss_receptor_coors = total_loss_receptor_coors / num_batches + avg_loss_ot = total_loss_ot / num_batches + avg_loss_intersection = total_loss_intersection / num_batches + + ligand_rmsd_mean, receptor_rmsd_mean, complex_rmsd_mean = meter.summarize(reduction_rmsd='mean') + ligand_rmsd_median, receptor_rmsd_median, complex_rmsd_median = meter.summarize(reduction_rmsd='median') + + return complex_rmsd_mean, complex_rmsd_median, \ + ligand_rmsd_mean, ligand_rmsd_median, \ + receptor_rmsd_mean, receptor_rmsd_median, \ + avg_loss.item(), avg_loss_ligand_coors.item(), \ + avg_loss_receptor_coors.item(), \ + avg_loss_ot.item(), avg_loss_intersection.item() + + + def train_step(self, epoch): + + self.log('+' * 100) + epoch_start = dt.now() + + run_epoch_tuple = (self.config, epoch, self.network, self.loss_fn_coors, self.optimizer) + self.run_a_train_epoch(run_epoch_tuple, self.train_data_batched, self.train_loader, epoch, self.config) + val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss = \ + self.run_an_eval_epoch(run_epoch_tuple, self.val_data_batched, self.val_loader, epoch, self.config) + + self.scheduler.step() + + if val_complex_rmsd_median < self.best_val_rmsd_median * 0.98: # We do this to avoid "pure luck" + # Best validation so far + self.best_val_rmsd_median = val_complex_rmsd_median + self.corr_val_rmsd_mean = val_complex_rmsd_mean + self.best_epoch = epoch + save_checkpoint(self.optimizer, self.config.checkpoint_dir + "/best_model.ckpt") + + log('[BEST SO FAR] ', self.config.data, + '|| At best epoch {} we have: best_val_rmsd_median {:.4f}, ' + '|| Current val rmsd_median {:.4f} ' + '|| Train time: {}\n'. + format(self.best_epoch, self.best_val_rmsd_median, + val_complex_rmsd_median, + dt.now() - epoch_start)) + + return '\n' + + + def log(self, *pargs): + print('[' + str(dt.now()) + '] ', *pargs) + + + def forward(self, data): + return None + + + def backward(self, data): + return None + + + def _jit_forward(self, data): + return None + + + def _pynative_forward(self, data): + return None \ No newline at end of file diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py new file mode 100644 index 000000000..07e21e270 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py @@ -0,0 +1,32 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""equidock_configuration""" + +equidock_configuration = { + "predict_dips": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml", + "predict_db5": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml", + "train_db5": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml", +} \ No newline at end of file diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py new file mode 100644 index 000000000..5adf98836 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py @@ -0,0 +1,251 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import numpy as np +from biopandas.pdb import PandasPdb +from joblib import Parallel, delayed, cpu_count +from scipy.spatial.transform import Rotation + +from .nn_arch import * + + +def one_hot_encoding(x, allowable_set, encode_unknown=False): + if encode_unknown and (allowable_set[-1] is not None): + allowable_set.append(None) + + if encode_unknown and (x not in allowable_set): + x = None + + return list(map(lambda s: x == s, allowable_set)) + + +def uniformrotation_translation(translation_interval): + rotation = Rotation.random(num=1) + rotation_matrix = rotation.as_matrix().squeeze() + + t = np.random.randn(1, 3) + t = t / np.sqrt(np.sum(t * t)) + length = np.random.uniform(low=0, high=translation_interval) + t = t * length + return rotation_matrix.astype(np.float32), t.astype(np.float32) + + +def get_residues_db5(pdb_filename): + df = PandasPdb().read_pdb(pdb_filename).df['ATOM'] + df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname', + 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True) + residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order ! + return residues + + +def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, **kwargs): + if n_jobs is None: + n_jobs = cpu_count() - 1 + results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( + delayed(pickleable_fn)(*d, **kwargs) for i, d in enumerate(data) + ) + + return results + + +def pmap_multi2(pickleable_fn, data, n_jobs=None, verbose=1, **kwargs): + if n_jobs is None: + n_jobs = cpu_count() - 1 + results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( + delayed(pickleable_fn)(d, **kwargs) for i, d in enumerate(data) + ) + + return results + + +class UnboundBoundData(): + + def __init__(self, args, reload_mode='train', raw_data_path=None, split_files_path=None, data_fraction=1.): + + os.makedirs(args.processed_dataset_path + reload_mode, exist_ok=True) + + if args.data == 'db5': + onlyfiles = [f for f in os.listdir(raw_data_path) if os.path.isfile(os.path.join(raw_data_path, f))] + code_set = set([file.split('_')[0] for file in onlyfiles]) + split_code_set = set() + with open(os.path.join(split_files_path, reload_mode + '.txt'), 'r') as f: + for line in f.readlines(): + split_code_set.add(line.rstrip()) + + code_set = code_set & split_code_set + code_list = list(code_set) + + bound_ligand_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_l_b.pdb')) + for code in code_list] + bound_receptor_residues_list = [get_residues_db5(os.path.join(raw_data_path, code + '_r_b.pdb')) + for code in code_list] + + input_residues_lists = [(bound_ligand_residues_list[i], bound_receptor_residues_list[i]) + for i in range(len(bound_ligand_residues_list))] + + log('Start preprocess_unbound_bound') + preprocess_result = pmap_multi(preprocess_unbound_bound, + input_residues_lists, + n_jobs=args.n_jobs, + graph_nodes=args.graph_nodes, + pos_cutoff=args.pocket_cutoff, + inference=False) + log('Done preprocess_unbound_bound\n\n') + + unbound_predic_ligand_list, unbound_predic_receptor_list = [], [] + bound_ligand_repres_nodes_loc_array_list, bound_receptor_repres_nodes_loc_array_list = [], [] + pocket_coors_list = [] + for result in preprocess_result: + unbound_predic_ligand, unbound_predic_receptor, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, pocket_coors = result + if pocket_coors is not None: + unbound_predic_ligand_list.append(unbound_predic_ligand) + unbound_predic_receptor_list.append(unbound_predic_receptor) + bound_ligand_repres_nodes_loc_array_list.append(bound_ligand_repres_nodes_loc_array) + bound_receptor_repres_nodes_loc_array_list.append(bound_receptor_repres_nodes_loc_array) + pocket_coors_list.append(pocket_coors) + + protein_to_graph_input = [(unbound_predic_ligand_list[i], + unbound_predic_receptor_list[i], + bound_ligand_repres_nodes_loc_array_list[i], + bound_receptor_repres_nodes_loc_array_list[i]) for i in + range(len(unbound_predic_ligand_list))] + log('Start protein_to_graph_unbound_bound') + + both_proteins_to_graph_pair_list = pmap_multi2(protein_to_graph_unbound_bound, + protein_to_graph_input, + n_jobs=args.n_jobs, + cutoff=args.graph_cutoff, + max_neighbor=args.graph_max_neighbor, + one_hot=False, + residue_loc_is_alphac=args.graph_residue_loc_is_alphaC + ) + + log('Done protein_to_graph_unbound_bound') + + ligand_graph_list, receptor_graph_list = [], [] + for result in both_proteins_to_graph_pair_list: + ligand_graph, receptor_graph = result + ligand_graph_list.append(ligand_graph) + receptor_graph_list.append(receptor_graph) + + ligand_graph_node_tensor_list, receptor_graph_node_tensor_list, ligand_graph_num_nodes_list = [], [], [] + receptor_graph_num_nodes_list, ligand_graph_edge_tensor_list, receptor_graph_edge_tensor_list = [], [], [] + ll_connection_tensor_list, rr_connection_tensor_list, bound_ligand_repres_nodes_loc_array_list_new = [], [], [] + pocket_coors_ligand_list_new, pocket_coors_receptor_list_new = [], [] + bound_receptor_repres_nodes_loc_array_list_new = [] + + for i, _ in enumerate(ligand_graph_list): + ligand_graph, receptor_graph = ligand_graph_list[i], receptor_graph_list[i] + bound_ligand_repres_nodes_loc_array = bound_ligand_repres_nodes_loc_array_list[i] + bound_receptor_repres_nodes_loc_array = bound_receptor_repres_nodes_loc_array_list[i] + pocket_coors_ligand, pocket_coors_receptor = pocket_coors_list[i], pocket_coors_list[i] + + # Randomly rotate and translate the ligand. + rot_t, rot_b = uniformrotation_translation(translation_interval=args.translation_interval) + + ligand_original_loc = ligand_graph.ndata['x'] + + mean_to_remove = ligand_original_loc.mean(axis=0, keep_dims=True) + + pocket_coors_ligand = (rot_t @ (pocket_coors_ligand - mean_to_remove).T).T + rot_b + ligand_new_loc = (rot_t @ (ligand_original_loc - mean_to_remove).T).T + rot_b + + for k in ligand_graph.ndata.keys(): + ligand_graph.ndata[k] = ligand_graph.ndata[k].asnumpy() + for k in receptor_graph.ndata.keys(): + receptor_graph.ndata[k] = receptor_graph.ndata[k].asnumpy() + for k in ligand_graph.edata.keys(): + ligand_graph.edata[k] = ligand_graph.edata[k].asnumpy() + for k in receptor_graph.edata.keys(): + receptor_graph.edata[k] = receptor_graph.edata[k].asnumpy() + ligand_graph.ndata['new_x'] = ligand_new_loc.astype(np.float32) + + ##### Create a batch of a single heterograph + ligand_graph_node_tensor = np.concatenate( + (ligand_graph.ndata["res_feat"], + ligand_graph.ndata["x"], + ligand_graph.ndata["new_x"], + ligand_graph.ndata["mu_r_norm"]), + axis=1 + ) + + receptor_graph_node_tensor = np.concatenate( + (receptor_graph.ndata["res_feat"], + receptor_graph.ndata["x"], + receptor_graph.ndata["mu_r_norm"]), + axis=1 + ) + + ligand_graph_num_nodes = [ligand_graph.ndata["x"].shape[0]] + receptor_graph_num_nodes = [receptor_graph.ndata["x"].shape[0]] + ligand_graph_edge_tensor = ligand_graph.edata['he'] + receptor_graph_edge_tensor = receptor_graph.edata['he'] + + ll_connection_tensor = np.stack( + (ligand_graph.src_list, ligand_graph.dst_list)) + + rr_connection_tensor = np.stack( + (receptor_graph.src_list, receptor_graph.dst_list)) + ##### Create a batch of a single heterograph + + ll_connection_tensor_list.append(ll_connection_tensor) + rr_connection_tensor_list.append(rr_connection_tensor) + ligand_graph_edge_tensor_list.append(ligand_graph_edge_tensor) + receptor_graph_edge_tensor_list.append(receptor_graph_edge_tensor) + ligand_graph_num_nodes_list.append(np.array(ligand_graph_num_nodes).astype(np.int32)) + receptor_graph_num_nodes_list.append(np.array(receptor_graph_num_nodes).astype(np.int32)) + ligand_graph_node_tensor_list.append(ligand_graph_node_tensor) + receptor_graph_node_tensor_list.append(receptor_graph_node_tensor) + bound_ligand_repres_nodes_loc_array_list_new.append(bound_ligand_repres_nodes_loc_array.astype(np.float32)) + bound_receptor_repres_nodes_loc_array_list_new.append( + bound_receptor_repres_nodes_loc_array.astype(np.float32)) + pocket_coors_ligand_list_new.append(pocket_coors_ligand.astype(np.float32)) + pocket_coors_receptor_list_new.append(pocket_coors_receptor.astype(np.float32)) + + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ll_connection_tensor_list.npz"), + *ll_connection_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "rr_connection_tensor_list.npz"), + *rr_connection_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_edge_tensor_list.npz"), + *ligand_graph_edge_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_edge_tensor_list.npz"), + *receptor_graph_edge_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_num_nodes_list.npz"), + *ligand_graph_num_nodes_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_num_nodes_list.npz"), + *receptor_graph_num_nodes_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, + "bound_ligand_repres_nodes_loc_array_list_new.npz"), + *bound_ligand_repres_nodes_loc_array_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, + "bound_receptor_repres_nodes_loc_array_list_new.npz"), + *bound_receptor_repres_nodes_loc_array_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_ligand_list_new.npz"), + *pocket_coors_ligand_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "pocket_coors_receptor_list_new.npz"), + *pocket_coors_receptor_list_new) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "ligand_graph_node_tensor_list.npz"), + *ligand_graph_node_tensor_list) + np.savez(os.path.join(args.processed_dataset_path, reload_mode, "receptor_graph_node_tensor_list.npz"), + *receptor_graph_node_tensor_list) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py new file mode 100644 index 000000000..cedb95442 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py @@ -0,0 +1,38 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +from ...dataset import PSP, data_process_run + +class EquiDockDataSet(PSP): + "EquiDockDataSet" + def __init__(self, config, seed=0): + self.config = config + + def process(self, data, **kwargs): + return + + def set_training_data_src(self, data_source, **kwargs): + return + + def create_iterator(self, num_epochs, **kwargs): + return [_ for _ in range(self.config.num_epochs)] diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py new file mode 100644 index 000000000..36516f725 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py @@ -0,0 +1,1580 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import math +from datetime import datetime as dt + +import ot +import numpy as np +from numpy import linalg as LA +from biopandas.pdb import PandasPdb +import scipy.spatial as spa +from scipy.special import softmax +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter + + +ATOM_NAME = 'atom_name' +FLAGS = os.O_RDWR | os.O_CREAT + + +class MeterUnboundBound(object): + def __init__(self): + self.complex_rmsd_list = [] + self.ligand_rmsd_list = [] + self.receptor_rmsd_list = [] + + def update_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true): + ligand_coors_pred = ligand_coors_pred.asnumpy() + receptor_coors_pred = receptor_coors_pred + + ligand_coors_true = ligand_coors_true + receptor_coors_true = receptor_coors_true + + ligand_rmsd = np.sqrt(np.mean(np.sum((ligand_coors_pred - ligand_coors_true) ** 2, axis=1))) + receptor_rmsd = np.sqrt(np.mean(np.sum((receptor_coors_pred - receptor_coors_true) ** 2, axis=1))) + + complex_coors_pred = np.concatenate((ligand_coors_pred, receptor_coors_pred), axis=0) + complex_coors_true = np.concatenate((ligand_coors_true, receptor_coors_true), axis=0) + + r, b = rigid_transform_kabsch_3d(complex_coors_pred.T, complex_coors_true.T) + complex_coors_pred_aligned = ((r @ complex_coors_pred.T) + b).T + + complex_rmsd = np.sqrt(np.mean(np.sum((complex_coors_pred_aligned - complex_coors_true) ** 2, axis=1))) + + self.complex_rmsd_list.append(complex_rmsd) + self.ligand_rmsd_list.append(ligand_rmsd) + self.receptor_rmsd_list.append(receptor_rmsd) + + return complex_rmsd + + def summarize(self, reduction_rmsd='median'): + if reduction_rmsd == 'mean': + complex_rmsd_array = np.array(self.complex_rmsd_list) + complex_rmsd_summarized = np.mean(complex_rmsd_array) + + ligand_rmsd_array = np.array(self.ligand_rmsd_list) + ligand_rmsd_summarized = np.mean(ligand_rmsd_array) + + receptor_rmsd_array = np.array(self.receptor_rmsd_list) + receptor_rmsd_summarized = np.mean(receptor_rmsd_array) + elif reduction_rmsd == 'median': + complex_rmsd_array = np.array(self.complex_rmsd_list) + complex_rmsd_summarized = np.median(complex_rmsd_array) + + ligand_rmsd_array = np.array(self.ligand_rmsd_list) + ligand_rmsd_summarized = np.median(ligand_rmsd_array) + + receptor_rmsd_array = np.array(self.receptor_rmsd_list) + receptor_rmsd_summarized = np.median(receptor_rmsd_array) + else: + raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!") + + return ligand_rmsd_summarized, receptor_rmsd_summarized, complex_rmsd_summarized + + def summarize_with_std(self, reduction_rmsd='median'): + complex_rmsd_array = np.array(self.complex_rmsd_list) + if reduction_rmsd == 'mean': + complex_rmsd_summarized = np.mean(complex_rmsd_array) + elif reduction_rmsd == 'median': + complex_rmsd_summarized = np.median(complex_rmsd_array) + else: + raise ValueError("Meter_Unbound_Bound: reduction_rmsd mis specified!") + + return complex_rmsd_summarized, np.std(complex_rmsd_array) + + +def create_dir(path): + if os.path.exists(path): + raise FileExistsError('Path already exists. Please delete and restart your job.') + else: + os.makedirs(path, exist_ok=False) + + +def log(*pargs): + banner = 'EQUIDOCK' + log_file_name = os.path.join('stdouterr/', banner + ".txt") + with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as w: + w.write('[' + str(dt.now()) + '] ') + w.write(" ".join(["{}".format(t) for t in pargs])) + w.write("\n") + print('[' + str(dt.now()) + '] ', *pargs) + + +def g_fn(protein_coords, x, sigma): + # protein_coords: (n,3) , x: (m,3), output: (m,) + e = ops.exp(- ops.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n) + + return - sigma * ops.log(1e-3 + ops.sum(e, dim=1)) + + +def compute_body_intersection_loss( + model_ligand_coors_deform, + bound_receptor_repres_nodes_loc_array, + sigma, + surface_ct, +): + loss = ops.mean( + ops.clamp(surface_ct - g_fn(Tensor(bound_receptor_repres_nodes_loc_array), model_ligand_coors_deform, sigma), + min=0)) + \ + ops.mean(ops.clamp( + surface_ct - g_fn(model_ligand_coors_deform, Tensor(bound_receptor_repres_nodes_loc_array), sigma), + min=0)) + + return loss + + +def pretty_print_stats(*args): + + split_type, epoch, total_num_epochs,\ + complex_rmsd_mean, complex_rmsd_median,\ + ligand_rmsd_mean, ligand_rmsd_median,\ + receptor_rmsd_mean, receptor_rmsd_median,\ + avg_loss, avg_loss_ligand_coors,\ + avg_loss_receptor_coors, avg_loss_ot, avg_loss_intersection, args_input = args + + log('[{:s}] --> epoch {:d}/{:d} ' + '|| mean/median complex rmsd {:.4f} / {:.4f} ' + '|| mean/median ligand rmsd {:.4f} / {:.4f} ' + '|| mean/median sqrt pocket OT loss {:.4f} ' + '|| intersection loss {:.4f} ' + '|| mean/median receptor rmsd {:.4f} / {:.4f} '. + format(split_type, + epoch, total_num_epochs, + complex_rmsd_mean, complex_rmsd_median, + ligand_rmsd_mean, ligand_rmsd_median, + math.sqrt(avg_loss_ot), + avg_loss_intersection, + receptor_rmsd_mean, receptor_rmsd_median)) + + +def load_data(files_dir): + + npz_files_list = [ + np.load(files_dir + "/ll_connection_tensor_list.npz"), + np.load(files_dir + "/rr_connection_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_edge_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_edge_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_num_nodes_list.npz"), + np.load(files_dir + "/receptor_graph_num_nodes_list.npz"), + np.load(files_dir + "/ligand_graph_node_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_node_tensor_list.npz"), + np.load(files_dir + "/bound_ligand_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/bound_receptor_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/pocket_coors_ligand_list_new.npz"), + np.load(files_dir + "/pocket_coors_receptor_list_new.npz"), + ] + + extracted_files = [] + for _, npz_file in enumerate(npz_files_list): + extracted_files.append([npz_file[id] for id in npz_file.files]) + + unbatch_list = [] + for i, _ in enumerate(extracted_files[0]): + temp_length = [] + temp_length.append(len(extracted_files[0][i][0])) + temp_length.append(len(extracted_files[1][i][0])) + temp_length.append(len(extracted_files[6][i])) + temp_length.append(len(extracted_files[7][i])) + temp_length.append(len(extracted_files[11][i])) + unbatch_list.append(np.array(temp_length)) + extracted_files.append(unbatch_list) + + index_shuffle = np.arange(len(npz_files_list[0])) + + return extracted_files, index_shuffle + + +def shuffle_list(source_list, index_shuffle): + shuffled_list = [] + for _, idx in enumerate(index_shuffle): + shuffled_list.append(source_list[idx]) + + return shuffled_list + + +def batch(bs, souce_list): + output_list = [] + num = len(souce_list) // bs + for i in range(num): + output_list.append(souce_list[i * bs: (i + 1) * bs]) + if num * bs < len(souce_list): + output_list.append(souce_list[num * bs:]) + + return output_list + + +def shuffle_batch_dataset(input_dataset, index_shuffle, batch_size, shuffle): + if shuffle: + shuffled_dataset = [] + np.random.shuffle(index_shuffle) + for i, _ in enumerate(input_dataset): + shuffled_dataset.append(shuffle_list(input_dataset[i], index_shuffle)) + else: + shuffled_dataset = input_dataset[:] + + dataset_batched = [] + for _, data in enumerate(shuffled_dataset): + dataset_batched.append(batch(batch_size, data)) + unbatch_list = dataset_batched[-1] + for j, _ in enumerate(unbatch_list): + unbatch_list[j].insert(0, np.array([0, 0, 0, 0, 0])) + for i, _ in enumerate(unbatch_list[j]): + if i >= 1: + unbatch_list[j][i][0] += unbatch_list[j][i - 1][0] + unbatch_list[j][i][1] += unbatch_list[j][i - 1][1] + unbatch_list[j][i][2] += unbatch_list[j][i - 1][2] + unbatch_list[j][i][3] += unbatch_list[j][i - 1][3] + unbatch_list[j][i][4] += unbatch_list[j][i - 1][4] + dataset_batched[-1] = unbatch_list + + return dataset_batched + + +def cat_properties(input_dataset_batched): + dataset_batch_cat = [] + for i in range(len(input_dataset_batched[0])): + dataset_batch_cat.append( + [ + np.concatenate(input_dataset_batched[0][i], axis=1), + np.concatenate(input_dataset_batched[1][i], axis=1), + np.concatenate(input_dataset_batched[2][i], axis=0), + np.concatenate(input_dataset_batched[3][i], axis=0), + np.concatenate(input_dataset_batched[4][i], axis=0), + np.concatenate(input_dataset_batched[5][i], axis=0), + np.concatenate(input_dataset_batched[6][i], axis=0), + np.concatenate(input_dataset_batched[7][i], axis=0), + input_dataset_batched[-1][i], + ] + ) + + return dataset_batch_cat + + +def create_dataloader(dataset_dir, bs, shuffle): + dataset, index_shuffle = load_data(dataset_dir) + dataset_batched = shuffle_batch_dataset( + input_dataset=dataset, + index_shuffle=index_shuffle, + batch_size=bs, + shuffle=shuffle, + ) + dataloader = cat_properties(dataset_batched) + + return dataset_batched, dataloader + + +def compute_sq_dist_mat(x_1, x_2): + '''Computes the l2 squared cost matrix between two point cloud inputs. + Args: + X_1: [n, #features] point cloud, tensor + X_2: [m, #features] point cloud, tensor + Output: + [n, m] matrix of the l2 distance between point pairs + ''' + n_1, _ = x_1.shape + n_2, _ = x_2.shape + x_1 = x_1.view(n_1, 1, -1) + x_2 = x_2.view(1, n_2, -1) + squared_dist = (x_1 - x_2) ** 2 + cost_mat = ops.sum(squared_dist, dim=2) + return cost_mat + + +def compute_ot_emd(cost_mat): + cost_mat_detach = cost_mat.asnumpy() + a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0] + b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1] + ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000) + ot_mat_attached = Parameter(Tensor(ot_mat, ms.float32), requires_grad=False) + ot_dist = ops.sum(ot_mat_attached * cost_mat) + return ot_dist, ot_mat_attached + + +def get_nodes_coors_numpy(filename, all_atoms=False): + df = PandasPdb().read_pdb(filename).df['ATOM'] + if not all_atoms: + return Tensor( + df[df['atom_name'] == 'CA'][['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)) + return Tensor(df[['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32)) + + +def get_residues(pdb_filename): + df = PandasPdb().read_pdb(pdb_filename).df['ATOM'] + df.rename(columns={'chain_id': 'chain', 'residue_number': 'residue', 'residue_name': 'resname', + 'x_coord': 'x', 'y_coord': 'y', 'z_coord': 'z', 'element_symbol': 'element'}, inplace=True) + residues = list(df.groupby(['chain', 'residue', 'resname'])) # Not the same as sequence order ! + return residues + + +def prepare_graphs(args, ppdb_ligand, ligand_filename, receptor_filename): + unbound_ligand_all_atoms_pre_pos = ppdb_ligand.df["ATOM"][ + ['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32) + + unbound_predic_ligand, \ + unbound_predic_receptor, \ + bound_ligand_repres_nodes_loc_clean_array, \ + bound_receptor_repres_nodes_loc_clean_array, _ = preprocess_unbound_bound( + get_residues(ligand_filename), get_residues(receptor_filename), + graph_nodes=args.graph_nodes, pos_cutoff=args.pocket_cutoff, inference=True) + + protein_to_graph_unbound_bound_input_tuple = ( + unbound_predic_ligand, unbound_predic_receptor, + bound_ligand_repres_nodes_loc_clean_array, bound_receptor_repres_nodes_loc_clean_array, + ) + + ligand_graph, receptor_graph = protein_to_graph_unbound_bound( + protein_to_graph_unbound_bound_input_tuple, + cutoff=args.graph_cutoff, + max_neighbor=args.graph_max_neighbor, + one_hot=False, + residue_loc_is_alphac=args.graph_residue_loc_is_alphaC, + ) + + return ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array + + +def graph_to_tensor(ligand_graph, receptor_graph): + ligand_graph.ndata['new_x'] = ligand_graph.ndata['x'] + + # Create a batch of a single heterograph + ligand_graph_node_tensor = ops.cat( + (ligand_graph.ndata["res_feat"], + ligand_graph.ndata["x"], + ligand_graph.ndata["new_x"], + ligand_graph.ndata["mu_r_norm"]), + axis=1 + ) + + receptor_graph_node_tensor = ops.cat( + (receptor_graph.ndata["res_feat"], + receptor_graph.ndata["x"], + receptor_graph.ndata["mu_r_norm"]), + axis=1 + ) + + ligand_graph_num_nodes = Tensor([ligand_graph.num_nodes]) + receptor_graph_num_nodes = Tensor([receptor_graph.num_nodes]) + ligand_graph_edge_tensor = ligand_graph.edata["he"] + receptor_graph_edge_tensor = receptor_graph.edata["he"] + + ll_connection_tensor = ops.stack( + (Tensor(ligand_graph.src_list), Tensor(ligand_graph.dst_list))) + + rr_connection_tensor = ops.stack( + (Tensor(receptor_graph.src_list), Tensor(receptor_graph.dst_list))) + + unbatch_list = [ + [ + len(ll_connection_tensor[0]), + len(rr_connection_tensor[0]), + len(ligand_graph_node_tensor), + len(receptor_graph_node_tensor), + 0, + ] + ] + unbatch_list.insert(0, np.array([0, 0, 0, 0, 0])) + + unbatch_list = Tensor(unbatch_list, ms.int32) + + input_tensor_tuple = ( + ligand_graph_num_nodes, + receptor_graph_num_nodes, + ligand_graph_edge_tensor, + receptor_graph_edge_tensor, + ll_connection_tensor, + rr_connection_tensor, + ) + + return ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple + + +def get_rot_mat(euler_angles): + roll = euler_angles[0] + yaw = euler_angles[1] + pitch = euler_angles[2] + + tensor_0 = Tensor(0.0) + tensor_1 = Tensor(1.0) + cos = ops.cos + sin = ops.sin + + rx = ops.stack([ + ops.stack([tensor_1, tensor_0, tensor_0]), + ops.stack([tensor_0, cos(roll), -sin(roll)]), + ops.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3, 3) + + ry = ops.stack([ + ops.stack([cos(pitch), tensor_0, sin(pitch)]), + ops.stack([tensor_0, tensor_1, tensor_0]), + ops.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3, 3) + + rz = ops.stack([ + ops.stack([cos(yaw), -sin(yaw), tensor_0]), + ops.stack([sin(yaw), cos(yaw), tensor_0]), + ops.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) + + r = ops.mm(rz, ry) + r = ops.mm(r, rx) + + return r + + +def distance_list_featurizer(dist_list): + length_scale_list = [1.5 ** x for x in range(15)] + center_list = [0. for _ in range(15)] + + num_edge = len(dist_list) + dist_list = np.array(dist_list) + + transformed_dist = [np.exp(- ((dist_list - center) ** 2) / float(length_scale)) + for length_scale, center in zip(length_scale_list, center_list)] + + transformed_dist = np.array(transformed_dist).T + transformed_dist = transformed_dist.reshape((num_edge, -1)) + + processed_features = dict() + processed_features['he'] = Tensor(transformed_dist.astype(np.float32)) + return processed_features + + +def rigid_transform_kabsch_3d(a, b): + if a.shape[1] != b.shape[1]: + raise ValueError("a.shape[1] not equals b.shape[1]") + num_rows, num_cols = a.shape + if num_rows != 3: + raise Exception(f"matrix A is not 3xN, it is {num_rows}x{num_cols}") + num_rows, num_cols = b.shape + if num_rows != 3: + raise Exception(f"matrix B is not 3xN, it is {num_rows}x{num_cols}") + + # find mean column wise: 3 x 1 + centroid_a = np.mean(a, axis=1, keepdims=True) + centroid_b = np.mean(b, axis=1, keepdims=True) + + # subtract mean + am = a - centroid_a + bm = b - centroid_b + + h = am @ bm.T + + # find rotation + u, s, vt = np.linalg.svd(h) + + r = vt.T @ u.T + + # special reflection case + if np.linalg.det(r) < 0: + ss = np.diag([1., 1., -1.]) + r = (vt.T @ ss) @ u.T + if math.fabs(np.linalg.det(r) - 1) >= 1e-5: + raise ValueError("The value should be smaller than 1e-5") + + t = -r @ centroid_a + centroid_b + return r, t + + +def one_hot_encoding(x, allowable_set, encode_unknown=False): + if encode_unknown and (allowable_set[-1] is not None): + allowable_set.append(None) + + if encode_unknown and (x not in allowable_set): + x = None + + return list(map(lambda s: x == s, allowable_set)) + + +def residue_list_featurizer_dips_one_hot(predic): + residue_list = [term[1]['resname'].iloc[0] for term in predic] + feature_list = [residue_type_one_hot_dips(residue) for residue in residue_list] + feature_list = np.stack(feature_list) + processed_features = dict() + processed_features['res_feat'] = Tensor(feature_list.astype(np.float32)) + return processed_features + + +def residue_list_featurizer_dips_not_one_hot(predic): + residue_list = [term[1]['resname'].iloc[0] for term in predic] + feature_list = [[residue_type_one_hot_dips_not_one_hot(residue)] for residue in residue_list] + feature_list = np.array(feature_list) + processed_features = dict() + processed_features['res_feat'] = Tensor(feature_list.astype(np.float32)) # (N_res, 1) + return processed_features + + +def residue_type_one_hot_dips(residue): + dit = { + 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', + 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', + 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V', + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G', + } + allowable_set = [ + 'Y', 'R', 'F', 'G', 'I', 'V', 'A', 'W', 'E', 'H', 'C', 'N', 'M', 'D', 'T', 'S', 'K', 'L', 'Q', 'P' + ] + res_name = residue + if res_name not in dit.keys(): + res_name = None + else: + res_name = dit[res_name] + return one_hot_encoding(res_name, allowable_set, encode_unknown=True) + + +def residue_type_one_hot_dips_not_one_hot(residue): + dit = { + 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', + 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', + 'PRO': 'P', 'SER': 'S', 'THR': 'T', 'TRP': 'W', 'TYR': 'Y', 'VAL': 'V', + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G' + } + + rare_residues = { + 'HIP': 'H', 'HIE': 'H', 'TPO': 'T', 'HID': 'H', 'LEV': 'L', 'MEU': 'M', 'PTR': 'Y', + 'GLV': 'E', 'CYT': 'C', 'SEP': 'S', 'HIZ': 'H', 'CYM': 'C', 'GLM': 'E', 'ASQ': 'D', + 'TYS': 'Y', 'CYX': 'C', 'GLZ': 'G' + } + + if residue in rare_residues.keys(): + log('Some rare residue: ', residue) + + indicator = { + 'Y': 0, 'R': 1, 'F': 2, 'G': 3, 'I': 4, 'V': 5, + 'A': 6, 'W': 7, 'E': 8, 'H': 9, 'C': 10, 'N': 11, + 'M': 12, 'D': 13, 'T': 14, 'S': 15, 'K': 16, 'L': 17, 'Q': 18, 'P': 19 + } + res_name = residue + if res_name not in dit.keys(): + return 20 + else: + res_name = dit[res_name] + return indicator.get(res_name) + + +def preprocess_unbound_bound(bound_ligand_residues, bound_receptor_residues, graph_nodes, pos_cutoff=8.0, + inference=False): + + ####################### + def filter_residues(residues): + residues_filtered = [] + for residue in residues: + df = residue[1] + natom = df[df[ATOM_NAME] == 'N'] + alphacatom = df[df[ATOM_NAME] == 'CA'] + catom = df[df[ATOM_NAME] == 'C'] + + if natom.shape[0] == 1 and alphacatom.shape[0] == 1 and catom.shape[0] == 1: + residues_filtered.append(residue) + return residues_filtered + + ########################## + + bound_predic_ligand_filtered = filter_residues(bound_ligand_residues) + unbound_predic_ligand_filtered = bound_predic_ligand_filtered + + bound_predic_receptor_filtered = filter_residues(bound_receptor_residues) + unbound_predic_receptor_filtered = bound_predic_receptor_filtered + + bound_predic_ligand_clean_list = bound_predic_ligand_filtered + unbound_predic_ligand_clean_list = unbound_predic_ligand_filtered + + bound_predic_receptor_clean_list = bound_predic_receptor_filtered + unbound_predic_receptor_clean_list = unbound_predic_receptor_filtered + + ################### + def get_alphac_loc_array(bound_predic_clean_list): + bound_alphac_loc_clean_list = [] + for residue in bound_predic_clean_list: + df = residue[1] + alphacatom = df[df[ATOM_NAME] == 'CA'] + alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + bound_alphac_loc_clean_list.append(alphac_loc) + if len(bound_alphac_loc_clean_list) <= 1: + bound_alphac_loc_clean_list.append(np.zeros(3)) + return np.stack(bound_alphac_loc_clean_list, axis=0) # (N_res,3) + + #################### + if graph_nodes != 'residues': + raise TypeError("graph_nodes should be residues") + bound_receptor_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_receptor_clean_list) + bound_ligand_repres_nodes_loc_array = get_alphac_loc_array(bound_predic_ligand_clean_list) + + if not inference: + + # Keep pairs of ligand and receptor residues/atoms that have pairwise distances < threshold + ligand_receptor_distance = spa.distance.cdist(bound_ligand_repres_nodes_loc_array, + bound_receptor_repres_nodes_loc_array) + positive_tuple = np.where(ligand_receptor_distance < pos_cutoff) + active_ligand = positive_tuple[0] + active_receptor = positive_tuple[1] + if active_ligand.size <= 3: # We need: active_ligand.size > 0 ' + pocket_coors = None # Will be filtered out later + else: + ligand_pocket_coors = bound_ligand_repres_nodes_loc_array[active_ligand, :] + receptor_pocket_coors = bound_receptor_repres_nodes_loc_array[active_receptor, :] + if np.max(np.linalg.norm(ligand_pocket_coors - receptor_pocket_coors, axis=1)) > pos_cutoff: + raise ValueError("The value should <= pos_cutoff") + pocket_coors = 0.5 * (ligand_pocket_coors + receptor_pocket_coors) + log('Num pocket nodes = ', len(active_ligand), ' total nodes = ', + bound_ligand_repres_nodes_loc_array.shape[0], ' graph_nodes = ', graph_nodes) + + return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, \ + pocket_coors + + return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \ + bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, 0 + + +def protein_to_graph_unbound_bound( + unbound_bound_tuple, + cutoff=20, + max_neighbor=None, + one_hot=False, + residue_loc_is_alphac=True, +): + return protein_to_graph_unbound_bound_residuesonly( + unbound_bound_tuple, + cutoff, + max_neighbor, + one_hot, + residue_loc_is_alphac + ) + + +def protein_to_graph_unbound_bound_residuesonly( + unbound_bound_tuple, + cutoff=20, + max_neighbor=None, + one_hot=False, + residue_loc_is_alphac=True +): + + unbound_ligand_predic, unbound_receptor_predic, bound_ligand_repres_nodes_loc_clean_array, \ + bound_receptor_repres_nodes_loc_clean_array = unbound_bound_tuple + + ################## Extract 3D coordinates and n_i,u_i,v_i vectors of representative residues ################ + def l_or_r_extract_3d_coord_and_n_u_v_vecs(l_or_r_predic): + l_or_r_all_atom_coords_in_residue_list = [] + l_or_r_residue_representatives_loc_list = [] + l_or_r_n_i_list = [] + l_or_r_u_i_list = [] + l_or_r_v_i_list = [] + + for residue in l_or_r_predic: + df = residue[1] + coord = df[['x', 'y', 'z']].to_numpy().astype(np.float32) # (N_atoms, 3) + l_or_r_all_atom_coords_in_residue_list.append(coord) + + natom = df[df[ATOM_NAME] == 'N'] + alphacatom = df[df[ATOM_NAME] == 'CA'] + catom = df[df[ATOM_NAME] == 'C'] + + if natom.shape[0] != 1 or alphacatom.shape[0] != 1 or catom.shape[0] != 1: + raise ValueError("protein utils protein_to_graph_unbound_bound, no N/CA/C exists") + + n_loc = natom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + alphac_loc = alphacatom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + c_loc = catom[['x', 'y', 'z']].to_numpy().squeeze().astype(np.float32) + + u_i = (n_loc - alphac_loc) / LA.norm(n_loc - alphac_loc) + t_i = (c_loc - alphac_loc) / LA.norm(c_loc - alphac_loc) + n_i = np.cross(u_i, t_i) / LA.norm(np.cross(u_i, t_i)) + v_i = np.cross(n_i, u_i) + + l_or_r_n_i_list.append(n_i) + l_or_r_u_i_list.append(u_i) + l_or_r_v_i_list.append(v_i) + + if residue_loc_is_alphac: + l_or_r_residue_representatives_loc_list.append(alphac_loc) + else: + heavy_df = df[df['element'] != 'H'] + residue_loc = heavy_df[['x', 'y', 'z']].mean(axis=0).to_numpy().astype( + np.float32) # average of all atom coordinates + l_or_r_residue_representatives_loc_list.append(residue_loc) + + l_or_r_residue_representatives_loc_feat = np.stack(l_or_r_residue_representatives_loc_list, + axis=0) # (N_res, 3) + l_or_r_n_i_feat = np.stack(l_or_r_n_i_list, axis=0) + l_or_r_u_i_feat = np.stack(l_or_r_u_i_list, axis=0) + l_or_r_v_i_feat = np.stack(l_or_r_v_i_list, axis=0) + + l_or_r_num_residues = len(l_or_r_predic) + if l_or_r_num_residues <= 1: + raise ValueError(f"l_or_r contains only 1 residue!") + return l_or_r_all_atom_coords_in_residue_list, \ + l_or_r_residue_representatives_loc_feat, \ + l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, l_or_r_num_residues + + (ligand_all_atom_coords_in_residue_list, # list of (N_atoms,3) arrays, for each residue + ligand_residue_representatives_loc_feat, # (N_res, 3) + ligand_n_i_feat, # (N_res, 3) + ligand_u_i_feat, # (N_res, 3) + ligand_v_i_feat, # (N_res, 3) + ligand_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_ligand_predic) + + (receptor_all_atom_coords_in_residue_list, + receptor_residue_representatives_loc_feat, + receptor_n_i_feat, + receptor_u_i_feat, + receptor_v_i_feat, + receptor_num_residues) = l_or_r_extract_3d_coord_and_n_u_v_vecs(unbound_receptor_predic) + + ################# Align unbound and bound structures, if needed ################################ + def l_or_r_align_unbound_and_bound(l_or_r_residue_representatives_loc_feat, + l_or_r_n_i_feat, + l_or_r_u_i_feat, + l_or_r_v_i_feat, + bound_l_or_r_alphac_loc_clean_array): + + ret_r_l_or_r, ret_t_l_or_r = rigid_transform_kabsch_3d(l_or_r_residue_representatives_loc_feat.T, + bound_l_or_r_alphac_loc_clean_array.T) + l_or_r_residue_representatives_loc_feat = ((ret_r_l_or_r @ (l_or_r_residue_representatives_loc_feat).T) + + ret_t_l_or_r).T + l_or_r_n_i_feat = ((ret_r_l_or_r @ (l_or_r_n_i_feat).T)).T + l_or_r_u_i_feat = ((ret_r_l_or_r @ (l_or_r_u_i_feat).T)).T + l_or_r_v_i_feat = ((ret_r_l_or_r @ (l_or_r_v_i_feat).T)).T + return l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat + + (ligand_residue_representatives_loc_feat, + ligand_n_i_feat, + ligand_u_i_feat, + ligand_v_i_feat) = l_or_r_align_unbound_and_bound(ligand_residue_representatives_loc_feat, + ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat, + bound_ligand_repres_nodes_loc_clean_array) + (receptor_residue_representatives_loc_feat, + receptor_n_i_feat, + receptor_u_i_feat, + receptor_v_i_feat) = l_or_r_align_unbound_and_bound(receptor_residue_representatives_loc_feat, + receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat, + bound_receptor_repres_nodes_loc_clean_array) + + ################### Build the k-NN graph ############################## + def loop_edges(input_tuple): + + l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat, \ + l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list = input_tuple + + # Loop over all edges of the graph and build the various p_ij, q_ij, k_ij, t_ij pairs + l_or_r_edge_feat_ori_list = [] + for i in range(len(l_or_r_dist_list)): + src = l_or_r_src_list[i] + dst = l_or_r_dst_list[i] + + # place n_i, u_i, v_i as lines in a 3x3 basis matrix + basis_matrix = np.stack((l_or_r_n_i_feat[dst, :], l_or_r_u_i_feat[dst, :], l_or_r_v_i_feat[dst, :]), axis=0) + p_ij = np.matmul(basis_matrix, + l_or_r_residue_representatives_loc_feat[src, :] - l_or_r_residue_representatives_loc_feat[ + dst, :]) + q_ij = np.matmul(basis_matrix, l_or_r_n_i_feat[src, :]) # shape (3,) + k_ij = np.matmul(basis_matrix, l_or_r_u_i_feat[src, :]) + t_ij = np.matmul(basis_matrix, l_or_r_v_i_feat[src, :]) + s_ij = np.concatenate((p_ij, q_ij, k_ij, t_ij), axis=0) # shape (12,) + l_or_r_edge_feat_ori_list.append(s_ij) + l_or_r_edge_feat_ori_feat = np.stack(l_or_r_edge_feat_ori_list, axis=0) # shape (num_edges, 4, 3) + l_or_r_edge_feat_ori_feat = Tensor(l_or_r_edge_feat_ori_feat.astype(np.float32)) + l_or_r_protein_graph.edata['he'] = Tensor( + np.concatenate((l_or_r_protein_graph.edata['he'].asnumpy(), l_or_r_edge_feat_ori_feat.asnumpy()), axis=1) + ) + + l_or_r_residue_representatives_loc_feat = Tensor( + l_or_r_residue_representatives_loc_feat.astype(np.float32)) + l_or_r_protein_graph.ndata['x'] = l_or_r_residue_representatives_loc_feat + l_or_r_protein_graph.ndata['mu_r_norm'] = Tensor( + np.array(l_or_r_mean_norm_list).astype(np.float32)) + + return l_or_r_protein_graph + + + def compute_dig_knn_graph(input_tuple): + + l_or_r_num_residues, l_or_r_all_atom_coords_in_residue_list, unbound_l_or_r_predic,\ + l_or_r_residue_representatives_loc_feat, l_or_r_n_i_feat, l_or_r_u_i_feat, l_or_r_v_i_feat = input_tuple + + l_or_r_distance = np.full((l_or_r_num_residues, l_or_r_num_residues), np.inf) + + for i in range(l_or_r_num_residues - 1): + for j in range((i + 1), l_or_r_num_residues): + l_or_r_pairwise_dis = spa.distance.cdist(l_or_r_all_atom_coords_in_residue_list[i], + l_or_r_all_atom_coords_in_residue_list[j]) + l_or_r_distance[i, j] = np.mean(l_or_r_pairwise_dis) + l_or_r_distance[j, i] = np.mean(l_or_r_pairwise_dis) + + l_or_r_protein_graph = Graph(num_nodes=l_or_r_num_residues) + + l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_mean_norm_list = [], [], [], [] + + for i in range(l_or_r_num_residues): + valid_src = list(np.where(l_or_r_distance[i, :] < cutoff)[0]) + if len(valid_src) > max_neighbor: + valid_src = list(np.argsort(l_or_r_distance[i, :]))[0: max_neighbor] + valid_dst = [i] * len(valid_src) + l_or_r_dst_list.extend(valid_dst) + l_or_r_src_list.extend(valid_src) + + valid_dist = list(l_or_r_distance[i, valid_src]) + l_or_r_dist_list.extend(valid_dist) + + valid_dist_np = l_or_r_distance[i, valid_src] + sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) + weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num) + diff_vecs = l_or_r_residue_representatives_loc_feat[valid_dst, :] - l_or_r_residue_representatives_loc_feat[ + valid_src, :] # (neigh_num, 3) + mean_vec = weights.dot(diff_vecs) # (sigma_num, 3) + denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,) + mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,) + l_or_r_mean_norm_list.append(mean_vec_ratio_norm) + + l_or_r_protein_graph.add_edges(l_or_r_src_list, l_or_r_dst_list) + + if one_hot: + l_or_r_protein_graph.ndata = residue_list_featurizer_dips_one_hot(unbound_l_or_r_predic) + else: + l_or_r_protein_graph.ndata = residue_list_featurizer_dips_not_one_hot(unbound_l_or_r_predic) + + l_or_r_protein_graph.edata = distance_list_featurizer(l_or_r_dist_list) + + loop_edges_input = l_or_r_src_list, l_or_r_dst_list, l_or_r_dist_list, l_or_r_n_i_feat, l_or_r_u_i_feat, \ + l_or_r_v_i_feat, l_or_r_residue_representatives_loc_feat, l_or_r_protein_graph, l_or_r_mean_norm_list + + return loop_edges(loop_edges_input) + + + ligand_protein_graph_tuple = ( + ligand_num_residues, ligand_all_atom_coords_in_residue_list, unbound_ligand_predic, + ligand_residue_representatives_loc_feat, ligand_n_i_feat, ligand_u_i_feat, ligand_v_i_feat, + ) + ligand_protein_graph = compute_dig_knn_graph(ligand_protein_graph_tuple) + + receptor_protein_graph_tuple = ( + receptor_num_residues, receptor_all_atom_coords_in_residue_list, unbound_receptor_predic, + receptor_residue_representatives_loc_feat, receptor_n_i_feat, receptor_u_i_feat, receptor_v_i_feat, + ) + receptor_protein_graph = compute_dig_knn_graph(receptor_protein_graph_tuple) + + return ligand_protein_graph, receptor_protein_graph + + +def unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, h_feats_ligand, coors_receptor, coors_ligand): + list_hetero_graph = [] + for i, _ in enumerate(unbatch_list_tensor): + if i < len(unbatch_list_tensor) - 1: + hetero_graph = { + "receptor_hv_iegmn_out": \ + h_feats_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + "ligand_hv_iegmn_out": \ + h_feats_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + "receptor_x_iegmn_out": \ + coors_receptor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + "ligand_x_iegmn_out": \ + coors_ligand[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + } + list_hetero_graph.append(hetero_graph) + return list_hetero_graph + + +def get_non_lin(input_type, negative_slope): + if input_type == 'swish': + return nn.SiLU() + elif input_type == 'lkyrelu': + return nn.LeakyReLU(alpha=negative_slope) + else: + raise NotImplementedError + + +def get_layer_norm(layer_norm_type, dim): + if layer_norm_type == 'BN': + return nn.BatchNorm1d([dim]) + elif layer_norm_type == 'LN': + return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, + epsilon=1e-5) + else: + return nn.Identity() + + +def get_final_h_layer_norm(layer_norm_type, dim): + if layer_norm_type == 'BN': + return nn.BatchNorm1d(dim) + elif layer_norm_type == 'LN': + return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, epsilon=1e-5) + elif layer_norm_type == '0': + return nn.Identity() + else: + raise NotImplementedError + + +def apply_final_h_layer_norm(h, node_type, norm_type, norm_layer): + return norm_layer(h) + + +def compute_cross_attention(queries, keys, values, mask, cross_msgs): + # Compute cross attention + if not cross_msgs: + return queries * 0. + a = mask * ops.mm(queries, ops.transpose(keys, (1, 0))) - 1000. * (1. - mask) + a_x = ops.softmax(a) # i->j, NxM + attention_x = ops.mm(a_x, values) # (N,d) + return attention_x + + +def get_mask(ligand_batch_num_nodes, receptor_batch_num_nodes): + rows = sum(ligand_batch_num_nodes) + cols = sum(receptor_batch_num_nodes) + mask = ops.zeros((int(rows), int(cols))) + partial_l = 0 + partial_r = 0 + for l_n, r_n in zip(ligand_batch_num_nodes, receptor_batch_num_nodes): + mask[partial_l: partial_l + l_n, partial_r: partial_r + r_n] = 1 + partial_l = partial_l + l_n + partial_r = partial_r + r_n + return mask + + +class Graph: + def __init__( + self, + num_nodes=0, + num_edges=0, + ): + self.num_nodes = num_nodes + self.num_edges = num_edges + self.ndata = {} + self.edata = {} + self.src_list = [] + self.dst_list = [] + + def add_edges(self, src_list, dst_list): + if len(src_list) != len(dst_list): + raise NotImplementedError + self.src_list = src_list + self.dst_list = dst_list + self.num_edges = len(dst_list) + + +class IEGMNLayer(nn.Cell): + + def __init__(self, orig_h_feats_dim, h_feats_dim, out_feats_dim, fine_tune, args, log=None): + + super(IEGMNLayer, self).__init__() + + dropout = args.dropout + nonlin = args.nonlin + layer_norm = args.layer_norm + leakyrelu_neg_slope = args.leakyrelu_neg_slope + self.input_edge_feats_dim = args.input_edge_feats_dim + self.layer_norm_coors = args.layer_norm_coors + self.cross_msgs = args.cross_msgs + self.final_h_layer_norm = args.final_h_layer_norm + self.use_dist_in_layers = args.use_dist_in_layers + self.skip_weight_h = args.skip_weight_h + self.x_connection_init = args.x_connection_init + self.debug = args.debug + self.fine_tune = fine_tune + self.log = log + self.h_feats_dim = h_feats_dim + self.out_feats_dim = out_feats_dim + self.all_sigmas_dist = [1.5 ** x for x in range(15)] + self.orig_h_feats_dim = orig_h_feats_dim + + # EDGES + self.edge_mlp = self.edge_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope) + + # NODES + self.node_norm = nn.Identity() + + self.att_mlp_q = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + + self.att_mlp_k = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + + self.att_mlp_v = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + ) + + self.node_mlp = self.node_mlp_func(dropout, nonlin, layer_norm, leakyrelu_neg_slope) + + self.final_h_layernorm_layer = get_final_h_layer_norm(self.final_h_layer_norm, self.out_feats_dim) + + ## The scalar weight to be multiplied by (x_i - x_j) + self.coors_mlp = self.coors_mlp_func(dropout, nonlin, leakyrelu_neg_slope) + + if self.fine_tune: + self.att_mlp_cross_coors_q = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + self.att_mlp_cross_coors_k = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim, has_bias=False), + get_non_lin(nonlin, leakyrelu_neg_slope), + ) + self.att_mlp_cross_coors_v = nn.SequentialCell( + nn.Dense(h_feats_dim, h_feats_dim), + get_non_lin(nonlin, leakyrelu_neg_slope), + nn.Dense(h_feats_dim, 1), + ) + + def __repr__(self): + return "IEGMNLayer " + str(self.__dict__) + + def edge_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope): + + edge_mlp = nn.SequentialCell( + nn.Dense((self.h_feats_dim * 2) + self.input_edge_feats_dim + + len(self.all_sigmas_dist), self.out_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(layer_norm, self.out_feats_dim), + nn.Dense(self.out_feats_dim, self.out_feats_dim), + ) + return edge_mlp + + def node_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope): + + node_mlp = nn.SequentialCell( + nn.Dense(self.orig_h_feats_dim + 2 * self.h_feats_dim + self.out_feats_dim, self.h_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(layer_norm, self.h_feats_dim), + nn.Dense(self.h_feats_dim, self.out_feats_dim), + ) + return node_mlp + + def coors_mlp_func(self, dropout, nonlin, leakyrelu_neg_slope): + coors_mlp = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.out_feats_dim), + nn.Dropout(dropout), + get_non_lin(nonlin, leakyrelu_neg_slope), + get_layer_norm(self.layer_norm_coors, self.out_feats_dim), + nn.Dense(self.out_feats_dim, 1) + ) + return coors_mlp + + def x_rel_mag(self, nodes_ligand_x_now, ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor): + nodes_ligand_x_now_src = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[0]) + nodes_ligand_x_now_dst = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[1]) + ll_edge_x_rel = ops.sub(nodes_ligand_x_now_src, nodes_ligand_x_now_dst) + + nodes_receptor_x_now_src = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[0]) + nodes_receptor_x_now_dst = ops.index_select(nodes_receptor_x_now, 0, rr_connection_tensor[1]) + rr_edge_x_rel = ops.sub(nodes_receptor_x_now_src, nodes_receptor_x_now_dst) + + x_rel_mag_ligand = ll_edge_x_rel ** 2 + x_rel_mag_ligand = ops.sum(x_rel_mag_ligand, dim=1, keepdim=True) # ||x_i - x_j||^2 : (N_res, 1) + x_rel_mag_ligand = ops.cat([ops.exp(-x_rel_mag_ligand / sigma) for sigma in self.all_sigmas_dist], axis=-1) + + x_rel_mag_receptor = rr_edge_x_rel ** 2 + x_rel_mag_receptor = ops.sum(x_rel_mag_receptor, dim=1, keepdim=True) + x_rel_mag_receptor = ops.cat([ops.exp(-x_rel_mag_receptor / sigma) for sigma in self.all_sigmas_dist], axis=-1) + + return ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor + + def cat_input_for_msg(self, nodes_feat, original_edge_feats, connection_tensor, x_rel_mag): + nodes_feat_src = ops.index_select(nodes_feat, 0, connection_tensor[0]) + nodes_feat_dst = ops.index_select(nodes_feat, 0, connection_tensor[1]) + nodes_cat_feat = ops.cat((nodes_feat_src, nodes_feat_dst), axis=1) + cat_input_for_msg = ops.cat((nodes_cat_feat, # [h_i h_j] + original_edge_feats, + x_rel_mag), axis=-1) + return cat_input_for_msg + + def nodes_aggr_cross_msg(self, h_feats_ligand, h_feats_receptor, mask): + # \mu_i + nodes_ligand_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_ligand), + self.att_mlp_k(h_feats_receptor), + self.att_mlp_v(h_feats_receptor), + mask, + self.cross_msgs) + nodes_receptor_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_receptor), + self.att_mlp_k(h_feats_ligand), + self.att_mlp_v(h_feats_ligand), + mask.transpose(1, 0), + self.cross_msgs) + return nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg + + def cal_x_final(self, edges_x_moment, edges_msg, orig_coors, nodes_x_now): + nodes_x_update = ops.mean(ops.reshape(edges_x_moment, (-1, 10, 3)), axis=1) + nodes_aggr_msg = ops.mean(ops.reshape(edges_msg, (-1, 10, edges_msg.shape[-1])), axis=1) + x_final = self.x_connection_init * orig_coors + (1. - self.x_connection_init) * nodes_x_now + \ + nodes_x_update + return nodes_aggr_msg, x_final + + def fine_tune_final_lr(self, input_fine_tune_tuple): + x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, h_feats_receptor, \ + nodes_receptor_x_now, mask = input_fine_tune_tuple + + x_final_ligand = x_final_ligand + \ + self.att_mlp_cross_coors_v(h_feats_ligand) * ( + nodes_ligand_x_now - + compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_ligand), + self.att_mlp_cross_coors_k(h_feats_receptor), + nodes_receptor_x_now, + mask, + self.cross_msgs)) + x_final_receptor = x_final_receptor + \ + self.att_mlp_cross_coors_v(h_feats_receptor) * ( + nodes_receptor_x_now - + compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_receptor), + self.att_mlp_cross_coors_k(h_feats_ligand), + nodes_ligand_x_now, + mask.transpose(1, 0), + self.cross_msgs)) + return x_final_ligand, x_final_receptor + + def skip_connections(self, h_feats_ligand, h_feats_receptor, input_node_upd_ligand, input_node_upd_receptor): + # Skip connections + if self.h_feats_dim == self.out_feats_dim: + node_upd_ligand = self.skip_weight_h * self.node_mlp(input_node_upd_ligand) + ( + 1. - self.skip_weight_h) * h_feats_ligand + node_upd_receptor = self.skip_weight_h * self.node_mlp(input_node_upd_receptor) + ( + 1. - self.skip_weight_h) * h_feats_receptor + else: + node_upd_ligand = self.node_mlp(input_node_upd_ligand) + node_upd_receptor = self.node_mlp(input_node_upd_receptor) + + node_upd_ligand = apply_final_h_layer_norm(node_upd_ligand, 'ligand', self.final_h_layer_norm, + self.final_h_layernorm_layer) + node_upd_receptor = apply_final_h_layer_norm(node_upd_receptor, 'receptor', self.final_h_layer_norm, + self.final_h_layernorm_layer) + + return node_upd_ligand, node_upd_receptor + + def log_debug_info(self, debug_variable_tuple): + nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, \ + nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, \ + x_final_ligand = debug_variable_tuple + self.log(ops.max(nodes_ligand_x_now)[0], 'x_now : x_i at layer entrance') + self.log(ops.max(nodes_ligand_feat)[0], 'data[feat] = h_i at layer entrance') + self.log(ops.max(ll_edge_x_rel)[0], 'x_rel : x_i - x_j') + self.log(ops.max(x_rel_mag_ligand, axis=0)[0], + 'x_rel_mag_ligand = [exp(-||x_i - x_j||^2 / sigma) for sigma = 1.5 ** x, x = [0, 15]]') + self.log(ops.max(edges_ll_msg)[0], 'data[msg] = m_{i->j} = phi^e(h_i, h_j, f_{i,j}, x_rel_mag_ligand)') + self.log(ops.max(nodes_ligand_aggr_cross_msg)[0], 'aggr_cross_msg(i) = sum_j a_{i,j} * h_j') + self.log(ops.max(edge_coef_ligand)[0], 'edge_coef_ligand : \phi^x(m_{i->j})') + self.log(ops.max(edges_ll_x_moment)[0], 'data[x_moment] = (x_i - x_j) * \phi^x(m_{i->j})') + self.log(ops.max(nodes_ligand_aggr_msg)[0], 'data[aggr_msg]: \sum_j m_{i->j} ') + self.log(ops.max(x_final_ligand)[0], 'x_i new = x_final_ligand : x_i + data[x_update]') + + def construct(self, iegmn_layer_tuple, input_tensor_tuple): + + ligand_graph_num_nodes, receptor_graph_num_nodes, ll_connection_tensor, rr_connection_tensor = \ + input_tensor_tuple[0], input_tensor_tuple[1], input_tensor_tuple[4], input_tensor_tuple[5] + + coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand, \ + orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features, \ + original_edge_feats_receptor, orig_coors_receptor = iegmn_layer_tuple + + nodes_ligand_x_now, nodes_receptor_x_now, nodes_ligand_feat, nodes_receptor_feat = \ + coors_ligand, coors_receptor, h_feats_ligand, h_feats_receptor + + ll_edge_x_rel, rr_edge_x_rel, x_rel_mag_ligand, x_rel_mag_receptor = self.x_rel_mag(nodes_ligand_x_now, \ + ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor) + + if not self.use_dist_in_layers: + x_rel_mag_ligand = x_rel_mag_ligand * 0. + x_rel_mag_receptor = x_rel_mag_receptor * 0. + + cat_input_for_msg_ligand = self.cat_input_for_msg( + nodes_ligand_feat, original_edge_feats_ligand, ll_connection_tensor, x_rel_mag_ligand) + cat_input_for_msg_receptor = self.cat_input_for_msg( + nodes_receptor_feat, original_edge_feats_receptor, rr_connection_tensor, x_rel_mag_receptor) + + edges_ll_msg = self.edge_mlp(cat_input_for_msg_ligand) # m_{i->j} + edges_rr_msg = self.edge_mlp(cat_input_for_msg_receptor) + + mask = get_mask(ligand_graph_num_nodes, receptor_graph_num_nodes) + + nodes_ligand_aggr_cross_msg, nodes_receptor_aggr_cross_msg = self.nodes_aggr_cross_msg(h_feats_ligand, + h_feats_receptor, mask) + edge_coef_ligand = self.coors_mlp(edges_ll_msg) # \phi^x(m_{i->j}) + edges_ll_x_moment = ll_edge_x_rel * edge_coef_ligand # (x_i - x_j) * \phi^x(m_{i->j}) + + edge_coef_receptor = self.coors_mlp(edges_rr_msg) + edges_rr_x_moment = rr_edge_x_rel * edge_coef_receptor + + nodes_ligand_aggr_msg, x_final_ligand = self.cal_x_final(edges_ll_x_moment, edges_ll_msg, + orig_coors_ligand, nodes_ligand_x_now) + nodes_receptor_aggr_msg, x_final_receptor = self.cal_x_final(edges_rr_x_moment, edges_rr_msg, + orig_coors_receptor, nodes_receptor_x_now) + + if self.fine_tune: + input_fine_tune = ( + x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, + h_feats_receptor, nodes_receptor_x_now, mask + ) + x_final_ligand, x_final_receptor = self.fine_tune_final_lr(input_fine_tune) + + if self.debug: + debug_variable_tuple = ( + nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, + nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, x_final_ligand, + ) + self.log_debug_info(debug_variable_tuple) + + input_node_upd_ligand = ops.cat((self.node_norm(nodes_ligand_feat), nodes_ligand_aggr_msg, + nodes_ligand_aggr_cross_msg, original_ligand_node_features), axis=-1) + + input_node_upd_receptor = ops.cat((self.node_norm(nodes_receptor_feat), nodes_receptor_aggr_msg, + nodes_receptor_aggr_cross_msg, original_receptor_node_features), axis=-1) + + node_upd_ligand, node_upd_receptor = self.skip_connections(h_feats_ligand, h_feats_receptor, + input_node_upd_ligand, input_node_upd_receptor) + + return x_final_ligand, node_upd_ligand, x_final_receptor, node_upd_receptor + + +class IEGMN(nn.Cell): + + def __init__(self, args, n_lays, fine_tune, log=None): + + super(IEGMN, self).__init__() + + self.debug = args.debug + self.log = log + self.graph_nodes = args.graph_nodes + self.rot_model = args.rot_model + self.iegmn_lay_hid_dim = args.iegmn_lay_hid_dim + self.noise_decay_rate = args.noise_decay_rate + self.noise_initial = args.noise_initial + self.use_edge_features_in_gmn = args.use_edge_features_in_gmn + self.use_mean_node_features = args.use_mean_node_features + + # 21 types of amino-acid types + self.residue_emb_layer = nn.Embedding(vocab_size=21, embedding_size=args.residue_emb_dim, use_one_hot=False) + if self.graph_nodes != 'residues': + raise NotImplementedError + input_node_feats_dim = args.residue_emb_dim # One residue type + + if self.use_mean_node_features: + input_node_feats_dim += 5 # Additional features from mu_r_norm + + self.iegmn_layers = self.iegmn_layers_func(args, input_node_feats_dim, n_lays, fine_tune) + + if args.rot_model != 'kb_att': + raise ValueError("args rot_model should be kb_att") + + # Attention layers + self.num_att_heads = args.num_att_heads + self.out_feats_dim = self.iegmn_lay_hid_dim + + self.att_mlp_key_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False), + ) + self.att_mlp_query_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.num_att_heads * self.out_feats_dim, has_bias=False), + ) + + self.mlp_h_mean_rot = nn.SequentialCell( + nn.Dense(self.out_feats_dim, self.out_feats_dim), + nn.Dropout(args.dropout), + get_non_lin(args.nonlin, args.leakyrelu_neg_slope), + ) + + def __repr__(self): + return "IEGMN " + str(self.__dict__) + + def iegmn_layers_func(self, args, input_node_feats_dim, n_lays, fine_tune): + + iegmn_layers = nn.CellList() + + iegmn_layers.append( + IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=input_node_feats_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + fine_tune=fine_tune, + args=args, + log=self.log)) + + if args.shared_layers: + interm_lay = IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log=self.log) + for _ in range(1, n_lays): + iegmn_layers.append(interm_lay) + + else: + for _ in range(1, n_lays): + iegmn_layers.append( + IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log=self.log)) + + return iegmn_layers + + def att_weights_rot(self, h_feats, h_feats_att_mean_rot, d, z_coors, all_y_att_rot_list): + att_weights_rot = ops.softmax( + self.att_mlp_key_rot(h_feats).view(-1, self.num_att_heads, d) \ + .transpose(1, 0, 2) @ # (K_heads, m_rec, d) + self.att_mlp_query_rot(h_feats_att_mean_rot) \ + .view(1, self.num_att_heads, d).transpose(1, 2, 0) / # (K_heads, d, 1) + math.sqrt(d), # (K_heads, m_receptor, 1) + axis=1).view(self.num_att_heads, -1) + + y_att_rot = att_weights_rot @ z_coors # K_heads, 3 + all_y_att_rot_list.append(y_att_rot) + + return y_att_rot, all_y_att_rot_list + + def ap_compute(self, list_hetero_graph): + + all_t_align_list, all_b_align_list, all_y_receptor_att_rot_list, all_y_ligand_att_rot_list = [], [], [], [] + + for _, hetero_graph in enumerate(list_hetero_graph): + + # Get H vectors + h_receptor_feats = hetero_graph["receptor_hv_iegmn_out"] # (m, d) + h_receptor_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_receptor_feats), axis=0, + keep_dims=True) # (1, d) + + h_ligand_feats = hetero_graph["ligand_hv_iegmn_out"] # (n, d) + h_ligand_feats_att_mean_rot = ops.mean(self.mlp_h_mean_rot(h_ligand_feats), axis=0, + keep_dims=True) # (1, d) + + d = h_ligand_feats.shape[1] + + # Z coordinates + z_receptor_coors = hetero_graph["receptor_x_iegmn_out"] + z_ligand_coors = hetero_graph["ligand_x_iegmn_out"] + + #### AP 1: compute two point clouds of K_heads points each, then do Kabsch ##### + # Att weights to compute the receptor centroid. + # They query is the average_h_ligand. Keys are each h_receptor_j. + + y_receptor_att_rot, all_y_receptor_att_rot_list = self.att_weights_rot( + h_receptor_feats, h_ligand_feats_att_mean_rot, d, z_receptor_coors, all_y_receptor_att_rot_list + ) + + y_ligand_att_rot, all_y_ligand_att_rot_list = self.att_weights_rot( + h_ligand_feats, h_receptor_feats_att_mean_rot, d, z_ligand_coors, all_y_ligand_att_rot_list + ) + + ## Apply Kabsch algorithm + y_receptor_att_rot_mean = y_receptor_att_rot.mean(axis=0, keep_dims=True) # (1,3) + y_ligand_att_rot_mean = y_ligand_att_rot.mean(axis=0, keep_dims=True) # (1,3) + + a = (y_receptor_att_rot - y_receptor_att_rot_mean).transpose(1, 0) @ ( + y_ligand_att_rot - y_ligand_att_rot_mean) # 3, 3 + + if ops.isnan(a).any(): + raise ValueError("There is Nan in a") + u, s, vt = np.linalg.svd(a.asnumpy()) + u, s, vt = Tensor(u), Tensor(s), Tensor(vt) + + num_it = 0 + while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - (s ** 2).view(3, 1) + ops.eye(3)))[ + 0] < 1e-2: + if self.debug: + self.log('S inside loop ', num_it, ' is ', s, ' and A = ', a) + + a = a + ops.rand(3, 3) * ops.eye(3) + u, s, vt = np.linalg.svd(a.asnumpy()) + u, s, vt = Tensor(u), Tensor(s), Tensor(vt) + num_it += 1 + + if num_it > 10: + self.log('SVD consistently numerically unstable! Exitting ... ') + raise ValueError('SVD consistently numerically unstable!') + + corr_mat = ops.diag(Tensor([1, 1, float(ops.sign(ops.det(a)))])) + t_align = (u @ corr_mat) @ vt + + b_align = y_receptor_att_rot_mean - ops.t(t_align @ y_ligand_att_rot_mean.t()) # (1,3) + + #################### end AP 1 ######################### + + if self.debug: + self.log('Y_receptor_att_ROT_mean', y_receptor_att_rot_mean) + self.log('Y_ligand_att_ROT_mean', y_ligand_att_rot_mean) + + all_t_align_list.append(t_align) + all_b_align_list.append(b_align) + + return [all_t_align_list, all_b_align_list, all_y_ligand_att_rot_list, all_y_receptor_att_rot_list] + + def construct( + self, + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ): + ligand_graph_edge_tensor = input_tensor_tuple[2] + receptor_graph_edge_tensor = input_tensor_tuple[3] + + orig_coors_ligand = ligand_graph_node_tensor[:, 4:7] + orig_coors_receptor = receptor_graph_node_tensor[:, 1:4] + + coors_ligand = ligand_graph_node_tensor[:, 4:7] + coors_receptor = receptor_graph_node_tensor[:, 1:4] + + ## Embed residue types with a lookup table. + h_feats_ligand = self.residue_emb_layer( + ligand_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim) + h_feats_receptor = self.residue_emb_layer( + receptor_graph_node_tensor[:, 0].view(-1).long()) # (N_res, emb_dim) + + if self.use_mean_node_features: + h_feats_ligand = ops.cat([h_feats_ligand, + ops.log(ligand_graph_node_tensor[:, 7:12])], axis=1) + h_feats_receptor = ops.cat([h_feats_receptor, + ops.log(receptor_graph_node_tensor[:, 4:9])], axis=1) + + original_ligand_node_features = h_feats_ligand + original_receptor_node_features = h_feats_receptor + + original_edge_feats_ligand = ligand_graph_edge_tensor * self.use_edge_features_in_gmn + original_edge_feats_receptor = receptor_graph_edge_tensor * self.use_edge_features_in_gmn + + for i, layer in enumerate(self.iegmn_layers): + if self.debug: + self.log('layer ', i) + + iegmn_layer_tuple = ( + coors_ligand, h_feats_ligand, original_ligand_node_features, original_edge_feats_ligand, + orig_coors_ligand, coors_receptor, h_feats_receptor, original_receptor_node_features, + original_edge_feats_receptor, orig_coors_receptor + ) + + coors_ligand, h_feats_ligand, coors_receptor, h_feats_receptor = \ + layer(iegmn_layer_tuple=iegmn_layer_tuple, input_tensor_tuple=input_tensor_tuple) + + if self.debug: + self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand before layers ') + self.log(ops.max(h_feats_ligand)[0], ops.norm(h_feats_ligand), + 'h_feats_ligand before layers but after mu_r_norm') + self.log(ops.max(h_feats_ligand)[0], 'h_feats_ligand after MPNN') + self.log(ops.max(coors_ligand)[0], 'coors_ligand before after MPNN') + + list_hetero_graph = unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, + h_feats_ligand, coors_receptor, coors_ligand) + + ap_res = self.ap_compute(list_hetero_graph) + + return ap_res + + +class RigidBodyDockingNet(nn.Cell): + + def __init__(self, args, log=None): + + super(RigidBodyDockingNet, self).__init__() + + self.debug = args.debug + self.log = log + + self.iegmn_original = IEGMN(args, n_lays=args.iegmn_n_lays, fine_tune=False, log=log) + if args.fine_tune: + self.iegmn_fine_tune = IEGMN(args, n_lays=2, fine_tune=True, log=log) + self.list_iegmns = [('original', self.iegmn_original), ('finetune', self.iegmn_fine_tune)] + else: + self.list_iegmns = [('finetune', self.iegmn_original)] # just original + + def __repr__(self): + return "RigidBodyDockingNet " + str(self.__dict__) + + ####### FORWARD for RigidBodyDockingNet + def construct( + self, + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ): + last_outputs = None + all_ligand_coors_deform_list = [] + + for stage, iegmn in self.list_iegmns: + outputs = iegmn( + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ) + + if len(outputs) != 4: + raise ValueError("Number of outputs not correct") + + if stage == 'finetune': + last_outputs = outputs + + if stage == 'original': + new_list_hetero_graph = [] + + list_hetero_graph = [] + for i, _ in enumerate(unbatch_list_tensor): + if i < len(unbatch_list_tensor) - 1: + list_hetero_graph.append( + [ + ligand_graph_node_tensor[unbatch_list_tensor[i][2]:unbatch_list_tensor[i + 1][2], :], + receptor_graph_node_tensor[unbatch_list_tensor[i][3]:unbatch_list_tensor[i + 1][3], :], + ] + ) + + for the_idx, hetero_graph in enumerate(list_hetero_graph): + orig_coors_ligand = hetero_graph[0][:, 4:7] + + t_align = outputs[0][the_idx] + b_align = outputs[1][the_idx] + if b_align.shape[0] != 1 or b_align.shape[1] != 3: + raise ValueError("Shape not correct") + + inner_coors_ligand = (t_align @ orig_coors_ligand.t()).t() + b_align # (n,3) + + if stage == 'original': + hetero_graph[0][:, 4:7] = inner_coors_ligand + new_list_hetero_graph.append(hetero_graph) + + if self.debug: + self.log('T_align', t_align) + self.log('T_align @ T_align.t() - eye(3)', t_align @ t_align.t() - ops.eye(3)) + self.log('b_align', b_align) + self.log('\n ---> inner_coors_ligand mean - true ligand mean ', + inner_coors_ligand.mean(axis=0)[0] - ligand_graph_node_tensor[:, 1:4].mean(axis=0)[0], + '\n') + + if stage == 'finetune': + all_ligand_coors_deform_list.append(inner_coors_ligand) + + all_keypts_ligand_list = last_outputs[2] + all_keypts_receptor_list = last_outputs[3] + all_rotation_list = last_outputs[0] + all_translation_list = last_outputs[1] + + return all_ligand_coors_deform_list, \ + all_keypts_ligand_list, all_keypts_receptor_list, \ + all_rotation_list, all_translation_list + diff --git a/MindSPONGE/src/mindsponge/pipeline/pipeline.py b/MindSPONGE/src/mindsponge/pipeline/pipeline.py index 8efdf2752..1de798558 100644 --- a/MindSPONGE/src/mindsponge/pipeline/pipeline.py +++ b/MindSPONGE/src/mindsponge/pipeline/pipeline.py @@ -35,6 +35,7 @@ from .models import RASP, RASPDataSet, rasp_configuration from .models import Multimer, MultimerDataSet, multimer_configuration from .models import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configuration from .models import UFold, UFoldDataSet, ufold_configuration +from .models import EquiDock, EquiDockDataSet, equidock_configuration model_card = { @@ -53,6 +54,7 @@ model_card = { "Proteinmpnn": {"model": ProteinMpnn, "dataset": ProteinMpnnDataset, "config": proteinmpnn_configuration}, "RASP": {"model": RASP, "dataset": RASPDataSet, "config": rasp_configuration}, "UFold": {"model": UFold, "dataset": UFoldDataSet, "config": ufold_configuration}, + "EquiDock": {"model": EquiDock, "dataset": EquiDockDataSet, "config": equidock_configuration}, } -- Gitee From 46eb56a4ddde8cd0bc1b480ef0a31a43858a64ae Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 13 Aug 2024 20:37:19 +0800 Subject: [PATCH 02/16] modify codes --- .../pipeline/models/equidock/equidock.py | 146 ++--- .../models/equidock/equidock_configuration.py | 4 +- .../pipeline/models/equidock/equidock_data.py | 33 +- .../models/equidock/equidock_dataset.py | 217 ++++++- .../pipeline/models/equidock/nn_arch.py | 572 +++++------------- .../pipeline/models/equidock/train_utils.py | 246 ++++++++ 6 files changed, 719 insertions(+), 499 deletions(-) create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py index 9089dd096..2a74e2956 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py @@ -1,3 +1,6 @@ +""" +equidock +""" # Copyright 2024 @ Shenzhen Bay Laboratory & # Peking University & # Huawei Technologies Co., Ltd @@ -27,20 +30,23 @@ from datetime import datetime as dt import numpy as np from biopandas.pdb import PandasPdb import mindspore as ms -from mindspore import jit, context, nn, Tensor, ops, save_checkpoint +from mindspore import nn, Tensor, ops, save_checkpoint from mindspore.experimental import optim -from .nn_arch import * -from .equidock_data import UnboundBoundData +from .train_utils import * +from .nn_arch import MeterUnboundBound, RigidBodyDockingNet, log, FLAGS +from .equidock_dataset import EquiDockDataSet from ..model import Model class EquiDock(Model): - "EquiDock" + """ + EquiDock class + """ name = "EquiDock" def __init__(self, config): - + self.config = config self.mixed_precision = False self.network = RigidBodyDockingNet(self.config) @@ -65,31 +71,32 @@ class EquiDock(Model): log_file_name = os.path.join(self.config.log_files_dir, banner + ".txt") with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as fout: fout.write('[' + str(dt.now()) + '] START\n') - self.config.checkpoint_filename = os.path.join(self.config.checkpoint_dir, self.config.data + '_model_best.pth') + self.config.checkpoint_filename = os.path.join( + self.config.checkpoint_dir, + self.config.data + '_model_best.pth' + ) self.loss_fn_coors = nn.MSELoss(reduction='mean') - self.optimizer = optim.Adam(self.network.trainable_params(), lr=self.config.lr, weight_decay=self.config.w_decay) - self.scheduler = optim.lr_scheduler.LambdaLR(self.optimizer, - lr_lambda=[lambda epoch: min(1., ((epoch + 1) / self.config.warmup) ** 3)]) + self.optimizer = optim.Adam( + self.network.trainable_params(), + lr=self.config.lr, + weight_decay=self.config.w_decay + ) + self.scheduler = optim.lr_scheduler.LambdaLR( + self.optimizer, + lr_lambda=[lambda epoch: min(1., ((epoch + 1) / self.config.warmup) ** 3)] + ) self.best_epoch = 0 self.best_val_rmsd_median = float('inf') self.corr_val_rmsd_mean = float('inf') - if not os.path.exists(self.config.processed_dataset_path): - UnboundBoundData(self.config, reload_mode='val', raw_data_path=self.config.raw_data_path, - split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) - UnboundBoundData(self.config, reload_mode='test', raw_data_path=self.config.raw_data_path, - split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) - UnboundBoundData(self.config, reload_mode='train', raw_data_path=self.config.raw_data_path, - split_files_path=self.config.split_files_path, data_fraction=self.config.data_fraction) - - self.train_data_batched, self.train_loader = create_dataloader(self.config.train_dir, bs=self.config.bs, shuffle=True) - self.val_data_batched, self.val_loader = create_dataloader(self.config.val_dir, bs=self.config.bs, shuffle=False) - self.test_data_batched, self.test_loader = create_dataloader(self.config.test_dir, bs=self.config.bs, shuffle=False) + dataset_util = EquiDockDataSet(self.config) + self.train_data_batched, self.train_loader, self.val_data_batched, self.val_loader, \ + self.test_data_batched, self.test_loader = dataset_util.set_training_data_src("Train Mode") - super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, - network=self.network, mixed_precision=self.mixed_precision) + super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, + network=self.network, mixed_precision=self.mixed_precision) def from_pretrained(self, ckpt_path=None): @@ -111,15 +118,12 @@ class EquiDock(Model): file_type = '.pdb' l_b_str = '_l_b' - files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f)) and f.endswith(file_type)] - - for file in files: + for file in data: if not file.endswith(l_b_str + file_type): continue ll = len(l_b_str + file_type) ligand_filename = os.path.join(input_dir, file[:-ll] + l_b_str + file_type) receptor_filename = os.path.join(ground_truth_dir, file[:-ll] + '_r_b' + '_COMPLEX' + file_type) - gt_ligand_filename = os.path.join(ground_truth_dir, file[:-ll] + l_b_str + '_COMPLEX' + file_type) out_filename = file[:-ll] + l_b_str + '_' + "EQUIDOCK" + file_type self.log(' inference on file = ', ligand_filename) @@ -137,19 +141,16 @@ class EquiDock(Model): ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, \ input_tensor_tuple = graph_to_tensor(ligand_graph, receptor_graph) - model_ligand_coors_deform_list, model_keypts_ligand_list, model_keypts_receptor_list, \ - all_rotation_list, all_translation_list = self.network( + _, _, _, all_rotation_list, all_translation_list = self.network( ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple, - ) + ) rotation = all_rotation_list[0].asnumpy() translation = all_translation_list[0].asnumpy() - new_residues = (rotation @ bound_ligand_repres_nodes_loc_clean_array.T).T + translation - unbound_ligand_new_pos = (rotation @ unbound_ligand_all_atoms_pre_pos.T).T + translation euler_angles_finetune = ops.zeros([3]) @@ -177,12 +178,12 @@ class EquiDock(Model): self.run_a_generic_epoch('train', run_epoch_tuple, data_batched, data_loader) pretty_print_stats('TRAIN', epoch, args.num_epochs, - train_complex_rmsd_mean, train_complex_rmsd_median, - train_ligand_rmsd_mean, train_ligand_rmsd_median, - train_receptor_rmsd_mean, train_receptor_rmsd_median, - train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, - train_avg_loss_ot, train_avg_loss_intersection, - self.log) + train_complex_rmsd_mean, train_complex_rmsd_median, + train_ligand_rmsd_mean, train_ligand_rmsd_median, + train_receptor_rmsd_mean, train_receptor_rmsd_median, + train_avg_loss, train_avg_loss_ligand_coors, train_avg_loss_receptor_coors, + train_avg_loss_ot, train_avg_loss_intersection, + self.log) def run_an_eval_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): @@ -195,34 +196,18 @@ class EquiDock(Model): self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader) pretty_print_stats('VALIDATION', epoch, args.num_epochs, - val_complex_rmsd_mean, val_complex_rmsd_median, - val_ligand_rmsd_mean, val_ligand_rmsd_median, - val_receptor_rmsd_mean, val_receptor_rmsd_median, - val_avg_loss, val_avg_loss_ligand_coors, val_avg_loss_receptor_coors, - val_avg_loss_ot, val_avg_loss_intersection, - self.log) - return val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss - - - def run_a_test_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): - test_complex_rmsd_mean, test_complex_rmsd_median, \ - test_ligand_rmsd_mean, test_ligand_rmsd_median, \ - test_receptor_rmsd_mean, test_receptor_rmsd_median, \ - test_avg_loss, test_avg_loss_ligand_coors, \ - test_avg_loss_receptor_coors, \ - test_avg_loss_ot, test_avg_loss_intersection = \ - self.run_a_generic_epoch('eval', run_epoch_tuple, data_batched, data_loader) + val_complex_rmsd_mean, val_complex_rmsd_median, + val_ligand_rmsd_mean, val_ligand_rmsd_median, + val_receptor_rmsd_mean, val_receptor_rmsd_median, + val_avg_loss, val_avg_loss_ligand_coors, val_avg_loss_receptor_coors, + val_avg_loss_ot, val_avg_loss_intersection, + self.log) - pretty_print_stats('FINAL TEST for ' + args.data, -1, args.num_epochs, - test_complex_rmsd_mean, test_complex_rmsd_median, - test_ligand_rmsd_mean, test_ligand_rmsd_median, - test_receptor_rmsd_mean, test_receptor_rmsd_median, - test_avg_loss, test_avg_loss_ligand_coors, test_avg_loss_receptor_coors, - test_avg_loss_ot, test_avg_loss_intersection, self.log) + return val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss def run_a_generic_epoch(self, ep_type, run_epoch_tuple, data_batched, data_loader): - args, epoch, self.network, loss_fn_coors, optimizer = run_epoch_tuple + args, self.network, loss_fn_coors, optimizer = run_epoch_tuple meter = MeterUnboundBound() @@ -258,11 +243,11 @@ class EquiDock(Model): model_ligand_coors_deform_list, \ model_keypts_ligand_list, model_keypts_receptor_list, \ _, _, = self.network( - ligand_graph_node_tensor, - receptor_graph_node_tensor, - unbatch_list_tensor, - input_tensor_tuple, - ) + ligand_graph_node_tensor, + receptor_graph_node_tensor, + unbatch_list_tensor, + input_tensor_tuple, + ) ################################ batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, batch_intersection_loss = Tensor(0.0), \ @@ -272,8 +257,8 @@ class EquiDock(Model): ## Compute average MSE loss (which is 3 times smaller than average squared RMSD) batch_ligand_coors_loss = batch_ligand_coors_loss + loss_fn_coors(model_ligand_coors_deform_list[i], - Tensor(data_batched[8][batch_id][i], - ms.float32)) + Tensor(data_batched[8][batch_id][i], + ms.float32)) # Compute the OT loss for the binding pocket: ligand_pocket_coors = Tensor(data_batched[10][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes receptor_pocket_coors = Tensor(data_batched[11][batch_id][i], ms.float32) # (N, 3), N = num pocket nodes @@ -295,9 +280,9 @@ class EquiDock(Model): ### Add new stats to the meter if ep_type != 'train' or random.random() < 0.1: meter.update_rmsd(model_ligand_coors_deform_list[i], - data_batched[9][batch_id][i], - data_batched[8][batch_id][i], - data_batched[9][batch_id][i]) + data_batched[9][batch_id][i], + data_batched[8][batch_id][i], + data_batched[9][batch_id][i]) batch_ligand_coors_loss = batch_ligand_coors_loss / float(len(model_ligand_coors_deform_list)) batch_receptor_coors_loss = batch_receptor_coors_loss / float(len(model_ligand_coors_deform_list)) @@ -318,7 +303,7 @@ class EquiDock(Model): num_batches += 1 (loss, batch_ligand_coors_loss, batch_receptor_coors_loss, batch_ot_loss, - batch_intersection_loss), grads = backward(step) + batch_intersection_loss), grads = backward(step) if ep_type == 'train': grads = ms.ops.clip_by_norm(grads, max_norm=args.clip, norm_type=2) @@ -348,15 +333,16 @@ class EquiDock(Model): avg_loss_ot.item(), avg_loss_intersection.item() - def train_step(self, epoch): + def train_step(self, data): self.log('+' * 100) + epoch = data epoch_start = dt.now() - run_epoch_tuple = (self.config, epoch, self.network, self.loss_fn_coors, self.optimizer) + run_epoch_tuple = (self.config, self.network, self.loss_fn_coors, self.optimizer) self.run_a_train_epoch(run_epoch_tuple, self.train_data_batched, self.train_loader, epoch, self.config) - val_complex_rmsd_mean, val_complex_rmsd_median, val_avg_loss = \ - self.run_an_eval_epoch(run_epoch_tuple, self.val_data_batched, self.val_loader, epoch, self.config) + val_complex_rmsd_mean, val_complex_rmsd_median, _ = \ + self.run_an_eval_epoch(run_epoch_tuple, self.val_data_batched, self.val_loader, epoch, self.config) self.scheduler.step() @@ -366,14 +352,14 @@ class EquiDock(Model): self.corr_val_rmsd_mean = val_complex_rmsd_mean self.best_epoch = epoch save_checkpoint(self.optimizer, self.config.checkpoint_dir + "/best_model.ckpt") - + log('[BEST SO FAR] ', self.config.data, '|| At best epoch {} we have: best_val_rmsd_median {:.4f}, ' '|| Current val rmsd_median {:.4f} ' '|| Train time: {}\n'. format(self.best_epoch, self.best_val_rmsd_median, - val_complex_rmsd_median, - dt.now() - epoch_start)) + val_complex_rmsd_median, + dt.now() - epoch_start)) return '\n' @@ -395,4 +381,4 @@ class EquiDock(Model): def _pynative_forward(self, data): - return None \ No newline at end of file + return None diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py index 07e21e270..e1fb7201d 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_configuration.py @@ -28,5 +28,5 @@ equidock_configuration = { "predict_db5": "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml", "train_db5": - "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml", -} \ No newline at end of file + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml", +} diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py index 5adf98836..a71b96d0c 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py @@ -1,3 +1,6 @@ +""" +equidock_data +""" # Copyright 2024 @ Shenzhen Bay Laboratory & # Peking University & # Huawei Technologies Co., Ltd @@ -26,7 +29,8 @@ from biopandas.pdb import PandasPdb from joblib import Parallel, delayed, cpu_count from scipy.spatial.transform import Rotation -from .nn_arch import * +from .nn_arch import preprocess_unbound_bound, protein_to_graph_unbound_bound +from .train_utils import log def one_hot_encoding(x, allowable_set, encode_unknown=False): @@ -79,14 +83,16 @@ def pmap_multi2(pickleable_fn, data, n_jobs=None, verbose=1, **kwargs): class UnboundBoundData(): - - def __init__(self, args, reload_mode='train', raw_data_path=None, split_files_path=None, data_fraction=1.): + """ + UnboundBoundData Class + """ + def __init__(self, args, reload_mode='train', raw_data_path=None, split_files_path=None): os.makedirs(args.processed_dataset_path + reload_mode, exist_ok=True) if args.data == 'db5': onlyfiles = [f for f in os.listdir(raw_data_path) if os.path.isfile(os.path.join(raw_data_path, f))] - code_set = set([file.split('_')[0] for file in onlyfiles]) + code_set = {file.split('_')[0] for file in onlyfiles} split_code_set = set() with open(os.path.join(split_files_path, reload_mode + '.txt'), 'r') as f: for line in f.readlines(): @@ -143,6 +149,25 @@ class UnboundBoundData(): log('Done protein_to_graph_unbound_bound') + self.save_processed_data( + args, + reload_mode, + both_proteins_to_graph_pair_list, + bound_ligand_repres_nodes_loc_array_list, + bound_receptor_repres_nodes_loc_array_list, + pocket_coors_list, + ) + + def save_processed_data( + self, + args, + reload_mode, + both_proteins_to_graph_pair_list, + bound_ligand_repres_nodes_loc_array_list, + bound_receptor_repres_nodes_loc_array_list, + pocket_coors_list, + ): + ligand_graph_list, receptor_graph_list = [], [] for result in both_proteins_to_graph_pair_list: ligand_graph, receptor_graph = result diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py index cedb95442..82bd2e1ee 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py @@ -1,3 +1,6 @@ +"""" +equidock_dataset +""" # Copyright 2024 @ Shenzhen Bay Laboratory & # Peking University & # Huawei Technologies Co., Ltd @@ -21,18 +24,222 @@ # limitations under the License. # ============================================================================ import os -from ...dataset import PSP, data_process_run +from datetime import datetime as dt + +import numpy as np + +from .equidock_data import UnboundBoundData +from ...dataset import PSP class EquiDockDataSet(PSP): "EquiDockDataSet" - def __init__(self, config, seed=0): + def __init__(self, config): self.config = config + self.train_data_batched = None + self.train_loader = None + self.val_data_batched = None + self.val_loader = None + self.test_data_batched = None + self.test_loader = None + super().__init__() def process(self, data, **kwargs): - return - + self.log(data, **kwargs) + input_dir = self.config.input_dir + files = [f for f in os.listdir(input_dir) \ + if os.path.isfile(os.path.join(input_dir, f)) and f.endswith('.pdb')] + return files + def set_training_data_src(self, data_source, **kwargs): - return + self.log(data_source, **kwargs) + + if not os.path.exists(self.config.processed_dataset_path): + UnboundBoundData( + self.config, + reload_mode='val', + raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, + ) + UnboundBoundData( + self.config, + reload_mode='test', + raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, + ) + UnboundBoundData( + self.config, + reload_mode='train', + raw_data_path=self.config.raw_data_path, + split_files_path=self.config.split_files_path, + ) + + self.train_data_batched, self.train_loader = self.create_dataloader( + self.config.train_dir, + bs=self.config.bs, + shuffle=True, + ) + self.val_data_batched, self.val_loader = self.create_dataloader( + self.config.val_dir, + bs=self.config.bs, + shuffle=False, + ) + self.test_data_batched, self.test_loader = self.create_dataloader( + self.config.test_dir, + bs=self.config.bs, + shuffle=False, + ) + + return ( + self.train_data_batched, + self.train_loader, + self.val_data_batched, + self.val_loader, + self.test_data_batched, + self.test_loader, + ) def create_iterator(self, num_epochs, **kwargs): + self.log(**kwargs) + if num_epochs == self.config.num_epochs: + return [_ for _ in range(num_epochs)] return [_ for _ in range(self.config.num_epochs)] + + def data_parse(self, idx): + return self.train_data_batched[idx] + + def log(self, *pargs): + print('[' + str(dt.now()) + '] ', *pargs) + + def load_data(self, files_dir): + """ + load_data + """ + npz_files_list = [ + np.load(files_dir + "/ll_connection_tensor_list.npz"), + np.load(files_dir + "/rr_connection_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_edge_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_edge_tensor_list.npz"), + np.load(files_dir + "/ligand_graph_num_nodes_list.npz"), + np.load(files_dir + "/receptor_graph_num_nodes_list.npz"), + np.load(files_dir + "/ligand_graph_node_tensor_list.npz"), + np.load(files_dir + "/receptor_graph_node_tensor_list.npz"), + np.load(files_dir + "/bound_ligand_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/bound_receptor_repres_nodes_loc_array_list_new.npz"), + np.load(files_dir + "/pocket_coors_ligand_list_new.npz"), + np.load(files_dir + "/pocket_coors_receptor_list_new.npz"), + ] + + extracted_files = [] + for _, npz_file in enumerate(npz_files_list): + extracted_files.append([npz_file[id] for id in npz_file.files]) + + unbatch_list = [] + for i, _ in enumerate(extracted_files[0]): + temp_length = [] + temp_length.append(len(extracted_files[0][i][0])) + temp_length.append(len(extracted_files[1][i][0])) + temp_length.append(len(extracted_files[6][i])) + temp_length.append(len(extracted_files[7][i])) + temp_length.append(len(extracted_files[11][i])) + unbatch_list.append(np.array(temp_length)) + extracted_files.append(unbatch_list) + + index_shuffle = np.arange(len(npz_files_list[0])) + + return extracted_files, index_shuffle + + + def shuffle_list(self, source_list, index_shuffle): + shuffled_list = [] + for _, idx in enumerate(index_shuffle): + shuffled_list.append(source_list[idx]) + + return shuffled_list + + + def batch(self, bs, souce_list): + output_list = [] + num = len(souce_list) // bs + for i in range(num): + output_list.append(souce_list[i * bs: (i + 1) * bs]) + if num * bs < len(souce_list): + output_list.append(souce_list[num * bs:]) + + return output_list + + + def shuffle_batch_dataset(self, input_dataset, index_shuffle, batch_size, shuffle): + """ + shuffle_batch_dataset + """ + if shuffle: + shuffled_dataset = [] + np.random.shuffle(index_shuffle) + for i, _ in enumerate(input_dataset): + shuffled_dataset.append(self.shuffle_list(input_dataset[i], index_shuffle)) + else: + shuffled_dataset = input_dataset[:] + + dataset_batched = [] + for _, data in enumerate(shuffled_dataset): + dataset_batched.append(self.batch(batch_size, data)) + unbatch_list = dataset_batched[-1] + for j, _ in enumerate(unbatch_list): + unbatch_list[j].insert(0, np.array([0, 0, 0, 0, 0])) + for i, _ in enumerate(unbatch_list[j]): + if i >= 1: + unbatch_list[j][i][0] += unbatch_list[j][i - 1][0] + unbatch_list[j][i][1] += unbatch_list[j][i - 1][1] + unbatch_list[j][i][2] += unbatch_list[j][i - 1][2] + unbatch_list[j][i][3] += unbatch_list[j][i - 1][3] + unbatch_list[j][i][4] += unbatch_list[j][i - 1][4] + dataset_batched[-1] = unbatch_list + + return dataset_batched + + + def cat_properties(self, input_dataset_batched): + """ + cat_properties + """ + dataset_batch_cat = [] + for i in range(len(input_dataset_batched[0])): + dataset_batch_cat.append( + [ + np.concatenate(input_dataset_batched[0][i], axis=1), + np.concatenate(input_dataset_batched[1][i], axis=1), + np.concatenate(input_dataset_batched[2][i], axis=0), + np.concatenate(input_dataset_batched[3][i], axis=0), + np.concatenate(input_dataset_batched[4][i], axis=0), + np.concatenate(input_dataset_batched[5][i], axis=0), + np.concatenate(input_dataset_batched[6][i], axis=0), + np.concatenate(input_dataset_batched[7][i], axis=0), + input_dataset_batched[-1][i], + ] + ) + + return dataset_batch_cat + + + def create_dataloader(self, dataset_dir, bs, shuffle): + """ + create_dataloader + """ + dataset, index_shuffle = self.load_data(dataset_dir) + dataset_batched = self.shuffle_batch_dataset( + input_dataset=dataset, + index_shuffle=index_shuffle, + batch_size=bs, + shuffle=shuffle, + ) + dataloader = self.cat_properties(dataset_batched) + + return dataset_batched, dataloader + + + def __getitem__(self, idx): + return self.data_parse(idx) + + + def __len__(self): + return len(self.train_data_batched) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py index 36516f725..37c5ae4b5 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py @@ -1,3 +1,6 @@ +""" +nn_arch +""" # Copyright 2024 @ Shenzhen Bay Laboratory & # Peking University & # Huawei Technologies Co., Ltd @@ -24,27 +27,42 @@ import os import math from datetime import datetime as dt -import ot import numpy as np from numpy import linalg as LA from biopandas.pdb import PandasPdb import scipy.spatial as spa from scipy.special import softmax import mindspore as ms -from mindspore import nn, ops, Tensor, Parameter +from mindspore import nn, ops, Tensor ATOM_NAME = 'atom_name' FLAGS = os.O_RDWR | os.O_CREAT -class MeterUnboundBound(object): +def log(*pargs): + banner = 'EQUIDOCK' + log_file_name = os.path.join('stdouterr/', banner + ".txt") + with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as w: + w.write('[' + str(dt.now()) + '] ') + w.write(" ".join(["{}".format(t) for t in pargs])) + w.write("\n") + print('[' + str(dt.now()) + '] ', *pargs) + + +class MeterUnboundBound(): + """ + MeterUnboundBound + """ def __init__(self): self.complex_rmsd_list = [] self.ligand_rmsd_list = [] self.receptor_rmsd_list = [] def update_rmsd(self, ligand_coors_pred, receptor_coors_pred, ligand_coors_true, receptor_coors_true): + """ + update_rmsd + """ ligand_coors_pred = ligand_coors_pred.asnumpy() receptor_coors_pred = receptor_coors_pred @@ -69,6 +87,9 @@ class MeterUnboundBound(object): return complex_rmsd def summarize(self, reduction_rmsd='median'): + """ + summarize + """ if reduction_rmsd == 'mean': complex_rmsd_array = np.array(self.complex_rmsd_list) complex_rmsd_summarized = np.mean(complex_rmsd_array) @@ -104,213 +125,6 @@ class MeterUnboundBound(object): return complex_rmsd_summarized, np.std(complex_rmsd_array) -def create_dir(path): - if os.path.exists(path): - raise FileExistsError('Path already exists. Please delete and restart your job.') - else: - os.makedirs(path, exist_ok=False) - - -def log(*pargs): - banner = 'EQUIDOCK' - log_file_name = os.path.join('stdouterr/', banner + ".txt") - with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as w: - w.write('[' + str(dt.now()) + '] ') - w.write(" ".join(["{}".format(t) for t in pargs])) - w.write("\n") - print('[' + str(dt.now()) + '] ', *pargs) - - -def g_fn(protein_coords, x, sigma): - # protein_coords: (n,3) , x: (m,3), output: (m,) - e = ops.exp(- ops.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n) - - return - sigma * ops.log(1e-3 + ops.sum(e, dim=1)) - - -def compute_body_intersection_loss( - model_ligand_coors_deform, - bound_receptor_repres_nodes_loc_array, - sigma, - surface_ct, -): - loss = ops.mean( - ops.clamp(surface_ct - g_fn(Tensor(bound_receptor_repres_nodes_loc_array), model_ligand_coors_deform, sigma), - min=0)) + \ - ops.mean(ops.clamp( - surface_ct - g_fn(model_ligand_coors_deform, Tensor(bound_receptor_repres_nodes_loc_array), sigma), - min=0)) - - return loss - - -def pretty_print_stats(*args): - - split_type, epoch, total_num_epochs,\ - complex_rmsd_mean, complex_rmsd_median,\ - ligand_rmsd_mean, ligand_rmsd_median,\ - receptor_rmsd_mean, receptor_rmsd_median,\ - avg_loss, avg_loss_ligand_coors,\ - avg_loss_receptor_coors, avg_loss_ot, avg_loss_intersection, args_input = args - - log('[{:s}] --> epoch {:d}/{:d} ' - '|| mean/median complex rmsd {:.4f} / {:.4f} ' - '|| mean/median ligand rmsd {:.4f} / {:.4f} ' - '|| mean/median sqrt pocket OT loss {:.4f} ' - '|| intersection loss {:.4f} ' - '|| mean/median receptor rmsd {:.4f} / {:.4f} '. - format(split_type, - epoch, total_num_epochs, - complex_rmsd_mean, complex_rmsd_median, - ligand_rmsd_mean, ligand_rmsd_median, - math.sqrt(avg_loss_ot), - avg_loss_intersection, - receptor_rmsd_mean, receptor_rmsd_median)) - - -def load_data(files_dir): - - npz_files_list = [ - np.load(files_dir + "/ll_connection_tensor_list.npz"), - np.load(files_dir + "/rr_connection_tensor_list.npz"), - np.load(files_dir + "/ligand_graph_edge_tensor_list.npz"), - np.load(files_dir + "/receptor_graph_edge_tensor_list.npz"), - np.load(files_dir + "/ligand_graph_num_nodes_list.npz"), - np.load(files_dir + "/receptor_graph_num_nodes_list.npz"), - np.load(files_dir + "/ligand_graph_node_tensor_list.npz"), - np.load(files_dir + "/receptor_graph_node_tensor_list.npz"), - np.load(files_dir + "/bound_ligand_repres_nodes_loc_array_list_new.npz"), - np.load(files_dir + "/bound_receptor_repres_nodes_loc_array_list_new.npz"), - np.load(files_dir + "/pocket_coors_ligand_list_new.npz"), - np.load(files_dir + "/pocket_coors_receptor_list_new.npz"), - ] - - extracted_files = [] - for _, npz_file in enumerate(npz_files_list): - extracted_files.append([npz_file[id] for id in npz_file.files]) - - unbatch_list = [] - for i, _ in enumerate(extracted_files[0]): - temp_length = [] - temp_length.append(len(extracted_files[0][i][0])) - temp_length.append(len(extracted_files[1][i][0])) - temp_length.append(len(extracted_files[6][i])) - temp_length.append(len(extracted_files[7][i])) - temp_length.append(len(extracted_files[11][i])) - unbatch_list.append(np.array(temp_length)) - extracted_files.append(unbatch_list) - - index_shuffle = np.arange(len(npz_files_list[0])) - - return extracted_files, index_shuffle - - -def shuffle_list(source_list, index_shuffle): - shuffled_list = [] - for _, idx in enumerate(index_shuffle): - shuffled_list.append(source_list[idx]) - - return shuffled_list - - -def batch(bs, souce_list): - output_list = [] - num = len(souce_list) // bs - for i in range(num): - output_list.append(souce_list[i * bs: (i + 1) * bs]) - if num * bs < len(souce_list): - output_list.append(souce_list[num * bs:]) - - return output_list - - -def shuffle_batch_dataset(input_dataset, index_shuffle, batch_size, shuffle): - if shuffle: - shuffled_dataset = [] - np.random.shuffle(index_shuffle) - for i, _ in enumerate(input_dataset): - shuffled_dataset.append(shuffle_list(input_dataset[i], index_shuffle)) - else: - shuffled_dataset = input_dataset[:] - - dataset_batched = [] - for _, data in enumerate(shuffled_dataset): - dataset_batched.append(batch(batch_size, data)) - unbatch_list = dataset_batched[-1] - for j, _ in enumerate(unbatch_list): - unbatch_list[j].insert(0, np.array([0, 0, 0, 0, 0])) - for i, _ in enumerate(unbatch_list[j]): - if i >= 1: - unbatch_list[j][i][0] += unbatch_list[j][i - 1][0] - unbatch_list[j][i][1] += unbatch_list[j][i - 1][1] - unbatch_list[j][i][2] += unbatch_list[j][i - 1][2] - unbatch_list[j][i][3] += unbatch_list[j][i - 1][3] - unbatch_list[j][i][4] += unbatch_list[j][i - 1][4] - dataset_batched[-1] = unbatch_list - - return dataset_batched - - -def cat_properties(input_dataset_batched): - dataset_batch_cat = [] - for i in range(len(input_dataset_batched[0])): - dataset_batch_cat.append( - [ - np.concatenate(input_dataset_batched[0][i], axis=1), - np.concatenate(input_dataset_batched[1][i], axis=1), - np.concatenate(input_dataset_batched[2][i], axis=0), - np.concatenate(input_dataset_batched[3][i], axis=0), - np.concatenate(input_dataset_batched[4][i], axis=0), - np.concatenate(input_dataset_batched[5][i], axis=0), - np.concatenate(input_dataset_batched[6][i], axis=0), - np.concatenate(input_dataset_batched[7][i], axis=0), - input_dataset_batched[-1][i], - ] - ) - - return dataset_batch_cat - - -def create_dataloader(dataset_dir, bs, shuffle): - dataset, index_shuffle = load_data(dataset_dir) - dataset_batched = shuffle_batch_dataset( - input_dataset=dataset, - index_shuffle=index_shuffle, - batch_size=bs, - shuffle=shuffle, - ) - dataloader = cat_properties(dataset_batched) - - return dataset_batched, dataloader - - -def compute_sq_dist_mat(x_1, x_2): - '''Computes the l2 squared cost matrix between two point cloud inputs. - Args: - X_1: [n, #features] point cloud, tensor - X_2: [m, #features] point cloud, tensor - Output: - [n, m] matrix of the l2 distance between point pairs - ''' - n_1, _ = x_1.shape - n_2, _ = x_2.shape - x_1 = x_1.view(n_1, 1, -1) - x_2 = x_2.view(1, n_2, -1) - squared_dist = (x_1 - x_2) ** 2 - cost_mat = ops.sum(squared_dist, dim=2) - return cost_mat - - -def compute_ot_emd(cost_mat): - cost_mat_detach = cost_mat.asnumpy() - a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0] - b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1] - ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000) - ot_mat_attached = Parameter(Tensor(ot_mat, ms.float32), requires_grad=False) - ot_dist = ops.sum(ot_mat_attached * cost_mat) - return ot_dist, ot_mat_attached - - def get_nodes_coors_numpy(filename, all_atoms=False): df = PandasPdb().read_pdb(filename).df['ATOM'] if not all_atoms: @@ -327,120 +141,10 @@ def get_residues(pdb_filename): return residues -def prepare_graphs(args, ppdb_ligand, ligand_filename, receptor_filename): - unbound_ligand_all_atoms_pre_pos = ppdb_ligand.df["ATOM"][ - ['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32) - - unbound_predic_ligand, \ - unbound_predic_receptor, \ - bound_ligand_repres_nodes_loc_clean_array, \ - bound_receptor_repres_nodes_loc_clean_array, _ = preprocess_unbound_bound( - get_residues(ligand_filename), get_residues(receptor_filename), - graph_nodes=args.graph_nodes, pos_cutoff=args.pocket_cutoff, inference=True) - - protein_to_graph_unbound_bound_input_tuple = ( - unbound_predic_ligand, unbound_predic_receptor, - bound_ligand_repres_nodes_loc_clean_array, bound_receptor_repres_nodes_loc_clean_array, - ) - - ligand_graph, receptor_graph = protein_to_graph_unbound_bound( - protein_to_graph_unbound_bound_input_tuple, - cutoff=args.graph_cutoff, - max_neighbor=args.graph_max_neighbor, - one_hot=False, - residue_loc_is_alphac=args.graph_residue_loc_is_alphaC, - ) - - return ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array - - -def graph_to_tensor(ligand_graph, receptor_graph): - ligand_graph.ndata['new_x'] = ligand_graph.ndata['x'] - - # Create a batch of a single heterograph - ligand_graph_node_tensor = ops.cat( - (ligand_graph.ndata["res_feat"], - ligand_graph.ndata["x"], - ligand_graph.ndata["new_x"], - ligand_graph.ndata["mu_r_norm"]), - axis=1 - ) - - receptor_graph_node_tensor = ops.cat( - (receptor_graph.ndata["res_feat"], - receptor_graph.ndata["x"], - receptor_graph.ndata["mu_r_norm"]), - axis=1 - ) - - ligand_graph_num_nodes = Tensor([ligand_graph.num_nodes]) - receptor_graph_num_nodes = Tensor([receptor_graph.num_nodes]) - ligand_graph_edge_tensor = ligand_graph.edata["he"] - receptor_graph_edge_tensor = receptor_graph.edata["he"] - - ll_connection_tensor = ops.stack( - (Tensor(ligand_graph.src_list), Tensor(ligand_graph.dst_list))) - - rr_connection_tensor = ops.stack( - (Tensor(receptor_graph.src_list), Tensor(receptor_graph.dst_list))) - - unbatch_list = [ - [ - len(ll_connection_tensor[0]), - len(rr_connection_tensor[0]), - len(ligand_graph_node_tensor), - len(receptor_graph_node_tensor), - 0, - ] - ] - unbatch_list.insert(0, np.array([0, 0, 0, 0, 0])) - - unbatch_list = Tensor(unbatch_list, ms.int32) - - input_tensor_tuple = ( - ligand_graph_num_nodes, - receptor_graph_num_nodes, - ligand_graph_edge_tensor, - receptor_graph_edge_tensor, - ll_connection_tensor, - rr_connection_tensor, - ) - - return ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple - - -def get_rot_mat(euler_angles): - roll = euler_angles[0] - yaw = euler_angles[1] - pitch = euler_angles[2] - - tensor_0 = Tensor(0.0) - tensor_1 = Tensor(1.0) - cos = ops.cos - sin = ops.sin - - rx = ops.stack([ - ops.stack([tensor_1, tensor_0, tensor_0]), - ops.stack([tensor_0, cos(roll), -sin(roll)]), - ops.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3, 3) - - ry = ops.stack([ - ops.stack([cos(pitch), tensor_0, sin(pitch)]), - ops.stack([tensor_0, tensor_1, tensor_0]), - ops.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3, 3) - - rz = ops.stack([ - ops.stack([cos(yaw), -sin(yaw), tensor_0]), - ops.stack([sin(yaw), cos(yaw), tensor_0]), - ops.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) - - r = ops.mm(rz, ry) - r = ops.mm(r, rx) - - return r - - def distance_list_featurizer(dist_list): + """ + distance_list_featurizer + """ length_scale_list = [1.5 ** x for x in range(15)] center_list = [0. for _ in range(15)] @@ -459,6 +163,9 @@ def distance_list_featurizer(dist_list): def rigid_transform_kabsch_3d(a, b): + """ + rigid_transform_kabsch_3d + """ if a.shape[1] != b.shape[1]: raise ValueError("a.shape[1] not equals b.shape[1]") num_rows, num_cols = a.shape @@ -479,7 +186,7 @@ def rigid_transform_kabsch_3d(a, b): h = am @ bm.T # find rotation - u, s, vt = np.linalg.svd(h) + u, _, vt = np.linalg.svd(h) r = vt.T @ u.T @@ -523,6 +230,9 @@ def residue_list_featurizer_dips_not_one_hot(predic): def residue_type_one_hot_dips(residue): + """ + residue_type_one_hot_dips + """ dit = { 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', @@ -543,6 +253,9 @@ def residue_type_one_hot_dips(residue): def residue_type_one_hot_dips_not_one_hot(residue): + """ + residue_type_one_hot_dips_not_one_hot + """ dit = { 'ALA': 'A', 'ARG': 'R', 'ASN': 'N', 'ASP': 'D', 'CYS': 'C', 'GLN': 'Q', 'GLU': 'E', 'GLY': 'G', 'HIS': 'H', 'ILE': 'I', 'LEU': 'L', 'LYS': 'K', 'MET': 'M', 'PHE': 'F', @@ -569,14 +282,15 @@ def residue_type_one_hot_dips_not_one_hot(residue): res_name = residue if res_name not in dit.keys(): return 20 - else: - res_name = dit[res_name] - return indicator.get(res_name) + res_name = dit[res_name] + return indicator.get(res_name) def preprocess_unbound_bound(bound_ligand_residues, bound_receptor_residues, graph_nodes, pos_cutoff=8.0, inference=False): - + """ + preprocess_unbound_bound + """ ####################### def filter_residues(residues): residues_filtered = [] @@ -639,7 +353,7 @@ def preprocess_unbound_bound(bound_ligand_residues, bound_receptor_residues, gra raise ValueError("The value should <= pos_cutoff") pocket_coors = 0.5 * (ligand_pocket_coors + receptor_pocket_coors) log('Num pocket nodes = ', len(active_ligand), ' total nodes = ', - bound_ligand_repres_nodes_loc_array.shape[0], ' graph_nodes = ', graph_nodes) + bound_ligand_repres_nodes_loc_array.shape[0], ' graph_nodes = ', graph_nodes) return unbound_predic_ligand_clean_list, unbound_predic_receptor_clean_list, \ bound_ligand_repres_nodes_loc_array, bound_receptor_repres_nodes_loc_array, \ @@ -672,7 +386,9 @@ def protein_to_graph_unbound_bound_residuesonly( one_hot=False, residue_loc_is_alphac=True ): - + """ + protein_to_graph_unbound_bound_residuesonly + """ unbound_ligand_predic, unbound_receptor_predic, bound_ligand_repres_nodes_loc_clean_array, \ bound_receptor_repres_nodes_loc_clean_array = unbound_bound_tuple @@ -787,9 +503,8 @@ def protein_to_graph_unbound_bound_residuesonly( # place n_i, u_i, v_i as lines in a 3x3 basis matrix basis_matrix = np.stack((l_or_r_n_i_feat[dst, :], l_or_r_u_i_feat[dst, :], l_or_r_v_i_feat[dst, :]), axis=0) - p_ij = np.matmul(basis_matrix, - l_or_r_residue_representatives_loc_feat[src, :] - l_or_r_residue_representatives_loc_feat[ - dst, :]) + p_ij = np.matmul(basis_matrix, l_or_r_residue_representatives_loc_feat[src, :] - + l_or_r_residue_representatives_loc_feat[dst, :]) q_ij = np.matmul(basis_matrix, l_or_r_n_i_feat[src, :]) # shape (3,) k_ij = np.matmul(basis_matrix, l_or_r_u_i_feat[src, :]) t_ij = np.matmul(basis_matrix, l_or_r_v_i_feat[src, :]) @@ -842,8 +557,8 @@ def protein_to_graph_unbound_bound_residuesonly( valid_dist_np = l_or_r_distance[i, valid_src] sigma = np.array([1., 2., 5., 10., 30.]).reshape((-1, 1)) weights = softmax(- valid_dist_np.reshape((1, -1)) ** 2 / sigma, axis=1) # (sigma_num, neigh_num) - diff_vecs = l_or_r_residue_representatives_loc_feat[valid_dst, :] - l_or_r_residue_representatives_loc_feat[ - valid_src, :] # (neigh_num, 3) + diff_vecs = l_or_r_residue_representatives_loc_feat[valid_dst, :] - \ + l_or_r_residue_representatives_loc_feat[valid_src, :] # (neigh_num, 3) mean_vec = weights.dot(diff_vecs) # (sigma_num, 3) denominator = weights.dot(np.linalg.norm(diff_vecs, axis=1)) # (sigma_num,) mean_vec_ratio_norm = np.linalg.norm(mean_vec, axis=1) / denominator # (sigma_num,) @@ -880,6 +595,9 @@ def protein_to_graph_unbound_bound_residuesonly( def unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, h_feats_ligand, coors_receptor, coors_ligand): + """ + unbatch_hetero_graph + """ list_hetero_graph = [] for i, _ in enumerate(unbatch_list_tensor): if i < len(unbatch_list_tensor) - 1: @@ -900,34 +618,31 @@ def unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, h_feats_ligand, def get_non_lin(input_type, negative_slope): if input_type == 'swish': return nn.SiLU() - elif input_type == 'lkyrelu': + if input_type == 'lkyrelu': return nn.LeakyReLU(alpha=negative_slope) - else: - raise NotImplementedError + raise NotImplementedError def get_layer_norm(layer_norm_type, dim): if layer_norm_type == 'BN': return nn.BatchNorm1d([dim]) - elif layer_norm_type == 'LN': + if layer_norm_type == 'LN': return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, - epsilon=1e-5) - else: - return nn.Identity() + epsilon=1e-5) + return nn.Identity() def get_final_h_layer_norm(layer_norm_type, dim): if layer_norm_type == 'BN': return nn.BatchNorm1d(dim) - elif layer_norm_type == 'LN': + if layer_norm_type == 'LN': return nn.LayerNorm([dim], begin_norm_axis=1, begin_params_axis=1, epsilon=1e-5) - elif layer_norm_type == '0': + if layer_norm_type == '0': return nn.Identity() - else: - raise NotImplementedError + raise NotImplementedError -def apply_final_h_layer_norm(h, node_type, norm_type, norm_layer): +def apply_final_h_layer_norm(h, norm_layer): return norm_layer(h) @@ -955,6 +670,9 @@ def get_mask(ligand_batch_num_nodes, receptor_batch_num_nodes): class Graph: + """ + Graph class for wrapping data + """ def __init__( self, num_nodes=0, @@ -976,8 +694,18 @@ class Graph: class IEGMNLayer(nn.Cell): - - def __init__(self, orig_h_feats_dim, h_feats_dim, out_feats_dim, fine_tune, args, log=None): + """ + IEGMNLayer class + """ + def __init__( + self, + orig_h_feats_dim, + h_feats_dim, + out_feats_dim, + fine_tune, + args, + log_input=None + ): super(IEGMNLayer, self).__init__() @@ -994,7 +722,7 @@ class IEGMNLayer(nn.Cell): self.x_connection_init = args.x_connection_init self.debug = args.debug self.fine_tune = fine_tune - self.log = log + self.log = log_input self.h_feats_dim = h_feats_dim self.out_feats_dim = out_feats_dim self.all_sigmas_dist = [1.5 ** x for x in range(15)] @@ -1046,7 +774,9 @@ class IEGMNLayer(nn.Cell): return "IEGMNLayer " + str(self.__dict__) def edge_mlp_func(self, dropout, nonlin, layer_norm, leakyrelu_neg_slope): - + """ + IEGMNLayer edge_mlp_func + """ edge_mlp = nn.SequentialCell( nn.Dense((self.h_feats_dim * 2) + self.input_edge_feats_dim + len(self.all_sigmas_dist), self.out_feats_dim), @@ -1079,6 +809,9 @@ class IEGMNLayer(nn.Cell): return coors_mlp def x_rel_mag(self, nodes_ligand_x_now, ll_connection_tensor, nodes_receptor_x_now, rr_connection_tensor): + """ + IEGMNLayer x_rel_mag + """ nodes_ligand_x_now_src = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[0]) nodes_ligand_x_now_dst = ops.index_select(nodes_ligand_x_now, 0, ll_connection_tensor[1]) ll_edge_x_rel = ops.sub(nodes_ligand_x_now_src, nodes_ligand_x_now_dst) @@ -1107,7 +840,9 @@ class IEGMNLayer(nn.Cell): return cat_input_for_msg def nodes_aggr_cross_msg(self, h_feats_ligand, h_feats_receptor, mask): - # \mu_i + """ + IEGMNLayer nodes_aggr_cross_msg + """ nodes_ligand_aggr_cross_msg = compute_cross_attention(self.att_mlp_q(h_feats_ligand), self.att_mlp_k(h_feats_receptor), self.att_mlp_v(h_feats_receptor), @@ -1128,46 +863,48 @@ class IEGMNLayer(nn.Cell): return nodes_aggr_msg, x_final def fine_tune_final_lr(self, input_fine_tune_tuple): + """ + IEGMNLayer fine_tune_final_lr + """ x_final_ligand, x_final_receptor, h_feats_ligand, nodes_ligand_x_now, h_feats_receptor, \ nodes_receptor_x_now, mask = input_fine_tune_tuple - x_final_ligand = x_final_ligand + \ - self.att_mlp_cross_coors_v(h_feats_ligand) * ( - nodes_ligand_x_now - - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_ligand), - self.att_mlp_cross_coors_k(h_feats_receptor), - nodes_receptor_x_now, - mask, - self.cross_msgs)) - x_final_receptor = x_final_receptor + \ - self.att_mlp_cross_coors_v(h_feats_receptor) * ( - nodes_receptor_x_now - - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_receptor), - self.att_mlp_cross_coors_k(h_feats_ligand), - nodes_ligand_x_now, - mask.transpose(1, 0), - self.cross_msgs)) + x_final_ligand = x_final_ligand + self.att_mlp_cross_coors_v(h_feats_ligand) * \ + (nodes_ligand_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_ligand), + self.att_mlp_cross_coors_k(h_feats_receptor), + nodes_receptor_x_now, + mask, + self.cross_msgs)) + x_final_receptor = x_final_receptor + self.att_mlp_cross_coors_v(h_feats_receptor) * \ + (nodes_receptor_x_now - compute_cross_attention(self.att_mlp_cross_coors_q(h_feats_receptor), + self.att_mlp_cross_coors_k(h_feats_ligand), + nodes_ligand_x_now, + mask.transpose(1, 0), + self.cross_msgs)) return x_final_ligand, x_final_receptor def skip_connections(self, h_feats_ligand, h_feats_receptor, input_node_upd_ligand, input_node_upd_receptor): - # Skip connections + """ + IEGMNLayer skip_connections + """ if self.h_feats_dim == self.out_feats_dim: node_upd_ligand = self.skip_weight_h * self.node_mlp(input_node_upd_ligand) + ( - 1. - self.skip_weight_h) * h_feats_ligand + 1. - self.skip_weight_h) * h_feats_ligand node_upd_receptor = self.skip_weight_h * self.node_mlp(input_node_upd_receptor) + ( - 1. - self.skip_weight_h) * h_feats_receptor + 1. - self.skip_weight_h) * h_feats_receptor else: node_upd_ligand = self.node_mlp(input_node_upd_ligand) node_upd_receptor = self.node_mlp(input_node_upd_receptor) - node_upd_ligand = apply_final_h_layer_norm(node_upd_ligand, 'ligand', self.final_h_layer_norm, - self.final_h_layernorm_layer) - node_upd_receptor = apply_final_h_layer_norm(node_upd_receptor, 'receptor', self.final_h_layer_norm, - self.final_h_layernorm_layer) + node_upd_ligand = apply_final_h_layer_norm(node_upd_ligand, self.final_h_layernorm_layer) + node_upd_receptor = apply_final_h_layer_norm(node_upd_receptor, self.final_h_layernorm_layer) return node_upd_ligand, node_upd_receptor def log_debug_info(self, debug_variable_tuple): + """ + IEGMNLayer log_debug_info + """ nodes_ligand_x_now, nodes_ligand_feat, ll_edge_x_rel, x_rel_mag_ligand, edges_ll_msg, \ nodes_ligand_aggr_cross_msg, edge_coef_ligand, edges_ll_x_moment, nodes_ligand_aggr_msg, \ x_final_ligand = debug_variable_tuple @@ -1178,13 +915,15 @@ class IEGMNLayer(nn.Cell): 'x_rel_mag_ligand = [exp(-||x_i - x_j||^2 / sigma) for sigma = 1.5 ** x, x = [0, 15]]') self.log(ops.max(edges_ll_msg)[0], 'data[msg] = m_{i->j} = phi^e(h_i, h_j, f_{i,j}, x_rel_mag_ligand)') self.log(ops.max(nodes_ligand_aggr_cross_msg)[0], 'aggr_cross_msg(i) = sum_j a_{i,j} * h_j') - self.log(ops.max(edge_coef_ligand)[0], 'edge_coef_ligand : \phi^x(m_{i->j})') - self.log(ops.max(edges_ll_x_moment)[0], 'data[x_moment] = (x_i - x_j) * \phi^x(m_{i->j})') - self.log(ops.max(nodes_ligand_aggr_msg)[0], 'data[aggr_msg]: \sum_j m_{i->j} ') + self.log(ops.max(edge_coef_ligand)[0], 'edge_coef_ligand : ' + r'\p' + 'hi^x(m_{i->j})') + self.log(ops.max(edges_ll_x_moment)[0], 'data[x_moment] = (x_i - x_j) * ' + r'\p' + 'hi^x(m_{i->j})') + self.log(ops.max(nodes_ligand_aggr_msg)[0], 'data[aggr_msg]: ' + r'\s' + 'um_j m_{i->j}') self.log(ops.max(x_final_ligand)[0], 'x_i new = x_final_ligand : x_i + data[x_update]') def construct(self, iegmn_layer_tuple, input_tensor_tuple): - + """ + IEGMNLayer construct + """ ligand_graph_num_nodes, receptor_graph_num_nodes, ll_connection_tensor, rr_connection_tensor = \ input_tensor_tuple[0], input_tensor_tuple[1], input_tensor_tuple[4], input_tensor_tuple[5] @@ -1252,13 +991,15 @@ class IEGMNLayer(nn.Cell): class IEGMN(nn.Cell): - - def __init__(self, args, n_lays, fine_tune, log=None): + """ + IEGMN class + """ + def __init__(self, args, n_lays, fine_tune, log_input=None): super(IEGMN, self).__init__() self.debug = args.debug - self.log = log + self.log = log_input self.graph_nodes = args.graph_nodes self.rot_model = args.rot_model self.iegmn_lay_hid_dim = args.iegmn_lay_hid_dim @@ -1302,24 +1043,27 @@ class IEGMN(nn.Cell): return "IEGMN " + str(self.__dict__) def iegmn_layers_func(self, args, input_node_feats_dim, n_lays, fine_tune): + """ + IEGMN iegmn_layers_func + """ iegmn_layers = nn.CellList() iegmn_layers.append( IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, - h_feats_dim=input_node_feats_dim, - out_feats_dim=self.iegmn_lay_hid_dim, - fine_tune=fine_tune, - args=args, - log=self.log)) + h_feats_dim=input_node_feats_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + fine_tune=fine_tune, + args=args, + log_input=self.log)) if args.shared_layers: interm_lay = IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, - h_feats_dim=self.iegmn_lay_hid_dim, - out_feats_dim=self.iegmn_lay_hid_dim, - args=args, - fine_tune=fine_tune, - log=self.log) + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log_input=self.log) for _ in range(1, n_lays): iegmn_layers.append(interm_lay) @@ -1327,15 +1071,18 @@ class IEGMN(nn.Cell): for _ in range(1, n_lays): iegmn_layers.append( IEGMNLayer(orig_h_feats_dim=input_node_feats_dim, - h_feats_dim=self.iegmn_lay_hid_dim, - out_feats_dim=self.iegmn_lay_hid_dim, - args=args, - fine_tune=fine_tune, - log=self.log)) + h_feats_dim=self.iegmn_lay_hid_dim, + out_feats_dim=self.iegmn_lay_hid_dim, + args=args, + fine_tune=fine_tune, + log_input=self.log)) return iegmn_layers def att_weights_rot(self, h_feats, h_feats_att_mean_rot, d, z_coors, all_y_att_rot_list): + """ + IEGMN att_weights_rot + """ att_weights_rot = ops.softmax( self.att_mlp_key_rot(h_feats).view(-1, self.num_att_heads, d) \ .transpose(1, 0, 2) @ # (K_heads, m_rec, d) @@ -1350,7 +1097,9 @@ class IEGMN(nn.Cell): return y_att_rot, all_y_att_rot_list def ap_compute(self, list_hetero_graph): - + """ + IEGMN ap_compute + """ all_t_align_list, all_b_align_list, all_y_receptor_att_rot_list, all_y_ligand_att_rot_list = [], [], [], [] for _, hetero_graph in enumerate(list_hetero_graph): @@ -1387,7 +1136,7 @@ class IEGMN(nn.Cell): y_ligand_att_rot_mean = y_ligand_att_rot.mean(axis=0, keep_dims=True) # (1,3) a = (y_receptor_att_rot - y_receptor_att_rot_mean).transpose(1, 0) @ ( - y_ligand_att_rot - y_ligand_att_rot_mean) # 3, 3 + y_ligand_att_rot - y_ligand_att_rot_mean) # 3, 3 if ops.isnan(a).any(): raise ValueError("There is Nan in a") @@ -1395,8 +1144,8 @@ class IEGMN(nn.Cell): u, s, vt = Tensor(u), Tensor(s), Tensor(vt) num_it = 0 - while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - (s ** 2).view(3, 1) + ops.eye(3)))[ - 0] < 1e-2: + while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - + (s ** 2).view(3, 1) + ops.eye(3)))[0] < 1e-2: if self.debug: self.log('S inside loop ', num_it, ' is ', s, ' and A = ', a) @@ -1432,6 +1181,9 @@ class IEGMN(nn.Cell): unbatch_list_tensor, input_tensor_tuple, ): + """ + IEGMN construct + """ ligand_graph_edge_tensor = input_tensor_tuple[2] receptor_graph_edge_tensor = input_tensor_tuple[3] @@ -1480,7 +1232,7 @@ class IEGMN(nn.Cell): self.log(ops.max(coors_ligand)[0], 'coors_ligand before after MPNN') list_hetero_graph = unbatch_hetero_graph(unbatch_list_tensor, h_feats_receptor, - h_feats_ligand, coors_receptor, coors_ligand) + h_feats_ligand, coors_receptor, coors_ligand) ap_res = self.ap_compute(list_hetero_graph) @@ -1488,17 +1240,19 @@ class IEGMN(nn.Cell): class RigidBodyDockingNet(nn.Cell): - - def __init__(self, args, log=None): + """ + RigidBodyDockingNet + """ + def __init__(self, args, log_input=None): super(RigidBodyDockingNet, self).__init__() self.debug = args.debug - self.log = log + self.log = log_input - self.iegmn_original = IEGMN(args, n_lays=args.iegmn_n_lays, fine_tune=False, log=log) + self.iegmn_original = IEGMN(args, n_lays=args.iegmn_n_lays, fine_tune=False, log_input=log_input) if args.fine_tune: - self.iegmn_fine_tune = IEGMN(args, n_lays=2, fine_tune=True, log=log) + self.iegmn_fine_tune = IEGMN(args, n_lays=2, fine_tune=True, log_input=log_input) self.list_iegmns = [('original', self.iegmn_original), ('finetune', self.iegmn_fine_tune)] else: self.list_iegmns = [('finetune', self.iegmn_original)] # just original @@ -1514,6 +1268,9 @@ class RigidBodyDockingNet(nn.Cell): unbatch_list_tensor, input_tensor_tuple, ): + """ + construct + """ last_outputs = None all_ligand_coors_deform_list = [] @@ -1577,4 +1334,3 @@ class RigidBodyDockingNet(nn.Cell): return all_ligand_coors_deform_list, \ all_keypts_ligand_list, all_keypts_receptor_list, \ all_rotation_list, all_translation_list - diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py new file mode 100644 index 000000000..01c1542fb --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py @@ -0,0 +1,246 @@ +""" +train_utils +""" +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import os +import math + +import ot +import numpy as np +import mindspore as ms +from mindspore import ops, Tensor, Parameter + +from .nn_arch import ( + preprocess_unbound_bound, + protein_to_graph_unbound_bound, + get_residues, + log, +) + + +FLAGS = os.O_RDWR | os.O_CREAT + + +def create_dir(path): + if os.path.exists(path): + raise FileExistsError('Path already exists. Please delete and restart your job.') + os.makedirs(path, exist_ok=False) + + +def graph_to_tensor(ligand_graph, receptor_graph): + """ + graph_to_tensor + """ + ligand_graph.ndata['new_x'] = ligand_graph.ndata['x'] + + # Create a batch of a single heterograph + ligand_graph_node_tensor = ops.cat( + (ligand_graph.ndata["res_feat"], + ligand_graph.ndata["x"], + ligand_graph.ndata["new_x"], + ligand_graph.ndata["mu_r_norm"]), + axis=1 + ) + + receptor_graph_node_tensor = ops.cat( + (receptor_graph.ndata["res_feat"], + receptor_graph.ndata["x"], + receptor_graph.ndata["mu_r_norm"]), + axis=1 + ) + + ligand_graph_num_nodes = Tensor([ligand_graph.num_nodes]) + receptor_graph_num_nodes = Tensor([receptor_graph.num_nodes]) + ligand_graph_edge_tensor = ligand_graph.edata["he"] + receptor_graph_edge_tensor = receptor_graph.edata["he"] + + ll_connection_tensor = ops.stack( + (Tensor(ligand_graph.src_list), Tensor(ligand_graph.dst_list))) + + rr_connection_tensor = ops.stack( + (Tensor(receptor_graph.src_list), Tensor(receptor_graph.dst_list))) + + unbatch_list = [ + [ + len(ll_connection_tensor[0]), + len(rr_connection_tensor[0]), + len(ligand_graph_node_tensor), + len(receptor_graph_node_tensor), + 0, + ] + ] + unbatch_list.insert(0, np.array([0, 0, 0, 0, 0])) + + unbatch_list = Tensor(unbatch_list, ms.int32) + + input_tensor_tuple = ( + ligand_graph_num_nodes, + receptor_graph_num_nodes, + ligand_graph_edge_tensor, + receptor_graph_edge_tensor, + ll_connection_tensor, + rr_connection_tensor, + ) + + return ligand_graph_node_tensor, receptor_graph_node_tensor, unbatch_list, input_tensor_tuple + + +def prepare_graphs(args, ppdb_ligand, ligand_filename, receptor_filename): + """ + prepare_graphs + """ + unbound_ligand_all_atoms_pre_pos = ppdb_ligand.df["ATOM"][ + ['x_coord', 'y_coord', 'z_coord']].to_numpy().squeeze().astype(np.float32) + + unbound_predic_ligand, \ + unbound_predic_receptor, \ + bound_ligand_repres_nodes_loc_clean_array, \ + bound_receptor_repres_nodes_loc_clean_array, _ = preprocess_unbound_bound( + get_residues(ligand_filename), get_residues(receptor_filename), + graph_nodes=args.graph_nodes, pos_cutoff=args.pocket_cutoff, inference=True) + + protein_to_graph_unbound_bound_input_tuple = ( + unbound_predic_ligand, unbound_predic_receptor, + bound_ligand_repres_nodes_loc_clean_array, bound_receptor_repres_nodes_loc_clean_array, + ) + + ligand_graph, receptor_graph = protein_to_graph_unbound_bound( + protein_to_graph_unbound_bound_input_tuple, + cutoff=args.graph_cutoff, + max_neighbor=args.graph_max_neighbor, + one_hot=False, + residue_loc_is_alphac=args.graph_residue_loc_is_alphaC, + ) + + return ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array + + +def get_rot_mat(euler_angles): + """ + get_rot_mat + """ + roll = euler_angles[0] + yaw = euler_angles[1] + pitch = euler_angles[2] + + tensor_0 = Tensor(0.0) + tensor_1 = Tensor(1.0) + cos = ops.cos + sin = ops.sin + + rx = ops.stack([ + ops.stack([tensor_1, tensor_0, tensor_0]), + ops.stack([tensor_0, cos(roll), -sin(roll)]), + ops.stack([tensor_0, sin(roll), cos(roll)])]).reshape(3, 3) + + ry = ops.stack([ + ops.stack([cos(pitch), tensor_0, sin(pitch)]), + ops.stack([tensor_0, tensor_1, tensor_0]), + ops.stack([-sin(pitch), tensor_0, cos(pitch)])]).reshape(3, 3) + + rz = ops.stack([ + ops.stack([cos(yaw), -sin(yaw), tensor_0]), + ops.stack([sin(yaw), cos(yaw), tensor_0]), + ops.stack([tensor_0, tensor_0, tensor_1])]).reshape(3, 3) + + r = ops.mm(rz, ry) + r = ops.mm(r, rx) + + return r + + +def compute_ot_emd(cost_mat): + cost_mat_detach = cost_mat.asnumpy() + a = np.ones([cost_mat.shape[0]]) / cost_mat.shape[0] + b = np.ones([cost_mat.shape[1]]) / cost_mat.shape[1] + ot_mat = ot.emd(a=a, b=b, M=cost_mat_detach, numItermax=10000) + ot_mat_attached = Parameter(Tensor(ot_mat, ms.float32), requires_grad=False) + ot_dist = ops.sum(ot_mat_attached * cost_mat) + return ot_dist, ot_mat_attached + + +def g_fn(protein_coords, x, sigma): + # protein_coords: (n,3) , x: (m,3), output: (m,) + e = ops.exp(- ops.sum((protein_coords.view(1, -1, 3) - x.view(-1, 1, 3)) ** 2, dim=2) / float(sigma)) # (m, n) + + return - sigma * ops.log(1e-3 + ops.sum(e, dim=1)) + + +def compute_body_intersection_loss( + model_ligand_coors_deform, + bound_receptor_repres_nodes_loc_array, + sigma, + surface_ct, + ): + """ + compute_body_intersection_loss + """ + loss = ops.mean( + ops.clamp(surface_ct - g_fn(Tensor(bound_receptor_repres_nodes_loc_array), model_ligand_coors_deform, sigma), + min=0)) + \ + ops.mean(ops.clamp( + surface_ct - g_fn(model_ligand_coors_deform, Tensor(bound_receptor_repres_nodes_loc_array), sigma), + min=0)) + + return loss + + +def compute_sq_dist_mat(x_1, x_2): + '''Computes the l2 squared cost matrix between two point cloud inputs. + Args: + X_1: [n, #features] point cloud, tensor + X_2: [m, #features] point cloud, tensor + Output: + [n, m] matrix of the l2 distance between point pairs + ''' + n_1, _ = x_1.shape + n_2, _ = x_2.shape + x_1 = x_1.view(n_1, 1, -1) + x_2 = x_2.view(1, n_2, -1) + squared_dist = (x_1 - x_2) ** 2 + cost_mat = ops.sum(squared_dist, dim=2) + return cost_mat + + +def pretty_print_stats(*args): + + split_type, epoch, total_num_epochs,\ + complex_rmsd_mean, complex_rmsd_median,\ + ligand_rmsd_mean, ligand_rmsd_median,\ + receptor_rmsd_mean, receptor_rmsd_median,\ + _, _, _, avg_loss_ot, avg_loss_intersection, _ = args + + log('[{:s}] --> epoch {:d}/{:d} ' + '|| mean/median complex rmsd {:.4f} / {:.4f} ' + '|| mean/median ligand rmsd {:.4f} / {:.4f} ' + '|| mean/median sqrt pocket OT loss {:.4f} ' + '|| intersection loss {:.4f} ' + '|| mean/median receptor rmsd {:.4f} / {:.4f} '. + format(split_type, + epoch, total_num_epochs, + complex_rmsd_mean, complex_rmsd_median, + ligand_rmsd_mean, ligand_rmsd_median, + math.sqrt(avg_loss_ot), + avg_loss_intersection, + receptor_rmsd_mean, receptor_rmsd_median)) -- Gitee From 7336b3f3bf42c4b790d2ab38ba7c2e317120c7cc Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 13 Aug 2024 21:05:31 +0800 Subject: [PATCH 03/16] slightly change --- .../pipeline/models/equidock/equidock.py | 18 +++++++++---- .../pipeline/models/equidock/equidock_data.py | 6 +++-- .../models/equidock/equidock_dataset.py | 19 ++++++++------ .../pipeline/models/equidock/nn_arch.py | 26 +++++++++++-------- .../pipeline/models/equidock/train_utils.py | 3 --- 5 files changed, 43 insertions(+), 29 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py index 2a74e2956..3e0a665f7 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock.py @@ -33,7 +33,9 @@ import mindspore as ms from mindspore import nn, Tensor, ops, save_checkpoint from mindspore.experimental import optim -from .train_utils import * +from .train_utils import create_dir, prepare_graphs, graph_to_tensor, \ + pretty_print_stats, get_rot_mat, compute_sq_dist_mat, compute_ot_emd, \ + compute_body_intersection_loss from .nn_arch import MeterUnboundBound, RigidBodyDockingNet, log, FLAGS from .equidock_dataset import EquiDockDataSet from ..model import Model @@ -72,14 +74,14 @@ class EquiDock(Model): with os.fdopen(os.open(log_file_name, FLAGS, 777), 'a+') as fout: fout.write('[' + str(dt.now()) + '] START\n') self.config.checkpoint_filename = os.path.join( - self.config.checkpoint_dir, + self.config.checkpoint_dir, self.config.data + '_model_best.pth' ) self.loss_fn_coors = nn.MSELoss(reduction='mean') self.optimizer = optim.Adam( - self.network.trainable_params(), - lr=self.config.lr, + self.network.trainable_params(), + lr=self.config.lr, weight_decay=self.config.w_decay ) self.scheduler = optim.lr_scheduler.LambdaLR( @@ -132,7 +134,7 @@ class EquiDock(Model): ppdb_ligand = PandasPdb().read_pdb(ligand_filename) - ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, bound_ligand_repres_nodes_loc_clean_array\ + ligand_graph, receptor_graph, unbound_ligand_all_atoms_pre_pos, _\ = prepare_graphs(self.config, ppdb_ligand, ligand_filename, receptor_filename) if self.config.input_edge_feats_dim < 0: @@ -187,6 +189,9 @@ class EquiDock(Model): def run_an_eval_epoch(self, run_epoch_tuple, data_batched, data_loader, epoch, args): + """ + run_an_eval_epoch + """ val_complex_rmsd_mean, val_complex_rmsd_median, \ val_ligand_rmsd_mean, val_ligand_rmsd_median, \ val_receptor_rmsd_mean, val_receptor_rmsd_median, \ @@ -207,6 +212,9 @@ class EquiDock(Model): def run_a_generic_epoch(self, ep_type, run_epoch_tuple, data_batched, data_loader): + """ + run_a_generic_epoch + """ args, self.network, loss_fn_coors, optimizer = run_epoch_tuple meter = MeterUnboundBound() diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py index a71b96d0c..0259c322a 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_data.py @@ -155,7 +155,7 @@ class UnboundBoundData(): both_proteins_to_graph_pair_list, bound_ligand_repres_nodes_loc_array_list, bound_receptor_repres_nodes_loc_array_list, - pocket_coors_list, + pocket_coors_list, ) def save_processed_data( @@ -167,7 +167,9 @@ class UnboundBoundData(): bound_receptor_repres_nodes_loc_array_list, pocket_coors_list, ): - + """ + save_processed_data + """ ligand_graph_list, receptor_graph_list = [], [] for result in both_proteins_to_graph_pair_list: ligand_graph, receptor_graph = result diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py index 82bd2e1ee..c1c7cdc66 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py @@ -51,31 +51,34 @@ class EquiDockDataSet(PSP): return files def set_training_data_src(self, data_source, **kwargs): + """ + set_training_data_src + """ self.log(data_source, **kwargs) if not os.path.exists(self.config.processed_dataset_path): UnboundBoundData( - self.config, - reload_mode='val', + self.config, + reload_mode='val', raw_data_path=self.config.raw_data_path, split_files_path=self.config.split_files_path, ) UnboundBoundData( - self.config, - reload_mode='test', + self.config, + reload_mode='test', raw_data_path=self.config.raw_data_path, split_files_path=self.config.split_files_path, ) UnboundBoundData( - self.config, - reload_mode='train', + self.config, + reload_mode='train', raw_data_path=self.config.raw_data_path, split_files_path=self.config.split_files_path, ) self.train_data_batched, self.train_loader = self.create_dataloader( - self.config.train_dir, - bs=self.config.bs, + self.config.train_dir, + bs=self.config.bs, shuffle=True, ) self.val_data_batched, self.val_loader = self.create_dataloader( diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py index 37c5ae4b5..f8fa452bf 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py @@ -32,7 +32,6 @@ from numpy import linalg as LA from biopandas.pdb import PandasPdb import scipy.spatial as spa from scipy.special import softmax -import mindspore as ms from mindspore import nn, ops, Tensor @@ -503,7 +502,7 @@ def protein_to_graph_unbound_bound_residuesonly( # place n_i, u_i, v_i as lines in a 3x3 basis matrix basis_matrix = np.stack((l_or_r_n_i_feat[dst, :], l_or_r_u_i_feat[dst, :], l_or_r_v_i_feat[dst, :]), axis=0) - p_ij = np.matmul(basis_matrix, l_or_r_residue_representatives_loc_feat[src, :] - + p_ij = np.matmul(basis_matrix, l_or_r_residue_representatives_loc_feat[src, :] - l_or_r_residue_representatives_loc_feat[dst, :]) q_ij = np.matmul(basis_matrix, l_or_r_n_i_feat[src, :]) # shape (3,) k_ij = np.matmul(basis_matrix, l_or_r_u_i_feat[src, :]) @@ -647,7 +646,9 @@ def apply_final_h_layer_norm(h, norm_layer): def compute_cross_attention(queries, keys, values, mask, cross_msgs): - # Compute cross attention + """ + compute_cross_attention + """ if not cross_msgs: return queries * 0. a = mask * ops.mm(queries, ops.transpose(keys, (1, 0))) - 1000. * (1. - mask) @@ -657,6 +658,9 @@ def compute_cross_attention(queries, keys, values, mask, cross_msgs): def get_mask(ligand_batch_num_nodes, receptor_batch_num_nodes): + """ + get_mask + """ rows = sum(ligand_batch_num_nodes) cols = sum(receptor_batch_num_nodes) mask = ops.zeros((int(rows), int(cols))) @@ -698,13 +702,13 @@ class IEGMNLayer(nn.Cell): IEGMNLayer class """ def __init__( - self, - orig_h_feats_dim, - h_feats_dim, - out_feats_dim, - fine_tune, - args, - log_input=None + self, + orig_h_feats_dim, + h_feats_dim, + out_feats_dim, + fine_tune, + args, + log_input=None, ): super(IEGMNLayer, self).__init__() @@ -1144,7 +1148,7 @@ class IEGMN(nn.Cell): u, s, vt = Tensor(u), Tensor(s), Tensor(vt) num_it = 0 - while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - + while ops.min(s)[0] < 1e-3 or ops.min(ops.abs((s ** 2).view(1, 3) - (s ** 2).view(3, 1) + ops.eye(3)))[0] < 1e-2: if self.debug: self.log('S inside loop ', num_it, ' is ', s, ' and A = ', a) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py index 01c1542fb..c3e0d8fdb 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/train_utils.py @@ -39,9 +39,6 @@ from .nn_arch import ( ) -FLAGS = os.O_RDWR | os.O_CREAT - - def create_dir(path): if os.path.exists(path): raise FileExistsError('Path already exists. Please delete and restart your job.') -- Gitee From 19e835de7bdabe022c21ace938f0dbd29d0e964f Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 13 Aug 2024 21:20:19 +0800 Subject: [PATCH 04/16] disable pylint --- .jenkins/check/config/filter_pylint.txt | 1 + .../mindsponge/pipeline/models/equidock/equidock_dataset.py | 3 +++ 2 files changed, 4 insertions(+) diff --git a/.jenkins/check/config/filter_pylint.txt b/.jenkins/check/config/filter_pylint.txt index f68d58fa7..6bec150a3 100644 --- a/.jenkins/check/config/filter_pylint.txt +++ b/.jenkins/check/config/filter_pylint.txt @@ -159,3 +159,4 @@ "mindscience/MindSPONGE/tutorials/basic/tutorial_p03.py" "wrong-import-position" "mindscience/MindSPONGE/tutorials/basic/tutorial_p04.py" "wrong-import-position" "mindscience/MindSPONGE/tutorials/basic/tutorial_p05.py" "wrong-import-position" +"mindscience/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py" "arguments-differ" diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py index c1c7cdc66..a7e675178 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/equidock_dataset.py @@ -23,6 +23,9 @@ equidock_dataset # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ + +#pylint: disable=W0221 + import os from datetime import datetime as dt -- Gitee From db64e9a87282d521898ef26abe88b2e52276fb0a Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 13 Aug 2024 21:28:46 +0800 Subject: [PATCH 05/16] small change --- MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py index f8fa452bf..f5d000309 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/equidock/nn_arch.py @@ -708,8 +708,7 @@ class IEGMNLayer(nn.Cell): out_feats_dim, fine_tune, args, - log_input=None, - ): + log_input=None): super(IEGMNLayer, self).__init__() -- Gitee From d710b4effdad1c8d695c22733e0da05e7bc842fe Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 3 Sep 2024 19:45:43 +0800 Subject: [PATCH 06/16] progen first commit --- .../model_configs/EquiDock/predict_db5.yaml | 39 + .../model_configs/EquiDock/predict_dips.yaml | 42 + .../model_configs/EquiDock/train_db5.yaml | 55 + .../model_configs/ProGen/small.yaml | 37 + .../mindsponge/pipeline/models/__init__.py | 1 + .../pipeline/models/progen/__init__.py | 28 + .../progen/module/configuration_utils.py | 3543 +++++++++++++++++ .../models/progen/module/injection.py | 933 +++++ .../models/progen/module/logits_process.py | 1427 +++++++ .../pipeline/models/progen/nn_arch.py | 692 ++++ .../pipeline/models/progen/progen.py | 281 ++ .../models/progen/progen_configuration.py | 29 + .../pipeline/models/progen/progen_dataset.py | 49 + .../pipeline/models/progen/small.yaml | 37 + .../pipeline/models/progen/tokenizer.json | 91 + .../src/mindsponge/pipeline/pipeline.py | 3 +- 16 files changed, 7286 insertions(+), 1 deletion(-) create mode 100644 MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml create mode 100644 MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml create mode 100644 MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml create mode 100644 MindSPONGE/applications/model_configs/ProGen/small.yaml create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json diff --git a/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml b/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml new file mode 100644 index 000000000..b6058bda3 --- /dev/null +++ b/MindSPONGE/applications/model_configs/EquiDock/predict_db5.yaml @@ -0,0 +1,39 @@ +data: db5 +is_train: False +data_fraction: 1.0 +split: 0 +dropout: 1.0 +x_connection_init: 0.0 +skip_weight_h: 0.5 +num_att_heads: 50 +iegmn_n_lays: 5 +leakyrelu_neg_slope: 0.01 +input_edge_feats_dim: 27 +residue_emb_dim: 64 +graph_max_neighbor: 10 +pocket_cutoff: 8.0 +graph_cutoff: 30.0 +layer_norm_coors: '0' +final_h_layer_norm: '0' +nonlin: lkyrelu +layer_norm: LN +graph_nodes: 'residues' +rot_model: 'kb_att' +iegmn_lay_hid_dim: 64 +noise_decay_rate: 0.0 +noise_initial: 0.0 +use_edge_features_in_gmn: False +use_mean_node_features: False +debug: False +fine_tune: False +shared_layers: True +use_dist_in_layers: True +cross_msgs: True +use_edge_features_in_gmn: True +use_mean_node_features: True +graph_residue_loc_is_alphaC: True +input_dir: './test_sets_pdb/db5_test_random_transformed/random_transformed/' +ground_truth_dir: './test_sets_pdb/db5_test_random_transformed/complexes/' +output_dir: './test_sets_pdb/db5_equidock_results/' +ckpt_dir: './db5_pretrained.ckpt' + diff --git a/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml b/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml new file mode 100644 index 000000000..8e5e35699 --- /dev/null +++ b/MindSPONGE/applications/model_configs/EquiDock/predict_dips.yaml @@ -0,0 +1,42 @@ +data: dips +is_train: False +data_fraction: 1.0 +split: 0 +dropout: 1.0 +x_connection_init: 0.0 +skip_weight_h: 0.75 +num_att_heads: 50 +iegmn_n_lays: 8 +leakyrelu_neg_slope: 0.01 +input_edge_feats_dim: 27 +residue_emb_dim: 64 +graph_max_neighbor: 10 +pocket_cutoff: 8.0 +graph_cutoff: 30.0 +layer_norm_coors: '0' +final_h_layer_norm: '0' +nonlin: lkyrelu +layer_norm: LN +graph_nodes: 'residues' +rot_model: 'kb_att' +iegmn_lay_hid_dim: 64 +noise_decay_rate: 0.0 +noise_initial: 0.0 +use_edge_features_in_gmn: False +use_mean_node_features: False +debug: False +fine_tune: False +shared_layers: False +use_dist_in_layers: True +cross_msgs: True +use_edge_features_in_gmn: True +use_mean_node_features: True +graph_residue_loc_is_alphaC: True +input_dir: './test_sets_pdb/dips_test_random_transformed/random_transformed/' +ground_truth_dir: './test_sets_pdb/dips_test_random_transformed/complexes/' +output_dir: './test_sets_pdb/dips_equidock_results/' +ckpt_dir: './dips_pretrained.ckpt' + + + + diff --git a/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml b/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml new file mode 100644 index 000000000..72851de6d --- /dev/null +++ b/MindSPONGE/applications/model_configs/EquiDock/train_db5.yaml @@ -0,0 +1,55 @@ +data: db5 +is_train: True +data_fraction: 1.0 +split: 0 +dropout: 1.0 +x_connection_init: 0.0 +skip_weight_h: 0.5 +num_att_heads: 50 +iegmn_n_lays: 5 +leakyrelu_neg_slope: 0.01 +input_edge_feats_dim: 27 +residue_emb_dim: 64 +graph_max_neighbor: 10 +pocket_cutoff: 8.0 +graph_cutoff: 30.0 +layer_norm_coors: '0' +final_h_layer_norm: '0' +nonlin: lkyrelu +layer_norm: LN +graph_nodes: 'residues' +rot_model: 'kb_att' +translation_interval: 5.0 +n_jobs: 40 +iegmn_lay_hid_dim: 64 +noise_decay_rate: 0.0 +noise_initial: 0.0 +use_edge_features_in_gmn: False +use_mean_node_features: False +debug: False +fine_tune: False +shared_layers: True +use_dist_in_layers: True +cross_msgs: True +use_edge_features_in_gmn: True +use_mean_node_features: True +graph_residue_loc_is_alphaC: True +processed_dataset_path: './processed_datasets/' +raw_data_path: './data/benchmark5.5/structures' +split_files_path: './data/benchmark5.5/cv/cv_0' +ckpt_dir: './Train_Mode_Not_Used' +train_dir: './processed_datasets/train' +val_dir: './processed_datasets/val' +test_dir: './processed_datasets/test' +lr: 0.00005 +w_decay: 0.0001 +scheduler: 'warmup' +warmup: 1.0 +num_epochs: 1000 +clip: 100.0 +bs: 20 +intersection_sigma: 25.0 +intersection_surface_ct: 10.0 +pocket_ot_loss_weight: 1.0 +intersection_loss_weight: 10.0 + diff --git a/MindSPONGE/applications/model_configs/ProGen/small.yaml b/MindSPONGE/applications/model_configs/ProGen/small.yaml new file mode 100644 index 000000000..7b40024c2 --- /dev/null +++ b/MindSPONGE/applications/model_configs/ProGen/small.yaml @@ -0,0 +1,37 @@ +config: "small" +model: "progen2-small" +p: 0.95 +t: 0.2 +max_length: 8 +num_samples: 1 +context: "1" +vocab_size: 32 +n_positions: 1024 +n_ctx: 2048 +n_embd: 1024 +n_layer: 12 +n_head: 16 +rotary_dim: 32 +n_inner: None +activation_function: "gelu_new" +resid_pdrop: 1.0 +embd_pdrop: 1.0 +attn_pdrop: 1.0 +layer_norm_epsilon: 0.00001 +initializer_range: 0.02 +scale_attn_weights: True +gradient_checkpointing: False +use_cache: True +bos_token_id: 1 +eos_token_id: 2 +min_length: 1 +ckpt_dir: './progen2-small.ckpt' +tokenizer_file: './tokenizer.json' +rng_seed: 42 +rng_deterministic: True +fp16: True +sanity: True +x_uniref90bfd30: '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1' +x_oas: '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2' +x_bfd90: '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' +x_data: '2PAQGRARLAAHYGTGRIGREVTVDERCRNLDRLEPSWELLRLLDDMGFIEGQNGLRRYVAEVFALDEPYDMTWRLRSLDEPHEVNAIEFAAPHERVYATLSERFFPDSVERDLRELVTRSLVEVDLGDPFTPPFVNSVYELRGASRRWVGVVRDVLAPDVLPCDATIRVLADAGTRAATRGLREILDTESGRVCVLGLHAALDAIADDRNEVSTSVAVADLEQCVALREAIRQITPRGAISVLVKGPLRTSGMRAQIAAVVHLRAKSSHLLPGGTDVVTFGAREFAIRSAANERKVVASMRLLALPGFAERSLCGLARPGVGRGRWEPAINVSVAADRDQIDLRVMGADVGDASVIFLKRDFRKLTEEFWRTHTDVPIEREDVSAQRTEPDNRWRWLVPCDDLVAPRLTVVPPRSVGHGM1' diff --git a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py index 8e8fff3c5..8f0d01313 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/__init__.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/__init__.py @@ -36,3 +36,4 @@ from .proteinmpnn import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configurat from .ufold import UFold, UFoldDataSet, ufold_configuration from .rasp import RASP, RASPDataSet, rasp_configuration from .equidock import EquiDock, EquiDockDataSet, equidock_configuration +from .progen import ProGen, ProGenDataSet, progen_configuration diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py new file mode 100644 index 000000000..04a85bf5e --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py @@ -0,0 +1,28 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" + +from .progen import ProGen +from .progen_dataset import ProGenDataSet +from .progen_configuration import progen_configuration + diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py new file mode 100644 index 000000000..be9376f11 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -0,0 +1,3543 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +import copy +from copy import deepcopy +import json +import os +import warnings +import inspect +from typing import Optional, List, Callable, Dict, Any, Tuple, Union, Iterable +from enum import Enum +from collections import OrderedDict, UserDict +from dataclasses import fields +from dataclasses import dataclass + +import numpy as np +import mindspore +from mindspore import nn, ops, Tensor, Parameter, jit_class + +from .logits_process import ( + EncoderNoRepeatNGramLogitsProcessor, + EncoderRepetitionPenaltyLogitsProcessor, + EpsilonLogitsWarper, + EtaLogitsWarper, + ExponentialDecayLengthPenalty, + ForcedBOSTokenLogitsProcessor, + ForcedEOSTokenLogitsProcessor, + ForceTokensLogitsProcessor, + HammingDiversityLogitsProcessor, + InfNanRemoveLogitsProcessor, + LogitNormalization, + LogitsProcessorList, + MinLengthLogitsProcessor, + MinNewTokensLengthLogitsProcessor, + NoBadWordsLogitsProcessor, + NoRepeatNGramLogitsProcessor, + PrefixConstrainedLogitsProcessor, + RepetitionPenaltyLogitsProcessor, + SequenceBiasLogitsProcessor, + SuppressTokensAtBeginLogitsProcessor, + SuppressTokensLogitsProcessor, + TemperatureLogitsWarper, + TopKLogitsWarper, + TopPLogitsWarper, + TypicalLogitsWarper, + UnbatchedClassifierFreeGuidanceLogitsProcessor, +) + +DEFAULT_DTYPE = mindspore.float32 +INIT_WEIGHTS_FLAG = True + + +def _is_mindspore(x): + + return isinstance(x, mindspore.Tensor) + +def is_mindspore_available(): + return mindspore.get_context('device_target') == 'Ascend' + +def is_mindspore_tensor(x): + """ + Tests if `x` is a torch tensor or not. Safe to call even if torch is not installed. + """ + return False if not is_mindspore_available() else _is_mindspore(x) + +def no_grad(func): + """no grad wrapper""" + def wrapper(*args, **kwargs): + _pynative_executor.set_enable_grad(False) + outputs = func(*args, **kwargs) + _pynative_executor.set_enable_grad(True) + return outputs + return wrapper + + +class CellUtilMixin: + """ + A few utilities to be used as a mixin. + """ + + def get_head_mask( + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: + """ + Prepare the head mask if needed. + + Args: + head_mask (`mindspore.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): + The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). + num_hidden_layers (`int`): + The number of hidden layers in the model. + is_attention_chunked: (`bool`, *optional*, defaults to `False`): + Whether or not the attentions scores are computed by chunks or not. + + Returns: + `mindspore.Tensor` with shape `[num_hidden_layers x batch x + num_heads x seq_length x seq_length]` or list with + `[None]` for each layer. + """ + if head_mask is not None: + head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) + if is_attention_chunked is True: + head_mask = head_mask.expand_dims(-1) + else: + head_mask = () + for _ in range(num_hidden_layers): + head_mask += (None,) + + return head_mask + + def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers): + """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]""" + if head_mask.ndim == 1: + head_mask = head_mask.expand_dims(0).expand_dims(0).expand_dims(-1).expand_dims(-1) + head_mask = head_mask.broadcast_to(num_hidden_layers, -1, -1, -1, -1) + elif head_mask.ndim == 2: + head_mask = head_mask.expand_dims(1).expand_dims(-1)\ + .expand_dims(-1) # We can specify head_mask for each layer + assert head_mask.ndim == 5, f"head_mask.dim != 5, instead {head_mask.ndim}" + head_mask = head_mask.astype(dtype=self.dtype) # switch to float if need + fp16 compatibility + return head_mask + + @property + def dtype(self) -> mindspore.TensorType: + """ + `mindspore.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype). + """ + return get_parameter_dtype(self) + + +class SetAttribute(nn.Cell): + def __init__(self, module_name): + super().__init__() + module_name._is_initialized = True + + +class ModelOutputMindnlp(OrderedDict): + """ + Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a + tuple) or strings (like a dictionary) that will ignore the `None` attributes. Otherwise behaves like a regular + python dictionary. + + + + You can't unpack a `ModelOutput` directly. Use the [`~utils.ModelOutput.to_tuple`] method to convert it to a tuple + before. + + + """ + + def __post_init__(self): + class_fields = fields(self) + + # Safety and consistency checks + if len(class_fields) == 0: + raise ValueError(f"{self.__class__.__name__} has no fields.") + if not all(field.default is None for field in class_fields[1:]): + raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") + + first_field = getattr(self, class_fields[0].name) + other_fields_are_none = all(getattr(self, field.name) is None for field in class_fields[1:]) + + if other_fields_are_none and not is_tensor(first_field): + if isinstance(first_field, dict): + iterator = first_field.items() + first_field_iterator = True + else: + try: + iterator = iter(first_field) + first_field_iterator = True + except TypeError: + first_field_iterator = False + + # if we provided an iterator as first field and the iterator is a (key, value) iterator + # set the associated fields + if first_field_iterator: + for idx, element in enumerate(iterator): + if ( + not isinstance(element, (list, tuple)) + or not len(element) == 2 + or not isinstance(element[0], str) + ): + if idx == 0: + # If we do not have an iterator of key/values, set it as attribute + self[class_fields[0].name] = first_field + else: + # If we have a mixed iterator, raise an error + raise ValueError( + f"Cannot set key/value for {element}. It needs to be a tuple (key, value)." + ) + break + setattr(self, element[0], element[1]) + if element[1] is not None: + self[element[0]] = element[1] + elif first_field is not None: + self[class_fields[0].name] = first_field + else: + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __delitem__(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.") + + def setdefault(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.") + + def pop(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.") + + def update(self, *args, **kwargs): + raise RuntimeError(f"You cannot use ``update`` on a {self.__class__.__name__} instance.") + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + # Don't call self.__setitem__ to avoid recursion errors + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + # Will raise a KeyException if needed + super().__setitem__(key, value) + # Don't call self.__setattr__ to avoid recursion errors + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any]: + """ + Convert self to a tuple containing all the attributes/keys that are not `None`. + """ + return tuple(v for _, v in self.items()) + + +@dataclass +class GenerateDecoderOnlyOutput(ModelOutputMindnlp): + """ + Outputs of decoder-only generation models, when using non-beam methods. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + logits: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None + + +@dataclass +class GenerateEncoderDecoderOutput(ModelOutputMindnlp): + """ + Outputs of encoder-decoder generation models, when using non-beam methods. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + logits: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + past_key_values: Optional[Tuple[Tuple[Tuple[mindspore.Tensor]]]] = None + + +@dataclass +class SampleDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using sampling. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class SampleEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using sampling. Hidden states and attention weights of + the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class BeamSampleDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using beam sample. + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSampleEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using beam sampling. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using beam search. + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class BeamSearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using beam search. Hidden states and attention weights + of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the encoder_hidden_states + attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + sequences_scores: Optional[mindspore.Tensor] = None + scores: Optional[Tuple[mindspore.Tensor]] = None + beam_indices: Optional[mindspore.Tensor] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class GreedySearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using greedy search. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class GreedySearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of encoder-decoder generation models using greedy search. Hidden states and attention + weights of the decoder (respectively the encoder) can be accessed via the encoder_attentions and the + encoder_hidden_states attributes (respectively the decoder_attentions and the decoder_hidden_states attributes) + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +@dataclass +class ContrastiveSearchEncoderDecoderOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using contrastive search. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + encoder_attentions: Optional[Tuple[mindspore.Tensor]] = None + encoder_hidden_states: Optional[Tuple[mindspore.Tensor]] = None + decoder_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + cross_attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + decoder_hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +@dataclass +class ContrastiveSearchDecoderOnlyOutput(ModelOutputMindnlp): + """ + Base class for outputs of decoder-only generation models using contrastive search. + + Args: + sequences (`mindspore.Tensor` of shape `(batch_size, sequence_length)`): + The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter + if all batches finished early due to the `eos_token_id`. + scores (`tuple(mindspore.Tensor)` *optional*, returned when `output_scores=True` is passed or when + `config.output_scores=True`): + Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) + at each generation step. Tuple of `mindspore.Tensor` with up to `max_new_tokens` elements (one element for + each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. + attentions (`tuple(tuple(mindspore.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `mindspore.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. + hidden_states (`tuple(tuple(mindspore.Tensor))`, *optional*, returned when `output_hidden_states=True` is + passed or when `config.output_hidden_states=True`): + Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of + `mindspore.Tensor` of shape `(batch_size, generated_length, hidden_size)`. + """ + + sequences: mindspore.Tensor = None + scores: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + + +# @dataclass +# class GreedySearchDecoderOnlyOutput(ModelOutputMindnlp): +# """ +# Base class for outputs of decoder-only generation models using greedy search. +# """ + +# sequences: mindspore.Tensor = None +# scores: Optional[Tuple[mindspore.Tensor]] = None +# attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None +# hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + +GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] +SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] +BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] +BeamSampleOutput = Union[BeamSampleEncoderDecoderOutput, BeamSampleDecoderOnlyOutput] +ContrastiveSearchOutput = Union[ContrastiveSearchEncoderDecoderOutput, ContrastiveSearchDecoderOnlyOutput] +GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamSampleOutput, ContrastiveSearchOutput] + + +class StoppingCriteriaList(list): + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + return any(criteria(input_ids, scores) for criteria in self) + + @property + def max_length(self) -> Optional[int]: + for stopping_criterium in self: + if not isinstance(stopping_criterium, MaxLengthCriteria): + raise TypeError + return stopping_criterium.max_length + + +class GenerationConfig: + """ + Class that holds a configuration for a generation task. + """ + def __init__(self, **kwargs): + # Parameters that control the length of the output + self.max_length = kwargs.pop("max_length", 20) + self.max_new_tokens = kwargs.pop("max_new_tokens", None) + self.min_length = kwargs.pop("min_length", 0) + self.min_new_tokens = kwargs.pop("min_new_tokens", None) + self.early_stopping = kwargs.pop("early_stopping", False) + self.use_cache = kwargs.pop("use_cache", True) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) + + # # Parameters that define the output variables of `generate` + self.output_attentions = kwargs.pop("output_attentions", False) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_scores = kwargs.pop("output_scores", False) + self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) + + # Special tokens that can be used at generation time + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self._from_model_config = kwargs.pop("_from_model_config", False) + + # Additional attributes without default values + if not self._from_model_config: + # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a + # model's default configuration file + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error("Can't set %s with value %s", key, value) + raise err + + # Validate the values of the attributes + self.validate(is_init=True) + + @classmethod + def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a Python dictionary of parameters. + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + # Those arguments may be passed along for our internal telemetry. + # We remove them so they don't appear in `return_unused_kwargs`. + kwargs.pop("_from_auto", None) + kwargs.pop("_from_pipeline", None) + # The commit hash might have been updated in the `config_dict`, we don't want the kwargs to erase that update. + commit_hash_str = "_commit_hash" + if commit_hash_str in kwargs and commit_hash_str in config_dict: + kwargs[commit_hash_str] = config_dict[commit_hash_str] + + config = cls(**config_dict) + unused_kwargs = config.update(**kwargs) + + return config + + @classmethod + def from_model_config(cls, model_config) -> "GenerationConfig": + """ + Instantiates a [`GenerationConfig`] from a [`PretrainedConfig`]. This function is useful to convert legacy + [`PretrainedConfig`] objects, which may contain generation parameters, into a stand-alone [`GenerationConfig`]. + """ + config_dict = model_config.to_dict() + config = cls.from_dict(config_dict, return_unused_kwargs=False) + config.set_from_model_config(True) + + return config + + + def set_from_model_config(self, value:bool): + """set _from_model_config""" + if not isinstance(value, bool): + raise TypeError + self._from_model_config = value + + def update(self, **kwargs): + """ + Updates attributes of this class instance with attributes from `kwargs` if they match existing atributtes, + returning all the unused kwargs. + """ + to_remove = [] + for key, value in kwargs.items(): + if hasattr(self, key): + setattr(self, key, value) + to_remove.append(key) + + # remove all the attributes that were updated, without modifying the input dict + unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} + return unused_kwargs + + def validate(self, is_init=False): + """ + Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence + of parameterization that can be detected as incorrect from the configuration instance alone. + + Note that some parameters are best validated at generate runtime, as they may depend on other inputs and/or the + model, such as parameters related to the generation length. + """ + + # Validation of individual attributes + if self.early_stopping not in {True, False, "never"}: + raise ValueError(f"`early_stopping` must be a boolean or 'never', but is {self.early_stopping}.") + + +class GenerationMixin: + """ + class GenerationMixin + A class containing all functions for auto-regressive text generation, to be used as a mixin in [`PreTrainedModel`]. + """ + + @staticmethod + def _expand_inputs_for_generation( + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[mindspore.Tensor] = None, + **model_kwargs, + ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: + """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" + + def _expand_dict_for_generation(dict_to_expand): + for key in dict_to_expand: + if dict_to_expand[key] is not None and isinstance(dict_to_expand[key], mindspore.Tensor): + dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0) + return dict_to_expand + + if input_ids is not None: + input_ids = input_ids.repeat_interleave(expand_size, dim=0) + + model_kwargs = _expand_dict_for_generation(model_kwargs) + + return input_ids, model_kwargs + + @staticmethod + def prepare_inputs_for_generation(self, *args, **kwargs): + """ + prepare_inputs_for_generation + """ + raise NotImplementedError( + "A model class needs to define a `prepare_inputs_for_generation` method in order to use `generate`." + ) + + @classmethod + def _maybe_initialize_input_ids_for_generation( + cls, + inputs: Optional[mindspore.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, + ) -> mindspore.Tensor: + """Initializes input ids for generation, if necessary.""" + if inputs is not None: + return inputs + + encoder_outputs = model_kwargs.get("encoder_outputs") + + if bos_token_id is None: + raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") + + # If there is some tensor in `model_kwargs`, we can infer the batch size from it. This is helpful with + # soft-prompting or in multimodal implementations built on top of decoder-only language models. + batch_size = 1 + for value in model_kwargs.values(): + if isinstance(value, mindspore.Tensor): + batch_size = value.shape[0] + break + return ops.ones((batch_size, 1), dtype=mindspore.int64) * bos_token_id + + @classmethod + def _prepare_decoder_input_ids_for_generation( + cls, + model_input_name: str, + model_kwargs: Dict[str, mindspore.Tensor], + ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: + """Prepares `decoder_input_ids` for generation with encoder-decoder models""" + # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, + # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. + input_ids_str = "input_ids_str" + if model_kwargs is not None and "decoder_input_ids" in model_kwargs: + decoder_input_ids = model_kwargs.pop("decoder_input_ids") + elif input_ids_str in model_kwargs and model_input_name != input_ids_str: + decoder_input_ids = model_kwargs.pop(input_ids_str) + else: + decoder_input_ids = None + + return decoder_input_ids, model_kwargs + + @classmethod + def _merge_criteria_processor_list( + cls, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if len(custom_list) == 0: + return default_list + for default in default_list: + for custom in custom_list: + if type(custom) is type(default): + object_type = "stopping criteria" if isinstance(custom, StoppingCriteria) else "logits processor" + raise ValueError( + f"A custom {object_type} of type {type(custom)} with values {custom} has been passed to" + f" `.generate()`, but it has already been created with the values {default}. {default} has been" + " created by passing the corresponding arguments to generate or by the model's config default" + f" values. If you just want to change the default values of {object_type} consider passing" + f" them as arguments to `.generate()` instead of using a custom {object_type}." + ) + default_list.extend(custom_list) + return default_list + + def _get_logits_warper( + self, + generation_config: GenerationConfig, + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances + used for multinomial sampling. + """ + + # instantiate warpers list + warpers = LogitsProcessorList() + + # In beam methods, we need to keep at least one non-eos token to explore continuations that might have a + # better score (i.e. keep len(list(generation_config.eos_token_id)) + 1) + if generation_config.num_beams > 1: + if isinstance(generation_config.eos_token_id, list): + min_tokens_to_keep = len(generation_config.eos_token_id) + 1 + else: + min_tokens_to_keep = 2 + else: + min_tokens_to_keep = 1 + + # the following idea is largely copied from this PR: https://github.com/huggingface/transformers/pull/5420/files + # all samplers can be found in `generation_utils_samplers.py` + if generation_config.temperature is not None and generation_config.temperature != 1.0: + warpers.append(TemperatureLogitsWarper(generation_config.temperature)) + if generation_config.top_k is not None and generation_config.top_k != 0: + warpers.append(TopKLogitsWarper(top_k=generation_config.top_k, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.top_p is not None and generation_config.top_p < 1.0: + warpers.append(TopPLogitsWarper(top_p=generation_config.top_p, min_tokens_to_keep=min_tokens_to_keep)) + if generation_config.typical_p is not None and generation_config.typical_p < 1.0: + warpers.append( + TypicalLogitsWarper(mass=generation_config.typical_p, min_tokens_to_keep=min_tokens_to_keep) + ) + return warpers + + def generate( + self, + inputs: Optional[mindspore.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, + ) -> Union[GreedySearchDecoderOnlyOutput, mindspore.Tensor]: + + # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call + self._validate_model_class() + generation_config = copy.deepcopy(self.generation_config) + model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs + generation_config.validate() + self._validate_model_kwargs(model_kwargs.copy()) + + # 2. Set generation parameters if not already defined + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + + # 3. Define model inputs + inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs( + inputs, generation_config.bos_token_id, model_kwargs) + + # 4. Define other model kwargs + model_kwargs["output_attentions"] = generation_config.output_attentions + model_kwargs["output_hidden_states"] = generation_config.output_hidden_states + if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds": + model_kwargs["use_cache"] = True + else: + model_kwargs["use_cache"] = generation_config.use_cache + + # 5. Prepare `input_ids` which will be used for auto-regressive generation + if not self.config.is_encoder_decoder: + input_ids = inputs_tensor if model_input_name == "input_ids" else model_kwargs.pop("input_ids") + if streamer is not None: + streamer.put(input_ids) + + # 6. Prepare `max_length` depending on other stopping criteria. + input_ids_length = input_ids.shape[-1] + has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None + self._validate_generated_length(generation_config, input_ids_length, has_default_max_length) + + # 7. prepare distribution pre_processing samplers + logits_processor = self._get_logits_processor( + generation_config=generation_config, + input_ids_seq_length=input_ids_length, + logits_processor=logits_processor) + + # 8. prepare stopping criteria + stopping_criteria = self._get_stopping_criteria( + generation_config=generation_config, stopping_criteria=stopping_criteria) + + # 11. prepare logits warper + logits_warper = self._get_logits_warper(generation_config) + + # 12. expand input_ids with `num_return_sequences` additional sequences per batch + input_ids, model_kwargs = self._expand_inputs_for_generation( + input_ids=input_ids, + expand_size=generation_config.num_return_sequences, + is_encoder_decoder=self.config.is_encoder_decoder, + **model_kwargs, + ) + + # 13. run sample + return self.sample( + input_ids, + logits_processor=logits_processor, + logits_warper=logits_warper, + stopping_criteria=stopping_criteria, + pad_token_id=generation_config.pad_token_id, + eos_token_id=generation_config.eos_token_id, + output_scores=generation_config.output_scores, + return_dict_in_generate=generation_config.return_dict_in_generate, + streamer=streamer, + **model_kwargs, + ) + + def sample( + self, + input_ids: mindspore.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, + ) -> Union[SampleOutput, mindspore.Tensor]: + r""" + Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and + can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. + """ + # init values + synced_gpus=None + # init values + logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() + stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() + if max_length is not None: + warnings.warn( + "`max_length` is deprecated in this function, use" + " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.", + UserWarning, + ) + stopping_criteria = validate_stopping_criteria(stopping_criteria, max_length) + logits_warper = logits_warper if logits_warper is not None else LogitsProcessorList() + pad_token_id = pad_token_id if pad_token_id is not None else self.generation_config.pad_token_id + eos_token_id = eos_token_id if eos_token_id is not None else self.generation_config.eos_token_id + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + eos_token_id_tensor = mindspore.tensor(eos_token_id) if eos_token_id is not None else None + output_scores = output_scores if output_scores is not None else self.generation_config.output_scores + output_attentions = ( + output_attentions if output_attentions is not None else self.generation_config.output_attentions + ) + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.generation_config.output_hidden_states + ) + return_dict_in_generate = ( + return_dict_in_generate + if return_dict_in_generate is not None + else self.generation_config.return_dict_in_generate + ) + + # init attention / hidden states / scores tuples + scores = () if (return_dict_in_generate and output_scores) else None + decoder_attentions = () if (return_dict_in_generate and output_attentions) else None + cross_attentions = () if (return_dict_in_generate and output_attentions) else None + decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None + + # if model is an encoder-decoder, retrieve encoder attention weights and hidden states + if return_dict_in_generate and self.config.is_encoder_decoder: + encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None + encoder_hidden_states = ( + model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None + ) + + # keep track of which sequences are already finished + unfinished_sequences = ops.ones(input_ids.shape[0], dtype=mindspore.int64) + + this_peer_finished = False # used by synced_gpus only + # auto-regressive generation + while True: + # prepare model inputs + model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + # forward pass to get next token + outputs = self( + **model_inputs, + return_dict=True, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + if synced_gpus and this_peer_finished: + continue # don't waste resources running the code we don't need + + if type(outputs) is dict: + outputs = ADDict(**outputs) + + next_token_logits = outputs.logits[:, -1, :] + # pre-process distribution + next_token_scores = logits_processor(input_ids, next_token_logits) + next_token_scores = logits_warper(input_ids, next_token_scores) + # Store scores, attentions and hidden_states when required + if return_dict_in_generate: + if output_scores: + scores += (next_token_scores,) + if output_attentions: + decoder_attentions += ( + (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,) + ) + if self.config.is_encoder_decoder: + cross_attentions += (outputs.cross_attentions,) + + if output_hidden_states: + decoder_hidden_states += ( + (outputs.decoder_hidden_states,) + if self.config.is_encoder_decoder + else (outputs.hidden_states,) + ) + + # sample + probs = ops.softmax(next_token_scores, axis=-1) + next_tokens = ops.multinomial(probs, num_samples=1).squeeze(1).astype(mindspore.int64) + # finished sentences should have their next token be a padding token + if eos_token_id is not None: + if pad_token_id is None: + raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.") + next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences) + # update generated ids, model inputs, and length for next step + input_ids = ops.cat([input_ids, next_tokens[:, None]], axis=-1) + if streamer is not None: + streamer.put(next_tokens) + model_kwargs = self._update_model_kwargs_for_generation( + outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder + ) + # if eos_token was found in one sentence, set sentence to finished + if eos_token_id_tensor is not None: + unfinished_sequences = unfinished_sequences.mul( + next_tokens.tile((eos_token_id_tensor.shape[0], 1)).ne(eos_token_id_tensor.unsqueeze(1)).prod(axis=0) + ) + + # stop when each sentence is finished + if unfinished_sequences.max() == 0: + this_peer_finished = True + + # stop if we exceed the maximum length + if stopping_criteria(input_ids, scores): + this_peer_finished = True + + if this_peer_finished and not synced_gpus: + break + if streamer is not None: + streamer.end() + if return_dict_in_generate: + if self.config.is_encoder_decoder: + return SampleEncoderDecoderOutput( + sequences=input_ids, + scores=scores, + encoder_attentions=encoder_attentions, + encoder_hidden_states=encoder_hidden_states, + decoder_attentions=decoder_attentions, + cross_attentions=cross_attentions, + decoder_hidden_states=decoder_hidden_states, + ) + else: + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) + else: + return input_ids + + def _prepare_model_inputs( + self, + inputs: Optional[mindspore.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, + ) -> Tuple[mindspore.Tensor, Optional[str], Dict[str, mindspore.Tensor]]: + """ + This function extracts the model-specific `inputs` for generation. + """ + # 1. retrieve all kwargs that are non-None or non-model input related. + # some encoder-decoder models have different names for model and encoder + if not self.config.is_encoder_decoder: + input_name = self.main_input_name + + model_kwargs = {k: v for k, v in model_kwargs.items() if v is not None or k != input_name} + + # 2. check whether model_input_name is passed as kwarg + # if yes and `inputs` is None use kwarg inputs + inputs_kwarg = model_kwargs.pop(input_name, None) + if inputs_kwarg is not None and inputs is not None: + raise ValueError( + f"`inputs`: {inputs}` were passed alongside {input_name} which is not allowed. " + f"Make sure to either pass {inputs} or {input_name}=..." + ) + if inputs_kwarg is not None: + inputs = inputs_kwarg + + # 3. In the presence of `inputs_embeds` for text models: + # - decoder-only models should complain if the user attempts to pass `inputs_embeds`, but the model + # doesn't have its forwarding implemented. `inputs_embeds` is kept in `model_kwargs` and can coexist with + # input_ids (`inputs_embeds` will be used in the 1st generation step, as opposed to `input_ids`) + inputs_embeds_str = "inputs_embeds" + input_ids_str = "input_ids" + if input_name == input_ids_str and inputs_embeds_str in model_kwargs: + if not self.config.is_encoder_decoder: + has_inputs_embeds_forwarding = inputs_embeds_str in set( + inspect.signature(self.prepare_inputs_for_generation).parameters.keys() + ) + if not has_inputs_embeds_forwarding: + raise ValueError( + f"You passed `inputs_embeds` to `.generate()`, but the model class {self.__class__.__name__} " + "doesn't have its forwarding implemented. See the GPT2 implementation for an example " + "(https://github.com/huggingface/transformers/pull/21405), and feel free to open a PR with it!" + ) + # In this case, `input_ids` is moved to the `model_kwargs`, so a few automations (like the creation of + # the attention mask) can rely on the actual model input. + model_kwargs[input_ids_str] = self._maybe_initialize_input_ids_for_generation( + inputs, bos_token_id, model_kwargs=model_kwargs + ) + else: + if inputs is not None: + raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") + inputs, input_name = model_kwargs[inputs_embeds_str], + + # 4. if `inputs` is still None, try to create `input_ids` from BOS token + inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) + return inputs, input_name, model_kwargs + + def _extract_past_from_model_output(self, outputs: ModelOutputMindnlp, standardize_cache_format: bool = False): + past_key_values = None + if "past_key_values" in outputs: + past_key_values = outputs.past_key_values + elif "mems" in outputs: + past_key_values = outputs.mems + elif "past_buckets_states" in outputs: + past_key_values = outputs.past_buckets_states + + # Bloom fix: standardizes the cache format when requested + if standardize_cache_format and hasattr(self, "_convert_to_standard_cache"): + batch_size = outputs.logits.shape[0] + past_key_values = self._convert_to_standard_cache(past_key_values, batch_size=batch_size) + return past_key_values + + def _update_model_kwargs_for_generation( + self, + outputs, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, + ) -> Dict[str, Any]: + # update past_key_values + model_kwargs["past_key_values"] = self._extract_past_from_model_output( + outputs, standardize_cache_format=standardize_cache_format + ) + + # update token_type_ids with last value + token_type_str = "token_type_ids" + if token_type_str in model_kwargs: + token_type_ids = model_kwargs[token_type_str] + model_kwargs[token_type_str] = ops.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], axis=-1) + + return model_kwargs + + def _reorder_cache(self, past, beam_idx): + raise NotImplementedError( + f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" + f" enable beam search for {self.__class__}" + ) + + def _get_logits_processor( + self, + generation_config: GenerationConfig, + input_ids_seq_length: int, + logits_processor: Optional[LogitsProcessorList], + ) -> LogitsProcessorList: + """ + This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] + instances used to modify the scores of the language model head. + """ + # instantiate processors list + processors = LogitsProcessorList() + + if generation_config.bad_words_ids is not None: + processors.append( + NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) + ) + if ( + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 + ): + processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) + if ( + generation_config.min_new_tokens is not None + and generation_config.eos_token_id is not None + and generation_config.min_new_tokens > 0 + ): + processors.append( + MinNewTokensLengthLogitsProcessor( + input_ids_seq_length, generation_config.min_new_tokens, generation_config.eos_token_id + ) + ) + + if generation_config.forced_eos_token_id is not None: + processors.append( + ForcedEOSTokenLogitsProcessor(generation_config.max_length, generation_config.forced_eos_token_id) + ) + + processors = self._merge_criteria_processor_list(processors, logits_processor) + + return processors + + def _get_stopping_criteria( + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + ) -> StoppingCriteriaList: + criteria = StoppingCriteriaList() + if generation_config.max_length is not None: + criteria.append(MaxLengthCriteria(max_length=generation_config.max_length)) + criteria = self._merge_criteria_processor_list(criteria, stopping_criteria) + return criteria + + def _validate_model_class(self): + """ + Confirms that the model class is compatible with generation. If not, raises an exception that points to the + right class to use. + """ + if not self.can_generate(): + raise NotImplementedError( + "TODO: You need to implement this function." + ) + + def _validate_model_kwargs(self, model_kwargs: Dict[str, Any]): + """Validates model kwargs for generation. Generate argument typos will also be caught here.""" + # Excludes arguments that are handled before calling any model function + if self.config.is_encoder_decoder: + for key in ["decoder_input_ids"]: + model_kwargs.pop(key, None) + + def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length): + """Performs validation related to the resulting generated length""" + + # 1. Max length warnings related to poor parameterization + if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20: + # 20 is the default max_length of the generation config + warnings.warn( + f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the " + "generation length. We recommend setting `max_new_tokens` to control the maximum length of the " + "generation.", + UserWarning, + ) + if input_ids_length >= generation_config.max_length: + input_ids_string = "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids" + warnings.warn( + f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to" + f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider" + " increasing `max_new_tokens`.", + UserWarning, + ) + + # 2. Min length warnings due to unfeasible parameter combinations + min_length_error_suffix = ( + " Generation will stop at the defined maximum length. You should decrease the minimum length and/or " + "increase the maximum length." + ) + if has_default_max_length: + min_length_error_suffix += ( + f" Note that `max_length` is set to {generation_config.max_length}, its default value." + ) + if generation_config.min_length is not None and generation_config.min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_length` ({generation_config.min_length}) is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + if generation_config.min_new_tokens is not None: + min_length = generation_config.min_new_tokens + input_ids_length + if min_length > generation_config.max_length: + warnings.warn( + f"Unfeasible length constraints: `min_new_tokens` ({generation_config.min_new_tokens}), when " + f"added to the prompt length ({input_ids_length}), is larger than" + f" the maximum possible length ({generation_config.max_length})." + min_length_error_suffix, + UserWarning, + ) + + +class ExplicitEnum(str, Enum): + """ + Enum with more explicit error message for missing values. + """ + + @classmethod + def _missing_(cls, value): + raise ValueError( + f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" + ) + + +class GenerationMode(ExplicitEnum): + """ + Possible generation modes, downstream of the [`~generation.GenerationMixin.generate`] method. + """ + + # Non-beam methods + CONTRASTIVE_SEARCH = "contrastive_search" + GREEDY_SEARCH = "greedy_search" + SAMPLE = "sample" + ASSISTED_GENERATION = "assisted_generation" + # Beam methods + BEAM_SEARCH = "beam_search" + BEAM_SAMPLE = "beam_sample" + CONSTRAINED_BEAM_SEARCH = "constrained_beam_search" + GROUP_BEAM_SEARCH = "group_beam_search" + + +class PeftAdapterMixin: + """ + A class containing all functions for loading and using adapters weights that are supported in PEFT library. For + more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT + library: https://huggingface.co/docs/peft/index + """ + + _hf_peft_config_loaded = False + + def load_adapter( + self, + peft_model_id: Optional[str] = None, + adapter_name: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + device_map: Optional[str] = "auto", + max_memory: Optional[str] = None, + offload_folder: Optional[str] = None, + offload_index: Optional[int] = None, + peft_config: Dict[str, Any] = None, + adapter_state_dict: Optional[Dict[str, "mindspore.Tensor"]] = None, + adapter_kwargs: Optional[Dict[str, Any]] = None, + ) -> None: + """ + Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we + invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft + + Requires peft as a backend to load the adapter weights. + + Args: + peft_model_id (`str`, *optional*): + The identifier of the model to look for on the Hub, or a local path to the saved adapter config file + and adapter weights. + adapter_name (`str`, *optional*): + The adapter name to use. If not set, will use the default adapter. + revision (`str`, *optional*, defaults to `"main"`): + The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a + git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any + identifier allowed by git. + + + + To test a pull request you made on the Hub, you can pass `revision="refs/pr/". + + + + token (`str`, `optional`): + Whether to use authentication token to load the remote folder. Userful to load private repositories + that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to + cache it. + device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): + A map that specifies where each submodule should go. It doesn't need to be refined to each + parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the + same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank + like `1`) on which the model will be allocated, the device map will map the entire model to this + device. Passing `device_map = 0` means put the whole model on GPU 0. + + To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For + more information about each option see [designing a device + map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). + max_memory (`Dict`, *optional*): + A dictionary device identifier to maximum memory. Will default to the maximum memory available for each + GPU and the available CPU RAM if unset. + offload_folder (`str` or `os.PathLike`, `optional`): + If the `device_map` contains any value `"disk"`, the folder where we will offload weights. + offload_index (`int`, `optional`): + `offload_index` argument to be passed to `accelerate.dispatch_model` method. + peft_config (`Dict[str, Any]`, *optional*): + The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts + methods. This argument is used in case users directly pass PEFT state dicts + adapter_state_dict (`Dict[str, mindspore.Tensor]`, *optional*): + The state dict of the adapter to load. This argument is used in case users directly pass PEFT state + dicts + adapter_kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and + `find_adapter_config_file` method. + """ + + adapter_name = adapter_name if adapter_name is not None else "default" + if adapter_kwargs is None: + adapter_kwargs = {} + + from ...peft import PeftConfig, inject_adapter_in_model, load_peft_weights + from ...peft.utils import set_peft_model_state_dict + + if self._hf_peft_config_loaded and adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if peft_model_id is None and (adapter_state_dict is None and peft_config is None): + raise ValueError( + "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." + ) + + # We keep `revision` in the signature for backward compatibility + if revision is not None and "revision" not in adapter_kwargs: + adapter_kwargs["revision"] = revision + elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]: + logger.error( + "You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. " + "The one in `adapter_kwargs` will be used." + ) + + # Override token with adapter_kwargs' token + if "token" in adapter_kwargs: + token = adapter_kwargs.pop("token") + + if peft_config is None: + adapter_config_file = find_adapter_config_file( + peft_model_id, + token=token, + **adapter_kwargs, + ) + + if adapter_config_file is None: + raise ValueError( + f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " + "adapter model." + ) + + peft_config = PeftConfig.from_pretrained( + peft_model_id, + token=token, + **adapter_kwargs, + ) + + # Create and add fresh new adapters into the model. + inject_adapter_in_model(peft_config, self, adapter_name) + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + + if peft_model_id is not None: + adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs) + + # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility + processed_adapter_state_dict = {} + prefix = "base_model.model." + for key, value in adapter_state_dict.items(): + if key.startswith(prefix): + new_key = key[len(prefix) :] + else: + new_key = key + processed_adapter_state_dict[new_key] = value + + # Load state dict + incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) + + if incompatible_keys is not None: + # check only for unexpected keys + if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: + logger.warning( + f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: " + f" {incompatible_keys.unexpected_keys}. " + ) + + # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. + if ( + (getattr(self, "hf_device_map", None) is not None) + and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) + and len(self.peft_config) == 1 + ): + self._dispatch_accelerate_model( + device_map=device_map, + max_memory=max_memory, + offload_folder=offload_folder, + offload_index=offload_index, + ) + + def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None: + r""" + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default + name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the + default adapter name). + + Args: + adapter_config (`~peft.PeftConfig`): + The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts + methods + adapter_name (`str`, *optional*, defaults to `"default"`): + The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. + """ + from ...peft import PeftConfig, inject_adapter_in_model + + adapter_name = adapter_name or "default" + + if not self._hf_peft_config_loaded: + self._hf_peft_config_loaded = True + elif adapter_name in self.peft_config: + raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") + + if not isinstance(adapter_config, PeftConfig): + raise ValueError( + f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." + ) + + # Retrieve the name or path of the model, one could also use self.config._name_or_path + # but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100 + adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None) + inject_adapter_in_model(adapter_config, self, adapter_name) + + self.set_adapter(adapter_name) + + def set_adapter(self, adapter_name: Union[List[str], str]) -> None: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. + + Args: + adapter_name (`Union[List[str], str]`): + The name of the adapter to set. Can be also a list of strings to set multiple adapters. + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + elif isinstance(adapter_name, list): + missing = set(adapter_name) - set(self.peft_config) + if len(missing) > 0: + raise ValueError( + f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." + f" current loaded adapters are: {list(self.peft_config.keys())}" + ) + elif adapter_name not in self.peft_config: + raise ValueError( + f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" + ) + + from ...peft.tuners.tuners_utils import BaseTunerLayer + from ...peft.utils import ModulesToSaveWrapper + + _adapters_has_been_set = False + + for _, module in self.named_modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + # For backward compatbility with previous PEFT versions + if hasattr(module, "set_adapter"): + module.set_adapter(adapter_name) + else: + module.active_adapter = adapter_name + _adapters_has_been_set = True + + if not _adapters_has_been_set: + raise ValueError( + "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." + ) + + def disable_adapters(self) -> None: + r""" + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Disable all adapters that are attached to the model. This leads to inferring with the base model only. + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from ...peft.tuners.tuners_utils import BaseTunerLayer + from ...peft.utils import ModulesToSaveWrapper + + for _, module in self.named_modules(): + if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): + # The recent version of PEFT need to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=False) + else: + module.disable_adapters = True + + def enable_adapters(self) -> None: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Enable adapters that are attached to the model. The model will use `self.active_adapter()` + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from ...peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + # The recent version of PEFT need to call `enable_adapters` instead + if hasattr(module, "enable_adapters"): + module.enable_adapters(enabled=True) + else: + module.disable_adapters = False + + def active_adapters(self) -> List[str]: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters + for inference) returns the list of all active adapters so that users can deal with them accordingly. + + For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return + a single string. + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from ...peft.tuners.tuners_utils import BaseTunerLayer + + for _, module in self.named_modules(): + if isinstance(module, BaseTunerLayer): + active_adapters = module.active_adapter + break + + # For previous PEFT versions + if isinstance(active_adapters, str): + active_adapters = [active_adapters] + + return active_adapters + + def active_adapter(self) -> str: + warnings.warn( + "The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning + ) + + return self.active_adapters()[0] + + def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: + """ + If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT + official documentation: https://huggingface.co/docs/peft + + Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. + If no adapter_name is passed, the active adapter is used. + + Args: + adapter_name (`str`, *optional*): + The name of the adapter to get the state dict from. If no name is passed, the active adapter is used. + """ + if not self._hf_peft_config_loaded: + raise ValueError("No adapter loaded. Please load an adapter first.") + + from ...peft import get_peft_model_state_dict + + if adapter_name is None: + adapter_name = self.active_adapter() + + adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name) + return adapter_state_dict + + +@jit_class +class PretrainedConfig: + """ + Abstract class for Pretrained models config. + """ + is_composition = False + # Add for handle attribute_map + attribute_map: Dict[str, str] = {} + + def __init__(self, **kwargs): + self.ms_dtype = kwargs.pop("ms_dtype", None) + if 'torch_dtype' in kwargs: + self.ms_dtype = kwargs.pop("torch_dtype", None) + self.return_dict = kwargs.pop("return_dict", True) + self.output_hidden_states = kwargs.pop("output_hidden_states", False) + self.output_attentions = kwargs.pop("output_attentions", False) + + self.pruned_heads = kwargs.pop("pruned_heads", {}) + self.tie_word_embeddings = kwargs.pop( + "tie_word_embeddings", True + ) # Whether input and output word embeddings should be tied for all MLM, LM and Seq2Seq models. + + # Is decoder is used in encoder-decoder models to differentiate encoder from decoder + self.is_encoder_decoder = kwargs.pop("is_encoder_decoder", False) + self.is_decoder = kwargs.pop("is_decoder", False) + self.cross_attention_hidden_size = kwargs.pop("cross_attention_hidden_size", None) + self.add_cross_attention = kwargs.pop("add_cross_attention", False) + self.tie_encoder_decoder = kwargs.pop("tie_encoder_decoder", False) + + # Parameters for sequence generation + self.max_length = kwargs.pop("max_length", 20) + self.min_length = kwargs.pop("min_length", 0) + self.do_sample = kwargs.pop("do_sample", False) + self.early_stopping = kwargs.pop("early_stopping", False) + self.num_beams = kwargs.pop("num_beams", 1) + self.num_beam_groups = kwargs.pop("num_beam_groups", 1) + self.diversity_penalty = kwargs.pop("diversity_penalty", 0.0) + self.temperature = kwargs.pop("temperature", 1.0) + self.top_k = kwargs.pop("top_k", 50) + self.top_p = kwargs.pop("top_p", 1.0) + self.typical_p = kwargs.pop("typical_p", 1.0) + self.repetition_penalty = kwargs.pop("repetition_penalty", 1.0) + self.length_penalty = kwargs.pop("length_penalty", 1.0) + self.no_repeat_ngram_size = kwargs.pop("no_repeat_ngram_size", 0) + self.encoder_no_repeat_ngram_size = kwargs.pop("encoder_no_repeat_ngram_size", 0) + self.bad_words_ids = kwargs.pop("bad_words_ids", None) + self.num_return_sequences = kwargs.pop("num_return_sequences", 1) + self.chunk_size_feed_forward = kwargs.pop("chunk_size_feed_forward", 0) + self.output_scores = kwargs.pop("output_scores", False) + self.return_dict_in_generate = kwargs.pop("return_dict_in_generate", False) + self.forced_bos_token_id = kwargs.pop("forced_bos_token_id", None) + self.forced_eos_token_id = kwargs.pop("forced_eos_token_id", None) + self.remove_invalid_values = kwargs.pop("remove_invalid_values", False) + self.exponential_decay_length_penalty = kwargs.pop("exponential_decay_length_penalty", None) + self.suppress_tokens = kwargs.pop("suppress_tokens", None) + self.begin_suppress_tokens = kwargs.pop("begin_suppress_tokens", None) + + # Fine-tuning task arguments + self.architectures = kwargs.pop("architectures", None) + self.finetuning_task = kwargs.pop("finetuning_task", None) + self.id2label = kwargs.pop("id2label", None) + self.label2id = kwargs.pop("label2id", None) + if self.label2id is not None and not isinstance(self.label2id, dict): + raise ValueError("Argument label2id should be a dictionary.") + if self.id2label is not None: + if not isinstance(self.id2label, dict): + raise ValueError("Argument id2label should be a dictionary.") + num_labels = kwargs.pop("num_labels", None) + + if num_labels is not None and len(self.id2label) != num_labels: + logger.warning( + f"You passed along `num_labels={num_labels}` with an incompatible id to label map: " + f"{self.id2label}. The number of labels wil be overwritten to {self.num_labels}." + ) + self.id2label = {int(key): value for key, value in self.id2label.items()} + # Keys are always strings in JSON so convert ids to int here. + else: + self.num_labels = kwargs.pop("num_labels", 2) + + if self.ms_dtype is not None and isinstance(self.ms_dtype, str): + if is_mindspore_available(): + import mindspore + + self.ms_dtype = getattr(mindspore, self.ms_dtype) + + # Tokenizer arguments TODO: eventually tokenizer and models should share the same config + self.tokenizer_class = kwargs.pop("tokenizer_class", None) + self.prefix = kwargs.pop("prefix", None) + self.bos_token_id = kwargs.pop("bos_token_id", None) + self.pad_token_id = kwargs.pop("pad_token_id", None) + self.eos_token_id = kwargs.pop("eos_token_id", None) + self.sep_token_id = kwargs.pop("sep_token_id", None) + + self.decoder_start_token_id = kwargs.pop("decoder_start_token_id", None) + + # task specific arguments + self.task_specific_params = kwargs.pop("task_specific_params", None) + + # regression / multi-label classification + self.problem_type = kwargs.pop("problem_type", None) + allowed_problem_types = ("regression", "single_label_classification", "multi_label_classification") + if self.problem_type is not None and self.problem_type not in allowed_problem_types: + raise ValueError( + f"The config parameter `problem_type` was not understood: received {self.problem_type} " + "but only 'regression', 'single_label_classification' and 'multi_label_classification' are valid." + ) + + # Name or path to the pretrained checkpoint + self._name_or_path = str(kwargs.pop("name_or_path", "")) + + # Additional attributes without default values + for key, value in kwargs.items(): + try: + setattr(self, key, value) + except AttributeError as err: + logger.error(f"Can't set {key} with value {value} for {self}") + raise err + + def __setattr__(self, key, value): + if key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + super().__setattr__(key, value) + + def __getattribute__(self, key): + if key != "attribute_map" and key in super().__getattribute__("attribute_map"): + key = super().__getattribute__("attribute_map")[key] + return super().__getattribute__(key) + + @property + def name_or_path(self) -> str: + """get name_or_path""" + return getattr(self, "_name_or_path", None) + + @name_or_path.setter + def name_or_path(self, value): + """set name_or_path""" + self._name_or_path = str(value) # Make sure that name_or_path is a string (for JSON encoding) + + @classmethod + def from_json(cls, file_path): + """load config from json.""" + with open(file_path, "r", encoding="utf-8") as file: + text = file.read() + config_map = json.loads(text) + config = cls() + for key, value in config_map.items(): + setattr(config, key, value) + return config + + @classmethod + def from_json_file(cls, json_file): + """Constructs a `Config` from a json file of parameters.""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + dict_obj = json.loads(text) + return cls(**dict_obj) + + @classmethod + def load(cls, pretrained_model_name_or_path): + """load config.""" + return cls.from_pretrained(pretrained_model_name_or_path) + + @property + def use_return_dict(self) -> bool: + """ + `bool`: Whether or not return [`~utils.ModelOutput`] instead of tuples. + """ + # If torchscript is set, force `return_dict=False` to avoid jit errors + return self.return_dict + + @classmethod + def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": + """ + Constructs a `Config` from a Python dictionary of parameters. + + Args: + config_dict (:obj:`Dict[str, any]`): + Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved + from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` + method. + kwargs (:obj:`Dict[str, any]`): + Additional parameters from which to initialize the configuration object. + + Returns: + :class:`PretrainedConfig`: An instance of a configuration object + """ + return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) + + config = cls(**config_dict) + + if hasattr(config, "pruned_heads"): + config.pruned_heads = dict((int(key), value) for key, value in config.pruned_heads.items()) + + # Update config with kwargs if needed + if "num_labels" in kwargs and "id2label" in kwargs: + num_labels = kwargs["num_labels"] + id2label = kwargs["id2label"] if kwargs["id2label"] is not None else [] + if len(id2label) != num_labels: + raise ValueError( + f"You passed along `num_labels={num_labels }` with an incompatible id to label map: " + f"{kwargs['id2label']}. Since those arguments are inconsistent with each other, you should remove " + "one of them.") + + to_remove = [] + for key, value in kwargs.items(): + if hasattr(config, key): + setattr(config, key, value) + to_remove.append(key) + for key in to_remove: + kwargs.pop(key, None) + + logger.info("Model config %s", str(config)) + if return_unused_kwargs: + return config, kwargs + return config + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + mirror: str = 'huggingface', + **kwargs, + ) -> "PretrainedConfig": + r""" + Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. + """ + kwargs["cache_dir"] = cache_dir + kwargs["force_download"] = force_download + kwargs["local_files_only"] = local_files_only + kwargs['mirror'] = mirror + + config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs) + if "model_type" in config_dict and hasattr(cls, "model_type") and config_dict["model_type"] != cls.model_type: + logger.warning( + f"You are using a model of type {config_dict['model_type']} to instantiate a model of type " + f"{cls.model_type}. This is not supported for all configurations of models and can yield errors." + ) + + return cls.from_dict(config_dict, **kwargs) + + @classmethod + def get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a + [`PretrainedConfig`] using `from_dict`. + + Parameters: + pretrained_model_name_or_path (`str` or `os.PathLike`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. + + """ + original_kwargs = copy.deepcopy(kwargs) + # Get config dict associated with the base config file + config_dict, kwargs = cls._get_config_dict(pretrained_model_name_or_path, **kwargs) + + # That config file may point us toward another config file to use. + if "configuration_files" in config_dict: + configuration_file = get_configuration_file(config_dict["configuration_files"]) + config_dict, kwargs = cls._get_config_dict( + pretrained_model_name_or_path, _configuration_file=configuration_file, **original_kwargs + ) + + return config_dict, kwargs + + @classmethod + def _get_config_dict( + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + ) -> Tuple[Dict[str, Any], Dict[str, Any]]: + """ + From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used + for instantiating a Config using `from_dict`. + + Parameters: + pretrained_model_name_or_path (:obj:`string`): + The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. + + Returns: + :obj:`Tuple[Dict, Dict]`: The dictionary that will be used to instantiate the configuration object. + + """ + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + subfolder = kwargs.pop("subfolder", "") + token = kwargs.pop('token', None) + revision = kwargs.pop('revision', 'main') + mirror = kwargs.pop('mirror', 'huggingface') + + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + + is_local = os.path.isdir(pretrained_model_name_or_path) + if os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + # Special case when pretrained_model_name_or_path is a local file + resolved_config_file = pretrained_model_name_or_path + is_local = True + + elif is_remote_url(pretrained_model_name_or_path): + configuration_file = pretrained_model_name_or_path + resolved_config_file = download_url(pretrained_model_name_or_path) + + else: + configuration_file = kwargs.pop("_configuration_file", CONFIG_NAME) + + try: + # Load from local folder or from cache or download from model Hub and cache + resolved_config_file = cached_file( + pretrained_model_name_or_path, + configuration_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + revision=revision, + token=token, + subfolder=subfolder, + mirror=mirror + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted to + # the original exception. + raise + except Exception as exc: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the configuration of '{pretrained_model_name_or_path}'. If you were trying to load it" + ", make sure you don't have a local directory with the same" + f" name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory" + f" containing a {configuration_file} file" + ) from exc + + try: + # Load config dict + config_dict = cls._dict_from_json_file(resolved_config_file) + except (json.JSONDecodeError, UnicodeDecodeError) as exc: + raise EnvironmentError( + f"It looks like the config file at '{resolved_config_file}' is not a valid JSON file." + ) from exc + + if is_local: + logger.info(f"loading configuration file {resolved_config_file}") + else: + logger.info(f"loading configuration file {configuration_file} from cache at {resolved_config_file}") + + return config_dict, kwargs + + @classmethod + def _dict_from_json_file(cls, json_file: str): + """_dict_from_json_file""" + with open(json_file, "r", encoding="utf-8") as reader: + text = reader.read() + return json.loads(text) + + def dict_ms_dtype_to_str(self, d: Dict[str, Any]) -> None: + """ + Checks whether the passed dictionary and its nested dicts have a *torch_dtype* key and if it's not None, + converts torch.dtype to a string of just the type. For example, `torch.float32` get converted into *"float32"* + string, which can then be stored in the json format. + """ + if d.get("ms_dtype", None) is not None and not isinstance(d["ms_dtype"], str): + d["ms_dtype"] = str(d["ms_dtype"]).lower() + for value in d.values(): + if isinstance(value, dict): + self.dict_ms_dtype_to_str(value) + + def to_dict(self) -> Dict[str, Any]: + """ + Serializes this instance to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance. + """ + output = copy.deepcopy(self.__dict__) + if hasattr(self.__class__, "model_type"): + output["model_type"] = self.__class__.model_type + if "_auto_class" in output: + del output["_auto_class"] + if "_commit_hash" in output: + del output["_commit_hash"] + if "_attn_implementation_internal" in output: + del output["_attn_implementation_internal"] + + for key, value in output.items(): + # Deal with nested configs like CLIP + if isinstance(value, PretrainedConfig): + value = value.to_dict() + + output[key] = value + + if hasattr(self, "quantization_config"): + output["quantization_config"] = ( + self.quantization_config.to_dict() + if not isinstance(self.quantization_config, dict) + else self.quantization_config + ) + + # pop the `_pre_quantization_dtype` as torch.dtypes are not serializable. + _ = output.pop("_pre_quantization_dtype", None) + + self.dict_ms_dtype_to_str(output) + + return output + + def to_diff_dict(self) -> Dict[str, Any]: + """ + Removes all attributes from config which correspond to the default config attributes for better readability and + serializes to a Python dictionary. + + Returns: + `Dict[str, Any]`: Dictionary of all the attributes that make up this configuration instance, + """ + config_dict = self.to_dict() + + # get the default config dict + default_config_dict = PretrainedConfig().to_dict() + + # get class specific config dict + class_config_dict = self.__class__().to_dict() if not self.is_composition else {} + + serializable_config_dict = {} + + # only serialize values that differ from the default config + for key, value in config_dict.items(): + if ( + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) + ): + # For nested configs we need to clean the diff recursively + diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) + if "model_type" in value: + # Needs to be set even if it's not in the diff + diff["model_type"] = value["model_type"] + if len(diff) > 0: + serializable_config_dict[key] = diff + elif ( + key not in default_config_dict + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) + ): + serializable_config_dict[key] = value + + return serializable_config_dict + + def to_json_string(self, use_diff: bool = True) -> str: + """ + Serializes this instance to a JSON string. + """ + if use_diff is True: + config_dict = self.to_diff_dict() + else: + config_dict = self.to_dict() + + return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" + + def to_file(self, save_path): + """Serializes this instance to a JSON file.""" + output_dict = self.to_dict() + with open(os.path.join(save_path, 'config.json'), encoding='utf-8') as f: + json.dump(output_dict, f, sort_keys=True, indent=2) + + def update(self, config_dict: Dict[str, Any]): + """ + Updates attributes of this class with attributes from `config_dict`. + """ + for key, value in config_dict.items(): + setattr(self, key, value) + + def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + """ + Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the + [`~PretrainedConfig.from_pretrained`] class method. + """ + + if os.path.isfile(save_directory): + raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file") + + os.makedirs(save_directory, exist_ok=True) + + # If we save using the predefined names, we can load using `from_pretrained` + output_config_file = os.path.join(save_directory, CONFIG_NAME) + + self.to_json_file(output_config_file, use_diff=True) + logger.info(f"Configuration saved in {output_config_file}") + + def to_json_file(self, json_file_path: Union[str, os.PathLike], use_diff: bool = True): + """ + Save this instance to a JSON file. + """ + with open(json_file_path, "w", encoding="utf-8") as writer: + writer.write(self.to_json_string(use_diff=use_diff)) + + @property + def num_labels(self) -> int: + """ + `int`: The number of labels for classification models. + """ + return len(self.id2label) + + @num_labels.setter + def num_labels(self, num_labels: int): + if not hasattr(self, "id2label") or self.id2label is None or len(self.id2label) != num_labels: + self.id2label = {i: f"LABEL_{i}" for i in range(num_labels)} + self.label2id = dict(zip(self.id2label.values(), self.id2label.keys())) + + @property + def _attn_implementation(self): + # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + if hasattr(self, "_attn_implementation_internal"): + if self._attn_implementation_internal is None: + # `config.attn_implementation` should never be None, for backward compatibility. + return "eager" + else: + return self._attn_implementation_internal + else: + return "eager" + + @_attn_implementation.setter + def _attn_implementation(self, value): + self._attn_implementation_internal = value + + + +class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapterMixin): + """ + Abstract class for Pretrained models + """ + config_class = None + base_model_prefix = "" + main_input_name = "input_ids" + + # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing + # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings. + _keys_to_ignore_on_load_missing = None + # a list of `re` patterns of `state_dict` keys that should be removed from the list of + # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary + # warnings. + _keys_to_ignore_on_load_unexpected = None + _keys_to_ignore_on_save = None + + _tied_weights_keys = None + + _keep_in_fp32_modules = None + + supports_recompute = False + + def __init__(self, config): + super().__init__(config) + self._check_and_unset_acl() + # Save config in model + self.config = config + self.name_or_path = config.name_or_path + self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None + + def _check_and_unset_acl(self): + if "MS" in str(self.__class__.__name__) and \ + 'MS_DEV_FORCE_ACL' in os.environ: + del os.environ['MS_DEV_FORCE_ACL'] + + def post_init(self): + """ + A method executed at the end of each Transformer model initialization, to execute code that needs the model's + modules properly initialized (such as weight initialization). + + """ + self.init_weights() + + @classmethod + def _from_config(cls, config, **kwargs): + """ + All context managers that the model should be initialized under go here. + + Args: + torch_dtype (`torch.dtype`, *optional*): + Override the default `torch.dtype` and load the model under this dtype. + """ + model = cls(config, **kwargs) + + return model + + def init_weights(self): + """ + If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any + initialization logic in `_init_weights`. + """ + # Prune heads if needed + if self.config.pruned_heads: + self.prune_heads(self.config.pruned_heads) + + if _init_weights: + # Initialize weights + if getattr(self, 'apply', None): + self.apply(self._initialize_weights) + else: + for _, cell in self.name_cells().items(): + self._initialize_weights(cell) + + # Tie weights should be skipped when not initializing all weights + # since from_pretrained(...) calls tie weights anyways + self.tie_weights() + + def prune_heads(self, heads_to_prune: Dict[int, List[int]]): + """ + Prunes heads of the base model. + + Arguments: + heads_to_prune (`Dict[int, List[int]]`): + Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads + to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on + layer 1 and heads 2 and 3 on layer 2. + """ + # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads + for layer, heads in heads_to_prune.items(): + union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) + self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON + + self.base_model._prune_heads(heads_to_prune) + + def _init_weights(self, cell): + """ + Initialize the weights. This method should be overridden by derived class and is + the only initialization method that will be called when loading a checkpoint + using `from_pretrained`. Any attempt to initialize outside of this function + will be useless as the torch.nn.init function are all replaced with skip. + """ + + def _initialize_weights(self, module): + """ + Initialize the weights if they are not already initialized. + """ + if getattr(module, "_is_initialized", False): + return + self._init_weights(module) + module._is_initialized = True + + @property + def base_model(self): + """ + to get base_model + """ + return getattr(self, self.base_model_prefix, self) + + def get_input_embeddings(self) -> "nn.Cell": + """ + Returns the model's input embeddings. + + Returns: + :obj:`nn.Cell`: A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.get_input_embeddings() + raise NotImplementedError + + def set_input_embeddings(self, new_embeddings: nn.Cell): + """ + Set model's input embeddings. + + Args: + value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.set_input_embeddings(new_embeddings) + raise NotImplementedError + + def resize_position_embeddings(self, new_num_position_embeddings: int): + """ + resize the model position embeddings if necessary + """ + raise NotImplementedError( + f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__}" + ) + + def get_output_embeddings(self): + """ Get model's output embeddings + Return None if the model doesn't have output embeddings + """ + return None # Overwrite for models with output embeddings + + def set_output_embeddings(self, new_embeddings: nn.Cell): + """ + Set model's output embeddings. + + Args: + value (:obj:`nn.Cell`): A mindspore cell mapping vocabulary to hidden states. + """ + base_model = getattr(self, self.base_model_prefix, self) + if base_model is not self: + return base_model.set_output_embeddings(new_embeddings) + raise NotImplementedError + + def get_position_embeddings(self): + """ + get the model position embeddings if necessary + """ + raise NotImplementedError( + f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " + f"overwrite this method in the class {self.__class__}" + ) + + def tie_weights(self): + """ + Make sure we are sharing the input and output embeddings. + If you need this feature, + you need to get it yourself output Add the output you need to add to the embeddings function_ Embedding layer, + otherwise you cannot + """ + if getattr(self.config, "tie_word_embeddings", True): + output_embeddings = self.get_output_embeddings() # pylint: disable=assignment-from-none + if output_embeddings is not None: + self._tie_or_clone_weights( + output_embeddings, self.get_input_embeddings()) + + if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): + if hasattr(self, self.base_model_prefix): + self = getattr(self, self.base_model_prefix) # pylint: disable=self-cls-assignment + self._tie_encoder_decoder_weights( + self.encoder, self.decoder, self.base_model_prefix) + + for _, cell in self.cells_and_names(): + if hasattr(cell, "_tie_weights"): + cell._tie_weights() + + @staticmethod + def _tie_encoder_decoder_weights(encoder: nn.Cell, decoder: nn.Cell, base_model_prefix: str): + """tie encoder decoder weights""" + uninitialized_encoder_weights: List[str] = [] + if decoder.__class__ != encoder.__class__: + logger.info( + f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" + " weights are correctly initialized." + ) + + def tie_encoder_to_decoder_recursively( + decoder_pointer: nn.Cell, + encoder_pointer: nn.Cell, + module_name: str, + uninitialized_encoder_weights: List[str], + depth=0, + ): + assert isinstance(decoder_pointer, nn.Cell) and isinstance( + encoder_pointer, nn.Cell + ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" + if hasattr(decoder_pointer, "weight"): + assert hasattr(encoder_pointer, "weight") + encoder_pointer.weight = decoder_pointer.weight + encoder_pointer._params['weight'] = decoder_pointer.weight + if hasattr(decoder_pointer, "bias"): + assert hasattr(encoder_pointer, "bias") + encoder_pointer.bias = decoder_pointer.bias + encoder_pointer._params['bias'] = decoder_pointer.bias + return + + encoder_cells = encoder_pointer._cells + decoder_cells = decoder_pointer._cells + if len(decoder_cells) > 0: + assert ( + len(encoder_cells) > 0 + ), f"Encoder cell {encoder_pointer} does not match decoder cell {decoder_pointer}" + + all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_cells.keys()} + encoder_layer_pos = 0 + for name, _ in decoder_cells.items(): + if name.isdigit(): + encoder_name = str(int(name) + encoder_layer_pos) + decoder_name = name + if not isinstance(decoder_cells[decoder_name], type(encoder_cells[encoder_name])) and len( + encoder_cells + ) != len(decoder_cells): + # this can happen if the name corresponds to the position in a list module list of layers + # in this case the decoder has added a cross-attention that the encoder does not have + # thus skip this step and subtract one layer pos from encoder + encoder_layer_pos -= 1 + continue + elif name not in encoder_cells: + continue + elif depth > 500: + raise ValueError( + "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" + " a circular dependency between two or more `nn.Cell` of your model." + ) + else: + decoder_name = encoder_name = name + tie_encoder_to_decoder_recursively( + decoder_cells[decoder_name], + encoder_cells[encoder_name], + module_name + "/" + name, + uninitialized_encoder_weights, + depth=depth + 1, + ) + all_encoder_weights.remove(module_name + "/" + encoder_name) + + uninitialized_encoder_weights += list(all_encoder_weights) + + # tie weights recursively + tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) + if len(uninitialized_encoder_weights) > 0: + logger.warning( + f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" + ) + + def _tie_or_clone_weights(self, output_embeddings, input_embeddings): + """ Tie or clone module weights depending of weither we are using or not + """ + if hasattr(output_embeddings, 'weight'): + output_embeddings.weight = input_embeddings.weight + output_embeddings._params['weight'] = input_embeddings.weight + + if getattr(output_embeddings, "bias", None) is not None: + if output_embeddings.weight.shape[0] == output_embeddings.bias.shape[0]: + pass + else: + # instantial a new Parameter since mindspore.Parameter do not support assign_value with different shape + replace_references(output_embeddings.bias, Parameter(ops.pad( + output_embeddings.bias.data, + (0, output_embeddings.weight.shape[0] - + output_embeddings.bias.shape[0]), + "constant", + 0, + ), name=output_embeddings.bias.name, requires_grad=output_embeddings.bias.requires_grad)) + + if hasattr(output_embeddings, "out_channels") and hasattr(input_embeddings, "vocab_size"): + output_embeddings.out_channels = input_embeddings.vocab_size + + def resize_token_embeddings( + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + ) -> nn.Embedding: + """ + Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. + + Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. + + Arguments: + new_num_tokens (`int`, *optional*): + The number of new tokens in the embedding matrix. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just + returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. + + Return: + `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. + """ + model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) + if new_num_tokens is None and pad_to_multiple_of is None: + return model_embeds + # Update base model and current model config + self.config.vocab_size = model_embeds.weight.shape[0] + self.vocab_size = model_embeds.weight.shape[0] + # Tie weights again if needed + self.tie_weights() + + return model_embeds + + def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) + + self.set_input_embeddings(new_embeddings) + self.get_input_embeddings().weight.data_sync(True) + # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: + new_num_tokens = new_embeddings.weight.shape[0] + # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: + old_lm_head = self.get_output_embeddings() # pylint: disable=assignment-from-none + new_lm_head = self._get_resized_lm_head( + old_lm_head, new_num_tokens) + self.set_output_embeddings(new_lm_head) + self.get_output_embeddings().weight.data_sync(True) + + + return self.get_input_embeddings() + + def resize_tokenizer_embeddings(self, new_num_tokens): + """ + Obtain a new embedding layer or use the original one without updating it. + """ + old_embeddings = self.get_input_embeddings() + new_embeddings = self._get_resized_embeddings( + old_embeddings, new_num_tokens) + self.set_input_embeddings(new_embeddings) + return self.get_input_embeddings() + + def _get_resized_embeddings( + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, + ) -> nn.Embedding: + """ Build a resized Embedding Module from a provided token Embedding Module. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + + Args: + new_num_tokens: (`optional`) int + New number of tokens in the embedding matrix. + Increasing the size will add newly initialized vectors at the end + Reducing the size will remove vectors from the end + If not provided or None: return the provided token Embedding Module. + Return: ``mindspore.nn.Embeddings`` + Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None + """ + if pad_to_multiple_of is not None: + if not isinstance(pad_to_multiple_of, int): + raise ValueError( + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + ) + if new_num_tokens is None: + new_num_tokens = old_embeddings.weight.shape[0] + new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of + else: + logger.info( + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" + f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." + " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" + ) + + if new_num_tokens is None: + return old_embeddings + + old_num_tokens, old_embedding_dim = old_embeddings.weight.shape + if old_num_tokens == new_num_tokens: + return old_embeddings + + # Build new embeddings + new_embeddings = nn.Embedding(new_num_tokens, old_embedding_dim) + + # initialize all new embeddings (in particular added tokens) + self._init_weights(new_embeddings) + + # Copy word embeddings from the previous weights + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[ + :num_tokens_to_copy, :] + return new_embeddings + + def _get_resized_lm_head( + self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + ) -> nn.Dense: + """ + Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized + vectors at the end. Reducing the size will remove vectors from the end + + Args: + old_lm_head (`nn.Dense`): + Old lm head liner layer to be resized. + new_num_tokens (`int`, *optional*): + New number of tokens in the linear matrix. + + Increasing the size will add newly initialized vectors at the end. Reducing the size will remove + vectors from the end. If not provided or `None`, just returns a pointer to the input tokens + `nn.Dense` module of the model without doing anything. transposed (`bool`, *optional*, defaults + to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, + vocab_size` else `vocab_size, lm_head_dim`. + + Return: + `nn.Dense`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is + `None` + """ + if new_num_tokens is None: + return old_lm_head + + old_num_tokens, old_lm_head_dim = ( + old_lm_head.weight.shape if not transposed else old_lm_head.weight.T.shape + ) + + if old_num_tokens == new_num_tokens: + return old_lm_head + + if not isinstance(old_lm_head, nn.Dense): + raise TypeError( + f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Dense}. You" + " should either use a different resize function or make sure that `old_lm_head` are an instance of" + f" {nn.Dense}." + ) + + # Build new lm head + new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim) + has_new_lm_head_bias = old_lm_head.bias is not None + + # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init + # because the shape of the new embedding layer is used across various modeling files + # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading + # to errors when training. + new_lm_head = nn.Dense( + *new_lm_head_shape, + has_bias=has_new_lm_head_bias, + ) + + # initialize new lm head (in particular added tokens) + self._init_weights(new_lm_head) + + num_tokens_to_copy = min(old_num_tokens, new_num_tokens) + + self._copy_lm_head_original_to_resized( + new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ) + + return new_lm_head + + def _copy_lm_head_original_to_resized( + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + ): + # Copy old lm head weights to new lm head + if not transposed: + new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] + else: + new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] + + # Copy bias weights to new lm head + if has_new_lm_head_bias: + new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] + + + @classmethod + def load(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, **kwargs): + """ + Load a pre-trained checkpoint from a pre-trained model file or url, + download and cache the pre-trained model file if model name in model list. + + Params: + pretrained_model_name_or_path: + """ + return cls.from_pretrained(pretrained_model_name_or_path, args, kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *model_args, + config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, + cache_dir: Optional[Union[str, os.PathLike]] = None, + ignore_mismatched_sizes: bool = False, + force_download: bool = False, + local_files_only: bool = False, + token: Optional[Union[str, bool]] = None, + use_safetensors: bool = None, + mirror: str = 'huggingface', + **kwargs, + ): + """from_pretrained""" + state_dict = kwargs.pop("state_dict", None) + cache_dir = kwargs.pop("cache_dir", None) + _ = kwargs.pop("from_pt", True) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", False) + _fast_init = kwargs.pop("_fast_init", True) + output_loading_info = kwargs.pop("output_loading_info", False) + subfolder = kwargs.pop("subfolder", "") + variant = kwargs.pop("variant", None) + ms_dtype = kwargs.pop("ms_dtype", None) + _ = kwargs.pop('low_cpu_mem_usage', None) + revision = kwargs.pop('revision', 'main') + + if use_safetensors is None and not is_safetensors_available(): + use_safetensors = False + + is_sharded = False + # Load config if we don't provide a configuration + if not isinstance(config, PretrainedConfig): + config_path = config if config is not None else pretrained_model_name_or_path + config, model_kwargs = cls.config_class.from_pretrained( + config_path, + *model_args, + cache_dir=cache_dir, + return_unused_kwargs=True, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + mirror=mirror, + **kwargs, + ) + else: + model_kwargs = kwargs + + # Load model + if pretrained_model_name_or_path is not None: + pretrained_model_name_or_path = str(pretrained_model_name_or_path) + is_local = os.path.isdir(pretrained_model_name_or_path) + if is_local: + if os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME) + ): + # Load from a PyTorch checkpoint + archive_file = os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) + ): + # Load from a MindSpore checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) + ) + elif use_safetensors is not False and os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) + ): + # Load from a safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) + ) + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded PyTorch checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif os.path.isfile( + os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) + ): + # Load from a sharded MindSpore checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + elif use_safetensors is not False and os.path.isfile( + os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + ): + # Load from a sharded safetensors checkpoint + archive_file = os.path.join( + pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) + ) + is_sharded = True + # At this stage we don't have a weight file so we will raise an error. + elif use_safetensors: + raise EnvironmentError( + f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" + f" {pretrained_model_name_or_path}." + ) + else: + raise EnvironmentError( + f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {PT_WEIGHTS_NAME}," + f" found in directory {pretrained_model_name_or_path}." + ) + elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): + archive_file = pretrained_model_name_or_path + is_local = True + elif is_remote_url(pretrained_model_name_or_path): + filename = pretrained_model_name_or_path + resolved_archive_file = download_url(pretrained_model_name_or_path) + else: + if use_safetensors is not False: + filename = _add_variant(SAFE_WEIGHTS_NAME, variant) + else: + filename = _add_variant(WEIGHTS_NAME, variant) + + try: + # Load from URL or cache if already cached + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "proxies": proxies, + "resume_download": resume_download, + "local_files_only": local_files_only, + "subfolder": subfolder, + "_raise_exceptions_for_missing_entries": False, + 'revision': revision, + "token": token, + 'mirror': mirror + } + # try safetensors + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + use_safetensors = resolved_archive_file is not None + + # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None + # result when internet is up, the repo and revision exist, but the file does not. + if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + use_safetensors = True + + if resolved_archive_file is None: + filename = _add_variant(WEIGHTS_NAME, variant) + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + filename = _add_variant(PT_WEIGHTS_NAME, variant) + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + + if resolved_archive_file is None and filename == _add_variant(PT_WEIGHTS_NAME, variant): + # Maybe the checkpoint is sharded, we try to grab the index name in this case. + resolved_archive_file = cached_file( + pretrained_model_name_or_path, + _add_variant(PT_WEIGHTS_INDEX_NAME, variant), + **cached_file_kwargs, + ) + if resolved_archive_file is not None: + is_sharded = True + + if resolved_archive_file is None: + raise EnvironmentError( + f"{pretrained_model_name_or_path} does not appear to have a file named" + f" {_add_variant(SAFE_WEIGHTS_NAME, variant)}, {_add_variant(PT_WEIGHTS_NAME, variant)}" + ) + except EnvironmentError: + # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted + # to the original exception. + raise + except Exception as exc: + # For any other exception, we throw a generic error. + raise EnvironmentError( + f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" + ", make sure you don't have a local directory with the" + f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" + f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," + f" {_add_variant(PT_WEIGHTS_NAME, variant)}." + ) from exc + + if is_local: + logger.info(f"loading weights file {archive_file}") + resolved_archive_file = archive_file + else: + logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") + else: + resolved_archive_file = None + + if is_sharded: + # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. + resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( + pretrained_model_name_or_path, + resolved_archive_file, + cache_dir=cache_dir, + force_download=force_download, + proxies=proxies, + resume_download=resume_download, + local_files_only=local_files_only, + token=token, + subfolder=subfolder, + revision=revision, + mirror=mirror, + ) + + if pretrained_model_name_or_path is None and state_dict is None: + raise ValueError("the argument 'pretrained_model_name_or_path' should be " + "a string of model name or checkpoint path, but got 'None'.") + + config.name_or_path = pretrained_model_name_or_path + # Instantiate model. + + config_dict = config.to_dict() + + dtype_group = {key: getattr(config, key).ms_dtype for key in config_dict.keys() \ + if isinstance(config_dict[key], dict) and 'ms_dtype' in config_dict[key]} + + if ms_dtype is None or ms_dtype == 'auto': + ms_dtype = config.ms_dtype + + if ms_dtype is None: + ms_dtype = mindspore.float32 + + use_fp16 = False + usage_dtype = mindspore.dtype_to_nptype(ms_dtype) + if ms_dtype == mindspore.bfloat16: + ms_dtype = mindspore.float16 + usage_dtype = np.float16 + use_fp16 = True + + def empty_initializer(init, shape=None, dtype=mindspore.float32): + if not isinstance(shape, (tuple, list)): + shape = (shape,) + if dtype in (mindspore.float16, mindspore.float32) \ + and ms_dtype is not None: + dtype = ms_dtype + return Tensor_(shape=shape, dtype=dtype) + + with no_init_weights(empty_initializer, _fast_init): + model = cls(config, *model_args, **model_kwargs) + + if ms_dtype != mindspore.float32: + set_global_fp16(False) + + if is_sharded: + converted_filenames = resolved_archive_file + + # tie the model weights before retrieving the state_dict + model.tie_weights() + + ptrs = collections.defaultdict(list) + for name, tensor in model.parameters_dict().items(): + id_tensor = id(tensor) + ptrs[id_tensor].append(name) + + # These are all the pointers of shared tensors. + tied_params = [names for _, names in ptrs.items() if len(names) > 1] + def load_ckpt(resolved_archive_file): + if not resolved_archive_file.endswith('ckpt'): + if use_safetensors or 'safetensors' in resolved_archive_file: + from safetensors.numpy import load_file + origin_state_dict = load_file(resolved_archive_file) + if use_fp16: + logger.warning_once("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16") + state_dict = {k: Parameter(Tensor.from_numpy(v.astype(usage_dtype))) for k, v in origin_state_dict.items()} + else: + state_dict = load(resolved_archive_file) + else: + try: + state_dict = load_checkpoint(str(resolved_archive_file)) + except Exception as exc: + raise OSError( + f"Unable to load weights from mindspore checkpoint file '{resolved_archive_file}'. " + ) from exc + + state_keys = list(state_dict.keys()) + for key in state_keys: + new_key = key.replace('gamma', 'weight').replace('beta', 'bias').replace('embedding_table', 'weight') + if new_key != key: + state_dict[new_key] = state_dict.pop(key) + return state_dict + + keys_missing = list(model.parameters_dict().keys()) + param_id_set = set() + + use_keep_in_fp32_modules = False + if model._keep_in_fp32_modules: + use_keep_in_fp32_modules = True + + remove_prefix_from_model = None + add_prefix_to_model = None + + def fix_weight_norm_missing_keys(state_dict_keys: dict, keys_missing:List[str]) -> List[str]: + ''' if both `weight_g` and `weight_v` are loaded, key `weight` is not missing :) ''' + non_missing_keys = [] + for key in keys_missing: + if f'{key}_g' in state_dict_keys and f'{key}_v' in state_dict_keys: + non_missing_keys.append(key) + return non_missing_keys + + def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_group: dict = None): + state_dict_keys = list(param_dict.keys()) + keep_in_fp32_modules = model._keep_in_fp32_modules + keys_unexpected = list(param_dict.keys()) + + has_prefix_module = any(s.startswith(prefix) for s in keys_unexpected) + expects_prefix_module = any(s.startswith(prefix) for s in keys_missing) + + nonlocal remove_prefix_from_model + nonlocal add_prefix_to_model + remove_prefix_from_model = not has_prefix_module and expects_prefix_module + add_prefix_to_model = has_prefix_module and not expects_prefix_module + + for pname_in_net, param in model.parameters_and_names(): + if add_prefix_to_model: + param_name = prefix + '.' + pname_in_net + elif remove_prefix_from_model: + param_name = pname_in_net.replace(f'{prefix}.', '') + else: + param_name = pname_in_net + + if param.uuid in param_id_set: + # for tied params + if param_name in keys_unexpected: + keys_unexpected.remove(param_name) + continue + + new_param = param_dict.pop(param_name, None) + + module_dtype = None + for m_name, m_dtype in dtype_group.items(): + if m_name in param_name: + module_dtype = m_dtype + break + + if new_param is not None: + use_replace = False + if new_param.shape != param.shape: + if not ignore_mismatched_sizes: + raise RuntimeError(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter' + f' `{param_name} with shape {new_param.shape} in checkpoint, ' + f'\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.') + logger.warning(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter' + f' `{param_name} with shape {new_param.shape} in checkpoint, ') + continue + + if use_keep_in_fp32_modules and \ + any(module_to_keep_in_fp32 in pname_in_net.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): + new_param = new_param.astype(mindspore.float32) + elif module_dtype and param.dtype in (mindspore.float32, mindspore.float16): + new_param = new_param.astype(module_dtype) + elif ms_dtype and param.dtype in (mindspore.float32, mindspore.float16): + new_param = new_param.astype(ms_dtype) + + if new_param.dtype != param.dtype or new_param.shape != param.shape: + use_replace = True + + if use_replace: + if isinstance(new_param, Parameter): + new_param.name = param.name + new_param.requires_grad = param.requires_grad + replace_references(param, new_param) + else: + replace_references(param, Parameter(new_param, requires_grad=param.requires_grad, name=param.name)) + else: + param.set_data(new_param) + keys_unexpected.remove(param_name) + keys_missing.remove(pname_in_net) + param_id_set.add(param.uuid) + else: + # fix missing value parameter dtype cast. + if ms_dtype and ms_dtype != param.dtype: + new_param = param.astype(ms_dtype) + replace_references(param, Parameter(new_param, name=param.name, requires_grad=param.requires_grad)) + + # NOTE: monkey patching weight_norm + for key in fix_weight_norm_missing_keys(state_dict_keys, keys_missing): + keys_missing.remove(key) + + return keys_unexpected, keys_missing + + all_keys_unexpected = None + if state_dict is None: + if is_sharded: + all_keys_unexpected = [] + for name in tqdm(converted_filenames, desc="Loading checkpoint shards"): + state_dict = load_ckpt(name) + keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) + all_keys_unexpected.extend(keys_unexpected) + del state_dict + gc.collect() + loaded_keys = sharded_metadata["all_checkpoint_keys"] + else: + state_dict = load_ckpt(resolved_archive_file) + loaded_keys = list(state_dict.keys()) + all_keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) + else: + loaded_keys = list(state_dict.keys()) + all_keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) + + loaded_add_keys = [] + for group in tied_params: + missing_in_group = [k for k in keys_missing if k in group] + if len(missing_in_group) > 0 and len(missing_in_group) < len(group): + loaded_add_keys.extend([k for k in keys_missing if k in missing_in_group]) + keys_missing = [k for k in keys_missing if k not in missing_in_group] + if cls._keys_to_ignore_on_load_missing is not None: + for pat in cls._keys_to_ignore_on_load_missing: + keys_missing = [k for k in keys_missing if re.search(pat, k) is None] + + if cls._keys_to_ignore_on_load_unexpected is not None: + for pat in cls._keys_to_ignore_on_load_unexpected: + all_keys_unexpected = [k for k in all_keys_unexpected if re.search(pat, k) is None] + + # make sure token embedding weights are still tied if needed + model.tie_weights() + + # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. + if _fast_init: + if not ignore_mismatched_sizes: + if remove_prefix_from_model: + _loaded_keys = [f"{cls.base_model_prefix}.{k}" for k in loaded_keys] + elif add_prefix_to_model: + _loaded_keys = [k[len(cls.base_model_prefix) + 1 :] for k in loaded_keys] + else: + _loaded_keys = loaded_keys + + _loaded_keys += loaded_add_keys + _ = set_initialized_submodules(model, _loaded_keys) + else: + _ = dict(model.cells_and_names()) + + model.apply(model._initialize_weights) + + # Set model in evaluation mode to deactivate DropOut modules by default + model.set_train(False) + + # If it is a model with generation capabilities, attempt to load the generation config + if model.can_generate() and pretrained_model_name_or_path is not None: + try: + model.generation_config = GenerationConfig.from_pretrained( + pretrained_model_name_or_path, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + subfolder=subfolder, + revision=revision, + **kwargs, + ) + except OSError: + logger.info( + "Generation config file not found, using a generation config created from the model config." + ) + + if output_loading_info: + loading_info = { + "missing_keys": keys_missing, + "unexpected_keys": all_keys_unexpected, + } + return model, loading_info + + if all_keys_unexpected: + logger.warning(f'The following parameters in checkpoint files are not loaded:\n' + f'{all_keys_unexpected}') + if keys_missing: + logger.warning(f'The following parameters in models are missing parameter:\n' + f'{keys_missing}') + return model + + def save(self, save_dir): + """ Save a model and its configuration file to a directory, so that + it can be re-loaded using the `:func:`PreTrainedModel.from_pretrained`` class method. + + Arguments: + save_dir: directory to which to save. + """ + if os.path.isfile(save_dir): + logger.error(f"Provided path ({save_dir}) should be a directory, not a file") + return + os.makedirs(save_dir, exist_ok=True) + + # Only save the model itself if we are using distributed training + model_to_save = self.cell if hasattr(self, "cell") else self + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # If we save using the predefined names, we can load using `from_pretrained` + output_model_file = os.path.join(save_dir, WEIGHTS_NAME) + save_checkpoint(model_to_save, output_model_file) + + logger.info(f"Model weights saved in {output_model_file}") + + @classmethod + def can_generate(cls) -> bool: + """ + Returns whether this model can generate sequences with `.generate()`. + + Returns: + `bool`: Whether this model can generate sequences with `.generate()`. + """ + # Detects whether `prepare_inputs_for_generation` has been overwritten, which is a requirement for generation. + # Alternativelly, the model can also have a custom `generate` function. + if "GenerationMixin" in str(cls.prepare_inputs_for_generation) and "GenerationMixin" in str(cls.generate): + return False + return True + + def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask): + """ + Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given. + """ + if (attention_mask is not None) or (self.config.pad_token_id is None): + return + + # Check only the first and last input IDs to reduce overhead. + if self.config.pad_token_id in input_ids[:, [-1, 0]]: + warn_string = ( + "We strongly recommend passing in an `attention_mask` since your input_ids may be padded." + ) + + # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an + # attention_mask or not. In this case, we should still show a warning because this is a rare case. + if ( + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + ): + warn_string += ( + f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " + f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), " + f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded." + ) + + logger.warning(warn_string) + + def parameters_and_names(self, name_prefix='', expand=True): + """ + fix ignore tied weights + """ + cells = [] + if expand: + cells = self.cells_and_names(name_prefix=name_prefix) + else: + cells.append((name_prefix, self)) + + for cell_name, cell in cells: + params = cell._params.items() + for par_name, par in params: + if par is not None and par.inited_param is not None: + par = par.inited_param + if par is not None: + par_new_name = par_name + if cell_name: + par_new_name = cell_name + '.' + par_new_name + + yield par_new_name, par + + def num_parameters(self, only_trainable=False): + """return parameters count""" + total = 0 + param_set = set() + for param in self.get_parameters(): + param_id = param.uuid + if param_id not in param_set and (only_trainable or param.requires_grad): + total += param.size + param_set.add(param_id) + return total + + def trainable_params(self, recurse=True): + """ + fix duplicated weights + """ + return list(set(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))) + + def check_names_and_refresh_name(self): + """ + fix ignore tied weights + """ + if not hasattr(self, "_params"): + return + all_name = dict(self.parameters_and_names()).keys() + + if len(set(all_name)) < len(all_name): + self.update_parameters_name() + self.check_names() + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = mindspore.save_checkpoint, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + **kwargs, + ): + """ + Save a model and its configuration file to a directory, so that it can be re-loaded using the + [`~PreTrainedModel.from_pretrained`] class method. + + Arguments: + save_directory (`str` or `os.PathLike`): + Directory to which to save. Will be created if it doesn't exist. + is_main_process (`bool`, *optional*, defaults to `True`): + Whether the process calling this is the main process or not. Useful when in distributed training like + TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on + the main process to avoid race conditions. + state_dict (nested dictionary of `torch.Tensor`): + The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only + save parts of the model or if special precautions need to be taken when recovering the state dictionary + of a model (like when using model parallelism). + save_function (`Callable`): + The function to use to save the state dictionary. Useful on distributed training like TPUs when one + need to replace `torch.save` by another method. + max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): + The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size + lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). + We default it to 5GB in order for models to be able to run easily on free-tier google colab instances + without CPU OOM issues. + + + + If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard + which will be bigger than `max_shard_size`. + + + variant (`str`, *optional*): + If specified, weights are saved in the format pytorch_model..bin. + save_peft_format (`bool`, *optional*, defaults to `True`): + For backward compatibility with PEFT library, in case adapter weights are attached to the model, all + keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can + disable this behaviours by setting `save_peft_format` to `False`. + kwargs (`Dict[str, Any]`, *optional*): + Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. + """ + + if os.path.isfile(save_directory): + logger.error(f"Provided path ({save_directory}) should be a directory, not a file") + return + + os.makedirs(save_directory, exist_ok=True) + + + # Only save the model itself if we are using distributed training + model_to_save = self + + # save the string version of dtype to the config, e.g. convert torch.float32 => "float32" + # we currently don't use this setting automatically, but may start to use with v5 + dtype = get_parameter_dtype(model_to_save) + model_to_save.config.ms_dtype = str(dtype).lower() + + # Attach architecture to the config + model_to_save.config.architectures = [model_to_save.__class__.__name__] + + # Save the config + if is_main_process: + model_to_save.config.save_pretrained(save_directory) + if self.can_generate(): + model_to_save.generation_config.save_pretrained(save_directory) + + + # Save the model + if state_dict is None: + state_dict = model_to_save.parameters_dict() + + # Handle the case where some state_dict keys shouldn't be saved + if self._keys_to_ignore_on_save is not None: + for ignore_key in self._keys_to_ignore_on_save: + if ignore_key in state_dict.keys(): + del state_dict[ignore_key] + + # Shard the model if it is too big. + # weights_name = _add_variant(WEIGHTS_NAME, variant) + weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME + + shards, index = shard_checkpoint(state_dict, max_shard_size=max_shard_size, weights_name=weights_name) + + # Clean the folder from a previous save + for filename in os.listdir(save_directory): + full_filename = os.path.join(save_directory, filename) + # If we have a shard file that is not going to be replaced, we delete it, but only from the main process + # in distributed settings to avoid race conditions. + weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "") + + # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005 + filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "") + reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") + + if ( + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None + ): + os.remove(full_filename) + + # Save the model + for shard_file, shard in shards.items(): + if safe_serialization: + # At some point we will need to deal better with save_function (used for TPU and other distributed + # joyfulness), but for now this enough. + safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "np"}) + else: + save_function(shard, os.path.join(save_directory, shard_file)) + + if index is None: + path_to_weights = os.path.join(save_directory, _add_variant(WEIGHTS_NAME, variant)) + logger.info(f"Model weights saved in {path_to_weights}") + else: + save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME + save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant)) + # Save the index as well + with open(save_index_file, "w", encoding="utf-8") as f: + content = json.dumps(index, indent=2, sort_keys=True) + "\n" + f.write(content) + logger.info( + f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be " + f"split in {len(shards)} checkpoint shards. You can find where each parameters has been saved in the " + f"index located at {save_index_file}." + ) + + def enable_recompute(self): + """Activates recompute (aka gradient checkpointing) for the current model.""" + if not self.supports_recompute: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + + for _, cell in self.cells_and_names(): + if hasattr(cell, "_set_recompute"): + cell._set_recompute() + + def check_names(self): + pass + + +class TensorType(ExplicitEnum): + """ + Possible values for the `return_tensors` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for + tab-completion in an IDE. + """ + + MINDSPORE = "ms" + NUMPY = "np" + + +class PaddingStrategy(ExplicitEnum): + """ + Possible values for the `padding` argument in [`PreTrainedTokenizerBase.__call__`]. Useful for tab-completion in an + IDE. + """ + + LONGEST = "longest" + MAX_LENGTH = "max_length" + DO_NOT_PAD = "do_not_pad" + + +class StoppingCriteria(): + """Abstract base class for all stopping criteria that can be applied during generation.""" + + @staticmethod + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + raise NotImplementedError("StoppingCriteria needs to be subclassed") + + +class MaxLengthCriteria(StoppingCriteria): + """ + This class can be used to stop generation whenever the full generated number of tokens exceeds `max_length`. Keep + in mind for decoder-only type of transformers, this will include the initial prompted tokens. + """ + + def __init__(self, max_length: int): + self.max_length = max_length + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + return input_ids.shape[-1] >= self.max_length + + +@dataclass +class BaseModelOutput(ModelOutputMindnlp): + """ + Base class for model's outputs, with potential hidden states and attentions. + """ + + last_hidden_state: mindspore.Tensor = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class BaseModelOutputWithPast(ModelOutputMindnlp): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ + + last_hidden_state: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class BaseModelOutputWithPastAndCrossAttentions(ModelOutputMindnlp): + """ + Base class for model's outputs that may also contain a past key/values (to speed up sequential decoding). + """ + + last_hidden_state: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + cross_attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class CausalLMOutputWithPast(ModelOutputMindnlp): + """ + Base class for causal language model (or autoregressive) outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + + +@dataclass +class CausalLMOutputWithCrossAttentions(ModelOutputMindnlp): + """ + Base class for causal language model (or autoregressive) outputs. + """ + + loss: Optional[mindspore.Tensor] = None + logits: mindspore.Tensor = None + past_key_values: Optional[Tuple[Tuple[mindspore.Tensor]]] = None + hidden_states: Optional[Tuple[mindspore.Tensor]] = None + attentions: Optional[Tuple[mindspore.Tensor]] = None + cross_attentions: Optional[Tuple[mindspore.Tensor]] = None + + +def validate_stopping_criteria(stopping_criteria: StoppingCriteriaList, max_length: int) -> StoppingCriteriaList: + """validate stopping criteria""" + stopping_max_length = stopping_criteria.max_length + new_stopping_criteria = deepcopy(stopping_criteria) + if stopping_max_length is not None and stopping_max_length != max_length: + warnings.warn("You set different `max_length` for stopping criteria and `max_length` parameter", UserWarning) + elif stopping_max_length is None: + new_stopping_criteria.append(MaxLengthCriteria(max_length=max_length)) + return new_stopping_criteria + +def _is_numpy(x): + return isinstance(x, np.ndarray) + +def is_numpy_array(x): + """ + Tests if `x` is a numpy array or not. + """ + return _is_numpy(x) + +def infer_framework_from_repr(x): + """ + Tries to guess the framework of an object `x` from its repr (brittle but will help in `is_tensor` to try the + frameworks in a smart order, without the need to import the frameworks). + """ + representation = str(type(x)) + if representation.startswith("= input_rank: + raise ValueError(f"`dim` should be in range [{-input_rank}, {input_rank}), but got {input_rank, dim}") + + _sizes_mul = reduce(operator.mul, list(sizes)) + if -1 not in sizes and _sizes_mul != input_shape[_dim]: + raise ValueError(f"unflatten: Provided `sizes` {sizes} don't multiply up to the" + f"size of dim {dim} ({input_shape[_dim]}) in the input tensor") + + out_shape = input_shape[:_dim] + tuple(sizes) + input_shape[_dim + 1:] + return out_shape + + +# For all backend +# For functional api +# matmul +origin_matmul = ops.matmul +ops.matmulcus = fp16_patch_decorator(origin_matmul) +# mm +ops.mmcus = fp16_patch_decorator(origin_matmul) +# addbmm +origin_addbmm = ops.addbmm +ops.addbmmcus = fp16_patch_decorator(origin_addbmm) +# addmm +origin_addmm = ops.addmm +ops.addmmcus = fp16_patch_decorator(origin_addmm) +# addmv +origin_addmv = ops.addmv +ops.addmvcus = fp16_patch_decorator(origin_addmv) +# addr +origin_addr = ops.addr +ops.addrcus = fp16_patch_decorator(origin_addr) +# baddbmm +origin_baddbmm = ops.baddbmm +ops.baddbmmcus = fp16_patch_decorator(origin_baddbmm) +# bmm +origin_bmm = ops.bmm +ops.bmmcus = fp16_patch_decorator(origin_bmm) + + +# dense +def origin_dense(input_dense, weight, bias=None): + """patched dense""" + dense_ = _get_cache_prim(ops.Dense)() + return dense_(input_dense, weight, bias) + +ops.densecus = fp16_patch_decorator(origin_dense) + +def einsum_label_to_index(label): + if label == '.': + return 52 + NUM_OF_LETTERS = ord('z') - ord('a') + 1 + return (ord(label) - ord('A')) if (label.isupper()) else (NUM_OF_LETTERS + (ord(label) - ord('a'))) + +def einsum(equation, *operands): + assert operands, "einsum(): must provide at least one operand" + if isinstance(operands[0], tuple): + operands = operands[0] + + arrow_pos = equation.find("->") + num_ops = len(operands) + op_labels = [[] for _ in range(num_ops)] + lhs = equation[0: arrow_pos] + + curr_op = 0 + found_ell = False + ell_skip = 0 + for i, label in enumerate(lhs): + if label == ' ': + continue + if label == '.': + if ell_skip != 0: + ell_skip -= 1 + continue + assert not found_ell, f"einsum(): found {curr_op} for operand for which an ellipsis was already found" + assert i + 2 < len(lhs) and lhs[i + 1] == '.', f"einsum(): found {curr_op} for operand that is not part of any ellipsis" + ell_skip = 2 + op_labels[curr_op].append(ELLIPSIS) + found_ell = True + elif label == ',': + curr_op += 1 + assert curr_op < num_ops, "einsum(): fewer operands were provided than specified in the equation" + found_ell = False + else: + assert str.isalpha(label), f"einsum(): invalid subscript given at index {i} in the equation string, subscripts must be in [a-zA-Z]" + op_labels[curr_op].append(einsum_label_to_index(label)) + + assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation" + # Labels must be within [a-zA-Z]. + TOTAL_LABELS = 52 + label_count = [0] * TOTAL_LABELS + # The maximum number of dimensions covered by any ellipsis, needed when + # unsqueezing missing dimensions from operands to permute and broadcast + ell_num_dim = 0 + + # Compute label frequency and number of dimensions covered by ellipsis + # We do this after parsing labels to make it more readable and simpler + # to compute the number of dimensions covered by ellipsis. + for i, operand in enumerate(operands): + labels = op_labels[i] + ndims = operand.ndim + nlabels = len(labels) + has_ellipsis = False + + for label in labels: + if label == ELLIPSIS: + nlabels -= 1 + has_ellipsis = True + ell_num_dim = max(ell_num_dim, ndims - nlabels) + else: + label_count[label] += 1 + if has_ellipsis: + assert nlabels <= ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \ + f") is more than the number of dimensions ({ndims}) for operand {i}" + else: + assert nlabels == ndims, f"einsum(): the number of subscripts in the equation ({nlabels}" \ + f") does not match the number of dimensions (" \ + f"{ndims}) for operand {i} and no ellipsis was given" + + # We want to align the dimensions of every input tensor to have + # shape out_dims + sum_dims. For this, we create a mapping of label + # to index into the permuted shape. + label_perm_index = [-1] * TOTAL_LABELS + # Current index in the permuted shape + perm_index = 0 + # Start index of ellipsis dimensions in the permuted shape + ell_index = 0 + found_ell = False + + if arrow_pos == -1: + # Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim + found_ell = True + for label, _label_count in enumerate(label_count): + if _label_count == 1: + label_perm_index[label] = perm_index + perm_index += 1 + else: + rhs = equation[arrow_pos + 2:] + ell_skip = 0 + for i, label in enumerate(rhs): + if label == ' ': + continue + if label == '.': + if ell_skip != 0: + ell_skip -= 1 + continue + assert not found_ell, "einsum(): found \'.\' for output but an ellipsis (...) was already found" + assert i + 2 < len(rhs) and rhs[i + 1] == '.', "einsum(): found \'.\' for output that is not part of any ellipsis (...)" + ell_skip = 2 + ell_index = perm_index + perm_index += ell_num_dim + found_ell = True + else: + assert str.isalpha(label), f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " \ + f"in the equation string, subscripts must be in [a-zA-Z]" + + index = einsum_label_to_index(label) + label_perm_index[index] = perm_index + perm_index += 1 + + out_size = perm_index + if not found_ell: + ell_index = perm_index + perm_index += ell_num_dim + + for label in range(TOTAL_LABELS): + if label_count[label] > 0 and label_perm_index[label] == -1: + label_perm_index[label] = perm_index + perm_index += 1 + + # Here we unsqueeze missing dimensions to make all operands have the same + # number of dimensions. We take diagonals for repeated labels within the + # same operand. Finally we permute the operands to align dimensions as + # per the perm_out_index we computed above. + permuted_operands = [] + for i, operand in enumerate(operands): + perm_shape = [-1] * perm_index + label_dim = [-1] * TOTAL_LABELS + operand = operands[i] + labels = op_labels[i] + original_sizes = operand.shape + + j = 0 + for label in labels: + if label == ELLIPSIS: + # Add missing dimensions covered by the ellipsis + num_missing_dim = ell_num_dim - \ + (len(original_sizes) - len(labels) + 1) + for k in range(num_missing_dim): + operand = ops.unsqueeze(operand, j) + for k in range(ell_num_dim): + perm_shape[ell_index + k] = j + j += 1 + elif label_dim[label] != -1: + dim = label_dim[label] + operand = ops.diagonal(operand, offset=0, dim1=dim, dim2=j) + operand = ops.moveaxis(operand, -1, dim) + else: + label_dim[label] = j + perm_shape[label_perm_index[label]] = j + j += 1 + + # Add dimensions for missing labels + for idx, index in enumerate(perm_shape): + if index == -1: + operand = ops.unsqueeze(operand, -1) + perm_shape[idx] = j + j += 1 + + operand = ops.transpose(operand, tuple(perm_shape)) + permuted_operands.append(operand) + + # Check if operands broadcast and keep track of last operand with + # dimension size != 1 for optimizing reductions + dim_last_op = [0] * perm_index + has_zero_size_dim = False + for dim in range(perm_index): + broadcast_size = permuted_operands[0].shape[dim] + for i in range(1, len(operands)): + dim_size = permuted_operands[i].shape[dim] + if broadcast_size != dim_size and broadcast_size != 1 and dim_size != 1: + raise RuntimeError("einsum(): operands do not broadcast with remapped shapes [original->remapped]") + if dim_size != 1: + broadcast_size = dim_size + dim_last_op[dim] = i + has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0) + + # Compute result + result = permuted_operands[0] + if has_zero_size_dim: + out_shape = [-1] * out_size + for i in range(out_size): + out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] + return ops.zeros(out_shape) + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for i in range(dim, perm_index): + if dim_last_op[i] == 0: + if result.shape[dim] == 1: + result = ops.squeeze(result, dim) + dim -= 1 + else: + result = ops.sum(result, dim) + dim -= 1 + dim += 1 + + for i in range(1, num_ops): + operand = permuted_operands[i] + sum_dims = [] + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for j in range(dim, perm_index): + if dim_last_op[j] < i: + operand = ops.squeeze(operand, dim) + dim -= 1 + elif dim_last_op[j] == i: + if result.shape[dim] == 1: + operand = ops.sum(operand, dim) + result = ops.squeeze(result, dim) + dim -= 1 + else: + sum_dims.append(dim) + dim += 1 + if len(sum_dims) == 0: + result = result.mul(operand) + elif len(sum_dims) == len(result.shape): + result = result.flatten().dot(operand.flatten()) + else: + result = sumproduct_pair( + result, operand, sum_dims, False) + return result +ops.einsum = einsum + +# conv1d +ops.conv1dcus = fp16_patch_decorator(ops.conv1d) + + +def _ones(*size, dtype=None): + if dtype is None: + dtype = mindspore.float32 + if isinstance(size[0], tuple): + size = size[0] + ones_ = _get_cache_prim(ops.Ones)() + return ones_(size, dtype) + +ops.onescus = _ones + + +def _zeros(*size, dtype=None): + if dtype is None: + dtype = mindspore.float32 + if isinstance(size[0], tuple): + size = size[0] + zeros_ = _get_cache_prim(ops.Zeros)() + return zeros_(size, dtype) + + +ops.zeroscus = _zeros + + +# cross_entropy +def _cross_entropy(input_ce, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): + if weight is None: + weight = ops.ones(input_ce.shape[-1], input.dtype) + _nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) + class_dim = 0 if input_ce.ndim == 1 else 1 + return _nll_loss(ops.log_softmax(input_ce, class_dim), target, weight)[0] + + +# for Tensor +# unfold +def _get_unfold_indices(input_shape, dimension, size, step): + if dimension < 0: + dimension += len(input_shape) + indices = [] + for i in range(0, input_shape[dimension] - size + 1, step): + indices.append(list(range(i, i + size))) + + return indices, dimension + + +def unfold(self, dimension, size, step): + """torch-like unfold""" + _indices, _dimension = _get_unfold_indices(self.shape, dimension, size, step) + indices = mindspore.Tensor(_indices).astype(mindspore.int32) + output = ops.gather(self, indices, axis=_dimension) + output = ops.moveaxis(output, _dimension + 1, -1) + return output + +Tensor.unfold = unfold +StubTensor.unfold = unfold + + +# var_mean +def var_mean(input_vm, axis=None, *, correction=1, keepdims=False): + """torch-like var_mean""" + axis = Validator.check_and_canonicalize_axes(axis, input.ndim) + x_mean = ops.mean(input_vm, axis, True) + x_sub = ops.sub(input_vm, x_mean) + x_pow = ops.pow(x_sub, 2) + x_sum = ops.sum(x_pow, axis, keepdims) + res_mean = ops.mean(input_vm, axis, keepdims) + nums = 1 + if not axis: + nums = input_vm.size + else: + for ax in axis: + nums *= input_vm.shape[ax] + return ops.true_divide(x_sum, nums - correction), res_mean + +ops.var_mean = var_mean + + +# std_mean +def std_mean(input_sm, axis=None, *, correction=1, keepdims=False): + """torch-like std_mean""" + output = var_mean(input_sm, axis, correction=correction, keepdims=keepdims) + return ops.pow(output[0], 0.5), output[1] + +ops.std_mean = std_mean + + +# masked_fill +def masked_fill(inputs, mask, value): + """patched masked_fill""" + masked_value = ops.fill(inputs.dtype, inputs.shape, value) + return ops.select(mask, masked_value, inputs) + + +def _masked_fill(self, mask, value): + return masked_fill(self, mask, value) + + +ops.masked_fill = masked_fill +Tensor.masked_fill = _masked_fill +StubTensor.masked_fill = _masked_fill + + +# ops.std +def std(input_std, axis=None, ddof=0, keepdims=False): + """patched std""" + # Calculate mean + mean = ops.mean(input_std, axis=axis, keep_dims=keepdims) + + # Squared differences from the mean + squared_diff = (input_std - mean)**2 + + # Sum along the specified dimension + if axis is not None: + sum_along_dim = ops.sum(squared_diff, dim=axis, keepdim=keepdims) + else: + sum_along_dim = squared_diff.sum() + + # Calculate the correction factor + factor = 1.0 / (input_std.shape[axis] - ddof) if axis is not None else 1.0 / (input_std.size - ddof) + + # Calculate the standard deviation + out = ops.sqrt(factor * sum_along_dim) + + return out + + +def _std(self, axis=None, ddof=0, keepdims=False): + return std(self, axis, ddof, keepdims) + + +ops.std = std +Tensor.std = _std +StubTensor.std = _std + + +# Tensor.__contains__ +def _contains(self, key): + eq_res = ops.equal(self, key) + res = ops.any(eq_res) + return bool(res) + + +Tensor.__contains__ = _contains +StubTensor.__contains__ = _contains + + +def unflatten(self, dim, sizes): + """Tensor.unflatten""" + out_shape = _get_unflatten_size(self.shape, dim, sizes) + return self.reshape(out_shape) + + +Tensor.unflatten = unflatten +StubTensor.unflatten = unflatten + + +def _as_strided(self, size, stride, storage_offset=None): + if len(size) != len(stride): + raise RuntimeError("mismatch in length of strides and shape.") + index = np.arange(0, size[0] * stride[0], stride[0]) + for i in range(1, len(size)): + tmp = np.arange(0, size[i] * stride[i], stride[i]) + index = np.expand_dims(index, -1) + index = index + tmp + if storage_offset is not None: + index = index + storage_offset + if index.size == 0: + input_indices = mindspore.numpy.empty(index.shape, dtype=mstype.int32) + else: + input_indices = Tensor(index) + out = ops.gather(self.reshape(-1), input_indices, 0) + return out + + +Tensor.as_strided = _as_strided +StubTensor.as_strided = _as_strided + + +def _nonzero(self, as_tuple=False): + if self.dtype == mstype.bool_: + self = self.astype(mstype.int64) + outs = ops.nonzero(self) + if as_tuple: + outs = ops.tensor_split(outs, self.ndim, -1) + outs = tuple(out.squeeze(-1) for out in outs) + return outs + + +Tensor.nonzero = _nonzero +StubTensor.nonzero = _nonzero + + +def _expand(self, *size): + if len(size) == 1: + size = size[0] + return ops.broadcast_to(self, size) + + +Tensor.expand = _expand +StubTensor.expand = _expand + +if LESS_MS_2_2: + mindspore.bfloat16 = None + + def eq(self, other): + """patched eq""" + return ops.equal(self, other) + + Tensor.eq = eq + StubTensor.eq = eq + + def _item(self): + return self.asnumpy().item() + Tensor.item = _item + StubTensor.item = _item + + def _tolist(self): + return self.asnumpy().tolist() + Tensor.tolist = _tolist + StubTensor.tolist = _tolist + +mindspore.tensor = mindspore.Tensor +ops.prod = bool_patch_decorator(ops.prod) + + +def _prod(self, axis=None, keep_dims=False): + return ops.prod(self, axis, keep_dims) +Tensor.prod = _prod +StubTensor.prod = _prod + + +def _eq(self, other): + if not isinstance(other, (int, float, Tensor)): + return False + if isinstance(other, Tensor) and self.shape != other.shape: + return False + if id(self) == id(other): + return True + # bool type is not supported for `Equal` operator in backend. + if self.dtype == mstype.bool_ or (isinstance(other, Tensor) and other.dtype == mstype.bool_): + self = self.to(mstype.int32) + other = other.to(mstype.int32) + return ops.eq(self, other) + + +Parameter.__eq__ = _eq + +old_repeat = Tensor.repeat + + +def new_repeat_interleave(input_ri, repeats, axis=None): + """new repeat_interleave""" + if axis is None: + input_ri = input_ri.reshape(-1) + axis = 0 + if isinstance(repeats, Tensor): + repeats = repeats.asnumpy().tolist() + output = old_repeat(input_ri, repeats, axis) + return output + + +ops.repeat_interleave = bool_io_patch_decorator(new_repeat_interleave) + + +def _repeat_interleave(self, repeats, dim): + return old_repeat(self, repeats, axis=dim) + + +Tensor.repeat_interleave = _repeat_interleave +StubTensor.repeat_interleave = _repeat_interleave + + +def _repeat(self, *sizes): + return ops.tile(self, tuple(sizes)) + + +Tensor.repeat = _repeat +StubTensor.repeat = _repeat + + +# For Cells +class DenseMindnlp(nn.Cell): + """patched Dense""" + def __init__(self, + in_channels, + out_channels, + has_bias=True, + dtype=mstype.float32): + """Initialize Dense.""" + super().__init__() + self.in_channels = Validator.check_positive_int( + in_channels, "in_channels", self.cls_name) + self.out_channels = Validator.check_positive_int( + out_channels, "out_channels", self.cls_name) + self.has_bias = Validator.check_bool( + has_bias, "has_bias", self.cls_name) + + self.weight = Parameter(initializer( + HeUniform(math.sqrt(5)), [out_channels, in_channels], dtype=dtype), name="weight") + + self.bias = None + if self.has_bias: + fan_in, _ = _calculate_fan_in_and_fan_out(self.weight.shape) + bound = 1 / math.sqrt(fan_in) + self.bias = Parameter(initializer( + Uniform(bound), [out_channels], dtype=dtype), name="bias") + + def construct(self, x): + if LESS_MS_2_2: + x_shape = x.shape + if len(x_shape) != 2: + x = x.reshape(-1, x.shape[-1]) + x = ops.matmul(x, self.weight.T) + if self.has_bias: + x = ops.add(x, self.bias) + if len(x_shape) != 2: + out_shape = x_shape[:-1] + (x.shape[-1],) + x = x.reshape(out_shape) + return x + return ops.dense(x, self.weight, self.bias) + + def extend_repr(self): + s = f'input_channels={self.in_channels}, output_channels={self.out_channels}' + if self.has_bias: + s += f', has_bias={self.has_bias}' + return s + + +class EmbeddingMindnlp(nn.Cell): + """patched Embedding""" + def __init__(self, vocab_size, embedding_size, padding_idx=None, use_one_hot=False, dtype=mstype.float32): + """Initialize Embedding.""" + super().__init__() + self.vocab_size = Validator.check_value_type('vocab_size', vocab_size, [int], self.cls_name) + self.embedding_size = Validator.check_value_type('embedding_size', embedding_size, [int], self.cls_name) + Validator.check_value_type('use_one_hot', use_one_hot, [bool], self.cls_name) + Validator.check_subclass("dtype", dtype, mstype.number_type, self.cls_name) + self.use_one_hot = use_one_hot + self.dtype = dtype + self.padding_idx = padding_idx + self.weight = Parameter(initializer(Normal(1.0), [vocab_size, embedding_size]), name='weight') + if self.padding_idx and self.weight.init_flag: + self.weight[self.padding_idx] = 0 + + def construct(self, ids): + out_shape = ids.shape + (self.embedding_size,) + flat_ids = ids.reshape((-1,)) + + if self.use_one_hot: + one_hot_ids = ops.one_hot(flat_ids, self.vocab_size) + output_for_reshape = ops.matmul(one_hot_ids, self.weight) + else: + output_for_reshape = ops.gather(self.weight, flat_ids, 0) + + output = output_for_reshape.reshape(out_shape) + return output + + def extend_repr(self): + return f'vocab_size={self.vocab_size}, embedding_size={self.embedding_size}, use_one_hot={self.use_one_hot}, ' \ + f'weight={self.weight}, dtype={self.dtype}, padding_idx={self.padding_idx}' + + +class LayerNormMindnlp(nn.Cell): + r""" + Applies Layer Normalization over a mini-batch of inputs. + """ + + def __init__(self, + normalized_shape, + begin_norm_axis=-1, + begin_params_axis=-1, + gamma_init='ones', + beta_init='zeros', + epsilon=1e-5, + dtype=mstype.float32, + elementwise_affine=True + ): + """Initialize LayerNorm.""" + super().__init__() + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] + if not isinstance(normalized_shape, (tuple, list)): + raise TypeError(f"For '{self.cls_name}', the type of 'normalized_shape' must be tuple[int] or list[int], " + f"but got {normalized_shape} and the type is {type(normalized_shape)}.") + if not normalized_shape: + raise ValueError( + f"Expected normalized_shape to be at least 1-dimensional, i.e., containing at " + f"least one element, but got normalized_shape = {normalized_shape}" + ) + self.normalized_shape = normalized_shape + self.begin_norm_axis = begin_norm_axis + self.begin_params_axis = begin_params_axis + self.epsilon = epsilon + self.weight = Parameter(initializer( + gamma_init, normalized_shape, dtype=dtype), name="weight") + self.bias = Parameter(initializer( + beta_init, normalized_shape, dtype=dtype), name="bias") + self.layer_norm = ops.LayerNorm(begin_norm_axis=self.begin_norm_axis, + begin_params_axis=self.begin_params_axis, + epsilon=self.epsilon) + self.elementwise_affine = elementwise_affine + + def construct(self, input_x): + if self.elementwise_affine: + y, _, _ = self.layer_norm(input_x, self.weight.astype(input_x.dtype), self.bias.astype(input_x.dtype)) + else: + y, _, _ = self.layer_norm(input_x, ops.ones(self.normalized_shape, input_x.dtype), + ops.zeros(self.normalized_shape, input_x.dtype),) + return y + + def extend_repr(self): + return f'normalized_shape={self.normalized_shape}, begin_norm_axis={self.begin_norm_axis}, ' \ + f'begin_params_axis={self.begin_params_axis}, weight={self.weight}, bias={self.bias}' + + +class BatchNorm1dMindnlp(nn.Cell): + """Batch Normalization base class.""" + def __init__(self, + num_features, + eps=1e-5, + momentum=0.9, + affine=True, + weight_init='ones', + bias_init='zeros', + moving_mean_init='zeros', + moving_var_init='ones', + use_batch_statistics=None, + dtype=mstype.float32): + """Initialize _BatchNorm.""" + super().__init__() + if num_features < 1: + raise ValueError(f"For '{self.cls_name}', the 'num_features' must be at least 1, but got {num_features}.") + + if momentum < 0 or momentum > 1: + raise ValueError(f"For '{self.cls_name}', the 'momentum' must be a number in range [0, 1], " + f"but got {momentum}.") + self.use_batch_statistics = use_batch_statistics + if self.use_batch_statistics is not None and not isinstance(self.use_batch_statistics, bool): + raise ValueError(f"For '{self.cls_name}', the 'use_batch_statistics' must be a boolean value or None," + f" but got {use_batch_statistics}.") + self.num_features = num_features + self.eps = eps + self.moving_mean_init = moving_mean_init + self.moving_var_init = moving_var_init + self.running_mean = Parameter(initializer( + moving_mean_init, num_features, dtype=dtype), name="running_mean", requires_grad=False) + self.running_var = Parameter(initializer( + moving_var_init, num_features, dtype=dtype), name="running_var", requires_grad=False) + self.weight = Parameter(initializer( + weight_init, num_features, dtype=dtype), name="weight", requires_grad=affine) + self.bias = Parameter(initializer( + bias_init, num_features, dtype=dtype), name="bias", requires_grad=affine) + + self.momentum = 1.0 - momentum + + self.bn_train = ops.BatchNorm(is_training=True, + epsilon=self.eps, + momentum=self.momentum) + + self.bn_infer = ops.BatchNorm(is_training=False, epsilon=self.eps) + + def construct(self, x): + if self.use_batch_statistics is None: + if self.training: + return self.bn_train(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + if not self.training: + return self.bn_infer(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + if self.use_batch_statistics: + return self.bn_train(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + return self.bn_infer(x, + self.weight, + self.bias, + self.running_mean, + self.running_var)[0] + + def extend_repr(self): + return f'num_features={self.num_features}, eps={self.eps}, momentum={1.0 - self.momentum}, ' \ + f'weight={self.weight}, bias={self.bias}, running_mean={self.running_mean}, ' \ + f'running_var={self.running_var}' + diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py new file mode 100644 index 000000000..0eb859740 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py @@ -0,0 +1,1427 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" +Logits process +""" + +import math +import inspect +import logging +from typing import List, Iterable, Union, Optional, Callable, Tuple, Dict +import mindspore +from mindspore import ops +import numpy as np + + +class LogitsProcessor: + """Abstract base class for all logit processors that can be applied during generation.""" + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + """Torch method for processing logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsWarper: + """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling.""" + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + """Torch method for warping logits.""" + raise NotImplementedError( + f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." + ) + + +class LogitsProcessorList(list): + """ + This class can be used to create a list of [`LogitsProcessor`] or [`LogitsWarper`] to subsequently process a + `scores` input tensor. This class inherits from list and adds a specific *__call__* method to apply each + [`LogitsProcessor`] or [`LogitsWarper`] to the inputs. + """ + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> mindspore.Tensor: + for processor in self: + function_args = inspect.signature(processor.__call__).parameters + if len(function_args) > 2: + if not all(arg in kwargs for arg in list(function_args.keys())[2:]): + raise ValueError( + f"Make sure that all the required parameters: {list(function_args.keys())} for " + f"{processor.__class__} are passed to the logits processor." + ) + scores = processor(input_ids, scores, **kwargs) + else: + scores = processor(input_ids, scores) + return scores + + +class HammingDiversityLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for + [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence + Models](https://arxiv.org/pdf/1610.02424.pdf) for more details. + + Args: + diversity_penalty (`float`): + This value is subtracted from a beam's score if it generates a token same as any beam from other group at a + particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. + num_beams (`int`): + Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more + details. + num_beam_groups (`int`): + Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. + See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. + """ + + def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): + if not isinstance(diversity_penalty, float) or (not diversity_penalty > 0.0): + raise ValueError("`diversity_penalty` should be a float strictly larger than 0.") + self._diversity_penalty = diversity_penalty + if not isinstance(num_beams, int) or num_beams < 2: + raise ValueError("`num_beams` should be an integer strictly larger than 1.") + self._num_beams = num_beams + if not isinstance(num_beam_groups, int) or num_beam_groups < 2: + raise ValueError("`num_beam_groups` should be an integer strictly larger than 1.") + if num_beam_groups > num_beams: + raise ValueError("`beam_groups` has to be smaller or equal to `num_beams`.") + self._num_sub_beams = num_beams // num_beam_groups + + def __call__( + self, + input_ids: mindspore.Tensor, + scores: mindspore.Tensor, + current_tokens: mindspore.Tensor, + beam_group_idx: int, + ) -> mindspore.Tensor: + # hamming diversity: penalise using same token in current group which was used in previous groups at + # the same time step + batch_size = current_tokens.shape[0] // self._num_beams + group_start_idx = beam_group_idx * self._num_sub_beams + group_end_idx = min(group_start_idx + self._num_sub_beams, self._num_beams) + group_size = group_end_idx - group_start_idx + vocab_size = scores.shape[-1] + + if group_start_idx == 0: + return scores + + for batch_idx in range(batch_size): + # predicted tokens of last time step of previous groups + previous_group_tokens = current_tokens[ + batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx + ] + token_frequency = ops.bincount(previous_group_tokens, minlength=vocab_size) + scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency + + return scores + + +class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. + + Args: + hallucination_penalty (`float`): + The parameter for hallucination penalty. 1.0 means no penalty. + encoder_input_ids (`mindspore.Tensor`): + The encoder_input_ids that should not be repeated within the decoder ids. + """ + + def __init__(self, penalty: float, encoder_input_ids: mindspore.Tensor): + if not isinstance(penalty, float) or (penalty <= 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = 1 / penalty + self.encoder_input_ids = encoder_input_ids + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + score = ops.gather_elements(scores, 1, self.encoder_input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = ops.where(score < 0, score * self.penalty, score / self.penalty) + + scores.scatter_(1, self.encoder_input_ids, score) + return scores + + +class RepetitionPenaltyLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. + + Args: + repetition_penalty (`float`): + The parameter for repetition penalty. 1.0 means no penalty. See [this + paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. + """ + + def __init__(self, penalty: float): + if not isinstance(penalty, float) or (penalty <= 0): + raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}") + + self.penalty = penalty + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + input_ids = ops.where(input_ids >= scores.shape[1], input_ids - scores.shape[1], input_ids) + score = ops.gather_elements(scores, 1, input_ids) + + # if score < 0 then repetition penalty has to be multiplied to reduce the previous token probability + score = ops.where(score < 0, score * self.penalty, score / self.penalty) + + scores = ops.tensor_scatter_elements(scores, input_ids, score, axis=1) + return scores + + +def _get_ngrams(ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int): + generated_ngrams = [{} for _ in range(num_hypos)] + for idx in range(num_hypos): + gen_tokens = prev_input_ids[idx].asnumpy().tolist() + generated_ngram = generated_ngrams[idx] + for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]): + prev_ngram_tuple = tuple(ngram[:-1]) + generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]] + return generated_ngrams + + +def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len): + # Before decoding the next token, prevent decoding of ngrams that have already appeared + start_idx = cur_len + 1 - ngram_size + ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist()) + return banned_ngrams.get(ngram_idx, []) + + +def _calc_banned_ngram_tokens( + ngram_size: int, prev_input_ids: mindspore.Tensor, num_hypos: int, cur_len: int +) -> List[Iterable[int]]: + """Copied from fairseq for no_repeat_ngram in beam_search""" + if cur_len + 1 < ngram_size: + # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet + return [[] for _ in range(num_hypos)] + + generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos) + + banned_tokens = [ + _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len) + for hypo_idx in range(num_hypos) + ] + return banned_tokens + + +class NoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of n-grams. See + [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). + + Args: + ngram_size (`int`): + All ngrams of size `ngram_size` can only occur once. + """ + + def __init__(self, ngram_size: int): + if not isinstance(ngram_size, int) or ngram_size <= 0: + raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}") + self.ngram_size = ngram_size + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + num_batch_hypotheses = scores.shape[0] + cur_len = input_ids.shape[-1] + banned_batch_tokens = _calc_banned_ngram_tokens(self.ngram_size, input_ids, num_batch_hypotheses, cur_len) + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + +class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See + [ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350). + + Args: + encoder_ngram_size (`int`): + All ngrams of size `ngram_size` can only occur within the encoder input ids. + encoder_input_ids (`int`): + The encoder_input_ids that should not be repeated within the decoder ids. + """ + + def __init__(self, encoder_ngram_size: int, encoder_input_ids: mindspore.Tensor): + if not isinstance(encoder_ngram_size, int) or encoder_ngram_size <= 0: + raise ValueError( + f"`encoder_ngram_size` has to be a strictly positive integer, but is {encoder_ngram_size}" + ) + self.ngram_size = encoder_ngram_size + if len(encoder_input_ids.shape) == 1: + encoder_input_ids = encoder_input_ids.unsqueeze(0) + self.batch_size = encoder_input_ids.shape[0] + self.generated_ngrams = _get_ngrams(encoder_ngram_size, encoder_input_ids, self.batch_size) + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # B x num_beams + num_hypos = scores.shape[0] + num_beams = num_hypos // self.batch_size + cur_len = input_ids.shape[-1] + banned_batch_tokens = [ + _get_generated_ngrams( + self.generated_ngrams[hypo_idx // num_beams], input_ids[hypo_idx], self.ngram_size, cur_len + ) + for hypo_idx in range(num_hypos) + ] + + for i, banned_tokens in enumerate(banned_batch_tokens): + scores[i, banned_tokens] = -float("inf") + + return scores + + +class NoBadWordsLogitsProcessor(LogitsProcessor): + """ + [`LogitsProcessor`] that enforces that specified sequences will never be sampled. + + Args: + bad_words_ids (`List[List[int]]`): + List of list of token ids that are not allowed to be generated. In order to get the token ids of the words + that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, + add_special_tokens=False).input_ids`. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): + if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: + raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") + if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): + raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") + if any( + any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids) + for bad_word_ids in bad_words_ids + ): + raise ValueError( + f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}." + ) + + if eos_token_id is None: + eos_token_id = [] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + + bad_words_ids = list( + filter(lambda bad_token_seq: all(bad_token_seq != [i] for i in eos_token_id), bad_words_ids) + ) + self.bad_words_id_length_1 = [] + self.bad_words_id_length_greater_than_1 = [] + for word in bad_words_ids: + if len(word) == 1: + self.bad_words_id_length_1.append(word[0]) + else: + self.bad_words_id_length_greater_than_1.append(word) + + self.static_bad_words_mask: Optional[mindspore.Tensor] = None + + for banned_token_seq in self.bad_words_id_length_greater_than_1: + if len(banned_token_seq) == 0: + raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0: + self.static_bad_words_mask = self._calc_static_bad_word_mask(scores) + + dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist()) + scores = self._set_scores_to_inf_for_banned_tokens(scores, dynamic_banned_tokens) + + return scores + + def _calc_static_bad_word_mask(self, scores: mindspore.Tensor) -> mindspore.Tensor: + static_bad_words_mask = ops.zeros(scores.shape[1]) + static_bad_words_mask[self.bad_words_id_length_1] = 1 + return static_bad_words_mask.unsqueeze(0).bool() + + def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool: + if len(tokens) == 0: + # if bad word tokens is just one token always ban it + return True + if len(tokens) > len(prev_tokens): + # if bad word tokens are longer then prev input_ids they can't be equal + return False + return prev_tokens[-len(tokens):] == tokens + + def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]: + banned_tokens = [] + for prev_input_ids_slice in prev_input_ids: + banned_tokens_slice = [] + for banned_token_seq in self.bad_words_id_length_greater_than_1: + if self._tokens_match(prev_input_ids_slice, banned_token_seq[:-1]): + banned_tokens_slice.append(banned_token_seq[-1]) + + banned_tokens.append(banned_tokens_slice) + + return banned_tokens + + def _set_scores_to_inf_for_banned_tokens( + self, scores: mindspore.Tensor, banned_tokens: List[List[int]] + ) -> mindspore.Tensor: + """ + Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a + list of list of banned tokens to ban in the format [[batch index, vocabulary position],... + + Args: + scores: logits distribution of shape (batch size, vocabulary size) + banned_tokens: list of list of tokens to ban of length (batch_size) + """ + banned_mask_list = [] + for idx, batch_banned_tokens in enumerate(banned_tokens): + for token in batch_banned_tokens: + # Eliminates invalid bad word IDs that are over the vocabulary size. + if token <= scores.shape[1]: + banned_mask_list.append([idx, token]) + else: + logging.error( + "An invalid bad word ID is defined: %d. This ID is not contained in the " + "vocabulary, and is therefore ignored.", token + ) + if not banned_mask_list and self.static_bad_words_mask is None: + return scores + + if banned_mask_list: + banned_mask = mindspore.Tensor(banned_mask_list) + indices = ops.ones(len(banned_mask)) + # A sparse tensor is generated from a list of coordinates: [[0, 1], [0, 2], [2, 0]]. A conversion to dense tensor generates: + # [ 0 1 1 ] + # [ 0 0 0 ] + # [ 1 0 0 ] + + banned_mask = ( + mindspore.COOTensor(banned_mask, + indices, scores.shape) + .to_dense() + .bool() + ) + + if self.static_bad_words_mask is not None: + banned_mask = ops.bitwise_or(banned_mask, self.static_bad_words_mask) + else: + banned_mask = self.static_bad_words_mask + + scores = scores.masked_fill(banned_mask, -float("inf")) + return scores + + +class MinLengthLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. + + Args: + min_length (`int`): + The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + """ + + def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): + if not isinstance(min_length, int) or min_length < 0: + raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}") + + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + if not all(isinstance(i, int) for i in eos_token_id) or any(i < 0 for i in eos_token_id): + raise ValueError(f"`eos_token_id` has to be a list of positive integers, but is {eos_token_id}") + + self.min_length = min_length + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + vocab_tensor = ops.arange(scores.shape[-1]) + eos_token_id = mindspore.tensor(self.eos_token_id) + eos_token_mask = vocab_tensor == eos_token_id + scores_processed = scores.copy() + if input_ids.shape[-1] < self.min_length: + scores_processed = ops.where(eos_token_mask, -math.inf, scores) + return scores_processed + + +class MinNewTokensLengthLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. + + Args: + prompt_length_to_skip (`int`): + The input tokens length. + min_new_tokens (`int`): + The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. + eos_token_id (`int`): + The id of the *end-of-sequence* token. + """ + + def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): + for arg_name, arg_value in [ + ("prompt_length_to_skip", prompt_length_to_skip), + ("min_new_tokens", min_new_tokens), + ("eos_token_id", eos_token_id), + ]: + if not isinstance(arg_value, int) or arg_value < 0: + raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") + + self.prompt_length_to_skip = prompt_length_to_skip + self.min_new_tokens = min_new_tokens + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + new_tokens_length = input_ids.shape[-1] - self.prompt_length_to_skip + if new_tokens_length < self.min_new_tokens: + scores[:, self.eos_token_id] = -float("inf") + + return scores + + +class PrefixConstrainedLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained + generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information. + + Args: + prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`): + This function constraints the beam search to allowed tokens only at each step. This function takes 2 + arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the + next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID + `batch_id`. + """ + + def __init__(self, prefix_allowed_tokens_fn: Callable[[int, mindspore.Tensor], List[int]], num_beams: int): + self._prefix_allowed_tokens_fn = prefix_allowed_tokens_fn + self._num_beams = num_beams + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + mask = ops.full_like(scores, -math.inf) + for batch_id, beam_sent in enumerate(input_ids.view(-1, self._num_beams, input_ids.shape[-1])): + for beam_id, sent in enumerate(beam_sent): + mask[batch_id * self._num_beams + beam_id, self._prefix_allowed_tokens_fn(batch_id, sent)] = 0 + + return scores + mask + + +class ForcedBOSTokenLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces the specified token as the first generated token. + + Args: + bos_token_id (`int`): + The id of the token to force as the first generated token. + """ + + def __init__(self, bos_token_id: int): + self.bos_token_id = bos_token_id + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + cur_len = input_ids.shape[-1] + if cur_len == 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i != self.bos_token_id]] = -float("inf") + scores[:, self.bos_token_id] = 0 + return scores + + +class ForcedEOSTokenLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. + + Args: + max_length (`int`): + The maximum length of the sequence to be generated. + eos_token_id (`Union[int, List[int]]`): + The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a + list to set multiple *end-of-sequence* tokens. + """ + + def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): + self.max_length = max_length + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + cur_len = input_ids.shape[-1] + if cur_len == self.max_length - 1: + num_tokens = scores.shape[1] + scores[:, [i for i in range(num_tokens) if i not in self.eos_token_id]] = \ + float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + for i in self.eos_token_id: + scores[:, i] = 0 + return scores + + +class InfNanRemoveLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] that removes all `nan` and `inf` values to avoid the generation method to fail. Note that using + the logits processor should only be used if necessary since it can slow down the generation method. `max_length` is + reached. + """ + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # set all nan values to 0.0 + scores[ops.isnan(scores)] = 0.0 + + # set all inf values to max possible value + scores[scores == float("inf")] = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).max) + return scores + + +class ExponentialDecayLengthPenalty(LogitsProcessor): + r""" + [`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been + reached. + + Args: + exponential_decay_length_penalty (`tuple(int, float)`, *optional*): + This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty + starts and `decay_factor` represents the factor of exponential decay + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + input_ids_seq_length (`int`): + The length of the input sequence. + """ + + def __init__( + self, exponential_decay_length_penalty: Tuple, eos_token_id: Union[int, List[int]], + input_ids_seq_length: int + ): + self.regulation_start = exponential_decay_length_penalty[0] + input_ids_seq_length + self.regulation_factor = exponential_decay_length_penalty[1] + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + cur_len = input_ids.shape[-1] + if cur_len > self.regulation_start: + for i in self.eos_token_id: + scores[:, i] = scores[:, i] * pow(self.regulation_factor, cur_len - self.regulation_start) + return scores + + +class SuppressTokensLogitsProcessor(LogitsProcessor): + r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they + are not sampled.""" + + def __init__(self, suppress_tokens): + self.suppress_tokens = list(suppress_tokens) + + def __call__(self, input_ids, scores): + scores[:, self.suppress_tokens] = -float("inf") + return scores + + +class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): + r""" + [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts + generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not + sampled at the begining of the generation. + """ + + def __init__(self, begin_suppress_tokens, begin_index): + self.begin_suppress_tokens = list(begin_suppress_tokens) + self.begin_index = begin_index + + def __call__(self, input_ids, scores): + if input_ids.shape[1] == self.begin_index: + scores[:, self.begin_suppress_tokens] = -float("inf") + + return scores + + +class ForceTokensLogitsProcessor(LogitsProcessor): + r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token + indices that will be forced before sampling. The processor will set their log probs to `inf` so that they are + sampled at their corresponding index.""" + + def __init__(self, force_token_map: List[List[int]]): + self.force_token_map = dict(force_token_map) + + def __call__(self, input_ids, scores): + generation_idx = input_ids.shape[-1] + current_token = self.force_token_map.get(generation_idx, None) + if current_token is not None: + scores[:, :] = -float("inf") + scores[:, current_token] = 0 + return scores + + +class LogitNormalization(LogitsProcessor, LogitsWarper): + r""" + [`LogitsWarper`] and [`LogitsProcessor`] for normalizing the scores using log-softmax. It's important to normalize + the scores during beam search, after applying the logits processors or warpers, since the search algorithm used in + this library doesn't do it (it only does it before, but they may need re-normalization) but it still supposes that + the scores are normalized when comparing the hypotheses. + """ + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + scores = ops.log_softmax(scores, axis=-1) + return scores + + +class TemperatureLogitsWarper(LogitsWarper): + r""" + [`TemperatureLogitsWarper`] for temperature (exponential scaling output probability distribution). + Args: + temperature (:obj:`float`): + The value used to module the logits distribution. + """ + + def __init__(self, temperature: float): + if not isinstance(temperature, float) or not temperature > 0: + raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}") + + self.temperature = temperature + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + scores = scores / self.temperature + return scores + + +class TopPLogitsWarper(LogitsWarper): + """ + [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. + + Args: + top_p (`float`): + If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or + higher are kept for generation. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + top_p = float(top_p) + if top_p < 0 or top_p > 1.0: + raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}") + if not isinstance(min_tokens_to_keep, int) or min_tokens_to_keep < 1: + raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") + + self.top_p = top_p + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + if self.filter_value == -float("Inf"): + self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + # scores = ops.select(ops.isneginf(scores), mindspore.tensor(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min), scores) + sorted_logits, sorted_indices = ops.sort(scores, descending=False) + cumulative_probs = ops.softmax(sorted_logits, axis=-1).cumsum(axis=-1) + + # Remove tokens with cumulative top_p above the threshold (token with 0 are kept) + sorted_indices_to_remove = cumulative_probs <= (1 - self.top_p) + # Keep at least min_tokens_to_keep + #sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 + + # scatter sorted tensors to original indexing + #indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + #scores = scores.masked_fill(indices_to_remove, self.filter_value) + sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 + + if type(sorted_indices_to_remove[0][0].item()) == bool: + sorted_indices_to_remove = sorted_indices_to_remove.astype("int32") + + indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) + + if type(indices_to_remove[0][0].item()) == int: + indices_to_remove = indices_to_remove.astype("bool") + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class TopKLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. + + Args: + top_k (`int`): + The number of highest probability vocabulary tokens to keep for top-k-filtering. + filter_value (`float`, *optional*, defaults to `-float("Inf")`): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + if not isinstance(top_k, int) or top_k <= 0: + raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}") + + self.top_k = max(top_k, min_tokens_to_keep) + self.filter_value = filter_value + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + if self.filter_value == -float("Inf"): + self.filter_value = float(np.finfo(mindspore.dtype_to_nptype(scores.dtype)).min) + top_k = min(self.top_k, scores.shape[-1]) # Safety check + # Remove all tokens with a probability less than the last token of the top-k + indices_to_remove = scores < ops.topk(scores, top_k)[0][..., -1, None] + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + +class TypicalLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language + Generation](https://arxiv.org/abs/2202.00666) for more information. + + Args: + mass (`float`, *optional*, defaults to 0.9): + Value of typical_p between 0 and 1 inclusive, defaults to 0.9. + filter_value (`float`, *optional*, defaults to -inf): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + """ + + def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + mass = float(mass) + if mass <= 0 or mass >= 1: + raise ValueError(f"`typical_p` has to be a float > 0 and < 1, but is {mass}") + if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1): + raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}") + + self.filter_value = filter_value + self.mass = mass + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # calculate entropy + normalized = ops.log_softmax(scores, axis=-1) + p = ops.exp(normalized) + ent = -(normalized * p).nansum(-1, keepdim=True) + + # shift and sort + shifted_scores = ops.abs((-normalized) - ent) + sorted_scores, sorted_indices = ops.sort(shifted_scores, descending=False) + sorted_logits = scores.gather(-1, sorted_indices) + cumulative_probs = sorted_logits.softmax(axis=-1).cumsum(axis=-1) + + # Remove tokens with cumulative mass above the threshold + last_ind = (cumulative_probs < self.mass).axis(dim=1) + last_ind.clamp_(max=sorted_scores.shape[-1] - 1) + sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) + sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 + indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, sorted_indices, sorted_indices_to_remove, axis=1) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class EpsilonLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the + largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model + Desmoothing](https://arxiv.org/abs/2210.15191) for more information. + + Args: + epsilon (`float`): + If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation. + filter_value (`float`, *optional*, defaults to -inf): + All filtered values will be set to this float value. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Minimum number of tokens that cannot be filtered. + + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed + + >>> set_seed(0) + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt") + + >>> # With sampling, the output is unexpected -- sometimes too unexpected. + >>> outputs = model.generate(**inputs, do_sample=True) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2 + + >>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to + >>> # Top P sampling, which restricts tokens based on their cumulative probability. + >>> # Pro tip: The paper recomends using `epsilon_cutoff` values between 3e-4 and 9e-4 + >>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9 + ``` + """ + + def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`epsilon_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = epsilon + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # Determine which indices to remove + probabilities = ops.softmax(scores, axis=-1) + indices_to_remove = probabilities < self.epsilon + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + + +class EtaLogitsWarper(LogitsWarper): + r""" + [`LogitsWarper`] that performs eta-sampling, a technique to filter out tokens with probabilities below a dynamic + cutoff value, `eta`, which is calculated based on a combination of the hyperparameter `epsilon` and the entropy of + the token probabilities, i.e. `eta := min(epsilon, sqrt(epsilon * e^-entropy(probabilities)))`. Takes the largest + min_tokens_to_keep tokens if no tokens satisfy this constraint. It addresses the issue of poor quality in long + samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation + Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` + must be set to `True` for this `LogitsWarper` to work. + + + Args: + epsilon (`float`): + A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The + suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model. + filter_value (`float`, *optional*, defaults to -inf): + All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This + parameter is useful when logits need to be modified for very low probability tokens that should be excluded + from generation entirely. + min_tokens_to_keep (`int`, *optional*, defaults to 1): + Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. + For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, + even if all tokens have probabilities below the cutoff `eta`. + + Examples: + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed + + >>> set_seed(0) + >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") + + >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt") + + >>> # With sampling, the output is unexpected -- sometimes too unexpected. + >>> outputs = model.generate(**inputs, do_sample=True) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2 + + >>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of + >>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff). + >>> # Pro tip: The paper recomends using `eta_cutoff` values between 3e-4 to 4e-3 + >>> outputs = model.generate(**inputs, do_sample=True, eta_cutoff=0.1) + >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) + A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9 + ``` + """ + + def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): + epsilon = float(epsilon) + if epsilon <= 0 or epsilon >= 1: + raise ValueError(f"`eta_cutoff` has to be a float > 0 and < 1, but is {epsilon}") + + min_tokens_to_keep = int(min_tokens_to_keep) + if min_tokens_to_keep < 1: + raise ValueError( + f"`min_tokens_to_keep` has to be a strictly positive integer, but is {min_tokens_to_keep}" + ) + + self.epsilon = mindspore.tensor(epsilon, mindspore.float32) + self.filter_value = filter_value + self.min_tokens_to_keep = min_tokens_to_keep + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # Calculate the adaptive cutoff + probabilities = scores.softmax(dim=-1) + entropy = mindspore.nn.probability.distribution.Categorical().entropy(scores) + eta = ops.min(self.epsilon, ops.sqrt(self.epsilon) * ops.exp(-entropy))[..., None] + indices_to_remove = probabilities < eta + + # Keep the words with the 'min_tokens_to_keep'-highest probabilities + top_k = min(self.min_tokens_to_keep, scores.size(-1)) # Safety check + indices_to_remove = indices_to_remove & (scores < ops.topk(scores, top_k)[0][..., -1, None]) + + scores = scores.masked_fill(indices_to_remove, self.filter_value) + return scores + +class SequenceBiasLogitsProcessor(LogitsProcessor): + """ + [`LogitsProcessor`] that applies an additive bias on sequences. The bias is applied to the last token of a sequence + when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than + one token, consider using beam methods (to gracefully work around partially completed sequences that have a + negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier). + + + + In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when + initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The + `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours + come from `pre tokenizers`. Read more [here](https://hf-mirror.com/docs/tokenizers/api/pre-tokenizers). + + + + Args: + sequence_bias (`Dict[Tuple[int], float]`): + Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the + sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias + will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be + completed (in the token selection step after this processor is applied). + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt") + + >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4) + >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald J. Trump Jr + + >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently! + >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True) + + + >>> def get_tokens_as_tuple(word): + ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]) + + + >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations + >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0} + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald J. Donald, + + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald Rumsfeld, + + >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations + >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0} + >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) + >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) + The full name of Donald is Donald Duck. + ``` + """ + + def __init__(self, sequence_bias: Dict[Tuple[int], float]): + self.sequence_bias = sequence_bias + self._validate_arguments() + + # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size + # is infered in the first usage, which inhibits initializing here) + self.length_1_bias = None + self.prepared_bias_variables = False + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # 1 - Prepares the bias tensors. This is only needed the first time the logit processor is called. + if not self.prepared_bias_variables: + self._prepare_bias_variables(scores) + + # 2 - prepares an empty bias to add + bias = ops.zeros_like(scores) + + # 3 - include the bias from length = 1 + bias += self.length_1_bias + + # 4 - include the bias from length > 1, after determining which biased sequences may be completed. + for sequence_ids, sequence_bias in self.sequence_bias.items(): + if len(sequence_ids) == 1: # the sequence is of length 1, already applied + continue + if len(sequence_ids) > input_ids.shape[1]: # the sequence is longer than the context, ignore + continue + prefix_length = len(sequence_ids) - 1 + last_token = sequence_ids[-1] + matching_rows = ops.eq( + input_ids[:, -prefix_length:], + mindspore.tensor(sequence_ids[:-1], dtype=input_ids.dtype), + ).prod(dim=1) + bias[:, last_token] += ops.where( + matching_rows.bool(), + mindspore.tensor(sequence_bias), + mindspore.tensor(0.0), + ) + + # 5 - apply the bias to the scores + scores = scores + bias + return scores + +class AlternatingCodebooksLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel. + + Args: + input_start_len (`int`): + The length of the initial input sequence. + semantic_vocab_size (`int`): + Vocabulary size of the semantic part, i.e number of tokens associated to the semantic vocabulary. + codebook_size (`int`): + Number of tokens associated to the codebook. + """ + + def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size: int): + if not isinstance(input_start_len, int) or input_start_len < 0: + raise ValueError(f"`input_starting_length` has to be a non-negative integer, but is {input_start_len}") + + self.input_start_len = input_start_len + self.semantic_vocab_size = semantic_vocab_size + self.codebook_size = codebook_size + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + curr_len = input_ids.shape[-1] + + # even -> first codebook, odd -> second codebook + is_first_codebook = ((curr_len - self.input_start_len) % 2) == 0 + + if is_first_codebook: + scores[:, : self.semantic_vocab_size] = -float("inf") + scores[:, self.semantic_vocab_size + self.codebook_size :] = -float("inf") + else: + scores[:, : self.semantic_vocab_size + self.codebook_size] = -float("inf") + + return scores + +class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r"""Logits processor for Classifier-Free Guidance (CFG). The processors + computes a weighted average across scores from prompt conditional and prompt unconditional (or negative) logits, + parameterized by the `guidance_scale`. The unconditional scores are computed internally by prompting `model` with + the `unconditional_ids` branch. + + See [the paper](https://arxiv.org/abs/2306.17806) for more information. + + Args: + guidance_scale (`float`): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while + making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt. + model (`PreTrainedModel`): + The model computing the unconditional scores. Supposedly the same as the one computing the conditional + scores. Both models must use the same tokenizer. + unconditional_ids (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to + the last token of the prompt. + unconditional_attention_mask (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): + Attention mask for unconditional_ids. + use_cache (`bool`, *optional*, defaults to `True`): + Whether to cache key/values during the negative prompt forward pass. + + + Examples: + + ```python + >>> from transformers import AutoTokenizer, AutoModelForCausalLM + + >>> model = AutoModelForCausalLM.from_pretrained("gpt2") + >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") + >>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt") + >>> out = model.generate(inputs["input_ids"], guidance_scale=1.5) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + 'Today, a dragon flew over Paris, France, killing at least 50 people and injuring more than 100' + + >>> # with a negative prompt + >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt") + >>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"]) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + 'Today, a dragon flew over Paris, France, killing at least 130 people. French media reported that' + + >>> # with a positive prompt + >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt") + >>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"]) + >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] + "Today, a dragon flew over Paris, France, and I'm very happy to be here. I" + ``` + """ + + def __init__( + self, + guidance_scale: float, + model, + unconditional_ids: Optional[mindspore.Tensor] = None, + unconditional_attention_mask: Optional[mindspore.Tensor] = None, + use_cache: Optional[bool] = True, + ): + self.guidance_scale = guidance_scale + self.model = model + self.unconditional_context = { + "input_ids": unconditional_ids, + "attention_mask": unconditional_attention_mask, + "use_cache": use_cache, + "past_key_values": None, + "first_pass": True, + } + + def get_unconditional_logits(self, input_ids): + """get_unconditional_logits""" + if self.unconditional_context["first_pass"]: + if self.unconditional_context["input_ids"] is None: + self.unconditional_context["input_ids"] = input_ids[:, -1:] + if self.unconditional_context["attention_mask"] is None: + self.unconditional_context["attention_mask"] = ops.ones_like( + self.unconditional_context["input_ids"], dtype=mindspore.int64 + ) + input_ids = self.unconditional_context["input_ids"] + attention_mask = self.unconditional_context["attention_mask"] + self.unconditional_context["first_pass"] = False + else: + attention_mask = ops.cat( + [ + self.unconditional_context["attention_mask"], + ops.ones_like(input_ids[:, -1:], dtype=mindspore.int64), + ], + axis=1, + ) + if not self.unconditional_context["use_cache"]: + input_ids = ops.cat([self.unconditional_context["input_ids"], input_ids[:, -1:]], axis=1) + else: + input_ids = input_ids[:, -1:] + self.unconditional_context["input_ids"] = input_ids + self.unconditional_context["attention_mask"] = attention_mask + + out = self.model( + input_ids, + attention_mask=attention_mask, + use_cache=self.unconditional_context["use_cache"], + past_key_values=self.unconditional_context["past_key_values"], + ) + self.unconditional_context["past_key_values"] = out.get("past_key_values", None) + + return out.logits + + def __call__(self, input_ids, scores): + scores = ops.log_softmax(scores, axis=-1) + if self.guidance_scale == 1: + return scores + + logits = self.get_unconditional_logits(input_ids) + + unconditional_logits = ops.log_softmax(logits[:, -1], axis=-1) + out = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits + return out + +class WhisperTimeStampLogitsProcessor(LogitsProcessor): + r""" + + [`LogitsProcessor`] that modifies the logits for the generation of timestamps in the transcription. When the input + tokens are at a specific threshold, the processor sets the scores to negative infinity. The processor makes sure + that timestamp tokens appear in pairs, by masking out the logits that would break this pairing pattern. This is + done to maintain the consistency and structure of generated timestamps. It also ensures that when the predicted + probability of sampling any of the timestamp token is greater than any individual non-timestamp token, those + non-timestamp logits are set to negative infinity. This is done to ensure the generation of timestamps over other + potential tokens. + + + See [the paper](https://arxiv.org/abs/2212.04356) for more information. + + Args: + generate_config (`GenerateConfig`): + The generate config used to generate the output. The following parameters are required: + eos_token_id (`int`, *optional*, defaults to 50257): + The id of the *end-of-sequence* token. + no_timestamps_token_id (`int`, *optional*, defaults to 50363): + The id of the `"<|notimestamps|>"` token. + max_initial_timestamp_index (`int`, *optional*, defaults to 1): + Used to set the maximum value of the initial timestamp. This is used to prevent the model from + predicting timestamps that are too far in the future. + + Examples: + ``` python + >>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig + >>> from datasets import load_dataset + + >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") + >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt") + >>> input_features = inputs.input_features + + >>> #Displaying timestamps + >>> generated_ids = model.generate(inputs=input_features, return_timestamps=True) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] + >>> print("Transcription:", transcription) + Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|> + + + >>> #No timestamps & change EOS: + >>> #This allows the user to select a specific token to terminate the sequence on, in this case it's the word "can"(460) + >>> model.generation_config.eos_token_id = 460 + >>> generated_ids = model.generate(inputs=input_features,return_timestamps=False) + >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] + >>> print("Transcription:", transcription) + Transcription: He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can + ``` + """ + + def __init__(self, generate_config): # support for the kwargs + self.eos_token_id = generate_config.eos_token_id + self.no_timestamps_token_id = generate_config.no_timestamps_token_id + self.timestamp_begin = generate_config.no_timestamps_token_id + 1 + + self.begin_index = len(generate_config.forced_decoder_ids) + 2 + if generate_config.forced_decoder_ids[-1][1] == self.no_timestamps_token_id: + self.begin_index -= 1 + self.max_initial_timestamp_index = generate_config.max_initial_timestamp_index + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # suppress <|notimestamps|> which is handled by without_timestamps + scores[:, self.no_timestamps_token_id] = -float("inf") + + if input_ids.shape[1] == self.begin_index - 1: + scores[:, :] = -float("inf") + scores[:, self.timestamp_begin] = 0 + return scores + + # timestamps have to appear in pairs, except directly before eos_token; mask logits accordingly + for k in range(input_ids.shape[0]): + seq = list(input_ids[k, self.begin_index :].tolist()) + last_was_timestamp = len(seq) >= 1 and seq[-1] >= self.timestamp_begin + penultimate_was_timestamp = len(seq) < 2 or seq[-2] >= self.timestamp_begin + + if last_was_timestamp: + if penultimate_was_timestamp: # has to be non-timestamp + scores[k, self.timestamp_begin :] = -float("inf") + else: # cannot be normal text tokens + scores[k, : self.eos_token_id] = -float("inf") + + # apply the `max_initial_timestamp` option + if input_ids.shape[1] == self.begin_index and self.max_initial_timestamp_index is not None: + last_allowed = self.timestamp_begin + self.max_initial_timestamp_index + scores[:, last_allowed + 1 :] = -float("inf") + + # if sum of probability over timestamps is above any other token, sample timestamp + logprobs = ops.log_softmax(scores.float(), axis=-1) + for k in range(input_ids.shape[0]): + timestamp_logprob = logprobs[k, self.timestamp_begin :].logsumexp(axis=-1) + max_text_token_logprob = logprobs[k, : self.timestamp_begin].max() + if not ops.isnan(timestamp_logprob) and timestamp_logprob > max_text_token_logprob: + scores[k, : self.timestamp_begin] = -float("inf") + + return scores + +class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): + r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. + + + + This logits processor is exclusively compatible with + [Bark](https://hf-mirror.com/docs/transformers/en/model_doc/bark). See the model documentation for examples. + + + + Args: + eos_token_id (`Union[int, List[int]]`): + The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. + min_eos_p (`float`, *optional*): + Minimum end of speech threshold. + """ + + def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): + if isinstance(eos_token_id, int): + eos_token_id = [eos_token_id] + self.eos_token_id = eos_token_id + if min_eos_p is not None and min_eos_p <= 0: + raise ValueError(f"`min_eos_p` has to be a positive float, but is {min_eos_p}") + self.min_eos_p = min_eos_p + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + if self.min_eos_p: + probs = ops.softmax(scores.float(), axis=-1) + # create scores full of -inf except for the eos_token_id + early_stop_scores = ops.ones_like(scores) * -float("inf") + early_stop_scores[:, self.eos_token_id] = scores[:, self.eos_token_id] + + do_early_stop = probs[:, self.eos_token_id] > self.min_eos_p + do_early_stop = ops.any(do_early_stop, axis=1, keep_dims=True) + scores = ops.where(do_early_stop, early_stop_scores, scores) + + return scores + + +class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): + r""" + [`LogitsProcessor`] for classifier free guidance (CFG). The scores are split over the batch dimension, + where the first half correspond to the conditional logits (predicted from the input prompt) and the second half + correspond to the unconditional logits (predicted from an empty or 'null' prompt). The processor computes a + weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. + + See [the paper](https://arxiv.org/abs/2306.05284) for more information. + + + + This logits processor is exclusively compatible with + [MusicGen](https://hf-mirror.com/docs/transformers/main/en/model_doc/musicgen) + + + + Args: + guidance_scale (float): + The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. + Higher guidance scale encourages the model to generate samples that are more closely linked to the input + prompt, usually at the expense of poorer quality. + + Examples: + + ```python + >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration + + >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") + >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") + + >>> inputs = processor( + ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + ... padding=True, + ... return_tensors="pt", + ... ) + >>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) + ``` + """ + + def __init__(self, guidance_scale): + if guidance_scale > 1: + self.guidance_scale = guidance_scale + else: + raise ValueError( + "Require guidance scale >1 to use the classifier free guidance processor, got guidance scale " + f"{guidance_scale}." + ) + + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: + # simple check to make sure we have compatible batch sizes between our + # logits scores (cond + uncond) and input ids (cond only) + if scores.shape[0] != 2 * input_ids.shape[0]: + raise ValueError( + f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to " + f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got " + f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids." + ) + unguided_bsz = scores.shape[0] // 2 + cond_logits, uncond_logits = scores.split(unguided_bsz, axis=0) + scores_processed = uncond_logits + (cond_logits - uncond_logits) * self.guidance_scale + return scores_processed diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py new file mode 100644 index 000000000..1cf86cf5a --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -0,0 +1,692 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" +import math +import time +from typing import Tuple +from collections import OrderedDict + +import numpy as np +import mindspore as ms +from mindspore import Tensor, nn, ops, Parameter +from mindspore.nn import CrossEntropyLoss + +from .module.injection import EmbeddingMindnlp, DenseMindnlp, LayerNormMindnlp +from .module.configuration_utils import BaseModelOutputWithPast, CausalLMOutputWithPast, PreTrainedModelMindnlp, PretrainedConfig + + +class ClassInstantier(OrderedDict): + r""" + Class Instantier + """ + + def __getitem__(self, key): + content = super().__getitem__(key) + cls, kwargs = content if isinstance(content, tuple) else (content, {}) + return cls(**kwargs) + + +ACT2CLS = { + """ + Excitation equation matrix + """ + 'relu': nn.ReLU, + 'gelu': (nn.GELU, {"approximate": False}), + 'gelu_new': nn.GELU, + 'gelu_approximate': nn.GELU, + 'gelu_pytorch_tanh': nn.GELU, + "swish": nn.SiLU, # MindSpore的SiLU激活函数是Swish函数 + "gelu_10": nn.GELU, # MindSpore的GELU激活函数不支持设置最大值和最小值 + "gelu_fast": nn.FastGelu, + "gelu_python": nn.GELU, # MindSpore的GELU激活函数不支持选择是否使用Python实现 + "linear": nn.ReLU, # MindSpore没有Linear激活函数,使用ReLU代替 + "mish": nn.Mish, + "quick_gelu": nn.FastGelu, + "relu": nn.ReLU, + "relu6": nn.ReLU6, + "sigmoid": nn.Sigmoid, + "silu": nn.SiLU, + "tanh": nn.Tanh, +} +ACT2FN = ClassInstantier(ACT2CLS) + +def fixed_pos_embedding(x, seq_dim=1, seq_len=None): + dim = x.shape[-1] + if seq_len is None: + seq_len = x.shape[seq_dim] + inv_freq = 1.0 / (10000 ** (ops.arange(0, dim, 2.0) / (dim * 1.0))) + sinusoid_inp = ops.einsum("i , j -> i j", ops.arange(seq_len), inv_freq).float() + return ops.sin(sinusoid_inp), ops.cos(sinusoid_inp) + + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = ops.stack((-x2, x1), axis=-1) + x_flatten = x.flatten(start_dim=-2) + return x_flatten # in einsum notation: rearrange(x, '... d j -> ... (d j)') + + +def apply_rotary_pos_emb(x, sincos, offset=0): + sin, cos = map(lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3), sincos) + # einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2) + return (x * cos) + (rotate_every_two(x) * sin) + + +class ProGenConfig(PretrainedConfig): + model_type = "progen" + + def __init__( + self, vocab_size, n_positions, n_ctx, n_embd, n_layer, n_head, rotary_dim, + n_inner, activation_function, resid_pdrop, embd_pdrop, attn_pdrop, + layer_norm_epsilon, initializer_range, scale_attn_weights, gradient_checkpointing, + use_cache, bos_token_id, eos_token_id, min_length, **kwargs, + ): + super().__init__(**kwargs) + self.vocab_size = vocab_size + self.n_ctx = n_ctx + self.n_positions = n_positions + self.n_embd = n_embd + self.n_layer = n_layer + self.n_head = n_head + self.n_inner = n_inner + self.rotary_dim = rotary_dim + self.activation_function = activation_function + self.resid_pdrop = resid_pdrop + self.embd_pdrop = embd_pdrop + self.attn_pdrop = attn_pdrop + self.layer_norm_epsilon = layer_norm_epsilon + self.initializer_range = initializer_range + self.gradient_checkpointing = gradient_checkpointing + self.scale_attn_weights = scale_attn_weights + self.use_cache = use_cache + self.bos_token_id = bos_token_id + self.eos_token_id = eos_token_id + self.min_length = min_length + + @property + def max_position_embeddings(self): + return self.n_positions + + @property + def hidden_size(self): + return self.n_embd + + @property + def num_attention_heads(self): + return self.n_head + + @property + def num_hidden_layers(self): + return self.n_layer + + +class NewGELUActivation(nn.Cell): + """ + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 + """ + def construct(self, inputs: Tensor) -> Tensor: + return 0.5 * inputs * (1.0 + ops.tanh(math.sqrt(2.0 / math.pi) * (inputs + 0.044715 * ops.pow(inputs, 3.0)))) + + +class ProGenAttention(nn.Cell): + def __init__(self, config): + super().__init__() + + max_positions = config.max_position_embeddings + temp_tri = ops.tril( + ops.ones((max_positions, max_positions), dtype=ms.int32)).view( + 1, 1, max_positions, max_positions + ) + self.bias = Parameter( + Tensor(temp_tri, dtype=ms.bool_), + name="bias", + requires_grad=False, + ) + + self.masked_bias = Parameter( + Tensor(-1e9), + name="masked_bias", + requires_grad=False, + ) + + self.attn_dropout = nn.Dropout(config.attn_pdrop) + self.resid_dropout = nn.Dropout(config.resid_pdrop) + + self.embed_dim = config.hidden_size + self.num_attention_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_attention_heads + if self.head_dim * self.num_attention_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_attention_heads (got `embed_dim`:\ + {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})." + ) + self.scale_attn = ops.sqrt(Tensor(self.head_dim, dtype=ms.float32)).to(ms.float16) + self.qkv_proj = DenseMindnlp(self.embed_dim, self.embed_dim * 3, has_bias=False) + + self.out_proj = DenseMindnlp(self.embed_dim, self.embed_dim, has_bias=False) + self.rotary_dim = None + if config.rotary_dim is not None: + self.rotary_dim = config.rotary_dim + + + def construct( + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, + ): + + layer_past, attention_mask, head_mask = add_input + qkv = self.qkv_proj(hidden_states) + mp_num = 8 + qkv_split = qkv.reshape(qkv.shape[:-1] + (mp_num, -1)) + + local_dim = self.head_dim * self.num_attention_heads // mp_num + query, value, key = ops.split(qkv_split, local_dim, axis=-1) + query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num) + key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num) + + value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num) + value = value.permute(0, 2, 1, 3) + + seq_len = key.shape[1] + offset = 0 + + if layer_past is not None: + offset = layer_past[0].shape[-2] + seq_len += offset + + if self.rotary_dim is not None: + k_rot = key[:, :, :, : self.rotary_dim] + k_pass = key[:, :, :, self.rotary_dim :] + + q_rot = query[:, :, :, : self.rotary_dim] + q_pass = query[:, :, :, self.rotary_dim :] + + sincos = fixed_pos_embedding(k_rot, 1, seq_len=seq_len) + k_rot = apply_rotary_pos_emb(k_rot, sincos, offset=offset) + q_rot = apply_rotary_pos_emb(q_rot, sincos, offset=offset) + + key = ops.cat((k_rot, k_pass), axis=-1) + query = ops.cat((q_rot, q_pass), axis=-1) + else: + sincos = fixed_pos_embedding(key, 1, seq_len=seq_len) + key = apply_rotary_pos_emb(key, sincos, offset=offset) + query = apply_rotary_pos_emb(query, sincos, offset=offset) + + key = key.permute(0, 2, 1, 3) + query = query.permute(0, 2, 1, 3) + + if layer_past is not None: + past_key = layer_past[0] + past_value = layer_past[1] + key = ops.cat((past_key, key), axis=-2) + value = ops.cat((past_value, value), axis=-2) + + if use_cache is True: + present = (key, value) + else: + present = None + + # compute self-attention: V x Softmax(QK^T) + attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask) + + attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim) + + attn_output = self.out_proj(attn_output) + attn_output = self.resid_dropout(attn_output) + + outputs = (attn_output, present) + if output_attentions: + outputs += (attn_weights,) + + return outputs # a, present, (attentions) + + def _split_heads(self, x, n_head, dim_head, mp_num): + reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + return reshaped + + def _merge_heads(self, tensor, num_attention_heads, attn_head_size): + """ + Merges attn_head_size dim and num_attn_heads dim into n_ctx + """ + if len(tensor.shape) == 5: + tensor = tensor.permute(0, 1, 3, 2, 4).contiguous() + elif len(tensor.shape) == 4: + tensor = tensor.permute(0, 2, 1, 3).contiguous() + else: + raise ValueError(f"Input tensor rank should be one of [4, 5], but is: {len(tensor.shape)}") + new_shape = tensor.shape[:-2] + (num_attention_heads * attn_head_size,) + return tensor.view(new_shape) + + def _attn( + self, + query, + key, + value, + attention_mask=None, + head_mask=None, + ): + + # compute causal mask from causal mask buffer + query_length, key_length = query.shape[-2], key.shape[-2] + causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] + + # Keep the attention weights computation in fp32 to avoid overflow issues + query = Tensor(query, dtype=ms.float32) + key = Tensor(key, dtype=ms.float32) + + new_key = key.transpose(0, 1, 3, 2) + attn_weights = ops.matmul(query, new_key) + + attn_weights = attn_weights / self.scale_attn + attn_weights = ops.where(causal_mask, attn_weights, self.masked_bias) + + if attention_mask is not None: + # Apply the attention mask + attn_weights = attn_weights + attention_mask + + attn_weights = nn.Softmax(axis=-1)(attn_weights) + attn_weights = self.attn_dropout(attn_weights) + + # Mask heads if we want to + if head_mask is not None: + attn_weights = attn_weights * head_mask + + attn_output = ops.matmul(attn_weights, value) + + return attn_output, attn_weights + + +class ProGenMLP(nn.Cell): + def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim + super().__init__() + embed_dim = config.n_embd # n_embd=4096 + + self.fc_in = DenseMindnlp(embed_dim, intermediate_size) + self.fc_out = DenseMindnlp(intermediate_size, embed_dim) + + self.act = NewGELUActivation() + self.dropout = nn.Dropout(config.resid_pdrop) # config.resid_pdrop=0.0 + + def construct(self, hidden_states): + hidden_states = self.fc_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states = self.fc_out(hidden_states) + hidden_states = self.dropout(hidden_states) + return hidden_states + + +class ProGenBlock(nn.Cell): + def __init__(self, config): + super().__init__() + if config.n_inner is not None and config.n_inner != "None": + inner_dim = config.n_inner + else: + inner_dim = 4 * config.n_embd + self.ln_1 = LayerNormMindnlp(config.n_embd, epsilon=config.layer_norm_epsilon) + self.attn = ProGenAttention(config) + self.mlp = ProGenMLP(inner_dim, config) + + def construct( + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, + ): + residual = hidden_states + hidden_states = self.ln_1(hidden_states) + attn_outputs = self.attn( + hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + add_input=add_input, + ) + attn_output = attn_outputs[0] # output_attn: a, present, (attentions) + outputs = attn_outputs[1:] + + feed_forward_hidden_states = self.mlp(hidden_states) + hidden_states = attn_output + feed_forward_hidden_states + residual + + if use_cache: + outputs = (hidden_states,) + outputs + else: + outputs = (hidden_states,) + outputs[1:] + + return outputs # hidden_states, present, (attentions) + + +class ProGenPreTrainedModel(PreTrainedModelMindnlp): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config_class = ProGenConfig + base_model_prefix = "transformer" + is_parallelizable = True + + def __init__(self, *inputs, **kwargs): + super().__init__(*inputs, **kwargs) + + def _init_weights(self, module): + """Initialize the weights.""" + if isinstance(module, (DenseMindnlp,)): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, EmbeddingMindnlp): + module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, LayerNormMindnlp): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + +class ProGenModel(ProGenPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + + self.embed_dim = config.n_embd + self.vocab_size = config.vocab_size + self.wte = EmbeddingMindnlp(config.vocab_size, self.embed_dim) + self.drop = nn.Dropout(config.embd_pdrop) + self.h = nn.CellList([ProGenBlock(config) for _ in range(config.n_layer)]) + self.ln_f = LayerNormMindnlp(self.embed_dim, epsilon=config.layer_norm_epsilon) + self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads) + + + def get_input_embeddings(self): + return self.wte + + def set_input_embeddings(self, new_embeddings): + self.wte = new_embeddings + + def construct( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") + elif input_ids is not None: + input_shape = input_ids.shape + input_ids = input_ids.view(-1, input_shape[-1]) + batch_size = input_ids.shape[0] + elif inputs_embeds is not None: + input_shape = inputs_embeds.shape[:-1] + batch_size = inputs_embeds.shape[0] + else: + raise ValueError("You have to specify either input_ids or inputs_embeds") + + # device = input_ids.device if input_ids is not None else inputs_embeds.device + + if token_type_ids is not None: + token_type_ids = token_type_ids.view(-1, input_shape[-1]) + + if position_ids is not None: + position_ids = position_ids.view(-1, input_shape[-1]) + + if past_key_values is None: + past_length = 0 + past_key_values = tuple([None] * len(self.h)) + else: + past_length = past_key_values[0][0].shape[-2] + + if position_ids is None: + position_ids = ops.arange(past_length, input_shape[-1] + past_length) + position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1]) + + # Attention mask. + if attention_mask is not None: + attention_mask = attention_mask.view(batch_size, -1) + # We create a 3D attention mask from a 2D tensor mask. + # Sizes are [batch_size, 1, 1, to_seq_length] + # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] + # this attention mask is more simple than the triangular masking of causal attention + # used in OpenAI GPT, we just need to prepare the broadcast dimension here. + attention_mask = attention_mask[:, None, None, :] + + # Since attention_mask is 1.0 for positions we want to attend and 0.0 for + # masked positions, this operation will create a tensor which is 0.0 for + # positions we want to attend and -10000.0 for masked positions. + # Since we are adding it to the raw scores before the softmax, this is + # effectively the same as removing these entirely. + attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility + attention_mask = (1.0 - attention_mask) * -10000.0 + + # Prepare head mask if needed + # 1.0 in head_mask indicate we keep the head + # attention_probs has shape bsz x num_attention_heads x N x N + # head_mask has shape n_layer x batch x num_attention_heads x N x N + head_mask = self.get_head_mask(head_mask, self.config.n_layer) + + if inputs_embeds is None: + inputs_embeds = self.wte(input_ids) + + hidden_states = inputs_embeds + + if token_type_ids is not None: + token_type_embeds = self.wte(token_type_ids) + hidden_states = hidden_states + token_type_embeds + + hidden_states = self.drop(hidden_states) + + output_shape = input_shape + (hidden_states.shape[-1],) + + presents = () if use_cache else None + all_self_attentions = () if output_attentions else None + all_hidden_states = () if output_hidden_states else None + + for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)): + + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if getattr(self.config, "gradient_checkpointing", False) and self.training: + + if use_cache: + use_cache = False + + else: + add_input = (layer_past, attention_mask, head_mask[i]) + outputs = block( + hidden_states, + use_cache=use_cache, + output_attentions=output_attentions, + add_input=add_input + ) + + hidden_states = outputs[0] + if use_cache is True: + presents = presents + (outputs[1],) + + if output_attentions: + all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],) + + hidden_states = self.ln_f(hidden_states) + + hidden_states = hidden_states.view(*output_shape) + # Add last hidden state + if output_hidden_states: + all_hidden_states = all_hidden_states + (hidden_states,) + + if not return_dict: + return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None) + + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=presents, + hidden_states=all_hidden_states, + attentions=all_self_attentions, + ) + + +class ProGenForCausalLM(ProGenPreTrainedModel): + _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] + + def __init__(self, config): + super().__init__(config) + self.transformer = ProGenModel(config) + self.lm_head = DenseMindnlp(config.n_embd, config.vocab_size) + + @staticmethod + def _reorder_cache(past: Tuple[Tuple[Tensor]], beam_idx: Tensor) -> Tuple[Tuple[Tensor]]: + """ + This function is used to re-order the :obj:`past_key_values` cache if + :meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is + called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step. + """ + return tuple( + tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past) + for layer_past in past + ) + + + def construct( + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, + ): + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + transformer_outputs = self.transformer( + input_ids, + past_key_values=past_key_values, + attention_mask=attention_mask, + token_type_ids=token_type_ids, + position_ids=position_ids, + head_mask=head_mask, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + hidden_states = transformer_outputs[0] + + lm_logits = self.lm_head(hidden_states).to(ms.float32) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = lm_logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + loss = loss_fct(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1).astype("int32")) + + loss = loss.to(hidden_states.dtype) + + if not return_dict: + output = (lm_logits,) + transformer_outputs[1:] + return ((loss,) + output) if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=lm_logits, + past_key_values=transformer_outputs.past_key_values, + hidden_states=transformer_outputs.hidden_states, + attentions=transformer_outputs.attentions, + ) + + + def get_output_embeddings(self): + return None + + def set_output_embeddings(self, new_embeddings): + return + + def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + token_type_ids = kwargs.get("token_type_ids", None) + # only last token for inputs_ids if past is defined in kwargs + if past_key_values: + input_ids = input_ids[:, -1].unsqueeze(-1) + if token_type_ids is not None: + token_type_ids = token_type_ids[:, -1].unsqueeze(-1) + + attention_mask = kwargs.get("attention_mask", None) + position_ids = kwargs.get("position_ids", None) + + if attention_mask is not None and position_ids is None: + # create position_ids on the fly for batch generation + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill(attention_mask == 0, 1) + if past_key_values: + position_ids = position_ids[:, -1].unsqueeze(-1) + else: + position_ids = None + return { + "input_ids": input_ids, + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "position_ids": position_ids, + "attention_mask": attention_mask, + "token_type_ids": token_type_ids, + } + + +class print_time: + def __init__(self, desc): + self.desc = desc + + def __enter__(self): + print(self.desc) + self.t = time.time() + + def __exit__(self, type, value, traceback): + print(f'{self.desc} took {time.time()-self.t:.02f}s') diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py new file mode 100644 index 000000000..0d982361d --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py @@ -0,0 +1,281 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen""" +import os +import random +import argparse + +import mindspore as ms +from mindspore import ops, Tensor, load_checkpoint, load_param_into_net +from tokenizers import Tokenizer + +from .nn_arch import ProGenForCausalLM, ProGenConfig, print_time +from ..model import Model + + +class ProGen(Model): + """ + ProGen class + """ + name = "ProGen" + + def __init__(self, config): + self.args = config + self.config = ProGenConfig( + config.vocab_size, config.n_positions, config.n_ctx, config.n_embd, + config.n_layer, config.n_head, config.rotary_dim, config.n_inner, + config.activation_function, config.resid_pdrop, config.embd_pdrop, + config.attn_pdrop, config.layer_norm_epsilon, config.initializer_range, + config.scale_attn_weights, config.gradient_checkpointing, config.use_cache, + config.bos_token_id, config.eos_token_id, config.min_length, + ) + self.mixed_precision = False + self.network = ProGenForCausalLM(self.config) + self.checkpoint_url = "Local_Checkpoint_Used" + self.checkpoint_path = config.ckpt_dir + with print_time('loading tokenizer'): + self.tokenizer = self.create_tokenizer_custom(file=self.args.tokenizer_file) + self.set_env() + self.set_seed(self.args.rng_seed, deterministic=self.args.rng_deterministic) + + super().__init__(checkpoint_url=self.checkpoint_url, checkpoint_path=self.checkpoint_path, + network=self.network, mixed_precision=self.mixed_precision) + + def set_env(self): + os.environ['TOKENIZERS_PARALLELISM'] = 'false' + + def set_seed(self, seed, deterministic=True): + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + + def create_tokenizer_custom(self, file): + with open(file, 'r') as f: + return Tokenizer.from_str(f.read()) + + def cross_entropy(self, logits, target, reduction='mean'): + return ops.cross_entropy(input=logits, target=target, reduction=reduction) + + def log_likelihood(self, logits, target, reduction='mean'): + return -self.cross_entropy(logits.view(-1, logits.shape[-1]), target.view(-1), reduction=reduction) + + def log_likelihood_custom_1(self, logits, target, reduction='mean'): + return -ops.nll_loss(inputs=ops.log_softmax(logits, axis=1), target=target, reduction=reduction) + + def log_likelihood_custom_2(self, logits, target, reduction='mean'): + log_likelihood = 0.0 + n = logits.shape[0] + for i in range(n): + log_likelihood += ops.log_softmax(logits, axis=1)[i, target[i]] / (1. if reduction == 'sum' else n) + return log_likelihood + + def ce(self, tokens): + target = ms.tensor(self.tokenizer.encode(tokens).ids) + logits = self.network(target, labels=target).logits + # shift + logits = logits[:-1, ...] + target = target[1:] + return self.cross_entropy(logits=logits, target=target).item() + + def ll(self, tokens, f, reduction): + target = Tensor(self.tokenizer.encode(tokens).ids) + logits = self.network(target, labels=target).logits + + # shift + logits = logits[:-1, ...] + target = target[1:] + + # remove terminals + bos_token, eos_token = 3, 4 + if target[-1] in [bos_token, eos_token]: + logits = logits[:-1, ...] + target = target[:-1] + + # remove unused logits + first_token, last_token = 5, 29 + logits = logits[:, first_token:(last_token+1)] + target = target - first_token + + return f(logits=logits, target=target, reduction=reduction).item() + + def from_pretrained(self, ckpt_path=None): + "from_pretrained" + self.get_checkpoint_path(ckpt_path) + if not ckpt_path: + param_dict = ms.load_checkpoint(self.checkpoint_path) + else: + param_dict = ms.load_checkpoint(ckpt_path) + param_not_load, _ = ms.load_param_into_net(self.network, param_dict) + print(f'param not load: {param_not_load}') + + def predict(self, data, **kwargs): + if self.args.sanity: + + with print_time('sanity cross-entropy'): + + x_uniref90bfd30 = self.args.x_uniref90bfd30 + x_oas = self.args.x_oas + x_bfd90 = self.args.x_bfd90 + + checkpoint_x_ce = { + 'progen2-small': (x_uniref90bfd30, 2.4), + 'progen2-medium': (x_uniref90bfd30, 1.9), + 'progen2-base': (x_uniref90bfd30, 1.9), + 'progen2-large': (x_uniref90bfd30, 1.8), + 'progen2-xlarge': (x_uniref90bfd30, 1.0), + 'progen2-oas': (x_oas, 0.3), + 'progen2-BFD90': (x_bfd90, 1.3), + } + + ce_eval = self.ce(checkpoint_x_ce[self.args.model][0]) + ce_target = checkpoint_x_ce[self.args.model][1] + + print(ce_target, ce_eval, abs(ce_eval - ce_target)) + + with print_time('sanity log-likelihood'): + + x_data = self.args.x_data + + ll_0 = self.ll(x_data, f=self.log_likelihood, reduction='mean') + ll_1 = self.ll(x_data, f=self.log_likelihood_custom_1, reduction='mean') + ll_2 = self.ll(x_data, f=self.log_likelihood_custom_2, reduction='mean') + + print(f'll_0={ll_0}') + print(f'll_1={ll_1}') + print(f'll_2={ll_2}') + + with print_time('sanity model'): + + alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + x_data = self.args.x_data + x_random = '2' + ''.join([random.choice(alphabet) for _ in range(len(x_data)-2)]) + '1' + x_perturb = x_random[:64] + x_data[len(x_random[:64]):] + + print(x_data) + print(x_perturb) + print(x_random) + + ll_x_data = self.ll(x_data, f=self.log_likelihood, reduction='mean') + ll_x_random = self.ll(x_random, f=self.log_likelihood, reduction='mean') + ll_x_perturb = self.ll(x_perturb, f=self.log_likelihood, reduction='mean') + + print(f'll_x_data={ll_x_data}') + print(f'll_x_random={ll_x_random}') + print(f'll_x_perturb={ll_x_perturb}') + + with print_time('log-likelihood (left-to-right, right-to-left)'): + + reverse = lambda s: s[::-1] + + ll_lr_sum = self.ll(tokens=data, f=self.log_likelihood, reduction='sum') + ll_rl_sum = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='sum') + + ll_lr_mean = self.ll(tokens=data, f=self.log_likelihood, reduction='mean') + ll_rl_mean = self.ll(tokens=reverse(data), f=self.log_likelihood, reduction='mean') + + ll_sum = .5 * (ll_lr_sum + ll_rl_sum) + ll_mean = .5 * (ll_lr_mean + ll_rl_mean) + + print(f'll_sum={(ll_sum)}') + print(f'll_mean={ll_mean}') + + def sample(self, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id): + input_ids = Tensor([self.tokenizer.encode(context).ids]) + tokens_batch = model.generate( + input_ids, + do_sample=True, + temperature=temp, + max_length=max_length, + top_p=top_p, + num_return_sequences=num_return_sequences, + pad_token_id=pad_token_id, + ) + as_lists = lambda batch: [batch[i, ...].asnumpy().tolist() for i in range(batch.shape[0])] + return self.tokenizer.decode_batch(as_lists(tokens_batch)) + + def truncate(self, input_sample, terminals): + pos = [] + for terminal in terminals: + find_pos = input_sample.find(terminal, 1) + if find_pos != -1: + pos.append(find_pos) + if len(pos) > 0: + return input_sample[:(min(pos) + 1)] + else: + return input_sample + + def generate(self): + if self.args.sanity: + with print_time('sanity cross-entropy'): + + x_uniref90bfd30 = self.args.x_uniref90bfd30 + x_oas = self.args.x_oas + x_bfd90 = self.args.x_bfd90 + + checkpoint_x_ce = { + 'progen2-small': (x_uniref90bfd30, 2.4), + 'progen2-medium': (x_uniref90bfd30, 1.9), + 'progen2-base': (x_uniref90bfd30, 1.9), + 'progen2-large': (x_uniref90bfd30, 1.8), + 'progen2-xlarge': (x_uniref90bfd30, 1.0), + 'progen2-oas': (x_oas, 0.3), + 'progen2-BFD90': (x_bfd90, 1.3), + } + + ce_eval = self.ce(checkpoint_x_ce.get(self.args.model)[0]) + ce_target = checkpoint_x_ce.get(self.args.model)[1] + + print(ce_target, ce_eval, abs(ce_eval - ce_target)) + + if abs(ce_eval - ce_target) >= 0.1: + raise ValueError("Difference should be within 0.1") + + with print_time('sampling'): + completions = self.sample(model=self.network, tokenizer=self.tokenizer, context=self.args.context, + pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], num_return_sequences=self.args.num_samples, + temp=self.args.t, top_p=self.args.p, max_length=self.args.max_length) + truncations = [self.truncate(completion, terminals=['1', '2']) for completion in completions] + + for (i, truncation) in enumerate(truncations): + print() + print(i) + print(truncation) + + + def train_step(self, data): + return None + + + def forward(self, data): + return None + + + def backward(self, data): + return None + + + def _jit_forward(self, data): + return None + + + def _pynative_forward(self, data): + return None diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py new file mode 100644 index 000000000..53577e71d --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py @@ -0,0 +1,29 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen_configuration""" + +progen_configuration = { + "small": + "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/ProGen/small.yaml", +} + diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py new file mode 100644 index 000000000..aa6cf5939 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -0,0 +1,49 @@ +# Copyright 2024 @ Shenzhen Bay Laboratory & +# Peking University & +# Huawei Technologies Co., Ltd +# +# This code is a part of MindSPONGE: +# MindSpore Simulation Package tOwards Next Generation molecular modelling. +# +# MindSPONGE is open-source software based on the AI-framework: +# MindSpore (https://www.mindspore.cn/) +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""progen_dataset""" +from ...dataset import PSP + + +class ProGenDataSet(PSP): + "EquiDockDataSet" + def __init__(self, config): + self.config = config + super().__init__() + + def process(self, data, **kwargs): + return data + + def set_training_data_src(self, data_source, **kwargs): + return None + + def create_iterator(self, num_epochs, **kwargs): + return None + + def data_parse(self, idx): + return None + + def __getitem__(self, idx): + pass + + def __len__(self): + pass diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml b/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml new file mode 100644 index 000000000..7b40024c2 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml @@ -0,0 +1,37 @@ +config: "small" +model: "progen2-small" +p: 0.95 +t: 0.2 +max_length: 8 +num_samples: 1 +context: "1" +vocab_size: 32 +n_positions: 1024 +n_ctx: 2048 +n_embd: 1024 +n_layer: 12 +n_head: 16 +rotary_dim: 32 +n_inner: None +activation_function: "gelu_new" +resid_pdrop: 1.0 +embd_pdrop: 1.0 +attn_pdrop: 1.0 +layer_norm_epsilon: 0.00001 +initializer_range: 0.02 +scale_attn_weights: True +gradient_checkpointing: False +use_cache: True +bos_token_id: 1 +eos_token_id: 2 +min_length: 1 +ckpt_dir: './progen2-small.ckpt' +tokenizer_file: './tokenizer.json' +rng_seed: 42 +rng_deterministic: True +fp16: True +sanity: True +x_uniref90bfd30: '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1' +x_oas: '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2' +x_bfd90: '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' +x_data: '2PAQGRARLAAHYGTGRIGREVTVDERCRNLDRLEPSWELLRLLDDMGFIEGQNGLRRYVAEVFALDEPYDMTWRLRSLDEPHEVNAIEFAAPHERVYATLSERFFPDSVERDLRELVTRSLVEVDLGDPFTPPFVNSVYELRGASRRWVGVVRDVLAPDVLPCDATIRVLADAGTRAATRGLREILDTESGRVCVLGLHAALDAIADDRNEVSTSVAVADLEQCVALREAIRQITPRGAISVLVKGPLRTSGMRAQIAAVVHLRAKSSHLLPGGTDVVTFGAREFAIRSAANERKVVASMRLLALPGFAERSLCGLARPGVGRGRWEPAINVSVAADRDQIDLRVMGADVGDASVIFLKRDFRKLTEEFWRTHTDVPIEREDVSAQRTEPDNRWRWLVPCDDLVAPRLTVVPPRSVGHGM1' diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json new file mode 100644 index 000000000..cd3d5a9d0 --- /dev/null +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json @@ -0,0 +1,91 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "special": true, + "content": "<|pad|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 1, + "special": true, + "content": "<|bos|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + }, + { + "id": 2, + "special": true, + "content": "<|eos|>", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false + } + ], + "normalizer": null, + "pre_tokenizer": { + "type": "ByteLevel", + "add_prefix_space": false, + "trim_offsets": true + }, + "post_processor": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true + }, + "decoder": { + "type": "ByteLevel", + "add_prefix_space": true, + "trim_offsets": true + }, + "model": { + "type": "BPE", + "dropout": null, + "unk_token": null, + "continuing_subword_prefix": null, + "end_of_word_suffix": null, + "fuse_unk": false, + "vocab": { + "<|pad|>": 0, + "<|bos|>": 1, + "<|eos|>": 2, + "1": 3, + "2": 4, + "A": 5, + "B": 6, + "C": 7, + "D": 8, + "E": 9, + "F": 10, + "G": 11, + "H": 12, + "I": 13, + "K": 14, + "L": 15, + "M": 16, + "N": 17, + "O": 18, + "P": 19, + "Q": 20, + "R": 21, + "S": 22, + "T": 23, + "U": 24, + "V": 25, + "W": 26, + "X": 27, + "Y": 28, + "Z": 29 + }, + "merges": [] + } +} \ No newline at end of file diff --git a/MindSPONGE/src/mindsponge/pipeline/pipeline.py b/MindSPONGE/src/mindsponge/pipeline/pipeline.py index 1de798558..631359dae 100644 --- a/MindSPONGE/src/mindsponge/pipeline/pipeline.py +++ b/MindSPONGE/src/mindsponge/pipeline/pipeline.py @@ -36,7 +36,7 @@ from .models import Multimer, MultimerDataSet, multimer_configuration from .models import ProteinMpnn, ProteinMpnnDataset, proteinmpnn_configuration from .models import UFold, UFoldDataSet, ufold_configuration from .models import EquiDock, EquiDockDataSet, equidock_configuration - +from .models import ProGen, ProGenDataSet, progen_configuration model_card = { "ColabDesign": {"model": COLABDESIGN, "dataset": ColabDesignDataSet, "config": colabdesign_configuration}, @@ -55,6 +55,7 @@ model_card = { "RASP": {"model": RASP, "dataset": RASPDataSet, "config": rasp_configuration}, "UFold": {"model": UFold, "dataset": UFoldDataSet, "config": ufold_configuration}, "EquiDock": {"model": EquiDock, "dataset": EquiDockDataSet, "config": equidock_configuration}, + "ProGen": {"model": ProGen, "dataset": ProGenDataSet, "config": progen_configuration}, } -- Gitee From c15947d6168c1e9d272667a949ae4027932e244b Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 3 Sep 2024 20:58:58 +0800 Subject: [PATCH 07/16] small change --- .../progen/module/configuration_utils.py | 77 +++++-------------- .../pipeline/models/progen/small.yaml | 37 --------- 2 files changed, 18 insertions(+), 96 deletions(-) delete mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index be9376f11..7534b640d 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -98,23 +98,10 @@ class CellUtilMixin: """ def get_head_mask( - self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False - ) -> Tensor: + self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False + ) -> Tensor: """ Prepare the head mask if needed. - - Args: - head_mask (`mindspore.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*): - The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard). - num_hidden_layers (`int`): - The number of hidden layers in the model. - is_attention_chunked: (`bool`, *optional*, defaults to `False`): - Whether or not the attentions scores are computed by chunks or not. - - Returns: - `mindspore.Tensor` with shape `[num_hidden_layers x batch x - num_heads x seq_length x seq_length]` or list with - `[None]` for each layer. """ if head_mask is not None: head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers) @@ -426,23 +413,6 @@ class ContrastiveSearchEncoderDecoderOutput(ModelOutputMindnlp): class ContrastiveSearchDecoderOnlyOutput(ModelOutputMindnlp): """ Base class for outputs of decoder-only generation models using contrastive search. - - Args: - sequences (`mindspore.Tensor` of shape `(batch_size, sequence_length)`): - The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter - if all batches finished early due to the `eos_token_id`. - scores (`tuple(mindspore.Tensor)` *optional*, returned when `output_scores=True` is passed or when - `config.output_scores=True`): - Processed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax) - at each generation step. Tuple of `mindspore.Tensor` with up to `max_new_tokens` elements (one element for - each generated token), with each tensor of shape `(batch_size, config.vocab_size)`. - attentions (`tuple(tuple(mindspore.Tensor))`, *optional*, returned when `output_attentions=True` is passed or `config.output_attentions=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `mindspore.Tensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`. - hidden_states (`tuple(tuple(mindspore.Tensor))`, *optional*, returned when `output_hidden_states=True` is - passed or when `config.output_hidden_states=True`): - Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of - `mindspore.Tensor` of shape `(batch_size, generated_length, hidden_size)`. """ sequences: mindspore.Tensor = None @@ -451,17 +421,6 @@ class ContrastiveSearchDecoderOnlyOutput(ModelOutputMindnlp): hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None -# @dataclass -# class GreedySearchDecoderOnlyOutput(ModelOutputMindnlp): -# """ -# Base class for outputs of decoder-only generation models using greedy search. -# """ - -# sequences: mindspore.Tensor = None -# scores: Optional[Tuple[mindspore.Tensor]] = None -# attentions: Optional[Tuple[Tuple[mindspore.Tensor]]] = None -# hidden_states: Optional[Tuple[Tuple[mindspore.Tensor]]] = None - GreedySearchOutput = Union[GreedySearchEncoderDecoderOutput, GreedySearchDecoderOnlyOutput] SampleOutput = Union[SampleEncoderDecoderOutput, SampleDecoderOnlyOutput] BeamSearchOutput = Union[BeamSearchEncoderDecoderOutput, BeamSearchDecoderOnlyOutput] @@ -507,7 +466,7 @@ class GenerationConfig: self.bos_token_id = kwargs.pop("bos_token_id", None) self.eos_token_id = kwargs.pop("eos_token_id", None) self._from_model_config = kwargs.pop("_from_model_config", False) - + # Additional attributes without default values if not self._from_model_config: # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a @@ -555,7 +514,7 @@ class GenerationConfig: return config - def set_from_model_config(self, value:bool): + def set_from_model_config(self, value: bool): """set _from_model_config""" if not isinstance(value, bool): raise TypeError @@ -598,11 +557,11 @@ class GenerationMixin: @staticmethod def _expand_inputs_for_generation( - expand_size: int = 1, - is_encoder_decoder: bool = False, - input_ids: Optional[mindspore.Tensor] = None, - **model_kwargs, - ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: + expand_size: int = 1, + is_encoder_decoder: bool = False, + input_ids: Optional[mindspore.Tensor] = None, + **model_kwargs, + ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" def _expand_dict_for_generation(dict_to_expand): @@ -654,10 +613,10 @@ class GenerationMixin: @classmethod def _prepare_decoder_input_ids_for_generation( - cls, - model_input_name: str, - model_kwargs: Dict[str, mindspore.Tensor], - ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: + cls, + model_input_name: str, + model_kwargs: Dict[str, mindspore.Tensor], + ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. @@ -673,10 +632,10 @@ class GenerationMixin: @classmethod def _merge_criteria_processor_list( - cls, - default_list: Union[LogitsProcessorList, StoppingCriteriaList], - custom_list: Union[LogitsProcessorList, StoppingCriteriaList], - ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + cls, + default_list: Union[LogitsProcessorList, StoppingCriteriaList], + custom_list: Union[LogitsProcessorList, StoppingCriteriaList], + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: if len(custom_list) == 0: return default_list for default in default_list: @@ -1021,7 +980,7 @@ class GenerationMixin: else: if inputs is not None: raise ValueError("You passed `inputs_embeds` and `input_ids` to `.generate()`. Please pick one.") - inputs, input_name = model_kwargs[inputs_embeds_str], + inputs, input_name = model_kwargs[inputs_embeds_str][0], model_kwargs[inputs_embeds_str][1] # 4. if `inputs` is still None, try to create `input_ids` from BOS token inputs = self._maybe_initialize_input_ids_for_generation(inputs, bos_token_id, model_kwargs) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml b/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml deleted file mode 100644 index 7b40024c2..000000000 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/small.yaml +++ /dev/null @@ -1,37 +0,0 @@ -config: "small" -model: "progen2-small" -p: 0.95 -t: 0.2 -max_length: 8 -num_samples: 1 -context: "1" -vocab_size: 32 -n_positions: 1024 -n_ctx: 2048 -n_embd: 1024 -n_layer: 12 -n_head: 16 -rotary_dim: 32 -n_inner: None -activation_function: "gelu_new" -resid_pdrop: 1.0 -embd_pdrop: 1.0 -attn_pdrop: 1.0 -layer_norm_epsilon: 0.00001 -initializer_range: 0.02 -scale_attn_weights: True -gradient_checkpointing: False -use_cache: True -bos_token_id: 1 -eos_token_id: 2 -min_length: 1 -ckpt_dir: './progen2-small.ckpt' -tokenizer_file: './tokenizer.json' -rng_seed: 42 -rng_deterministic: True -fp16: True -sanity: True -x_uniref90bfd30: '2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1' -x_oas: '1EVQLVESGGGLVQPGGSLRLSCAASGFTFSSYAMHWVRQAPWKGLEYVSAISSNGGSTYYANSVKGRFTISRDNSKNTLYLQMGSLRAEDMAVYYCARDESGYSYGWGYYFDYWGQGTLVTVSS2' -x_bfd90: '1TAPRSTRASGSEGSRPPGIPAKGRRCLPSRAGSVTPRFRHARQGTATVAKEQGRKLIASNRKARHDYHIEDTFEAGLVLTGTEVKSLRMGRASLIDGYAVFYGEELWLEGVHIPEYLNGNWTNHTPRRRRKLLLNRSELTKLAHKTSESGHTIVPLALYFKDGRAKVEIAVAKGKKAYDKRHALRERQDQREV2' -x_data: '2PAQGRARLAAHYGTGRIGREVTVDERCRNLDRLEPSWELLRLLDDMGFIEGQNGLRRYVAEVFALDEPYDMTWRLRSLDEPHEVNAIEFAAPHERVYATLSERFFPDSVERDLRELVTRSLVEVDLGDPFTPPFVNSVYELRGASRRWVGVVRDVLAPDVLPCDATIRVLADAGTRAATRGLREILDTESGRVCVLGLHAALDAIADDRNEVSTSVAVADLEQCVALREAIRQITPRGAISVLVKGPLRTSGMRAQIAAVVHLRAKSSHLLPGGTDVVTFGAREFAIRSAANERKVVASMRLLALPGFAERSLCGLARPGVGRGRWEPAINVSVAADRDQIDLRVMGADVGDASVIFLKRDFRKLTEEFWRTHTDVPIEREDVSAQRTEPDNRWRWLVPCDDLVAPRLTVVPPRSVGHGM1' -- Gitee From 03941b863703a8ab55f20248b83afa5be3135b86 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Wed, 4 Sep 2024 14:56:01 +0800 Subject: [PATCH 08/16] pr modification --- .../pipeline/models/progen/__init__.py | 1 - .../progen/module/configuration_utils.py | 918 +++--------------- .../models/progen/module/injection.py | 48 +- .../models/progen/module/logits_process.py | 373 +------ .../pipeline/models/progen/nn_arch.py | 127 ++- .../pipeline/models/progen/progen.py | 48 +- .../models/progen/progen_configuration.py | 1 - 7 files changed, 255 insertions(+), 1261 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py index 04a85bf5e..b22c10e4e 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py @@ -25,4 +25,3 @@ from .progen import ProGen from .progen_dataset import ProGenDataSet from .progen_configuration import progen_configuration - diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index 7534b640d..92be9cf4a 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -26,7 +26,7 @@ import json import os import warnings import inspect -from typing import Optional, List, Callable, Dict, Any, Tuple, Union, Iterable +from typing import Optional, List, Callable, Dict, Any, Tuple, Union from enum import Enum from collections import OrderedDict, UserDict from dataclasses import fields @@ -37,32 +37,15 @@ import mindspore from mindspore import nn, ops, Tensor, Parameter, jit_class from .logits_process import ( - EncoderNoRepeatNGramLogitsProcessor, - EncoderRepetitionPenaltyLogitsProcessor, - EpsilonLogitsWarper, - EtaLogitsWarper, - ExponentialDecayLengthPenalty, - ForcedBOSTokenLogitsProcessor, ForcedEOSTokenLogitsProcessor, - ForceTokensLogitsProcessor, - HammingDiversityLogitsProcessor, - InfNanRemoveLogitsProcessor, - LogitNormalization, LogitsProcessorList, MinLengthLogitsProcessor, MinNewTokensLengthLogitsProcessor, NoBadWordsLogitsProcessor, - NoRepeatNGramLogitsProcessor, - PrefixConstrainedLogitsProcessor, - RepetitionPenaltyLogitsProcessor, - SequenceBiasLogitsProcessor, - SuppressTokensAtBeginLogitsProcessor, - SuppressTokensLogitsProcessor, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper, TypicalLogitsWarper, - UnbatchedClassifierFreeGuidanceLogitsProcessor, ) DEFAULT_DTYPE = mindspore.float32 @@ -99,7 +82,7 @@ class CellUtilMixin: def get_head_mask( self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False - ) -> Tensor: + ) -> Tensor: """ Prepare the head mask if needed. """ @@ -158,7 +141,7 @@ class ModelOutputMindnlp(OrderedDict): class_fields = fields(self) # Safety and consistency checks - if len(class_fields) == 0: + if not class_fields: raise ValueError(f"{self.__class__.__name__} has no fields.") if not all(field.default is None for field in class_fields[1:]): raise ValueError(f"{self.__class__.__name__} should not have more than one required field.") @@ -561,7 +544,7 @@ class GenerationMixin: is_encoder_decoder: bool = False, input_ids: Optional[mindspore.Tensor] = None, **model_kwargs, - ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: + ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]""" def _expand_dict_for_generation(dict_to_expand): @@ -616,7 +599,7 @@ class GenerationMixin: cls, model_input_name: str, model_kwargs: Dict[str, mindspore.Tensor], - ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: + ) -> Tuple[mindspore.Tensor, Dict[str, mindspore.Tensor]]: """Prepares `decoder_input_ids` for generation with encoder-decoder models""" # 1. Check whether the user has defined `decoder_input_ids` manually. To facilitate in terms of input naming, # we also allow the user to pass it under `input_ids`, if the encoder does not use it as the main input. @@ -635,8 +618,8 @@ class GenerationMixin: cls, default_list: Union[LogitsProcessorList, StoppingCriteriaList], custom_list: Union[LogitsProcessorList, StoppingCriteriaList], - ) -> Union[LogitsProcessorList, StoppingCriteriaList]: - if len(custom_list) == 0: + ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + if not custom_list: return default_list for default in default_list: for custom in custom_list: @@ -653,8 +636,8 @@ class GenerationMixin: return default_list def _get_logits_warper( - self, - generation_config: GenerationConfig, + self, + generation_config: GenerationConfig, ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsWarper`] instances @@ -689,13 +672,13 @@ class GenerationMixin: return warpers def generate( - self, - inputs: Optional[mindspore.Tensor] = None, - generation_config: Optional[GenerationConfig] = None, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - streamer: Optional["BaseStreamer"] = None, - **kwargs, + self, + inputs: Optional[mindspore.Tensor] = None, + generation_config: Optional[GenerationConfig] = None, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + streamer: Optional["BaseStreamer"] = None, + **kwargs, ) -> Union[GreedySearchDecoderOnlyOutput, mindspore.Tensor]: # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call @@ -768,27 +751,27 @@ class GenerationMixin: ) def sample( - self, - input_ids: mindspore.Tensor, - logits_processor: Optional[LogitsProcessorList] = None, - stopping_criteria: Optional[StoppingCriteriaList] = None, - logits_warper: Optional[LogitsProcessorList] = None, - max_length: Optional[int] = None, - pad_token_id: Optional[int] = None, - eos_token_id: Optional[Union[int, List[int]]] = None, - output_attentions: Optional[bool] = None, - output_hidden_states: Optional[bool] = None, - output_scores: Optional[bool] = None, - return_dict_in_generate: Optional[bool] = None, - streamer: Optional["BaseStreamer"] = None, - **model_kwargs, + self, + input_ids: mindspore.Tensor, + logits_processor: Optional[LogitsProcessorList] = None, + stopping_criteria: Optional[StoppingCriteriaList] = None, + logits_warper: Optional[LogitsProcessorList] = None, + max_length: Optional[int] = None, + pad_token_id: Optional[int] = None, + eos_token_id: Optional[Union[int, List[int]]] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + output_scores: Optional[bool] = None, + return_dict_in_generate: Optional[bool] = None, + streamer: Optional["BaseStreamer"] = None, + **model_kwargs, ) -> Union[SampleOutput, mindspore.Tensor]: r""" Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models. """ # init values - synced_gpus=None + synced_gpus = None # init values logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList() stopping_criteria = stopping_criteria if stopping_criteria is not None else StoppingCriteriaList() @@ -929,10 +912,10 @@ class GenerationMixin: return input_ids def _prepare_model_inputs( - self, - inputs: Optional[mindspore.Tensor] = None, - bos_token_id: Optional[int] = None, - model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, + self, + inputs: Optional[mindspore.Tensor] = None, + bos_token_id: Optional[int] = None, + model_kwargs: Optional[Dict[str, mindspore.Tensor]] = None, ) -> Tuple[mindspore.Tensor, Optional[str], Dict[str, mindspore.Tensor]]: """ This function extracts the model-specific `inputs` for generation. @@ -1002,11 +985,11 @@ class GenerationMixin: return past_key_values def _update_model_kwargs_for_generation( - self, - outputs, - model_kwargs: Dict[str, Any], - is_encoder_decoder: bool = False, - standardize_cache_format: bool = False, + self, + outputs, + model_kwargs: Dict[str, Any], + is_encoder_decoder: bool = False, + standardize_cache_format: bool = False, ) -> Dict[str, Any]: # update past_key_values model_kwargs["past_key_values"] = self._extract_past_from_model_output( @@ -1028,10 +1011,10 @@ class GenerationMixin: ) def _get_logits_processor( - self, - generation_config: GenerationConfig, - input_ids_seq_length: int, - logits_processor: Optional[LogitsProcessorList], + self, + generation_config: GenerationConfig, + input_ids_seq_length: int, + logits_processor: Optional[LogitsProcessorList], ) -> LogitsProcessorList: """ This class returns a [`LogitsProcessorList`] list object that contains all relevant [`LogitsProcessor`] @@ -1045,15 +1028,15 @@ class GenerationMixin: NoBadWordsLogitsProcessor(generation_config.bad_words_ids, generation_config.eos_token_id) ) if ( - generation_config.min_length is not None - and generation_config.eos_token_id is not None - and generation_config.min_length > 0 + generation_config.min_length is not None + and generation_config.eos_token_id is not None + and generation_config.min_length > 0 ): processors.append(MinLengthLogitsProcessor(generation_config.min_length, generation_config.eos_token_id)) if ( - generation_config.min_new_tokens is not None - and generation_config.eos_token_id is not None - and generation_config.min_new_tokens > 0 + generation_config.min_new_tokens is not None + and generation_config.eos_token_id is not None + and generation_config.min_new_tokens > 0 ): processors.append( MinNewTokensLengthLogitsProcessor( @@ -1071,7 +1054,7 @@ class GenerationMixin: return processors def _get_stopping_criteria( - self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] + self, generation_config: GenerationConfig, stopping_criteria: Optional[StoppingCriteriaList] ) -> StoppingCriteriaList: criteria = StoppingCriteriaList() if generation_config.max_length is not None: @@ -1182,72 +1165,24 @@ class PeftAdapterMixin: _hf_peft_config_loaded = False def load_adapter( - self, - peft_model_id: Optional[str] = None, - adapter_name: Optional[str] = None, - revision: Optional[str] = None, - token: Optional[str] = None, - device_map: Optional[str] = "auto", - max_memory: Optional[str] = None, - offload_folder: Optional[str] = None, - offload_index: Optional[int] = None, - peft_config: Dict[str, Any] = None, - adapter_state_dict: Optional[Dict[str, "mindspore.Tensor"]] = None, - adapter_kwargs: Optional[Dict[str, Any]] = None, + self, + peft_model_id: Optional[str] = None, + adapter_name: Optional[str] = None, + revision: Optional[str] = None, + token: Optional[str] = None, + device_map: Optional[str] = "auto", + max_memory: Optional[str] = None, + offload_folder: Optional[str] = None, + offload_index: Optional[int] = None, + peft_config: Dict[str, Any] = None, + adapter_state_dict: Optional[Dict[str, "mindspore.Tensor"]] = None, + adapter_kwargs: Optional[Dict[str, Any]] = None, ) -> None: """ Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft Requires peft as a backend to load the adapter weights. - - Args: - peft_model_id (`str`, *optional*): - The identifier of the model to look for on the Hub, or a local path to the saved adapter config file - and adapter weights. - adapter_name (`str`, *optional*): - The adapter name to use. If not set, will use the default adapter. - revision (`str`, *optional*, defaults to `"main"`): - The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a - git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any - identifier allowed by git. - - - - To test a pull request you made on the Hub, you can pass `revision="refs/pr/". - - - - token (`str`, `optional`): - Whether to use authentication token to load the remote folder. Userful to load private repositories - that are on HuggingFace Hub. You might need to call `huggingface-cli login` and paste your tokens to - cache it. - device_map (`str` or `Dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*): - A map that specifies where each submodule should go. It doesn't need to be refined to each - parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the - same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank - like `1`) on which the model will be allocated, the device map will map the entire model to this - device. Passing `device_map = 0` means put the whole model on GPU 0. - - To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For - more information about each option see [designing a device - map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map). - max_memory (`Dict`, *optional*): - A dictionary device identifier to maximum memory. Will default to the maximum memory available for each - GPU and the available CPU RAM if unset. - offload_folder (`str` or `os.PathLike`, `optional`): - If the `device_map` contains any value `"disk"`, the folder where we will offload weights. - offload_index (`int`, `optional`): - `offload_index` argument to be passed to `accelerate.dispatch_model` method. - peft_config (`Dict[str, Any]`, *optional*): - The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts - methods. This argument is used in case users directly pass PEFT state dicts - adapter_state_dict (`Dict[str, mindspore.Tensor]`, *optional*): - The state dict of the adapter to load. This argument is used in case users directly pass PEFT state - dicts - adapter_kwargs (`Dict[str, Any]`, *optional*): - Additional keyword arguments passed along to the `from_pretrained` method of the adapter config and - `find_adapter_config_file` method. """ adapter_name = adapter_name if adapter_name is not None else "default" @@ -1329,9 +1264,9 @@ class PeftAdapterMixin: # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. if ( - (getattr(self, "hf_device_map", None) is not None) - and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) - and len(self.peft_config) == 1 + (getattr(self, "hf_device_map", None) is not None) + and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) + and len(self.peft_config) == 1 ): self._dispatch_accelerate_model( device_map=device_map, @@ -1348,13 +1283,6 @@ class PeftAdapterMixin: Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the default adapter name). - - Args: - adapter_config (`~peft.PeftConfig`): - The configuration of the adapter to add, supported adapters are non-prefix tuning and adaption prompts - methods - adapter_name (`str`, *optional*, defaults to `"default"`): - The name of the adapter to add. If no name is passed, a default name is assigned to the adapter. """ from ...peft import PeftConfig, inject_adapter_in_model @@ -1383,10 +1311,6 @@ class PeftAdapterMixin: official documentation: https://huggingface.co/docs/peft Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. - - Args: - adapter_name (`Union[List[str], str]`): - The name of the adapter to set. Can be also a list of strings to set multiple adapters. """ if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") @@ -1405,18 +1329,18 @@ class PeftAdapterMixin: from ...peft.tuners.tuners_utils import BaseTunerLayer from ...peft.utils import ModulesToSaveWrapper - _adapters_has_been_set = False + adapters_has_been_set = False for _, module in self.named_modules(): if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): - # For backward compatbility with previous PEFT versions + # For backward compatibility with previous PEFT versions if hasattr(module, "set_adapter"): module.set_adapter(adapter_name) else: module.active_adapter = adapter_name - _adapters_has_been_set = True + adapters_has_been_set = True - if not _adapters_has_been_set: + if not adapters_has_been_set: raise ValueError( "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." ) @@ -1503,10 +1427,6 @@ class PeftAdapterMixin: Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. If no adapter_name is passed, the active adapter is used. - - Args: - adapter_name (`str`, *optional*): - The name of the adapter to get the state dict from. If no name is passed, the active adapter is used. """ if not self._hf_peft_config_loaded: raise ValueError("No adapter loaded. Please load an adapter first.") @@ -1694,17 +1614,6 @@ class PretrainedConfig: def from_dict(cls, config_dict: Dict, **kwargs) -> "PretrainedConfig": """ Constructs a `Config` from a Python dictionary of parameters. - - Args: - config_dict (:obj:`Dict[str, any]`): - Dictionary that will be used to instantiate the configuration object. Such a dictionary can be retrieved - from a pre-trained checkpoint by leveraging the :func:`~transformers.PretrainedConfig.get_config_dict` - method. - kwargs (:obj:`Dict[str, any]`): - Additional parameters from which to initialize the configuration object. - - Returns: - :class:`PretrainedConfig`: An instance of a configuration object """ return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) @@ -1738,13 +1647,13 @@ class PretrainedConfig: @classmethod def from_pretrained( - cls, - pretrained_model_name_or_path: Union[str, os.PathLike], - cache_dir: Optional[Union[str, os.PathLike]] = None, - force_download: bool = False, - local_files_only: bool = False, - mirror: str = 'huggingface', - **kwargs, + cls, + pretrained_model_name_or_path: Union[str, os.PathLike], + cache_dir: Optional[Union[str, os.PathLike]] = None, + force_download: bool = False, + local_files_only: bool = False, + mirror: str = 'huggingface', + **kwargs, ) -> "PretrainedConfig": r""" Instantiate a [`PretrainedConfig`] (or a derived class) from a pretrained model configuration. @@ -1765,19 +1674,11 @@ class PretrainedConfig: @classmethod def get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used for instantiating a [`PretrainedConfig`] using `from_dict`. - - Parameters: - pretrained_model_name_or_path (`str` or `os.PathLike`): - The identifier of the pre-trained checkpoint from which we want the dictionary of parameters. - - Returns: - `Tuple[Dict, Dict]`: The dictionary(ies) that will be used to instantiate the configuration object. - """ original_kwargs = copy.deepcopy(kwargs) # Get config dict associated with the base config file @@ -1794,7 +1695,7 @@ class PretrainedConfig: @classmethod def _get_config_dict( - cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs + cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs ) -> Tuple[Dict[str, Any], Dict[str, Any]]: """ From a `pretrained_model_name_or_path`, resolve to a dictionary of parameters, to be used @@ -1954,9 +1855,9 @@ class PretrainedConfig: # only serialize values that differ from the default config for key, value in config_dict.items(): if ( - isinstance(getattr(self, key, None), PretrainedConfig) - and key in class_config_dict - and isinstance(class_config_dict[key], dict) + isinstance(getattr(self, key, None), PretrainedConfig) + and key in class_config_dict + and isinstance(class_config_dict[key], dict) ): # For nested configs we need to clean the diff recursively diff = recursive_diff_dict(value, class_config_dict[key], config_obj=getattr(self, key, None)) @@ -1966,9 +1867,9 @@ class PretrainedConfig: if len(diff) > 0: serializable_config_dict[key] = diff elif ( - key not in default_config_dict - or value != default_config_dict[key] - or (key in class_config_dict and value != class_config_dict[key]) + key not in default_config_dict + or value != default_config_dict[key] + or (key in class_config_dict and value != class_config_dict[key]) ): serializable_config_dict[key] = value @@ -2101,10 +2002,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte def _from_config(cls, config, **kwargs): """ All context managers that the model should be initialized under go here. - - Args: - torch_dtype (`torch.dtype`, *optional*): - Override the default `torch.dtype` and load the model under this dtype. """ model = cls(config, **kwargs) @@ -2134,12 +2031,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte def prune_heads(self, heads_to_prune: Dict[int, List[int]]): """ Prunes heads of the base model. - - Arguments: - heads_to_prune (`Dict[int, List[int]]`): - Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads - to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on - layer 1 and heads 2 and 3 on layer 2. """ # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads for layer, heads in heads_to_prune.items(): @@ -2266,11 +2157,11 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte ) def tie_encoder_to_decoder_recursively( - decoder_pointer: nn.Cell, - encoder_pointer: nn.Cell, - module_name: str, - uninitialized_encoder_weights: List[str], - depth=0, + decoder_pointer: nn.Cell, + encoder_pointer: nn.Cell, + module_name: str, + uninitialized_encoder_weights: List[str], + depth=0, ): assert isinstance(decoder_pointer, nn.Cell) and isinstance( encoder_pointer, nn.Cell @@ -2299,7 +2190,7 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte encoder_name = str(int(name) + encoder_layer_pos) decoder_name = name if not isinstance(decoder_cells[decoder_name], type(encoder_cells[encoder_name])) and len( - encoder_cells + encoder_cells ) != len(decoder_cells): # this can happen if the name corresponds to the position in a list module list of layers # in this case the decoder has added a cross-attention that the encoder does not have @@ -2348,7 +2239,7 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte replace_references(output_embeddings.bias, Parameter(ops.pad( output_embeddings.bias.data, (0, output_embeddings.weight.shape[0] - - output_embeddings.bias.shape[0]), + output_embeddings.bias.shape[0]), "constant", 0, ), name=output_embeddings.bias.name, requires_grad=output_embeddings.bias.requires_grad)) @@ -2357,21 +2248,11 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte output_embeddings.out_channels = input_embeddings.vocab_size def resize_token_embeddings( - self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None + self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None ) -> nn.Embedding: """ Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`. - Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method. - - Arguments: - new_num_tokens (`int`, *optional*): - The number of new tokens in the embedding matrix. Increasing the size will add newly initialized - vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just - returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything. - - Return: - `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model. """ model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of) if new_num_tokens is None and pad_to_multiple_of is None: @@ -2415,37 +2296,30 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return self.get_input_embeddings() def _get_resized_embeddings( - self, - old_embeddings: nn.Embedding, - new_num_tokens: Optional[int] = None, - pad_to_multiple_of: Optional[int] = None, + self, + old_embeddings: nn.Embedding, + new_num_tokens: Optional[int] = None, + pad_to_multiple_of: Optional[int] = None, ) -> nn.Embedding: """ Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly initialized vectors at the end Reducing the size will remove vectors from the end - - Args: - new_num_tokens: (`optional`) int - New number of tokens in the embedding matrix. - Increasing the size will add newly initialized vectors at the end - Reducing the size will remove vectors from the end - If not provided or None: return the provided token Embedding Module. - Return: ``mindspore.nn.Embeddings`` - Pointer to the resized Embedding Module or the old Embedding Module if new_num_tokens is None """ if pad_to_multiple_of is not None: if not isinstance(pad_to_multiple_of, int): raise ValueError( - f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer" + f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, " + "which is not and integer. Please make sure to pass an integer" ) if new_num_tokens is None: new_num_tokens = old_embeddings.weight.shape[0] new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of else: logger.info( - "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding" - f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available." - " For more details about this, or help on choosing the correct value for resizing, refer to this guide:" + "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. " + f"This means that the new embedding dimension will be {new_num_tokens}. This might induce " + "some performance reduction as *Tensor Cores* will not be available. For more details about " + "this, or help on choosing the correct value for resizing, refer to this guide:" " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc" ) @@ -2465,31 +2339,15 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte # Copy word embeddings from the previous weights num_tokens_to_copy = min(old_num_tokens, new_num_tokens) new_embeddings.weight.data[:num_tokens_to_copy, :] = old_embeddings.weight.data[ - :num_tokens_to_copy, :] + :num_tokens_to_copy, :] return new_embeddings def _get_resized_lm_head( - self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False + self, old_lm_head: nn.Dense, new_num_tokens: Optional[int] = None, transposed: Optional[bool] = False ) -> nn.Dense: """ Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove vectors from the end - - Args: - old_lm_head (`nn.Dense`): - Old lm head liner layer to be resized. - new_num_tokens (`int`, *optional*): - New number of tokens in the linear matrix. - - Increasing the size will add newly initialized vectors at the end. Reducing the size will remove - vectors from the end. If not provided or `None`, just returns a pointer to the input tokens - `nn.Dense` module of the model without doing anything. transposed (`bool`, *optional*, defaults - to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim, - vocab_size` else `vocab_size, lm_head_dim`. - - Return: - `nn.Dense`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is - `None` """ if new_num_tokens is None: return old_lm_head @@ -2533,7 +2391,7 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return new_lm_head def _copy_lm_head_original_to_resized( - self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias + self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ): # Copy old lm head weights to new lm head if not transposed: @@ -2558,500 +2416,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte """ return cls.from_pretrained(pretrained_model_name_or_path, args, kwargs) - @classmethod - def from_pretrained( - cls, - pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], - *model_args, - config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None, - cache_dir: Optional[Union[str, os.PathLike]] = None, - ignore_mismatched_sizes: bool = False, - force_download: bool = False, - local_files_only: bool = False, - token: Optional[Union[str, bool]] = None, - use_safetensors: bool = None, - mirror: str = 'huggingface', - **kwargs, - ): - """from_pretrained""" - state_dict = kwargs.pop("state_dict", None) - cache_dir = kwargs.pop("cache_dir", None) - _ = kwargs.pop("from_pt", True) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", False) - _fast_init = kwargs.pop("_fast_init", True) - output_loading_info = kwargs.pop("output_loading_info", False) - subfolder = kwargs.pop("subfolder", "") - variant = kwargs.pop("variant", None) - ms_dtype = kwargs.pop("ms_dtype", None) - _ = kwargs.pop('low_cpu_mem_usage', None) - revision = kwargs.pop('revision', 'main') - - if use_safetensors is None and not is_safetensors_available(): - use_safetensors = False - - is_sharded = False - # Load config if we don't provide a configuration - if not isinstance(config, PretrainedConfig): - config_path = config if config is not None else pretrained_model_name_or_path - config, model_kwargs = cls.config_class.from_pretrained( - config_path, - *model_args, - cache_dir=cache_dir, - return_unused_kwargs=True, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - mirror=mirror, - **kwargs, - ) - else: - model_kwargs = kwargs - - # Load model - if pretrained_model_name_or_path is not None: - pretrained_model_name_or_path = str(pretrained_model_name_or_path) - is_local = os.path.isdir(pretrained_model_name_or_path) - if is_local: - if os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME) - ): - # Load from a PyTorch checkpoint - archive_file = os.path.join(pretrained_model_name_or_path, subfolder, PT_WEIGHTS_NAME) - elif os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)) - ): - # Load from a MindSpore checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant) - ) - elif use_safetensors is not False and os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)) - ): - # Load from a safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant) - ) - elif os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant)) - ): - # Load from a sharded PyTorch checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(PT_WEIGHTS_INDEX_NAME, variant) - ) - is_sharded = True - elif os.path.isfile( - os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)) - ): - # Load from a sharded MindSpore checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant) - ) - is_sharded = True - elif use_safetensors is not False and os.path.isfile( - os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - ): - # Load from a sharded safetensors checkpoint - archive_file = os.path.join( - pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant) - ) - is_sharded = True - # At this stage we don't have a weight file so we will raise an error. - elif use_safetensors: - raise EnvironmentError( - f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory" - f" {pretrained_model_name_or_path}." - ) - else: - raise EnvironmentError( - f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {PT_WEIGHTS_NAME}," - f" found in directory {pretrained_model_name_or_path}." - ) - elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)): - archive_file = pretrained_model_name_or_path - is_local = True - elif is_remote_url(pretrained_model_name_or_path): - filename = pretrained_model_name_or_path - resolved_archive_file = download_url(pretrained_model_name_or_path) - else: - if use_safetensors is not False: - filename = _add_variant(SAFE_WEIGHTS_NAME, variant) - else: - filename = _add_variant(WEIGHTS_NAME, variant) - - try: - # Load from URL or cache if already cached - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "proxies": proxies, - "resume_download": resume_download, - "local_files_only": local_files_only, - "subfolder": subfolder, - "_raise_exceptions_for_missing_entries": False, - 'revision': revision, - "token": token, - 'mirror': mirror - } - # try safetensors - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - use_safetensors = resolved_archive_file is not None - - # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None - # result when internet is up, the repo and revision exist, but the file does not. - if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, - ) - if resolved_archive_file is not None: - is_sharded = True - use_safetensors = True - - if resolved_archive_file is None: - filename = _add_variant(WEIGHTS_NAME, variant) - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, - ) - if resolved_archive_file is not None: - is_sharded = True - - if resolved_archive_file is None: - filename = _add_variant(PT_WEIGHTS_NAME, variant) - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - - if resolved_archive_file is None and filename == _add_variant(PT_WEIGHTS_NAME, variant): - # Maybe the checkpoint is sharded, we try to grab the index name in this case. - resolved_archive_file = cached_file( - pretrained_model_name_or_path, - _add_variant(PT_WEIGHTS_INDEX_NAME, variant), - **cached_file_kwargs, - ) - if resolved_archive_file is not None: - is_sharded = True - - if resolved_archive_file is None: - raise EnvironmentError( - f"{pretrained_model_name_or_path} does not appear to have a file named" - f" {_add_variant(SAFE_WEIGHTS_NAME, variant)}, {_add_variant(PT_WEIGHTS_NAME, variant)}" - ) - except EnvironmentError: - # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted - # to the original exception. - raise - except Exception as exc: - # For any other exception, we throw a generic error. - raise EnvironmentError( - f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it" - ", make sure you don't have a local directory with the" - f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a" - f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)}," - f" {_add_variant(PT_WEIGHTS_NAME, variant)}." - ) from exc - - if is_local: - logger.info(f"loading weights file {archive_file}") - resolved_archive_file = archive_file - else: - logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}") - else: - resolved_archive_file = None - - if is_sharded: - # rsolved_archive_file becomes a list of files that point to the different checkpoint shards in this case. - resolved_archive_file, sharded_metadata = get_checkpoint_shard_files( - pretrained_model_name_or_path, - resolved_archive_file, - cache_dir=cache_dir, - force_download=force_download, - proxies=proxies, - resume_download=resume_download, - local_files_only=local_files_only, - token=token, - subfolder=subfolder, - revision=revision, - mirror=mirror, - ) - - if pretrained_model_name_or_path is None and state_dict is None: - raise ValueError("the argument 'pretrained_model_name_or_path' should be " - "a string of model name or checkpoint path, but got 'None'.") - - config.name_or_path = pretrained_model_name_or_path - # Instantiate model. - - config_dict = config.to_dict() - - dtype_group = {key: getattr(config, key).ms_dtype for key in config_dict.keys() \ - if isinstance(config_dict[key], dict) and 'ms_dtype' in config_dict[key]} - - if ms_dtype is None or ms_dtype == 'auto': - ms_dtype = config.ms_dtype - - if ms_dtype is None: - ms_dtype = mindspore.float32 - - use_fp16 = False - usage_dtype = mindspore.dtype_to_nptype(ms_dtype) - if ms_dtype == mindspore.bfloat16: - ms_dtype = mindspore.float16 - usage_dtype = np.float16 - use_fp16 = True - - def empty_initializer(init, shape=None, dtype=mindspore.float32): - if not isinstance(shape, (tuple, list)): - shape = (shape,) - if dtype in (mindspore.float16, mindspore.float32) \ - and ms_dtype is not None: - dtype = ms_dtype - return Tensor_(shape=shape, dtype=dtype) - - with no_init_weights(empty_initializer, _fast_init): - model = cls(config, *model_args, **model_kwargs) - - if ms_dtype != mindspore.float32: - set_global_fp16(False) - - if is_sharded: - converted_filenames = resolved_archive_file - - # tie the model weights before retrieving the state_dict - model.tie_weights() - - ptrs = collections.defaultdict(list) - for name, tensor in model.parameters_dict().items(): - id_tensor = id(tensor) - ptrs[id_tensor].append(name) - - # These are all the pointers of shared tensors. - tied_params = [names for _, names in ptrs.items() if len(names) > 1] - def load_ckpt(resolved_archive_file): - if not resolved_archive_file.endswith('ckpt'): - if use_safetensors or 'safetensors' in resolved_archive_file: - from safetensors.numpy import load_file - origin_state_dict = load_file(resolved_archive_file) - if use_fp16: - logger.warning_once("MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16") - state_dict = {k: Parameter(Tensor.from_numpy(v.astype(usage_dtype))) for k, v in origin_state_dict.items()} - else: - state_dict = load(resolved_archive_file) - else: - try: - state_dict = load_checkpoint(str(resolved_archive_file)) - except Exception as exc: - raise OSError( - f"Unable to load weights from mindspore checkpoint file '{resolved_archive_file}'. " - ) from exc - - state_keys = list(state_dict.keys()) - for key in state_keys: - new_key = key.replace('gamma', 'weight').replace('beta', 'bias').replace('embedding_table', 'weight') - if new_key != key: - state_dict[new_key] = state_dict.pop(key) - return state_dict - - keys_missing = list(model.parameters_dict().keys()) - param_id_set = set() - - use_keep_in_fp32_modules = False - if model._keep_in_fp32_modules: - use_keep_in_fp32_modules = True - - remove_prefix_from_model = None - add_prefix_to_model = None - - def fix_weight_norm_missing_keys(state_dict_keys: dict, keys_missing:List[str]) -> List[str]: - ''' if both `weight_g` and `weight_v` are loaded, key `weight` is not missing :) ''' - non_missing_keys = [] - for key in keys_missing: - if f'{key}_g' in state_dict_keys and f'{key}_v' in state_dict_keys: - non_missing_keys.append(key) - return non_missing_keys - - def load_param_into_net(model: nn.Cell, param_dict: dict, prefix: str, dtype_group: dict = None): - state_dict_keys = list(param_dict.keys()) - keep_in_fp32_modules = model._keep_in_fp32_modules - keys_unexpected = list(param_dict.keys()) - - has_prefix_module = any(s.startswith(prefix) for s in keys_unexpected) - expects_prefix_module = any(s.startswith(prefix) for s in keys_missing) - - nonlocal remove_prefix_from_model - nonlocal add_prefix_to_model - remove_prefix_from_model = not has_prefix_module and expects_prefix_module - add_prefix_to_model = has_prefix_module and not expects_prefix_module - - for pname_in_net, param in model.parameters_and_names(): - if add_prefix_to_model: - param_name = prefix + '.' + pname_in_net - elif remove_prefix_from_model: - param_name = pname_in_net.replace(f'{prefix}.', '') - else: - param_name = pname_in_net - - if param.uuid in param_id_set: - # for tied params - if param_name in keys_unexpected: - keys_unexpected.remove(param_name) - continue - - new_param = param_dict.pop(param_name, None) - - module_dtype = None - for m_name, m_dtype in dtype_group.items(): - if m_name in param_name: - module_dtype = m_dtype - break - - if new_param is not None: - use_replace = False - if new_param.shape != param.shape: - if not ignore_mismatched_sizes: - raise RuntimeError(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter' - f' `{param_name} with shape {new_param.shape} in checkpoint, ' - f'\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method.') - logger.warning(f'The shape of parameter `{param.name} is {param.shape}, but got mismatch parameter' - f' `{param_name} with shape {new_param.shape} in checkpoint, ') - continue - - if use_keep_in_fp32_modules and \ - any(module_to_keep_in_fp32 in pname_in_net.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules): - new_param = new_param.astype(mindspore.float32) - elif module_dtype and param.dtype in (mindspore.float32, mindspore.float16): - new_param = new_param.astype(module_dtype) - elif ms_dtype and param.dtype in (mindspore.float32, mindspore.float16): - new_param = new_param.astype(ms_dtype) - - if new_param.dtype != param.dtype or new_param.shape != param.shape: - use_replace = True - - if use_replace: - if isinstance(new_param, Parameter): - new_param.name = param.name - new_param.requires_grad = param.requires_grad - replace_references(param, new_param) - else: - replace_references(param, Parameter(new_param, requires_grad=param.requires_grad, name=param.name)) - else: - param.set_data(new_param) - keys_unexpected.remove(param_name) - keys_missing.remove(pname_in_net) - param_id_set.add(param.uuid) - else: - # fix missing value parameter dtype cast. - if ms_dtype and ms_dtype != param.dtype: - new_param = param.astype(ms_dtype) - replace_references(param, Parameter(new_param, name=param.name, requires_grad=param.requires_grad)) - - # NOTE: monkey patching weight_norm - for key in fix_weight_norm_missing_keys(state_dict_keys, keys_missing): - keys_missing.remove(key) - - return keys_unexpected, keys_missing - - all_keys_unexpected = None - if state_dict is None: - if is_sharded: - all_keys_unexpected = [] - for name in tqdm(converted_filenames, desc="Loading checkpoint shards"): - state_dict = load_ckpt(name) - keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) - all_keys_unexpected.extend(keys_unexpected) - del state_dict - gc.collect() - loaded_keys = sharded_metadata["all_checkpoint_keys"] - else: - state_dict = load_ckpt(resolved_archive_file) - loaded_keys = list(state_dict.keys()) - all_keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) - else: - loaded_keys = list(state_dict.keys()) - all_keys_unexpected, keys_missing = load_param_into_net(model, state_dict, cls.base_model_prefix, dtype_group) - - loaded_add_keys = [] - for group in tied_params: - missing_in_group = [k for k in keys_missing if k in group] - if len(missing_in_group) > 0 and len(missing_in_group) < len(group): - loaded_add_keys.extend([k for k in keys_missing if k in missing_in_group]) - keys_missing = [k for k in keys_missing if k not in missing_in_group] - if cls._keys_to_ignore_on_load_missing is not None: - for pat in cls._keys_to_ignore_on_load_missing: - keys_missing = [k for k in keys_missing if re.search(pat, k) is None] - - if cls._keys_to_ignore_on_load_unexpected is not None: - for pat in cls._keys_to_ignore_on_load_unexpected: - all_keys_unexpected = [k for k in all_keys_unexpected if re.search(pat, k) is None] - - # make sure token embedding weights are still tied if needed - model.tie_weights() - - # retrieve unintialized modules and initialize before maybe overriding that with the pretrained weights. - if _fast_init: - if not ignore_mismatched_sizes: - if remove_prefix_from_model: - _loaded_keys = [f"{cls.base_model_prefix}.{k}" for k in loaded_keys] - elif add_prefix_to_model: - _loaded_keys = [k[len(cls.base_model_prefix) + 1 :] for k in loaded_keys] - else: - _loaded_keys = loaded_keys - - _loaded_keys += loaded_add_keys - _ = set_initialized_submodules(model, _loaded_keys) - else: - _ = dict(model.cells_and_names()) - - model.apply(model._initialize_weights) - - # Set model in evaluation mode to deactivate DropOut modules by default - model.set_train(False) - - # If it is a model with generation capabilities, attempt to load the generation config - if model.can_generate() and pretrained_model_name_or_path is not None: - try: - model.generation_config = GenerationConfig.from_pretrained( - pretrained_model_name_or_path, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - subfolder=subfolder, - revision=revision, - **kwargs, - ) - except OSError: - logger.info( - "Generation config file not found, using a generation config created from the model config." - ) - - if output_loading_info: - loading_info = { - "missing_keys": keys_missing, - "unexpected_keys": all_keys_unexpected, - } - return model, loading_info - - if all_keys_unexpected: - logger.warning(f'The following parameters in checkpoint files are not loaded:\n' - f'{all_keys_unexpected}') - if keys_missing: - logger.warning(f'The following parameters in models are missing parameter:\n' - f'{keys_missing}') - return model - def save(self, save_dir): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the `:func:`PreTrainedModel.from_pretrained`` class method. @@ -3106,9 +2470,9 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an # attention_mask or not. In this case, we should still show a warning because this is a rare case. if ( - (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) - or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) - or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) + (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id) + or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id) + or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id) ): warn_string += ( f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical " @@ -3170,54 +2534,19 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte self.check_names() def save_pretrained( - self, - save_directory: Union[str, os.PathLike], - is_main_process: bool = True, - state_dict: Optional[dict] = None, - save_function: Callable = mindspore.save_checkpoint, - max_shard_size: Union[int, str] = "5GB", - safe_serialization: bool = True, - variant: Optional[str] = None, - **kwargs, + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = mindspore.save_checkpoint, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the [`~PreTrainedModel.from_pretrained`] class method. - - Arguments: - save_directory (`str` or `os.PathLike`): - Directory to which to save. Will be created if it doesn't exist. - is_main_process (`bool`, *optional*, defaults to `True`): - Whether the process calling this is the main process or not. Useful when in distributed training like - TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on - the main process to avoid race conditions. - state_dict (nested dictionary of `torch.Tensor`): - The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only - save parts of the model or if special precautions need to be taken when recovering the state dictionary - of a model (like when using model parallelism). - save_function (`Callable`): - The function to use to save the state dictionary. Useful on distributed training like TPUs when one - need to replace `torch.save` by another method. - max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`): - The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size - lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`). - We default it to 5GB in order for models to be able to run easily on free-tier google colab instances - without CPU OOM issues. - - - - If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard - which will be bigger than `max_shard_size`. - - - variant (`str`, *optional*): - If specified, weights are saved in the format pytorch_model..bin. - save_peft_format (`bool`, *optional*, defaults to `True`): - For backward compatibility with PEFT library, in case adapter weights are attached to the model, all - keys of the state dict of adapters needs to be pre-pended with `base_model.model`. Advanced users can - disable this behaviours by setting `save_peft_format` to `False`. - kwargs (`Dict[str, Any]`, *optional*): - Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ if os.path.isfile(save_directory): @@ -3273,11 +2602,11 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte reg = re.compile(r"(.*?)-\d{5}-of-\d{5}") if ( - filename.startswith(weights_no_suffix) - and os.path.isfile(full_filename) - and filename not in shards - and is_main_process - and reg.fullmatch(filename_no_suffix) is not None + filename.startswith(weights_no_suffix) + and os.path.isfile(full_filename) + and filename not in shards + and is_main_process + and reg.fullmatch(filename_no_suffix) is not None ): os.remove(full_filename) @@ -3499,4 +2828,3 @@ def to_py_obj(obj): if isinstance(obj, np.number): return obj.tolist() return obj - diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py index d6d8a1c0f..4901c9f55 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py @@ -24,7 +24,7 @@ Injection mindspore.nn for MindNLP """ import operator -from functools import reduce, partial +from functools import reduce import math from packaging import version import numpy as np @@ -34,12 +34,8 @@ from mindspore import nn, ops, Tensor, Parameter from mindspore.common._stub_tensor import StubTensor from mindspore.common.initializer import ( initializer, - Constant, - HeNormal, - XavierNormal, Normal, HeUniform, - XavierUniform, Uniform, _calculate_fan_in_and_fan_out ) @@ -149,25 +145,25 @@ def _get_unflatten_size(input_shape, dim, sizes): if not isinstance(sizes, (tuple, list)): raise TypeError(f"Type of `sizes` should be `Tuple` or `List`, but got {type(sizes)}") - if len(sizes) == 0: + if not sizes: raise ValueError("`sizes` must be non-empty") if isinstance(dim, str): raise TypeError("Until Now, `dim` not support type of str in `unflatten`") - _dim = dim - if _dim < 0: - _dim += input_rank + dim_new = dim + if dim_new < 0: + dim_new += input_rank - if _dim < 0 or _dim >= input_rank: + if dim_new < 0 or dim_new >= input_rank: raise ValueError(f"`dim` should be in range [{-input_rank}, {input_rank}), but got {input_rank, dim}") - _sizes_mul = reduce(operator.mul, list(sizes)) - if -1 not in sizes and _sizes_mul != input_shape[_dim]: + sizes_mul = reduce(operator.mul, list(sizes)) + if -1 not in sizes and sizes_mul != input_shape[dim_new]: raise ValueError(f"unflatten: Provided `sizes` {sizes} don't multiply up to the" - f"size of dim {dim} ({input_shape[_dim]}) in the input tensor") + f"size of dim {dim} ({input_shape[dim_new]}) in the input tensor") - out_shape = input_shape[:_dim] + tuple(sizes) + input_shape[_dim + 1:] + out_shape = input_shape[:dim_new] + tuple(sizes) + input_shape[dim_new + 1:] return out_shape @@ -209,8 +205,8 @@ ops.densecus = fp16_patch_decorator(origin_dense) def einsum_label_to_index(label): if label == '.': return 52 - NUM_OF_LETTERS = ord('z') - ord('a') + 1 - return (ord(label) - ord('A')) if (label.isupper()) else (NUM_OF_LETTERS + (ord(label) - ord('a'))) + num_of_letters = ord('z') - ord('a') + 1 + return (ord(label) - ord('A')) if (label.isupper()) else (num_of_letters + (ord(label) - ord('a'))) def einsum(equation, *operands): assert operands, "einsum(): must provide at least one operand" @@ -424,7 +420,7 @@ def einsum(equation, *operands): else: sum_dims.append(dim) dim += 1 - if len(sum_dims) == 0: + if not sum_dims: result = result.mul(operand) elif len(sum_dims) == len(result.shape): result = result.flatten().dot(operand.flatten()) @@ -465,9 +461,9 @@ ops.zeroscus = _zeros def _cross_entropy(input_ce, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): if weight is None: weight = ops.ones(input_ce.shape[-1], input.dtype) - _nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) + nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) class_dim = 0 if input_ce.ndim == 1 else 1 - return _nll_loss(ops.log_softmax(input_ce, class_dim), target, weight)[0] + return nll_loss(ops.log_softmax(input_ce, class_dim), target, weight)[0] # for Tensor @@ -483,11 +479,11 @@ def _get_unfold_indices(input_shape, dimension, size, step): def unfold(self, dimension, size, step): - """torch-like unfold""" - _indices, _dimension = _get_unfold_indices(self.shape, dimension, size, step) - indices = mindspore.Tensor(_indices).astype(mindspore.int32) - output = ops.gather(self, indices, axis=_dimension) - output = ops.moveaxis(output, _dimension + 1, -1) + """unfold""" + indices_new, _dimension_new = _get_unfold_indices(self.shape, dimension, size, step) + indices = mindspore.Tensor(indices_new).astype(mindspore.int32) + output = ops.gather(self, indices, axis=_dimension_new) + output = ops.moveaxis(output, _dimension_new + 1, -1) return output Tensor.unfold = unfold @@ -496,7 +492,7 @@ StubTensor.unfold = unfold # var_mean def var_mean(input_vm, axis=None, *, correction=1, keepdims=False): - """torch-like var_mean""" + """var_mean""" axis = Validator.check_and_canonicalize_axes(axis, input.ndim) x_mean = ops.mean(input_vm, axis, True) x_sub = ops.sub(input_vm, x_mean) @@ -516,7 +512,7 @@ ops.var_mean = var_mean # std_mean def std_mean(input_sm, axis=None, *, correction=1, keepdims=False): - """torch-like std_mean""" + """std_mean""" output = var_mean(input_sm, axis, correction=correction, keepdims=keepdims) return ops.pow(output[0], 0.5), output[1] diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py index 0eb859740..0c6ab0e74 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py @@ -80,17 +80,6 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): [`LogitsProcessor`] that enforces diverse beam search. Note that this logits processor is only effective for [`PreTrainedModel.group_beam_search`]. See [Diverse Beam Search: Decoding Diverse Solutions from Neural Sequence Models](https://arxiv.org/pdf/1610.02424.pdf) for more details. - - Args: - diversity_penalty (`float`): - This value is subtracted from a beam's score if it generates a token same as any beam from other group at a - particular time. Note that `diversity_penalty` is only effective if `group beam search` is enabled. - num_beams (`int`): - Number of beams used for group beam search. See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more - details. - num_beam_groups (`int`): - Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams. - See [this paper](https://arxiv.org/pdf/1610.02424.pdf) for more details. """ def __init__(self, diversity_penalty: float, num_beams: int, num_beam_groups: int): @@ -138,12 +127,6 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing an exponential penalty on tokens that are not in the original input. - - Args: - hallucination_penalty (`float`): - The parameter for hallucination penalty. 1.0 means no penalty. - encoder_input_ids (`mindspore.Tensor`): - The encoder_input_ids that should not be repeated within the decoder ids. """ def __init__(self, penalty: float, encoder_input_ids: mindspore.Tensor): @@ -166,11 +149,6 @@ class EncoderRepetitionPenaltyLogitsProcessor(LogitsProcessor): class RepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. - - Args: - repetition_penalty (`float`): - The parameter for repetition penalty. 1.0 means no penalty. See [this - paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. """ def __init__(self, penalty: float): @@ -229,10 +207,6 @@ class NoRepeatNGramLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that enforces no repetition of n-grams. See [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345). - - Args: - ngram_size (`int`): - All ngrams of size `ngram_size` can only occur once. """ def __init__(self, ngram_size: int): @@ -255,12 +229,6 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that enforces no repetition of encoder input ids n-grams for the decoder ids. See [ParlAI](https://github.com/facebookresearch/ParlAI/blob/master/parlai/core/torch_generator_agent.py#L1350). - - Args: - encoder_ngram_size (`int`): - All ngrams of size `ngram_size` can only occur within the encoder input ids. - encoder_input_ids (`int`): - The encoder_input_ids that should not be repeated within the decoder ids. """ def __init__(self, encoder_ngram_size: int, encoder_input_ids: mindspore.Tensor): @@ -295,18 +263,10 @@ class EncoderNoRepeatNGramLogitsProcessor(LogitsProcessor): class NoBadWordsLogitsProcessor(LogitsProcessor): """ [`LogitsProcessor`] that enforces that specified sequences will never be sampled. - - Args: - bad_words_ids (`List[List[int]]`): - List of list of token ids that are not allowed to be generated. In order to get the token ids of the words - that should not appear in the generated text, use `tokenizer(bad_words, add_prefix_space=True, - add_special_tokens=False).input_ids`. - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, bad_words_ids: List[List[int]], eos_token_id: Union[int, List[int]]): - if not isinstance(bad_words_ids, List) or len(bad_words_ids) == 0: + if not isinstance(bad_words_ids, List) or not bad_words_ids: raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.") if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids): raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.") @@ -337,11 +297,11 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): self.static_bad_words_mask: Optional[mindspore.Tensor] = None for banned_token_seq in self.bad_words_id_length_greater_than_1: - if len(banned_token_seq) == 0: + if not banned_token_seq: raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list") def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor) -> mindspore.Tensor: - if self.static_bad_words_mask is None and len(self.bad_words_id_length_1) > 0: + if self.static_bad_words_mask is None and self.bad_words_id_length_1 > 0: self.static_bad_words_mask = self._calc_static_bad_word_mask(scores) dynamic_banned_tokens = self._calc_banned_bad_words_ids(input_ids.tolist()) @@ -355,7 +315,7 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): return static_bad_words_mask.unsqueeze(0).bool() def _tokens_match(self, prev_tokens: List[int], tokens: List[int]) -> bool: - if len(tokens) == 0: + if not tokens: # if bad word tokens is just one token always ban it return True if len(tokens) > len(prev_tokens): @@ -381,10 +341,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): """ Modifies the scores in place by setting the banned token positions to `-inf`. Banned token is expected to be a list of list of banned tokens to ban in the format [[batch index, vocabulary position],... - - Args: - scores: logits distribution of shape (batch size, vocabulary size) - banned_tokens: list of list of tokens to ban of length (batch_size) """ banned_mask_list = [] for idx, batch_banned_tokens in enumerate(banned_tokens): @@ -427,12 +383,6 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): class MinLengthLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing a min-length by setting EOS probability to 0. - - Args: - min_length (`int`): - The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. """ def __init__(self, min_length: int, eos_token_id: Union[int, List[int]]): @@ -460,14 +410,6 @@ class MinLengthLogitsProcessor(LogitsProcessor): class MinNewTokensLengthLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing a min-length of new tokens by setting EOS (End-Of-Sequence) token probability to 0. - - Args: - prompt_length_to_skip (`int`): - The input tokens length. - min_new_tokens (`int`): - The minimum *new* tokens length below which the score of `eos_token_id` is set to `-float("Inf")`. - eos_token_id (`int`): - The id of the *end-of-sequence* token. """ def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): @@ -495,13 +437,6 @@ class PrefixConstrainedLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that enforces constrained generation and is useful for prefix-conditioned constrained generation. See [Autoregressive Entity Retrieval](https://arxiv.org/abs/2010.00904) for more information. - - Args: - prefix_allowed_tokens_fn: (`Callable[[int, torch.Tensor], List[int]]`): - This function constraints the beam search to allowed tokens only at each step. This function takes 2 - arguments `inputs_ids` and the batch ID `batch_id`. It has to return a list with the allowed tokens for the - next generation step conditioned on the previously generated tokens `inputs_ids` and the batch ID - `batch_id`. """ def __init__(self, prefix_allowed_tokens_fn: Callable[[int, mindspore.Tensor], List[int]], num_beams: int): @@ -541,13 +476,6 @@ class ForcedBOSTokenLogitsProcessor(LogitsProcessor): class ForcedEOSTokenLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached. - - Args: - max_length (`int`): - The maximum length of the sequence to be generated. - eos_token_id (`Union[int, List[int]]`): - The id of the token to force as the last generated token when `max_length` is reached. Optionally, use a - list to set multiple *end-of-sequence* tokens. """ def __init__(self, max_length: int, eos_token_id: Union[int, List[int]]): @@ -587,15 +515,6 @@ class ExponentialDecayLengthPenalty(LogitsProcessor): r""" [`LogitsProcessor`] that exponentially increases the score of the eos_token_id after regulation_start has been reached. - - Args: - exponential_decay_length_penalty (`tuple(int, float)`, *optional*): - This tuple shall consist of: `(start_index, decay_factor)` where `start_index` indicates where penalty - starts and `decay_factor` represents the factor of exponential decay - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - input_ids_seq_length (`int`): - The length of the input sequence. """ def __init__( @@ -630,9 +549,9 @@ class SuppressTokensLogitsProcessor(LogitsProcessor): class SuppressTokensAtBeginLogitsProcessor(LogitsProcessor): r""" - [`SuppressTokensAtBeginLogitsProcessor`] supresses a list of tokens as soon as the `generate` function starts + [`SuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not - sampled at the begining of the generation. + sampled at the beginning of the generation. """ def __init__(self, begin_suppress_tokens, begin_index): @@ -698,15 +617,6 @@ class TemperatureLogitsWarper(LogitsWarper): class TopPLogitsWarper(LogitsWarper): """ [`LogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to prob_cut_off <= prob_cut_off. - - Args: - top_p (`float`): - If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or - higher are kept for generation. - filter_value (`float`, *optional*, defaults to `-float("Inf")`): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. """ def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -752,14 +662,6 @@ class TopPLogitsWarper(LogitsWarper): class TopKLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements. - - Args: - top_k (`int`): - The number of highest probability vocabulary tokens to keep for top-k-filtering. - filter_value (`float`, *optional*, defaults to `-float("Inf")`): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. """ def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -782,14 +684,6 @@ class TypicalLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] that performs typical decoding. See [Typical Decoding for Natural Language Generation](https://arxiv.org/abs/2202.00666) for more information. - - Args: - mass (`float`, *optional*, defaults to 0.9): - Value of typical_p between 0 and 1 inclusive, defaults to 0.9. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. """ def __init__(self, mass: float = 0.9, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -831,37 +725,6 @@ class EpsilonLogitsWarper(LogitsWarper): [`LogitsWarper`] that performs epsilon-sampling, i.e. restricting to tokens with `prob >= epsilon`. Takes the largest min_tokens_to_keep tokens if no tokens satisfy this constraint. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. - - Args: - epsilon (`float`): - If set to > 0, only the most tokens with probabilities `epsilon` or higher are kept for generation. - filter_value (`float`, *optional*, defaults to -inf): - All filtered values will be set to this float value. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Minimum number of tokens that cannot be filtered. - - Examples: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed - - >>> set_seed(0) - >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - - >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt") - - >>> # With sampling, the output is unexpected -- sometimes too unexpected. - >>> outputs = model.generate(**inputs, do_sample=True) - >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) - A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2 - - >>> # With epsilon sampling, the output gets restricted to high-probability tokens. Note that this is similar to - >>> # Top P sampling, which restricts tokens based on their cumulative probability. - >>> # Pro tip: The paper recomends using `epsilon_cutoff` values between 3e-4 and 9e-4 - >>> outputs = model.generate(**inputs, do_sample=True, epsilon_cutoff=0.1) - >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) - A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9 - ``` """ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -901,43 +764,6 @@ class EtaLogitsWarper(LogitsWarper): samples of text generated by neural language models leading to more coherent and fluent text. See [Truncation Sampling as Language Model Desmoothing](https://arxiv.org/abs/2210.15191) for more information. Note: `do_sample` must be set to `True` for this `LogitsWarper` to work. - - - Args: - epsilon (`float`): - A float value in the range (0, 1). Hyperparameter used to calculate the dynamic cutoff value, `eta`. The - suggested values from the paper ranges from 3e-4 to 4e-3 depending on the size of the model. - filter_value (`float`, *optional*, defaults to -inf): - All values that are found to be below the dynamic cutoff value, `eta`, are set to this float value. This - parameter is useful when logits need to be modified for very low probability tokens that should be excluded - from generation entirely. - min_tokens_to_keep (`int`, *optional*, defaults to 1): - Specifies the minimum number of tokens that must be kept for generation, regardless of their probabilities. - For example, if `min_tokens_to_keep` is set to 1, at least one token will always be kept for generation, - even if all tokens have probabilities below the cutoff `eta`. - - Examples: - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM, set_seed - - >>> set_seed(0) - >>> model = AutoModelForCausalLM.from_pretrained("distilgpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("distilgpt2") - - >>> inputs = tokenizer("A sequence: 1, 2", return_tensors="pt") - - >>> # With sampling, the output is unexpected -- sometimes too unexpected. - >>> outputs = model.generate(**inputs, do_sample=True) - >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) - A sequence: 1, 2, 0, 2, 2. 2, 2, 2, 2 - - >>> # With eta sampling, the output gets restricted to high-probability tokens. You can see it as a dynamic form of - >>> # epsilon sampling that adapts its cutoff probability based on the entropy (high entropy = lower cutoff). - >>> # Pro tip: The paper recomends using `eta_cutoff` values between 3e-4 to 4e-3 - >>> outputs = model.generate(**inputs, do_sample=True, eta_cutoff=0.1) - >>> print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0]) - A sequence: 1, 2, 3, 4, 5, 6, 7, 8, 9 - ``` """ def __init__(self, epsilon: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1): @@ -975,60 +801,6 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): when the next generated token can complete it. Consequently, to take the most of biasing sequences with more than one token, consider using beam methods (to gracefully work around partially completed sequences that have a negative bias) and applying the bias to their prefixes (to ensure the bias is applied earlier). - - - - In order to get the token ids of the sequences that you want to bias, make sure to set `add_prefix_space=True` when - initializing the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The - `add_prefix_space` argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours - come from `pre tokenizers`. Read more [here](https://hf-mirror.com/docs/tokenizers/api/pre-tokenizers). - - - - Args: - sequence_bias (`Dict[Tuple[int], float]`): - Dictionary that maps a sequence of tokens to its bias term. Positive biases increase the odds of the - sequence being selected, while negative biases do the opposite. If a sequence has a length of 1, its bias - will always be applied. Otherwise, the bias will only be applied if the sequence in question is about to be - completed (in the token selection step after this processor is applied). - - Examples: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> inputs = tokenizer(["The full name of Donald is Donald"], return_tensors="pt") - - >>> summary_ids = model.generate(inputs["input_ids"], max_new_tokens=4) - >>> print(tokenizer.batch_decode(summary_ids, skip_special_tokens=True)[0]) - The full name of Donald is Donald J. Trump Jr - - >>> # Now let's control generation through a bias. Please note that the tokenizer is initialized differently! - >>> tokenizer_with_prefix_space = AutoTokenizer.from_pretrained("gpt2", add_prefix_space=True) - - - >>> def get_tokens_as_tuple(word): - ... return tuple(tokenizer_with_prefix_space([word], add_special_tokens=False).input_ids[0]) - - - >>> # If we add a negative bias without beam search, it may become "stuck" in a prefix without good continuations - >>> sequence_bias = {get_tokens_as_tuple("Trump"): -10.0} - >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, sequence_bias=sequence_bias) - >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) - The full name of Donald is Donald J. Donald, - - >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) - >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) - The full name of Donald is Donald Rumsfeld, - - >>> # We can also add a positive bias to nudge the model towards specific tokens or continuations - >>> sequence_bias = {get_tokens_as_tuple("Donald Duck"): 10.0} - >>> biased_ids = model.generate(inputs["input_ids"], max_new_tokens=4, num_beams=4, sequence_bias=sequence_bias) - >>> print(tokenizer.batch_decode(biased_ids, skip_special_tokens=True)[0]) - The full name of Donald is Donald Duck. - ``` """ def __init__(self, sequence_bias: Dict[Tuple[int], float]): @@ -1036,7 +808,7 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): self._validate_arguments() # Bias variables that will be populated on the first call (for retrocompatibility purposes, the vocabulary size - # is infered in the first usage, which inhibits initializing here) + # is inferred in the first usage, which inhibits initializing here) self.length_1_bias = None self.prepared_bias_variables = False @@ -1076,14 +848,6 @@ class SequenceBiasLogitsProcessor(LogitsProcessor): class AlternatingCodebooksLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing alternated generation between the two codebooks of [`Bark`]'s fine submodel. - - Args: - input_start_len (`int`): - The length of the initial input sequence. - semantic_vocab_size (`int`): - Vocabulary size of the semantic part, i.e number of tokens associated to the semantic vocabulary. - codebook_size (`int`): - Number of tokens associated to the codebook. """ def __init__(self, input_start_len: int, semantic_vocab_size: int, codebook_size: int): @@ -1115,49 +879,6 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): the `unconditional_ids` branch. See [the paper](https://arxiv.org/abs/2306.17806) for more information. - - Args: - guidance_scale (`float`): - The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale != 1`. - Higher guidance scale encourages the model to generate samples that are more closely linked to the input - prompt, usually at the expense of poorer quality. A value smaller than 1 has the opposite effect, while - making the negative prompt provided with negative_prompt_ids (if any) act as a positive prompt. - model (`PreTrainedModel`): - The model computing the unconditional scores. Supposedly the same as the one computing the conditional - scores. Both models must use the same tokenizer. - unconditional_ids (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Indices of input sequence tokens in the vocabulary for the unconditional branch. If unset, will default to - the last token of the prompt. - unconditional_attention_mask (`mindspore.Tensor` of shape `(batch_size, sequence_length)`, *optional*): - Attention mask for unconditional_ids. - use_cache (`bool`, *optional*, defaults to `True`): - Whether to cache key/values during the negative prompt forward pass. - - - Examples: - - ```python - >>> from transformers import AutoTokenizer, AutoModelForCausalLM - - >>> model = AutoModelForCausalLM.from_pretrained("gpt2") - >>> tokenizer = AutoTokenizer.from_pretrained("gpt2") - >>> inputs = tokenizer(["Today, a dragon flew over Paris, France,"], return_tensors="pt") - >>> out = model.generate(inputs["input_ids"], guidance_scale=1.5) - >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] - 'Today, a dragon flew over Paris, France, killing at least 50 people and injuring more than 100' - - >>> # with a negative prompt - >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt") - >>> out = model.generate(inputs["input_ids"], guidance_scale=2, negative_prompt_ids=neg_inputs["input_ids"]) - >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] - 'Today, a dragon flew over Paris, France, killing at least 130 people. French media reported that' - - >>> # with a positive prompt - >>> neg_inputs = tokenizer(["A very happy event happened,"], return_tensors="pt") - >>> out = model.generate(inputs["input_ids"], guidance_scale=0, negative_prompt_ids=neg_inputs["input_ids"]) - >>> tokenizer.batch_decode(out, skip_special_tokens=True)[0] - "Today, a dragon flew over Paris, France, and I'm very happy to be here. I" - ``` """ def __init__( @@ -1239,44 +960,6 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): See [the paper](https://arxiv.org/abs/2212.04356) for more information. - - Args: - generate_config (`GenerateConfig`): - The generate config used to generate the output. The following parameters are required: - eos_token_id (`int`, *optional*, defaults to 50257): - The id of the *end-of-sequence* token. - no_timestamps_token_id (`int`, *optional*, defaults to 50363): - The id of the `"<|notimestamps|>"` token. - max_initial_timestamp_index (`int`, *optional*, defaults to 1): - Used to set the maximum value of the initial timestamp. This is used to prevent the model from - predicting timestamps that are too far in the future. - - Examples: - ``` python - >>> from transformers import AutoProcessor, WhisperForConditionalGeneration,GenerationConfig - >>> from datasets import load_dataset - - >>> processor = AutoProcessor.from_pretrained("openai/whisper-tiny.en") - >>> model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en") - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> inputs = processor(ds[3]["audio"]["array"], return_tensors="pt") - >>> input_features = inputs.input_features - - >>> #Displaying timestamps - >>> generated_ids = model.generate(inputs=input_features, return_timestamps=True) - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] - >>> print("Transcription:", transcription) - Transcription: <|startoftranscript|><|0.00|> He has grave doubts whether Sir Frederick Layton's work is really Greek after all, and can<|6.44|><|6.44|> discover in it but little of rocky Ithaca.<|9.44|><|endoftext|> - - - >>> #No timestamps & change EOS: - >>> #This allows the user to select a specific token to terminate the sequence on, in this case it's the word "can"(460) - >>> model.generation_config.eos_token_id = 460 - >>> generated_ids = model.generate(inputs=input_features,return_timestamps=False) - >>> transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] - >>> print("Transcription:", transcription) - Transcription: He has grave doubts whether Sir Frederick Layton's work is really Greek after all and can - ``` """ def __init__(self, generate_config): # support for the kwargs @@ -1327,19 +1010,6 @@ class WhisperTimeStampLogitsProcessor(LogitsProcessor): class BarkEosPrioritizerLogitsProcessor(LogitsProcessor): r"""This processor ensures that the EOS token is selected if its probability is greater than the `min_eos_p`. - - - - This logits processor is exclusively compatible with - [Bark](https://hf-mirror.com/docs/transformers/en/model_doc/bark). See the model documentation for examples. - - - - Args: - eos_token_id (`Union[int, List[int]]`): - The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens. - min_eos_p (`float`, *optional*): - Minimum end of speech threshold. """ def __init__(self, eos_token_id: Union[int, List[int]], min_eos_p: float): @@ -1372,35 +1042,6 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): weighted average across the conditional and unconditional logits, parameterised by the `guidance_scale`. See [the paper](https://arxiv.org/abs/2306.05284) for more information. - - - - This logits processor is exclusively compatible with - [MusicGen](https://hf-mirror.com/docs/transformers/main/en/model_doc/musicgen) - - - - Args: - guidance_scale (float): - The guidance scale for classifier free guidance (CFG). CFG is enabled by setting `guidance_scale > 1`. - Higher guidance scale encourages the model to generate samples that are more closely linked to the input - prompt, usually at the expense of poorer quality. - - Examples: - - ```python - >>> from transformers import AutoProcessor, MusicgenForConditionalGeneration - - >>> processor = AutoProcessor.from_pretrained("facebook/musicgen-small") - >>> model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") - - >>> inputs = processor( - ... text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], - ... padding=True, - ... return_tensors="pt", - ... ) - >>> audio_values = model.generate(**inputs, do_sample=True, guidance_scale=3, max_new_tokens=256) - ``` """ def __init__(self, guidance_scale): diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py index 1cf86cf5a..37e24b52a 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -26,13 +26,17 @@ import time from typing import Tuple from collections import OrderedDict -import numpy as np import mindspore as ms from mindspore import Tensor, nn, ops, Parameter from mindspore.nn import CrossEntropyLoss from .module.injection import EmbeddingMindnlp, DenseMindnlp, LayerNormMindnlp -from .module.configuration_utils import BaseModelOutputWithPast, CausalLMOutputWithPast, PreTrainedModelMindnlp, PretrainedConfig +from .module.configuration_utils import ( + BaseModelOutputWithPast, + CausalLMOutputWithPast, + PreTrainedModelMindnlp, + PretrainedConfig, +) class ClassInstantier(OrderedDict): @@ -94,13 +98,16 @@ def apply_rotary_pos_emb(x, sincos, offset=0): class ProGenConfig(PretrainedConfig): + """ + ProGenConfig class + """ model_type = "progen" def __init__( - self, vocab_size, n_positions, n_ctx, n_embd, n_layer, n_head, rotary_dim, - n_inner, activation_function, resid_pdrop, embd_pdrop, attn_pdrop, - layer_norm_epsilon, initializer_range, scale_attn_weights, gradient_checkpointing, - use_cache, bos_token_id, eos_token_id, min_length, **kwargs, + self, vocab_size, n_positions, n_ctx, n_embd, n_layer, n_head, rotary_dim, + n_inner, activation_function, resid_pdrop, embd_pdrop, attn_pdrop, + layer_norm_epsilon, initializer_range, scale_attn_weights, gradient_checkpointing, + use_cache, bos_token_id, eos_token_id, min_length, **kwargs, ): super().__init__(**kwargs) self.vocab_size = vocab_size @@ -151,6 +158,9 @@ class NewGELUActivation(nn.Cell): class ProGenAttention(nn.Cell): + """ + ProGenAttention class + """ def __init__(self, config): super().__init__() @@ -192,13 +202,15 @@ class ProGenAttention(nn.Cell): def construct( - self, - hidden_states, - use_cache=False, - output_attentions=False, - add_input=None, + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, ): - + """ + construct method of ProGenAttention + """ layer_past, attention_mask, head_mask = add_input qkv = self.qkv_proj(hidden_states) mp_num = 8 @@ -266,8 +278,8 @@ class ProGenAttention(nn.Cell): return outputs # a, present, (attentions) def _split_heads(self, x, n_head, dim_head, mp_num): - reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head)) - reshaped = reshaped.reshape(x.shape[:-2] + (-1, ) + reshaped.shape[-1:]) + reshaped = x.reshape(x.shape[: -1] + (n_head // mp_num, dim_head)) + reshaped = reshaped.reshape(x.shape[: -2] + (-1, ) + reshaped.shape[-1: ]) return reshaped def _merge_heads(self, tensor, num_attention_heads, attn_head_size): @@ -284,14 +296,16 @@ class ProGenAttention(nn.Cell): return tensor.view(new_shape) def _attn( - self, - query, - key, - value, - attention_mask=None, - head_mask=None, + self, + query, + key, + value, + attention_mask=None, + head_mask=None, ): - + """ + _attn method of ProGenAttention + """ # compute causal mask from causal mask buffer query_length, key_length = query.shape[-2], key.shape[-2] causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length] @@ -323,6 +337,9 @@ class ProGenAttention(nn.Cell): class ProGenMLP(nn.Cell): + """ + ProGenMLP class + """ def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim super().__init__() embed_dim = config.n_embd # n_embd=4096 @@ -342,6 +359,9 @@ class ProGenMLP(nn.Cell): class ProGenBlock(nn.Cell): + """ + ProGenBlock class + """ def __init__(self, config): super().__init__() if config.n_inner is not None and config.n_inner != "None": @@ -353,12 +373,15 @@ class ProGenBlock(nn.Cell): self.mlp = ProGenMLP(inner_dim, config) def construct( - self, - hidden_states, - use_cache=False, - output_attentions=False, - add_input=None, + self, + hidden_states, + use_cache=False, + output_attentions=False, + add_input=None, ): + """ + construct method of ProGenBlock + """ residual = hidden_states hidden_states = self.ln_1(hidden_states) attn_outputs = self.attn( @@ -430,18 +453,18 @@ class ProGenModel(ProGenPreTrainedModel): self.wte = new_embeddings def construct( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( @@ -588,19 +611,19 @@ class ProGenForCausalLM(ProGenPreTrainedModel): def construct( - self, - input_ids=None, - past_key_values=None, - attention_mask=None, - token_type_ids=None, - position_ids=None, - head_mask=None, - inputs_embeds=None, - labels=None, - use_cache=None, - output_attentions=None, - output_hidden_states=None, - return_dict=None, + self, + input_ids=None, + past_key_values=None, + attention_mask=None, + token_type_ids=None, + position_ids=None, + head_mask=None, + inputs_embeds=None, + labels=None, + use_cache=None, + output_attentions=None, + output_hidden_states=None, + return_dict=None, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -680,7 +703,7 @@ class ProGenForCausalLM(ProGenPreTrainedModel): } -class print_time: +class PrintTime: def __init__(self, desc): self.desc = desc diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py index 0d982361d..35fa926a1 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py @@ -23,13 +23,12 @@ """progen""" import os import random -import argparse import mindspore as ms -from mindspore import ops, Tensor, load_checkpoint, load_param_into_net +from mindspore import ops, Tensor from tokenizers import Tokenizer -from .nn_arch import ProGenForCausalLM, ProGenConfig, print_time +from .nn_arch import ProGenForCausalLM, ProGenConfig, PrintTime from ..model import Model @@ -42,18 +41,18 @@ class ProGen(Model): def __init__(self, config): self.args = config self.config = ProGenConfig( - config.vocab_size, config.n_positions, config.n_ctx, config.n_embd, - config.n_layer, config.n_head, config.rotary_dim, config.n_inner, - config.activation_function, config.resid_pdrop, config.embd_pdrop, - config.attn_pdrop, config.layer_norm_epsilon, config.initializer_range, - config.scale_attn_weights, config.gradient_checkpointing, config.use_cache, + config.vocab_size, config.n_positions, config.n_ctx, config.n_embd, + config.n_layer, config.n_head, config.rotary_dim, config.n_inner, + config.activation_function, config.resid_pdrop, config.embd_pdrop, + config.attn_pdrop, config.layer_norm_epsilon, config.initializer_range, + config.scale_attn_weights, config.gradient_checkpointing, config.use_cache, config.bos_token_id, config.eos_token_id, config.min_length, ) self.mixed_precision = False self.network = ProGenForCausalLM(self.config) self.checkpoint_url = "Local_Checkpoint_Used" self.checkpoint_path = config.ckpt_dir - with print_time('loading tokenizer'): + with PrintTime('loading tokenizer'): self.tokenizer = self.create_tokenizer_custom(file=self.args.tokenizer_file) self.set_env() self.set_seed(self.args.rng_seed, deterministic=self.args.rng_deterministic) @@ -65,6 +64,7 @@ class ProGen(Model): os.environ['TOKENIZERS_PARALLELISM'] = 'false' def set_seed(self, seed, deterministic=True): + print("deterministic", deterministic) random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) @@ -130,7 +130,7 @@ class ProGen(Model): def predict(self, data, **kwargs): if self.args.sanity: - with print_time('sanity cross-entropy'): + with PrintTime('sanity cross-entropy'): x_uniref90bfd30 = self.args.x_uniref90bfd30 x_oas = self.args.x_oas @@ -150,8 +150,8 @@ class ProGen(Model): ce_target = checkpoint_x_ce[self.args.model][1] print(ce_target, ce_eval, abs(ce_eval - ce_target)) - - with print_time('sanity log-likelihood'): + + with PrintTime('sanity log-likelihood'): x_data = self.args.x_data @@ -163,9 +163,10 @@ class ProGen(Model): print(f'll_1={ll_1}') print(f'll_2={ll_2}') - with print_time('sanity model'): + with PrintTime('sanity model'): - alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] + alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', + 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] x_data = self.args.x_data x_random = '2' + ''.join([random.choice(alphabet) for _ in range(len(x_data)-2)]) + '1' x_perturb = x_random[:64] + x_data[len(x_random[:64]):] @@ -182,7 +183,7 @@ class ProGen(Model): print(f'll_x_random={ll_x_random}') print(f'll_x_perturb={ll_x_perturb}') - with print_time('log-likelihood (left-to-right, right-to-left)'): + with PrintTime('log-likelihood (left-to-right, right-to-left)'): reverse = lambda s: s[::-1] @@ -213,19 +214,25 @@ class ProGen(Model): return self.tokenizer.decode_batch(as_lists(tokens_batch)) def truncate(self, input_sample, terminals): + """ + truncate method + """ pos = [] for terminal in terminals: find_pos = input_sample.find(terminal, 1) if find_pos != -1: pos.append(find_pos) - if len(pos) > 0: + if pos: return input_sample[:(min(pos) + 1)] else: return input_sample def generate(self): + """ + generate method + """ if self.args.sanity: - with print_time('sanity cross-entropy'): + with PrintTime('sanity cross-entropy'): x_uniref90bfd30 = self.args.x_uniref90bfd30 x_oas = self.args.x_oas @@ -249,10 +256,11 @@ class ProGen(Model): if abs(ce_eval - ce_target) >= 0.1: raise ValueError("Difference should be within 0.1") - with print_time('sampling'): + with PrintTime('sampling'): completions = self.sample(model=self.network, tokenizer=self.tokenizer, context=self.args.context, - pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], num_return_sequences=self.args.num_samples, - temp=self.args.t, top_p=self.args.p, max_length=self.args.max_length) + pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], + num_return_sequences=self.args.num_samples, temp=self.args.t, + top_p=self.args.p, max_length=self.args.max_length) truncations = [self.truncate(completion, terminals=['1', '2']) for completion in completions] for (i, truncation) in enumerate(truncations): diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py index 53577e71d..e4fa6993e 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py @@ -26,4 +26,3 @@ progen_configuration = { "small": "https://gitee.com/mindspore/mindscience/raw/master/MindSPONGE/applications/model_configs/ProGen/small.yaml", } - -- Gitee From 6c27622a50222ceb043581e47181758aa6e05e6b Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Thu, 5 Sep 2024 09:23:19 +0800 Subject: [PATCH 09/16] pr modification --- .../progen/module/configuration_utils.py | 599 ++---------------- .../models/progen/module/injection.py | 68 +- .../models/progen/module/logits_process.py | 39 +- .../pipeline/models/progen/nn_arch.py | 15 +- .../pipeline/models/progen/progen.py | 19 +- 5 files changed, 127 insertions(+), 613 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index 92be9cf4a..3d7c2f61e 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -1,3 +1,6 @@ +""" +configuration_utils +""" # Copyright 2024 @ Shenzhen Bay Laboratory & # Peking University & # Huawei Technologies Co., Ltd @@ -117,12 +120,6 @@ class CellUtilMixin: return get_parameter_dtype(self) -class SetAttribute(nn.Cell): - def __init__(self, module_name): - super().__init__() - module_name._is_initialized = True - - class ModelOutputMindnlp(OrderedDict): """ Base class for all model outputs as dataclass. Has a `__getitem__` that allows indexing by integer or slice (like a @@ -469,7 +466,6 @@ class GenerationConfig: """ Instantiates a [`GenerationConfig`] from a Python dictionary of parameters. """ - return_unused_kwargs = kwargs.pop("return_unused_kwargs", False) # Those arguments may be passed along for our internal telemetry. # We remove them so they don't appear in `return_unused_kwargs`. kwargs.pop("_from_auto", None) @@ -619,6 +615,9 @@ class GenerationMixin: default_list: Union[LogitsProcessorList, StoppingCriteriaList], custom_list: Union[LogitsProcessorList, StoppingCriteriaList], ) -> Union[LogitsProcessorList, StoppingCriteriaList]: + """" + merge the criteria processor list + """ if not custom_list: return default_list for default in default_list: @@ -680,7 +679,9 @@ class GenerationMixin: streamer: Optional["BaseStreamer"] = None, **kwargs, ) -> Union[GreedySearchDecoderOnlyOutput, mindspore.Tensor]: - + """ + generate method + """ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call self._validate_model_class() generation_config = copy.deepcopy(self.generation_config) @@ -832,7 +833,7 @@ class GenerationMixin: if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - if type(outputs) is dict: + if isinstance(outputs, dict): outputs = ADDict(**outputs) next_token_logits = outputs.logits[:, -1, :] @@ -875,7 +876,8 @@ class GenerationMixin: # if eos_token was found in one sentence, set sentence to finished if eos_token_id_tensor is not None: unfinished_sequences = unfinished_sequences.mul( - next_tokens.tile((eos_token_id_tensor.shape[0], 1)).ne(eos_token_id_tensor.unsqueeze(1)).prod(axis=0) + next_tokens.tile((eos_token_id_tensor.shape[0], 1)).ne( + eos_token_id_tensor.unsqueeze(1)).prod(axis=0) ) # stop when each sentence is finished @@ -908,8 +910,7 @@ class GenerationMixin: attentions=decoder_attentions, hidden_states=decoder_hidden_states, ) - else: - return input_ids + return input_ids def _prepare_model_inputs( self, @@ -970,6 +971,9 @@ class GenerationMixin: return inputs, input_name, model_kwargs def _extract_past_from_model_output(self, outputs: ModelOutputMindnlp, standardize_cache_format: bool = False): + """ + extract past from model output + """ past_key_values = None if "past_key_values" in outputs: past_key_values = outputs.past_key_values @@ -991,12 +995,14 @@ class GenerationMixin: is_encoder_decoder: bool = False, standardize_cache_format: bool = False, ) -> Dict[str, Any]: - # update past_key_values + """ + update past_key_values and update token_type_ids with last value + """ model_kwargs["past_key_values"] = self._extract_past_from_model_output( outputs, standardize_cache_format=standardize_cache_format ) + model_kwargs["is_encoder_decoder"] = is_encoder_decoder - # update token_type_ids with last value token_type_str = "token_type_ids" if token_type_str in model_kwargs: token_type_ids = model_kwargs[token_type_str] @@ -1004,12 +1010,6 @@ class GenerationMixin: return model_kwargs - def _reorder_cache(self, past, beam_idx): - raise NotImplementedError( - f"Make sure that a `_reorder_cache` function is correctly implemented in {self.__class__.__module__} to" - f" enable beam search for {self.__class__}" - ) - def _get_logits_processor( self, generation_config: GenerationConfig, @@ -1155,291 +1155,6 @@ class GenerationMode(ExplicitEnum): GROUP_BEAM_SEARCH = "group_beam_search" -class PeftAdapterMixin: - """ - A class containing all functions for loading and using adapters weights that are supported in PEFT library. For - more details about adapters and injecting them on a transformer-based model, check out the documentation of PEFT - library: https://huggingface.co/docs/peft/index - """ - - _hf_peft_config_loaded = False - - def load_adapter( - self, - peft_model_id: Optional[str] = None, - adapter_name: Optional[str] = None, - revision: Optional[str] = None, - token: Optional[str] = None, - device_map: Optional[str] = "auto", - max_memory: Optional[str] = None, - offload_folder: Optional[str] = None, - offload_index: Optional[int] = None, - peft_config: Dict[str, Any] = None, - adapter_state_dict: Optional[Dict[str, "mindspore.Tensor"]] = None, - adapter_kwargs: Optional[Dict[str, Any]] = None, - ) -> None: - """ - Load adapter weights from file or remote Hub folder. If you are not familiar with adapters and PEFT methods, we - invite you to read more about them on PEFT official documentation: https://huggingface.co/docs/peft - - Requires peft as a backend to load the adapter weights. - """ - - adapter_name = adapter_name if adapter_name is not None else "default" - if adapter_kwargs is None: - adapter_kwargs = {} - - from ...peft import PeftConfig, inject_adapter_in_model, load_peft_weights - from ...peft.utils import set_peft_model_state_dict - - if self._hf_peft_config_loaded and adapter_name in self.peft_config: - raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") - - if peft_model_id is None and (adapter_state_dict is None and peft_config is None): - raise ValueError( - "You should either pass a `peft_model_id` or a `peft_config` and `adapter_state_dict` to load an adapter." - ) - - # We keep `revision` in the signature for backward compatibility - if revision is not None and "revision" not in adapter_kwargs: - adapter_kwargs["revision"] = revision - elif revision is not None and "revision" in adapter_kwargs and revision != adapter_kwargs["revision"]: - logger.error( - "You passed a `revision` argument both in `adapter_kwargs` and as a standalone argument. " - "The one in `adapter_kwargs` will be used." - ) - - # Override token with adapter_kwargs' token - if "token" in adapter_kwargs: - token = adapter_kwargs.pop("token") - - if peft_config is None: - adapter_config_file = find_adapter_config_file( - peft_model_id, - token=token, - **adapter_kwargs, - ) - - if adapter_config_file is None: - raise ValueError( - f"adapter model file not found in {peft_model_id}. Make sure you are passing the correct path to the " - "adapter model." - ) - - peft_config = PeftConfig.from_pretrained( - peft_model_id, - token=token, - **adapter_kwargs, - ) - - # Create and add fresh new adapters into the model. - inject_adapter_in_model(peft_config, self, adapter_name) - - if not self._hf_peft_config_loaded: - self._hf_peft_config_loaded = True - - if peft_model_id is not None: - adapter_state_dict = load_peft_weights(peft_model_id, token=token, **adapter_kwargs) - - # We need to pre-process the state dict to remove unneeded prefixes - for backward compatibility - processed_adapter_state_dict = {} - prefix = "base_model.model." - for key, value in adapter_state_dict.items(): - if key.startswith(prefix): - new_key = key[len(prefix) :] - else: - new_key = key - processed_adapter_state_dict[new_key] = value - - # Load state dict - incompatible_keys = set_peft_model_state_dict(self, processed_adapter_state_dict, adapter_name) - - if incompatible_keys is not None: - # check only for unexpected keys - if hasattr(incompatible_keys, "unexpected_keys") and len(incompatible_keys.unexpected_keys) > 0: - logger.warning( - f"Loading adapter weights from {peft_model_id} led to unexpected keys not found in the model: " - f" {incompatible_keys.unexpected_keys}. " - ) - - # Re-dispatch model and hooks in case the model is offloaded to CPU / Disk. - if ( - (getattr(self, "hf_device_map", None) is not None) - and (len(set(self.hf_device_map.values()).intersection({"cpu", "disk"})) > 0) - and len(self.peft_config) == 1 - ): - self._dispatch_accelerate_model( - device_map=device_map, - max_memory=max_memory, - offload_folder=offload_folder, - offload_index=offload_index, - ) - - def add_adapter(self, adapter_config, adapter_name: Optional[str] = None) -> None: - r""" - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Adds a fresh new adapter to the current model for training purpose. If no adapter name is passed, a default - name is assigned to the adapter to follow the convention of PEFT library (in PEFT we use "default" as the - default adapter name). - """ - from ...peft import PeftConfig, inject_adapter_in_model - - adapter_name = adapter_name or "default" - - if not self._hf_peft_config_loaded: - self._hf_peft_config_loaded = True - elif adapter_name in self.peft_config: - raise ValueError(f"Adapter with name {adapter_name} already exists. Please use a different name.") - - if not isinstance(adapter_config, PeftConfig): - raise ValueError( - f"adapter_config should be an instance of PeftConfig. Got {type(adapter_config)} instead." - ) - - # Retrieve the name or path of the model, one could also use self.config._name_or_path - # but to be consistent with what we do in PEFT: https://github.com/huggingface/peft/blob/6e783780ca9df3a623992cc4d1d665001232eae0/src/peft/mapping.py#L100 - adapter_config.base_model_name_or_path = self.__dict__.get("name_or_path", None) - inject_adapter_in_model(adapter_config, self, adapter_name) - - self.set_adapter(adapter_name) - - def set_adapter(self, adapter_name: Union[List[str], str]) -> None: - """ - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Sets a specific adapter by forcing the model to use a that adapter and disable the other adapters. - """ - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - elif isinstance(adapter_name, list): - missing = set(adapter_name) - set(self.peft_config) - if len(missing) > 0: - raise ValueError( - f"Following adapter(s) could not be found: {', '.join(missing)}. Make sure you are passing the correct adapter name(s)." - f" current loaded adapters are: {list(self.peft_config.keys())}" - ) - elif adapter_name not in self.peft_config: - raise ValueError( - f"Adapter with name {adapter_name} not found. Please pass the correct adapter name among {list(self.peft_config.keys())}" - ) - - from ...peft.tuners.tuners_utils import BaseTunerLayer - from ...peft.utils import ModulesToSaveWrapper - - adapters_has_been_set = False - - for _, module in self.named_modules(): - if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): - # For backward compatibility with previous PEFT versions - if hasattr(module, "set_adapter"): - module.set_adapter(adapter_name) - else: - module.active_adapter = adapter_name - adapters_has_been_set = True - - if not adapters_has_been_set: - raise ValueError( - "Did not succeeded in setting the adapter. Please make sure you are using a model that supports adapters." - ) - - def disable_adapters(self) -> None: - r""" - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Disable all adapters that are attached to the model. This leads to inferring with the base model only. - """ - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from ...peft.tuners.tuners_utils import BaseTunerLayer - from ...peft.utils import ModulesToSaveWrapper - - for _, module in self.named_modules(): - if isinstance(module, (BaseTunerLayer, ModulesToSaveWrapper)): - # The recent version of PEFT need to call `enable_adapters` instead - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=False) - else: - module.disable_adapters = True - - def enable_adapters(self) -> None: - """ - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Enable adapters that are attached to the model. The model will use `self.active_adapter()` - """ - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from ...peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - # The recent version of PEFT need to call `enable_adapters` instead - if hasattr(module, "enable_adapters"): - module.enable_adapters(enabled=True) - else: - module.disable_adapters = False - - def active_adapters(self) -> List[str]: - """ - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Gets the current active adapters of the model. In case of multi-adapter inference (combining multiple adapters - for inference) returns the list of all active adapters so that users can deal with them accordingly. - - For previous PEFT versions (that does not support multi-adapter inference), `module.active_adapter` will return - a single string. - """ - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from ...peft.tuners.tuners_utils import BaseTunerLayer - - for _, module in self.named_modules(): - if isinstance(module, BaseTunerLayer): - active_adapters = module.active_adapter - break - - # For previous PEFT versions - if isinstance(active_adapters, str): - active_adapters = [active_adapters] - - return active_adapters - - def active_adapter(self) -> str: - warnings.warn( - "The `active_adapter` method is deprecated and will be removed in a future version.", FutureWarning - ) - - return self.active_adapters()[0] - - def get_adapter_state_dict(self, adapter_name: Optional[str] = None) -> dict: - """ - If you are not familiar with adapters and PEFT methods, we invite you to read more about them on the PEFT - official documentation: https://huggingface.co/docs/peft - - Gets the adapter state dict that should only contain the weights tensors of the specified adapter_name adapter. - If no adapter_name is passed, the active adapter is used. - """ - if not self._hf_peft_config_loaded: - raise ValueError("No adapter loaded. Please load an adapter first.") - - from ...peft import get_peft_model_state_dict - - if adapter_name is None: - adapter_name = self.active_adapter() - - adapter_state_dict = get_peft_model_state_dict(self, adapter_name=adapter_name) - return adapter_state_dict - - @jit_class class PretrainedConfig: """ @@ -1521,8 +1236,6 @@ class PretrainedConfig: if self.ms_dtype is not None and isinstance(self.ms_dtype, str): if is_mindspore_available(): - import mindspore - self.ms_dtype = getattr(mindspore, self.ms_dtype) # Tokenizer arguments TODO: eventually tokenizer and models should share the same config @@ -1864,7 +1577,7 @@ class PretrainedConfig: if "model_type" in value: # Needs to be set even if it's not in the diff diff["model_type"] = value["model_type"] - if len(diff) > 0: + if diff: serializable_config_dict[key] = diff elif ( key not in default_config_dict @@ -1899,7 +1612,7 @@ class PretrainedConfig: for key, value in config_dict.items(): setattr(self, key, value) - def save_pretrained(self, save_directory: Union[str, os.PathLike], **kwargs): + def save_pretrained(self, save_directory: Union[str, os.PathLike]): """ Save a configuration object to the directory `save_directory`, so that it can be re-loaded using the [`~PretrainedConfig.from_pretrained`] class method. @@ -1938,23 +1651,24 @@ class PretrainedConfig: @property def _attn_implementation(self): - # This property is made private for now (as it cannot be changed and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + """ + This property is made private for now (as it cannot be changed + and a PreTrainedModel.use_attn_implementation method needs to be implemented.) + """ if hasattr(self, "_attn_implementation_internal"): - if self._attn_implementation_internal is None: + if not self._attn_implementation_internal: # `config.attn_implementation` should never be None, for backward compatibility. return "eager" else: return self._attn_implementation_internal - else: - return "eager" + return "eager" @_attn_implementation.setter def _attn_implementation(self, value): self._attn_implementation_internal = value - -class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapterMixin): +class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): """ Abstract class for Pretrained models """ @@ -2007,55 +1721,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return model - def init_weights(self): - """ - If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any - initialization logic in `_init_weights`. - """ - # Prune heads if needed - if self.config.pruned_heads: - self.prune_heads(self.config.pruned_heads) - - if _init_weights: - # Initialize weights - if getattr(self, 'apply', None): - self.apply(self._initialize_weights) - else: - for _, cell in self.name_cells().items(): - self._initialize_weights(cell) - - # Tie weights should be skipped when not initializing all weights - # since from_pretrained(...) calls tie weights anyways - self.tie_weights() - - def prune_heads(self, heads_to_prune: Dict[int, List[int]]): - """ - Prunes heads of the base model. - """ - # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads - for layer, heads in heads_to_prune.items(): - union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads) - self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON - - self.base_model._prune_heads(heads_to_prune) - - def _init_weights(self, cell): - """ - Initialize the weights. This method should be overridden by derived class and is - the only initialization method that will be called when loading a checkpoint - using `from_pretrained`. Any attempt to initialize outside of this function - will be useless as the torch.nn.init function are all replaced with skip. - """ - - def _initialize_weights(self, module): - """ - Initialize the weights if they are not already initialized. - """ - if getattr(module, "_is_initialized", False): - return - self._init_weights(module) - module._is_initialized = True - @property def base_model(self): """ @@ -2087,15 +1752,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return base_model.set_input_embeddings(new_embeddings) raise NotImplementedError - def resize_position_embeddings(self, new_num_position_embeddings: int): - """ - resize the model position embeddings if necessary - """ - raise NotImplementedError( - f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " - f"overwrite this method in the class {self.__class__}" - ) - def get_output_embeddings(self): """ Get model's output embeddings Return None if the model doesn't have output embeddings @@ -2114,139 +1770,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return base_model.set_output_embeddings(new_embeddings) raise NotImplementedError - def get_position_embeddings(self): - """ - get the model position embeddings if necessary - """ - raise NotImplementedError( - f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should " - f"overwrite this method in the class {self.__class__}" - ) - - def tie_weights(self): - """ - Make sure we are sharing the input and output embeddings. - If you need this feature, - you need to get it yourself output Add the output you need to add to the embeddings function_ Embedding layer, - otherwise you cannot - """ - if getattr(self.config, "tie_word_embeddings", True): - output_embeddings = self.get_output_embeddings() # pylint: disable=assignment-from-none - if output_embeddings is not None: - self._tie_or_clone_weights( - output_embeddings, self.get_input_embeddings()) - - if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False): - if hasattr(self, self.base_model_prefix): - self = getattr(self, self.base_model_prefix) # pylint: disable=self-cls-assignment - self._tie_encoder_decoder_weights( - self.encoder, self.decoder, self.base_model_prefix) - - for _, cell in self.cells_and_names(): - if hasattr(cell, "_tie_weights"): - cell._tie_weights() - - @staticmethod - def _tie_encoder_decoder_weights(encoder: nn.Cell, decoder: nn.Cell, base_model_prefix: str): - """tie encoder decoder weights""" - uninitialized_encoder_weights: List[str] = [] - if decoder.__class__ != encoder.__class__: - logger.info( - f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder" - " weights are correctly initialized." - ) - - def tie_encoder_to_decoder_recursively( - decoder_pointer: nn.Cell, - encoder_pointer: nn.Cell, - module_name: str, - uninitialized_encoder_weights: List[str], - depth=0, - ): - assert isinstance(decoder_pointer, nn.Cell) and isinstance( - encoder_pointer, nn.Cell - ), f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module" - if hasattr(decoder_pointer, "weight"): - assert hasattr(encoder_pointer, "weight") - encoder_pointer.weight = decoder_pointer.weight - encoder_pointer._params['weight'] = decoder_pointer.weight - if hasattr(decoder_pointer, "bias"): - assert hasattr(encoder_pointer, "bias") - encoder_pointer.bias = decoder_pointer.bias - encoder_pointer._params['bias'] = decoder_pointer.bias - return - - encoder_cells = encoder_pointer._cells - decoder_cells = decoder_pointer._cells - if len(decoder_cells) > 0: - assert ( - len(encoder_cells) > 0 - ), f"Encoder cell {encoder_pointer} does not match decoder cell {decoder_pointer}" - - all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_cells.keys()} - encoder_layer_pos = 0 - for name, _ in decoder_cells.items(): - if name.isdigit(): - encoder_name = str(int(name) + encoder_layer_pos) - decoder_name = name - if not isinstance(decoder_cells[decoder_name], type(encoder_cells[encoder_name])) and len( - encoder_cells - ) != len(decoder_cells): - # this can happen if the name corresponds to the position in a list module list of layers - # in this case the decoder has added a cross-attention that the encoder does not have - # thus skip this step and subtract one layer pos from encoder - encoder_layer_pos -= 1 - continue - elif name not in encoder_cells: - continue - elif depth > 500: - raise ValueError( - "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is" - " a circular dependency between two or more `nn.Cell` of your model." - ) - else: - decoder_name = encoder_name = name - tie_encoder_to_decoder_recursively( - decoder_cells[decoder_name], - encoder_cells[encoder_name], - module_name + "/" + name, - uninitialized_encoder_weights, - depth=depth + 1, - ) - all_encoder_weights.remove(module_name + "/" + encoder_name) - - uninitialized_encoder_weights += list(all_encoder_weights) - - # tie weights recursively - tie_encoder_to_decoder_recursively(decoder, encoder, base_model_prefix, uninitialized_encoder_weights) - if len(uninitialized_encoder_weights) > 0: - logger.warning( - f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}" - ) - - def _tie_or_clone_weights(self, output_embeddings, input_embeddings): - """ Tie or clone module weights depending of weither we are using or not - """ - if hasattr(output_embeddings, 'weight'): - output_embeddings.weight = input_embeddings.weight - output_embeddings._params['weight'] = input_embeddings.weight - - if getattr(output_embeddings, "bias", None) is not None: - if output_embeddings.weight.shape[0] == output_embeddings.bias.shape[0]: - pass - else: - # instantial a new Parameter since mindspore.Parameter do not support assign_value with different shape - replace_references(output_embeddings.bias, Parameter(ops.pad( - output_embeddings.bias.data, - (0, output_embeddings.weight.shape[0] - - output_embeddings.bias.shape[0]), - "constant", - 0, - ), name=output_embeddings.bias.name, requires_grad=output_embeddings.bias.requires_grad)) - - if hasattr(output_embeddings, "out_channels") and hasattr(input_embeddings, "vocab_size"): - output_embeddings.out_channels = input_embeddings.vocab_size - def resize_token_embeddings( self, new_num_tokens: Optional[int] = None, pad_to_multiple_of: Optional[int] = None ) -> nn.Embedding: @@ -2266,15 +1789,19 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte return model_embeds def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None): + """ + Update new_num_tokens with the actual size of new_embeddings. + If word embeddings are not tied, make sure that lm head is resized as well. + """ old_embeddings = self.get_input_embeddings() new_embeddings = self._get_resized_embeddings(old_embeddings, new_num_tokens, pad_to_multiple_of) self.set_input_embeddings(new_embeddings) self.get_input_embeddings().weight.data_sync(True) - # Update new_num_tokens with the actual size of new_embeddings + if pad_to_multiple_of is not None: new_num_tokens = new_embeddings.weight.shape[0] - # if word embeddings are not tied, make sure that lm head is resized as well + if self.get_output_embeddings() is not None and not self.config.tie_word_embeddings: old_lm_head = self.get_output_embeddings() # pylint: disable=assignment-from-none new_lm_head = self._get_resized_lm_head( @@ -2282,7 +1809,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte self.set_output_embeddings(new_lm_head) self.get_output_embeddings().weight.data_sync(True) - return self.get_input_embeddings() def resize_tokenizer_embeddings(self, new_num_tokens): @@ -2393,13 +1919,15 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte def _copy_lm_head_original_to_resized( self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias ): - # Copy old lm head weights to new lm head + """ + Copy old lm head weights to new lm head. + Copy bias weights to new lm head. + """ if not transposed: new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :] else: new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy] - # Copy bias weights to new lm head if has_new_lm_head_bias: new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy] @@ -2482,28 +2010,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte logger.warning(warn_string) - def parameters_and_names(self, name_prefix='', expand=True): - """ - fix ignore tied weights - """ - cells = [] - if expand: - cells = self.cells_and_names(name_prefix=name_prefix) - else: - cells.append((name_prefix, self)) - - for cell_name, cell in cells: - params = cell._params.items() - for par_name, par in params: - if par is not None and par.inited_param is not None: - par = par.inited_param - if par is not None: - par_new_name = par_name - if cell_name: - par_new_name = cell_name + '.' + par_new_name - - yield par_new_name, par - def num_parameters(self, only_trainable=False): """return parameters count""" total = 0 @@ -2521,18 +2027,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte """ return list(set(filter(lambda x: x.requires_grad, self.get_parameters(expand=recurse)))) - def check_names_and_refresh_name(self): - """ - fix ignore tied weights - """ - if not hasattr(self, "_params"): - return - all_name = dict(self.parameters_and_names()).keys() - - if len(set(all_name)) < len(all_name): - self.update_parameters_name() - self.check_names() - def save_pretrained( self, save_directory: Union[str, os.PathLike], @@ -2542,7 +2036,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte max_shard_size: Union[int, str] = "5GB", safe_serialization: bool = True, variant: Optional[str] = None, - **kwargs, ): """ Save a model and its configuration file to a directory, so that it can be re-loaded using the @@ -2635,15 +2128,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin, PeftAdapte f"index located at {save_index_file}." ) - def enable_recompute(self): - """Activates recompute (aka gradient checkpointing) for the current model.""" - if not self.supports_recompute: - raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") - - for _, cell in self.cells_and_names(): - if hasattr(cell, "_set_recompute"): - cell._set_recompute() - def check_names(self): pass @@ -2780,10 +2264,9 @@ def infer_framework_from_repr(x): representation = str(type(x)) if representation.startswith(" 0 and label_perm_index[label] == -1: label_perm_index[label] = perm_index perm_index += 1 @@ -332,7 +338,7 @@ def einsum(equation, *operands): permuted_operands = [] for i, operand in enumerate(operands): perm_shape = [-1] * perm_index - label_dim = [-1] * TOTAL_LABELS + label_dim = [-1] * total_labels operand = operands[i] labels = op_labels[i] original_sizes = operand.shape @@ -456,16 +462,6 @@ def _zeros(*size, dtype=None): ops.zeroscus = _zeros - -# cross_entropy -def _cross_entropy(input_ce, target, weight=None, ignore_index=-100, reduction='mean', label_smoothing=0.0): - if weight is None: - weight = ops.ones(input_ce.shape[-1], input.dtype) - nll_loss = _get_cache_prim(ops.NLLLoss)(reduction, ignore_index) - class_dim = 0 if input_ce.ndim == 1 else 1 - return nll_loss(ops.log_softmax(input_ce, class_dim), target, weight)[0] - - # for Tensor # unfold def _get_unfold_indices(input_shape, dimension, size, step): @@ -480,10 +476,10 @@ def _get_unfold_indices(input_shape, dimension, size, step): def unfold(self, dimension, size, step): """unfold""" - indices_new, _dimension_new = _get_unfold_indices(self.shape, dimension, size, step) + indices_new, dimension_new = _get_unfold_indices(self.shape, dimension, size, step) indices = mindspore.Tensor(indices_new).astype(mindspore.int32) - output = ops.gather(self, indices, axis=_dimension_new) - output = ops.moveaxis(output, _dimension_new + 1, -1) + output = ops.gather(self, indices, axis=dimension_new) + output = ops.moveaxis(output, dimension_new + 1, -1) return output Tensor.unfold = unfold @@ -590,6 +586,9 @@ StubTensor.unflatten = unflatten def _as_strided(self, size, stride, storage_offset=None): + """ + replace as_strided + """ if len(size) != len(stride): raise RuntimeError("mismatch in length of strides and shape.") index = np.arange(0, size[0] * stride[0], stride[0]) @@ -665,6 +664,9 @@ StubTensor.prod = _prod def _eq(self, other): + """ + replace __eq__ + """ if not isinstance(other, (int, float, Tensor)): return False if isinstance(other, Tensor) and self.shape != other.shape: @@ -741,6 +743,9 @@ class DenseMindnlp(nn.Cell): Uniform(bound), [out_channels], dtype=dtype), name="bias") def construct(self, x): + """ + construct method of DenseMindnlp + """ if LESS_MS_2_2: x_shape = x.shape if len(x_shape) != 2: @@ -778,6 +783,9 @@ class EmbeddingMindnlp(nn.Cell): self.weight[self.padding_idx] = 0 def construct(self, ids): + """ + construct method of EmbeddingMindnlp + """ out_shape = ids.shape + (self.embedding_size,) flat_ids = ids.reshape((-1,)) @@ -831,11 +839,14 @@ class LayerNormMindnlp(nn.Cell): self.bias = Parameter(initializer( beta_init, normalized_shape, dtype=dtype), name="bias") self.layer_norm = ops.LayerNorm(begin_norm_axis=self.begin_norm_axis, - begin_params_axis=self.begin_params_axis, - epsilon=self.epsilon) + begin_params_axis=self.begin_params_axis, + epsilon=self.epsilon) self.elementwise_affine = elementwise_affine def construct(self, input_x): + """ + construct method of LayerNormMindnlp + """ if self.elementwise_affine: y, _, _ = self.layer_norm(input_x, self.weight.astype(input_x.dtype), self.bias.astype(input_x.dtype)) else: @@ -889,8 +900,8 @@ class BatchNorm1dMindnlp(nn.Cell): self.momentum = 1.0 - momentum self.bn_train = ops.BatchNorm(is_training=True, - epsilon=self.eps, - momentum=self.momentum) + epsilon=self.eps, + momentum=self.momentum) self.bn_infer = ops.BatchNorm(is_training=False, epsilon=self.eps) @@ -926,4 +937,3 @@ class BatchNorm1dMindnlp(nn.Cell): return f'num_features={self.num_features}, eps={self.eps}, momentum={1.0 - self.momentum}, ' \ f'weight={self.weight}, bias={self.bias}, running_mean={self.running_mean}, ' \ f'running_var={self.running_var}' - diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py index 0c6ab0e74..86345b5b8 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py @@ -116,8 +116,8 @@ class HammingDiversityLogitsProcessor(LogitsProcessor): for batch_idx in range(batch_size): # predicted tokens of last time step of previous groups previous_group_tokens = current_tokens[ - batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx - ] + batch_idx * self._num_beams: batch_idx * self._num_beams + group_start_idx + ] token_frequency = ops.bincount(previous_group_tokens, minlength=vocab_size) scores[batch_idx * group_size: (batch_idx + 1) * group_size] -= self._diversity_penalty * token_frequency @@ -324,6 +324,9 @@ class NoBadWordsLogitsProcessor(LogitsProcessor): return prev_tokens[-len(tokens):] == tokens def _calc_banned_bad_words_ids(self, prev_input_ids: List[List[int]]) -> Iterable[int]: + """ + calculate banned bad words ids + """ banned_tokens = [] for prev_input_ids_slice in prev_input_ids: banned_tokens_slice = [] @@ -414,9 +417,9 @@ class MinNewTokensLengthLogitsProcessor(LogitsProcessor): def __init__(self, prompt_length_to_skip: int, min_new_tokens: int, eos_token_id: int): for arg_name, arg_value in [ - ("prompt_length_to_skip", prompt_length_to_skip), - ("min_new_tokens", min_new_tokens), - ("eos_token_id", eos_token_id), + ("prompt_length_to_skip", prompt_length_to_skip), + ("min_new_tokens", min_new_tokens), + ("eos_token_id", eos_token_id), ]: if not isinstance(arg_value, int) or arg_value < 0: raise ValueError(f"`{arg_name}` has to be a positive integer, but is {arg_value}") @@ -647,12 +650,12 @@ class TopPLogitsWarper(LogitsWarper): #scores = scores.masked_fill(indices_to_remove, self.filter_value) sorted_indices_to_remove[..., -self.min_tokens_to_keep:] = 0 - if type(sorted_indices_to_remove[0][0].item()) == bool: + if isinstance(sorted_indices_to_remove[0][0].item(), bool): sorted_indices_to_remove = sorted_indices_to_remove.astype("int32") indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) - if type(indices_to_remove[0][0].item()) == int: + if isinstance(indices_to_remove[0][0].item(), int): indices_to_remove = indices_to_remove.astype("bool") scores = scores.masked_fill(indices_to_remove, self.filter_value) @@ -714,7 +717,8 @@ class TypicalLogitsWarper(LogitsWarper): last_ind.clamp_(max=sorted_scores.shape[-1] - 1) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, sorted_indices, sorted_indices_to_remove, axis=1) + indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, + sorted_indices, sorted_indices_to_remove, axis=1) scores = scores.masked_fill(indices_to_remove, self.filter_value) return scores @@ -882,12 +886,12 @@ class UnbatchedClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): """ def __init__( - self, - guidance_scale: float, - model, - unconditional_ids: Optional[mindspore.Tensor] = None, - unconditional_attention_mask: Optional[mindspore.Tensor] = None, - use_cache: Optional[bool] = True, + self, + guidance_scale: float, + model, + unconditional_ids: Optional[mindspore.Tensor] = None, + unconditional_attention_mask: Optional[mindspore.Tensor] = None, + use_cache: Optional[bool] = True, ): self.guidance_scale = guidance_scale self.model = model @@ -1058,9 +1062,10 @@ class ClassifierFreeGuidanceLogitsProcessor(LogitsProcessor): # logits scores (cond + uncond) and input ids (cond only) if scores.shape[0] != 2 * input_ids.shape[0]: raise ValueError( - f"Logits should have twice the batch size of the input ids, the first half of batches corresponding to " - f"the conditional inputs, and the second half of batches corresponding to the unconditional inputs. Got " - f"batch size {scores.shape[0]} for the logits and {input_ids.shape[0]} for the input ids." + f"Logits should have twice the batch size of the input ids, the first half of batches " + "corresponding to the conditional inputs, and the second half of batches corresponding to " + f"the unconditional inputs. Got batch size {scores.shape[0]} for the logits and " + f" {input_ids.shape[0]} for the input ids." ) unguided_bsz = scores.shape[0] // 2 cond_logits, uncond_logits = scores.split(unguided_bsz, axis=0) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py index 37e24b52a..f9c5ae24c 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -433,6 +433,9 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): class ProGenModel(ProGenPreTrainedModel): + """ + ProGenModel class + """ def __init__(self, config): super().__init__(config) @@ -466,6 +469,9 @@ class ProGenModel(ProGenPreTrainedModel): output_hidden_states=None, return_dict=None, ): + """ + construct method of progen model + """ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states @@ -475,7 +481,7 @@ class ProGenModel(ProGenPreTrainedModel): if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") - elif input_ids is not None: + if input_ids is not None: input_shape = input_ids.shape input_ids = input_ids.view(-1, input_shape[-1]) batch_size = input_ids.shape[0] @@ -590,6 +596,9 @@ class ProGenModel(ProGenPreTrainedModel): class ProGenForCausalLM(ProGenPreTrainedModel): + """ + ProGenForCausalLM class + """ _keys_to_ignore_on_load_missing = [r"h\.\d+\.attn\.masked_bias", r"h\.\d+\.attn\.bias", r"lm_head\.weight"] def __init__(self, config): @@ -609,7 +618,6 @@ class ProGenForCausalLM(ProGenPreTrainedModel): for layer_past in past ) - def construct( self, input_ids=None, @@ -625,6 +633,9 @@ class ProGenForCausalLM(ProGenPreTrainedModel): output_hidden_states=None, return_dict=None, ): + """ + construct method of ProGenForCausalLM class + """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict transformer_outputs = self.transformer( diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py index 35fa926a1..9d6489c97 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py @@ -97,6 +97,9 @@ class ProGen(Model): return self.cross_entropy(logits=logits, target=target).item() def ll(self, tokens, f, reduction): + """ + shift, remove terminals and remove unused logits + """ target = Tensor(self.tokenizer.encode(tokens).ids) logits = self.network(target, labels=target).logits @@ -165,7 +168,7 @@ class ProGen(Model): with PrintTime('sanity model'): - alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', + alphabet = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'O', 'P', 'Q', 'R', 'S', 'T', 'U', 'V', 'W', 'X', 'Y', 'Z'] x_data = self.args.x_data x_random = '2' + ''.join([random.choice(alphabet) for _ in range(len(x_data)-2)]) + '1' @@ -200,7 +203,10 @@ class ProGen(Model): print(f'll_mean={ll_mean}') def sample(self, model, tokenizer, context, max_length, num_return_sequences, top_p, temp, pad_token_id): - input_ids = Tensor([self.tokenizer.encode(context).ids]) + """ + sample method of progen model + """ + input_ids = Tensor([tokenizer.encode(context).ids]) tokens_batch = model.generate( input_ids, do_sample=True, @@ -211,7 +217,7 @@ class ProGen(Model): pad_token_id=pad_token_id, ) as_lists = lambda batch: [batch[i, ...].asnumpy().tolist() for i in range(batch.shape[0])] - return self.tokenizer.decode_batch(as_lists(tokens_batch)) + return tokenizer.decode_batch(as_lists(tokens_batch)) def truncate(self, input_sample, terminals): """ @@ -224,8 +230,7 @@ class ProGen(Model): pos.append(find_pos) if pos: return input_sample[:(min(pos) + 1)] - else: - return input_sample + return input_sample def generate(self): """ @@ -258,8 +263,8 @@ class ProGen(Model): with PrintTime('sampling'): completions = self.sample(model=self.network, tokenizer=self.tokenizer, context=self.args.context, - pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], - num_return_sequences=self.args.num_samples, temp=self.args.t, + pad_token_id=self.tokenizer.encode('<|pad|>').ids[0], + num_return_sequences=self.args.num_samples, temp=self.args.t, top_p=self.args.p, max_length=self.args.max_length) truncations = [self.truncate(completion, terminals=['1', '2']) for completion in completions] -- Gitee From b3dd72f5f6d0401ce120029c2822c325bdc16446 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Thu, 5 Sep 2024 21:39:12 +0800 Subject: [PATCH 10/16] pr modification --- .../progen/module/configuration_utils.py | 50 ++++++++----------- .../models/progen/module/injection.py | 3 ++ .../models/progen/module/logits_process.py | 2 +- .../pipeline/models/progen/nn_arch.py | 23 ++++++--- 4 files changed, 40 insertions(+), 38 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index 3d7c2f61e..dff125af1 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -37,7 +37,7 @@ from dataclasses import dataclass import numpy as np import mindspore -from mindspore import nn, ops, Tensor, Parameter, jit_class +from mindspore import nn, ops, Tensor, jit_class from .logits_process import ( ForcedEOSTokenLogitsProcessor, @@ -410,7 +410,7 @@ GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamS class StoppingCriteriaList(list): - def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + def __call__(input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: return any(criteria(input_ids, scores) for criteria in self) @property @@ -459,7 +459,7 @@ class GenerationConfig: raise err # Validate the values of the attributes - self.validate(is_init=True) + self.validate() @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs) -> "GenerationConfig": @@ -514,7 +514,7 @@ class GenerationConfig: unused_kwargs = {key: value for key, value in kwargs.items() if key not in to_remove} return unused_kwargs - def validate(self, is_init=False): + def validate(self): """ Validates the values of the attributes of the [`GenerationConfig`] instance. Raises exceptions in the presence of parameterization that can be detected as incorrect from the configuration instance alone. @@ -537,7 +537,6 @@ class GenerationMixin: @staticmethod def _expand_inputs_for_generation( expand_size: int = 1, - is_encoder_decoder: bool = False, input_ids: Optional[mindspore.Tensor] = None, **model_kwargs, ) -> Tuple[mindspore.Tensor, Dict[str, Any]]: @@ -557,7 +556,7 @@ class GenerationMixin: return input_ids, model_kwargs @staticmethod - def prepare_inputs_for_generation(self, *args, **kwargs): + def prepare_inputs_for_generation(*args, **kwargs): """ prepare_inputs_for_generation """ @@ -576,8 +575,6 @@ class GenerationMixin: if inputs is not None: return inputs - encoder_outputs = model_kwargs.get("encoder_outputs") - if bos_token_id is None: raise ValueError("`bos_token_id` has to be defined when no `input_ids` are provided.") @@ -733,7 +730,6 @@ class GenerationMixin: input_ids, model_kwargs = self._expand_inputs_for_generation( input_ids=input_ids, expand_size=generation_config.num_return_sequences, - is_encoder_decoder=self.config.is_encoder_decoder, **model_kwargs, ) @@ -822,9 +818,9 @@ class GenerationMixin: # auto-regressive generation while True: # prepare model inputs - model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) + model_inputs = self.prepare_inputs_for_generation_new(input_ids, **model_kwargs) # forward pass to get next token - outputs = self( + outputs = self.construct( **model_inputs, return_dict=True, output_attentions=output_attentions, @@ -833,9 +829,6 @@ class GenerationMixin: if synced_gpus and this_peer_finished: continue # don't waste resources running the code we don't need - if isinstance(outputs, dict): - outputs = ADDict(**outputs) - next_token_logits = outputs.logits[:, -1, :] # pre-process distribution next_token_scores = logits_processor(input_ids, next_token_logits) @@ -903,13 +896,12 @@ class GenerationMixin: cross_attentions=cross_attentions, decoder_hidden_states=decoder_hidden_states, ) - else: - return SampleDecoderOnlyOutput( - sequences=input_ids, - scores=scores, - attentions=decoder_attentions, - hidden_states=decoder_hidden_states, - ) + return SampleDecoderOnlyOutput( + sequences=input_ids, + scores=scores, + attentions=decoder_attentions, + hidden_states=decoder_hidden_states, + ) return input_ids def _prepare_model_inputs( @@ -1659,8 +1651,7 @@ class PretrainedConfig: if not self._attn_implementation_internal: # `config.attn_implementation` should never be None, for backward compatibility. return "eager" - else: - return self._attn_implementation_internal + return self._attn_implementation_internal return "eager" @_attn_implementation.setter @@ -1683,7 +1674,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary # warnings. _keys_to_ignore_on_load_unexpected = None - _keys_to_ignore_on_save = None _tied_weights_keys = None @@ -1712,6 +1702,12 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): """ self.init_weights() + def prepare_inputs_for_generation(*args, **kwargs): + """ + prepare_inputs_for_generation + """ + return + @classmethod def _from_config(cls, config, **kwargs): """ @@ -2071,12 +2067,6 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): if state_dict is None: state_dict = model_to_save.parameters_dict() - # Handle the case where some state_dict keys shouldn't be saved - if self._keys_to_ignore_on_save is not None: - for ignore_key in self._keys_to_ignore_on_save: - if ignore_key in state_dict.keys(): - del state_dict[ignore_key] - # Shard the model if it is too big. # weights_name = _add_variant(WEIGHTS_NAME, variant) weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py index f0350f4b7..7f0af1506 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py @@ -906,6 +906,9 @@ class BatchNorm1dMindnlp(nn.Cell): self.bn_infer = ops.BatchNorm(is_training=False, epsilon=self.eps) def construct(self, x): + """ + construct method of BatchNorm1dMindnlp + """ if self.use_batch_statistics is None: if self.training: return self.bn_train(x, diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py index 86345b5b8..19200f8ac 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py @@ -717,7 +717,7 @@ class TypicalLogitsWarper(LogitsWarper): last_ind.clamp_(max=sorted_scores.shape[-1] - 1) sorted_indices_to_remove = sorted_scores > sorted_scores.gather(1, last_ind.view(-1, 1)) sorted_indices_to_remove[..., : self.min_tokens_to_keep] = 0 - indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, + indices_to_remove = ops.tensor_scatter_elements(sorted_indices_to_remove, sorted_indices, sorted_indices_to_remove, axis=1) scores = scores.masked_fill(indices_to_remove, self.filter_value) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py index f9c5ae24c..89afadd24 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -279,7 +279,7 @@ class ProGenAttention(nn.Cell): def _split_heads(self, x, n_head, dim_head, mp_num): reshaped = x.reshape(x.shape[: -1] + (n_head // mp_num, dim_head)) - reshaped = reshaped.reshape(x.shape[: -2] + (-1, ) + reshaped.shape[-1: ]) + reshaped = reshaped.reshape(x.shape[: -2] + (-1,) + reshaped.shape[-1: ]) return reshaped def _merge_heads(self, tensor, num_attention_heads, attn_head_size): @@ -409,13 +409,12 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained models. """ - config_class = ProGenConfig base_model_prefix = "transformer" is_parallelizable = True - def __init__(self, *inputs, **kwargs): - super().__init__(*inputs, **kwargs) + def __init__(self, config): + super().__init__(config) def _init_weights(self, module): """Initialize the weights.""" @@ -431,12 +430,16 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): module.bias.data.zero_() module.weight.data.fill_(1.0) + def prepare_inputs_for_generation(*args, **kwargs): + """ + prepare_inputs_for_generation + """ + return class ProGenModel(ProGenPreTrainedModel): """ ProGenModel class """ - def __init__(self, config): super().__init__(config) @@ -455,6 +458,12 @@ class ProGenModel(ProGenPreTrainedModel): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings + def prepare_inputs_for_generation(*args, **kwargs): + """ + prepare_inputs_for_generation + """ + return + def construct( self, input_ids=None, @@ -685,7 +694,7 @@ class ProGenForCausalLM(ProGenPreTrainedModel): def set_output_embeddings(self, new_embeddings): return - def prepare_inputs_for_generation(self, input_ids, past_key_values=None, **kwargs): + def prepare_inputs_for_generation_new(self, input_ids, past_key_values=None, **kwargs): token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: @@ -722,5 +731,5 @@ class PrintTime: print(self.desc) self.t = time.time() - def __exit__(self, type, value, traceback): + def __exit__(self, input_type, value, traceback): print(f'{self.desc} took {time.time()-self.t:.02f}s') -- Gitee From cd55609574c6e753679cdb9a93691ddc7bba3b6d Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Tue, 10 Sep 2024 15:11:25 +0800 Subject: [PATCH 11/16] pr modification --- .../models/progen/module/configuration_utils.py | 8 ++++---- .../src/mindsponge/pipeline/models/progen/nn_arch.py | 10 +++++----- .../pipeline/models/progen/progen_dataset.py | 9 +++++++-- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index dff125af1..c0be414fc 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -39,7 +39,7 @@ import numpy as np import mindspore from mindspore import nn, ops, Tensor, jit_class -from .logits_process import ( +from mindspore.pipeline.models.progen.module.logits_process import ( ForcedEOSTokenLogitsProcessor, LogitsProcessorList, MinLengthLogitsProcessor, @@ -410,7 +410,7 @@ GenerateOutput = Union[GreedySearchOutput, SampleOutput, BeamSearchOutput, BeamS class StoppingCriteriaList(list): - def __call__(input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: return any(criteria(input_ids, scores) for criteria in self) @property @@ -1702,7 +1702,7 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): """ self.init_weights() - def prepare_inputs_for_generation(*args, **kwargs): + def prepare_inputs_for_generation(self): """ prepare_inputs_for_generation """ @@ -2147,7 +2147,7 @@ class StoppingCriteria(): """Abstract base class for all stopping criteria that can be applied during generation.""" @staticmethod - def __call__(self, input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: + def __call__(input_ids: mindspore.Tensor, scores: mindspore.Tensor, **kwargs) -> bool: raise NotImplementedError("StoppingCriteria needs to be subclassed") diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py index 89afadd24..a7c621bb0 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -413,9 +413,6 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): base_model_prefix = "transformer" is_parallelizable = True - def __init__(self, config): - super().__init__(config) - def _init_weights(self, module): """Initialize the weights.""" if isinstance(module, (DenseMindnlp,)): @@ -430,7 +427,7 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): module.bias.data.zero_() module.weight.data.fill_(1.0) - def prepare_inputs_for_generation(*args, **kwargs): + def prepare_inputs_for_generation(self): """ prepare_inputs_for_generation """ @@ -458,7 +455,7 @@ class ProGenModel(ProGenPreTrainedModel): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings - def prepare_inputs_for_generation(*args, **kwargs): + def prepare_inputs_for_generation(self): """ prepare_inputs_for_generation """ @@ -695,6 +692,9 @@ class ProGenForCausalLM(ProGenPreTrainedModel): return def prepare_inputs_for_generation_new(self, input_ids, past_key_values=None, **kwargs): + """ + new method of preparing inputs for generation + """ token_type_ids = kwargs.get("token_type_ids", None) # only last token for inputs_ids if past is defined in kwargs if past_key_values: diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py index aa6cf5939..6ccff4dbd 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -21,6 +21,7 @@ # limitations under the License. # ============================================================================ """progen_dataset""" +from .nn_arch import PrintTime from ...dataset import PSP @@ -34,6 +35,9 @@ class ProGenDataSet(PSP): return data def set_training_data_src(self, data_source, **kwargs): + with PrintTime('set_training_data_src'): + print(data_source) + print(**kwargs) return None def create_iterator(self, num_epochs, **kwargs): @@ -42,8 +46,9 @@ class ProGenDataSet(PSP): def data_parse(self, idx): return None - def __getitem__(self, idx): - pass + def __getitem__(self, **kwargs): + with PrintTime('get_item'): + print(**kwargs) def __len__(self): pass -- Gitee From 13bf5ff0a0c150f3dc21419159c8eb622d702285 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Wed, 11 Sep 2024 11:46:39 +0800 Subject: [PATCH 12/16] pr modification --- .../models/progen/module/injection.py | 219 ++++++++++-------- .../pipeline/models/progen/progen_dataset.py | 1 - .../mindsponge/pipeline/models/progen_v6.zip | Bin 0 -> 57127 bytes 3 files changed, 125 insertions(+), 95 deletions(-) create mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py index 7f0af1506..af228a1a7 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py @@ -211,22 +211,12 @@ def einsum_label_to_index(label): num_of_letters = ord('z') - ord('a') + 1 return (ord(label) - ord('A')) if (label.isupper()) else (num_of_letters + (ord(label) - ord('a'))) -def einsum(equation, *operands): - """ - einsum method - """ - assert operands, "einsum(): must provide at least one operand" - if isinstance(operands[0], tuple): - operands = operands[0] - - arrow_pos = equation.find("->") - num_ops = len(operands) - op_labels = [[] for _ in range(num_ops)] - lhs = equation[0: arrow_pos] +def enumerate_lhs(op_labels, lhs, num_ops): curr_op = 0 found_ell = False ell_skip = 0 + for i, label in enumerate(lhs): if label == ' ': continue @@ -247,17 +237,14 @@ def einsum(equation, *operands): op_labels[curr_op].append(einsum_label_to_index(label)) - assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation" - # Labels must be within [a-zA-Z]. - total_labels = 52 - label_count = [0] * total_labels - # The maximum number of dimensions covered by any ellipsis, needed when - # unsqueezing missing dimensions from operands to permute and broadcast - ell_num_dim = 0 + return op_labels, lhs, num_ops, curr_op, found_ell, ell_skip +def enumerate_operands(op_labels, operands, label_count): # Compute label frequency and number of dimensions covered by ellipsis # We do this after parsing labels to make it more readable and simpler # to compute the number of dimensions covered by ellipsis. + ell_num_dim = 0 + for i, operand in enumerate(operands): labels = op_labels[i] ndims = operand.ndim @@ -279,62 +266,10 @@ def einsum(equation, *operands): f") does not match the number of dimensions (" \ f"{ndims}) for operand {i} and no ellipsis was given" - # We want to align the dimensions of every input tensor to have - # shape out_dims + sum_dims. For this, we create a mapping of label - # to index into the permuted shape. - label_perm_index = [-1] * total_labels - # Current index in the permuted shape - perm_index = 0 - # Start index of ellipsis dimensions in the permuted shape - ell_index = 0 - found_ell = False + return ell_num_dim, label_count - if arrow_pos == -1: - # Implicit output is ellipsis (...) + labels seen only once - perm_index = ell_num_dim - found_ell = True - for label, label_count in enumerate(label_count): - if label_count == 1: - label_perm_index[label] = perm_index - perm_index += 1 - else: - rhs = equation[arrow_pos + 2:] - ell_skip = 0 - for i, label in enumerate(rhs): - if label == ' ': - continue - if label == '.': - if ell_skip != 0: - ell_skip -= 1 - continue - assert not found_ell, "einsum(): found \'.\' for output but an ellipsis (...) was already found" - - ell_skip = 2 - ell_index = perm_index - perm_index += ell_num_dim - found_ell = True - else: - assert str.isalpha(label), f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " \ - f"in the equation string, subscripts must be in [a-zA-Z]" - - index = einsum_label_to_index(label) - label_perm_index[index] = perm_index - perm_index += 1 - - out_size = perm_index - if not found_ell: - ell_index = perm_index - perm_index += ell_num_dim - - for label in range(total_labels): - if label_count[label] > 0 and label_perm_index[label] == -1: - label_perm_index[label] = perm_index - perm_index += 1 - - # Here we unsqueeze missing dimensions to make all operands have the same - # number of dimensions. We take diagonals for repeated labels within the - # same operand. Finally we permute the operands to align dimensions as - # per the perm_out_index we computed above. +def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, + ell_num_dim, ell_index, label_perm_index): permuted_operands = [] for i, operand in enumerate(operands): perm_shape = [-1] * perm_index @@ -388,26 +323,49 @@ def einsum(equation, *operands): dim_last_op[dim] = i has_zero_size_dim = has_zero_size_dim or (broadcast_size == 0) - # Compute result - result = permuted_operands[0] - if has_zero_size_dim: - out_shape = [-1] * out_size - for i in range(out_size): - out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] - return ops.zeros(out_shape) + return permuted_operands, has_zero_size_dim, dim_last_op - # Sum out or squeeze dimensions that are size 1 for all later operands - dim = out_size - for i in range(dim, perm_index): - if dim_last_op[i] == 0: - if result.shape[dim] == 1: - result = ops.squeeze(result, dim) - dim -= 1 +def einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs): + # Current index in the permuted shape + perm_index = 0 + # Start index of ellipsis dimensions in the permuted shape + ell_index = 0 + found_ell = False + + if arrow_pos == -1: + # Implicit output is ellipsis (...) + labels seen only once + perm_index = ell_num_dim + found_ell = True + for label, label_count in enumerate(label_count): + if label_count == 1: + label_perm_index[label] = perm_index + perm_index += 1 + else: + rhs = equation[arrow_pos + 2:] + ell_skip = 0 + for i, label in enumerate(rhs): + if label == ' ': + continue + if label == '.': + if ell_skip != 0: + ell_skip -= 1 + continue + assert not found_ell, "einsum(): found \'.\' for output but an ellipsis (...) was already found" + + ell_skip = 2 + ell_index = perm_index + perm_index += ell_num_dim + found_ell = True else: - result = ops.sum(result, dim) - dim -= 1 - dim += 1 + assert str.isalpha(label), f"einsum(): invalid subscript given at index {len(lhs) + 2 + i} " \ + f"in the equation string, subscripts must be in [a-zA-Z]" + index = einsum_label_to_index(label) + label_perm_index[index] = perm_index + perm_index += 1 + return perm_index, found_ell, label_perm_index, ell_index + +def sum_result(num_ops, permuted_operands, out_size, perm_index, dim_last_op, result): for i in range(1, num_ops): operand = permuted_operands[i] sum_dims = [] @@ -430,10 +388,83 @@ def einsum(equation, *operands): result = result.mul(operand) elif len(sum_dims) == len(result.shape): result = result.flatten().dot(operand.flatten()) - else: - result = sumproduct_pair( - result, operand, sum_dims, False) + return result + +def einsum(equation, *operands): + """ + einsum method + """ + assert operands, "einsum(): must provide at least one operand" + if isinstance(operands[0], tuple): + operands = operands[0] + + arrow_pos = equation.find("->") + num_ops = len(operands) + op_labels = [[] for _ in range(num_ops)] + lhs = equation[0: arrow_pos] + + op_labels, lhs, num_ops, curr_op, found_ell, ell_skip = enumerate_lhs(op_labels, lhs, num_ops) + + assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation" + # Labels must be within [a-zA-Z]. + total_labels = 52 + label_count = [0] * total_labels + # The maximum number of dimensions covered by any ellipsis, needed when + # unsqueezing missing dimensions from operands to permute and broadcast + ell_num_dim, label_count = enumerate_operands(op_labels, operands, label_count) + + # We want to align the dimensions of every input tensor to have + # shape out_dims + sum_dims. For this, we create a mapping of label + # to index into the permuted shape. + label_perm_index = [-1] * total_labels + + perm_index, found_ell, label_perm_index, ell_index = \ + einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs) + + out_size = perm_index + if not found_ell: + ell_index = perm_index + perm_index += ell_num_dim + + for label in range(total_labels): + if label_count[label] > 0 and label_perm_index[label] == -1: + label_perm_index[label] = perm_index + perm_index += 1 + + # Here we unsqueeze missing dimensions to make all operands have the same + # number of dimensions. We take diagonals for repeated labels within the + # same operand. Finally we permute the operands to align dimensions as + # per the perm_out_index we computed above. + permuted_operands, has_zero_size_dim, dim_last_op = \ + unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, + ell_num_dim, ell_index, label_perm_index) + + # Compute result + result = permuted_operands[0] + if has_zero_size_dim: + out_shape = [-1] * out_size + for i in range(out_size): + out_shape[i] = permuted_operands[dim_last_op[i]].shape[i] + return ops.zeros(out_shape) + + # Sum out or squeeze dimensions that are size 1 for all later operands + dim = out_size + for i in range(dim, perm_index): + if dim_last_op[i] == 0: + if result.shape[dim] == 1: + result = ops.squeeze(result, dim) + dim -= 1 + else: + result = ops.sum(result, dim) + dim -= 1 + dim += 1 + + result = sum_result(num_ops, permuted_operands, + out_size, perm_index, dim_last_op, result) + + return result + ops.einsum = einsum # conv1d diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py index 6ccff4dbd..5645e174c 100644 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -38,7 +38,6 @@ class ProGenDataSet(PSP): with PrintTime('set_training_data_src'): print(data_source) print(**kwargs) - return None def create_iterator(self, num_epochs, **kwargs): return None diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip b/MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip new file mode 100644 index 0000000000000000000000000000000000000000..6ab624ef7199e8d00f39b5cec144f70b402abb98 GIT binary patch literal 57127 zcmV)UK(N11O9KQH000080Cc7!S(YJHs!aj_09XV7022TJ0B~||XJu|OFJE72ZfSI1 zUoLQYty9r%n?MkJ_gAclR3(ywqo_}*D%DBc1l7e>a2n;QI>Q31!0q+!C~?2OdziMd z(!Lc50k^~K%-r4)T(Q|%s%j7CXXkI>17tPoA9BN`+QCw7*(%R=2X7`v@akIQ6X^=> zHGN0x$o~U&+o>K2Hkj9%8?FeQG>)gR^u^>z{@B#yAm;)>o(hbz9=L?hq>F65N^cey z0vf&yTZD|-u2G&?!&>D}szUIqUfIGyiq9TyP@~<~TW&BHu$8w0n?~fuL(4zhp(h*E zr;a@pxu)hmCIwSOyNsNPa2)aXK)+<(O8O3-U-8wJ7{G)G<)crm<@52)F9wZFN zNv|~xDJ=$i^7YVWBoV{#I8yM|D?uy4n_bBG1rC%S1P4E(9W@Az*;cb=4|20*aF=n8r6A`mZ99Ju z)GuY?_a#80QW~Op21ypdg>1|H^jck^_UEHYT#!PP3gPBzIZm9LL5 zPaok^l3q^%B_?Tu&&CG*B9cOcScJSX!~+2(?*%!7IhB-)FI{zE@+y|u*Wnl#v@N-? zkutx`J;D`YLB>(^baB^>FWo|Mr zZEs|CY-KNFZ*FF3XLWL6bZKvHUv+e8Y;!JfdF_2^cjHEq;P?6!+0>gAS*BP=w`cch zhhEE7vaRj9UP*3GuUZd;NKk@B0&D=((R!NSzH#K836QGt(KBPbZWjgQ5g8d75g8c~ z+3WS5Ov_@PU#zRN$;%>HH~FG|(u4Oq;#Il2tMZG>M!bCf@=xMxaeSE-|A7CCZ_+#Q zCOt2qQd!-J{{e-A=R7(BEQb6I&WHveZNj`vRs^G)61e_Cl^%e)rTa+V4BHx;Y2YD75~Z}Vby{Neq(?+*8&5dF0(t4ti{%k_fBB0i+ktMnoh z&HJ0Qn$_Z6cH4;WvLfTLieee*-cqpxzO}4n`Tw-KYMm_b2D1P>l$B=rp2=X_4Fp6W(Bsi zE@oN9jjhrt{3E{%#a|G6(9z4$bJ52FJ^86O_}|dNUAY!Zn2w@s#JYyTBD&`JA`{u| zG+Q+yF93VX)gn&|7{E>5Tv9i(5g4%et8Ah?Z&E0oLiyDl;@B!D(uM%W@yOw%1n8)& zE}kv8Q2p%9(W}FE$A^0WEEWHuSY&lAs_gIUyn?As{DN2IEYe^!FcTgfU2|*y@O+Mblel)92_4V521}8k50aS z|HFy+@!;dfgLfxKhsWamNAc?YyVpl2NAKUkU*C#@cYhUsK6>|hC^A43bd%k#Duf?^ z%n^mz4DofGWex!Ik^!n$*)*T$Qy5oqu|~?fD1r71Tn4MGTIMxQBoJ`{lr0wdGH-|w z>u^xm&G=`0o&fkIEF7HKC-bUY@*K;zvn*TD7y0*3by=w2Wvza}WZ+Qr-+6%pZn(MT zZiT~?pWd%Vpg|))~}wZkwMphoU|6jO1;u+ zmz&Xlsz!=p2#lhJQNi?;)lmGDH;HW3tub2Vs| zAKsiKuTTE^;ZTe%q|O&*+Whh5lcRS>C&`b8N8f#aa-4ko=HNRhbyBUfCr_Ttvbjj| zI?wrWdC+(HO_N7;3n)AlJq*wiL(39l+}`{dyxShltlxw zr@uSPu3>#AOmCpL}bul$VC`y5EfzN{=$STz-R*8GB%3ocqeg_!hbmj4j)d|I0TBoYQ4)xGMQwg1_^g3e1p2NlvVCWCrpTB^ zI>NODGS)&z(v-=Uv%da2GTwpHW}a1#qj6MJ5a-EP(<8q#yW5*N%{kMP|{1QTd?wf%WD5QRRXLW?+mVp*8Xau?~RF6RE2qFG* zDCRx$8sGkb-l8bMm45Hmc`&mYA06QGZk6?C_}?)tX!=P|*RWSXu~rHa?uvnkFT{NH z;$LY~104GtdFR~*djA^KDPYO698y1ZL!?gZ%$N0`fMgDl>)#Vol-ii=!wxm$_SBw< zRXNcIJg%3h7B`mwKgwb2E3in|GB&tYF4Jq~mNmTLrUr2?DTjyrU)N8pv1p<57In+_ zB%E+JS}az5OKJ8rif>>uE4Us~1OiIBsg|A~jI@?K65o?-nj{H;kvG{gNhW%PZLs`4#LYsow^700a64+c(%+sIRc0BZWV$?grfSO%6ci z7ey)M$^;qR1a_rz-v?m5$1U4U4nsN55Tmt_jaAtN?((V2)m?K5J4Dm|h$-aXPV$x7 z;f(Ja;QJ3>@fX}3!|K_E?R5b=-YXosh206XRl)XTaW}%92myCG`46(AjI44aV6HUY z$(ev@yDVwLD|xWU8~iI}j8M1d*&Jz1LiY^7`%2H5;o-@u1a6cUu-QnM0u0sCY$d$; zLH&n*C^Aunc1TZb?r>9z(A2PD$cpKmfUW)XsEj4SnXVV8fRZ*q^hsaHu3n-1L@fcvK-NK` zJY9T*{G#@h9(=4aiB*-hZg#m)o5H85?!uS&SWV$iXN1uT+roj5G_awpAo_}f(9$C7 z(@u>YMDqW^F+@3m zb|n@d@tIT)?}r5_nyNbOyI*coDeM|n@erj|yw8?l;)5$_c8fhfZ!d(;jyNJJn3aYZ z*@OT(2F&s^nq<&9aXLYIph#A-u(Nq&D5o*Lbmp$q)<)th(01-+B;GL!MnfG3$1>z0yqVYsjdM~H&VI+7=M=^PRTo+Bg)GCJ57u9*pWMZ{Anb7fsXjelN zztJiN!UTjP^`M^#B-7w0LbNJZKNBd>(kH^OUd>R%{5)WAW1k3#wG)!yTv-6Py3-Yl z4^O=+3f#wmCI!TZ|4=hZS#k+o*{XGw(<`%{=0$2a2*JueI*fsk$ zoKKRH#e4LWB#Jt*$Z%ph%>!YjE4#R}Fg6nBGo9*i4#Wx6oi?Po|1a@P+VB_ssi zl6TrI7_w_Na}JSNOXAX${^ZHm<~)u~67s}8d!3;dK=r;@-0_yW9~wD4!^V4QzW|+S z#PssoFsLR1@K#N%&?pKdk+53Y(nxDlx~-1s zKP+eh6%{<50y1u3Pg25vF(QzoOFD>Qy9Tr%+Bka$e(9fL6R^)#gYiY;-Mo!~Tf%u+ACU|6|V`*H$ zP8xmiuwm*J<*;hF4|9M%IQex3GmT$l*8@86!L98S9W-G7eVwOv0c&w-TT+w*sELzu zKhkBm=yi+H6==0K`1NS-cmDWG@yBpgeFJjyw_l?ZJD zCGzR^9ewV!f9`a5&G&D`NzJTpVyD!hVX40yJ8fqg|7`5^pPk@;XWhRT>uy=gSz1jm ze=FWio&V0ee@)&^DF1DFH_!SuiZFHiOLFf2Jc{r?JGuYPxqmUv9ooq-i01CMV%-$N zbN6}7`%|d$bEo4Mord3vqq>vvyEgobQ{nG??$_gUugapS(i&;xmsw?7@w=&k;@4+- zzmwZ9Ms8cs+pjNNe^X-n+?o088L=fQBN@I?)adCZVLR;Biq4*})yiwN-IO(KJFPi; zV)jv9Yuy{HVb_LRwi<)6+vGOvGQplb2}s&&d@B|^++t;%JyEB(9TVrGHbhG}q*ee_ zuPO@SxG7f{m;9X)?XDEo9E!K;EyoBev_{-s*?mU4Mmyh?lTy(1$P7zQp~=_qiV{`0jP(_j zAvcuE65|{&d4ZMBg5z0i#@b-mk<$hfffYj$btBsmA56dwbx2&fWXBj8(H zdAOLN!6M++N9aY@Gaq_G@$&f~tX5<e54npJU~0>z&%G7{etNm<#%g@Vi4yA~Ic;Cl-aWI1<>6K%>V(%l+!oOm<* zt))IMTyi>u^^&Nudo8t;GHqH2KgN*66xI;a8LuneTyhk2IC{RwoC?U|jVmMyHi=2;R2T}U)^H)$LqF!nsU8{obD*$JQVydq zvZ^HxXi;8VWYviLI{O%vT5Njfan>nMw|~hKT;kU z+HlG6T2|>RJ%dCpJG^PpNEAXo$y0QZJKZ|D)} z8cojT8*v(v$(f`R@xhPY#1Wc6+PLp2rjRPpm$$lDWBfx+$&Su30+)7oSk~YNuWm+%j5Kg3nJZ zs3pq+=Ipxa0nQasl5uKptEHH2DMvV>q#t5{Q%~`u{LT#BbxS==UX2}NZ2P*5B z`AIBcRzKRkDS%lFTE&QJCz5$m_9Eo9OG+z}-NG)O)Cw*R<%b^X?Z5?|RmD)Fa01mx z);<=eCZgtX+6CtMNNZPzh$tZ2>lqrV5lWY0hCv^W3MgR?BQ%&{-(;YNN!CBqnjy?? zKELD46d9X5NT{}QzkBF*Uzk|RE{q@&;JZvQWk&GeFAb*?LtI)bAA97P-?OA5d?=(HuUs=7*wQ!}@|2tbVI(L7g?E@7VmDEeP(4vZzze6P5(r<;5(&&Sz_9yJB;& znCU>0(LJ(Pxd$4GUXg)d>OI8x+F$I14->bwiCC@yeH8PHj4{Lf4wcxY=5~Qut|`^T zeFTsP3i%7HQjx4)(IWKaKz^ILp;b!w%JeIs3hq_e1rT427Jp0~XQ6x;&k}|>BDs9R za+%H~lQaqR7=0%nq; zSqInxES~GE@8D~eFUQYaAtrgf+n7Np2FEIbX^$v1MaJg$mcY2|&4{$FX|#)mDxGD# zw+nWyuob7oYgu9K!OOPH>ue)hgZUwdlOsGe!M$%bi{w%09Vq(^F(;iFLE1i3G3o4j zqVgH#$V+FK$hPFDYHb{~k|}hV8WNCK=UJl7MzTXhbthB84OB=kk7q!()o`deP1m)vHXReG z%sjL~=JqCbr3vMFVu{7{G~+1Cq7)N7yA#rERNfmWy%Hr8LB)jv&PgF+mX72$v1fp- z7^sQm#ayOSCe|*9wUYqI8@b`3u7-04M>z@W7S#7fDVgicUD3^{; zPwvW<9!bR^a=MOx*T1k?Q~wVL=S7UT<*k zQkR_f&vXf^eMKAlMkY457Wz=`)!_t&VhJs5B2XS68Z`oew5!>4PB|H^u{eX&(`E3czayhj=WDrJCfu}e4_L0X zFMM!iVoh+mnDyvNxS4E}o2UboLI zHV5@sBH-SF^?{5_2=u)@=m+aCuNL?GmJ?pmDU3XGYRWiR+@Y9^l&K9p6ghu>_!VybSjURXx1V>vf93TMH*_aLB=fuT z#%L1ia6fRTvUpj9i^VshC&qTq0Q0$1r2gcLxR+exLc2|YaIeYoRb`9p8Yfy+|B(h} znP-iCkPKePjz}3V$|X9ExLGU|r23B?*$y3#3C3FOa#Zzl9790g-H8;!P6nltCsdYxZrSbJW|_0@OmnKn_|L?;i*$)?saGb$FgNz}whGQ1tUr zHsXFTnI%cgV1^yPl)*rJA*?*)`8*_v5#}++QPturKg+tyc)R-=3d^t3e$ZQt-Fwjf zL*S2T02^RO(4EHOi6^7YztqLTO3?~0>}f_;GcQr6#WfEMB}GxT07ijPDbsS5BO$U; zA3l8exG&T-Pe4GI>+{jHTt2&8UtBX>qak%ueLqdZcWm+OK|D-#P}lwCSeMMtYU<9-&Mci+=KGUaqGEyMgGkn+Fptr{Uykl*!WaG+m6xcED*O6G9ziV4zNMK zexjK%B9?UvzI=|i;9e3{R#pQMAr^1LcfU@pMxuL{mR3Mz1h!^X?qRgcv3%caJL(Az zdBVVeAe>MvQr>Vg2?21(17uI(=jR5kop&_wvo z;%&5zjO&7Wc>dafe!_QnpdG&Sdy2-WI0*UUNWA8w7@iPmr&k5!p-dwUk-iOS=u-vH zjk2;@#cP+}@w;JU!p|MOujR08NBl{{&7KEQaqI4XT>u%3=j_+(T%=tjL zuzaRv`Mge47^SkzMooBQWvGPm%kykj_guN?0&Joi5K8CTfru*GQ8lvm)pvW2z8vZ6 zx-2WL^x>mRSCp$k%r(Ju@;FR++YHO#=n%89o{fF5d?Egmh2xv^{Sn@X^jo}vmybQe z71$yoxA%(DEg7A`YLiR@=$quA^kg^!CTP!y_lAs37OVCRJAPc6Vdl7yqhc|ym| zUmNxvA(gom>_Dyv^6bwZ(k(S5Qqi$qITf8)xJXUsyIZv-_$|^%OUd~5YOH=diY&<8 zJTl`gJ6WWGNNjT_Rs424Plv=l1q>sRw(<3`$!weQHg#F$d6!Gqts-PpKjIaHPoU#k z&7kjUXRWrZ3mYp2IU?F?v{3L?bwAl|5E8s)J#fC6*Mst!F1qDYA^B8k-8eB?vXnP>u^*>sPJ%N}>C=p&<3*apdDgggbxcbMJ4z>9~E(Q89E!UI4iB zCv@cgy=6ysC7Vj^d*~e73y2Tz$jt-m)c*S{3cT8Ek1dADubj>Xn@8m^5kJ|0`!}@Z zz1R{(PA$JPC*Oj*Qo|Ak1pDRVz=#;220l=J@`fl}ZXpAj_x^Fv#cf&s`A00Dg9~Fz zIj0PO+}2#ozJO2H2sU-;Y8TLUWX>s-@$i|yX_b zLhfNYv_~~;s>K@jbEN%@O-B=_J*s`Drrm`-Lvv-jMn^~ArImtjt310lZlPv2Rd!13 z(V179L3V-p7*k{xEJ&3cG5chVE-CP2nt-x?C-8Lj{2XSDW zRYk55O@&upgEq_-VWS|U(MO%a%wYl6m)0@Nx=^o-K?6mAavJvyop=Xk+;7|fgnCp0 ziZEVe%xw#Ki!K67di437qt&7ZzBQgLjMp&8MMhS4mn?R)ohy0*gO7cVBBTYWGi)X) zvKu=On7g$Zc>P0Nct2yuYVy0oD@<=2-laq&iUtBGA2wU-X%kYE>QbKc1mCSyHKy>( zphQ6Bw*F^ z+3|6fnba8Q>h;8F5obW^L(%}M^qhLcFHZM|V(-OJ?7MLlL{NKjkK|as+uGN{Dll$a zG!(!y*tR)oh>f=;rJgqe*q@-mnvGyqnn;UND{ey(Kk5#3e;V7&2*KR>I%8M@A33rw zfQ1G_7={7%F;*RKI3nS{i1BIl^c5w96^F!_H=HICi`I?Um7`t8-)`(CK304Jr2|K+ z!0|s02TwXs>sq5avFj4f;}GmJk2pf`*@%?Zn(mxz)1QZz9p5^S53J>LB_A>T40g>C z+7;<-UXS-)w5@5?f@OT_Q_HZH*GuUwSdU)}Mm31!>>pVleggAD@AW|Jx-e^rTf=m- zG_Nnq^f7%};IMS`7lXm5$l5t_toFJv>}CKeaXIS~IXw?oR78gCkP(s< zh%BWYh(sn@Rb$A+2M`)Z1TfYCdXPMivybQ=gPd@DCXPVmnCFnv#$+=Yw0N5@*Gst! z!^PqoOs$0{Rjd%KEkVlXv69FW$Q#FK%3%*~eD7APTCZsIwZmUV#Fkz2G)T2k~8&Qp9HAnAT_*njIKIOl5>Wa=-c-_>OH4k`r_yOLh@5Xb2z@a!rbdjmu^qhm9{XaCtm~*j!EZ~& z5;6&voeq@JQFz8oby7^4@hqRuvkJ%#qvtd0B$-&VtZDPlPxGB*$4IV-k@J?nq!AOI zf^~njFRnQ!(GY(4`6hCc%d@U? z#wG+kwaSqZGXzk1-~J};Yy7P>1fHFtu_Z@hIhFkw7irIH6@;6^i03^uj(u^@zaA!> zme0;Vi*c*+JL6OuTxHLcVF74QglEU|-dnSYDMgkO*2zw&`olT1+a8ZbUd-xN7>Q7x z<+95&X-(xHM;I_^=h35#>9RQcLQk~sgIqTIRs(xV2yreIxvu2Kb6SElp(GJ=lX(JJ z{Bp&iyo3R-Hx8xNY{cpd(@>y|k5(-j@}$8PH?TV63r>I!Kj7yIGeHO0Xo!B$%VSzn zvPwVCjM<9y{WcxHYnhAiMeT${;LeWn{DDqcazwf9z-`T^-L#W9L;Itjc9)&mH9mps;e*b_Zco;fjpIPTZY-`R<^|n5HmdWBg5neVn+^D|nK7BC zud+Kklt(>Ldn1mv3;Hch8_`nu%zJ+=*NVINwB-l6drLybb#ns5ghP~)1b*RkC&@#E z0ja0A^%^?PIAH3#4<|3Zd`cW_vVT|J{9#i_zj}6<$K@EWi-{?%HHfSM^PcLonJ#}AV)pP9;JBfasb$JW*8U>D zQmH^PMoBr3blY&IQ+Y_bOzSJYmaoDX<&t>@C{37a(^6taK}*i5qO+FIv=!zP)3T|D zDR3y-!-mH89~%}?JPcfmt!}FcKB0}UW4Rbbgy+;>J|pn~BmP0g<#WpkHo65X6vxV)SMu-BXY9JEo2BipT?0De;rP84E9?p0ujQyMU$DWl|;_~qX zuz9ZIJ0}N2v3x=@^Cs;C5|}I}i8h=3WIAPO|JxUD@f&Y4-}1F=7~=pliahBbG8hU6 z0BetvZmSdJmJE!td!!Ya)t8qR_ZY20kb-bXWX=-L(-P*a=M|w`RB8+kDTQsRg9z`7 z0T|Ox^_HN?e@d%Z&ufYbpRcD^SyL+~gHLEV(u_^a-z-YN?L5ETx95xZe|#knPntPb z%s?MgfogiS5Y#e!`PzWAik=;OBI|1GOqeSamM8`if8xu6De^u%UVnl3it z@sm(z-LClf@&63yv-*D&3WF{Hr|+_x6KX1s+@K0RF|HQFvl@tv-yt*$OC-V%iUch9GZ!HNX&t2X6q zst?v&MpG4Gzb#v1nl(%pO`_g6C!*R=E>|$f^M$-AzNlRKtSEa!|kog zTqfvjtov#z%E-h$@Rb8`k?$CoQ{$B2CsW_@x&&(e(OLCta#KG`&nQiPj7nn4?E9CcZ+GZ(7JR8;uKlMjm7Cr$b7(N` zbOcBjR`Jo+ZbH)+ zBfNlHx%%D6*i^f@RasWSJykaDTXM;?C=2cl)cLkqWyWLi(sQ%gfj8KIoy*s%Ays+P zs=7SSZsVpq7ylVl>wMKkT zC|*FBQ2ln#Aa=tCL@p6QkdcZbj>x{vjV~;4EHRC{$ zeV!LIXe+;(-WMfwm2pz9=kxql_>THLOQmBWLq3yIIgoDE9?3fx0t%@zzIZmnUhSok z^5u0d@_G_)S!6Le7Z9mmt(?%M^|ld5V-tW@-s$W6;smyVn1*z8>ZCgFa@OeF9Q*Dp zorS-8#QNX3GJh}xC7Yj899oKs+&*AN_X|8Yh7kSzfijs<_GP@*B#k7n=WOnJ%0aDh zdXBblvG`F>e9<<`k6Lc(SL?LIH~nqoCZZhKZRU%Bk-L0)E_ymwC(9<(PL>%ih3uVU72B}1{|V{}j#&|plCh9J**D~lzmm6uCA0k42nuM+8U zGkCCBd*psjP&4{A4n53$Q1fhQai}=LNN43u!Jh8sBH912bT6wDyn$1@1x(Il-eC{B zJGk2S>>C8=^{e;qPCgzSpB(+=FgZRv`1tC3)TT%=OhIwI&XStaAv3(+eLOsT{a35t z1!YWj${ZiO{qW`xKXT9)zdblUJ~}=*e4TuE`0nuI!O7A4cUT0k%D~f8>p-w4!TrVu z^_#rZV;JkvJAlG#OEV3CBxv`$kMDo@;McgQ%Js@E`{c>j zKjjTuD%-Pu1lulY!y3*!PY6_m%JTK$=ZD{oxn7iR7c79N&uB5@hW*uMoP4RiImZe zt4N8Epm)RvlUE(7LB-uzqR`_pmUYMFGWr8_^YXa|ucOgA)$|9b=GU5)QgnEg^!pa< zc-M!!mEDgmEtnf|+{*fWZ)o4D+l!QlBR+QA z9BnGw*i@CFt2uyAyDQ^v=#Dc|rj3i0Q#fkTw;M16{9I7#dV7V}{$N z_ch+cd3FI)*4rb0NRMR{?NRb2XnfzyNr z*s>8^WDWV(>*_TDVcR9EL~oOy!m;A1YkCckWvOzl#Avu5CBF?Ra2KE*3{=o6b2i8@=gRklBc zOp9u(GhiUIs%dK{UMMDzNj&OU{}ySCg%VXLq2tS7h9gB{{i*Dblth;DXzek}Ai}`P zkzFz7*ji-EBw`@5KLzNPG2Mw>2mwb$O@RTGdT4uJWW44R7ZChRV8*|OBGZA@=MP~C z5}rjI(hZV5U=s5lvUpyd)*5U|^h%Uxpgq&UYty71W%?~ntL_YDX~eL$rR%9I&_oY`Cw;$x7q zN~j*@yRZ5V&$m9~ocG)t^A5~?H`=nyF)Xl2a*vYD+6hIT>`}@(cYd4dF2?s8D~`%E z<_cf-*aBG{o1-StTmW=c{gPlWG^#`LQk%Em)_X#EaHFkVISPW~Ax37E(|LoVjJaOotu{+ad~JUk)pR>|J3JvxLFaV`JgvNd+rg#&)=gOGl+7_5 zNBU#+aaE8GO130;rG^kY;-#pxpX##UBVN3hnwoS$UzF)gF>o^EV-3eolvXDO)Aw}~ zfKaqOD=5ajbu-`lU%de`04(IO8O#JSrVk^`Ep0htbH_+TP=IDo7>HFl zWOk;xIKxxCZ7LlEig6%GQi?IpHHq8Ks?49l{N6k*wqs z@hQ+fbWH&w*>|^Jf7?J9Lq<>HF(_~*C&QUt4MuB!Z`476mePl(bPtJtA zu!Wor=%#|U2(}}PwydU?^|Z<{^$dBI4Fwqs#l+6qO{P6^G)=EdP{e;yJ|VS_F5*2_ z)lK9~n4l?>D+yolQMPqRjacgWNASHk&N@l|!ySN7ScVpF2}E0z!{N(~P3EhJ?bj*N z3y#WMw(|6;HDq@Kt|H47*KiMoJr*Nr*E>=s#P*eeLh+nd4(3!>tjy64f8YpbIRxSl z5*27;%CeaoId2;nj8;gBg0U~Qc?(3W7N>TDZZL0P#XL1VoRQKJdFx++a_Y4?C*#=G zcCpPH7ksw2jET=Tjmc5(sxj=IJ&eS0P8nn`%Q~|RCAcZUUJCNzDLBa#CrYIYXoup? zc+|4}zop+^$u3t=Yl%ioU8un`D5^NNas9aMLu6^V8mfA>5xeTqkf|^$(iYWMD`-R) z^fXt)2mJ65eB@3fFXm;xN1R%%p?^UZQEMpO0MHmmar<@DhM@$*=f~SA6_#vsSGK} z32x=P$}V$Ux0osf-5_?Y6i~ZvK_!)}(DyFcRN12B$O;{4O(RZzpnGS%^4`n7f$-@90rfBIsQ z*& zQnno_%3UdK;tD;*X1;)AVSh2&~|x_KH0JmvV<*Gdh*Fj-Gc2z zCJa&ig2r^mQNVx?K<5p5m81pfI8He~4j9>nvjCX0`TQfMnA;V^*rHtfgS`SyqdlDs zB5iMlxM1yXz+?oxarg=`?h(tTB_D#^vBJBD@!w#H>mDLYo#W&$VvAZ=GhUf4Rdxa@ z`upYiPx})#|4kyAHJ-lYGxX7!=X}U64=47*(+O5vI0NxbCT`L~hV8X(zjb#XIWdsi zKf`7wi>K&pvQJQ1t&&mGNq_4J$VQ(bNt%`^hAar|N!Hqux9vANm47Pds!U!rPsIvc zr!a_n*;?EC8dhIb=hG{ys&m& zy?^)Z(RazagSUsCb64~Pz@b^A{@z>G$h%snAa^!#q$e z_@TlxR%C)Y_;hO=(JB#*BVJCi;IaM{FX%GevHNGdie#B%`ER_C$BxI|-|A&#;nv0ALKZzD();X3NbpLDgfc~(!&a1LO!y;KiV-)(4tfiRQ zB{QXElqp+M&+o*Q<~y(yL8U4|AJZE_bggSM zSko!=NW4eH-Q<`B-Sn_=5Go(cw5+gG?KzVgC|diSaz;tktnWj{^FRVH;vCA16Xw{k z<$xtic5LmCBQk-FPjdf;zZ37QH^z9)>XYP2N3|n7t>@d3h-`+q55m8!cY`SXZcBFa z-VxGUd49zSiP9HST4H(=lM%!RE@tk$o2{x63boc-%z5q!$i zH=zXA%C%+|WHQ4&yz92_oCxEFwAMEX(SDVW;H)hpc{Vpdo|FTezSN4HJ+&f~pVDa% zl?H-rcFj4Wk)Ts{)~@+@eZ>CE#Ehg{4Y1{()-`X2@#BcZhP7H9N-OczOct7rYQ|!j< z)Koc(IfOd}YzU7y_%9w8?&VJR-gKsW-@y<@%L2)16`H9L%sjrJMU+wOC30`y}o$Q6aHNc}Y zeV@XqQ~u0t33IFEO#&mh+^#9Rg4Y({ph(x!Wn%Vbk(?Ra}uIaqYxR7D*b?_H|!_cfaA zIQO8Gn{BidD~f%D8Vr$Hq__pMU!^4kEYgT%J3$NNWzts3DeYlwN36;fE%6B+)F)2$ z8T-4Zj;yKWC4}hl%)AIgZrqJVBGlCr()JKmMj`b{7L8b7+Z9u&$3!UFdxv$G&c5PTq=j4M+x=MPUFY<;&yftRS(*N0*_%Mno zcVEE)F5!aQA=M&+9%rATGTXUr`V65zXzzi;rE~;NtP|$>5wNv5uscxekxF<;bZDX7 zKiCv&7J=8py>sa#nZNa1%SZ>F2+l*olosNvgEL$GwfZ!*q&7Fa#&s<>UPckN9Bs;S z&kyk2TEXf+s6_lb#_1wrfo&q5W!JiD503Yy#oWh+ z!IQN$ZT#7G-?a>#O;f(-6YK*Z^+bgj;Xr$(3AyIZLN}r>&{Lm z|KW>jKdfYOCQ)x^eN;f(i`qLMb|Ld{%SPRU8kiQ2?@-1=eyGap93f?=k9)U`hk=)t zLcSkL<1HBrtyK;Hp`r@qY7ZJ%2zY66QeN@8wzaLJIcfn|-rx&Q3RO~PH1OASZbBKc z(TZ;SyI6?x7qg@|c{&*`Sqr@Q;i_T-#-JLuY&iPZ%{%DYj9)qAh24`K?mlc#n}Tf- zVlh4a_a^_QxTNV=1f9$>S}kH8mW;4>(`F;Sq3c0h^5l0T`ls(;UCK_4WA#G?r9=Wb zb?}7QROUs?9b^;uBK?MfbRo#X+^}t($uQQGlyPA$6$z+;D+u*;As*ThCJ4G4dW&n% zvySM1mVuzedkZ(44Owa7!{gGtxXoaelYKGoy*c>i@J(``KlIok4T-o)i;Jvp+VK*b zP($>Evjve*{YSoXVbKQNq1&@c%-cHOF>@Kue0<8uy;-$_=ve}w(dss zfl`^JBfjx6lS>?>u?3*wnc}|dgt^aKB#jZ0X9OO}%W1Epb_Vtx)`KXdp&6{(G#KtE z_1?I%VyF@CC=ox3)SD=&P{76pZ6(CYOj zu>F;%8e=+%RrMEME9r!>)|h+NMz6BP0%Q*S_lG85P`c-#@bW*) zz7;)elUmrRASDVHx9$l>A8dJazEVkOWEFIxOPrgf?W;7X$PS+B(^ZUf3=N&2;()^^ zwQ)j8b6G1y{Y~0Tj1IwJQ!`^eQBBF5;Ia#<3i2N0|BSYNB$Yr(BKPhA&9hPPOub?ruFpQlbpt44!meo{9v6Erk6tLzYcj7khM%+VZAV zL)XY?8M5qRyMbGP$gy*AOzxZ9WOz4LO?BPRS6Q~gEQj;eA73I>trwYFLI#kaCm*te zHKj`=&rze~csZ$=a_zOJ^Et?JQ_sf5i<3~404iWGu+Woqy0G?cJFMeyRU{iio!f-` zIQGqX^wDz$mHaX4LrCR=1piRUtZ`fS=heuC)Yy=FGfd7EV{Fwl=q2>U2JQ(!G-16x zCeu|N9MS2-`AkTEbJ+3j+vDW*;a`$(-+z2{m>j%%L%F>{)08e-uJE99EdA{?tbAtR zUX^u2GukDs50vQ4ZnNnc$y46>f>$BI0;J$nV6#PyC)p}zw9uE!6SB#lUN#`m{e@Sl zl*Vcq?ONB;NJ?fkVR2{L0|5H<8u~;{mX?Lr*Pz9BoRu1+U%8t4mYNToc$HQeGvDAs z?GE&GR*cqQVJU@V)RBZ}Qc6yeu>E*Js z8|UgAF(1J5#2cMl+dRfVVKfr33$s8mcw}J;cVN@Y`9f12$XR}GO*%%luzWRX)yH$);YxIonNRn@*!RjzabWyPyN1Q|yT;yu!fp6-vYFa+;p67Zg$flSJ9c z1!u^Q&z2dy7(WU*s)5{|CLN3jMcffc^qZoiC{itdqqa*K&9u~Q(*pal?8bVzFC^;P z_S8T=TTIPC5Yf#hv%Wh$gZ~=y6UDBJ*GJ73&bPFo_^Zg76^c((vmsquOa%vMUMsw7 z&gT~Mnw43ty<6Mci0}t$3sjKzGWDpuJcqP)X+#N{a|kj%#LuC~pu_*xBpKrOGib8P zC>uGaqO_WH;4t%th7-;rZDm3LMPFw<5S*klp?reSX;RMPC8-0F9xz3Rs^c~RQqu_u zb0!YPvaa`7rcDKV!_uaI!0^Ncm=4`olcfo#H0;)iz)5L&Z6)sXj+Z z(1ucf*9mqgR+`a3fjH|5uRn3}J?Kic*}Pcu03)}%7%}oj^<}!sPM@E-wI5$?N1WuD zp3U?k%?rkllLCkuccw0`tSLO`MbtW9v4sAw=Kdg)UVlw%{3?q_Xf_a=5r|Gcv&^zcCteK1vu81NKs=4S)qQq{(gj1) zsjp}aDMI!6ui~@3kaCLGoP_qD!UA zT1rbWEU#FatzKO&GqK7l3b#VE!!lBYq7d93MIV$QSWQhnMVpRLS7t}lhr~JK@bw0( z!&l|%F4#Y5(X+zT2IX2EvRGr*Bf{m5A!)c_&5OQ2hiytud>??y=jndf_)zShwN5@& za^|z81zkqYw+lmyiWJqk zWMAWXlNYe|+teKZ+AT49q;{kG4Yi!nlwlAp+V=w>FekYpvHb}qj=Tt(?kX1@6J zH&Xt@<3lck{A89C3@y~;++CxjmZH0TvEF(!{AGf*fp-qfW+Uo|YrTFlo4OKOzAu0>~JfivN& zyBru|T89l+t5~xczcd>;3{24z6ADn2M`Qu}Ya^zpl&~5yeDgA`6T9p4JXOl4qDFdU zZ3FD_{D?8=?6~_Hl!W6IhW-4%?Bn~rKaz7VPX!*YAWLXGO-vf+Oig49R_o{4G*x?Y z#)-6C$M)~Jym30E3z9((sz}lh*c&t9qL4i`HaxHS~%I$~_3-v7MiG^fVw1ce{&Ih;uZK&7~>YrWaVm z_*n{pwh)=YeX3CRy26@zP&zGTQ!1J2qRWECq5EM39$Rb4@!68Ck)2t+@!Ka26r#7*i*`!zR9DAQYNLHSw-EitT|A? z2O4Rsh3QsjxzlOw)IK9tJC|v#Df&cv3zGAupIh}Yc~#@Xlw9J`SIi6rW6|+~QejeF zY?S8>A%mXG`6A}LaYLX84{qu#9qr_dxHIxeO!iTK+Q(efj;FjIyfg8lo}4ZY)D;yj zuJKU=+_$2%0BTkY;Rk+iWj!q+RG$9oT_V>OG__^8l-;s)##4XB0;sw|r`}VbI*0D) zL;40M9tU@mSy|*~c`al3W?H$>Eiq{UL!h?g$-)#>o-hOfU-t$kjfD-*LDQf+rhRF( z3Ze5(teqa!$(er<_NiMc^jAHzJ3sox;>Z|`t}wEp`0?=QyYEkqWdxaOfW8}+)prIo z4BNM6*9{+|HKPS<5I$yF(9vU2e(pFE1HajD@04;fG%2(4T6rL(g-|U49sGTr!H!32 z>64LhSe|ThLiF#p`m{b`&||_Iq$UGDX6J5qmz|KO$$19qw{=CdwZ}w5hqg+R;nY+f zS=}BiC?cf6} zU-j*8+F2+Wdyw_g@)(Lc${t=}mtE`=hI2FJOjRu1kUT`!0m` z#+Ksj`u3^6bH%bW8AUN_vnfw~?M%N3Thdrt<>6yWCZ#mAgU| zj=~@amC)0@7en#fPGOG^&7aIj7fX}W`w=t8Rb_Dja8ek@G9zy%4wb=U@VIcM8;oBgcaI@g7Fb?e?MzHPUW z^JSa(zlc#kKJ(t+3Ot!a^5wWGo?eR?5us1)t|zD(wt)3OZbW>40;b4~pE%EZQhA!# zlk2x@e|s1V#U$$}oOy-YIxcy|Ol&;+ZQ0+jEu15>-qrL)*E&AldiSnHvgXCw*x>fn z+!G;X3i8i{bT{&lB8laS6DhV#O^^}mqOtjXQ#KUB`CPxD(s*AoDa3Xf5S0A%mBgrW zeX-#3=!{#U*k}4VTCa20FC(9uLNcn+bdPxLyMq)8f|1ln>g1wIXA%15aebkBM#EAB zHt1#q^VC;B1P(sbjF|@w2#;FL1Se$5#vO?UFyM5y<|E%s8)^T~#HWKjXej1dUP`_rkIAxdok+=WwZn{fv(MVj)$n94R2aBP&$4+}+ ze)mn!?{%IfI`t8E-8v$)tfz%N9NDntQ`MsPeV%>dQQ25J%@V&vGL|jv36Dy`Vz=)r znvcoLqAUHf)&;vC-GvmxpyHCtO*t1jWwJ3L0LY_@5&nUBkcE>XS&Y9DIw7*v1Md0R ze;!7}8S1r)hA9DPm#~JNSW4m;_iPwsjp7~_CG|oTXy%i(00?pA*g zM~0H5DubqwiIDNPNq}q(2UBdJD#7q0vz2P}aE}1@YG# z`-~2j8nwr2*mT@A!1_M+Kr&(}@iV0D;OSvCvzHns48$&)vJ~g}g>7L_0r@gbLjtlp zk<7MlTcU=F1sjBVRgVw8JxuH>G{k_(v%SwQXFw+zQge1P%qGDy#j<)Cj0`=YGj{(R zI?#>nKU~)#ohnk^X=9E8*l7r{t;?gjhFnsa@_`+qP5~3&&8Fw;1>4jlsyY)$sd6sm zDyyV4_s~eC;36r-D)L2{v+F8}tJNZ%;(;A*WU~cO3y=ZxBQg|d50)j_IC#N%onlr- ziz7=G$4P4`OGE`-;4+?3s?^%um};~G)vo89A;g}Fk%Y0|8=dC`-|A((L*jvLbX``p zyz$F}&hw#Aa{OG4cZLqTKt8ohPLG0sFqbjzqy$*qaXNllEql-5d9hc{>BpDuJj$tY z2Bcm-dN6j&RfZOm3Ypq+g<`+z^+!+t&tUIwv->X}5E}a7qo8%?uYn?|Q=_=)2eT31 zNu!SB=^zgMw%5`WE;k;$oLh@~tiH7z;AVsn&|LPzIdB4LoK_;79Y)5ET>hXob8KAw z%wd>zcn#`twkG3^D7pbyD3Ara_R@*ujr9IdDQ#z%aQ&QfN|Wy3={)II5^IQU^VOEXx5jzDasc!WH@ z@A0ba?TcQq>OHuG6`rDV=hm#zH14c(FAUwd#~)1IOM%Xu4zD?nK6eqWfMM zf{BKtj+Y#HeDv=1;lFP!l}+gleQb{wuWlP=ti{LkBR!^Ev7&PY2CcRy3yh>PM)N+N zzl`m4veeoZiDdo8<}&A?Y_xsC!jZnJ7m2+r=e#v)E*_)-ln&22+bO8ej6SkotQ+<5=ZJ5WW=O} zA@VwyBx7^k3J*m;|xrG@>DV+n9sO9*MB?-iNQOu z{`p9P?b^Cw8L7AT;!J!s_JM@C{~Gl!nnXzkxVB0Ccu(FRee1VOGU<`QH||Ozp`}y6 zpic~n=2c0DpMY{FTMOqFR0pFz{_&tX;qnQU4{;%^+$45MJ3fW)^7d<kK7 z;*)T90j}CGcJ`lSo(_uHD|FjG(B1q#bHC-xVMU6gNIt20Wa3|?>pES$dHb2`b4tzl zQAK^!B-XIi$u59%V5g8b5jpUme|oo?pU+*8zfaZAF;!Uc zvsNN%Ki7C6FPiI>s+}lRyC1yS57hWKdEJB>7^*>(Xqr4E8$M zrBh6A=6gfj-&H?yCp-qay{TmsuJJ62r4c$*7PEf1=P*r;%};3&vI3rh7MZ}3yv7JF z!2PXc3Q8X~L8Z4tQLIMAEUl^(&*iY5s+Q18U0LsRU~5X*<0iKgcDtj}f=Vk-KBtE4 z)`my+RKS!Zb42Y;SzU2H+$vjDVW(BjR{>rCm*8Bix*A>Qr!SI0sOaPd{TQ$<(LYE3 zV%X!O372oQ$7BtlcuWAIh8E|94tJ|N?X;jLs2|mlf==hB_?&H0zQMLGwj=eUJh;{`+%bNZ_LPd4QsJ)B9Lg-wvp}Hq8W*{gC-UG z@17(pJcBnwBqwCxyA7G0gNaj)?Jz-o!XBQIk;AHVYXI!Z%N_qX`k!A6>>HxZ4CCZL z7usiAxr)=_kw*^Ht&BZjYS9OIL=NQpORnODc!z_Rzp((F3e@hT6)W_PQ#K}iV+gh> zY3T6rUZW53zfXrZYL+C88IK)qa^G8N{GD&~SW$WqK$gKYbken6sM?PBQ67vCShU|q z7HW^V=oqs_^A-PI4g8lWJ|AMl<&{NJPfSJHDP zl_uA37kG9g%@!Z-PVi7w|NYKoqAOk#bcX@CURf_Zoy>unfr${LbmqgfviP4-S~D-k zJ3h(rZ(!YP3X~q!aLdViJagyjVkkBgefB>1MOT_BxD8zzB7hziV#e@^X{3a;opTr85MShu-p0 z^g}!mCZn5YBIR&>Br`*&riB=x)0gm26|m*DY(7~SiqH(xx_Uogs3f*{>Ur1B$N*A& zK7*CS6kE4V6e<^F(+nrcCaEXx`zQY&P)h>@6aWAK2mobUD_J?sz(GPF008fU000{R z003}uZ)at0GB0g!WOZz1FKKRSWn*+{Z*DGddF?&@bKAC(zw58SQ)bRoLaVm#wKM0d zI+wO-o4LFsQ>X2`=jUN439{Kxq(V~uh{yl??Jm9+AS6q9y}37c$C)-FuvqLa7Q2fD z2tN_u)vLRvxV-Ab$n88&=1%;};ECuxZ{>%gx)i@u z#kFkOqT2xW(>lMAg?J}tS5;lsmxXMh<8dlpb@Ra!Xz}i6*X5Na|^0pH{%1Sn>ep#1t z2EYycn#-~Ta>vxA!;O8`D_I@3^}3l!(bkI&kV!GkTR9hiq`Q*h`OCvalP~2>-F)z= zNv^tX)t-Lw#m&vlcnQTV7LRAu7bEKFRWXwlVp>;o*(hYId_POMvC7R9wFN-5;+Os+anR3N?Osx0yf7`Q3AE9yog0fxn2 zHHv!LLx08nDcZwrh&%y7L5OC-(J4^>CG?i#Ba}k{`vXqcQ0R@i8nus@7}!r{^h%u zZ(hUO590ajzl#5S`TF}*NDvcrBX3s?(ho!yC_*_$d7Vk=5wNHgLG4P;ibXL4zN*VL zYTjiHx?kZmSjlEtv^bEU#1&|^D2rv$ks`Kzrm&mIANUN0!{Oj%^}C!Y&9|CZRYX2& z?CV!=2@E)`agMDdP4B^?sh480u4Y|bm#xs%Fq-R`G{DQeyHbEFnsmsruDVwGYapxj z5|k@%MYXbx4K=-jF=+0xuEz83ZY8nNvc+GjCkthIs;U&IZc{KORa-YH7yz7=oos>z z25r`M>uJ^zv_WzPMXLRVMnwe*nwP~VY18N`zHy85G1`+uii^G z%S*IFoy-<_l@(P68bUwobvGDo;5LKFH~X+c7Ev<4LeFzIUXE;#$T(XysBK`0p)a&9 z2FIoBR0%;02CrV6on?PH%TBUmF)>3ohE7_UIFwnYYqD%4z7D_+PR7UMqk|E{b@C73 zIw4$x?_c~1rX_p#{O2EEyaRL&0kKe59T@lIU@os=0%cv^KryuY@ztAepTEj}czgU+ z_V)R^?|y=&KjdXA2ZI-{UcGz^#sNycJQ)m%g$VkXh=b=X%*Oe_DGgH$*Sls7sOoZ_ zwT=vrJXSs!4CZnnIB4?p=?5@!m+eUP0aOG$piNXDXgBoF^8n~V{3)asAMEI|tfzUI zEmp^0CB*rNssoZ89fLuXcHk-Nvl4cIe)LpK!A@v?z@7k%P%r{v(86GiF6s$&qo@(L z?iTT6tGt_CWpg=GQAmOo%JqsMfNjJ^vH0v_|TwC zL<0Ygb23bcvQOhhS?AsHS0h9MwdkVIKfkUkNuyY_Mb&obkg>5bSR?Uw55IY2iT)SR zprs~sl~7~Fp(=TkX+@cc`vSF8IBLTiXtxuOF6kLVo1 z{9{U|hwr?vQiPm@zDapd{Mb%8OBiX^^s7Io+INm|c4jeE+oWa}7V;`@GeQO==G9$- zsFFW~oY7|8b~^pxR3xD`M#XaNQ*pf;4bj=BGiAC+L#1Jlu*zQRE{CS4<+|;# zpQ^47<#N^CF+_|bCBd%Q!o~2Hsw+werf*X59^l_ALe?uJ;g=!76n-s^T+gQP*@TVGAkqY-W?sMHz*{7<7%Jl35CFN*=2D9D`$LD~Ar<%R zZN$dK6_e`Ou_2LX%XLZmj~S@MJs8y}QW2Ie6HPt%WA^MZyiuCbZGkdX> zKP&oE@wTb2i@BUz#k;3}AH=+_4mwbdvMW~Qomj6h!3i^E7_l4rSq!l)jxkUn?vZhk z+{dVU7$G}2@2FO&9F~z%ZQzB^+o@?To+%TrB?YgMPN4Mo^n%+Aix;fM6Y&F@-?9|b zeD*<7N)t;cQCQR!tnKqv0fo!F10f9RrnoGsOuYeJn3ghC_UCMkCOf*k_u7vT?@~2f z?$aor&!^Z>m8&;RjH$9eAIGcUG+qxa$Lf&oh|{6G<|wZj%4>`AdN-8U8fC*UHV$E9 z@HIQ3YY^Mi7v|KHrmb|pAICg(0_BJuQiE_OWQYj$bdCvi@^ZR{3bYfNt4)?0iThk& zfQr1GU^2sbIYJNfn;>0S2`?E?5M)pjB=38y93>+@byS1RodT)YlxiO5jlf3AbxZIt z>t$BvQ(0zR4eQFeyiMrUB^BwFQkr9yAJnvcc|s)>V5t{bDZ#KcNI~7qlY>tOBXOvH z<_BmgjyYQiR;a!V)boQ8MOLbKT%h}c;CrOAKuOTevltHOX=K;~cr@@bE3aA*uVNHJ z19(jp%`1#ah^pDTY2Y1bKce?Vy{_h&EYaYov@e#zd)9s^6qtqsmzruRfnl|sCeVwY z6LgXkOp?@j4c@z2%Q*BnfQCv-2e_akzZM@`eK^rV=TEOUeDb!HxDf*0&a$3@TTwca zj80XtxYvSuAkN8yLMi^fxhh~8l~>q%v1*G}+~nXf%SPt&JJotP@cLxZJn<^*Se@(M zF2)#gR`bN9-u-+XA1unNyB<7ArlH*Rpg#-6q@`i%zc$*E3N8?NBdrL4TiHmlGCsY# z0!ip?J$ThFSAJa6zTI#m`A*mHLA^g6>?oUU)7ei@ZkP+HVL0VFU`8pNREBXrvrL(u z5ZMgSbb1y%5#QBIuwhaU$O5-^Wi`8_s0z?dQJNIMfQ_Qf%^7Zjs=Z&MR2tQ95~3uv zW1TO+8sO%Bi(7|^5)7D@`3Hz`AdbSFFH}O>I93td5+ayJqXN*GCJcvYS3<_tAfWUx z>I!XwnH-)~FcM;>L6baTLtD7zXa%TMB{dvZCP1f3+ZGo`>cveBTm{<6Xm_tnEmCRC zmTKCwGj{5jdDpi!hJ>i82`KR9rg?l3F7w-jm86>gL(`8x@zwkYF*4mdr{5R;4A%=} zLzs`i7(&g~*Aqp7Zwq}zY1h+s))cF*4d*$;;2P1xaKjc&?^z5-xHm#-OIBo~RswP) z?P28E;``#knU}tR0{$ldBmB*4sdj&slzZFE;T~8w2}7Be+ILHr#kH)4o>BC)j;WBc zd+VwNYb8I)j1FT|mk5_Qdxxhql*LWKqI7x5Y@+?Sg%z)Ne9^cRAr8B?8(4s6R;(3< zsdIkuIo5|znLhaLm_>MJv0>5#^1JCo~h`e_NXc{6~JV=E+TQfXu>U=)S+fJRi zk#Z%VZvd)MpWDvTg~us17CRi2TSP5C)=sP|mc=I=8p>GGi9c!sEi%NWR;;*#(lgo^ z;e3)!O*TavB)7O&sN}4h*F6t|k7b6C0D_)ipzETox&sc$lvmuW2gaypsXh?7>VZbb>}Rate;LZOE?q^S`gNQcRW4o}2c zhX=w{X}xgqA+wH~F*?__Ib%e{1)k&5TymrMM7+dv!82H)*6R*4@zAxUAsLUy==pWP z)Jh4f_Nu&t|IOT>du%ZQvU#{6F=B%zFVqbZSgxXP|4p2+Lv@&0}={`-OQGD`b!1C1x|lmN_w%HX1(DVH8)APf>T|}gAr~Q_!-{W(V{{Id|+(2Dy zO8uzLoc3{{TH*q)r#609J(?X&1t^xXyQ=42jW)b)LNjYiR!(o1pouLdt&S>b!1W0r z?<5Rcc9^4Py_wAtM+m3*VD`u`*D&Gn>0QKfjA&jLzEN&KVvf5HpDFkT-!#Q zenrTWe_Um69C&Ehet zeuvFXc2*@Q8@WO)otx#bwl9FuXglDwy9MG0Ot-*V`Nk?Mbnn4hb*wSM6>BXe3F0^Bxvs>q_R1o)Xj9CJFWU9|LFq?5G%pn!cH ztUp{d?ScdF(EwR{I$cl2NY19ih;N`SIEzT&38P*gKE!e*2V&4Lt0ve#54S|mCU*aZ zl>xu1LBR09+78V!{YKru$*peyz%X-BmviqM3H)>#SSn_q1K~Fk%DkvACGfPrTrn}( z4WjTUzAk>Y>FWs9neYz>IP-Iqis1ORLU;jd1d`t~Tq-bP7*ewmrgQTOT>vCj2 zpRqvd+aNWsR8m!0#LFD71ko&7LEwh*#3MV>)a@)&X~-uHkJ&MGN9_S}Fi<3ql?(u# z)<<_sMXw4~_&RU0rOYcD8}p(S>w2lF#@U*Rr?6bDn?_v>G66zxw-e}VC+fBUJ8TAx zO$*%8X)t^5#F*}*psOopc~w`%jPAn7@>>`*1lC$gXWtaBw}j9LsZU}`xnTp_uBX;H zO&O%ZVxU&_jfR3>3AIYaiNmoJ&cPA1>B{hJhM|3wIKKsWH*A^(0z8)E)cSa?; zxuEzp}5GTh)$lXHJiWwPj^HwbK8QxZeQ0gZd&+(D?f(gW$x1z;h z2FH06W0t6kw8r2^H;dDQ;paSW@avxIf)Y!{78yZ(Y`R3#_ta-QIE}w{_xp4v653Tq zH4eX_YQx=(wN%PTnVF^vuvI&X7QS|9L71~IK6M*!YH%|IS7#nQrjSm}d_P$Ma!I1)GUdARdVjn#R z(WSxbjS_+YQwZec9R^{iDkxLGgODu(ptUD&Gu%#|OQQ=Wvc)p|7nY@$GSlZA{cqJY znX59N;dN;SkymYYqcFV3!wd96aBOuDJ3O^4T6NGHPw2|(vb)k>BSTaDQ7z??hc*DG z>}s_d=XRZ<*Ks(Ns2%xZ`%^WHA3Iaa5l3OSTt$%=40+*@7Z)yt!6=irD-2k1dL=rw zz>h8(NS#K-PK^xxafn;qu({vHV4_-OtNi4UGGwuZOof-|jp<|aL^szeOUrZbmV=1( z4iUR`5oQn(qK++ltkQ#sQXff*C(E#*XAQNwJ4M7v$JCNdiH$INR=|P+7wE=3wizV~ zJR^#Of>cJYZD6hoV9$%yU<(J3p94NFiC_;1*g zSLeFxYIXO)E*)xEuS->ih}EyK+g?4pC{Ss<=;vA}U|BUa4r=rwDpPOQDfrTlo?Ofw z37ZTKFLtsmMn>~A8Po&UIN;IqCcQ`+E6MH1@35}XC30`&`iB#VUzkh+V<-OW5 z!5|r|mA(*SZ2H`6mFbl7c0Ltn@#~G^Mjb3q-vQ!7sDu$#6wEUNfejeEHP^|G)yR{f z`lzn|RGXXDmyXrDRwumFP!)i`g!ZM(ugwJnO@T=mtZSLZJB`BxVXrMyH?tH9BM*gJ z*ehEK?!7IXPU0O!bj}i`M}sQ~)rg7twaK!$B&e0pCa$B&PqH38=8Y+d+zS>Jv^6dL zAy5Ns37Ol^p$s98?f}%}eVJ#A5!&yQTb>Re_V#8)`j|%tVDs<_i?QXw%^)&4jCRpaiLdHb z-MtKBns~Hvp2d*84Ry#kE8DCBqZ9PUrLk9so*2cDzdq8Xk#M9tTIGA)7+~na{OqIm z@vF;F0K2ua3A<3+M{h6FPVnAwm81D`1Z$FWpFbwh3$r#-`+uoG8YV;4>o6onGfF|F zsW6VrFm&pXq%Q7Gh00fKxz9Uc+>Y5MpLwQquphs1`a(GF6doUvQ@VrD@G)YePC|KY z5D^XD$OW3S8aza^o*7GY>(3+m+rsHHrI;+>rX$)Y7Vn$q)+TQBa@U_qYRRI3D6ha;<}#Y zQ+=jI8hCTYC0B37{8q1g)~&?T7}-_bskKkKpEI)*?d8xMjS+iK`#>3GP7WO4fzR1F z_@0aHF2!g5k&b-egTCmtceo8T6Y+s3a6pQ_ZFj8gdP*9T3~7vpdaUe4)9ScnrpmU7 zh7kdBVQ)T|IbuWq`c+Jm0D!&uV2`+6^(QQnp;#!92dgg)D!bcg?By1V}2=|;L(QM>H>Pb_y&_3QD$a(F})82x4M8m$TShYhhrUh=G|fGpBvnWlc$Z1 zmvD3*F~;2@v4?A$qY9RaSoDb3(JdCtW@#Vb3LZkm))0Q-gKw7a!-0ePgW!|Gg`OB1 zSiL+LTYm3m^MmgM?nAu?KV{s~tNczjI78!ZsWHm?9RK2y)bRSR(iET6my7tp1%5XG zGw^V#uFzhXlLyARIq5#OhO0DgIoD_6`XEz@7Z@U7Mr&x82AhCZXp-73V3+xFnG;Eq z0}y?C5JlA)5IW%(gORJYDDkw4Jp6LU#bXRhDVMV9ZVFI@e6c7hIf-7CB{uqMJD|Jz zYCmJslRgY&_|bO=1M}`YEb}hX57L?aJ~my^j8FfXW=tj=+{DLYOg=+OPM{3p}kl*#H}utvfUT0z7slx8$FnPO?57{VgIK`{3LJ}8NSGb>;lurj}ip2+~HfBDt>BpIA>N&9wfdosc2V&@gQ zHC4NIsF&tmcO}{olBo8?)~F<*>5Yv^#I9|4`ZhW~XuD%HW|)s;xMb+rSRCqJ;g-0N z7M=U#EJ80EZD^vmJ}JIUxJyN*l0?B^>AZN5WzhT6x^8Sn-|AHyU|Ybp0|7d;)F-WD zPQ{~U$Kge+>Pv#%(Dov|yZwT^umc@Q^(}Xm9nUu=9tb9g>F{rp3PNGyh@FDB^)g$?9ONfslNKiArT}SKqtdUJ zlcVv!MXFga6sc8$cr1!5=r1O1jpCN|HO9zzf^tVtJZZeWUkvIgEY4c|Tb*rJ%=DIY zA9k;_OlAjnk64e7koY?MKrHfi3bsUB2OOD;uw1~w{u%=Zr8tNaWd`5;wFvwE%|A$> zLBxnN-?y*SiN|T2#AffPQHZ2J-Z)m8~$y)@K2m5ug^ieT7orrbNm#L?*6HkGCKuK3N!2kW?a88R!Z<VW|vJq*9ThSWC788QDvCihbU;X6HWWX z&djFoLW&NYLn@Tt`229%_>5gh&|8#o0zwtj22PzK z3XuLLrx<<0_;}ZL);)XV_x3zdpBJ*(O0M342r2o4wg>wGPqgV{M%O%<=_7My(>ywpv7dBEPyV+`XD2laeDmij&wirW zM=YNslI`O7EFIXL!Cjf#m$A)E?Yxp?&cm!Cc|c6&(Cibe@ZB2Rd*vWJySr&6x!dXx zTomJ7*YKiE=$QzDR!{m;@-_Sw{yt}ZnsL26dy%w%S0S;^+V-+tf?3`r@9lihn*=R}?vpwVdb3ylVF zUpz0X>pEXtwBq>i_@VgU;@yQTeu4jsXW6y*KAV(v)|U0P_}6>)Mf7_sFY{s{{!!!~ zWZmR#N3p-Hvn!d4_i}nsl*@9F%LW=A4#fBE?B0E-@%|!j#I&4A0e>@5Wpyjcx%eS3 zX7Ap=TB#kzscBDv=IQVRKTDyvrH1yFBQE%U5^4qWB!1vR69K!?SD zC=lhO&7gD!<*RF?u~SZDEg_8kQQgT1(P3FH4whV~Ir#q7^OvvRz1#<6srWyNrED5e z%b(YI4c(kvi>v}%rr88ATxM6ItVOn{C46g3#PzDqTNsFeAvHjmf{_$S?^z_FcpT2(o>g7A}=11}T&FdGh z-oJYD8lJutPhbCs`1`BZF9t#anV=c@v8oY&Kr%-X${Et@os=E|^O6y2Dml&P`4sw9 zEY`?*ixPOhKryIfy~-OLNMPauI9n|9Ro)UKHgTu0nUg#G-P_ySyZ1dvSKElHE~n5b zdUX%FRf0gRvi3s%J1?*!t$qZVTtE-*&FgX{+Up8?qdt8PWFLrEEeI40=Ky8BEC%9D zg}us_1MxgtF8R~@b%lRkngC_p0LimfOo3k}k{J^iXtu7gj|-zHhD0d@ zKa+DY9!~+(c-(8`ay|fRRqJ+~&zh0#!SFqdVOawYPvI%?r7!lML@&7qsu%CedU_#N zvb`v0guX%=F{55Igq5yQXE}8Cb=kfGHkB(`fF!)E>$2{dBKVu{{fbx*p$|xx@%T4V zsvt3%X!L$4-W1Dg{x_)OyZ~7ssbuZKG*whXVt`m!RQkaP{gl@$cTl%kPZ}^3KnduH;&+izgXcUZ=R9ztCDIRFJXDW?nNXQf>!4>v zHk+Z>U{v!1%?J7m%v4fW()T@Is&Z# zSU{)0RUxoZ`}kMz9_PGUCW*UFx0sBhv0C5YXdsI@qp!Xv6My6 zgV`5P#IfI|DvI6-!0SON#EpBza?fvYzif6ta{nUEa! zH{bgqy95gjon*a&>IMhTcqX=+Bk^lDHop!H( zKYGio0^?_@_6%!jMsnE3t1p+0jA^+GL<^^uz$=A%Z8XQHHlM+<0`s=%Lxn_5cRM2a= zuE89=BaJqd)F?I7oS9icR)5StkOy7O<_AZQjt++~(>i`QgvUN{+6o<_nFPmtnN~4* zk+(+0%ry1KFlWh@?e$2^mt_V{8hB9DHQ?^NaGuB1YkuoUnG)F?{Eeo?i9;^}Jz&s% zRE<3>9F+iV0R0JMHV)n-S>XPth>Id~HF_G8mr2;hyD7Hj^Pb>yIAg)FL5N_&etT$= zr1_?BgJJipb{ToE)O20f-~)}@@)EvoD{V0c3_h}SQ-Ll10!-yuV_WGM2@($+g*8ny zz?BB=ifB+2f$OyF|(Pt(fEe4J$M z^kUrPzepGvzw^W9BCF)-;TeuoAQq-RcOImm#zFW9Kzy2r2PrNHs>@=gDg#Y=A;`~D z5C$Pe{SYPYlHb67h_9*ip`2zDU5Qym^?(29?4B!9QIAi=VIWJsjn=QxqAO-`#7jGka2Qba756h2YcVzal|*l(0efSH#Hy%OgPe~ULE{dJRgLF9&P#v zgVjZ0{X6sC5CY?Q&4wE7MOih&NnT9LbLt6La#=0&&C@x@J6R5<;rNxAA?YB^Ihh9 z=f-8R<|(s=u_|Hc$n9i=Cu?k9{HkjG2A0YXm$^$z2k9$7S@ulJiV#}27%t(r2i{wE0F}<3W-ZJk$Gy%jD@fa>RDYJh zBFQsU`Hh;p>@f52CG)?n&}o7Dl+xqMv1La4_DSFk9Gy9MfIrdw(Q$)!lhFUMLa9J< z+;FTVG$KJLp#zb9%$t*=Zb>B`S;%%=ENXB*dh|zGc$#Qqr#IOro{C>w1Jk;ba24T} zGX)qraQZ6;pT^d4GgbTS(%@&$v$F*(zX53Lh56{6Lglkz*3f!OuRm-{iiGvux-MGi z-3Evf2vZ&r;*x;UFL~8_aB7=B%|~a9vV*XpUq6ic5vu&X#nmKeoc?m1(-C;dz5^DZ zb{YaWi_ss#Q1-%d195uR2Pf^6!9Y9QLB$Ip1zfHOzk$I>Hmf4I}VI??YGlK2Tom8vn zlY(xs_)tN929;bR;^Auj6osF~!d|pQeZzi@&3C z6`6`sQA#_}Q2k2MNEV;9mh-~Kc^0YaTuVC`7!RIbWJPD z?>@Zr~o-yDBE z{YD;s`S9z>S6_UUWwY5AkEDF`D4Txu$kftapB|U89#+@)zyIRlm;LR2U?7eg9YJ@% zOt8|j@8dYCk#GfVg>BxTidtHK&{q8|j2Jp-oVZHR>GmJ)7g9unVg$&dpzRy-V?1~7 z(3)DUR!%H3PWVWdQe4e&jX-wm*?=0JilGy zcNk#ih1@&Zb#)6>+TswThA|cS@Z0ZyQk4o$wD_1s$6cO~PEbJH!?@pssV;@NaVi%p*y~mtlYO$PGk~Gc>B9vOs+r1^pdkVcb*m z#_1(>GPZx)=wfPYBTx)_?`1Z#_(^o4gXBQ_tM$ zt^?wf8w9hK6dd>2QEDFUP>#YLzN*ZZ;Fk54*=jP&s62xK%`l?$IPaDCU(i&>v`!i>H*}7JmUK$)-N4{a!b!pg zkb0?|Cc!}wDxuZ~97dU47uhN|Q)c70Ie?4@;}F@d>I7j?!f>iqDi}xuJ|DjV@Zbo{ z$G}#8gJ#_#iB7lYLp?Cz)pU<(cyBp^U1Y{3m|tXFYQ)W&Rey8U)9!&r;X8`@nF8WX zhlVh9UiMtRW|~1sa5(jvCWadM{mA-I0};IU?Xc77Z2pKCOa zh0eIjmRH$zgK{tO)&rn=`MzF9b4WGv!2bB;bRbIXN#KCV8OyTZOg_m8T~@;J-8CwS z&ssF+lz5jy>)Q+zb2uUnCa3!jqN6i33j7fAWm15gk#TfS;0n`tc;S27dv^;4HiMf5 zrH-;|iZsIYiQFntDBHJ)K4RkrH^vE-8|NVG(ai&?s(64k@H;ap_|sa}XZ#fYgQ#I% zYl~sxFVP!qY$GtVV%JTJR6PYU)@?z!9={MP?V?5_L$PT;hF3LWJp#3TqmcRc59JJJ zhYcr4vIRhXS!TcmlWV}HZS`WGVp3KQ>_lU>IVQjEL&wgC;+b0lREqqF8$5U+9$j$L ztUHszzbL zw!b|pkW^`AM#LUz1Sw|M`{E^T!YlwXw(}ud=2iuWS1+8LHQdKZ#BjQe0g_<^6y%Dr zcpRl|Y!nzn8_o5nL=u8^7SHI?-Y@>te$AoWyd!4sX_4wP?WUD;+!Zkr|1wistoS_| zv}swiv}ZzdNB35kJxi+rxobIbdCkk%`J#Y7v*DiVV-k#?NN&Ag%Sgrz+~_U(l7gBTD)s#e_r+7uR2l9p;j|U(j9u}i^=CV$ zo|e#QX5|rx>fu0OCJ_ET#=pk{ad>t%1duR*AJmR=5c8QVEPM?-@>fp9p%6#-zq8aA z`d|Fz$b8vA#h5`YU7x>s6O(ojRY?65wFjOE#!>`c_?Sr0OMFu_|VVf-O4z9)xd zx8Uloxr@jp`|RnsR`TVWcWzo4a`@r)xGKO>UXxzY`kjGP3(&4)*>@Pg3lvf!3~9^z zLA37WDO-pgXP|7O0g_muYaT;Rqk4{EHcR?ii|3O1K{rL-TC8*nf)oPTI^R%uHSZ5?xvAM_IiqL;v`MaSmjx|DYa)`{54gd#m}nil=j%)KpxMM zMerV{sFTt?r=vkkn;Wk{zlAAL{_#1?E43QlJzm-=uJ~S^z=G>vvQ?1eb6&u%+3-`u5ij4zgUCoe)r9TA_ z!qgE?KKYa~?Z;@)B-jYw_+Mk1RQB`K;xO*3hDsfG?qQvS!xPp_^?B60V>a&l=yy&GB$hm<{fzm zcP_im$zmq7m+|X)UN>$QIDH&4+E43+qp~Kh?(Urfo1X{de2*vD%=D>m+zgl{LSjpM zSbbyO$_C$o=f)-;I!br_J-JlX_ffb_{V#!P^SP-n57 zNyCq85_9WYT8{M@R+#5^{Dez_R-9Vb39cL4PYTn!Q2z!8VGFQ*36!L+!pJ7X}6QVkkjm%VP)M3DU8VF6cCsM68?O;EA9YkE?;4_|KH}%xA(p%uw zS*JS~nT~UWaB3p7SrW5o$o3<2T+zQ(_0{^syo-KTm9CaMeqXq{-;l;gJxfZ9@KTT~#>MC%108OrWEQ#KzU8g>qBRBqWvaUQ6r)qia_Xs&q8JShPAdc zqlrR04VW=ZHP!9y?=ntL+h=l--`tPi1qvu?jjN0Y5Z(HwERB1Au3h_ca~>)0KWF2u zZG5p%XONvwu*orHot!&N0o?mu1vVk)2v8oEkpYkjm&r@XG@GDdVssY(54BX-Y?KnCc6&+!x3=H%&&Y=ugG$=3d zRLyDI3trz&^I7ta`D}?zR@FNjPoy?B$R>3=x2g(Vj)3tS5;e?RxR!}Ft57gw3pb`j}`z*>jZ!cE* zlu6Yrp7ongYMe06O4(pe@~p&JV=r%ca*40;1&bdqq>%*CWf^$iL@oCVm*Rb1Y{R2w zNhU#Ym5mbWu(#||C{oNgJcYMxWVP<7-4g%jOhokzkYfHo^|2aM6ZNN)jrKPjZN!}U zUal$>BQUWpi}f~hV@EB;fUW&1dhQ`)4UmP5?z#nvcwUF#o*tVTLI1J6H z880cjc$~Oc{W&bG{&5f?w~cqp>g_hB)jW&8_=JD7yYNq3ltkKA`xSI`D2LD>Q<%gZ zU4csboz0q$)rGgnzvC&;rJT<(AJ~1~fvw7FT$!UTdG8?S0H~LaF?xl@i1Q#A4GiR5 zO367LQ1HlMNdjy_LG!4*~URGf9pbx74+`&U_Rw`)7j8Ut8($^u)Afqu-rG;e%q zH|fb{D#<-Rp*>O-)O5{vmVJ<8_WgBjWwHI1wBa5?HBdG^W5^x)@IrxBNmqDBj%`HA zbT*+PIAgv51TRXMOY^=gu5d_R&ZVr{sMD&KNHShT5GQdbK;iX{beJj7GWW&bk*h&7 zEG4e@C79b3?P~LMI2;bdeuu=2f+c+989Mb*^+6@Qlh`}Q5W}=duU}BDR8PAQ(kH*t z!LMGK@Y&XJND`PkF9AEMH;AQc`zYOvVty326XxL={BM|pRqo-Y^s-!rYM`5EIw=V} zfr1E=toDmzFPr%$P;NCI9sS?d@i?xXLbbc6NI1`!+!;4?`bK}><{NcSFJs5(vcoX~ z^>QgL@&!zVe1~2izPg(ubg5SE@6ZiOc;?uUbuj30m-D8S=X1Hu1&V%!{5StnQkVk%s2sT{f4%meUJ4y^Jj>F2cIjH1l|wF1e)J zPfP|28&Z=_1p=~Jhq1Un=3|P>>e9FkgmWLY4QuZsj$}D2V}F#UP>SNZ%7JU{Ojj^C zp--t*Nu;V3d#e)030~tQ4_d5bmo~VfnG@SqRRSAL846=Nm38_lHGR;H(4n5K>{5 zuTJnJ3UB(+T;y}LHy8XYOI>^os>z$ZfyDND`wj*$Wxy0%TZ>&a-4Wk&^Cnsz_)Qb8 z#+^}F&B-U@b0V|X%Eqah2h2>`e*Y|u_%ks{GXT)eAclZLFP?o_Chrfi@g9zpGyb~v z5Y+st8iOMGc;b=T+e&d15pRPhuEzI?|E$Ri)HQ9J#POs4_6FDhj=F@@UuQ|7xA6*q zjz)LeZ=Cx15WA-w_L6}=-lXeHt*XhFWpP(FT{5_Ry~f!UBeiUPj!cS%O;tS~iubfg z3mif%(v5P-z<`EB*2)B3f%TjDHLnbr9V@MNF6wnL<>lLV3Pnsu&~61^%Sg?>kj<(r z0ZOsx_AZVNheux?eRFhM??UzYj|lB5(8jw^&z_CfWwei>pLcUJRA~%jm-G3~zJ{vi znd_vTg+3)b3*SS!P2e{Y;~gB49cFMTkHl*#bWiLEsek#OclTcYl1|DCyg?TVNH?Q` zvn5zrnX-2Jlsi^@i_el$>OP>y9+h2VO94kTq^F>iblz54}E9p+2C3Z8&PcAQ%5w&izX$7xf;NSx2gG5779NsGtyO(f@R zmEs&HKj)m*fK>EZ8%uz`x%{8*6tvkt9fIHO5Lm;imMV=#uwRIxUmV{c`0Ew?YkCdO zO=tsEO*YFaqAhb&r^lcTta2J~dh7ZMS`WK)EsCKNwjFwDKF^_jOg4l5_WDDuak&u( zD(Gu=R*dmCqeLzY>TdTnUhJyFELxBE`)sHEIPNaa4K)8xpXQJEX;ccmI%efrp6!@> z96C5|<|UscSmDKp4W6$I_LWu01i+)aVQ<@{95s3k8ilJ+qFTJTkeOQ4F6AL&slvCw zTpssaV3OhzJQtCJQN<%Cs5+>$0W|>51l3=a>Rd=b0xhh2cLDGi*-BjRpjOS9wFB}Y z@)`I=oJ>mDpuAVF!Q29|sIw^^HUTOW&7#Qaat)RN{3*4+OXKHWm*Cu9N9`4Ikr~SM z$+UPmrbIa=SFVq=kleo1dj(?p%B+UJ^$s`TmDC;d@&bp*k>JVJIUaHmq@7n~r5xE(o!?2j_EC)sOxH$U2e!my z)izh_SJk?vT-O0r2+p{4+RgEt%(<*-jtFzoPUfd#&YJ6h%QVw zIi}7I0G@rA8$j@^~MWL`{{>zN~8w0pvAn?LfQetwc*0~h4OZ6KeZN6A^{GZlG2R%15A(w-i{%H`jEFn3?z=zdG ze{vQst>6zi1aYxlJwt}K2h_zlxm{JKkKP1xvW7cUI2;~^y>Nx@OBgLOr=cy?-jOaE zcwczfj%&%1F4a?Xz6#pk_Z%kbc!AxHu2mE~a(yf3Fg?|Pf6&!LwpUWQM(tH8rn+qu zz|-?*S$%na_8I6Fj9fPB356%NI|9rbw0jYGO{FL`a+QIUOvijiqH|N8Zq~lrn2r>H zw@dHAhusbobt)l~=|dC3tC@^*70qmeR^yfT5UrCCQNMVCPNWT@w|+H`h)}_M{-h28 z?(*8O4cw6_yqFnbDOXHk3^T^r3|}`A4_q_vElV_Gsxj7Ipru~VG`pW9dtdyEI3|-5 z=5hwW0uO>{K)c*JhiBLAZ=J*Yd1?qZ?5c~;HQoiyWHY;OJY`MwkAjW_le6dei02g5 z-_}xow+6=?gmZVL#a|qM51kDQbRVG#4OaA?fBRiO*dxk172TFK?@hguFqKCY2Xmqt zgPz$GJhMhCV)BqGI`&YQll-zJO_bE@b>V*K;r1gX_kFboiYN9y}DSb1|u*)eCjN8AXJ# zWCa3OcSd=?ID9lb`s(Y$TjYZIjE8E(d*f3<2X4J0)G_~6)U6$Emf8dX(-;ccj{tLe_5dkAhW|iH1^1Gs(Pbm7UhvS~nu6v`H zetv)Ue0oF`{LevMLkz5@{~eU6(4V`M*Bn7YjAEkhv7&{d_}!8bf4^Bq!&^Ol>{#?d zljTq!7`3aGyOB*N;U?*O>Y^7Pc)O{h9f&h@3DQT~?+U=Nn&h+}jm~a@OEA#4L@0t{ z?`PndZUk)`M5l{-4q-(RRN$kB5lG*T8UW&C*M_|MewVfY3Hu-Eyj^(6gkyBY*ltHi zA(f|D8`C#FG8%t3=|ehY%OBQ=^*Y8XF`-Jp=LD{BQsq9m5uJ{1SnZRzDaK_AO}f(^ zxnmFu+L=u{b6ZK@D&w}jS!8!a$c3|W{e-0YudG9d$pQNCK{Pr(lbNRd@$!M4Y%&yP zq)v7mPrrI#LN)>k&5Z*p2&V1|02yfA(g_*SM}0_!fvj9(tiV=g`vlAElyrOs9Vys3 zIa%xe(9wujqtkc!*vwvn(q(9^iM3Tz1rwfjy6h&dn7c7tsBu9tx`9_>{`<`=1rDaBNkkOmd)j z#YEG`S)S|Ut`e_e!=qJd?q=Dl!$Wu^VRlIGcbpy`9~}=5AAa%WEvAR8_w+!a#+~Pj z{)qNJ;Mp;CYeePpDytWpuTu!%c%S`>MLD+Q*aRY7&IcLC0^tM#{6!PuI699bH_AZirix9=`Sv7?ZE9IN$83(xc>L~BFo z?H2M5Ly-6F$%&8u3mi;Ci2r(Ga~naFXG;#>V2B0OnM1L_NSp9onyee5XgePo7oP&{ z1odtdVhotq-vENY4)5T;t4}Dg02e(&3m`vEWz~vVUQ5dPQU|#i!TdMARY5Dd>6C8Z zTP*FyEENAsVa}-z$#WRqp^A69szd**pRKQeb167FOWYrgKU!_N${IU%!mP*jDVa~* zm`Ine;)@g^J&=_w)=MaW+}{cLu{ajIhI+bx+zrvp^9b z{3!;yOf%3n_QM}dk$68IU6(%52Z3?cD;uFLZI5CLutBi!ci&l(J5MT~bMR2nVud8$ zQ6|kjnw#`szUf9o8cTJthDeXU4$N=To*`H(@?HaAeo~|5h&WCm)+lL^Lb7sYbBWQLUoTsrEN;F(L(>$DG3m z1^nRvAiFB|OT!+el`M{4a4n_^ZrB6OBeXe7EAGvc4pb6Byij@)%Ks+oJGkcivV<8= zjt30pzsP#`NmFpVsOOENgyLOo#-ncNbRs_HUM&O}ZVVl67H<$rR1MYncJFhSGscAh zm^_{C)1TDsG;UCHbqn!|{j%Eu9F}V7z2qvb-c8e{;_e1`R!*kj6N|(rb?S8R!;YJ8 zkW)uQZBsn>1Di$6X4VPU%O&c^eUKOJ?dTaxa=`G1X_D9D2D5X$u#OF_s_>6^u)&f> zP&#i?+w|GDk+X7a!3xfu?t=RypzuaOv`f9X?s@NZ0F`&ak7xO+QTyOsP#%dm?H-*W-S#fVRthnDO#(_)VhhH9led~+%{MLQ*5d=4% zL~}ct2;+9G7su-l+qdoP{V0E4!yL`BsKVq3hkV^I%XRXz8A|9T*q;DMe7pU8?^hrB z--gNkJ)g){+!swwek$$1;27hjq-ZBFBTsI1uJ=+4Cg#I5CMm95*A6%(5lXguBu+XW z&2q+)hs#0y71N#l$UPe)?b71yWyfF*VrTcJLp|5YH0tDkp1NP<>da`{IUV3Ms18EC zgL@k&zo$uz@!r%sOg*8P2rw>O#CpdC9l=TQQ>>L3`cQm_^Gn;QUjxGaW~{Z2K>!M= z7If~OOM6?{5Gi70{jYU_ERQMZ%`e6z+wna(aAJ7a4|7x{cPlz)g{(!-g{rLSq2r5D zQSWkRAKK^4e)@6W|J>KNL3U`Gt~;OqA5cpJ1QY-O00;nEbt+j|g-%?l8vp?JX8-^a z0001Ra&Kp4ZZa=!ZeL+?V`wgLdCWZhZyQH;zw57fH%oFzuC?5ycb6iSz(5KVC=N|OP+WncJrpQf^y8P@q390RrTkK)JxGIwAtYUjRd-vH7*~yG&pTPg@o%o6!#>2db ztGu{kzqNIX)jr4kBFU!ggDkn^MVVABU_V;ME1s}ZKAL5Dnokp60^?rD4y*CjExW6t1j%ocG`v3$awCfWGp_}S5u$9({zZ;QO(>?D~l)3{3Vj2*|Li+IXe^=uUv z2kB?3s zZ-cM|{y~=VvSbB+xl9UBb9lw#1;{drhah1ZuUKBNcv^7it8$ceRU{P*M5s*-RK`hJ z70GZ}Ifhb_fxb2Xm@3X#Xa9s9oOIYb`zHq{A<+2f;Pgk&J~(9`?Z5y2{?X~d;}iDm zefIF#(W8UYgJ(zZ?OnEi^keqk!O^3TaWE6G;p;_#@`I2GnvjptUMHM80!(rtP`Tiv zWRi?PuWY(R&zt7p`x#Dy1uy1Fi315voPlSvG?^zAIbvB?3fT;PgP*NVr?V*XDbL_x z3v|h07R}>ork<;0&bKB-K4;a{0tJ@sr^`huJ4Sg5^F;_Od!H2`Hy=MrM%9+4s=S=V zC7YK5*EA}+pk-c!EXzWeU}XrC6=zVz3lE`}WeVqEk(cG;tSZ1g58-i3fc9YaE>qr1 zvKK-;1^;+HAjwI6ZvY)OTF)(*kB19}~KgNA*>r?;-IIx4K(f zTcb1vFMNpq9%N+|XH~+Bz_Na~PuK;HF0m&B151tANe^Q_VNo>YRZ{VJ6a^(uCn3Aw zSCR~VK>dnm6&tW}3CJU(FXeuPxl@ko&bkWwKn+o@p>fZ5rZfrBYOo0TR5^XYa-!=&P;;9%Xzb=n#h0VPQ4{ zEUsX=9L*vy$;^l7l&hqib*S|uK~b2skbLvke~68*zxwR;zy0#-Kl<|ZU;g6tU;X?9 zQN%YvgShPObPR2u3O3qQe*N!%^7^MgfBpF%{O1?{@y(ZCy#Co|U;o2Def`<5zW$5P z@%a~j2G17l2{^H!{w_Q{fu_*v1^s4u_HDTT*JnR@{m1_TDt-Ou|M=@Kzxw8<|9nie zzW(BGzxh+C9IVegE;?oi9qO?el=W3%}0UwtkJBFe9S`S-8?<|{#NuFd@v zo}?PEOpPugRHjkOLU$M+qQ(cN^?@4CWIE52u^LhOqPr9s8fnppP9Yjk;VImDMIQg| z5jea*@#R#2S%EWolC1eSTI6NKRlp0@uzvEF(Kwk8cEw}Lv%yiGaXH1I2MDiwCbSP7istzNIk#=?sD7{GJ)Eu}1`W>*oa49ZG}B3DW>un!n& zGAVgA*pViH<`BX?2K{-w2-0{y9LKEcv+4|8N5Qkc_L$vsc1N~`Ou+Qddj((cxB{!g z>YwuXk_TXzdtJp>sC&!TjM6T{QxFNNGeN#@Sl*Eg`{yC+WS~9h9&}lT|B(LROD_hl zRf9EnCGrM&g9aafUJ%YK&BsL!8w;^y2mTr^&81c%82I2 zQo$0eCngOAt>BkAGf-B~;E$PncE}OeD1(eE;tL3ZNCdS>;uK^q zB63q0_EDVjh={<}5cZ#-!L*3S38V@UY&k|3iyWK?ZC2+mvO7aSA- z^|;vKr@eU0nlGWsQMuUE|@{pl@46#&vgJIYHI!t-Npgc&G>@(!CZ!_Zx`C9KuxcGS;w#LN_x0StoykLds z?(Tz~2fdx|-`{=PQ3NdpyTGz4mZM5;L?QiJR%Kt-2)ejli^rM`cY1HbN-W^vA;8S@ z4&_*wO$S)u=#?*vD%hixl)f(#T)hK{bO&hc+`s?ZyGRHT7x_x|K@LJ`slUqaD^F<| z8xN{siFejyNG&mc>&w4jRf1?5^06tNZ$-V}zS8%Zd#->~t@aGQmps@TzH% zVZk)7*vA_2(fCW)8`z8`|OehbQVst!>n zi0|`~!)pSGA0xe`{_<7YKO+?8d5M%c&w&2+@Z@b-QRQ)u@{xN^98pD0&`np2bhj*E4cLxB zfiPg4xPAU>R0DjrbFS*Gd*=?Y664N5;yLX6P}0it0&90d7ppMo;#7)!z0k5glO-j?isxr5#1~=NKvBoKBcy&(kr`xD_Bbk@i zl5xe+Q-2g!ff7IrhF=lZyyd*tSd!n%RxCVa>%W01aem)uNRqtSj1(U>FrrkX*@!ZR z-C#r^2F}{13GQ>Ezt@L?4d*rXMuNFBv+F=A3|U~i=MARR!>LG%qgXYZd#>#iRToP} zPWoeG+S-{S5h!>_fI@C;6m; zEnx6_@BRL%p-vkWk)z1+u~qz{bpjieZy5QmNg^!eE$W)yoEK9r*ByuZhO5u@jAvhZ{7fYjw1~iE)u^b3P9|pUY?|G|}By~s8QiXPlojNp=pfIkn<`oi}B;7T6 zth1a`_R7}gm;uZtpX8F=x0un&h!w}RPR`UBPM^u#OqQ~BUS&vjyG^9jV-v+aqPV>q z;u((yROg|6cHZ?Zp@1xPZ7&k)R_bC)lLwKh$%AmG+>%;9MVh4tO9!^N^OJN{ib?rE zR^{H=f=v$t8_xnEm{73G1_SoC>k^_Lj3xoCV2P2ELgxK0UW=)c=`vrIfgK#$nB3o( ziHgW=9Q^Q8+mcT@2e@HSJc>BGVC5`dreoZTLx{pX%Gv#py?q|C;Sys!?3im(zv>vS z8DB-zY9v;h1>rtTS^sTs=)#&KNm5ia(=6;+M8VF{w!RY-vCwJ4xac&980j>*@zUDp zsgDkzH*W$(!Qjm{fM#_V(G@918$NNt%xV}eCll*1ij>u-cA%#@c@Rsx&Z|Q!leg1s zL0oMwDdu7gKqjFU#chWfq_n0FK+P8`N4>?~<9tE&acy3v7BHc%D99%nn=J0_VVOO? z%o8x$B`+pvz5-t?m%McK1vyVe1xBuWwaSZz7ny{Frsc5?>v*+EpDLIrB_v<=O!t;5 zvle64F^{YHGF383V*YYyP@J!_YG=w#*D+enc))@lsA@*@)!-+)!3@{dT5p}K1FRsC;}_2I9tV#rB%6Y z8&hg&HOqDS-Zli<(jd!I+(DChr5_V5YubZE_B^w&c_Mf?GSe+ z)FF(#N&}?&bud`dK!06M)Os`=*r=|Cg1J1>O3bTmT4l_d zWZfFRDDFE18r0S_Yv&6i%I<4tk{+W?j}a*L`>lhW-i}%M{G&j(z-JuPX;;&O4Axc? zY{UQs-w0lq^o>x|bZxo+-PR9&Xvt!{}}!Ge-eulsug z*1^7x8%e0scYpV?Nj<%@k?zxFENkAl{+DJEUdHpUa|aFzxVgCAF_6vmiyxj_Q*tw? z9r{WOy0G(fp+PZ1Y3a4an|$0+xSrwt|F89v<2B;Y3|~SvH}HI00W?k)9LctsVl5EP zg6LXtBb%9vVmH{W*3xD}{bV23YA|PEwDmaKImfleu*v%En`~Fi(kHcRS4}4zBKjmR z@WrR7(bV9vo5nF){yW!^zN3fTNyN0d`wix%a)vuei<8^Lko%N$;q-iQ2V2y+M*R6z`O05gc7DcAc1eW^@r6)?QUti;!jxLE2Z@kqRD zG|pEUB=uObB{228(DzP>DP1HjdyXI?7)3$`6|lrrRO-jq+9GiS6{t-TytUY&9Re_( z^FmNb%4iW6ahh@z7tgl`R8X{bu(X&+)#TD6-m7tA%z?g1N9ngj=y~M}>pc=d1V+eC zYQk>U^$qz7k093@$5q_JKo8r9V9w(Va}L-&RO10n1-T&9ZuPl^SCB)JrMn-kd9%eI z0R?xUnMZ!Xq(QC6-YZN0t9mVBVjU&p^+xJhVw}}UoNuhOzi>Az-EPcgk$RIP1-sy@ z1;#JIslju(h}9*Av1qDa2=y=O*9kX`Mi561%@M_cqf?*l-Owh4FJQX)%rLG-Gqv7&OzhhRBB$ESfDMu>BfNgGexWDGMi48SyzHZ9LUApS;^%5Hd(kE&=^*)QhK)k{q&^w;z>^7bI ziuB7mgm1Vm;_&dOf*-lJe@lKuN_}!Q`WPC z&aw1Ix-um^V?xf##}*V)OfvgGID^y;M;6R`p{_Mxw}oT`fRue$;u{rA{OT=9Pd{O-3eT{It7P>&=pCM=rxs!48Wby5eSFtxgLSkP(l*M@3 zZFSwf4drua7!4e!X@oMlTSI|Y-wQV5*Y{KmuY8GV60%P7d+VK0{5n#X z;x}Q?@dI>udPw3e=RoRXZ)%CKOQfDr;zu96;ec)Qa?x4hZME=Y4wGp7>aX|V=bE!c z^||qkQGMR{tWkX$&KzrS-njanVGbYt;lJUEIDV{YsUN|y^nu%XgICp5Yxdy!1a4}z znC&%$1y_K!Q|3u)+`2YMiY)uK9ovjAL7G1v!^DTAParWuvP1mMv-!KvWL)H7-LT3q zE(di7l{Rsg6t^xhM4dGshT^SdpMTgiFz3km(q6VyRodd(@|$_?d_xmD3Qy?|EYVW^QOU1HF)4OlOXcexb zTK-}r?Chry-X|(rnsL{xj`iabeP*u!TA@k2Kz#9H92T0ha)k$vbmSncT zS{F`4d7Wi2r8VbRO^D~$ZGB}_oXgVn z;4-)cC%C&i!Gbfm%i!)#f;+)o65QS09fBvgJHZJAf_|JUAGz;2{ipkxHG4m^)^t}_ zt?I6QXP`YUTMZ~Z$Pp%t)g%#Q3lU~bq8p2J44W5fWZeCL(B@wEtIcC&m4!8ng!|~M zP;4Jxa`L>zeXnCNFG7lexubx{Ek*@-wzFn#LN2)T@(0&r@*61kI_=v--i0H=mHrP` z!7L+VvUQ1Dh;zP8g(qe7(q%aU1zvvfRWy7L64qDull_Qyrz*fUv|ln>pppE5`D75y z1%ymlA1{(YFyW?GO^o!|AMZJ8nyA7N(%wM>_?4wuH>QSujdV zP&z+u$h;wW_~$ug(cQ0c+c=UK;Y49c?N>o5SQWrcK8 z<*W&uqRI)u{W-Y_wKupD`mm+hUY|z4^mR9-X}7JIu`YZKn8vM#VC@bF;zPSZ)L?$n zgXb=~2v&k6g;tR8bQY1ohT4sNQ1{yoQt@yd9omgO%|~SkuY4l0E8pDQ!gOT~fF`Fb ztP*8gc6-K_TeurSv6h7z|GJ_Z(|aSyxtExnUI+}dS3n_lN_OXfBsc!*$s_|kzYE&> zbkn;-0f^|Xmf$Ru!m)yxVy9P+%=Uh#IKf!+v6R(8sI;F=RdJv!$?be}O-4x7YPNsP ze6P$rj)wfCt10IuG z;}Rcqo0CB>PzH<^ryAF@fW^BXSm&SpkxmkLSt*=DTWamsD<)UJUya+pR@6a+qSP_~ z!>@260RWpXRdvw+NbJ4T&@ngtmxK@gUlKmo?vL21PTX-({hD^II`x*^Lyq_cM9Yzx zZY2{#fl^YU@&OEh-+7i#Hyv>2Q2E6R3T>+=mCfvOkY~U%AZqrP*gl4?gbxjKCcgI1 zyuYc_D$;IYo`z+;-YXoCC8(NTMnAON1lcq=yA2FPw7+>smmauAspcl|>f>oKoYppx zZrqZC*fx|>a9Qz&#C>C3DU~?zsoPh^iH>t5d57`M#qq=O)p8yJAgxIt;nP=-R_Wrs zw72^lt-17)SBuC=HDpD{>IwE~hBFgITD@d|^ykS+A!HP%>;1xm6z{T{Z z%ntHU^P;L~=IBX<%@a;uo=1^RexZ+?6iDBoZ5C{qlzhs^Y8O13A3U{)jrs8Ad>&&bm+(+4Oe8FJ z)Ix6}YYb`Y!j%g3P;z=_=cF2!?`wa2)p-=~V~>jZ1{{^L>{Zk4%+YmjP(U0*p*ifM zG-?8nX>Ny%T>?qhSXAU&Jnf;5pU))HppA5IEU>a~N`*lSg_I2RC4i2L5T#g`w!1N5 zoR%Lp;nb@g5jFz`p^_aw1mvGA!{rbE6N!zgHH01KP-j1?Akt$&LNN%T!*~Wlj~!p5 zifvYPEtpl22%hOJt($UQ+}`}4Y~QY=_$?^GC}IdqJjPHp2r4%!PRN|g4F=~UOHj?m z?nY3fKyg46A@G6VnOcoLbEt`9;^2=~Cv>~1(c4m5eodPqr6@h1o9s@Lv!g8}H3&hc zs!?c-lnWj#5YsP0E-&GthxsTz(hVOgii{0&waexGyG-d=5jSEg}D6D zA2iZB)>VEmm}bQqrDzN`U$K{h-RN&#WX$-&khUhJWN!>{DoV;`Ao?q>W}h2+I1%#1 zdf3h#Mu)+ImE_%o9;+a9z{McN^x<1m)#f;r^8h_g3*{}c^#jX<8J7&h>zy!YDR6I< zX79Wfj&F@z9+{tB2}9CGdL4m12U7{GmAo?-Nfh&A{5ym#VTFa>51Y4R1(IqJ6bP8H z%T`MAj!1Z5*R&a(DhgTqmh|Yy_55|^oyVcb6$)sSPB<7B-taaFu@3K)?}T$-2#?I_ zguQZ;B6M7`6h;IQQ-3P2@piRSOonlzz#GififPHfS|y;s2kO2HT37c@D8M8x2StX1 zcHyml_mHVbJr{ntO4otUhZF#K!I*U$@3xf&!X&oI#qo3~?<45>Wd`#;c0+13D~4(R zELbZvCE3HgI(`e!)MUz-$nizD5xbgQv(*?kIZx|Hi+&OZ2h7OW1<`wZQ}oF3a;P3< z9w)R68`89Gtvte&m$Tsmy zpkR)W9L}@e))mBhZ}Gu79D?lDozuo|@v;T7yc-mG^cP)t@sDrX8n|7QsP&`nT#nCF zY4KV#8(yS=nT3yLS3RA_rpN6NO+O_kMwX7Uz!aBcE3j|cEj{9VD9eTE1edK`(rGso z1$bUm=BHaOu@=P_f}SKLgnEtUbAGsDyK>sfg~#O0lV(dm;E6#K$4~XtJNu;uqZ#Yh zsh4~5_R9%25~lLjIpnjGl4R%#Arsd5g~UmLqf@Q;C=5Qjx(1ro8KZQTEB5&3f@s#| zy`l!16j5-7ab)tyxVM-VM0(u_?4VO@ejC}>eCk@GpN-i3mVVSdUTg*KmVzVP=(3pQ zGG2z{#pyyOZ)dx0h6%JaWY+rio~}TeEfuYbl1bd^eC13CR&|*6fW&Qq>DcjnOUR}@ z{XuyJy3ZZ%(%?+2^O5k^+Q$V$@<>Efv0~rOPSnlaBMs(X3=uW>PKTm*-@pnuwbtpX z?^X61kOZne_O9^giH`5O$CcFXbwIDk8E_GN#!ds%;wBmccaxN(tt59n^KAIZkm4_k zD|_66aADoX?j0=%uw)MC&kQhOd&kRaNu% zo^raa8{NRR86S7{okF)YEqq21a+ZLwbqE3NFFXR<+!w~H+qzE$-PVn7pxmAZPodm| zQJ9(zPThKJrMy)D)q#%?o-(k38`NQ(_WCPt#g9ITVWn7|kGIg2Pf8$UUNPGG)}I;N zNf`+oY$UKFT#rNICJF2(bP04pdqfH~C1%q4*0NUvqxsE&W^E7*87RKcRhCv~LhinX zw>%M*BTG5!GR<=~@-4XD(1()}32J@_wD%s0S4s5rEXQLVJi%XKlrQGOHr%eb>R2Jn zb$`T7NbW7uF_ zC~S^aBQ55=1H``f>=W-luTI$oqdaS-Z`Y7l3WXoHerCy(9=*bceA`D3JLjMtrhw$E z2M1jf&umA&r^AgN+`mWHOFLPf(+Be%-DjRfZo|c*^S*A6yE#JMc|NjkD#E6+DZ{zh zBELS)#tLgHwZfI#_1gl4-)M$5HWyW^i{J@|$n3A^r=nBY`5~Ip?(w7!;_AcxVoTAz zJz@Rl%?b$-Y?g^L0pTe;G7q2o(&=uEH}Ow%oSL>3sW}bhg6+zH$%c6EsLGc;)`I4! zXijtvzwb`cJTxz9IJnHUcL%T1XRTd~5n=m|j%qybAwH6)ez+0@5$#QhLMWOZhfuVG zc%i7mhR+>t6j+fO|!ZiHg@qb5g#L1 zsAzVg4X{-&TnYKqjClLpj@a!jvENs_DE}~|59fmUyGj?qFwkKIM*hj4B!#aQjZ2zW zHGO^7q%S%@NU-TSf;mNQt5WP6-!cWO)&fb~yD!jekCTZVRPzosR?h<65BO1wBtqEQ z%ygJ}$#Zb$Rs=QM%hwSs@d?ZmytKebPfzqBk9RTtC9K^)PzzRDUQ%XT z^?LQSriu5!u+YouVDYHbwt@_HeN?WGjSCDwu?2!MtDB07o#WP7-puq+57@5yYY}9% z$B8NJ#@$8uqs^j$h=Le+q$|T4lJ_{i%iOl0W4_c2Jxr9QFFJ-G6!mEace~~$VUq13 zI5?oXUl-X&q|pFLjhF!h-z~leNPVymwPC^h@#gR|#A}SCMtd7Y_g}V zx5t~4o(%KA8&NffynfVrgPw0G_s2&EE~0{nM;9Us}(6EHOIKWQ?IOR{yNhhjV{4uT@d8`JDRD zxwu6%4O$9FAqNiPpxi5bau^kpY|Kt#3o0}=BzMf9HBZU!pMJ=bu=yy#eMJoj%o;%< z2%2c9z~IGcF-iKuT=&6K$|&@bfO49*`nCf*+2+*CFr?n0kvurwHJh<|l-Lc5*dc(E{%bS`CEp0QI>GQ`XjyTnjaE{ne>u(acu8h`i$ZgPah z8M8!|OMTXa3Tm<4JnEDA^w|OG>ukNyrPta3KA{jm0RTy`007?KXR8U=1?+6<@^a#} zLiNmkmIdwcOV?JS^i(dB1ZGx9ERBwTSe&}-P=C>|D^vLQey&C_^(U_hOAG#rhg-h_ zx~Q|$Rqx6D%$}qf9rDr9o`JDQjdPqC#9h?89{$E)f=2ZaUem^N2A)7_>u>^6+2({T zeth~B`x&9em#Ip>lcd@}2cvC8D8Dhd+k>Z8P6zfcBlQ1*#KDfD3Xv83b#|BRH%nK3 z{*yd?U~EKTMH;3|Z~O%^J^D(nF;!FSMqnT+&GhkBRB4<_ZH_PR9Orb@T{4`_x{77u zzIsClxuzNqq?jgqBrjT68JgLJCs;G^Gh4=xjbxJZlDX7)K{`;!%F6gbh%n zBOQ|S0n!*#2pZS6uxJ=a25|5VZl7xLc8@A(~cNfhhBl67`$dbLQ;y@{!L~ zc$<7;R=+#1W#P4m>!Y{(?ZI00(dE-vsoVB_(5w9ko2zCT&MDFnE4@3s7 zt@SbdM$FpU@q(1%QX=>;HTaV{zE>x#<%VA3gLpQGUsRp^P4Gs(Y;r|byYlq0tS>T0 zx}>IWEJuGl6@k?>{5)#zN4n7_&z~e3XdyXVrQ&J3U#E}xmf#Wd@Wb-Irfw!_6d%|z zM#b7u;uH2Nai|_+&X}S>x0{)KouO$3%_qj@5-6E`z`NBUnmFvP!d^$psa_wi@ngVL zZe}3RDvx}TNMI1>RsFr&l()*%j@KJ%C@T~_I8AIqcNtynT4oEPDKB@!xZnB+-4!IYBk$Lev+8OcL7#$*$( zZ4W|}milb!DIei?i}a=q>em%DR7|F2?YA@oe6jtKFDvTL!{g4jU>h4IPq3|xno5t| zvJ_Is5q*HJ(^C0~TcoXBkhwTQZZNvJ$UEwHHPq3KjV78h+^bavqwo>xaMGEZ0+)iF zi!pt&z8_lKvL;p!U9vKV!~`roTB3(|Y@Tj4I|WuNh`XMWzK*#kWVYS>!V2AfxO2I2 zXKVkwT*7al;ONz=Lw)avCh6fOKM-dDlV`O6l70yPW?1 z_v4{m-Cenet$;9zk6mUf$wle3(=n&p*kH;I-^A+md+&ou62C_l2CZG5{`}?1#~V9; zdxN!*$ICoh!tKKfqOGo*5l)52oDRauJPd8&%J_n{Do5_gzN(9gU^m>VeJ=gNr}~Ww zVvW1<1RoIDs&t9ovNbp1wykx7-2ExmGQmt&jd16L45#k8kyE={s1S;Cn{U*;eM$|< zo)73`_LO5_$2Q`mbG9Ppc;C@#|3XJK1o)HU3FipuZgU?=^~yd)OG1g1nE3g@)+1(Y zPRxnl>*%^bctB@qvS7%!(9uJxvajZ|k?PfChK@s+bGy}NpY;33_d3ie*C{dUB`PpJ zAE3`pmZuE1$L7EP9NV0g0ot%ormzLnNC9jhdM^jEnJpIZ4FY)_KfO3(03d+W&waw* z5gKC{lq?OP7a@gDgAr#dhXi_7<1;r`3-pMJC6}<3fLecU1Q(bMVgN|9>l#XaP{HGJ zX%7bi({M<45M96-hrD_6-*6jo3BMCZaRwE@2VqrD2j&kD5{$YZ!BGmj@vVk7Bt`Ul ziBpsi=2BC2sZphde7lEr(_+*f_zn|q+b?$dvBWdSGW$yTZ4v2>jf_E}KYpnjFywOC zpJNq#W9GLmYXghA6Z+LQ!A0dy)g%4j4v^b z>_`PoX=F{A2ELZe<3s@1av)G54a^e>ULdUJz=7c?oU#jJ)F?Ef=jIEwJH-L=PVI?E zgqHXCW#*YdvnDs`{wN>jGtg0P7)d26Y#AsOH?k9FOn}5Y!mk5zV%}2@mnX(9V~O}` z;kza@Wfo3&)bOb{(v=oQU=y`FS~@Uce+C&#R~)BVzeDE|-MHFqgcYsSSB1(MMR^fI z<)X$+5QKw>Gu2^fGl~vU0Y!8%P>a=wm?(Ir|tvvyV7kOCHpZVzkbkLnBW$ z+}l9DOl!UVIgR5wA4IwVNRe7UGwp2NzG^^+2BUt$)((z*w~(HSl#$4*QcU3%sEMtd zY)M>H0=l_|$CM&3A`#L0YR?J`n7<6qBO9=(zsCUXHL}2IuM>Xxn zrWgt$Oy_actKrCw9p$@9?qnIU3#l);7}M%^+p9)jWf&*w2rHjj+&_<=kkIxS&~N{& zw80hb`Q(!9{|-|PD$pIc>Wo6ses^q$>h+f1)T#}B-eilm)-$ua7_PDEK$35afk)@F zSRP*!%lzD?_GiC_yOf4o!!Cv-(cIbf&Pi@-mzn@Vl}!udW_K%>N}vY69TM=Gh39-a zHC?~tIH1RdAQ5A*Q{9oNYYi+Tsy z)GH3lzy%$$0(Z7=v zY|I7*mUfmd1_m!q*((76!2f|<>#wig`Ml8nm%!HE#MQ=>`4tQg=zehp zzaOWOv}ytX8(|Ot%Kw1>{r0lIga0RtF>rOUwE0Vc;7>|Z9eOR^@$Pq`ZvlX_Kmg$7 zKGnb4zbT7#{`Pw;?W|0V{}UDcN%T)r2!ApCn~40w--!MujPNJhKi%(oWy>sjWqZBt z^(X$H(nzm(*S=T$|4%CY3IC@E@fE%}`G4Vmix>YS`zHtfl`M4TZ)C6T{GT-c Date: Fri, 13 Sep 2024 15:13:50 +0800 Subject: [PATCH 13/16] pr modification --- .../pipeline/models/progen/__init__.py | 0 .../progen/module/configuration_utils.py | 5 +-- .../models/progen/module/injection.py | 29 +++++++++++++----- .../models/progen/module/logits_process.py | 0 .../pipeline/models/progen/nn_arch.py | 0 .../pipeline/models/progen/progen.py | 0 .../models/progen/progen_configuration.py | 0 .../pipeline/models/progen/progen_dataset.py | 5 ++- .../pipeline/models/progen/tokenizer.json | 0 .../mindsponge/pipeline/models/progen_v6.zip | Bin 57127 -> 0 bytes 10 files changed, 26 insertions(+), 13 deletions(-) mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py mode change 100644 => 100755 MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json delete mode 100644 MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/__init__.py old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py old mode 100644 new mode 100755 index c0be414fc..6c3da402b --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -1702,11 +1702,12 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): """ self.init_weights() - def prepare_inputs_for_generation(self): + @staticmethod + def prepare_inputs_for_generation(*args, **kwargs): """ prepare_inputs_for_generation """ - return + pass @classmethod def _from_config(cls, config, **kwargs): diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py old mode 100644 new mode 100755 index af228a1a7..26e07f1b5 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/injection.py @@ -212,12 +212,14 @@ def einsum_label_to_index(label): return (ord(label) - ord('A')) if (label.isupper()) else (num_of_letters + (ord(label) - ord('a'))) def enumerate_lhs(op_labels, lhs, num_ops): - + """ + enumerate lhs in einsum function + """ curr_op = 0 found_ell = False ell_skip = 0 - for i, label in enumerate(lhs): + for _, label in enumerate(lhs): if label == ' ': continue if label == '.': @@ -237,9 +239,12 @@ def enumerate_lhs(op_labels, lhs, num_ops): op_labels[curr_op].append(einsum_label_to_index(label)) - return op_labels, lhs, num_ops, curr_op, found_ell, ell_skip + return op_labels, lhs, num_ops, curr_op, found_ell def enumerate_operands(op_labels, operands, label_count): + """ + enumerate operands in einsum function + """ # Compute label frequency and number of dimensions covered by ellipsis # We do this after parsing labels to make it more readable and simpler # to compute the number of dimensions covered by ellipsis. @@ -268,8 +273,11 @@ def enumerate_operands(op_labels, operands, label_count): return ell_num_dim, label_count -def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, +def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, ell_num_dim, ell_index, label_perm_index): + """ + unsqueeze missing dimension in einsum function + """ permuted_operands = [] for i, operand in enumerate(operands): perm_shape = [-1] * perm_index @@ -326,6 +334,9 @@ def unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, return permuted_operands, has_zero_size_dim, dim_last_op def einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs): + """ + intermediate operation in einsum function + """ # Current index in the permuted shape perm_index = 0 # Start index of ellipsis dimensions in the permuted shape @@ -366,11 +377,13 @@ def einsum_operate(arrow_pos, ell_num_dim, label_perm_index, equation, lhs): return perm_index, found_ell, label_perm_index, ell_index def sum_result(num_ops, permuted_operands, out_size, perm_index, dim_last_op, result): + """ + sum out or squeeze dimensions that are size 1 for all later operands + """ for i in range(1, num_ops): operand = permuted_operands[i] sum_dims = [] - # Sum out or squeeze dimensions that are size 1 for all later operands dim = out_size for j in range(dim, perm_index): if dim_last_op[j] < i: @@ -404,7 +417,7 @@ def einsum(equation, *operands): op_labels = [[] for _ in range(num_ops)] lhs = equation[0: arrow_pos] - op_labels, lhs, num_ops, curr_op, found_ell, ell_skip = enumerate_lhs(op_labels, lhs, num_ops) + op_labels, lhs, num_ops, curr_op, found_ell = enumerate_lhs(op_labels, lhs, num_ops) assert curr_op == num_ops - 1, "einsum(): more operands were provided than specified in the equation" # Labels must be within [a-zA-Z]. @@ -437,7 +450,7 @@ def einsum(equation, *operands): # same operand. Finally we permute the operands to align dimensions as # per the perm_out_index we computed above. permuted_operands, has_zero_size_dim, dim_last_op = \ - unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, + unsqueeze_missing_dim(operands, perm_index, total_labels, op_labels, ell_num_dim, ell_index, label_perm_index) # Compute result @@ -460,7 +473,7 @@ def einsum(equation, *operands): dim -= 1 dim += 1 - result = sum_result(num_ops, permuted_operands, + result = sum_result(num_ops, permuted_operands, out_size, perm_index, dim_last_op, result) return result diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/logits_process.py old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen.py old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_configuration.py old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py old mode 100644 new mode 100755 index 5645e174c..e528c79ff --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -45,9 +45,8 @@ class ProGenDataSet(PSP): def data_parse(self, idx): return None - def __getitem__(self, **kwargs): - with PrintTime('get_item'): - print(**kwargs) + def __getitem__(self, idx): + pass def __len__(self): pass diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json b/MindSPONGE/src/mindsponge/pipeline/models/progen/tokenizer.json old mode 100644 new mode 100755 diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip b/MindSPONGE/src/mindsponge/pipeline/models/progen_v6.zip deleted file mode 100644 index 6ab624ef7199e8d00f39b5cec144f70b402abb98..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 57127 zcmV)UK(N11O9KQH000080Cc7!S(YJHs!aj_09XV7022TJ0B~||XJu|OFJE72ZfSI1 zUoLQYty9r%n?MkJ_gAclR3(ywqo_}*D%DBc1l7e>a2n;QI>Q31!0q+!C~?2OdziMd z(!Lc50k^~K%-r4)T(Q|%s%j7CXXkI>17tPoA9BN`+QCw7*(%R=2X7`v@akIQ6X^=> zHGN0x$o~U&+o>K2Hkj9%8?FeQG>)gR^u^>z{@B#yAm;)>o(hbz9=L?hq>F65N^cey z0vf&yTZD|-u2G&?!&>D}szUIqUfIGyiq9TyP@~<~TW&BHu$8w0n?~fuL(4zhp(h*E zr;a@pxu)hmCIwSOyNsNPa2)aXK)+<(O8O3-U-8wJ7{G)G<)crm<@52)F9wZFN zNv|~xDJ=$i^7YVWBoV{#I8yM|D?uy4n_bBG1rC%S1P4E(9W@Az*;cb=4|20*aF=n8r6A`mZ99Ju z)GuY?_a#80QW~Op21ypdg>1|H^jck^_UEHYT#!PP3gPBzIZm9LL5 zPaok^l3q^%B_?Tu&&CG*B9cOcScJSX!~+2(?*%!7IhB-)FI{zE@+y|u*Wnl#v@N-? zkutx`J;D`YLB>(^baB^>FWo|Mr zZEs|CY-KNFZ*FF3XLWL6bZKvHUv+e8Y;!JfdF_2^cjHEq;P?6!+0>gAS*BP=w`cch zhhEE7vaRj9UP*3GuUZd;NKk@B0&D=((R!NSzH#K836QGt(KBPbZWjgQ5g8d75g8c~ z+3WS5Ov_@PU#zRN$;%>HH~FG|(u4Oq;#Il2tMZG>M!bCf@=xMxaeSE-|A7CCZ_+#Q zCOt2qQd!-J{{e-A=R7(BEQb6I&WHveZNj`vRs^G)61e_Cl^%e)rTa+V4BHx;Y2YD75~Z}Vby{Neq(?+*8&5dF0(t4ti{%k_fBB0i+ktMnoh z&HJ0Qn$_Z6cH4;WvLfTLieee*-cqpxzO}4n`Tw-KYMm_b2D1P>l$B=rp2=X_4Fp6W(Bsi zE@oN9jjhrt{3E{%#a|G6(9z4$bJ52FJ^86O_}|dNUAY!Zn2w@s#JYyTBD&`JA`{u| zG+Q+yF93VX)gn&|7{E>5Tv9i(5g4%et8Ah?Z&E0oLiyDl;@B!D(uM%W@yOw%1n8)& zE}kv8Q2p%9(W}FE$A^0WEEWHuSY&lAs_gIUyn?As{DN2IEYe^!FcTgfU2|*y@O+Mblel)92_4V521}8k50aS z|HFy+@!;dfgLfxKhsWamNAc?YyVpl2NAKUkU*C#@cYhUsK6>|hC^A43bd%k#Duf?^ z%n^mz4DofGWex!Ik^!n$*)*T$Qy5oqu|~?fD1r71Tn4MGTIMxQBoJ`{lr0wdGH-|w z>u^xm&G=`0o&fkIEF7HKC-bUY@*K;zvn*TD7y0*3by=w2Wvza}WZ+Qr-+6%pZn(MT zZiT~?pWd%Vpg|))~}wZkwMphoU|6jO1;u+ zmz&Xlsz!=p2#lhJQNi?;)lmGDH;HW3tub2Vs| zAKsiKuTTE^;ZTe%q|O&*+Whh5lcRS>C&`b8N8f#aa-4ko=HNRhbyBUfCr_Ttvbjj| zI?wrWdC+(HO_N7;3n)AlJq*wiL(39l+}`{dyxShltlxw zr@uSPu3>#AOmCpL}bul$VC`y5EfzN{=$STz-R*8GB%3ocqeg_!hbmj4j)d|I0TBoYQ4)xGMQwg1_^g3e1p2NlvVCWCrpTB^ zI>NODGS)&z(v-=Uv%da2GTwpHW}a1#qj6MJ5a-EP(<8q#yW5*N%{kMP|{1QTd?wf%WD5QRRXLW?+mVp*8Xau?~RF6RE2qFG* zDCRx$8sGkb-l8bMm45Hmc`&mYA06QGZk6?C_}?)tX!=P|*RWSXu~rHa?uvnkFT{NH z;$LY~104GtdFR~*djA^KDPYO698y1ZL!?gZ%$N0`fMgDl>)#Vol-ii=!wxm$_SBw< zRXNcIJg%3h7B`mwKgwb2E3in|GB&tYF4Jq~mNmTLrUr2?DTjyrU)N8pv1p<57In+_ zB%E+JS}az5OKJ8rif>>uE4Us~1OiIBsg|A~jI@?K65o?-nj{H;kvG{gNhW%PZLs`4#LYsow^700a64+c(%+sIRc0BZWV$?grfSO%6ci z7ey)M$^;qR1a_rz-v?m5$1U4U4nsN55Tmt_jaAtN?((V2)m?K5J4Dm|h$-aXPV$x7 z;f(Ja;QJ3>@fX}3!|K_E?R5b=-YXosh206XRl)XTaW}%92myCG`46(AjI44aV6HUY z$(ev@yDVwLD|xWU8~iI}j8M1d*&Jz1LiY^7`%2H5;o-@u1a6cUu-QnM0u0sCY$d$; zLH&n*C^Aunc1TZb?r>9z(A2PD$cpKmfUW)XsEj4SnXVV8fRZ*q^hsaHu3n-1L@fcvK-NK` zJY9T*{G#@h9(=4aiB*-hZg#m)o5H85?!uS&SWV$iXN1uT+roj5G_awpAo_}f(9$C7 z(@u>YMDqW^F+@3m zb|n@d@tIT)?}r5_nyNbOyI*coDeM|n@erj|yw8?l;)5$_c8fhfZ!d(;jyNJJn3aYZ z*@OT(2F&s^nq<&9aXLYIph#A-u(Nq&D5o*Lbmp$q)<)th(01-+B;GL!MnfG3$1>z0yqVYsjdM~H&VI+7=M=^PRTo+Bg)GCJ57u9*pWMZ{Anb7fsXjelN zztJiN!UTjP^`M^#B-7w0LbNJZKNBd>(kH^OUd>R%{5)WAW1k3#wG)!yTv-6Py3-Yl z4^O=+3f#wmCI!TZ|4=hZS#k+o*{XGw(<`%{=0$2a2*JueI*fsk$ zoKKRH#e4LWB#Jt*$Z%ph%>!YjE4#R}Fg6nBGo9*i4#Wx6oi?Po|1a@P+VB_ssi zl6TrI7_w_Na}JSNOXAX${^ZHm<~)u~67s}8d!3;dK=r;@-0_yW9~wD4!^V4QzW|+S z#PssoFsLR1@K#N%&?pKdk+53Y(nxDlx~-1s zKP+eh6%{<50y1u3Pg25vF(QzoOFD>Qy9Tr%+Bka$e(9fL6R^)#gYiY;-Mo!~Tf%u+ACU|6|V`*H$ zP8xmiuwm*J<*;hF4|9M%IQex3GmT$l*8@86!L98S9W-G7eVwOv0c&w-TT+w*sELzu zKhkBm=yi+H6==0K`1NS-cmDWG@yBpgeFJjyw_l?ZJD zCGzR^9ewV!f9`a5&G&D`NzJTpVyD!hVX40yJ8fqg|7`5^pPk@;XWhRT>uy=gSz1jm ze=FWio&V0ee@)&^DF1DFH_!SuiZFHiOLFf2Jc{r?JGuYPxqmUv9ooq-i01CMV%-$N zbN6}7`%|d$bEo4Mord3vqq>vvyEgobQ{nG??$_gUugapS(i&;xmsw?7@w=&k;@4+- zzmwZ9Ms8cs+pjNNe^X-n+?o088L=fQBN@I?)adCZVLR;Biq4*})yiwN-IO(KJFPi; zV)jv9Yuy{HVb_LRwi<)6+vGOvGQplb2}s&&d@B|^++t;%JyEB(9TVrGHbhG}q*ee_ zuPO@SxG7f{m;9X)?XDEo9E!K;EyoBev_{-s*?mU4Mmyh?lTy(1$P7zQp~=_qiV{`0jP(_j zAvcuE65|{&d4ZMBg5z0i#@b-mk<$hfffYj$btBsmA56dwbx2&fWXBj8(H zdAOLN!6M++N9aY@Gaq_G@$&f~tX5<e54npJU~0>z&%G7{etNm<#%g@Vi4yA~Ic;Cl-aWI1<>6K%>V(%l+!oOm<* zt))IMTyi>u^^&Nudo8t;GHqH2KgN*66xI;a8LuneTyhk2IC{RwoC?U|jVmMyHi=2;R2T}U)^H)$LqF!nsU8{obD*$JQVydq zvZ^HxXi;8VWYviLI{O%vT5Njfan>nMw|~hKT;kU z+HlG6T2|>RJ%dCpJG^PpNEAXo$y0QZJKZ|D)} z8cojT8*v(v$(f`R@xhPY#1Wc6+PLp2rjRPpm$$lDWBfx+$&Su30+)7oSk~YNuWm+%j5Kg3nJZ zs3pq+=Ipxa0nQasl5uKptEHH2DMvV>q#t5{Q%~`u{LT#BbxS==UX2}NZ2P*5B z`AIBcRzKRkDS%lFTE&QJCz5$m_9Eo9OG+z}-NG)O)Cw*R<%b^X?Z5?|RmD)Fa01mx z);<=eCZgtX+6CtMNNZPzh$tZ2>lqrV5lWY0hCv^W3MgR?BQ%&{-(;YNN!CBqnjy?? zKELD46d9X5NT{}QzkBF*Uzk|RE{q@&;JZvQWk&GeFAb*?LtI)bAA97P-?OA5d?=(HuUs=7*wQ!}@|2tbVI(L7g?E@7VmDEeP(4vZzze6P5(r<;5(&&Sz_9yJB;& znCU>0(LJ(Pxd$4GUXg)d>OI8x+F$I14->bwiCC@yeH8PHj4{Lf4wcxY=5~Qut|`^T zeFTsP3i%7HQjx4)(IWKaKz^ILp;b!w%JeIs3hq_e1rT427Jp0~XQ6x;&k}|>BDs9R za+%H~lQaqR7=0%nq; zSqInxES~GE@8D~eFUQYaAtrgf+n7Np2FEIbX^$v1MaJg$mcY2|&4{$FX|#)mDxGD# zw+nWyuob7oYgu9K!OOPH>ue)hgZUwdlOsGe!M$%bi{w%09Vq(^F(;iFLE1i3G3o4j zqVgH#$V+FK$hPFDYHb{~k|}hV8WNCK=UJl7MzTXhbthB84OB=kk7q!()o`deP1m)vHXReG z%sjL~=JqCbr3vMFVu{7{G~+1Cq7)N7yA#rERNfmWy%Hr8LB)jv&PgF+mX72$v1fp- z7^sQm#ayOSCe|*9wUYqI8@b`3u7-04M>z@W7S#7fDVgicUD3^{; zPwvW<9!bR^a=MOx*T1k?Q~wVL=S7UT<*k zQkR_f&vXf^eMKAlMkY457Wz=`)!_t&VhJs5B2XS68Z`oew5!>4PB|H^u{eX&(`E3czayhj=WDrJCfu}e4_L0X zFMM!iVoh+mnDyvNxS4E}o2UboLI zHV5@sBH-SF^?{5_2=u)@=m+aCuNL?GmJ?pmDU3XGYRWiR+@Y9^l&K9p6ghu>_!VybSjURXx1V>vf93TMH*_aLB=fuT z#%L1ia6fRTvUpj9i^VshC&qTq0Q0$1r2gcLxR+exLc2|YaIeYoRb`9p8Yfy+|B(h} znP-iCkPKePjz}3V$|X9ExLGU|r23B?*$y3#3C3FOa#Zzl9790g-H8;!P6nltCsdYxZrSbJW|_0@OmnKn_|L?;i*$)?saGb$FgNz}whGQ1tUr zHsXFTnI%cgV1^yPl)*rJA*?*)`8*_v5#}++QPturKg+tyc)R-=3d^t3e$ZQt-Fwjf zL*S2T02^RO(4EHOi6^7YztqLTO3?~0>}f_;GcQr6#WfEMB}GxT07ijPDbsS5BO$U; zA3l8exG&T-Pe4GI>+{jHTt2&8UtBX>qak%ueLqdZcWm+OK|D-#P}lwCSeMMtYU<9-&Mci+=KGUaqGEyMgGkn+Fptr{Uykl*!WaG+m6xcED*O6G9ziV4zNMK zexjK%B9?UvzI=|i;9e3{R#pQMAr^1LcfU@pMxuL{mR3Mz1h!^X?qRgcv3%caJL(Az zdBVVeAe>MvQr>Vg2?21(17uI(=jR5kop&_wvo z;%&5zjO&7Wc>dafe!_QnpdG&Sdy2-WI0*UUNWA8w7@iPmr&k5!p-dwUk-iOS=u-vH zjk2;@#cP+}@w;JU!p|MOujR08NBl{{&7KEQaqI4XT>u%3=j_+(T%=tjL zuzaRv`Mge47^SkzMooBQWvGPm%kykj_guN?0&Joi5K8CTfru*GQ8lvm)pvW2z8vZ6 zx-2WL^x>mRSCp$k%r(Ju@;FR++YHO#=n%89o{fF5d?Egmh2xv^{Sn@X^jo}vmybQe z71$yoxA%(DEg7A`YLiR@=$quA^kg^!CTP!y_lAs37OVCRJAPc6Vdl7yqhc|ym| zUmNxvA(gom>_Dyv^6bwZ(k(S5Qqi$qITf8)xJXUsyIZv-_$|^%OUd~5YOH=diY&<8 zJTl`gJ6WWGNNjT_Rs424Plv=l1q>sRw(<3`$!weQHg#F$d6!Gqts-PpKjIaHPoU#k z&7kjUXRWrZ3mYp2IU?F?v{3L?bwAl|5E8s)J#fC6*Mst!F1qDYA^B8k-8eB?vXnP>u^*>sPJ%N}>C=p&<3*apdDgggbxcbMJ4z>9~E(Q89E!UI4iB zCv@cgy=6ysC7Vj^d*~e73y2Tz$jt-m)c*S{3cT8Ek1dADubj>Xn@8m^5kJ|0`!}@Z zz1R{(PA$JPC*Oj*Qo|Ak1pDRVz=#;220l=J@`fl}ZXpAj_x^Fv#cf&s`A00Dg9~Fz zIj0PO+}2#ozJO2H2sU-;Y8TLUWX>s-@$i|yX_b zLhfNYv_~~;s>K@jbEN%@O-B=_J*s`Drrm`-Lvv-jMn^~ArImtjt310lZlPv2Rd!13 z(V179L3V-p7*k{xEJ&3cG5chVE-CP2nt-x?C-8Lj{2XSDW zRYk55O@&upgEq_-VWS|U(MO%a%wYl6m)0@Nx=^o-K?6mAavJvyop=Xk+;7|fgnCp0 ziZEVe%xw#Ki!K67di437qt&7ZzBQgLjMp&8MMhS4mn?R)ohy0*gO7cVBBTYWGi)X) zvKu=On7g$Zc>P0Nct2yuYVy0oD@<=2-laq&iUtBGA2wU-X%kYE>QbKc1mCSyHKy>( zphQ6Bw*F^ z+3|6fnba8Q>h;8F5obW^L(%}M^qhLcFHZM|V(-OJ?7MLlL{NKjkK|as+uGN{Dll$a zG!(!y*tR)oh>f=;rJgqe*q@-mnvGyqnn;UND{ey(Kk5#3e;V7&2*KR>I%8M@A33rw zfQ1G_7={7%F;*RKI3nS{i1BIl^c5w96^F!_H=HICi`I?Um7`t8-)`(CK304Jr2|K+ z!0|s02TwXs>sq5avFj4f;}GmJk2pf`*@%?Zn(mxz)1QZz9p5^S53J>LB_A>T40g>C z+7;<-UXS-)w5@5?f@OT_Q_HZH*GuUwSdU)}Mm31!>>pVleggAD@AW|Jx-e^rTf=m- zG_Nnq^f7%};IMS`7lXm5$l5t_toFJv>}CKeaXIS~IXw?oR78gCkP(s< zh%BWYh(sn@Rb$A+2M`)Z1TfYCdXPMivybQ=gPd@DCXPVmnCFnv#$+=Yw0N5@*Gst! z!^PqoOs$0{Rjd%KEkVlXv69FW$Q#FK%3%*~eD7APTCZsIwZmUV#Fkz2G)T2k~8&Qp9HAnAT_*njIKIOl5>Wa=-c-_>OH4k`r_yOLh@5Xb2z@a!rbdjmu^qhm9{XaCtm~*j!EZ~& z5;6&voeq@JQFz8oby7^4@hqRuvkJ%#qvtd0B$-&VtZDPlPxGB*$4IV-k@J?nq!AOI zf^~njFRnQ!(GY(4`6hCc%d@U? z#wG+kwaSqZGXzk1-~J};Yy7P>1fHFtu_Z@hIhFkw7irIH6@;6^i03^uj(u^@zaA!> zme0;Vi*c*+JL6OuTxHLcVF74QglEU|-dnSYDMgkO*2zw&`olT1+a8ZbUd-xN7>Q7x z<+95&X-(xHM;I_^=h35#>9RQcLQk~sgIqTIRs(xV2yreIxvu2Kb6SElp(GJ=lX(JJ z{Bp&iyo3R-Hx8xNY{cpd(@>y|k5(-j@}$8PH?TV63r>I!Kj7yIGeHO0Xo!B$%VSzn zvPwVCjM<9y{WcxHYnhAiMeT${;LeWn{DDqcazwf9z-`T^-L#W9L;Itjc9)&mH9mps;e*b_Zco;fjpIPTZY-`R<^|n5HmdWBg5neVn+^D|nK7BC zud+Kklt(>Ldn1mv3;Hch8_`nu%zJ+=*NVINwB-l6drLybb#ns5ghP~)1b*RkC&@#E z0ja0A^%^?PIAH3#4<|3Zd`cW_vVT|J{9#i_zj}6<$K@EWi-{?%HHfSM^PcLonJ#}AV)pP9;JBfasb$JW*8U>D zQmH^PMoBr3blY&IQ+Y_bOzSJYmaoDX<&t>@C{37a(^6taK}*i5qO+FIv=!zP)3T|D zDR3y-!-mH89~%}?JPcfmt!}FcKB0}UW4Rbbgy+;>J|pn~BmP0g<#WpkHo65X6vxV)SMu-BXY9JEo2BipT?0De;rP84E9?p0ujQyMU$DWl|;_~qX zuz9ZIJ0}N2v3x=@^Cs;C5|}I}i8h=3WIAPO|JxUD@f&Y4-}1F=7~=pliahBbG8hU6 z0BetvZmSdJmJE!td!!Ya)t8qR_ZY20kb-bXWX=-L(-P*a=M|w`RB8+kDTQsRg9z`7 z0T|Ox^_HN?e@d%Z&ufYbpRcD^SyL+~gHLEV(u_^a-z-YN?L5ETx95xZe|#knPntPb z%s?MgfogiS5Y#e!`PzWAik=;OBI|1GOqeSamM8`if8xu6De^u%UVnl3it z@sm(z-LClf@&63yv-*D&3WF{Hr|+_x6KX1s+@K0RF|HQFvl@tv-yt*$OC-V%iUch9GZ!HNX&t2X6q zst?v&MpG4Gzb#v1nl(%pO`_g6C!*R=E>|$f^M$-AzNlRKtSEa!|kog zTqfvjtov#z%E-h$@Rb8`k?$CoQ{$B2CsW_@x&&(e(OLCta#KG`&nQiPj7nn4?E9CcZ+GZ(7JR8;uKlMjm7Cr$b7(N` zbOcBjR`Jo+ZbH)+ zBfNlHx%%D6*i^f@RasWSJykaDTXM;?C=2cl)cLkqWyWLi(sQ%gfj8KIoy*s%Ays+P zs=7SSZsVpq7ylVl>wMKkT zC|*FBQ2ln#Aa=tCL@p6QkdcZbj>x{vjV~;4EHRC{$ zeV!LIXe+;(-WMfwm2pz9=kxql_>THLOQmBWLq3yIIgoDE9?3fx0t%@zzIZmnUhSok z^5u0d@_G_)S!6Le7Z9mmt(?%M^|ld5V-tW@-s$W6;smyVn1*z8>ZCgFa@OeF9Q*Dp zorS-8#QNX3GJh}xC7Yj899oKs+&*AN_X|8Yh7kSzfijs<_GP@*B#k7n=WOnJ%0aDh zdXBblvG`F>e9<<`k6Lc(SL?LIH~nqoCZZhKZRU%Bk-L0)E_ymwC(9<(PL>%ih3uVU72B}1{|V{}j#&|plCh9J**D~lzmm6uCA0k42nuM+8U zGkCCBd*psjP&4{A4n53$Q1fhQai}=LNN43u!Jh8sBH912bT6wDyn$1@1x(Il-eC{B zJGk2S>>C8=^{e;qPCgzSpB(+=FgZRv`1tC3)TT%=OhIwI&XStaAv3(+eLOsT{a35t z1!YWj${ZiO{qW`xKXT9)zdblUJ~}=*e4TuE`0nuI!O7A4cUT0k%D~f8>p-w4!TrVu z^_#rZV;JkvJAlG#OEV3CBxv`$kMDo@;McgQ%Js@E`{c>j zKjjTuD%-Pu1lulY!y3*!PY6_m%JTK$=ZD{oxn7iR7c79N&uB5@hW*uMoP4RiImZe zt4N8Epm)RvlUE(7LB-uzqR`_pmUYMFGWr8_^YXa|ucOgA)$|9b=GU5)QgnEg^!pa< zc-M!!mEDgmEtnf|+{*fWZ)o4D+l!QlBR+QA z9BnGw*i@CFt2uyAyDQ^v=#Dc|rj3i0Q#fkTw;M16{9I7#dV7V}{$N z_ch+cd3FI)*4rb0NRMR{?NRb2XnfzyNr z*s>8^WDWV(>*_TDVcR9EL~oOy!m;A1YkCckWvOzl#Avu5CBF?Ra2KE*3{=o6b2i8@=gRklBc zOp9u(GhiUIs%dK{UMMDzNj&OU{}ySCg%VXLq2tS7h9gB{{i*Dblth;DXzek}Ai}`P zkzFz7*ji-EBw`@5KLzNPG2Mw>2mwb$O@RTGdT4uJWW44R7ZChRV8*|OBGZA@=MP~C z5}rjI(hZV5U=s5lvUpyd)*5U|^h%Uxpgq&UYty71W%?~ntL_YDX~eL$rR%9I&_oY`Cw;$x7q zN~j*@yRZ5V&$m9~ocG)t^A5~?H`=nyF)Xl2a*vYD+6hIT>`}@(cYd4dF2?s8D~`%E z<_cf-*aBG{o1-StTmW=c{gPlWG^#`LQk%Em)_X#EaHFkVISPW~Ax37E(|LoVjJaOotu{+ad~JUk)pR>|J3JvxLFaV`JgvNd+rg#&)=gOGl+7_5 zNBU#+aaE8GO130;rG^kY;-#pxpX##UBVN3hnwoS$UzF)gF>o^EV-3eolvXDO)Aw}~ zfKaqOD=5ajbu-`lU%de`04(IO8O#JSrVk^`Ep0htbH_+TP=IDo7>HFl zWOk;xIKxxCZ7LlEig6%GQi?IpHHq8Ks?49l{N6k*wqs z@hQ+fbWH&w*>|^Jf7?J9Lq<>HF(_~*C&QUt4MuB!Z`476mePl(bPtJtA zu!Wor=%#|U2(}}PwydU?^|Z<{^$dBI4Fwqs#l+6qO{P6^G)=EdP{e;yJ|VS_F5*2_ z)lK9~n4l?>D+yolQMPqRjacgWNASHk&N@l|!ySN7ScVpF2}E0z!{N(~P3EhJ?bj*N z3y#WMw(|6;HDq@Kt|H47*KiMoJr*Nr*E>=s#P*eeLh+nd4(3!>tjy64f8YpbIRxSl z5*27;%CeaoId2;nj8;gBg0U~Qc?(3W7N>TDZZL0P#XL1VoRQKJdFx++a_Y4?C*#=G zcCpPH7ksw2jET=Tjmc5(sxj=IJ&eS0P8nn`%Q~|RCAcZUUJCNzDLBa#CrYIYXoup? zc+|4}zop+^$u3t=Yl%ioU8un`D5^NNas9aMLu6^V8mfA>5xeTqkf|^$(iYWMD`-R) z^fXt)2mJ65eB@3fFXm;xN1R%%p?^UZQEMpO0MHmmar<@DhM@$*=f~SA6_#vsSGK} z32x=P$}V$Ux0osf-5_?Y6i~ZvK_!)}(DyFcRN12B$O;{4O(RZzpnGS%^4`n7f$-@90rfBIsQ z*& zQnno_%3UdK;tD;*X1;)AVSh2&~|x_KH0JmvV<*Gdh*Fj-Gc2z zCJa&ig2r^mQNVx?K<5p5m81pfI8He~4j9>nvjCX0`TQfMnA;V^*rHtfgS`SyqdlDs zB5iMlxM1yXz+?oxarg=`?h(tTB_D#^vBJBD@!w#H>mDLYo#W&$VvAZ=GhUf4Rdxa@ z`upYiPx})#|4kyAHJ-lYGxX7!=X}U64=47*(+O5vI0NxbCT`L~hV8X(zjb#XIWdsi zKf`7wi>K&pvQJQ1t&&mGNq_4J$VQ(bNt%`^hAar|N!Hqux9vANm47Pds!U!rPsIvc zr!a_n*;?EC8dhIb=hG{ys&m& zy?^)Z(RazagSUsCb64~Pz@b^A{@z>G$h%snAa^!#q$e z_@TlxR%C)Y_;hO=(JB#*BVJCi;IaM{FX%GevHNGdie#B%`ER_C$BxI|-|A&#;nv0ALKZzD();X3NbpLDgfc~(!&a1LO!y;KiV-)(4tfiRQ zB{QXElqp+M&+o*Q<~y(yL8U4|AJZE_bggSM zSko!=NW4eH-Q<`B-Sn_=5Go(cw5+gG?KzVgC|diSaz;tktnWj{^FRVH;vCA16Xw{k z<$xtic5LmCBQk-FPjdf;zZ37QH^z9)>XYP2N3|n7t>@d3h-`+q55m8!cY`SXZcBFa z-VxGUd49zSiP9HST4H(=lM%!RE@tk$o2{x63boc-%z5q!$i zH=zXA%C%+|WHQ4&yz92_oCxEFwAMEX(SDVW;H)hpc{Vpdo|FTezSN4HJ+&f~pVDa% zl?H-rcFj4Wk)Ts{)~@+@eZ>CE#Ehg{4Y1{()-`X2@#BcZhP7H9N-OczOct7rYQ|!j< z)Koc(IfOd}YzU7y_%9w8?&VJR-gKsW-@y<@%L2)16`H9L%sjrJMU+wOC30`y}o$Q6aHNc}Y zeV@XqQ~u0t33IFEO#&mh+^#9Rg4Y({ph(x!Wn%Vbk(?Ra}uIaqYxR7D*b?_H|!_cfaA zIQO8Gn{BidD~f%D8Vr$Hq__pMU!^4kEYgT%J3$NNWzts3DeYlwN36;fE%6B+)F)2$ z8T-4Zj;yKWC4}hl%)AIgZrqJVBGlCr()JKmMj`b{7L8b7+Z9u&$3!UFdxv$G&c5PTq=j4M+x=MPUFY<;&yftRS(*N0*_%Mno zcVEE)F5!aQA=M&+9%rATGTXUr`V65zXzzi;rE~;NtP|$>5wNv5uscxekxF<;bZDX7 zKiCv&7J=8py>sa#nZNa1%SZ>F2+l*olosNvgEL$GwfZ!*q&7Fa#&s<>UPckN9Bs;S z&kyk2TEXf+s6_lb#_1wrfo&q5W!JiD503Yy#oWh+ z!IQN$ZT#7G-?a>#O;f(-6YK*Z^+bgj;Xr$(3AyIZLN}r>&{Lm z|KW>jKdfYOCQ)x^eN;f(i`qLMb|Ld{%SPRU8kiQ2?@-1=eyGap93f?=k9)U`hk=)t zLcSkL<1HBrtyK;Hp`r@qY7ZJ%2zY66QeN@8wzaLJIcfn|-rx&Q3RO~PH1OASZbBKc z(TZ;SyI6?x7qg@|c{&*`Sqr@Q;i_T-#-JLuY&iPZ%{%DYj9)qAh24`K?mlc#n}Tf- zVlh4a_a^_QxTNV=1f9$>S}kH8mW;4>(`F;Sq3c0h^5l0T`ls(;UCK_4WA#G?r9=Wb zb?}7QROUs?9b^;uBK?MfbRo#X+^}t($uQQGlyPA$6$z+;D+u*;As*ThCJ4G4dW&n% zvySM1mVuzedkZ(44Owa7!{gGtxXoaelYKGoy*c>i@J(``KlIok4T-o)i;Jvp+VK*b zP($>Evjve*{YSoXVbKQNq1&@c%-cHOF>@Kue0<8uy;-$_=ve}w(dss zfl`^JBfjx6lS>?>u?3*wnc}|dgt^aKB#jZ0X9OO}%W1Epb_Vtx)`KXdp&6{(G#KtE z_1?I%VyF@CC=ox3)SD=&P{76pZ6(CYOj zu>F;%8e=+%RrMEME9r!>)|h+NMz6BP0%Q*S_lG85P`c-#@bW*) zz7;)elUmrRASDVHx9$l>A8dJazEVkOWEFIxOPrgf?W;7X$PS+B(^ZUf3=N&2;()^^ zwQ)j8b6G1y{Y~0Tj1IwJQ!`^eQBBF5;Ia#<3i2N0|BSYNB$Yr(BKPhA&9hPPOub?ruFpQlbpt44!meo{9v6Erk6tLzYcj7khM%+VZAV zL)XY?8M5qRyMbGP$gy*AOzxZ9WOz4LO?BPRS6Q~gEQj;eA73I>trwYFLI#kaCm*te zHKj`=&rze~csZ$=a_zOJ^Et?JQ_sf5i<3~404iWGu+Woqy0G?cJFMeyRU{iio!f-` zIQGqX^wDz$mHaX4LrCR=1piRUtZ`fS=heuC)Yy=FGfd7EV{Fwl=q2>U2JQ(!G-16x zCeu|N9MS2-`AkTEbJ+3j+vDW*;a`$(-+z2{m>j%%L%F>{)08e-uJE99EdA{?tbAtR zUX^u2GukDs50vQ4ZnNnc$y46>f>$BI0;J$nV6#PyC)p}zw9uE!6SB#lUN#`m{e@Sl zl*Vcq?ONB;NJ?fkVR2{L0|5H<8u~;{mX?Lr*Pz9BoRu1+U%8t4mYNToc$HQeGvDAs z?GE&GR*cqQVJU@V)RBZ}Qc6yeu>E*Js z8|UgAF(1J5#2cMl+dRfVVKfr33$s8mcw}J;cVN@Y`9f12$XR}GO*%%luzWRX)yH$);YxIonNRn@*!RjzabWyPyN1Q|yT;yu!fp6-vYFa+;p67Zg$flSJ9c z1!u^Q&z2dy7(WU*s)5{|CLN3jMcffc^qZoiC{itdqqa*K&9u~Q(*pal?8bVzFC^;P z_S8T=TTIPC5Yf#hv%Wh$gZ~=y6UDBJ*GJ73&bPFo_^Zg76^c((vmsquOa%vMUMsw7 z&gT~Mnw43ty<6Mci0}t$3sjKzGWDpuJcqP)X+#N{a|kj%#LuC~pu_*xBpKrOGib8P zC>uGaqO_WH;4t%th7-;rZDm3LMPFw<5S*klp?reSX;RMPC8-0F9xz3Rs^c~RQqu_u zb0!YPvaa`7rcDKV!_uaI!0^Ncm=4`olcfo#H0;)iz)5L&Z6)sXj+Z z(1ucf*9mqgR+`a3fjH|5uRn3}J?Kic*}Pcu03)}%7%}oj^<}!sPM@E-wI5$?N1WuD zp3U?k%?rkllLCkuccw0`tSLO`MbtW9v4sAw=Kdg)UVlw%{3?q_Xf_a=5r|Gcv&^zcCteK1vu81NKs=4S)qQq{(gj1) zsjp}aDMI!6ui~@3kaCLGoP_qD!UA zT1rbWEU#FatzKO&GqK7l3b#VE!!lBYq7d93MIV$QSWQhnMVpRLS7t}lhr~JK@bw0( z!&l|%F4#Y5(X+zT2IX2EvRGr*Bf{m5A!)c_&5OQ2hiytud>??y=jndf_)zShwN5@& za^|z81zkqYw+lmyiWJqk zWMAWXlNYe|+teKZ+AT49q;{kG4Yi!nlwlAp+V=w>FekYpvHb}qj=Tt(?kX1@6J zH&Xt@<3lck{A89C3@y~;++CxjmZH0TvEF(!{AGf*fp-qfW+Uo|YrTFlo4OKOzAu0>~JfivN& zyBru|T89l+t5~xczcd>;3{24z6ADn2M`Qu}Ya^zpl&~5yeDgA`6T9p4JXOl4qDFdU zZ3FD_{D?8=?6~_Hl!W6IhW-4%?Bn~rKaz7VPX!*YAWLXGO-vf+Oig49R_o{4G*x?Y z#)-6C$M)~Jym30E3z9((sz}lh*c&t9qL4i`HaxHS~%I$~_3-v7MiG^fVw1ce{&Ih;uZK&7~>YrWaVm z_*n{pwh)=YeX3CRy26@zP&zGTQ!1J2qRWECq5EM39$Rb4@!68Ck)2t+@!Ka26r#7*i*`!zR9DAQYNLHSw-EitT|A? z2O4Rsh3QsjxzlOw)IK9tJC|v#Df&cv3zGAupIh}Yc~#@Xlw9J`SIi6rW6|+~QejeF zY?S8>A%mXG`6A}LaYLX84{qu#9qr_dxHIxeO!iTK+Q(efj;FjIyfg8lo}4ZY)D;yj zuJKU=+_$2%0BTkY;Rk+iWj!q+RG$9oT_V>OG__^8l-;s)##4XB0;sw|r`}VbI*0D) zL;40M9tU@mSy|*~c`al3W?H$>Eiq{UL!h?g$-)#>o-hOfU-t$kjfD-*LDQf+rhRF( z3Ze5(teqa!$(er<_NiMc^jAHzJ3sox;>Z|`t}wEp`0?=QyYEkqWdxaOfW8}+)prIo z4BNM6*9{+|HKPS<5I$yF(9vU2e(pFE1HajD@04;fG%2(4T6rL(g-|U49sGTr!H!32 z>64LhSe|ThLiF#p`m{b`&||_Iq$UGDX6J5qmz|KO$$19qw{=CdwZ}w5hqg+R;nY+f zS=}BiC?cf6} zU-j*8+F2+Wdyw_g@)(Lc${t=}mtE`=hI2FJOjRu1kUT`!0m` z#+Ksj`u3^6bH%bW8AUN_vnfw~?M%N3Thdrt<>6yWCZ#mAgU| zj=~@amC)0@7en#fPGOG^&7aIj7fX}W`w=t8Rb_Dja8ek@G9zy%4wb=U@VIcM8;oBgcaI@g7Fb?e?MzHPUW z^JSa(zlc#kKJ(t+3Ot!a^5wWGo?eR?5us1)t|zD(wt)3OZbW>40;b4~pE%EZQhA!# zlk2x@e|s1V#U$$}oOy-YIxcy|Ol&;+ZQ0+jEu15>-qrL)*E&AldiSnHvgXCw*x>fn z+!G;X3i8i{bT{&lB8laS6DhV#O^^}mqOtjXQ#KUB`CPxD(s*AoDa3Xf5S0A%mBgrW zeX-#3=!{#U*k}4VTCa20FC(9uLNcn+bdPxLyMq)8f|1ln>g1wIXA%15aebkBM#EAB zHt1#q^VC;B1P(sbjF|@w2#;FL1Se$5#vO?UFyM5y<|E%s8)^T~#HWKjXej1dUP`_rkIAxdok+=WwZn{fv(MVj)$n94R2aBP&$4+}+ ze)mn!?{%IfI`t8E-8v$)tfz%N9NDntQ`MsPeV%>dQQ25J%@V&vGL|jv36Dy`Vz=)r znvcoLqAUHf)&;vC-GvmxpyHCtO*t1jWwJ3L0LY_@5&nUBkcE>XS&Y9DIw7*v1Md0R ze;!7}8S1r)hA9DPm#~JNSW4m;_iPwsjp7~_CG|oTXy%i(00?pA*g zM~0H5DubqwiIDNPNq}q(2UBdJD#7q0vz2P}aE}1@YG# z`-~2j8nwr2*mT@A!1_M+Kr&(}@iV0D;OSvCvzHns48$&)vJ~g}g>7L_0r@gbLjtlp zk<7MlTcU=F1sjBVRgVw8JxuH>G{k_(v%SwQXFw+zQge1P%qGDy#j<)Cj0`=YGj{(R zI?#>nKU~)#ohnk^X=9E8*l7r{t;?gjhFnsa@_`+qP5~3&&8Fw;1>4jlsyY)$sd6sm zDyyV4_s~eC;36r-D)L2{v+F8}tJNZ%;(;A*WU~cO3y=ZxBQg|d50)j_IC#N%onlr- ziz7=G$4P4`OGE`-;4+?3s?^%um};~G)vo89A;g}Fk%Y0|8=dC`-|A((L*jvLbX``p zyz$F}&hw#Aa{OG4cZLqTKt8ohPLG0sFqbjzqy$*qaXNllEql-5d9hc{>BpDuJj$tY z2Bcm-dN6j&RfZOm3Ypq+g<`+z^+!+t&tUIwv->X}5E}a7qo8%?uYn?|Q=_=)2eT31 zNu!SB=^zgMw%5`WE;k;$oLh@~tiH7z;AVsn&|LPzIdB4LoK_;79Y)5ET>hXob8KAw z%wd>zcn#`twkG3^D7pbyD3Ara_R@*ujr9IdDQ#z%aQ&QfN|Wy3={)II5^IQU^VOEXx5jzDasc!WH@ z@A0ba?TcQq>OHuG6`rDV=hm#zH14c(FAUwd#~)1IOM%Xu4zD?nK6eqWfMM zf{BKtj+Y#HeDv=1;lFP!l}+gleQb{wuWlP=ti{LkBR!^Ev7&PY2CcRy3yh>PM)N+N zzl`m4veeoZiDdo8<}&A?Y_xsC!jZnJ7m2+r=e#v)E*_)-ln&22+bO8ej6SkotQ+<5=ZJ5WW=O} zA@VwyBx7^k3J*m;|xrG@>DV+n9sO9*MB?-iNQOu z{`p9P?b^Cw8L7AT;!J!s_JM@C{~Gl!nnXzkxVB0Ccu(FRee1VOGU<`QH||Ozp`}y6 zpic~n=2c0DpMY{FTMOqFR0pFz{_&tX;qnQU4{;%^+$45MJ3fW)^7d<kK7 z;*)T90j}CGcJ`lSo(_uHD|FjG(B1q#bHC-xVMU6gNIt20Wa3|?>pES$dHb2`b4tzl zQAK^!B-XIi$u59%V5g8b5jpUme|oo?pU+*8zfaZAF;!Uc zvsNN%Ki7C6FPiI>s+}lRyC1yS57hWKdEJB>7^*>(Xqr4E8$M zrBh6A=6gfj-&H?yCp-qay{TmsuJJ62r4c$*7PEf1=P*r;%};3&vI3rh7MZ}3yv7JF z!2PXc3Q8X~L8Z4tQLIMAEUl^(&*iY5s+Q18U0LsRU~5X*<0iKgcDtj}f=Vk-KBtE4 z)`my+RKS!Zb42Y;SzU2H+$vjDVW(BjR{>rCm*8Bix*A>Qr!SI0sOaPd{TQ$<(LYE3 zV%X!O372oQ$7BtlcuWAIh8E|94tJ|N?X;jLs2|mlf==hB_?&H0zQMLGwj=eUJh;{`+%bNZ_LPd4QsJ)B9Lg-wvp}Hq8W*{gC-UG z@17(pJcBnwBqwCxyA7G0gNaj)?Jz-o!XBQIk;AHVYXI!Z%N_qX`k!A6>>HxZ4CCZL z7usiAxr)=_kw*^Ht&BZjYS9OIL=NQpORnODc!z_Rzp((F3e@hT6)W_PQ#K}iV+gh> zY3T6rUZW53zfXrZYL+C88IK)qa^G8N{GD&~SW$WqK$gKYbken6sM?PBQ67vCShU|q z7HW^V=oqs_^A-PI4g8lWJ|AMl<&{NJPfSJHDP zl_uA37kG9g%@!Z-PVi7w|NYKoqAOk#bcX@CURf_Zoy>unfr${LbmqgfviP4-S~D-k zJ3h(rZ(!YP3X~q!aLdViJagyjVkkBgefB>1MOT_BxD8zzB7hziV#e@^X{3a;opTr85MShu-p0 z^g}!mCZn5YBIR&>Br`*&riB=x)0gm26|m*DY(7~SiqH(xx_Uogs3f*{>Ur1B$N*A& zK7*CS6kE4V6e<^F(+nrcCaEXx`zQY&P)h>@6aWAK2mobUD_J?sz(GPF008fU000{R z003}uZ)at0GB0g!WOZz1FKKRSWn*+{Z*DGddF?&@bKAC(zw58SQ)bRoLaVm#wKM0d zI+wO-o4LFsQ>X2`=jUN439{Kxq(V~uh{yl??Jm9+AS6q9y}37c$C)-FuvqLa7Q2fD z2tN_u)vLRvxV-Ab$n88&=1%;};ECuxZ{>%gx)i@u z#kFkOqT2xW(>lMAg?J}tS5;lsmxXMh<8dlpb@Ra!Xz}i6*X5Na|^0pH{%1Sn>ep#1t z2EYycn#-~Ta>vxA!;O8`D_I@3^}3l!(bkI&kV!GkTR9hiq`Q*h`OCvalP~2>-F)z= zNv^tX)t-Lw#m&vlcnQTV7LRAu7bEKFRWXwlVp>;o*(hYId_POMvC7R9wFN-5;+Os+anR3N?Osx0yf7`Q3AE9yog0fxn2 zHHv!LLx08nDcZwrh&%y7L5OC-(J4^>CG?i#Ba}k{`vXqcQ0R@i8nus@7}!r{^h%u zZ(hUO590ajzl#5S`TF}*NDvcrBX3s?(ho!yC_*_$d7Vk=5wNHgLG4P;ibXL4zN*VL zYTjiHx?kZmSjlEtv^bEU#1&|^D2rv$ks`Kzrm&mIANUN0!{Oj%^}C!Y&9|CZRYX2& z?CV!=2@E)`agMDdP4B^?sh480u4Y|bm#xs%Fq-R`G{DQeyHbEFnsmsruDVwGYapxj z5|k@%MYXbx4K=-jF=+0xuEz83ZY8nNvc+GjCkthIs;U&IZc{KORa-YH7yz7=oos>z z25r`M>uJ^zv_WzPMXLRVMnwe*nwP~VY18N`zHy85G1`+uii^G z%S*IFoy-<_l@(P68bUwobvGDo;5LKFH~X+c7Ev<4LeFzIUXE;#$T(XysBK`0p)a&9 z2FIoBR0%;02CrV6on?PH%TBUmF)>3ohE7_UIFwnYYqD%4z7D_+PR7UMqk|E{b@C73 zIw4$x?_c~1rX_p#{O2EEyaRL&0kKe59T@lIU@os=0%cv^KryuY@ztAepTEj}czgU+ z_V)R^?|y=&KjdXA2ZI-{UcGz^#sNycJQ)m%g$VkXh=b=X%*Oe_DGgH$*Sls7sOoZ_ zwT=vrJXSs!4CZnnIB4?p=?5@!m+eUP0aOG$piNXDXgBoF^8n~V{3)asAMEI|tfzUI zEmp^0CB*rNssoZ89fLuXcHk-Nvl4cIe)LpK!A@v?z@7k%P%r{v(86GiF6s$&qo@(L z?iTT6tGt_CWpg=GQAmOo%JqsMfNjJ^vH0v_|TwC zL<0Ygb23bcvQOhhS?AsHS0h9MwdkVIKfkUkNuyY_Mb&obkg>5bSR?Uw55IY2iT)SR zprs~sl~7~Fp(=TkX+@cc`vSF8IBLTiXtxuOF6kLVo1 z{9{U|hwr?vQiPm@zDapd{Mb%8OBiX^^s7Io+INm|c4jeE+oWa}7V;`@GeQO==G9$- zsFFW~oY7|8b~^pxR3xD`M#XaNQ*pf;4bj=BGiAC+L#1Jlu*zQRE{CS4<+|;# zpQ^47<#N^CF+_|bCBd%Q!o~2Hsw+werf*X59^l_ALe?uJ;g=!76n-s^T+gQP*@TVGAkqY-W?sMHz*{7<7%Jl35CFN*=2D9D`$LD~Ar<%R zZN$dK6_e`Ou_2LX%XLZmj~S@MJs8y}QW2Ie6HPt%WA^MZyiuCbZGkdX> zKP&oE@wTb2i@BUz#k;3}AH=+_4mwbdvMW~Qomj6h!3i^E7_l4rSq!l)jxkUn?vZhk z+{dVU7$G}2@2FO&9F~z%ZQzB^+o@?To+%TrB?YgMPN4Mo^n%+Aix;fM6Y&F@-?9|b zeD*<7N)t;cQCQR!tnKqv0fo!F10f9RrnoGsOuYeJn3ghC_UCMkCOf*k_u7vT?@~2f z?$aor&!^Z>m8&;RjH$9eAIGcUG+qxa$Lf&oh|{6G<|wZj%4>`AdN-8U8fC*UHV$E9 z@HIQ3YY^Mi7v|KHrmb|pAICg(0_BJuQiE_OWQYj$bdCvi@^ZR{3bYfNt4)?0iThk& zfQr1GU^2sbIYJNfn;>0S2`?E?5M)pjB=38y93>+@byS1RodT)YlxiO5jlf3AbxZIt z>t$BvQ(0zR4eQFeyiMrUB^BwFQkr9yAJnvcc|s)>V5t{bDZ#KcNI~7qlY>tOBXOvH z<_BmgjyYQiR;a!V)boQ8MOLbKT%h}c;CrOAKuOTevltHOX=K;~cr@@bE3aA*uVNHJ z19(jp%`1#ah^pDTY2Y1bKce?Vy{_h&EYaYov@e#zd)9s^6qtqsmzruRfnl|sCeVwY z6LgXkOp?@j4c@z2%Q*BnfQCv-2e_akzZM@`eK^rV=TEOUeDb!HxDf*0&a$3@TTwca zj80XtxYvSuAkN8yLMi^fxhh~8l~>q%v1*G}+~nXf%SPt&JJotP@cLxZJn<^*Se@(M zF2)#gR`bN9-u-+XA1unNyB<7ArlH*Rpg#-6q@`i%zc$*E3N8?NBdrL4TiHmlGCsY# z0!ip?J$ThFSAJa6zTI#m`A*mHLA^g6>?oUU)7ei@ZkP+HVL0VFU`8pNREBXrvrL(u z5ZMgSbb1y%5#QBIuwhaU$O5-^Wi`8_s0z?dQJNIMfQ_Qf%^7Zjs=Z&MR2tQ95~3uv zW1TO+8sO%Bi(7|^5)7D@`3Hz`AdbSFFH}O>I93td5+ayJqXN*GCJcvYS3<_tAfWUx z>I!XwnH-)~FcM;>L6baTLtD7zXa%TMB{dvZCP1f3+ZGo`>cveBTm{<6Xm_tnEmCRC zmTKCwGj{5jdDpi!hJ>i82`KR9rg?l3F7w-jm86>gL(`8x@zwkYF*4mdr{5R;4A%=} zLzs`i7(&g~*Aqp7Zwq}zY1h+s))cF*4d*$;;2P1xaKjc&?^z5-xHm#-OIBo~RswP) z?P28E;``#knU}tR0{$ldBmB*4sdj&slzZFE;T~8w2}7Be+ILHr#kH)4o>BC)j;WBc zd+VwNYb8I)j1FT|mk5_Qdxxhql*LWKqI7x5Y@+?Sg%z)Ne9^cRAr8B?8(4s6R;(3< zsdIkuIo5|znLhaLm_>MJv0>5#^1JCo~h`e_NXc{6~JV=E+TQfXu>U=)S+fJRi zk#Z%VZvd)MpWDvTg~us17CRi2TSP5C)=sP|mc=I=8p>GGi9c!sEi%NWR;;*#(lgo^ z;e3)!O*TavB)7O&sN}4h*F6t|k7b6C0D_)ipzETox&sc$lvmuW2gaypsXh?7>VZbb>}Rate;LZOE?q^S`gNQcRW4o}2c zhX=w{X}xgqA+wH~F*?__Ib%e{1)k&5TymrMM7+dv!82H)*6R*4@zAxUAsLUy==pWP z)Jh4f_Nu&t|IOT>du%ZQvU#{6F=B%zFVqbZSgxXP|4p2+Lv@&0}={`-OQGD`b!1C1x|lmN_w%HX1(DVH8)APf>T|}gAr~Q_!-{W(V{{Id|+(2Dy zO8uzLoc3{{TH*q)r#609J(?X&1t^xXyQ=42jW)b)LNjYiR!(o1pouLdt&S>b!1W0r z?<5Rcc9^4Py_wAtM+m3*VD`u`*D&Gn>0QKfjA&jLzEN&KVvf5HpDFkT-!#Q zenrTWe_Um69C&Ehet zeuvFXc2*@Q8@WO)otx#bwl9FuXglDwy9MG0Ot-*V`Nk?Mbnn4hb*wSM6>BXe3F0^Bxvs>q_R1o)Xj9CJFWU9|LFq?5G%pn!cH ztUp{d?ScdF(EwR{I$cl2NY19ih;N`SIEzT&38P*gKE!e*2V&4Lt0ve#54S|mCU*aZ zl>xu1LBR09+78V!{YKru$*peyz%X-BmviqM3H)>#SSn_q1K~Fk%DkvACGfPrTrn}( z4WjTUzAk>Y>FWs9neYz>IP-Iqis1ORLU;jd1d`t~Tq-bP7*ewmrgQTOT>vCj2 zpRqvd+aNWsR8m!0#LFD71ko&7LEwh*#3MV>)a@)&X~-uHkJ&MGN9_S}Fi<3ql?(u# z)<<_sMXw4~_&RU0rOYcD8}p(S>w2lF#@U*Rr?6bDn?_v>G66zxw-e}VC+fBUJ8TAx zO$*%8X)t^5#F*}*psOopc~w`%jPAn7@>>`*1lC$gXWtaBw}j9LsZU}`xnTp_uBX;H zO&O%ZVxU&_jfR3>3AIYaiNmoJ&cPA1>B{hJhM|3wIKKsWH*A^(0z8)E)cSa?; zxuEzp}5GTh)$lXHJiWwPj^HwbK8QxZeQ0gZd&+(D?f(gW$x1z;h z2FH06W0t6kw8r2^H;dDQ;paSW@avxIf)Y!{78yZ(Y`R3#_ta-QIE}w{_xp4v653Tq zH4eX_YQx=(wN%PTnVF^vuvI&X7QS|9L71~IK6M*!YH%|IS7#nQrjSm}d_P$Ma!I1)GUdARdVjn#R z(WSxbjS_+YQwZec9R^{iDkxLGgODu(ptUD&Gu%#|OQQ=Wvc)p|7nY@$GSlZA{cqJY znX59N;dN;SkymYYqcFV3!wd96aBOuDJ3O^4T6NGHPw2|(vb)k>BSTaDQ7z??hc*DG z>}s_d=XRZ<*Ks(Ns2%xZ`%^WHA3Iaa5l3OSTt$%=40+*@7Z)yt!6=irD-2k1dL=rw zz>h8(NS#K-PK^xxafn;qu({vHV4_-OtNi4UGGwuZOof-|jp<|aL^szeOUrZbmV=1( z4iUR`5oQn(qK++ltkQ#sQXff*C(E#*XAQNwJ4M7v$JCNdiH$INR=|P+7wE=3wizV~ zJR^#Of>cJYZD6hoV9$%yU<(J3p94NFiC_;1*g zSLeFxYIXO)E*)xEuS->ih}EyK+g?4pC{Ss<=;vA}U|BUa4r=rwDpPOQDfrTlo?Ofw z37ZTKFLtsmMn>~A8Po&UIN;IqCcQ`+E6MH1@35}XC30`&`iB#VUzkh+V<-OW5 z!5|r|mA(*SZ2H`6mFbl7c0Ltn@#~G^Mjb3q-vQ!7sDu$#6wEUNfejeEHP^|G)yR{f z`lzn|RGXXDmyXrDRwumFP!)i`g!ZM(ugwJnO@T=mtZSLZJB`BxVXrMyH?tH9BM*gJ z*ehEK?!7IXPU0O!bj}i`M}sQ~)rg7twaK!$B&e0pCa$B&PqH38=8Y+d+zS>Jv^6dL zAy5Ns37Ol^p$s98?f}%}eVJ#A5!&yQTb>Re_V#8)`j|%tVDs<_i?QXw%^)&4jCRpaiLdHb z-MtKBns~Hvp2d*84Ry#kE8DCBqZ9PUrLk9so*2cDzdq8Xk#M9tTIGA)7+~na{OqIm z@vF;F0K2ua3A<3+M{h6FPVnAwm81D`1Z$FWpFbwh3$r#-`+uoG8YV;4>o6onGfF|F zsW6VrFm&pXq%Q7Gh00fKxz9Uc+>Y5MpLwQquphs1`a(GF6doUvQ@VrD@G)YePC|KY z5D^XD$OW3S8aza^o*7GY>(3+m+rsHHrI;+>rX$)Y7Vn$q)+TQBa@U_qYRRI3D6ha;<}#Y zQ+=jI8hCTYC0B37{8q1g)~&?T7}-_bskKkKpEI)*?d8xMjS+iK`#>3GP7WO4fzR1F z_@0aHF2!g5k&b-egTCmtceo8T6Y+s3a6pQ_ZFj8gdP*9T3~7vpdaUe4)9ScnrpmU7 zh7kdBVQ)T|IbuWq`c+Jm0D!&uV2`+6^(QQnp;#!92dgg)D!bcg?By1V}2=|;L(QM>H>Pb_y&_3QD$a(F})82x4M8m$TShYhhrUh=G|fGpBvnWlc$Z1 zmvD3*F~;2@v4?A$qY9RaSoDb3(JdCtW@#Vb3LZkm))0Q-gKw7a!-0ePgW!|Gg`OB1 zSiL+LTYm3m^MmgM?nAu?KV{s~tNczjI78!ZsWHm?9RK2y)bRSR(iET6my7tp1%5XG zGw^V#uFzhXlLyARIq5#OhO0DgIoD_6`XEz@7Z@U7Mr&x82AhCZXp-73V3+xFnG;Eq z0}y?C5JlA)5IW%(gORJYDDkw4Jp6LU#bXRhDVMV9ZVFI@e6c7hIf-7CB{uqMJD|Jz zYCmJslRgY&_|bO=1M}`YEb}hX57L?aJ~my^j8FfXW=tj=+{DLYOg=+OPM{3p}kl*#H}utvfUT0z7slx8$FnPO?57{VgIK`{3LJ}8NSGb>;lurj}ip2+~HfBDt>BpIA>N&9wfdosc2V&@gQ zHC4NIsF&tmcO}{olBo8?)~F<*>5Yv^#I9|4`ZhW~XuD%HW|)s;xMb+rSRCqJ;g-0N z7M=U#EJ80EZD^vmJ}JIUxJyN*l0?B^>AZN5WzhT6x^8Sn-|AHyU|Ybp0|7d;)F-WD zPQ{~U$Kge+>Pv#%(Dov|yZwT^umc@Q^(}Xm9nUu=9tb9g>F{rp3PNGyh@FDB^)g$?9ONfslNKiArT}SKqtdUJ zlcVv!MXFga6sc8$cr1!5=r1O1jpCN|HO9zzf^tVtJZZeWUkvIgEY4c|Tb*rJ%=DIY zA9k;_OlAjnk64e7koY?MKrHfi3bsUB2OOD;uw1~w{u%=Zr8tNaWd`5;wFvwE%|A$> zLBxnN-?y*SiN|T2#AffPQHZ2J-Z)m8~$y)@K2m5ug^ieT7orrbNm#L?*6HkGCKuK3N!2kW?a88R!Z<VW|vJq*9ThSWC788QDvCihbU;X6HWWX z&djFoLW&NYLn@Tt`229%_>5gh&|8#o0zwtj22PzK z3XuLLrx<<0_;}ZL);)XV_x3zdpBJ*(O0M342r2o4wg>wGPqgV{M%O%<=_7My(>ywpv7dBEPyV+`XD2laeDmij&wirW zM=YNslI`O7EFIXL!Cjf#m$A)E?Yxp?&cm!Cc|c6&(Cibe@ZB2Rd*vWJySr&6x!dXx zTomJ7*YKiE=$QzDR!{m;@-_Sw{yt}ZnsL26dy%w%S0S;^+V-+tf?3`r@9lihn*=R}?vpwVdb3ylVF zUpz0X>pEXtwBq>i_@VgU;@yQTeu4jsXW6y*KAV(v)|U0P_}6>)Mf7_sFY{s{{!!!~ zWZmR#N3p-Hvn!d4_i}nsl*@9F%LW=A4#fBE?B0E-@%|!j#I&4A0e>@5Wpyjcx%eS3 zX7Ap=TB#kzscBDv=IQVRKTDyvrH1yFBQE%U5^4qWB!1vR69K!?SD zC=lhO&7gD!<*RF?u~SZDEg_8kQQgT1(P3FH4whV~Ir#q7^OvvRz1#<6srWyNrED5e z%b(YI4c(kvi>v}%rr88ATxM6ItVOn{C46g3#PzDqTNsFeAvHjmf{_$S?^z_FcpT2(o>g7A}=11}T&FdGh z-oJYD8lJutPhbCs`1`BZF9t#anV=c@v8oY&Kr%-X${Et@os=E|^O6y2Dml&P`4sw9 zEY`?*ixPOhKryIfy~-OLNMPauI9n|9Ro)UKHgTu0nUg#G-P_ySyZ1dvSKElHE~n5b zdUX%FRf0gRvi3s%J1?*!t$qZVTtE-*&FgX{+Up8?qdt8PWFLrEEeI40=Ky8BEC%9D zg}us_1MxgtF8R~@b%lRkngC_p0LimfOo3k}k{J^iXtu7gj|-zHhD0d@ zKa+DY9!~+(c-(8`ay|fRRqJ+~&zh0#!SFqdVOawYPvI%?r7!lML@&7qsu%CedU_#N zvb`v0guX%=F{55Igq5yQXE}8Cb=kfGHkB(`fF!)E>$2{dBKVu{{fbx*p$|xx@%T4V zsvt3%X!L$4-W1Dg{x_)OyZ~7ssbuZKG*whXVt`m!RQkaP{gl@$cTl%kPZ}^3KnduH;&+izgXcUZ=R9ztCDIRFJXDW?nNXQf>!4>v zHk+Z>U{v!1%?J7m%v4fW()T@Is&Z# zSU{)0RUxoZ`}kMz9_PGUCW*UFx0sBhv0C5YXdsI@qp!Xv6My6 zgV`5P#IfI|DvI6-!0SON#EpBza?fvYzif6ta{nUEa! zH{bgqy95gjon*a&>IMhTcqX=+Bk^lDHop!H( zKYGio0^?_@_6%!jMsnE3t1p+0jA^+GL<^^uz$=A%Z8XQHHlM+<0`s=%Lxn_5cRM2a= zuE89=BaJqd)F?I7oS9icR)5StkOy7O<_AZQjt++~(>i`QgvUN{+6o<_nFPmtnN~4* zk+(+0%ry1KFlWh@?e$2^mt_V{8hB9DHQ?^NaGuB1YkuoUnG)F?{Eeo?i9;^}Jz&s% zRE<3>9F+iV0R0JMHV)n-S>XPth>Id~HF_G8mr2;hyD7Hj^Pb>yIAg)FL5N_&etT$= zr1_?BgJJipb{ToE)O20f-~)}@@)EvoD{V0c3_h}SQ-Ll10!-yuV_WGM2@($+g*8ny zz?BB=ifB+2f$OyF|(Pt(fEe4J$M z^kUrPzepGvzw^W9BCF)-;TeuoAQq-RcOImm#zFW9Kzy2r2PrNHs>@=gDg#Y=A;`~D z5C$Pe{SYPYlHb67h_9*ip`2zDU5Qym^?(29?4B!9QIAi=VIWJsjn=QxqAO-`#7jGka2Qba756h2YcVzal|*l(0efSH#Hy%OgPe~ULE{dJRgLF9&P#v zgVjZ0{X6sC5CY?Q&4wE7MOih&NnT9LbLt6La#=0&&C@x@J6R5<;rNxAA?YB^Ihh9 z=f-8R<|(s=u_|Hc$n9i=Cu?k9{HkjG2A0YXm$^$z2k9$7S@ulJiV#}27%t(r2i{wE0F}<3W-ZJk$Gy%jD@fa>RDYJh zBFQsU`Hh;p>@f52CG)?n&}o7Dl+xqMv1La4_DSFk9Gy9MfIrdw(Q$)!lhFUMLa9J< z+;FTVG$KJLp#zb9%$t*=Zb>B`S;%%=ENXB*dh|zGc$#Qqr#IOro{C>w1Jk;ba24T} zGX)qraQZ6;pT^d4GgbTS(%@&$v$F*(zX53Lh56{6Lglkz*3f!OuRm-{iiGvux-MGi z-3Evf2vZ&r;*x;UFL~8_aB7=B%|~a9vV*XpUq6ic5vu&X#nmKeoc?m1(-C;dz5^DZ zb{YaWi_ss#Q1-%d195uR2Pf^6!9Y9QLB$Ip1zfHOzk$I>Hmf4I}VI??YGlK2Tom8vn zlY(xs_)tN929;bR;^Auj6osF~!d|pQeZzi@&3C z6`6`sQA#_}Q2k2MNEV;9mh-~Kc^0YaTuVC`7!RIbWJPD z?>@Zr~o-yDBE z{YD;s`S9z>S6_UUWwY5AkEDF`D4Txu$kftapB|U89#+@)zyIRlm;LR2U?7eg9YJ@% zOt8|j@8dYCk#GfVg>BxTidtHK&{q8|j2Jp-oVZHR>GmJ)7g9unVg$&dpzRy-V?1~7 z(3)DUR!%H3PWVWdQe4e&jX-wm*?=0JilGy zcNk#ih1@&Zb#)6>+TswThA|cS@Z0ZyQk4o$wD_1s$6cO~PEbJH!?@pssV;@NaVi%p*y~mtlYO$PGk~Gc>B9vOs+r1^pdkVcb*m z#_1(>GPZx)=wfPYBTx)_?`1Z#_(^o4gXBQ_tM$ zt^?wf8w9hK6dd>2QEDFUP>#YLzN*ZZ;Fk54*=jP&s62xK%`l?$IPaDCU(i&>v`!i>H*}7JmUK$)-N4{a!b!pg zkb0?|Cc!}wDxuZ~97dU47uhN|Q)c70Ie?4@;}F@d>I7j?!f>iqDi}xuJ|DjV@Zbo{ z$G}#8gJ#_#iB7lYLp?Cz)pU<(cyBp^U1Y{3m|tXFYQ)W&Rey8U)9!&r;X8`@nF8WX zhlVh9UiMtRW|~1sa5(jvCWadM{mA-I0};IU?Xc77Z2pKCOa zh0eIjmRH$zgK{tO)&rn=`MzF9b4WGv!2bB;bRbIXN#KCV8OyTZOg_m8T~@;J-8CwS z&ssF+lz5jy>)Q+zb2uUnCa3!jqN6i33j7fAWm15gk#TfS;0n`tc;S27dv^;4HiMf5 zrH-;|iZsIYiQFntDBHJ)K4RkrH^vE-8|NVG(ai&?s(64k@H;ap_|sa}XZ#fYgQ#I% zYl~sxFVP!qY$GtVV%JTJR6PYU)@?z!9={MP?V?5_L$PT;hF3LWJp#3TqmcRc59JJJ zhYcr4vIRhXS!TcmlWV}HZS`WGVp3KQ>_lU>IVQjEL&wgC;+b0lREqqF8$5U+9$j$L ztUHszzbL zw!b|pkW^`AM#LUz1Sw|M`{E^T!YlwXw(}ud=2iuWS1+8LHQdKZ#BjQe0g_<^6y%Dr zcpRl|Y!nzn8_o5nL=u8^7SHI?-Y@>te$AoWyd!4sX_4wP?WUD;+!Zkr|1wistoS_| zv}swiv}ZzdNB35kJxi+rxobIbdCkk%`J#Y7v*DiVV-k#?NN&Ag%Sgrz+~_U(l7gBTD)s#e_r+7uR2l9p;j|U(j9u}i^=CV$ zo|e#QX5|rx>fu0OCJ_ET#=pk{ad>t%1duR*AJmR=5c8QVEPM?-@>fp9p%6#-zq8aA z`d|Fz$b8vA#h5`YU7x>s6O(ojRY?65wFjOE#!>`c_?Sr0OMFu_|VVf-O4z9)xd zx8Uloxr@jp`|RnsR`TVWcWzo4a`@r)xGKO>UXxzY`kjGP3(&4)*>@Pg3lvf!3~9^z zLA37WDO-pgXP|7O0g_muYaT;Rqk4{EHcR?ii|3O1K{rL-TC8*nf)oPTI^R%uHSZ5?xvAM_IiqL;v`MaSmjx|DYa)`{54gd#m}nil=j%)KpxMM zMerV{sFTt?r=vkkn;Wk{zlAAL{_#1?E43QlJzm-=uJ~S^z=G>vvQ?1eb6&u%+3-`u5ij4zgUCoe)r9TA_ z!qgE?KKYa~?Z;@)B-jYw_+Mk1RQB`K;xO*3hDsfG?qQvS!xPp_^?B60V>a&l=yy&GB$hm<{fzm zcP_im$zmq7m+|X)UN>$QIDH&4+E43+qp~Kh?(Urfo1X{de2*vD%=D>m+zgl{LSjpM zSbbyO$_C$o=f)-;I!br_J-JlX_ffb_{V#!P^SP-n57 zNyCq85_9WYT8{M@R+#5^{Dez_R-9Vb39cL4PYTn!Q2z!8VGFQ*36!L+!pJ7X}6QVkkjm%VP)M3DU8VF6cCsM68?O;EA9YkE?;4_|KH}%xA(p%uw zS*JS~nT~UWaB3p7SrW5o$o3<2T+zQ(_0{^syo-KTm9CaMeqXq{-;l;gJxfZ9@KTT~#>MC%108OrWEQ#KzU8g>qBRBqWvaUQ6r)qia_Xs&q8JShPAdc zqlrR04VW=ZHP!9y?=ntL+h=l--`tPi1qvu?jjN0Y5Z(HwERB1Au3h_ca~>)0KWF2u zZG5p%XONvwu*orHot!&N0o?mu1vVk)2v8oEkpYkjm&r@XG@GDdVssY(54BX-Y?KnCc6&+!x3=H%&&Y=ugG$=3d zRLyDI3trz&^I7ta`D}?zR@FNjPoy?B$R>3=x2g(Vj)3tS5;e?RxR!}Ft57gw3pb`j}`z*>jZ!cE* zlu6Yrp7ongYMe06O4(pe@~p&JV=r%ca*40;1&bdqq>%*CWf^$iL@oCVm*Rb1Y{R2w zNhU#Ym5mbWu(#||C{oNgJcYMxWVP<7-4g%jOhokzkYfHo^|2aM6ZNN)jrKPjZN!}U zUal$>BQUWpi}f~hV@EB;fUW&1dhQ`)4UmP5?z#nvcwUF#o*tVTLI1J6H z880cjc$~Oc{W&bG{&5f?w~cqp>g_hB)jW&8_=JD7yYNq3ltkKA`xSI`D2LD>Q<%gZ zU4csboz0q$)rGgnzvC&;rJT<(AJ~1~fvw7FT$!UTdG8?S0H~LaF?xl@i1Q#A4GiR5 zO367LQ1HlMNdjy_LG!4*~URGf9pbx74+`&U_Rw`)7j8Ut8($^u)Afqu-rG;e%q zH|fb{D#<-Rp*>O-)O5{vmVJ<8_WgBjWwHI1wBa5?HBdG^W5^x)@IrxBNmqDBj%`HA zbT*+PIAgv51TRXMOY^=gu5d_R&ZVr{sMD&KNHShT5GQdbK;iX{beJj7GWW&bk*h&7 zEG4e@C79b3?P~LMI2;bdeuu=2f+c+989Mb*^+6@Qlh`}Q5W}=duU}BDR8PAQ(kH*t z!LMGK@Y&XJND`PkF9AEMH;AQc`zYOvVty326XxL={BM|pRqo-Y^s-!rYM`5EIw=V} zfr1E=toDmzFPr%$P;NCI9sS?d@i?xXLbbc6NI1`!+!;4?`bK}><{NcSFJs5(vcoX~ z^>QgL@&!zVe1~2izPg(ubg5SE@6ZiOc;?uUbuj30m-D8S=X1Hu1&V%!{5StnQkVk%s2sT{f4%meUJ4y^Jj>F2cIjH1l|wF1e)J zPfP|28&Z=_1p=~Jhq1Un=3|P>>e9FkgmWLY4QuZsj$}D2V}F#UP>SNZ%7JU{Ojj^C zp--t*Nu;V3d#e)030~tQ4_d5bmo~VfnG@SqRRSAL846=Nm38_lHGR;H(4n5K>{5 zuTJnJ3UB(+T;y}LHy8XYOI>^os>z$ZfyDND`wj*$Wxy0%TZ>&a-4Wk&^Cnsz_)Qb8 z#+^}F&B-U@b0V|X%Eqah2h2>`e*Y|u_%ks{GXT)eAclZLFP?o_Chrfi@g9zpGyb~v z5Y+st8iOMGc;b=T+e&d15pRPhuEzI?|E$Ri)HQ9J#POs4_6FDhj=F@@UuQ|7xA6*q zjz)LeZ=Cx15WA-w_L6}=-lXeHt*XhFWpP(FT{5_Ry~f!UBeiUPj!cS%O;tS~iubfg z3mif%(v5P-z<`EB*2)B3f%TjDHLnbr9V@MNF6wnL<>lLV3Pnsu&~61^%Sg?>kj<(r z0ZOsx_AZVNheux?eRFhM??UzYj|lB5(8jw^&z_CfWwei>pLcUJRA~%jm-G3~zJ{vi znd_vTg+3)b3*SS!P2e{Y;~gB49cFMTkHl*#bWiLEsek#OclTcYl1|DCyg?TVNH?Q` zvn5zrnX-2Jlsi^@i_el$>OP>y9+h2VO94kTq^F>iblz54}E9p+2C3Z8&PcAQ%5w&izX$7xf;NSx2gG5779NsGtyO(f@R zmEs&HKj)m*fK>EZ8%uz`x%{8*6tvkt9fIHO5Lm;imMV=#uwRIxUmV{c`0Ew?YkCdO zO=tsEO*YFaqAhb&r^lcTta2J~dh7ZMS`WK)EsCKNwjFwDKF^_jOg4l5_WDDuak&u( zD(Gu=R*dmCqeLzY>TdTnUhJyFELxBE`)sHEIPNaa4K)8xpXQJEX;ccmI%efrp6!@> z96C5|<|UscSmDKp4W6$I_LWu01i+)aVQ<@{95s3k8ilJ+qFTJTkeOQ4F6AL&slvCw zTpssaV3OhzJQtCJQN<%Cs5+>$0W|>51l3=a>Rd=b0xhh2cLDGi*-BjRpjOS9wFB}Y z@)`I=oJ>mDpuAVF!Q29|sIw^^HUTOW&7#Qaat)RN{3*4+OXKHWm*Cu9N9`4Ikr~SM z$+UPmrbIa=SFVq=kleo1dj(?p%B+UJ^$s`TmDC;d@&bp*k>JVJIUaHmq@7n~r5xE(o!?2j_EC)sOxH$U2e!my z)izh_SJk?vT-O0r2+p{4+RgEt%(<*-jtFzoPUfd#&YJ6h%QVw zIi}7I0G@rA8$j@^~MWL`{{>zN~8w0pvAn?LfQetwc*0~h4OZ6KeZN6A^{GZlG2R%15A(w-i{%H`jEFn3?z=zdG ze{vQst>6zi1aYxlJwt}K2h_zlxm{JKkKP1xvW7cUI2;~^y>Nx@OBgLOr=cy?-jOaE zcwczfj%&%1F4a?Xz6#pk_Z%kbc!AxHu2mE~a(yf3Fg?|Pf6&!LwpUWQM(tH8rn+qu zz|-?*S$%na_8I6Fj9fPB356%NI|9rbw0jYGO{FL`a+QIUOvijiqH|N8Zq~lrn2r>H zw@dHAhusbobt)l~=|dC3tC@^*70qmeR^yfT5UrCCQNMVCPNWT@w|+H`h)}_M{-h28 z?(*8O4cw6_yqFnbDOXHk3^T^r3|}`A4_q_vElV_Gsxj7Ipru~VG`pW9dtdyEI3|-5 z=5hwW0uO>{K)c*JhiBLAZ=J*Yd1?qZ?5c~;HQoiyWHY;OJY`MwkAjW_le6dei02g5 z-_}xow+6=?gmZVL#a|qM51kDQbRVG#4OaA?fBRiO*dxk172TFK?@hguFqKCY2Xmqt zgPz$GJhMhCV)BqGI`&YQll-zJO_bE@b>V*K;r1gX_kFboiYN9y}DSb1|u*)eCjN8AXJ# zWCa3OcSd=?ID9lb`s(Y$TjYZIjE8E(d*f3<2X4J0)G_~6)U6$Emf8dX(-;ccj{tLe_5dkAhW|iH1^1Gs(Pbm7UhvS~nu6v`H zetv)Ue0oF`{LevMLkz5@{~eU6(4V`M*Bn7YjAEkhv7&{d_}!8bf4^Bq!&^Ol>{#?d zljTq!7`3aGyOB*N;U?*O>Y^7Pc)O{h9f&h@3DQT~?+U=Nn&h+}jm~a@OEA#4L@0t{ z?`PndZUk)`M5l{-4q-(RRN$kB5lG*T8UW&C*M_|MewVfY3Hu-Eyj^(6gkyBY*ltHi zA(f|D8`C#FG8%t3=|ehY%OBQ=^*Y8XF`-Jp=LD{BQsq9m5uJ{1SnZRzDaK_AO}f(^ zxnmFu+L=u{b6ZK@D&w}jS!8!a$c3|W{e-0YudG9d$pQNCK{Pr(lbNRd@$!M4Y%&yP zq)v7mPrrI#LN)>k&5Z*p2&V1|02yfA(g_*SM}0_!fvj9(tiV=g`vlAElyrOs9Vys3 zIa%xe(9wujqtkc!*vwvn(q(9^iM3Tz1rwfjy6h&dn7c7tsBu9tx`9_>{`<`=1rDaBNkkOmd)j z#YEG`S)S|Ut`e_e!=qJd?q=Dl!$Wu^VRlIGcbpy`9~}=5AAa%WEvAR8_w+!a#+~Pj z{)qNJ;Mp;CYeePpDytWpuTu!%c%S`>MLD+Q*aRY7&IcLC0^tM#{6!PuI699bH_AZirix9=`Sv7?ZE9IN$83(xc>L~BFo z?H2M5Ly-6F$%&8u3mi;Ci2r(Ga~naFXG;#>V2B0OnM1L_NSp9onyee5XgePo7oP&{ z1odtdVhotq-vENY4)5T;t4}Dg02e(&3m`vEWz~vVUQ5dPQU|#i!TdMARY5Dd>6C8Z zTP*FyEENAsVa}-z$#WRqp^A69szd**pRKQeb167FOWYrgKU!_N${IU%!mP*jDVa~* zm`Ine;)@g^J&=_w)=MaW+}{cLu{ajIhI+bx+zrvp^9b z{3!;yOf%3n_QM}dk$68IU6(%52Z3?cD;uFLZI5CLutBi!ci&l(J5MT~bMR2nVud8$ zQ6|kjnw#`szUf9o8cTJthDeXU4$N=To*`H(@?HaAeo~|5h&WCm)+lL^Lb7sYbBWQLUoTsrEN;F(L(>$DG3m z1^nRvAiFB|OT!+el`M{4a4n_^ZrB6OBeXe7EAGvc4pb6Byij@)%Ks+oJGkcivV<8= zjt30pzsP#`NmFpVsOOENgyLOo#-ncNbRs_HUM&O}ZVVl67H<$rR1MYncJFhSGscAh zm^_{C)1TDsG;UCHbqn!|{j%Eu9F}V7z2qvb-c8e{;_e1`R!*kj6N|(rb?S8R!;YJ8 zkW)uQZBsn>1Di$6X4VPU%O&c^eUKOJ?dTaxa=`G1X_D9D2D5X$u#OF_s_>6^u)&f> zP&#i?+w|GDk+X7a!3xfu?t=RypzuaOv`f9X?s@NZ0F`&ak7xO+QTyOsP#%dm?H-*W-S#fVRthnDO#(_)VhhH9led~+%{MLQ*5d=4% zL~}ct2;+9G7su-l+qdoP{V0E4!yL`BsKVq3hkV^I%XRXz8A|9T*q;DMe7pU8?^hrB z--gNkJ)g){+!swwek$$1;27hjq-ZBFBTsI1uJ=+4Cg#I5CMm95*A6%(5lXguBu+XW z&2q+)hs#0y71N#l$UPe)?b71yWyfF*VrTcJLp|5YH0tDkp1NP<>da`{IUV3Ms18EC zgL@k&zo$uz@!r%sOg*8P2rw>O#CpdC9l=TQQ>>L3`cQm_^Gn;QUjxGaW~{Z2K>!M= z7If~OOM6?{5Gi70{jYU_ERQMZ%`e6z+wna(aAJ7a4|7x{cPlz)g{(!-g{rLSq2r5D zQSWkRAKK^4e)@6W|J>KNL3U`Gt~;OqA5cpJ1QY-O00;nEbt+j|g-%?l8vp?JX8-^a z0001Ra&Kp4ZZa=!ZeL+?V`wgLdCWZhZyQH;zw57fH%oFzuC?5ycb6iSz(5KVC=N|OP+WncJrpQf^y8P@q390RrTkK)JxGIwAtYUjRd-vH7*~yG&pTPg@o%o6!#>2db ztGu{kzqNIX)jr4kBFU!ggDkn^MVVABU_V;ME1s}ZKAL5Dnokp60^?rD4y*CjExW6t1j%ocG`v3$awCfWGp_}S5u$9({zZ;QO(>?D~l)3{3Vj2*|Li+IXe^=uUv z2kB?3s zZ-cM|{y~=VvSbB+xl9UBb9lw#1;{drhah1ZuUKBNcv^7it8$ceRU{P*M5s*-RK`hJ z70GZ}Ifhb_fxb2Xm@3X#Xa9s9oOIYb`zHq{A<+2f;Pgk&J~(9`?Z5y2{?X~d;}iDm zefIF#(W8UYgJ(zZ?OnEi^keqk!O^3TaWE6G;p;_#@`I2GnvjptUMHM80!(rtP`Tiv zWRi?PuWY(R&zt7p`x#Dy1uy1Fi315voPlSvG?^zAIbvB?3fT;PgP*NVr?V*XDbL_x z3v|h07R}>ork<;0&bKB-K4;a{0tJ@sr^`huJ4Sg5^F;_Od!H2`Hy=MrM%9+4s=S=V zC7YK5*EA}+pk-c!EXzWeU}XrC6=zVz3lE`}WeVqEk(cG;tSZ1g58-i3fc9YaE>qr1 zvKK-;1^;+HAjwI6ZvY)OTF)(*kB19}~KgNA*>r?;-IIx4K(f zTcb1vFMNpq9%N+|XH~+Bz_Na~PuK;HF0m&B151tANe^Q_VNo>YRZ{VJ6a^(uCn3Aw zSCR~VK>dnm6&tW}3CJU(FXeuPxl@ko&bkWwKn+o@p>fZ5rZfrBYOo0TR5^XYa-!=&P;;9%Xzb=n#h0VPQ4{ zEUsX=9L*vy$;^l7l&hqib*S|uK~b2skbLvke~68*zxwR;zy0#-Kl<|ZU;g6tU;X?9 zQN%YvgShPObPR2u3O3qQe*N!%^7^MgfBpF%{O1?{@y(ZCy#Co|U;o2Def`<5zW$5P z@%a~j2G17l2{^H!{w_Q{fu_*v1^s4u_HDTT*JnR@{m1_TDt-Ou|M=@Kzxw8<|9nie zzW(BGzxh+C9IVegE;?oi9qO?el=W3%}0UwtkJBFe9S`S-8?<|{#NuFd@v zo}?PEOpPugRHjkOLU$M+qQ(cN^?@4CWIE52u^LhOqPr9s8fnppP9Yjk;VImDMIQg| z5jea*@#R#2S%EWolC1eSTI6NKRlp0@uzvEF(Kwk8cEw}Lv%yiGaXH1I2MDiwCbSP7istzNIk#=?sD7{GJ)Eu}1`W>*oa49ZG}B3DW>un!n& zGAVgA*pViH<`BX?2K{-w2-0{y9LKEcv+4|8N5Qkc_L$vsc1N~`Ou+Qddj((cxB{!g z>YwuXk_TXzdtJp>sC&!TjM6T{QxFNNGeN#@Sl*Eg`{yC+WS~9h9&}lT|B(LROD_hl zRf9EnCGrM&g9aafUJ%YK&BsL!8w;^y2mTr^&81c%82I2 zQo$0eCngOAt>BkAGf-B~;E$PncE}OeD1(eE;tL3ZNCdS>;uK^q zB63q0_EDVjh={<}5cZ#-!L*3S38V@UY&k|3iyWK?ZC2+mvO7aSA- z^|;vKr@eU0nlGWsQMuUE|@{pl@46#&vgJIYHI!t-Npgc&G>@(!CZ!_Zx`C9KuxcGS;w#LN_x0StoykLds z?(Tz~2fdx|-`{=PQ3NdpyTGz4mZM5;L?QiJR%Kt-2)ejli^rM`cY1HbN-W^vA;8S@ z4&_*wO$S)u=#?*vD%hixl)f(#T)hK{bO&hc+`s?ZyGRHT7x_x|K@LJ`slUqaD^F<| z8xN{siFejyNG&mc>&w4jRf1?5^06tNZ$-V}zS8%Zd#->~t@aGQmps@TzH% zVZk)7*vA_2(fCW)8`z8`|OehbQVst!>n zi0|`~!)pSGA0xe`{_<7YKO+?8d5M%c&w&2+@Z@b-QRQ)u@{xN^98pD0&`np2bhj*E4cLxB zfiPg4xPAU>R0DjrbFS*Gd*=?Y664N5;yLX6P}0it0&90d7ppMo;#7)!z0k5glO-j?isxr5#1~=NKvBoKBcy&(kr`xD_Bbk@i zl5xe+Q-2g!ff7IrhF=lZyyd*tSd!n%RxCVa>%W01aem)uNRqtSj1(U>FrrkX*@!ZR z-C#r^2F}{13GQ>Ezt@L?4d*rXMuNFBv+F=A3|U~i=MARR!>LG%qgXYZd#>#iRToP} zPWoeG+S-{S5h!>_fI@C;6m; zEnx6_@BRL%p-vkWk)z1+u~qz{bpjieZy5QmNg^!eE$W)yoEK9r*ByuZhO5u@jAvhZ{7fYjw1~iE)u^b3P9|pUY?|G|}By~s8QiXPlojNp=pfIkn<`oi}B;7T6 zth1a`_R7}gm;uZtpX8F=x0un&h!w}RPR`UBPM^u#OqQ~BUS&vjyG^9jV-v+aqPV>q z;u((yROg|6cHZ?Zp@1xPZ7&k)R_bC)lLwKh$%AmG+>%;9MVh4tO9!^N^OJN{ib?rE zR^{H=f=v$t8_xnEm{73G1_SoC>k^_Lj3xoCV2P2ELgxK0UW=)c=`vrIfgK#$nB3o( ziHgW=9Q^Q8+mcT@2e@HSJc>BGVC5`dreoZTLx{pX%Gv#py?q|C;Sys!?3im(zv>vS z8DB-zY9v;h1>rtTS^sTs=)#&KNm5ia(=6;+M8VF{w!RY-vCwJ4xac&980j>*@zUDp zsgDkzH*W$(!Qjm{fM#_V(G@918$NNt%xV}eCll*1ij>u-cA%#@c@Rsx&Z|Q!leg1s zL0oMwDdu7gKqjFU#chWfq_n0FK+P8`N4>?~<9tE&acy3v7BHc%D99%nn=J0_VVOO? z%o8x$B`+pvz5-t?m%McK1vyVe1xBuWwaSZz7ny{Frsc5?>v*+EpDLIrB_v<=O!t;5 zvle64F^{YHGF383V*YYyP@J!_YG=w#*D+enc))@lsA@*@)!-+)!3@{dT5p}K1FRsC;}_2I9tV#rB%6Y z8&hg&HOqDS-Zli<(jd!I+(DChr5_V5YubZE_B^w&c_Mf?GSe+ z)FF(#N&}?&bud`dK!06M)Os`=*r=|Cg1J1>O3bTmT4l_d zWZfFRDDFE18r0S_Yv&6i%I<4tk{+W?j}a*L`>lhW-i}%M{G&j(z-JuPX;;&O4Axc? zY{UQs-w0lq^o>x|bZxo+-PR9&Xvt!{}}!Ge-eulsug z*1^7x8%e0scYpV?Nj<%@k?zxFENkAl{+DJEUdHpUa|aFzxVgCAF_6vmiyxj_Q*tw? z9r{WOy0G(fp+PZ1Y3a4an|$0+xSrwt|F89v<2B;Y3|~SvH}HI00W?k)9LctsVl5EP zg6LXtBb%9vVmH{W*3xD}{bV23YA|PEwDmaKImfleu*v%En`~Fi(kHcRS4}4zBKjmR z@WrR7(bV9vo5nF){yW!^zN3fTNyN0d`wix%a)vuei<8^Lko%N$;q-iQ2V2y+M*R6z`O05gc7DcAc1eW^@r6)?QUti;!jxLE2Z@kqRD zG|pEUB=uObB{228(DzP>DP1HjdyXI?7)3$`6|lrrRO-jq+9GiS6{t-TytUY&9Re_( z^FmNb%4iW6ahh@z7tgl`R8X{bu(X&+)#TD6-m7tA%z?g1N9ngj=y~M}>pc=d1V+eC zYQk>U^$qz7k093@$5q_JKo8r9V9w(Va}L-&RO10n1-T&9ZuPl^SCB)JrMn-kd9%eI z0R?xUnMZ!Xq(QC6-YZN0t9mVBVjU&p^+xJhVw}}UoNuhOzi>Az-EPcgk$RIP1-sy@ z1;#JIslju(h}9*Av1qDa2=y=O*9kX`Mi561%@M_cqf?*l-Owh4FJQX)%rLG-Gqv7&OzhhRBB$ESfDMu>BfNgGexWDGMi48SyzHZ9LUApS;^%5Hd(kE&=^*)QhK)k{q&^w;z>^7bI ziuB7mgm1Vm;_&dOf*-lJe@lKuN_}!Q`WPC z&aw1Ix-um^V?xf##}*V)OfvgGID^y;M;6R`p{_Mxw}oT`fRue$;u{rA{OT=9Pd{O-3eT{It7P>&=pCM=rxs!48Wby5eSFtxgLSkP(l*M@3 zZFSwf4drua7!4e!X@oMlTSI|Y-wQV5*Y{KmuY8GV60%P7d+VK0{5n#X z;x}Q?@dI>udPw3e=RoRXZ)%CKOQfDr;zu96;ec)Qa?x4hZME=Y4wGp7>aX|V=bE!c z^||qkQGMR{tWkX$&KzrS-njanVGbYt;lJUEIDV{YsUN|y^nu%XgICp5Yxdy!1a4}z znC&%$1y_K!Q|3u)+`2YMiY)uK9ovjAL7G1v!^DTAParWuvP1mMv-!KvWL)H7-LT3q zE(di7l{Rsg6t^xhM4dGshT^SdpMTgiFz3km(q6VyRodd(@|$_?d_xmD3Qy?|EYVW^QOU1HF)4OlOXcexb zTK-}r?Chry-X|(rnsL{xj`iabeP*u!TA@k2Kz#9H92T0ha)k$vbmSncT zS{F`4d7Wi2r8VbRO^D~$ZGB}_oXgVn z;4-)cC%C&i!Gbfm%i!)#f;+)o65QS09fBvgJHZJAf_|JUAGz;2{ipkxHG4m^)^t}_ zt?I6QXP`YUTMZ~Z$Pp%t)g%#Q3lU~bq8p2J44W5fWZeCL(B@wEtIcC&m4!8ng!|~M zP;4Jxa`L>zeXnCNFG7lexubx{Ek*@-wzFn#LN2)T@(0&r@*61kI_=v--i0H=mHrP` z!7L+VvUQ1Dh;zP8g(qe7(q%aU1zvvfRWy7L64qDull_Qyrz*fUv|ln>pppE5`D75y z1%ymlA1{(YFyW?GO^o!|AMZJ8nyA7N(%wM>_?4wuH>QSujdV zP&z+u$h;wW_~$ug(cQ0c+c=UK;Y49c?N>o5SQWrcK8 z<*W&uqRI)u{W-Y_wKupD`mm+hUY|z4^mR9-X}7JIu`YZKn8vM#VC@bF;zPSZ)L?$n zgXb=~2v&k6g;tR8bQY1ohT4sNQ1{yoQt@yd9omgO%|~SkuY4l0E8pDQ!gOT~fF`Fb ztP*8gc6-K_TeurSv6h7z|GJ_Z(|aSyxtExnUI+}dS3n_lN_OXfBsc!*$s_|kzYE&> zbkn;-0f^|Xmf$Ru!m)yxVy9P+%=Uh#IKf!+v6R(8sI;F=RdJv!$?be}O-4x7YPNsP ze6P$rj)wfCt10IuG z;}Rcqo0CB>PzH<^ryAF@fW^BXSm&SpkxmkLSt*=DTWamsD<)UJUya+pR@6a+qSP_~ z!>@260RWpXRdvw+NbJ4T&@ngtmxK@gUlKmo?vL21PTX-({hD^II`x*^Lyq_cM9Yzx zZY2{#fl^YU@&OEh-+7i#Hyv>2Q2E6R3T>+=mCfvOkY~U%AZqrP*gl4?gbxjKCcgI1 zyuYc_D$;IYo`z+;-YXoCC8(NTMnAON1lcq=yA2FPw7+>smmauAspcl|>f>oKoYppx zZrqZC*fx|>a9Qz&#C>C3DU~?zsoPh^iH>t5d57`M#qq=O)p8yJAgxIt;nP=-R_Wrs zw72^lt-17)SBuC=HDpD{>IwE~hBFgITD@d|^ykS+A!HP%>;1xm6z{T{Z z%ntHU^P;L~=IBX<%@a;uo=1^RexZ+?6iDBoZ5C{qlzhs^Y8O13A3U{)jrs8Ad>&&bm+(+4Oe8FJ z)Ix6}YYb`Y!j%g3P;z=_=cF2!?`wa2)p-=~V~>jZ1{{^L>{Zk4%+YmjP(U0*p*ifM zG-?8nX>Ny%T>?qhSXAU&Jnf;5pU))HppA5IEU>a~N`*lSg_I2RC4i2L5T#g`w!1N5 zoR%Lp;nb@g5jFz`p^_aw1mvGA!{rbE6N!zgHH01KP-j1?Akt$&LNN%T!*~Wlj~!p5 zifvYPEtpl22%hOJt($UQ+}`}4Y~QY=_$?^GC}IdqJjPHp2r4%!PRN|g4F=~UOHj?m z?nY3fKyg46A@G6VnOcoLbEt`9;^2=~Cv>~1(c4m5eodPqr6@h1o9s@Lv!g8}H3&hc zs!?c-lnWj#5YsP0E-&GthxsTz(hVOgii{0&waexGyG-d=5jSEg}D6D zA2iZB)>VEmm}bQqrDzN`U$K{h-RN&#WX$-&khUhJWN!>{DoV;`Ao?q>W}h2+I1%#1 zdf3h#Mu)+ImE_%o9;+a9z{McN^x<1m)#f;r^8h_g3*{}c^#jX<8J7&h>zy!YDR6I< zX79Wfj&F@z9+{tB2}9CGdL4m12U7{GmAo?-Nfh&A{5ym#VTFa>51Y4R1(IqJ6bP8H z%T`MAj!1Z5*R&a(DhgTqmh|Yy_55|^oyVcb6$)sSPB<7B-taaFu@3K)?}T$-2#?I_ zguQZ;B6M7`6h;IQQ-3P2@piRSOonlzz#GififPHfS|y;s2kO2HT37c@D8M8x2StX1 zcHyml_mHVbJr{ntO4otUhZF#K!I*U$@3xf&!X&oI#qo3~?<45>Wd`#;c0+13D~4(R zELbZvCE3HgI(`e!)MUz-$nizD5xbgQv(*?kIZx|Hi+&OZ2h7OW1<`wZQ}oF3a;P3< z9w)R68`89Gtvte&m$Tsmy zpkR)W9L}@e))mBhZ}Gu79D?lDozuo|@v;T7yc-mG^cP)t@sDrX8n|7QsP&`nT#nCF zY4KV#8(yS=nT3yLS3RA_rpN6NO+O_kMwX7Uz!aBcE3j|cEj{9VD9eTE1edK`(rGso z1$bUm=BHaOu@=P_f}SKLgnEtUbAGsDyK>sfg~#O0lV(dm;E6#K$4~XtJNu;uqZ#Yh zsh4~5_R9%25~lLjIpnjGl4R%#Arsd5g~UmLqf@Q;C=5Qjx(1ro8KZQTEB5&3f@s#| zy`l!16j5-7ab)tyxVM-VM0(u_?4VO@ejC}>eCk@GpN-i3mVVSdUTg*KmVzVP=(3pQ zGG2z{#pyyOZ)dx0h6%JaWY+rio~}TeEfuYbl1bd^eC13CR&|*6fW&Qq>DcjnOUR}@ z{XuyJy3ZZ%(%?+2^O5k^+Q$V$@<>Efv0~rOPSnlaBMs(X3=uW>PKTm*-@pnuwbtpX z?^X61kOZne_O9^giH`5O$CcFXbwIDk8E_GN#!ds%;wBmccaxN(tt59n^KAIZkm4_k zD|_66aADoX?j0=%uw)MC&kQhOd&kRaNu% zo^raa8{NRR86S7{okF)YEqq21a+ZLwbqE3NFFXR<+!w~H+qzE$-PVn7pxmAZPodm| zQJ9(zPThKJrMy)D)q#%?o-(k38`NQ(_WCPt#g9ITVWn7|kGIg2Pf8$UUNPGG)}I;N zNf`+oY$UKFT#rNICJF2(bP04pdqfH~C1%q4*0NUvqxsE&W^E7*87RKcRhCv~LhinX zw>%M*BTG5!GR<=~@-4XD(1()}32J@_wD%s0S4s5rEXQLVJi%XKlrQGOHr%eb>R2Jn zb$`T7NbW7uF_ zC~S^aBQ55=1H``f>=W-luTI$oqdaS-Z`Y7l3WXoHerCy(9=*bceA`D3JLjMtrhw$E z2M1jf&umA&r^AgN+`mWHOFLPf(+Be%-DjRfZo|c*^S*A6yE#JMc|NjkD#E6+DZ{zh zBELS)#tLgHwZfI#_1gl4-)M$5HWyW^i{J@|$n3A^r=nBY`5~Ip?(w7!;_AcxVoTAz zJz@Rl%?b$-Y?g^L0pTe;G7q2o(&=uEH}Ow%oSL>3sW}bhg6+zH$%c6EsLGc;)`I4! zXijtvzwb`cJTxz9IJnHUcL%T1XRTd~5n=m|j%qybAwH6)ez+0@5$#QhLMWOZhfuVG zc%i7mhR+>t6j+fO|!ZiHg@qb5g#L1 zsAzVg4X{-&TnYKqjClLpj@a!jvENs_DE}~|59fmUyGj?qFwkKIM*hj4B!#aQjZ2zW zHGO^7q%S%@NU-TSf;mNQt5WP6-!cWO)&fb~yD!jekCTZVRPzosR?h<65BO1wBtqEQ z%ygJ}$#Zb$Rs=QM%hwSs@d?ZmytKebPfzqBk9RTtC9K^)PzzRDUQ%XT z^?LQSriu5!u+YouVDYHbwt@_HeN?WGjSCDwu?2!MtDB07o#WP7-puq+57@5yYY}9% z$B8NJ#@$8uqs^j$h=Le+q$|T4lJ_{i%iOl0W4_c2Jxr9QFFJ-G6!mEace~~$VUq13 zI5?oXUl-X&q|pFLjhF!h-z~leNPVymwPC^h@#gR|#A}SCMtd7Y_g}V zx5t~4o(%KA8&NffynfVrgPw0G_s2&EE~0{nM;9Us}(6EHOIKWQ?IOR{yNhhjV{4uT@d8`JDRD zxwu6%4O$9FAqNiPpxi5bau^kpY|Kt#3o0}=BzMf9HBZU!pMJ=bu=yy#eMJoj%o;%< z2%2c9z~IGcF-iKuT=&6K$|&@bfO49*`nCf*+2+*CFr?n0kvurwHJh<|l-Lc5*dc(E{%bS`CEp0QI>GQ`XjyTnjaE{ne>u(acu8h`i$ZgPah z8M8!|OMTXa3Tm<4JnEDA^w|OG>ukNyrPta3KA{jm0RTy`007?KXR8U=1?+6<@^a#} zLiNmkmIdwcOV?JS^i(dB1ZGx9ERBwTSe&}-P=C>|D^vLQey&C_^(U_hOAG#rhg-h_ zx~Q|$Rqx6D%$}qf9rDr9o`JDQjdPqC#9h?89{$E)f=2ZaUem^N2A)7_>u>^6+2({T zeth~B`x&9em#Ip>lcd@}2cvC8D8Dhd+k>Z8P6zfcBlQ1*#KDfD3Xv83b#|BRH%nK3 z{*yd?U~EKTMH;3|Z~O%^J^D(nF;!FSMqnT+&GhkBRB4<_ZH_PR9Orb@T{4`_x{77u zzIsClxuzNqq?jgqBrjT68JgLJCs;G^Gh4=xjbxJZlDX7)K{`;!%F6gbh%n zBOQ|S0n!*#2pZS6uxJ=a25|5VZl7xLc8@A(~cNfhhBl67`$dbLQ;y@{!L~ zc$<7;R=+#1W#P4m>!Y{(?ZI00(dE-vsoVB_(5w9ko2zCT&MDFnE4@3s7 zt@SbdM$FpU@q(1%QX=>;HTaV{zE>x#<%VA3gLpQGUsRp^P4Gs(Y;r|byYlq0tS>T0 zx}>IWEJuGl6@k?>{5)#zN4n7_&z~e3XdyXVrQ&J3U#E}xmf#Wd@Wb-Irfw!_6d%|z zM#b7u;uH2Nai|_+&X}S>x0{)KouO$3%_qj@5-6E`z`NBUnmFvP!d^$psa_wi@ngVL zZe}3RDvx}TNMI1>RsFr&l()*%j@KJ%C@T~_I8AIqcNtynT4oEPDKB@!xZnB+-4!IYBk$Lev+8OcL7#$*$( zZ4W|}milb!DIei?i}a=q>em%DR7|F2?YA@oe6jtKFDvTL!{g4jU>h4IPq3|xno5t| zvJ_Is5q*HJ(^C0~TcoXBkhwTQZZNvJ$UEwHHPq3KjV78h+^bavqwo>xaMGEZ0+)iF zi!pt&z8_lKvL;p!U9vKV!~`roTB3(|Y@Tj4I|WuNh`XMWzK*#kWVYS>!V2AfxO2I2 zXKVkwT*7al;ONz=Lw)avCh6fOKM-dDlV`O6l70yPW?1 z_v4{m-Cenet$;9zk6mUf$wle3(=n&p*kH;I-^A+md+&ou62C_l2CZG5{`}?1#~V9; zdxN!*$ICoh!tKKfqOGo*5l)52oDRauJPd8&%J_n{Do5_gzN(9gU^m>VeJ=gNr}~Ww zVvW1<1RoIDs&t9ovNbp1wykx7-2ExmGQmt&jd16L45#k8kyE={s1S;Cn{U*;eM$|< zo)73`_LO5_$2Q`mbG9Ppc;C@#|3XJK1o)HU3FipuZgU?=^~yd)OG1g1nE3g@)+1(Y zPRxnl>*%^bctB@qvS7%!(9uJxvajZ|k?PfChK@s+bGy}NpY;33_d3ie*C{dUB`PpJ zAE3`pmZuE1$L7EP9NV0g0ot%ormzLnNC9jhdM^jEnJpIZ4FY)_KfO3(03d+W&waw* z5gKC{lq?OP7a@gDgAr#dhXi_7<1;r`3-pMJC6}<3fLecU1Q(bMVgN|9>l#XaP{HGJ zX%7bi({M<45M96-hrD_6-*6jo3BMCZaRwE@2VqrD2j&kD5{$YZ!BGmj@vVk7Bt`Ul ziBpsi=2BC2sZphde7lEr(_+*f_zn|q+b?$dvBWdSGW$yTZ4v2>jf_E}KYpnjFywOC zpJNq#W9GLmYXghA6Z+LQ!A0dy)g%4j4v^b z>_`PoX=F{A2ELZe<3s@1av)G54a^e>ULdUJz=7c?oU#jJ)F?Ef=jIEwJH-L=PVI?E zgqHXCW#*YdvnDs`{wN>jGtg0P7)d26Y#AsOH?k9FOn}5Y!mk5zV%}2@mnX(9V~O}` z;kza@Wfo3&)bOb{(v=oQU=y`FS~@Uce+C&#R~)BVzeDE|-MHFqgcYsSSB1(MMR^fI z<)X$+5QKw>Gu2^fGl~vU0Y!8%P>a=wm?(Ir|tvvyV7kOCHpZVzkbkLnBW$ z+}l9DOl!UVIgR5wA4IwVNRe7UGwp2NzG^^+2BUt$)((z*w~(HSl#$4*QcU3%sEMtd zY)M>H0=l_|$CM&3A`#L0YR?J`n7<6qBO9=(zsCUXHL}2IuM>Xxn zrWgt$Oy_actKrCw9p$@9?qnIU3#l);7}M%^+p9)jWf&*w2rHjj+&_<=kkIxS&~N{& zw80hb`Q(!9{|-|PD$pIc>Wo6ses^q$>h+f1)T#}B-eilm)-$ua7_PDEK$35afk)@F zSRP*!%lzD?_GiC_yOf4o!!Cv-(cIbf&Pi@-mzn@Vl}!udW_K%>N}vY69TM=Gh39-a zHC?~tIH1RdAQ5A*Q{9oNYYi+Tsy z)GH3lzy%$$0(Z7=v zY|I7*mUfmd1_m!q*((76!2f|<>#wig`Ml8nm%!HE#MQ=>`4tQg=zehp zzaOWOv}ytX8(|Ot%Kw1>{r0lIga0RtF>rOUwE0Vc;7>|Z9eOR^@$Pq`ZvlX_Kmg$7 zKGnb4zbT7#{`Pw;?W|0V{}UDcN%T)r2!ApCn~40w--!MujPNJhKi%(oWy>sjWqZBt z^(X$H(nzm(*S=T$|4%CY3IC@E@fE%}`G4Vmix>YS`zHtfl`M4TZ)C6T{GT-c Date: Fri, 13 Sep 2024 15:41:17 +0800 Subject: [PATCH 14/16] pr modification --- .../models/progen/module/configuration_utils.py | 2 +- .../src/mindsponge/pipeline/models/progen/nn_arch.py | 11 ----------- .../pipeline/models/progen/progen_dataset.py | 2 +- 3 files changed, 2 insertions(+), 13 deletions(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py index 6c3da402b..6b194b087 100755 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/module/configuration_utils.py @@ -1707,7 +1707,7 @@ class PreTrainedModelMindnlp(nn.Cell, CellUtilMixin, GenerationMixin): """ prepare_inputs_for_generation """ - pass + return @classmethod def _from_config(cls, config, **kwargs): diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py index a7c621bb0..af93ed603 100755 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/nn_arch.py @@ -427,11 +427,6 @@ class ProGenPreTrainedModel(PreTrainedModelMindnlp): module.bias.data.zero_() module.weight.data.fill_(1.0) - def prepare_inputs_for_generation(self): - """ - prepare_inputs_for_generation - """ - return class ProGenModel(ProGenPreTrainedModel): """ @@ -455,12 +450,6 @@ class ProGenModel(ProGenPreTrainedModel): def set_input_embeddings(self, new_embeddings): self.wte = new_embeddings - def prepare_inputs_for_generation(self): - """ - prepare_inputs_for_generation - """ - return - def construct( self, input_ids=None, diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py index e528c79ff..195e1d23f 100755 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -45,7 +45,7 @@ class ProGenDataSet(PSP): def data_parse(self, idx): return None - def __getitem__(self, idx): + def __getitem__(self): pass def __len__(self): -- Gitee From 409e2a1f4ba1086e80c3f3814fb029f8a9ae0700 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Fri, 13 Sep 2024 17:57:49 +0800 Subject: [PATCH 15/16] pr modification --- .../src/mindsponge/pipeline/models/progen/progen_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py index 195e1d23f..4bcd9ea87 100755 --- a/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py +++ b/MindSPONGE/src/mindsponge/pipeline/models/progen/progen_dataset.py @@ -45,7 +45,8 @@ class ProGenDataSet(PSP): def data_parse(self, idx): return None - def __getitem__(self): + #pylint: disable=arguments-differ + def __getitem__(self, idx): pass def __len__(self): -- Gitee From 01f049561162727cdf04175cb3075d652eae0384 Mon Sep 17 00:00:00 2001 From: zhang-yucheng2024 Date: Mon, 23 Sep 2024 19:55:46 +0800 Subject: [PATCH 16/16] update requirements --- MindSPONGE/requirements.txt | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/MindSPONGE/requirements.txt b/MindSPONGE/requirements.txt index 94d6310c3..85f8ef0e6 100644 --- a/MindSPONGE/requirements.txt +++ b/MindSPONGE/requirements.txt @@ -1,6 +1,7 @@ numpy >= 1.17.0 scipy >= 1.7.0 biopython == 1.79 +biopandas == 0.4.1 pyyaml >= 5.4.1 dataclasses >= 0.6 glob2 >= 0.6 @@ -9,5 +10,8 @@ absl-py >= 1.1.0 biotite == 0.38 descriptastorus == 2.6.0 pyparsing >= 3.0.7 +POT == 0.9.3 +tokenizers +joblib rdkit -mindspore-gl \ No newline at end of file +mindspore-gl -- Gitee