diff --git a/MindChemistry/applications/Uni-Mol/README.md b/MindChemistry/applications/Uni-Mol/README.md new file mode 100644 index 0000000000000000000000000000000000000000..8d583e1b08786db6fbbb382ecd6ca77f0444ae83 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/README.md @@ -0,0 +1,148 @@ +# 模型名称 + +> Uni-Mol + +## 介绍 + +> Uni-Mol 是由 DP Technology 开发的首个通用大规模三维分子表示学习(MRL)框架,通过训练 2.09 亿个分子三维构象和 300 万个蛋白质口袋数据,显著提升了分子表示能力和应用范围。该框架在多种分子性质预测任务中表现优异,尤其在 3D 相关任务中性能突出,可应用于药物设计、材料性质预测(如 MOF 材料的气体吸附性能、OLED 分子的光学性质等)。相关论文发表于 The Eleventh International Conference on Learning Representations (ICLR 2023)。 +## 数据集 + +> Uni-Mol v1使用的数据集包括: + + +| Data | File Size | Update Date | Download Link | +|--------------------------|------------| ----------- |---------------------------------------------------------------------------------------------------------------------------| +| molecular pretrain | 114.76GB | Jun 10 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/pretrain/ligands.tar.gz | +| pocket pretrain | 8.585GB | Aug 17 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/pretrain/pockets.tar.gz | +| molecular property | 3.506GB | Jul 10 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/molecular_property_prediction.tar.gz | +| molecular conformation | 8.331GB | Jul 19 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/conformation_generation.tar.gz | +| pocket property | 455.239MB | Jul 19 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/pocket_property_prediction.tar.gz | +| protein-ligand binding | 263.27MB | Sep 8 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/protein_ligand_binding_pose_prediction.tar.gz | + + +我们使用[LMDB](https://lmdb.readthedocs.io)去存储数据,你可以使用如下脚本去读取LMDB文件。 +```python +import lmdb +import numpy as np +import os +import pickle + +def read_lmdb(lmdb_path): + env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=256, + ) + txn = env.begin() + keys = list(txn.cursor().iternext(values=False)) + for idx in keys: + datapoint_pickled = txn.get(idx) + data = pickle.loads(datapoint_pickled) +``` +建议使用Python >= 3.8以上版本。 + +## 环境要求 + +> 1. 安装`mindspore`(推荐版本2.6.0) +> 2. 安装`mindchemistry` + +## 快速入门 + +> 训练命令: `python train.py` + +## 脚本说明 + +### 代码目录结构 + +```txt +├─unicore +│ ├─data +│ │ └─__pycache__ +│ ├─distributed +│ │ └─__pycache__ +│ ├─logging +│ │ └─__pycache__ +│ ├─losses +│ │ └─__pycache__ +│ ├─models +│ │ └─__pycache__ +│ ├─modules +│ │ └─__pycache__ +│ ├─optim +│ │ ├─lr_scheduler +│ │ │ └─__pycache__ +│ │ └─__pycache__ +│ ├─tasks +│ │ └─__pycache__ +│ └─__pycache__ +└─unimol + ├─Alignment + │ ├─ms-output + │ ├─torch-output + │ └─__pycache__ + ├─docker + ├─example_data + │ ├─molecule + │ └─pocket + ├─figure + ├─notebooks + ├─offload + ├─unimol + │ ├─data + │ │ └─__pycache__ + │ ├─losses + │ │ └─__pycache__ + │ ├─mindspore_ascend_outputs + │ ├─models + │ │ └─__pycache__ + │ ├─offload + │ ├─tasks + │ │ └─__pycache__ + │ ├─unimol.egg-info + │ ├─utils + │ │ └─__pycache__ + │ └─__pycache__ + └─unimol.egg-info +``` + +## 训练过程 + +### 训练 + +直接训练 + +```txt +python train.py +``` + +在昇腾上使用分布式训练运行下面的命令 + +```shell +bash run.sh +``` + +训练过程日志 + +```log + + +``` + +## 推理评估过程 + +### 推理评估 + +```txt +1.将权重checkpoint文件保存至 `/checkpoint/`目录下(默认读取目录) +2.执行推理脚本:python predict.py +``` + +推理评估结果 + +```txt +可以通过 predict.log 文件查看结果; 推理输出文件为 pred.npy +``` \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9df4c9212d1eab6a16a95d2dfb03c7861323fc85 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import os +import sys + +try: + from .version import __version__ # noqa +except ImportError: + version_txt = os.path.join(os.path.dirname(__file__), "version.txt") + with open(version_txt) as f: + __version__ = f.read().strip() + +__all__ = ["pdb"] + +# backwards compatibility to support `from unicore.X import Y` +from unicore.distributed import utils as distributed_utils +from unicore.logging import meters, metrics, progress_bar # noqa + +sys.modules["unicore.distributed_utils"] = distributed_utils +sys.modules["unicore.meters"] = meters +sys.modules["unicore.metrics"] = metrics +sys.modules["unicore.progress_bar"] = progress_bar + +import unicore.losses # noqa +import unicore.distributed # noqa +import unicore.models # noqa +import unicore.modules # noqa +import unicore.optim # noqa +import unicore.optim.lr_scheduler # noqa +import unicore.tasks # noqa + diff --git a/MindChemistry/applications/Uni-Mol/unicore/checkpoint_utils.py b/MindChemistry/applications/Uni-Mol/unicore/checkpoint_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..62751d8005395e56b28a2371364be4eabdc01ea8 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/checkpoint_utils.py @@ -0,0 +1,621 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import ast +import collections +import logging +import os +import re +import shutil +import traceback +from typing import Any, Dict, Optional + +import mindspore as ms + + +logger = logging.getLogger(__name__) + + +# async ckp copy +def ckp_copy_fun(src, checkpoints, end_of_epoch, args): + has_copy = False + can_delete = args.tmp_save_dir != args.save_dir + for cp in checkpoints: + try: + if src != cp: + logger.info("copy {} to {}".format(src, cp)) + has_copy = True + shutil.copyfile(src, cp) + except: + logger.info("copy failed, please copy it manaully") + + try: + if can_delete and has_copy and os.path.lexists(src): + logger.info("removing temp file {} ...".format(src)) + os.remove(src) + + def remove_ckps(root_path): + if not end_of_epoch and args.keep_interval_updates > 0: + # remove old checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + root_path, pattern=r"checkpoint_\d+_(\d+)\.ckpt" # MindSpore默认 checkpoint 后缀为 .ckpt + ) + for old_chk in checkpoints[args.keep_interval_updates :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + logger.info("removed {}".format(old_chk)) + + if args.keep_last_epochs >= 0: + # remove old epoch checkpoints; checkpoints are sorted in descending order + checkpoints = checkpoint_paths( + root_path, pattern=r"checkpoint(\d+)\.ckpt" + ) + for old_chk in checkpoints[args.keep_last_epochs :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + logger.info("removed {}".format(old_chk)) + + if args.keep_best_checkpoints > 0: + # only keep the best N checkpoints according to validation metric + checkpoints = checkpoint_paths( + root_path, + pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.ckpt".format( + args.best_checkpoint_metric + ), + ) + if not args.maximize_best_checkpoint_metric: + checkpoints = checkpoints[::-1] + for old_chk in checkpoints[args.keep_best_checkpoints :]: + if os.path.lexists(old_chk): + os.remove(old_chk) + logger.info("removed {}".format(old_chk)) + + remove_ckps(args.save_dir) + except: + logger.info("remove old ckps error") + + logger.info("finished async ckp saving.") + + +def save_checkpoint(args, trainer, epoch_itr, val_loss, ckp_copy_thread, do_save=True): + from unicore import meters + + # only one worker should attempt to create the required dir + if trainer.data_parallel_rank == 0: + os.makedirs(args.save_dir, exist_ok=True) + + prev_best = getattr(save_checkpoint, "best", val_loss) + if val_loss is not None: + best_function = max if args.maximize_best_checkpoint_metric else min + save_checkpoint.best = best_function(val_loss, prev_best) + + if args.no_save or not do_save: + return + + if not trainer.should_save_checkpoint_on_current_rank: + return + + write_timer = meters.StopwatchMeter() + write_timer.start() + + epoch = epoch_itr.epoch + end_of_epoch = epoch_itr.end_of_epoch() + updates = trainer.get_num_updates() + + logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + + def is_better(a, b): + return a >= b if args.maximize_best_checkpoint_metric else a <= b + + suffix = trainer.checkpoint_suffix + checkpoint_conds = collections.OrderedDict() + # MindSpore默认 checkpoint 后缀为 .ckpt,替换 .pt 为 .ckpt + checkpoint_conds["checkpoint{}{}.ckpt".format(epoch, suffix)] = ( + end_of_epoch + and not args.no_epoch_checkpoints + and epoch % args.save_interval == 0 + ) + checkpoint_conds["checkpoint_{}_{}{}.ckpt".format(epoch, updates, suffix)] = ( + not end_of_epoch + and args.save_interval_updates > 0 + and updates % args.save_interval_updates == 0 + ) + checkpoint_conds["checkpoint_best{}.ckpt".format(suffix)] = val_loss is not None and ( + not hasattr(save_checkpoint, "best") + or is_better(val_loss, save_checkpoint.best) + ) + if val_loss is not None and args.keep_best_checkpoints > 0: + checkpoint_conds[ + "checkpoint.best_{}_{:.2f}.ckpt".format(args.best_checkpoint_metric, val_loss) + ] = not hasattr(save_checkpoint, "best") or is_better( + val_loss, save_checkpoint.best + ) + checkpoint_conds["checkpoint_last{}.ckpt".format(suffix)] = ( + not args.no_last_checkpoints + ) + + extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} + if hasattr(save_checkpoint, "best"): + extra_state.update({"best": save_checkpoint.best}) + + checkpoints = [ + os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond + ] + tmp_checkpoints = [ + os.path.join(args.tmp_save_dir, fn) + for fn, cond in checkpoint_conds.items() + if cond + ] + if len(checkpoints) > 0: + trainer.save_checkpoint(tmp_checkpoints[0], extra_state) + if ckp_copy_thread is not None: + ckp_copy_thread.apply_async( + ckp_copy_fun, (tmp_checkpoints[0], checkpoints, end_of_epoch, args) + ) + write_timer.stop() + logger.info( + "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( + tmp_checkpoints[0], epoch, updates, val_loss, write_timer.sum + ) + ) + + +def load_checkpoint(args, trainer, **passthrough_args): + """ + Load a checkpoint and restore the training iterator. + + *passthrough_args* will be passed through to + ``trainer.get_train_iterator``. + """ + + reset_optimizer = args.reset_optimizer + reset_lr_scheduler = args.reset_lr_scheduler + optimizer_overrides = ast.literal_eval(args.optimizer_overrides) + reset_meters = args.reset_meters + reset_dataloader = args.reset_dataloader + + if args.finetune_from_model is not None and ( + reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader + ): + raise ValueError( + "--finetune-from-model can not be set together with either --reset-optimizer" + " or reset_lr_scheduler or reset_meters or reset_dataloader" + ) + + suffix = trainer.checkpoint_suffix + if ( + args.restore_file == "checkpoint_last.ckpt" # 适配MindSpore默认后缀 + ): # default value of restore_file is 'checkpoint_last.ckpt' + checkpoint_path = os.path.join( + args.save_dir, "checkpoint_last{}.ckpt".format(suffix) + ) + first_launch = not os.path.exists(checkpoint_path) + if args.finetune_from_model is not None and first_launch: + # if there is no last checkpoint to restore, start the finetune from pretrained model + # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. + if os.path.exists(args.finetune_from_model): + checkpoint_path = args.finetune_from_model + reset_optimizer = True + reset_lr_scheduler = True + reset_meters = True + reset_dataloader = True + logger.info( + f"loading pretrained model from {checkpoint_path}: " + "optimizer, lr scheduler, meters, dataloader will be reset" + ) + else: + raise ValueError( + f"--funetune-from-model {args.finetune_from_model} does not exist" + ) + elif suffix is not None: + checkpoint_path = args.restore_file.replace(".ckpt", suffix + ".ckpt") # 替换后缀 + else: + checkpoint_path = args.restore_file + + if args.restore_file != "checkpoint_last.ckpt" and args.finetune_from_model: # 适配后缀 + raise ValueError( + "--finetune-from-model and --restore-file (non-default value) " + "can not be specified together: " + str(args) + ) + + extra_state, epoch_itr = trainer.load_checkpoint( + checkpoint_path, + reset_optimizer, + reset_lr_scheduler, + reset_dataloader, + optimizer_overrides, + reset_meters=reset_meters,** passthrough_args, + ) + + if ( + extra_state is not None + and "best" in extra_state + and not reset_optimizer + and not reset_meters + ): + save_checkpoint.best = extra_state["best"] + + return extra_state, epoch_itr + + +def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=True): + """Loads a checkpoint to CPU (with upgrading for backward compatibility). + There's currently no support for > 1 but < all processes loading the + checkpoint on each node. + """ + local_path = path + with open(local_path, "rb") as f: + # 替换 torch.load 为 mindspore.load,map_location指定为CPU + state = ms.load(f, map_location='cpu') + + if "args" in state and state["args"] is not None and arg_overrides is not None: + args = state["args"] + for arg_name, arg_val in arg_overrides.items(): + setattr(args, arg_name, arg_val) + + return state + + +def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.ckpt"): # 适配MindSpore checkpoint后缀 + """Retrieves all checkpoints found in `path` directory. + + Checkpoints are identified by matching filename to the specified pattern. If + the pattern contains groups, the result will be sorted by the first group in + descending order. + """ + pt_regexp = re.compile(pattern) + files = os.listdir(path) + + entries = [] + for i, f in enumerate(files): + m = pt_regexp.fullmatch(f) + if m is not None: + idx = float(m.group(1)) if len(m.groups()) > 0 else i + entries.append((idx, m.group(0))) + return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +def torch_persistent_save(obj, filename): + # 原子保存逻辑,替换为MindSpore格式 + temp_filename = filename + ".tmp" + _torch_persistent_save(obj, temp_filename) + os.rename(temp_filename, filename) + + +def _torch_persistent_save(obj, f): + if isinstance(f, str): + with open(f, "wb") as h: + torch_persistent_save(obj, h) + return + for i in range(3): + try: + # 替换 torch.save 为 mindspore.save + return ms.save(obj, f) + except Exception: + if i == 2: + logger.error(traceback.format_exc()) + + +def verify_checkpoint_directory(save_dir: str) -> None: + if not os.path.exists(save_dir): + os.makedirs(save_dir, exist_ok=True) + temp_file_path = os.path.join(save_dir, "dummy") + try: + with open(temp_file_path, "w"): + pass + except OSError as e: + logger.warning( + "Unable to access checkpoint save directory: {}".format(save_dir) + ) + raise e + else: + os.remove(temp_file_path) +# import ast +# import collections +# import logging +# import os +# import re +# import shutil +# import traceback +# from typing import Any, Dict, Optional + +# import torch + + +# logger = logging.getLogger(__name__) + + +# # async ckp copy +# def ckp_copy_fun(src, checkpoints, end_of_epoch, args): +# has_copy = False +# can_delete = args.tmp_save_dir != args.save_dir +# for cp in checkpoints: +# try: +# if src != cp: +# logger.info("copy {} to {}".format(src, cp)) +# has_copy = True +# shutil.copyfile(src, cp) +# except: +# logger.info("copy failed, please copy it manaully") + +# try: +# if can_delete and has_copy and os.path.lexists(src): +# logger.info("removing temp file {} ...".format(src)) +# os.remove(src) + +# def remove_ckps(root_path): +# if not end_of_epoch and args.keep_interval_updates > 0: +# # remove old checkpoints; checkpoints are sorted in descending order +# checkpoints = checkpoint_paths( +# root_path, pattern=r"checkpoint_\d+_(\d+)\.pt" +# ) +# for old_chk in checkpoints[args.keep_interval_updates :]: +# if os.path.lexists(old_chk): +# os.remove(old_chk) +# logger.info("removed {}".format(old_chk)) + +# if args.keep_last_epochs >= 0: +# # remove old epoch checkpoints; checkpoints are sorted in descending order +# checkpoints = checkpoint_paths( +# root_path, pattern=r"checkpoint(\d+)\.pt" +# ) +# for old_chk in checkpoints[args.keep_last_epochs :]: +# if os.path.lexists(old_chk): +# os.remove(old_chk) +# logger.info("removed {}".format(old_chk)) + +# if args.keep_best_checkpoints > 0: +# # only keep the best N checkpoints according to validation metric +# checkpoints = checkpoint_paths( +# root_path, +# pattern=r"checkpoint\.best_{}_(\d+\.?\d*)\.pt".format( +# args.best_checkpoint_metric +# ), +# ) +# if not args.maximize_best_checkpoint_metric: +# checkpoints = checkpoints[::-1] +# for old_chk in checkpoints[args.keep_best_checkpoints :]: +# if os.path.lexists(old_chk): +# os.remove(old_chk) +# logger.info("removed {}".format(old_chk)) + +# remove_ckps(args.save_dir) +# except: +# logger.info("remove old ckps error") + +# logger.info("finished async ckp saving.") + + +# def save_checkpoint(args, trainer, epoch_itr, val_loss, ckp_copy_thread, do_save=True): +# from unicore import meters + +# # only one worker should attempt to create the required dir +# if trainer.data_parallel_rank == 0: +# os.makedirs(args.save_dir, exist_ok=True) + +# prev_best = getattr(save_checkpoint, "best", val_loss) +# if val_loss is not None: +# best_function = max if args.maximize_best_checkpoint_metric else min +# save_checkpoint.best = best_function(val_loss, prev_best) + +# if args.no_save or not do_save: +# return + +# if not trainer.should_save_checkpoint_on_current_rank: +# return + +# write_timer = meters.StopwatchMeter() +# write_timer.start() + +# epoch = epoch_itr.epoch +# end_of_epoch = epoch_itr.end_of_epoch() +# updates = trainer.get_num_updates() + +# logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates") + +# def is_better(a, b): +# return a >= b if args.maximize_best_checkpoint_metric else a <= b + +# suffix = trainer.checkpoint_suffix +# checkpoint_conds = collections.OrderedDict() +# checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = ( +# end_of_epoch +# and not args.no_epoch_checkpoints +# and epoch % args.save_interval == 0 +# ) +# checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = ( +# not end_of_epoch +# and args.save_interval_updates > 0 +# and updates % args.save_interval_updates == 0 +# ) +# checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and ( +# not hasattr(save_checkpoint, "best") +# or is_better(val_loss, save_checkpoint.best) +# ) +# if val_loss is not None and args.keep_best_checkpoints > 0: +# checkpoint_conds[ +# "checkpoint.best_{}_{:.2f}.pt".format(args.best_checkpoint_metric, val_loss) +# ] = not hasattr(save_checkpoint, "best") or is_better( +# val_loss, save_checkpoint.best +# ) +# checkpoint_conds["checkpoint_last{}.pt".format(suffix)] = ( +# not args.no_last_checkpoints +# ) + +# extra_state = {"train_iterator": epoch_itr.state_dict(), "val_loss": val_loss} +# if hasattr(save_checkpoint, "best"): +# extra_state.update({"best": save_checkpoint.best}) + +# checkpoints = [ +# os.path.join(args.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond +# ] +# tmp_checkpoints = [ +# os.path.join(args.tmp_save_dir, fn) +# for fn, cond in checkpoint_conds.items() +# if cond +# ] +# if len(checkpoints) > 0: +# trainer.save_checkpoint(tmp_checkpoints[0], extra_state) +# if ckp_copy_thread is not None: +# ckp_copy_thread.apply_async( +# ckp_copy_fun, (tmp_checkpoints[0], checkpoints, end_of_epoch, args) +# ) +# write_timer.stop() +# logger.info( +# "Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format( +# tmp_checkpoints[0], epoch, updates, val_loss, write_timer.sum +# ) +# ) + + +# def load_checkpoint(args, trainer, **passthrough_args): +# """ +# Load a checkpoint and restore the training iterator. + +# *passthrough_args* will be passed through to +# ``trainer.get_train_iterator``. +# """ + +# reset_optimizer = args.reset_optimizer +# reset_lr_scheduler = args.reset_lr_scheduler +# optimizer_overrides = ast.literal_eval(args.optimizer_overrides) +# reset_meters = args.reset_meters +# reset_dataloader = args.reset_dataloader + +# if args.finetune_from_model is not None and ( +# reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader +# ): +# raise ValueError( +# "--finetune-from-model can not be set together with either --reset-optimizer" +# " or reset_lr_scheduler or reset_meters or reset_dataloader" +# ) + +# suffix = trainer.checkpoint_suffix +# if ( +# args.restore_file == "checkpoint_last.pt" +# ): # default value of restore_file is 'checkpoint_last.pt' +# checkpoint_path = os.path.join( +# args.save_dir, "checkpoint_last{}.pt".format(suffix) +# ) +# first_launch = not os.path.exists(checkpoint_path) +# if args.finetune_from_model is not None and first_launch: +# # if there is no last checkpoint to restore, start the finetune from pretrained model +# # else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc. +# if os.path.exists(args.finetune_from_model): +# checkpoint_path = args.finetune_from_model +# reset_optimizer = True +# reset_lr_scheduler = True +# reset_meters = True +# reset_dataloader = True +# logger.info( +# f"loading pretrained model from {checkpoint_path}: " +# "optimizer, lr scheduler, meters, dataloader will be reset" +# ) +# else: +# raise ValueError( +# f"--funetune-from-model {args.finetune_from_model} does not exist" +# ) +# elif suffix is not None: +# checkpoint_path = args.restore_file.replace(".pt", suffix + ".pt") +# else: +# checkpoint_path = args.restore_file + +# if args.restore_file != "checkpoint_last.pt" and args.finetune_from_model: +# raise ValueError( +# "--finetune-from-model and --restore-file (non-default value) " +# "can not be specified together: " + str(args) +# ) + +# extra_state, epoch_itr = trainer.load_checkpoint( +# checkpoint_path, +# reset_optimizer, +# reset_lr_scheduler, +# reset_dataloader, +# optimizer_overrides, +# reset_meters=reset_meters, +# **passthrough_args, +# ) + +# if ( +# extra_state is not None +# and "best" in extra_state +# and not reset_optimizer +# and not reset_meters +# ): +# save_checkpoint.best = extra_state["best"] + +# return extra_state, epoch_itr + + +# def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=True): +# """Loads a checkpoint to CPU (with upgrading for backward compatibility). +# There's currently no support for > 1 but < all processes loading the +# checkpoint on each node. +# """ +# local_path = path +# with open(local_path, "rb") as f: +# state = torch.load(f, map_location=torch.device("cpu"), weights_only=False) + +# if "args" in state and state["args"] is not None and arg_overrides is not None: +# args = state["args"] +# for arg_name, arg_val in arg_overrides.items(): +# setattr(args, arg_name, arg_val) + +# return state + + +# def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt"): +# """Retrieves all checkpoints found in `path` directory. + +# Checkpoints are identified by matching filename to the specified pattern. If +# the pattern contains groups, the result will be sorted by the first group in +# descending order. +# """ +# pt_regexp = re.compile(pattern) +# files = os.listdir(path) + +# entries = [] +# for i, f in enumerate(files): +# m = pt_regexp.fullmatch(f) +# if m is not None: +# idx = float(m.group(1)) if len(m.groups()) > 0 else i +# entries.append((idx, m.group(0))) +# return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)] + + +# def torch_persistent_save(obj, filename): +# # do atomic save +# with open(filename + ".tmp", "wb") as f: +# _torch_persistent_save(obj, f) +# os.rename(filename + ".tmp", filename) + + +# def _torch_persistent_save(obj, f): +# if isinstance(f, str): +# with open(f, "wb") as h: +# torch_persistent_save(obj, h) +# return +# for i in range(3): +# try: +# return torch.save(obj, f) +# except Exception: +# if i == 2: +# logger.error(traceback.format_exc()) + + +# def verify_checkpoint_directory(save_dir: str) -> None: +# if not os.path.exists(save_dir): +# os.makedirs(save_dir, exist_ok=True) +# temp_file_path = os.path.join(save_dir, "dummy") +# try: +# with open(temp_file_path, "w"): +# pass +# except OSError as e: +# logger.warning( +# "Unable to access checkpoint save directory: {}".format(save_dir) +# ) +# raise e +# else: +# os.remove(temp_file_path) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c297d7bb8a79e82044a63a357adabf3a1f2186f0 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/__init__.py @@ -0,0 +1,36 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + + +from .unicore_dataset import UnicoreDataset + +from .base_wrapper_dataset import BaseWrapperDataset + +from .append_token_dataset import AppendTokenDataset +from .dictionary import Dictionary +from .lru_cache_dataset import LRUCacheDataset +from .mask_tokens_dataset import MaskTokensDataset +from .bert_tokenize_dataset import BertTokenizeDataset +from .tokenize_dataset import TokenizeDataset +from .nested_dictionary_dataset import NestedDictionaryDataset +from .numel_dataset import NumelDataset +from .num_samples_dataset import NumSamplesDataset +from .pad_dataset import LeftPadDataset, PadDataset, RightPadDataset, RightPadDataset2D +from .prepend_token_dataset import PrependTokenDataset +from .raw_dataset import RawLabelDataset, RawArrayDataset, RawNumpyDataset +from .lmdb_dataset import LMDBDataset +from .sort_dataset import SortDataset, EpochShuffleDataset +from .from_numpy_dataset import FromNumpyDataset + +from .iterators import ( + CountingIterator, + EpochBatchIterator, + GroupedIterator, + ShardedIterator, +) + +__all__ = [] diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/append_token_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/append_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..74e929b1dadc407e71d057faf3976ad9c7b686ad --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/append_token_dataset.py @@ -0,0 +1,45 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +# import numpy as np +# import torch +# from functools import lru_cache + +# from . import BaseWrapperDataset + + +# class AppendTokenDataset(BaseWrapperDataset): +# def __init__(self, dataset, token=None): +# super().__init__(dataset) +# self.token = token + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# item = self.dataset[idx] +# if self.token is not None: +# item = torch.cat([item, torch.full_like(item[0], self.token).unsqueeze(0)], dim=0) +# return item +import numpy as np +import mindspore.mint as ms # 替换torch为mindspore.mint +from functools import lru_cache + +from . import BaseWrapperDataset + + +class AppendTokenDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + # 替换torch.cat为mindspore.mint.cat,功能一致 + # 替换torch.full_like为mindspore.mint.full_like,功能一致 + # 替换unsqueeze(0)为mindspore.mint.unsqueeze,功能一致 + item = ms.cat([item, ms.full_like(item[0], self.token).unsqueeze(0)], dim=0) + return item diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/base_wrapper_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/base_wrapper_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..f414afce3f991b0354947e4c60519b3c6edf4d34 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/base_wrapper_dataset.py @@ -0,0 +1,118 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from mindspore.dataset import GeneratorDataset # 替换PyTorch DataLoader相关依赖 +from mindspore.dataset.transforms import Compose # MindSpore中用于批处理的工具类 + +from . import UnicoreDataset + + +class BaseWrapperDataset(UnicoreDataset): + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, "collater"): + return self.dataset.collater(samples) + else: + # MindSpore中无直接对应的default_collate,但可通过Compose和默认批处理逻辑实现类似功能 + # 这里模拟PyTorch的default_collate,将样本列表转换为批处理格式 + return Compose()(samples) # 利用MindSpore的Compose实现默认拼接 + + def ordered_indices(self): + return self.dataset.ordered_indices() + + @property + def supports_prefetch(self): + return getattr(self.dataset, "supports_prefetch", False) + + def attr(self, attr: str, index: int): + return self.dataset.attr(attr, index) + + def prefetch(self, indices): + self.dataset.prefetch(indices) + + def batch_by_size( + self, + indices, + batch_size=None, + required_batch_size_multiple=1, + ): + return self.dataset.batch_by_size( + indices, + batch_size=batch_size, + required_batch_size_multiple=required_batch_size_multiple, + ) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return self.dataset.can_reuse_epoch_itr_across_epochs + + def set_epoch(self, epoch): + super().set_epoch(epoch) + if hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(epoch) +# from torch.utils.data.dataloader import default_collate + +# from . import UnicoreDataset + + +# class BaseWrapperDataset(UnicoreDataset): +# def __init__(self, dataset): +# super().__init__() +# self.dataset = dataset + +# def __getitem__(self, index): +# return self.dataset[index] + +# def __len__(self): +# return len(self.dataset) + +# def collater(self, samples): +# if hasattr(self.dataset, "collater"): +# return self.dataset.collater(samples) +# else: +# return default_collate(samples) + +# def ordered_indices(self): +# return self.dataset.ordered_indices() + +# @property +# def supports_prefetch(self): +# return getattr(self.dataset, "supports_prefetch", False) + +# def attr(self, attr: str, index: int): +# return self.dataset.attr(attr, index) + +# def prefetch(self, indices): +# self.dataset.prefetch(indices) + +# def batch_by_size( +# self, +# indices, +# batch_size=None, +# required_batch_size_multiple=1, +# ): +# return self.dataset.batch_by_size( +# indices, +# batch_size=batch_size, +# required_batch_size_multiple=required_batch_size_multiple, +# ) + +# @property +# def can_reuse_epoch_itr_across_epochs(self): +# return self.dataset.can_reuse_epoch_itr_across_epochs + +# def set_epoch(self, epoch): +# super().set_epoch(epoch) +# if hasattr(self.dataset, "set_epoch"): +# self.dataset.set_epoch(epoch) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/bert_tokenize_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/bert_tokenize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a7b3e32d6f616aeebc7a878b674bafb56bf870a1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/bert_tokenize_dataset.py @@ -0,0 +1,69 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from functools import lru_cache + +import numpy as np +import mindspore as ms # 替换torch为mindspore +from tokenizers import BertWordPieceTokenizer + +from . import BaseWrapperDataset, LRUCacheDataset + + +class BertTokenizeDataset(BaseWrapperDataset): + def __init__( + self, + dataset: ms.dataset.Dataset, # 替换torch.utils.data.Dataset为mindspore.dataset.Dataset + dict_path: str, + max_seq_len: int = 512, + ): + self.dataset = dataset + self.tokenizer = BertWordPieceTokenizer(dict_path, lowercase=True) + self.max_seq_len = max_seq_len + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def __getitem__(self, index: int): + raw_str = self.dataset[index] + raw_str = raw_str.replace('', '[UNK]') + output = self.tokenizer.encode(raw_str) + # 替换torch.Tensor(...).long()为mindspore.Tensor,指定dtype为int64(对应long类型) + ret = ms.Tensor(output.ids, dtype=ms.int64) + # 替换torch.Tensor.size(0)为mindspore.Tensor.shape[0] + if ret.shape[0] > self.max_seq_len: + ret = ret[:self.max_seq_len] + return ret +# from functools import lru_cache + +# import numpy as np +# import torch +# from tokenizers import BertWordPieceTokenizer + +# from . import BaseWrapperDataset, LRUCacheDataset + + +# class BertTokenizeDataset(BaseWrapperDataset): +# def __init__( +# self, +# dataset: torch.utils.data.Dataset, +# dict_path: str, +# max_seq_len: int=512, +# ): +# self.dataset = dataset +# self.tokenizer = BertWordPieceTokenizer(dict_path, lowercase=True) +# self.max_seq_len = max_seq_len + +# @property +# def can_reuse_epoch_itr_across_epochs(self): +# return True # only the noise changes, not item sizes + +# def __getitem__(self, index: int): +# raw_str = self.dataset[index] +# raw_str = raw_str.replace('', '[UNK]') +# output = self.tokenizer.encode(raw_str) +# ret = torch.Tensor(output.ids).long() +# if ret.size(0) > self.max_seq_len: +# ret = ret[:self.max_seq_len] +# return ret \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/data_utils.py b/MindChemistry/applications/Uni-Mol/unicore/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c9655b12fe424cbfd37cbb11c29e6d7b0377d693 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/data_utils.py @@ -0,0 +1,289 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import contextlib +import logging + +import numpy as np +import mindspore as ms # 替换torch为mindspore +from mindspore import ops # 导入MindSpore算子库 + + +logger = logging.getLogger(__name__) + + +def collate_tokens( + values, + pad_idx, + left_pad=False, + pad_to_length=None, + pad_to_multiple=1, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + # 替换v.size(0)为v.shape[0](获取张量第一维长度) + size = max(v.shape[0] for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + # 替换torch.Tensor.new()创建同类型张量:MindSpore通过指定dtype实现 + # 先获取原张量数据类型,再创建填充pad_idx的新张量 + dtype = values[0].dtype + res = ms.Tensor(np.full((len(values), size), pad_idx, dtype=np.int64), dtype=dtype) + + def copy_tensor(src, dst): + assert dst.size == src.size # MindSpore中用size属性(总元素数)替代numel() + dst.assign(src) # 替换torch.copy_()为MindSpore的assign()方法 + + for i, v in enumerate(values): + # 调整切片逻辑,保持与原逻辑一致 + if left_pad: + copy_tensor(v, res[i, size - len(v):]) # MindSpore用逗号分隔维度 + else: + copy_tensor(v, res[i, :len(v)]) + return res + + +def collate_tokens_2d( + values, + pad_idx, + left_pad=False, + pad_to_length=None, + pad_to_multiple=1, +): + """Convert a list of 2d tensors into a padded 3d tensor.""" + size = max(v.shape[0] for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + + # 创建三维填充张量 + dtype = values[0].dtype + res = ms.Tensor(np.full((len(values), size, size), pad_idx, dtype=np.int64), dtype=dtype) + + def copy_tensor(src, dst): + assert dst.size == src.size + dst.assign(src) + + for i, v in enumerate(values): + if left_pad: + # 2D张量的切片逻辑 + copy_tensor(v, res[i, size - len(v):, size - len(v):]) + else: + copy_tensor(v, res[i, :len(v), :len(v)]) + return res + + +def collate_dict( + values, + dim=0, +): + if len(values) <= 0: + return values + ret = {} + keys = values[0].keys() + for key in keys: + # 替换torch.stack为mindspore.stack(映射表中功能一致) + ret[key] = ops.stack([v[key] for v in values], axis=dim) # MindSpore用axis参数 + return ret + + +def str_hash(text: str): + hash_val = 0 + for ch in text: + hash_val = (hash_val * 281 ^ ord(ch) * 997) & 0xFFFFFFFF + return hash_val + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds, key=None): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + def check_seed(s): + assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 + check_seed(seed) + if len(addl_seeds) > 0: + for s in addl_seeds: + check_seed(s) + seed = int(hash((seed, *addl_seeds)) % 1e8) + if key is not None: + seed = int(hash((seed, str_hash(key))) % 1e8) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) + + +def batch_by_size( + indices, + batch_size=None, + required_batch_size_multiple=1, +): + """ + Yield mini-batches of indices bucketed by size. Batches may contain + sequences of different lengths. + + Args: + indices (List[int]): ordered list of dataset indices + batch_size (int, optional): max number of sentences in each + batch (default: None). + required_batch_size_multiple (int, optional): require batch size to + be less than N or a multiple of N (default: 1). + """ + + batch_size = batch_size if batch_size is not None else 1 + bsz_mult = required_batch_size_multiple + + step = ((batch_size + bsz_mult - 1) // bsz_mult) * bsz_mult + + if not isinstance(indices, np.ndarray): + indices = np.fromiter(indices, dtype=np.int64, count=-1) + + num_batches = (len(indices) + step - 1) // step + steps = np.arange(num_batches - 1) + 1 + steps *= step + batch_indices = np.split(indices, steps) + assert len(batch_indices) == num_batches + # validation or test data size is smaller than a mini-batch size in some downstream tasks. + assert batch_indices[0].shape[0] <= step + return batch_indices +# import contextlib +# import logging + +# import numpy as np +# import torch + + +# logger = logging.getLogger(__name__) + + +# def collate_tokens( +# values, +# pad_idx, +# left_pad=False, +# pad_to_length=None, +# pad_to_multiple=1, +# ): +# """Convert a list of 1d tensors into a padded 2d tensor.""" +# size = max(v.size(0) for v in values) +# size = size if pad_to_length is None else max(size, pad_to_length) +# if pad_to_multiple != 1 and size % pad_to_multiple != 0: +# size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) +# res = values[0].new(len(values), size).fill_(pad_idx) + +# def copy_tensor(src, dst): +# assert dst.numel() == src.numel() +# dst.copy_(src) + +# for i, v in enumerate(values): +# copy_tensor(v, res[i][size - len(v) :] if left_pad else res[i][: len(v)]) +# return res + + +# def collate_tokens_2d( +# values, +# pad_idx, +# left_pad=False, +# pad_to_length=None, +# pad_to_multiple=1, +# ): +# """Convert a list of 1d tensors into a padded 2d tensor.""" +# size = max(v.size(0) for v in values) +# size = size if pad_to_length is None else max(size, pad_to_length) +# if pad_to_multiple != 1 and size % pad_to_multiple != 0: +# size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) +# res = values[0].new(len(values), size, size).fill_(pad_idx) + +# def copy_tensor(src, dst): +# assert dst.numel() == src.numel() +# dst.copy_(src) + +# for i, v in enumerate(values): +# copy_tensor(v, res[i][size - len(v):, size - len(v):] if left_pad else res[i][:len(v), :len(v)]) +# return res + + +# def collate_dict( +# values, +# dim=0, +# ): +# if len(values) <= 0: +# return values +# ret = {} +# keys = values[0].keys() +# for key in keys: +# ret[key] = torch.stack([v[key] for v in values], dim=dim) +# return ret + + +# def str_hash(text:str): +# hash=0 +# for ch in text: +# hash = ( hash*281 ^ ord(ch)*997) & 0xFFFFFFFF +# return hash + + +# @contextlib.contextmanager +# def numpy_seed(seed, *addl_seeds, key=None): +# """Context manager which seeds the NumPy PRNG with the specified seed and +# restores the state afterward""" +# if seed is None: +# yield +# return +# def check_seed(s): +# assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 +# check_seed(seed) +# if len(addl_seeds) > 0: +# for s in addl_seeds: +# check_seed(s) +# seed = int(hash((seed, *addl_seeds)) % 1e8) +# if key is not None: +# seed = int(hash((seed, str_hash(key))) % 1e8) +# state = np.random.get_state() +# np.random.seed(seed) +# try: +# yield +# finally: +# np.random.set_state(state) + + +# def batch_by_size( +# indices, +# batch_size=None, +# required_batch_size_multiple=1, +# ): +# """ +# Yield mini-batches of indices bucketed by size. Batches may contain +# sequences of different lengths. + +# Args: +# indices (List[int]): ordered list of dataset indices +# batch_size (int, optional): max number of sentences in each +# batch (default: None). +# required_batch_size_multiple (int, optional): require batch size to +# be less than N or a multiple of N (default: 1). +# """ + +# batch_size = batch_size if batch_size is not None else 1 +# bsz_mult = required_batch_size_multiple + +# step = ((batch_size + bsz_mult - 1) // bsz_mult) * bsz_mult + +# if not isinstance(indices, np.ndarray): +# indices = np.fromiter(indices, dtype=np.int64, count=-1) + +# num_batches = (len(indices) + step - 1) // step +# steps = np.arange(num_batches - 1) + 1 +# steps *= step +# batch_indices = np.split(indices, steps) +# assert len(batch_indices) == num_batches +# # validation or test data size is smaller than a mini-batch size in some downstream tasks. +# assert batch_indices[0].shape[0] <= step +# return batch_indices diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/dictionary.py b/MindChemistry/applications/Uni-Mol/unicore/data/dictionary.py new file mode 100644 index 0000000000000000000000000000000000000000..db2b2c801db9431ccc3af5ddbfec473728c39c87 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/dictionary.py @@ -0,0 +1,148 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging + +import numpy as np + +logger = logging.getLogger(__name__) # pylint: disable=invalid-name + +class Dictionary: + """A mapping from symbols to consecutive integers""" + + def __init__( + self, + *, # begin keyword-only arguments + bos="[CLS]", + pad="[PAD]", + eos="[SEP]", + unk="[UNK]", + extra_special_symbols=None, + ): + self.bos_word, self.unk_word, self.pad_word, self.eos_word = bos, unk, pad, eos + self.symbols = [] + self.count = [] + self.indices = {} + self.specials = set() + self.specials.add(bos) + self.specials.add(unk) + self.specials.add(pad) + self.specials.add(eos) + + def __eq__(self, other): + return self.indices == other.indices + + def __getitem__(self, idx): + if idx < len(self.symbols): + return self.symbols[idx] + return self.unk_word + + def __len__(self): + """Returns the number of symbols in the dictionary""" + return len(self.symbols) + + def __contains__(self, sym): + return sym in self.indices + + def vec_index(self, a): + return np.vectorize(self.index)(a) + + def index(self, sym): + """Returns the index of the specified symbol""" + assert isinstance(sym, str) + if sym in self.indices: + return self.indices[sym] + return self.indices[self.unk_word] + + def special_index(self): + return [self.index(x) for x in self.specials] + + def add_symbol(self, word, n=1, overwrite=False, is_special=False): + """Adds a word to the dictionary""" + if is_special: + self.specials.add(word) + if word in self.indices and not overwrite: + idx = self.indices[word] + self.count[idx] = self.count[idx] + n + return idx + else: + idx = len(self.symbols) + self.indices[word] = idx + self.symbols.append(word) + self.count.append(n) + return idx + + def bos(self): + """Helper to get index of beginning-of-sentence symbol""" + return self.index(self.bos_word) + + def pad(self): + """Helper to get index of pad symbol""" + return self.index(self.pad_word) + + def eos(self): + """Helper to get index of end-of-sentence symbol""" + return self.index(self.eos_word) + + def unk(self): + """Helper to get index of unk symbol""" + return self.index(self.unk_word) + + @classmethod + def load(cls, f): + """Loads the dictionary from a text file with the format: + + ``` + + + ... + ``` + """ + d = cls() + d.add_from_file(f) + return d + + def add_from_file(self, f): + """ + Loads a pre-existing dictionary from a text file and adds its symbols + to this instance. + """ + if isinstance(f, str): + try: + with open(f, "r", encoding="utf-8") as fd: + self.add_from_file(fd) + except FileNotFoundError as fnfe: + raise fnfe + except UnicodeError: + raise Exception( + "Incorrect encoding detected in {}, please " + "rebuild the dataset".format(f) + ) + return + + lines = f.readlines() + + for line_idx, line in enumerate(lines): + try: + splits = line.rstrip().rsplit(" ", 1) + line = splits[0] + field = splits[1] if len(splits) > 1 else str(len(lines) - line_idx) + if field == "#overwrite": + overwrite = True + line, field = line.rsplit(" ", 1) + else: + overwrite = False + count = int(field) + word = line + if word in self and not overwrite: + logger.info( + "Duplicate word found when loading Dictionary: '{}', index is {}.".format(word, self.indices[word]) + ) + else: + self.add_symbol(word, n=count, overwrite=overwrite) + except ValueError: + raise ValueError( + "Incorrect dictionary format, expected ' [flags]'" + ) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/from_numpy_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/from_numpy_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a99e634965e18d7b9153b612792e3d86f962b47e --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/from_numpy_dataset.py @@ -0,0 +1,32 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms # 替换torch为mindspore +from functools import lru_cache + +from . import BaseWrapperDataset + + +class FromNumpyDataset(BaseWrapperDataset): + def __init__(self, dataset): + super().__init__(dataset) + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + # 替换torch.from_numpy为mindspore.Tensor,功能一致(将numpy数组转换为张量) + return ms.Tensor(self.dataset[idx]) +# import torch +# from functools import lru_cache + +# from . import BaseWrapperDataset + + +# class FromNumpyDataset(BaseWrapperDataset): +# def __init__(self, dataset): +# super().__init__(dataset) + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# return torch.from_numpy(self.dataset[idx]) + + diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/iterators.py b/MindChemistry/applications/Uni-Mol/unicore/data/iterators.py new file mode 100644 index 0000000000000000000000000000000000000000..90e898254e39e028c30bbbb97b2c5e000ddbe825 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/iterators.py @@ -0,0 +1,1005 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import itertools +import logging +import math +import operator +import os +import queue +import time +from threading import Thread + +import numpy as np +import mindspore.dataset as ds # 替换torch.utils.data相关导入 +from mindspore.dataset import Dataset # MindSpore数据集基类 +from unicore.data import data_utils + + +logger = logging.getLogger(__name__) + +# Object used by _background_consumer to signal the source is exhausted +# to the main thread. +_sentinel = object() + + +class CountingIterator(object): + """Wrapper around an iterable that maintains the iteration count.""" + + def __init__(self, iterable, start=None, total=None): + self.iterable = iterable + self.itr = iter(self) + + if start is None: + self.n = getattr(iterable, "n", 0) + else: + self.n = start + + if total is None: + self.total = self.n + len(iterable) + else: + self.total = total + + def __len__(self): + return self.total + + def __iter__(self): + for x in self.iterable: + if self.n >= self.total: + raise RuntimeError( + "Mismatch between actual and expected iterable length. " + "This may be caused by resuming training from a checkpoint using " + "a different number of GPUs, in which case you can try the " + "--reset-dataloader option. Alternatively you may have a train or " + "validation set that is smaller than the number of GPUs. If none " + "of these apply, please report this to the unicore developers." + ) + self.n += 1 + yield x + + def __next__(self): + return next(self.itr) + + def has_next(self): + """Whether the iterator has been exhausted.""" + return self.n < len(self) + + def skip(self, num_to_skip): + """Fast-forward the iterator by skipping *num_to_skip* elements.""" + next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) + return self + + def take(self, n): + """Truncates the iterator to n elements at most.""" + self.total = min(self.total, n) + propagated_take = max(n - self.n, 0) + if hasattr(self.iterable, "take"): + self.iterable.take(propagated_take) + else: + self.iterable = itertools.islice(self.iterable, propagated_take) + + +class EpochBatchIterating(object): + def __len__(self) -> int: + raise NotImplementedError + + @property + def next_epoch_idx(self): + raise NotImplementedError + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + raise NotImplementedError + + def end_of_epoch(self) -> bool: + raise NotImplementedError + + @property + def iterations_in_epoch(self) -> int: + raise NotImplementedError + + def state_dict(self): + raise NotImplementedError + + def load_state_dict(self, state_dict): + raise NotImplementedError + + @property + def first_batch(self): + return "DUMMY" + + +class EpochBatchIterator(EpochBatchIterating): + """A multi-epoch iterator over a MindSpore Dataset.""" + + def __init__( + self, + dataset, + collate_fn, + batch_sampler, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + buffer_size=0, + timeout=0, + disable_shuffling=False, + ): + assert isinstance(dataset, Dataset) # 替换为MindSpore Dataset检查 + self.dataset = dataset + self.collate_fn = collate_fn + self.batch_sampler = batch_sampler + self._frozen_batches = ( + tuple(batch_sampler) if not callable(batch_sampler) else None + ) + self.seed = seed + self.num_shards = num_shards + self.shard_id = shard_id + self.num_workers = num_workers # 对应MindSpore的num_parallel_workers + self.buffer_size = min(buffer_size, 32) + self.timeout = timeout + self.disable_shuffling = disable_shuffling + + self.epoch = max(epoch, 1) + self.shuffle = not disable_shuffling + self._cur_epoch_itr = None + self._next_epoch_itr = None + self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + + @property + def frozen_batches(self): + if self._frozen_batches is None: + self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) + return self._frozen_batches + + @property + def first_batch(self): + if len(self.frozen_batches) == 0: + raise Exception( + "The dataset is empty. This could indicate " + "that all elements in the dataset have been skipped. " + "Try increasing the max number of allowed tokens or using " + "a larger dataset." + ) + + if getattr(self.dataset, "supports_fetch_outside_dataloader", True): + return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) + else: + return "DUMMY" + + def __len__(self): + return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + + @property + def n(self): + return self.iterations_in_epoch + + @property + def next_epoch_idx(self): + if self._next_epoch_itr is not None: + return self.epoch + elif self._cur_epoch_itr is not None and self.end_of_epoch(): + return self.epoch + 1 + else: + return self.epoch + + def next_epoch_itr( + self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True + ): + if self.disable_shuffling: + shuffle = False + self.epoch = self.next_epoch_idx + if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): + self.dataset.set_epoch(self.epoch) + if self._next_epoch_itr is not None: + self._cur_epoch_itr = self._next_epoch_itr + self._next_epoch_itr = None + else: + if callable(self.batch_sampler): + self._frozen_batches = None + self._cur_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle, + fix_batches_to_gpus=fix_batches_to_gpus, + ) + self.shuffle = shuffle + return self._cur_epoch_itr + + def end_of_epoch(self) -> bool: + return not self._cur_epoch_itr.has_next() + + @property + def iterations_in_epoch(self): + if self._cur_epoch_itr is not None: + return self._cur_epoch_itr.n + elif self._next_epoch_itr is not None: + return self._next_epoch_itr.n + return 0 + + def state_dict(self): + if self.end_of_epoch(): + epoch = self.epoch + 1 + iter_in_epoch = 0 + else: + epoch = self.epoch + iter_in_epoch = self.iterations_in_epoch + return { + "epoch": epoch, + "iterations_in_epoch": iter_in_epoch, + "shuffle": self.shuffle, + "len": len(self), + } + + def load_state_dict(self, state_dict): + self.epoch = state_dict["epoch"] + itr_pos = state_dict.get("iterations_in_epoch", 0) + if itr_pos > 0: + if "len" in state_dict and state_dict["len"] != len(self): + old_itr_pos = itr_pos + itr_pos = int(itr_pos * len(self) / state_dict["len"]) + logger.info( + "Iterator size is changed. it is possible due to the change of update_freq/num_gpu. The itr_pos is change from {} to {} for consistency".format(old_itr_pos, itr_pos) + ) + + self._next_epoch_itr = self._get_iterator_for_epoch( + self.epoch, + shuffle=state_dict.get("shuffle", True), + offset=itr_pos, + ) + if self._next_epoch_itr is None: + raise RuntimeError( + "Cannot resume training due to dataloader mismatch. You can relaunch " + "training with `--reset-dataloader` and it should work." + ) + else: + self._next_epoch_itr = None + + def _get_iterator_for_epoch( + self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 + ): + def shuffle_batches(batches, seed): + with data_utils.numpy_seed(seed): + np.random.shuffle(batches) + return batches + + if self._supports_prefetch: + batches = self.frozen_batches + + if shuffle and not fix_batches_to_gpus: + batches = shuffle_batches(list(batches), self.seed + epoch) + + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + self.dataset.prefetch([i for s in batches for i in s]) + + if shuffle and fix_batches_to_gpus: + batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) + else: + if shuffle: + batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) + else: + batches = self.frozen_batches + batches = list( + ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) + ) + + if offset > 0 and offset >= len(batches): + return None + + if self.num_workers > 0: + os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + + # 替换torch.utils.data.DataLoader为MindSpore数据集处理链 + # 1. 创建采样器迭代器 + def batch_sampler_generator(): + for batch in batches[offset:]: + yield batch + + # 2. 从采样器创建数据集 + sampler_ds = ds.GeneratorDataset( + batch_sampler_generator, + column_names=["indices"], + num_parallel_workers=self.num_workers # 对应num_workers + ) + + # 3. 映射获取数据并应用collate_fn + def fetch_and_collate(indices): + samples = [self.dataset[int(i)] for i in indices] + return self.collate_fn(samples) + + data_loader = sampler_ds.map( + operations=fetch_and_collate, + input_columns=["indices"], + num_parallel_workers=self.num_workers + ) + + # 转换为迭代器 + itr = data_loader.create_dict_iterator(output_numpy=False, num_epochs=1) + + # 包装缓冲迭代器 + if self.buffer_size > 0: + itr = BufferedIterator(self.buffer_size, itr) + + # 包装计数迭代器 + itr = CountingIterator(itr, start=offset) + return itr + + +class GroupedIterator(CountingIterator): + """Wrapper around an iterable that returns groups (chunks) of items.""" + + def __init__(self, iterable, chunk_size): + itr = _chunk_iterator(iterable, chunk_size) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), + total=int(math.ceil(len(iterable) / float(chunk_size))), + ) + self.chunk_size = chunk_size + + +def _chunk_iterator(itr, chunk_size): + chunk = [] + for x in itr: + chunk.append(x) + if len(chunk) == chunk_size: + yield chunk + chunk = [] + if len(chunk) > 0: + yield chunk + + +class ShardedIterator(CountingIterator): + """A sharded wrapper around an iterable, padded to length.""" + + def __init__(self, iterable, num_shards, shard_id, fill_value=None): + if shard_id < 0 or shard_id >= num_shards: + raise ValueError("shard_id must be between 0 and num_shards") + sharded_len = int(math.ceil(len(iterable) / float(num_shards))) + itr = map( + operator.itemgetter(1), + itertools.zip_longest( + range(sharded_len), + itertools.islice(iterable, shard_id, len(iterable), num_shards), + fillvalue=fill_value, + ), + ) + super().__init__( + itr, + start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), + total=sharded_len, + ) + + +class BackgroundConsumer(Thread): + def __init__(self, queue, source, max_len): + Thread.__init__(self) + + self._queue = queue + self._source = source + self._max_len = max_len + self.count = 0 + + def run(self): + try: + for item in self._source: + # MindSpore的迭代器返回字典,提取数据 + data = item["flexible"] if "flexible" in item else item + self._queue.put(data) + + self.count += 1 + if self._max_len is not None and self.count >= self._max_len: + break + + self._queue.put(_sentinel) + except Exception as e: + self._queue.put(e) + + +class BufferedIterator(object): + def __init__(self, size, iterable): + self._queue = queue.Queue(size) + self._iterable = iterable + self._consumer = None + + self.start_time = time.time() + self.warning_time = None + + self.total = len(iterable) + + def _create_consumer(self): + self._consumer = BackgroundConsumer( + self._queue, + self._iterable, + self.total, + ) + self._consumer.daemon = True + self._consumer.start() + + def __iter__(self): + return self + + def __len__(self): + return self.total + + def take(self, n): + self.total = min(self.total, n) + if hasattr(self._iterable, "take"): + self._iterable.take(n) + + def __next__(self): + if self._consumer is None: + self._create_consumer() + + if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): + if time.time() - self.start_time > 5 * 60: + if ( + self.warning_time is None + or time.time() - self.warning_time > 15 * 60 + ): + logger.debug( + "Data loading buffer is empty or nearly empty. This may " + "indicate a data loading bottleneck, and increasing the " + "number of workers (--num-workers) may help." + ) + self.warning_time = time.time() + + item = self._queue.get(True) + if isinstance(item, Exception): + raise item + if item is _sentinel: + raise StopIteration() + return item +# import itertools +# import logging +# import math +# import operator +# import os +# import queue +# import time +# from threading import Thread + +# import numpy as np +# import torch +# from unicore.data import data_utils + + +# logger = logging.getLogger(__name__) + +# # Object used by _background_consumer to signal the source is exhausted +# # to the main thread. +# _sentinel = object() + + +# class CountingIterator(object): +# """Wrapper around an iterable that maintains the iteration count. + +# Args: +# iterable (iterable): iterable to wrap +# start (int): starting iteration count. Note that this doesn't +# actually advance the iterator. +# total (int): override the iterator length returned by +# ``__len__``. This can be used to truncate *iterator*. + +# Attributes: +# n (int): number of elements consumed from this iterator +# """ + +# def __init__(self, iterable, start=None, total=None): +# self.iterable = iterable +# self.itr = iter(self) + +# if start is None: +# self.n = getattr(iterable, "n", 0) +# else: +# self.n = start + +# if total is None: +# self.total = self.n + len(iterable) +# else: +# self.total = total + +# def __len__(self): +# return self.total + +# def __iter__(self): +# for x in self.iterable: +# if self.n >= self.total: +# raise RuntimeError( +# "Mismatch between actual and expected iterable length. " +# "This may be caused by resuming training from a checkpoint using " +# "a different number of GPUs, in which case you can try the " +# "--reset-dataloader option. Alternatively you may have a train or " +# "validation set that is smaller than the number of GPUs. If none " +# "of these apply, please report this to the unicore developers." +# ) +# self.n += 1 +# yield x + +# def __next__(self): +# return next(self.itr) + +# def has_next(self): +# """Whether the iterator has been exhausted.""" +# return self.n < len(self) + +# def skip(self, num_to_skip): +# """Fast-forward the iterator by skipping *num_to_skip* elements.""" +# next(itertools.islice(self.itr, num_to_skip, num_to_skip), None) +# return self + +# def take(self, n): +# """ +# Truncates the iterator to n elements at most. +# """ +# self.total = min(self.total, n) + +# # Propagate this change to the underlying iterator +# # Only take after what we have already consumed (i.e. after restarting +# # from checkpoint mid epoch, we have to subtract self.n which is the +# # starting point) +# # +# # This to maintain the invariant self.total = self.n + len(iterable), +# # before calling __next__ or __iter__ +# propagated_take = max(n - self.n, 0) +# if hasattr(self.iterable, "take"): +# self.iterable.take(propagated_take) +# else: +# self.iterable = itertools.islice(self.iterable, propagated_take) + + +# class EpochBatchIterating(object): +# def __len__(self) -> int: +# raise NotImplementedError + +# @property +# def next_epoch_idx(self): +# raise NotImplementedError + +# def next_epoch_itr( +# self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True +# ): +# """Return a new iterator over the dataset. + +# Args: +# shuffle (bool, optional): shuffle batches before returning the +# iterator (default: True). +# fix_batches_to_gpus (bool, optional): ensure that batches are always +# allocated to the same shards across epochs. Requires +# that :attr:`dataset` supports prefetching (default: False). +# set_dataset_epoch (bool, optional): update the wrapped Dataset with +# the new epoch number (default: True). +# """ +# raise NotImplementedError + +# def end_of_epoch(self) -> bool: +# """Returns whether the most recent epoch iterator has been exhausted""" +# raise NotImplementedError + +# @property +# def iterations_in_epoch(self) -> int: +# """The number of consumed batches in the current epoch.""" +# raise NotImplementedError + +# def state_dict(self): +# """Returns a dictionary containing a whole state of the iterator.""" +# raise NotImplementedError + +# def load_state_dict(self, state_dict): +# """Copies the state of the iterator from the given *state_dict*.""" +# raise NotImplementedError + +# @property +# def first_batch(self): +# return "DUMMY" + + +# class EpochBatchIterator(EpochBatchIterating): +# """A multi-epoch iterator over a :class:`torch.utils.data.Dataset`. + +# Compared to :class:`torch.utils.data.DataLoader`, this iterator: + +# - can be reused across multiple epochs with the :func:`next_epoch_itr` +# method (optionally shuffled between epochs) +# - can be serialized/deserialized with the :func:`state_dict` and +# :func:`load_state_dict` methods +# - supports sharding with the *num_shards* and *shard_id* arguments + +# Args: +# dataset (~torch.utils.data.Dataset): dataset from which to load the data +# collate_fn (callable): merges a list of samples to form a mini-batch +# batch_sampler (~torch.utils.data.Sampler or a callable): an iterator over batches of +# indices, or a callable to create such an iterator (~torch.utils.data.Sampler). +# A callable batch_sampler will be called for each epoch to enable per epoch dynamic +# batch iterators defined by this callable batch_sampler. +# seed (int, optional): seed for random number generator for +# reproducibility (default: 1). +# num_shards (int, optional): shard the data iterator into N +# shards (default: 1). +# shard_id (int, optional): which shard of the data iterator to +# return (default: 0). +# num_workers (int, optional): how many subprocesses to use for data +# loading. 0 means the data will be loaded in the main process +# (default: 0). +# epoch (int, optional): the epoch to start the iterator from +# (default: 1). +# buffer_size (int, optional): the number of batches to keep ready in the +# queue. Helps speeding up dataloading. When buffer_size is zero, the +# default torch.utils.data.DataLoader preloading is used. +# timeout (int, optional): if positive, the timeout value for collecting a batch +# from workers. Should always be non-negative (default: ``0``). +# disable_shuffling (bool, optional): force disable shuffling +# (default: ``False``). +# """ + +# def __init__( +# self, +# dataset, +# collate_fn, +# batch_sampler, +# seed=1, +# num_shards=1, +# shard_id=0, +# num_workers=0, +# epoch=1, +# buffer_size=0, +# timeout=0, +# disable_shuffling=False, +# ): +# assert isinstance(dataset, torch.utils.data.Dataset) +# self.dataset = dataset +# self.collate_fn = collate_fn +# self.batch_sampler = batch_sampler +# self._frozen_batches = ( +# tuple(batch_sampler) if not callable(batch_sampler) else None +# ) +# self.seed = seed +# self.num_shards = num_shards +# self.shard_id = shard_id +# self.num_workers = num_workers +# # This upper limit here is to prevent people from abusing this feature +# # in a shared computing environment. +# self.buffer_size = min(buffer_size, 32) +# self.timeout = timeout +# self.disable_shuffling = disable_shuffling + +# self.epoch = max(epoch, 1) # we use 1-based indexing for epochs +# self.shuffle = not disable_shuffling +# self._cur_epoch_itr = None +# self._next_epoch_itr = None +# self._supports_prefetch = getattr(dataset, "supports_prefetch", False) + +# @property +# def frozen_batches(self): +# if self._frozen_batches is None: +# self._frozen_batches = tuple(self.batch_sampler(self.dataset, self.epoch)) +# return self._frozen_batches + +# @property +# def first_batch(self): +# if len(self.frozen_batches) == 0: +# raise Exception( +# "The dataset is empty. This could indicate " +# "that all elements in the dataset have been skipped. " +# "Try increasing the max number of allowed tokens or using " +# "a larger dataset." +# ) + +# if getattr(self.dataset, "supports_fetch_outside_dataloader", True): +# return self.collate_fn([self.dataset[i] for i in self.frozen_batches[0]]) +# else: +# return "DUMMY" + +# def __len__(self): +# return int(math.ceil(len(self.frozen_batches) / float(self.num_shards))) + +# @property +# def n(self): +# return self.iterations_in_epoch + +# @property +# def next_epoch_idx(self): +# """Return the epoch index after *next_epoch_itr* is called.""" +# if self._next_epoch_itr is not None: +# return self.epoch +# elif self._cur_epoch_itr is not None and self.end_of_epoch(): +# return self.epoch + 1 +# else: +# return self.epoch + +# def next_epoch_itr( +# self, shuffle=True, fix_batches_to_gpus=False, set_dataset_epoch=True +# ): +# """Return a new iterator over the dataset. + +# Args: +# shuffle (bool, optional): shuffle batches before returning the +# iterator (default: True). +# fix_batches_to_gpus (bool, optional): ensure that batches are always +# allocated to the same shards across epochs. Requires +# that :attr:`dataset` supports prefetching (default: False). +# set_dataset_epoch (bool, optional): update the wrapped Dataset with +# the new epoch number (default: True). +# """ +# if self.disable_shuffling: +# shuffle = False +# self.epoch = self.next_epoch_idx +# if set_dataset_epoch and hasattr(self.dataset, "set_epoch"): +# self.dataset.set_epoch(self.epoch) +# if self._next_epoch_itr is not None: +# self._cur_epoch_itr = self._next_epoch_itr +# self._next_epoch_itr = None +# else: +# if callable(self.batch_sampler): +# # reset _frozen_batches to refresh the next epoch +# self._frozen_batches = None +# self._cur_epoch_itr = self._get_iterator_for_epoch( +# self.epoch, +# shuffle, +# fix_batches_to_gpus=fix_batches_to_gpus, +# ) +# self.shuffle = shuffle +# return self._cur_epoch_itr + +# def end_of_epoch(self) -> bool: +# """Returns whether the most recent epoch iterator has been exhausted""" +# return not self._cur_epoch_itr.has_next() + +# @property +# def iterations_in_epoch(self): +# """The number of consumed batches in the current epoch.""" +# if self._cur_epoch_itr is not None: +# return self._cur_epoch_itr.n +# elif self._next_epoch_itr is not None: +# return self._next_epoch_itr.n +# return 0 + +# def state_dict(self): +# """Returns a dictionary containing a whole state of the iterator.""" +# if self.end_of_epoch(): +# epoch = self.epoch + 1 +# iter_in_epoch = 0 +# else: +# epoch = self.epoch +# iter_in_epoch = self.iterations_in_epoch +# return { +# "epoch": epoch, +# "iterations_in_epoch": iter_in_epoch, +# "shuffle": self.shuffle, +# "len": len(self), +# } + +# def load_state_dict(self, state_dict): +# """Copies the state of the iterator from the given *state_dict*.""" +# self.epoch = state_dict["epoch"] +# itr_pos = state_dict.get("iterations_in_epoch", 0) +# if itr_pos > 0: +# if "len" in state_dict and state_dict["len"] != len(self): +# old_itr_pos = itr_pos +# itr_pos = int(itr_pos * len(self) / state_dict["len"]) +# logger.info( +# "Iterator size is changed. it is possible due to the change of update_freq/num_gpu. The itr_pos is change from {} to {} for consistency".format(old_itr_pos, itr_pos) +# ) + +# # fast-forward epoch iterator +# self._next_epoch_itr = self._get_iterator_for_epoch( +# self.epoch, +# shuffle=state_dict.get("shuffle", True), +# offset=itr_pos, +# ) +# if self._next_epoch_itr is None: +# raise RuntimeError( +# "Cannot resume training due to dataloader mismatch. You can relaunch " +# "training with `--reset-dataloader` and it should work." +# ) +# else: +# self._next_epoch_itr = None + +# def _get_iterator_for_epoch( +# self, epoch, shuffle, fix_batches_to_gpus=False, offset=0 +# ): +# def shuffle_batches(batches, seed): +# with data_utils.numpy_seed(seed): +# np.random.shuffle(batches) +# return batches + +# if self._supports_prefetch: +# batches = self.frozen_batches + +# if shuffle and not fix_batches_to_gpus: +# batches = shuffle_batches(list(batches), self.seed + epoch) + +# batches = list( +# ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) +# ) +# self.dataset.prefetch([i for s in batches for i in s]) + +# if shuffle and fix_batches_to_gpus: +# batches = shuffle_batches(batches, self.seed + epoch + self.shard_id) +# else: +# if shuffle: +# batches = shuffle_batches(list(self.frozen_batches), self.seed + epoch) +# else: +# batches = self.frozen_batches +# batches = list( +# ShardedIterator(batches, self.num_shards, self.shard_id, fill_value=[]) +# ) + +# if offset > 0 and offset >= len(batches): +# return None + +# if self.num_workers > 0: +# os.environ["PYTHONWARNINGS"] = "ignore:semaphore_tracker:UserWarning" + +# # Create data loader +# itr = torch.utils.data.DataLoader( +# self.dataset, +# collate_fn=self.collate_fn, +# batch_sampler=batches[offset:], +# num_workers=self.num_workers, +# timeout=self.timeout, +# ) + +# # Wrap with a BufferedIterator if needed +# if self.buffer_size > 0: +# itr = BufferedIterator(self.buffer_size, itr) + +# # Wrap with CountingIterator +# itr = CountingIterator(itr, start=offset) +# return itr + + +# class GroupedIterator(CountingIterator): +# """Wrapper around an iterable that returns groups (chunks) of items. + +# Args: +# iterable (iterable): iterable to wrap +# chunk_size (int): size of each chunk + +# Attributes: +# n (int): number of elements consumed from this iterator +# """ + +# def __init__(self, iterable, chunk_size): +# itr = _chunk_iterator(iterable, chunk_size) +# super().__init__( +# itr, +# start=int(math.ceil(getattr(iterable, "n", 0) / float(chunk_size))), +# total=int(math.ceil(len(iterable) / float(chunk_size))), +# ) +# self.chunk_size = chunk_size + + +# def _chunk_iterator(itr, chunk_size): +# chunk = [] +# for x in itr: +# chunk.append(x) +# if len(chunk) == chunk_size: +# yield chunk +# chunk = [] +# if len(chunk) > 0: +# yield chunk + + +# class ShardedIterator(CountingIterator): +# """A sharded wrapper around an iterable, padded to length. + +# Args: +# iterable (iterable): iterable to wrap +# num_shards (int): number of shards to split the iterable into +# shard_id (int): which shard to iterator over +# fill_value (Any, optional): padding value when the iterable doesn't +# evenly divide *num_shards* (default: None). + +# Attributes: +# n (int): number of elements consumed from this iterator +# """ + +# def __init__(self, iterable, num_shards, shard_id, fill_value=None): +# if shard_id < 0 or shard_id >= num_shards: +# raise ValueError("shard_id must be between 0 and num_shards") +# sharded_len = int(math.ceil(len(iterable) / float(num_shards))) +# itr = map( +# operator.itemgetter(1), +# itertools.zip_longest( +# range(sharded_len), +# itertools.islice(iterable, shard_id, len(iterable), num_shards), +# fillvalue=fill_value, +# ), +# ) +# super().__init__( +# itr, +# start=int(math.ceil(getattr(iterable, "n", 0) / float(num_shards))), +# total=sharded_len, +# ) + + +# class BackgroundConsumer(Thread): +# def __init__(self, queue, source, max_len): +# Thread.__init__(self) + +# self._queue = queue +# self._source = source +# self._max_len = max_len +# self.count = 0 + +# def run(self): +# try: +# for item in self._source: +# self._queue.put(item) + +# # Stop if we reached the maximum length +# self.count += 1 +# if self._max_len is not None and self.count >= self._max_len: +# break + +# # Signal the consumer we are done. +# self._queue.put(_sentinel) +# except Exception as e: +# self._queue.put(e) + + +# class BufferedIterator(object): +# def __init__(self, size, iterable): +# self._queue = queue.Queue(size) +# self._iterable = iterable +# self._consumer = None + +# self.start_time = time.time() +# self.warning_time = None + +# self.total = len(iterable) + +# def _create_consumer(self): +# self._consumer = BackgroundConsumer( +# self._queue, +# self._iterable, +# self.total, +# ) +# self._consumer.daemon = True +# self._consumer.start() + +# def __iter__(self): +# return self + +# def __len__(self): +# return self.total + +# def take(self, n): +# self.total = min(self.total, n) + +# # Propagate this change to the underlying iterator +# if hasattr(self._iterable, "take"): +# self._iterable.take(n) + +# def __next__(self): +# # Create consumer if not created yet +# if self._consumer is None: +# self._create_consumer() + +# # Notify the user if there is a data loading bottleneck +# if self._queue.qsize() < min(2, max(1, self._queue.maxsize // 2)): +# if time.time() - self.start_time > 5 * 60: +# if ( +# self.warning_time is None +# or time.time() - self.warning_time > 15 * 60 +# ): +# logger.debug( +# "Data loading buffer is empty or nearly empty. This may " +# "indicate a data loading bottleneck, and increasing the " +# "number of workers (--num-workers) may help." +# ) +# self.warning_time = time.time() + +# # Get next example +# item = self._queue.get(True) +# if isinstance(item, Exception): +# raise item +# if item is _sentinel: +# raise StopIteration() +# return item diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/lmdb_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/lmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fda72612f7a22eb32f8863f4cdaaa32eff26ac78 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/lmdb_dataset.py @@ -0,0 +1,52 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import lmdb +import os +import pickle +# import torch +import numpy as np +import collections +from functools import lru_cache +from . import data_utils +import logging +logger = logging.getLogger(__name__) + +class LMDBDataset: + def __init__(self, db_path): + self.db_path = db_path + assert os.path.isfile(self.db_path), "{} not found".format( + self.db_path + ) + env = self.connect_db(self.db_path) + with env.begin() as txn: + self._keys = list(txn.cursor().iternext(values=False)) + + def connect_db(self, lmdb_path, save_to_self=False): + env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=256, + ) + if not save_to_self: + return env + else: + self.env = env + + def __len__(self): + return len(self._keys) + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + if not hasattr(self, 'env'): + self.connect_db(self.db_path, save_to_self=True) + datapoint_pickled = self.env.begin().get(self._keys[idx]) + data = pickle.loads(datapoint_pickled) + return data + + diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/lru_cache_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/lru_cache_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7c054acdc44a31d1be41f4cd75b91906653fa448 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/lru_cache_dataset.py @@ -0,0 +1,22 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache + +from . import BaseWrapperDataset + + +class LRUCacheDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + + @lru_cache(maxsize=16) + def __getitem__(self, index): + return self.dataset[index] + + @lru_cache(maxsize=16) + def collater(self, samples): + return self.dataset.collater(samples) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/mask_tokens_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/mask_tokens_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..5c3501b9b8d5d2da2d7cc9a0773ea7e67e24e304 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/mask_tokens_dataset.py @@ -0,0 +1,258 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from functools import lru_cache + +import numpy as np +import mindspore as ms # 替换torch为mindspore +from unicore.data import Dictionary, data_utils + +from . import BaseWrapperDataset, LRUCacheDataset + + +class MaskTokensDataset(BaseWrapperDataset): + + @classmethod + def apply_mask(cls, dataset: ms.dataset.Dataset, *args, **kwargs): # 替换torch数据集类型为MindSpore数据集 + """Return the source and target datasets for masked LM training.""" + dataset = LRUCacheDataset(dataset) + return ( + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), + LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), + ) + + def __init__( + self, + dataset: ms.dataset.Dataset, # 替换torch数据集类型为MindSpore数据集 + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + return_masked_tokens: bool = False, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.return_masked_tokens = return_masked_tokens + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + + if random_token_prob > 0.0: + weights = np.ones(len(self.vocab)) + weights[vocab.special_index()] = 0 + self.weights = weights / weights.sum() + + self.epoch = None + + @property + def can_reuse_epoch_itr_across_epochs(self): + return True # only the noise changes, not item sizes + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.epoch, index) + + @lru_cache(maxsize=16) + def __getitem_cached__(self, epoch: int, index: int): + with data_utils.numpy_seed(self.seed, epoch, index): + item = self.dataset[index] + sz = len(item) + # don't allow empty sequence + assert sz > 2 + assert ( + self.mask_idx not in item + ), "Dataset contains mask_idx (={}), this is not expected!".format( + self.mask_idx, + ) + + # decide elements to mask + mask = np.full(sz, False) + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * (sz - 2) + np.random.rand() + ) + # don't mask first and last position + mask_idc = np.random.choice(sz - 2, num_mask, replace=False) + 1 + mask[mask_idc] = True + + if self.return_masked_tokens: + new_item = np.full(len(mask), self.pad_idx) + # 替换torch.from_numpy为mindspore.Tensor,调整掩码索引获取方式 + new_item[mask] = item[mask.astype(np.uint8) == 1] + return ms.Tensor(new_item) # 替换torch.from_numpy为ms.Tensor + + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + + return ms.Tensor(new_item) # 替换torch.from_numpy为ms.Tensor +# from functools import lru_cache + +# import numpy as np +# import torch +# from unicore.data import Dictionary, data_utils + +# from . import BaseWrapperDataset, LRUCacheDataset + + +# class MaskTokensDataset(BaseWrapperDataset): + +# @classmethod +# def apply_mask(cls, dataset: torch.utils.data.Dataset, *args, **kwargs): +# """Return the source and target datasets for masked LM training.""" +# dataset = LRUCacheDataset(dataset) +# return ( +# LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=False)), +# LRUCacheDataset(cls(dataset, *args, **kwargs, return_masked_tokens=True)), +# ) + +# def __init__( +# self, +# dataset: torch.utils.data.Dataset, +# vocab: Dictionary, +# pad_idx: int, +# mask_idx: int, +# return_masked_tokens: bool = False, +# seed: int = 1, +# mask_prob: float = 0.15, +# leave_unmasked_prob: float = 0.1, +# random_token_prob: float = 0.1, +# ): +# assert 0.0 < mask_prob < 1.0 +# assert 0.0 <= random_token_prob <= 1.0 +# assert 0.0 <= leave_unmasked_prob <= 1.0 +# assert random_token_prob + leave_unmasked_prob <= 1.0 + +# self.dataset = dataset +# self.vocab = vocab +# self.pad_idx = pad_idx +# self.mask_idx = mask_idx +# self.return_masked_tokens = return_masked_tokens +# self.seed = seed +# self.mask_prob = mask_prob +# self.leave_unmasked_prob = leave_unmasked_prob +# self.random_token_prob = random_token_prob + +# if random_token_prob > 0.0: +# weights = np.ones(len(self.vocab)) +# weights[vocab.special_index()] = 0 +# self.weights = weights / weights.sum() + +# self.epoch = None + +# @property +# def can_reuse_epoch_itr_across_epochs(self): +# return True # only the noise changes, not item sizes + +# def set_epoch(self, epoch, **unused): +# super().set_epoch(epoch) +# self.epoch = epoch + +# def __getitem__(self, index: int): +# return self.__getitem_cached__(self.epoch, index) + +# @lru_cache(maxsize=16) +# def __getitem_cached__(self, epoch: int, index: int): +# with data_utils.numpy_seed(self.seed, epoch, index): +# item = self.dataset[index] +# sz = len(item) +# # don't allow empty sequence +# assert sz > 2 +# assert ( +# self.mask_idx not in item +# ), "Dataset contains mask_idx (={}), this is not expected!".format( +# self.mask_idx, +# ) + +# # decide elements to mask +# mask = np.full(sz, False) +# num_mask = int( +# # add a random number for probabilistic rounding +# self.mask_prob * (sz - 2) + np.random.rand() +# ) +# # don't mask first and last position +# mask_idc = np.random.choice(sz - 2, num_mask, replace=False) + 1 +# mask[mask_idc] = True + +# if self.return_masked_tokens: +# new_item = np.full(len(mask), self.pad_idx) +# new_item[mask] = item[torch.from_numpy(mask.astype(np.uint8)) == 1] +# return torch.from_numpy(new_item) + +# # decide unmasking and random replacement +# rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob +# if rand_or_unmask_prob > 0.0: +# rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) +# if self.random_token_prob == 0.0: +# unmask = rand_or_unmask +# rand_mask = None +# elif self.leave_unmasked_prob == 0.0: +# unmask = None +# rand_mask = rand_or_unmask +# else: +# unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob +# decision = np.random.rand(sz) < unmask_prob +# unmask = rand_or_unmask & decision +# rand_mask = rand_or_unmask & (~decision) +# else: +# unmask = rand_mask = None + +# if unmask is not None: +# mask = mask ^ unmask + +# new_item = np.copy(item) +# new_item[mask] = self.mask_idx +# if rand_mask is not None: +# num_rand = rand_mask.sum() +# if num_rand > 0: +# new_item[rand_mask] = np.random.choice( +# len(self.vocab), +# num_rand, +# p=self.weights, +# ) + +# return torch.from_numpy(new_item) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/nested_dictionary_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/nested_dictionary_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2eb335a28fd0b25484bc6eaf880fc79aafed7d43 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/nested_dictionary_dataset.py @@ -0,0 +1,238 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from collections import OrderedDict + +import mindspore.dataset as ds # 替换PyTorch数据集导入为MindSpore数据集 +from mindspore import Tensor # 导入MindSpore张量类,用于数据拼接 + +from . import UnicoreDataset + + +def _flatten(dico, prefix=None): + """Flatten a nested dictionary.""" + new_dico = OrderedDict() + if isinstance(dico, dict): + prefix = prefix + "." if prefix is not None else "" + for k, v in dico.items(): + if v is None: + continue + new_dico.update(_flatten(v, prefix + k)) + elif isinstance(dico, list): + for i, v in enumerate(dico): + new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) + else: + new_dico = OrderedDict({prefix: dico}) + return new_dico + + +def _unflatten(dico): + """Unflatten a flattened dictionary into a nested dictionary.""" + new_dico = OrderedDict() + for full_k, v in dico.items(): + full_k = full_k.split(".") + node = new_dico + for k in full_k[:-1]: + if k.startswith("[") and k.endswith("]"): + k = int(k[1:-1]) + if k not in node: + node[k] = OrderedDict() + node = node[k] + node[full_k[-1]] = v + return new_dico + + +def default_collate(batch): + """MindSpore版本的默认数据拼接函数,替代PyTorch的default_collate""" + if not batch: + return None + elem = batch[0] + # 处理张量类型 + if isinstance(elem, Tensor): + return Tensor.cat(batch) # MindSpore中拼接张量使用Tensor.cat + # 处理列表类型 + elif isinstance(elem, list): + if isinstance(elem[0], Tensor): + # 若列表元素是张量,按维度0拼接 + return [default_collate([d[i] for d in batch]) for i in range(len(elem))] + else: + return batch + # 处理字典类型 + elif isinstance(elem, dict): + return {key: default_collate([d[key] for d in batch]) for key in elem} + # 其他类型直接返回 + else: + return batch + + +class NestedDictionaryDataset(UnicoreDataset): + def __init__(self, defn): + super().__init__() + self.defn = _flatten(defn) + first = None + for v in self.defn.values(): + if not isinstance( + v, + ( + UnicoreDataset, + ds.Dataset, # 替换PyTorch数据集类型为MindSpore数据集类型 + ), + ): + raise ValueError("Expected Dataset but found: {}".format(v.__class__)) + first = first or v + if len(v) > 0: + assert len(v) == len(first), "dataset lengths must match" + + self._len = len(first) + + def __getitem__(self, index): + return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) + + def __len__(self): + return self._len + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + if len(samples) == 0: + return {} + sample = OrderedDict() + for k, ds in self.defn.items(): + try: + sample[k] = ds.collater([s[k] for s in samples]) + except NotImplementedError: + sample[k] = default_collate([s[k] for s in samples]) # 使用MindSpore版本的默认拼接函数 + return _unflatten(sample) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return any(ds.supports_prefetch for ds in self.defn.values()) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + for ds in self.defn.values(): + if getattr(ds, "supports_prefetch", False): + ds.prefetch(indices) + + @property + def can_reuse_epoch_itr_across_epochs(self): + return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + for ds in self.defn.values(): + ds.set_epoch(epoch) +# from collections import OrderedDict + +# import torch +# from torch.utils.data.dataloader import default_collate + +# from . import UnicoreDataset + + +# def _flatten(dico, prefix=None): +# """Flatten a nested dictionary.""" +# new_dico = OrderedDict() +# if isinstance(dico, dict): +# prefix = prefix + "." if prefix is not None else "" +# for k, v in dico.items(): +# if v is None: +# continue +# new_dico.update(_flatten(v, prefix + k)) +# elif isinstance(dico, list): +# for i, v in enumerate(dico): +# new_dico.update(_flatten(v, prefix + ".[" + str(i) + "]")) +# else: +# new_dico = OrderedDict({prefix: dico}) +# return new_dico + + +# def _unflatten(dico): +# """Unflatten a flattened dictionary into a nested dictionary.""" +# new_dico = OrderedDict() +# for full_k, v in dico.items(): +# full_k = full_k.split(".") +# node = new_dico +# for k in full_k[:-1]: +# if k.startswith("[") and k.endswith("]"): +# k = int(k[1:-1]) +# if k not in node: +# node[k] = OrderedDict() +# node = node[k] +# node[full_k[-1]] = v +# return new_dico + + +# class NestedDictionaryDataset(UnicoreDataset): +# def __init__(self, defn): +# super().__init__() +# self.defn = _flatten(defn) +# first = None +# for v in self.defn.values(): +# if not isinstance( +# v, +# ( +# UnicoreDataset, +# torch.utils.data.Dataset, +# ), +# ): +# raise ValueError("Expected Dataset but found: {}".format(v.__class__)) +# first = first or v +# if len(v) > 0: +# assert len(v) == len(first), "dataset lengths must match" + +# self._len = len(first) + +# def __getitem__(self, index): +# return OrderedDict((k, ds[index]) for k, ds in self.defn.items()) + +# def __len__(self): +# return self._len + +# def collater(self, samples): +# """Merge a list of samples to form a mini-batch. + +# Args: +# samples (List[dict]): samples to collate + +# Returns: +# dict: a mini-batch suitable for forwarding with a Model +# """ +# if len(samples) == 0: +# return {} +# sample = OrderedDict() +# for k, ds in self.defn.items(): +# try: +# sample[k] = ds.collater([s[k] for s in samples]) +# except NotImplementedError: +# sample[k] = default_collate([s[k] for s in samples]) +# return _unflatten(sample) + +# @property +# def supports_prefetch(self): +# """Whether this dataset supports prefetching.""" +# return any(ds.supports_prefetch for ds in self.defn.values()) + +# def prefetch(self, indices): +# """Prefetch the data required for this epoch.""" +# for ds in self.defn.values(): +# if getattr(ds, "supports_prefetch", False): +# ds.prefetch(indices) + +# @property +# def can_reuse_epoch_itr_across_epochs(self): +# return all(ds.can_reuse_epoch_itr_across_epochs for ds in self.defn.values()) + +# def set_epoch(self, epoch): +# super().set_epoch(epoch) +# for ds in self.defn.values(): +# ds.set_epoch(epoch) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/num_samples_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/num_samples_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..af2982cf7dcfc58183801d4756b0db259eb61510 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/num_samples_dataset.py @@ -0,0 +1,18 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from . import UnicoreDataset + + +class NumSamplesDataset(UnicoreDataset): + def __getitem__(self, index): + return 1 + + def __len__(self): + return 0 + + def collater(self, samples): + return sum(samples) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/numel_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/numel_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..52c20ee995e6ca31b99696a3dde8671e2a3d4fc3 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/numel_dataset.py @@ -0,0 +1,57 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import mindspore as ms # 替换PyTorch为MindSpore + +from . import BaseWrapperDataset + + +class NumelDataset(BaseWrapperDataset): + def __init__(self, dataset, reduce=False): + super().__init__(dataset) + self.reduce = reduce + + def __getitem__(self, index): + item = self.dataset[index] + if ms.is_tensor(item): # 替换torch.is_tensor为mindspore.is_tensor + return item.numel() # MindSpore张量的numel()方法,替代torch.numel() + else: + return np.size(item) # numpy操作保持不变 + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if self.reduce: + return sum(samples) + else: + return ms.Tensor(samples) # 替换torch.tensor为mindspore.Tensor +# import numpy as np +# import torch + +# from . import BaseWrapperDataset + + +# class NumelDataset(BaseWrapperDataset): +# def __init__(self, dataset, reduce=False): +# super().__init__(dataset) +# self.reduce = reduce + +# def __getitem__(self, index): +# item = self.dataset[index] +# if torch.is_tensor(item): +# return torch.numel(item) +# else: +# return np.size(item) + +# def __len__(self): +# return len(self.dataset) + +# def collater(self, samples): +# if self.reduce: +# return sum(samples) +# else: +# return torch.tensor(samples) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/pad_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/pad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ec96724bd6e9b2addbf261cc61b5c3903013eb04 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/pad_dataset.py @@ -0,0 +1,73 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unicore.data import data_utils + +from . import BaseWrapperDataset + + +class PadDataset(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + def collater(self, samples): + # 保持填充逻辑,假设data_utils.collate_tokens适配MindSpore张量 + return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8) + + +class LeftPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=True) + + +class RightPadDataset(PadDataset): + def __init__(self, dataset, pad_idx): + super().__init__(dataset, pad_idx, left_pad=False) + + +class RightPadDataset2D(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad=False): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + def collater(self, samples): + # 保持2D填充逻辑,假设data_utils.collate_tokens_2d适配MindSpore张量 + return data_utils.collate_tokens_2d(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8) +# from unicore.data import data_utils + +# from . import BaseWrapperDataset + + +# class PadDataset(BaseWrapperDataset): +# def __init__(self, dataset, pad_idx, left_pad): +# super().__init__(dataset) +# self.pad_idx = pad_idx +# self.left_pad = left_pad + +# def collater(self, samples): +# return data_utils.collate_tokens(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8) + + +# class LeftPadDataset(PadDataset): +# def __init__(self, dataset, pad_idx): +# super().__init__(dataset, pad_idx, left_pad=True) + + +# class RightPadDataset(PadDataset): +# def __init__(self, dataset, pad_idx): +# super().__init__(dataset, pad_idx, left_pad=False) + + +# class RightPadDataset2D(BaseWrapperDataset): +# def __init__(self, dataset, pad_idx,left_pad=False): +# super().__init__(dataset) +# self.pad_idx = pad_idx +# self.left_pad = left_pad +# def collater(self, samples): +# return data_utils.collate_tokens_2d(samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/prepend_token_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/prepend_token_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d5a441c4ec4f858033bf1f11125dca125e77e438 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/prepend_token_dataset.py @@ -0,0 +1,44 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import mindspore.mint as mint # 替换torch为mindspore.mint +from functools import lru_cache + +from . import BaseWrapperDataset + + +class PrependTokenDataset(BaseWrapperDataset): + + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + # 替换torch.full_like为mindspore.mint.full_like,torch.cat为mindspore.mint.cat + item = mint.cat([mint.full_like(item[0], self.token).unsqueeze(0), item], dim=0) + return item +# import numpy as np +# import torch +# from functools import lru_cache + +# from . import BaseWrapperDataset + + +# class PrependTokenDataset(BaseWrapperDataset): + +# def __init__(self, dataset, token=None): +# super().__init__(dataset) +# self.token = token + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# item = self.dataset[idx] +# if self.token is not None: +# item = torch.cat([torch.full_like(item[0], self.token).unsqueeze(0), item], dim=0) +# return item diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/raw_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/raw_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..d724649b5bdac35daf6ad9cef57f300521ec0752 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/raw_dataset.py @@ -0,0 +1,148 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +from functools import lru_cache +from . import UnicoreDataset + + +# 实现MindSpore版本的默认数据拼接函数,替代PyTorch的default_collate +def default_collate(batch): + if not batch: + return None + elem = batch[0] + # 处理MindSpore张量 + if isinstance(elem, ms.Tensor): + return ms.ops.cat(batch) # 张量拼接使用mindspore.ops.cat + # 处理列表类型 + elif isinstance(elem, list): + if isinstance(elem[0], ms.Tensor): + return [default_collate([d[i] for d in batch]) for i in range(len(elem))] + else: + return batch + # 处理字典类型 + elif isinstance(elem, dict): + return {key: default_collate([d[key] for d in batch]) for key in elem} + # 其他类型直接返回 + else: + return batch + + +class RawLabelDataset(UnicoreDataset): + def __init__(self, labels): + super().__init__() + self.labels = labels + + @lru_cache(maxsize=16) + def __getitem__(self, index): + return self.labels[index] + + def __len__(self): + return len(self.labels) + + def collater(self, samples): + # 替换torch.tensor为mindspore.Tensor + return ms.Tensor(samples) + + +class RawArrayDataset(UnicoreDataset): + + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + @lru_cache(maxsize=16) + def __getitem__(self, index): + return self.dataset[index] + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, 'collater'): + return self.dataset.collater(samples) + else: + # 使用MindSpore版本的default_collate + return default_collate(samples) + + +class RawNumpyDataset(UnicoreDataset): + + def __init__(self, dataset): + super().__init__() + self.dataset = dataset + + @lru_cache(maxsize=16) + def __getitem__(self, index): + # 替换torch.from_numpy为mindspore.Tensor(直接从numpy数组创建Tensor) + return ms.Tensor(self.dataset[index]) + + def __len__(self): + return len(self.dataset) + + def collater(self, samples): + if hasattr(self.dataset, 'collater'): + return self.dataset.collater(samples) + else: + # 使用MindSpore版本的default_collate + return default_collate(samples) +# import torch +# from torch.utils.data.dataloader import default_collate +# from functools import lru_cache +# from . import UnicoreDataset + + +# class RawLabelDataset(UnicoreDataset): +# def __init__(self, labels): +# super().__init__() +# self.labels = labels + +# @lru_cache(maxsize=16) +# def __getitem__(self, index): +# return self.labels[index] + +# def __len__(self): +# return len(self.labels) + +# def collater(self, samples): +# return torch.tensor(samples) + + +# class RawArrayDataset(UnicoreDataset): + +# def __init__(self, dataset): +# super().__init__() +# self.dataset = dataset + +# @lru_cache(maxsize=16) +# def __getitem__(self, index): +# return self.dataset[index] + +# def __len__(self): +# return len(self.dataset) + +# def collater(self, samples): +# if hasattr(self.dataset, 'collater'): +# return self.dataset.collater(samples) +# else: +# return default_collate(samples) + + +# class RawNumpyDataset(UnicoreDataset): + +# def __init__(self, dataset): +# super().__init__() +# self.dataset = dataset + +# @lru_cache(maxsize=16) +# def __getitem__(self, index): +# return torch.from_numpy(self.dataset[index]) + +# def __len__(self): +# return len(self.dataset) + +# def collater(self, samples): +# if hasattr(self.dataset, 'collater'): +# return self.dataset.collater(samples) +# else: +# return default_collate(samples) diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/sort_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/sort_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..cdae5c7ac8855e14e93167582bf3f9726e43f0e9 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/sort_dataset.py @@ -0,0 +1,42 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np + +from . import BaseWrapperDataset, data_utils + + +class SortDataset(BaseWrapperDataset): + def __init__(self, dataset, sort_order): + super().__init__(dataset) + if not isinstance(sort_order, (list, tuple)): + sort_order = [sort_order] + self.sort_order = sort_order + + assert all(len(so) == len(dataset) for so in sort_order) + + def ordered_indices(self): + return np.lexsort(self.sort_order) + + +class EpochShuffleDataset(BaseWrapperDataset): + def __init__(self, dataset, size, seed): + super().__init__(dataset) + self.size = size + self.seed = seed + self.set_epoch(1) + + def set_epoch(self, epoch): + super().set_epoch(epoch) + with data_utils.numpy_seed(self.seed + epoch - 1): + self.sort_order = np.random.permutation(self.size) + + def ordered_indices(self): + return self.sort_order + + @property + def can_reuse_epoch_itr_across_epochs(self): + return False \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/tokenize_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/tokenize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..403511c09c5be5a5248b5ba666f889d799773a7e --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/tokenize_dataset.py @@ -0,0 +1,51 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from functools import lru_cache + +import mindspore as ms # 替换torch导入为mindspore +from unicore.data import Dictionary +from . import BaseWrapperDataset + + +class TokenizeDataset(BaseWrapperDataset): + def __init__( + self, + dataset: ms.dataset.Dataset, # 替换PyTorch数据集类型为MindSpore数据集类型 + dictionary: Dictionary, + max_seq_len: int = 512, + ): + self.dataset = dataset + self.dictionary = dictionary + self.max_seq_len = max_seq_len + + @lru_cache(maxsize=16) + def __getitem__(self, index: int): + raw_data = self.dataset[index] + assert len(raw_data) < self.max_seq_len and len(raw_data) > 0 + # 替换torch.from_numpy为mindspore.Tensor,.long()对应指定dtype为int64 + return ms.Tensor(self.dictionary.vec_index(raw_data), dtype=ms.int64) +# from functools import lru_cache + +# import torch +# from unicore.data import Dictionary +# from functools import lru_cache +# from . import BaseWrapperDataset + + +# class TokenizeDataset(BaseWrapperDataset): +# def __init__( +# self, +# dataset: torch.utils.data.Dataset, +# dictionary: Dictionary, +# max_seq_len: int=512, +# ): +# self.dataset = dataset +# self.dictionary = dictionary +# self.max_seq_len = max_seq_len + +# @lru_cache(maxsize=16) +# def __getitem__(self, index: int): +# raw_data = self.dataset[index] +# assert len(raw_data) < self.max_seq_len and len(raw_data) > 0 +# return torch.from_numpy(self.dictionary.vec_index(raw_data)).long() \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/data/unicore_dataset.py b/MindChemistry/applications/Uni-Mol/unicore/data/unicore_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..071d898c92818f822a6f04de3f13370441ff0263 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/data/unicore_dataset.py @@ -0,0 +1,174 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import numpy as np +import mindspore.dataset as ds # 替换PyTorch数据集模块为MindSpore数据集模块 + +logger = logging.getLogger(__name__) + + +class EpochListening: + """Mixin for receiving updates whenever the epoch increments.""" + + @property + def can_reuse_epoch_itr_across_epochs(self): + """ + Whether we can reuse the iterator for this dataset across epochs. + + This needs to return ``False`` if the sample sizes can change across + epochs, in which case we may need to regenerate batches at each epoch. + If your dataset relies in ``set_epoch`` then you should consider setting + this to ``False``. + """ + return True + + def set_epoch(self, epoch): + """Will receive the updated epoch number at the beginning of the epoch.""" + pass + + +class UnicoreDataset(ds.Dataset, EpochListening): # 替换父类为MindSpore的Dataset + """A dataset that provides helpers for batching.""" + + def __getitem__(self, index): + raise NotImplementedError + + def __len__(self): + raise NotImplementedError + + def collater(self, samples): + """Merge a list of samples to form a mini-batch. + + Args: + samples (List[dict]): samples to collate + + Returns: + dict: a mini-batch suitable for forwarding with a Model + """ + raise NotImplementedError + + def ordered_indices(self): + """Return an ordered list of indices. Batches will be constructed based + on this order.""" + return np.arange(len(self), dtype=np.int64) + + @property + def supports_prefetch(self): + """Whether this dataset supports prefetching.""" + return False + + def attr(self, attr: str, index: int): + return getattr(self, attr, None) + + def prefetch(self, indices): + """Prefetch the data required for this epoch.""" + raise NotImplementedError + + def batch_by_size( + self, + indices, + batch_size=None, + required_batch_size_multiple=1, + ): + """ + Given an ordered set of indices + """ + from unicore.data import data_utils + return data_utils.batch_by_size( + indices, + batch_size=batch_size, + required_batch_size_multiple=required_batch_size_multiple, + ) + + @property + def supports_fetch_outside_dataloader(self): + """Whether this dataset supports fetching outside the workers of the dataloader.""" + return True +# import logging +# import numpy as np +# import torch.utils.data + +# logger = logging.getLogger(__name__) + + +# class EpochListening: +# """Mixin for receiving updates whenever the epoch increments.""" + +# @property +# def can_reuse_epoch_itr_across_epochs(self): +# """ +# Whether we can reuse the :class:`unicore.data.EpochBatchIterator` for +# this dataset across epochs. + +# This needs to return ``False`` if the sample sizes can change across +# epochs, in which case we may need to regenerate batches at each epoch. +# If your dataset relies in ``set_epoch`` then you should consider setting +# this to ``False``. +# """ +# return True + +# def set_epoch(self, epoch): +# """Will receive the updated epoch number at the beginning of the epoch.""" +# pass + + +# class UnicoreDataset(torch.utils.data.Dataset, EpochListening): +# """A dataset that provides helpers for batching.""" + +# def __getitem__(self, index): +# raise NotImplementedError + +# def __len__(self): +# raise NotImplementedError + +# def collater(self, samples): +# """Merge a list of samples to form a mini-batch. + +# Args: +# samples (List[dict]): samples to collate + +# Returns: +# dict: a mini-batch suitable for forwarding with a Model +# """ +# raise NotImplementedError + +# def ordered_indices(self): +# """Return an ordered list of indices. Batches will be constructed based +# on this order.""" +# return np.arange(len(self), dtype=np.int64) + +# @property +# def supports_prefetch(self): +# """Whether this dataset supports prefetching.""" +# return False + +# def attr(self, attr: str, index: int): +# return getattr(self, attr, None) + +# def prefetch(self, indices): +# """Prefetch the data required for this epoch.""" +# raise NotImplementedError + +# def batch_by_size( +# self, +# indices, +# batch_size=None, +# required_batch_size_multiple=1, +# ): +# """ +# Given an ordered set of indices +# """ +# from unicore.data import data_utils +# return data_utils.batch_by_size( +# indices, +# batch_size=batch_size, +# required_batch_size_multiple=required_batch_size_multiple, +# ) + +# @property +# def supports_fetch_outside_dataloader(self): +# """Whether this dataset supports fetching outside the workers of the dataloader.""" +# return True diff --git a/MindChemistry/applications/Uni-Mol/unicore/distributed/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/distributed/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..956f5b3cd605596b3289992506c7675a94c7c175 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/distributed/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from .module_proxy_wrapper import ModuleProxyWrapper +from .legacy_distributed_data_parallel import LegacyDistributedDataParallel + +__all__ = [ + "ModuleProxyWrapper", +] diff --git a/MindChemistry/applications/Uni-Mol/unicore/distributed/legacy_distributed_data_parallel.py b/MindChemistry/applications/Uni-Mol/unicore/distributed/legacy_distributed_data_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6a3c71dc6faec15b4680dde3aaafd40d394dc6 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/distributed/legacy_distributed_data_parallel.py @@ -0,0 +1,349 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +A modified version of the legacy DistributedDataParallel module that uses c10d +communication primitives. This version is simpler than the latest PyTorch +version and is useful for debugging. Notably it does not overlap gradient +communication with the backward pass, which makes it slower but more robust +than the PyTorch version. + +This version also supports the *no_sync* context manager, which allows faster +training with `--update-freq`. +""" +from collections import OrderedDict +from contextlib import contextmanager + +import mindspore as ms +from mindspore import nn, ops +# 单卡环境移除分布式通信相关导入(注释掉) +# from mindspore.communication import all_reduce +# from mindspore.parallel._utils import _get_device_num + +from unicore.distributed import utils + + +class LegacyDistributedDataParallel(nn.Cell): + """Implements distributed data parallelism at the cell level. + + 适配单卡Ascend NPU环境:移除多卡并行逻辑,保留单卡运行核心逻辑 + """ + + # 单卡环境无需process_group参数(注释原参数) + def __init__(self, cell, # process_group, + buffer_size=2** 28): + super().__init__() + + self.cell = cell + # 单卡环境无需通信组(注释掉) + # self.process_group = process_group + # 单卡环境世界大小固定为1(注释原多卡逻辑) + # self.world_size = _get_device_num() # 获取世界大小,对应PyTorch的get_world_size + self.world_size = 1 # 单卡环境 + + # 保留缓冲区大小计算(单卡下仅作兼容) + self.buffer_size = min(buffer_size, sum(p.size for p in cell.get_parameters())) + self.buffer = None + + # 单卡环境无需累积梯度同步(保留属性,注释原逻辑说明) + self.accumulate_grads = False # 控制是否累积梯度,单卡环境下无实际意义 + + # 单卡环境无需按设备分组(注释原多卡分组逻辑) + # 按设备分组参数(多卡逻辑) + # paramlists = OrderedDict() + # for param in self.cell.get_parameters(): + # device = param.device_info # 获取参数所在设备,对应PyTorch的param.device + # if paramlists.get(device) is None: + # paramlists[device] = [] + # paramlists[device] += [param] + # self.per_device_params = list(paramlists.values()) + + # 单卡环境:所有参数在同一设备,直接收集需要梯度的参数 + self.per_device_params = [[param for param in cell.get_parameters() if param.requires_grad]] + + @contextmanager + def no_sync(self): + """A context manager to disable gradient synchronization. + 单卡环境无需同步逻辑,保留空实现 + """ + # 多卡逻辑(注释掉) + # old_accumulate_grads = self.accumulate_grads + # self.accumulate_grads = True + yield + # self.accumulate_grads = old_accumulate_grads + + def construct(self, *inputs, **kwargs): # MindSpore中用construct替代forward + return self.cell(*inputs, **kwargs) + + def all_reduce_params(self, params): + """单卡环境无需梯度同步,仅确保梯度存在""" + if self.accumulate_grads: + return + + # 多卡缓冲区及all-reduce逻辑(注释掉) + # buffer = self.buffer + # nonzero_buffer = False + # if len(params) > 1: + # offset = 0 + # for p in params: + # sz = p.size # 替代numel(),获取参数元素数量 + # if p.grad is not None: + # # 复制梯度到缓冲区,reshape替代view + # buffer[offset : offset + sz] = p.grad.reshape(-1) + # nonzero_buffer = True + # else: + # buffer[offset : offset + sz] = 0 # 零初始化 + # offset += sz + # else: + # # 单个梯度的all-reduce + # p = params[0] + # if p.grad is not None: + # buffer = p.grad + # nonzero_buffer = True + # elif p.size <= self.buffer.size: + # buffer = buffer[: p.size] + # buffer = ops.zeros_like(buffer) # 替代torch.zeros_like + # else: + # buffer = ops.zeros_like(p) # 替代torch.zeros_like + + # if nonzero_buffer: + # buffer = buffer / self.world_size # 除以世界大小 + + # # 执行all-reduce操作,对应PyTorch的all_reduce(单卡无需) + # all_reduce(buffer, group=self.process_group) + + # # 将all-reduce后的梯度复制回原参数(单卡无需) + # offset = 0 + # for p in params: + # sz = p.size + # if p.grad is not None: + # # reshape替代view_as + # p.grad = buffer[offset : offset + sz].reshape(p.shape) + # else: + # # 克隆梯度,替代clone() + # p.grad = buffer[offset : offset + sz].reshape(p.shape).copy() + # offset += sz + + # 单卡逻辑:确保每个参数都有梯度(无同步操作) + for p in params: + if p.grad is None: + p.grad = ops.zeros_like(p) # 初始化零梯度 + + def all_reduce_grads(self): + """ + This function must be called explicitly after backward to reduce + gradients. + 单卡环境无需梯度同步,仅处理本地梯度 + """ + + def reduction_fn(): + if self.accumulate_grads: + return + + # 单卡环境初始化缓冲区(仅作兼容,实际很少使用) + if self.buffer is None and self.per_device_params: + first_param = next(param for param in self.cell.get_parameters() if param.requires_grad) + self.buffer = ms.Tensor(shape=(self.buffer_size,), dtype=first_param.dtype) # 单卡无需指定device + + for params in self.per_device_params: + # 多卡按桶进行梯度all-reduce逻辑(注释掉) + # offset = 0 + # buffered_params = [] + # for param in params: + # if not param.requires_grad: + # continue + # if param.grad is None: + # param.grad = ops.zeros_like(param) # 替代torch.zeros_like + + # if hasattr(param, "expert"): + # # 跳过非共享参数的梯度同步 + # continue + + # if param.grad.requires_grad: + # raise RuntimeError( + # "DistributedDataParallel only works " + # "with gradients that don't require " + # "grad" + # ) + # sz = param.size + # if sz > self.buffer.size: + # # 大参数直接进行all-reduce + # self.all_reduce_params([param]) + # else: + # if offset + sz > self.buffer.size: + # self.all_reduce_params(buffered_params) + # offset = 0 + # buffered_params.clear() + # buffered_params.append(param) + # offset += sz + + # if len(buffered_params) > 0: + # self.all_reduce_params(buffered_params) + + # 单卡逻辑:直接处理所有参数梯度(无需分桶) + for param in params: + if not param.requires_grad: + continue + if param.grad is None: + param.grad = ops.zeros_like(param) # 确保梯度存在 + + if hasattr(param, "expert"): + continue # 保留过滤逻辑 + + if param.grad.requires_grad: + raise RuntimeError("单卡环境下梯度不应需要梯度") + + reduction_fn() +# from collections import OrderedDict +# from contextlib import contextmanager + +# import torch +# from torch import nn + +# from unicore.distributed import utils + + +# class LegacyDistributedDataParallel(nn.Module): +# """Implements distributed data parallelism at the module level. + +# A simplified version of :class:`torch.nn.parallel.DistributedDataParallel`. +# This version uses a c10d process group for communication and does not +# broadcast buffers. + +# Args: +# module (~torch.nn.Module): module to be parallelized +# process_group: the c10d process group to be used for distributed data +# parallel all-reduction. +# buffer_size (int, optional): number of elements to buffer before +# performing all-reduce (default: 256M). +# """ + +# def __init__(self, module, process_group, buffer_size=2 ** 28): +# super().__init__() + +# self.module = module +# self.process_group = process_group +# self.world_size = utils.get_world_size(self.process_group) + +# # Never use a bigger buffer than the number of model params +# self.buffer_size = min(buffer_size, sum(p.numel() for p in module.parameters())) +# self.buffer = None + +# # We can also forcibly accumulate grads locally and only do the +# # all-reduce at some later time +# self.accumulate_grads = False + +# # make per-device lists of parameters +# paramlists = OrderedDict() +# for param in self.module.parameters(): +# device = param.device +# if paramlists.get(device) is None: +# paramlists[device] = [] +# paramlists[device] += [param] +# self.per_device_params = list(paramlists.values()) + +# @contextmanager +# def no_sync(self): +# """A context manager to disable gradient synchronization.""" +# old_accumulate_grads = self.accumulate_grads +# self.accumulate_grads = True +# yield +# self.accumulate_grads = old_accumulate_grads + +# def forward(self, *inputs, **kwargs): +# return self.module(*inputs, **kwargs) + +# def all_reduce_params(self, params): +# if self.accumulate_grads: +# return +# buffer = self.buffer +# nonzero_buffer = False +# if len(params) > 1: +# offset = 0 +# for p in params: +# sz = p.numel() +# if p.grad is not None: +# buffer[offset : offset + sz].copy_(p.grad.data.view(-1)) +# nonzero_buffer = True +# else: +# buffer[offset : offset + sz].zero_() +# offset += sz +# else: +# # we only have a single grad to all-reduce +# p = params[0] +# if p.grad is not None: +# buffer = p.grad.data +# nonzero_buffer = True +# elif p.numel() <= self.buffer.numel(): +# buffer = buffer[: p.numel()] +# buffer.zero_() +# else: +# buffer = torch.zeros_like(p) + +# if nonzero_buffer: +# buffer.div_(self.world_size) + +# utils.all_reduce(buffer, self.process_group) + +# # copy all-reduced grads back into their original place +# offset = 0 +# for p in params: +# sz = p.numel() +# if p.grad is not None: +# p.grad.data.copy_(buffer[offset : offset + sz].view_as(p)) +# else: +# p.grad = buffer[offset : offset + sz].view_as(p).clone() +# offset += sz + +# def all_reduce_grads(self): +# """ +# This function must be called explicitly after backward to reduce +# gradients. There is no automatic hook like c10d. +# """ + +# def reduction_fn(): +# # This function only needs to be called once +# if self.accumulate_grads: +# return + +# if self.buffer is None: +# self.buffer = next(self.module.parameters()).new(self.buffer_size) +# for params in self.per_device_params: +# # All-reduce the gradients in buckets +# offset = 0 +# buffered_params = [] +# for param in params: +# if not param.requires_grad: +# continue +# if param.grad is None: +# param.grad = torch.zeros_like(param) + +# if hasattr(param, "expert"): +# # Skip gradient sync for unshared parameters +# continue + +# if param.grad.requires_grad: +# raise RuntimeError( +# "DistributedDataParallel only works " +# "with gradients that don't require " +# "grad" +# ) +# sz = param.numel() +# if sz > self.buffer.numel(): +# # all-reduce big params directly +# self.all_reduce_params([param]) +# else: +# if offset + sz > self.buffer.numel(): +# self.all_reduce_params(buffered_params) +# offset = 0 +# buffered_params.clear() +# buffered_params.append(param) +# offset += sz + +# if len(buffered_params) > 0: +# self.all_reduce_params(buffered_params) + +# reduction_fn() \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/distributed/module_proxy_wrapper.py b/MindChemistry/applications/Uni-Mol/unicore/distributed/module_proxy_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..f8495bd491f17651512c1eeaacb109adee92f6da --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/distributed/module_proxy_wrapper.py @@ -0,0 +1,117 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from mindspore import nn + + +class ModuleProxyWrapper(nn.Cell): + """ + Wrap a DistributedDataParallel cell and forward requests for missing + attributes to the cell wrapped by DDP (the twice-wrapped cell). + Also forward calls to :func:`parameters_dict` and :func:`load_parameters`. + + Usage:: + + cell.xyz = "hello world" + wrapped_cell = nn.DistributedDataParallel(cell, **ddp_args) + wrapped_cell = ModuleProxyWrapper(wrapped_cell) + assert wrapped_cell.xyz == "hello world" + assert wrapped_cell.parameters_dict().keys() == cell.parameters_dict().keys() + + Args: + cell (nn.Cell): cell to wrap + """ + + def __init__(self, cell: nn.Cell): + super().__init__() + assert hasattr(cell, "cell"), \ + "ModuleProxyWrapper expects input to wrap another cell" # MindSpore的DDP内部属性为cell,对应PyTorch的module + self.cell = cell # 这里用cell替代PyTorch的module,保持命名一致性 + + def __getattr__(self, name): + """Forward missing attributes to twice-wrapped cell.""" + try: + # defer to nn.Cell's logic + return super().__getattr__(name) + except AttributeError: + try: + # forward to the once-wrapped cell + return getattr(self.cell, name) + except AttributeError: + # forward to the twice-wrapped cell + return getattr(self.cell.cell, name) # 访问DDP内部包装的cell + + def parameters_dict(self, *args, **kwargs): + """Forward to the twice-wrapped cell (MindSpore中对应state_dict的方法为parameters_dict)。""" + return self.cell.cell.parameters_dict(*args, **kwargs) + + def load_parameters(self, *args, **kwargs): + """Forward to the twice-wrapped cell (MindSpore中对应load_state_dict的方法为load_parameters)。""" + return self.cell.cell.load_parameters(*args, **kwargs) + + def construct(self, *args, **kwargs): # MindSpore中用construct替代forward + return self.cell(*args, **kwargs) + + def bfloat16(self): + return self.cell.cell.bfloat16() # 转发类型转换方法到内部cell + + def half(self): + return self.cell.cell.half() # 转发类型转换方法到内部cell +# from torch import nn + + +# class ModuleProxyWrapper(nn.Module): +# """ +# Wrap a DistributedDataParallel module and forward requests for missing +# attributes to the module wrapped by DDP (the twice-wrapped module). +# Also forward calls to :func:`state_dict` and :func:`load_state_dict`. + +# Usage:: + +# module.xyz = "hello world" +# wrapped_module = DistributedDataParallel(module, **ddp_args) +# wrapped_module = ModuleProxyWrapper(wrapped_module) +# assert wrapped_module.xyz == "hello world" +# assert wrapped_module.state_dict().keys() == module.state_dict().keys() + +# Args: +# module (nn.Module): module to wrap +# """ + +# def __init__(self, module: nn.Module): +# super().__init__() +# assert hasattr(module, "module"), \ +# "ModuleProxyWrapper expects input to wrap another module" +# self.module = module + +# def __getattr__(self, name): +# """Forward missing attributes to twice-wrapped module.""" +# try: +# # defer to nn.Module's logic +# return super().__getattr__(name) +# except AttributeError: +# try: +# # forward to the once-wrapped module +# return getattr(self.module, name) +# except AttributeError: +# # forward to the twice-wrapped module +# return getattr(self.module.module, name) + +# def state_dict(self, *args, **kwargs): +# """Forward to the twice-wrapped module.""" +# return self.module.module.state_dict(*args, **kwargs) + +# def load_state_dict(self, *args, **kwargs): +# """Forward to the twice-wrapped module.""" +# return self.module.module.load_state_dict(*args, **kwargs) + +# def forward(self, *args, **kwargs): +# return self.module(*args, **kwargs) + +# def bfloat16(self): +# return self.module.module.bfloat16() + +# def half(self): +# return self.module.module.half() diff --git a/MindChemistry/applications/Uni-Mol/unicore/distributed/utils.py b/MindChemistry/applications/Uni-Mol/unicore/distributed/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3f7582afb2b5101e2e4b09fcaed9ac0010f5ec75 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/distributed/utils.py @@ -0,0 +1,1083 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import datetime +import io +import logging +import os +import pickle +import random +import socket +import struct +import subprocess +import warnings +from collections import OrderedDict +from typing import Any, Dict, List, Mapping, Optional +from dataclasses import dataclass + +import mindspore as ms +# 单卡环境移除分布式通信相关导入(注释掉) +# import mindspore.communication as comm +# from mindspore.communication.management import get_rank, get_group_size, init, ALLREDUCE_OP, BROADCAST_OP +from mindspore import Tensor, ops +# 单卡环境无需并行工具(注释掉) +# from mindspore.parallel._utils import _get_device_num + + +logger = logging.getLogger(__name__) + + +def is_master(args): + # 单卡环境始终是主节点 + return True # args.distributed_rank == 0 # 原多卡逻辑注释 + + +def infer_init_method(args, force_distributed=False): + # 单卡环境无需推断分布式初始化方式(注释多卡逻辑) + # if args.distributed_init_method is not None: + # return + + # if all( + # key in os.environ + # for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] + # ): + # # 适配MindSpore分布式启动环境变量 + # _infer_mindspore_distributed_launch_init(args) + # elif args.distributed_port > 0: + # # 适配Slurm环境初始化 + # _infer_slurm_init(args) + # elif args.distributed_world_size > 1 or force_distributed: + # # 单节点多GPU fallback + # _infer_single_node_init(args) + + # elif not args.distributed_no_spawn: + # args.distributed_num_procs = min( + # _get_device_num(), args.distributed_world_size + # ) + # 单卡环境强制关闭分布式初始化 + args.distributed_init_method = None + args.distributed_world_size = 1 + args.distributed_rank = 0 + + +# 多卡环境初始化函数(注释掉,单卡无需) +# def _infer_mindspore_distributed_launch_init(args): +# args.distributed_init_method = "env://" +# args.distributed_world_size = int(os.environ["WORLD_SIZE"]) +# args.distributed_rank = int(os.environ["RANK"]) +# # 由MindSpore分布式启动管理进程 +# args.distributed_no_spawn = True + + +# 多卡Slurm环境适配(注释掉,单卡无需) +# def _infer_slurm_init(args): +# node_list = os.environ.get("SLURM_STEP_NODELIST") +# if node_list is None: +# node_list = os.environ.get("SLURM_JOB_NODELIST") +# if node_list is not None: +# try: +# hostnames = subprocess.check_output( +# ["scontrol", "show", "hostnames", node_list] +# ) +# args.distributed_init_method = "tcp://{host}:{port}".format( +# host=hostnames.split()[0].decode("utf-8"), +# port=args.distributed_port, +# ) +# nnodes = int(os.environ.get("SLURM_NNODES")) +# ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") +# if ntasks_per_node is not None: +# ntasks_per_node = int(ntasks_per_node) +# else: +# ntasks = int(os.environ.get("SLURM_NTASKS")) +# nnodes = int(os.environ.get("SLURM_NNODES")) +# assert ntasks % nnodes == 0 +# ntasks_per_node = int(ntasks / nnodes) +# if ntasks_per_node == 1: +# gpus_per_node = _get_device_num() +# node_id = int(os.environ.get("SLURM_NODEID")) +# args.distributed_rank = node_id * gpus_per_node +# args.distributed_world_size = nnodes * gpus_per_node +# else: +# assert ntasks_per_node == args.distributed_world_size // nnodes +# args.distributed_no_spawn = True +# args.distributed_rank = int(os.environ.get("SLURM_PROCID")) +# args.device_id = int(os.environ.get("SLURM_LOCALID")) +# except subprocess.CalledProcessError as e: # scontrol执行失败 +# raise e +# except FileNotFoundError: # 未安装Slurm +# pass + + +# 单节点多卡初始化(注释掉,单卡无需) +# def _infer_single_node_init(args): +# assert ( +# args.distributed_world_size <= _get_device_num() +# ), f"world size is {args.distributed_world_size} but have {_get_device_num()} available devices" +# port = random.randint(10000, 20000) +# args.distributed_init_method = "tcp://localhost:{port}".format(port=port) + + +def distributed_init(args): + # 单卡环境无需分布式初始化 + logger.info("单卡Ascend NPU环境,跳过分布式初始化") + args.distributed_rank = 0 # 固定为0 + # 单卡环境日志级别统一 + logging.getLogger().setLevel(logging.INFO) + return args.distributed_rank + + +# 多卡进程管理(注释掉,单卡无需) +# def distributed_main(i, main, args, kwargs): +# args.device_id = i +# if ms.get_context("device_target") in ["GPU", "Ascend"] and not args.cpu: +# ms.set_context(device_id=args.device_id) +# if args.distributed_rank is None: # 多进程启动场景 +# args.distributed_rank = kwargs.pop("start_rank", 0) + i + +# args.distributed_rank = distributed_init(args) + +# after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) +# if after_distributed_init_fn: +# args = after_distributed_init_fn(args) + +# main(args, **kwargs) + +# if comm.is_initialized(): +# comm.barrier() + + +def call_main(args, main,** kwargs): + # 单卡环境直接调用主函数,无需分布式处理 + infer_init_method(args) # 强制单卡初始化 + main(args, **kwargs) + # 多卡启动逻辑(注释掉) + # if args.distributed_init_method is None: + # infer_init_method(args) + + # if args.distributed_init_method is not None: + # # 分布式训练 + # if not args.distributed_no_spawn: + # start_rank = args.distributed_rank + # args.distributed_rank = None # 自动分配 + # kwargs["start_rank"] = start_rank + # # MindSpore多进程启动(模拟torch.multiprocessing.spawn) + # from mindspore import context + # context.set_context(mode=context.GRAPH_MODE) + # # 此处简化处理,实际可根据需要使用mindspore的多进程启动方式 + # for i in range(min(_get_device_num(), args.distributed_world_size)): + # distributed_main(i, main, args, kwargs) + # else: + # distributed_main(int(os.environ["LOCAL_RANK"]), main, args, kwargs) + # else: + # # 单卡训练 + # main(args,** kwargs) + + +def get_rank(group=None): + # 单卡环境固定返回0 + return 0 + # 多卡逻辑(注释掉) + # return get_rank(group=group) + + +def get_world_size(group=None): + # 单卡环境固定返回1 + return 1 + # 多卡逻辑(注释掉) + # if comm.is_initialized(): + # return get_group_size(group=group) + # else: + # return 1 + + +def get_global_group(): + # 单卡环境无通信组,返回None + return None + # 多卡逻辑(注释掉) + # return None # MindSpore默认使用全局组 + + +def get_global_rank(): + # 单卡环境固定返回0 + return 0 + # 多卡逻辑(注释掉) + # if comm.is_initialized(): + # return get_rank() + # else: + # return 0 + + +def get_global_world_size(): + # 单卡环境固定返回1 + return 1 + # 多卡逻辑(注释掉) + # if comm.is_initialized(): + # return get_group_size() + # else: + # return 1 + + +def get_data_parallel_group(): + """单卡环境无数据并行组""" + return None + # 多卡逻辑(注释掉) + # return get_global_group() + + +def get_data_parallel_rank(): + """单卡环境数据并行rank固定为0""" + return 0 + # 多卡逻辑(注释掉) + # return get_rank(get_data_parallel_group()) + + +def get_data_parallel_world_size(): + """单卡环境数据并行世界大小固定为1""" + return 1 + # 多卡逻辑(注释掉) + # return get_world_size(get_data_parallel_group()) + + +def all_reduce(tensor, group, op="sum"): + # 单卡环境无需all-reduce,直接返回原张量 + return tensor + # 多卡逻辑(注释掉) + # if op == "sum": + # op = ALLREDUCE_OP.SUM + # elif op == "max": + # op = ALLREDUCE_OP.MAX + # else: + # raise NotImplementedError + # comm.all_reduce(tensor, op=op, group=group) + # return tensor + + +def broadcast(tensor, src, group): + # 单卡环境无需广播,直接返回原张量 + return tensor + # 多卡逻辑(注释掉) + # comm.broadcast(tensor, src=src, group=group, op=BROADCAST_OP.BROADCAST) + + +def all_to_all(tensor, group): + # 单卡环境无需all-to-all,直接返回原张量 + return tensor + # 多卡逻辑(注释掉) + # """对1D张量执行all-to-all操作""" + # assert tensor.ndim == 1 + # split_count = get_world_size(group=group) + # assert tensor.size [0] % split_count == 0 + # output = ops.zeros_like(tensor) + # comm.alltoall(output, tensor, group=group) + # return output + + +def all_gather(tensor, group, return_tensor=False): + # 单卡环境all-gather返回自身(列表或张量) + if return_tensor: + return ops.stack([tensor], axis=0) + else: + return [tensor] + # 多卡逻辑(注释掉) + # """执行all-gather操作""" + # world_size = get_world_size(group=group) + # rank = get_rank(group=group) + # tensor_list = [ + # tensor if i == rank else ops.zeros_like(tensor) for i in range(world_size) + # ] + # comm.all_gather(tensor_list, tensor, group=group) + # if return_tensor: + # return ops.stack(tensor_list, axis=0) + # else: + # return tensor_list + + +def all_gather_list(data, group=None, max_size=16384): + # 单卡环境直接返回包含自身数据的列表 + return [data] + # 多卡逻辑(注释掉) + # """ + # 收集所有节点的任意数据到列表中。 + # 与all_gather类似,但支持任意Python数据。注意数据必须可pickle序列化, + # CUDA张量会被移到CPU并在CPU上返回。 + # """ + # from unicore import utils + + # if group is None: + # group = get_global_group() + # rank = get_rank(group=group) + # world_size = get_world_size(group=group) + + # buffer_size = max_size * world_size + # if ( + # not hasattr(all_gather_list, "_buffer") + # or all_gather_list._buffer.size [0] < buffer_size + # ): + # all_gather_list._buffer = Tensor( + # shape=(buffer_size,), + # dtype=ms.uint8, # 对应PyTorch的ByteTensor + # device=ms.get_context("device_target").lower() + # ) + # all_gather_list._cpu_buffer = Tensor(shape=(max_size,), dtype=ms.uint8).asnumpy() # 主机端缓冲区 + # buffer = all_gather_list._buffer + # buffer.fill(0) + # cpu_buffer = all_gather_list._cpu_buffer + + # data = utils.move_to_cpu(data) + # enc = pickle.dumps(data) + # enc_size = len(enc) + # header_size = 4 # 存储编码数据长度的头部大小 + # size = header_size + enc_size + # if size > max_size: + # raise ValueError( + # "encoded data size ({}) exceeds max_size ({})".format(size, max_size) + # ) + + # header = struct.pack(">I", enc_size) + # cpu_buffer[:size] = list(header + enc) + # start = rank * max_size + # buffer[start : start + size] = Tensor(cpu_buffer[:size], dtype=ms.uint8, device=buffer.device) + + # all_reduce(buffer, group=group) + + # buffer = buffer.asnumpy() + # try: + # result = [] + # for i in range(world_size): + # out_buffer = buffer[i * max_size : (i + 1) * max_size] + # (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) + # if enc_size > 0: + # result.append( + # pickle.loads( + # bytes(out_buffer[header_size : header_size + enc_size].tolist()) + # ) + # ) + # return result + # except pickle.UnpicklingError: + # raise Exception( + # "Unable to unpickle data from other workers. all_gather_list requires all " + # "workers to enter the function together, so this error usually indicates " + # "that the workers have fallen out of sync somehow. Workers can fall out of " + # "sync if one of them runs out of memory, or if there are other conditions " + # "in your training script that can cause one worker to finish an epoch " + # "while other workers are still iterating over their portions of the data. " + # "Try rerunning with appropriate DDP backend." + # ) + + +def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: + # 单卡环境无需all-reduce,直接返回原字典 + return OrderedDict(data) + # 多卡逻辑(注释掉) + # """ + # 对字典中的值进行跨worker的AllReduce。为提高性能, + # 分别对已在设备上的数据和CPU上的数据进行归约。 + + # 参数: + # data (Mapping[str, Any]): 要all-reduce的字典数据(不能是嵌套字典) + # device: 归约使用的设备 + # group: 通信组 + # """ + # data_keys = list(data.keys()) + + # # 分离设备上和CPU上的数据以提高性能 + # cpu_data = OrderedDict() + # device_data = OrderedDict() + # for k in data_keys: + # t = data[k] + # if not isinstance(t, Tensor): + # cpu_data[k] = Tensor(t, dtype=ms.float64) + # elif t.device != device: + # cpu_data[k] = t.astype(ms.float64) + # else: + # device_data[k] = t.astype(ms.float64) + + # def _all_reduce_dict(data: OrderedDict): + # if len(data) == 0: + # return data + # buf = ops.cat([t.reshape(-1) for t in data.values()]).to(device) + # all_reduce(buf, group=group) + # split_sizes = [t.size [0] for t in data.values()] + # split_buf = ops.split(buf, split_sizes) + # reduced_data = [t.reshape(orig.shape) for t, orig in zip(split_buf, data.values())] + # return OrderedDict(zip(data.keys(), reduced_data)) + + # cpu_data = _all_reduce_dict(cpu_data) + # device_data = _all_reduce_dict(device_data) + + # def get_from_stack(key): + # if key in cpu_data: + # return cpu_data[key] + # elif key in device_data: + # return device_data[key] + # raise KeyError + + # return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) + + +@dataclass +class _TensorPlaceholder: + index: int + + +def broadcast_tensors( + tensors: Optional[List[Tensor]], + src_rank: int, + group: object, + dist_device: Optional[str] = None, +) -> List[Tensor]: + # 单卡环境直接返回原张量列表 + return tensors if tensors is not None else [] + # 多卡逻辑(注释掉) + # """ + # 广播张量列表,非源rank无需知道张量的dtype和形状 + # """ + # if dist_device is None: + # if comm.get_backend(group) == "nccl": + # dist_device = "gpu" + # else: + # dist_device = "cpu" + + # # 先共享元数据以简化传输 + # is_src_rank = get_rank(group) == src_rank + # if is_src_rank: + # metadata = [ + # {"size": t.shape, "dtype": t.dtype, "device": t.device} for t in tensors + # ] + # metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) + # else: + # metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + + # out_tensors = [] + # for i, meta in enumerate(metadata): + # if is_src_rank: + # tensor = tensors[i] + # broadcast(tensor.to(dist_device), src=src_rank, group=group) + # else: + # tensor = ops.zeros( + # [meta["size"][0]], dtype=meta["dtype"], device=dist_device + # ) + # broadcast(tensor, src=src_rank, group=group) + # tensor = tensor.reshape(meta["size"]).to(meta["device"]) + # out_tensors.append(tensor) + # return out_tensors + + +def broadcast_object( + obj: Any, + src_rank: int, + group: object, + dist_device: Optional[str] = None, +) -> Any: + # 单卡环境直接返回原对象 + return obj + # 多卡逻辑(注释掉) + # """向其他worker广播任意Python对象""" + # if dist_device is None: + # if comm.get_backend(group) == "nccl": + # dist_device = "gpu" + # else: + # dist_device = "cpu" + + # if get_rank(group) == src_rank: + # # 分离张量和非张量数据,以便直接广播张量,避免不必要的序列化 + # tensors = [] + # obj = _split_tensors_from_obj(obj, tensors) + # obj = _broadcast_object_slow(obj, src_rank, group, dist_device) + # tensors = broadcast_tensors(tensors, src_rank, group, dist_device) + # else: + # obj = _broadcast_object_slow(None, src_rank, group, dist_device) + # tensors = broadcast_tensors(None, src_rank, group, dist_device) + # return _put_tensors_in_obj(obj, tensors) + + +def _broadcast_object_slow( + obj: Any, + src_rank: int, + group: object, + dist_device: str, +) -> Any: + # 单卡环境直接返回原对象 + return obj + # 多卡逻辑(注释掉) + # if get_rank(group) == src_rank: + # # 发送数据 + # buffer = io.BytesIO() + # ms.save(obj, buffer) # MindSpore的保存函数 + # buffer = Tensor(list(buffer.getbuffer()), dtype=ms.uint8, device=dist_device) + # length = Tensor([len(buffer)], dtype=ms.int64, device=dist_device) + # broadcast(length, src=src_rank, group=group) + # broadcast(buffer, src=src_rank, group=group) + # else: + # # 从源获取数据 + # length = Tensor([0], dtype=ms.int64, device=dist_device) + # broadcast(length, src=src_rank, group=group) + # buffer = Tensor([0] * int(length.item()), dtype=ms.uint8, device=dist_device) + # broadcast(buffer, src=src_rank, group=group) + # buffer = io.BytesIO(bytes(buffer.asnumpy().tolist())) + # obj = ms.load(buffer, map_location="cpu") # MindSpore的加载函数 + # return obj + + +def _split_tensors_from_obj(obj: Any, tensors: List[Tensor]) -> Any: + # 单卡环境无需分离张量(直接返回原对象) + return obj + # 多卡逻辑(注释掉) + # if isinstance(obj, Tensor): + # placeholder = _TensorPlaceholder(index=len(tensors)) + # tensors.append(obj) + # return placeholder + # elif isinstance(obj, dict): + # return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} + # elif isinstance(obj, list): + # return [_split_tensors_from_obj(v, tensors) for v in obj] + # elif isinstance(obj, tuple): + # return tuple(_split_tensors_from_obj(v, tensors) for v in obj) + # elif isinstance(obj, set): + # return {_split_tensors_from_obj(v, tensors) for v in obj} + # else: + # return obj + + +def _put_tensors_in_obj(obj: Any, tensors: List[Tensor]) -> Any: + # 单卡环境无需恢复张量(直接返回原对象) + return obj + # 多卡逻辑(注释掉) + # if isinstance(obj, _TensorPlaceholder): + # return tensors[obj.index] + # elif isinstance(obj, dict): + # return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} + # elif isinstance(obj, list): + # return [_put_tensors_in_obj(v, tensors) for v in obj] + # elif isinstance(obj, tuple): + # return tuple(_put_tensors_in_obj(v, tensors) for v in obj) + # elif isinstance(obj, set): + # return {_put_tensors_in_obj(v, tensors) for v in obj} + # else: + # return obj + + + +# import datetime +# import io +# import logging +# import os +# import pickle +# import random +# import socket +# import struct +# import subprocess +# import warnings +# from collections import OrderedDict +# from typing import Any, Dict, List, Mapping, Optional +# from dataclasses import dataclass + +# import torch +# import torch.distributed as dist + + +# logger = logging.getLogger(__name__) + + +# def is_master(args): +# return args.distributed_rank == 0 + + +# def infer_init_method(args, force_distributed=False): +# if args.distributed_init_method is not None: +# return + +# if all( +# key in os.environ +# for key in ["MASTER_ADDR", "MASTER_PORT", "WORLD_SIZE", "RANK"] +# ): +# # support torch.distributed.launch +# _infer_torch_distributed_launch_init(args) +# elif args.distributed_port > 0: +# # we can determine the init method automatically for Slurm +# _infer_slurm_init(args) +# elif args.distributed_world_size > 1 or force_distributed: +# # fallback for single node with multiple GPUs +# _infer_single_node_init(args) + +# elif not args.distributed_no_spawn: +# args.distributed_num_procs = min( +# torch.cuda.device_count(), args.distributed_world_size +# ) + + +# def _infer_torch_distributed_launch_init(args): +# args.distributed_init_method = "env://" +# args.distributed_world_size = int(os.environ["WORLD_SIZE"]) +# args.distributed_rank = int(os.environ["RANK"]) +# # processes are created by torch.distributed.launch +# args.distributed_no_spawn = True + + +# def _infer_slurm_init(args): +# node_list = os.environ.get("SLURM_STEP_NODELIST") +# if node_list is None: +# node_list = os.environ.get("SLURM_JOB_NODELIST") +# if node_list is not None: +# try: +# hostnames = subprocess.check_output( +# ["scontrol", "show", "hostnames", node_list] +# ) +# args.distributed_init_method = "tcp://{host}:{port}".format( +# host=hostnames.split()[0].decode("utf-8"), +# port=args.distributed_port, +# ) +# nnodes = int(os.environ.get("SLURM_NNODES")) +# ntasks_per_node = os.environ.get("SLURM_NTASKS_PER_NODE") +# if ntasks_per_node is not None: +# ntasks_per_node = int(ntasks_per_node) +# else: +# ntasks = int(os.environ.get("SLURM_NTASKS")) +# nnodes = int(os.environ.get("SLURM_NNODES")) +# assert ntasks % nnodes == 0 +# ntasks_per_node = int(ntasks / nnodes) +# if ntasks_per_node == 1: +# gpus_per_node = torch.cuda.device_count() +# node_id = int(os.environ.get("SLURM_NODEID")) +# args.distributed_rank = node_id * gpus_per_node +# args.distributed_world_size = nnodes * gpus_per_node +# else: +# assert ntasks_per_node == args.distributed_world_size // nnodes +# args.distributed_no_spawn = True +# args.distributed_rank = int(os.environ.get("SLURM_PROCID")) +# args.device_id = int(os.environ.get("SLURM_LOCALID")) +# except subprocess.CalledProcessError as e: # scontrol failed +# raise e +# except FileNotFoundError: # Slurm is not installed +# pass + + +# def _infer_single_node_init(args): +# assert ( +# args.distributed_world_size <= torch.cuda.device_count() +# ), f"world size is {args.distributed_world_size} but have {torch.cuda.device_count()} available devices" +# port = random.randint(10000, 20000) +# args.distributed_init_method = "tcp://localhost:{port}".format(port=port) + + +# def distributed_init(args): +# if torch.distributed.is_available() and torch.distributed.is_initialized(): +# warnings.warn("Distributed is already initialized, cannot initialize twice!") +# else: +# logger.info( +# "distributed init (rank {}): {}".format( +# args.distributed_rank, +# args.distributed_init_method, +# ) +# ) +# dist.init_process_group( +# backend=args.distributed_backend, +# init_method=args.distributed_init_method, +# world_size=args.distributed_world_size, +# rank=args.distributed_rank, +# timeout=datetime.timedelta(seconds=90), +# ) +# logger.info( +# "initialized host {} as rank {}".format( +# socket.gethostname(), +# args.distributed_rank, +# ) +# ) + +# # perform a dummy all-reduce to initialize the NCCL communicator +# if torch.cuda.is_available(): +# dist.all_reduce(torch.zeros(1).cuda()) + +# args.distributed_rank = torch.distributed.get_rank() + +# if is_master(args): +# logging.getLogger().setLevel(logging.INFO) +# else: +# logging.getLogger().setLevel(logging.WARNING) + +# return args.distributed_rank + + +# def distributed_main(i, main, args, kwargs): +# args.device_id = i +# if torch.cuda.is_available() and not args.cpu: +# torch.cuda.set_device(args.device_id) +# if args.distributed_rank is None: # torch.multiprocessing.spawn +# args.distributed_rank = kwargs.pop("start_rank", 0) + i + +# args.distributed_rank = distributed_init(args) + +# after_distributed_init_fn = kwargs.pop("after_distributed_init_fn", None) +# if after_distributed_init_fn: +# args = after_distributed_init_fn(args) + +# main(args, **kwargs) + +# if torch.distributed.is_initialized(): +# torch.distributed.barrier(get_global_group()) + + +# def call_main(args, main, **kwargs): +# if args.distributed_init_method is None: +# infer_init_method(args) + +# if args.distributed_init_method is not None: +# # distributed training +# if not args.distributed_no_spawn: +# start_rank = args.distributed_rank +# args.distributed_rank = None # assign automatically +# kwargs["start_rank"] = start_rank +# torch.multiprocessing.spawn( +# fn=distributed_main, +# args=(main, args, kwargs), +# nprocs=min( +# torch.cuda.device_count(), +# args.distributed_world_size, +# ), +# join=True, +# ) +# else: +# distributed_main(int(os.environ["LOCAL_RANK"]), main, args, kwargs) +# else: +# # single GPU main +# main(args, **kwargs) + + +# def get_rank(group): +# return dist.get_rank(group=group) + + +# def get_world_size(group): +# if torch.distributed.is_initialized(): +# return dist.get_world_size(group=group) +# else: +# return 1 + + +# def get_global_group(): +# return None + + +# def get_global_rank(): +# if torch.distributed.is_initialized(): +# return torch.distributed.get_rank() +# else: +# return 0 + + +# def get_global_world_size(): +# if torch.distributed.is_initialized(): +# return torch.distributed.get_world_size() +# else: +# return 1 + + +# def get_data_parallel_group(): +# """Get the data parallel group the caller rank belongs to.""" +# return get_global_group() + + +# def get_data_parallel_rank(): +# """Return my rank for the data parallel group.""" +# return get_rank(get_data_parallel_group()) + + +# def get_data_parallel_world_size(): +# """Return world size for the data parallel group.""" +# return get_world_size(get_data_parallel_group()) + + +# def all_reduce(tensor, group, op="sum"): +# if op == "sum": +# op = dist.ReduceOp.SUM +# elif op == "max": +# op = dist.ReduceOp.MAX +# else: +# raise NotImplementedError +# dist.all_reduce(tensor, op=op, group=group) +# return tensor + + +# def broadcast(tensor, src, group): +# dist.broadcast(tensor, src=src, group=group) + + +# def all_to_all(tensor, group): +# """Perform an all-to-all operation on a 1D Tensor.""" +# assert tensor.dim() == 1 +# split_count = get_world_size(group=group) +# assert tensor.numel() % split_count == 0 +# output = torch.zeros_like(tensor) +# dist.all_to_all_single(output, tensor, group=group) +# return output + + +# def all_gather(tensor, group, return_tensor=False): +# """Perform an all-gather operation.""" +# world_size = get_world_size(group=group) +# rank = get_rank(group=group) +# tensor_list = [ +# tensor if i == rank else torch.empty_like(tensor) for i in range(world_size) +# ] +# dist.all_gather(tensor_list, tensor, group=group) +# if return_tensor: +# return torch.stack(tensor_list, dim=0) +# else: +# return tensor_list + + +# def all_gather_list(data, group=None, max_size=16384): +# """Gathers arbitrary data from all nodes into a list. + +# Similar to :func:`~torch.distributed.all_gather` but for arbitrary Python +# data. Note that *data* must be picklable and any CUDA tensors will be moved +# to CPU and returned on CPU as well. + +# Args: +# data (Any): data from the local worker to be gathered on other workers +# group: group of the collective +# max_size (int, optional): maximum size of the data to be gathered +# across workers +# """ +# from unicore import utils + +# if group is None: +# group = get_global_group() +# rank = get_rank(group=group) +# world_size = get_world_size(group=group) + +# buffer_size = max_size * world_size +# if ( +# not hasattr(all_gather_list, "_buffer") +# or all_gather_list._buffer.numel() < buffer_size +# ): +# all_gather_list._buffer = torch.tensor( +# data=[0] * buffer_size, # Initialize with zeros +# dtype=torch.uint8, # Byte tensor corresponds to uint8 +# device="cuda", # Specify the device as CUDA +# ) +# all_gather_list._cpu_buffer = torch.ByteTensor(max_size).pin_memory() +# buffer = all_gather_list._buffer +# buffer.zero_() +# cpu_buffer = all_gather_list._cpu_buffer + +# data = utils.move_to_cpu(data) +# enc = pickle.dumps(data) +# enc_size = len(enc) +# header_size = 4 # size of header that contains the length of the encoded data +# size = header_size + enc_size +# if size > max_size: +# raise ValueError( +# "encoded data size ({}) exceeds max_size ({})".format(size, max_size) +# ) + +# header = struct.pack(">I", enc_size) +# cpu_buffer[:size] = torch.ByteTensor(list(header + enc)) +# start = rank * max_size +# buffer[start : start + size].copy_(cpu_buffer[:size]) + +# all_reduce(buffer, group=group) + +# buffer = buffer.cpu() +# try: +# result = [] +# for i in range(world_size): +# out_buffer = buffer[i * max_size : (i + 1) * max_size] +# (enc_size,) = struct.unpack(">I", bytes(out_buffer[:header_size].tolist())) +# if enc_size > 0: +# result.append( +# pickle.loads( +# bytes(out_buffer[header_size : header_size + enc_size].tolist()) +# ) +# ) +# return result +# except pickle.UnpicklingError: +# raise Exception( +# "Unable to unpickle data from other workers. all_gather_list requires all " +# "workers to enter the function together, so this error usually indicates " +# "that the workers have fallen out of sync somehow. Workers can fall out of " +# "sync if one of them runs out of memory, or if there are other conditions " +# "in your training script that can cause one worker to finish an epoch " +# "while other workers are still iterating over their portions of the data. " +# "Try rerunning with --ddp-backend=legacy_ddp and see if that helps." +# ) + + +# def all_reduce_dict(data: Mapping[str, Any], device, group) -> Dict[str, Any]: +# """ +# AllReduce a dictionary of values across workers. We separately +# reduce items that are already on the device and items on CPU for +# better performance. + +# Args: +# data (Mapping[str, Any]): dictionary of data to all-reduce, but +# cannot be a nested dictionary +# device (torch.device): device for the reduction +# group: group of the collective +# """ +# data_keys = list(data.keys()) + +# # We want to separately reduce items that are already on the +# # device and items on CPU for performance reasons. +# cpu_data = OrderedDict() +# device_data = OrderedDict() +# for k in data_keys: +# t = data[k] +# if not torch.is_tensor(t): +# cpu_data[k] = torch.tensor(t, dtype=torch.double) +# elif t.device.type != device.type: +# cpu_data[k] = t.to(dtype=torch.double) +# else: +# device_data[k] = t.to(dtype=torch.double) + +# def _all_reduce_dict(data: OrderedDict): +# if len(data) == 0: +# return data +# buf = torch.cat([t.view(-1) for t in data.values()]).to(device=device) +# all_reduce(buf, group=group) +# split_buf = torch.split(buf, [t.numel() for t in data.values()]) +# reduced_data = [t.view_as(orig) for t, orig in zip(split_buf, data.values())] +# return OrderedDict(zip(data.keys(), reduced_data)) + +# cpu_data = _all_reduce_dict(cpu_data) +# device_data = _all_reduce_dict(device_data) + +# def get_from_stack(key): +# if key in cpu_data: +# return cpu_data[key] +# elif key in device_data: +# return device_data[key] +# raise KeyError + +# return OrderedDict([(key, get_from_stack(key)) for key in data_keys]) + + +# @dataclass +# class _TensorPlaceholder: +# index: int + + +# def broadcast_tensors( +# tensors: Optional[List[torch.Tensor]], +# src_rank: int, +# group: object, +# dist_device: Optional[torch.device] = None, +# ) -> List[torch.Tensor]: +# """ +# Broadcasts a list of tensors without other (non-src) ranks needing to know +# the dtypes/shapes of the tensors. +# """ +# if dist_device is None: +# if torch.distributed.get_backend(group) == "nccl": +# dist_device = torch.device("cuda") +# else: +# dist_device = torch.device("cpu") + +# # share metadata first to simplify transfer +# is_src_rank = get_rank(group) == src_rank +# if is_src_rank: +# metadata = [ +# {"size": t.size(), "dtype": t.dtype, "device": t.device} for t in tensors +# ] +# metadata = _broadcast_object_slow(metadata, src_rank, group, dist_device) +# else: +# metadata = _broadcast_object_slow(None, src_rank, group, dist_device) + +# out_tensors = [] +# for i, meta in enumerate(metadata): +# if is_src_rank: +# tensor = tensors[i] +# broadcast(tensors[i].to(dist_device), src=src_rank, group=group) +# else: +# tensor = torch.zeros( +# [meta["size"].numel()], dtype=meta["dtype"], device=dist_device +# ) +# broadcast(tensor, src=src_rank, group=group) +# tensor = tensor.view(meta["size"]).to(meta["device"]) +# out_tensors.append(tensor) +# return out_tensors + + +# def broadcast_object( +# obj: Any, +# src_rank: int, +# group: object, +# dist_device: Optional[torch.device] = None, +# ) -> Any: +# """Broadcast an arbitrary Python object to other workers.""" +# if dist_device is None: +# if torch.distributed.get_backend(group) == "nccl": +# dist_device = torch.device("cuda") +# else: +# dist_device = torch.device("cpu") + +# if get_rank(group) == src_rank: +# # split the tensors from the non-tensors so we can broadcast them +# # directly, avoiding unnecessary serialization/deserialization +# tensors = [] +# obj = _split_tensors_from_obj(obj, tensors) +# obj = _broadcast_object_slow(obj, src_rank, group, dist_device) +# tensors = broadcast_tensors(tensors, src_rank, group, dist_device) +# else: +# obj = _broadcast_object_slow(None, src_rank, group, dist_device) +# tensors = broadcast_tensors(None, src_rank, group, dist_device) +# return _put_tensors_in_obj(obj, tensors) + + +# def _broadcast_object_slow( +# obj: Any, +# src_rank: int, +# group: object, +# dist_device: torch.device, +# ) -> Any: +# if get_rank(group) == src_rank: +# # Emit data +# buffer = io.BytesIO() +# torch.save(obj, buffer) +# buffer = torch.ByteTensor(buffer.getbuffer()).to(dist_device) +# length = torch.LongTensor([len(buffer)]).to(dist_device) +# broadcast(length, src=src_rank, group=group) +# broadcast(buffer, src=src_rank, group=group) +# else: +# # Fetch from the source +# length = torch.LongTensor([0]).to(dist_device) +# broadcast(length, src=src_rank, group=group) +# buffer = torch.ByteTensor(int(length.item())).to(dist_device) +# broadcast(buffer, src=src_rank, group=group) +# buffer = io.BytesIO(buffer.cpu().numpy()) +# obj = torch.load(buffer, map_location="cpu", weights_only=False) +# return obj + + +# def _split_tensors_from_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: +# if torch.is_tensor(obj): +# placeholder = _TensorPlaceholder(index=len(tensors)) +# tensors.append(obj) +# return placeholder +# elif isinstance(obj, dict): +# return {k: _split_tensors_from_obj(v, tensors) for k, v in obj.items()} +# elif isinstance(obj, list): +# return [_split_tensors_from_obj(v, tensors) for v in obj] +# elif isinstance(obj, tuple): +# return tuple(_split_tensors_from_obj(v, tensors) for v in obj) +# elif isinstance(obj, set): +# return {_split_tensors_from_obj(v, tensors) for v in obj} +# else: +# return obj + + +# def _put_tensors_in_obj(obj: Any, tensors: List[torch.Tensor]) -> Any: +# if isinstance(obj, _TensorPlaceholder): +# return tensors[obj.index] +# elif isinstance(obj, dict): +# return {k: _put_tensors_in_obj(v, tensors) for k, v in obj.items()} +# elif isinstance(obj, list): +# return [_put_tensors_in_obj(v, tensors) for v in obj] +# elif isinstance(obj, tuple): +# return tuple(_put_tensors_in_obj(v, tensors) for v in obj) +# elif isinstance(obj, set): +# return {_put_tensors_in_obj(v, tensors) for v in obj} +# else: +# return obj diff --git a/MindChemistry/applications/Uni-Mol/unicore/ema.py b/MindChemistry/applications/Uni-Mol/unicore/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..14c407f8d4c03efd2f752f006381fd7e2fd69257 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/ema.py @@ -0,0 +1,136 @@ +from copy import deepcopy +from unicore.optim.fp16_optimizer import separate_decay_params, flatten_parameters_fp32 # 假设已适配MindSpore版本 +import mindspore as ms + + +class ExponentialMovingAverageModel: + def __init__(self, args, model, decay, is_flattened=False): + self.args = args + self.model_ema = deepcopy(model) # MindSpore模型支持deepcopy + self.decay = decay + self.is_flattened = is_flattened + if not is_flattened: + self.name2param = self.get_name2param() + else: + self.flatten_params = self.flatten_parameters() + + def get_name2param(self): + name2param = dict() + for n, p in self.model_ema.named_parameters(): + name2param[n] = p + # MindSpore中转换为float类型(32位浮点数) + p.data = p.data.astype(ms.float32) + p.grad = None # 清除梯度 + return name2param + + def flatten_parameters(self): + # 假设separate_decay_params已适配MindSpore参数格式 + param_group = separate_decay_params( + self.args, self.model_ema.named_parameters() + ) + flatten_group = [] + for param_dict in param_group: + params = param_dict["params"] + # 假设flatten_parameters_fp32已适配MindSpore参数 + flatten_param = flatten_parameters_fp32( + params, set_to_param=True, set_grad=False + ) + flatten_group.append(flatten_param) + return flatten_group + + def update_one_param(self, ema_param, new_param): + # 指数移动平均更新公式:ema = ema * decay + new * (1 - decay) + # 等价于:ema -= (ema - new) * (1 - decay) + diff = ema_param - new_param + diff *= 1 - self.decay + ema_param -= diff + + def update(self, new_param): + # MindSpore中使用no_grad上下文管理器禁用梯度计算 + with ms.no_grad(): + if self.is_flattened: + for i in range(len(self.flatten_params)): + self.update_one_param( + self.flatten_params[i], new_param[i]["params"][0] + ) + else: + for n, p in new_param: + if n in self.name2param: + self.update_one_param(self.name2param[n], p) + + def load_state_dict(self, state_dict): + # MindSpore模型加载状态字典 + self.model_ema.load_state_dict(state_dict["params"]) + self.decay = state_dict["decay"] if "decay" in state_dict else self.decay + + def state_dict(self): + # 返回MindSpore模型的状态字典 + return { + "params": self.model_ema.state_dict(), + "decay": self.decay, + } +# from copy import deepcopy +# from unicore.optim.fp16_optimizer import separate_decay_params, flatten_parameters_fp32 +# import torch + + +# class ExponentialMovingAverageModel: +# def __init__(self, args, model, decay, is_flattened=False): +# self.args = args +# self.model_ema = deepcopy(model) +# self.decay = decay +# self.is_flattened = is_flattened +# if not is_flattened: +# self.name2param = self.get_name2param() +# else: +# self.flatten_params = self.flatten_parameters() + +# def get_name2param(self): +# name2param = dict() +# for n, p in self.model_ema.named_parameters(): +# name2param[n] = p +# # use float type for ema +# p.data = p.data.float() +# p.grad = None +# return name2param + +# def flatten_parameters(self): +# param_group = separate_decay_params( +# self.args, self.model_ema.named_parameters() +# ) +# flatten_group = [] +# for param_dict in param_group: +# params = param_dict["params"] +# flatten_param = flatten_parameters_fp32( +# params, set_to_param=True, set_grad=False +# ) +# flatten_group.append(flatten_param) +# return flatten_group + +# def update_one_param(self, ema_param, new_param): +# diff = ema_param - new_param +# diff *= 1 - self.decay +# ema_param -= diff + +# def update(self, new_param): +# if self.is_flattened: +# with torch.no_grad(): +# for i in range(len(self.flatten_params)): +# self.update_one_param( +# self.flatten_params[i], new_param[i]["params"][0] +# ) +# else: +# with torch.no_grad(): +# for n, p in new_param: +# if n in self.name2param: +# self.update_one_param(self.name2param[n], p) + +# def load_state_dict(self, state_dict): +# self.model_ema.load_state_dict(state_dict["params"]) +# self.decay = state_dict["decay"] if "decay" in state_dict else self.decay + +# def state_dict(self): +# return { +# "params": self.model_ema.state_dict(), +# "decay": self.decay, +# } diff --git a/MindChemistry/applications/Uni-Mol/unicore/logging/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/logging/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindChemistry/applications/Uni-Mol/unicore/logging/meters.py b/MindChemistry/applications/Uni-Mol/unicore/logging/meters.py new file mode 100644 index 0000000000000000000000000000000000000000..8480b388a5d5fcdac043a5da155f1d0e54298b22 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/logging/meters.py @@ -0,0 +1,577 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import bisect +import time +from collections import OrderedDict +from typing import Dict, Optional + + +try: + import mindspore as ms + + def type_as(a, b): + if ms.is_tensor(a) and ms.is_tensor(b): + return a.astype(b.dtype) + else: + return a + + +except ImportError: + ms = None + + def type_as(a, b): + return a + + +try: + import numpy as np +except ImportError: + np = None + + +class Meter(object): + """Base class for Meters.""" + + def __init__(self): + pass + + def state_dict(self): + return {} + + def load_state_dict(self, state_dict): + pass + + def reset(self): + raise NotImplementedError + + @property + def smoothed_value(self) -> float: + """Smoothed value used for logging.""" + raise NotImplementedError + + +def safe_round(number, ndigits): + if hasattr(number, "__round__"): + return round(number, ndigits) + elif ms is not None and ms.is_tensor(number) and number.size == 1: + return safe_round(number.item(), ndigits) + elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): + return safe_round(number.item(), ndigits) + else: + return number + + +class AverageMeter(Meter): + """Computes and stores the average and current value""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.reset() + + def reset(self): + self.val = None # most recent update + self.sum = 0 # sum from all updates + self.count = 0 # total n from all updates + + def update(self, val, n=1): + if val is not None: + self.val = val + if n > 0: + self.sum = type_as(self.sum, val) + (val * n) + self.count = type_as(self.count, n) + n + + def state_dict(self): + return { + "val": self.val, + "sum": self.sum, + "count": self.count, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.val = state_dict["val"] + self.sum = state_dict["sum"] + self.count = state_dict["count"] + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.count if self.count > 0 else self.val + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class TimeMeter(Meter): + """Computes the average occurrence of some event per second""" + + def __init__( + self, + init: int = 0, + n: int = 0, + round: Optional[int] = None, + ): + self.round = round + self.reset(init, n) + + def reset(self, init=0, n=0): + self.init = init + self.start = time.perf_counter() + self.n = n + self.i = 0 + + def update(self, val=1): + self.n = type_as(self.n, val) + val + self.i += 1 + + def state_dict(self): + return { + "init": self.elapsed_time, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + if "start" in state_dict: + # backwards compatibility for old state_dicts + self.reset(init=state_dict["init"]) + else: + self.reset(init=state_dict["init"], n=state_dict["n"]) + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.n / self.elapsed_time + + @property + def elapsed_time(self): + return self.init + (time.perf_counter() - self.start) + + @property + def smoothed_value(self) -> float: + val = self.avg + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class StopwatchMeter(Meter): + """Computes the sum/avg duration of some event in seconds""" + + def __init__(self, round: Optional[int] = None): + self.round = round + self.sum = 0 + self.n = 0 + self.start_time = None + + def start(self): + self.start_time = time.perf_counter() + + def stop(self, n=1, prehook=None): + if self.start_time is not None: + if prehook is not None: + prehook() + delta = time.perf_counter() - self.start_time + self.sum = self.sum + delta + self.n = type_as(self.n, n) + n + + def reset(self): + self.sum = 0 # cumulative time during which stopwatch was active + self.n = 0 # total n across all start/stop + self.start() + + def state_dict(self): + return { + "sum": self.sum, + "n": self.n, + "round": self.round, + } + + def load_state_dict(self, state_dict): + self.sum = state_dict["sum"] + self.n = state_dict["n"] + self.start_time = None + self.round = state_dict.get("round", None) + + @property + def avg(self): + return self.sum / self.n if self.n > 0 else self.sum + + @property + def elapsed_time(self): + if self.start_time is None: + return 0.0 + return time.perf_counter() - self.start_time + + @property + def smoothed_value(self) -> float: + val = self.avg if self.sum > 0 else self.elapsed_time + if self.round is not None and val is not None: + val = safe_round(val, self.round) + return val + + +class MetersDict(OrderedDict): + """A sorted dictionary of :class:`Meters`. + + Meters are sorted according to a priority that is given when the + meter is first added to the dictionary. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.priorities = [] + + def __setitem__(self, key, value): + assert key not in self, "MetersDict doesn't support reassignment" + priority, value = value + bisect.insort(self.priorities, (priority, len(self.priorities), key)) + super().__setitem__(key, value) + for _, _, key in self.priorities: # reorder dict to match priorities + self.move_to_end(key) + + def add_meter(self, key, meter, priority): + self.__setitem__(key, (priority, meter)) + + def state_dict(self): + return [ + (pri, key, self[key].__class__.__name__, self[key].state_dict()) + for pri, _, key in self.priorities + # can't serialize DerivedMeter instances + if not isinstance(self[key], MetersDict._DerivedMeter) + ] + + def load_state_dict(self, state_dict): + self.clear() + self.priorities.clear() + for pri, key, meter_cls, meter_state in state_dict: + meter = globals()[meter_cls]() + meter.load_state_dict(meter_state) + self.add_meter(key, meter, pri) + + def get_smoothed_value(self, key: str) -> float: + """Get a single smoothed value.""" + meter = self[key] + if isinstance(meter, MetersDict._DerivedMeter): + return meter.fn(self) + else: + return meter.smoothed_value + + def get_smoothed_values(self) -> Dict[str, float]: + """Get all smoothed values.""" + return OrderedDict( + [ + (key, self.get_smoothed_value(key)) + for key in self.keys() + if not key.startswith("_") + ] + ) + + def reset(self): + """Reset Meter instances.""" + for meter in self.values(): + if isinstance(meter, MetersDict._DerivedMeter): + continue + meter.reset() + + class _DerivedMeter(Meter): + """A Meter whose values are derived from other Meters.""" + + def __init__(self, fn): + self.fn = fn + + def reset(self): + pass +# import bisect +# import time +# from collections import OrderedDict +# from typing import Dict, Optional + + +# try: +# import torch + +# def type_as(a, b): +# if torch.is_tensor(a) and torch.is_tensor(b): +# return a.to(b) +# else: +# return a + + +# except ImportError: +# torch = None + +# def type_as(a, b): +# return a + + +# try: +# import numpy as np +# except ImportError: +# np = None + + +# class Meter(object): +# """Base class for Meters.""" + +# def __init__(self): +# pass + +# def state_dict(self): +# return {} + +# def load_state_dict(self, state_dict): +# pass + +# def reset(self): +# raise NotImplementedError + +# @property +# def smoothed_value(self) -> float: +# """Smoothed value used for logging.""" +# raise NotImplementedError + + +# def safe_round(number, ndigits): +# if hasattr(number, "__round__"): +# return round(number, ndigits) +# elif torch is not None and torch.is_tensor(number) and number.numel() == 1: +# return safe_round(number.item(), ndigits) +# elif np is not None and np.ndim(number) == 0 and hasattr(number, "item"): +# return safe_round(number.item(), ndigits) +# else: +# return number + + +# class AverageMeter(Meter): +# """Computes and stores the average and current value""" + +# def __init__(self, round: Optional[int] = None): +# self.round = round +# self.reset() + +# def reset(self): +# self.val = None # most recent update +# self.sum = 0 # sum from all updates +# self.count = 0 # total n from all updates + +# def update(self, val, n=1): +# if val is not None: +# self.val = val +# if n > 0: +# self.sum = type_as(self.sum, val) + (val * n) +# self.count = type_as(self.count, n) + n + +# def state_dict(self): +# return { +# "val": self.val, +# "sum": self.sum, +# "count": self.count, +# "round": self.round, +# } + +# def load_state_dict(self, state_dict): +# self.val = state_dict["val"] +# self.sum = state_dict["sum"] +# self.count = state_dict["count"] +# self.round = state_dict.get("round", None) + +# @property +# def avg(self): +# return self.sum / self.count if self.count > 0 else self.val + +# @property +# def smoothed_value(self) -> float: +# val = self.avg +# if self.round is not None and val is not None: +# val = safe_round(val, self.round) +# return val + + +# class TimeMeter(Meter): +# """Computes the average occurrence of some event per second""" + +# def __init__( +# self, +# init: int = 0, +# n: int = 0, +# round: Optional[int] = None, +# ): +# self.round = round +# self.reset(init, n) + +# def reset(self, init=0, n=0): +# self.init = init +# self.start = time.perf_counter() +# self.n = n +# self.i = 0 + +# def update(self, val=1): +# self.n = type_as(self.n, val) + val +# self.i += 1 + +# def state_dict(self): +# return { +# "init": self.elapsed_time, +# "n": self.n, +# "round": self.round, +# } + +# def load_state_dict(self, state_dict): +# if "start" in state_dict: +# # backwards compatibility for old state_dicts +# self.reset(init=state_dict["init"]) +# else: +# self.reset(init=state_dict["init"], n=state_dict["n"]) +# self.round = state_dict.get("round", None) + +# @property +# def avg(self): +# return self.n / self.elapsed_time + +# @property +# def elapsed_time(self): +# return self.init + (time.perf_counter() - self.start) + +# @property +# def smoothed_value(self) -> float: +# val = self.avg +# if self.round is not None and val is not None: +# val = safe_round(val, self.round) +# return val + + +# class StopwatchMeter(Meter): +# """Computes the sum/avg duration of some event in seconds""" + +# def __init__(self, round: Optional[int] = None): +# self.round = round +# self.sum = 0 +# self.n = 0 +# self.start_time = None + +# def start(self): +# self.start_time = time.perf_counter() + +# def stop(self, n=1, prehook=None): +# if self.start_time is not None: +# if prehook is not None: +# prehook() +# delta = time.perf_counter() - self.start_time +# self.sum = self.sum + delta +# self.n = type_as(self.n, n) + n + +# def reset(self): +# self.sum = 0 # cumulative time during which stopwatch was active +# self.n = 0 # total n across all start/stop +# self.start() + +# def state_dict(self): +# return { +# "sum": self.sum, +# "n": self.n, +# "round": self.round, +# } + +# def load_state_dict(self, state_dict): +# self.sum = state_dict["sum"] +# self.n = state_dict["n"] +# self.start_time = None +# self.round = state_dict.get("round", None) + +# @property +# def avg(self): +# return self.sum / self.n if self.n > 0 else self.sum + +# @property +# def elapsed_time(self): +# if self.start_time is None: +# return 0.0 +# return time.perf_counter() - self.start_time + +# @property +# def smoothed_value(self) -> float: +# val = self.avg if self.sum > 0 else self.elapsed_time +# if self.round is not None and val is not None: +# val = safe_round(val, self.round) +# return val + + +# class MetersDict(OrderedDict): +# """A sorted dictionary of :class:`Meters`. + +# Meters are sorted according to a priority that is given when the +# meter is first added to the dictionary. +# """ + +# def __init__(self, *args, **kwargs): +# super().__init__(*args, **kwargs) +# self.priorities = [] + +# def __setitem__(self, key, value): +# assert key not in self, "MetersDict doesn't support reassignment" +# priority, value = value +# bisect.insort(self.priorities, (priority, len(self.priorities), key)) +# super().__setitem__(key, value) +# for _, _, key in self.priorities: # reorder dict to match priorities +# self.move_to_end(key) + +# def add_meter(self, key, meter, priority): +# self.__setitem__(key, (priority, meter)) + +# def state_dict(self): +# return [ +# (pri, key, self[key].__class__.__name__, self[key].state_dict()) +# for pri, _, key in self.priorities +# # can't serialize DerivedMeter instances +# if not isinstance(self[key], MetersDict._DerivedMeter) +# ] + +# def load_state_dict(self, state_dict): +# self.clear() +# self.priorities.clear() +# for pri, key, meter_cls, meter_state in state_dict: +# meter = globals()[meter_cls]() +# meter.load_state_dict(meter_state) +# self.add_meter(key, meter, pri) + +# def get_smoothed_value(self, key: str) -> float: +# """Get a single smoothed value.""" +# meter = self[key] +# if isinstance(meter, MetersDict._DerivedMeter): +# return meter.fn(self) +# else: +# return meter.smoothed_value + +# def get_smoothed_values(self) -> Dict[str, float]: +# """Get all smoothed values.""" +# return OrderedDict( +# [ +# (key, self.get_smoothed_value(key)) +# for key in self.keys() +# if not key.startswith("_") +# ] +# ) + +# def reset(self): +# """Reset Meter instances.""" +# for meter in self.values(): +# if isinstance(meter, MetersDict._DerivedMeter): +# continue +# meter.reset() + +# class _DerivedMeter(Meter): +# """A Meter whose values are derived from other Meters.""" + +# def __init__(self, fn): +# self.fn = fn + +# def reset(self): +# pass diff --git a/MindChemistry/applications/Uni-Mol/unicore/logging/metrics.py b/MindChemistry/applications/Uni-Mol/unicore/logging/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d52b0b57bd0709a325f491279cebec2ef8343fe7 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/logging/metrics.py @@ -0,0 +1,288 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +A standalone module for aggregating metrics. + +Metrics can be logged from anywhere using the `log_*` functions defined +in this module. The logged values will be aggregated dynamically based +on the aggregation context in which the logging occurs. See the +:func:`aggregate` context manager for more details. +""" + +import contextlib +import uuid +from collections import OrderedDict, defaultdict +from typing import Callable, Dict, List, Optional + +from .meters import * + + +# Aggregation contexts are considered "active" when inside the scope +# created by the :func:`aggregate` context manager. +_aggregators = OrderedDict() +_active_aggregators = OrderedDict() +_active_aggregators_cnt = defaultdict(lambda: 0) + + +def reset() -> None: + """Reset all metrics aggregators.""" + _aggregators.clear() + _active_aggregators.clear() + _active_aggregators_cnt.clear() + + # The "default" aggregator observes all logged values. + _aggregators["default"] = MetersDict() + _active_aggregators["default"] = _aggregators["default"] + _active_aggregators_cnt["default"] = 1 + + +reset() + + +@contextlib.contextmanager +def aggregate(name: Optional[str] = None, new_root: bool = False): + """Context manager to aggregate metrics under a given name. + + Aggregations can be nested. If *new_root* is ``False``, then logged + metrics will be recorded along the entire stack of nested + aggregators, including a global "default" aggregator. If *new_root* + is ``True``, then this aggregator will be the root of a new + aggregation stack, thus bypassing any parent aggregators. + + Note that aggregation contexts are uniquely identified by their + *name* (e.g., train, valid). Creating a context with an existing + name will reuse the corresponding :class:`MetersDict` instance. + If no name is given, then a temporary aggregator will be created. + + Usage:: + + with metrics.aggregate("train"): + for step, batch in enumerate(epoch): + with metrics.aggregate("train_inner") as agg: + metrics.log_scalar("loss", get_loss(batch)) + if step % log_interval == 0: + print(agg.get_smoothed_value("loss")) + agg.reset() + print(metrics.get_smoothed_values("train")["loss"]) + + Args: + name (str): name of the aggregation. Defaults to a + random/temporary name if not given explicitly. + new_root (bool): make this aggregation the root of a new + aggregation stack. + """ + if name is None: + # generate a temporary name + name = str(uuid.uuid4()) + assert name not in _aggregators + agg = MetersDict() + else: + assert name != "default" + agg = _aggregators.setdefault(name, MetersDict()) + + if new_root: + backup_aggregators = _active_aggregators.copy() + _active_aggregators.clear() + backup_aggregators_cnt = _active_aggregators_cnt.copy() + _active_aggregators_cnt.clear() + + _active_aggregators[name] = agg + _active_aggregators_cnt[name] += 1 + + yield agg + + _active_aggregators_cnt[name] -= 1 + if _active_aggregators_cnt[name] == 0 and name in _active_aggregators: + del _active_aggregators[name] + + if new_root: + _active_aggregators.clear() + _active_aggregators.update(backup_aggregators) + _active_aggregators_cnt.clear() + _active_aggregators_cnt.update(backup_aggregators_cnt) + + +def get_active_aggregators() -> List[MetersDict]: + return list(_active_aggregators.values()) + + +def log_scalar( + key: str, + value: float, + weight: float = 1, + priority: int = 10, + round: Optional[int] = None, +): + """Log a scalar value. + + Args: + key (str): name of the field to log + value (float): value to log + weight (float): weight that this value contributes to the average. + A weight of 0 will always log the latest value. + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, AverageMeter(round=round), priority) + agg[key].update(value, weight) + + +def log_derived(key: str, fn: Callable[[MetersDict], float], priority: int = 20): + """Log a scalar value derived from other meters. + + Args: + key (str): name of the field to log + fn (Callable[[MetersDict], float]): function that takes a single + argument *meters* and returns the derived value + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, MetersDict._DerivedMeter(fn), priority) + + +def log_speed( + key: str, + value: float, + priority: int = 30, + round: Optional[int] = None, +): + """Log the rate of some quantity per second. + + Args: + key (str): name of the field to log + value (float): value to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, TimeMeter(round=round), priority) + agg[key].reset() # reset meter on the first call + else: + agg[key].update(value) + + +def log_start_time(key: str, priority: int = 40, round: Optional[int] = None): + """Log the duration of some event in seconds. + + The duration will be computed once :func:`log_stop_time` is called. + + Args: + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + round (Optional[int]): number of digits to round to when displaying + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, StopwatchMeter(round=round), priority) + agg[key].start() + + +def log_stop_time(key: str, weight: float = 0.0, prehook=None): + """Log the duration of some event in seconds. + + The duration will be computed since :func:`log_start_time` was called. + Set weight > 0 to report the average time instead of the sum. + + Args: + key (str): name of the field to log + weight (float): weight that this time contributes to the average + prehook (function, no arguments): will be called before the timer + is stopped. For example, use prehook=torch.cuda.synchronize to + make sure all gpu operations are done before timer is stopped. + """ + for agg in get_active_aggregators(): + if key in agg: + agg[key].stop(weight, prehook) + + +def log_custom( + new_meter_fn: Callable[[], Meter], + key: str, + *args, + priority: int = 50, + **kwargs, +): + """Log using a custom Meter. + + Any extra *args* or *kwargs* will be passed through to the Meter's + *update* method. + + Args: + new_meter_fn (Callable[[], Meter]): function that returns a new + Meter instance + key (str): name of the field to log + priority (int): smaller values are logged earlier in the output + """ + for agg in get_active_aggregators(): + if key not in agg: + agg.add_meter(key, new_meter_fn(), priority) + agg[key].update(*args, **kwargs) + + +def reset_meter(name: str, key: str) -> None: + """Reset Meter instance aggregated under a given *name* and *key*.""" + meter = get_meter(name, key) + if meter is not None: + meter.reset() + + +def reset_meters(name: str) -> None: + """Reset Meter instances aggregated under a given *name*.""" + meters = get_meters(name) + if meters is not None: + meters.reset() + + +def get_meter(name: str, key: str) -> Meter: + """Get a single Meter instance aggregated under *name* and *key*. + + Returns: + Meter or None if no metrics have been logged under *name* and *key*. + """ + if name not in _aggregators: + return None + return _aggregators[name].get(key, None) + + +def get_meters(name: str) -> MetersDict: + """Get Meter instances aggregated under a given *name*. + + Returns: + MetersDict or None if no metrics have been logged under *name*. + """ + return _aggregators.get(name, None) + + +def get_smoothed_value(name: str, key: str) -> float: + """Get a single smoothed value. + + Raises: + KeyError: if no metrics have been logged under *name* and *key*. + """ + return _aggregators[name].get_smoothed_value(key) + + +def get_smoothed_values(name: str) -> Dict[str, float]: + """Get smoothed values aggregated under a given *name*. + + Raises: + KeyError: if no metrics have been logged under *name*. + """ + return _aggregators[name].get_smoothed_values() + + +def state_dict(): + return OrderedDict([(name, agg.state_dict()) for name, agg in _aggregators.items()]) + + +def load_state_dict(state_dict): + for name, agg_state in state_dict.items(): + _aggregators[name] = MetersDict() + _aggregators[name].load_state_dict(agg_state) diff --git a/MindChemistry/applications/Uni-Mol/unicore/logging/progress_bar.py b/MindChemistry/applications/Uni-Mol/unicore/logging/progress_bar.py new file mode 100644 index 0000000000000000000000000000000000000000..9f30223160d04f7b28adeb6a40fbc68ca6bfb514 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/logging/progress_bar.py @@ -0,0 +1,742 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Wrapper around various loggers and progress bars (e.g., tqdm). +""" +import atexit +import json +import logging +import os +import sys +from collections import OrderedDict +from contextlib import contextmanager +from numbers import Number +from typing import Optional + +import mindspore as ms + +from .meters import AverageMeter, StopwatchMeter, TimeMeter + + +logger = logging.getLogger(__name__) + + +def progress_bar( + iterator, + log_format: Optional[str] = None, + log_interval: int = 100, + epoch: Optional[int] = None, + prefix: Optional[str] = None, + tensorboard_logdir: Optional[str] = None, + wandb_project: Optional[str] = None, + default_log_format: str = "tqdm", + args=None, +): + if log_format is None: + log_format = default_log_format + if log_format == "tqdm" and not sys.stderr.isatty(): + log_format = "simple" + + if log_format == "json": + bar = JsonProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "none": + bar = NoopProgressBar(iterator, epoch, prefix) + elif log_format == "simple": + bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) + elif log_format == "tqdm": + bar = TqdmProgressBar(iterator, epoch, prefix) + else: + raise ValueError("Unknown log format: {}".format(log_format)) + + if tensorboard_logdir: + bar = TensorboardProgressBarWrapper( + bar, tensorboard_logdir, wandb_project, args + ) + + return bar + + +def format_stat(stat): + if isinstance(stat, Number): + stat = "{:g}".format(stat) + elif isinstance(stat, AverageMeter): + stat = "{:.3f}".format(stat.avg) + elif isinstance(stat, TimeMeter): + stat = "{:g}".format(round(stat.avg)) + elif isinstance(stat, StopwatchMeter): + stat = "{:g}".format(round(stat.sum)) + elif ms.is_tensor(stat): + stat = stat.tolist() + return stat + + +class BaseProgressBar(object): + """Abstract class for progress bars.""" + + def __init__(self, iterable, epoch=None, prefix=None): + self.iterable = iterable + self.n = getattr(iterable, "n", 0) + self.epoch = epoch + self.prefix = "" + if epoch is not None: + self.prefix += "epoch {:03d}".format(epoch) + if prefix is not None: + self.prefix += (" | " if self.prefix != "" else "") + prefix + + def __len__(self): + return len(self.iterable) + + def __enter__(self): + return self + + def __exit__(self, *exc): + return False + + def __iter__(self): + raise NotImplementedError + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + raise NotImplementedError + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + raise NotImplementedError + + def update_config(self, config): + """Log latest configuration.""" + pass + + def _str_commas(self, stats): + return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) + + def _str_pipes(self, stats): + return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) + + def _format_stats(self, stats): + postfix = OrderedDict(stats) + # Preprocess stats according to datatype + for key in postfix.keys(): + postfix[key] = str(format_stat(postfix[key])) + return postfix + + +@contextmanager +def rename_logger(logger, new_name): + old_name = logger.name + if new_name is not None: + logger.name = new_name + yield logger + logger.name = old_name + + +class JsonProgressBar(BaseProgressBar): + """Log output in JSON format.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + update = ( + self.epoch - 1 + (self.i + 1) / float(self.size) + if self.epoch is not None + else None + ) + stats = self._format_stats(stats, epoch=self.epoch, update=update) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self.stats = stats + if tag is not None: + self.stats = OrderedDict( + [(tag + "_" + k, v) for k, v in self.stats.items()] + ) + stats = self._format_stats(self.stats, epoch=self.epoch) + with rename_logger(logger, tag): + logger.info(json.dumps(stats)) + + def _format_stats(self, stats, epoch=None, update=None): + postfix = OrderedDict() + if epoch is not None: + postfix["epoch"] = epoch + if update is not None: + postfix["update"] = round(update, 3) + # Preprocess stats according to datatype + for key in stats.keys(): + postfix[key] = format_stat(stats[key]) + return postfix + + +class NoopProgressBar(BaseProgressBar): + """No logging.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + + def __iter__(self): + for obj in self.iterable: + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + pass + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + pass + + +class SimpleProgressBar(BaseProgressBar): + """A minimal logger for non-TTY environments.""" + + def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): + super().__init__(iterable, epoch, prefix) + self.log_interval = log_interval + self.i = None + self.size = None + + def __iter__(self): + self.size = len(self.iterable) + for i, obj in enumerate(self.iterable, start=self.n): + self.i = i + yield obj + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + step = step or self.i or 0 + if step > 0 and self.log_interval is not None and step % self.log_interval == 0: + stats = self._format_stats(stats) + postfix = self._str_commas(stats) + with rename_logger(logger, tag): + logger.info( + "{}: {:5d} / {:d} {}".format( + self.prefix, self.i + 1, self.size, postfix + ) + ) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +class TqdmProgressBar(BaseProgressBar): + """Log to tqdm.""" + + def __init__(self, iterable, epoch=None, prefix=None): + super().__init__(iterable, epoch, prefix) + from tqdm import tqdm + + self.tqdm = tqdm( + iterable, + self.prefix, + leave=False, + disable=(logger.getEffectiveLevel() > logging.INFO), + ) + + def __iter__(self): + return iter(self.tqdm) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats according to log_interval.""" + self.tqdm.set_postfix(self._format_stats(stats), refresh=False) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + postfix = self._str_pipes(self._format_stats(stats)) + with rename_logger(logger, tag): + logger.info("{} | {}".format(self.prefix, postfix)) + + +try: + _tensorboard_writers = {} + # MindSpore兼容TensorBoard,使用相同的SummaryWriter + from torch.utils.tensorboard import SummaryWriter +except ImportError: + try: + from tensorboardX import SummaryWriter + except ImportError: + SummaryWriter = None + +try: + _wandb_inited = False + import wandb + + wandb_available = True +except ImportError: + wandb_available = False + + +def _close_writers(): + for w in _tensorboard_writers.values(): + w.close() + if _wandb_inited: + try: + wandb.finish() + except: + pass + + +atexit.register(_close_writers) + + +class TensorboardProgressBarWrapper(BaseProgressBar): + """Log to tensorboard.""" + + def __init__(self, wrapped_bar, tensorboard_logdir, wandb_project, args): + self.wrapped_bar = wrapped_bar + self.tensorboard_logdir = tensorboard_logdir + + if SummaryWriter is None: + logger.warning( + "tensorboard not found, please install with: pip install tensorboard" + ) + global _wandb_inited + if not _wandb_inited and wandb_project and wandb_available: + wandb_name = args.wandb_name or wandb.util.generate_id() + if "/" in wandb_project: + entity, project = wandb_project.split("/") + else: + entity, project = None, wandb_project + wandb.init( + project=project, + entity=entity, + name=wandb_name, + config=vars(args), + id=wandb_name, + resume="allow", + ) + _wandb_inited = True + + def _writer(self, key): + if SummaryWriter is None: + return None + _writers = _tensorboard_writers + if key not in _writers: + _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) + _writers[key].add_text("sys.argv", " ".join(sys.argv)) + return _writers[key] + + def __iter__(self): + return iter(self.wrapped_bar) + + def log(self, stats, tag=None, step=None): + """Log intermediate stats to tensorboard.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.log(stats, tag=tag, step=step) + + def print(self, stats, tag=None, step=None): + """Print end-of-epoch stats.""" + self._log_to_tensorboard(stats, tag, step) + self.wrapped_bar.print(stats, tag=tag, step=step) + + def update_config(self, config): + """Log latest configuration.""" + # TODO add hparams to Tensorboard + self.wrapped_bar.update_config(config) + + def _log_to_tensorboard(self, stats, tag=None, step=None): + writer = self._writer(tag or "") + if writer is None: + return + if step is None: + step = stats["num_updates"] + for key in stats.keys() - {"num_updates"}: + if isinstance(stats[key], AverageMeter): + val = stats[key].val + elif isinstance(stats[key], Number): + val = stats[key] + elif ms.is_tensor(stats[key]) and stats[key].size == 1: + val = stats[key].item() + else: + val = None + if val: + writer.add_scalar(key, val, step) + if _wandb_inited: + wandb.log({"{}_{}".format(tag, key): val}, step=step) + writer.flush() +# import atexit +# import json +# import logging +# import os +# import sys +# from collections import OrderedDict +# from contextlib import contextmanager +# from numbers import Number +# from typing import Optional + +# import torch + +# from .meters import AverageMeter, StopwatchMeter, TimeMeter + + +# logger = logging.getLogger(__name__) + + +# def progress_bar( +# iterator, +# log_format: Optional[str] = None, +# log_interval: int = 100, +# epoch: Optional[int] = None, +# prefix: Optional[str] = None, +# tensorboard_logdir: Optional[str] = None, +# wandb_project: Optional[str] = None, +# default_log_format: str = "tqdm", +# args=None, +# ): +# if log_format is None: +# log_format = default_log_format +# if log_format == "tqdm" and not sys.stderr.isatty(): +# log_format = "simple" + +# if log_format == "json": +# bar = JsonProgressBar(iterator, epoch, prefix, log_interval) +# elif log_format == "none": +# bar = NoopProgressBar(iterator, epoch, prefix) +# elif log_format == "simple": +# bar = SimpleProgressBar(iterator, epoch, prefix, log_interval) +# elif log_format == "tqdm": +# bar = TqdmProgressBar(iterator, epoch, prefix) +# else: +# raise ValueError("Unknown log format: {}".format(log_format)) + +# if tensorboard_logdir: +# bar = TensorboardProgressBarWrapper( +# bar, tensorboard_logdir, wandb_project, args +# ) + +# return bar + + +# def format_stat(stat): +# if isinstance(stat, Number): +# stat = "{:g}".format(stat) +# elif isinstance(stat, AverageMeter): +# stat = "{:.3f}".format(stat.avg) +# elif isinstance(stat, TimeMeter): +# stat = "{:g}".format(round(stat.avg)) +# elif isinstance(stat, StopwatchMeter): +# stat = "{:g}".format(round(stat.sum)) +# elif torch.is_tensor(stat): +# stat = stat.tolist() +# return stat + + +# class BaseProgressBar(object): +# """Abstract class for progress bars.""" + +# def __init__(self, iterable, epoch=None, prefix=None): +# self.iterable = iterable +# self.n = getattr(iterable, "n", 0) +# self.epoch = epoch +# self.prefix = "" +# if epoch is not None: +# self.prefix += "epoch {:03d}".format(epoch) +# if prefix is not None: +# self.prefix += (" | " if self.prefix != "" else "") + prefix + +# def __len__(self): +# return len(self.iterable) + +# def __enter__(self): +# return self + +# def __exit__(self, *exc): +# return False + +# def __iter__(self): +# raise NotImplementedError + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats according to log_interval.""" +# raise NotImplementedError + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# raise NotImplementedError + +# def update_config(self, config): +# """Log latest configuration.""" +# pass + +# def _str_commas(self, stats): +# return ", ".join(key + "=" + stats[key].strip() for key in stats.keys()) + +# def _str_pipes(self, stats): +# return " | ".join(key + " " + stats[key].strip() for key in stats.keys()) + +# def _format_stats(self, stats): +# postfix = OrderedDict(stats) +# # Preprocess stats according to datatype +# for key in postfix.keys(): +# postfix[key] = str(format_stat(postfix[key])) +# return postfix + + +# @contextmanager +# def rename_logger(logger, new_name): +# old_name = logger.name +# if new_name is not None: +# logger.name = new_name +# yield logger +# logger.name = old_name + + +# class JsonProgressBar(BaseProgressBar): +# """Log output in JSON format.""" + +# def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): +# super().__init__(iterable, epoch, prefix) +# self.log_interval = log_interval +# self.i = None +# self.size = None + +# def __iter__(self): +# self.size = len(self.iterable) +# for i, obj in enumerate(self.iterable, start=self.n): +# self.i = i +# yield obj + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats according to log_interval.""" +# step = step or self.i or 0 +# if step > 0 and self.log_interval is not None and step % self.log_interval == 0: +# update = ( +# self.epoch - 1 + (self.i + 1) / float(self.size) +# if self.epoch is not None +# else None +# ) +# stats = self._format_stats(stats, epoch=self.epoch, update=update) +# with rename_logger(logger, tag): +# logger.info(json.dumps(stats)) + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# self.stats = stats +# if tag is not None: +# self.stats = OrderedDict( +# [(tag + "_" + k, v) for k, v in self.stats.items()] +# ) +# stats = self._format_stats(self.stats, epoch=self.epoch) +# with rename_logger(logger, tag): +# logger.info(json.dumps(stats)) + +# def _format_stats(self, stats, epoch=None, update=None): +# postfix = OrderedDict() +# if epoch is not None: +# postfix["epoch"] = epoch +# if update is not None: +# postfix["update"] = round(update, 3) +# # Preprocess stats according to datatype +# for key in stats.keys(): +# postfix[key] = format_stat(stats[key]) +# return postfix + + +# class NoopProgressBar(BaseProgressBar): +# """No logging.""" + +# def __init__(self, iterable, epoch=None, prefix=None): +# super().__init__(iterable, epoch, prefix) + +# def __iter__(self): +# for obj in self.iterable: +# yield obj + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats according to log_interval.""" +# pass + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# pass + + +# class SimpleProgressBar(BaseProgressBar): +# """A minimal logger for non-TTY environments.""" + +# def __init__(self, iterable, epoch=None, prefix=None, log_interval=1000): +# super().__init__(iterable, epoch, prefix) +# self.log_interval = log_interval +# self.i = None +# self.size = None + +# def __iter__(self): +# self.size = len(self.iterable) +# for i, obj in enumerate(self.iterable, start=self.n): +# self.i = i +# yield obj + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats according to log_interval.""" +# step = step or self.i or 0 +# if step > 0 and self.log_interval is not None and step % self.log_interval == 0: +# stats = self._format_stats(stats) +# postfix = self._str_commas(stats) +# with rename_logger(logger, tag): +# logger.info( +# "{}: {:5d} / {:d} {}".format( +# self.prefix, self.i + 1, self.size, postfix +# ) +# ) + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# postfix = self._str_pipes(self._format_stats(stats)) +# with rename_logger(logger, tag): +# logger.info("{} | {}".format(self.prefix, postfix)) + + +# class TqdmProgressBar(BaseProgressBar): +# """Log to tqdm.""" + +# def __init__(self, iterable, epoch=None, prefix=None): +# super().__init__(iterable, epoch, prefix) +# from tqdm import tqdm + +# self.tqdm = tqdm( +# iterable, +# self.prefix, +# leave=False, +# disable=(logger.getEffectiveLevel() > logging.INFO), +# ) + +# def __iter__(self): +# return iter(self.tqdm) + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats according to log_interval.""" +# self.tqdm.set_postfix(self._format_stats(stats), refresh=False) + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# postfix = self._str_pipes(self._format_stats(stats)) +# with rename_logger(logger, tag): +# logger.info("{} | {}".format(self.prefix, postfix)) + + +# try: +# _tensorboard_writers = {} +# from torch.utils.tensorboard import SummaryWriter +# except ImportError: +# try: +# from tensorboardX import SummaryWriter +# except ImportError: +# SummaryWriter = None + +# try: +# _wandb_inited = False +# import wandb + +# wandb_available = True +# except ImportError: +# wandb_available = False + + +# def _close_writers(): +# for w in _tensorboard_writers.values(): +# w.close() +# if _wandb_inited: +# try: +# wandb.finish() +# except: +# pass + + +# atexit.register(_close_writers) + + +# class TensorboardProgressBarWrapper(BaseProgressBar): +# """Log to tensorboard.""" + +# def __init__(self, wrapped_bar, tensorboard_logdir, wandb_project, args): +# self.wrapped_bar = wrapped_bar +# self.tensorboard_logdir = tensorboard_logdir + +# if SummaryWriter is None: +# logger.warning( +# "tensorboard not found, please install with: pip install tensorboard" +# ) +# global _wandb_inited +# if not _wandb_inited and wandb_project and wandb_available: +# wandb_name = args.wandb_name or wandb.util.generate_id() +# if "/" in wandb_project: +# entity, project = wandb_project.split("/") +# else: +# entity, project = None, wandb_project +# wandb.init( +# project=project, +# entity=entity, +# name=wandb_name, +# config=vars(args), +# id=wandb_name, +# resume="allow", +# ) +# _wandb_inited = True + +# def _writer(self, key): +# if SummaryWriter is None: +# return None +# _writers = _tensorboard_writers +# if key not in _writers: +# _writers[key] = SummaryWriter(os.path.join(self.tensorboard_logdir, key)) +# _writers[key].add_text("sys.argv", " ".join(sys.argv)) +# return _writers[key] + +# def __iter__(self): +# return iter(self.wrapped_bar) + +# def log(self, stats, tag=None, step=None): +# """Log intermediate stats to tensorboard.""" +# self._log_to_tensorboard(stats, tag, step) +# self.wrapped_bar.log(stats, tag=tag, step=step) + +# def print(self, stats, tag=None, step=None): +# """Print end-of-epoch stats.""" +# self._log_to_tensorboard(stats, tag, step) +# self.wrapped_bar.print(stats, tag=tag, step=step) + +# def update_config(self, config): +# """Log latest configuration.""" +# # TODO add hparams to Tensorboard +# self.wrapped_bar.update_config(config) + +# def _log_to_tensorboard(self, stats, tag=None, step=None): +# writer = self._writer(tag or "") +# if writer is None: +# return +# if step is None: +# step = stats["num_updates"] +# for key in stats.keys() - {"num_updates"}: +# if isinstance(stats[key], AverageMeter): +# val = stats[key].val +# elif isinstance(stats[key], Number): +# val = stats[key] +# elif torch.is_tensor(stats[key]) and stats[key].numel() == 1: +# val = stats[key].item() +# else: +# val = None +# if val: +# writer.add_scalar(key, val, step) +# if _wandb_inited: +# wandb.log({"{}_{}".format(tag, key): val}, step=step) +# writer.flush() diff --git a/MindChemistry/applications/Uni-Mol/unicore/losses/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..622eeff736b3e7fe8eec5168987ecd5d465a76f2 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/losses/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from unicore import registry +from unicore.losses.unicore_loss import ( # noqa + UnicoreLoss, +) + + +( + build_loss_, + register_loss, + CRITERION_REGISTRY, +) = registry.setup_registry( + "--loss", base_class=UnicoreLoss, default="cross_entropy" +) + + +def build_loss(args, task): + return build_loss_(args, task) + + +# automatically import any Python files in the losses/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("unicore.losses." + file_name) diff --git a/MindChemistry/applications/Uni-Mol/unicore/losses/cross_entropy.py b/MindChemistry/applications/Uni-Mol/unicore/losses/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1285da7c87fde90b20824f3da0c6e800c9203d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/losses/cross_entropy.py @@ -0,0 +1,126 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math +import mindspore as ms +import mindspore.ops as ops +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss + +@register_loss("cross_entropy") +class CrossEntropyLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + net_output = model(**sample["net_input"]) + loss = self.compute_loss(model, net_output, sample, reduce=reduce) + sample_size = sample["target"].shape[0] + logging_output = { + "loss": loss.asnumpy() if isinstance(loss, ms.Tensor) else loss, + "bsz": sample["target"].shape[0], + "sample_size": sample_size, + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + # MindSpore的log_softmax对应PyTorch的F.log_softmax,指定维度dim + lprobs = ops.log_softmax(net_output.astype(ms.float32), axis=-1) + # MindSpore中用reshape替代view + lprobs = lprobs.reshape(-1, lprobs.shape[-1]) + target = sample['target'].reshape(-1) + # MindSpore的nll_loss对应PyTorch的F.nll_loss,reduction参数保持一致 + loss = ops.nll_loss( + lprobs, + target, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split='valid') -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + # 保持与原逻辑一致,将损失从自然对数转换为以2为底的对数 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True +# import math +# import torch +# import torch.nn.functional as F +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss + +# @register_loss("cross_entropy") +# class CrossEntropyLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model(**sample["net_input"]) +# loss = self.compute_loss(model, net_output, sample, reduce=reduce) +# sample_size = sample["target"].size(0) +# logging_output = { +# "loss": loss.data, +# "bsz": sample["target"].size(0), +# "sample_size": sample_size, +# } +# return loss, sample_size, logging_output + +# def compute_loss(self, model, net_output, sample, reduce=True): +# lprobs = F.log_softmax(net_output.float(), dim=-1) +# lprobs = lprobs.view(-1, lprobs.size(-1)) +# target = sample['target'].view(-1) +# loss = F.nll_loss( +# lprobs, +# target, +# reduction="sum" if reduce else "none", +# ) +# return loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split='valid') -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return True diff --git a/MindChemistry/applications/Uni-Mol/unicore/losses/masked_lm.py b/MindChemistry/applications/Uni-Mol/unicore/losses/masked_lm.py new file mode 100644 index 0000000000000000000000000000000000000000..5f05de7d5ffb21ad55d96efa7bd237a8d0b674d8 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/losses/masked_lm.py @@ -0,0 +1,139 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math +import mindspore as ms +import mindspore.ops as ops +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss + +@register_loss("masked_lm") +class MaskedLMLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + self.padding_idx = task.dictionary.pad() + + def forward(self, model, sample, reduce=True): + # 替换PyTorch的ne为MindSpore的not_equal + masked_tokens = sample["target"].not_equal(self.padding_idx) + # 替换int()为astype(ms.int32),sum保持一致 + sample_size = masked_tokens.astype(ms.int32).sum() + + # 替换torch.where为ops.where,保持条件逻辑 + masked_tokens = ops.where( + masked_tokens.any(), + masked_tokens, + # 替换new创建张量的方式,保持数据类型一致 + ms.Tensor([True], dtype=masked_tokens.dtype) + ) + logits = model(** sample["net_input"], masked_tokens=masked_tokens) + target = sample['target'] + if masked_tokens is not None: + # MindSpore支持布尔索引,保持索引方式 + target = target[masked_tokens] + # 替换F.log_softmax为ops.log_softmax,dim→axis,dtype调整为ms.float32 + # 替换F.nll_loss为ops.nll_loss,参数保持一致 + loss = ops.nll_loss( + ops.log_softmax(logits, axis=-1, dtype=ms.float32), + target, + ignore_index=self.padding_idx, + reduction='sum', + ) + logging_output = { + # 替换loss.data为loss.asnumpy()获取数据 + "loss": loss.asnumpy() if isinstance(loss, ms.Tensor) else loss, + # 替换size(0)为shape[0] + "bsz": sample["target"].shape[0], + "sample_size": sample_size, + # 替换size(1)为shape[1] + "seq_len": sample["target"].shape[1] * sample["target"].shape[0], + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split='valid') -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + bsz = sum(log.get("bsz", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + seq_len = sum(log.get("seq_len", 0) for log in logging_outputs) + # 保持与原逻辑一致,将损失从自然对数转换为以2为底的对数 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + metrics.log_scalar( + "seq_len", seq_len / bsz, 1, round=3 + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return True +# import math +# import torch +# import torch.nn.functional as F +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss + +# @register_loss("masked_lm") +# class MaskedLMLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) +# self.padding_idx = task.dictionary.pad() + +# def forward(self, model, sample, reduce=True): +# masked_tokens = sample["target"].ne(self.padding_idx) +# sample_size = masked_tokens.int().sum() + +# masked_tokens = torch.where( +# masked_tokens.any(), +# masked_tokens, +# masked_tokens.new([True]), +# ) +# logits = model(**sample["net_input"], masked_tokens=masked_tokens) +# target = sample['target'] +# if masked_tokens is not None: +# target = target[masked_tokens] +# loss = F.nll_loss( +# F.log_softmax(logits, dim=-1, dtype=torch.float32), +# target, +# ignore_index=self.padding_idx, +# reduction='sum', +# ) +# logging_output = { +# "loss": loss.data, +# "bsz": sample["target"].size(0), +# "sample_size": sample_size, +# "seq_len": sample["target"].size(1) * sample["target"].size(0), +# } +# return loss, sample_size, logging_output + +# @staticmethod +# def reduce_metrics(logging_outputs, split='valid') -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# bsz = sum(log.get("bsz", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# seq_len = sum(log.get("seq_len", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# metrics.log_scalar( +# "seq_len", seq_len / bsz, 1, round=3 +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return True diff --git a/MindChemistry/applications/Uni-Mol/unicore/losses/unicore_loss.py b/MindChemistry/applications/Uni-Mol/unicore/losses/unicore_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..40912f14a3e35c110fde40dc0f37542df86df531 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/losses/unicore_loss.py @@ -0,0 +1,148 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import inspect +from typing import Any, Dict, List + +from unicore import metrics, utils +from mindspore.nn import Loss + + +class UnicoreLoss(Loss): + def __init__(self, task): + super().__init__() + self.task = task + if task is not None: + self.args = task.args + if hasattr(task, "target_dictionary"): + tgt_dict = task.target_dictionary + self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 + + @classmethod + def add_args(cls, parser): + pass + + @classmethod + def build_loss(cls, args, task): + """Construct a loss from command-line args.""" + # arguments in the __init__. + init_args = {} + for p in inspect.signature(cls).parameters.values(): + if ( + p.kind == p.POSITIONAL_ONLY + or p.kind == p.VAR_POSITIONAL + or p.kind == p.VAR_KEYWORD + ): + # we haven't implemented inference for these argument types, + # but PRs welcome :) + raise NotImplementedError("{} not supported".format(p.kind)) + + assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + + if p.name == "task": + init_args["task"] = task + elif p.name == "args": + init_args["args"] = args + elif hasattr(args, p.name): + init_args[p.name] = getattr(args, p.name) + elif p.default != p.empty: + pass # we'll use the default value + else: + raise NotImplementedError( + "Unable to infer Loss arguments, please implement " + "{}.build_loss".format(cls.__name__) + ) + return cls(**init_args) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample. + + Returns a tuple with three elements: + 1) the loss + 2) the sample size, which is used as the denominator for the gradient + 3) logging outputs to display while training + """ + raise NotImplementedError + + @staticmethod + def logging_outputs_can_be_summed(is_train: bool) -> bool: + """ + Whether the logging outputs returned by `forward` can be summed + across workers prior to calling `reduce_metrics`. Setting this + to True will improves distributed training speed. + """ + return False +# import inspect +# from typing import Any, Dict, List + +# from unicore import metrics, utils +# from torch.nn.modules.loss import _Loss + + +# class UnicoreLoss(_Loss): +# def __init__(self, task): +# super().__init__() +# self.task = task +# if task is not None: +# self.args = task.args +# if hasattr(task, "target_dictionary"): +# tgt_dict = task.target_dictionary +# self.padding_idx = tgt_dict.pad() if tgt_dict is not None else -100 + +# @classmethod +# def add_args(cls, parser): +# pass + +# @classmethod +# def build_loss(cls, args, task): +# """Construct a loss from command-line args.""" +# # arguments in the __init__. +# init_args = {} +# for p in inspect.signature(cls).parameters.values(): +# if ( +# p.kind == p.POSITIONAL_ONLY +# or p.kind == p.VAR_POSITIONAL +# or p.kind == p.VAR_KEYWORD +# ): +# # we haven't implemented inference for these argument types, +# # but PRs welcome :) +# raise NotImplementedError("{} not supported".format(p.kind)) + +# assert p.kind in {p.POSITIONAL_OR_KEYWORD, p.KEYWORD_ONLY} + +# if p.name == "task": +# init_args["task"] = task +# elif p.name == "args": +# init_args["args"] = args +# elif hasattr(args, p.name): +# init_args[p.name] = getattr(args, p.name) +# elif p.default != p.empty: +# pass # we'll use the default value +# else: +# raise NotImplementedError( +# "Unable to infer Loss arguments, please implement " +# "{}.build_loss".format(cls.__name__) +# ) +# return cls(**init_args) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# raise NotImplementedError + +# @staticmethod +# def logging_outputs_can_be_summed(is_train: bool) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return False + diff --git a/MindChemistry/applications/Uni-Mol/unicore/models/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a68899ca4fb9450ed79afcfd994984f67c0dcf96 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/models/__init__.py @@ -0,0 +1,120 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from .distributed_unicore_model import DistributedUnicoreModel +from .unicore_model import ( + BaseUnicoreModel, +) + +MODEL_REGISTRY = {} +ARCH_MODEL_REGISTRY = {} +ARCH_MODEL_NAME_REGISTRY = {} +ARCH_MODEL_INV_REGISTRY = {} +ARCH_CONFIG_REGISTRY = {} + + +__all__ = [ + "BaseUnicoreModel", + "DistributedUnicoreModel", +] + + +def build_model(args, task): + return ARCH_MODEL_REGISTRY[args.arch].build_model(args, task) + + +def register_model(name): + """ + New model types can be added to unicore with the :func:`register_model` + function decorator. + + For example:: + + @register_model("lstm") + class LSTM(UnicoreEncoderDecoderModel): + (...) + + .. note:: All models must implement the :class:`BaseUnicoreModel` interface. + Typically you will extend :class:`UnicoreEncoderDecoderModel` for + sequence-to-sequence tasks or :class:`UnicoreLanguageModel` for + language modeling tasks. + + Args: + name (str): the name of the model + """ + + def register_model_cls(cls): + if name in MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model ({})".format(name)) + if not issubclass(cls, BaseUnicoreModel): + raise ValueError("Model ({}: {}) must extend BaseUnicoreModel".format(name, cls.__name__)) + MODEL_REGISTRY[name] = cls + return cls + + return register_model_cls + + +def register_model_architecture(model_name, arch_name): + """ + New model architectures can be added to unicore with the + :func:`register_model_architecture` function decorator. After registration, + model architectures can be selected with the ``--arch`` command-line + argument. + + For example:: + + @register_model_architecture("lstm", "lstm_luong_wmt_en_de") + def lstm_luong_wmt_en_de(args): + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 1000) + (...) + + The decorated function should take a single argument *args*, which is a + :class:`argparse.Namespace` of arguments parsed from the command-line. The + decorated function should modify these arguments in-place to match the + desired architecture. + + Args: + model_name (str): the name of the Model (Model must already be + registered) + arch_name (str): the name of the model architecture (``--arch``) + """ + + def register_model_arch_fn(fn): + if model_name not in MODEL_REGISTRY: + raise ValueError("Cannot register model architecture for unknown model type ({})".format(model_name)) + if arch_name in ARCH_MODEL_REGISTRY: + raise ValueError("Cannot register duplicate model architecture ({})".format(arch_name)) + if not callable(fn): + raise ValueError("Model architecture must be callable ({})".format(arch_name)) + ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name] + ARCH_MODEL_INV_REGISTRY.setdefault(model_name, []).append(arch_name) + ARCH_CONFIG_REGISTRY[arch_name] = fn + return fn + + return register_model_arch_fn + + +# automatically import any Python files in the models/ directory +models_dir = os.path.dirname(__file__) +for file in os.listdir(models_dir): + path = os.path.join(models_dir, file) + if not file.startswith("_") and not file.startswith(".") and (file.endswith(".py") or os.path.isdir(path)): + model_name = file[:file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("unicore.models." + model_name) + + # extra `model_parser` for sphinx + if model_name in MODEL_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_archs = parser.add_argument_group("Named architectures") + group_archs.add_argument("--arch", choices=ARCH_MODEL_INV_REGISTRY[model_name]) + group_args = parser.add_argument_group("Additional command-line arguments") + MODEL_REGISTRY[model_name].add_args(group_args) + globals()[model_name + "_parser"] = parser diff --git a/MindChemistry/applications/Uni-Mol/unicore/models/distributed_unicore_model.py b/MindChemistry/applications/Uni-Mol/unicore/models/distributed_unicore_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cdef7b4fefd2002fcd9c8f1363dd4b7b03e94b6a --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/models/distributed_unicore_model.py @@ -0,0 +1,177 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging + +import mindspore as ms +from mindspore.nn import Cell + +# 移除分布式相关导入,因为单卡环境不需要 +# 保留可能需要的ModuleProxyWrapper(如果存在) +try: + from unicore.distributed import ModuleProxyWrapper +except ImportError: + # 如果没有ModuleProxyWrapper,定义一个简单的替代类 + class ModuleProxyWrapper: + def __init__(self, module): + self.module = module + + def __getattr__(self, name): + try: + return getattr(self.module, name) + except AttributeError: + return getattr(self, name) + + +logger = logging.getLogger(__name__) + +def DistributedUnicoreModel(args, model, process_group=None, device=None): + """ + 适配单卡Ascend NPU的模型包装,移除所有分布式相关代码 + + Args: + args: 配置参数 + model: 待包装的模型(MindSpore Cell) + process_group: 分布式进程组(单卡环境中忽略) + device: 设备信息(单卡环境中自动获取) + """ + assert isinstance(model, Cell), "模型必须是MindSpore的Cell类型" + + # 单卡环境下不需要分布式包装,仅处理设备迁移 + if device is None: + # 自动获取当前设备(Ascend环境) + device_id = ms.get_context("device_id") + device = f"Ascend:{device_id}" + + # 将模型迁移到目标设备 + model = model.to(device) + + # 日志提示当前为单卡模式 + logger.info(f"使用单卡Ascend模式,设备: {device}") + + # 如果需要保留包装器接口,使用ModuleProxyWrapper + if 'ModuleProxyWrapper' in globals(): + return ModuleProxyWrapper(model) + return model + +# import logging + +# import mindspore as ms +# from mindspore.nn import Cell +# from mindspore.nn import DistributedDataParallel + +# from unicore.distributed import ( +# ModuleProxyWrapper, LegacyDistributedDataParallel # 假设unicore中已有适配MindSpore的实现 +# ) + + +# logger = logging.getLogger(__name__) + +# def DistributedUnicoreModel(args, model, process_group, device): +# """ +# Wrap a *model* to support distributed data parallel training. + +# 适配MindSpore的分布式模型包装,功能与原PyTorch版本一致。 +# """ +# assert isinstance(model, Cell), "模型必须是MindSpore的Cell类型" + +# if args.ddp_backend in {"c10d", "pytorch_ddp"}: +# # 替换PyTorch的DistributedDataParallel为MindSpore的对应类 +# wrapped_model = DistributedDataParallel( +# network=model.to(device), # MindSpore中模型设备迁移使用to(device) +# device_ids=[args.device_id], +# output_device=args.device_id, +# broadcast_buffers=args.broadcast_buffers, +# bucket_cap_mb=args.bucket_cap_mb, +# process_group=process_group, +# find_unused_parameters=args.find_unused_parameters, +# ) +# # 保留属性转发包装器 +# wrapped_model = ModuleProxyWrapper(wrapped_model) + +# elif args.ddp_backend in {'apex'}: +# # MindSpore不直接支持Apex,此处可根据实际需求替换为MindSpore的分布式实现 +# logger.warning("MindSpore暂不支持apex后端,已自动切换为原生DDP") +# wrapped_model = DistributedDataParallel( +# network=model.to(device), +# device_ids=[args.device_id], +# process_group=process_group, +# ) +# wrapped_model = ModuleProxyWrapper(wrapped_model) + +# elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: +# # 替换LegacyDistributedDataParallel为MindSpore兼容实现 +# wrapped_model = LegacyDistributedDataParallel( +# network=model.to(device), +# buffer_size=2 **28, +# process_group=process_group, +# ) +# wrapped_model = ModuleProxyWrapper(wrapped_model) + +# else: +# raise ValueError(f"未知的--ddp-backend: {args.ddp_backend}") + +# return wrapped_model +# import logging + +# import torch +# import torch.nn as nn +# from torch.nn.parallel import DistributedDataParallel + +# from unicore.distributed import ( +# ModuleProxyWrapper, LegacyDistributedDataParallel +# ) + + +# logger = logging.getLogger(__name__) + +# def DistributedUnicoreModel(args, model, process_group, device): +# """ +# Wrap a *model* to support distributed data parallel training. + +# This is similar to the built-in DistributedDataParallel, but allows +# additional configuration of the DistributedDataParallel class to +# use, and also provides easier access to the wrapped model by +# forwarding requests for missing attributes to the wrapped model. + +# Args: +# args (argparse.Namespace): unicore args +# model (BaseUnicoreModel): model to wrap +# process_group: the c10d process group to be used for distributed data +# parallel all-reduction. +# device: device to move model to +# """ +# assert isinstance(model, nn.Module) +# if args.ddp_backend in {"c10d", "pytorch_ddp"}: +# wrapped_model = DistributedDataParallel( +# module=model.to(device), +# device_ids=[args.device_id], +# output_device=args.device_id, +# broadcast_buffers=args.broadcast_buffers, +# bucket_cap_mb=args.bucket_cap_mb, +# process_group=process_group, +# find_unused_parameters=args.find_unused_parameters, +# ) +# # forward missing getattr and state_dict/load_state_dict to orig model +# wrapped_model = ModuleProxyWrapper(wrapped_model) +# elif args.ddp_backend in {'apex'}: +# import apex +# wrapped_model = apex.parallel.DistributedDataParallel( +# module=model.to(device) +# ) +# # forward missing getattr and state_dict/load_state_dict to orig model +# wrapped_model = ModuleProxyWrapper(wrapped_model) +# elif args.ddp_backend in {"no_c10d", "legacy_ddp"}: +# wrapped_model = LegacyDistributedDataParallel( +# module=model.to(device), +# buffer_size=2 ** 28, +# process_group=process_group, +# ) +# # forward missing getattr and state_dict/load_state_dict to orig model +# wrapped_model = ModuleProxyWrapper(wrapped_model) +# else: +# raise ValueError("Unknown --ddp-backend: " + args.ddp_backend) + +# return wrapped_model diff --git a/MindChemistry/applications/Uni-Mol/unicore/models/unicore_model.py b/MindChemistry/applications/Uni-Mol/unicore/models/unicore_model.py new file mode 100644 index 0000000000000000000000000000000000000000..13a6ad93cb0cb421cce152df89d4479af753c26a --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/models/unicore_model.py @@ -0,0 +1,106 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +""" +Base classes for various unicore models. +""" +import logging + +import mindspore as ms +from mindspore.nn import Cell + +logger = logging.getLogger(__name__) + + +class BaseUnicoreModel(Cell): + """Base class for unicore models (MindSpore version).""" + + def __init__(self): + super().__init__() + + @classmethod + def add_args(cls, parser): + """Add model-specific arguments to the parser.""" + pass + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + raise NotImplementedError("Model must implement the build_model method") + + def extract_features(self, *args, **kwargs): + """Similar to *forward* but only return features.""" + return self(*args, **kwargs) + + def load_state_dict( + self, + state_dict, + strict=True, + model_args=None, + ): + """Copies parameters and buffers from *state_dict* into this module and + its descendants. + + Overrides the method in :class:`mindspore.nn.Cell`. + """ + return super().load_state_dict(state_dict, strict) + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + + def _apply(m): + if hasattr(m, "set_num_updates") and m != self: + m.set_num_updates(num_updates) + + self.apply(_apply) +# import logging + +# import torch +# import torch.nn as nn + +# logger = logging.getLogger(__name__) + + +# class BaseUnicoreModel(nn.Module): +# """Base class for unicore models.""" + +# def __init__(self): +# super().__init__() + +# @classmethod +# def add_args(cls, parser): +# """Add model-specific arguments to the parser.""" +# pass + +# @classmethod +# def build_model(cls, args, task): +# """Build a new model instance.""" +# raise NotImplementedError("Model must implement the build_model method") + +# def extract_features(self, *args, **kwargs): +# """Similar to *forward* but only return features.""" +# return self(*args, **kwargs) + +# def load_state_dict( +# self, +# state_dict, +# strict=True, +# model_args = None, +# ): +# """Copies parameters and buffers from *state_dict* into this module and +# its descendants. + +# Overrides the method in :class:`nn.Module`. +# """ +# return super().load_state_dict(state_dict, strict) + +# def set_num_updates(self, num_updates): +# """State from trainer to pass along to model at every update.""" + +# def _apply(m): +# if hasattr(m, "set_num_updates") and m != self: +# m.set_num_updates(num_updates) + +# self.apply(_apply) diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/modules/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..5f95ee99e9ecd7ae8040d3129490816eff2a540f --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/__init__.py @@ -0,0 +1,14 @@ +"""isort:skip_file""" + +from .layer_norm import LayerNorm +from .rms_norm import RMSNorm +from .softmax_dropout import softmax_dropout +from .multihead_attention import SelfMultiheadAttention, CrossMultiheadAttention +from .transformer_encoder_layer import TransformerEncoderLayer +from .transformer_encoder import ( + TransformerEncoder, + init_bert_params, + relative_position_bucket, +) +from .transformer_decoder_layer import TransformerDecoderLayer +from .transformer_decoder import TransformerDecoder diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/layer_norm.py b/MindChemistry/applications/Uni-Mol/unicore/modules/layer_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..2053d06a8af721c90e3a59fa6a175deded248552 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/layer_norm.py @@ -0,0 +1,209 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import numbers +from mindspore import Parameter, ops, Tensor +from mindspore.common.initializer import initializer, One, Zero +from mindspore.nn import Cell + +try: + import unicore_fused_layernorm + import unicore_fused_layernorm_backward_gamma_beta + HAS_LAYER_NORM = True +except: + print("fused_layer_norm is not installed correctly") + HAS_LAYER_NORM = False + +# 检查设备是否支持融合层归一化 +device_target = ms.get_context("device_target") +if device_target not in ("GPU", "Ascend"): + HAS_LAYER_NORM = False +else: + HAS_LAYER_NORM = HAS_LAYER_NORM and device_target == "GPU" # 仅GPU支持 + + +class FusedLayerNormFastOp(Cell): + """基于nn.Cell实现的融合层归一化算子(自带反向传播)""" + def __init__(self): + super().__init__() + # 保存正向计算的中间变量(供反向传播使用) + self.mean = None + self.invvar = None + self.input_ = None + self.weight_ = None + self.bias_ = None + self.normalized_shape = None + self.eps = None + + def construct(self, input, weight, bias, normalized_shape, eps): + # 正向计算逻辑 + input = input.contiguous() + weight = weight.contiguous() + bias = bias.contiguous() + output, mean, invvar = unicore_fused_layernorm.forward( + input, normalized_shape, weight, bias, eps) + # 保存中间变量(反向传播需要) + self.mean = mean + self.invvar = invvar + self.input_ = input + self.weight_ = weight + self.bias_ = bias + self.normalized_shape = normalized_shape + self.eps = eps + return output + + def bprop(self, input, weight, bias, normalized_shape, eps, out, grad_output): + """定义反向传播(替代RegisterGradient)""" + # 从正向保存的变量中获取所需数据 + mean = self.mean + invvar = self.invvar + input_ = self.input_ + weight_ = self.weight_ + bias_ = self.bias_ + normalized_shape = self.normalized_shape + eps = self.eps + + # 计算各输入的梯度 + grad_input = unicore_fused_layernorm.backward( + grad_output.contiguous(), mean, invvar, + input_, normalized_shape, weight_, bias_, eps) + grad_weight, grad_bias = unicore_fused_layernorm_backward_gamma_beta.backward( + grad_output.contiguous(), mean, invvar, + input_, normalized_shape, weight_, bias_, eps) + + # 返回梯度(与输入参数一一对应,多余参数返回0梯度) + return (grad_input, grad_weight, grad_bias, + Tensor(0.0, dtype=ms.float32), # normalized_shape的梯度(无意义,返回0) + Tensor(0.0, dtype=ms.float32)) # eps的梯度(无意义,返回0) + + +FUSED_LAYER_NORM_SUPPORT_DIM = {64, 128, 192, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120} + + +class LayerNorm(Cell): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(LayerNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = normalized_shape + self.eps = eps + assert elementwise_affine, "MindSpore LayerNorm here requires elementwise_affine=True" + + # 初始化权重和偏置 + self.weight = Parameter(initializer(One(), normalized_shape), name="weight") + self.bias = Parameter(initializer(Zero(), normalized_shape), name="bias") + + # 定义普通LayerNorm函数(使用原生接口) + def ms_layer_norm(input): + return ops.layer_norm( + input, self.normalized_shape, self.weight.astype(input.dtype), + self.bias.astype(input.dtype), self.eps + ) + + # 定义融合LayerNorm函数(使用自定义Cell算子) + def fused_layer_norm(input): + if device_target == "GPU": + fused_op = FusedLayerNormFastOp() + return fused_op( + input, self.weight.astype(input.dtype), self.bias.astype(input.dtype), + self.normalized_shape, self.eps + ) + else: + return ms_layer_norm(input) + + # 根据支持情况选择函数 + if HAS_LAYER_NORM and normalized_shape[0] in FUSED_LAYER_NORM_SUPPORT_DIM: + self.func = fused_layer_norm + else: + self.func = ms_layer_norm + + def reset_parameters(self): + """重置参数""" + self.weight.set_data(initializer(One(), self.weight.shape)) + self.bias.set_data(initializer(Zero(), self.bias.shape)) + + def construct(self, input): + """前向计算接口""" + return self.func(input) + + def extra_repr(self): + return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine=True' +# import torch +# import numbers +# from torch.nn.parameter import Parameter +# from torch.nn import init +# from torch.nn import functional as F + +# try: +# import unicore_fused_layernorm +# import unicore_fused_layernorm_backward_gamma_beta +# HAS_LAYER_NORM = True +# except: +# print("fused_layer_norm is not installed corrected") +# HAS_LAYER_NORM = False + +# if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7: +# HAS_LAYER_NORM = False + +# class FusedLayerNormFastFunction(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, weight, bias, normalized_shape, eps): +# ctx.normalized_shape = normalized_shape +# ctx.eps = eps +# input = input.contiguous() +# weight = weight.contiguous() +# bias = bias.contiguous() +# output, mean, invvar = unicore_fused_layernorm.forward( +# input, ctx.normalized_shape, weight, bias, ctx.eps) +# ctx.save_for_backward(input, weight, bias, mean, invvar) +# return output +# @staticmethod +# def backward(ctx, grad_output): +# input_, weight_, bias_, mean, invvar = ctx.saved_tensors +# grad_input = grad_weight = grad_bias = None +# grad_input = unicore_fused_layernorm.backward( +# grad_output.contiguous(), mean, invvar, +# input_, ctx.normalized_shape, +# weight_, bias_, ctx.eps) +# grad_weight, grad_bias = unicore_fused_layernorm_backward_gamma_beta.backward( +# grad_output.contiguous(), mean, invvar, +# input_, ctx.normalized_shape, +# weight_, bias_, ctx.eps) +# return grad_input, grad_weight, grad_bias, None, None + +# FUSED_LAYER_NORM_SUPPORT_DIM = set([64, 128, 192, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120]) + +# class LayerNorm(torch.nn.Module): +# def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): +# super(LayerNorm, self).__init__() +# if isinstance(normalized_shape, numbers.Integral): +# normalized_shape = (normalized_shape,) +# self.normalized_shape = torch.Size(normalized_shape) +# self.eps = eps +# assert elementwise_affine +# self.weight = Parameter(torch.Tensor(*normalized_shape)) +# self.bias = Parameter(torch.Tensor(*normalized_shape)) +# self.reset_parameters() +# def torch_layer_norm(input): +# return F.layer_norm( +# input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps) +# def fused_layer_norm(input): +# if input.is_cuda: +# return FusedLayerNormFastFunction.apply( +# input, self.weight.type(input.dtype), self.bias.type(input.dtype), self.normalized_shape, self.eps) +# else: +# return F.layer_norm( +# input, self.normalized_shape, self.weight.type(input.dtype), self.bias.type(input.dtype), self.eps) +# self.func = torch_layer_norm if (not HAS_LAYER_NORM or normalized_shape[0] not in FUSED_LAYER_NORM_SUPPORT_DIM) else fused_layer_norm + +# def reset_parameters(self): +# init.ones_(self.weight) +# init.zeros_(self.bias) + +# def forward(self, input): +# return self.func(input) + +# def extra_repr(self): +# return '{normalized_shape}, eps={eps}, ' \ +# 'elementwise_affine=True'.format(**self.__dict__) diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/multihead_attention.py b/MindChemistry/applications/Uni-Mol/unicore/modules/multihead_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..d1e5654f8f0578774f9a2d26c748715234545071 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/multihead_attention.py @@ -0,0 +1,443 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + +import mindspore as ms +from mindspore import Tensor, nn, ops +from .softmax_dropout import softmax_dropout # 假设已适配为MindSpore版本 + + +class SelfMultiheadAttention(nn.Cell): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.1, + bias=True, + scaling_factor=1, + ): + super().__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + + # MindSpore中nn.Dense对应PyTorch的nn.Linear,has_bias对应bias参数 + self.in_proj = nn.Dense(embed_dim, embed_dim * 3, has_bias=bias) + self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias) + + def construct( # MindSpore中前向计算用construct替代forward + self, + query, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + return_attn: bool = False, + ) -> Tensor: + + bsz, tgt_len, embed_dim = query.shape # size()替换为shape + assert embed_dim == self.embed_dim + + # chunk替换为split,dim替换为axis + q, k, v = self.in_proj(query).split(3, axis=-1) + + # view替换为reshape,transpose替换为swapaxes,contiguous保持不变 + q = ( + q.reshape(bsz, tgt_len, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + * self.scaling + ) + if k is not None: + k = ( + k.reshape(bsz, -1, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + ) + if v is not None: + v = ( + v.reshape(bsz, -1, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + ) + + assert k is not None + src_len = k.shape[1] # size()替换为shape + + # 处理key_padding_mask的维度检查 + if key_padding_mask is not None and key_padding_mask.ndim == 0: # dim替换为ndim + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + # torch.bmm替换为ops.bmm + attn_weights = ops.bmm(q, k.swapaxes(1, 2)) # transpose替换为swapaxes + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if key_padding_mask is not None: + # 掩码填充:masked_fill_替换为masked_fill(MindSpore非in-place操作) + attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_), # torch.bool替换为ms.bool_ + ms.Tensor(float("-inf"), dtype=attn_weights.dtype) + ) + attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) + + if not return_attn: + attn = softmax_dropout( + attn_weights, self.dropout, self.training, bias=attn_bias, + ) + else: + attn_weights += attn_bias + attn = softmax_dropout( + attn_weights, self.dropout, self.training, inplace=False, + ) + + o = ops.bmm(attn, v) + assert list(o.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + + o = ( + o.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz, tgt_len, embed_dim) + ) + o = self.out_proj(o) + if not return_attn: + return o + else: + return o, attn_weights, attn + + +class CrossMultiheadAttention(nn.Cell): + def __init__( + self, + embed_dim, + num_heads, + dropout=0.1, + bias=True, + scaling_factor=1, + ): + super().__init__() + self.embed_dim = embed_dim + + self.num_heads = num_heads + self.dropout = dropout + + self.head_dim = embed_dim // num_heads + assert ( + self.head_dim * num_heads == self.embed_dim + ), "embed_dim must be divisible by num_heads" + self.scaling = (self.head_dim * scaling_factor) ** -0.5 + + # 替换nn.Linear为nn.Dense + self.q_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias) + self.k_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias) + self.v_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias) + + self.out_proj = nn.Dense(embed_dim, embed_dim, has_bias=bias) + + def construct( # 前向计算接口替换为construct + self, + query, + key, + value, + key_padding_mask: Optional[Tensor] = None, + attn_bias: Optional[Tensor] = None, + ) -> Tensor: + + bsz, tgt_len, embed_dim = query.shape # size()替换为shape + assert embed_dim == self.embed_dim + + q = self.q_proj(query) + k = self.k_proj(key) + v = self.v_proj(value) + + q = ( + q.reshape(bsz, tgt_len, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + * self.scaling + ) + if k is not None: + k = ( + k.reshape(bsz, -1, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + ) + if v is not None: + v = ( + v.reshape(bsz, -1, self.num_heads, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz * self.num_heads, -1, self.head_dim) + ) + + assert k is not None + src_len = k.shape[1] # size()替换为shape + + if key_padding_mask is not None and key_padding_mask.ndim == 0: # dim替换为ndim + key_padding_mask = None + + if key_padding_mask is not None: + assert key_padding_mask.shape[0] == bsz + assert key_padding_mask.shape[1] == src_len + + attn_weights = ops.bmm(q, k.swapaxes(1, 2)) # transpose替换为swapaxes,bmm替换为ops.bmm + + assert list(attn_weights.shape) == [bsz * self.num_heads, tgt_len, src_len] + + if key_padding_mask is not None: + # 掩码填充处理 + attn_weights = attn_weights.reshape(bsz, self.num_heads, tgt_len, src_len) + attn_weights = attn_weights.masked_fill( + key_padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_), # 类型转换适配 + ms.Tensor(float("-inf"), dtype=attn_weights.dtype) + ) + attn_weights = attn_weights.reshape(bsz * self.num_heads, tgt_len, src_len) + + attn = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias) + + o = ops.bmm(attn, v) + assert list(o.shape) == [bsz * self.num_heads, tgt_len, self.head_dim] + + o = ( + o.reshape(bsz, self.num_heads, tgt_len, self.head_dim) + .swapaxes(1, 2) + .contiguous() + .reshape(bsz, tgt_len, embed_dim) + ) + o = self.out_proj(o) + return o +# from typing import Dict, Optional + +# import torch +# from torch import Tensor, nn +# from .softmax_dropout import softmax_dropout + + +# class SelfMultiheadAttention(nn.Module): +# def __init__( +# self, +# embed_dim, +# num_heads, +# dropout=0.1, +# bias=True, +# scaling_factor=1, +# ): +# super().__init__() +# self.embed_dim = embed_dim + +# self.num_heads = num_heads +# self.dropout = dropout + +# self.head_dim = embed_dim // num_heads +# assert ( +# self.head_dim * num_heads == self.embed_dim +# ), "embed_dim must be divisible by num_heads" +# self.scaling = (self.head_dim * scaling_factor) ** -0.5 + +# self.in_proj = nn.Linear(embed_dim, embed_dim * 3, bias=bias) +# self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + +# def forward( +# self, +# query, +# key_padding_mask: Optional[Tensor] = None, +# attn_bias: Optional[Tensor] = None, +# return_attn: bool = False, +# ) -> Tensor: + +# bsz, tgt_len, embed_dim = query.size() +# assert embed_dim == self.embed_dim + +# q, k, v = self.in_proj(query).chunk(3, dim=-1) + +# q = ( +# q.view(bsz, tgt_len, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# * self.scaling +# ) +# if k is not None: +# k = ( +# k.view(bsz, -1, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# ) +# if v is not None: +# v = ( +# v.view(bsz, -1, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# ) + +# assert k is not None +# src_len = k.size(1) + +# # This is part of a workaround to get around fork/join parallelism +# # not supporting Optional types. +# if key_padding_mask is not None and key_padding_mask.dim() == 0: +# key_padding_mask = None + +# if key_padding_mask is not None: +# assert key_padding_mask.size(0) == bsz +# assert key_padding_mask.size(1) == src_len + +# attn_weights = torch.bmm(q, k.transpose(1, 2)) + +# assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + +# if key_padding_mask is not None: +# # don't attend to padding symbols +# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) +# attn_weights.masked_fill_( +# key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") +# ) +# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + +# if not return_attn: +# attn = softmax_dropout( +# attn_weights, self.dropout, self.training, bias=attn_bias, +# ) +# else: +# attn_weights += attn_bias +# attn = softmax_dropout( +# attn_weights, self.dropout, self.training, inplace=False, +# ) + +# o = torch.bmm(attn, v) +# assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + +# o = ( +# o.view(bsz, self.num_heads, tgt_len, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz, tgt_len, embed_dim) +# ) +# o = self.out_proj(o) +# if not return_attn: +# return o +# else: +# return o, attn_weights, attn + + +# class CrossMultiheadAttention(nn.Module): +# def __init__( +# self, +# embed_dim, +# num_heads, +# dropout=0.1, +# bias=True, +# scaling_factor=1, +# ): +# super().__init__() +# self.embed_dim = embed_dim + +# self.num_heads = num_heads +# self.dropout = dropout + +# self.head_dim = embed_dim // num_heads +# assert ( +# self.head_dim * num_heads == self.embed_dim +# ), "embed_dim must be divisible by num_heads" +# self.scaling = (self.head_dim * scaling_factor) ** -0.5 + +# self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias) +# self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias) +# self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + +# self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + +# def forward( +# self, +# query, +# key, +# value, +# key_padding_mask: Optional[Tensor] = None, +# attn_bias: Optional[Tensor] = None, +# ) -> Tensor: + +# bsz, tgt_len, embed_dim = query.size() +# assert embed_dim == self.embed_dim + +# q = self.q_proj(query) +# k = self.k_proj(key) +# v = self.v_proj(value) + +# q = ( +# q.view(bsz, tgt_len, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# * self.scaling +# ) +# if k is not None: +# k = ( +# k.view(bsz, -1, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# ) +# if v is not None: +# v = ( +# v.view(bsz, -1, self.num_heads, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz * self.num_heads, -1, self.head_dim) +# ) + +# assert k is not None +# src_len = k.size(1) + +# # This is part of a workaround to get around fork/join parallelism +# # not supporting Optional types. +# if key_padding_mask is not None and key_padding_mask.dim() == 0: +# key_padding_mask = None + +# if key_padding_mask is not None: +# assert key_padding_mask.size(0) == bsz +# assert key_padding_mask.size(1) == src_len + +# attn_weights = torch.bmm(q, k.transpose(1, 2)) + +# assert list(attn_weights.size()) == [bsz * self.num_heads, tgt_len, src_len] + +# if key_padding_mask is not None: +# # don't attend to padding symbols +# attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len) +# attn_weights.masked_fill_( +# key_padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), float("-inf") +# ) +# attn_weights = attn_weights.view(bsz * self.num_heads, tgt_len, src_len) + +# attn = softmax_dropout(attn_weights, self.dropout, self.training, bias=attn_bias) + +# o = torch.bmm(attn, v) +# assert list(o.size()) == [bsz * self.num_heads, tgt_len, self.head_dim] + +# o = ( +# o.view(bsz, self.num_heads, tgt_len, self.head_dim) +# .transpose(1, 2) +# .contiguous() +# .view(bsz, tgt_len, embed_dim) +# ) +# o = self.out_proj(o) +# return o diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/rms_norm.py b/MindChemistry/applications/Uni-Mol/unicore/modules/rms_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..fa38bcf6a5f1e9485a1957fdc0cd3b6d0d0bea3c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/rms_norm.py @@ -0,0 +1,250 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import numbers +from mindspore import Parameter, ops +from mindspore.nn import Cell +from mindspore.common.initializer import initializer, One + + +try: + import unicore_fused_rmsnorm + import unicore_fused_rmsnorm_backward_gamma + HAS_RMS_NORM = True +except: + print("fused_rms_norm is not installed correctly") + HAS_RMS_NORM = False + +# 适配MindSpore的设备检查:仅GPU支持融合RMSNorm(参考原逻辑) +device_target = ms.get_context("device_target") +if device_target != "GPU" or (device_target == "GPU" and not ms.cuda.is_available()): # MindSpore中检查GPU可用性 + HAS_RMS_NORM = False + + +class FusedRMSNormFastOp(ms.nn.Cell): + """MindSpore自定义融合RMSNorm算子,对应PyTorch的FusedRMSNormFastFunction""" + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps): + ctx.normalized_shape = normalized_shape + ctx.eps = eps + input = input.contiguous() + weight = weight.contiguous() + output, invvar = unicore_fused_rmsnorm.forward( + input, ctx.normalized_shape, weight, ctx.eps + ) + ctx.save_for_backward(input, weight, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_, weight_, invvar = ctx.saved_tensors + grad_input = unicore_fused_rmsnorm.backward( + grad_output.contiguous(), + invvar, + input_, + ctx.normalized_shape, + weight_, + ctx.eps, + ) + grad_weight = unicore_fused_rmsnorm_backward_gamma.backward( + grad_output.contiguous(), + invvar, + input_, + ctx.normalized_shape, + weight_, + ctx.eps, + ) + return grad_input, grad_weight, None, None + + +FUSED_RMS_NORM_SUPPORT_DIM = { + 64, 128, 192, 256, 320, 384, 512, 640, 768, 1024, 1280, 1536, 1792, 2048, 2560, 5120 +} + + +class RMSNorm(Cell): + def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): + super(RMSNorm, self).__init__() + if isinstance(normalized_shape, numbers.Integral): + normalized_shape = (normalized_shape,) + self.normalized_shape = normalized_shape # 替换torch.Size为元组(MindSpore中直接使用元组表示形状) + self.eps = eps + assert elementwise_affine, "MindSpore RMSNorm requires elementwise_affine=True" + + # 初始化权重参数(对应PyTorch的Parameter) + self.weight = Parameter(initializer(One(), normalized_shape), name="weight") + self.reset_parameters() + + # 定义普通RMSNorm函数(对应PyTorch的F.rms_norm) + def ms_rms_norm(input): + return ops.rms_norm( + input, + self.normalized_shape, + self.weight.astype(input.dtype), + self.eps + ) + + # 定义融合RMSNorm函数(使用自定义算子) + def fused_rms_norm(input): + if device_target == "GPU": # 仅GPU使用融合算子(对应原input.is_cuda) + return FusedRMSNormFastOp.apply( + input, + self.weight.astype(input.dtype), + self.normalized_shape, + self.eps, + ) + else: + return ms_rms_norm(input) + + # 根据支持情况选择函数(与原逻辑对齐) + self.func = ( + ms_rms_norm + if ( + not HAS_RMS_NORM + or normalized_shape[0] not in FUSED_RMS_NORM_SUPPORT_DIM + ) + else fused_rms_norm + ) + + def reset_parameters(self): + """重置权重参数为1(对应PyTorch的init.ones_)""" + self.weight.set_data(initializer(One(), self.weight.shape)) + + def construct(self, input): + """MindSpore前向计算接口(替代PyTorch的forward)""" + return self.func(input) + + def extra_repr(self): + return f"{self.normalized_shape}, eps={self.eps}, elementwise_affine=True" +# import torch +# import numbers +# from torch.nn.parameter import Parameter +# from torch.nn import init +# from torch.nn import functional as F + +# try: +# import unicore_fused_rmsnorm +# import unicore_fused_rmsnorm_backward_gamma + +# HAS_RMS_NORM = True +# except: +# print("fused_rms_norm is not installed corrected") +# HAS_RMS_NORM = False + +# if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7: +# HAS_RMS_NORM = False + + +# class FusedRMSNormFastFunction(torch.autograd.Function): +# @staticmethod +# def forward(ctx, input, weight, normalized_shape, eps): +# ctx.normalized_shape = normalized_shape +# ctx.eps = eps +# input = input.contiguous() +# weight = weight.contiguous() +# output, invvar = unicore_fused_rmsnorm.forward( +# input, ctx.normalized_shape, weight, ctx.eps +# ) +# ctx.save_for_backward(input, weight, invvar) +# return output + +# @staticmethod +# def backward(ctx, grad_output): +# input_, weight_, invvar = ctx.saved_tensors +# grad_input = grad_weight = None +# grad_input = unicore_fused_rmsnorm.backward( +# grad_output.contiguous(), +# invvar, +# input_, +# ctx.normalized_shape, +# weight_, +# ctx.eps, +# ) +# grad_weight = unicore_fused_rmsnorm_backward_gamma.backward( +# grad_output.contiguous(), +# invvar, +# input_, +# ctx.normalized_shape, +# weight_, +# ctx.eps, +# ) +# return grad_input, grad_weight, None, None + + +# FUSED_RMS_NORM_SUPPORT_DIM = set( +# [ +# 64, +# 128, +# 192, +# 256, +# 320, +# 384, +# 512, +# 640, +# 768, +# 1024, +# 1280, +# 1536, +# 1792, +# 2048, +# 2560, +# 5120, +# ] +# ) + + +# class RMSNorm(torch.nn.Module): +# def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True): +# super(RMSNorm, self).__init__() +# if isinstance(normalized_shape, numbers.Integral): +# normalized_shape = (normalized_shape,) +# self.normalized_shape = torch.Size(normalized_shape) +# self.eps = eps +# assert elementwise_affine +# self.weight = Parameter(torch.Tensor(*normalized_shape)) +# self.reset_parameters() + +# def torch_rms_norm(input): +# return F.rms_norm( +# input, +# self.normalized_shape, +# self.weight.type(input.dtype), +# self.eps, +# ) + +# def fused_rms_norm(input): +# if input.is_cuda: +# return FusedRMSNormFastFunction.apply( +# input, +# self.weight.type(input.dtype), +# self.normalized_shape, +# self.eps, +# ) +# else: +# return F.rms_norm( +# input, +# self.normalized_shape, +# self.weight.type(input.dtype), +# self.eps, +# ) + +# self.func = ( +# torch_rms_norm +# if ( +# not HAS_RMS_NORM +# or normalized_shape[0] not in FUSED_RMS_NORM_SUPPORT_DIM +# ) +# else fused_rms_norm +# ) + +# def reset_parameters(self): +# init.ones_(self.weight) + +# def forward(self, input): +# return self.func(input) + +# def extra_repr(self): +# return "{normalized_shape}, eps={eps}, " "elementwise_affine=True".format( +# **self.__dict__ +# ) diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/softmax_dropout.py b/MindChemistry/applications/Uni-Mol/unicore/modules/softmax_dropout.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2d3bbde9b16ba325f3161a62f7b02c75e1faeb --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/softmax_dropout.py @@ -0,0 +1,291 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import mindspore.ops as ops +from mindspore import Tensor + +try: + import unicore_fused_softmax_dropout + HAS_SOFTMAX = True +except: + print("fused_softmax is not installed correctly") + HAS_SOFTMAX = False + +# 适配MindSpore的设备检查:仅GPU支持融合softmax(参考原逻辑) +device_target = ms.get_context("device_target") +if device_target != "GPU" or (device_target == "GPU" and not ms.cuda.is_available()): + HAS_SOFTMAX = False + + +class SoftmaxDropoutFast(ms.nn.Cell): + """MindSpore自定义融合SoftmaxDropout算子,对应PyTorch的SoftmaxDropoutFast""" + @staticmethod + def forward(ctx, is_training, inputs, mask, bias, dropout_prob): + ( + dropout_results, + dropout_mask, + softmax_results, + ) = unicore_fused_softmax_dropout.forward( + is_training, inputs, mask, bias, dropout_prob, None + ) + if is_training: + ctx.dropout_prob = dropout_prob + ctx.save_for_backward(softmax_results, dropout_mask) + ctx.has_bias = bias is not None and bias.requires_grad + if ctx.has_bias: + ctx.bias_batch_dim = bias.shape[0] + return dropout_results + + @staticmethod + def backward(ctx, grad_output): + softmax_results, dropout_mask = ctx.saved_tensors + dropout_prob = ctx.dropout_prob + grad_output = grad_output.contiguous() + grad_input = unicore_fused_softmax_dropout.backward( + grad_output, softmax_results, dropout_mask, dropout_prob + ) + if ctx.has_bias: + grad_bias = grad_input.view( + -1, ctx.bias_batch_dim, grad_input.shape[-2], grad_input.shape[-1] + ).sum(axis=0) # MindSpore中sum的维度参数为axis + else: + grad_bias = None + return None, grad_input, None, grad_bias, None + + +def _check_mask(mask, input): + try: + assert mask.dtype == input.dtype, "mask and input must have the same dtype" + assert len(mask.shape) == len(input.shape), "wrong length of mask.shape" + assert ( + mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3] + ), "mask.shape[-3] must be 1 or input.shape[-3]" + if mask.shape[-3] == 1: + assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1" + else: + assert ( + mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2] + ), "mask.shape[-2] must be 1 or input.shape[-2]" + return True + except: + return False + + +def _check_bias(bias, input): + try: + assert bias.dtype == input.dtype, "bias and input must have the same dtype" + assert len(bias.shape) == len(input.shape), "wrong length of bias.shape" + assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]" + assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]" + len_shape = len(input.shape) + if len_shape > 3: + # head dim should be the same + assert ( + bias.shape[-3] == input.shape[-3] + ), "bias.shape[-3] must be input.shape[-3]" + offset = 3 + else: + offset = 2 + prev_non_one = True + for i in range(len_shape - offset - 1, -1, -1): + if prev_non_one: + assert ( + bias.shape[i] == input.shape[i] or bias.shape[i] == 1 + ), f"bias.shape[{i}] must be input.shape[{i}] or 1" + else: + assert bias.shape[i] == 1, f"bias.shape[{i}] must be 1" + prev_non_one = bias.shape[i] != 1 + return True + except: + return False + + +def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): + """softmax dropout, and mask, bias are optional. + Args: + input (mindspore.Tensor): input tensor + dropout_prob (float): dropout probability + is_training (bool, optional): is in training or not. Defaults to True. + mask (mindspore.Tensor, optional): the mask tensor, use as input + mask . Defaults to None. + bias (mindspore.Tensor, optional): the bias tensor, use as input + bias . Defaults to None. + + Returns: + mindspore.Tensor: the result after softmax + """ + input = input.contiguous() + if not inplace: + # MindSpore中clone替换为copy + input = input.copy() + # 检查是否在GPU且支持融合操作(对应原input.is_cuda) + if device_target == "GPU" and HAS_SOFTMAX: + input_size = input.shape # size()替换为shape + if mask is not None: + if _check_mask(mask, input): + mask = mask.contiguous().reshape(-1, mask.shape[-2], mask.shape[-1]) # view替换为reshape + else: + input += mask + mask = None + if bias is not None: + if _check_bias(bias, input): + bias = bias.contiguous().reshape(-1, input_size[-2], input_size[-1]) # view替换为reshape + else: + input += bias + bias = None + input = input.reshape(-1, input_size[-2], input_size[-1]) # view替换为reshape + if dropout_prob <= 0.0 or input_size[-1] <= 1024: + return SoftmaxDropoutFast.apply( + is_training, input, mask, bias, dropout_prob + ).reshape(*input_size) # view替换为reshape + else: + # MindSpore的dropout参数为p和training + return ops.dropout(SoftmaxDropoutFast.apply( + is_training, input, mask, bias, 0.0 + ).reshape(*input_size), p=dropout_prob, training=is_training) + else: + if mask is not None: + input += mask + if bias is not None: + input += bias + # 替换F.softmax和F.dropout为MindSpore的ops + return ops.dropout(ops.softmax(input, axis=-1), p=dropout_prob, training=is_training) +# import torch +# import torch.nn.functional as F + +# try: +# import unicore_fused_softmax_dropout +# HAS_SOFTMAX = True +# except: +# print("fused_softmax is not installed corrected") +# HAS_SOFTMAX = False + +# if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7: +# HAS_SOFTMAX = False + +# class SoftmaxDropoutFast(torch.autograd.Function): +# @staticmethod +# def forward(ctx, is_training, inputs, mask, bias, dropout_prob): +# ( +# dropout_results, +# dropout_mask, +# softmax_results, +# ) = unicore_fused_softmax_dropout.forward( +# is_training, inputs, mask, bias, dropout_prob, None +# ) +# if is_training: +# ctx.dropout_prob = dropout_prob +# ctx.save_for_backward(softmax_results, dropout_mask) +# ctx.has_bias = bias is not None and bias.requires_grad +# if ctx.has_bias: +# ctx.bias_batch_dim = bias.shape[0] +# return dropout_results + +# @staticmethod +# def backward(ctx, grad_output): +# softmax_results, dropout_mask = ctx.saved_tensors +# dropout_prob = ctx.dropout_prob +# grad_output = grad_output.contiguous() +# grad_input = unicore_fused_softmax_dropout.backward( +# grad_output, softmax_results, dropout_mask, dropout_prob +# ) +# if ctx.has_bias: +# grad_bias = grad_input.view( +# -1, ctx.bias_batch_dim, grad_input.shape[-2], grad_input.shape[-1] +# ).sum(dim=0) +# else: +# grad_bias = None +# return None, grad_input, None, grad_bias, None + + +# def _check_mask(mask, input): +# try: +# assert mask.dtype == input.dtype, "mask and input must have the same dtype" +# assert len(mask.shape) == len(input.shape), "wrong length of mask.shape" +# assert ( +# mask.shape[-3] == 1 or mask.shape[-3] == input.shape[-3] +# ), "mask.shape[-3] must be 1 or input.shape[-3]" +# if mask.shape[-3] == 1: +# assert mask.shape[-2] == 1, "when mask.shape[-3] == 1, mask.shape[-2] must be 1" +# else: +# assert ( +# mask.shape[-2] == 1 or mask.shape[-2] == input.shape[-2] +# ), "mask.shape[-2] must be 1 or input.shape[-2]" +# return True +# except: +# return False + + +# def _check_bias(bias, input): +# try: +# assert bias.dtype == input.dtype, "bias and input must have the same dtype" +# assert len(bias.shape) == len(input.shape), "wrong length of bias.shape" +# assert bias.shape[-1] == input.shape[-1], "bias.shape[-1] must be input.shape[-1]" +# assert bias.shape[-2] == input.shape[-2], "bias.shape[-2] must be input.shape[-2]" +# len_shape = len(input.shape) +# if len_shape > 3: +# # head dim should be the same +# assert ( +# bias.shape[-3] == input.shape[-3] +# ), "bias.shape[-3] must be input.shape[-3]" +# offset = 3 +# else: +# offset = 2 +# prev_non_one = True +# for i in range(len_shape - offset - 1, -1, -1): +# if prev_non_one: +# assert ( +# bias.shape[i] == input.shape[i] or bias.shape[i] == 1 +# ), "bias.shape[{}] must be input.shape[{}] or 1".format(i, i) +# else: +# assert bias.shape[i] == 1, "bias.shape[{}] must be 1".format(i) +# prev_non_one = bias.shape[i] != 1 +# return True +# except: +# return False + + +# def softmax_dropout(input, dropout_prob, is_training=True, mask=None, bias=None, inplace=True): +# """softmax dropout, and mask, bias are optional. +# Args: +# input (torch.Tensor): input tensor +# dropout_prob (float): dropout probability +# is_training (bool, optional): is in training or not. Defaults to True. +# mask (torch.Tensor, optional): the mask tensor, use as input + mask . Defaults to None. +# bias (torch.Tensor, optional): the bias tensor, use as input + bias . Defaults to None. + +# Returns: +# torch.Tensor: the result after softmax +# """ +# input = input.contiguous() +# if not inplace: +# # copy a input for non-inplace case +# input = input.clone() +# if input.is_cuda and HAS_SOFTMAX: +# input_size = input.size() +# if mask is not None: +# if _check_mask(mask, input): +# mask = mask.contiguous().view(-1, mask.shape[-2], mask.shape[-1]) +# else: +# input += mask +# mask = None +# if bias is not None: +# if _check_bias(bias, input): +# bias = bias.contiguous().view(-1, input_size[-2], input_size[-1]) +# else: +# input += bias +# bias = None +# input = input.view(-1, input_size[-2], input_size[-1]) +# if dropout_prob <= 0.0 or input_size[-1] <= 1024: +# return SoftmaxDropoutFast.apply( +# is_training, input, mask, bias, dropout_prob +# ).view(*input_size) +# else: +# return F.dropout(SoftmaxDropoutFast.apply( +# is_training, input, mask, bias, 0.0 +# ).view(*input_size), p=dropout_prob, training=is_training) +# else: +# if mask is not None: +# input += mask +# if bias is not None: +# input += bias +# return F.dropout(F.softmax(input, dim=-1), p=dropout_prob, training=is_training) diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder.py b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..038b31a57263b682582a8e178a344c9522f3a1d9 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder.py @@ -0,0 +1,360 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from . import TransformerDecoderLayer, LayerNorm # 假设已适配为MindSpore版本 +from .transformer_encoder import relative_position_bucket # 假设已适配为MindSpore版本 + + +def fill_with_neg_inf(t): + """替换torch.Tensor.fill_为MindSpore的fill(无in-place下划线)""" + return t.fill(float("-inf")) + + +def bulid_future_mask(seq_len): + """构建未来掩码,替换torch.triu和torch.zeros""" + return ops.triu( + fill_with_neg_inf(ops.zeros((seq_len, seq_len), dtype=ms.float32)), 1 + ) + + +class TransformerDecoder(nn.Cell): + def __init__( + self, + decoder_layers: int = 6, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + emb_dropout: float = 0.1, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + max_seq_len: int = 256, + activation_fn: str = "gelu", + rel_pos: bool = True, + rel_pos_bins: int = 32, + max_rel_pos: int = 128, + post_ln: bool = False, + auto_regressive: bool = True, + ) -> None: + + super().__init__() + self.emb_dropout = emb_dropout + self.max_seq_len = max_seq_len + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.emb_layer_norm = LayerNorm(self.embed_dim) + self.auto_regressive = auto_regressive + + # 初始化未来掩码 + if self.auto_regressive: + self._future_mask = bulid_future_mask(self.max_seq_len) + else: + self._future_mask = None + + # 最终层归一化 + if not post_ln: + self.final_layer_norm = LayerNorm(self.embed_dim) + else: + self.final_layer_norm = None + + # 替换nn.ModuleList为nn.CellList + self.layers = nn.CellList( + [ + TransformerDecoderLayer( + embed_dim=self.embed_dim, + ffn_embed_dim=ffn_embed_dim, + attention_heads=attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + post_ln=post_ln, + ) + for _ in range(decoder_layers) + ] + ) + + self.rel_pos = rel_pos + if self.rel_pos: + assert rel_pos_bins % 2 == 0 + self.rel_pos_bins = rel_pos_bins + self.max_rel_pos = max_rel_pos + # 替换nn.Embedding为mindspore.nn.Embedding + self.relative_attention_bias = nn.Embedding( + self.rel_pos_bins, self.attention_heads + ) + seq_len = self.max_seq_len + # 替换torch.arange为mindspore.ops.arange + context_position = ops.arange(seq_len, dtype=ms.int64)[:, None] + memory_position = ops.arange(seq_len, dtype=ms.int64)[None, :] + relative_position = memory_position - context_position + self.rp_bucket = relative_position_bucket( + relative_position, + num_buckets=self.rel_pos_bins, + max_distance=self.max_rel_pos + ) + self.rp_bucket -= self.rp_bucket.min() + + def get_rel_pos_bias(self, x): + # 适配设备和序列长度 + if self.rp_bucket.device != x.device: + self.rp_bucket = self.rp_bucket.to(x.device) + seq_len = x.shape[1] + rp_bucket = self.rp_bucket[:seq_len, :seq_len] + # 替换F.embedding为ops.embedding + values = ops.embedding(rp_bucket, self.relative_attention_bias.weight) + values = values.permute(2, 0, 1) # 维度交换保持一致 + return values.contiguous() + + def get_future_mask(self, x, attn_mask): + if not self.auto_regressive: + return attn_mask + # 设备和数据类型适配 + if self._future_mask.device != x.device: + self._future_mask = self._future_mask.to(x.device) + if self._future_mask.dtype != x.dtype: + self._future_mask = self._future_mask.astype(x.dtype) # 替换type_as为astype + + if attn_mask is None: + ret = self._future_mask[:x.shape[1], :x.shape[1]] + # 替换repeat逻辑(保持维度一致) + ret = ret.contiguous().unsqueeze(0).repeat( + x.shape[0] * self.attention_heads, 1, 1 + ) + return ret + else: + assert list(attn_mask.shape) == [ + x.shape[0] * self.attention_heads, x.shape[1], x.shape[1] + ] + return attn_mask + self._future_mask[:x.shape[1], :x.shape[1]] + + def construct( # 替换forward为construct + self, + emb, + encoder_out: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + encoder_padding_mask: Optional[ms.Tensor] = None, + attn_mask: Optional[ms.Tensor] = None, + encoder_attn_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + + seq_len = emb.shape[1] + x = self.emb_layer_norm(emb) + # 替换F.dropout为ops.dropout + x = ops.dropout(x, p=self.emb_dropout, training=self.training) + + # 处理padding mask + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).astype(x.dtype)) # 替换type_as为astype + + # 计算相对位置偏置 + rel_pos_bias = self.get_rel_pos_bias(x).repeat( + x.shape[0], 1, 1 + ) if self.rel_pos else None + + # 合并注意力掩码和相对位置偏置 + if attn_mask is None: + attn_mask = rel_pos_bias + elif rel_pos_bias is not None: + attn_mask += rel_pos_bias + + # 应用未来掩码(自回归模式) + if self.auto_regressive: + attn_mask = self.get_future_mask(x, attn_mask) + + # 合并padding mask和attn_mask + if attn_mask is not None and padding_mask is not None: + attn_mask = attn_mask.view(x.shape[0], -1, seq_len, seq_len) + # 替换masked_fill_为masked_fill(无in-place操作) + attn_mask = attn_mask.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_), # 替换to(torch.bool)为astype(ms.bool_) + float("-inf") + ) + attn_mask = attn_mask.view(-1, seq_len, seq_len) + padding_mask = None + + # 逐层处理 + for layer in self.layers: + x = layer( + x, + encoder_out=encoder_out, + padding_mask=padding_mask, + attn_bias=attn_mask, + encoder_padding_mask=encoder_padding_mask, + encoder_attn_bias=encoder_attn_mask + ) + + # 最终层归一化 + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + + return x +# from typing import Optional +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from . import TransformerDecoderLayer, LayerNorm +# from .transformer_encoder import relative_position_bucket + + +# def fill_with_neg_inf(t): +# return t.fill_(float("-inf")) + + +# def bulid_future_mask(seq_len): +# return torch.triu( +# fill_with_neg_inf(torch.zeros([seq_len, seq_len])), 1 +# ) + + +# class TransformerDecoder(nn.Module): +# def __init__( +# self, +# decoder_layers: int = 6, +# embed_dim: int = 768, +# ffn_embed_dim: int = 3072, +# attention_heads: int = 8, +# emb_dropout: float = 0.1, +# dropout: float = 0.1, +# attention_dropout: float = 0.1, +# activation_dropout: float = 0.0, +# max_seq_len: int = 256, +# activation_fn: str = "gelu", +# rel_pos: bool = True, +# rel_pos_bins: int = 32, +# max_rel_pos: int = 128, +# post_ln: bool = False, +# auto_regressive: bool = True, +# ) -> None: + +# super().__init__() +# self.emb_dropout = emb_dropout +# self.max_seq_len = max_seq_len +# self.embed_dim = embed_dim +# self.attention_heads = attention_heads +# self.emb_layer_norm = LayerNorm(self.embed_dim) +# self.auto_regressive = auto_regressive +# if self.auto_regressive: +# self._future_mask = bulid_future_mask(self.max_seq_len) +# else: +# self._future_mask = None +# if not post_ln: +# self.final_layer_norm = LayerNorm(self.embed_dim) +# else: +# self.final_layer_norm = None + +# self.layers = nn.ModuleList( +# [ +# TransformerDecoderLayer( +# embed_dim=self.embed_dim, +# ffn_embed_dim=ffn_embed_dim, +# attention_heads=attention_heads, +# dropout=dropout, +# attention_dropout=attention_dropout, +# activation_dropout=activation_dropout, +# activation_fn=activation_fn, +# post_ln=post_ln, + +# ) +# for _ in range(decoder_layers) +# ] +# ) + +# self.rel_pos = rel_pos +# if self.rel_pos: +# assert rel_pos_bins % 2 == 0 +# self.rel_pos_bins = rel_pos_bins +# self.max_rel_pos = max_rel_pos +# self.relative_attention_bias = nn.Embedding( +# self.rel_pos_bins, self.attention_heads) +# seq_len = self.max_seq_len +# context_position = torch.arange(seq_len, dtype=torch.long)[:, None] +# memory_position = torch.arange(seq_len, dtype=torch.long)[None, :] +# relative_position = memory_position - context_position +# self.rp_bucket = relative_position_bucket( +# relative_position, +# num_buckets=self.rel_pos_bins, +# max_distance=self.max_rel_pos +# ) +# self.rp_bucket -= self.rp_bucket.min() + +# def get_rel_pos_bias(self, x): +# # Assume the input is ordered. If your input token is permuted, you may need to update this accordingly +# if self.rp_bucket.device != x.device: +# self.rp_bucket = self.rp_bucket.to(x.device) +# seq_len = x.size(1) +# rp_bucket = self.rp_bucket[:seq_len, :seq_len] +# values = F.embedding(rp_bucket, self.relative_attention_bias.weight) +# values = values.permute([2, 0, 1]) +# return values.contiguous() + +# def get_future_mask(self, x, attn_mask): +# if not self.auto_regressive: +# return attn_mask +# if self._future_mask.device != x.device: +# self._future_mask = self._future_mask.to(x.device) +# if self._future_mask.dtype != x.dtype: +# self._future_mask = self._future_mask.type_as(x) +# if attn_mask is None: +# ret = self._future_mask[:x.size(1), :x.size(1)] +# ret = ret.contiguous().unsqueeze(0).repeat( +# x.size(0)*self.attention_heads, 1, 1) +# return ret +# else: +# assert list(attn_mask.size()) == [x.size( +# 0) * self.attention_heads, x.size(1), x.size(1)] +# return attn_mask + self._future_mask[:x.size(1), :x.size(1)] + +# def forward( +# self, +# emb, +# encoder_out: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# encoder_padding_mask: Optional[torch.Tensor] = None, +# attn_mask: Optional[torch.Tensor] = None, +# encoder_attn_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: + +# seq_len = emb.size(1) +# x = self.emb_layer_norm(emb) +# x = F.dropout(x, p=self.emb_dropout, training=self.training) + +# # account for padding while computing the representation +# if padding_mask is not None: +# x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + +# rel_pos_bias = self.get_rel_pos_bias(x).repeat( +# x.size(0), 1, 1) if self.rel_pos else None + +# if attn_mask is None: +# attn_mask = rel_pos_bias +# elif rel_pos_bias is not None: +# attn_mask += rel_pos_bias + +# if self.auto_regressive: +# attn_mask = self.get_future_mask(x, attn_mask) + +# if attn_mask is not None and padding_mask is not None: +# # merge key_padding_mask and attn_mask +# attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) +# attn_mask.masked_fill_( +# padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), +# float("-inf") +# ) +# attn_mask = attn_mask.view(-1, seq_len, seq_len) +# padding_mask = None + +# for layer in self.layers: +# x = layer(x, encoder_out=encoder_out, padding_mask=padding_mask, attn_bias=attn_mask, +# encoder_padding_mask=encoder_padding_mask, encoder_attn_bias=encoder_attn_mask) + +# if self.final_layer_norm is not None: +# x = self.final_layer_norm(x) + +# return x diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder_layer.py b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..765a5a9a1472c29ea7c392cd24a3f4ff62d4cdb1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_decoder_layer.py @@ -0,0 +1,239 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + +import mindspore as ms +import mindspore.ops as ops +from unicore import utils # 假设已适配为MindSpore版本 +from mindspore import nn +from . import LayerNorm, SelfMultiheadAttention, CrossMultiheadAttention # 假设已适配为MindSpore版本 + + +class TransformerDecoderLayer(nn.Cell): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + activation_fn: str = "gelu", + post_ln=False, + ) -> None: + super().__init__() + + # Initialize parameters + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.attention_dropout = attention_dropout + + self.dropout = dropout + self.activation_dropout = activation_dropout + self.activation_fn = utils.get_activation_fn(activation_fn) # 假设utils已适配 + + self.self_attn = SelfMultiheadAttention( + self.embed_dim, + attention_heads, + dropout=attention_dropout, + ) + + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + + self.encoder_attn = CrossMultiheadAttention( + self.embed_dim, + attention_heads, + dropout=attention_dropout, + ) + + # layer norm associated with the encoder attention layer + self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + + # 替换nn.Linear为nn.Dense + self.fc1 = nn.Dense(self.embed_dim, ffn_embed_dim) + self.fc2 = nn.Dense(ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + self.post_ln = post_ln + + def construct( # 替换forward为construct + self, + x: ms.Tensor, + encoder_out: ms.Tensor = None, + attn_bias: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + encoder_attn_bias: Optional[ms.Tensor] = None, + encoder_padding_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + if not self.post_ln: + x = self.self_attn_layer_norm(x) + # 自注意力计算 + x = self.self_attn( + query=x, + key_padding_mask=padding_mask, + attn_bias=attn_bias, + ) + # 替换F.dropout为ops.dropout + x = ops.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.self_attn_layer_norm(x) + + if encoder_out is not None: + residual = x + if not self.post_ln: + x = self.encoder_attn_layer_norm(x) + x = self.encoder_attn( + query=x, + key=encoder_out, + value=encoder_out, + key_padding_mask=encoder_padding_mask, + attn_bias=encoder_attn_bias, + ) + x = ops.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.encoder_attn_layer_norm(x) + + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = ops.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = ops.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + return x +# from typing import Dict, Optional + +# import torch +# import torch.nn.functional as F +# from unicore import utils +# from torch import nn +# from . import LayerNorm, SelfMultiheadAttention, CrossMultiheadAttention + +# class TransformerDecoderLayer(nn.Module): +# """ +# Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained +# models. +# """ + +# def __init__( +# self, +# embed_dim: int = 768, +# ffn_embed_dim: int = 3072, +# attention_heads: int = 8, +# dropout: float = 0.1, +# attention_dropout: float = 0.1, +# activation_dropout: float = 0.0, +# activation_fn: str = "gelu", +# post_ln = False, +# ) -> None: +# super().__init__() + +# # Initialize parameters +# self.embed_dim = embed_dim +# self.attention_heads = attention_heads +# self.attention_dropout = attention_dropout + +# self.dropout = dropout +# self.activation_dropout = activation_dropout +# self.activation_fn = utils.get_activation_fn(activation_fn) + +# self.self_attn = SelfMultiheadAttention( +# self.embed_dim, +# attention_heads, +# dropout=attention_dropout, +# ) + +# # layer norm associated with the self attention layer +# self.self_attn_layer_norm = LayerNorm(self.embed_dim) + +# self.encoder_attn = CrossMultiheadAttention( +# self.embed_dim, +# attention_heads, +# dropout=attention_dropout, +# ) + +# # layer norm associated with the self attention layer +# self.encoder_attn_layer_norm = LayerNorm(self.embed_dim) + +# self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim) +# self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim) +# self.final_layer_norm = LayerNorm(self.embed_dim) +# self.post_ln = post_ln + + +# def forward( +# self, +# x: torch.Tensor, +# encoder_out:torch.Tensor=None, +# attn_bias: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# encoder_attn_bias: Optional[torch.Tensor] = None, +# encoder_padding_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: +# """ +# LayerNorm is applied either before or after the self-attention/ffn +# modules similar to the original Transformer implementation. +# """ +# residual = x +# if not self.post_ln: +# x = self.self_attn_layer_norm(x) +# # new added +# x = self.self_attn( +# query=x, +# key_padding_mask=padding_mask, +# attn_bias=attn_bias, +# ) +# x = F.dropout(x, p=self.dropout, training=self.training) +# x = residual + x +# if self.post_ln: +# x = self.self_attn_layer_norm(x) + +# if encoder_out is not None: +# residual = x +# if not self.post_ln: +# x = self.encoder_attn_layer_norm(x) +# x = self.encoder_attn( +# query=x, +# key=encoder_out, +# value=encoder_out, +# key_padding_mask=encoder_padding_mask, +# attn_bias=encoder_attn_bias, +# ) +# #x = self.dropout_module(x) +# x = F.dropout(x, p=self.dropout, training=self.training) +# x = residual + x +# if self.post_ln: +# x = self.encoder_attn_layer_norm(x) + + +# residual = x +# if not self.post_ln: +# x = self.final_layer_norm(x) +# x = self.fc1(x) +# x = self.activation_fn(x) +# x = F.dropout(x, p=self.activation_dropout, training=self.training) +# x = self.fc2(x) +# x = F.dropout(x, p=self.dropout, training=self.training) +# x = residual + x +# if self.post_ln: +# x = self.final_layer_norm(x) +# return x diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder.py b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..a7fc84631daa69bcbea62135b0782faa2fb94639 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder.py @@ -0,0 +1,336 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional + +import math +import mindspore as ms +import mindspore.nn as nn +import mindspore.ops as ops +from . import TransformerEncoderLayer, LayerNorm # 假设已适配为MindSpore版本 + + +def init_bert_params(module): + if not getattr(module, 'can_global_init', True): + return + + def normal_(data): + # MindSpore中无in-place操作,替换torch.normal_为标准正态分布初始化 + normal_data = ops.normal(data.shape, mean=0.0, std=0.02, dtype=data.dtype) + data.assign_value(normal_data.astype(data.dtype)) + + if isinstance(module, nn.Dense): + normal_(module.weight) + if module.bias is not None: + module.bias.assign_value(ops.zeros_like(module.bias)) # 替换zero_为assign zeros + if isinstance(module, nn.Embedding): + normal_(module.weight) + if module.padding_idx is not None: + # 对padding_idx位置的权重置零 + padding_weight = ops.zeros_like(module.weight.data[module.padding_idx]) + module.weight.data[module.padding_idx] = padding_weight + + +def relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + sign = ops.sign(relative_position) + num_buckets //= 2 + n = ops.abs(relative_position) + + # 一半桶用于精确位置增量 + max_exact = num_buckets // 2 + is_small = n < max_exact + max_bucket_val = num_buckets - 1 - max_exact + + # 另一半桶用于对数级增长的位置区间(直到max_distance) + # 替换torch.log为ops.log,torch.ceil为ops.ceil,.long()为astype(ms.int64) + val_if_large = max_exact + ops.ceil( + ops.log(n.astype(ms.float32) / max_exact) / math.log((max_distance - 1) / max_exact) * (max_bucket_val) + ).astype(ms.int64) + val_if_large = ops.minimum(val_if_large, ops.full_like(val_if_large, num_buckets - 1, dtype=ms.int64)) + + # 替换torch.where为ops.where + ret = ops.where(is_small, n, val_if_large) * sign + return ret + + +class TransformerEncoder(nn.Cell): + def __init__( + self, + encoder_layers: int = 6, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + emb_dropout: float = 0.1, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + max_seq_len: int = 256, + activation_fn: str = "gelu", + rel_pos: bool = True, + rel_pos_bins: int = 32, + max_rel_pos: int = 128, + post_ln: bool = False, + ) -> None: + + super().__init__() + self.emb_dropout = emb_dropout + self.max_seq_len = max_seq_len + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.emb_layer_norm = LayerNorm(self.embed_dim) + + # 最终层归一化 + if not post_ln: + self.final_layer_norm = LayerNorm(self.embed_dim) + else: + self.final_layer_norm = None + + # 替换nn.ModuleList为nn.CellList + self.layers = nn.CellList( + [ + TransformerEncoderLayer( + embed_dim=self.embed_dim, + ffn_embed_dim=ffn_embed_dim, + attention_heads=attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + post_ln=post_ln, + ) + for _ in range(encoder_layers) + ] + ) + + self.rel_pos = rel_pos + + if self.rel_pos: + assert rel_pos_bins % 2 == 0 + self.rel_pos_bins = rel_pos_bins + self.max_rel_pos = max_rel_pos + # 替换nn.Embedding为mindspore.nn.Embedding + self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads) + seq_len = self.max_seq_len + # 替换torch.arange为ops.arange,dtype=torch.long改为ms.int64 + context_position = ops.arange(seq_len, dtype=ms.int64)[:, None] + memory_position = ops.arange(seq_len, dtype=ms.int64)[None, :] + relative_position = memory_position - context_position + self.rp_bucket = relative_position_bucket( + relative_position, + num_buckets=self.rel_pos_bins, + max_distance=self.max_rel_pos + ) + self.rp_bucket -= self.rp_bucket.min() + + def get_rel_pos_bias(self, x): + # 设备适配(MindSpore张量支持to方法) + if self.rp_bucket.device != x.device: + self.rp_bucket = self.rp_bucket.to(x.device) + seq_len = x.shape[1] # 替换size(1)为shape[1] + rp_bucket = self.rp_bucket[:seq_len, :seq_len] + # 替换F.embedding为ops.embedding + values = ops.embedding(rp_bucket, self.relative_attention_bias.weight) + values = values.permute(2, 0, 1) # 维度交换保持一致 + return values.contiguous() + + def construct( # 替换forward为construct + self, + emb: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + + seq_len = emb.shape[1] # 替换size(1)为shape[1] + x = self.emb_layer_norm(emb) + # 替换F.dropout为ops.dropout + x = ops.dropout(x, p=self.emb_dropout, training=self.training) + + # 处理padding mask(替换type_as为astype) + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).astype(x.dtype)) + + # 计算相对位置偏置 + rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.shape[0], 1, 1) if self.rel_pos else None + if attn_mask is None: + attn_mask = rel_pos_bias + elif rel_pos_bias is not None: + attn_mask += rel_pos_bias + + # 合并padding mask和attn_mask(替换in-place操作masked_fill_为masked_fill) + if attn_mask is not None and padding_mask is not None: + attn_mask = attn_mask.view(x.shape[0], -1, seq_len, seq_len) + attn_mask = attn_mask.masked_fill( + padding_mask.unsqueeze(1).unsqueeze(2).astype(ms.bool_), # 替换to(torch.bool)为astype(ms.bool_) + float("-inf") + ) + attn_mask = attn_mask.view(-1, seq_len, seq_len) + padding_mask = None + + # 逐层处理 + for layer in self.layers: + x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask) + + # 最终层归一化 + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + + return x +# from typing import Optional + +# import math +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from . import TransformerEncoderLayer, LayerNorm + + +# def init_bert_params(module): +# if not getattr(module, 'can_global_init', True): +# return +# def normal_(data): +# data.copy_( +# data.cpu().normal_(mean=0.0, std=0.02).to(data.device) +# ) +# if isinstance(module, nn.Linear): +# normal_(module.weight.data) +# if module.bias is not None: +# module.bias.data.zero_() +# if isinstance(module, nn.Embedding): +# normal_(module.weight.data) +# if module.padding_idx is not None: +# module.weight.data[module.padding_idx].zero_() + + +# def relative_position_bucket(relative_position, num_buckets=32, max_distance=128): +# sign = torch.sign(relative_position) +# num_buckets //= 2 +# n = torch.abs(relative_position) + +# # half of the buckets are for exact increments in positions +# max_exact = num_buckets // 2 +# is_small = n < max_exact +# max_bucket_val = num_buckets - 1 - max_exact +# # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance +# val_if_large = max_exact + torch.ceil( +# torch.log(n.float() / max_exact) / math.log((max_distance - 1) / max_exact) * (max_bucket_val) +# ).long() +# val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) +# ret = torch.where(is_small, n, val_if_large) * sign +# return ret + + +# class TransformerEncoder(nn.Module): +# def __init__( +# self, +# encoder_layers: int = 6, +# embed_dim: int = 768, +# ffn_embed_dim: int = 3072, +# attention_heads: int = 8, +# emb_dropout: float = 0.1, +# dropout: float = 0.1, +# attention_dropout: float = 0.1, +# activation_dropout: float = 0.0, +# max_seq_len: int = 256, +# activation_fn: str = "gelu", +# rel_pos: bool = True, +# rel_pos_bins: int = 32, +# max_rel_pos: int = 128, +# post_ln: bool = False, +# ) -> None: + +# super().__init__() +# self.emb_dropout = emb_dropout +# self.max_seq_len = max_seq_len +# self.embed_dim = embed_dim +# self.attention_heads = attention_heads +# self.emb_layer_norm = LayerNorm(self.embed_dim) +# if not post_ln: +# self.final_layer_norm = LayerNorm(self.embed_dim) +# else: +# self.final_layer_norm = None + +# self.layers = nn.ModuleList( +# [ +# TransformerEncoderLayer( +# embed_dim=self.embed_dim, +# ffn_embed_dim=ffn_embed_dim, +# attention_heads=attention_heads, +# dropout=dropout, +# attention_dropout=attention_dropout, +# activation_dropout=activation_dropout, +# activation_fn=activation_fn, +# post_ln=post_ln, + +# ) +# for _ in range(encoder_layers) +# ] +# ) + +# self.rel_pos = rel_pos + +# if self.rel_pos: +# assert rel_pos_bins % 2 == 0 +# self.rel_pos_bins = rel_pos_bins +# self.max_rel_pos = max_rel_pos +# self.relative_attention_bias = nn.Embedding(self.rel_pos_bins, self.attention_heads) +# seq_len = self.max_seq_len +# context_position = torch.arange(seq_len, dtype=torch.long)[:, None] +# memory_position = torch.arange(seq_len, dtype=torch.long)[None, :] +# relative_position = memory_position - context_position +# self.rp_bucket = relative_position_bucket( +# relative_position, +# num_buckets=self.rel_pos_bins, +# max_distance=self.max_rel_pos +# ) +# self.rp_bucket -= self.rp_bucket.min() + +# def get_rel_pos_bias(self, x): +# # Assume the input is ordered. If your input token is permuted, you may need to update this accordingly +# if self.rp_bucket.device != x.device: +# self.rp_bucket = self.rp_bucket.to(x.device) +# seq_len = x.size(1) +# rp_bucket = self.rp_bucket[:seq_len, :seq_len] +# values = F.embedding(rp_bucket, self.relative_attention_bias.weight) +# values = values.permute([2, 0, 1]) +# return values.contiguous() + +# def forward( +# self, +# emb: torch.Tensor, +# attn_mask: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: + +# seq_len = emb.size(1) +# x = self.emb_layer_norm(emb) +# x = F.dropout(x, p=self.emb_dropout, training=self.training) + +# # account for padding while computing the representation +# if padding_mask is not None: +# x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) + +# rel_pos_bias = self.get_rel_pos_bias(x).repeat(x.size(0), 1, 1) if self.rel_pos else None +# if attn_mask is None: +# attn_mask = rel_pos_bias +# elif rel_pos_bias is not None: +# attn_mask += rel_pos_bias + +# if attn_mask is not None and padding_mask is not None: +# # merge key_padding_mask and attn_mask +# attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) +# attn_mask.masked_fill_( +# padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), +# float("-inf") +# ) +# attn_mask = attn_mask.view(-1, seq_len, seq_len) +# padding_mask = None + +# for layer in self.layers: +# x = layer(x, padding_mask=padding_mask, attn_bias=attn_mask) + +# if self.final_layer_norm is not None: +# x = self.final_layer_norm(x) + +# return x \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder_layer.py b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder_layer.py new file mode 100644 index 0000000000000000000000000000000000000000..a99034117e86aba63b4123c89e2496fd2b9a37c6 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/modules/transformer_encoder_layer.py @@ -0,0 +1,192 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Dict, Optional + +import mindspore as ms +import mindspore.ops as ops +from unicore import utils # 假设已适配为MindSpore版本 +from mindspore import nn +from . import LayerNorm, SelfMultiheadAttention # 假设已适配为MindSpore版本 + + +class TransformerEncoderLayer(nn.Cell): + """ + Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained + models. + """ + + def __init__( + self, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + activation_fn: str = "gelu", + post_ln=False, + ) -> None: + super().__init__() + + # Initialize parameters + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.attention_dropout = attention_dropout + + self.dropout = dropout + self.activation_dropout = activation_dropout + self.activation_fn = utils.get_activation_fn(activation_fn) # 假设utils已适配 + + self.self_attn = SelfMultiheadAttention( + self.embed_dim, + attention_heads, + dropout=attention_dropout, + ) + # layer norm associated with the self attention layer + self.self_attn_layer_norm = LayerNorm(self.embed_dim) + # 替换nn.Linear为nn.Dense + self.fc1 = nn.Dense(self.embed_dim, ffn_embed_dim) + self.fc2 = nn.Dense(ffn_embed_dim, self.embed_dim) + self.final_layer_norm = LayerNorm(self.embed_dim) + self.post_ln = post_ln + + def construct( # 替换forward为construct + self, + x: ms.Tensor, + attn_bias: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + return_attn: bool = False, + ) -> ms.Tensor: + """ + LayerNorm is applied either before or after the self-attention/ffn + modules similar to the original Transformer implementation. + """ + residual = x + if not self.post_ln: + x = self.self_attn_layer_norm(x) + # 自注意力计算 + x = self.self_attn( + query=x, + key_padding_mask=padding_mask, + attn_bias=attn_bias, + return_attn=return_attn, + ) + if return_attn: + x, attn_weights, attn_probs = x + # 替换F.dropout为ops.dropout + x = ops.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.self_attn_layer_norm(x) + + residual = x + if not self.post_ln: + x = self.final_layer_norm(x) + x = self.fc1(x) + x = self.activation_fn(x) + x = ops.dropout(x, p=self.activation_dropout, training=self.training) + x = self.fc2(x) + x = ops.dropout(x, p=self.dropout, training=self.training) + x = residual + x + if self.post_ln: + x = self.final_layer_norm(x) + if not return_attn: + return x + else: + return x, attn_weights, attn_probs +# from typing import Dict, Optional + +# import torch +# import torch.nn.functional as F +# from unicore import utils +# from torch import nn +# from . import LayerNorm, SelfMultiheadAttention + +# class TransformerEncoderLayer(nn.Module): +# """ +# Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained +# models. +# """ + +# def __init__( +# self, +# embed_dim: int = 768, +# ffn_embed_dim: int = 3072, +# attention_heads: int = 8, +# dropout: float = 0.1, +# attention_dropout: float = 0.1, +# activation_dropout: float = 0.0, +# activation_fn: str = "gelu", +# post_ln = False, +# ) -> None: +# super().__init__() + +# # Initialize parameters +# self.embed_dim = embed_dim +# self.attention_heads = attention_heads +# self.attention_dropout = attention_dropout + +# self.dropout = dropout +# self.activation_dropout = activation_dropout +# self.activation_fn = utils.get_activation_fn(activation_fn) + +# self.self_attn = SelfMultiheadAttention( +# self.embed_dim, +# attention_heads, +# dropout=attention_dropout, +# ) +# # layer norm associated with the self attention layer +# self.self_attn_layer_norm = LayerNorm(self.embed_dim) +# self.fc1 = nn.Linear(self.embed_dim, ffn_embed_dim) +# self.fc2 = nn.Linear(ffn_embed_dim, self.embed_dim) +# self.final_layer_norm = LayerNorm(self.embed_dim) +# self.post_ln = post_ln + + +# def forward( +# self, +# x: torch.Tensor, +# attn_bias: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# return_attn: bool=False, +# ) -> torch.Tensor: +# """ +# LayerNorm is applied either before or after the self-attention/ffn +# modules similar to the original Transformer implementation. +# """ +# residual = x +# if not self.post_ln: +# x = self.self_attn_layer_norm(x) +# # new added +# x = self.self_attn( +# query=x, +# key_padding_mask=padding_mask, +# attn_bias=attn_bias, +# return_attn=return_attn, +# ) +# if return_attn: +# x, attn_weights, attn_probs = x +# x = F.dropout(x, p=self.dropout, training=self.training) +# x = residual + x +# if self.post_ln: +# x = self.self_attn_layer_norm(x) + +# residual = x +# if not self.post_ln: +# x = self.final_layer_norm(x) +# x = self.fc1(x) +# x = self.activation_fn(x) +# x = F.dropout(x, p=self.activation_dropout, training=self.training) +# x = self.fc2(x) +# x = F.dropout(x, p=self.dropout, training=self.training) +# x = residual + x +# if self.post_ln: +# x = self.final_layer_norm(x) +# if not return_attn: +# return x +# else: +# return x, attn_weights, attn_probs + \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/nan_detector.py b/MindChemistry/applications/Uni-Mol/unicore/nan_detector.py new file mode 100644 index 0000000000000000000000000000000000000000..af3660943446d28f27cb1c91af201a04a70ff891 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/nan_detector.py @@ -0,0 +1,225 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import mindspore as ms +import mindspore.ops as ops + + +logger = logging.getLogger(__name__) + + +class NanDetector: + """ + Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name + """ + + def __init__(self, model, forward=True, backward=True): + self.bhooks = [] + self.fhooks = [] + self.forward = forward + self.backward = backward + self.named_parameters = list(model.named_parameters()) + self.reset() + + for name, mod in model.named_modules(): + mod.__module_name = name # 为模块添加名称属性 + self.add_hooks(mod) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, exc_traceback): + # 输出所有模型梯度范数以辅助调试 + norm = {} + gradients = {} + for name, param in self.named_parameters: + if param.grad is not None: + # MindSpore中计算L2范数,转换为float32 + grad_norm = ops.norm(param.grad, p=2).astype(ms.float32) + norm[name] = grad_norm.asnumpy().item() # 转换为Python数值 + if ops.isnan(grad_norm).any() or ops.isinf(grad_norm).any(): + gradients[name] = param.grad + + if len(gradients) > 0: + logger.info("Detected nan/inf grad norm, dumping norms...") + logger.info(f"norms: {norm}") + logger.info(f"gradients: {gradients}") + + self.close() + + def add_hooks(self, module): + if self.forward: + # MindSpore注册前向钩子,参数为(net, inputs, outputs) + self.fhooks.append(module.register_forward_hook(self.fhook_fn)) + if self.backward: + # MindSpore注册反向钩子,参数为(net, grad_inputs, grad_outputs) + self.bhooks.append(module.register_backward_hook(self.bhook_fn)) + + def reset(self): + self.has_printed_f = False # 前向已打印标记 + self.has_printed_b = False # 反向已打印标记 + + def _detect(self, tensor, name, backward): + err = None + # 检查浮点类型张量,且元素数量不少于2(单元素张量信息量少) + if ( + ms.is_floating_point(tensor) + and tensor.size().prod() >= 2 # MindSpore中用size().prod()替代numel() + ): + with ms.no_grad(): # MindSpore无梯度上下文 + if ops.isnan(tensor).any(): + err = "NaN" + elif ops.isinf(tensor).any(): + err = "Inf" + if err is not None: + err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" + return err + + def _apply(self, module, inp, x, backward): + if isinstance(x, ms.Tensor): + # 处理输入(取第一个输入用于信息打印) + if isinstance(inp, tuple) and len(inp) > 0: + inp = inp[0] + err = self._detect(x, module.__module_name, backward) + if err is not None: + # 前向传播时添加输入的最大/最小值信息 + if isinstance(inp, ms.Tensor) and not backward: + inp_max = ops.max(inp).asnumpy().item() + inp_min = ops.min(inp).asnumpy().item() + err += f" input max: {inp_max}, input min: {inp_min}" + + # 记录已打印状态,避免重复输出 + has_printed_attr = "has_printed_b" if backward else "has_printed_f" + if not getattr(self, has_printed_attr): + logger.warning(err) + setattr(self, has_printed_attr, True) + elif isinstance(x, dict): + # 递归处理字典类型输出 + for v in x.values(): + self._apply(module, inp, v, backward) + elif isinstance(x, (list, tuple)): + # 递归处理列表/元组类型输出 + for v in x: + self._apply(module, inp, v, backward) + + def fhook_fn(self, module, inp, output): + """前向钩子函数""" + if not self.has_printed_f: + self._apply(module, inp, output, backward=False) + + def bhook_fn(self, module, grad_in, grad_out): + """反向钩子函数(MindSpore中输入为梯度输入和梯度输出)""" + if not self.has_printed_b: + # 反向传播中检测梯度输出 + self._apply(module, grad_in, grad_out, backward=True) + + def close(self): + """移除所有钩子""" + for hook in self.fhooks + self.bhooks: + hook.remove() +# import logging + +# import torch + + +# logger = logging.getLogger(__name__) + + +# class NanDetector: +# """ +# Detects the first NaN or Inf in forward and/or backward pass and logs, together with the module name +# """ + +# def __init__(self, model, forward=True, backward=True): +# self.bhooks = [] +# self.fhooks = [] +# self.forward = forward +# self.backward = backward +# self.named_parameters = list(model.named_parameters()) +# self.reset() + +# for name, mod in model.named_modules(): +# mod.__module_name = name +# self.add_hooks(mod) + +# def __enter__(self): +# return self + +# def __exit__(self, exc_type, exc_value, exc_traceback): +# # Dump out all model gnorms to enable better debugging +# norm = {} +# gradients = {} +# for name, param in self.named_parameters: +# if param.grad is not None: +# grad_norm = torch.norm(param.grad.data, p=2, dtype=torch.float32) +# norm[name] = grad_norm.item() +# if torch.isnan(grad_norm).any() or torch.isinf(grad_norm).any(): +# gradients[name] = param.grad.data +# if len(gradients) > 0: +# logger.info("Detected nan/inf grad norm, dumping norms...") +# logger.info(f"norms: {norm}") +# logger.info(f"gradients: {gradients}") + +# self.close() + +# def add_hooks(self, module): +# if self.forward: +# self.fhooks.append(module.register_forward_hook(self.fhook_fn)) +# if self.backward: +# self.bhooks.append(module.register_backward_hook(self.bhook_fn)) + +# def reset(self): +# self.has_printed_f = False +# self.has_printed_b = False + +# def _detect(self, tensor, name, backward): +# err = None +# if ( +# torch.is_floating_point(tensor) +# # single value tensors (like the loss) will not provide much info +# and tensor.numel() >= 2 +# ): +# with torch.no_grad(): +# if torch.isnan(tensor).any(): +# err = "NaN" +# elif torch.isinf(tensor).any(): +# err = "Inf" +# if err is not None: +# err = f"{err} detected in output of {name}, shape: {tensor.shape}, {'backward' if backward else 'forward'}" +# return err + +# def _apply(self, module, inp, x, backward): +# if torch.is_tensor(x): +# if isinstance(inp, tuple) and len(inp) > 0: +# inp = inp[0] +# err = self._detect(x, module.__module_name, backward) +# if err is not None: +# if torch.is_tensor(inp) and not backward: +# err += ( +# f" input max: {inp.max().item()}, input min: {inp.min().item()}" +# ) + +# has_printed_attr = "has_printed_b" if backward else "has_printed_f" +# logger.warning(err) +# setattr(self, has_printed_attr, True) +# elif isinstance(x, dict): +# for v in x.values(): +# self._apply(module, inp, v, backward) +# elif isinstance(x, list) or isinstance(x, tuple): +# for v in x: +# self._apply(module, inp, v, backward) + +# def fhook_fn(self, module, inp, output): +# if not self.has_printed_f: +# self._apply(module, inp, output, backward=False) + +# def bhook_fn(self, module, inp, output): +# if not self.has_printed_b: +# self._apply(module, inp, output, backward=True) + +# def close(self): +# for hook in self.fhooks + self.bhooks: +# hook.remove() diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/optim/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ecda871dc677289894e82e9f16cd562b3d7bb363 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/__init__.py @@ -0,0 +1,37 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from unicore import registry +from unicore.optim.unicore_optimizer import ( # noqa + UnicoreOptimizer, +) +from unicore.optim.fp16_optimizer import FP16Optimizer, separate_decay_params + +__all__ = [ + "UnicoreOptimizer", + "FP16Optimizer", +] + +(_build_optimizer, register_optimizer, OPTIMIZER_REGISTRY) = registry.setup_registry( + "--optimizer", base_class=UnicoreOptimizer, default="adam" +) + + +def build_optimizer(args, params, separate=True, *extra_args, **extra_kwargs): + if separate: + params = separate_decay_params(args, params) + return _build_optimizer(args, params, *extra_args, **extra_kwargs) + + +# automatically import any Python files in the optim/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("unicore.optim." + file_name) diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/adadelta.py b/MindChemistry/applications/Uni-Mol/unicore/optim/adadelta.py new file mode 100644 index 0000000000000000000000000000000000000000..e7086f9597067f9b181aee3c54b915b7aa75fb91 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/adadelta.py @@ -0,0 +1,95 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore.nn as nn + +from . import UnicoreOptimizer, register_optimizer + + +@register_optimizer("adadelta") +class Adadelta(UnicoreOptimizer): + def __init__(self, args, params): + super().__init__(args) + # 替换为MindSpore的Adadelta优化器,参数与PyTorch对应 + self._optimizer = nn.Adadelta( + params=params, + learning_rate=self.args.lr[0], + rho=self.args.adadelta_rho, + eps=self.args.adadelta_eps, + weight_decay=self.args.weight_decay + ) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # 参数定义与PyTorch版本保持一致,无需修改 + # fmt: off + parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', + help='coefficient used for computing a running average of squared gradients') + parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', + help='term added to the denominator to improve numerical stability') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') + # fmt: on + + @property + def optimizer_config(self): + """ + 返回用于覆盖检查点中存储的优化器参数的字典,与MindSpore Adadelta参数对应 + """ + return { + "learning_rate": self.args.lr[0], + "rho": self.args.adadelta_rho, + "eps": self.args.adadelta_eps, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + # MindSpore优化器支持扁平参数 + return True +# import torch.optim + +# from . import UnicoreOptimizer, register_optimizer + + +# @register_optimizer("adadelta") +# class Adadelta(UnicoreOptimizer): +# def __init__(self, args, params): +# super().__init__(args) +# self._optimizer = torch.optim.Adadelta(params, **self.optimizer_config) + +# @staticmethod +# def add_args(parser): +# """Add optimizer-specific arguments to the parser.""" +# # fmt: off +# parser.add_argument('--adadelta-rho', type=float, default=0.9, metavar='RHO', +# help='coefficient used for computing a running average of squared gradients') +# parser.add_argument('--adadelta-eps', type=float, default=1e-6, metavar='EPS', +# help='term added to the denominator to improve numerical stability') +# parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', +# help='weight decay') +# parser.add_argument('--anneal-eps', action='store_true', help='flag to anneal eps') +# # fmt: on + +# @property +# def optimizer_config(self): +# """ +# Return a kwarg dictionary that will be used to override optimizer +# args stored in checkpoints. This allows us to load a checkpoint and +# resume training using a different set of optimizer args, e.g., with a +# different learning rate. +# """ +# return { +# "lr": self.args.lr[0], +# "rho": self.args.adadelta_rho, +# "eps": self.args.adadelta_eps, +# "weight_decay": self.args.weight_decay, +# } + +# @property +# def supports_flat_params(self): +# return True diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/adagrad.py b/MindChemistry/applications/Uni-Mol/unicore/optim/adagrad.py new file mode 100644 index 0000000000000000000000000000000000000000..bd0a32da8b1774333e1f0116e3214f426ce103c1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/adagrad.py @@ -0,0 +1,79 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore.nn as nn + +from . import UnicoreOptimizer, register_optimizer + + +@register_optimizer("adagrad") +class Adagrad(UnicoreOptimizer): + def __init__(self, args, params): + super().__init__(args) + # 替换为MindSpore的Adagrad优化器,参数对应PyTorch版本 + self._optimizer = nn.Adagrad( + params=params, + learning_rate=self.args.lr[0], # MindSpore中学习率参数名为learning_rate + weight_decay=self.args.weight_decay + ) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # 参数定义与原PyTorch版本保持一致 + # fmt: off + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + 返回用于覆盖检查点中存储的优化器参数的字典,适配MindSpore Adagrad参数名 + """ + return { + "learning_rate": self.args.lr[0], # 调整学习率参数名为MindSpore对应的learning_rate + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + # 保持原逻辑,MindSpore Adagrad此处同样不支持扁平参数 + return False +# import torch.optim + +# from . import UnicoreOptimizer, register_optimizer + + +# @register_optimizer("adagrad") +# class Adagrad(UnicoreOptimizer): +# def __init__(self, args, params): +# super().__init__(args) +# self._optimizer = torch.optim.Adagrad(params, **self.optimizer_config) + +# @staticmethod +# def add_args(parser): +# """Add optimizer-specific arguments to the parser.""" +# # fmt: off +# parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', +# help='weight decay') +# # fmt: on + +# @property +# def optimizer_config(self): +# """ +# Return a kwarg dictionary that will be used to override optimizer +# args stored in checkpoints. This allows us to load a checkpoint and +# resume training using a different set of optimizer args, e.g., with a +# different learning rate. +# """ +# return { +# "lr": self.args.lr[0], +# "weight_decay": self.args.weight_decay, +# } + +# @property +# def supports_flat_params(self): +# return False diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/adam.py b/MindChemistry/applications/Uni-Mol/unicore/optim/adam.py new file mode 100644 index 0000000000000000000000000000000000000000..ce7615eec7b6e85078a0617eacb506017d9e966e --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/adam.py @@ -0,0 +1,385 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import math +from collections.abc import Collection +from typing import List + +import mindspore as ms +from mindspore import ops, Tensor +from mindspore.nn.optim import Optimizer +from unicore.optim import UnicoreOptimizer, register_optimizer + +logger = logging.getLogger(__name__) + + +@register_optimizer("adam") +class UnicoreAdam(UnicoreOptimizer): + """Adam optimizer for unicore, adapted for MindSpore. + + This implementation corresponds to the "AdamW" variant with weight decay, + analogous to PyTorch's AdamW but adapted for MindSpore and Ascend NPU. + """ + + def __init__(self, args, params): + super().__init__(args) + # 移除GPU相关的fused Adam逻辑(不适用Ascend NPU) + self._optimizer = Adam(params, **self.optimizer_config) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # 保持参数定义与原版本一致 + # fmt: off + parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', + help='betas for Adam optimizer') + parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', + help='epsilon for Adam optimizer') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """Return optimizer config dictionary for checkpoint resumption.""" + return { + "learning_rate": self.args.lr[0] + if isinstance(self.args.lr, Collection) + else self.args.lr, + "betas": eval(self.args.adam_betas), + "eps": self.args.adam_eps, + "weight_decay": self.args.weight_decay, + "amsgrad": False, # 原代码默认不启用amsgrad + } + + +class Adam(Optimizer): + r"""Implements Adam algorithm for MindSpore, adapted from PyTorch's AdamW. + + Args: + params (iterable): iterable of parameters to optimize or dicts defining + parameter groups + learning_rate (float, optional): learning rate (default: 1e-3) + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float, optional): term added to the denominator to improve + numerical stability (default: 1e-8) + weight_decay (float, optional): weight decay (L2 penalty) (default: 0) + amsgrad (boolean, optional): whether to use the AMSGrad variant (default: False) + """ + + def __init__( + self, + params, + learning_rate=1e-3, + betas=(0.9, 0.999), + eps=1e-8, + weight_decay=0, + amsgrad=False, + ): + defaults = dict( + learning_rate=learning_rate, + betas=betas, + eps=eps, + weight_decay=weight_decay, + amsgrad=amsgrad + ) + super(Adam, self).__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True # 支持混合精度 + + @property + def supports_flat_params(self): + return True # 支持扁平参数 + + def construct(self, gradients): + """Performs a single optimization step (MindSpore要求的construct方法)""" + for group_id, group in enumerate(self.param_groups): + lr = group["learning_rate"] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group["weight_decay"] + amsgrad = group["amsgrad"] + + # 遍历参数组中的参数 + for param_id, param in enumerate(group["params"]): + grad = gradients[group_id * len(group["params"]) + param_id] + if grad is None: + continue + + # 处理梯度数据类型(转为float32计算) + if grad.dtype in (ms.float16, ms.bfloat16): + grad = ops.cast(grad, ms.float32) + + # 处理参数数据类型(转为float32计算) + param_data = param.data + if param_data.dtype in (ms.float16, ms.bfloat16): + param_data_fp32 = ops.cast(param_data, ms.float32) + else: + param_data_fp32 = param_data + + # 初始化状态(替代PyTorch的state字典) + state = self.state[(group_id, param_id)] + if not state: + state["step"] = Tensor(0, ms.int32) + # 梯度的指数移动平均 + state["exp_avg"] = ops.zeros_like(param_data_fp32) + # 梯度平方的指数移动平均 + state["exp_avg_sq"] = ops.zeros_like(param_data_fp32) + if amsgrad: + # 用于amsgrad的最大平方移动平均 + state["max_exp_avg_sq"] = ops.zeros_like(param_data_fp32) + else: + # 确保状态张量在正确的设备和类型上 + state["exp_avg"] = ops.cast(state["exp_avg"], param_data_fp32.dtype) + state["exp_avg_sq"] = ops.cast(state["exp_avg_sq"], param_data_fp32.dtype) + if amsgrad: + state["max_exp_avg_sq"] = ops.cast(state["max_exp_avg_sq"], param_data_fp32.dtype) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + if amsgrad: + max_exp_avg_sq = state["max_exp_avg_sq"] + + # 步骤计数加1 + state["step"] += 1 + step = state["step"].asnumpy().item() + + # 计算一阶和二阶矩的移动平均(替换PyTorch的in-place操作) + exp_avg = exp_avg * beta1 + grad * (1 - beta1) + exp_avg_sq = exp_avg_sq * beta2 + ops.square(grad) * (1 - beta2) + + # 处理amsgrad + if amsgrad: + max_exp_avg_sq = ops.maximum(max_exp_avg_sq, exp_avg_sq) + denom = ops.sqrt(max_exp_avg_sq) + eps + else: + denom = ops.sqrt(exp_avg_sq) + eps + + # 偏差校正 + bias_correction1 = 1 - (beta1 ** step) + bias_correction2 = 1 - (beta2 ** step) + step_size = lr * math.sqrt(bias_correction2) / bias_correction1 + + # 权重衰减(AdamW风格) + if weight_decay != 0: + param_data_fp32 = param_data_fp32 - param_data_fp32 * weight_decay * lr + + # 参数更新 + param_data_fp32 = param_data_fp32 - (exp_avg / denom) * step_size + + # 类型转换回原始类型(如果需要) + if param.data.dtype in (ms.float16, ms.bfloat16): + param_data = ops.cast(param_data_fp32, param.data.dtype) + else: + param_data = param_data_fp32 + + # 更新参数和状态 + param.set_data(param_data) + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq + if amsgrad: + state["max_exp_avg_sq"] = max_exp_avg_sq +# import logging +# import math +# from collections.abc import Collection +# from typing import List + +# import torch +# import torch.optim +# from unicore.optim import UnicoreOptimizer, register_optimizer +# from unicore.optim.fused_adam import get_fused_adam_class + + +# logger = logging.getLogger(__name__) + + +# @register_optimizer("adam") +# class UnicoreAdam(UnicoreOptimizer): +# """Adam optimizer for unicore. + +# Important note: this optimizer corresponds to the "AdamW" variant of +# Adam in its weight decay behavior. As such, it is most closely +# analogous to torch.optim.AdamW from PyTorch. +# """ + +# def __init__(self, args, params): +# super().__init__(args) +# fused_adam_cls = get_fused_adam_class() +# use_fused_adam = ( +# not getattr(args, "use_old_adam", False) +# and fused_adam_cls is not None +# and torch.cuda.is_available() +# and torch.cuda.get_device_capability()[0] >= 7 +# ) +# if use_fused_adam: +# logger.info("using FusedAdam") +# self._optimizer = fused_adam_cls(params, **self.optimizer_config) +# else: +# self._optimizer = Adam(params, **self.optimizer_config) + +# @staticmethod +# def add_args(parser): +# """Add optimizer-specific arguments to the parser.""" +# # fmt: off +# parser.add_argument('--adam-betas', default='(0.9, 0.999)', metavar='B', +# help='betas for Adam optimizer') +# parser.add_argument('--adam-eps', type=float, default=1e-8, metavar='D', +# help='epsilon for Adam optimizer') +# parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', +# help='weight decay') +# # fmt: on + +# @property +# def optimizer_config(self): +# """ +# Return a kwarg dictionary that will be used to override optimizer +# args stored in checkpoints. This allows us to load a checkpoint and +# resume training using a different set of optimizer args, e.g., with a +# different learning rate. +# """ +# return { +# "lr": self.args.lr[0] +# if isinstance(self.args.lr, Collection) +# else self.args.lr, +# "betas": eval(self.args.adam_betas), +# "eps": self.args.adam_eps, +# "weight_decay": self.args.weight_decay, +# } + + +# class Adam(torch.optim.Optimizer): +# r"""Implements Adam algorithm. + +# This implementation is modified from torch.optim.Adam based on: +# `Fixed Weight Decay Regularization in Adam` +# (see https://arxiv.org/abs/1711.05101) + +# It has been proposed in `Adam: A Method for Stochastic Optimization`_. + +# Args: +# params (iterable): iterable of parameters to optimize or dicts defining +# parameter groups +# lr (float, optional): learning rate (default: 1e-3) +# betas (Tuple[float, float], optional): coefficients used for computing +# running averages of gradient and its square (default: (0.9, 0.999)) +# eps (float, optional): term added to the denominator to improve +# numerical stability (default: 1e-8) +# weight_decay (float, optional): weight decay (L2 penalty) (default: 0) +# amsgrad (boolean, optional): whether to use the AMSGrad variant of this +# algorithm from the paper `On the Convergence of Adam and Beyond`_ + +# .. _Adam\: A Method for Stochastic Optimization: +# https://arxiv.org/abs/1412.6980 +# .. _On the Convergence of Adam and Beyond: +# https://openreview.net/forum?id=ryQu7f-RZ +# """ + +# def __init__( +# self, +# params, +# lr=1e-3, +# betas=(0.9, 0.999), +# eps=1e-8, +# weight_decay=0, +# amsgrad=False, +# ): +# defaults = dict( +# lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad +# ) +# super(Adam, self).__init__(params, defaults) + +# @property +# def supports_memory_efficient_fp16(self): +# return True + +# @property +# def supports_flat_params(self): +# return True + +# def step(self, closure=None): +# """Performs a single optimization step. + +# Args: +# closure (callable, optional): A closure that reevaluates the model +# and returns the loss. +# """ +# loss = None +# if closure is not None: +# loss = closure() + +# for group in self.param_groups: +# for p in group["params"]: +# if p.grad is None: +# continue +# grad = p.grad.data +# if grad.dtype in {torch.float16, torch.bfloat16}: +# grad = grad.float() +# if grad.is_sparse: +# raise RuntimeError( +# "Adam does not support sparse gradients, please consider SparseAdam instead" +# ) +# amsgrad = group.get("amsgrad", False) + +# p_data_fp32 = p.data +# if p.data.dtype in {torch.float16, torch.bfloat16}: +# p_data_fp32 = p_data_fp32.float() + +# state = self.state[p] + +# # State initialization +# if len(state) == 0: +# state["step"] = 0 +# # Exponential moving average of gradient values +# state["exp_avg"] = torch.zeros_like(p_data_fp32) +# # Exponential moving average of squared gradient values +# state["exp_avg_sq"] = torch.zeros_like(p_data_fp32) +# if amsgrad: +# # Maintains max of all exp. moving avg. of sq. grad. values +# state["max_exp_avg_sq"] = torch.zeros_like(p_data_fp32) +# else: +# state["exp_avg"] = state["exp_avg"].to(p_data_fp32) +# state["exp_avg_sq"] = state["exp_avg_sq"].to(p_data_fp32) +# if amsgrad: +# state["max_exp_avg_sq"] = state["max_exp_avg_sq"].to( +# p_data_fp32 +# ) + +# exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] +# if amsgrad: +# max_exp_avg_sq = state["max_exp_avg_sq"] +# beta1, beta2 = group["betas"] + +# state["step"] += 1 + +# # Decay the first and second moment running average coefficient +# exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) +# exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) +# if amsgrad: +# # Maintains the maximum of all 2nd moment running avg. till now +# torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) +# # Use the max. for normalizing running avg. of gradient +# denom = max_exp_avg_sq.sqrt().add_(group["eps"]) +# else: +# denom = exp_avg_sq.sqrt().add_(group["eps"]) + +# bias_correction1 = 1 - beta1 ** state["step"] +# bias_correction2 = 1 - beta2 ** state["step"] +# step_size = group["lr"] * math.sqrt(bias_correction2) / bias_correction1 + +# if group["weight_decay"] != 0: +# p_data_fp32.add_( +# p_data_fp32, alpha=-group["weight_decay"] * group["lr"] +# ) + +# p_data_fp32.addcdiv_(exp_avg, denom, value=-step_size) + +# if p.data.dtype in {torch.float16, torch.bfloat16}: +# p.data.copy_(p_data_fp32) + +# return loss diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/dynamic_loss_scaler.py b/MindChemistry/applications/Uni-Mol/unicore/optim/dynamic_loss_scaler.py new file mode 100644 index 0000000000000000000000000000000000000000..fe502299a1990ead2140ea824f8d6b9c8b80cca7 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/dynamic_loss_scaler.py @@ -0,0 +1,71 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +class DynamicLossScaler(object): + def __init__( + self, + init_scale=2.0 ** 15, + scale_factor=2.0, + scale_window=2000, + tolerance=0.0, + threshold=None, + min_loss_scale=1e-4, + ): + self.loss_scale = init_scale + self.scale_factor = scale_factor + self.scale_window = scale_window + self.tolerance = tolerance + self.threshold = threshold + self._iter = 0 + self._last_overflow_iter = -1 + self._last_rescale_iter = -1 + self._overflows_since_rescale = 0 + self.min_loss_scale = min_loss_scale + + def scale(self, outputs): + return self.loss_scale * outputs + + def update(self): + if (self._iter - self._last_overflow_iter) % self.scale_window == 0: + self.loss_scale *= self.scale_factor + self._last_rescale_iter = self._iter + self._iter += 1 + + def _decrease_loss_scale(self): + self.loss_scale /= self.scale_factor + if self.threshold is not None: + self.loss_scale = max(self.loss_scale, self.threshold) + + def check_overflow(self, grad_norm): + # detect inf and nan + if grad_norm == float("inf") or grad_norm != grad_norm: + # overflow has occured + prev_scale = self.loss_scale + iter_since_rescale = self._iter - self._last_rescale_iter + + self._last_overflow_iter = self._iter + self._overflows_since_rescale += 1 + pct_overflow = self._overflows_since_rescale / float(iter_since_rescale) + if pct_overflow >= self.tolerance: + self._decrease_loss_scale() + self._last_rescale_iter = self._iter + self._overflows_since_rescale = 0 + + if self.loss_scale <= self.min_loss_scale: + # Use FloatingPointError as an uncommon error that parent + # functions can safely catch to stop training. + self.loss_scale = prev_scale + raise FloatingPointError( + ( + "Minimum loss scale reached ({}). Your loss is probably exploding. " + "Try lowering the learning rate, using gradient clipping or " + "increasing the batch size." + ).format(self.min_loss_scale) + ) + + self._iter += 1 + raise OverflowError("setting loss scale to: " + str(self.loss_scale)) diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/fp16_optimizer.py b/MindChemistry/applications/Uni-Mol/unicore/optim/fp16_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..97510e28e63a82ac065dce7686c7ef3823cf5d51 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/fp16_optimizer.py @@ -0,0 +1,748 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from collections import defaultdict + +import mindspore as ms +from mindspore import nn, ops, Tensor, Parameter +from mindspore.common.initializer import Zero +from unicore import optim +from unicore import utils + +from .dynamic_loss_scaler import DynamicLossScaler + + +def separate_decay_params(args, params): + if args.weight_decay <= 0: + return [{"params": [p for _, p in params if p.requires_grad]}] + + no_wd = ( + set(args.no_weight_decay_names.split(",")) + if args.no_weight_decay_names + else set() + ) + + def skip_decay(name, p): + return name.endswith(".bias") or p.ndim == 1 or any(nd in name for nd in no_wd) + + decay_params = [] + no_decay_params = [] + for name, p in params: + if not p.requires_grad: + continue + elif skip_decay(name, p): + no_decay_params.append(p) + else: + decay_params.append(p) + ret = [] + if len(decay_params) > 0: + ret.append({"params": decay_params}) + if len(no_decay_params) > 0: + ret.append({"params": no_decay_params, "weight_decay": 0.0}) + return ret + + +def check_param_device(params): + if len(params) <= 0: + return True + device = params[0].device + for i in range(1, len(params)): + assert device == params[i].device, "All parameters must be on the same device" + + +def pad_numel(numel, multiplier=2): + return (numel + multiplier - 1) // multiplier * multiplier + + +def _numel(shape): + """计算张量元素数量(替代PyTorch的numel())""" + return ops.prod(Tensor(shape, dtype=ms.int64)).asnumpy().item() + + +def flatten_orders(params): + dtype_grouped_params = {} + ordered_dtype = [] # 用于排序dtype + total_param_size = 0 + for p in params: + if p.dtype not in dtype_grouped_params: + dtype_grouped_params[p.dtype] = [] + ordered_dtype.append(p.dtype) + dtype_grouped_params[p.dtype].append(p) + total_param_size += pad_numel(_numel(p.shape)) + return dtype_grouped_params, ordered_dtype, total_param_size + + +# @ms.no_grad() +def flatten_parameters(params): + dtype_grouped_params, ordered_dtype, _ = flatten_orders(params) + + flatten_params = {} + for dtype in ordered_dtype: + cur_params = dtype_grouped_params[dtype] + total_param_size = sum(pad_numel(_numel(p.shape)) for p in cur_params) + # 创建空张量替代torch的new_zeros + flatten_params[dtype] = Parameter( + Tensor(shape=(total_param_size,), dtype=dtype, init=Zero()) + ) + offset = 0 + for p in cur_params: + numel = _numel(p.shape) + # 替换copy_为assign(MindSpore无in-place操作) + flatten_params[dtype].data[offset:offset+numel] = p.data.reshape(-1) + # 更新参数数据引用 + p.data = flatten_params[dtype].data[offset:offset+numel].reshape(p.shape) + offset += pad_numel(numel) + # 初始化梯度 + flatten_params[dtype].grad = Tensor(shape=(total_param_size,), dtype=dtype, init=Zero()) + offset = 0 + for p in cur_params: + numel = _numel(p.shape) + p.grad = flatten_params[dtype].grad[offset:offset+numel].reshape(p.shape) + offset += pad_numel(numel) + ms.clear_cache() # 替代PyTorch的cuda缓存清理 + return [flatten_params[dtype] for dtype in ordered_dtype] + + +# @ms.no_grad() +def flatten_parameters_fp32(params, set_to_param=False, set_grad=True): + dtype_grouped_params, ordered_dtype, total_param_size = flatten_orders(params) + device = params[0].device if params else ms.get_context("device_target") + + # 替代torch.zeros,指定设备为Ascend + flatten_params = Tensor( + shape=(total_param_size,), dtype=ms.float32, device=device, init=Zero() + ) + offset = 0 + for dtype in ordered_dtype: + cur_params = dtype_grouped_params[dtype] + for p in cur_params: + numel = _numel(p.shape) + # 复制数据 + flatten_params[offset:offset+numel] = p.data.reshape(-1).astype(ms.float32) + if set_to_param: + p.data = flatten_params[offset:offset+numel].reshape(p.shape).astype(p.dtype) + p.grad = None # 清除梯度引用 + offset += pad_numel(numel) + flatten_params = Parameter(flatten_params) + if set_grad: + flatten_params.grad = Tensor(shape=flatten_params.shape, dtype=ms.float32, init=Zero()) + ms.clear_cache() + return flatten_params + + +def get_fp16_params(args, params): + param_group = separate_decay_params(args, params) + fp16_group = [] + fp32_group = [] + for param_dict in param_group: + params = param_dict["params"] + check_param_device(params) + fp16_params = flatten_parameters(params) + fp32_params = flatten_parameters_fp32(params) + fp16_group.append({"params": fp16_params}) + param_dict["params"] = [fp32_params] + fp32_group.append(param_dict) + return fp16_group, fp32_group + + +class _FP16OptimizerMixin(object): + def __init__(self, args, **kwargs): + super().__init__(args,** kwargs) + self._multiply_factor = 1.0 + self.bf16_sr = getattr(args, "bf16_sr", False) + + def state_dict(self): + state_dict = self.fp32_optimizer.state_dict() + if self.scaler is not None: + state_dict["loss_scale"] = self.scaler.loss_scale + return state_dict + + def load_state_dict(self, state_dict, optimizer_overrides=None): + if "loss_scale" in state_dict and self.scaler is not None: + self.scaler.loss_scale = state_dict["loss_scale"] + self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + + def backward(self, loss): + if self.scaler is not None: + loss = self.scaler.scale(loss) + loss.backward() # MindSpore反向传播接口 + self._needs_sync = True + + def _sync_fp16_grads_to_fp32(self): + with ms.no_grad(): + if self._needs_sync: + for gid in range(len(self.fp16_params)): + offset = 0 + for p in self.fp16_params[gid]["params"]: + numel = _numel(p.shape) + # 梯度复制(替代copy_) + self.fp32_params[gid]["params"][0].grad[offset:offset+numel] = p.grad.reshape(-1) + offset += pad_numel(numel) + self._needs_sync = False + + def _add_fp16_grads_to_fp32(self, mul=0.0): + with ms.no_grad(): + for gid in range(len(self.fp16_params)): + offset = 0 + for p in self.fp16_params[gid]["params"]: + numel = _numel(p.shape) + # 梯度累加 + self.fp32_params[gid]["params"][0].grad[offset:offset+numel] += mul * p.grad.astype(ms.float32).reshape(-1) + p.grad.set_data(ops.zeros_like(p.grad)) # 清零梯度 + offset += pad_numel(numel) + self._needs_sync = False + + def _sync_fp32_params_to_fp16(self): + # 将FP32参数复制回FP16模型 + for gid in range(len(self.fp16_params)): + offset = 0 + for p in self.fp16_params[gid]["params"]: + numel = _numel(p.shape) + u = self.fp32_params[gid]["params"][0].data[offset:offset+numel].reshape(p.shape) + if self.bf16_sr and p.dtype == ms.bfloat16: + utils.fp32_to_bf16_sr(u, p) # 假设utils接口兼容 + else: + p.data = u.astype(p.dtype) # 类型转换后赋值 + offset += pad_numel(numel) + + def _unscale_grads(self): + self._sync_fp16_grads_to_fp32() + if ( + ms.is_tensor(self._multiply_factor) + or self._multiply_factor != 1.0 + ): + self.fp32_optimizer.multiply_grads(self._multiply_factor) + self._multiply_factor = 1.0 + + def multiply_grads(self, c): + if self._needs_sync: + self._multiply_factor *= c + else: + self.fp32_optimizer.multiply_grads(c) + + def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + if max_norm <= 0.0: + return 0.0 + all_fp16_params = [] + for p in self.fp16_params: + all_fp16_params.extend(p["params"]) + grad_norm = self._multiply_factor * utils.clip_grad_norm_( + all_fp16_params, 0, aggregate_norm_fn + ) + if grad_norm > max_norm > 0.0: + clip_coef = max_norm / (grad_norm + 1e-6) + else: + clip_coef = 1.0 + self._add_fp16_grads_to_fp32(mul=clip_coef) + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + self._sync_fp16_grads_to_fp32() + grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( + 0, + aggregate_norm_fn=aggregate_norm_fn, + ) + + if self.scaler is not None: + if grad_norm > max_norm > 0.0: + self._multiply_factor *= max_norm / grad_norm + + self.scaler.check_overflow(grad_norm) + elif max_norm > 0.0: + clip_coef = (max_norm / (grad_norm + 1e-6)).clip(max=1) + self._multiply_factor *= clip_coef + + return grad_norm + + def step(self, closure=None, groups=None): + self._sync_fp16_grads_to_fp32() + if getattr(self, "supports_step_with_scale", False): + self.fp32_optimizer.step( + closure, scale=(1.0 / self._multiply_factor), groups=groups + ) + else: + self._unscale_grads() + self.fp32_optimizer.step(closure, groups=groups) + + if self.scaler is not None: + self.scaler.update() + + self._sync_fp32_params_to_fp16() + + def zero_grad(self): + def zero(group): + for x in group: + for p in x["params"]: + if p.grad is not None: + p.grad.set_data(ops.zeros_like(p.grad)) # 梯度清零 + + zero(self.fp16_params) + zero(self.fp32_params) + self._needs_sync = False + + if self.scaler is not None: + self._multiply_factor = 1.0 / float(self.scaler.loss_scale) + else: + self._multiply_factor = 1.0 + + +class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer): + """适配MindSpore的FP16优化器封装""" + def __init__(self, args, params, fp32_optimizer, fp32_params, **kwargs): + super().__init__(args) + self.fp16_params = params + self.fp32_optimizer = fp32_optimizer + self.fp32_params = fp32_params + self.allreduce_fp32_grad = getattr(args, "allreduce_fp32_grad", False) + + if getattr(args, "fp16_scale_window", None) is None: + if len(args.update_freq) > 1: + raise ValueError( + "--fp16-scale-window must be given explicitly when using a " + "custom --update-freq schedule" + ) + data_parallel_size = int(args.distributed_world_size) + scale_window = int(2**14 / data_parallel_size / args.update_freq[0]) + else: + scale_window = args.fp16_scale_window + + if not getattr(args, "bf16", False): + self.scaler = DynamicLossScaler( + init_scale=args.fp16_init_scale, + scale_window=scale_window, + tolerance=args.fp16_scale_tolerance, + threshold=args.threshold_loss_scale, + min_loss_scale=args.min_loss_scale, + ) + else: + self.scaler = None + + @classmethod + def build_optimizer(cls, args, params, **kwargs): + flatten = not getattr(args, "fp16_no_flatten_grads", False) + assert flatten + fp16_group, fp32_group = get_fp16_params(args, params) + fp32_optimizer = optim.build_optimizer(args, fp32_group, separate=False) + return cls(args, fp16_group, fp32_optimizer, fp32_group, **kwargs) + + @property + def optimizer(self): + return self.fp32_optimizer.optimizer + + @optimizer.setter + def optimizer(self, optimizer): + self.fp32_optimizer.optimizer = optimizer + + @property + def lr_scheduler(self): + return getattr(self.fp32_optimizer, "lr_scheduler", None) + + @property + def optimizer_config(self): + return self.fp32_optimizer.optimizer_config + + def get_lr(self): + return self.fp32_optimizer.get_lr() + + def set_lr(self, lr): + self.fp32_optimizer.set_lr(lr) + + def all_reduce_grads(self, module): + if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"): + self._sync_fp16_grads_to_fp32() + with ms.no_grad(): + params = [x["params"][0] for x in self.fp32_params] + module.all_reduce_params(params) + else: + self.fp32_optimizer.all_reduce_grads(module) + + @property + def supports_flat_params(self): + return self.fp32_optimizer.supports_flat_params +# from collections import defaultdict + +# import torch +# from unicore import optim +# from unicore import utils + +# from .dynamic_loss_scaler import DynamicLossScaler + + +# def separate_decay_params(args, params): +# if args.weight_decay <= 0: +# return [{"params": [p for _, p in params if p.requires_grad]}] + +# no_wd = ( +# set(args.no_weight_decay_names.split(",")) +# if args.no_weight_decay_names +# else set() +# ) + +# def skip_decay(name, p): +# return name.endswith(".bias") or p.ndim == 1 or any(nd in name for nd in no_wd) + +# decay_params = [] +# no_decay_params = [] +# for name, p in params: +# if not p.requires_grad: +# continue +# elif skip_decay(name, p): +# no_decay_params.append(p) +# else: +# decay_params.append(p) +# ret = [] +# if len(decay_params) > 0: +# ret.append({"params": decay_params}) +# if len(no_decay_params) > 0: +# ret.append({"params": no_decay_params, "weight_decay": 0.0}) +# return ret + + +# def check_param_device(params): +# if len(params) <= 0: +# return True +# device = params[0].device +# for i in range(1, len(params)): +# assert device == params[i].device + + +# def pad_numel(numel, multiplier=2): +# return (numel + multiplier - 1) // multiplier * multiplier + + +# def flatten_orders(params): +# dtype_grouped_params = {} +# ordered_dtype = [] # for sort dtype +# total_param_size = 0 +# for p in params: +# if p.dtype not in dtype_grouped_params: +# dtype_grouped_params[p.dtype] = [] +# ordered_dtype.append(p.dtype) +# dtype_grouped_params[p.dtype].append(p) +# total_param_size += pad_numel(p.data.numel()) +# return dtype_grouped_params, ordered_dtype, total_param_size + + +# @torch.no_grad() +# def flatten_parameters(params): +# dtype_grouped_params, ordered_dtype, _ = flatten_orders(params) + +# flatten_params = {} +# for dtype in ordered_dtype: +# cur_params = dtype_grouped_params[dtype] +# total_param_size = sum(pad_numel(p.data.numel()) for p in cur_params) +# flatten_params[dtype] = ( +# cur_params[0].new(0).type(dtype).new_zeros(total_param_size) +# ) +# offset = 0 +# for p in cur_params: +# numel = p.data.numel() +# flatten_params[dtype][offset : offset + numel].copy_(p.data.view(-1)) +# p.data = flatten_params[dtype].data[offset : offset + numel].view(*p.shape) +# offset += pad_numel(numel) +# flatten_params[dtype] = torch.nn.Parameter(flatten_params[dtype]) +# flatten_params[dtype].grad = flatten_params[dtype].data.new(total_param_size) +# offset = 0 +# for p in cur_params: +# numel = p.data.numel() +# p.grad = flatten_params[dtype].grad[offset : offset + numel].view(*p.shape) +# offset += pad_numel(numel) +# torch.cuda.empty_cache() +# return [flatten_params[dtype] for dtype in ordered_dtype] + + +# @torch.no_grad() +# def flatten_parameters_fp32(params, set_to_param=False, set_grad=True): +# dtype_grouped_params, ordered_dtype, total_param_size = flatten_orders(params) + +# flatten_params = torch.zeros( +# total_param_size, dtype=torch.float32, device=params[0].device +# ) +# offset = 0 +# for dtype in ordered_dtype: +# cur_params = dtype_grouped_params[dtype] +# for p in cur_params: +# numel = p.data.numel() +# flatten_params[offset : offset + numel].copy_(p.data.view(-1)) +# if set_to_param: +# p.data = flatten_params.data[offset : offset + numel].view(*p.shape) +# # set to None here, it will throw error when using this incorrectly +# p.grad = None +# offset += pad_numel(numel) +# flatten_params = torch.nn.Parameter(flatten_params) +# if set_grad: +# flatten_params.grad = torch.zeros_like(flatten_params) +# torch.cuda.empty_cache() +# return flatten_params + + +# def get_fp16_params(args, params): +# param_group = separate_decay_params(args, params) +# fp16_group = [] +# fp32_group = [] +# for param_dict in param_group: +# params = param_dict["params"] +# check_param_device(params) +# fp16_params = flatten_parameters(params) +# fp32_params = flatten_parameters_fp32(params) +# fp16_group.append({"params": fp16_params}) +# param_dict["params"] = [fp32_params] +# fp32_group.append(param_dict) +# return fp16_group, fp32_group + + +# class _FP16OptimizerMixin(object): +# def __init__(self, args, **kwargs): +# # forward __init__ call to the next class in mro(method resolution order) +# super().__init__(args, **kwargs) +# self._multiply_factor = 1.0 +# self.bf16_sr = getattr(args, "bf16_sr", False) + +# def state_dict(self): +# """Return the optimizer's state dict.""" +# state_dict = self.fp32_optimizer.state_dict() +# if self.scaler is not None: +# state_dict["loss_scale"] = self.scaler.loss_scale +# return state_dict + +# def load_state_dict(self, state_dict, optimizer_overrides=None): +# """Load an optimizer state dict. +# In general we should prefer the configuration of the existing optimizer +# instance (e.g., learning rate) over that found in the state_dict. This +# allows us to resume training from a checkpoint using a new set of +# optimizer args. +# """ +# if "loss_scale" in state_dict and self.scaler is not None: +# self.scaler.loss_scale = state_dict["loss_scale"] +# self.fp32_optimizer.load_state_dict(state_dict, optimizer_overrides) + +# def backward(self, loss): +# """Computes the sum of gradients of the given tensor w.r.t. graph leaves. +# Compared to :func:`unicore.optim.UnicoreOptimizer.backward`, this +# function additionally dynamically scales the loss to avoid gradient +# underflow. +# """ +# if self.scaler is not None: +# loss = self.scaler.scale(loss) +# loss.backward() +# self._needs_sync = True + +# def _sync_fp16_grads_to_fp32(self): +# with torch.no_grad(): +# if self._needs_sync: +# for gid in range(len(self.fp16_params)): +# offset = 0 +# for p in self.fp16_params[gid]["params"]: +# numel = p.numel() +# self.fp32_params[gid]["params"][0].grad.data[ +# offset : offset + numel +# ].copy_(p.grad.data.view(-1)) +# offset += pad_numel(numel) +# self._needs_sync = False + +# def _add_fp16_grads_to_fp32(self, mul=0.0): +# with torch.no_grad(): +# for gid in range(len(self.fp16_params)): +# offset = 0 +# for p in self.fp16_params[gid]["params"]: +# numel = p.numel() +# self.fp32_params[gid]["params"][0].grad.data[ +# offset : offset + numel +# ] += mul * p.grad.data.float().view(-1) +# p.grad.zero_() +# offset += pad_numel(numel) +# self._needs_sync = False + +# def _sync_fp32_params_to_fp16(self): +# # copy FP32 params back into FP16 model +# for gid in range(len(self.fp16_params)): +# offset = 0 +# for p in self.fp16_params[gid]["params"]: +# numel = p.numel() +# u = ( +# self.fp32_params[gid]["params"][0] +# .data[offset : offset + numel] +# .view_as(p.data) +# ) +# if self.bf16_sr and p.dtype == torch.bfloat16: +# utils.fp32_to_bf16_sr(u, p) +# else: +# p.data.copy_(u) +# offset += pad_numel(numel) + +# def _unscale_grads(self): +# self._sync_fp16_grads_to_fp32() +# if ( +# # Skip the multiplication if it's a no-op (i.e., if _multiply_factor +# # is 1.0). At the same time, we want to avoid the device-to-host +# # transfer by comparing it to 1.0. Since _multiply_factor starts as +# # a Python float, we roughly assume that if it's a tensor then it's +# # probably not =1.0 anymore and we do the multiplication. Otherwise +# # we can safely check the value without a D2H transfer. +# torch.is_tensor(self._multiply_factor) +# or self._multiply_factor != 1.0 +# ): +# self.fp32_optimizer.multiply_grads(self._multiply_factor) +# self._multiply_factor = 1.0 + +# def multiply_grads(self, c): +# """Multiplies grads by a constant ``c``.""" +# if self._needs_sync: +# self._multiply_factor *= c +# else: +# # gradients already synced to fp32 parameters, update it directly +# self.fp32_optimizer.multiply_grads(c) + +# def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None): +# """Clips gradient norm.""" +# if max_norm <= 0.0: +# return 0.0 +# all_fp16_params = defaultdict(list) +# for p in self.fp16_params: +# all_fp16_params.extend(p["params"]) +# grad_norm = self._multiply_factor * utils.clip_grad_norm_( +# all_fp16_params, 0, aggregate_norm_fn +# ) +# # grad_norm = 1.0 +# if grad_norm > max_norm > 0.0: +# clip_coef = max_norm / (grad_norm + 1e-6) +# else: +# clip_coef = 1.0 +# self._add_fp16_grads_to_fp32(mul=clip_coef) + +# def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): +# """Clips gradient norm and updates dynamic loss scaler.""" +# self._sync_fp16_grads_to_fp32() +# grad_norm = self._multiply_factor * self.fp32_optimizer.clip_grad_norm( +# 0, +# aggregate_norm_fn=aggregate_norm_fn, +# ) + +# if self.scaler is not None: +# if grad_norm > max_norm > 0.0: +# self._multiply_factor *= max_norm / grad_norm + +# self.scaler.check_overflow(grad_norm) +# elif max_norm > 0.0: +# clip_coef = (max_norm / (grad_norm + 1e-6)).clamp_(max=1) +# self._multiply_factor *= clip_coef + +# return grad_norm + +# def step(self, closure=None, groups=None): +# """Performs a single optimization step.""" +# self._sync_fp16_grads_to_fp32() +# if getattr(self, "supports_step_with_scale", False): +# self.fp32_optimizer.step( +# closure, scale=(1.0 / self._multiply_factor), groups=groups +# ) +# else: +# self._unscale_grads() +# self.fp32_optimizer.step(closure, groups=groups) + +# if self.scaler is not None: +# self.scaler.update() + +# self._sync_fp32_params_to_fp16() + +# def zero_grad(self): +# """Clears the gradients of all optimized parameters.""" + +# def zero(group): +# for x in group: +# for p in x["params"]: +# p.grad.zero_() + +# zero(self.fp16_params) +# zero(self.fp32_params) +# self._needs_sync = False + +# if self.scaler is not None: +# self._multiply_factor = 1.0 / float(self.scaler.loss_scale) +# else: +# self._multiply_factor = 1.0 + + +# class FP16Optimizer(_FP16OptimizerMixin, optim.UnicoreOptimizer): +# """ +# Wrap an *optimizer* to support FP16 (mixed precision) training. +# """ + +# def __init__(self, args, params, fp32_optimizer, fp32_params, **kwargs): +# super().__init__(args) +# self.fp16_params = params +# self.fp32_optimizer = fp32_optimizer +# self.fp32_params = fp32_params +# self.allreduce_fp32_grad = getattr(args, "allreduce_fp32_grad", False) + +# if getattr(args, "fp16_scale_window", None) is None: +# if len(args.update_freq) > 1: +# raise ValueError( +# "--fp16-scale-window must be given explicitly when using a " +# "custom --update-freq schedule" +# ) +# data_parallel_size = int(args.distributed_world_size) +# scale_window = int(2**14 / data_parallel_size / args.update_freq[0]) +# else: +# scale_window = args.fp16_scale_window + +# if not getattr(args, "bf16", False): +# self.scaler = DynamicLossScaler( +# init_scale=args.fp16_init_scale, +# scale_window=scale_window, +# tolerance=args.fp16_scale_tolerance, +# threshold=args.threshold_loss_scale, +# min_loss_scale=args.min_loss_scale, +# ) +# else: +# # disable loss scaling for bfloat16 +# self.scaler = None + +# @classmethod +# def build_optimizer(cls, args, params, **kwargs): +# """ +# Args: +# args : unicore args +# params (iterable): iterable of parameters to optimize +# """ +# flatten = not getattr(args, "fp16_no_flatten_grads", False) +# assert flatten +# fp16_group, fp32_group = get_fp16_params(args, params) +# fp32_optimizer = optim.build_optimizer(args, fp32_group, separate=False) +# return cls(args, fp16_group, fp32_optimizer, fp32_group, **kwargs) + +# @property +# def optimizer(self): +# return self.fp32_optimizer.optimizer + +# @optimizer.setter +# def optimizer(self, optimizer): +# self.fp32_optimizer.optimizer = optimizer + +# @property +# def lr_scheduler(self): +# return getattr(self.fp32_optimizer, "lr_scheduler", None) + +# @property +# def optimizer_config(self): +# return self.fp32_optimizer.optimizer_config + +# def get_lr(self): +# return self.fp32_optimizer.get_lr() + +# def set_lr(self, lr): +# self.fp32_optimizer.set_lr(lr) + +# def all_reduce_grads(self, module): +# if self.allreduce_fp32_grad and hasattr(module, "all_reduce_params"): +# self._sync_fp16_grads_to_fp32() +# with torch.no_grad(): +# params = [x["params"][0] for x in self.fp32_params] +# module.all_reduce_params(params) +# else: +# self.fp32_optimizer.all_reduce_grads(module) + +# @property +# def supports_flat_params(self): +# return self.fp32_optimizer.supports_flat_params diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/fused_adam.py b/MindChemistry/applications/Uni-Mol/unicore/optim/fused_adam.py new file mode 100644 index 0000000000000000000000000000000000000000..d45da7308d46ce2c2ec87ae6f6051b96133a0d35 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/fused_adam.py @@ -0,0 +1,250 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +from mindspore.nn.optim import Optimizer + + +def get_fused_adam_class(): + """Ascend NPU不支持基于CUDA的FusedAdam,返回None""" + return None + + +class FusedAdam(Optimizer): + """ + 适配MindSpore的FusedAdam类(Ascend NPU环境下不支持CUDA融合操作,仅保留结构) + + 注意:原PyTorch的FusedAdam依赖CUDA扩展,Ascend NPU无对应实现,此版本仅作为兼容性占位 + """ + + def __init__(self, params, + learning_rate=1e-3, bias_correction=True, + betas=(0.9, 0.999), eps=1e-8, + weight_decay=0., amsgrad=False): + if amsgrad: + raise RuntimeError("FusedAdam does not support the AMSGrad variant.") + defaults = { + "learning_rate": learning_rate, + "bias_correction": bias_correction, + "betas": betas, + "eps": eps, + "weight_decay": weight_decay, + } + super().__init__(params, defaults) + + @property + def supports_memory_efficient_fp16(self): + return True + + @property + def supports_flat_params(self): + return True + + @property + def supports_step_with_scale(self): + return True + + def construct(self, gradients, scale=1.): + """执行优化步骤(Ascend NPU不支持融合操作,此处简化实现)""" + for group_id, group in enumerate(self.param_groups): + combined_scale = scale + bias_correction = 1 if group.get("bias_correction", 1) else 0 + lr = group["learning_rate"] + beta1, beta2 = group["betas"] + eps = group["eps"] + weight_decay = group["weight_decay"] + + for param_id, param in enumerate(group["params"]): + grad = gradients[group_id * len(group["params"]) + param_id] + if grad is None: + continue + + # 初始化状态 + state = self.state[(group_id, param_id)] + if not state: + state["step"] = ms.Tensor(0, dtype=ms.int32) + # 梯度的指数移动平均(使用float32提高精度) + state["exp_avg"] = ms.ops.zeros_like(param, dtype=ms.float32) + # 梯度平方的指数移动平均 + state["exp_avg_sq"] = ms.ops.zeros_like(param, dtype=ms.float32) + else: + state["exp_avg"] = ms.ops.cast(state["exp_avg"], ms.float32) + state["exp_avg_sq"] = ms.ops.cast(state["exp_avg_sq"], ms.float32) + + exp_avg = state["exp_avg"] + exp_avg_sq = state["exp_avg_sq"] + state["step"] += 1 + step = state["step"].asnumpy().item() + + # 转换梯度和参数为float32计算(类似原逻辑) + grad_fp32 = ms.ops.cast(grad, ms.float32) if grad.dtype != ms.float32 else grad + param_fp32 = ms.ops.cast(param, ms.float32) if param.dtype != ms.float32 else param + + # 计算移动平均(替代原CUDA融合操作) + exp_avg = exp_avg * beta1 + grad_fp32 * (1 - beta1) + exp_avg_sq = exp_avg_sq * beta2 + ms.ops.square(grad_fp32) * (1 - beta2) + + # 偏差校正 + if bias_correction: + bias_correction1 = 1 - (beta1 ** step) + bias_correction2 = 1 - (beta2 ** step) + step_size = lr * ms.ops.sqrt(bias_correction2) / bias_correction1 + else: + step_size = lr + + # 权重衰减 + if weight_decay != 0: + param_fp32 = param_fp32 - param_fp32 * weight_decay * lr + + # 应用梯度缩放 + scaled_exp_avg = exp_avg / combined_scale + # 参数更新 + param_fp32 = param_fp32 - (scaled_exp_avg / (ms.ops.sqrt(exp_avg_sq) + eps)) * step_size + + # 转回原始数据类型 + param.set_data(ms.ops.cast(param_fp32, param.dtype)) + # 更新状态 + state["exp_avg"] = exp_avg + state["exp_avg_sq"] = exp_avg_sq +# import torch + + +# def get_fused_adam_class(): +# try: +# global unicore_fused_adam +# import importlib +# unicore_fused_adam = importlib.import_module("unicore_fused_adam") +# return FusedAdam +# except ImportError: +# pass +# return None + + +# class FusedAdam(torch.optim.Optimizer): +# """ +# Implements Adam algorithm. Currently GPU-only. Requires Apex to be installed via +# ``python setup.py install --cuda_ext --cpp_ext``. + +# It has been proposed in `Adam: A Method for Stochastic Optimization`_. + +# Compared to the original version in Apex, the unicore version casts grads +# and params to FP32 internally to support ``--memory-efficient-fp16``. + +# Arguments: +# params (iterable): iterable of parameters to optimize or dicts defining +# parameter groups. +# lr (float, optional): learning rate. (default: 1e-3) +# betas (Tuple[float, float], optional): coefficients used for computing +# running averages of gradient and its square. (default: (0.9, 0.999)) +# eps (float, optional): term added to the denominator to improve +# numerical stability. (default: 1e-8) +# weight_decay (float, optional): weight decay (L2 penalty) (default: 0) +# amsgrad (boolean, optional): whether to use the AMSGrad variant of this +# algorithm from the paper `On the Convergence of Adam and Beyond`_ +# (default: False) NOT SUPPORTED in FusedAdam! +# eps_inside_sqrt (boolean, optional): in the "update parameters" step, +# adds eps to the bias-corrected second moment estimate before +# evaluating square root instead of adding it to the square root of +# second moment estimate as in the original paper. (default: False) +# .. _Adam: A Method for Stochastic Optimization: +# https://arxiv.org/abs/1412.6980 +# .. _On the Convergence of Adam and Beyond: +# https://openreview.net/forum?id=ryQu7f-RZ +# """ + +# def __init__(self, params, +# lr=1e-3, bias_correction=True, +# betas=(0.9, 0.999), eps=1e-8, +# weight_decay=0., amsgrad=False): +# global unicore_fused_adam +# import importlib +# unicore_fused_adam = importlib.import_module("unicore_fused_adam") + +# if amsgrad: +# raise RuntimeError("FusedAdam does not support the AMSGrad variant.") +# defaults = { +# "lr": lr, +# "bias_correction": bias_correction, +# "betas": betas, +# "eps": eps, +# "weight_decay": weight_decay, +# } +# super().__init__(params, defaults) + +# @property +# def supports_memory_efficient_fp16(self): +# return True + +# @property +# def supports_flat_params(self): +# return True + +# @property +# def supports_step_with_scale(self): +# return True + +# def step(self, closure=None, scale=1.): +# """Performs a single optimization step. +# Arguments: +# closure (callable, optional): A closure that reevaluates the model +# and returns the loss. +# scale (float, optional): factor to divide gradient tensor values +# by before applying to weights. (default: 1) +# """ +# loss = None +# if closure is not None: +# loss = closure() + +# for group in self.param_groups: +# # compute combined scale factor for this group +# combined_scale = scale +# bias_correction = 1 if group.get("bias_correction", 1) else 0 + +# for p in group["params"]: +# if p.grad is None: +# continue +# grad = p.grad.data +# if grad.is_sparse: +# raise RuntimeError( +# "FusedAdam does not support sparse gradients, " +# "please consider SparseAdam instead" +# ) + +# state = self.state[p] + +# # State initialization +# if len(state) == 0: +# state["step"] = 0 +# # Exponential moving average of gradient values +# state["exp_avg"] = torch.zeros_like(p.data, dtype=torch.float) +# # Exponential moving average of squared gradient values +# state["exp_avg_sq"] = torch.zeros_like(p.data, dtype=torch.float) +# else: +# state["exp_avg"] = state["exp_avg"].to(dtype=torch.float) +# state["exp_avg_sq"] = state["exp_avg_sq"].to(dtype=torch.float) + +# exp_avg = state["exp_avg"] +# exp_avg_sq = state["exp_avg_sq"] +# beta1, beta2 = group["betas"] + +# state["step"] += 1 + +# with torch.cuda.device(p.device): +# unicore_fused_adam.adam(p.data, +# exp_avg, +# exp_avg_sq, +# grad, +# group["lr"], +# beta1, +# beta2, +# group["eps"], +# combined_scale, +# state["step"], +# bias_correction, +# group["weight_decay"]) + +# return loss + + diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6cc7177b516821f665ab738727e3da2a597063 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/__init__.py @@ -0,0 +1,34 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import importlib +import os + +from unicore import registry +from unicore.optim.lr_scheduler.unicore_lr_scheduler import ( # noqa + UnicoreLRScheduler, +) + + +( + build_lr_scheduler_, + register_lr_scheduler, + LR_SCHEDULER_REGISTRY, +) = registry.setup_registry( + "--lr-scheduler", base_class=UnicoreLRScheduler, default="fixed" +) + + +def build_lr_scheduler(args, optimizer, total_train_steps): + return build_lr_scheduler_(args, optimizer, total_train_steps) + + +# automatically import any Python files in the optim/lr_scheduler/ directory +for file in os.listdir(os.path.dirname(__file__)): + if file.endswith(".py") and not file.startswith("_"): + file_name = file[: file.find(".py")] + importlib.import_module("unicore.optim.lr_scheduler." + file_name) diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/cosine_lr_scheduler.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/cosine_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..2c5b6b8e2e903b4e279ed703fc50dfd45a8a03ad --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/cosine_lr_scheduler.py @@ -0,0 +1,140 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from collections.abc import Collection +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("cosine") +class CosineLRSchedule(UnicoreLRScheduler): + """Assign LR based on a cyclical schedule that follows the cosine function. + + See https://arxiv.org/pdf/1608.03983.pdf for details. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + max learning rate (``--lr``). + + During warmup:: + + lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + lr = args.min_lr + 0.5*(args.lr - args.min_lr)*(1 + cos(t_curr / t_i)) + + where ``t_curr`` is current percentage of updates within the current period + range and ``t_i`` is the current period range, which is scaled by ``t_mul`` + after every iteration. + """ + + def __init__(self, args, unicore_optimizer, total_train_steps): + super().__init__(args, unicore_optimizer, total_train_steps) + if isinstance(args.lr, Collection) and len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with cosine." + f" Consider --lr-scheduler=fixed instead. ({args.lr})" + ) + + self.max_lr = args.lr[0] if isinstance(args.lr, Collection) else args.lr + assert ( + self.max_lr > args.min_lr + ), f"max_lr (={args.lr}) must be more than min_lr (={args.min_lr})" + + assert total_train_steps is not None + if self.args.warmup_ratio > 0: + self.warmup_updates = int(self.args.warmup_ratio * total_train_steps) + else: + self.warmup_updates = args.warmup_updates + + warmup_end_lr = self.max_lr + if args.warmup_init_lr < 0: + args.warmup_init_lr = args.min_lr + + self.t_mult = args.t_mult + self.period = args.lr_period_updates + + if self.period <= 0: + self.period = total_train_steps - self.warmup_updates + + if self.warmup_updates > 0: + # linearly warmup for the first args.warmup_updates + self.lr_step = (warmup_end_lr - args.warmup_init_lr) / self.warmup_updates + else: + self.lr_step = 1 + + self.lr_shrink = args.lr_shrink + + # initial learning rate + self.lr = args.warmup_init_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--warmup-ratio', default=-1.0, type=float, metavar='N', + help='warmup the learning rate linearly for the first N-percent updates') + parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', + help='initial learning rate during warmup phase; default is args.lr') + parser.add_argument('--min-lr', type=float, metavar='LR', + help='min learning rate') + parser.add_argument('--max-lr', type=float, metavar='LR', + help='max learning rate, must be more than args.lr') + parser.add_argument('--t-mult', default=1, type=float, metavar='LR', + help='factor to grow the length of each period') + parser.add_argument('--lr-period-updates', default=-1, type=float, metavar='LR', + help='initial number of updates per period') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing') + # fmt: on + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.warmup_updates: + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + else: + curr_updates = num_updates - self.warmup_updates + if self.t_mult != 1: + i = math.floor( + math.log( + 1 - curr_updates / self.period * (1 - self.t_mult), self.t_mult + ) + ) + t_i = self.t_mult**i * self.period + t_curr = ( + curr_updates + - (1 - self.t_mult**i) / (1 - self.t_mult) * self.period + ) + r = float(t_curr) / t_i + else: + # force i to zero in one-cycle + i = 0 + t_i = self.period + t_curr = curr_updates + r = float(t_curr) / t_i + r = min(1.0, r) + + lr_shrink = self.lr_shrink**i + min_lr = self.args.min_lr * lr_shrink + max_lr = self.max_lr * lr_shrink + + self.lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + math.cos(math.pi * r)) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/exponential_decay_schedule.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/exponential_decay_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..87ba1a44143a83656d094302d8a5c564739aabe1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/exponential_decay_schedule.py @@ -0,0 +1,50 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + +@register_lr_scheduler("exponential_decay") +class ExponentialDecayLRSchedule(UnicoreLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + self.warmup_updates = args.warmup_updates + self.lr = args.lr[0] + if self.warmup_updates > 0: + self.warmup_factor = 1.0 / self.warmup_updates + else: + self.warmup_factor = 1.0 + self.decay_ratio = args.decay_ratio + self.decay_steps = args.decay_steps + self.optimizer.set_lr(self.warmup_factor * self.lr) + self.stair_decay = getattr(args, "stair_decay", False) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + parser.add_argument('--warmup-updates', default=1000, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--decay-ratio', default=0.95, type=float) + parser.add_argument('--decay-steps', default=500, type=int) + parser.add_argument('--stair-decay', action="store_true") + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + self.warmup_factor = num_updates / float(self.warmup_updates) + lr = self.warmup_factor * self.lr + else: + if self.stair_decay: + step = num_updates + lr = self.lr * float(self.decay_ratio ** (int(step // self.decay_steps))) + else: + step = num_updates - self.warmup_updates + lr = self.lr * float(self.decay_ratio ** (float(step / self.decay_steps))) + self.optimizer.set_lr(lr) + return self.optimizer.get_lr() diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/fixed_schedule.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/fixed_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..5488f2b4a28749b136fa6612a3b99f970f865d14 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/fixed_schedule.py @@ -0,0 +1,69 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("fixed") +class FixedLRSchedule(UnicoreLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + + self.lr = args.lr[0] + if args.warmup_updates > 0: + self.warmup_factor = 1.0 / args.warmup_updates + else: + self.warmup_factor = 1 + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', + help='force annealing at specified epoch') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + # fmt: on + + def state_dict(self): + return {"lr": self.lr} + + def load_state_dict(self, state_dict): + if "lr" in state_dict: + self.lr = state_dict["lr"] + + def get_next_lr(self, epoch): + lrs = self.args.lr + if self.args.force_anneal is None or epoch < self.args.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch - 1, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = lrs[-1] * self.args.lr_shrink ** ( + epoch + 1 - self.args.force_anneal + ) + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.args.warmup_updates > 0 and num_updates < self.args.warmup_updates: + self.warmup_factor = (num_updates + 1) / float(self.args.warmup_updates) + self.optimizer.set_lr(self.warmup_factor * self.lr) + else: + self.optimizer.set_lr(self.lr) + return self.optimizer.get_lr() diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/inverse_square_root_schedule.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/inverse_square_root_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..b71cfffd505f8b98a0ffb5b1458d2aad659d4abf --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/inverse_square_root_schedule.py @@ -0,0 +1,77 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from collections.abc import Collection +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("inverse_sqrt") +class InverseSquareRootSchedule(UnicoreLRScheduler): + """Decay the LR based on the inverse square root of the update number. + + We also support a warmup phase where we linearly increase the learning rate + from some initial learning rate (``--warmup-init-lr``) until the configured + learning rate (``--lr``). Thereafter we decay proportional to the number of + updates, with a decay factor set to align with the configured learning rate. + + During warmup:: + + lrs = torch.linspace(args.warmup_init_lr, args.lr, args.warmup_updates) + lr = lrs[update_num] + + After warmup:: + + decay_factor = args.lr * sqrt(args.warmup_updates) + lr = decay_factor / sqrt(update_num) + """ + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + if isinstance(args.lr, Collection) and len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with inverse_sqrt." + " Consider --lr-scheduler=fixed instead." + ) + warmup_end_lr = args.lr[0] if isinstance(args.lr, Collection) else args.lr + if args.warmup_init_lr < 0: + args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first args.warmup_updates + self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + + # then, decay prop. to the inverse square root of the update number + self.decay_factor = warmup_end_lr * args.warmup_updates ** 0.5 + + # initial learning rate + self.lr = args.warmup_init_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--warmup-updates', default=4000, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', + help='initial learning rate during warmup phase; default is args.lr') + # fmt: on + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if num_updates < self.args.warmup_updates: + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + else: + self.lr = self.decay_factor * num_updates ** -0.5 + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/pass_through.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/pass_through.py new file mode 100644 index 0000000000000000000000000000000000000000..ad60f2a0d7e5cd7c778fe3380d3b45900928f11b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/pass_through.py @@ -0,0 +1,32 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("pass_through") +class PassThroughScheduleSchedule(UnicoreLRScheduler): + """Delegate lr scheduling to the optimizer.""" + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + assert ( + hasattr(optimizer, "lr_scheduler") and optimizer.lr_scheduler is not None + ), "Pass-through schedule can only be used with optimizers with their own schedulers" + + def state_dict(self): + return self.optimizer.lr_scheduler.state_dict() + + def load_state_dict(self, state_dict): + self.optimizer.lr_scheduler.load_state_dict(state_dict) + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + return self.optimizer.lr_scheduler.step_begin_epoch(epoch) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.lr_scheduler.step_update(num_updates) diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/polynomial_decay_schedule.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/polynomial_decay_schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..2cffcdadf6a4a2157feb0aaf5ae4bb94f818fe12 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/polynomial_decay_schedule.py @@ -0,0 +1,79 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + +@register_lr_scheduler("polynomial_decay") +class PolynomialDecayLRSchedule(UnicoreLRScheduler): + """Decay the LR on a fixed schedule.""" + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + if self.args.warmup_ratio > 0: + # if warmup_ratio > 0, use external train steps + assert total_train_steps is not None + self.warmup_updates = int(self.args.warmup_ratio * total_train_steps) + self.total_num_update = total_train_steps + else: + assert args.total_num_update > 0 + self.warmup_updates = args.warmup_updates + self.total_num_update = args.total_num_update + self.lr = args.lr[0] + if self.warmup_updates > 0: + self.warmup_factor = 1.0 / self.warmup_updates + else: + self.warmup_factor = 1 + self.end_learning_rate = args.end_learning_rate + self.power = args.power + self.optimizer.set_lr(self.warmup_factor * self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + parser.add_argument('--force-anneal', '--fa', type=int, metavar='N', + help='force annealing at specified epoch') + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--warmup-ratio', default=-1.0, type=float, metavar='N', + help='warmup the learning rate linearly for the first N-percent updates') + parser.add_argument('--end-learning-rate', default=0.0, type=float) + parser.add_argument('--power', default=1.0, type=float) + parser.add_argument('--total-num-update', default=1000000, type=int) + + def get_next_lr(self, epoch): + lrs = self.args.lr + if self.args.force_anneal is None or epoch < self.args.force_anneal: + # use fixed LR schedule + next_lr = lrs[min(epoch, len(lrs) - 1)] + else: + # annneal based on lr_shrink + next_lr = self.optimizer.get_lr() + return next_lr + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + self.lr = self.get_next_lr(epoch) + self.optimizer.set_lr(self.warmup_factor * self.lr) + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + if self.warmup_updates > 0 and num_updates <= self.warmup_updates: + self.warmup_factor = num_updates / float(self.warmup_updates) + lr = self.warmup_factor * self.lr + elif num_updates >= self.total_num_update: + lr = self.end_learning_rate + else: + warmup = self.warmup_updates + lr_range = self.lr - self.end_learning_rate + pct_remaining = 1 - (num_updates - warmup) / ( + self.total_num_update - warmup + ) + lr = lr_range * pct_remaining ** (self.power) + self.end_learning_rate + self.optimizer.set_lr(lr) + return self.optimizer.get_lr() diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py new file mode 100644 index 0000000000000000000000000000000000000000..d9961557ff963478557693ae5882887212594c83 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/reduce_lr_on_plateau.py @@ -0,0 +1,227 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import List + +import mindspore.nn as nn +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler # 假设已适配MindSpore版本 + + +@register_lr_scheduler( + "reduce_lr_on_plateau" +) +class ReduceLROnPlateauLRSchedule(UnicoreLRScheduler): + """ + Decay the LR by a factor every time the validation loss plateaus. + Also comes with optional warmup phase, where we linearly increase + the learning rate from some initial learning rate + (``--warmup-init-lr``) until the configured learning rate + (``--lr``). Thereafter the lr is adjusted according to original + reduce_on_plateau scheme. + + During warmup:: + + lrs = mindspore.ops.linspace( + args.warmup_init_lr, args.lr, args.warmup_updates + ) + lr = lrs[update_num] + """ + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." + " Consider --lr-scheduler=fixed instead." + ) + # 替换torch.optim.lr_scheduler.ReduceLROnPlateau为mindspore.nn.ReduceLROnPlateau + self.lr_scheduler = nn.ReduceLROnPlateau( + optimizer=self.optimizer, # MindSpore优化器直接传入,无需访问.optimizer属性 + patience=args.lr_patience, + factor=args.lr_shrink, + mode="max" if args.maximize_best_checkpoint_metric else "min", + threshold=args.lr_threshold, + ) + warmup_end_lr = args.lr[0] + # if no warm up, sets initial lr to be args.lr[0] + if args.warmup_init_lr < 0: + args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + + # linearly warmup for the first args.warmup_updates + if args.warmup_updates > 0: + self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + + # this flag is either set from arg when no warm up, or set by + # step_update() when warmup finishes + self.warmup_end = True if args.warmup_updates <= 0 else False + + # initial learning rate + # this self.lr is used only during init and/or warm up period + self.lr = args.warmup_init_lr + self.optimizer.set_lr(self.lr) # MindSpore优化器设置学习率方法 + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing, lr_new = (lr * lr_shrink)') + parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', + help='Threshold for measuring the new optimum, \ + to only focus on significant changes') + parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', + help='warmup the learning rate linearly for the first N updates') + parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', + help='initial learning rate during warmup phase; default is args.lr') + # fmt: on + + def state_dict(self): + """Return the LR scheduler state dict.""" + return { + "best": self.lr_scheduler.best, + "last_epoch": self.lr_scheduler.last_epoch, + } + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.lr_scheduler.best = state_dict["best"] + if "last_epoch" in state_dict: + self.lr_scheduler.last_epoch = state_dict["last_epoch"] + + def step(self, epoch, val_loss=None): + """ + Update the learning rate at the end of the given epoch if warmup + finishes otherwise no update of lr on epoch boundaries + """ + if val_loss is not None and self.warmup_end is True: + self.lr_scheduler.step(val_loss) # MindSpore的ReduceLROnPlateau.step接受验证指标 + else: + self.lr_scheduler.last_epoch = epoch + return self.optimizer.get_lr() # MindSpore优化器获取学习率方法 + + def step_update(self, num_updates): + """ + Update the learning rate after each update. + """ + # if there is warmup + if self.args.warmup_updates > 0: + if num_updates <= self.args.warmup_updates: + self.lr = self.args.warmup_init_lr + num_updates * self.lr_step + self.optimizer.set_lr(self.lr) # 设置学习率 + else: + if self.warmup_end is False: + self.warmup_end = True + # else do nothing + return self.optimizer.get_lr() # 获取当前学习率 +# from typing import List + +# import torch.optim.lr_scheduler +# from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +# @register_lr_scheduler( +# "reduce_lr_on_plateau" +# ) +# class ReduceLROnPlateauLRSchedule(UnicoreLRScheduler): +# """ +# Decay the LR by a factor every time the validation loss plateaus. +# Also comes with optional warmup phase, where we linearly increase +# the learning rate from some initial learning rate +# (``--warmup-init-lr``) until the configured learning rate +# (``--lr``). Thereafter the lr is adjusted according to original +# reduce_on_plateau scheme. + +# During warmup:: + +# lrs = torch.linspace( +# args.warmup_init_lr, args.lr, args.warmup_updates +# ) +# lr = lrs[update_num] +# """ + +# def __init__(self, args, optimizer, total_train_steps): +# super().__init__(args, optimizer, total_train_steps) +# if len(args.lr) > 1: +# raise ValueError( +# "Cannot use a fixed learning rate schedule with reduce_lr_on_plateau." +# " Consider --lr-scheduler=fixed instead." +# ) +# self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( +# self.optimizer.optimizer, +# patience=args.lr_patience, +# factor=args.lr_shrink, +# mode="max" if args.maximize_best_checkpoint_metric else "min", +# threshold=args.lr_threshold, +# ) +# warmup_end_lr = args.lr[0] +# # if no warm up, sets initial lr to be args.lr[0] +# if args.warmup_init_lr < 0: +# args.warmup_init_lr = 0 if args.warmup_updates > 0 else warmup_end_lr + +# # linearly warmup for the first args.warmup_updates +# if args.warmup_updates > 0: +# self.lr_step = (warmup_end_lr - args.warmup_init_lr) / args.warmup_updates + +# # this flag is either set from arg when no warm up, or set by +# # step_update() when warmup finishes +# self.warmup_end = True if args.warmup_updates <= 0 else False + +# # initial learning rate +# # this self.lr is used only during init and/or warm up period +# self.lr = args.warmup_init_lr +# self.optimizer.set_lr(self.lr) + +# @staticmethod +# def add_args(parser): +# """Add arguments to the parser for this LR scheduler.""" +# # fmt: off +# parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', +# help='shrink factor for annealing, lr_new = (lr * lr_shrink)') +# parser.add_argument('--lr-threshold', default=1e-4, type=float, metavar='LT', +# help='Threshold for measuring the new optimum, \ +# to only focus on significant changes') +# parser.add_argument('--warmup-updates', default=0, type=int, metavar='N', +# help='warmup the learning rate linearly for the first N updates') +# parser.add_argument('--warmup-init-lr', default=-1, type=float, metavar='LR', +# help='initial learning rate during warmup phase; default is args.lr') +# # fmt: on + +# def state_dict(self): +# """Return the LR scheduler state dict.""" +# return { +# "best": self.lr_scheduler.best, +# "last_epoch": self.lr_scheduler.last_epoch, +# } + +# def load_state_dict(self, state_dict): +# """Load an LR scheduler state dict.""" +# self.lr_scheduler.best = state_dict["best"] +# if "last_epoch" in state_dict: +# self.lr_scheduler.last_epoch = state_dict["last_epoch"] + +# def step(self, epoch, val_loss=None): +# """ +# Update the learning rate at the end of the given epoch if warmup +# finishes otherwise no update of lr on epoch boundaries +# """ +# if val_loss is not None and self.warmup_end is True: +# self.lr_scheduler.step(val_loss) +# else: +# self.lr_scheduler.last_epoch = epoch +# return self.optimizer.get_lr() + +# def step_update(self, num_updates): +# """ +# Update the learning rate after each update.""" +# # if there is warmup +# if self.args.warmup_updates > 0: +# if num_updates <= self.args.warmup_updates: +# self.lr = self.args.warmup_init_lr + num_updates * self.lr_step +# self.optimizer.set_lr(self.lr) +# else: +# if self.warmup_end is False: +# self.warmup_end = True +# # else do nothing +# return self.optimizer.get_lr() diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1b0e8b0515b19b7b26efc5bfb2cf077b833fc9d9 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/tri_stage_lr_scheduler.py @@ -0,0 +1,177 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List + +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + +@register_lr_scheduler("tri_stage") +class TriStageLRSchedule(UnicoreLRScheduler): + """Tristage learning rate schedulr + + Implement the learning rate scheduler in https://arxiv.org/pdf/1904.08779.pdf + + Similar to inverse_squre_root scheduler, but tri_stage learning rate employs + three stages LR scheduling: + + - warmup stage, starting from `lr` * `init_lr_scale`, linearly + increased to `lr` in `warmup_steps` iterations + + - hold stage, after `warmup_steps`, keep the LR as `lr` for `hold_steps` + iterations + + - decay stage, after hold stage, decay LR exponetially to + `lr` * `final_lr_scale` in `decay_steps`; + after that LR is keep as `final_lr_scale` * `lr` + + During warmup:: + + init_lr = args.init_lr_scale * args.lr + lrs = torch.linspace(init_lr, args.lr, args.warmup_steps) + lr = lrs[update_num] + + During hold:: + + lr = args.lr + + During decay:: + + decay_factor = - math.log(args.final_lr_scale) / args.decay_steps + lr = args.lr * exp(- (update_num - warmup_steps - decay_steps) * decay_factor) + + After that:: + + lr = args.lr * args.final_lr_scale + """ + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with tri-stage lr." + " Consider --lr-scheduler=fixed instead." + ) + + # calculate LR at each point + self.peak_lr = args.lr[0] + self.init_lr = args.init_lr_scale * args.lr[0] + self.final_lr = args.final_lr_scale * args.lr[0] + + if args.phase_ratio is not None: + assert args.max_update > 0 + assert sum(args.phase_ratio) == 1, "phase ratios must add up to 1" + self.warmup_steps = int(args.max_update * args.phase_ratio[0]) + self.hold_steps = int(args.max_update * args.phase_ratio[1]) + self.decay_steps = int(args.max_update * args.phase_ratio[2]) + else: + self.warmup_steps = args.warmup_steps + self.hold_steps = args.hold_steps + self.decay_steps = args.decay_steps + + assert ( + self.warmup_steps + self.hold_steps + self.decay_steps > 0 + ), "please specify steps or phase_ratio" + + self.warmup_rate = ( + (self.peak_lr - self.init_lr) / self.warmup_steps + if self.warmup_steps != 0 + else 0 + ) + self.decay_factor = -math.log(args.final_lr_scale) / self.decay_steps + + # initial learning rate + self.lr = self.init_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument( + '--warmup-steps', + default=4000, + type=int, + metavar='N', + help='warmup the learning rate linearly for the first N updates' + ) + parser.add_argument( + '--hold-steps', + default=20000, + type=int, + metavar='N', + help='steps in hold stage.' + ) + parser.add_argument( + '--decay-steps', + default=60000, + type=int, + metavar='N', + help='steps in decay stages' + ) + parser.add_argument( + '--init-lr-scale', + default=0.01, + type=float, + help=""" + initial learning rate scale during warmup phase; default is 0.01""") + parser.add_argument( + '--final-lr-scale', + default=0.01, + type=float, + help="final learning rate scale; default to 0.01" + ) + # fmt: on + + def _decide_stage(self, update_step): + """ + return stage, and the corresponding steps within the current stage + """ + if update_step < self.warmup_steps: + # warmup state + return 0, update_step + + offset = self.warmup_steps + + if update_step < offset + self.hold_steps: + # hold stage + return 1, update_step - offset + + offset += self.hold_steps + + if update_step <= offset + self.decay_steps: + # decay stage + return 2, update_step - offset + + offset += self.decay_steps + + # still here ? constant lr stage + return 3, update_step - offset + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + stage, steps_in_stage = self._decide_stage(num_updates) + if stage == 0: + self.lr = self.init_lr + self.warmup_rate * steps_in_stage + elif stage == 1: + self.lr = self.peak_lr + elif stage == 2: + self.lr = self.peak_lr * math.exp(-self.decay_factor * steps_in_stage) + elif stage == 3: + self.lr = self.final_lr + else: + raise ValueError("Undefined stage") + + self.optimizer.set_lr(self.lr) + + return self.lr diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/triangular_lr_scheduler.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/triangular_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..ab81f6f857e719d13ffdea5326d4f6f7cc1f192b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/triangular_lr_scheduler.py @@ -0,0 +1,76 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import List +from unicore.optim.lr_scheduler import UnicoreLRScheduler, register_lr_scheduler + + + +@register_lr_scheduler("triangular") +class TriangularLRSchedule(UnicoreLRScheduler): + """Assign LR based on a triangular cyclical schedule. + + See https://arxiv.org/pdf/1506.01186.pdf for details. + """ + + def __init__(self, args, optimizer, total_train_steps): + super().__init__(args, optimizer, total_train_steps) + if len(args.lr) > 1: + raise ValueError( + "Cannot use a fixed learning rate schedule with triangular." + " Consider --lr-scheduler=fixed instead." + ) + + lr = args.lr[0] + + assert args.max_lr > lr, "max_lr must be more than lr" + self.min_lr = lr + self.max_lr = args.max_lr + self.stepsize = args.lr_period_updates // 2 + self.lr_shrink = args.lr_shrink + self.shrink_min = args.shrink_min + + # initial learning rate + self.lr = self.min_lr + self.optimizer.set_lr(self.lr) + + @staticmethod + def add_args(parser): + """Add arguments to the parser for this LR scheduler.""" + # fmt: off + parser.add_argument('--max-lr', required=True, type=float, metavar='LR', + help='max learning rate, must be more than args.lr') + parser.add_argument('--lr-period-updates', default=5000, type=float, metavar='LR', + help='initial number of updates per period (cycle length)') + parser.add_argument('--lr-shrink', default=0.1, type=float, metavar='LS', + help='shrink factor for annealing') + parser.add_argument('--shrink-min', action='store_true', + help='if set, also shrinks min lr') + # fmt: on + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + super().step(epoch, val_loss) + # we don't change the learning rate at epoch boundaries + return self.optimizer.get_lr() + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + cycle = math.floor(num_updates / (2 * self.stepsize)) + + lr_shrink = self.lr_shrink ** cycle + max_lr = self.max_lr * lr_shrink + if self.shrink_min: + min_lr = self.min_lr * lr_shrink + else: + min_lr = self.min_lr + + x = abs(num_updates / self.stepsize - 2 * (cycle + 1) + 1) + self.lr = min_lr + (max_lr - min_lr) * max(0, (1 - x)) + + self.optimizer.set_lr(self.lr) + return self.lr diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/unicore_lr_scheduler.py b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/unicore_lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..688273ac5822fb8b1c43948e21cc02207460809c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/lr_scheduler/unicore_lr_scheduler.py @@ -0,0 +1,50 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from argparse import Namespace + +from unicore.optim import UnicoreOptimizer + + +class UnicoreLRScheduler(object): + def __init__(self, args, optimizer, total_train_steps): + super().__init__() + if optimizer is not None and not isinstance(optimizer, UnicoreOptimizer): + raise ValueError("optimizer must be an instance of UnicoreOptimizer") + self.args = args + self.optimizer = optimizer + self.total_train_steps = total_train_steps + self.best = None + + @classmethod + def add_args(cls, parser): + """Add arguments to the parser for this LR scheduler.""" + pass + + def state_dict(self): + """Return the LR scheduler state dict.""" + return {"best": self.best} + + def load_state_dict(self, state_dict): + """Load an LR scheduler state dict.""" + self.best = state_dict["best"] + + def step_begin_epoch(self, epoch): + """Update the learning rate at the beginning of the given epoch.""" + pass + + def step(self, epoch, val_loss=None): + """Update the learning rate at the end of the given epoch.""" + if val_loss is not None: + if self.best is None: + self.best = val_loss + else: + self.best = min(self.best, val_loss) + + def step_update(self, num_updates): + """Update the learning rate after each update.""" + return self.optimizer.get_lr() + diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/sgd.py b/MindChemistry/applications/Uni-Mol/unicore/optim/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..7f751f5600cad9259ab65a419808a9a5c49a8866 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/sgd.py @@ -0,0 +1,86 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore.nn as nn + +from . import UnicoreOptimizer, register_optimizer + + +@register_optimizer("sgd") +class SGD(UnicoreOptimizer): + def __init__(self, args, params): + super().__init__(args) + # 替换为MindSpore的SGD优化器,参数对应PyTorch版本 + self._optimizer = nn.SGD( + params=params, + learning_rate=self.args.lr[0], + momentum=self.args.momentum, + weight_decay=self.args.weight_decay + ) + + @staticmethod + def add_args(parser): + """Add optimizer-specific arguments to the parser.""" + # 参数定义与原PyTorch版本保持一致 + # fmt: off + parser.add_argument('--momentum', default=0.0, type=float, metavar='M', + help='momentum factor') + parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', + help='weight decay') + # fmt: on + + @property + def optimizer_config(self): + """ + 返回用于覆盖检查点中存储的优化器参数的字典,适配MindSpore SGD参数名 + """ + return { + "learning_rate": self.args.lr[0], # MindSpore中参数名为learning_rate + "momentum": self.args.momentum, + "weight_decay": self.args.weight_decay, + } + + @property + def supports_flat_params(self): + # MindSpore的SGD优化器支持扁平参数 + return True +# import torch.optim + +# from . import UnicoreOptimizer, register_optimizer + + +# @register_optimizer("sgd") +# class SGD(UnicoreOptimizer): +# def __init__(self, args, params): +# super().__init__(args) +# self._optimizer = torch.optim.SGD(params, **self.optimizer_config) + +# @staticmethod +# def add_args(parser): +# """Add optimizer-specific arguments to the parser.""" +# # fmt: off +# parser.add_argument('--momentum', default=0.0, type=float, metavar='M', +# help='momentum factor') +# parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD', +# help='weight decay') +# # fmt: on + +# @property +# def optimizer_config(self): +# """ +# Return a kwarg dictionary that will be used to override optimizer +# args stored in checkpoints. This allows us to load a checkpoint and +# resume training using a different set of optimizer args, e.g., with a +# different learning rate. +# """ +# return { +# "lr": self.args.lr[0], +# "momentum": self.args.momentum, +# "weight_decay": self.args.weight_decay, +# } + +# @property +# def supports_flat_params(self): +# return True diff --git a/MindChemistry/applications/Uni-Mol/unicore/optim/unicore_optimizer.py b/MindChemistry/applications/Uni-Mol/unicore/optim/unicore_optimizer.py new file mode 100644 index 0000000000000000000000000000000000000000..11b78a5a7ade62409c232cc4be4b0ddd11647cf6 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/optim/unicore_optimizer.py @@ -0,0 +1,368 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +from mindspore import ops +from unicore import utils # 假设utils中相关函数已适配MindSpore + +class UnicoreOptimizer(object): + def __init__(self, args): + super().__init__() + self.args = args + self._grad_buffer = None + self._need_sync_grad_buf = False + + @classmethod + def add_args(cls, parser): + """Add optimizer-specific arguments to the parser.""" + pass + + @property + def optimizer(self): + """Return a mindspore.nn.Optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(self._optimizer, ms.nn.Optimizer): + raise ValueError("_optimizer must be an instance of mindspore.nn.Optimizer") + return self._optimizer + + @optimizer.setter + def optimizer(self, optimizer): + """Reset optimizer instance.""" + if not hasattr(self, "_optimizer"): + raise NotImplementedError + if not isinstance(optimizer, ms.nn.Optimizer): + raise ValueError("_optimizer must be an instance of mindspore.nn.Optimizer") + self._optimizer = optimizer + + @property + def optimizer_config(self): + """Return a kwarg dictionary for overriding optimizer args in checkpoints.""" + raise NotImplementedError + + @property + def params(self): + """Return an iterable of the parameters held by the optimizer.""" + for param_group in self.param_groups: + for p in param_group["params"]: + yield p + + @property + def param_groups(self): + return self.optimizer.param_groups + + def __getstate__(self): + return self._optimizer.__getstate__() + + def get_lr(self): + """Return the current learning rate.""" + return self.param_groups[0]["lr"] + + def set_lr(self, lr): + """Set the learning rate.""" + for param_group in self.param_groups: + param_group["lr"] = lr + + def state_dict(self): + """Return the optimizer's state dict.""" + return self.optimizer.state_dict() + + def load_state_dict(self, state_dict, optimizer_overrides=None): + """Load an optimizer state dict.""" + self.optimizer.load_state_dict(state_dict) + + if optimizer_overrides is not None and len(optimizer_overrides) > 0: + # Override learning rate, momentum, etc. with latest values + for group in self.param_groups: + group.update(optimizer_overrides) + + def backward(self, loss): + """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" + loss.backward() # MindSpore中loss.backward()用法与PyTorch一致 + + def all_reduce_grads(self, module): + """单卡环境无需梯度同步,简化实现""" + self.__sync_grad_from_buf__() + # 移除分布式相关逻辑(单卡无需all-reduce) + + def multiply_grads(self, c): + """Multiplies grads by a constant *c*.""" + for p in self.params: + if p.grad is not None: + if ops.is_tensor(c): + c = c.to(p.grad.device) # MindSpore中device处理与PyTorch兼容 + p.grad = p.grad * c # MindSpore中用*替代mul_ + + def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + if max_norm <= 0.0: + return 0.0 + if self._grad_buffer is None: + # 用MindSpore的zeros_like替代torch.zeros_like + self._grad_buffer = [ops.zeros_like(g) for g in self.params] + + # 用MindSpore的梯度裁剪替代PyTorch版本 + gnorm = utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + for i, p in enumerate(self.params): + if p.grad is None: + continue + self._grad_buffer[i] += p.grad + p.grad = None + self._need_sync_grad_buf = True + return gnorm + + def __sync_grad_from_buf__(self): + if self._need_sync_grad_buf: + assert self._grad_buffer is not None + for i, p in enumerate(self.params): + p.grad = self._grad_buffer[i] + self._need_sync_grad_buf = False + + def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): + """Clips gradient norm.""" + self.__sync_grad_from_buf__() + return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + + def step(self, closure=None, scale=1.0, groups=None): + """Performs a single optimization step.""" + self.__sync_grad_from_buf__() + + # MindSpore优化器step方法通常不支持closure参数,这里做兼容处理 + if closure is not None: + closure() + + if self.supports_step_with_scale: + if self.supports_groups: + self.optimizer.step(scale=scale, groups=groups) + else: + self.optimizer.step(scale=scale) + else: + if scale != 1.0: + self.multiply_grads(1.0 / scale) + if self.supports_groups: + self.optimizer.step(groups=groups) + else: + self.optimizer.step() + + def zero_grad(self): + """Clears the gradients of all optimized parameters.""" + for p in self.params: + p.grad = None + self.optimizer.zero_grad() # MindSpore优化器清零梯度方法 + self._need_sync_grad_buf = False + if self._grad_buffer is not None: + for t in self._grad_buffer: + t.zero_() # MindSpore Tensor支持zero_方法 + + @property + def supports_memory_efficient_fp16(self): + if hasattr(self.optimizer, "supports_memory_efficient_fp16"): + return self.optimizer.supports_memory_efficient_fp16 + return False # MindSpore默认不启用PyTorch风格的内存高效FP16 + + @property + def supports_step_with_scale(self): + if hasattr(self.optimizer, "supports_step_with_scale"): + return self.optimizer.supports_step_with_scale + return False + + @property + def supports_groups(self): + if hasattr(self.optimizer, "supports_groups"): + return self.optimizer.supports_groups + return False + + @property + def supports_flat_params(self): + """MindSpore通常不使用扁平化参数,默认返回False""" + if hasattr(self.optimizer, "supports_flat_params"): + return self.optimizer.supports_flat_params + return False +# import torch +# from unicore import utils + +# class UnicoreOptimizer(object): +# def __init__(self, args): +# super().__init__() +# self.args = args +# self._grad_buffer = None +# self._need_sync_grad_buf = False + +# @classmethod +# def add_args(cls, parser): +# """Add optimizer-specific arguments to the parser.""" +# pass + +# @property +# def optimizer(self): +# """Return a torch.optim.optimizer.Optimizer instance.""" +# if not hasattr(self, "_optimizer"): +# raise NotImplementedError +# if not isinstance(self._optimizer, torch.optim.Optimizer): +# raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") +# return self._optimizer + +# @optimizer.setter +# def optimizer(self, optimizer): +# """Reset optimizer instance.""" +# if not hasattr(self, "_optimizer"): +# raise NotImplementedError +# if not isinstance(self._optimizer, torch.optim.Optimizer): +# raise ValueError("_optimizer must be an instance of torch.optim.Optimizer") +# self._optimizer = optimizer + +# @property +# def optimizer_config(self): +# """ +# Return a kwarg dictionary that will be used to override optimizer +# args stored in checkpoints. This allows us to load a checkpoint and +# resume training using a different set of optimizer args, e.g., with a +# different learning rate. +# """ +# raise NotImplementedError + +# @property +# def params(self): +# """Return an iterable of the parameters held by the optimizer.""" +# for param_group in self.param_groups: +# for p in param_group["params"]: +# yield p + +# @property +# def param_groups(self): +# return self.optimizer.param_groups + +# def __getstate__(self): +# return self._optimizer.__getstate__() + +# def get_lr(self): +# """Return the current learning rate.""" +# return self.param_groups[0]["lr"] + +# def set_lr(self, lr): +# """Set the learning rate.""" +# for param_group in self.param_groups: +# param_group["lr"] = lr + +# def state_dict(self): +# """Return the optimizer's state dict.""" +# return self.optimizer.state_dict() + +# def load_state_dict(self, state_dict, optimizer_overrides=None): +# """Load an optimizer state dict. + +# In general we should prefer the configuration of the existing optimizer +# instance (e.g., learning rate) over that found in the state_dict. This +# allows us to resume training from a checkpoint using a new set of +# optimizer args. +# """ +# self.optimizer.load_state_dict(state_dict) + +# if optimizer_overrides is not None and len(optimizer_overrides) > 0: +# # override learning rate, momentum, etc. with latest values +# for group in self.param_groups: +# group.update(optimizer_overrides) + +# def backward(self, loss): +# """Computes the sum of gradients of the given tensor w.r.t. graph leaves.""" +# loss.backward() + +# def all_reduce_grads(self, module): +# """Manually all-reduce gradients (if required).""" +# self.__sync_grad_from_buf__() +# if hasattr(module, "all_reduce_grads"): +# module.all_reduce_grads() + +# def multiply_grads(self, c): +# """Multiplies grads by a constant *c*.""" +# for p in self.params: +# if p.grad is not None: +# if torch.is_tensor(c): +# c = c.to(p.grad.device) +# p.grad.data.mul_(c) + +# def per_sample_clip_grad_norm(self, max_norm, aggregate_norm_fn=None): +# """Clips gradient norm.""" +# if max_norm <= 0.0: +# return 0.0 +# if self._grad_buffer is None: +# self._grad_buffer = [torch.zeros_like(g) for g in self.params] +# gnorm = utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) +# for i, p in enumerate(self.params): +# if p.grad is None: +# continue +# self._grad_buffer[i] += p.grad +# p.grad = None +# self._need_sync_grad_buf = True +# return gnorm + +# def __sync_grad_from_buf__(self): +# if self._need_sync_grad_buf: +# assert self._grad_buffer is not None +# for i, p in enumerate(self.params): +# p.grad = self._grad_buffer[i] +# self._need_sync_grad_buf = False + +# def clip_grad_norm(self, max_norm, aggregate_norm_fn=None): +# """Clips gradient norm.""" +# self.__sync_grad_from_buf__() +# return utils.clip_grad_norm_(self.params, max_norm, aggregate_norm_fn) + +# def step(self, closure=None, scale=1.0, groups=None): +# """Performs a single optimization step.""" +# self.__sync_grad_from_buf__() +# if self.supports_step_with_scale: +# if self.supports_groups: +# self.optimizer.step(closure, scale=scale, groups=groups) +# else: +# self.optimizer.step(closure, scale=scale) +# else: +# if scale != 1.0: +# self.multiply_grads(1.0 / scale) +# if self.supports_groups: +# self.optimizer.step(closure, groups=groups) +# else: +# self.optimizer.step(closure) + +# def zero_grad(self): +# """Clears the gradients of all optimized parameters.""" +# for p in self.params: +# p.grad = None +# self.optimizer.zero_grad() +# self._need_sync_grad_buf = False +# if self._grad_buffer is not None: +# for t in self._grad_buffer: +# t.zero_() + +# @property +# def supports_memory_efficient_fp16(self): +# if hasattr(self.optimizer, "supports_memory_efficient_fp16"): +# return self.optimizer.supports_memory_efficient_fp16 +# return False + +# @property +# def supports_step_with_scale(self): +# if hasattr(self.optimizer, "supports_step_with_scale"): +# return self.optimizer.supports_step_with_scale +# return False + +# @property +# def supports_groups(self): +# if hasattr(self.optimizer, "supports_groups"): +# return self.optimizer.supports_groups +# return False + +# @property +# def supports_flat_params(self): +# """ +# Whether the optimizer supports collapsing of the model +# parameters/gradients into a single contiguous Tensor. +# """ +# if hasattr(self.optimizer, "supports_flat_params"): +# return self.optimizer.supports_flat_params +# return False + + diff --git a/MindChemistry/applications/Uni-Mol/unicore/options.py b/MindChemistry/applications/Uni-Mol/unicore/options.py new file mode 100644 index 0000000000000000000000000000000000000000..9fc9cdb9ca9ea77d0e772407a09a43f353e7fafd --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/options.py @@ -0,0 +1,848 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import argparse +import mindspore as ms + + +from typing import Callable, List, Optional + +# this import is for backward compatibility +from unicore.utils import ( + csv_str_list, + eval_bool, + eval_str_dict, + eval_str_list, + import_user_module, +) # noqa + + +def get_training_parser(default_task="translation"): + parser = get_parser("Trainer", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser) + add_model_args(parser) + add_optimization_args(parser) + add_checkpoint_args(parser) + return parser + + +def get_validation_parser(default_task=None): + parser = get_parser("Validation", default_task) + add_dataset_args(parser, train=True) + add_distributed_training_args(parser) + group = parser.add_argument_group("Evaluation") + add_common_eval_args(group) + return parser + + +def parse_args_and_arch( + parser: argparse.ArgumentParser, + input_args: List[str] = None, + parse_known: bool = False, + suppress_defaults: bool = False, + modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, +): + """ + Args: + parser (ArgumentParser): the parser + input_args (List[str]): strings to parse, defaults to sys.argv + parse_known (bool): only parse known arguments, similar to + `ArgumentParser.parse_known_args` + suppress_defaults (bool): parse while ignoring all default values + modify_parser (Optional[Callable[[ArgumentParser], None]]): + function to modify the parser, e.g., to set default values + """ + if suppress_defaults: + # Parse args without any default values. This requires us to parse + # twice, once to identify all the necessary task/model args, and a second + # time with all defaults set to None. + args = parse_args_and_arch( + parser, + input_args=input_args, + parse_known=parse_known, + suppress_defaults=False, + ) + suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) + suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) + args = suppressed_parser.parse_args(input_args) + return argparse.Namespace( + **{k: v for k, v in vars(args).items() if v is not None} + ) + + from unicore.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY + + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args(input_args) + import_user_module(usr_args) + + if modify_parser is not None: + modify_parser(parser) + + # The parser doesn't know about model/loss/optimizer-specific args, so + # we parse twice. First we parse the model/loss/optimizer, then we + # parse a second time after adding the *-specific arguments. + # If input_args is given, we will parse those args instead of sys.argv. + args, _ = parser.parse_known_args(input_args) + + # Add model-specific args to parser. + if hasattr(args, "arch"): + model_specific_group = parser.add_argument_group( + "Model-specific configuration", + # Only include attributes which are explicitly given as command-line + # arguments or which have default values. + argument_default=argparse.SUPPRESS, + ) + if args.arch in ARCH_MODEL_REGISTRY: + ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) + elif args.arch in MODEL_REGISTRY: + MODEL_REGISTRY[args.arch].add_args(model_specific_group) + else: + raise RuntimeError() + + if hasattr(args, "task"): + from unicore.tasks import TASK_REGISTRY + + TASK_REGISTRY[args.task].add_args(parser) + + # Add *-specific args to parser. + from unicore.registry import REGISTRIES + + for registry_name, REGISTRY in REGISTRIES.items(): + choice = getattr(args, registry_name, None) + if choice is not None: + cls = REGISTRY["registry"][choice] + if hasattr(cls, "add_args"): + cls.add_args(parser) + + # Modify the parser a second time, since defaults may have been reset + if modify_parser is not None: + modify_parser(parser) + + # Parse a second time. + if parse_known: + args, extra = parser.parse_known_args(input_args) + else: + args = parser.parse_args(input_args) + extra = None + # Post-process args. + if ( + hasattr(args, "batch_size_valid") and args.batch_size_valid is None + ) or not hasattr(args, "batch_size_valid"): + args.batch_size_valid = args.batch_size + args.bf16 = getattr(args, "bf16", False) + + if getattr(args, "seed", None) is None: + args.seed = 1 # default seed for training + args.no_seed_provided = True + else: + args.no_seed_provided = False + + args.validate_with_ema = getattr(args, "validate_with_ema", False) + # Apply architecture configuration. + if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: + ARCH_CONFIG_REGISTRY[args.arch](args) + + if parse_known: + return args, extra + else: + return args + + +def get_parser(desc, default_task="test"): + # Before creating the true parser, we need to import optional user module + # in order to eagerly import custom tasks, optimizers, architectures, etc. + usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) + usr_parser.add_argument("--user-dir", default=None) + usr_args, _ = usr_parser.parse_known_args() + import_user_module(usr_args) + + parser = argparse.ArgumentParser(allow_abbrev=False) + # fmt: off + parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') + parser.add_argument('--log-interval', type=int, default=1000, metavar='N', + help='log progress every N batches (when progress bar is disabled)') + parser.add_argument('--log-format', default=None, help='log format to use', + choices=['json', 'none', 'simple', 'tqdm']) + parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', + help='path to save logs for tensorboard, should match --logdir ' + 'of running tensorboard (default: no tensorboard logging)') + parser.add_argument('--wandb-project', metavar='DIR', default='', + help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY. You can also use team_name/project_name for project name.') + parser.add_argument('--wandb-name', metavar='DIR', default='', + help='wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY') + parser.add_argument('--seed', default=1, type=int, metavar='N', + help='pseudo random number generator seed') + parser.add_argument('--cpu', action='store_true', help='use CPU instead of GPU') + parser.add_argument('--fp16', action='store_true', help='use FP16') + parser.add_argument('--bf16', action='store_true', help='use BF16') + parser.add_argument('--bf16-sr', action='store_true', help='use stachostic rounding for bf16') + parser.add_argument('--allreduce-fp32-grad', action='store_true', help='use fp32-grads in fp16/bf16 mode. --ddp-backend should be no_c10d') + parser.add_argument('--fp16-no-flatten-grads', action='store_true', help="don't flatten FP16 grads tensor") + parser.add_argument('--fp16-init-scale', default=2** 7, type=int, + help='default FP16 loss scale') + parser.add_argument('--fp16-scale-window', type=int, + help='number of updates before increasing loss scale') + parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float, + help='pct of updates that can overflow before decreasing the loss scale') + parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', + help='minimum FP16 loss scale, after which training is stopped') + parser.add_argument('--threshold-loss-scale', type=float, + help='threshold FP16 loss scale from below') + parser.add_argument('--user-dir', default=None, + help='path to a python module containing custom extensions (tasks and/or architectures)') + parser.add_argument('--empty-cache-freq', default=0, type=int, + help='how often to clear the MindSpore GPU cache (0 to disable)') + parser.add_argument('--all-gather-list-size', default=16384, type=int, + help='number of bytes reserved for gathering stats from workers') + parser.add_argument('--suppress-crashes', action='store_true', help="suppress crashes when training with the entry point so that the " + "main method can return a value (useful for sweeps)") + parser.add_argument('--profile', action='store_true', help="enable profiler emit_nvtx") + parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights") + parser.add_argument("--validate-with-ema", action="store_true") + + + from unicore.registry import REGISTRIES + for registry_name, REGISTRY in REGISTRIES.items(): + parser.add_argument( + '--' + registry_name.replace('_', '-'), + default=REGISTRY['default'], + choices=REGISTRY['registry'].keys(), + ) + + # Task definitions can be found under unicore/tasks/ + from unicore.tasks import TASK_REGISTRY + parser.add_argument('--task', metavar='TASK', default=default_task, + choices=TASK_REGISTRY.keys(), + help='task') + # fmt: on + return parser + + +def add_dataset_args(parser, train=False, gen=False): + group = parser.add_argument_group("Dataset and data loading") + # fmt: off + group.add_argument('--num-workers', default=1, type=int, metavar='N', + help='how many subprocesses to use for data loading') + group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', + help='ignore too long or too short lines in valid and test set') + group.add_argument('--batch-size', '--max-sentences', type=int, metavar='N', + help='maximum number of sentences in a batch') + group.add_argument('--required-batch-size-multiple', default=1, type=int, metavar='N', + help='batch size will be a multiplier of this value') + group.add_argument('--data-buffer-size', default=10, type=int, + help='Number of batches to preload') + group.add_argument('--train-subset', default='train', metavar='SPLIT', + choices=['train', 'valid', 'test', 'train.small'], + help='data subset to use for training (train, valid, test)') + group.add_argument('--valid-subset', default='valid', metavar='SPLIT', + help='comma separated list of data subsets to use for validation' + ' (train, valid, valid1, test, test1)') + group.add_argument('--validate-interval', type=int, default=1, metavar='N', + help='validate every N epochs') + group.add_argument('--validate-interval-updates', type=int, default=0, metavar='N', + help='validate every N updates') + group.add_argument('--validate-after-updates', type=int, default=0, metavar='N', + help='dont validate until reaching this many updates') + group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N', + help='specified random seed for validation') + group.add_argument('--disable-validation', action='store_true', + help='disable validation') + group.add_argument('--batch-size-valid', type=int, metavar='N', + help='maximum number of sentences in a validation batch' + ' (defaults to --max-sentences)') + group.add_argument('--max-valid-steps', type=int, metavar='N', + help='How many batches to evaluate') + group.add_argument('--curriculum', default=0, type=int, metavar='N', + help='don\'t shuffle batches for first N epochs') + # fmt: on + return group + + +def add_distributed_training_args(parser): + group = parser.add_argument_group("Distributed training") + # fmt: off + group.add_argument('--distributed-world-size', type=int, metavar='N', + default=max(1, ms.hal.device_count()), # MindSpore获取设备数量 + help='total number of GPUs across all nodes (default: all visible GPUs)') + group.add_argument('--distributed-rank', default=0, type=int, + help='rank of the current worker') + group.add_argument('--distributed-backend', default='nccl', type=str, + help='distributed backend') + group.add_argument('--distributed-init-method', default=None, type=str, + help='typically tcp://hostname:port that will be used to ' + 'establish initial connetion') + group.add_argument('--distributed-port', default=-1, type=int, + help='port number (not required if using --distributed-init-method)') + group.add_argument('--device-id', '--local_rank', default=0, type=int, + help='which GPU to use (usually configured automatically)') + group.add_argument('--distributed-no-spawn', action='store_true', + help='do not spawn multiple processes even if multiple GPUs are visible') + group.add_argument('--ddp-backend', default='c10d', type=str, + choices=['c10d', 'apex', 'no_c10d'], + help='DistributedDataParallel backend') + group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB', + help='bucket size for reduction') + group.add_argument('--fix-batches-to-gpus', action='store_true', + help='don\'t shuffle batches between GPUs; this reduces overall ' + 'randomness and may affect precision but avoids the cost of ' + 're-reading the data') + group.add_argument('--find-unused-parameters', default=False, action='store_true', + help='disable unused parameter detection (not applicable to ' + 'no_c10d ddp-backend') + group.add_argument('--fast-stat-sync', default=False, action='store_true', + help='Enable fast sync of stats between nodes, this hardcodes to ' + 'sync only some default stats from logging_output.') + group.add_argument('--broadcast-buffers', default=False, action='store_true', + help="Copy non-trainable parameters between GPUs, such as " + "batchnorm population statistics") + group.add_argument('--nprocs-per-node', default=max(1, ms.hal.device_count()), type=int, # MindSpore获取设备数量 + help="number of GPUs in each node. An allreduce operation across GPUs in " + "a node is very fast. Hence, we do allreduce across GPUs in a node, " + "and gossip across different nodes") + # fmt: on + return group + + +def add_optimization_args(parser): + group = parser.add_argument_group("Optimization") + # fmt: off + group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', + help='force stop training at specified epoch') + group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', + help='force stop training at specified update') + group.add_argument('--stop-time-hours', default=0, type=float, + help="force stop training after specified cumulative time (if >0)") + group.add_argument('--no-weight-decay-names', default="", type=str, + help='names of parameters to not weight decay, comma separated') + group.add_argument('--clip-norm', default=0, type=float, metavar='NORM', + help='clip threshold of gradients') + group.add_argument('--per-sample-clip-norm', default=0, type=float, metavar='PNORM', + help='clip threshold of gradients, before gradient sync over workers. In fp16/bf16 mode, --fp32-grad should be set, and --dpp-backend should be no_c10d') + group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K', + type=lambda uf: eval_str_list(uf, type=int), + help='update parameters every N_i batches, when in epoch i') + group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list, + metavar='LR_1,LR_2,...,LR_N', + help='learning rate for the first N epochs; all epochs >N using LR_N' + ' (note: this may be interpreted differently depending on --lr-scheduler)') + group.add_argument('--stop-min-lr', default=-1, type=float, metavar='LR', + help='stop training when the learning rate reaches this minimum') + # fmt: on + return group + + +def add_checkpoint_args(parser): + group = parser.add_argument_group("Checkpointing") + # fmt: off + group.add_argument('--save-dir', metavar='DIR', default='checkpoints', + help='path to save checkpoints') + group.add_argument('--tmp-save-dir', metavar='DIR', default='./', + help='path to temporarily save checkpoints') + group.add_argument('--restore-file', default='checkpoint_last.ckpt', # MindSpore默认 checkpoint 后缀为 .ckpt + help='filename from which to load checkpoint ' + '(default: /checkpoint_last.ckpt') + group.add_argument('--finetune-from-model', type=str, + help="finetune from a pretrained model; note that meters and lr scheduler will be reset") + group.add_argument('--load-from-ema', action="store_true", + help="finetune from a pretrained model; note that meters and lr scheduler will be reset") + group.add_argument('--reset-dataloader', action='store_true', + help='if set, does not reload dataloader state from the checkpoint') + group.add_argument('--reset-lr-scheduler', action='store_true', + help='if set, does not load lr scheduler state from the checkpoint') + group.add_argument('--reset-meters', action='store_true', + help='if set, does not load meters from the checkpoint') + group.add_argument('--reset-optimizer', action='store_true', + help='if set, does not load optimizer state from the checkpoint') + group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT', + help='a dictionary used to override optimizer args when loading a checkpoint') + group.add_argument('--save-interval', type=int, default=1, metavar='N', + help='save a checkpoint every N epochs') + group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', + help='save a checkpoint (and validate) every N updates') + group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', + help='keep the last N checkpoints saved with --save-interval-updates') + group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N', + help='keep last N epoch checkpoints') + group.add_argument('--keep-best-checkpoints', type=int, default=-1, metavar='N', + help='keep best N checkpoints based on scores') + group.add_argument('--no-save', action='store_true', + help='don\'t save models or checkpoints') + group.add_argument('--no-epoch-checkpoints', action='store_true', + help='only store last and best checkpoints') + group.add_argument('--no-last-checkpoints', action='store_true', + help='don\'t store last checkpoints') + group.add_argument('--no-save-optimizer-state', action='store_true', + help='don\'t save optimizer-state as part of checkpoint') + group.add_argument('--best-checkpoint-metric', type=str, default='loss', + help='metric to use for saving "best" checkpoints') + group.add_argument('--maximize-best-checkpoint-metric', action='store_true', + help='select the largest metric value for saving "best" checkpoints') + group.add_argument('--patience', type=int, default=-1, metavar='N', + help="early stop training if valid performance doesn't " + "improve for N consecutive validation runs; note " + "that this is influenced by --validate-interval") + group.add_argument('--checkpoint-suffix', type=str, default="", + help="suffix to add to the checkpoint file name") + # fmt: on + return group + + +def add_common_eval_args(group): + # fmt: off + group.add_argument('--path', metavar='FILE', + help='path(s) to model file(s), colon separated') + group.add_argument('--quiet', action='store_true', + help='only print final scores') + group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', + help='a dictionary used to override model args at generation ' + 'that were used during model training') + group.add_argument('--results-path', metavar='RESDIR', type=str, default=None, + help='path to save eval results (optional)"') + # fmt: on + + +def add_model_args(parser): + group = parser.add_argument_group("Model configuration") + # fmt: off + + # Model definitions can be found under unicore/models/ + # + # The model architecture can be specified in several ways. + # In increasing order of priority: + # 1) model defaults (lowest priority) + # 2) --arch argument + # 3) --encoder/decoder-* arguments (highest priority) + from unicore.models import ARCH_MODEL_REGISTRY + group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True, + choices=ARCH_MODEL_REGISTRY.keys(), + help='Model Architecture') + # fmt: on + return group +# import argparse + +# import torch + + +# from typing import Callable, List, Optional + +# # this import is for backward compatibility +# from unicore.utils import ( +# csv_str_list, +# eval_bool, +# eval_str_dict, +# eval_str_list, +# import_user_module, +# ) # noqa + + +# def get_training_parser(default_task="translation"): +# parser = get_parser("Trainer", default_task) +# add_dataset_args(parser, train=True) +# add_distributed_training_args(parser) +# add_model_args(parser) +# add_optimization_args(parser) +# add_checkpoint_args(parser) +# return parser + + +# def get_validation_parser(default_task=None): +# parser = get_parser("Validation", default_task) +# add_dataset_args(parser, train=True) +# add_distributed_training_args(parser) +# group = parser.add_argument_group("Evaluation") +# add_common_eval_args(group) +# return parser + + +# def parse_args_and_arch( +# parser: argparse.ArgumentParser, +# input_args: List[str] = None, +# parse_known: bool = False, +# suppress_defaults: bool = False, +# modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None, +# ): +# """ +# Args: +# parser (ArgumentParser): the parser +# input_args (List[str]): strings to parse, defaults to sys.argv +# parse_known (bool): only parse known arguments, similar to +# `ArgumentParser.parse_known_args` +# suppress_defaults (bool): parse while ignoring all default values +# modify_parser (Optional[Callable[[ArgumentParser], None]]): +# function to modify the parser, e.g., to set default values +# """ +# if suppress_defaults: +# # Parse args without any default values. This requires us to parse +# # twice, once to identify all the necessary task/model args, and a second +# # time with all defaults set to None. +# args = parse_args_and_arch( +# parser, +# input_args=input_args, +# parse_known=parse_known, +# suppress_defaults=False, +# ) +# suppressed_parser = argparse.ArgumentParser(add_help=False, parents=[parser]) +# suppressed_parser.set_defaults(**{k: None for k, v in vars(args).items()}) +# args = suppressed_parser.parse_args(input_args) +# return argparse.Namespace( +# **{k: v for k, v in vars(args).items() if v is not None} +# ) + +# from unicore.models import ARCH_MODEL_REGISTRY, ARCH_CONFIG_REGISTRY, MODEL_REGISTRY + +# # Before creating the true parser, we need to import optional user module +# # in order to eagerly import custom tasks, optimizers, architectures, etc. +# usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) +# usr_parser.add_argument("--user-dir", default=None) +# usr_args, _ = usr_parser.parse_known_args(input_args) +# import_user_module(usr_args) + +# if modify_parser is not None: +# modify_parser(parser) + +# # The parser doesn't know about model/loss/optimizer-specific args, so +# # we parse twice. First we parse the model/loss/optimizer, then we +# # parse a second time after adding the *-specific arguments. +# # If input_args is given, we will parse those args instead of sys.argv. +# args, _ = parser.parse_known_args(input_args) + +# # Add model-specific args to parser. +# if hasattr(args, "arch"): +# model_specific_group = parser.add_argument_group( +# "Model-specific configuration", +# # Only include attributes which are explicitly given as command-line +# # arguments or which have default values. +# argument_default=argparse.SUPPRESS, +# ) +# if args.arch in ARCH_MODEL_REGISTRY: +# ARCH_MODEL_REGISTRY[args.arch].add_args(model_specific_group) +# elif args.arch in MODEL_REGISTRY: +# MODEL_REGISTRY[args.arch].add_args(model_specific_group) +# else: +# raise RuntimeError() + +# if hasattr(args, "task"): +# from unicore.tasks import TASK_REGISTRY + +# TASK_REGISTRY[args.task].add_args(parser) + +# # Add *-specific args to parser. +# from unicore.registry import REGISTRIES + +# for registry_name, REGISTRY in REGISTRIES.items(): +# choice = getattr(args, registry_name, None) +# if choice is not None: +# cls = REGISTRY["registry"][choice] +# if hasattr(cls, "add_args"): +# cls.add_args(parser) + +# # Modify the parser a second time, since defaults may have been reset +# if modify_parser is not None: +# modify_parser(parser) + +# # Parse a second time. +# if parse_known: +# args, extra = parser.parse_known_args(input_args) +# else: +# args = parser.parse_args(input_args) +# extra = None +# # Post-process args. +# if ( +# hasattr(args, "batch_size_valid") and args.batch_size_valid is None +# ) or not hasattr(args, "batch_size_valid"): +# args.batch_size_valid = args.batch_size +# args.bf16 = getattr(args, "bf16", False) + +# if getattr(args, "seed", None) is None: +# args.seed = 1 # default seed for training +# args.no_seed_provided = True +# else: +# args.no_seed_provided = False + +# args.validate_with_ema = getattr(args, "validate_with_ema", False) +# # Apply architecture configuration. +# if hasattr(args, "arch") and args.arch in ARCH_CONFIG_REGISTRY: +# ARCH_CONFIG_REGISTRY[args.arch](args) + +# if parse_known: +# return args, extra +# else: +# return args + + +# def get_parser(desc, default_task="test"): +# # Before creating the true parser, we need to import optional user module +# # in order to eagerly import custom tasks, optimizers, architectures, etc. +# usr_parser = argparse.ArgumentParser(add_help=False, allow_abbrev=False) +# usr_parser.add_argument("--user-dir", default=None) +# usr_args, _ = usr_parser.parse_known_args() +# import_user_module(usr_args) + +# parser = argparse.ArgumentParser(allow_abbrev=False) +# # fmt: off +# parser.add_argument('--no-progress-bar', action='store_true', help='disable progress bar') +# parser.add_argument('--log-interval', type=int, default=1000, metavar='N', +# help='log progress every N batches (when progress bar is disabled)') +# parser.add_argument('--log-format', default=None, help='log format to use', +# choices=['json', 'none', 'simple', 'tqdm']) +# parser.add_argument('--tensorboard-logdir', metavar='DIR', default='', +# help='path to save logs for tensorboard, should match --logdir ' +# 'of running tensorboard (default: no tensorboard logging)') +# parser.add_argument('--wandb-project', metavar='DIR', default='', +# help='name of wandb project, empty for no wandb logging, for wandb login, use env WANDB_API_KEY. You can also use team_name/project_name for project name.') +# parser.add_argument('--wandb-name', metavar='DIR', default='', +# help='wandb run/id name, empty for no wandb logging, for wandb login, use env WANDB_API_KEY') +# parser.add_argument('--seed', default=1, type=int, metavar='N', +# help='pseudo random number generator seed') +# parser.add_argument('--cpu', action='store_true', help='use CPU instead of CUDA') +# parser.add_argument('--fp16', action='store_true', help='use FP16') +# parser.add_argument('--bf16', action='store_true', help='use BF16') +# parser.add_argument('--bf16-sr', action='store_true', help='use stachostic rounding for bf16') +# parser.add_argument('--allreduce-fp32-grad', action='store_true', help='use fp32-grads in fp16/bf16 mode. --ddp-backend should be no_c10d') +# parser.add_argument('--fp16-no-flatten-grads', action='store_true', help="don't flatten FP16 grads tensor") +# parser.add_argument('--fp16-init-scale', default=2 ** 7, type=int, +# help='default FP16 loss scale') +# parser.add_argument('--fp16-scale-window', type=int, +# help='number of updates before increasing loss scale') +# parser.add_argument('--fp16-scale-tolerance', default=0.0, type=float, +# help='pct of updates that can overflow before decreasing the loss scale') +# parser.add_argument('--min-loss-scale', default=1e-4, type=float, metavar='D', +# help='minimum FP16 loss scale, after which training is stopped') +# parser.add_argument('--threshold-loss-scale', type=float, +# help='threshold FP16 loss scale from below') +# parser.add_argument('--user-dir', default=None, +# help='path to a python module containing custom extensions (tasks and/or architectures)') +# parser.add_argument('--empty-cache-freq', default=0, type=int, +# help='how often to clear the PyTorch CUDA cache (0 to disable)') +# parser.add_argument('--all-gather-list-size', default=16384, type=int, +# help='number of bytes reserved for gathering stats from workers') +# parser.add_argument('--suppress-crashes', action='store_true', help="suppress crashes when training with the entry point so that the " +# "main method can return a value (useful for sweeps)") +# parser.add_argument('--profile', action='store_true', help="enable autograd profiler emit_nvtx") +# parser.add_argument('--ema-decay', default=-1.0, type=float, help="enable moving average for model weights") +# parser.add_argument("--validate-with-ema", action="store_true") + + +# from unicore.registry import REGISTRIES +# for registry_name, REGISTRY in REGISTRIES.items(): +# parser.add_argument( +# '--' + registry_name.replace('_', '-'), +# default=REGISTRY['default'], +# choices=REGISTRY['registry'].keys(), +# ) + +# # Task definitions can be found under unicore/tasks/ +# from unicore.tasks import TASK_REGISTRY +# parser.add_argument('--task', metavar='TASK', default=default_task, +# choices=TASK_REGISTRY.keys(), +# help='task') +# # fmt: on +# return parser + + +# def add_dataset_args(parser, train=False, gen=False): +# group = parser.add_argument_group("Dataset and data loading") +# # fmt: off +# group.add_argument('--num-workers', default=1, type=int, metavar='N', +# help='how many subprocesses to use for data loading') +# group.add_argument('--skip-invalid-size-inputs-valid-test', action='store_true', +# help='ignore too long or too short lines in valid and test set') +# group.add_argument('--batch-size', '--max-sentences', type=int, metavar='N', +# help='maximum number of sentences in a batch') +# group.add_argument('--required-batch-size-multiple', default=1, type=int, metavar='N', +# help='batch size will be a multiplier of this value') +# group.add_argument('--data-buffer-size', default=10, type=int, +# help='Number of batches to preload') +# group.add_argument('--train-subset', default='train', metavar='SPLIT', +# choices=['train', 'valid', 'test', 'train.small'], +# help='data subset to use for training (train, valid, test)') +# group.add_argument('--valid-subset', default='valid', metavar='SPLIT', +# help='comma separated list of data subsets to use for validation' +# ' (train, valid, valid1, test, test1)') +# group.add_argument('--validate-interval', type=int, default=1, metavar='N', +# help='validate every N epochs') +# group.add_argument('--validate-interval-updates', type=int, default=0, metavar='N', +# help='validate every N updates') +# group.add_argument('--validate-after-updates', type=int, default=0, metavar='N', +# help='dont validate until reaching this many updates') +# group.add_argument('--fixed-validation-seed', default=None, type=int, metavar='N', +# help='specified random seed for validation') +# group.add_argument('--disable-validation', action='store_true', +# help='disable validation') +# group.add_argument('--batch-size-valid', type=int, metavar='N', +# help='maximum number of sentences in a validation batch' +# ' (defaults to --max-sentences)') +# group.add_argument('--max-valid-steps', type=int, metavar='N', +# help='How many batches to evaluate') +# group.add_argument('--curriculum', default=0, type=int, metavar='N', +# help='don\'t shuffle batches for first N epochs') +# # fmt: on +# return group + + +# def add_distributed_training_args(parser): +# group = parser.add_argument_group("Distributed training") +# # fmt: off +# group.add_argument('--distributed-world-size', type=int, metavar='N', +# default=max(1, torch.cuda.device_count()), +# help='total number of GPUs across all nodes (default: all visible GPUs)') +# group.add_argument('--distributed-rank', default=0, type=int, +# help='rank of the current worker') +# group.add_argument('--distributed-backend', default='nccl', type=str, +# help='distributed backend') +# group.add_argument('--distributed-init-method', default=None, type=str, +# help='typically tcp://hostname:port that will be used to ' +# 'establish initial connetion') +# group.add_argument('--distributed-port', default=-1, type=int, +# help='port number (not required if using --distributed-init-method)') +# group.add_argument('--device-id', '--local_rank', default=0, type=int, +# help='which GPU to use (usually configured automatically)') +# group.add_argument('--distributed-no-spawn', action='store_true', +# help='do not spawn multiple processes even if multiple GPUs are visible') +# group.add_argument('--ddp-backend', default='c10d', type=str, +# choices=['c10d', 'apex', 'no_c10d'], +# help='DistributedDataParallel backend') +# group.add_argument('--bucket-cap-mb', default=25, type=int, metavar='MB', +# help='bucket size for reduction') +# group.add_argument('--fix-batches-to-gpus', action='store_true', +# help='don\'t shuffle batches between GPUs; this reduces overall ' +# 'randomness and may affect precision but avoids the cost of ' +# 're-reading the data') +# group.add_argument('--find-unused-parameters', default=False, action='store_true', +# help='disable unused parameter detection (not applicable to ' +# 'no_c10d ddp-backend') +# group.add_argument('--fast-stat-sync', default=False, action='store_true', +# help='Enable fast sync of stats between nodes, this hardcodes to ' +# 'sync only some default stats from logging_output.') +# group.add_argument('--broadcast-buffers', default=False, action='store_true', +# help="Copy non-trainable parameters between GPUs, such as " +# "batchnorm population statistics") +# group.add_argument('--nprocs-per-node', default=max(1, torch.cuda.device_count()), type=int, +# help="number of GPUs in each node. An allreduce operation across GPUs in " +# "a node is very fast. Hence, we do allreduce across GPUs in a node, " +# "and gossip across different nodes") +# # fmt: on +# return group + + +# def add_optimization_args(parser): +# group = parser.add_argument_group("Optimization") +# # fmt: off +# group.add_argument('--max-epoch', '--me', default=0, type=int, metavar='N', +# help='force stop training at specified epoch') +# group.add_argument('--max-update', '--mu', default=0, type=int, metavar='N', +# help='force stop training at specified update') +# group.add_argument('--stop-time-hours', default=0, type=float, +# help="force stop training after specified cumulative time (if >0)") +# group.add_argument('--no-weight-decay-names', default="", type=str, +# help='names of parameters to not weight decay, comma separated') +# group.add_argument('--clip-norm', default=0, type=float, metavar='NORM', +# help='clip threshold of gradients') +# group.add_argument('--per-sample-clip-norm', default=0, type=float, metavar='PNORM', +# help='clip threshold of gradients, before gradient sync over workers. In fp16/bf16 mode, --fp32-grad should be set, and --dpp-backend should be no_c10d') +# group.add_argument('--update-freq', default='1', metavar='N1,N2,...,N_K', +# type=lambda uf: eval_str_list(uf, type=int), +# help='update parameters every N_i batches, when in epoch i') +# group.add_argument('--lr', '--learning-rate', default='0.25', type=eval_str_list, +# metavar='LR_1,LR_2,...,LR_N', +# help='learning rate for the first N epochs; all epochs >N using LR_N' +# ' (note: this may be interpreted differently depending on --lr-scheduler)') +# group.add_argument('--stop-min-lr', default=-1, type=float, metavar='LR', +# help='stop training when the learning rate reaches this minimum') +# # fmt: on +# return group + + +# def add_checkpoint_args(parser): +# group = parser.add_argument_group("Checkpointing") +# # fmt: off +# group.add_argument('--save-dir', metavar='DIR', default='checkpoints', +# help='path to save checkpoints') +# group.add_argument('--tmp-save-dir', metavar='DIR', default='./', +# help='path to temporarily save checkpoints') +# group.add_argument('--restore-file', default='checkpoint_last.pt', +# help='filename from which to load checkpoint ' +# '(default: /checkpoint_last.pt') +# group.add_argument('--finetune-from-model', type=str, +# help="finetune from a pretrained model; note that meters and lr scheduler will be reset") +# group.add_argument('--load-from-ema', action="store_true", +# help="finetune from a pretrained model; note that meters and lr scheduler will be reset") +# group.add_argument('--reset-dataloader', action='store_true', +# help='if set, does not reload dataloader state from the checkpoint') +# group.add_argument('--reset-lr-scheduler', action='store_true', +# help='if set, does not load lr scheduler state from the checkpoint') +# group.add_argument('--reset-meters', action='store_true', +# help='if set, does not load meters from the checkpoint') +# group.add_argument('--reset-optimizer', action='store_true', +# help='if set, does not load optimizer state from the checkpoint') +# group.add_argument('--optimizer-overrides', default="{}", type=str, metavar='DICT', +# help='a dictionary used to override optimizer args when loading a checkpoint') +# group.add_argument('--save-interval', type=int, default=1, metavar='N', +# help='save a checkpoint every N epochs') +# group.add_argument('--save-interval-updates', type=int, default=0, metavar='N', +# help='save a checkpoint (and validate) every N updates') +# group.add_argument('--keep-interval-updates', type=int, default=-1, metavar='N', +# help='keep the last N checkpoints saved with --save-interval-updates') +# group.add_argument('--keep-last-epochs', type=int, default=-1, metavar='N', +# help='keep last N epoch checkpoints') +# group.add_argument('--keep-best-checkpoints', type=int, default=-1, metavar='N', +# help='keep best N checkpoints based on scores') +# group.add_argument('--no-save', action='store_true', +# help='don\'t save models or checkpoints') +# group.add_argument('--no-epoch-checkpoints', action='store_true', +# help='only store last and best checkpoints') +# group.add_argument('--no-last-checkpoints', action='store_true', +# help='don\'t store last checkpoints') +# group.add_argument('--no-save-optimizer-state', action='store_true', +# help='don\'t save optimizer-state as part of checkpoint') +# group.add_argument('--best-checkpoint-metric', type=str, default='loss', +# help='metric to use for saving "best" checkpoints') +# group.add_argument('--maximize-best-checkpoint-metric', action='store_true', +# help='select the largest metric value for saving "best" checkpoints') +# group.add_argument('--patience', type=int, default=-1, metavar='N', +# help="early stop training if valid performance doesn't " +# "improve for N consecutive validation runs; note " +# "that this is influenced by --validate-interval") +# group.add_argument('--checkpoint-suffix', type=str, default="", +# help="suffix to add to the checkpoint file name") +# # fmt: on +# return group + + +# def add_common_eval_args(group): +# # fmt: off +# group.add_argument('--path', metavar='FILE', +# help='path(s) to model file(s), colon separated') +# group.add_argument('--quiet', action='store_true', +# help='only print final scores') +# group.add_argument('--model-overrides', default="{}", type=str, metavar='DICT', +# help='a dictionary used to override model args at generation ' +# 'that were used during model training') +# group.add_argument('--results-path', metavar='RESDIR', type=str, default=None, +# help='path to save eval results (optional)"') +# # fmt: on + + +# def add_model_args(parser): +# group = parser.add_argument_group("Model configuration") +# # fmt: off + +# # Model definitions can be found under unicore/models/ +# # +# # The model architecture can be specified in several ways. +# # In increasing order of priority: +# # 1) model defaults (lowest priority) +# # 2) --arch argument +# # 3) --encoder/decoder-* arguments (highest priority) +# from unicore.models import ARCH_MODEL_REGISTRY +# group.add_argument('--arch', '-a', default='fconv', metavar='ARCH', required=True, +# choices=ARCH_MODEL_REGISTRY.keys(), +# help='Model Architecture') +# # fmt: on +# return group diff --git a/MindChemistry/applications/Uni-Mol/unicore/registry.py b/MindChemistry/applications/Uni-Mol/unicore/registry.py new file mode 100644 index 0000000000000000000000000000000000000000..627547e4e2fa36810e466d436edd6ad8afd9e214 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/registry.py @@ -0,0 +1,81 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import argparse + + +REGISTRIES = {} + + +def setup_registry( + registry_name: str, + base_class=None, + default=None, +): + assert registry_name.startswith('--') + registry_name = registry_name[2:].replace('-', '_') + + REGISTRY = {} + REGISTRY_CLASS_NAMES = set() + + # maintain a registry of all registries + if registry_name in REGISTRIES: + return # registry already exists + REGISTRIES[registry_name] = { + 'registry': REGISTRY, + 'default': default, + } + + def build_x(args, *extra_args, **extra_kwargs): + choice = getattr(args, registry_name, None) + if choice is None: + return None + cls = REGISTRY[choice] + if hasattr(cls, 'build_' + registry_name): + builder = getattr(cls, 'build_' + registry_name) + else: + builder = cls + set_defaults(args, cls) + return builder(args, *extra_args, **extra_kwargs) + + def register_x(name): + + def register_x_cls(cls): + if name in REGISTRY: + raise ValueError('Cannot register duplicate {} ({})'.format(registry_name, name)) + if cls.__name__ in REGISTRY_CLASS_NAMES: + raise ValueError( + 'Cannot register {} with duplicate class name ({})'.format( + registry_name, cls.__name__, + ) + ) + if base_class is not None and not issubclass(cls, base_class): + raise ValueError('{} must extend {}'.format(cls.__name__, base_class.__name__)) + REGISTRY[name] = cls + REGISTRY_CLASS_NAMES.add(cls.__name__) + return cls + + return register_x_cls + + return build_x, register_x, REGISTRY + + +def set_defaults(args, cls): + """Helper to set default arguments based on *add_args*.""" + if not hasattr(cls, 'add_args'): + return + parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, allow_abbrev=False) + cls.add_args(parser) + # copied from argparse.py: + defaults = argparse.Namespace() + for action in parser._actions: + if action.dest is not argparse.SUPPRESS: + if not hasattr(defaults, action.dest): + if action.default is not argparse.SUPPRESS: + setattr(defaults, action.dest, action.default) + for key, default_value in vars(defaults).items(): + if not hasattr(args, key): + setattr(args, key, default_value) diff --git a/MindChemistry/applications/Uni-Mol/unicore/tasks/__init__.py b/MindChemistry/applications/Uni-Mol/unicore/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..22bc9952b46d51afa147ecf6805d7e1ad0964140 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/tasks/__init__.py @@ -0,0 +1,86 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +"""isort:skip_file""" + +import argparse +import importlib +import os + +from .unicore_task import UnicoreTask + + +# register dataclass +TASK_REGISTRY = {} +TASK_CLASS_NAMES = set() + + +def setup_task(args, **kwargs): + return TASK_REGISTRY[args.task].setup_task(args, **kwargs) + + +def register_task(name): + """ + New tasks can be added to unicore with the + :func:`~unicore.tasks.register_task` function decorator. + + For example:: + + @register_task('classification') + class ClassificationTask(UnicoreTask): + (...) + + .. note:: + + All Tasks must implement the :class:`~unicore.tasks.UnicoreTask` + interface. + + Args: + name (str): the name of the task + """ + + def register_task_cls(cls): + if name in TASK_REGISTRY: + raise ValueError("Cannot register duplicate task ({})".format(name)) + if not issubclass(cls, UnicoreTask): + raise ValueError( + "Task ({}: {}) must extend UnicoreTask".format(name, cls.__name__) + ) + if cls.__name__ in TASK_CLASS_NAMES: + raise ValueError( + "Cannot register task with duplicate class name ({})".format( + cls.__name__ + ) + ) + TASK_REGISTRY[name] = cls + TASK_CLASS_NAMES.add(cls.__name__) + return cls + + return register_task_cls + + +# automatically import any Python files in the tasks/ directory +tasks_dir = os.path.dirname(__file__) +for file in os.listdir(tasks_dir): + path = os.path.join(tasks_dir, file) + if ( + not file.startswith("_") + and not file.startswith(".") + and (file.endswith(".py") or os.path.isdir(path)) + ): + task_name = file[: file.find(".py")] if file.endswith(".py") else file + module = importlib.import_module("unicore.tasks." + task_name) + + # expose `task_parser` for sphinx + if task_name in TASK_REGISTRY: + parser = argparse.ArgumentParser(add_help=False) + group_task = parser.add_argument_group("Task name") + # fmt: off + group_task.add_argument('--task', metavar=task_name, + help='Enable this task with: ``--task=' + task_name + '``') + # fmt: on + group_args = parser.add_argument_group("Additional command-line arguments") + TASK_REGISTRY[task_name].add_args(group_args) + globals()[task_name + "_parser"] = parser diff --git a/MindChemistry/applications/Uni-Mol/unicore/tasks/unicore_task.py b/MindChemistry/applications/Uni-Mol/unicore/tasks/unicore_task.py new file mode 100644 index 0000000000000000000000000000000000000000..42e1a92582a77a18fbcf8d7de4408a81768d378c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/tasks/unicore_task.py @@ -0,0 +1,562 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +import warnings +from argparse import Namespace +from typing import Any, Callable, Dict, List + +import mindspore as ms +from mindspore import profiler +from unicore import metrics, utils +from unicore.data import UnicoreDataset, data_utils, iterators # 假设已适配MindSpore版本 + +logger = logging.getLogger(__name__) + + +class StatefulContainer(object): + + _state: Dict[str, Any] = dict() + _factories: Dict[str, Callable[[], Any]] = dict() + + def add_factory(self, name, factory: Callable[[], Any]): + self._factories[name] = factory + + def merge_state_dict(self, state_dict: Dict[str, Any]): + self._state.update(state_dict) + + @property + def state_dict(self) -> Dict[str, Any]: + return self._state + + def __getattr__(self, name): + if name not in self._state and name in self._factories: + self._state[name] = self._factories[name]() + + if name in self._state: + return self._state[name] + + raise AttributeError(f"Task state has no factory for attribute {name}") + + +class UnicoreTask(object): + """ + 适配MindSpore的任务基类,用于管理数据集、模型构建、训练/验证流程等 + """ + + @classmethod + def add_args(cls, parser): + """为解析器添加任务特定参数""" + pass + + @staticmethod + def logging_outputs_can_be_summed(loss, is_train) -> bool: + """判断训练/验证步骤返回的日志输出是否可在worker间求和""" + return loss.logging_outputs_can_be_summed(is_train) + + args: Namespace + datasets: Dict[str, UnicoreDataset] + dataset_to_epoch_iter: Dict[UnicoreDataset, Any] + state: StatefulContainer = None + + def __init__(self, args: Namespace, **kwargs): + self.args = args + self.datasets = dict() + self.dataset_to_epoch_iter = dict() + self.state = StatefulContainer() + + + @classmethod + def setup_task(cls, args: Namespace, **kwargs): + """设置任务(如加载词典)""" + return cls(args,** kwargs) + + def has_sharded_data(self, split): + return os.pathsep in getattr(self.args, "data", "") + + def load_dataset( + self, + split: str, + combine: bool = False, + **kwargs + ): + """加载指定的数据集拆分""" + raise NotImplementedError + + def dataset(self, split): + """返回已加载的数据集拆分""" + from unicore.data import UnicoreDataset + + if split not in self.datasets: + raise KeyError("Dataset not loaded: " + split) + if not isinstance(self.datasets[split], UnicoreDataset): + raise TypeError("Datasets are expected to be of type UnicoreDataset") + return self.datasets[split] + + def can_reuse_epoch_itr(self, dataset): + # 保持与原逻辑一致:判断是否可跨epoch复用迭代器 + return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + + def get_batch_iterator( + self, + dataset, + batch_size=None, + ignore_invalid_inputs=False, + required_batch_size_multiple=1, + seed=1, + num_shards=1, + shard_id=0, + num_workers=0, + epoch=1, + data_buffer_size=0, + disable_iterator_cache=False, + ): + """获取数据集的批处理迭代器(适配MindSpore数据集迭代逻辑)""" + can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( + dataset + ) + if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: + logger.info("reusing EpochBatchIterator for epoch {}".format(epoch)) + return self.dataset_to_epoch_iter[dataset] + else: + logger.info("get EpochBatchIterator for epoch {}".format(epoch)) + + assert isinstance(dataset, UnicoreDataset) + + # 初始化数据集的起始epoch + dataset.set_epoch(epoch) + + # 按样本大小排序的索引 + with data_utils.numpy_seed(seed): + indices = dataset.ordered_indices() + + # 创建满足大小约束的mini-batch采样器 + batch_sampler = dataset.batch_by_size( + indices, + batch_size=batch_size, + required_batch_size_multiple=required_batch_size_multiple, + ) + + # 返回可复用的分片迭代器 + epoch_iter = iterators.EpochBatchIterator( + dataset=dataset, + collate_fn=dataset.collater, + batch_sampler=batch_sampler, + seed=seed, + num_shards=num_shards, + shard_id=shard_id, + num_workers=num_workers, + epoch=epoch, + buffer_size=data_buffer_size, + disable_shuffling=self.disable_shuffling(), + ) + + if can_reuse_epoch_itr: + self.dataset_to_epoch_iter[dataset] = epoch_iter + + return epoch_iter + + def build_model(self, args: Namespace): + """构建MindSpore模型实例""" + from unicore import models + return models.build_model(args, self) + + def build_loss(self, args: Namespace): + """构建MindSpore损失函数实例""" + from unicore import losses + + return losses.build_loss(args, self) + + def train_step( + self, sample, model, loss, optimizer, update_num, ignore_grad=False + ): + """ + 训练步骤:前向计算、反向传播,返回损失相关信息 + """ + # MindSpore中设置模型为训练模式 + model.set_train(True) + model.set_num_updates(update_num) + with profiler.record_function("forward"): # 替换PyTorch的autograd.profiler + loss_val, sample_size, logging_output = loss(model, sample) + if ignore_grad: + loss_val *= 0 + with profiler.record_function("backward"): + # MindSpore中通过张量调用backward计算梯度 + loss_val.backward() + return loss_val, sample_size, logging_output + + def valid_step(self, sample, model, loss, test=False): + """验证步骤:关闭梯度计算的前向计算""" + # MindSpore中设置模型为评估模式 + model.set_train(False) + with ms.no_grad(): # 替换PyTorch的torch.no_grad() + loss_val, sample_size, logging_output = loss(model, sample) + return loss_val, sample_size, logging_output + + def optimizer_step(self, optimizer, model, update_num): + """优化器步骤:执行参数更新""" + optimizer.step() + + def build_dataset_for_inference( + self, src_tokens: List[ms.Tensor], src_lengths: List[int], **kwargs + ) -> ms.dataset.Dataset: # 替换PyTorch的Dataset为MindSpore的Dataset + raise NotImplementedError + + def begin_epoch(self, epoch, model): + """每个epoch开始时的钩子函数""" + pass + + def begin_valid_epoch(self, epoch, model): + """每个验证epoch开始时的钩子函数""" + pass + + def reduce_metrics(self, logging_outputs, loss, split='train'): + """聚合分布式训练的日志输出""" + if not any("bsz" in log for log in logging_outputs): + warnings.warn( + "bsz not found in Loss logging outputs, cannot log bsz" + ) + else: + bsz = sum(log.get("bsz", 0) for log in logging_outputs) + metrics.log_scalar("bsz", bsz, priority=190, round=1) + + loss.__class__.reduce_metrics(logging_outputs, split) + + def state_dict(self): + if self.state is not None: + return self.state.state_dict + return {} + + def load_state_dict(self, state_dict: Dict[str, Any]): + if self.state is not None: + self.state.merge_state_dict(state_dict) + + def disable_shuffling(self) -> bool: + return False +# import logging +# import os +# import warnings +# from argparse import Namespace +# from typing import Any, Callable, Dict, List + +# import torch +# from unicore import metrics, utils +# from unicore.data import UnicoreDataset, data_utils, iterators + +# logger = logging.getLogger(__name__) + + +# class StatefulContainer(object): + +# _state: Dict[str, Any] = dict() +# _factories: Dict[str, Callable[[], Any]] = dict() + +# def add_factory(self, name, factory: Callable[[], Any]): +# self._factories[name] = factory + +# def merge_state_dict(self, state_dict: Dict[str, Any]): +# self._state.update(state_dict) + +# @property +# def state_dict(self) -> Dict[str, Any]: +# return self._state + +# def __getattr__(self, name): +# if name not in self._state and name in self._factories: +# self._state[name] = self._factories[name]() + +# if name in self._state: +# return self._state[name] + +# raise AttributeError(f"Task state has no factory for attribute {name}") + + +# class UnicoreTask(object): +# """ +# Tasks store dictionaries and provide helpers for loading/iterating over +# Datasets, initializing the Model/Loss and calculating the loss. + +# Tasks have limited statefulness. In particular, state that needs to be +# saved to/loaded from checkpoints needs to be stored in the `self.state` +# :class:`StatefulContainer` object. For example:: + +# self.state.add_factory("dictionary", self.load_dictionary) +# print(self.state.dictionary) # calls self.load_dictionary() + +# This is necessary so that when loading checkpoints, we can properly +# recreate the task state after initializing the task instance. +# """ + +# @classmethod +# def add_args(cls, parser): +# """Add task-specific arguments to the parser.""" +# pass + +# @staticmethod +# def logging_outputs_can_be_summed(loss, is_train) -> bool: +# """ +# Whether the logging outputs returned by `train_step` and `valid_step` can +# be summed across workers prior to calling `reduce_metrics`. +# Setting this to True will improves distributed training speed. +# """ +# return loss.logging_outputs_can_be_summed(is_train) + +# args: Namespace +# datasets: Dict[str, UnicoreDataset] +# dataset_to_epoch_iter: Dict[UnicoreDataset, Any] +# state: StatefulContainer = None + +# def __init__(self, args: Namespace, **kwargs): +# self.args = args +# self.datasets = dict() +# self.dataset_to_epoch_iter = dict() +# self.state = StatefulContainer() + + +# @classmethod +# def setup_task(cls, args: Namespace, **kwargs): +# """Setup the task (e.g., load dictionaries). + +# Args: +# args (Namespace): parsed command-line arguments +# """ +# return cls(args, **kwargs) + +# def has_sharded_data(self, split): +# return os.pathsep in getattr(self.args, "data", "") + +# def load_dataset( +# self, +# split: str, +# combine: bool = False, +# **kwargs +# ): +# """Load a given dataset split. + +# Args: +# split (str): name of the split (e.g., train, valid, test) +# combine (bool): combines a split segmented into pieces into one dataset +# """ +# raise NotImplementedError + +# def dataset(self, split): +# """ +# Return a loaded dataset split. + +# Args: +# split (str): name of the split (e.g., train, valid, test) + +# Returns: +# a :class:`~unicore.data.UnicoreDataset` corresponding to *split* +# """ +# from unicore.data import UnicoreDataset + +# if split not in self.datasets: +# raise KeyError("Dataset not loaded: " + split) +# if not isinstance(self.datasets[split], UnicoreDataset): +# raise TypeError("Datasets are expected to be of type UnicoreDataset") +# return self.datasets[split] + +# def can_reuse_epoch_itr(self, dataset): +# # We can reuse the epoch iterator across epochs as long as the dataset +# # hasn't disabled it. We default to ``False`` here, although in practice +# # this will be ``True`` for most datasets that inherit from +# # ``UnicoreDataset`` due to the base implementation there. +# return getattr(dataset, "can_reuse_epoch_itr_across_epochs", False) + +# def get_batch_iterator( +# self, +# dataset, +# batch_size=None, +# ignore_invalid_inputs=False, +# required_batch_size_multiple=1, +# seed=1, +# num_shards=1, +# shard_id=0, +# num_workers=0, +# epoch=1, +# data_buffer_size=0, +# disable_iterator_cache=False, +# ): +# """ +# Get an iterator that yields batches of data from the given dataset. + +# Args: +# dataset (~unicore.data.UnicoreDataset): dataset to batch +# batch_size (int, optional): max number of samples in each +# batch (default: None). +# ignore_invalid_inputs (bool, optional): don't raise Exception for +# sentences that are too long (default: False). +# required_batch_size_multiple (int, optional): require batch size to +# be a multiple of N (default: 1). +# seed (int, optional): seed for random number generator for +# reproducibility (default: 1). +# num_shards (int, optional): shard the data iterator into N +# shards (default: 1). +# shard_id (int, optional): which shard of the data iterator to +# return (default: 0). +# num_workers (int, optional): how many subprocesses to use for data +# loading. 0 means the data will be loaded in the main process +# (default: 0). +# epoch (int, optional): the epoch to start the iterator from +# (default: 1). +# data_buffer_size (int, optional): number of batches to +# preload (default: 0). +# disable_iterator_cache (bool, optional): don't cache the +# EpochBatchIterator (ignores `UnicoreTask::can_reuse_epoch_itr`) +# (default: False). +# Returns: +# ~unicore.iterators.EpochBatchIterator: a batched iterator over the +# given dataset split +# """ +# can_reuse_epoch_itr = not disable_iterator_cache and self.can_reuse_epoch_itr( +# dataset +# ) +# if can_reuse_epoch_itr and dataset in self.dataset_to_epoch_iter: +# logger.info("reusing EpochBatchIterator for epoch {}".format(epoch)) +# return self.dataset_to_epoch_iter[dataset] +# else: +# logger.info("get EpochBatchIterator for epoch {}".format(epoch)) + +# assert isinstance(dataset, UnicoreDataset) + +# # initialize the dataset with the correct starting epoch +# dataset.set_epoch(epoch) + +# # get indices ordered by example size +# with data_utils.numpy_seed(seed): +# indices = dataset.ordered_indices() + +# # create mini-batches with given size constraints +# batch_sampler = dataset.batch_by_size( +# indices, +# batch_size=batch_size, +# required_batch_size_multiple=required_batch_size_multiple, +# ) + +# # return a reusable, sharded iterator +# epoch_iter = iterators.EpochBatchIterator( +# dataset=dataset, +# collate_fn=dataset.collater, +# batch_sampler=batch_sampler, +# seed=seed, +# num_shards=num_shards, +# shard_id=shard_id, +# num_workers=num_workers, +# epoch=epoch, +# buffer_size=data_buffer_size, +# disable_shuffling=self.disable_shuffling(), +# ) + +# if can_reuse_epoch_itr: +# self.dataset_to_epoch_iter[dataset] = epoch_iter + +# return epoch_iter + +# def build_model(self, args: Namespace): +# """ +# Build the :class:`~unicore.models.BaseUnicoreModel` instance for this +# task. + +# Returns: +# a :class:`~unicore.models.BaseUnicoreModel` instance +# """ +# from unicore import models +# return models.build_model(args, self) + +# def build_loss(self, args: Namespace): +# """ +# Build the :class:`~unicore.losses.UnicoreLoss` instance for +# this task. + +# Args: +# args (Namespace): configration object + +# Returns: +# a :class:`~unicore.losses.UnicoreLoss` instance +# """ +# from unicore import losses + +# return losses.build_loss(args, self) + +# def train_step( +# self, sample, model, loss, optimizer, update_num, ignore_grad=False +# ): +# """ +# Do forward and backward, and return the loss as computed by *loss* +# for the given *model* and *sample*. + +# Args: +# sample (dict): the mini-batch. The format is defined by the +# :class:`~unicore.data.UnicoreDataset`. +# model (~unicore.models.BaseUnicoreModel): the model +# loss (~unicore.losses.UnicoreLoss): the loss +# optimizer (~unicore.optim.UnicoreOptimizer): the optimizer +# update_num (int): the current update +# ignore_grad (bool): multiply loss by 0 if this is set to True + +# Returns: +# tuple: +# - the loss +# - the sample size, which is used as the denominator for the +# gradient +# - logging outputs to display while training +# """ +# model.train() +# model.set_num_updates(update_num) +# with torch.autograd.profiler.record_function("forward"): +# loss, sample_size, logging_output = loss(model, sample) +# if ignore_grad: +# loss *= 0 +# with torch.autograd.profiler.record_function("backward"): +# optimizer.backward(loss) +# return loss, sample_size, logging_output + +# def valid_step(self, sample, model, loss, test=False): +# model.eval() +# with torch.no_grad(): +# loss, sample_size, logging_output = loss(model, sample) +# return loss, sample_size, logging_output + +# def optimizer_step(self, optimizer, model, update_num): +# optimizer.step() + +# def build_dataset_for_inference( +# self, src_tokens: List[torch.Tensor], src_lengths: List[int], **kwargs +# ) -> torch.utils.data.Dataset: +# raise NotImplementedError + +# def begin_epoch(self, epoch, model): +# """Hook function called before the start of each epoch.""" +# pass + +# def begin_valid_epoch(self, epoch, model): +# """Hook function called before the start of each validation epoch.""" +# pass + +# def reduce_metrics(self, logging_outputs, loss, split='train'): +# """Aggregate logging outputs from data parallel training.""" +# if not any("bsz" in log for log in logging_outputs): +# warnings.warn( +# "bsz not found in Loss logging outputs, cannot log bsz" +# ) +# else: +# bsz = sum(log.get("bsz", 0) for log in logging_outputs) +# metrics.log_scalar("bsz", bsz, priority=190, round=1) + +# loss.__class__.reduce_metrics(logging_outputs, split) + +# def state_dict(self): +# if self.state is not None: +# return self.state.state_dict +# return {} + +# def load_state_dict(self, state_dict: Dict[str, Any]): +# if self.state is not None: +# self.state.merge_state_dict(state_dict) + +# def disable_shuffling(self) -> bool: +# return False \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unicore/trainer.py b/MindChemistry/applications/Uni-Mol/unicore/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..4bcb8c74cdcd175b9c37895cf168a59443b3fd28 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/trainer.py @@ -0,0 +1,2252 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +""" +Train a network across multiple GPUs. +""" +import contextlib +import logging +import os +import sys +import time +from itertools import chain +from typing import Any, Dict, List +import mindspore as ms +from mindspore import Tensor, ops, nn +from mindspore.communication import init, get_rank, get_group_size, all_gather, all_reduce +from mindspore.context import set_context +from mindspore.nn.wrap import DistributedParallel + +from unicore import checkpoint_utils, models, optim, utils +from unicore.distributed import utils as distributed_utils +from unicore.logging import meters, metrics +from unicore.nan_detector import NanDetector +from unicore.optim import lr_scheduler +from unicore.ema import ExponentialMovingAverageModel + + +logger = logging.getLogger(__name__) + + +class Trainer(object): + """Main class for data parallel training. + + This class supports synchronous distributed data parallel training, + where multiple workers each have a full model replica and gradients + are accumulated across workers before each update. We use + :class:`~mindspore.nn.DistributedParallel` to handle + communication of the gradients across workers. + """ + + def __init__(self, args, task, model, loss): + + self.args = args + self.task = task + + # 初始化分布式环境 + if args.distributed_world_size > 1: + init() + + # catalog shared parameters + shared_params = _catalog_shared_params(model) + self.cuda = args.device_target == "GPU" + if self.cuda: + set_context(device_target="GPU") + self.device = ms.device("GPU") + else: + set_context(device_target="CPU") + self.device = ms.device("CPU") + + # 复制模型和损失函数到设备/设置数据类型 + self._loss = loss + self._model = model + if args.fp16: + self._loss = self._loss.astype(ms.float16) + self._model = self._model.astype(ms.float16) + elif args.bf16: + self._loss = self._loss.astype(ms.bfloat16) + self._model = self._model.astype(ms.bfloat16) + + # 非分布式包装时手动移动设备 + if not self.use_distributed_wrapper: + self._loss = self._loss.to(self.device) + self._model = self._model.to(self.device) + + # 检查共享参数在设备转移后是否保留 + for shared_param in shared_params: + ref = _get_module_by_path(self._model, shared_param[0]) + for path in shared_param[1:]: + logger.info( + "detected shared parameter: {} <- {}".format(shared_param[0], path) + ) + _set_module_by_path(self._model, path, ref) + + self._dummy_batch = None # 初始无虚拟批次 + self._total_train_steps = None + self._lr_scheduler = None + self._num_updates = 0 + self._optim_history = None + self._optimizer = None + self._warn_once = set() + self._wrapped_loss = None + self._wrapped_model = None + + # 梯度范数缓冲区 + if self.cuda and self.data_parallel_world_size > 1: + self._grad_norm_buf = Tensor( + data=[0.0] * self.data_parallel_world_size, + dtype=ms.float64, + device=self.device + ) + else: + self._grad_norm_buf = None + + # 获取CUDA环境信息 + if self.cuda: + self.cuda_env = utils.CudaEnvironment() + if self.data_parallel_world_size > 1: + self.cuda_env_arr = distributed_utils.all_gather_list( + self.cuda_env, group=distributed_utils.get_global_group() + ) + else: + self.cuda_env_arr = [self.cuda_env] + if self.data_parallel_rank == 0: + utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) + else: + self.cuda_env = None + self.cuda_env_arr = None + + # 初始化EMA + if args.validate_with_ema: + assert args.ema_decay > 0, "valid with ema must with ema_decay > 0" + + model = self.model + if args.ema_decay > 0 and ( + self.data_parallel_rank == 0 or args.validate_with_ema + ): + self.ema = ExponentialMovingAverageModel( + args, + model, + args.ema_decay, + is_flattened=(args.fp16 or args.bf16), + ) + else: + self.ema = None + metrics.log_start_time("wall", priority=790, round=2) + + self._start_time = time.time() + self._previous_training_time = 0 + self._cumulative_training_time = None + + def reinitialize(self): + """重新初始化Trainer,通常在模型参数改变后""" + self._lr_scheduler = None + self._optimizer = None + self._wrapped_loss = None + self._wrapped_model = None + + @property + def data_parallel_world_size(self): + if self.args.distributed_world_size == 1: + return 1 + return get_group_size() + + @property + def data_parallel_process_group(self): + return distributed_utils.get_data_parallel_group() + + @property + def data_parallel_rank(self): + if self.args.distributed_world_size == 1: + return 0 + return get_rank() + + @property + def is_data_parallel_master(self): + return self.data_parallel_rank == 0 + + @property + def use_distributed_wrapper(self) -> bool: + return self.data_parallel_world_size > 1 + + @property + def should_save_checkpoint_on_current_rank(self) -> bool: + return self.is_data_parallel_master + + @property + def checkpoint_suffix(self) -> str: + return self.args.checkpoint_suffix or "" + + @property + def loss(self): + if self._wrapped_loss is None: + if utils.has_parameters(self._loss) and self.use_distributed_wrapper: + self._wrapped_loss = DistributedParallel( + self._loss, + process_group=self.data_parallel_process_group, + device_ids=[self.data_parallel_rank] + ) + else: + self._wrapped_loss = self._loss + return self._wrapped_loss + + @property + def model(self): + if self._wrapped_model is None: + if self.use_distributed_wrapper: + self._wrapped_model = DistributedParallel( + self._model, + process_group=self.data_parallel_process_group, + device_ids=[self.data_parallel_rank] + ) + else: + self._wrapped_model = self._model + return self._wrapped_model + + @property + def optimizer(self): + if self._optimizer is None: + self._build_optimizer() + return self._optimizer + + @property + def lr_scheduler(self): + if self._lr_scheduler is None: + self._build_optimizer() # 初始化优化器时会同时初始化学习率调度器 + return self._lr_scheduler + + def _build_optimizer(self): + params = [ + (name, param) + for name, param in chain( + self.model.named_parameters(), + self.loss.named_parameters(), + ) + if param.requires_grad + ] + + # 根据精度设置优化器 + if self.args.fp16 or self.args.bf16: + if self.cuda: + logger.info("使用混合精度优化器") + self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + else: + if self.cuda: + logger.info("使用FP32优化器") + self._optimizer = optim.build_optimizer(self.args, params) + + # 初始化学习率调度器 + self._lr_scheduler = lr_scheduler.build_lr_scheduler( + self.args, + self.optimizer, + self._total_train_steps, + ) + self._lr_scheduler.step_update(0) + + def state_dict(self): + state_dict = { + "args": self.args, + "model": self.model.state_dict(), + "loss": ( + self.loss.state_dict() if utils.has_parameters(self.loss) else None + ), + "optimizer_history": (self._optim_history or []) + + [ + { + "loss_name": self.get_loss().__class__.__name__, + "optimizer_name": self.optimizer.__class__.__name__, + "lr_scheduler_state": self.lr_scheduler.state_dict(), + "num_updates": self.get_num_updates(), + } + ], + "task_state": self.task.state_dict() if self.task is not None else {}, + "extra_state": { + "metrics": metrics.state_dict(), + "previous_training_time": self.cumulative_training_time(), + }, + } + if not self.args.no_save_optimizer_state: + state_dict["last_optimizer_state"] = self.optimizer.state_dict() + if self.ema is not None: + state_dict["ema"] = self.ema.state_dict() + return state_dict + + def save_checkpoint(self, filename, extra_state): + """保存训练状态到检查点文件""" + logger.info(f"Saving checkpoint to {filename}") + state_dict = utils.move_to_cpu(self.state_dict()) + state_dict["extra_state"].update(extra_state) + if self.should_save_checkpoint_on_current_rank: + checkpoint_utils.mindspore_persistent_save( # 替换为MindSpore的保存函数 + state_dict, + filename, + ) + logger.info(f"Finished saving checkpoint to {filename}") + + def load_checkpoint( + self, + filename, + reset_optimizer=False, + reset_lr_scheduler=False, + reset_dataloader=False, + optimizer_overrides=None, + reset_meters=False,** passthrough_args, + ): + """从检查点加载训练状态""" + extra_state, self._optim_history, last_optim_state = None, [], None + + logger.info(f"Preparing to load checkpoint {filename}") + is_distributed = self.data_parallel_world_size > 1 + is_master = self.data_parallel_rank == 0 + bexists = None + if is_master: + bexists = os.path.isfile(filename) + if is_distributed: + bexists = distributed_utils.broadcast_object( + bexists, + src_rank=0, + group=self.data_parallel_process_group, + dist_device=self.device, + ) + + had_loaded_model = False + ema_loaded = False + if bexists: + state = None + if is_master: + state = checkpoint_utils.load_checkpoint_to_cpu( # 替换为MindSpore的加载函数 + filename, + ) + if is_distributed: + logger.info("Broadcast checkpoint from rank_0") + state = distributed_utils.broadcast_object( + state, + src_rank=0, + group=self.data_parallel_process_group, + dist_device=self.device, + ) + last_optim_state = state.get("last_optimizer_state", None) + ema_state = state.get("ema", None) + + # 加载模型参数 + try: + if self.args.load_from_ema: + logger.info("loading ema state to model") + errors = self.model.load_state_dict( + ema_state["params"], strict=False, model_args=self.args + ) + ema_loaded = True + else: + errors = self.model.load_state_dict( + state["model"], strict=False, model_args=self.args + ) + del state["model"] # 释放内存 + had_loaded_model = True + + if errors.missing_keys: + logger.warning( + "Error in loading model state, missing_keys " + + str(errors.missing_keys) + ) + if errors.unexpected_keys: + logger.warning( + "Error in loading model state, unexpected_keys " + + str(errors.unexpected_keys) + ) + if utils.has_parameters(self.get_loss()): + self.get_loss().load_state_dict(state["loss"], strict=True) + del state["loss"] + + except Exception: + raise Exception( + "Cannot load model parameters from checkpoint {}; " + "please ensure that the architectures match.".format(filename) + ) + extra_state = state["extra_state"] if "extra_state" in state else None + self._optim_history = ( + state["optimizer_history"] if "optimizer_history" in state else None + ) + + # 加载EMA状态 + if ema_state is not None and self.ema is not None and not self.args.load_from_ema: + logger.info(f"Loading EMA state...") + self.ema.load_state_dict(ema_state) + elif self.ema is not None and not ema_loaded: + logger.info( + f"Cannot find EMA state in checkpoint, load model weight to ema directly" + ) + self.ema = ExponentialMovingAverageModel( + self.args, + self._model, + decay=self.ema.decay, + is_flattened=(self.args.fp16 or self.args.bf16), + ) + + loaded_train_itr = False + if extra_state is not None: + itr_state = extra_state["train_iterator"] + epoch = itr_state["epoch"] + + if "previous_training_time" in extra_state: + self._previous_training_time = extra_state["previous_training_time"] + self._start_time = time.time() + + if ( + itr_state.get("version", 1) >= 2 + and itr_state["iterations_in_epoch"] == 0 + ): + reset_meters = True # epoch开始时重置计量器 + + if "metrics" in extra_state and not reset_meters: + metrics.load_state_dict(extra_state["metrics"]) + # 重置时间计量器(其起始时间已无意义) + for meter in metrics.get_meters("default"): + if isinstance(meter, meters.TimeMeter): + meter.reset() + + if not reset_dataloader: + # 从检查点恢复迭代器 + epoch_itr = self.get_train_iterator( + epoch=itr_state["epoch"], load_dataset=True, **passthrough_args + ) + epoch_itr.load_state_dict(itr_state) + loaded_train_itr = True + + if not loaded_train_itr: + epoch_itr = self.get_train_iterator( + epoch=1, load_dataset=True,** passthrough_args + ) + + self.init_total_train_steps(epoch_itr) + + # 加载优化器状态 + if last_optim_state is not None and not reset_optimizer: + self._build_optimizer() # 加载模型后重建优化器 + + # 检查优化器和损失是否匹配 + last_optim = self._optim_history[-1] + assert ( + last_optim["loss_name"] == self.get_loss().__class__.__name__ + ), f"Loss does not match; please reset the optimizer (--reset-optimizer). {last_optim['loss_name']} vs {self.get_loss().__class__.__name__}" + assert ( + last_optim["optimizer_name"] == self.optimizer.__class__.__name__ + ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" + + if not reset_lr_scheduler: + self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + + self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + self.set_num_updates(last_optim["num_updates"]) + + # 日志输出加载结果 + if had_loaded_model: + if loaded_train_itr: + logger.info( + "Loaded checkpoint {} (epoch {} @ {} updates)".format( + filename, epoch, self.get_num_updates() + ) + ) + else: + logger.info("Loaded checkpoint {}".format(filename)) + elif ema_loaded: + logger.info("Loaded ema state from checkpoint {}".format(filename)) + else: + logger.info("No existing checkpoint found {}".format(filename)) + + self.lr_step(epoch_itr.epoch) + + return extra_state, epoch_itr + + def get_train_iterator( + self, + epoch, + combine=True, + load_dataset=True, + data_selector=None, + shard_batch_itr=True, + disable_iterator_cache=False, + ): + """获取训练集的迭代器""" + if load_dataset: + logger.info("loading train data for epoch {}".format(epoch)) + self.task.load_dataset( + self.args.train_subset, + epoch=epoch, + combine=combine, + data_selector=data_selector, + ) + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(self.args.train_subset), + batch_size=self.args.batch_size, + ignore_invalid_inputs=True, + required_batch_size_multiple=self.args.required_batch_size_multiple, + seed=self.args.seed, + num_shards=self.data_parallel_world_size if shard_batch_itr else 1, + shard_id=self.data_parallel_rank if shard_batch_itr else 0, + num_workers=self.args.num_workers, + epoch=epoch, + data_buffer_size=self.args.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ) + self.reset_dummy_batch(batch_iterator.first_batch) + return batch_iterator + + def init_total_train_steps(self, epoch_itr): + if self.args.max_epoch > 0: + self._total_train_steps = ( + (len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch + ) + else: + self._total_train_steps = self.args.max_update + + def get_valid_iterator( + self, + subset, + disable_iterator_cache=False, + ): + """获取验证集的迭代器""" + batch_iterator = self.task.get_batch_iterator( + dataset=self.task.dataset(subset), + batch_size=self.args.batch_size_valid, + ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, + required_batch_size_multiple=self.args.required_batch_size_multiple, + seed=self.args.seed, + num_shards=self.data_parallel_world_size, + shard_id=self.data_parallel_rank, + num_workers=self.args.num_workers, + epoch=1, # 固定epoch确保验证数据一致 + data_buffer_size=self.args.data_buffer_size, + disable_iterator_cache=disable_iterator_cache, + ) + return batch_iterator + + def begin_epoch(self, epoch): + """每个epoch开始时调用""" + logger.info("begin training epoch {}".format(epoch)) + self.lr_step_begin_epoch(epoch) + self.task.begin_epoch(epoch, self.get_model()) # 任务特定的epoch初始化 + + def begin_valid_epoch(self, epoch): + """每个验证epoch开始时调用""" + self.task.begin_valid_epoch(epoch, self.get_model()) + + def reset_dummy_batch(self, batch): + self._dummy_batch = batch + + @metrics.aggregate("train") + def train_step(self, samples, raise_oom=False): + """前向、反向传播和参数更新""" + self.model.set_train(True) + self.loss.set_train(True) + self.zero_grad() + + metrics.log_start_time("train_wall", priority=800, round=2) + + logging_outputs, sample_size, ooms = [], 0, 0 + for i, sample in enumerate(samples): # 延迟更新循环 + sample, is_dummy_batch = self._prepare_sample(sample) + + def maybe_no_sync(): + """ + 当samples包含多个mini-batch时,累积本地梯度,仅在最后一次反向传播时进行all-reduce + """ + if ( + self.data_parallel_world_size > 1 + and hasattr(self.model, "no_sync") + and i < len(samples) - 1 + ): + return self.model.no_sync() + else: + return contextlib.ExitStack() # 空上下文管理器 + + try: + with maybe_no_sync(): + # 不同rank使用不同种子,避免dropout行为一致 + with utils.mindspore_seed( # 替换为MindSpore的种子设置 + self.args.seed, + self.get_num_updates(), + i, + self.data_parallel_rank, + ): + # 前向和反向传播 + loss, sample_size_i, logging_output = self.task.train_step( + sample=sample, + model=self.model, + loss=self.loss, + optimizer=self.optimizer, + update_num=self.get_num_updates(), + ignore_grad=is_dummy_batch, + ) + del loss + if self.args.per_sample_clip_norm > 0: + self.optimizer.per_sample_clip_grad_norm( + self.args.per_sample_clip_norm + ) + + logging_outputs.append(logging_output) + sample_size += sample_size_i + + # 第一步后清空CUDA缓存减少OOM + if self.cuda and self.get_num_updates() == 0: + ms.clear_auto_alloc_cache() # 替换为MindSpore的缓存清理 + + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if raise_oom: + raise e + logger.warning( + "attempting to recover from OOM in forward/backward pass" + ) + ooms += 1 + self.zero_grad() + if self.cuda: + ms.clear_auto_alloc_cache() + if self.args.distributed_world_size == 1: + return None + else: + raise e + + # 虚拟批次样本量归零 + if is_dummy_batch: + if isinstance(sample_size, Tensor): + sample_size.zero_() + else: + sample_size *= 0.0 + + # 转换样本量类型 + if isinstance(sample_size, Tensor): + sample_size = sample_size.astype(ms.float32) + else: + sample_size = float(sample_size) + + local_sample_size = sample_size + # 聚合所有副本的日志输出 + if self._sync_stats(): + train_time = self._local_cumulative_training_time() + logging_outputs, ( + sample_size, + ooms, + total_train_time, + ) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ooms, + train_time, + ignore=is_dummy_batch, + is_train=True, + ) + self._cumulative_training_time = ( + total_train_time / self.data_parallel_world_size + ) + + overflow = False + try: + with ops.profiler.record_function("reduce-grads"): # 替换为MindSpore的profiler + # 跨worker聚合梯度 + self.optimizer.all_reduce_grads(self.model) + if utils.has_parameters(self.loss): + self.optimizer.all_reduce_grads(self.loss) + + with ops.profiler.record_function("multiply-grads"): + # 梯度缩放:(数据并行大小 / 样本量) + numer = self.data_parallel_world_size if self._sync_stats() else 1 + self.optimizer.multiply_grads(numer / (sample_size or 1.0)) + + with ops.profiler.record_function("clip-grads"): + # 梯度裁剪 + grad_norm = self.clip_grad_norm(self.args.clip_norm) + + self._check_grad_norms(grad_norm) + if not ops.isfinite(grad_norm).all(): # 替换为MindSpore的isfinite + raise FloatingPointError("gradients are Nan/Inf") + + with ops.profiler.record_function("optimizer"): + # 固定种子确保不同rank的随机行为一致 + with utils.mindspore_seed(self.args.seed, self.get_num_updates()): + # 执行优化步骤 + self.task.optimizer_step( + self.optimizer, + model=self.model, + update_num=self.get_num_updates(), + ) + + # 更新EMA + if self.ema is not None: + with ops.profiler.record_function("ema"): + if self.args.fp16 or self.args.bf16: + self.ema.update(self.optimizer.fp32_params) + else: + self.ema.update(self.model.named_parameters()) + + except FloatingPointError: + # 附加钩子重新运行以定位错误 + self.zero_grad() + with NanDetector(self.get_model()): + for i, sample in enumerate(samples): + sample, _ = self._prepare_sample(sample) + with utils.mindspore_seed( + self.args.seed, + self.get_num_updates(), + i, + self.data_parallel_rank, + ): + self.task.train_step( + sample, + self.model, + self.loss, + self.optimizer, + self.get_num_updates(), + ignore_grad=False, + ) + raise + except OverflowError as e: + overflow = True + logger.info( + f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" + ) + grad_norm = Tensor(0.0, device=self.device) + self.zero_grad() + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + logger.error("OOM during optimization, irrecoverable") + raise e + + logging_output = None + if not overflow: + self.set_num_updates(self.get_num_updates() + 1) + + # 记录GPU内存使用 + if self.cuda and self.cuda_env is not None: + gb_used = ms.get_runtime_context().max_device_memory / 1024 / 1024 / 1024 + ms.get_runtime_context().reset_max_device_memory() # 重置峰值内存统计 + gb_free = self.cuda_env.total_memory_in_GB - gb_used + metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0) + + # 日志统计 + logging_output = self._reduce_and_log_stats( + logging_outputs, + sample_size, + grad_norm, + ) + + # 定期清空缓存减少内存碎片 + if ( + self.cuda + and self.args.empty_cache_freq > 0 + and ( + (self.get_num_updates() + self.args.empty_cache_freq - 1) + % self.args.empty_cache_freq + ) + == 0 + ): + ms.clear_auto_alloc_cache() + + # 记录损失缩放(FP16时) + if self.args.fp16: + metrics.log_scalar( + "loss_scale", + self.optimizer.scaler.loss_scale, + priority=700, + round=4, + weight=0, + ) + + metrics.log_stop_time("train_wall") + return logging_output + + @metrics.aggregate("valid") + def valid_step(self, sample, raise_oom=False): + """验证阶段前向传播""" + with ms.no_grad(): # 替换为MindSpore的no_grad + self.model.set_train(False) + self.loss.set_train(False) + + sample, is_dummy_batch = self._prepare_sample(sample) + + try: + _loss, sample_size, logging_output = self.task.valid_step( + sample, self.model, self.loss + ) + except RuntimeError as e: + if "out of memory" in str(e): + self._log_oom(e) + if not raise_oom: + logger.warning( + "ran out of memory in validation step, retrying batch" + ) + for p in self.model.parameters(): + if p.grad is not None: + p.grad = None # 释放内存 + if self.cuda: + ms.clear_auto_alloc_cache() + return self.valid_step(sample, raise_oom=True) + raise e + + logging_outputs = [logging_output] + if is_dummy_batch: + if isinstance(sample_size, Tensor): + sample_size.zero_() + else: + sample_size *= 0.0 + + # 聚合所有副本的日志输出 + if self.data_parallel_world_size > 1: + logging_outputs, (sample_size,) = self._aggregate_logging_outputs( + logging_outputs, + sample_size, + ignore=is_dummy_batch, + is_train=False, + ) + + return logging_outputs + + def zero_grad(self): + self.optimizer.zero_grad() + + def lr_step_begin_epoch(self, epoch): + """epoch开始时调整学习率""" + self.lr_scheduler.step_begin_epoch(epoch) + return self.lr_step_update() + + def lr_step(self, epoch, val_loss=None): + """epoch结束时调整学习率""" + self.lr_scheduler.step(epoch, val_loss) + return self.lr_step_update() + + def lr_step_update(self): + """每次更新后调整学习率""" + new_lr = self.lr_scheduler.step_update(self.get_num_updates()) + if isinstance(new_lr, dict): + for k, v in new_lr.items(): + metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) + new_lr = new_lr.get("default", next(iter(new_lr.values()))) + else: + metrics.log_scalar("lr", new_lr, weight=0, priority=300) + return new_lr + + def get_lr(self): + """获取当前学习率""" + return self.optimizer.get_lr() + + def get_model(self): + """获取未包装的模型实例""" + return self._model + + def get_loss(self): + """获取未包装的损失函数实例""" + return self._loss + + def get_num_updates(self): + """获取参数更新次数""" + return self._num_updates + + def set_num_updates(self, num_updates): + """设置参数更新次数""" + self._num_updates = num_updates + self.lr_step_update() + metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) + + def clip_grad_norm(self, clip_norm): + return self.optimizer.clip_grad_norm(clip_norm) + + def cumulative_training_time(self): + if self._cumulative_training_time is None: + return self._local_cumulative_training_time() + else: + return self._cumulative_training_time + + def _local_cumulative_training_time(self): + """聚合训练时间(秒)""" + return time.time() - self._start_time + self._previous_training_time + + def _prepare_sample(self, sample, is_dummy=False): + if sample == "DUMMY": + raise Exception( + "Trying to use an uninitialized 'dummy' batch. This usually indicates " + "that the total number of batches is smaller than the number of " + "participating GPUs. Try reducing the batch size or using fewer GPUs." + ) + + # 使用虚拟批次处理空样本 + if sample is None or len(sample) == 0: + assert ( + self._dummy_batch is not None and len(self._dummy_batch) > 0 + ), "Invalid dummy batch: {}".format(self._dummy_batch) + sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) + return sample, True + + # 移动样本到CUDA设备 + if self.cuda: + sample = utils.move_to_device(sample, self.device) # 替换为MindSpore的设备移动 + + # 数据类型转换(根据需要手动启用) + # def apply_half(t): + # if t.dtype == ms.float32: + # return t.astype(ms.float16) + # return t + # if self.args.fp16: + # sample = utils.apply_to_sample(apply_half, sample) + + if self._dummy_batch == "DUMMY": + self._dummy_batch = sample + + return sample, False + + def _sync_stats(self): + # 多GPU且使用DDP时需要同步统计 + if self.data_parallel_world_size == 1: + return False + else: + return True + + def _log_oom(self, exc): + msg = "OOM: Ran out of memory with exception: {}".format(exc) + logger.warning(msg) + if self.cuda and hasattr(ms, "memory_summary"): + for device_idx in range(ms.get_context("device_id") + 1): + logger.warning(ms.memory_summary(device_id=device_idx)) + sys.stderr.flush() + + def _aggregate_logging_outputs( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + is_train=False, + ): + if self.task.__class__.logging_outputs_can_be_summed( + self.get_loss(), is_train=is_train + ): + return self._fast_stat_sync_sum( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + else: + return self._all_gather_list_sync( + logging_outputs, *extra_stats_to_sum, ignore=ignore + ) + + def _all_gather_list_sync( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + 跨worker同步日志输出,适用于复杂类型的日志 + """ + if ignore: + logging_outputs = [] + results = list( + zip( + *distributed_utils.all_gather_list( + [logging_outputs] + list(extra_stats_to_sum), + max_size=getattr(self.args, "all_gather_list_size", 16384), + group=self.data_parallel_process_group, + ) + ) + ) + logging_outputs, extra_stats_to_sum = results[0], results[1:] + logging_outputs = list(chain.from_iterable(logging_outputs)) + extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] + return logging_outputs, extra_stats_to_sum + + def _fast_stat_sync_sum( + self, + logging_outputs: List[Dict[str, Any]], + *extra_stats_to_sum, + ignore=False, + ): + """ + 快速同步日志输出,适用于标量类型的日志(不可包含嵌套结构) + """ + data = {} + for i, stat in enumerate(extra_stats_to_sum): + data["extra_stats_" + str(i)] = stat + if len(logging_outputs) > 0: + log_keys = list(logging_outputs[0].keys()) + for k in log_keys: + if not ignore: + v = sum(log[k] for log in logging_outputs if k in log) + else: + v = logging_outputs[0][k] + v = ops.zeros_like(v) if isinstance(v, Tensor) else 0 # 替换为MindSpore的zeros_like + data["logging_outputs_" + k] = v + else: + log_keys = None + + # 聚合分布式数据 + data = distributed_utils.all_reduce_dict( + data, device=self.device, group=self.data_parallel_process_group + ) + + extra_stats_to_sum = [ + data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) + ] + if log_keys is not None: + logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] + else: + logging_outputs = [] + return logging_outputs, extra_stats_to_sum + + def _check_grad_norms(self, grad_norm): + """检查所有worker的梯度范数是否一致""" + if self._grad_norm_buf is not None: + self._grad_norm_buf.zero_() + self._grad_norm_buf[self.data_parallel_rank] = grad_norm + all_reduce(self._grad_norm_buf, group=self.data_parallel_process_group) # 替换为MindSpore的all_reduce + + def is_consistent(tensor): + max_abs_diff = ops.max(ops.abs(tensor - tensor[0])) # 替换为MindSpore的ops + return ( + ops.isfinite(tensor).all() + and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() + ) + + if not is_consistent(self._grad_norm_buf): + pretty_detail = "\n".join( + "rank {:3d} = {:.8f}".format(r, n) + for r, n in enumerate(self._grad_norm_buf.asnumpy().tolist()) # 转换为numpy列表 + ) + error_detail = "grad_norm across the workers:\n{}\n".format( + pretty_detail + ) + raise FloatingPointError( + "Fatal error: gradients are inconsistent between workers. " + "Try --ddp-backend=legacy_ddp. " + "Or are you mixing up different generation of GPUs in training?" + + "\n" + + "-" * 80 + + "\n{}\n".format(error_detail) + + "-" * 80 + ) + + def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): + if grad_norm is not None and ( + not isinstance(grad_norm, Tensor) or ops.isfinite(grad_norm) + ): + metrics.log_speed("ups", 1.0, priority=100, round=2) + metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) + if self.args.clip_norm > 0: + metrics.log_scalar( + "clip", + ops.where( # 替换为MindSpore的where + grad_norm > self.args.clip_norm, + grad_norm.astype(ms.float32).fill(100), + grad_norm.astype(ms.float32).fill(0), + ), + priority=500, + round=1, + ) + + with metrics.aggregate() as agg: + if logging_outputs is not None: + self.task.reduce_metrics(logging_outputs, self.get_loss()) + del logging_outputs + + # 警告未正确记录损失的情况 + if "loss" not in agg: + if "loss" not in self._warn_once: + self._warn_once.add("loss") + logger.warning( + "Loss.reduce_metrics did not log a 'loss' value, " + "which may break some functionality" + ) + metrics.log_scalar("loss", -1) + + logging_output = agg.get_smoothed_values() + logging_output["sample_size"] = sample_size + for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: + if key_to_delete in logging_output: + del logging_output[key_to_delete] + return logging_output + + +def _catalog_shared_params(module, memo=None, prefix=""): + if memo is None: + first_call = True + memo = {} + else: + first_call = False + for name, param in module._parameters.items(): + if param is None: + continue + param_prefix = prefix + ("." if prefix else "") + name + if param not in memo: + memo[param] = [] + memo[param].append(param_prefix) + for name, m in module._modules.items(): + if m is None: + continue + submodule_prefix = prefix + ("." if prefix else "") + name + _catalog_shared_params(m, memo, submodule_prefix) + if first_call: + return [x for x in memo.values() if len(x) > 1] + + +def _get_module_by_path(module, path): + path = path.split(".") + for name in path: + module = getattr(module, name) + return module + + +def _set_module_by_path(module, path, value): + path = path.split(".") + for name in path[:-1]: + module = getattr(module, name) + setattr(module, path[-1], value) +# import contextlib +# import logging +# import os +# import sys +# import time +# from itertools import chain +# from typing import Any, Dict, List +# import torch +# from unicore import checkpoint_utils, models, optim, utils +# from unicore.distributed import utils as distributed_utils +# from unicore.logging import meters, metrics +# from unicore.nan_detector import NanDetector +# from unicore.optim import lr_scheduler +# from unicore.ema import ExponentialMovingAverageModel + + +# logger = logging.getLogger(__name__) + + +# class Trainer(object): +# """Main class for data parallel training. + +# This class supports synchronous distributed data parallel training, +# where multiple workers each have a full model replica and gradients +# are accumulated across workers before each update. We use +# :class:`~torch.nn.parallel.DistributedDataParallel` to handle +# communication of the gradients across workers. +# """ + +# def __init__(self, args, task, model, loss): + +# self.args = args +# self.task = task + +# # catalog shared parameters +# shared_params = _catalog_shared_params(model) +# self.cuda = torch.cuda.is_available() +# if self.cuda: +# self.device = torch.device("cuda") +# else: +# self.device = torch.device("cpu") + +# # copy model and loss to current device/dtype +# self._loss = loss +# self._model = model +# if args.fp16: +# self._loss = self._loss.half() +# self._model = self._model.half() +# elif args.bf16: +# self._loss = self._loss.bfloat16() +# self._model = self._model.bfloat16() +# if ( +# # the DistributedUnicoreModel wrapper will handle moving to device, +# # so only handle cases which don't use the wrapper +# not self.use_distributed_wrapper +# ): +# self._loss = self._loss.to(device=self.device) +# self._model = self._model.to(device=self.device) + +# # check that shared parameters are preserved after device transfer +# for shared_param in shared_params: +# ref = _get_module_by_path(self._model, shared_param[0]) +# for path in shared_param[1:]: +# logger.info( +# "detected shared parameter: {} <- {}".format(shared_param[0], path) +# ) +# _set_module_by_path(self._model, path, ref) + +# self._dummy_batch = None # indicates we don't have a dummy batch at first +# self._total_train_steps = None +# self._lr_scheduler = None +# self._num_updates = 0 +# self._optim_history = None +# self._optimizer = None +# self._warn_once = set() +# self._wrapped_loss = None +# self._wrapped_model = None + +# if self.cuda and self.data_parallel_world_size > 1: +# self._grad_norm_buf = torch.tensor( +# data=[0.0] +# * self.data_parallel_world_size, # Initialize with zeros or appropriate values +# dtype=torch.double, # Set the desired data type +# device="cuda", +# ) +# else: +# self._grad_norm_buf = None + +# # get detailed cuda environment +# if self.cuda: +# self.cuda_env = utils.CudaEnvironment() +# if self.data_parallel_world_size > 1: +# self.cuda_env_arr = distributed_utils.all_gather_list( +# self.cuda_env, group=distributed_utils.get_global_group() +# ) +# else: +# self.cuda_env_arr = [self.cuda_env] +# if self.data_parallel_rank == 0: +# utils.CudaEnvironment.pretty_print_cuda_env_list(self.cuda_env_arr) +# else: +# self.cuda_env = None +# self.cuda_env_arr = None + +# # add ema +# if args.validate_with_ema: +# assert args.ema_decay > 0, "valid with ema must with ema_decay > 0" + +# model = self.model +# if args.ema_decay > 0 and ( +# self.data_parallel_rank == 0 or args.validate_with_ema +# ): +# self.ema = ExponentialMovingAverageModel( +# args, +# model, +# args.ema_decay, +# is_flattened=(args.fp16 or args.bf16), +# ) + +# else: +# self.ema = None +# metrics.log_start_time("wall", priority=790, round=2) + +# self._start_time = time.time() +# self._previous_training_time = 0 +# self._cumulative_training_time = None + +# def reinitialize(self): +# """Reinitialize the Trainer, typically after model params change.""" +# self._lr_scheduler = None +# self._optimizer = None +# self._wrapped_loss = None +# self._wrapped_model = None + +# @property +# def data_parallel_world_size(self): +# if self.args.distributed_world_size == 1: +# return 1 +# return distributed_utils.get_data_parallel_world_size() + +# @property +# def data_parallel_process_group(self): +# return distributed_utils.get_data_parallel_group() + +# @property +# def data_parallel_rank(self): +# if self.args.distributed_world_size == 1: +# return 0 +# return distributed_utils.get_data_parallel_rank() + +# @property +# def is_data_parallel_master(self): +# # NOTE: this returns true for all model parallel replicas with data +# # parallel rank 0 +# return self.data_parallel_rank == 0 + +# @property +# def use_distributed_wrapper(self) -> bool: +# return self.data_parallel_world_size > 1 + +# @property +# def should_save_checkpoint_on_current_rank(self) -> bool: +# """Indicates whether to save checkpoints on the current DDP rank.""" +# return self.is_data_parallel_master + +# @property +# def checkpoint_suffix(self) -> str: +# """Suffix to add to the checkpoint file name.""" +# return self.args.checkpoint_suffix or "" + +# @property +# def loss(self): +# if self._wrapped_loss is None: +# if utils.has_parameters(self._loss) and self.use_distributed_wrapper: +# self._wrapped_loss = models.DistributedUnicoreModel( +# self.args, +# self._loss, +# process_group=self.data_parallel_process_group, +# device=self.device, +# ) +# else: +# self._wrapped_loss = self._loss +# return self._wrapped_loss + +# @property +# def model(self): +# if self._wrapped_model is None: +# if self.use_distributed_wrapper: +# self._wrapped_model = models.DistributedUnicoreModel( +# self.args, +# self._model, +# process_group=self.data_parallel_process_group, +# device=self.device, +# ) +# else: +# self._wrapped_model = self._model +# return self._wrapped_model + +# @property +# def optimizer(self): +# if self._optimizer is None: +# self._build_optimizer() +# return self._optimizer + +# @property +# def lr_scheduler(self): +# if self._lr_scheduler is None: +# self._build_optimizer() # this will initialize self._lr_scheduler +# return self._lr_scheduler + +# def _build_optimizer(self): +# params = [ +# (name, param) +# for name, param in chain( +# self.model.named_parameters(), +# self.loss.named_parameters(), +# ) +# if param.requires_grad +# ] +# if self.args.per_sample_clip_norm > 0: +# assert self.args.ddp_backend == "no_c10d" +# assert self.args.batch_size == 1 +# if self.args.fp16 or self.args.bf16: +# if self.cuda and torch.cuda.get_device_capability(0)[0] < 7: +# logger.info( +# "NOTE: your device does NOT support faster training with --fp16, " +# "please switch to FP32 which is likely to be faster" +# ) +# self._optimizer = optim.FP16Optimizer.build_optimizer(self.args, params) + +# if self.args.allreduce_fp32_grad: +# assert self.args.ddp_backend == "no_c10d" +# if self.args.per_sample_clip_norm > 0: +# assert self.args.allreduce_fp32_grad +# else: +# if self.cuda and torch.cuda.get_device_capability(0)[0] >= 7: +# logger.info("NOTE: your device may support faster training with --fp16") +# self._optimizer = optim.build_optimizer(self.args, params) + +# # We should initialize the learning rate scheduler immediately after +# # building the optimizer, so that the initial learning rate is set. +# self._lr_scheduler = lr_scheduler.build_lr_scheduler( +# self.args, +# self.optimizer, +# self._total_train_steps, +# ) +# self._lr_scheduler.step_update(0) + +# def state_dict(self): +# state_dict = { +# "args": self.args, +# "model": self.model.state_dict(), +# "loss": ( +# self.loss.state_dict() if utils.has_parameters(self.loss) else None +# ), +# "optimizer_history": (self._optim_history or []) +# + [ +# { +# "loss_name": self.get_loss().__class__.__name__, +# "optimizer_name": self.optimizer.__class__.__name__, +# "lr_scheduler_state": self.lr_scheduler.state_dict(), +# "num_updates": self.get_num_updates(), +# } +# ], +# "task_state": self.task.state_dict() if self.task is not None else {}, +# "extra_state": { +# "metrics": metrics.state_dict(), +# "previous_training_time": self.cumulative_training_time(), +# }, +# } +# if not self.args.no_save_optimizer_state: +# state_dict["last_optimizer_state"] = self.optimizer.state_dict() +# if self.ema is not None: +# state_dict["ema"] = self.ema.state_dict() +# return state_dict + +# def save_checkpoint(self, filename, extra_state): +# """Save all training state in a checkpoint file.""" +# logger.info(f"Saving checkpoint to {filename}") +# # call state_dict on all ranks in case it needs internal communication +# state_dict = utils.move_to_cpu(self.state_dict()) +# state_dict["extra_state"].update(extra_state) +# if self.should_save_checkpoint_on_current_rank: +# checkpoint_utils.torch_persistent_save( +# state_dict, +# filename, +# ) +# logger.info(f"Finished saving checkpoint to {filename}") + +# def load_checkpoint( +# self, +# filename, +# reset_optimizer=False, +# reset_lr_scheduler=False, +# reset_dataloader=False, +# optimizer_overrides=None, +# reset_meters=False, +# **passthrough_args, +# ): +# """ +# Load all training state from a checkpoint file. +# rank = 0 will load the checkpoint, and then broadcast it to all +# other ranks. +# """ +# extra_state, self._optim_history, last_optim_state = None, [], None + +# logger.info(f"Preparing to load checkpoint {filename}") +# is_distributed = self.data_parallel_world_size > 1 +# is_master = self.data_parallel_rank == 0 +# bexists = None +# if is_master: +# bexists = os.path.isfile(filename) +# if is_distributed: +# bexists = distributed_utils.broadcast_object( +# bexists, +# src_rank=0, +# group=self.data_parallel_process_group, +# dist_device=self.device, +# ) + +# had_loaded_model = False +# ema_loaded = False +# if bexists: +# state = None +# if is_master: +# state = checkpoint_utils.load_checkpoint_to_cpu( +# filename, +# ) +# if is_distributed: +# logger.info("Broadcast checkpoint from rank_0") +# state = distributed_utils.broadcast_object( +# state, +# src_rank=0, +# group=self.data_parallel_process_group, +# dist_device=self.device, +# ) +# last_optim_state = state.get("last_optimizer_state", None) +# ema_state = state.get("ema", None) + +# # load model parameters +# try: +# if self.args.load_from_ema: +# logger.info("loading ema state to model") +# errors = self.model.load_state_dict( +# ema_state["params"], strict=False, model_args=self.args +# ) +# ema_loaded = True +# else: +# errors = self.model.load_state_dict( +# state["model"], strict=False, model_args=self.args +# ) +# # save memory for later steps +# del state["model"] +# had_loaded_model = True + +# if errors.missing_keys: +# logger.warning( +# "Error in loading model state, missing_keys " +# + str(errors.missing_keys) +# ) +# if errors.unexpected_keys: +# logger.warning( +# "Error in loading model state, unexpected_keys " +# + str(errors.unexpected_keys) +# ) +# if utils.has_parameters(self.get_loss()): +# self.get_loss().load_state_dict(state["loss"], strict=True) +# del state["loss"] + +# except Exception: +# raise Exception( +# "Cannot load model parameters from checkpoint {}; " +# "please ensure that the architectures match.".format(filename) +# ) +# extra_state = state["extra_state"] if "extra_state" in state else None +# self._optim_history = ( +# state["optimizer_history"] if "optimizer_history" in state else None +# ) + +# if ( +# ema_state is not None +# and self.ema is not None +# and not self.args.load_from_ema +# ): +# logger.info(f"Loading EMA state...") +# self.ema.load_state_dict(ema_state) +# elif self.ema is not None and not ema_loaded: +# logger.info( +# f"Cannot find EMA state in checkpoint, load model weight to ema directly" +# ) +# self.ema = ExponentialMovingAverageModel( +# self.args, +# self._model, +# decay=self.ema.decay, +# is_flattened=(self.args.fp16 or self.args.bf16), +# ) + +# loaded_train_itr = False +# if extra_state is not None: +# itr_state = extra_state["train_iterator"] +# epoch = itr_state["epoch"] + +# if "previous_training_time" in extra_state: +# self._previous_training_time = extra_state["previous_training_time"] +# self._start_time = time.time() + +# if ( +# itr_state.get("version", 1) >= 2 +# and itr_state["iterations_in_epoch"] == 0 +# ): +# # reset meters at start of epoch +# reset_meters = True + +# if "metrics" in extra_state and not reset_meters: +# metrics.load_state_dict(extra_state["metrics"]) + +# # reset TimeMeters, since their start times don't make sense anymore +# for meter in metrics.get_meters("default"): +# if isinstance(meter, meters.TimeMeter): +# meter.reset() + +# if not reset_dataloader: +# # restore iterator from checkpoint +# epoch_itr = self.get_train_iterator( +# epoch=itr_state["epoch"], load_dataset=True, **passthrough_args +# ) +# epoch_itr.load_state_dict(itr_state) +# loaded_train_itr = True + +# if not loaded_train_itr: +# epoch_itr = self.get_train_iterator( +# epoch=1, load_dataset=True, **passthrough_args +# ) + +# self.init_total_train_steps(epoch_itr) + +# if last_optim_state is not None and not reset_optimizer: +# # rebuild optimizer after loading model, since params may have changed +# self._build_optimizer() + +# # only reload optimizer and lr_scheduler if they match +# last_optim = self._optim_history[-1] +# assert ( +# last_optim["loss_name"] == self.get_loss().__class__.__name__ +# ), f"Loss does not match; please reset the optimizer (--reset-optimizer). {last_optim['loss_name']} vs {self.get_loss().__class__.__name__}" +# assert ( +# last_optim["optimizer_name"] == self.optimizer.__class__.__name__ +# ), f"Optimizer does not match; please reset the optimizer (--reset-optimizer). {last_optim['optimizer_name']} vs {self.optimizer.__class__.__name__}" + +# if not reset_lr_scheduler: +# self.lr_scheduler.load_state_dict(last_optim["lr_scheduler_state"]) + +# self.optimizer.load_state_dict(last_optim_state, optimizer_overrides) + +# self.set_num_updates(last_optim["num_updates"]) + +# if had_loaded_model: +# if loaded_train_itr: +# logger.info( +# "Loaded checkpoint {} (epoch {} @ {} updates)".format( +# filename, epoch, self.get_num_updates() +# ) +# ) +# else: +# logger.info("Loaded checkpoint {}".format(filename)) +# elif ema_loaded: +# logger.info("Loaded ema state from checkpoint {}".format(filename)) +# else: +# logger.info("No existing checkpoint found {}".format(filename)) + +# self.lr_step(epoch_itr.epoch) + +# return extra_state, epoch_itr + +# def get_train_iterator( +# self, +# epoch, +# combine=True, +# load_dataset=True, +# data_selector=None, +# shard_batch_itr=True, +# disable_iterator_cache=False, +# ): +# """Return an EpochBatchIterator over the training set for a given epoch.""" +# if load_dataset: +# logger.info("loading train data for epoch {}".format(epoch)) +# self.task.load_dataset( +# self.args.train_subset, +# epoch=epoch, +# combine=combine, +# data_selector=data_selector, +# ) +# batch_iterator = self.task.get_batch_iterator( +# dataset=self.task.dataset(self.args.train_subset), +# batch_size=self.args.batch_size, +# ignore_invalid_inputs=True, +# required_batch_size_multiple=self.args.required_batch_size_multiple, +# seed=self.args.seed, +# num_shards=self.data_parallel_world_size if shard_batch_itr else 1, +# shard_id=self.data_parallel_rank if shard_batch_itr else 0, +# num_workers=self.args.num_workers, +# epoch=epoch, +# data_buffer_size=self.args.data_buffer_size, +# disable_iterator_cache=disable_iterator_cache, +# ) +# self.reset_dummy_batch(batch_iterator.first_batch) +# return batch_iterator + +# def init_total_train_steps(self, epoch_itr): +# if self.args.max_epoch > 0: +# self._total_train_steps = ( +# (len(epoch_itr) + 1) // self.args.update_freq[0] * self.args.max_epoch +# ) +# else: +# self._total_train_steps = self.args.max_update + +# def get_valid_iterator( +# self, +# subset, +# disable_iterator_cache=False, +# ): +# """Return an EpochBatchIterator over given validation subset for a given epoch.""" +# batch_iterator = self.task.get_batch_iterator( +# dataset=self.task.dataset(subset), +# batch_size=self.args.batch_size_valid, +# ignore_invalid_inputs=self.args.skip_invalid_size_inputs_valid_test, +# required_batch_size_multiple=self.args.required_batch_size_multiple, +# seed=self.args.seed, +# num_shards=self.data_parallel_world_size, +# shard_id=self.data_parallel_rank, +# num_workers=self.args.num_workers, +# # always pass a fixed "epoch" to keep validation data consistent +# # across training epochs +# epoch=1, +# data_buffer_size=self.args.data_buffer_size, +# disable_iterator_cache=disable_iterator_cache, +# ) +# # Using training data for dummy batch. If the following line is enabled, the dummy batch will be from validation data, +# # and cause OOM error for some corner case during training. So disable it. +# # self.reset_dummy_batch(batch_iterator.first_batch) +# return batch_iterator + +# def begin_epoch(self, epoch): +# """Called at the beginning of each epoch.""" +# logger.info("begin training epoch {}".format(epoch)) + +# self.lr_step_begin_epoch(epoch) + +# # task specific setup per epoch +# self.task.begin_epoch(epoch, self.get_model()) + +# def begin_valid_epoch(self, epoch): +# """Called at the beginning of each validation epoch.""" + +# # task specific setup per validation epoch +# self.task.begin_valid_epoch(epoch, self.get_model()) + +# def reset_dummy_batch(self, batch): +# self._dummy_batch = batch + +# @metrics.aggregate("train") +# def train_step(self, samples, raise_oom=False): +# """Do forward, backward and parameter update.""" +# self.model.train() +# self.loss.train() +# self.zero_grad() + +# metrics.log_start_time("train_wall", priority=800, round=2) + +# # forward and backward pass +# logging_outputs, sample_size, ooms = [], 0, 0 +# for i, sample in enumerate(samples): # delayed update loop +# sample, is_dummy_batch = self._prepare_sample(sample) + +# def maybe_no_sync(): +# """ +# Whenever *samples* contains more than one mini-batch, we +# want to accumulate gradients locally and only call +# all-reduce in the last backwards pass. +# """ +# if ( +# self.data_parallel_world_size > 1 +# and hasattr(self.model, "no_sync") +# and i < len(samples) - 1 +# ): +# return self.model.no_sync() +# else: +# return contextlib.ExitStack() # dummy contextmanager + +# try: +# with maybe_no_sync(): +# # use different seed for different rank in training, otherwise the dropout will be the same in different workers. +# with utils.torch_seed( +# self.args.seed, +# self.get_num_updates(), +# i, +# self.data_parallel_rank, +# ): +# # forward and backward +# loss, sample_size_i, logging_output = self.task.train_step( +# sample=sample, +# model=self.model, +# loss=self.loss, +# optimizer=self.optimizer, +# update_num=self.get_num_updates(), +# ignore_grad=is_dummy_batch, +# ) +# del loss +# if self.args.per_sample_clip_norm > 0: +# self.optimizer.per_sample_clip_grad_norm( +# self.args.per_sample_clip_norm +# ) + +# logging_outputs.append(logging_output) +# sample_size += sample_size_i + +# # emptying the CUDA cache after the first step can +# # reduce the chance of OOM +# if self.cuda and self.get_num_updates() == 0: +# torch.cuda.empty_cache() +# except RuntimeError as e: +# if "out of memory" in str(e): +# self._log_oom(e) +# if raise_oom: +# raise e +# logger.warning( +# "attempting to recover from OOM in forward/backward pass" +# ) +# ooms += 1 +# self.zero_grad() +# if self.cuda: +# torch.cuda.empty_cache() +# if self.args.distributed_world_size == 1: +# return None +# else: +# raise e + +# if is_dummy_batch: +# if torch.is_tensor(sample_size): +# sample_size.zero_() +# else: +# sample_size *= 0.0 + +# if torch.is_tensor(sample_size): +# sample_size = sample_size.float() +# else: +# sample_size = float(sample_size) + +# local_sample_size = sample_size +# # gather logging outputs from all replicas +# if self._sync_stats(): +# train_time = self._local_cumulative_training_time() +# logging_outputs, ( +# sample_size, +# ooms, +# total_train_time, +# ) = self._aggregate_logging_outputs( +# logging_outputs, +# sample_size, +# ooms, +# train_time, +# ignore=is_dummy_batch, +# is_train=True, +# ) +# self._cumulative_training_time = ( +# total_train_time / self.data_parallel_world_size +# ) + +# overflow = False +# try: +# with torch.autograd.profiler.record_function("reduce-grads"): +# # reduce gradients across workers +# self.optimizer.all_reduce_grads(self.model) +# if utils.has_parameters(self.loss): +# self.optimizer.all_reduce_grads(self.loss) + +# with torch.autograd.profiler.record_function("multiply-grads"): +# # multiply gradients by (data_parallel_size / sample_size) since +# # DDP normalizes by the number of data parallel workers for +# # improved fp16 precision. +# # Thus we get (sum_of_gradients / sample_size) at the end. +# # In case of fp16, this step also undoes loss scaling. +# # (Debugging note: Some optimizers perform this scaling on the +# # fly, so inspecting model.parameters() or optimizer.params may +# # still show the original, unscaled gradients.) +# numer = self.data_parallel_world_size if self._sync_stats() else 1 + +# self.optimizer.multiply_grads(numer / (sample_size or 1.0)) +# # Note: (sample_size or 1.0) handles the case of a zero gradient, in a +# # way that avoids CPU/device transfers in case sample_size is a GPU or +# # TPU object. The assumption is that the gradient itself is also 0. + +# with torch.autograd.profiler.record_function("clip-grads"): +# # clip grads +# grad_norm = self.clip_grad_norm(self.args.clip_norm) + +# self._check_grad_norms(grad_norm) +# if not torch.isfinite(grad_norm).all(): +# # check local gradnorm single GPU case, trigger NanDetector +# raise FloatingPointError("gradients are Nan/Inf") + +# with torch.autograd.profiler.record_function("optimizer"): +# # fixed the seed in case for the stochastic rounding in different ranks +# with utils.torch_seed(self.args.seed, self.get_num_updates()): +# # take an optimization step +# self.task.optimizer_step( +# self.optimizer, +# model=self.model, +# update_num=self.get_num_updates(), +# ) +# if self.ema is not None: +# with torch.autograd.profiler.record_function("ema"): +# if self.args.fp16 or self.args.bf16: +# self.ema.update(self.optimizer.fp32_params) +# else: +# self.ema.update(self.model.named_parameters()) + +# except FloatingPointError: +# # re-run the forward and backward pass with hooks attached to print +# # out where it fails +# self.zero_grad() +# with NanDetector(self.get_model()): +# for i, sample in enumerate(samples): +# sample, _ = self._prepare_sample(sample) +# with utils.torch_seed( +# self.args.seed, +# self.get_num_updates(), +# i, +# self.data_parallel_rank, +# ): +# self.task.train_step( +# sample, +# self.model, +# self.loss, +# self.optimizer, +# self.get_num_updates(), +# ignore_grad=False, +# ) +# raise +# except OverflowError as e: +# overflow = True +# logger.info( +# f"NOTE: gradient overflow detected, ignoring gradient, {str(e)}" +# ) +# grad_norm = torch.tensor(0.0).cuda() +# self.zero_grad() +# except RuntimeError as e: +# if "out of memory" in str(e): +# self._log_oom(e) +# logger.error("OOM during optimization, irrecoverable") +# raise e + +# logging_output = None +# if not overflow: +# self.set_num_updates(self.get_num_updates() + 1) + +# if self.cuda and self.cuda_env is not None: +# # log minimum free memory over the iteration +# gb_used = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024 +# torch.cuda.reset_peak_memory_stats() +# gb_free = self.cuda_env.total_memory_in_GB - gb_used +# metrics.log_scalar("gb_free", gb_free, priority=1500, round=1, weight=0) + +# # log stats +# logging_output = self._reduce_and_log_stats( +# logging_outputs, +# sample_size, +# grad_norm, +# ) + +# # clear CUDA cache to reduce memory fragmentation +# if ( +# self.cuda +# and self.args.empty_cache_freq > 0 +# and ( +# (self.get_num_updates() + self.args.empty_cache_freq - 1) +# % self.args.empty_cache_freq +# ) +# == 0 +# ): +# torch.cuda.empty_cache() + +# if self.args.fp16: +# metrics.log_scalar( +# "loss_scale", +# self.optimizer.scaler.loss_scale, +# priority=700, +# round=4, +# weight=0, +# ) + +# metrics.log_stop_time("train_wall") +# return logging_output + +# @metrics.aggregate("valid") +# def valid_step(self, sample, raise_oom=False): +# """Do forward pass in evaluation mode.""" +# with torch.no_grad(): +# self.model.eval() +# self.loss.eval() + +# sample, is_dummy_batch = self._prepare_sample(sample) + +# try: +# _loss, sample_size, logging_output = self.task.valid_step( +# sample, self.model, self.loss +# ) +# except RuntimeError as e: +# if "out of memory" in str(e): +# self._log_oom(e) +# if not raise_oom: +# logger.warning( +# "ran out of memory in validation step, retrying batch" +# ) +# for p in self.model.parameters(): +# if p.grad is not None: +# p.grad = None # free some memory +# if self.cuda: +# torch.cuda.empty_cache() +# return self.valid_step(sample, raise_oom=True) +# raise e + +# logging_outputs = [logging_output] +# if is_dummy_batch: +# if torch.is_tensor(sample_size): +# sample_size.zero_() +# else: +# sample_size *= 0.0 + +# # gather logging outputs from all replicas +# if self.data_parallel_world_size > 1: +# logging_outputs, (sample_size,) = self._aggregate_logging_outputs( +# logging_outputs, +# sample_size, +# ignore=is_dummy_batch, +# is_train=False, +# ) + +# return logging_outputs + +# def zero_grad(self): +# self.optimizer.zero_grad() + +# def lr_step_begin_epoch(self, epoch): +# """Adjust the learning rate at the beginning of the epoch.""" +# self.lr_scheduler.step_begin_epoch(epoch) +# # prefer updating the LR based on the number of steps +# return self.lr_step_update() + +# def lr_step(self, epoch, val_loss=None): +# """Adjust the learning rate at the end of the epoch.""" +# self.lr_scheduler.step(epoch, val_loss) +# # prefer updating the LR based on the number of steps +# return self.lr_step_update() + +# def lr_step_update(self): +# """Update the learning rate after each update.""" +# new_lr = self.lr_scheduler.step_update(self.get_num_updates()) +# if isinstance(new_lr, dict): +# for k, v in new_lr.items(): +# metrics.log_scalar(f"lr_{k}", v, weight=0, priority=300) +# new_lr = new_lr.get("default", next(iter(new_lr.values()))) +# else: +# metrics.log_scalar("lr", new_lr, weight=0, priority=300) +# return new_lr + +# def get_lr(self): +# """Get the current learning rate.""" +# return self.optimizer.get_lr() + +# def get_model(self): +# """Get the (non-wrapped) model instance.""" +# return self._model + +# def get_loss(self): +# """Get the (non-wrapped) loss instance.""" +# return self._loss + +# def get_num_updates(self): +# """Get the number of parameters updates.""" +# return self._num_updates + +# def set_num_updates(self, num_updates): +# """Set the number of parameters updates.""" +# self._num_updates = num_updates +# self.lr_step_update() +# metrics.log_scalar("num_updates", self._num_updates, weight=0, priority=200) + +# def clip_grad_norm(self, clip_norm): +# return self.optimizer.clip_grad_norm(clip_norm) + +# def cumulative_training_time(self): +# if self._cumulative_training_time is None: +# # single GPU +# return self._local_cumulative_training_time() +# else: +# return self._cumulative_training_time + +# def _local_cumulative_training_time(self): +# """Aggregate training time in seconds.""" +# return time.time() - self._start_time + self._previous_training_time + +# def _prepare_sample(self, sample, is_dummy=False): +# if sample == "DUMMY": +# raise Exception( +# "Trying to use an uninitialized 'dummy' batch. This usually indicates " +# "that the total number of batches is smaller than the number of " +# "participating GPUs. Try reducing the batch size or using fewer GPUs." +# ) + +# if sample is None or len(sample) == 0: +# assert ( +# self._dummy_batch is not None and len(self._dummy_batch) > 0 +# ), "Invalid dummy batch: {}".format(self._dummy_batch) +# sample, _ = self._prepare_sample(self._dummy_batch, is_dummy=True) +# return sample, True + +# if self.cuda: +# sample = utils.move_to_cuda(sample) + +# def apply_half(t): +# if t.dtype is torch.float32: +# return t.half() +# return t + +# def apply_bfloat16(t): +# if t.dtype is torch.float32: +# return t.to(dtype=torch.bfloat16) +# return t + +# # Please manually convert data type by yourself. +# # if self.args.fp16: +# # sample = utils.apply_to_sample(apply_half, sample) + +# # if self.args.bf16: +# # sample = utils.apply_to_sample(apply_bfloat16, sample) + +# if self._dummy_batch == "DUMMY": +# self._dummy_batch = sample + +# return sample, False + +# def _sync_stats(self): +# # Return True if it's using multiple GPUs and DDP or multiple GPUs with +# if self.data_parallel_world_size == 1: +# return False +# else: +# return True + +# def _log_oom(self, exc): +# msg = "OOM: Ran out of memory with exception: {}".format(exc) +# logger.warning(msg) +# if torch.cuda.is_available() and hasattr(torch.cuda, "memory_summary"): +# for device_idx in range(torch.cuda.device_count()): +# logger.warning(torch.cuda.memory_summary(device=device_idx)) +# sys.stderr.flush() + +# def _aggregate_logging_outputs( +# self, +# logging_outputs: List[Dict[str, Any]], +# *extra_stats_to_sum, +# ignore=False, +# is_train=False, +# ): +# if self.task.__class__.logging_outputs_can_be_summed( +# self.get_loss(), is_train=is_train +# ): +# return self._fast_stat_sync_sum( +# logging_outputs, *extra_stats_to_sum, ignore=ignore +# ) +# else: +# return self._all_gather_list_sync( +# logging_outputs, *extra_stats_to_sum, ignore=ignore +# ) + +# def _all_gather_list_sync( +# self, +# logging_outputs: List[Dict[str, Any]], +# *extra_stats_to_sum, +# ignore=False, +# ): +# """ +# Sync logging outputs across workers. all_gather_list_sync is +# suitable when logging outputs are complex types. +# """ +# if ignore: +# logging_outputs = [] +# results = list( +# zip( +# *distributed_utils.all_gather_list( +# [logging_outputs] + list(extra_stats_to_sum), +# max_size=getattr(self.args, "all_gather_list_size", 16384), +# group=self.data_parallel_process_group, +# ) +# ) +# ) +# logging_outputs, extra_stats_to_sum = results[0], results[1:] +# logging_outputs = list(chain.from_iterable(logging_outputs)) +# extra_stats_to_sum = [sum(s) for s in extra_stats_to_sum] +# return logging_outputs, extra_stats_to_sum + +# def _fast_stat_sync_sum( +# self, +# logging_outputs: List[Dict[str, Any]], +# *extra_stats_to_sum, +# ignore=False, +# ): +# """ +# Sync logging outputs across workers. fast_stat_sync_sum is +# faster than all_gather_list_sync, but is only suitable when +# logging outputs are scalars and can be summed. Note that +# *logging_outputs* cannot contain any nested dicts/lists. +# """ +# data = {} +# for i, stat in enumerate(extra_stats_to_sum): +# data["extra_stats_" + str(i)] = stat +# if len(logging_outputs) > 0: +# log_keys = list(logging_outputs[0].keys()) +# for k in log_keys: +# if not ignore: +# v = sum(log[k] for log in logging_outputs if k in log) +# else: +# v = logging_outputs[0][k] +# v = torch.zeros_like(v) if torch.is_tensor(v) else 0 +# data["logging_outputs_" + k] = v +# else: +# log_keys = None + +# data = distributed_utils.all_reduce_dict( +# data, device=self.device, group=self.data_parallel_process_group +# ) + +# extra_stats_to_sum = [ +# data["extra_stats_" + str(i)] for i in range(len(extra_stats_to_sum)) +# ] +# if log_keys is not None: +# logging_outputs = [{k: data["logging_outputs_" + k] for k in log_keys}] +# else: +# logging_outputs = [] +# return logging_outputs, extra_stats_to_sum + +# def _check_grad_norms(self, grad_norm): +# """Check that grad norms are consistent across workers.""" +# if self._grad_norm_buf is not None: +# self._grad_norm_buf.zero_() +# self._grad_norm_buf[self.data_parallel_rank] = grad_norm +# distributed_utils.all_reduce( +# self._grad_norm_buf, group=self.data_parallel_process_group +# ) + +# def is_consistent(tensor): +# max_abs_diff = torch.max(torch.abs(tensor - tensor[0])) +# return ( +# torch.isfinite(tensor).all() +# and (max_abs_diff / (tensor[0] + 1e-6) < 1e-6).all() +# ) + +# if not is_consistent(self._grad_norm_buf): +# pretty_detail = "\n".join( +# "rank {:3d} = {:.8f}".format(r, n) +# for r, n in enumerate(self._grad_norm_buf.tolist()) +# ) +# error_detail = "grad_norm across the workers:\n{}\n".format( +# pretty_detail +# ) +# # use FloatingPointError to trigger NanDetector +# raise FloatingPointError( +# "Fatal error: gradients are inconsistent between workers. " +# "Try --ddp-backend=legacy_ddp. " +# "Or are you mixing up different generation of GPUs in training?" +# + "\n" +# + "-" * 80 +# + "\n{}\n".format(error_detail) +# + "-" * 80 +# ) + +# def _reduce_and_log_stats(self, logging_outputs, sample_size, grad_norm=None): +# if grad_norm is not None and ( +# not torch.is_tensor(grad_norm) or torch.isfinite(grad_norm) +# ): +# metrics.log_speed("ups", 1.0, priority=100, round=2) +# metrics.log_scalar("gnorm", grad_norm, priority=400, round=3) +# if self.args.clip_norm > 0: +# metrics.log_scalar( +# "clip", +# torch.where( +# grad_norm > self.args.clip_norm, +# grad_norm.new_tensor(100), +# grad_norm.new_tensor(0), +# ), +# priority=500, +# round=1, +# ) + +# with metrics.aggregate() as agg: +# if logging_outputs is not None: +# self.task.reduce_metrics(logging_outputs, self.get_loss()) +# del logging_outputs + +# # extra warning for losses that don't properly log a loss value +# if "loss" not in agg: +# if "loss" not in self._warn_once: +# self._warn_once.add("loss") +# logger.warning( +# "Loss.reduce_metrics did not log a 'loss' value, " +# "which may break some functionality" +# ) +# metrics.log_scalar("loss", -1) + +# logging_output = agg.get_smoothed_values() +# logging_output["sample_size"] = sample_size +# for key_to_delete in ["ppl", "wps", "wpb", "bsz"]: +# if key_to_delete in logging_output: +# del logging_output[key_to_delete] +# return logging_output + + +# def _catalog_shared_params(module, memo=None, prefix=""): +# if memo is None: +# first_call = True +# memo = {} +# else: +# first_call = False +# for name, param in module._parameters.items(): +# if param is None: +# continue +# param_prefix = prefix + ("." if prefix else "") + name +# if param not in memo: +# memo[param] = [] +# memo[param].append(param_prefix) +# for name, m in module._modules.items(): +# if m is None: +# continue +# submodule_prefix = prefix + ("." if prefix else "") + name +# _catalog_shared_params(m, memo, submodule_prefix) +# if first_call: +# return [x for x in memo.values() if len(x) > 1] + + +# def _get_module_by_path(module, path): +# path = path.split(".") +# for name in path: +# module = getattr(module, name) +# return module + + +# def _set_module_by_path(module, path, value): +# path = path.split(".") +# for name in path[:-1]: +# module = getattr(module, name) +# setattr(module, path[-1], value) diff --git a/MindChemistry/applications/Uni-Mol/unicore/utils.py b/MindChemistry/applications/Uni-Mol/unicore/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4675d5a42bea78073fb7488686591b4d442f5b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/utils.py @@ -0,0 +1,945 @@ +# Copyright (c) DP Technology. +# Copyright (c) Facebook, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import contextlib +import importlib +import logging +import os +import sys +import warnings +import numpy as np +from copy import deepcopy +from functools import partial +from typing import List, Callable, Any, Dict +import mindspore as ms +import mindspore.ops as F +from mindspore import ops, Tensor + +try: + import unicore_fused_multi_tensor + + HAS_MULTI_TENSOR = True +except: + print("fused_multi_tensor is not installed correctly") + HAS_MULTI_TENSOR = False + +try: + import unicore_fused_rounding + + HAS_FUSED_ROUNDING = True +except: + print("fused_rounding is not installed correctly") + HAS_FUSED_ROUNDING = False + +# 设备能力检查(仅GPU环境执行,避免Ascend环境报错) +device_target = ms.get_context("device_target") +if device_target == "GPU": + try: + gpu_capability = ms.get_device_capability()[0] + if gpu_capability < 7: + HAS_MULTI_TENSOR = False + HAS_FUSED_ROUNDING = False + except: + HAS_MULTI_TENSOR = False + HAS_FUSED_ROUNDING = False +else: + # Ascend/NPU环境不支持GPU融合操作 + HAS_MULTI_TENSOR = False + HAS_FUSED_ROUNDING = False + +logger = logging.getLogger(__name__) + + +def apply_to_sample(f, sample): + if hasattr(sample, "__len__") and len(sample) == 0: + return {} + + def _apply(x): + if isinstance(x, Tensor): # 替换torch.is_tensor为MindSpore张量判断 + return f(x) + elif isinstance(x, dict): + return {key: _apply(value) for key, value in x.items()} + elif isinstance(x, list): + return [_apply(x) for x in x] + elif isinstance(x, tuple): + return tuple(_apply(x) for x in x) + elif isinstance(x, set): + return {_apply(x) for x in x} + else: + return x + + return apply_to_sample(_move_to_device, sample) + + +def move_to_device(sample, device=None): + """替换move_to_cuda,MindSpore统一用device管理(适配Ascend设备格式)""" + device = device or ms.get_context("device_id") # 获取当前设备ID + if isinstance(device, int): + # 设备字符串格式(Ascend为"Ascend:0",GPU为"GPU:0") + device = f"{device_target}:{device}" + + def _move_to_device(tensor): + return tensor.to(device) # MindSpore的to方法支持设备迁移 + + return apply_to_sample(_move_to_device, sample) + + +def move_to_cpu(sample): + """移动到CPU并处理半精度张量""" + def _move_to_cpu(tensor): + # MindSpore对CPU上的半精度支持有限,转为float32 + if tensor.dtype in {ms.bfloat16, ms.float16}: + tensor = tensor.astype(ms.float32) + return tensor.to("CPU") + + return apply_to_sample(_move_to_cpu, sample) + + +def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> Tensor: + per_device_grads = {} + norms = [] + for grad in grads: + # Ascend环境下获取设备方式(避免GPU专属属性报错) + device = grad.device if hasattr(grad, "device") else f"{device_target}:{ms.get_context('device_id')}" + dtype = grad.dtype + if device not in per_device_grads: + per_device_grads[device] = {} + if dtype not in per_device_grads[device]: + per_device_grads[device][dtype] = [] + per_device_grads[device][dtype].append(grad) + + for device in per_device_grads.keys(): + for dtype in per_device_grads[device].keys(): + cur_grads = per_device_grads[device][dtype] + # 仅GPU环境使用融合操作,Ascend环境跳过 + if HAS_MULTI_TENSOR and device.startswith("GPU"): + norm = unicore_fused_multi_tensor.l2norm(chunk_size, [cur_grads]) + norms.append(norm) + else: + # 替换torch.norm为MindSpore的ops.norm + norms += [ops.norm(g, ord=2, dtype=ms.float32) for g in cur_grads] + + total_norm = ops.norm(Tensor(norms), ord=2, dtype=ms.float32) + return total_norm + + +# @ms.no_grad +def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> Tensor: + if isinstance(params, Tensor): + params = [params] + params = list(params) + # 筛选有梯度的参数 + grads = [p.grad for p in filter(lambda p: p.grad is not None, params)] + if len(grads) == 0: + if len(params) > 0: + return params[0].copy().fill(0.0) + else: + return Tensor(0.0) + + if len(grads) == 1: + total_norm = ops.norm(grads[0], ord=2, dtype=ms.float32) + else: + total_norm = multi_tensor_total_norm(grads) + + if aggregate_norm_fn is not None: + total_norm = aggregate_norm_fn(total_norm) + + if max_norm > 0: + max_norm = float(max_norm) + clip_coef = (max_norm / (total_norm + 1e-6)).clip(max=1.0) # 替换clamp_为clip + for g in grads: + g *= clip_coef + return total_norm + + +def import_user_module(args): + """导入用户模块,逻辑与PyTorch版本一致""" + module_path = getattr(args, "user_dir", None) + if module_path is not None: + module_path = os.path.abspath(args.user_dir) + if not os.path.exists(module_path) and not os.path.isfile( + os.path.dirname(module_path) + ): + unicore_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) + if os.path.exists(unicore_rel_path): + module_path = unicore_rel_path + else: + unicore_rel_path = os.path.join( + os.path.dirname(__file__), "..", args.user_dir + ) + if os.path.exists(unicore_rel_path): + module_path = unicore_rel_path + else: + raise FileNotFoundError(module_path) + + import_user_module.memo = getattr(import_user_module, "memo", set()) + if module_path not in import_user_module.memo: + import_user_module.memo.add(module_path) + + module_parent, module_name = os.path.split(module_path) + if module_name not in sys.modules: + sys.path.insert(0, module_parent) + importlib.import_module(module_name) + else: + raise ImportError( + "Failed to import --user-dir={} because the corresponding module name " + "({}) is not globally unique. Please rename the directory to " + "something unique and try again.".format(module_path, module_name) + ) + + +def get_activation_fn(activation: str) -> Callable: + """获取激活函数,替换为MindSpore实现""" + if activation == "relu": + return F.relu + elif activation == "gelu": + return F.gelu + elif activation == "tanh": + return ops.tanh # MindSpore的tanh在ops中 + elif activation == "linear": + return lambda x: x + else: + raise RuntimeError(f"--activation-fn {activation} not supported") + + +def get_available_activation_fns() -> List: + return [ + "relu", + "gelu", + "tanh", + "linear", + ] + + +def has_parameters(module): + """检查模块是否有可训练参数""" + try: + next(module.parameters()) + return True + except StopIteration: + return False + + +def get_rng_state(): + """获取随机数状态,替换为MindSpore的实现(适配多设备)""" + state = {"mindspore_rng_state": ms.get_rng_state()} # MindSpore的全局随机状态 + if device_target == "GPU": + state["gpu_rng_state"] = ms.random.get_rng_state() # GPU随机状态 + elif device_target == "Ascend": + # Ascend环境随机状态获取(MindSpore 2.x支持) + state["ascend_rng_state"] = ms.random.get_rng_state() + return state + + +def set_rng_state(state): + """设置随机数状态(适配多设备)""" + ms.set_rng_state(state["mindspore_rng_state"]) + if device_target == "GPU" and "gpu_rng_state" in state: + ms.random.set_rng_state(state["gpu_rng_state"]) + elif device_target == "Ascend" and "ascend_rng_state" in state: + ms.random.set_rng_state(state["ascend_rng_state"]) + + +@contextlib.contextmanager +def mindspore_seed(seed, *addl_seeds): + """MindSpore的种子上下文管理器,替换torch_seed(修复numpy未导入问题)""" + if seed is None: + yield + return + + def check_seed(s): + assert isinstance(s, (int, np.int32, np.int64)) # 依赖numpy,已新增导入 + + check_seed(seed) + if len(addl_seeds) > 0: + for s in addl_seeds: + check_seed(s) + seed = int(hash((seed, *addl_seeds)) % 1e8) + + state = get_rng_state() + ms.set_seed(seed) # 设置全局种子 + if device_target == "GPU": + ms.random.set_seed(seed) # GPU种子 + elif device_target == "Ascend": + ms.random.set_seed(seed) # Ascend种子 + try: + yield + finally: + set_rng_state(state) + + +class DeviceEnvironment(object): + """设备环境类(移除GPU硬编码,适配Ascend/NPU)""" + def __init__(self): + cur_device = ms.get_context("device_id") + # 多设备兼容:获取设备属性(GPU/Ascend通用) + try: + prop = ms.get_device_properties(f"{device_target}:{cur_device}") + except: + # 兜底:若无法获取属性,填充默认值 + prop = { + "name": f"{device_target} Device {cur_device}", + "compute_capability": (0, 0) if device_target != "GPU" else (7, 0), + "total_memory": 0 + } + + self.name = prop["name"] + self.major = prop["compute_capability"][0] + self.minor = prop["compute_capability"][1] + # 内存转换为GB(MindSpore设备属性内存单位为字节) + self.total_memory_in_GB = prop["total_memory"] / (1024 ** 3) if prop["total_memory"] > 0 else 0.0 + + @staticmethod + def pretty_print_device_env_list(device_env_list): + num_workers = len(device_env_list) + device_type = device_env_list[0].name.split()[0] if device_env_list else "Device" + center = f"{device_type} environments for all {num_workers} workers" + banner_len = 40 - len(center) // 2 + first_line = "*" * banner_len + center + "*" * banner_len + logger.info(first_line) + for r, env in enumerate(device_env_list): + logger.info( + f"rank {r:3d}: " + f"capabilities = {env.major:2d}.{env.minor:<2d} ; " + f"total memory = {env.total_memory_in_GB:.3f} GB ; " + f"name = {env.name:40s}" + ) + logger.info(first_line) + + +def csv_str_list(x): + return x.split(",") + + +def eval_str_list(x, type=float): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + try: + return list(map(type, x)) + except TypeError: + return [type(x)] + + +def eval_str_dict(x, type=dict): + if x is None: + return None + if isinstance(x, str): + x = eval(x) + return x + + +def eval_bool(x, default=False): + if x is None: + return default + try: + return bool(eval(x)) + except TypeError: + return default + + +def checkpoint_sequential( + functions, + input, + enabled=True, +): + """MindSpore 2.6梯度检查点实现(使用官方推荐的ops.checkpoint)""" + def wrap_tuple(a): + return (a,) if not isinstance(a, tuple) else a + + def exec(func, a): + return wrap_tuple(func(*a)) + + def get_wrap_exec(func): + def wrap_exec(*a): + return exec(func, a) + return wrap_exec + + input = wrap_tuple(input) + is_grad_enabled = ms.is_grad_enabled() # 替换torch.is_grad_enabled() + + if enabled and is_grad_enabled: + for func in functions: + # MindSpore 2.6推荐使用ops.checkpoint实现梯度检查点 + wrap_func = get_wrap_exec(func) + # 梯度检查点:preserve_graph=True确保反向计算正常 + input = ops.checkpoint(wrap_func, *input, preserve_graph=True) + else: + for func in functions: + input = exec(func, input) + return input + + +def permute_final_dims(tensor: Tensor, inds: List[int]): + """维度重排,替换torch.permute(MindSpore transpose适配)""" + zero_index = -1 * len(inds) + first_inds = list(range(len(tensor.shape[:zero_index]))) + permutation = first_inds + [zero_index + i for i in inds] + return tensor.transpose(permutation) + + +def flatten_final_dims(t: Tensor, num_dims: int): + """展平最后几个维度,替换torch.reshape(MindSpore reshape兼容)""" + return t.reshape(t.shape[:-num_dims] + (-1,)) + + +def masked_mean(mask, value, dim, eps=1e-10): + """带掩码的均值计算(MindSpore broadcast_to适配)""" + mask = mask.broadcast_to(value.shape) # 替换torch.expand + sum_val = ops.sum(mask * value, axis=dim) # axis对应PyTorch的dim + sum_mask = ops.sum(mask, axis=dim) + return sum_val / (eps + sum_mask) + + +def dict_multimap(fn, dicts): + first = dicts[0] + new_dict = {} + for k, v in first.items(): + all_v = [d[k] for d in dicts] + if isinstance(v, dict): + new_dict[k] = dict_multimap(fn, all_v) + else: + new_dict[k] = fn(all_v) + return new_dict + + +def one_hot(x, num_classes, dtype=ms.float32): + """替换torch.one_hot实现(MindSpore one_hot适配)""" + # MindSpore one_hot要求输入为int32,深度为num_classes + x_one_hot = ops.one_hot(x.astype(ms.int32), num_classes, 1.0, 0.0, dtype=dtype) + return x_one_hot + + +def batched_gather(data, inds, dim=0, num_batch_dims=0): + """批量gather操作,替换torch.gather(MindSpore索引逻辑适配)""" + assert dim < 0 or dim - num_batch_dims >= 0 + ranges = [] + for i, s in enumerate(data.shape[:num_batch_dims]): + r = Tensor(range(s), dtype=ms.int32) + r = r.reshape((*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) + ranges.append(r) + + remaining_dims = [slice(None) for _ in range(len(data.shape) - num_batch_dims)] + gather_dim = dim - num_batch_dims if dim >= 0 else dim + remaining_dims[gather_dim] = inds + ranges.extend(remaining_dims) + return data[ranges] + + +def dict_map(fn, dic, leaf_type): + new_dict = {} + for k, v in dic.items(): + if isinstance(v, dict): + new_dict[k] = dict_map(fn, v, leaf_type) + else: + new_dict[k] = tree_map(fn, v, leaf_type) + return new_dict + + +def tree_map(fn, tree, leaf_type): + if isinstance(tree, dict): + return dict_map(fn, tree, leaf_type) + elif isinstance(tree, list): + return [tree_map(fn, x, leaf_type) for x in tree] + elif isinstance(tree, tuple): + return tuple([tree_map(fn, x, leaf_type) for x in tree]) + elif isinstance(tree, leaf_type): + try: + return fn(tree) + except: + raise ValueError(f"cannot apply {fn} on {tree}.") + else: + raise ValueError(f"{type(tree)} not supported") + + +tensor_tree_map = partial(tree_map, leaf_type=Tensor) + + +def fp32_to_bf16_sr(t, o): + """FP32转BF16(带随机舍入,适配Ascend环境)""" + # 仅GPU环境使用融合库,Ascend环境直接转换 + if HAS_FUSED_ROUNDING and device_target == "GPU" and hasattr(t, "device") and t.device.startswith("GPU"): + unicore_fused_rounding.fp32_to_bf16_sr(t, o) + else: + # MindSpore随机舍入逻辑(Ascend兼容) + r = (ops.rand(t.shape, device=t.device, dtype=ms.float32) - 0.5) / 256 + m, e = ops.frexp(t) # 替换torch.frexp + t = t + ops.ldexp(r, e) # 替换torch.ldexp + o.copy_(t.astype(ms.bfloat16)) # MindSpore Tensor.copy_适配 + + +def set_jit_fusion_options(): + """MindSpore的JIT融合选项设置(适配Ascend/GPU)""" + # Ascend环境推荐启用图核融合,GPU环境启用O2优化 + if device_target == "Ascend": + ms.set_context(jit_level='O2', enable_graph_kernel=True) + elif device_target == "GPU": + ms.set_context(jit_level='O2', enable_graph_kernel=False) + + +@contextlib.contextmanager +def validate_with_ema(trainer, ema=False): + if not ema: + yield + return + _wrapped_model = trainer._wrapped_model + trainer._wrapped_model = deepcopy(trainer.ema.model_ema) + if trainer.args.fp16: + trainer._wrapped_model = trainer._wrapped_model.astype(ms.float16) # 替换PyTorch的half() + elif trainer.args.bf16: + trainer._wrapped_model = trainer._wrapped_model.astype(ms.bfloat16) # 替换PyTorch的bfloat16() + + try: + yield + finally: + del trainer._wrapped_model + trainer._wrapped_model = _wrapped_model +# import contextlib +# import importlib +# import logging +# import os +# import sys +# import warnings +# from copy import deepcopy +# from functools import partial +# from typing import List, Callable, Any, Dict +# import torch +# import torch.utils.checkpoint +# import torch.nn.functional as F + +# try: +# import unicore_fused_multi_tensor + +# HAS_MULTI_TENSOR = True +# except: +# print("fused_multi_tensor is not installed corrected") +# HAS_MULTI_TENSOR = False + +# try: +# import unicore_fused_rounding + +# HAS_FUSED_ROUNDING = True +# except: +# print("fused_rounding is not installed corrected") +# HAS_FUSED_ROUNDING = False + +# if not torch.cuda.is_available() or torch.cuda.get_device_capability()[0] < 7: +# HAS_MULTI_TENSOR = False +# HAS_FUSED_ROUNDING = False + +# logger = logging.getLogger(__name__) + + +# def apply_to_sample(f, sample): +# if hasattr(sample, "__len__") and len(sample) == 0: +# return {} + +# def _apply(x): +# if torch.is_tensor(x): +# return f(x) +# elif isinstance(x, dict): +# return {key: _apply(value) for key, value in x.items()} +# elif isinstance(x, list): +# return [_apply(x) for x in x] +# elif isinstance(x, tuple): +# return tuple(_apply(x) for x in x) +# elif isinstance(x, set): +# return {_apply(x) for x in x} +# else: +# return x + +# return _apply(sample) + + +# def move_to_cuda(sample, device=None): +# device = device or torch.cuda.current_device() + +# def _move_to_cuda(tensor): +# # non_blocking is ignored if tensor is not pinned, so we can always set +# # to True (see github.com/PyTorchLightning/pytorch-lightning/issues/620) +# return tensor.to(device=device, non_blocking=True) + +# return apply_to_sample(_move_to_cuda, sample) + + +# def move_to_cpu(sample): + +# def _move_to_cpu(tensor): +# # PyTorch has poor support for half tensors (float16) on CPU. +# # Move any such tensors to float32. +# if tensor.dtype in {torch.bfloat16, torch.float16}: +# tensor = tensor.to(dtype=torch.float32) +# return tensor.cpu() + +# return apply_to_sample(_move_to_cpu, sample) + + +# def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: +# per_device_grads = {} +# norms = [] +# for grad in grads: +# device = grad.device +# dtype = grad.dtype +# if device not in per_device_grads: +# per_device_grads[device] = {} +# if dtype not in per_device_grads[device]: +# per_device_grads[device][dtype] = [] +# per_device_grads[device][dtype].append(grad) +# for device in per_device_grads.keys(): +# for dtype in per_device_grads[device].keys(): +# cur_grads = per_device_grads[device][dtype] +# if HAS_MULTI_TENSOR and device.type == "cuda": +# norm = unicore_fused_multi_tensor.l2norm(chunk_size, [cur_grads]) +# norms.append(norm) +# else: +# norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_grads] +# total_norm = torch.norm(torch.stack(norms), p=2, dtype=torch.float32) +# return total_norm + + +# @torch.no_grad() +# def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: +# if isinstance(params, torch.Tensor): +# params = [params] +# params = list(params) +# grads = [p.grad.detach() for p in filter(lambda p: p.grad is not None, params)] +# if len(grads) == 0: +# if len(params) > 0: +# return params[0].new_tensor(0.0) +# else: +# return torch.tensor(0.0) + +# if len(grads) == 1: +# total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) +# else: +# total_norm = multi_tensor_total_norm(grads) + +# if aggregate_norm_fn is not None: +# total_norm = aggregate_norm_fn(total_norm) + +# if max_norm > 0: +# max_norm = float(max_norm) +# clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) +# for g in grads: +# g.mul_(clip_coef) +# return total_norm + + +# def import_user_module(args): +# module_path = getattr(args, "user_dir", None) +# if module_path is not None: +# module_path = os.path.abspath(args.user_dir) +# if not os.path.exists(module_path) and not os.path.isfile( +# os.path.dirname(module_path) +# ): +# unicore_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) +# if os.path.exists(unicore_rel_path): +# module_path = unicore_rel_path +# else: +# unicore_rel_path = os.path.join( +# os.path.dirname(__file__), "..", args.user_dir +# ) +# if os.path.exists(unicore_rel_path): +# module_path = unicore_rel_path +# else: +# raise FileNotFoundError(module_path) + +# # ensure that user modules are only imported once +# import_user_module.memo = getattr(import_user_module, "memo", set()) +# if module_path not in import_user_module.memo: +# import_user_module.memo.add(module_path) + +# module_parent, module_name = os.path.split(module_path) +# if module_name not in sys.modules: +# sys.path.insert(0, module_parent) +# importlib.import_module(module_name) +# else: +# raise ImportError( +# "Failed to import --user-dir={} because the corresponding module name " +# "({}) is not globally unique. Please rename the directory to " +# "something unique and try again.".format(module_path, module_name) +# ) + + +# def get_activation_fn(activation: str) -> Callable: +# """Returns the activation function corresponding to `activation`""" + +# if activation == "relu": +# return F.relu +# elif activation == "gelu": +# return F.gelu +# elif activation == "tanh": +# return torch.tanh +# elif activation == "linear": +# return lambda x: x +# else: +# raise RuntimeError("--activation-fn {} not supported".format(activation)) + + +# def get_available_activation_fns() -> List: +# return [ +# "relu", +# "gelu", +# "tanh", +# "linear", +# ] + + +# def has_parameters(module): +# try: +# next(module.parameters()) +# return True +# except StopIteration: +# return False + + +# def get_rng_state(): +# state = {"torch_rng_state": torch.get_rng_state()} +# if torch.cuda.is_available(): +# state["cuda_rng_state"] = torch.cuda.get_rng_state() +# return state + + +# def set_rng_state(state): +# torch.set_rng_state(state["torch_rng_state"]) +# if torch.cuda.is_available(): +# torch.cuda.set_rng_state(state["cuda_rng_state"]) + + +# @contextlib.contextmanager +# def torch_seed(seed, *addl_seeds): +# """Context manager which seeds the NumPy PRNG with the specified seed and +# restores the state afterward""" +# if seed is None: +# yield +# return + +# def check_seed(s): +# assert type(s) == int or type(s) == np.int32 or type(s) == np.int64 + +# check_seed(seed) +# if len(addl_seeds) > 0: +# for s in addl_seeds: +# check_seed(s) +# seed = int(hash((seed, *addl_seeds)) % 1e8) +# state = get_rng_state() +# torch.manual_seed(seed) +# if torch.cuda.is_available(): +# torch.cuda.manual_seed(seed) +# try: +# yield +# finally: +# set_rng_state(state) + + +# class CudaEnvironment(object): +# def __init__(self): +# cur_device = torch.cuda.current_device() +# prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) +# self.name = prop.name +# self.major = prop.major +# self.minor = prop.minor +# self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 + +# @staticmethod +# def pretty_print_cuda_env_list(cuda_env_list): +# """ +# Given a list of CudaEnviorments, pretty print them +# """ +# num_workers = len(cuda_env_list) +# center = "CUDA enviroments for all {} workers".format(num_workers) +# banner_len = 40 - len(center) // 2 +# first_line = "*" * banner_len + center + "*" * banner_len +# logger.info(first_line) +# for r, env in enumerate(cuda_env_list): +# logger.info( +# "rank {:3d}: ".format(r) +# + "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) +# + "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) +# + "name = {:40s}".format(env.name) +# ) +# logger.info(first_line) + + +# def csv_str_list(x): +# return x.split(",") + + +# def eval_str_list(x, type=float): +# if x is None: +# return None +# if isinstance(x, str): +# x = eval(x) +# try: +# return list(map(type, x)) +# except TypeError: +# return [type(x)] + + +# def eval_str_dict(x, type=dict): +# if x is None: +# return None +# if isinstance(x, str): +# x = eval(x) +# return x + + +# def eval_bool(x, default=False): +# if x is None: +# return default +# try: +# return bool(eval(x)) +# except TypeError: +# return default + + +# def checkpoint_sequential( +# functions, +# input, +# enabled=True, +# ): +# def wrap_tuple(a): +# return (a,) if type(a) is not tuple else a + +# def exec(func, a): +# return wrap_tuple(func(*a)) + +# def get_wrap_exec(func): +# def wrap_exec(*a): +# return exec(func, a) + +# return wrap_exec + +# input = wrap_tuple(input) + +# is_grad_enabled = torch.is_grad_enabled() + +# if enabled and is_grad_enabled: +# for func in functions: +# input = torch.utils.checkpoint.checkpoint(get_wrap_exec(func), *input) +# else: +# for func in functions: +# input = exec(func, input) +# return input + + +# def permute_final_dims(tensor: torch.Tensor, inds: List[int]): +# zero_index = -1 * len(inds) +# first_inds = list(range(len(tensor.shape[:zero_index]))) +# return tensor.permute(first_inds + [zero_index + i for i in inds]) + + +# def flatten_final_dims(t: torch.Tensor, num_dims: int): +# return t.reshape(t.shape[:-num_dims] + (-1,)) + + +# def masked_mean(mask, value, dim, eps=1e-10): +# mask = mask.expand(*value.shape) +# return torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) + + +# def dict_multimap(fn, dicts): +# first = dicts[0] +# new_dict = {} +# for k, v in first.items(): +# all_v = [d[k] for d in dicts] +# if type(v) is dict: +# new_dict[k] = dict_multimap(fn, all_v) +# else: +# new_dict[k] = fn(all_v) + +# return new_dict + + +# def one_hot(x, num_classes, dtype=torch.float32): +# x_one_hot = torch.zeros(*x.shape, num_classes, dtype=dtype, device=x.device) +# x_one_hot.scatter_(-1, x.long().unsqueeze(-1), 1) +# return x_one_hot + + +# def batched_gather(data, inds, dim=0, num_batch_dims=0): +# assert dim < 0 or dim - num_batch_dims >= 0 +# ranges = [] +# for i, s in enumerate(data.shape[:num_batch_dims]): +# r = torch.arange(s) +# r = r.view(*(*((1,) * i), -1, *((1,) * (len(inds.shape) - i - 1)))) +# ranges.append(r) + +# remaining_dims = [slice(None) for _ in range(len(data.shape) - num_batch_dims)] +# remaining_dims[dim - num_batch_dims if dim >= 0 else dim] = inds +# ranges.extend(remaining_dims) +# return data[ranges] + + +# def dict_map(fn, dic, leaf_type): +# new_dict = {} +# for k, v in dic.items(): +# if type(v) is dict: +# new_dict[k] = dict_map(fn, v, leaf_type) +# else: +# new_dict[k] = tree_map(fn, v, leaf_type) + +# return new_dict + + +# def tree_map(fn, tree, leaf_type): +# if isinstance(tree, dict): +# return dict_map(fn, tree, leaf_type) +# elif isinstance(tree, list): +# return [tree_map(fn, x, leaf_type) for x in tree] +# elif isinstance(tree, tuple): +# return tuple([tree_map(fn, x, leaf_type) for x in tree]) +# elif isinstance(tree, leaf_type): +# try: +# return fn(tree) +# except: +# raise ValueError(f"cannot apply {fn} on {tree}.") +# else: +# raise ValueError(f"{type(tree)} not supported") + + +# tensor_tree_map = partial(tree_map, leaf_type=torch.Tensor) + + +# def fp32_to_bf16_sr(t, o): +# if HAS_FUSED_ROUNDING and t.device.type == "cuda": +# unicore_fused_rounding.fp32_to_bf16_sr(t, o) +# else: +# r = ( +# torch.rand(size=t.size(), device=t.device, dtype=torch.float32) - 0.5 +# ) / 256 +# m, e = torch.frexp(t) +# t = t + torch.ldexp(r, e) +# o.data.copy_(t.bfloat16()) + + +# def set_jit_fusion_options(): +# """Set PyTorch JIT layer fusion options.""" +# # flags required to enable jit fusion kernels +# # legacy pytorch fuser +# torch._C._jit_set_profiling_mode(False) +# torch._C._jit_set_profiling_executor(False) +# torch._C._jit_override_can_fuse_on_cpu(True) +# torch._C._jit_override_can_fuse_on_gpu(True) + + +# @contextlib.contextmanager +# def validate_with_ema(trainer, ema=False): +# if not ema: +# yield +# return +# _wrapped_model = trainer._wrapped_model +# trainer._wrapped_model = deepcopy(trainer.ema.model_ema) +# if trainer.args.fp16: +# trainer._wrapped_model.half() +# elif trainer.args.bf16: +# trainer._wrapped_model.bfloat16() + +# try: +# yield +# finally: +# del trainer._wrapped_model +# trainer._wrapped_model = _wrapped_model diff --git a/MindChemistry/applications/Uni-Mol/unicore/version.txt b/MindChemistry/applications/Uni-Mol/unicore/version.txt new file mode 100644 index 0000000000000000000000000000000000000000..8acdd82b765e8e0b8cd8787f7f18c7fe2ec52493 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unicore/version.txt @@ -0,0 +1 @@ +0.0.1 diff --git a/MindChemistry/applications/Uni-Mol/unimol/Alignment/mindspore_ascend_output.py b/MindChemistry/applications/Uni-Mol/unimol/Alignment/mindspore_ascend_output.py new file mode 100644 index 0000000000000000000000000000000000000000..77773c5ebc273408c381e4929208bd80c6ddb194 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/Alignment/mindspore_ascend_output.py @@ -0,0 +1,947 @@ +import sys +import os +import argparse +import numpy as np +import mindspore as ms +import mindspore.mint.nn as nn +import mindspore.mint.nn.functional as F +from unicore import utils +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unicore.modules import LayerNorm, init_bert_params +from unimol.models.transformer_encoder_with_pair import TransformerEncoderWithPair +from typing import Dict, Any, List +from unicore.data import Dictionary +from unimol.data.add_2d_conformer_dataset import Add2DConformerDataset +from unimol.losses.conf_gen import MolConfGLoss + +# -------------------------- 1. 辅助函数(内置) -------------------------- +def set_seed(seed=42): + """固定随机种子(确保结果可复现)""" + np.random.seed(seed) + ms.set_seed(seed) + +def save_output(output_data, module_name, sample_idx, save_dir): + """ + 保存UniMol的输出为numpy的npz格式 + """ + module_save_dir = os.path.join(save_dir, module_name) + os.makedirs(module_save_dir, exist_ok=True) + save_path = os.path.join(module_save_dir, f"sample_{sample_idx}.npz") + + if isinstance(output_data, dict): + np.savez_compressed(save_path, **output_data) + else: + np.savez_compressed(save_path, output=output_data) + + print(f"✅ UniMol输出已保存:{os.path.abspath(save_path)}") + + +# -------------------------- 2. 模型定义(修正 `unimol.py`) -------------------------- +@register_model("unimol") +class UniMolModel(BaseUnicoreModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument("--encoder-layers", type=int, metavar="L", help="num encoder layers") + parser.add_argument("--encoder-embed-dim", type=int, metavar="H", help="encoder embedding dimension") + parser.add_argument("--encoder-ffn-embed-dim", type=int, metavar="F", help="encoder embedding dimension for FFN") + parser.add_argument("--encoder-attention-heads", type=int, metavar="A", help="num encoder attention heads") + parser.add_argument("--activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use") + parser.add_argument("--pooler-activation-fn", choices=utils.get_available_activation_fns(), help="activation function to use for pooler layer") + parser.add_argument("--emb-dropout", type=float, metavar="D", help="dropout probability for embeddings") + parser.add_argument("--dropout", type=float, metavar="D", help="dropout probability") + parser.add_argument("--attention-dropout", type=float, metavar="D", help="dropout probability for attention weights") + parser.add_argument("--activation-dropout", type=float, metavar="D", help="dropout probability after activation in FFN") + parser.add_argument("--pooler-dropout", type=float, metavar="D", help="dropout probability in the masked_lm pooler layers") + parser.add_argument("--max-seq-len", type=int, help="number of positional embeddings to learn") + parser.add_argument("--post-ln", type=bool, help="use post layernorm or pre layernorm") + parser.add_argument("--masked-token-loss", type=float, metavar="D", help="mask loss ratio") + parser.add_argument("--masked-dist-loss", type=float, metavar="D", help="masked distance loss ratio") + parser.add_argument("--masked-coord-loss", type=float, metavar="D", help="masked coord loss ratio") + parser.add_argument("--x-norm-loss", type=float, metavar="D", help="x norm loss ratio") + parser.add_argument("--delta-pair-repr-norm-loss", type=float, metavar="D", help="delta encoder pair repr norm loss ratio") + parser.add_argument("--masked-coord-dist-loss", type=float, metavar="D", help="masked coord dist loss ratio") + parser.add_argument("--mode", type=str, default="train", choices=["train", "infer"]) + + def __init__(self, args, dictionary): + super().__init__() + base_architecture(args) + self.args = args + self.padding_idx = dictionary.pad() + dict_len = len(dictionary) + if not (0 <= self.padding_idx < dict_len): + self.padding_idx = dict_len - 1 + print(f"⚠️ 原始 padding_idx={dictionary.pad()} 无效,修正为 {self.padding_idx}") + else: + print(f"✅ padding_idx 有效:{self.padding_idx}(字典长度={dict_len})") + + # 核心改动:移除了手动设置 nn.Embedding 权重的部分 + # 这一步可能导致 Ascend 上的内存访问问题,我们让 MindSpore 自动处理初始化 + self.embed_tokens = nn.Embedding(dict_len, args.encoder_embed_dim) + + print(f"✅ Embedding 创建成功(size={dict_len}, dim={args.encoder_embed_dim})") + + self._num_updates = None + self.encoder = TransformerEncoderWithPair( + encoder_layers=args.encoder_layers, + embed_dim=args.encoder_embed_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + emb_dropout=args.emb_dropout, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + max_seq_len=args.max_seq_len, + activation_fn=args.activation_fn, + no_final_head_layer_norm=args.delta_pair_repr_norm_loss < 0, + ) + + if args.masked_token_loss > 0: + self.lm_head = MaskLMHead(embed_dim=args.encoder_embed_dim, output_dim=dict_len, activation_fn=args.activation_fn, weight=None) + + K = 128 + n_edge_type = dict_len * dict_len + self.gbf_proj = NonLinearHead(K, args.encoder_attention_heads, args.activation_fn) + self.gbf = GaussianLayer(K, n_edge_type) + + if args.masked_coord_loss > 0: + self.pair2coord_proj = NonLinearHead(args.encoder_attention_heads, 1, args.activation_fn) + + if args.masked_dist_loss > 0: + self.dist_head = DistanceHead(args.encoder_attention_heads, args.activation_fn) + + self.classification_heads = nn.CellDict() + self.apply(init_bert_params) + + @classmethod + def build_model(cls, args, task): + return cls(args, task.dictionary) + + def forward(self, src_tokens, src_distance, src_coord, src_edge_type, encoder_masked_tokens=None, features_only=False, classification_head_name=None, **kwargs): + if classification_head_name is not None: + features_only = True + + padding_mask = ms.ops.equal(src_tokens, self.padding_idx) + if not padding_mask.any(): + padding_mask = None + x = self.embed_tokens(src_tokens) + + def get_dist_features(dist, et): + n_node = dist.shape[-1] + gbf_feature = self.gbf(dist, et) + gbf_result = self.gbf_proj(gbf_feature) + graph_attn_bias = gbf_result + graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) + graph_attn_bias = ms.ops.contiguous(graph_attn_bias) + graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) + return graph_attn_bias + + graph_attn_bias = get_dist_features(src_distance, src_edge_type) + ( + encoder_rep, + encoder_pair_rep, + delta_encoder_pair_rep, + x_norm, + delta_encoder_pair_rep_norm, + ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) + + encoder_pair_rep = ms.ops.where( + encoder_pair_rep == float("-inf"), + ms.ops.zeros_like(encoder_pair_rep), + encoder_pair_rep + ) + + encoder_distance = None + encoder_coord = None + + if not features_only: + if self.args.masked_token_loss > 0: + logits = self.lm_head(encoder_rep, encoder_masked_tokens) + if self.args.masked_coord_loss > 0: + coords_emb = src_coord + if padding_mask is not None: + atom_num = ms.ops.sum(1 - padding_mask.astype(x.dtype), dim=1).reshape(-1, 1, 1, 1) + else: + atom_num = src_coord.shape[1] + delta_pos = ms.ops.unsqueeze(coords_emb, 1) - ms.ops.unsqueeze(coords_emb, 2) + attn_probs = self.pair2coord_proj(delta_encoder_pair_rep) + coord_update = delta_pos / atom_num * attn_probs + pair_coords_mask = (1 - padding_mask.astype(ms.float32)).unsqueeze(-1) * (1 - padding_mask.astype(ms.float32)).unsqueeze(1) + coord_update = coord_update * pair_coords_mask.unsqueeze(-1) + coord_update = ms.ops.sum(coord_update, dim=2) + encoder_coord = coords_emb + coord_update + if self.args.masked_dist_loss > 0: + encoder_distance = self.dist_head(encoder_pair_rep) + + if classification_head_name is not None: + logits = self.classification_heads[classification_head_name](encoder_rep) + + if self.args.mode == 'infer': + return encoder_rep, encoder_pair_rep + else: + return ( + logits, + encoder_distance, + encoder_coord, + x_norm, + delta_encoder_pair_rep_norm, + ) + + def register_classification_head(self, name, num_classes=None, inner_dim=None, **kwargs): + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + print(f're-registering head "{name}" with num_classes {num_classes} (prev: {prev_num_classes}) and inner_dim {inner_dim} (prev: {prev_inner_dim})') + self.classification_heads[name] = ClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) + + def set_num_updates(self, num_updates): + self._num_updates = num_updates + + def get_num_updates(self): + return self._num_updates + + +class MaskLMHead(nn.Cell): + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, has_bias=False).weight + self.weight = weight + self.bias = ms.Parameter(ms.ops.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + if masked_tokens is not None: + features = features[masked_tokens, :] + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + x = F.linear(x, self.weight) + self.bias + return x + + +class ClassificationHead(nn.Cell): + def __init__(self, input_dim, inner_dim, num_classes, activation_fn, pooler_dropout): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) + self.out_proj = nn.Linear(inner_dim, num_classes) + + def forward(self, features, **kwargs): + x = features[:, 0, :] + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class NonLinearHead(nn.Cell): + def __init__(self, input_dim, out_dim, activation_fn, hidden=None): + super().__init__() + hidden = input_dim if not hidden else hidden + self.linear1 = nn.Linear(input_dim, hidden) + self.linear2 = nn.Linear(hidden, out_dim) + self.activation_fn = utils.get_activation_fn(activation_fn) + + def forward(self, x): + x = self.linear1(x) + x = self.activation_fn(x) + x = self.linear2(x) + return x + + +class DistanceHead(nn.Cell): + def __init__(self, heads, activation_fn): + super().__init__() + self.dense = nn.Linear(heads, heads) + self.layer_norm = nn.LayerNorm(heads) + self.out_proj = nn.Linear(heads, 1) + self.activation_fn = utils.get_activation_fn(activation_fn) + + def forward(self, x): + bsz, seq_len, seq_len, _ = x.shape + x = self.dense(x) + x = self.activation_fn(x) + x = self.layer_norm(x) + x = self.out_proj(x).reshape(bsz, seq_len, seq_len) + x = (x + ms.ops.transpose(x, (-1, -2))) * 0.5 + return x + + +def gaussian(x, mean, std): + pi = 3.14159 + a = (2 * pi) ** 0.5 + return ms.ops.exp(-0.5 * (((x - mean) / std) ** 2)) / (a * std) + + +class GaussianLayer(nn.Cell): + def __init__(self, K=128, edge_types=1024): + super().__init__() + self.K = K + self.means = nn.Embedding(1, K) + self.stds = nn.Embedding(1, K) + self.mul = nn.Embedding(edge_types, 1) + self.bias = nn.Embedding(edge_types, 1) + + self.means.weight.set_data(ms.common.initializer.Uniform(3)(self.means.weight.shape)) + self.stds.weight.set_data(ms.common.initializer.Uniform(3)(self.stds.weight.shape)) + self.bias.weight.set_data(ms.common.initializer.Constant(0)(self.bias.weight.shape)) + self.mul.weight.set_data(ms.common.initializer.Constant(1)(self.mul.weight.shape)) + + def forward(self, x, edge_type): + mul = self.mul(edge_type).astype(x.dtype) + bias = self.bias(edge_type).astype(x.dtype) + x = mul * ms.ops.unsqueeze(x, -1) + bias + x = ms.ops.tile(x, (1, 1, 1, self.K)) + mean = self.means.weight.astype(ms.float32).reshape(-1) + std = ms.ops.abs(self.stds.weight.astype(ms.float32).reshape(-1)) + 1e-5 + return gaussian(x.astype(ms.float32), mean, std).astype(self.means.weight.dtype) + + +@register_model_architecture("unimol", "unimol") +def base_architecture(args): + """Default Architecture""" + args.encoder_layers = getattr(args, "encoder_layers", 15) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) + args.dropout = getattr(args, "dropout", 0.1) + args.emb_dropout = getattr(args, "emb_dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.max_seq_len = getattr(args, "max_seq_len", 512) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.post_ln = getattr(args, "post_ln", False) + args.masked_token_loss = getattr(args, "masked_token_loss", -1.0) + args.masked_coord_loss = getattr(args, "masked_coord_loss", -1.0) + args.masked_dist_loss = getattr(args, "masked_dist_loss", -1.0) + args.x_norm_loss = getattr(args, "x_norm_loss", -1.0) + args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) + + +@register_model_architecture("unimol", "unimol_base") +def unimol_base_architecture(args): + """Base Architecture Alias""" + base_architecture(args) + + +# -------------------------- 3. 主程序(修复后的推理脚本) -------------------------- +if __name__ == "__main__": + # 更新全局上下文配置,使用非弃用API + ms.set_device(device_target="Ascend", device_id=0) + ms.set_context(mode=1) + # ms.set_context(jit_config={"jit_level": "O0"}) # 暂时关闭jit以方便调试,确定没有segfault后再打开 + set_seed(seed=42) + print(f"✅ 全局配置完成:Ascend:0,图模式") + + # 添加unimol路径 + UNIMOL_ROOT = "/home/ma-user/work/Uni-Mol/unimol" + if UNIMOL_ROOT not in sys.path: + sys.path.insert(0, UNIMOL_ROOT) + print(f"✅ 添加unimol路径:{UNIMOL_ROOT}") + else: + print(f"✅ unimol路径已存在:{UNIMOL_ROOT}") + + # 配置路径 + TEST_DATA_PATH = "/home/ma-user/work/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz" + DICTIONARY_PATH = "/home/ma-user/work/Uni-Mol/unimol/example_data/molecule/dict.txt" + OUTPUT_DIR = "./mindspore_ascend_outputs" + os.makedirs(OUTPUT_DIR, exist_ok=True) + print(f"✅ 输出目录创建完成:{OUTPUT_DIR}") + + # 加载数据与字典 + if not os.path.exists(TEST_DATA_PATH): + raise FileNotFoundError(f"测试数据不存在:{TEST_DATA_PATH}") + test_data = np.load(TEST_DATA_PATH, allow_pickle=True)["test_data"] + print(f"✅ 加载测试集:共{len(test_data)}个样本") + + if not os.path.exists(DICTIONARY_PATH): + raise FileNotFoundError(f"字典不存在:{DICTIONARY_PATH}") + atom_dict = Dictionary.load(DICTIONARY_PATH) + pad_idx = atom_dict.pad() + dict_len = len(atom_dict) + assert 0 <= pad_idx < dict_len, f"无效padding_idx:{pad_idx}(字典长度{dict_len})" + print(f"✅ 加载字典:{dict_len}个原子类型,padding_idx={pad_idx}") + + # 构造模型配置 + args = argparse.Namespace() + args.encoder_layers = 6 + args.encoder_embed_dim = 256 + args.encoder_ffn_embed_dim = 1024 + args.encoder_attention_heads = 8 + args.dropout = 0.1 + args.emb_dropout = 0.1 + args.attention_dropout = 0.1 + args.activation_dropout = 0.0 + args.max_seq_len = 256 + args.activation_fn = "gelu" + args.mode = "infer" + args.masked_token_loss = -1.0 + args.masked_coord_loss = -1.0 + args.masked_dist_loss = -1.0 + args.delta_pair_repr_norm_loss = -1.0 + base_architecture(args) + print(f"✅ 模型配置完成:层数{args.encoder_layers},维度{args.encoder_embed_dim}") + + # 创建模拟task与初始化模型 + class MockTask: + def __init__(self, dictionary): + self.dictionary = dictionary + + mock_task = MockTask(dictionary=atom_dict) + ms_model = UniMolModel.build_model(args=args, task=mock_task) + ms_model.set_train(False) + print(f"✅ UniMolModel 初始化成功!") + + # 批量处理样本 + for sample_idx, sample in enumerate(test_data): + print(f"\n=== 处理第{sample_idx}个样本 ===") + + # 数据加载模块 + print("--- 数据模块 ---") + dataset_input = [{"smi": sample["smi"], "atoms": sample["atoms"], "coordinates": sample["coordinates"]}] + ms_data_module = Add2DConformerDataset(dataset=dataset_input, smi="smi", atoms="atoms", coordinates="coordinates") + ms_data_output = ms_data_module[0] + save_output(output_data=ms_data_output, module_name="data_loader_add2d", sample_idx=sample_idx, save_dir=OUTPUT_DIR) + print("✅ 数据模块完成") + + # 模型前向模块 + print("--- 模型模块 ---") + try: + ms_atoms = ms.Tensor(ms_data_output["atoms"], dtype=ms.int32).expand_dims(0) + ms_coords = ms.Tensor(ms_data_output["coordinates"], dtype=ms.float32).expand_dims(0) + ms_distance = ms.Tensor(ms_data_output.get("distance", np.zeros((1, ms_atoms.shape[1], ms_atoms.shape[1]))), dtype=ms.float32) + ms_edge_type = ms.Tensor(ms_data_output.get("edge_type", np.zeros((1, ms_atoms.shape[1], ms_atoms.shape[1]))), dtype=ms.int32) + + encoder_rep, encoder_pair_rep = ms_model( + src_tokens=ms_atoms, + src_distance=ms_distance, + src_coord=ms_coords, + src_edge_type=ms_edge_type, + features_only=True + ) + + ms_model_output_np = { + "encoder_rep": encoder_rep.asnumpy(), + "encoder_pair_rep": encoder_pair_rep.asnumpy() + } + save_output(output_data=ms_model_output_np, module_name="unimol_model_forward", sample_idx=sample_idx, save_dir=OUTPUT_DIR) + print("✅ 模型模块完成") + except Exception as e: + print(f"❌ 模型前向失败:{str(e)}") + continue + + # 损失计算模块 + print("--- 损失模块 ---") + if hasattr(ms_model, "pair2coord_proj") and args.masked_coord_loss > 0: + delta_pos = ms.ops.unsqueeze(ms_coords, 1) - ms.ops.unsqueeze(ms_coords, 2) + attn_probs = ms_model.pair2coord_proj(encoder_pair_rep) + atom_num = ms.ops.sum(1 - ms.ops.equal(ms_atoms, pad_idx).astype(ms.float32), dim=1).reshape(-1, 1, 1, 1) + coord_update = delta_pos / atom_num * attn_probs + pred_coords = ms_coords + coord_update.sum(dim=2) + ms_loss_module = MolConfGLoss() + ms_loss = ms_loss_module(pred_coords, ms_coords) + save_output(output_data={"loss": ms_loss.asnumpy()}, module_name="conf_gen_loss", sample_idx=sample_idx, save_dir=OUTPUT_DIR) + print("✅ 损失模块完成") + else: + print("⚠️ 跳过损失模块(未启用坐标预测)") + + print(f"\n🎉 所有样本处理完成!输出目录:{OUTPUT_DIR}") +# import numpy as np +# import os +# import mindspore as ms +# import argparse +# from unicore.data import Dictionary + +# # -------------------------- 1. 内置set_seed和save_output(不依赖外部文件) -------------------------- +# def set_seed(seed=42): +# """固定随机种子(确保结果可复现)""" +# np.random.seed(seed) +# ms.set_seed(seed) + +# def save_uni_output(output_data, module_name, sample_idx, save_dir): +# """ +# 内置UniMol输出保存逻辑:将结果保存为numpy的npz格式 +# - output_data:UniMol的输出(如encoder_rep、pred_coords) +# - module_name:模块名(如"unimol_forward"、"conf_loss") +# - sample_idx:样本序号(区分不同样本) +# - save_dir:保存根目录 +# """ +# # 1. 创建模块专属目录(如 save_dir/unimol_forward) +# module_save_dir = os.path.join(save_dir, module_name) +# os.makedirs(module_save_dir, exist_ok=True) + +# # 2. 生成保存路径(如 save_dir/unimol_forward/sample_0.npz) +# save_path = os.path.join(module_save_dir, f"sample_{sample_idx}.npz") + +# # 3. 保存数据(支持字典格式,方便存储多个输出) +# if isinstance(output_data, dict): +# np.savez_compressed(save_path, **output_data) # 压缩保存,减少体积 +# else: +# np.savez_compressed(save_path, output=output_data) + +# # 4. 打印保存信息(方便确认保存位置) +# print(f"✅ UniMol输出已保存:{os.path.abspath(save_path)}") + +# import sys +# import os +# import numpy as np +# import mindspore as ms +# import argparse +# from unicore.data import Dictionary +# from precision_utils import set_seed, save_output # 假设该工具文件存在 + +# # -------------------------- 1. 全局Ascend配置(最优先执行,确保模块默认在Ascend) -------------------------- +# ms.context.set_context( +# device_target="Ascend", # 固定Ascend +# device_id=0, # 设备编号(默认0) +# mode=ms.GRAPH_MODE, # 图模式(适配UniMol) +# enable_graph_kernel=False # 关闭图核优化(避免旧版本bug) +# ) +# set_seed(seed=42) # 固定随机种子 +# print(f"✅ 全局配置完成:Ascend:0,图模式") + +# # -------------------------- 2. 添加unimol路径(确保导入成功) -------------------------- +# UNIMOL_ROOT = "/home/ma-user/work/Uni-Mol/unimol" # 你的unimol根目录 +# if UNIMOL_ROOT not in sys.path: +# sys.path.insert(0, UNIMOL_ROOT) +# print(f"✅ 添加unimol路径:{UNIMOL_ROOT}") +# else: +# print(f"✅ unimol路径已存在:{UNIMOL_ROOT}") + +# # 验证导入 +# try: +# from unimol.models.unimol import UniMolModel, base_architecture +# from unimol.data.add_2d_conformer_dataset import Add2DConformerDataset +# from unimol.losses.conf_gen import MolConfGLoss +# print("✅ 成功导入unimol模块") +# except ImportError as e: +# print(f"❌ 导入失败:{str(e)}") +# sys.exit() + +# # -------------------------- 3. 配置路径(按实际修改) -------------------------- +# TEST_DATA_PATH = "/home/ma-user/work/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz" # 测试数据 +# DICTIONARY_PATH = "/home/ma-user/work/Uni-Mol/unimol/example_data/molecule/dict.txt" # 原子字典 +# OUTPUT_DIR = "./mindspore_ascend_outputs" # 输出目录 +# os.makedirs(OUTPUT_DIR, exist_ok=True) +# print(f"✅ 输出目录创建完成:{OUTPUT_DIR}") + +# # -------------------------- 4. 加载数据与字典 -------------------------- +# # 加载测试集 +# if not os.path.exists(TEST_DATA_PATH): +# raise FileNotFoundError(f"测试数据不存在:{TEST_DATA_PATH}") +# test_data = np.load(TEST_DATA_PATH, allow_pickle=True)["test_data"] +# print(f"✅ 加载测试集:共{len(test_data)}个样本") + +# # 加载原子字典 +# if not os.path.exists(DICTIONARY_PATH): +# raise FileNotFoundError(f"字典不存在:{DICTIONARY_PATH}") +# atom_dict = Dictionary.load(DICTIONARY_PATH) +# pad_idx = atom_dict.pad() +# dict_len = len(atom_dict) +# assert 0 <= pad_idx < dict_len, f"无效padding_idx:{pad_idx}(字典长度{dict_len})" +# print(f"✅ 加载字典:{dict_len}个原子类型,padding_idx={pad_idx}") + +# # -------------------------- 5. 构造模型配置(args) -------------------------- +# args = argparse.Namespace() +# # 模型超参(与PyTorch对齐) +# args.encoder_layers = 6 # 编码器层数 +# args.encoder_embed_dim = 256 # 嵌入维度 +# args.encoder_ffn_embed_dim = 1024 # FFN维度 +# args.encoder_attention_heads = 8 # 注意力头数 +# args.dropout = 0.1 # dropout +# args.emb_dropout = 0.1 +# args.attention_dropout = 0.1 +# args.activation_dropout = 0.0 +# args.max_seq_len = 256 # 最大序列长度 +# args.activation_fn = "gelu" # 激活函数 +# args.mode = "infer" # 推理模式 +# args.masked_token_loss = -1.0 # 关闭训练损失 +# args.masked_coord_loss = -1.0 +# args.masked_dist_loss = -1.0 +# args.delta_pair_repr_norm_loss = -1.0 +# # 补全默认超参 +# base_architecture(args) +# print(f"✅ 模型配置完成:层数{args.encoder_layers},维度{args.encoder_embed_dim}") + +# # -------------------------- 6. 创建模拟task与初始化模型 -------------------------- +# class MockTask: +# def __init__(self, dictionary): +# self.dictionary = dictionary + +# mock_task = MockTask(dictionary=atom_dict) +# # 初始化模型(依赖全局Ascend配置,无设备操作) +# ms_model = UniMolModel.build_model(args=args, task=mock_task) +# ms_model.set_train(False) # 推理模式 +# print(f"✅ UniMolModel 初始化成功!") + +# # -------------------------- 7. 批量处理样本 -------------------------- +# for sample_idx, sample in enumerate(test_data): +# print(f"\n=== 处理第{sample_idx}个样本 ===") + +# # 7.1 数据加载模块(Add2DConformerDataset) +# print("--- 数据模块 ---") +# dataset_input = [{ +# "smi": sample["smi"], +# "atoms": sample["atoms"], +# "coordinates": sample["coordinates"] +# }] +# ms_data_module = Add2DConformerDataset( +# dataset=dataset_input, +# smi="smi", +# atoms="atoms", +# coordinates="coordinates" +# ) +# ms_data_output = ms_data_module[0] +# # 保存数据输出 +# save_output( +# output_data=ms_data_output, +# module_name="data_loader_add2d", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) +# print("✅ 数据模块完成") + +# # 7.2 模型前向模块(UniMolModel) +# print("--- 模型模块 ---") +# # 构造输入张量(依赖全局Ascend配置) +# ms_atoms = ms.Tensor(ms_data_output["atoms"], dtype=ms.float32).expand_dims(0) +# ms_coords = ms.Tensor(ms_data_output["coordinates"], dtype=ms.float32).expand_dims(0) +# ms_distance = ms.Tensor(ms_data_output.get("distance", np.zeros((1, ms_atoms.shape[1], ms_atoms.shape[1]))), dtype=ms.float32) +# ms_edge_type = ms.Tensor(ms_data_output.get("edge_type", np.zeros((1, ms_atoms.shape[1], ms_atoms.shape[1]))), dtype=ms.int32) +# # 模型前向(无设备操作) +# encoder_rep, encoder_pair_rep = ms_model( +# src_tokens=ms_atoms, +# src_distance=ms_distance, +# src_coord=ms_coords, +# src_edge_type=ms_edge_type, +# features_only=True +# ) +# # 保存模型输出 +# ms_model_output_np = { +# "encoder_rep": encoder_rep.asnumpy(), +# "encoder_pair_rep": encoder_pair_rep.asnumpy() +# } +# save_output( +# output_data=ms_model_output_np, +# module_name="unimol_model_forward", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) +# print("✅ 模型模块完成") + +# # 7.3 损失计算模块(MolConfGLoss) +# print("--- 损失模块 ---") +# if hasattr(ms_model, "pair2coord_proj") and args.masked_coord_loss > 0: +# delta_pos = ms.ops.unsqueeze(ms_coords, 1) - ms.ops.unsqueeze(ms_coords, 2) +# attn_probs = ms_model.pair2coord_proj(encoder_pair_rep) +# atom_num = ms.ops.sum(1 - ms.ops.equal(ms_atoms, pad_idx).astype(ms.float32), dim=1).reshape(-1, 1, 1, 1) +# coord_update = delta_pos / atom_num * attn_probs +# pred_coords = ms_coords + coord_update.sum(dim=2) +# # 计算损失 +# ms_loss_module = MolConfGLoss() +# ms_loss = ms_loss_module(pred_coords, ms_coords) +# # 保存损失 +# save_output( +# output_data={"loss": ms_loss.asnumpy()}, +# module_name="conf_gen_loss", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) +# print("✅ 损失模块完成") +# else: +# print("⚠️ 跳过损失模块(未启用坐标预测)") + +# print(f"\n🎉 所有样本处理完成!输出目录:{OUTPUT_DIR}") + + + +# import sys +# import os +# # -------------------------- mindspore_ascend_output.py 关键配置 -------------------------- +# import mindspore as ms +# from unimol.models.unimol import UniMolModel +# import argparse + +# # 2. 创建空的 args 配置对象(存储模型超参) +# args = argparse.Namespace() + +# # 3. 显式设置模型超参(根据你的需求调整,参考 UniMolModel 的 base_architecture 默认值) +# # 核心超参:与你之前的配置一致(如 encoder_layers=6,encoder_embed_dim=256) +# args.encoder_layers = 6 # 编码器层数(对应 Transformer 层数,之前用的 6) +# args.encoder_embed_dim = 256 # 嵌入维度(对应之前的 hidden_size=256) +# args.encoder_ffn_embed_dim = 1024 # FFN 层维度(通常是 embed_dim 的 4 倍,256*4=1024) +# args.encoder_attention_heads = 8 # 注意力头数(根据 embed_dim 调整,256/8=32,合理值) +# args.dropout = 0.1 # 整体 dropout 概率(之前用的 0.1) +# args.emb_dropout = 0.1 # 嵌入层 dropout(默认值) +# args.attention_dropout = 0.1 # 注意力层 dropout(默认值) +# args.activation_dropout = 0.0 # 激活层 dropout(默认值) +# args.max_seq_len = 256 # 最大序列长度(原子数上限,需覆盖你的样本) +# args.activation_fn = "gelu" # 激活函数(UniMol 常用 gelu) +# args.mode = "infer" # 模式:infer(推理,避免训练相关逻辑) +# args.masked_token_loss = -1.0 # 关闭掩码 token 损失(推理时不用) +# args.masked_coord_loss = -1.0 # 关闭坐标掩码损失(推理时不用) +# args.masked_dist_loss = -1.0 # 关闭距离掩码损失(推理时不用) +# args.delta_pair_repr_norm_loss = -1.0 # 关闭归一化损失(推理时不用) + +# # 4. 调用 base_architecture 补全默认超参(确保所有必要参数都有值) +# # 注意:base_architecture 函数在 UniMolModel 所在文件中,需先导入 +# from unimol.models.unimol import base_architecture +# base_architecture(args) + +# # 5. 验证 args 是否构造成功(可选,用于调试) +# print("✅ args 对象构造完成!关键超参:") +# print(" - 编码器层数:{}".format(args.encoder_layers)) +# print(" - 嵌入维度:{}".format(args.encoder_embed_dim)) +# print(" - 注意力头数:{}".format(args.encoder_attention_heads)) + +# # -------------------------- 之后再调用 build_model(此时 args 已定义) -------------------------- +# # 创建模拟 task 对象(你之前的代码,确保保留) +# class MockTask: +# def __init__(self, dictionary): +# self.dictionary = dictionary + +# # 加载原子字典(你之前的代码,确保保留) +# from unicore.data import Dictionary +# DICTIONARY_PATH = "/home/ma-user/work/Uni-Mol/unimol/example_data/molecule/dict.txt" # 你的字典路径 +# atom_dict = Dictionary.load(DICTIONARY_PATH) +# mock_task = MockTask(dictionary=atom_dict) + +# # 现在调用 build_model,args 已定义,不会报 NameError +# from unimol.models.unimol import UniMolModel +# ms_model = UniMolModel.build_model(args=args, task=mock_task) +# ms_model.set_train(False) # 推理模式 +# print("✅ UniMolModel 初始化成功!") +# # 1. 配置全局默认设备:强制后续所有张量/模块初始化时落在 Ascend:0 +# ms.context.set_context( +# device_target="Ascend", # 目标设备类型:必须是 Ascend +# device_id=0, # 设备编号:与你的 NPU 编号一致(通常是 0) +# mode=ms.GRAPH_MODE, # 图模式(UniMol 必需,避免动态图兼容性问题) +# enable_graph_kernel=False # 关闭图核优化,减少内核生成 bug(可选,但对旧版本友好) +# ) + +# # 2. 验证全局配置是否生效(可选,用于调试) +# current_target = ms.context.get_context("device_target") +# current_id = ms.context.get_context("device_id") +# print("✅ 全局设备配置生效:target={},device_id={}".format(current_target, current_id)) + +# # 后续导入和模型创建代码... +# import unimol +# # ... +# ms_model = UniMolModel.build_model(args=args, task=mock_task) # 此时创建的模型权重默认在 Ascend:0 +# # 替换为你的 unimol 源码根目录(就是 pip show unimol 显示的 Location) +# # 示例:UNIMOL_ROOT = "/home/ma-user/work/Uni-Mol/unimol" +# UNIMOL_ROOT = "/home/ma-user/work/Uni-Mol/unimol/unimol" + +# # 将 unimol 根目录添加到 Python 搜索路径(优先搜索) +# if UNIMOL_ROOT not in sys.path: +# sys.path.insert(0, UNIMOL_ROOT) +# print(f"✅ 已手动添加 unimol 路径到 Python:{UNIMOL_ROOT}") +# else: +# print(f"✅ unimol 路径已在 Python 搜索路径中:{UNIMOL_ROOT}") + +# # 验证是否能找到 unimol 模块(可选,用于调试) +# try: +# import unimol +# print(f"✅ 成功导入 unimol!版本:{unimol.__version__ if hasattr(unimol, '__version__') else '未知'}") +# except ImportError: +# print(f"❌ 仍无法导入 unimol,请检查 UNIMOL_ROOT 是否正确:{UNIMOL_ROOT}") +# sys.exit() # 路径错了就停止,避免后续报错 +# # ---------------------------------------------------------------------------------------- +# # mindspore_ascend_output.py(Ascend端运行的脚本) +# import numpy as np +# import mindspore as ms +# import os +# from unicore.data import Dictionary +# # 导入通用工具函数(必须和脚本在同一文件夹) +# from precision_utils import set_seed, save_output + +# # -------------------------- 1. 配置Ascend环境(不用改) -------------------------- +# ms.context.set_context( +# device_target="Ascend", # 固定为Ascend +# mode=ms.GRAPH_MODE, # 推理模式,固定 +# device_id=0 # NPU编号,通常是0,若有多个填1/2等 +# ) +# set_seed(seed=42) # 固定随机种子,和PyTorch一致 + +# # -------------------------- 2. 必须修改的3个路径(按你的实际情况改) -------------------------- +# # 路径1:固定测试集的路径(你刚上传的uni_mol_test_data.npz) +# # 示例:因为你上传到了当前文件夹,所以路径是"./uni_mol_test_data.npz"(不用改,除非你放别的地方) +# TEST_DATA_PATH = "/home/ma-user/work/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz" + +# # 路径2:原子字典的路径(Uni-Mol自带的atom_dict.txt,你上传到哪里就填哪里) +# # 示例:若你把atom_dict.txt上传到了./unimol/dict/文件夹,就填这个路径 +# DICTIONARY_PATH = "/home/ma-user/work/Uni-Mol/unimol/example_data/molecule/dict.txt" + +# # 路径3:输出文件保存目录(生成的MindSpore输出文件放这里,不用改,自动创建) +# OUTPUT_DIR = "./mindspore_ascend_outputs" + +# # -------------------------- 3. 导入你修改好的MindSpore版模块(必须和你的代码路径一致) -------------------------- +# # 示例:从unimol/data/导入数据模块,unimol/models/导入模型模块,unimol/losses/导入损失模块 +# from unimol.data.add_2d_conformer_dataset import Add2DConformerDataset # 数据模块 +# from unimol.models.unimol import UniMolModel # 模型模块 +# from unimol.losses.conf_gen import MolConfGLoss # 损失模块 + +# # -------------------------- 4. 加载测试集(不用改) -------------------------- +# if not os.path.exists(TEST_DATA_PATH): +# raise FileNotFoundError(f"测试集没找到!路径:{TEST_DATA_PATH},请重新上传") +# test_data = np.load(TEST_DATA_PATH, allow_pickle=True)["test_data"] +# print(f"✅ 加载测试集成功!共{len(test_data)}个样本") + +# # -------------------------- 5. 初始化原子字典(不用改) -------------------------- +# if not os.path.exists(DICTIONARY_PATH): +# raise FileNotFoundError(f"原子字典没找到!路径:{DICTIONARY_PATH},请重新上传") +# atom_dict = Dictionary.load(DICTIONARY_PATH) +# print(f"✅ 加载原子字典成功!共{len(atom_dict)}个原子类型") + +# # 新增:打印 padding_idx 和字典长度,验证有效性 +# pad_idx = atom_dict.pad() +# dict_len = len(atom_dict) +# print(f"🔍 原子字典验证:") +# print(f" - 字典长度(原子类型数):{dict_len}") +# print(f" - padding_idx(填充索引):{pad_idx}") +# print(f" - padding_idx 是否有效(0 ≤ pad_idx < dict_len):{0 <= pad_idx < dict_len}") + +# # 如果 padding_idx 无效,直接报错并提示 +# if not (0 <= pad_idx < dict_len): +# raise ValueError( +# f"❌ 无效的 padding_idx!当前 pad_idx={pad_idx},字典长度={dict_len}\n" +# f"要求:0 ≤ pad_idx < dict_len(填充索引必须是字典内的有效索引)\n" +# f"请检查原子字典文件 {DICTIONARY_PATH} 的生成逻辑,确保 pad() 返回合法值。" +# ) +# print(f"✅ 加载原子字典成功!共{dict_len}个原子类型,padding_idx={pad_idx}") + +# # -------------------------- 新增1:构造 UniMolModel 需要的 args 配置对象 -------------------------- +# import argparse # 用于创建配置对象 + +# # 1. 创建空的配置对象 +# args = argparse.Namespace() + +# # 2. 设置模型超参(参考 base_architecture 函数的默认值,可根据你的PyTorch端调整) +# # 核心超参:和你之前的 hidden_size=256、num_layers=6 对应 +# args.encoder_layers = 6 # 对应之前的 num_layers(编码器层数) +# args.encoder_embed_dim = 256 # 对应之前的 hidden_size(嵌入维度) +# args.encoder_ffn_embed_dim = 1024 # FFN层维度(默认是 embed_dim 的4倍,可调整) +# args.encoder_attention_heads = 8 # 注意力头数(默认值,可调整) +# args.dropout = 0.1 # dropout 概率(和之前一致) +# args.emb_dropout = 0.1 # 嵌入层 dropout(默认值) +# args.attention_dropout = 0.1 # 注意力层 dropout(默认值) +# args.activation_dropout = 0.0 # 激活层 dropout(默认值) +# args.max_seq_len = 256 # 最大序列长度(和你的 max_atoms 一致) +# args.activation_fn = "gelu" # 激活函数(默认值,和PyTorch一致) +# args.mode = "infer" # 模式:infer(推理),避免训练相关逻辑 +# args.masked_token_loss = -1.0 # 关闭掩码token损失(推理时不用) +# args.masked_coord_loss = -1.0 # 关闭坐标掩码损失(推理时不用) +# args.masked_dist_loss = -1.0 # 关闭距离掩码损失(推理时不用) +# args.delta_pair_repr_norm_loss = -1.0 # 关闭归一化损失(推理时不用) + +# # 3. 调用 base_architecture 函数补全默认参数(确保所有必要参数都有值) +# # 注意:base_architecture 函数在 UniMolModel 所在文件中定义,需先导入 +# from unimol.models.unimol import base_architecture +# base_architecture(args) +# print(f"✅ 构造 args 配置对象成功!编码器层数:{args.encoder_layers},嵌入维度:{args.encoder_embed_dim}") + +# # -------------------------- 新增2:创建模拟 task 对象(适配 build_model 方法) -------------------------- +# # UniMolModel 推荐用 build_model 类方法初始化,需传入 task 对象(含 dictionary 属性) +# class MockTask: +# def __init__(self, dictionary): +# self.dictionary = dictionary # task 对象必须有 dictionary 属性 + +# mock_task = MockTask(dictionary=atom_dict) +# print(f"✅ 创建模拟 task 对象成功!") + +# # -------------------------- 新增3:正确初始化 UniMolModel(用 build_model 方法) -------------------------- +# # 优先用 build_model 方法(自动从 task 取 dictionary,避免手动传参错误) +# from unimol.models.unimol import UniMolModel +# ms_model = UniMolModel.build_model(args=args, task=mock_task) +# ms_model.set_train(False) # 推理模式,固定 +# print(f"✅ UniMolModel 初始化成功!") + + +# # -------------------------- 6. 运行模块并保存输出(不用改) -------------------------- +# # 创建输出目录 +# os.makedirs(OUTPUT_DIR, exist_ok=True) + +# # 逐个样本运行 +# for sample_idx, sample in enumerate(test_data): +# print(f"\n=== 处理第{sample_idx}个样本 ===") + +# # 模块1:数据加载(Add2DConformerDataset) +# print("--- 运行数据模块 ---") +# dataset_input = [{ +# "smi": sample["smi"], +# "atoms": sample["atoms"], +# "coordinates": sample["coordinates"] +# }] +# ms_data_module = Add2DConformerDataset( +# dataset=dataset_input, +# smi="smi", +# atoms="atoms", +# coordinates="coordinates" +# ) +# ms_data_output = ms_data_module[0] +# # 保存数据模块输出 +# save_output( +# output_data=ms_data_output, +# module_name="data_loader_add2d", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) + +# # 模块2:模型前向(UniMolModel) +# print("--- 运行模型模块 ---") +# # 构造模型输入(和PyTorch一致) +# ms_atoms = ms.Tensor(ms_data_output["atoms"], dtype=ms.float32).expand_dims(0) +# ms_coords = ms.Tensor(ms_data_output["coordinates"], dtype=ms.float32).expand_dims(0) +# # 初始化模型(参数要和PyTorch一致,比如hidden_size=256) +# ms_model = UniMolModel( +# atom_dict_size=len(atom_dict), +# hidden_size=256, # 若你的模型用了其他值,这里要改 +# num_layers=6, # 同上,和PyTorch一致 +# dropout=0.1 # 同上 +# ) +# ms_model.set_train(False) # 推理模式,固定 +# ms_model_output = ms_model(ms_atoms, ms_coords) +# # 转为numpy数组(方便保存) +# if isinstance(ms_model_output, dict): +# ms_model_output_np = {k: v.asnumpy() for k, v in ms_model_output.items()} +# else: +# ms_model_output_np = ms_model_output.asnumpy() +# # 保存模型输出 +# save_output( +# output_data=ms_model_output_np, +# module_name="unimol_model_forward", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) + +# # 模块3:损失计算(ConfGenLoss) +# print("--- 运行损失模块 ---") +# ms_target_coords = ms_coords.clone() +# ms_loss_module = MolConfGLoss() +# ms_loss = ms_loss_module(ms_model_output["pred_coords"], ms_target_coords) +# # 保存损失输出 +# save_output( +# output_data={"loss": ms_loss.asnumpy()}, +# module_name="conf_gen_loss", +# sample_idx=sample_idx, +# save_dir=OUTPUT_DIR +# ) + +# print(f"\n🎉 所有样本运行完成!输出文件在:{OUTPUT_DIR}") +# print("下一步:把这个文件夹下载到本地!") \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/Alignment/precision_utils.py b/MindChemistry/applications/Uni-Mol/unimol/Alignment/precision_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e025b08f2527fe59cab92f574d6a085954b25c1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/Alignment/precision_utils.py @@ -0,0 +1,222 @@ +# precision_utils.py +# 通用精度对齐工具函数,供MindSpore/Ascend、PyTorch/Colab、本地对比脚本共用 +import numpy as np +import os +import sys + + +def set_seed(seed=42): + """ + 固定随机种子,确保MindSpore/PyTorch两端结果可复现 + Args: + seed: 随机种子值(默认42,需在MindSpore和PyTorch端使用相同值) + """ + # 1. 固定numpy随机种子(数据生成、数组操作依赖) + np.random.seed(seed) + + # 2. 固定MindSpore随机种子(若在MindSpore环境中调用) + try: + import mindspore as ms + ms.set_seed(seed) + # 关闭MindSpore图核优化、固定精度模式(避免因优化导致的精度差异) + ms.context.set_context( + enable_graph_kernel=False, + ascend_config={"precision_mode": "force_fp32"} + # precision_mode="fp32" # 统一用float32精度,避免fp16/fp32混用导致差异 + ) + except ImportError: + # 若未导入MindSpore(如在PyTorch端调用),跳过这部分 + pass + + # 3. 固定PyTorch随机种子(若在PyTorch环境中调用,预留接口) + try: + import torch + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # 若用GPU,固定CUDA种子 + torch.backends.cudnn.deterministic = True # 禁用CuDNN随机优化 + torch.backends.cudnn.benchmark = False + except ImportError: + # 若未导入PyTorch(如在MindSpore端调用),跳过这部分 + pass + + # 4. 固定RDKit随机种子(若用到分子坐标生成,预留接口) + try: + from rdkit import Chem + from rdkit.Chem import AllChem + + # 先尝试新版本方法(SetRandomSeed) + if hasattr(Chem, "SetRandomSeed"): + Chem.SetRandomSeed(seed) + AllChem.SetRandomSeed(seed) + print(f"✅ RDKit 新版本:已固定种子为 {seed}") + else: + # 旧版本适配:用 SetPreferCoordGen 减少随机性(虽不能完全固定,但能降低差异) + AllChem.SetPreferCoordGen(False) + print(f"✅ RDKit 旧版本:已启用坐标生成兼容性模式(无 SetRandomSeed,建议升级 RDKit)") + + except ImportError: + print("⚠️ 未导入 RDKit,跳过 RDKit 种子设置") + + # except ImportError: + # pass + + print(f"✅ 已固定随机种子为 {seed}(支持MindSpore/PyTorch/RDKit)") + + +def save_output(output_data, module_name, sample_idx, save_dir="uni_mol_outputs"): + """ + 保存模块输出为numpy文件(.npz格式),统一MindSpore/PyTorch输出格式,便于后续对比 + Args: + output_data: 待保存的输出数据(支持numpy数组、Python字典、MindSpore Tensor、PyTorch Tensor) + module_name: 模块名称(如"data_loader_add2d"、"unimol_model_forward",用于区分不同模块) + sample_idx: 样本索引(如0、1、2,确保MindSpore和PyTorch的同一样本输出对应) + save_dir: 输出文件保存目录(默认"uni_mol_outputs",自动创建) + Returns: + save_path: 输出文件的完整路径(供调用者记录或后续对比使用) + """ + # 1. 创建保存目录(若不存在则自动创建,避免报错) + os.makedirs(save_dir, exist_ok=True) + + # 2. 自动识别框架(MindSpore/PyTorch),统一转为numpy数组 + framework = "unknown" + # 处理MindSpore Tensor:转为numpy数组 + try: + import mindspore as ms + if isinstance(output_data, ms.Tensor): + output_data = output_data.asnumpy() + framework = "mindspore" + except ImportError: + pass + # 处理PyTorch Tensor:转为numpy数组(并转移到CPU) + try: + import torch + if isinstance(output_data, torch.Tensor): + output_data = output_data.detach().cpu().numpy() + framework = "pytorch" + except ImportError: + pass + # 若未识别到框架(如直接传入numpy数组/字典),根据目录名推断(预留) + if framework == "unknown": + if "mindspore" in save_dir.lower(): + framework = "mindspore" + elif "pytorch" in save_dir.lower(): + framework = "pytorch" + + # 3. 生成输出文件名(格式:框架_模块名_样本索引.npz,便于后续对应) + file_name = f"{framework}_{module_name}_idx{sample_idx}.npz" + save_path = os.path.join(save_dir, file_name) # 拼接完整路径 + + # 4. 保存数据(支持字典和单个数组两种格式) + if isinstance(output_data, dict): + # 若为字典:保存所有键值对(如{"atoms": ..., "coordinates": ...}) + np.savez(save_path, **output_data) + else: + # 若为单个数组:用"output"作为键保存(统一格式) + np.savez(save_path, output=output_data) + + # 5. 打印保存信息(便于调试,确认保存路径) + print(f"💾 已保存 {framework} 模块输出:") + print(f" 模块名:{module_name} | 样本索引:{sample_idx}") + print(f" 保存路径:{save_path}") + + return save_path + + +# -------------------------- 以下为可选的辅助函数(本地对比时会用到,提前预留) -------------------------- +def compare_output(pytorch_save_path, mindspore_save_path, module_name, sample_idx): + """ + 对比PyTorch和MindSpore的输出精度(预留函数,供本地对比脚本使用) + Args: + pytorch_save_path: PyTorch输出文件的路径(.npz格式) + mindspore_save_path: MindSpore输出文件的路径(.npz格式) + module_name: 模块名称(用于日志输出,区分不同模块) + sample_idx: 样本索引(用于日志输出,定位同一样本) + Returns: + is_aligned: 是否对齐(True=误差≤1e-5,False=误差>1e-5) + error_info: 误差详情(字典,包含最大绝对误差、最大相对误差等) + """ + # 1. 加载两端输出文件(若文件不存在,直接报错并返回) + if not os.path.exists(pytorch_save_path): + raise FileNotFoundError(f"PyTorch输出文件不存在:{pytorch_save_path}") + if not os.path.exists(mindspore_save_path): + raise FileNotFoundError(f"MindSpore输出文件不存在:{mindspore_save_path}") + + # 2. 读取.npz文件(支持字典格式) + pt_data = np.load(pytorch_save_path, allow_pickle=True) + ms_data = np.load(mindspore_save_path, allow_pickle=True) + + # 3. 检查两端的键是否一致(若为字典格式,键必须完全相同,否则无法对比) + pt_keys = set(pt_data.keys()) + ms_keys = set(ms_data.keys()) + if pt_keys != ms_keys: + error_info = { + "module": module_name, + "sample_idx": sample_idx, + "status": "error", + "reason": "键不匹配", + "pytorch_keys": list(pt_keys), + "mindspore_keys": list(ms_keys) + } + print(f"❌ {module_name}(样本{sample_idx}):PyTorch和MindSpore键不匹配!") + print(f" PyTorch键:{pt_keys}") + print(f" MindSpore键:{ms_keys}") + return False, error_info + + # 4. 逐个键对比精度(对齐标准:最大绝对误差≤1e-5 且 最大相对误差≤1e-5) + error_info = { + "module": module_name, + "sample_idx": sample_idx, + "status": "aligned", + "details": {} + } + is_aligned = True # 默认对齐,若有一个键未对齐则改为False + + for key in pt_keys: + # 读取并统一转为float32(避免数据类型差异导致的误差) + pt_arr = pt_data[key].astype(np.float32) + ms_arr = ms_data[key].astype(np.float32) + + # 检查形状是否一致(形状不同直接判定未对齐) + if pt_arr.shape != ms_arr.shape: + error_info["status"] = "not_aligned" + error_info["details"][key] = { + "status": "not_aligned", + "reason": "形状不匹配", + "pytorch_shape": pt_arr.shape, + "mindspore_shape": ms_arr.shape + } + is_aligned = False + print(f"❌ {module_name}(样本{sample_idx})键[{key}]:形状不匹配!") + print(f" PyTorch形状:{pt_arr.shape} | MindSpore形状:{ms_arr.shape}") + continue + + # 计算误差(绝对误差、相对误差) + abs_error = np.abs(pt_arr - ms_arr) # 逐元素绝对误差 + max_abs_error = np.max(abs_error) # 最大绝对误差(关键指标) + # 相对误差:避免除以0,分母取max(绝对值, 1e-10) + rel_error = abs_error / np.maximum(np.abs(pt_arr), 1e-10) + max_rel_error = np.max(rel_error) # 最大相对误差(关键指标) + + # 判断当前键是否对齐(误差≤1e-5为对齐) + key_aligned = (max_abs_error <= 1e-5) and (max_rel_error <= 1e-5) + if not key_aligned: + is_aligned = False + error_info["status"] = "not_aligned" + + # 记录当前键的误差详情 + error_info["details"][key] = { + "status": "aligned" if key_aligned else "not_aligned", + "max_abs_error": float(max_abs_error), # 转为Python float(便于后续保存) + "max_rel_error": float(max_rel_error), + "pytorch_mean": float(np.mean(pt_arr)), # PyTorch输出均值(参考) + "mindspore_mean": float(np.mean(ms_arr)) # MindSpore输出均值(参考) + } + + # 打印当前键的对比结果 + log_prefix = "✅" if key_aligned else "❌" + print(f"{log_prefix} {module_name}(样本{sample_idx})键[{key}]:") + print(f" 最大绝对误差:{max_abs_error:.8f}(阈值≤1e-5)") + print(f" 最大相对误差:{max_rel_error:.8f}(阈值≤1e-5)") + print(f" PyTorch均值:{np.mean(pt_arr):.6f} | MindSpore均值:{np.mean(ms_arr):.6f}") + + return is_aligned, error_info \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz b/MindChemistry/applications/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz new file mode 100644 index 0000000000000000000000000000000000000000..77a5496ad4bcdd41791970b2240dfa8e66afa2db Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/Alignment/uni_mol_test_data.npz differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/README.md b/MindChemistry/applications/Uni-Mol/unimol/README.md new file mode 100644 index 0000000000000000000000000000000000000000..6b0b93023cd1dd7da3c3f832a4ec5a0cdcddc4cc --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/README.md @@ -0,0 +1,546 @@ +Uni-Mol: A Universal 3D Molecular Representation Learning Framework +=================================================================== + +[[Paper](https://openreview.net/forum?id=6K2RM6wVqKu)], [[Uni-Mol Docking Colab](https://colab.research.google.com/github/deepmodeling/Uni-Mol/blob/main/unimol/notebooks/unimol_binding_pose_demo.ipynb)] + +Authors: Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke + +Uni-Mol is a universal 3D molecular pretraining framework that significantly enlarges the representation ability and application scope in drug design. + +

+

Schematic illustration of the Uni-Mol framework

+ +Uni-Mol comprises two models: a molecular pretraining model that has been trained using 209M molecular 3D conformations, and a pocket pretraining model that has been trained using 3M candidate protein pocket data. These two models can be used independently for different tasks and are combined for protein-ligand binding tasks. Uni-Mol has demonstrated superior performance compared to the state-of-the-art (SOTA) in 14 out of 15 molecular property prediction tasks. Moreover, Uni-Mol has achieved exceptional accuracy in 3D spatial tasks, such as protein-ligand binding pose prediction and molecular conformation generation. + + +Uni-Mol's 3D conformation data +------------------------------ + +For the details of datasets, please refer to Appendix A and B in our [paper](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361). + +There are total 6 datasets: + + +| Data | File Size | Update Date | Download Link | +|--------------------------|------------| ----------- |---------------------------------------------------------------------------------------------------------------------------| +| molecular pretrain | 114.76GB | Jun 10 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/pretrain/ligands.tar.gz | +| pocket pretrain | 8.585GB | Aug 17 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/pretrain/pockets.tar.gz | +| molecular property | 3.506GB | Jul 10 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/molecular_property_prediction.tar.gz | +| molecular conformation | 8.331GB | Jul 19 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/conformation_generation.tar.gz | +| pocket property | 455.239MB | Jul 19 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/pocket_property_prediction.tar.gz | +| protein-ligand binding | 263.27MB | Sep 8 2022 |https://bioos-hermite-beijing.tos-cn-beijing.volces.com/unimol_data/finetune/protein_ligand_binding_pose_prediction.tar.gz | + + +We use [LMDB](https://lmdb.readthedocs.io) to store data, you can use the following code snippets to read from the LMDB file. + +```python +import lmdb +import numpy as np +import os +import pickle + +def read_lmdb(lmdb_path): + env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=256, + ) + txn = env.begin() + keys = list(txn.cursor().iternext(values=False)) + for idx in keys: + datapoint_pickled = txn.get(idx) + data = pickle.loads(datapoint_pickled) +``` +We use pickle protocol 5, so Python >= 3.8 is recommended. + + +Uni-Mol's pretrained model weights +---------------------------------- + +| Model | File Size |Update Date | Download Link | +|--------------------------|------------| ------------|--------------------------------------------------------------| +| molecular pretrain | 181MB | Aug 17 2022 |https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/mol_pre_no_h_220816.pt | +| pocket pretrain | 181MB | Aug 17 2022 |https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt | + + +Uni-Mol's finetuned model weights +---------------------------------- + +| Model | File Size| Update Date| Download Link | +|-------------------------------------------------|---------| -----------|--------------------------------------------------------------------| +| molecular conformation generation (qm9) | 181MB | Sep 8 2022 |https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/qm9_220908.pt | +| molecular conformation generation (drugs) | 181MB | Sep 8 2022 |https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/drugs_220908.pt | +| Protein-ligand binding pose prediction | 415MB | Sep 8 2022 |https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/binding_pose_220908.pt | + + +Dependencies +------------ + - [Uni-Core](https://github.com/dptech-corp/Uni-Core), check its [Installation Documentation](https://github.com/dptech-corp/Uni-Core#installation). + - rdkit==2022.9.3, install via `pip install rdkit-pypi==2022.9.3` + +To use GPUs within docker you need to [install nvidia-docker-2](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/install-guide.html#docker) first. Use the following command to pull the docker image: + +```bash +docker pull dptechnology/unimol:latest-pytorch1.11.0-cuda11.3 +``` + +Molecular Pretraining +--------------------- + +```bash +data_path=./example_data/molecule/ # replace to your data path +save_dir=./save/ # replace to your save path +n_gpu=8 +MASTER_PORT=10086 +lr=1e-4 +wd=1e-4 +batch_size=16 +update_freq=1 +masked_token_loss=1 +masked_coord_loss=5 +masked_dist_loss=10 +x_norm_loss=0.01 +delta_pair_repr_norm_loss=0.01 +mask_prob=0.15 +only_polar=0 +noise_type="uniform" +noise=1.0 +seed=1 +warmup_steps=10000 +max_steps=1000000 + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \ + --num-workers 8 --ddp-backend=c10d \ + --task unimol --loss unimol --arch unimol_base \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 --weight-decay $wd \ + --lr-scheduler polynomial_decay --lr $lr --warmup-updates $warmup_steps --total-num-update $max_steps \ + --update-freq $update_freq --seed $seed \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $save_dir/tsb \ + --max-update $max_steps --log-interval 10 --log-format simple \ + --save-interval-updates 10000 --validate-interval-updates 10000 --keep-interval-updates 10 --no-epoch-checkpoints \ + --masked-token-loss $masked_token_loss --masked-coord-loss $masked_coord_loss --masked-dist-loss $masked_dist_loss \ + --x-norm-loss $x_norm_loss --delta-pair-repr-norm-loss $delta_pair_repr_norm_loss \ + --mask-prob $mask_prob --noise-type $noise_type --noise $noise --batch-size $batch_size \ + --save-dir $save_dir --only-polar $only_polar + +``` +The above setting is for 8 V100 GPUs, and the batch size is 128 (`n_gpu * batch_size * update_freq`). You may need to change `batch_size` or `update_freq` according to your environment. + +Pocket Pretraining +------------------ + +```bash +data_path=./example_data/pocket/ # replace to your data path +save_dir=./save/ # replace to your save path +n_gpu=8 +MASTER_PORT=10086 +dict_name="dict_coarse.txt" +lr=1e-4 +wd=1e-4 +batch_size=16 +update_freq=1 +masked_token_loss=1 +masked_coord_loss=1 +masked_dist_loss=1 +x_norm_loss=0.01 +delta_pair_repr_norm_loss=0.01 +mask_prob=0.15 +noise_type="uniform" +noise=1.0 +seed=1 +warmup_steps=10000 +max_steps=1000000 + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \ + --num-workers 8 --ddp-backend=c10d \ + --task unimol_pocket --loss unimol --arch unimol_base \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 --weight-decay $wd \ + --lr-scheduler polynomial_decay --lr $lr --warmup-updates $warmup_steps --total-num-update $max_steps \ + --update-freq $update_freq --seed $seed \ + --dict-name $dict_name \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --tensorboard-logdir $save_dir/tsb \ + --max-update $max_steps --log-interval 10 --log-format simple \ + --save-interval-updates 10000 --validate-interval-updates 10000 --keep-interval-updates 10 \ + --masked-token-loss $masked_token_loss --masked-coord-loss $masked_coord_loss --masked-dist-loss $masked_dist_loss \ + --x-norm-loss $x_norm_loss --delta-pair-repr-norm-loss $delta_pair_repr_norm_loss \ + --mask-prob $mask_prob --noise-type $noise_type --noise $noise --batch-size $batch_size \ + --save-dir $save_dir + +``` +The above setting is for 8 V100 GPUs, and the batch size is 128 (`n_gpu * batch_size * update_freq`). You may need to change `batch_size` or `update_freq` according to your environment. + + +Molecular Property Prediction +------------------ + +```bash +data_path="./molecular_property_prediction" # replace to your data path +save_dir="./save_finetune" # replace to your save path +n_gpu=4 +MASTER_PORT=10086 +dict_name="dict.txt" +weight_path="./weights/checkpoint.pt" # replace to your ckpt path +task_name="qm9dft" # molecular property prediction task name +task_num=3 +loss_func="finetune_smooth_mae" +lr=1e-4 +batch_size=32 +epoch=40 +dropout=0 +warmup=0.06 +local_batch_size=32 +only_polar=0 +conf_size=11 +seed=0 + +if [ "$task_name" == "qm7dft" ] || [ "$task_name" == "qm8dft" ] || [ "$task_name" == "qm9dft" ]; then + metric="valid_agg_mae" +elif [ "$task_name" == "esol" ] || [ "$task_name" == "freesolv" ] || [ "$task_name" == "lipo" ]; then + metric="valid_agg_rmse" +else + metric="valid_agg_auc" +fi + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +update_freq=`expr $batch_size / $local_batch_size` +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --task-name $task_name --user-dir ./unimol --train-subset train --valid-subset valid \ + --conf-size $conf_size \ + --num-workers 8 --ddp-backend=c10d \ + --dict-name $dict_name \ + --task mol_finetune --loss $loss_func --arch unimol_base \ + --classification-head-name $task_name --num-classes $task_num \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ + --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\ + --update-freq $update_freq --seed $seed \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ + --log-interval 100 --log-format simple \ + --validate-interval 1 \ + --finetune-from-model $weight_path \ + --best-checkpoint-metric $metric --patience 20 \ + --save-dir $save_dir --only-polar $only_polar + +# --maximize-best-checkpoint-metric, for classification task + +``` + +To speed up finetune, we set `n_gpu=4` for QM9, MUV, PCBA and HIV, and `n_gpu=1` for others, and the batch size is `n_gpu * local_batch_size * update_freq`. +For classification task, we set `--maximize-best-checkpoint-metric`. + +Each task will be run by 3 different seeds. We choose the checkpoint with the best metric on validation set and report the mean and standard deviation of the three results on the test set. + +For the selection of `task_num` and other hyperparameters, please refer to the following table: + +- Classification + +|Dataset | BBBP | BACE | ClinTox | Tox21 | ToxCast | SIDER | HIV | PCBA | MUV | +|--------|----|----|----|----|----|-----|-----|----|-----| +| task_num | 2 | 2 | 2 | 12 | 617 | 27 | 2 | 128 | 17 | +| lr | 4e-4 | 1e-4 | 5e-5 | 1e-4 | 1e-4 | 5e-4 | 5e-5 | 1e-4 | 2e-5 | +| batch_size | 128 | 64 | 256 | 128 | 64 | 32 | 256 | 128 | 128 | +| epoch | 40 | 60 | 100 | 80 | 80 | 80 | 5 | 20 | 40 | +| dropout | 0 | 0.1 | 0.5 | 0.1 | 0.1 | 0 | 0.2 | 0.1 | 0 | +| warmup | 0.06 | 0.06 | 0.1 | 0.06 | 0.06 | 0.1 | 0.1 | 0.06 | 0 | + +For BBBP, BACE and HIV, we set `loss_func=finetune_cross_entropy`. +For ClinTox, Tox21, ToxCast, SIDER, HIV, PCBA and MUV, we set `loss_func=multi_task_BCE`. + +- Regression + +| Dataset | ESOL | FreeSolv | Lipo | QM7 | QM8 | QM9 | +|----- | ---- | ---- | ---- | ---- | --- | --- | +| task_num | 1 | 1 | 1 | 1 | 12 | 3 | +| lr | 5e-4 | 8e-5 | 1e-4 | 3e-4 | 1e-4 | 1e-4 | +| batch_size | 256 | 64 | 32 | 32 | 32 | 128 | +| epoch | 100 | 60 | 80 | 100 | 40 | 40 | +| dropout | 0.2 | 0.2 | 0.1 | 0 | 0 | 0 | +| warmup | 0.06 | 0.1 | 0.06 | 0.06 | 0.06 | 0.06 | + + +For ESOL, FreeSolv and Lipo, we set `loss_func=finetune_mse`. +For QM7, QM8 and QM9, we set `loss_func=finetune_smooth_mae`. + +**NOTE**: Our first version of the molecular pretraining ran with **all hydrogen** pretrained model, and above hyper-parameters are also for **all hydrogen** pretrained model. You can download the [all hydrogen model parameter](https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/mol_pre_all_h_220816.pt) here, and use it with `only_polar=-1` to reproduce our results. The performance of pretraining model with **no hydrogen** is very close to the **all hydrogen** one in molecular property prediction. We will update the hyperparameters for the no hydrogen version later. + +**NOTE**: For reproduce, you can do the validation on test set while training, with `--valid-subset valid` changing to `--valid-subset valid,test`. The model selection is still based on the performance of the valid set. It is controlled by `--best-checkpoint-metric $metric`. + +**NOTE**: You"d better align the `only_polar` parameter in pretraining and finetuning: `-1` for all hydrogen, `0` for no hydrogen, `1` for polar hydrogen. + + +Molecular conformation generation +------------------ + +**NOTE**: If you would like to reproduce the results from the paper, you can switch to commit 37b0198 or an earlier commit by using the following command: +``` +git checkout 37b0198cf68a349a854410a06777c2e7dacbce5e +``` +**Reproduction** + +1. Finetune Uni-Mol pretrained model on the training set of the conformation generation task: + +```bash +data_path="./conformation_generation" # replace to your data path +save_dir="./save_confgen" # replace to your save path +n_gpu=1 +MASTER_PORT=10086 +dict_name="dict.txt" +weight_path="./weights/checkpoint.pt" # replace to your ckpt path +task_name="qm9" # or "drugs", conformation generation task name, as a part of complete data path +recycles=4 +coord_loss=1 +distance_loss=1 +beta=4.0 +smooth=0.1 +topN=20 +lr=2e-5 +batch_size=128 +epoch=50 +warmup=0.06 +update_freq=1 + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --task-name $task_name --user-dir ./unimol --train-subset train --valid-subset valid \ + --num-workers 8 --ddp-backend=c10d \ + --task mol_confG --loss mol_confG --arch mol_confG \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ + --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size \ + --update-freq $update_freq --seed 1 \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ + --log-interval 100 --log-format simple --tensorboard-logdir $save_dir/tsb \ + --validate-interval 1 --keep-last-epochs 10 \ + --keep-interval-updates 10 --best-checkpoint-metric loss --patience 50 --all-gather-list-size 102400 \ + --finetune-mol-model $weight_path --save-dir $save_dir \ + --coord-loss $coord_loss --distance-loss $distance_loss \ + --num-recycles $recycles --beta $beta --smooth $smooth --topN $topN \ + --find-unused-parameters + +``` + +2. Generate initial RDKit conformations for inference: + +- Run this command, + +```bash +mode="gen_data" +nthreads=20 # Num of threads +reference_file="./conformation_generation/qm9/test_data_200.pkl" # Your reference file dir +output_dir="./conformation_generation/qm9" # Generated initial data dir + +python ./unimol/utils/conf_gen_cal_metrics.py --mode $mode --nthreads $nthreads --reference-file $reference_file --output-dir $output_dir + +``` + +3. Inference on the generated RDKit initial conformations: + +```bash +data_path="./conformation_generation" # replace to your data path +results_path="./infer_confgen" # replace to your results path +weight_path="./save_confgen/checkpoint_best.pt" # replace to your ckpt path +batch_size=128 +task_name="qm9" # or "drugs", conformation generation task name +recycles=4 + +python ./unimol/infer.py --user-dir ./unimol $data_path --task-name $task_name --valid-subset test \ + --results-path $results_path \ + --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \ + --task mol_confG --loss mol_confG --arch mol_confG \ + --num-recycles $recycles \ + --path $weight_path \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ + --log-interval 50 --log-format simple +``` +- For reproduce, you can also use the finetuned checkpoint we released in the table above to infer. + +- **NOTE**: Currently, the inference is only supported to run on a single GPU. You can add `CUDA_VISIBLE_DEVICES="0"` before the command. + +4. Calculate metrics on the results of inference: + +- Run this command +```bash +mode="cal_metrics" +threshold=0.5 # Threshold for cal metrics, 0.5 for qm9, 1.25 for drugs +nthreads=20 # Num of threads +predict_file="./infer_confgen/save_confgen_test.out.pkl" # Your inference file dir +reference_file="./conformation_generation/qm9/test_data_200.pkl" # Your reference file dir + +python ./unimol/utils/conf_gen_cal_metrics.py --mode $mode --threshold $threshold --nthreads $nthreads --predict-file $predict_file --reference-file $reference_file +``` + + +Pocket Property Prediction +------------------ + +```bash +data_path="./pocket_property_prediction" # replace to your data path +save_dir="./save_finetune" # replace to your save path +n_gpu=1 +MASTER_PORT=10086 +dict_name="dict_coarse.txt" +weight_path="./weights/checkpoint.pt" +task_name="druggability" # or "nrdld", pocket property prediction dataset folder name +lr=3e-4 +batch_size=32 +epoch=20 +dropout=0 +warmup=0.1 +local_batch_size=32 +seed=1 + +if [ "$task_name" == "druggability" ]; then + metric="valid_rmse" + loss_func="finetune_mse_pocket" + task_num=1 + fpocket_score="Druggability Score" # choose in ["Score", "Druggability Score", "Total SASA", "Hydrophobicity score"] +else + metric="loss" + loss_func="finetune_cross_entropy_pocket" + task_num=2 +fi + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +update_freq=`expr $batch_size / $local_batch_size` +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --task-name $task_name --user-dir ./unimol --train-subset train --valid-subset valid \ + --num-workers 8 --ddp-backend=c10d \ + --dict-name $dict_name \ + --task pocket_finetune --loss $loss_func --arch unimol_base \ + --classification-head-name $task_name --num-classes $task_num \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ + --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout \ + --update-freq $update_freq --seed $seed \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ + --log-interval 100 --log-format simple \ + --validate-interval 1 --finetune-from-model $weight_path \ + --best-checkpoint-metric $metric --patience 2000 \ + --save-dir $save_dir --remove-hydrogen --fpocket-score "$fpocket_score" + +# --maximize-best-checkpoint-metric, for classification task + +``` + +The batch size is `n_gpu * local_batch_size * update_freq`. +For classification task, we set `--maximize-best-checkpoint-metric`. + +We choose the checkpoint with the best metric on validation set. It is controlled by `--best-checkpoint-metric $metric`. Specifically, for NRDLD, since it has no validation set, we choose the checkpoint of the last epoch. For Fpocket Scores, we report the mean and standard deviation of the results for three random seeds. + +**NOTE**: For reproduce, you can do the validation on test set while training, with `--valid-subset valid` changing to `--valid-subset valid,test`. + + +Protein-ligand Binding Pose Prediction +------------------ + +1. Finetune Uni-Mol pretrained model on the training set: + +```bash +data_path="./protein_ligand_binding_pose_prediction" # replace to your data path +save_dir="./save_pose" # replace to your save path +n_gpu=4 +MASTER_PORT=10086 +finetune_mol_model="./weights/mol_checkpoint.pt" +finetune_pocket_model="./weights/pocket_checkpoint.pt" +lr=3e-4 +batch_size=8 +epoch=50 +dropout=0.2 +warmup=0.06 +update_freq=1 +dist_threshold=8.0 +recycling=3 + +export NCCL_ASYNC_ERROR_HANDLING=1 +export OMP_NUM_THREADS=1 +python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --user-dir ./unimol --train-subset train --valid-subset valid \ + --num-workers 8 --ddp-backend=c10d \ + --task docking_pose --loss docking_pose --arch docking_pose \ + --optimizer adam --adam-betas "(0.9, 0.99)" --adam-eps 1e-6 --clip-norm 1.0 \ + --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $batch_size \ + --mol-pooler-dropout $dropout --pocket-pooler-dropout $dropout \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 --update-freq $update_freq --seed 1 \ + --tensorboard-logdir $save_dir/tsb \ + --log-interval 100 --log-format simple \ + --validate-interval 1 --keep-last-epochs 10 \ + --best-checkpoint-metric valid_loss --patience 2000 --all-gather-list-size 2048000 \ + --finetune-mol-model $finetune_mol_model \ + --finetune-pocket-model $finetune_pocket_model \ + --dist-threshold $dist_threshold --recycling $recycling \ + --save-dir $save_dir \ + --find-unused-parameters + +``` + +2. Inference on the test set: + +```bash +data_path="./protein_ligand_binding_pose_prediction" # replace to your data path +results_path="./infer_pose" # replace to your results path +weight_path="./save_pose/checkpoint.pt" +batch_size=8 +dist_threshold=8.0 +recycling=3 + +python ./unimol/infer.py --user-dir ./unimol $data_path --valid-subset test \ + --results-path $results_path \ + --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \ + --task docking_pose --loss docking_pose --arch docking_pose \ + --path $weight_path \ + --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \ + --dist-threshold $dist_threshold --recycling $recycling \ + --log-interval 50 --log-format simple +``` +- For reproduce, you can also use the finetuned checkpoint we released in the table above to infer. + +- **NOTE**: Currently, the inference is only supported to run on a single GPU. You can add `CUDA_VISIBLE_DEVICES="0"` before the command. + +4. Docking and cal metrics: + +- Run this command +```bash +nthreads=20 # Num of threads +predict_file="./infer_pose/save_pose_test.out.pkl" # Your inference file dir +reference_file="./protein_ligand_binding_pose_prediction/test.lmdb" # Your reference file dir +output_path="./protein_ligand_binding_pose_prediction" # Docking results path + +python ./unimol/utils/docking.py --nthreads $nthreads --predict-file $predict_file --reference-file $reference_file --output-path $output_path +``` + +AIAC 2022 Competition Prediction of protein binding ability of drug molecules +------------------ +- Competition Link: [AIAC 2022 Competition Prediction of protein binding ability of drug molecules](http://www.aiinnovation.com.cn/#/aiaeDetail?id=560). +- Entry and final submission deadline: 2022-09-26 +- Run this command +```bash +git checkout ifd_demo +### download data from competition website and decompress it to ./examples/ifd_docking +sh train_ifd.sh +sh infer_ifd.sh +cd ./examples/ifd_scoring && python generate_submit.py +``` + +Citation +-------- + +Please kindly cite this paper if you use the data/code/model. +``` +@inproceedings{ + zhou2023unimol, + title={Uni-Mol: A Universal 3D Molecular Representation Learning Framework}, + author={Gengmo Zhou and Zhifeng Gao and Qiankun Ding and Hang Zheng and Hongteng Xu and Zhewei Wei and Linfeng Zhang and Guolin Ke}, + booktitle={The Eleventh International Conference on Learning Representations }, + year={2023}, + url={https://openreview.net/forum?id=6K2RM6wVqKu} +} +``` + +License +------- + +This project is licensed under the terms of the MIT license. See [LICENSE](https://github.com/deepmodeling/Uni-Mol/blob/main/LICENSE) for additional details. diff --git a/MindChemistry/applications/Uni-Mol/unimol/docker/Dockerfile b/MindChemistry/applications/Uni-Mol/unimol/docker/Dockerfile new file mode 100644 index 0000000000000000000000000000000000000000..4afe4853ed21d0a780d9a079417ed616d3915ef2 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/docker/Dockerfile @@ -0,0 +1,11 @@ +FROM dptechnology/unicore:0.0.1-pytorch1.11.0-cuda11.3 + +RUN pip install setuptools wheel twine + +RUN pip install rdkit-pypi==2021.9.5.1 + +RUN ldconfig && \ + apt-get clean && \ + apt-get autoremove && \ + rm -rf /var/lib/apt/lists/* /tmp/* && \ + pip cache purge diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/dict.txt b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/dict.txt new file mode 100644 index 0000000000000000000000000000000000000000..4130c254b4da592338b43298b49120a561dfae60 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/dict.txt @@ -0,0 +1,30 @@ +[PAD] +[CLS] +[SEP] +[UNK] +C +N +O +S +H +Cl +F +Br +I +Si +P +B +Na +K +Al +Ca +Sn +As +Hg +Fe +Zn +Cr +Se +Gd +Au +Li \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/train.lmdb b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/train.lmdb new file mode 100644 index 0000000000000000000000000000000000000000..c0516d4c5fb77d729c21501a132a469481fef207 Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/train.lmdb differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/valid.lmdb b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/valid.lmdb new file mode 100644 index 0000000000000000000000000000000000000000..c6560d2e871b5b68adc1b3e658eaa600ab294fd7 Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/example_data/molecule/valid.lmdb differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/dict_coarse.txt b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/dict_coarse.txt new file mode 100644 index 0000000000000000000000000000000000000000..9cd15c6800b2965e3fbe2f4db301b44dbb3dbc97 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/dict_coarse.txt @@ -0,0 +1,9 @@ +[PAD] +[CLS] +[SEP] +[UNK] +C +N +O +S +H diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/train.lmdb b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/train.lmdb new file mode 100644 index 0000000000000000000000000000000000000000..b969b5b77e2f84b727a25578a5714d57f3f2b11b Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/train.lmdb differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/valid.lmdb b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/valid.lmdb new file mode 100644 index 0000000000000000000000000000000000000000..757ddebd6e12cb871784fad57e32007915e8de31 Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/example_data/pocket/valid.lmdb differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/figure/overview.png b/MindChemistry/applications/Uni-Mol/unimol/figure/overview.png new file mode 100644 index 0000000000000000000000000000000000000000..21eff9fcc1faa41f9fb1fbc280d66d47111c41e2 Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/figure/overview.png differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/fusion_result.json b/MindChemistry/applications/Uni-Mol/unimol/fusion_result.json new file mode 100644 index 0000000000000000000000000000000000000000..ec747fa47ddb81e9bf2d282011ed32aa4c59f932 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/fusion_result.json @@ -0,0 +1 @@ +null \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/mol_property_demo.csv b/MindChemistry/applications/Uni-Mol/unimol/notebooks/mol_property_demo.csv new file mode 100644 index 0000000000000000000000000000000000000000..3ea1044a45e2ac2d40f290e8d7fe888615b46491 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/mol_property_demo.csv @@ -0,0 +1,21 @@ +mol,Class +O=C1N(C)C(=N[C@@]1(c1cc(ccc1)-c1cccnc1)c1ccncc1)N,1 +s1cc(cc1)[C@@]1(N=C(N)N(C)C1=O)c1cc(ccc1)-c1cccnc1,1 +O=C(NC1CCCCC1)CCc1cc2cc(ccc2nc1N)-c1ccccc1C,1 +S(=O)(=O)(C(CCC)CCC)C[C@@H](NC(OCc1ccccc1)=O)C(=O)N[C@H]([C@H](O)C[NH2+]Cc1cc(OC)ccc1)Cc1ccccc1,0 +S1(=O)(=O)C[C@@H](Cc2cc(F)c3NCC4(CCC(F)(F)CC4)c3c2)[C@H](O)[C@@H]([NH2+]Cc2cc(ccc2)C(C)(C)C)C1,0 +O=C(N1CC[C@H](C[C@H]1c1ccccc1)c1ccccc1)[C@@H]1C[NH2+]C[C@]12CCCc1c2cccc1,0 +S1(=O)(=O)C[C@@H](Cc2cc(C[C@@H]3N(CCC)C(OC3)=O)c(O)cc2)[C@H](O)[C@@H]([NH2+]Cc2cc(ccc2)C(C)C)C1,0 +O(C)c1ccc(cc1C)[C@@]1(N=C(N)N(C)C1=O)C12CC3CC(C1)CC(C2)C3,0 +Clc1cc2CC([NH+]=C(N[C@@H](Cc3ccccc3)C=3NC(=O)c4c(N=3)ccnc4)c2cc1)(C)C,0 +O=C1N(CCC1)c1cc(cc(NCC)c1)C(=O)N[C@H]([C@H](O)C[NH2+]C1CCCCC1)Cc1ccccc1,0 +Fc1ccc(NC(=O)c2ncc(OCC)cc2)cc1[C@]1(N=C(OCC1(F)F)N)C,0 +O(C)c1cc(ccc1)C[NH2+]C[C@@H](O)[C@@H](NC(=O)c1cc(ccc1)C(=O)N(CCC)CCC)Cc1ccccc1,0 +FC(F)(F)c1cc(ccc1)C[NH2+]C[C@@H](O)[C@@H](NC(=O)C=1C=C(N2CCCC2=O)C(=O)N(C=1)C1CCCC1)Cc1ccccc1,1 +Fc1ccc(NC(=O)c2ncc(cc2)C#N)cc1[C@]1(N=C(OC[C@@H]1F)N)CF,1 +FC1(F)CN2C(=NC1)[C@]([NH+]=C2N)(c1cc(ccc1)C#CCOC)c1ccc(OC(F)F)cc1,1 +[NH+]=1[C@](N=C(C)C=1N)(C1CC1)c1cc(ccc1)-c1cc(cnc1)C#CC,1 +Fc1ccc(cc1-c1cncnc1)[C@]1([NH+]=C(N)c2c1cccc2F)c1cc(ncc1)C(F)(F)F,0 +O=C1N(C)C(=N[C@@]1(c1cc(ccc1)-c1cncnc1)c1cn(nc1)CC)N,1 +O(C)c1cc(ccc1)-c1cc(ccc1)CC[C@]1(N=C(N)N(C)C(=O)C1)C,0 +O1c2c(cc(cc2)-c2cc(ccc2)C#N)[C@]2(N=C(N)N(C)C2=O)CC1(C)C,1 \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_binding_pose_demo.ipynb b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_binding_pose_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..790afc8024ca4738f35f9a6af9a8957b0d290a70 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_binding_pose_demo.ipynb @@ -0,0 +1,442 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "4eem1ns5n7Eq" + }, + "source": [ + "# Uni-Mol Binding Pose Prediction Colab\n", + "\n", + "This Colab notebook provides an online runnable version of [Uni-Mol](https://github.com/deepmodeling/Uni-Mol/) binding pose prediction (short for \"docking\" in the following) with custom settings.\n", + "Uni-Mol docking is very fast in dozens of seconds. \n", + "\n", + "Please note that this Colab notebook is not a finished product and is provided as an early-access prototype. It is provided for theoretical modeling only and caution should be exercised in its use. \n", + "\n", + "**Licenses**\n", + "\n", + "This Colab uses the [Uni-Mol model parameters](https://github.com/deepmodeling/Uni-Mol/LICENSE) and its outputs are under the terms of the Creative Commons Attribution 4.0 International (CC BY 4.0) license. You can find details at: https://creativecommons.org/licenses/by/4.0/legalcode. The Colab is provided under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0).\n", + "\n", + "\n", + "**Citations**\n", + "\n", + "Please cite the following papers if you use this notebook:\n", + " \n", + "* Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\n", + "\" ChemRxiv (2022)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "p6uWJIpRQR6y" + }, + "outputs": [], + "source": [ + "%%bash\n", + "#@title Install dependencies\n", + "\n", + "GIT_REPO='https://github.com/deepmodeling/Uni-Mol'\n", + "UNICORE_URL='https://github.com/dptech-corp/Uni-Core/releases/download/0.0.2/unicore-0.0.1+cu116torch1.13.1-cp39-cp39-linux_x86_64.whl'\n", + "DOCKING_DATA_URL='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/CASF-2016.tar.gz'\n", + "DOCKING_WEIGHT_URL='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/binding_pose_220908.pt'\n", + "if [ ! -f UNIMOL_READY ]; then\n", + " wget -q ${UNICORE_URL} \n", + " pip3 -q install \"unicore-0.0.1+cu116torch1.13.1-cp39-cp39-linux_x86_64.whl\" \n", + " rm -rf ./Uni-Mol\n", + " git clone -b main ${GIT_REPO}\n", + " pip3 install -q ./Uni-Mol/unimol\n", + " pip install -q rdkit\n", + " pip install -q biopandas\n", + " wget -q ${DOCKING_DATA_URL}\n", + " tar -xzf \"CASF-2016.tar.gz\"\n", + " wget -q ${DOCKING_WEIGHT_URL}\n", + " pip install -q py3Dmol\n", + " touch UNIMOL_READY\n", + "fi\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "242RLQ0JVWns" + }, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "import numpy as np\n", + "import pandas as pd\n", + "import biopandas\n", + "import lmdb\n", + "from biopandas.pdb import PandasPdb\n", + "from rdkit import Chem\n", + "from rdkit.Chem import AllChem\n", + "from sklearn.cluster import KMeans\n", + "from rdkit.Chem import rdMolTransforms\n", + "from rdkit.Chem.rdMolAlign import AlignMolConformers\n", + "from unimol.utils.docking_utils import docking_data_pre, ensemble_iterations\n", + "from tqdm import tqdm\n", + "import pickle\n", + "import re\n", + "import json\n", + "import copy\n", + "\n", + "CASF_PATH = \"CASF-2016\"\n", + "main_atoms = [\"N\", \"CA\", \"C\", \"O\", \"H\"]\n", + "\n", + "\n", + "def load_from_CASF(pdb_id):\n", + " try:\n", + " pdb_path = os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_protein.pdb\")\n", + " pmol = PandasPdb().read_pdb(pdb_path)\n", + " pocket_residues = json.load(\n", + " open(os.path.join(CASF_PATH, \"casf2016.pocket.json\"))\n", + " )[pdb_id]\n", + " return pmol, pocket_residues\n", + " except:\n", + " print(\"Currently not support parsing pdb and pocket info from local files.\")\n", + "\n", + "\n", + "def normalize_atoms(atom):\n", + " return re.sub(\"\\d+\", \"\", atom)\n", + "\n", + "\n", + "def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):\n", + " mol = copy.deepcopy(tgt_mol)\n", + " mol = Chem.AddHs(mol)\n", + " allconformers = AllChem.EmbedMultipleConfs(\n", + " mol, numConfs=num_confs, randomSeed=seed, clearConfs=True\n", + " )\n", + " sz = len(allconformers)\n", + " for i in range(sz):\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol, confId=i)\n", + " except:\n", + " continue\n", + " if removeHs:\n", + " mol = Chem.RemoveHs(mol)\n", + " return mol\n", + "\n", + "\n", + "def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True):\n", + " rdkit_coords_list = []\n", + " rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)\n", + " noHsIds = [\n", + " rdkit_mol.GetAtoms()[i].GetIdx()\n", + " for i in range(len(rdkit_mol.GetAtoms()))\n", + " if rdkit_mol.GetAtoms()[i].GetAtomicNum() != 1\n", + " ]\n", + " ### exclude hydrogens for aligning\n", + " AlignMolConformers(rdkit_mol, atomIds=noHsIds)\n", + " sz = len(rdkit_mol.GetConformers())\n", + " for i in range(sz):\n", + " _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)\n", + " rdkit_coords_list.append(_coords)\n", + "\n", + " ### exclude hydrogens for clustering\n", + " rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz, -1)\n", + " ids = (\n", + " KMeans(n_clusters=N, random_state=seed)\n", + " .fit_predict(rdkit_coords_flatten)\n", + " .tolist()\n", + " )\n", + " coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]\n", + " return coords_list\n", + "\n", + "\n", + "def parser(pdb_id, smiles, seed=42):\n", + " pmol, pocket_residues = load_from_CASF(pdb_id)\n", + " pname = pdb_id\n", + " pro_atom = pmol.df[\"ATOM\"]\n", + " pro_hetatm = pmol.df[\"HETATM\"]\n", + "\n", + " pro_atom[\"ID\"] = pro_atom[\"chain_id\"].astype(str) + pro_atom[\n", + " \"residue_number\"\n", + " ].astype(str)\n", + " pro_hetatm[\"ID\"] = pro_hetatm[\"chain_id\"].astype(str) + pro_hetatm[\n", + " \"residue_number\"\n", + " ].astype(str)\n", + "\n", + " pocket = pd.concat(\n", + " [\n", + " pro_atom[pro_atom[\"ID\"].isin(pocket_residues)],\n", + " pro_hetatm[pro_hetatm[\"ID\"].isin(pocket_residues)],\n", + " ],\n", + " axis=0,\n", + " ignore_index=True,\n", + " )\n", + "\n", + " pocket[\"normalize_atom\"] = pocket[\"atom_name\"].map(normalize_atoms)\n", + " pocket = pocket[pocket[\"normalize_atom\"] != \"\"]\n", + " patoms = pocket[\"atom_name\"].apply(normalize_atoms).values.tolist()\n", + " pcoords = [pocket[[\"x_coord\", \"y_coord\", \"z_coord\"]].values]\n", + " side = [0 if a in main_atoms else 1 for a in patoms]\n", + " residues = (\n", + " pocket[\"chain_id\"].astype(str) + pocket[\"residue_number\"].astype(str)\n", + " ).values.tolist()\n", + "\n", + " # generate ligand conformation\n", + " M, N = 100, 10\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " mol = Chem.AddHs(mol)\n", + " AllChem.EmbedMolecule(mol, randomSeed=seed)\n", + " latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n", + " holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]\n", + " holo_mol = mol\n", + " coordinate_list = clustering_coords(mol, M=M, N=N, seed=seed, removeHs=False)\n", + " mol_list = [mol] * N\n", + "\n", + " return pickle.dumps(\n", + " {\n", + " \"atoms\": latoms,\n", + " \"coordinates\": coordinate_list,\n", + " \"mol_list\": mol_list,\n", + " \"pocket_atoms\": patoms,\n", + " \"pocket_coordinates\": pcoords,\n", + " \"side\": side,\n", + " \"residue\": residues,\n", + " \"holo_coordinates\": holo_coordinates,\n", + " \"holo_mol\": holo_mol,\n", + " \"holo_pocket_coordinates\": pcoords,\n", + " \"smi\": smiles,\n", + " \"pocket\": pname,\n", + " },\n", + " protocol=-1,\n", + " )\n", + "\n", + "\n", + "def write_lmdb(pdb_id, smiles_list, seed=42, result_dir=\"./results\"):\n", + " os.makedirs(result_dir, exist_ok=True)\n", + " outputfilename = os.path.join(result_dir, pdb_id + \".lmdb\")\n", + " try:\n", + " os.remove(outputfilename)\n", + " except:\n", + " pass\n", + " env_new = lmdb.open(\n", + " outputfilename,\n", + " subdir=False,\n", + " readonly=False,\n", + " lock=False,\n", + " readahead=False,\n", + " meminit=False,\n", + " max_readers=1,\n", + " map_size=int(10e9),\n", + " )\n", + " for i, smiles in enumerate(smiles_list):\n", + " inner_output = parser(pdb_id, smiles, seed=seed)\n", + " txn_write = env_new.begin(write=True)\n", + " txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n", + " txn_write.commit()\n", + " env_new.close()\n", + "\n", + "\n", + "# @title Run Uni-Mol Binding Pose Prediction\n", + "\n", + "# @markdown Currently this scripts only support CASF-2016 dataset with given pockets residues.\n", + "\n", + "# @markdown You can input multiple SMILES, split by ','.\n", + "\n", + "# @markdown If SMILES is not given, the default one in the complex will be used.\n", + "\n", + "pdb_id = \"4ty7\" # @param {type:\"string\"}\n", + "pdb_id = pdb_id.lower()\n", + "casf_collect = os.listdir(os.path.join(CASF_PATH, \"casf2016\"))\n", + "casf_collect = list(set([item[:4] for item in casf_collect]))\n", + "if pdb_id not in casf_collect:\n", + " warning_str = \"{} is not int CASF-2016 dataset, Please select from \\n\".format(pdb_id)\n", + " for i in range(15):\n", + " warning_str += \"{}\\n\".format(','.join(casf_collect[20*i:20*(i+1)]))\n", + " raise Exception(warning_str)\n", + "supp = Chem.SDMolSupplier(os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_ligand.sdf\"))\n", + "mol = [mol for mol in supp if mol][0]\n", + "ori_smiles = Chem.MolToSmiles(mol)\n", + "smiles = \"\" # @param {type:\"string\"}\n", + "seed = 42 # @param {type:\"number\"}\n", + "data_path = \"./CASF-2016\"\n", + "results_path = \"./results/\"\n", + "weight_path = \"/content/binding_pose_220908.pt\"\n", + "batch_size = 8\n", + "dist_threshold = 8.0\n", + "recycling = 3\n", + "if smiles.split(\",\") == 0 or smiles == \"\":\n", + " print(\"No other smiles inputs\")\n", + " smiles_list = [ori_smiles]\n", + "else:\n", + " print(\"Docking with smiles: {}\".format(smiles))\n", + " smiles_list = smiles.split(\",\")\n", + "\n", + "write_lmdb(pdb_id, smiles_list, seed=seed, result_dir=data_path)\n", + "\n", + "!python ./Uni-Mol/unimol/unimol/infer.py --user-dir ./Uni-Mol/unimol/unimol $data_path --valid-subset $pdb_id \\\n", + " --results-path $results_path \\\n", + " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", + " --task docking_pose --loss docking_pose --arch docking_pose \\\n", + " --path $weight_path \\\n", + " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", + " --dist-threshold $dist_threshold --recycling $recycling \\\n", + " --log-interval 50 --log-format simple\n", + "\n", + "def generate_docking_input(\n", + " predict_file, reference_file, tta_times=10, output_dir=\"./results\"\n", + "):\n", + " (\n", + " mol_list,\n", + " smi_list,\n", + " pocket_list,\n", + " pocket_coords_list,\n", + " distance_predict_list,\n", + " holo_distance_predict_list,\n", + " holo_coords_list,\n", + " holo_center_coords_list,\n", + " ) = docking_data_pre(reference_file, predict_file)\n", + " iter = ensemble_iterations(\n", + " mol_list,\n", + " smi_list,\n", + " pocket_list,\n", + " pocket_coords_list,\n", + " distance_predict_list,\n", + " holo_distance_predict_list,\n", + " holo_coords_list,\n", + " holo_center_coords_list,\n", + " tta_times=tta_times,\n", + " )\n", + " for i, content in enumerate(iter):\n", + " pocket = content[3]\n", + " output_name = os.path.join(output_dir, \"{}.{}.pkl\".format(pocket, i))\n", + " try:\n", + " os.remove(output_name)\n", + " except:\n", + " pass\n", + " pd.to_pickle(content, output_name)\n", + "\n", + "\n", + "predict_file = os.path.join(results_path, \"content_\" + pdb_id + \".out.pkl\")\n", + "reference_file = os.path.join(data_path, pdb_id + \".lmdb\")\n", + "generate_docking_input(\n", + " predict_file, reference_file, tta_times=10, output_dir=results_path\n", + ")\n", + "for i, smiles in enumerate(smiles_list):\n", + " print(\"Docking {}\".format(smiles))\n", + " input_path = os.path.join(results_path, \"{}.{}.pkl\".format(pdb_id, i))\n", + " ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id, i))\n", + " cmd = \"python ./Uni-Mol/unimol/unimol/utils/coordinate_model.py --input {} --output-ligand {}\".format(\n", + " input_path, ligand_path\n", + " )\n", + " os.system(cmd)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "ilKA-z6h_2oO" + }, + "outputs": [], + "source": [ + "#@title Visualization\n", + "\n", + "#@markdown Note: The first figure shows the result of the Uni-Mol prediction, and the second one shows the difference between the Uni-Mol prediction and the ground-truth ligand in the complex.\n", + "\n", + "#@markdown Note: We only visualize the first ligand when multiple SMILES are provided.\n", + "\n", + "import py3Dmol\n", + "import matplotlib.pyplot as plt\n", + "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n", + "ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id,0))\n", + "gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')\n", + "view = py3Dmol.view()\n", + "view.removeAllModels()\n", + "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n", + "view.addModel(open(pdb_path,'r').read(),format='pdb')\n", + "view.setStyle({'cartoon': {'arrows':True, 'tubes':False, 'style':'oval', 'color':'white'}})\n", + "view.addSurface(py3Dmol.VDW,{'opacity':0.5,'color':'white'})\n", + "\n", + "view.addModel(open(ligand_path,'r').read(),format='sdf')\n", + "ref_m = view.getModel()\n", + "ref_m.setStyle({},{'stick':{'colorscheme':'greenCarbon','radius':0.2}})\n", + "\n", + "view.zoomTo(viewer=(100,0))\n", + "view.show()\n", + "\n", + "view.removeAllModels()\n", + "view.addModel(open(ligand_path,'r').read(),format='sdf')\n", + "ref_m = view.getModel()\n", + "ref_m.setStyle({},{'stick':{'colorscheme':'greenCarbon','radius':0.2}})\n", + "\n", + "view.addModel(open(gt_ligand_path,'r').read(),format='sdf')\n", + "ref_m = view.getModel()\n", + "ref_m.setStyle({},{'stick':{'colorscheme':'redCarbon','radius':0.2}})\n", + "\n", + "view.zoomTo(viewer=(100,0))\n", + "view.show()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "BetxYhrqB1SD" + }, + "outputs": [], + "source": [ + "#@title Download the prediction\n", + "#@markdown **The content of zip file**:\n", + "#@markdown 1. PDB formatted structures\n", + "#@markdown 2. Docking ligand SDF files\n", + "#@markdown 3. Target ligand SDF files.\n", + "\n", + "from google.colab import files\n", + "file_lists = []\n", + "pdb_path = os.path.join(CASF_PATH, 'casf2016', pdb_id+'_protein.pdb')\n", + "file_lists.append(pdb_path)\n", + "for i in range(len(smiles_list)):\n", + " ligand_path = os.path.join(results_path, \"docking.{}.{}.sdf\".format(pdb_id,i))\n", + " file_lists.append(ligand_path)\n", + "gt_ligand_path = os.path.join(CASF_PATH,'casf2016',pdb_id+'_ligand.sdf')\n", + "file_lists.append(gt_ligand_path)\n", + "\n", + "!zip -j {\"unimol.docking.\"+pdb_id}.zip {\" \".join(file_lists)}\n", + "files.download(f'{\"unimol.docking.\"+pdb_id}.zip')" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "provenance": [] + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3.8.10 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + }, + "vscode": { + "interpreter": { + "hash": "916dbcbb3f70747c44a77c7bcd40155683ae19c65e1c03b4aa3499c5328201f1" + } + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_property_demo.ipynb b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_property_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..c72f4fb2febccc595d3c417d2df02e5e757dd290 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_property_demo.ipynb @@ -0,0 +1,319 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Uni-Mol Molecular Property Prediction" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Licenses**\n", + "\n", + "Copyright (c) DP Technology.\n", + "\n", + "This source code is licensed under the MIT license found in the\n", + "LICENSE file in the root directory of this source tree.\n", + "\n", + "**Citations**\n", + "\n", + "Please cite the following papers if you use this notebook:\n", + "\n", + "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n", + "ChemRxiv (2022)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Data Preparation (SMILES, label to .lmdb)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pickle\n", + "import lmdb\n", + "import pandas as pd\n", + "import numpy as np\n", + "from rdkit import Chem\n", + "from tqdm import tqdm\n", + "from rdkit.Chem import AllChem\n", + "from rdkit import RDLogger\n", + "RDLogger.DisableLog('rdApp.*') \n", + "import warnings\n", + "warnings.filterwarnings(action='ignore')\n", + "from multiprocessing import Pool\n", + "\n", + "\n", + "def smi2_2Dcoords(smi):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " mol = AllChem.AddHs(mol)\n", + " AllChem.Compute2DCoords(mol)\n", + " coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n", + " len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n", + " return coordinates\n", + "\n", + "\n", + "def smi2_3Dcoords(smi,cnt):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " mol = AllChem.AddHs(mol)\n", + " coordinate_list=[]\n", + " for seed in range(cnt):\n", + " try:\n", + " res = AllChem.EmbedMolecule(mol, randomSeed=seed) # will random generate conformer with seed equal to -1. else fixed random seed.\n", + " if res == 0:\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize\n", + " coordinates = mol.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + " \n", + " elif res == -1:\n", + " mol_tmp = Chem.MolFromSmiles(smi)\n", + " AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n", + " mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize\n", + " coordinates = mol_tmp.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + "\n", + " assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n", + " coordinate_list.append(coordinates.astype(np.float32))\n", + " return coordinate_list\n", + "\n", + "\n", + "def inner_smi2coords(content):\n", + " smi = content[0]\n", + " target = content[1:]\n", + " cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n", + "\n", + " mol = Chem.MolFromSmiles(smi)\n", + " if len(mol.GetAtoms()) > 400:\n", + " coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1)\n", + " print(\"atom num >400,use 2D coords\",smi)\n", + " else:\n", + " coordinate_list = smi2_3Dcoords(smi,cnt)\n", + " coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n", + " mol = AllChem.AddHs(mol)\n", + " atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H \n", + " return pickle.dumps({'atoms': atoms, \n", + " 'coordinates': coordinate_list, \n", + " 'mol':mol,'smi': smi, 'target': target}, protocol=-1)\n", + "\n", + "\n", + "def smi2coords(content):\n", + " try:\n", + " return inner_smi2coords(content)\n", + " except:\n", + " print(\"failed smiles: {}\".format(content[0]))\n", + " return None\n", + "\n", + "\n", + "def write_lmdb(inpath='./', outpath='./', nthreads=16):\n", + "\n", + " df = pd.read_csv(os.path.join(inpath))\n", + " sz = len(df)\n", + " train, valid, test = df[:int(sz*0.8)], df[int(sz*0.8):int(sz*0.9)], df[int(sz*0.9):]\n", + " for name, content_list in [('train.lmdb', zip(*[train[c].values.tolist() for c in train])),\n", + " ('valid.lmdb', zip(*[valid[c].values.tolist() for c in valid])),\n", + " ('test.lmdb', zip(*[test[c].values.tolist() for c in test]))]:\n", + " os.makedirs(outpath, exist_ok=True)\n", + " output_name = os.path.join(outpath, name)\n", + " try:\n", + " os.remove(output_name)\n", + " except:\n", + " pass\n", + " env_new = lmdb.open(\n", + " output_name,\n", + " subdir=False,\n", + " readonly=False,\n", + " lock=False,\n", + " readahead=False,\n", + " meminit=False,\n", + " max_readers=1,\n", + " map_size=int(100e9),\n", + " )\n", + " txn_write = env_new.begin(write=True)\n", + " with Pool(nthreads) as pool:\n", + " i = 0\n", + " for inner_output in tqdm(pool.imap(smi2coords, content_list)):\n", + " if inner_output is not None:\n", + " txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n", + " i += 1\n", + " print('{} process {} lines'.format(name, i))\n", + " txn_write.commit()\n", + " env_new.close()\n", + "\n", + "write_lmdb(inpath='mol_property_demo.csv', outpath='./demo', nthreads=8)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Finetuning (based on pretraining)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path='./' # replace to your data path\n", + "save_dir='./save_demo' # replace to your save path\n", + "MASTER_PORT=10086\n", + "n_gpu=1\n", + "dict_name='dict.txt'\n", + "weight_path='./weights/mol_pre_no_h_220816.pt' # replace to your ckpt path\n", + "task_name='demo' # data folder name\n", + "task_num=2\n", + "loss_func='finetune_cross_entropy'\n", + "lr=1e-4\n", + "batch_size=32\n", + "epoch=5\n", + "dropout=0.1\n", + "warmup=0.06\n", + "local_batch_size=32\n", + "only_polar=0 # -1 all h; 0 no h\n", + "conf_size=11\n", + "seed=0\n", + "metric=\"valid_agg_auc\"\n", + "update_freq=batch_size / local_batch_size\n", + "\n", + "!cp ../example_data/molecule/$dict_name $data_path\n", + "!export NCCL_ASYNC_ERROR_HANDLING=1\n", + "!export OMP_NUM_THREADS=1\n", + "!python -m torch.distributed.launch --nproc_per_node=$n_gpu --master_port=$MASTER_PORT $(which unicore-train) $data_path --task-name $task_name --user-dir ../unimol --train-subset train --valid-subset valid \\\n", + " --conf-size $conf_size \\\n", + " --num-workers 8 --ddp-backend=c10d \\\n", + " --dict-name $dict_name \\\n", + " --task mol_finetune --loss $loss_func --arch unimol_base \\\n", + " --classification-head-name $task_name --num-classes $task_num \\\n", + " --optimizer adam --adam-betas '(0.9, 0.99)' --adam-eps 1e-6 --clip-norm 1.0 \\\n", + " --lr-scheduler polynomial_decay --lr $lr --warmup-ratio $warmup --max-epoch $epoch --batch-size $local_batch_size --pooler-dropout $dropout\\\n", + " --update-freq $update_freq --seed $seed \\\n", + " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", + " --log-interval 100 --log-format simple \\\n", + " --validate-interval 1 --keep-last-epochs 10 \\\n", + " --finetune-from-model $weight_path \\\n", + " --best-checkpoint-metric $metric --patience 20 \\\n", + " --save-dir $save_dir --only-polar $only_polar \\\n", + " --maximize-best-checkpoint-metric\n", + "# --maximize-best-checkpoint-metric, for classification task" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inference" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path='./' # replace to your data path\n", + "results_path='./infer_demo' # replace to your results path\n", + "weight_path='./save_demo/checkpoint_best.pt' # replace to your ckpt path\n", + "batch_size=32\n", + "task_name='demo' # data folder name \n", + "task_num=2\n", + "loss_func='finetune_cross_entropy'\n", + "dict_name='dict.txt'\n", + "conf_size=11\n", + "only_polar=0\n", + "\n", + "!cp ../example_data/molecule/$dict_name $data_path\n", + "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --task-name $task_name --valid-subset test \\\n", + " --results-path $results_path \\\n", + " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", + " --task mol_finetune --loss $loss_func --arch unimol_base \\\n", + " --classification-head-name $task_name --num-classes $task_num \\\n", + " --dict-name $dict_name --conf-size $conf_size \\\n", + " --only-polar $only_polar \\\n", + " --path $weight_path \\\n", + " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", + " --log-interval 50 --log-format simple " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Read inference results (.pkl to .csv)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def get_csv_results(predict_path, csv_path):\n", + " predict = pd.read_pickle(predict_path)\n", + " smi_list, predict_list = [], []\n", + " for batch in predict:\n", + " sz = batch[\"bsz\"]\n", + " for i in range(sz):\n", + " smi_list.append(batch[\"smi_name\"][i])\n", + " predict_list.append(batch[\"prob\"][i][1].cpu().tolist())\n", + " predict_df = pd.DataFrame({\"SMILES\": smi_list, \"predict_prob\": predict_list})\n", + " predict_df = predict_df.groupby(\"SMILES\")[\"predict_prob\"].mean().reset_index()\n", + " predict_df.to_csv(csv_path,index=False)\n", + " return predict_df\n", + "\n", + "predict_path='./infer_demo/save_demo_test.out.pkl' # replace to your results path\n", + "csv_path='./infer_demo/demo_results.csv'\n", + "predict_df = get_csv_results(predict_path, csv_path)\n", + "predict_df.info(), predict_df.head()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.8.13 ('base')", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "d4d1e4263499bec80672ea0156c357c1ee493ec2b1c70f0acce89fc37c4a6abe" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_repr_demo.ipynb b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_repr_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..ed08ba05b1b61d72778e1c7be0548e39493ac3ab --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_mol_repr_demo.ipynb @@ -0,0 +1,311 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "4f0f701f-c552-4ca1-8188-2cdfc1362f6b", + "metadata": {}, + "source": [ + "# Uni-Mol Molecular Representation" + ] + }, + { + "cell_type": "markdown", + "id": "d3449ed8-2a57-4e62-9163-e32baf66e828", + "metadata": {}, + "source": [ + "**Licenses**\n", + "\n", + "Copyright (c) DP Technology.\n", + "\n", + "This source code is licensed under the MIT license found in the\n", + "LICENSE file in the root directory of this source tree.\n", + "\n", + "**Citations**\n", + "\n", + "Please cite the following papers if you use this notebook:\n", + "\n", + "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n", + "ChemRxiv (2022)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6d51f850-76cd-4801-bf2e-a4c53221d586", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "import lmdb\n", + "from rdkit import Chem\n", + "from rdkit.Chem import AllChem\n", + "from tqdm import tqdm\n", + "import pickle\n", + "import glob\n", + "from multiprocessing import Pool\n", + "from collections import defaultdict" + ] + }, + { + "cell_type": "markdown", + "id": "89c70ab0-da59-459d-bf1c-ac307e9e7ae5", + "metadata": {}, + "source": [ + "### Your SMILES list" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "bfa0ce2a-b7aa-4cae-81ba-27b91c0591e4", + "metadata": {}, + "outputs": [], + "source": [ + "smi_list = [\n", + "'CC1=C(C(=O)OC2CCCC2)[C@H](c2ccccc2OC(C)C)C2=C(O)CC(C)(C)CC2=[N+]1',\n", + "'COc1cccc(-c2nc(C(=O)NC[C@H]3CCCO3)cc3c2[nH]c2ccccc23)c1',\n", + "'O=C1c2ccccc2C(=O)c2c1ccc(C(=O)n1nc3c4c(cccc41)C(=O)c1ccccc1-3)c2[N+](=O)[O-]',\n", + "'COc1cc(/C=N/c2nonc2NC(C)=O)ccc1OC(C)C',\n", + "'CCC[C@@H]1CN(Cc2ccc3nsnc3c2)C[C@H]1NS(C)(=O)=O',\n", + "'CCc1nnc(N/C(O)=C/CCOc2ccc(OC)cc2)s1',\n", + "'CC(C)(C)SCCN/C=C1\\C(=O)NC(=O)N(c2ccc(Br)cc2)C1=O',\n", + "'CC(C)(C)c1nc(COc2ccc3c(c2)CCn2c-3cc(OCC3COCCO3)nc2=O)no1',\n", + "'N#CCCNS(=O)(=O)c1ccc(/C(O)=N/c2ccccc2Oc2ccccc2Cl)cc1',\n", + "'O=C(Nc1ncc(Cl)s1)c1cccc(S(=O)(=O)Nc2ccc(Br)cc2)c1',\n", + "]" + ] + }, + { + "cell_type": "markdown", + "id": "b109d84a-8d59-445b-9997-d1383ee24079", + "metadata": {}, + "source": [ + "### Generate conformations from SMILES and save to .lmdb" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ea582d7d-8851-4d46-880e-54867737b232", + "metadata": {}, + "outputs": [], + "source": [ + "def smi2_2Dcoords(smi):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " mol = AllChem.AddHs(mol)\n", + " AllChem.Compute2DCoords(mol)\n", + " coordinates = mol.GetConformer().GetPositions().astype(np.float32)\n", + " len(mol.GetAtoms()) == len(coordinates), \"2D coordinates shape is not align with {}\".format(smi)\n", + " return coordinates\n", + "\n", + "\n", + "def smi2_3Dcoords(smi,cnt):\n", + " mol = Chem.MolFromSmiles(smi)\n", + " mol = AllChem.AddHs(mol)\n", + " coordinate_list=[]\n", + " for seed in range(cnt):\n", + " try:\n", + " res = AllChem.EmbedMolecule(mol, randomSeed=seed) # will random generate conformer with seed equal to -1. else fixed random seed.\n", + " if res == 0:\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol) # some conformer can not use MMFF optimize\n", + " coordinates = mol.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + " \n", + " elif res == -1:\n", + " mol_tmp = Chem.MolFromSmiles(smi)\n", + " AllChem.EmbedMolecule(mol_tmp, maxAttempts=5000, randomSeed=seed)\n", + " mol_tmp = AllChem.AddHs(mol_tmp, addCoords=True)\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol_tmp) # some conformer can not use MMFF optimize\n", + " coordinates = mol_tmp.GetConformer().GetPositions()\n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + " except:\n", + " print(\"Failed to generate 3D, replace with 2D\")\n", + " coordinates = smi2_2Dcoords(smi) \n", + "\n", + " assert len(mol.GetAtoms()) == len(coordinates), \"3D coordinates shape is not align with {}\".format(smi)\n", + " coordinate_list.append(coordinates.astype(np.float32))\n", + " return coordinate_list\n", + "\n", + "\n", + "def inner_smi2coords(content):\n", + " smi = content\n", + " cnt = 10 # conformer num,all==11, 10 3d + 1 2d\n", + "\n", + " mol = Chem.MolFromSmiles(smi)\n", + " if len(mol.GetAtoms()) > 400:\n", + " coordinate_list = [smi2_2Dcoords(smi)] * (cnt+1)\n", + " print(\"atom num >400,use 2D coords\",smi)\n", + " else:\n", + " coordinate_list = smi2_3Dcoords(smi,cnt)\n", + " # add 2d conf\n", + " coordinate_list.append(smi2_2Dcoords(smi).astype(np.float32))\n", + " mol = AllChem.AddHs(mol)\n", + " atoms = [atom.GetSymbol() for atom in mol.GetAtoms()] # after add H \n", + " return pickle.dumps({'atoms': atoms, 'coordinates': coordinate_list, 'smi': smi }, protocol=-1)\n", + "\n", + "\n", + "def smi2coords(content):\n", + " try:\n", + " return inner_smi2coords(content)\n", + " except:\n", + " print(\"failed smiles: {}\".format(content[0]))\n", + " return None\n", + "\n", + "\n", + "def write_lmdb(smiles_list, job_name, seed=42, outpath='./results', nthreads=8):\n", + " os.makedirs(outpath, exist_ok=True)\n", + " output_name = os.path.join(outpath,'{}.lmdb'.format(job_name))\n", + " try:\n", + " os.remove(output_name)\n", + " except:\n", + " pass\n", + " env_new = lmdb.open(\n", + " output_name,\n", + " subdir=False,\n", + " readonly=False,\n", + " lock=False,\n", + " readahead=False,\n", + " meminit=False,\n", + " max_readers=1,\n", + " map_size=int(100e9),\n", + " )\n", + " txn_write = env_new.begin(write=True)\n", + " with Pool(nthreads) as pool:\n", + " i = 0\n", + " for inner_output in tqdm(pool.imap(smi2coords, smiles_list)):\n", + " if inner_output is not None:\n", + " txn_write.put(f'{i}'.encode(\"ascii\"), inner_output)\n", + " i += 1\n", + " print('{} process {} lines'.format(job_name, i))\n", + " txn_write.commit()\n", + " env_new.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dad25a1a-f93e-4fdf-b389-2a3fe61a40ee", + "metadata": {}, + "outputs": [], + "source": [ + "seed = 42\n", + "job_name = 'get_mol_repr' # replace to your custom name\n", + "data_path = './results' # replace to your data path\n", + "weight_path='../ckp/mol_pre_no_h_220816.pt' # replace to your ckpt path\n", + "only_polar=0 # no h\n", + "dict_name='dict.txt'\n", + "batch_size=16\n", + "conf_size=11 # default 10 3d + 1 2d\n", + "results_path=data_path # replace to your save path\n", + "write_lmdb(smi_list, job_name=job_name, seed=seed, outpath=data_path)" + ] + }, + { + "cell_type": "markdown", + "id": "12284210-7f86-4062-b291-7c077ef6f83a", + "metadata": {}, + "source": [ + "### Infer from ckpt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9fb2391b-81b0-4b11-95ea-3b7855db9bc6", + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES=\"0\" before the command.\n", + "!cp ../example_data/molecule/$dict_name $data_path\n", + "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n", + " --results-path $results_path \\\n", + " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", + " --task unimol --loss unimol_infer --arch unimol_base \\\n", + " --path $weight_path \\\n", + " --only-polar $only_polar --dict-name $dict_name --conf-size $conf_size \\\n", + " --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer" + ] + }, + { + "cell_type": "markdown", + "id": "d8421258-eca6-4801-aadd-fc67fd928cb1", + "metadata": {}, + "source": [ + "### Read .pkl and save results to .csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c456f31e-94fc-4593-97c9-1db7182465aa", + "metadata": {}, + "outputs": [], + "source": [ + "def get_csv_results(predict_path, results_path):\n", + " predict = pd.read_pickle(predict_path)\n", + " mol_repr_dict = defaultdict(list)\n", + " atom_repr_dict = defaultdict(list)\n", + " pair_repr_dict = defaultdict(list)\n", + " for batch in predict:\n", + " sz = batch[\"bsz\"]\n", + " for i in range(sz):\n", + " smi = batch[\"data_name\"][i]\n", + " mol_repr_dict[smi].append(batch[\"mol_repr_cls\"][i])\n", + " atom_repr_dict[smi].append(batch[\"atom_repr\"][i])\n", + " pair_repr_dict[smi].append(batch[\"pair_repr\"][i])\n", + " # get mean repr for each molecule with multiple conf\n", + " smi_list, avg_mol_repr_list, avg_atom_repr_list, avg_pair_repr_list = [], [], [], []\n", + " for smi in mol_repr_dict.keys():\n", + " smi_list.append(smi)\n", + " avg_mol_repr_list.append(np.mean(mol_repr_dict[smi], axis=0))\n", + " avg_atom_repr_list.append(np.mean(atom_repr_dict[smi], axis=0))\n", + " avg_pair_repr_list.append(np.mean(pair_repr_dict[smi], axis=0))\n", + " predict_df = pd.DataFrame({\n", + " \"SMILES\": smi_list,\n", + " \"mol_repr\": avg_mol_repr_list,\n", + " \"atom_repr\": avg_atom_repr_list,\n", + " \"pair_repr\": avg_pair_repr_list\n", + " })\n", + " print(predict_df.head(1),predict_df.info())\n", + " predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n", + "\n", + "pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n", + "get_csv_results(pkl_path, results_path)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.10.6 64-bit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + }, + "vscode": { + "interpreter": { + "hash": "e7370f93d1d0cde622a1f8e1c04877d8463912d04d973331ad4851f04de6915a" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_pocket_repr_demo.ipynb b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_pocket_repr_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..8a7360b0e2fb6beb0264a62a183049df891d20fe --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_pocket_repr_demo.ipynb @@ -0,0 +1,261 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Uni-Mol Pocket Representation" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "**Licenses**\n", + "\n", + "Copyright (c) DP Technology.\n", + "\n", + "This source code is licensed under the MIT license found in the\n", + "LICENSE file in the root directory of this source tree.\n", + "\n", + "**Citations**\n", + "\n", + "Please cite the following papers if you use this notebook:\n", + "\n", + "- Gengmo Zhou, Zhifeng Gao, Qiankun Ding, Hang Zheng, Hongteng Xu, Zhewei Wei, Linfeng Zhang, Guolin Ke. \"[Uni-Mol: A Universal 3D Molecular Representation Learning Framework.](https://chemrxiv.org/engage/chemrxiv/article-details/6318b529bada388485bc8361)\"\n", + "ChemRxiv (2022)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Download pretrained pocket weights, and CASF-2016 data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%bash\n", + "pocket_data_url='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/CASF-2016.tar.gz'\n", + "pocket_weight_url='https://github.com/deepmodeling/Uni-Mol/releases/download/v0.1/pocket_pre_220816.pt'\n", + "wget -q ${pocket_data_url}\n", + "tar -xzf \"CASF-2016.tar.gz\"\n", + "wget -q ${pocket_weight_url}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Read pocket information from CASF-2016 and save it to a .lmdb file" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pandas as pd\n", + "import lmdb\n", + "from biopandas.pdb import PandasPdb\n", + "from tqdm import tqdm\n", + "import pickle\n", + "import re\n", + "import json\n", + "import glob" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "CASF_PATH = \"CASF-2016\"\n", + "main_atoms = [\"N\", \"CA\", \"C\", \"O\", \"H\"]\n", + "\n", + "def load_from_CASF(pdb_id):\n", + " try:\n", + " pdb_path = os.path.join(CASF_PATH, \"casf2016\", pdb_id + \"_protein.pdb\")\n", + " pmol = PandasPdb().read_pdb(pdb_path)\n", + " pocket_residues = json.load(\n", + " open(os.path.join(CASF_PATH, \"casf2016.pocket.json\"))\n", + " )[pdb_id]\n", + " return pmol, pocket_residues\n", + " except:\n", + " print(\"Currently not support parsing pdb and pocket info from local files.\")\n", + "\n", + "def normalize_atoms(atom):\n", + " return re.sub(\"\\d+\", \"\", atom)\n", + "\n", + "def parser(pdb_id):\n", + " pmol, pocket_residues = load_from_CASF(pdb_id)\n", + " pname = pdb_id\n", + " pro_atom = pmol.df[\"ATOM\"]\n", + " pro_hetatm = pmol.df[\"HETATM\"]\n", + "\n", + " pro_atom[\"ID\"] = pro_atom[\"chain_id\"].astype(str) + pro_atom[\n", + " \"residue_number\"\n", + " ].astype(str)\n", + " pro_hetatm[\"ID\"] = pro_hetatm[\"chain_id\"].astype(str) + pro_hetatm[\n", + " \"residue_number\"\n", + " ].astype(str)\n", + "\n", + " pocket = pd.concat(\n", + " [\n", + " pro_atom[pro_atom[\"ID\"].isin(pocket_residues)],\n", + " pro_hetatm[pro_hetatm[\"ID\"].isin(pocket_residues)],\n", + " ],\n", + " axis=0,\n", + " ignore_index=True,\n", + " )\n", + "\n", + " pocket[\"normalize_atom\"] = pocket[\"atom_name\"].map(normalize_atoms)\n", + " pocket = pocket[pocket[\"normalize_atom\"] != \"\"]\n", + " patoms = pocket[\"atom_name\"].apply(normalize_atoms).values.tolist()\n", + " pcoords = [pocket[[\"x_coord\", \"y_coord\", \"z_coord\"]].values]\n", + " side = [0 if a in main_atoms else 1 for a in patoms]\n", + " residues = (\n", + " pocket[\"chain_id\"].astype(str) + pocket[\"residue_number\"].astype(str)\n", + " ).values.tolist()\n", + "\n", + " return pickle.dumps(\n", + " {\n", + " \"atoms\": patoms,\n", + " \"coordinates\": pcoords,\n", + " \"side\": side,\n", + " \"residue\": residues,\n", + " \"pdbid\": pname,\n", + " },\n", + " protocol=-1,\n", + " )\n", + "\n", + "def write_lmdb(pdb_id_list, job_name, outpath=\"./results\"):\n", + " os.makedirs(outpath, exist_ok=True)\n", + " outputfilename = os.path.join(outpath, job_name + \".lmdb\")\n", + " try:\n", + " os.remove(outputfilename)\n", + " except:\n", + " pass\n", + " env_new = lmdb.open(\n", + " outputfilename,\n", + " subdir=False,\n", + " readonly=False,\n", + " lock=False,\n", + " readahead=False,\n", + " meminit=False,\n", + " max_readers=1,\n", + " map_size=int(10e9),\n", + " )\n", + " txn_write = env_new.begin(write=True)\n", + " for i, pdb_id in tqdm(enumerate(pdb_id_list)):\n", + " inner_output = parser(pdb_id)\n", + " txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n", + " txn_write.commit()\n", + " env_new.close()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "job_name = 'get_pocket_repr' # replace to your custom name\n", + "data_path = './results' # replace to your data path\n", + "weight_path='pocket_pre_220816.pt' # replace to your ckpt path\n", + "only_polar=0 # no h\n", + "dict_name='dict_coarse.txt'\n", + "batch_size=16\n", + "results_path=data_path # replace to your save path\n", + "casf_collect = os.listdir(os.path.join(CASF_PATH, \"casf2016\"))\n", + "casf_collect = list(set([item[:4] for item in casf_collect]))\n", + "casf_collect.remove('3qgy')\n", + "write_lmdb(casf_collect, job_name=job_name, outpath=data_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Infer from pretrained pocket ckpt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# NOTE: Currently, the inference is only supported to run on a single GPU. You can add CUDA_VISIBLE_DEVICES=\"0\" before the command.\n", + "!cp ../example_data/pocket/$dict_name $data_path\n", + "!CUDA_VISIBLE_DEVICES=\"0\" python ../unimol/infer.py --user-dir ../unimol $data_path --valid-subset $job_name \\\n", + " --results-path $results_path \\\n", + " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", + " --task unimol_pocket --loss unimol_infer --arch unimol_base \\\n", + " --path $weight_path \\\n", + " --dict-name $dict_name \\\n", + " --log-interval 50 --log-format simple --random-token-prob 0 --leave-unmasked-prob 1.0 --mode infer" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Read .pkl and save results to .csv" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_csv_results(predict_path, results_path):\n", + " predict = pd.read_pickle(predict_path)\n", + " pdb_id_list, mol_repr_list, atom_repr_list, pair_repr_list = [], [], []\n", + " for batch in predict:\n", + " sz = batch[\"bsz\"]\n", + " for i in range(sz):\n", + " pdb_id_list.append(batch[\"data_name\"][i])\n", + " mol_repr_list.append(batch[\"mol_repr_cls\"][i])\n", + " atom_repr_list.append(batch['atom_repr'][i])\n", + " pair_repr_list.append(batch[\"pair_repr\"][i])\n", + " predict_df = pd.DataFrame({\"pdb_id\": pdb_id_list, \"mol_repr\": mol_repr_list, \"atom_repr\": atom_repr_list, \"pair_repr\": pair_repr_list})\n", + " print(predict_df.head(1),predict_df.info())\n", + " predict_df.to_csv(results_path+'/mol_repr.csv',index=False)\n", + "\n", + "pkl_path = glob.glob(f'{results_path}/*_{job_name}.out.pkl')[0]\n", + "get_csv_results(pkl_path, results_path)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_posebuster_demo.ipynb b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_posebuster_demo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..51f5dcdadb62dd389af771412c80e4634b200ab7 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/notebooks/unimol_posebuster_demo.ipynb @@ -0,0 +1,387 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Please note before running:\n", + "Commit: b962451 (b962451a019e15363bd34b3af9d3a3cd02330947)\n", + "\n", + "Workspace path: Uni-Mol\n", + "\n", + "Notebook path: Uni-Mol/unimol_posebuster_demo.ipynb" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Import modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import pickle\n", + "import numpy as np\n", + "import pandas as pd\n", + "from rdkit import Chem, RDLogger\n", + "from rdkit.Chem import AllChem\n", + "from tqdm import tqdm\n", + "RDLogger.DisableLog('rdApp.*') \n", + "import warnings\n", + "warnings.filterwarnings(action='ignore')\n", + "from multiprocessing import Pool\n", + "import copy\n", + "import lmdb\n", + "from biopandas.pdb import PandasPdb\n", + "from sklearn.cluster import KMeans\n", + "from rdkit.Chem.rdMolAlign import AlignMolConformers" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preprocess func for generating the LMDB file" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# allowed atom types \n", + "main_atoms = ['N', 'CA', 'C', 'O', 'H']\n", + "allow_pocket_atoms = ['C', 'H', 'N', 'O', 'S']\n", + "\n", + "def cal_configs(coords):\n", + " \"\"\"Calculate pocket configs\"\"\"\n", + "\n", + " centerx,centery,centerz = list((np.max(coords,axis=0)+np.min(coords,axis=0))/2)\n", + " sizex,sizey,sizez = list(np.max(coords,axis=0)-np.mean(coords,axis=0))\n", + " config = {'cx':centerx,'cy':centery,'cz':centerz,\n", + " 'sx':sizex,'sy':sizey,'sz':sizez}\n", + " \n", + " return config,centerx,centery,centerz,sizex,sizey,sizez\n", + "\n", + "\n", + "def filter_pocketatoms(atom):\n", + " if atom[:2] in ['Cd','Cs', 'Cn', 'Ce', 'Cm', 'Cf', 'Cl', 'Ca', \\\n", + " 'Cr', 'Co', 'Cu', 'Nh', 'Nd', 'Np', 'No', 'Ne', 'Na',\\\n", + " 'Ni','Nb', 'Os', 'Og', 'Hf', 'Hg', 'Hs', 'Ho', 'He',\\\n", + " 'Sr', 'Sn', 'Sb', 'Sg', 'Sm', 'Si', 'Sc', 'Se']:\n", + " return None\n", + " if atom[0] >= '0' and atom[0] <= '9':\n", + " return filter_pocketatoms(atom[1:])\n", + " if atom[0] in ['Z','M','P','D','F','K','I','B']:\n", + " return None\n", + " if atom[0] in allow_pocket_atoms:\n", + " return atom\n", + " return atom\n", + "\n", + "\n", + "def single_conf_gen(tgt_mol, num_confs=1000, seed=42, removeHs=True):\n", + " mol = copy.deepcopy(tgt_mol)\n", + " mol = Chem.AddHs(mol)\n", + " allconformers = AllChem.EmbedMultipleConfs(mol, numConfs=num_confs, randomSeed=seed, clearConfs=True)\n", + " sz = len(allconformers)\n", + " for i in range(sz):\n", + " try:\n", + " AllChem.MMFFOptimizeMolecule(mol, confId=i)\n", + " except:\n", + " continue\n", + " if removeHs:\n", + " mol = Chem.RemoveHs(mol)\n", + " return mol\n", + "\n", + "\n", + "def clustering_coords(mol, M=1000, N=100, seed=42, removeHs=True, method='bonds'):\n", + " rdkit_coords_list = []\n", + " if method == 'rdkit_MMFF':\n", + " rdkit_mol = single_conf_gen(mol, num_confs=M, seed=seed, removeHs=removeHs)\n", + " else:\n", + " print('no conformer generation methods:{}'.format(method))\n", + " raise \n", + " noHsIds = [rdkit_mol.GetAtoms()[i].GetIdx() for i in range(len(rdkit_mol.GetAtoms())) if rdkit_mol.GetAtoms()[i].GetAtomicNum()!=1]\n", + " ### exclude hydrogens for aligning\n", + " AlignMolConformers(rdkit_mol, atomIds=noHsIds)\n", + " sz = len(rdkit_mol.GetConformers())\n", + " for i in range(sz):\n", + " _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32)\n", + " rdkit_coords_list.append(_coords)\n", + "\n", + " ### exclude hydrogens for clustering\n", + " rdkit_coords_flatten = np.array(rdkit_coords_list)[:, noHsIds].reshape(sz,-1)\n", + " ids = KMeans(n_clusters=N, random_state=seed).fit_predict(rdkit_coords_flatten).tolist()\n", + " coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)]\n", + " return coords_list\n", + "\n", + "\n", + "def extract_pose_posebuster(content):\n", + "\n", + " pdbid, ligid, protein_path, ligand_path, index = content\n", + "\n", + " def read_pdb(path, pdbid):\n", + " #### protein preparation\n", + " pfile = os.path.join(path, pdbid+'.pdb')\n", + " pmol = PandasPdb().read_pdb(pfile)\n", + " \n", + " return pmol\n", + "\n", + " ### totally posebuster data\n", + " def read_mol(path, pdbid, ligid):\n", + " lsdf = os.path.join(path, f'{pdbid}_{ligid}.sdf')\n", + " supp = Chem.SDMolSupplier(lsdf)\n", + " mols = [mol for mol in supp if mol]\n", + " if len(mols) == 0:\n", + " print(lsdf)\n", + " mol = mols[0]\n", + " return mol\n", + "\n", + " # influence pocket size\n", + " dist_thres=6\n", + " if pdbid == 'index' or pdbid == 'readme':\n", + " return None\n", + "\n", + " pmol = read_pdb(protein_path, pdbid)\n", + " pname = pdbid\n", + " mol = read_mol(ligand_path, pdbid, ligid)\n", + " mol = Chem.RemoveHs(mol)\n", + " lcoords = mol.GetConformer().GetPositions().astype(np.float32)\n", + " \n", + " pdf = pmol.df['ATOM']\n", + " filter_std = []\n", + " for lcoord in lcoords:\n", + " pdf['dist'] = pmol.distance(xyz=list(lcoord), records=('ATOM'))\n", + " df = pdf[(pdf.dist <= dist_thres) & (pdf.element_symbol != 'H')][['chain_id', 'residue_number']]\n", + " filter_std += list(zip(df.chain_id.tolist(), df.residue_number.tolist()))\n", + "\n", + " filter_std = set(filter_std)\n", + " patoms, pcoords, residues = [], np.empty((0,3)), []\n", + " for id,res in filter_std:\n", + " df = pdf[(pdf.chain_id == id) & (pdf.residue_number == res)]\n", + " patoms += df['atom_name'].tolist()\n", + " pcoords = np.concatenate((pcoords, df[['x_coord','y_coord','z_coord']].to_numpy()), axis=0)\n", + " residues += [str(id)+str(res)]*len(df)\n", + "\n", + " if len(pcoords)==0:\n", + " print('empty pocket:', pdbid)\n", + " return None\n", + " config,centerx,centery,centerz,sizex,sizey,sizez = cal_configs(pcoords)\n", + "\n", + " # filter unnormal atoms, include metal\n", + " atoms, index, residues_tmp = [], [], []\n", + " for i,a in enumerate(patoms):\n", + " output = filter_pocketatoms(a)\n", + " if output is not None:\n", + " index.append(True)\n", + " atoms.append(output)\n", + " residues_tmp.append(residues[i])\n", + " else:\n", + " index.append(False)\n", + " coordinates = pcoords[index].astype(np.float32)\n", + " residues = residues_tmp\n", + "\n", + " assert len(atoms) == len(residues)\n", + " assert len(atoms) == coordinates.shape[0]\n", + "\n", + " if len(atoms) != coordinates.shape[0]:\n", + " print(pname)\n", + " return None\n", + " patoms = atoms\n", + " pcoords = [coordinates]\n", + " side = [0 if a in main_atoms else 1 for a in patoms]\n", + "\n", + " smiles = Chem.MolToSmiles(mol)\n", + " mol = AllChem.AddHs(mol, addCoords=True)\n", + " latoms = [atom.GetSymbol() for atom in mol.GetAtoms()]\n", + " holo_coordinates = [mol.GetConformer().GetPositions().astype(np.float32)]\n", + " holo_mol = mol\n", + " \n", + " M, N = 100, 10\n", + " coordinate_list = clustering_coords(mol, M=M, N=N, seed=42, removeHs=False, method='rdkit_MMFF')\n", + " mol_list = [mol]*N\n", + " ligand = [latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol]\n", + "\n", + " return pname, patoms, pcoords, side, residues, config, ligand\n", + "\n", + "\n", + "def parser(content):\n", + " pname, patoms, pcoords, side, residues, config, ligand = extract_pose_posebuster(content)\n", + " latoms, coordinate_list, holo_coordinates, smiles, mol_list, holo_mol = ligand\n", + " pickle.dumps({})\n", + " return pickle.dumps(\n", + " {\n", + " \"atoms\": latoms,\n", + " \"coordinates\": coordinate_list,\n", + " \"mol_list\": mol_list,\n", + " \"pocket_atoms\": patoms,\n", + " \"pocket_coordinates\": pcoords,\n", + " \"side\": side,\n", + " \"residue\": residues,\n", + " \"config\": config,\n", + " \"holo_coordinates\": holo_coordinates,\n", + " \"holo_mol\": holo_mol,\n", + " \"holo_pocket_coordinates\": pcoords,\n", + " \"smi\": smiles,\n", + " 'pocket':pname,\n", + " 'scaffold':pname,\n", + " },\n", + " protocol=-1,\n", + " )\n", + "\n", + "\n", + "def write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=8):\n", + " os.makedirs(outpath, exist_ok=True)\n", + " df = pd.read_csv(meta_info_file)\n", + " pdb_ids = list(df['pdb_code'].values)\n", + " lig_ids = list(df['lig_code'].values)\n", + " content_list = list(zip(pdb_ids, lig_ids, [protein_path]*len(pdb_ids), [ligand_path]*len(pdb_ids), range(len(pdb_ids))))\n", + " outputfilename = os.path.join(outpath, lmdb_name +'.lmdb')\n", + " try:\n", + " os.remove(outputfilename)\n", + " except:\n", + " pass\n", + " env_new = lmdb.open(\n", + " outputfilename,\n", + " subdir=False,\n", + " readonly=False,\n", + " lock=False,\n", + " readahead=False,\n", + " meminit=False,\n", + " max_readers=1,\n", + " map_size=int(100e9),\n", + " )\n", + " txn_write = env_new.begin(write=True)\n", + " print(\"Start preprocessing data...\")\n", + " print(f'Number of systems: {len(pdb_ids)}')\n", + " with Pool(nthreads) as pool:\n", + " i = 0\n", + " failed_num = 0\n", + " for inner_output in tqdm(pool.imap(parser, content_list)):\n", + " if inner_output is not None:\n", + " txn_write.put(f\"{i}\".encode(\"ascii\"), inner_output)\n", + " i+=1\n", + " elif inner_output is None: \n", + " failed_num += 1\n", + " txn_write.commit()\n", + " env_new.close()\n", + " print(f'Total num: {len(pdb_ids)}, Success: {i}, Failed: {failed_num}')\n", + " print(\"Done!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Generate `lmdb` from `pdb` and `sdf`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "protein_path = 'eval_sets/posebusters/proteins'\n", + "ligand_path = 'eval_sets/posebusters/ligands'\n", + "outpath = 'posebuster_test'\n", + "meta_info_file = 'eval_sets/posebusters/posebuster_set_meta.csv'\n", + "lmdb_name = 'posebuster_428'\n", + "nthreads = 8\n", + "\n", + "write_lmdb(protein_path, ligand_path, outpath, meta_info_file, lmdb_name, nthreads=nthreads)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Infer with public ckp\n", + "The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "data_path=outpath\n", + "results_path=\"./infer_pose\" # replace to your results path\n", + "weight_path=\"./ckp/binding_pose_220908.pt\"\n", + "batch_size=8\n", + "dist_threshold=8.0\n", + "recycling=3\n", + "valid_subset=lmdb_name\n", + "mol_dict_name='dict_mol.txt'\n", + "pocket_dict_name='dict_pkt.txt'\n", + "\n", + "!cp ./example_data/molecule/dict.txt $data_path/$mol_dict_name\n", + "!cp ./example_data/pocket/dict_coarse.txt $data_path/$pocket_dict_name\n", + "!python ./unimol/infer.py --user-dir ./unimol $data_path --valid-subset $valid_subset \\\n", + " --results-path $results_path \\\n", + " --num-workers 8 --ddp-backend=c10d --batch-size $batch_size \\\n", + " --task docking_pose --loss docking_pose --arch docking_pose \\\n", + " --path $weight_path \\\n", + " --fp16 --fp16-init-scale 4 --fp16-scale-window 256 \\\n", + " --dist-threshold $dist_threshold --recycling $recycling \\\n", + " --log-interval 50 --log-format simple" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Docking and cal metrics:\n", + "The script is the same as it is in the [Readme](https://github.com/dptech-corp/Uni-Mol/tree/main/unimol#protein-ligand-binding-pose-prediction)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nthreads=8 # Num of threads\n", + "predict_file=f\"{results_path}/ckp_{lmdb_name}.out.pkl\" # Your inference file dir\n", + "reference_file=f\"{outpath}/{lmdb_name}.lmdb\" # Your reference file dir\n", + "output_path=\"./unimol_repro_posebuster428\" # Docking results path\n", + "\n", + "!python ./unimol/utils/docking.py --nthreads $nthreads --predict-file $predict_file --reference-file $reference_file --output-path $output_path" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindChemistry/applications/Uni-Mol/unimol/requirements.txt b/MindChemistry/applications/Uni-Mol/unimol/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..f1cd33521b5767997d25a8f36998f445ff4ba79b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/requirements.txt @@ -0,0 +1 @@ +git+git://github.com/dptech-corp/Uni-Core.git@stable#egg=Uni-Core diff --git a/MindChemistry/applications/Uni-Mol/unimol/setup.py b/MindChemistry/applications/Uni-Mol/unimol/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..11e12d53e2c4709bb2a6fb4b9de209897100d58d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/setup.py @@ -0,0 +1,33 @@ +"""Install script for setuptools.""" + +from setuptools import find_packages +from setuptools import setup + +setup( + name="unimol", + version="1.0.0", + description="", + author="DP Technology", + author_email="unimol@dp.tech", + license="The MIT License", + url="https://github.com/deepmodeling/Uni-Mol", + packages=find_packages( + exclude=["scripts", "tests", "example_data", "docker", "figure"] + ), + install_requires=[ + "numpy", + "pandas", + "scikit-learn-extra", + ], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..58b8eabdb8d9d0a1a02fadb13c31ccc953217c10 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/__init__.py @@ -0,0 +1,6 @@ +# import importlib +# import unimol.tasks +# import unimol.data +# import unimol.models +# import unimol.losses +# import unimol.utils diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..848d97825658cf04bb0ce89c295b4a0abc104149 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/__init__.py @@ -0,0 +1,42 @@ +from .key_dataset import KeyDataset +from .normalize_dataset import ( + NormalizeDataset, + NormalizeDockingPoseDataset, +) +from .remove_hydrogen_dataset import ( + RemoveHydrogenDataset, + RemoveHydrogenResiduePocketDataset, + RemoveHydrogenPocketDataset, +) +from .tta_dataset import ( + TTADataset, + TTADockingPoseDataset, +) +from .cropping_dataset import ( + CroppingDataset, + CroppingPocketDataset, + CroppingResiduePocketDataset, + CroppingPocketDockingPoseDataset, +) +from .atom_type_dataset import AtomTypeDataset +from .add_2d_conformer_dataset import Add2DConformerDataset +from .distance_dataset import ( + DistanceDataset, + EdgeTypeDataset, + CrossDistanceDataset, +) +from .conformer_sample_dataset import ( + ConformerSampleDataset, + ConformerSamplePocketDataset, + ConformerSamplePocketFinetuneDataset, + ConformerSampleConfGDataset, + ConformerSampleConfGV2Dataset, + ConformerSampleDockingPoseDataset, +) +from .mask_points_dataset import MaskPointsDataset, MaskPointsPocketDataset +from .coord_pad_dataset import RightPadDatasetCoord, RightPadDatasetCross2D +from .from_str_dataset import FromStrLabelDataset +from .lmdb_dataset import LMDBDataset +from .prepend_and_append_2d_dataset import PrependAndAppend2DDataset + +__all__ = [] \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/add_2d_conformer_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/add_2d_conformer_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..047d6d6ef4b573fa0466939e1353b8c26b9f6e7d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/add_2d_conformer_dataset.py @@ -0,0 +1,245 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import warnings +warnings.filterwarnings("ignore", category=UserWarning, module="numpy.core.getlimits") + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset +from rdkit import Chem +from rdkit.Chem import AllChem + + +# -------------------------- 新增1:常见原子符号→原子序数映射表(覆盖99%分子场景) -------------------------- +# 键:原子符号(如 'C'、'O'),值:对应的原子序数(数值) +# 若你的数据中有其他原子(如 'N'、'S'),可按格式添加到字典中 +ATOM_SYMBOL_TO_NUMBER = { + 'H': 1, # 氢 + 'He': 2, # 氦 + 'Li': 3, # 锂 + 'Be': 4, # 铍 + 'B': 5, # 硼 + 'C': 6, # 碳(你当前报错的符号) + 'N': 7, # 氮 + 'O': 8, # 氧 + 'F': 9, # 氟 + 'Ne': 10, # 氖 + 'Na': 11, # 钠 + 'Mg': 12, # 镁 + 'Al': 13, # 铝 + 'Si': 14, # 硅 + 'P': 15, # 磷 + 'S': 16, # 硫 + 'Cl': 17, # 氯 + 'K': 19, # 钾 + 'Ca': 20, # 钙 + 'Fe': 26, # 铁(若有金属分子可添加) + 'Cu': 29 # 铜(按需添加) +} + + +# -------------------------- 新增2:原子符号转原子序数的辅助函数 -------------------------- +def convert_atom_symbols_to_numbers(raw_atoms): + """ + 将原子符号列表(如 ['C','C','O'])转为原子序数列表(如 [6,6,8]) + Args: + raw_atoms: 原始atoms数据(可能是字符列表、字符串、数值列表) + Returns: + converted_atoms: 原子序数列表(数值类型) + """ + converted_atoms = [] + for atom in raw_atoms: + # 1. 若已是数值(如 6、8.0),直接保留 + if isinstance(atom, (int, float)): + converted_atoms.append(atom) + # 2. 若是字符串(如 'C'、'6'),先处理 + elif isinstance(atom, str): + atom = atom.strip() # 去除空格(避免 ' C ' 这类脏数据) + # 2.1 若字符串是数字(如 '6'),转为数值 + if atom.isdigit(): + converted_atoms.append(float(atom)) + # 2.2 若字符串是原子符号(如 'C'),用映射表转换 + elif atom in ATOM_SYMBOL_TO_NUMBER: + converted_atoms.append(ATOM_SYMBOL_TO_NUMBER[atom]) + # 2.3 遇到不认识的符号,报错并提示(方便你补充映射表) + else: + raise ValueError( + f"未识别的原子符号:'{atom}'!\n" + f"请在 ATOM_SYMBOL_TO_NUMBER 字典中添加该符号对应的原子序数(如 'X': 123)。\n" + f"当前支持的原子符号:{list(ATOM_SYMBOL_TO_NUMBER.keys())}" + ) + # 3. 遇到其他类型(如布尔值),报错 + else: + raise TypeError( + f"不支持的原子数据类型:{type(atom)}(值:{atom})!\n" + f"请确保原子数据是 数值 或 原子符号字符串(如 6、'C')。" + ) + return converted_atoms + + +class Add2DConformerDataset(BaseWrapperDataset): + def __init__(self, dataset, smi, atoms, coordinates): + self.dataset = dataset + self.smi = smi + self.atoms = atoms + self.coordinates = coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + # -------------------------- 核心修改:先转原子符号为序数,再转float32 -------------------------- + # 1. 获取原始atoms数据(如 ['C','C','O']) + raw_atoms = self.dataset[index][self.atoms] + # 2. 调用新增函数,将原子符号转为原子序数(如 ['C','C','O'] → [6,6,8]) + converted_atoms = convert_atom_symbols_to_numbers(raw_atoms) + # 3. 转为float32数组(模型需要的数值类型) + atoms = np.array(converted_atoms, dtype=np.float32) + # -------------------------------------------------------------------------------------------------- + + assert len(atoms) > 0, f"第{index}个样本的atoms为空!" + smi = self.dataset[index][self.smi] + + # 生成2D坐标(原逻辑不变) + coordinates_2d = smi2_2Dcoords(smi) + + # 处理coordinates(确保是数值类型,避免后续问题) + raw_coords = self.dataset[index][self.coordinates] + # 若coordinates是列表,先转为float32数组再转回列表(保持格式一致) + if isinstance(raw_coords, list): + coordinates = np.array(raw_coords, dtype=np.float32).tolist() + else: + coordinates = raw_coords.astype(np.float32).tolist() + coordinates.append(coordinates_2d) + + return {"smi": smi, "atoms": atoms, "coordinates": coordinates} + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +def smi2_2Dcoords(smi): + # 增强异常处理:避免无效SMILES导致崩溃 + mol = Chem.MolFromSmiles(smi) + if mol is None: + raise ValueError(f"无效的SMILES字符串:'{smi}'!请检查数据是否正确。") + + mol = AllChem.AddHs(mol) + AllChem.Compute2DCoords(mol) + coordinates = mol.GetConformer().GetPositions().astype(np.float32) + + # 验证原子数与坐标数一致(用显式异常替代原断言,报错更清晰) + atom_count = len(mol.GetAtoms()) + coord_count = len(coordinates) + if atom_count != coord_count: + raise ValueError( + f"2D坐标与原子数不匹配!SMILES: '{smi}'\n" + f"原子数:{atom_count},坐标数:{coord_count}" + ) + return coordinates + + + +# import warnings +# warnings.filterwarnings("ignore", category=UserWarning, module="numpy.core.getlimits") + +# import numpy as np +# from functools import lru_cache +# from unicore.data import BaseWrapperDataset +# from rdkit import Chem +# from rdkit.Chem import AllChem + + +# class Add2DConformerDataset(BaseWrapperDataset): +# def __init__(self, dataset, smi, atoms, coordinates): +# self.dataset = dataset +# self.smi = smi +# self.atoms = atoms +# self.coordinates = coordinates +# self.set_epoch(None) + +# def set_epoch(self, epoch, **unused): +# super().set_epoch(epoch) +# self.epoch = epoch + +# @lru_cache(maxsize=16) +# def __cached_item__(self, index: int, epoch: int): +# # -------------------------- 核心修改:添加 dtype=np.float32,强制转为数值类型 -------------------------- +# atoms = np.array(self.dataset[index][self.atoms], dtype=np.float32) +# # -------------------------------------------------------------------------------------------------- +# assert len(atoms) > 0 +# smi = self.dataset[index][self.smi] +# coordinates_2d = smi2_2Dcoords(smi) +# coordinates = self.dataset[index][self.coordinates] +# # 额外优化:coordinates 也建议转成数值数组(避免类似问题) +# coordinates = np.array(coordinates, dtype=np.float32).tolist() # 转数组后再转回列表(保持原格式) +# coordinates.append(coordinates_2d) +# return {"smi": smi, "atoms": atoms, "coordinates": coordinates} + +# def __getitem__(self, index: int): +# return self.__cached_item__(index, self.epoch) + + +# def smi2_2Dcoords(smi): +# mol = Chem.MolFromSmiles(smi) +# if mol is None: # 额外加个异常处理,避免无效SMILES导致崩溃 +# raise ValueError(f"Invalid SMILES: {smi}") +# mol = AllChem.AddHs(mol) +# AllChem.Compute2DCoords(mol) +# coordinates = mol.GetConformer().GetPositions().astype(np.float32) +# if len(mol.GetAtoms()) != len(coordinates): # 原代码的断言改为显式异常,报错更清晰 +# raise ValueError(f"2D coordinates shape not align with SMILES {smi}: atoms={len(mol.GetAtoms())}, coords={len(coordinates)}") +# return coordinates + + + + +# import warnings +# warnings.filterwarnings("ignore", category=UserWarning, module="numpy.core.getlimits") + +# import numpy as np +# from functools import lru_cache +# from unicore.data import BaseWrapperDataset +# from rdkit import Chem +# from rdkit.Chem import AllChem + + +# class Add2DConformerDataset(BaseWrapperDataset): +# def __init__(self, dataset, smi, atoms, coordinates): +# self.dataset = dataset +# self.smi = smi +# self.atoms = atoms +# self.coordinates = coordinates +# self.set_epoch(None) + +# def set_epoch(self, epoch, **unused): +# super().set_epoch(epoch) +# self.epoch = epoch + +# @lru_cache(maxsize=16) +# def __cached_item__(self, index: int, epoch: int): +# atoms = np.array(self.dataset[index][self.atoms]) +# assert len(atoms) > 0 +# smi = self.dataset[index][self.smi] +# coordinates_2d = smi2_2Dcoords(smi) +# coordinates = self.dataset[index][self.coordinates] +# coordinates.append(coordinates_2d) +# return {"smi": smi, "atoms": atoms, "coordinates": coordinates} + +# def __getitem__(self, index: int): +# return self.__cached_item__(index, self.epoch) + + +# def smi2_2Dcoords(smi): +# mol = Chem.MolFromSmiles(smi) +# mol = AllChem.AddHs(mol) +# AllChem.Compute2DCoords(mol) +# coordinates = mol.GetConformer().GetPositions().astype(np.float32) +# len(mol.GetAtoms()) == len( +# coordinates +# ), "2D coordinates shape is not align with {}".format(smi) +# return coordinates diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/atom_type_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/atom_type_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..ee3c5d86a5a5b3fb945836f2aeedad4f89502a96 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/atom_type_dataset.py @@ -0,0 +1,34 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class AtomTypeDataset(BaseWrapperDataset): + def __init__( + self, + raw_dataset, + dataset, + smi="smi", + atoms="atoms", + ): + self.raw_dataset = raw_dataset + self.dataset = dataset + self.smi = smi + self.atoms = atoms + + @lru_cache(maxsize=16) + def __getitem__(self, index: int): + # for low rdkit version + if len(self.dataset[index]["atoms"]) != len(self.dataset[index]["coordinates"]): + min_len = min( + len(self.dataset[index]["atoms"]), + len(self.dataset[index]["coordinates"]), + ) + self.dataset[index]["atoms"] = self.dataset[index]["atoms"][:min_len] + self.dataset[index]["coordinates"] = self.dataset[index]["coordinates"][ + :min_len + ] + return self.dataset[index] diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/conformer_sample_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/conformer_sample_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..804a18012a1394d953364e99ba8941724d4015aa --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/conformer_sample_dataset.py @@ -0,0 +1,280 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset +from . import data_utils + + +class ConformerSampleDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + atoms = np.array(self.dataset[index][self.atoms]) + assert len(atoms) > 0 + size = len(self.dataset[index][self.coordinates]) + with data_utils.numpy_seed(self.seed, epoch, index): + sample_idx = np.random.randint(size) + coordinates = self.dataset[index][self.coordinates][sample_idx] + return {"atoms": atoms, "coordinates": coordinates.astype(np.float32)} + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class ConformerSamplePocketDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, dict_name): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.dict_name = dict_name + self.coordinates = coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + if self.dict_name == "dict_coarse.txt": + atoms = np.array([a[0] for a in self.dataset[index][self.atoms]]) + elif self.dict_name == "dict_fine.txt": + atoms = np.array( + [ + a[0] if len(a) == 1 or a[0] == "H" else a[:2] + for a in self.dataset[index][self.atoms] + ] + ) + assert len(atoms) > 0 + size = len(self.dataset[index][self.coordinates]) + with data_utils.numpy_seed(self.seed, epoch, index): + sample_idx = np.random.randint(size) + coordinates = self.dataset[index][self.coordinates][sample_idx] + residue = np.array(self.dataset[index]["residue"]) + return { + "atoms": atoms, + "coordinates": coordinates.astype(np.float32), + "residue": residue, + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class ConformerSamplePocketFinetuneDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, residues, coordinates): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.residues = residues + self.coordinates = coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + atoms = np.array( + [a[0] for a in self.dataset[index][self.atoms]] + ) # only 'C H O N S' + assert len(atoms) > 0 + # This judgment is reserved for possible future expansion. + # The number of pocket conformations is 1, and the 'sample' does not work. + if isinstance(self.dataset[index][self.coordinates], list): + size = len(self.dataset[index][self.coordinates]) + with data_utils.numpy_seed(self.seed, epoch, index): + sample_idx = np.random.randint(size) + coordinates = self.dataset[index][self.coordinates][sample_idx] + else: + coordinates = self.dataset[index][self.coordinates] + + if self.residues in self.dataset[index]: + residues = np.array(self.dataset[index][self.residues]) + else: + residues = None + assert len(atoms) == len(coordinates) + return { + self.atoms: atoms, + self.coordinates: coordinates.astype(np.float32), + self.residues: residues, + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class ConformerSampleConfGDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, tgt_coordinates): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.tgt_coordinates = tgt_coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + atoms = np.array(self.dataset[index][self.atoms]) + assert len(atoms) > 0 + size = len(self.dataset[index][self.coordinates]) + with data_utils.numpy_seed(self.seed, epoch, index): + sample_idx = np.random.randint(size) + coordinates = self.dataset[index][self.coordinates][sample_idx] + tgt_coordinates = self.dataset[index][self.tgt_coordinates] + return { + self.atoms: atoms, + self.coordinates: coordinates.astype(np.float32), + self.tgt_coordinates: tgt_coordinates.astype(np.float32), + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class ConformerSampleConfGV2Dataset(BaseWrapperDataset): + def __init__( + self, + dataset, + seed, + atoms, + coordinates, + tgt_coordinates, + beta=1.0, + smooth=0.1, + topN=10, + ): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.tgt_coordinates = tgt_coordinates + self.beta = beta + self.smooth = smooth + self.topN = topN + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + atoms = np.array(self.dataset[index][self.atoms]) + assert len(atoms) > 0 + meta_df = self.dataset[index]["meta"] + tgt_conf_ids = meta_df["gid"].unique() + # randomly choose one conf + with data_utils.numpy_seed(self.seed, epoch, index): + conf_id = np.random.choice(tgt_conf_ids) + conf_df = meta_df[meta_df["gid"] == conf_id] + conf_df = conf_df.sort_values("score").reset_index(drop=False)[ + : self.topN + ] # only use top 5 confs for sampling... + # importance sampling with rmsd inverse score + + def normalize(x, beta=1.0, smooth=0.1): + x = 1.0 / (x**beta + smooth) + return x / x.sum() + + rmsd_score = conf_df["score"].values + weight = normalize( + rmsd_score, beta=self.beta, smooth=self.smooth + ) # for smoothing purpose + with data_utils.numpy_seed(self.seed, epoch, index): + idx = np.random.choice(len(conf_df), 1, replace=False, p=weight) + # idx = [np.argmax(weight)] + coordinates = conf_df.iloc[idx]["rdkit_coords"].values[0] + tgt_coordinates = conf_df.iloc[idx]["tgt_coords"].values[0] + return { + self.atoms: atoms, + self.coordinates: coordinates.astype(np.float32), + self.tgt_coordinates: tgt_coordinates.astype(np.float32), + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class ConformerSampleDockingPoseDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + seed, + atoms, + coordinates, + pocket_atoms, + pocket_coordinates, + holo_coordinates, + holo_pocket_coordinates, + is_train=True, + ): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.pocket_atoms = pocket_atoms + self.pocket_coordinates = pocket_coordinates + self.holo_coordinates = holo_coordinates + self.holo_pocket_coordinates = holo_pocket_coordinates + self.is_train = is_train + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + atoms = np.array(self.dataset[index][self.atoms]) + size = len(self.dataset[index][self.coordinates]) + with data_utils.numpy_seed(self.seed, epoch, index): + sample_idx = np.random.randint(size) + coordinates = self.dataset[index][self.coordinates][sample_idx] + pocket_atoms = np.array( + [item[0] for item in self.dataset[index][self.pocket_atoms]] + ) + pocket_coordinates = self.dataset[index][self.pocket_coordinates][0] + if self.is_train: + holo_coordinates = self.dataset[index][self.holo_coordinates][0] + holo_pocket_coordinates = self.dataset[index][self.holo_pocket_coordinates][ + 0 + ] + else: + holo_coordinates = coordinates + holo_pocket_coordinates = pocket_coordinates + + smi = self.dataset[index]["smi"] + pocket = self.dataset[index]["pocket"] + + return { + "atoms": atoms, + "coordinates": coordinates.astype(np.float32), + "pocket_atoms": pocket_atoms, + "pocket_coordinates": pocket_coordinates.astype(np.float32), + "holo_coordinates": holo_coordinates.astype(np.float32), + "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32), + "smi": smi, + "pocket": pocket, + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/coord_pad_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/coord_pad_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..3cbfcc87925a89cd6860b7ff5ddb5b1afba0f582 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/coord_pad_dataset.py @@ -0,0 +1,82 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from unicore.data import BaseWrapperDataset + + +def collate_tokens_coords( + values, + pad_idx, + left_pad=False, + pad_to_length=None, + pad_to_multiple=1, +): + """Convert a list of 1d tensors into a padded 2d tensor.""" + size = max(v.size(0) for v in values) + size = size if pad_to_length is None else max(size, pad_to_length) + if pad_to_multiple != 1 and size % pad_to_multiple != 0: + size = int(((size - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + res = values[0].new(len(values), size, 3).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor(v, res[i][size - len(v) :, :] if left_pad else res[i][: len(v), :]) + return res + + +class RightPadDatasetCoord(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad=False): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + def collater(self, samples): + return collate_tokens_coords( + samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 + ) + + +def collate_cross_2d( + values, + pad_idx, + left_pad=False, + pad_to_length=None, + pad_to_multiple=1, +): + """Convert a list of 2d tensors into a padded 2d tensor.""" + size_h = max(v.size(0) for v in values) + size_w = max(v.size(1) for v in values) + if pad_to_multiple != 1 and size_h % pad_to_multiple != 0: + size_h = int(((size_h - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + if pad_to_multiple != 1 and size_w % pad_to_multiple != 0: + size_w = int(((size_w - 0.1) // pad_to_multiple + 1) * pad_to_multiple) + res = values[0].new(len(values), size_h, size_w).fill_(pad_idx) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + dst.copy_(src) + + for i, v in enumerate(values): + copy_tensor( + v, + res[i][size_h - v.size(0) :, size_w - v.size(1) :] + if left_pad + else res[i][: v.size(0), : v.size(1)], + ) + return res + + +class RightPadDatasetCross2D(BaseWrapperDataset): + def __init__(self, dataset, pad_idx, left_pad=False): + super().__init__(dataset) + self.pad_idx = pad_idx + self.left_pad = left_pad + + def collater(self, samples): + return collate_cross_2d( + samples, self.pad_idx, left_pad=self.left_pad, pad_to_multiple=8 + ) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/cropping_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/cropping_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..511b30446556ce19a630bbac4d02eb83d1240de5 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/cropping_dataset.py @@ -0,0 +1,220 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +import logging +from unicore.data import BaseWrapperDataset +from . import data_utils + +logger = logging.getLogger(__name__) + + +class CroppingDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = max_atoms + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + index = np.random.choice(len(atoms), self.max_atoms, replace=False) + atoms = np.array(atoms)[index] + coordinates = coordinates[index] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingPocketDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.max_atoms = ( + max_atoms # max number of atoms in a molecule, None indicates no limit. + ) + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + residue = dd["residue"] + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + distance += 1 # prevent inf + weight = softmax(np.reciprocal(distance)) + index = np.random.choice( + len(atoms), self.max_atoms, replace=False, p=weight + ) + atoms = atoms[index] + coordinates = coordinates[index] + residue = residue[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd["residue"] = residue + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingResiduePocketDataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, residues, coordinates, max_atoms=256): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.residues = residues + self.coordinates = coordinates + self.max_atoms = ( + max_atoms # max number of atoms in a molecule, None indicates no limit. + ) + + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + residues = dd[self.residues] + coordinates = dd[self.coordinates] + + residues_distance_map = {} + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch, index): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + residues_ids, residues_distance = [], [] + for res in residues: + if res not in residues_ids: + residues_ids.append(res) + residues_distance.append(distance[residues == res].mean()) + residues_ids = np.array(residues_ids) + residues_distance = np.array(residues_distance) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + residues_distance += 1 # prevent inf and smoothing out the distance + weight = softmax(np.reciprocal(residues_distance)) + max_residues = self.max_atoms // (len(atoms) // (len(residues_ids) + 1)) + if max_residues < 1: + max_residues += 1 + max_residues = min(max_residues, len(residues_ids)) + residue_index = np.random.choice( + len(residues_ids), max_residues, replace=False, p=weight + ) + index = [ + i + for i in range(len(atoms)) + if residues[i] in residues_ids[residue_index] + ] + atoms = atoms[index] + coordinates = coordinates[index] + residues = residues[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.residues] = residues + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class CroppingPocketDockingPoseDataset(BaseWrapperDataset): + def __init__( + self, dataset, seed, atoms, coordinates, holo_coordinates, max_atoms=256 + ): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.holo_coordinates = holo_coordinates + self.max_atoms = max_atoms + + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + holo_coordinates = dd[self.holo_coordinates] + + # crop atoms according to their distance to the center of pockets + if self.max_atoms and len(atoms) > self.max_atoms: + with data_utils.numpy_seed(self.seed, epoch): + distance = np.linalg.norm( + coordinates - coordinates.mean(axis=0), axis=1 + ) + + def softmax(x): + x -= np.max(x) + x = np.exp(x) / np.sum(np.exp(x)) + return x + + distance += 1 # prevent inf + weight = softmax(np.reciprocal(distance)) + index = np.random.choice( + len(atoms), self.max_atoms, replace=False, p=weight + ) + atoms = atoms[index] + coordinates = coordinates[index] + holo_coordinates = holo_coordinates[index] + + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.holo_coordinates] = holo_coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/data_utils.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..53111b654c91785e5fa1e6e682f8020aef0aed5c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/data_utils.py @@ -0,0 +1,23 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +import contextlib + + +@contextlib.contextmanager +def numpy_seed(seed, *addl_seeds): + """Context manager which seeds the NumPy PRNG with the specified seed and + restores the state afterward""" + if seed is None: + yield + return + if len(addl_seeds) > 0: + seed = int(hash((seed, *addl_seeds)) % 1e6) + state = np.random.get_state() + np.random.seed(seed) + try: + yield + finally: + np.random.set_state(state) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/distance_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/distance_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..046c461d0a33e7db5aa30cc240282de37fc1802e --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/distance_dataset.py @@ -0,0 +1,92 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import mindspore as ms +from scipy.spatial import distance_matrix +from functools import lru_cache +from unicore.data import BaseWrapperDataset # 假设BaseWrapperDataset在MindSpore环境中兼容或有对应实现 + + +class DistanceDataset(BaseWrapperDataset): + def __init__(self, dataset): + super().__init__(dataset) + self.dataset = dataset + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + # MindSpore中用reshape替代view改变张量形状;用mindspore.Tensor替代torch.from_numpy + pos = self.dataset[idx].reshape(-1, 3).asnumpy() # 先转为numpy数组 + dist = distance_matrix(pos, pos).astype(np.float32) + return ms.Tensor(dist) # 用mindspore.Tensor创建张量 + + +class EdgeTypeDataset(BaseWrapperDataset): + def __init__(self, dataset: ms.dataset.Dataset, num_types: int): # 替换为MindSpore的Dataset + self.dataset = dataset + self.num_types = num_types + + @lru_cache(maxsize=16) + def __getitem__(self, index: int): + node_input = self.dataset[index].clone() # clone在MindSpore中兼容 + # 用reshape替代view + offset = node_input.reshape(-1, 1) * self.num_types + node_input.reshape(1, -1) + return offset + + +class CrossDistanceDataset(BaseWrapperDataset): + def __init__(self, mol_dataset, pocket_dataset): + super().__init__(mol_dataset) + self.mol_dataset = mol_dataset + self.pocket_dataset = pocket_dataset + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + # 用reshape替代view;用mindspore.Tensor替代torch.from_numpy + mol_pos = self.mol_dataset[idx].reshape(-1, 3).asnumpy() + pocket_pos = self.pocket_dataset[idx].reshape(-1, 3).asnumpy() + dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32) + return ms.Tensor(dist) +# import numpy as np +# import torch +# from scipy.spatial import distance_matrix +# from functools import lru_cache +# from unicore.data import BaseWrapperDataset + + +# class DistanceDataset(BaseWrapperDataset): +# def __init__(self, dataset): +# super().__init__(dataset) +# self.dataset = dataset + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# pos = self.dataset[idx].view(-1, 3).numpy() +# dist = distance_matrix(pos, pos).astype(np.float32) +# return torch.from_numpy(dist) + + +# class EdgeTypeDataset(BaseWrapperDataset): +# def __init__(self, dataset: torch.utils.data.Dataset, num_types: int): +# self.dataset = dataset +# self.num_types = num_types + +# @lru_cache(maxsize=16) +# def __getitem__(self, index: int): +# node_input = self.dataset[index].clone() +# offset = node_input.view(-1, 1) * self.num_types + node_input.view(1, -1) +# return offset + + +# class CrossDistanceDataset(BaseWrapperDataset): +# def __init__(self, mol_dataset, pocket_dataset): +# super().__init__(mol_dataset) +# self.mol_dataset = mol_dataset +# self.pocket_dataset = pocket_dataset + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# mol_pos = self.mol_dataset[idx].view(-1, 3).numpy() +# pocket_pos = self.pocket_dataset[idx].view(-1, 3).numpy() +# dist = distance_matrix(mol_pos, pocket_pos).astype(np.float32) +# return torch.from_numpy(dist) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/from_str_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/from_str_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..abb53468cacf62d0cd226b76e972b4304e151b19 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/from_str_dataset.py @@ -0,0 +1,38 @@ +import mindspore as ms +from functools import lru_cache +from unicore.data import UnicoreDataset + + +class FromStrLabelDataset(UnicoreDataset): + def __init__(self, labels): + super().__init__() + self.labels = labels + + @lru_cache(maxsize=16) + def __getitem__(self, index): + return self.labels[index] + + def __len__(self): + return len(self.labels) + + def collater(self, samples): + return ms.Tensor(list(map(float, samples)), dtype=ms.float32) +# import torch +# from functools import lru_cache +# from unicore.data import UnicoreDataset + + +# class FromStrLabelDataset(UnicoreDataset): +# def __init__(self, labels): +# super().__init__() +# self.labels = labels + +# @lru_cache(maxsize=16) +# def __getitem__(self, index): +# return self.labels[index] + +# def __len__(self): +# return len(self.labels) + +# def collater(self, samples): +# return torch.tensor(list(map(float, samples))) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/key_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/key_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..c241374e474a7c0993f778e81d939ae794460677 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/key_dataset.py @@ -0,0 +1,19 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class KeyDataset(BaseWrapperDataset): + def __init__(self, dataset, key): + self.dataset = dataset + self.key = key + + def __len__(self): + return len(self.dataset) + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + return self.dataset[idx][self.key] diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/lmdb_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/lmdb_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..fced4c5d355d57c0f517f718dec4cd50c40fc0d9 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/lmdb_dataset.py @@ -0,0 +1,47 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + + +import lmdb +import os +import pickle +from functools import lru_cache +import logging + +logger = logging.getLogger(__name__) + + +class LMDBDataset: + def __init__(self, db_path): + self.db_path = db_path + assert os.path.isfile(self.db_path), "{} not found".format(self.db_path) + env = self.connect_db(self.db_path) + with env.begin() as txn: + self._keys = list(txn.cursor().iternext(values=False)) + + def connect_db(self, lmdb_path, save_to_self=False): + env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=256, + ) + if not save_to_self: + return env + else: + self.env = env + + def __len__(self): + return len(self._keys) + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + if not hasattr(self, "env"): + self.connect_db(self.db_path, save_to_self=True) + datapoint_pickled = self.env.begin().get(f"{idx}".encode("ascii")) + data = pickle.loads(datapoint_pickled) + return data diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/mask_points_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/mask_points_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..85622660cd3df8123eb29ff7be7c526030ac8628 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/mask_points_dataset.py @@ -0,0 +1,533 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from functools import lru_cache + +import numpy as np +import mindspore as ms +from unicore.data import Dictionary +from unicore.data import BaseWrapperDataset +from unimol.data import data_utils + + +class MaskPointsDataset(BaseWrapperDataset): + def __init__( + self, + dataset: ms.dataset.Dataset, # 替换为MindSpore数据集类型 + coord_dataset: ms.dataset.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + noise_type: str, + noise: float = 1.0, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.coord_dataset = coord_dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.noise_type = noise_type + self.noise = noise + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + + if random_token_prob > 0.0: + weights = np.ones(len(self.vocab)) + weights[vocab.special_index()] = 0 + self.weights = weights / weights.sum() + + self.epoch = None + if self.noise_type == "trunc_normal": + self.noise_f = lambda num_mask: np.clip( + np.random.randn(num_mask, 3) * self.noise, + a_min=-self.noise * 2.0, + a_max=self.noise * 2.0, + ) + elif self.noise_type == "normal": + self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise + elif self.noise_type == "uniform": + self.noise_f = lambda num_mask: np.random.uniform( + low=-self.noise, high=self.noise, size=(num_mask, 3) + ) + else: + self.noise_f = lambda num_mask: 0.0 + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.coord_dataset.set_epoch(epoch) + self.dataset.set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.epoch, index) + + @lru_cache(maxsize=16) + def __getitem_cached__(self, epoch: int, index: int): + ret = {} + with data_utils.numpy_seed(self.seed, epoch, index): + item = self.dataset[index] + coord = self.coord_dataset[index] + sz = len(item) + # don't allow empty sequence + assert sz > 0 + # decide elements to mask + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * sz + + np.random.rand() + ) + mask_idc = np.random.choice(sz, num_mask, replace=False) + mask = np.full(sz, False) + mask[mask_idc] = True + ret["targets"] = np.full(len(mask), self.pad_idx) + ret["targets"][mask] = item[mask] + # 替换PyTorch张量转换为MindSpore张量,long()对应int64 + ret["targets"] = ms.Tensor(ret["targets"], dtype=ms.int64) + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + + num_mask = mask.astype(np.int32).sum() + new_coord = np.copy(coord) + new_coord[mask, :] += self.noise_f(num_mask) + + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + # 替换PyTorch张量转换为MindSpore张量 + ret["atoms"] = ms.Tensor(new_item, dtype=ms.int64) + ret["coordinates"] = ms.Tensor(new_coord, dtype=ms.float32) + return ret + + +class MaskPointsPocketDataset(BaseWrapperDataset): + def __init__( + self, + dataset: ms.dataset.Dataset, # 替换为MindSpore数据集类型 + coord_dataset: ms.dataset.Dataset, + residue_dataset: ms.dataset.Dataset, + vocab: Dictionary, + pad_idx: int, + mask_idx: int, + noise_type: str, + noise: float = 1.0, + seed: int = 1, + mask_prob: float = 0.15, + leave_unmasked_prob: float = 0.1, + random_token_prob: float = 0.1, + ): + assert 0.0 < mask_prob < 1.0 + assert 0.0 <= random_token_prob <= 1.0 + assert 0.0 <= leave_unmasked_prob <= 1.0 + assert random_token_prob + leave_unmasked_prob <= 1.0 + + self.dataset = dataset + self.coord_dataset = coord_dataset + self.residue_dataset = residue_dataset + self.vocab = vocab + self.pad_idx = pad_idx + self.mask_idx = mask_idx + self.noise_type = noise_type + self.noise = noise + self.seed = seed + self.mask_prob = mask_prob + self.leave_unmasked_prob = leave_unmasked_prob + self.random_token_prob = random_token_prob + + if random_token_prob > 0.0: + weights = np.ones(len(self.vocab)) + weights[vocab.special_index()] = 0 + self.weights = weights / weights.sum() + + self.epoch = None + if self.noise_type == "trunc_normal": + self.noise_f = lambda num_mask: np.clip( + np.random.randn(num_mask, 3) * self.noise, + a_min=-self.noise * 2.0, + a_max=self.noise * 2.0, + ) + elif self.noise_type == "normal": + self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise + elif self.noise_type == "uniform": + self.noise_f = lambda num_mask: np.random.uniform( + low=-self.noise, high=self.noise, size=(num_mask, 3) + ) + else: + self.noise_f = lambda num_mask: 0.0 + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.coord_dataset.set_epoch(epoch) + self.dataset.set_epoch(epoch) + self.epoch = epoch + + def __getitem__(self, index: int): + return self.__getitem_cached__(self.epoch, index) + + @lru_cache(maxsize=16) + def __getitem_cached__(self, epoch: int, index: int): + ret = {} + with data_utils.numpy_seed(self.seed, epoch, index): + item = self.dataset[index] + coord = self.coord_dataset[index] + sz = len(item) + # don't allow empty sequence + assert sz > 0 + + # mask on the level of residues + residue = self.residue_dataset[index] + res_list = list(set(residue)) + res_sz = len(res_list) + + # decide elements to mask + num_mask = int( + # add a random number for probabilistic rounding + self.mask_prob * res_sz + + np.random.rand() + ) + mask_res = np.random.choice(res_list, num_mask, replace=False).tolist() + mask = np.isin(residue, mask_res) + + ret["targets"] = np.full(len(mask), self.pad_idx) + ret["targets"][mask] = item[mask] + # 替换PyTorch张量转换为MindSpore张量 + ret["targets"] = ms.Tensor(ret["targets"], dtype=ms.int64) + # decide unmasking and random replacement + rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob + if rand_or_unmask_prob > 0.0: + rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) + if self.random_token_prob == 0.0: + unmask = rand_or_unmask + rand_mask = None + elif self.leave_unmasked_prob == 0.0: + unmask = None + rand_mask = rand_or_unmask + else: + unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob + decision = np.random.rand(sz) < unmask_prob + unmask = rand_or_unmask & decision + rand_mask = rand_or_unmask & (~decision) + else: + unmask = rand_mask = None + + if unmask is not None: + mask = mask ^ unmask + + new_item = np.copy(item) + new_item[mask] = self.mask_idx + + num_mask = mask.astype(np.int32).sum() + new_coord = np.copy(coord) + new_coord[mask, :] += self.noise_f(num_mask) + + if rand_mask is not None: + num_rand = rand_mask.sum() + if num_rand > 0: + new_item[rand_mask] = np.random.choice( + len(self.vocab), + num_rand, + p=self.weights, + ) + # 替换PyTorch张量转换为MindSpore张量 + ret["atoms"] = ms.Tensor(new_item, dtype=ms.int64) + ret["coordinates"] = ms.Tensor(new_coord, dtype=ms.float32) + return ret +# from functools import lru_cache + +# import numpy as np +# import torch +# from unicore.data import Dictionary +# from unicore.data import BaseWrapperDataset +# from . import data_utils + + +# class MaskPointsDataset(BaseWrapperDataset): +# def __init__( +# self, +# dataset: torch.utils.data.Dataset, +# coord_dataset: torch.utils.data.Dataset, +# vocab: Dictionary, +# pad_idx: int, +# mask_idx: int, +# noise_type: str, +# noise: float = 1.0, +# seed: int = 1, +# mask_prob: float = 0.15, +# leave_unmasked_prob: float = 0.1, +# random_token_prob: float = 0.1, +# ): +# assert 0.0 < mask_prob < 1.0 +# assert 0.0 <= random_token_prob <= 1.0 +# assert 0.0 <= leave_unmasked_prob <= 1.0 +# assert random_token_prob + leave_unmasked_prob <= 1.0 + +# self.dataset = dataset +# self.coord_dataset = coord_dataset +# self.vocab = vocab +# self.pad_idx = pad_idx +# self.mask_idx = mask_idx +# self.noise_type = noise_type +# self.noise = noise +# self.seed = seed +# self.mask_prob = mask_prob +# self.leave_unmasked_prob = leave_unmasked_prob +# self.random_token_prob = random_token_prob + +# if random_token_prob > 0.0: +# weights = np.ones(len(self.vocab)) +# weights[vocab.special_index()] = 0 +# self.weights = weights / weights.sum() + +# self.epoch = None +# if self.noise_type == "trunc_normal": +# self.noise_f = lambda num_mask: np.clip( +# np.random.randn(num_mask, 3) * self.noise, +# a_min=-self.noise * 2.0, +# a_max=self.noise * 2.0, +# ) +# elif self.noise_type == "normal": +# self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise +# elif self.noise_type == "uniform": +# self.noise_f = lambda num_mask: np.random.uniform( +# low=-self.noise, high=self.noise, size=(num_mask, 3) +# ) +# else: +# self.noise_f = lambda num_mask: 0.0 + +# def set_epoch(self, epoch, **unused): +# super().set_epoch(epoch) +# self.coord_dataset.set_epoch(epoch) +# self.dataset.set_epoch(epoch) +# self.epoch = epoch + +# def __getitem__(self, index: int): +# return self.__getitem_cached__(self.epoch, index) + +# @lru_cache(maxsize=16) +# def __getitem_cached__(self, epoch: int, index: int): +# ret = {} +# with data_utils.numpy_seed(self.seed, epoch, index): +# item = self.dataset[index] +# coord = self.coord_dataset[index] +# sz = len(item) +# # don't allow empty sequence +# assert sz > 0 +# # decide elements to mask +# num_mask = int( +# # add a random number for probabilistic rounding +# self.mask_prob * sz +# + np.random.rand() +# ) +# mask_idc = np.random.choice(sz, num_mask, replace=False) +# mask = np.full(sz, False) +# mask[mask_idc] = True +# ret["targets"] = np.full(len(mask), self.pad_idx) +# ret["targets"][mask] = item[mask] +# ret["targets"] = torch.from_numpy(ret["targets"]).long() +# # decide unmasking and random replacement +# rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob +# if rand_or_unmask_prob > 0.0: +# rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) +# if self.random_token_prob == 0.0: +# unmask = rand_or_unmask +# rand_mask = None +# elif self.leave_unmasked_prob == 0.0: +# unmask = None +# rand_mask = rand_or_unmask +# else: +# unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob +# decision = np.random.rand(sz) < unmask_prob +# unmask = rand_or_unmask & decision +# rand_mask = rand_or_unmask & (~decision) +# else: +# unmask = rand_mask = None + +# if unmask is not None: +# mask = mask ^ unmask + +# new_item = np.copy(item) +# new_item[mask] = self.mask_idx + +# num_mask = mask.astype(np.int32).sum() +# new_coord = np.copy(coord) +# new_coord[mask, :] += self.noise_f(num_mask) + +# if rand_mask is not None: +# num_rand = rand_mask.sum() +# if num_rand > 0: +# new_item[rand_mask] = np.random.choice( +# len(self.vocab), +# num_rand, +# p=self.weights, +# ) +# ret["atoms"] = torch.from_numpy(new_item).long() +# ret["coordinates"] = torch.from_numpy(new_coord).float() +# return ret + + +# class MaskPointsPocketDataset(BaseWrapperDataset): +# def __init__( +# self, +# dataset: torch.utils.data.Dataset, +# coord_dataset: torch.utils.data.Dataset, +# residue_dataset: torch.utils.data.Dataset, +# vocab: Dictionary, +# pad_idx: int, +# mask_idx: int, +# noise_type: str, +# noise: float = 1.0, +# seed: int = 1, +# mask_prob: float = 0.15, +# leave_unmasked_prob: float = 0.1, +# random_token_prob: float = 0.1, +# ): +# assert 0.0 < mask_prob < 1.0 +# assert 0.0 <= random_token_prob <= 1.0 +# assert 0.0 <= leave_unmasked_prob <= 1.0 +# assert random_token_prob + leave_unmasked_prob <= 1.0 + +# self.dataset = dataset +# self.coord_dataset = coord_dataset +# self.residue_dataset = residue_dataset +# self.vocab = vocab +# self.pad_idx = pad_idx +# self.mask_idx = mask_idx +# self.noise_type = noise_type +# self.noise = noise +# self.seed = seed +# self.mask_prob = mask_prob +# self.leave_unmasked_prob = leave_unmasked_prob +# self.random_token_prob = random_token_prob + +# if random_token_prob > 0.0: +# weights = np.ones(len(self.vocab)) +# weights[vocab.special_index()] = 0 +# self.weights = weights / weights.sum() + +# self.epoch = None +# if self.noise_type == "trunc_normal": +# self.noise_f = lambda num_mask: np.clip( +# np.random.randn(num_mask, 3) * self.noise, +# a_min=-self.noise * 2.0, +# a_max=self.noise * 2.0, +# ) +# elif self.noise_type == "normal": +# self.noise_f = lambda num_mask: np.random.randn(num_mask, 3) * self.noise +# elif self.noise_type == "uniform": +# self.noise_f = lambda num_mask: np.random.uniform( +# low=-self.noise, high=self.noise, size=(num_mask, 3) +# ) +# else: +# self.noise_f = lambda num_mask: 0.0 + +# def set_epoch(self, epoch, **unused): +# super().set_epoch(epoch) +# self.coord_dataset.set_epoch(epoch) +# self.dataset.set_epoch(epoch) +# self.epoch = epoch + +# def __getitem__(self, index: int): +# return self.__getitem_cached__(self.epoch, index) + +# @lru_cache(maxsize=16) +# def __getitem_cached__(self, epoch: int, index: int): +# ret = {} +# with data_utils.numpy_seed(self.seed, epoch, index): +# item = self.dataset[index] +# coord = self.coord_dataset[index] +# sz = len(item) +# # don't allow empty sequence +# assert sz > 0 + +# # mask on the level of residues +# residue = self.residue_dataset[index] +# res_list = list(set(residue)) +# res_sz = len(res_list) + +# # decide elements to mask +# num_mask = int( +# # add a random number for probabilistic rounding +# self.mask_prob * res_sz +# + np.random.rand() +# ) +# mask_res = np.random.choice(res_list, num_mask, replace=False).tolist() +# mask = np.isin(residue, mask_res) + +# ret["targets"] = np.full(len(mask), self.pad_idx) +# ret["targets"][mask] = item[mask] +# ret["targets"] = torch.from_numpy(ret["targets"]).long() +# # decide unmasking and random replacement +# rand_or_unmask_prob = self.random_token_prob + self.leave_unmasked_prob +# if rand_or_unmask_prob > 0.0: +# rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob) +# if self.random_token_prob == 0.0: +# unmask = rand_or_unmask +# rand_mask = None +# elif self.leave_unmasked_prob == 0.0: +# unmask = None +# rand_mask = rand_or_unmask +# else: +# unmask_prob = self.leave_unmasked_prob / rand_or_unmask_prob +# decision = np.random.rand(sz) < unmask_prob +# unmask = rand_or_unmask & decision +# rand_mask = rand_or_unmask & (~decision) +# else: +# unmask = rand_mask = None + +# if unmask is not None: +# mask = mask ^ unmask + +# new_item = np.copy(item) +# new_item[mask] = self.mask_idx + +# num_mask = mask.astype(np.int32).sum() +# new_coord = np.copy(coord) +# new_coord[mask, :] += self.noise_f(num_mask) + +# if rand_mask is not None: +# num_rand = rand_mask.sum() +# if num_rand > 0: +# new_item[rand_mask] = np.random.choice( +# len(self.vocab), +# num_rand, +# p=self.weights, +# ) +# ret["atoms"] = torch.from_numpy(new_item).long() +# ret["coordinates"] = torch.from_numpy(new_coord).float() +# return ret diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/normalize_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/normalize_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..836f77daf53ea8e061f61d67dfabae2ecda02165 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/normalize_dataset.py @@ -0,0 +1,68 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class NormalizeDataset(BaseWrapperDataset): + def __init__(self, dataset, coordinates, normalize_coord=True): + self.dataset = dataset + self.coordinates = coordinates + self.normalize_coord = normalize_coord # normalize the coordinates. + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + coordinates = dd[self.coordinates] + # normalize + if self.normalize_coord: + coordinates = coordinates - coordinates.mean(axis=0) + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class NormalizeDockingPoseDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + coordinates, + pocket_coordinates, + center_coordinates="center_coordinates", + ): + self.dataset = dataset + self.coordinates = coordinates + self.pocket_coordinates = pocket_coordinates + self.center_coordinates = center_coordinates + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + coordinates = dd[self.coordinates] + pocket_coordinates = dd[self.pocket_coordinates] + # normalize coordinates and pocket coordinates ,align with pocket center coordinates + center_coordinates = pocket_coordinates.mean(axis=0) + coordinates = coordinates - center_coordinates + pocket_coordinates = pocket_coordinates - center_coordinates + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.pocket_coordinates] = pocket_coordinates.astype(np.float32) + dd[self.center_coordinates] = center_coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/prepend_and_append_2d_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/prepend_and_append_2d_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a02ce247008a862f825f9299dea19cb302b72756 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/prepend_and_append_2d_dataset.py @@ -0,0 +1,44 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class PrependAndAppend2DDataset(BaseWrapperDataset): + def __init__(self, dataset, token=None): + super().__init__(dataset) + self.token = token + + @lru_cache(maxsize=16) + def __getitem__(self, idx): + item = self.dataset[idx] + if self.token is not None: + # 获取2D数据的高度和宽度(替换torch.size为mindspore的shape属性) + h, w = item.shape[-2], item.shape[-1] + # 创建填充指定token的新张量(替换torch.full为mindspore.ops.full,保持数据类型一致) + new_item = ms.ops.full((h + 2, w + 2), self.token, dtype=item.dtype) + # 将原始数据填充到新张量的中间区域 + new_item[1:-1, 1:-1] = item + return new_item + return item +# import torch +# from functools import lru_cache +# from unicore.data import BaseWrapperDataset + + +# class PrependAndAppend2DDataset(BaseWrapperDataset): +# def __init__(self, dataset, token=None): +# super().__init__(dataset) +# self.token = token + +# @lru_cache(maxsize=16) +# def __getitem__(self, idx): +# item = self.dataset[idx] +# if self.token is not None: +# h, w = item.size(-2), item.size(-1) +# new_item = torch.full((h + 2, w + 2), self.token).type_as(item) +# new_item[1:-1, 1:-1] = item +# return new_item +# return item diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/remove_hydrogen_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/remove_hydrogen_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..787b9f1a38402ac96c74184f03d34dede1f4f97d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/remove_hydrogen_dataset.py @@ -0,0 +1,149 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class RemoveHydrogenDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + atoms, + coordinates, + remove_hydrogen=False, + remove_polar_hydrogen=False, + ): + self.dataset = dataset + self.atoms = atoms + self.coordinates = coordinates + self.remove_hydrogen = remove_hydrogen + self.remove_polar_hydrogen = remove_polar_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + coordinates = coordinates[mask_hydrogen] + if not self.remove_hydrogen and self.remove_polar_hydrogen: + end_idx = 0 + for i, atom in enumerate(atoms[::-1]): + if atom != "H": + break + else: + end_idx = i + 1 + if end_idx != 0: + atoms = atoms[:-end_idx] + coordinates = coordinates[:-end_idx] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class RemoveHydrogenResiduePocketDataset(BaseWrapperDataset): + def __init__(self, dataset, atoms, residues, coordinates, remove_hydrogen=True): + self.dataset = dataset + self.atoms = atoms + self.residues = residues + self.coordinates = coordinates + self.remove_hydrogen = remove_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + residues = dd[self.residues] + coordinates = dd[self.coordinates] + if len(atoms) != len(residues): + min_len = min(len(atoms), len(residues)) + atoms = atoms[:min_len] + residues = residues[:min_len] + coordinates = coordinates[:min_len, :] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + residues = residues[mask_hydrogen] + coordinates = coordinates[mask_hydrogen] + + dd[self.atoms] = atoms + dd[self.residues] = residues + dd[self.coordinates] = coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class RemoveHydrogenPocketDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + atoms, + coordinates, + holo_coordinates, + remove_hydrogen=True, + remove_polar_hydrogen=False, + ): + self.dataset = dataset + self.atoms = atoms + self.coordinates = coordinates + self.holo_coordinates = holo_coordinates + self.remove_hydrogen = remove_hydrogen + self.remove_polar_hydrogen = remove_polar_hydrogen + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + dd = self.dataset[index].copy() + atoms = dd[self.atoms] + coordinates = dd[self.coordinates] + holo_coordinates = dd[self.holo_coordinates] + + if self.remove_hydrogen: + mask_hydrogen = atoms != "H" + atoms = atoms[mask_hydrogen] + coordinates = coordinates[mask_hydrogen] + holo_coordinates = holo_coordinates[mask_hydrogen] + if not self.remove_hydrogen and self.remove_polar_hydrogen: + end_idx = 0 + for i, atom in enumerate(atoms[::-1]): + if atom != "H": + break + else: + end_idx = i + 1 + if end_idx != 0: + atoms = atoms[:-end_idx] + coordinates = coordinates[:-end_idx] + holo_coordinates = holo_coordinates[:-end_idx] + dd[self.atoms] = atoms + dd[self.coordinates] = coordinates.astype(np.float32) + dd[self.holo_coordinates] = holo_coordinates.astype(np.float32) + return dd + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/data/tta_dataset.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/tta_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a28c31e9dd9b959b1d64beefa95c2c6559c45cc0 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/data/tta_dataset.py @@ -0,0 +1,110 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from functools import lru_cache +from unicore.data import BaseWrapperDataset + + +class TTADataset(BaseWrapperDataset): + def __init__(self, dataset, seed, atoms, coordinates, conf_size=10): + self.dataset = dataset + self.seed = seed + self.atoms = atoms + self.coordinates = coordinates + self.conf_size = conf_size + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __len__(self): + return len(self.dataset) * self.conf_size + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + smi_idx = index // self.conf_size + coord_idx = index % self.conf_size + atoms = np.array(self.dataset[smi_idx][self.atoms]) + coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) + smi = self.dataset[smi_idx]["smi"] + target = self.dataset[smi_idx].get("target", None) + return { + "atoms": atoms, + "coordinates": coordinates.astype(np.float32), + "smi": smi, + "target": target, + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) + + +class TTADockingPoseDataset(BaseWrapperDataset): + def __init__( + self, + dataset, + atoms, + coordinates, + pocket_atoms, + pocket_coordinates, + holo_coordinates, + holo_pocket_coordinates, + is_train=True, + conf_size=10, + ): + self.dataset = dataset + self.atoms = atoms + self.coordinates = coordinates + self.pocket_atoms = pocket_atoms + self.pocket_coordinates = pocket_coordinates + self.holo_coordinates = holo_coordinates + self.holo_pocket_coordinates = holo_pocket_coordinates + self.is_train = is_train + self.conf_size = conf_size + self.set_epoch(None) + + def set_epoch(self, epoch, **unused): + super().set_epoch(epoch) + self.epoch = epoch + + def __len__(self): + return len(self.dataset) * self.conf_size + + @lru_cache(maxsize=16) + def __cached_item__(self, index: int, epoch: int): + smi_idx = index // self.conf_size + coord_idx = index % self.conf_size + atoms = np.array(self.dataset[smi_idx][self.atoms]) + coordinates = np.array(self.dataset[smi_idx][self.coordinates][coord_idx]) + pocket_atoms = np.array( + [item[0] for item in self.dataset[smi_idx][self.pocket_atoms]] + ) + pocket_coordinates = np.array(self.dataset[smi_idx][self.pocket_coordinates][0]) + if self.is_train: + holo_coordinates = np.array(self.dataset[smi_idx][self.holo_coordinates][0]) + holo_pocket_coordinates = np.array( + self.dataset[smi_idx][self.holo_pocket_coordinates][0] + ) + else: + holo_coordinates = coordinates + holo_pocket_coordinates = pocket_coordinates + + smi = self.dataset[smi_idx]["smi"] + pocket = self.dataset[smi_idx]["pocket"] + + return { + "atoms": atoms, + "coordinates": coordinates.astype(np.float32), + "pocket_atoms": pocket_atoms, + "pocket_coordinates": pocket_coordinates.astype(np.float32), + "holo_coordinates": holo_coordinates.astype(np.float32), + "holo_pocket_coordinates": holo_pocket_coordinates.astype(np.float32), + "smi": smi, + "pocket": pocket, + } + + def __getitem__(self, index: int): + return self.__cached_item__(index, self.epoch) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/fusion_result.json b/MindChemistry/applications/Uni-Mol/unimol/unimol/fusion_result.json new file mode 100644 index 0000000000000000000000000000000000000000..231870b605edb2088d0059f7cca0bce484d4122c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/fusion_result.json @@ -0,0 +1,18 @@ +{ + "session_and_graph_id_1_0": { + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "0", + "match_times": "1" + } + } + }, + "session_and_graph_id_2_1": { + "graph_fusion": { + "RefreshInt64ToInt32FusionPass": { + "effect_times": "0", + "match_times": "1" + } + } + } +} \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/infer.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..79334a42a3f31dcb82fcf611752c02801d2b089c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/infer.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 -u +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import os +import sys +import pickle +import mindspore as ms +from mindspore import Tensor, set_context, get_context +from unicore import checkpoint_utils, options, utils +from unicore.logging import progress_bar +from unicore import tasks + +# 设置Ascend设备(单卡环境) +set_context(device_target="Ascend") + +logging.basicConfig( + format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + level=os.environ.get("LOGLEVEL", "INFO").upper(), + stream=sys.stdout, +) +logger = logging.getLogger("unimol.inference") + + +def main(args): + assert ( + args.batch_size is not None + ), "Must specify batch size either with --batch-size" + + use_fp16 = args.fp16 + # 单卡Ascend环境,无需CPU判断 + use_ascend = get_context("device_target") == "Ascend" + + # 单卡环境,无需分布式处理 + data_parallel_world_size = 1 + data_parallel_rank = 0 + + # 加载模型 + logger.info(f"loading model(s) from {args.path}") + # MindSpore加载 checkpoint(假设checkpoint_utils已适配) + state = checkpoint_utils.load_checkpoint_to_cpu(args.path) + task = tasks.setup_task(args) + model = task.build_model(args) + # 加载模型参数(MindSpore中load_state_dict与PyTorch类似) + model.load_state_dict(state["model"], strict=False) + + # 模型数据类型设置(FP16) + if use_ascend and use_fp16: + model.set_dtype(ms.float16) + + # 打印参数 + logger.info(args) + + # 构建损失函数 + loss = task.build_loss(args) + loss.set_train(False) # MindSpore中设置评估模式 + + for subset in args.valid_subset.split(","): + try: + task.load_dataset(subset, combine=False, epoch=1) + dataset = task.dataset(subset) + except KeyError: + raise Exception(f"Cannot find dataset: {subset}") + + if not os.path.exists(args.results_path): + os.makedirs(args.results_path) + try: + fname = args.path.split("/")[-2] + except: + fname = 'infer' + save_path = os.path.join(args.results_path, f"{fname}_{subset}.out.pkl") + + # 初始化数据迭代器 + itr = task.get_batch_iterator( + dataset=dataset, + batch_size=args.batch_size, + ignore_invalid_inputs=True, + required_batch_size_multiple=args.required_batch_size_multiple, + seed=args.seed, + num_shards=data_parallel_world_size, + shard_id=data_parallel_rank, + num_workers=args.num_workers, + data_buffer_size=args.data_buffer_size, + ).next_epoch_itr(shuffle=False) + + progress = progress_bar.progress_bar( + itr, + log_format=args.log_format, + log_interval=args.log_interval, + prefix=f"valid on '{subset}' subset", + default_log_format=("tqdm" if not args.no_progress_bar else "simple"), + ) + + log_outputs = [] + for i, sample in enumerate(progress): + # MindSpore中无需显式移动数据到设备(自动根据context处理) + # 若数据为numpy格式,转换为MindSpore张量 + if isinstance(sample, dict): + sample = {k: Tensor(v) if not isinstance(v, Tensor) else v for k, v in sample.items()} + if len(sample) == 0: + continue + # 验证步骤(MindSpore中推理无需特殊梯度设置) + _, _, log_output = task.valid_step(sample, model, loss, test=True) + progress.log({}, step=i) + log_outputs.append(log_output) + + pickle.dump(log_outputs, open(save_path, "wb")) + logger.info("Done inference! ") + return None + + +def cli_main(): + parser = options.get_validation_parser() + options.add_model_args(parser) + args = options.parse_args_and_arch(parser) + + # 单卡环境,无需分布式调用 + main(args) + + +if __name__ == "__main__": + cli_main() +# import logging +# import os +# import sys +# import pickle +# import torch +# from unicore import checkpoint_utils, distributed_utils, options, utils +# from unicore.logging import progress_bar +# from unicore import tasks + +# logging.basicConfig( +# format="%(asctime)s | %(levelname)s | %(name)s | %(message)s", +# datefmt="%Y-%m-%d %H:%M:%S", +# level=os.environ.get("LOGLEVEL", "INFO").upper(), +# stream=sys.stdout, +# ) +# logger = logging.getLogger("unimol.inference") + + +# def main(args): + +# assert ( +# args.batch_size is not None +# ), "Must specify batch size either with --batch-size" + +# use_fp16 = args.fp16 +# use_cuda = torch.cuda.is_available() and not args.cpu + +# if use_cuda: +# torch.cuda.set_device(args.device_id) + +# if args.distributed_world_size > 1: +# data_parallel_world_size = distributed_utils.get_data_parallel_world_size() +# data_parallel_rank = distributed_utils.get_data_parallel_rank() +# else: +# data_parallel_world_size = 1 +# data_parallel_rank = 0 + +# # Load model +# logger.info("loading model(s) from {}".format(args.path)) +# state = checkpoint_utils.load_checkpoint_to_cpu(args.path) +# task = tasks.setup_task(args) +# model = task.build_model(args) +# model.load_state_dict(state["model"], strict=False) + +# # Move models to GPU +# if use_cuda: +# model.cuda() +# # fp16 only supported on CUDA for fused kernels +# if use_fp16: +# model.half() + +# # Print args +# logger.info(args) + +# # Build loss +# loss = task.build_loss(args) +# loss.eval() + +# for subset in args.valid_subset.split(","): +# try: +# task.load_dataset(subset, combine=False, epoch=1) +# dataset = task.dataset(subset) +# except KeyError: +# raise Exception("Cannot find dataset: " + subset) + +# if not os.path.exists(args.results_path): +# os.makedirs(args.results_path) +# try: +# fname = (args.path).split("/")[-2] +# except: +# fname = 'infer' +# save_path = os.path.join(args.results_path, fname + "_" + subset + ".out.pkl") +# # Initialize data iterator +# itr = task.get_batch_iterator( +# dataset=dataset, +# batch_size=args.batch_size, +# ignore_invalid_inputs=True, +# required_batch_size_multiple=args.required_batch_size_multiple, +# seed=args.seed, +# num_shards=data_parallel_world_size, +# shard_id=data_parallel_rank, +# num_workers=args.num_workers, +# data_buffer_size=args.data_buffer_size, +# ).next_epoch_itr(shuffle=False) +# progress = progress_bar.progress_bar( +# itr, +# log_format=args.log_format, +# log_interval=args.log_interval, +# prefix=f"valid on '{subset}' subset", +# default_log_format=("tqdm" if not args.no_progress_bar else "simple"), +# ) +# log_outputs = [] +# for i, sample in enumerate(progress): +# sample = utils.move_to_cuda(sample) if use_cuda else sample +# if len(sample) == 0: +# continue +# _, _, log_output = task.valid_step(sample, model, loss, test=True) +# progress.log({}, step=i) +# log_outputs.append(log_output) +# pickle.dump(log_outputs, open(save_path, "wb")) +# logger.info("Done inference! ") +# return None + + +# def cli_main(): +# parser = options.get_validation_parser() +# options.add_model_args(parser) +# args = options.parse_args_and_arch(parser) + +# distributed_utils.call_main(args, main) + + +# if __name__ == "__main__": +# cli_main() diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..dcf6c62a83aedf027719b8930ca2cacff858f802 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/__init__.py @@ -0,0 +1,7 @@ +from pathlib import Path +import importlib + +# automatically import any Python files in the criterions/ directory +for file in sorted(Path(__file__).parent.glob("*.py")): + if not file.name.startswith("_"): + importlib.import_module("unimol.losses." + file.name[:-3]) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/conf_gen.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/conf_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..f49104e8afdd9b852de5a191d7f08d9ff34a8155 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/conf_gen.py @@ -0,0 +1,288 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import mindspore.ops as F +import numpy as np +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss +from scipy.spatial.transform import Rotation as R + + +@register_loss("mol_confG") +class MolConfGLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + self.padding_idx = task.dictionary.pad() + self.eos_idx = task.dictionary.eos() + self.bos_idx = task.dictionary.bos() + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model(**sample["net_input"]) + distance_loss, coord_loss = self.compute_loss( + model, net_output, sample, reduce=reduce + ) + # 替换torch.size为mindspore.shape,获取批次大小 + sample_size = sample["target"]["coord_target"].shape[0] + loss = ( + self.args.coord_loss * coord_loss + self.args.distance_loss * distance_loss + ) + # 移除.data属性,MindSpore张量直接访问 + logging_output = { + "loss": loss, + "distance_loss": distance_loss, + "coord_loss": coord_loss, + "bsz": sample["target"]["coord_target"].shape[0], + "sample_size": 1, + "coord_predict": net_output[-1], + "coord_target": sample["target"]["coord_target"], + "distance_predict": net_output[0], + } + if not self.training: + logging_output["smi_name"] = sample["smi_name"] + + return loss, 1, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + distance_predict, coord_predict = net_output[0], net_output[-1] + # 替换torch.ne为mindspore.ne,逻辑判断保持一致 + token_mask = sample["net_input"]["src_tokens"].ne(self.padding_idx) # B,L + token_mask &= sample["net_input"]["src_tokens"].ne(self.eos_idx) + token_mask &= sample["net_input"]["src_tokens"].ne(self.bos_idx) + distance_mask, coord_mask = calc_mask(token_mask) + + # 调整sum的keepdims参数,MindSpore与PyTorch用法一致 + mean_coord = (coord_mask * coord_predict).sum(dim=1) / token_mask.sum( + dim=1, keepdims=True + ) + coord_predict = coord_predict - mean_coord.unsqueeze(dim=1) + + # 距离损失计算,F.l1_loss在MindSpore中功能一致 + distance_predict = distance_predict[distance_mask] + distance_target = sample["target"]["distance_target"][distance_mask] + distance_loss = F.l1_loss( + distance_predict.astype(ms.float32), + distance_target.astype(ms.float32), + reduction="mean", + ) + + # 坐标损失计算 + coord_target = sample["target"]["coord_target"] # B, L, 3 + new_coord_target = realign_coord(coord_predict, coord_target, token_mask) + coord_predict = coord_predict[coord_mask] + new_coord_target = new_coord_target[coord_mask] + coord_loss = F.l1_loss( + coord_predict.astype(ms.float32), + new_coord_target.astype(ms.float32), + reduction="mean", + ) + + return distance_loss, coord_loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=5) + distance_loss = sum(log.get("distance_loss", 0) for log in logging_outputs) + if distance_loss > 0: + metrics.log_scalar( + "distance_loss", distance_loss / sample_size, sample_size, round=5 + ) + coord_loss = sum(log.get("coord_loss", 0) for log in logging_outputs) + if coord_loss > 0: + metrics.log_scalar( + "coord_loss", coord_loss / sample_size, sample_size, round=5 + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """Whether logging outputs can be summed across workers.""" + return is_train + + +def realign_coord(coord_predict, coord_target, token_mask): + # 替换torch.zeros_like为mindspore.zeros_like + new_coord_target = ms.zeros_like(coord_target).astype(coord_target.dtype) + bs = token_mask.shape[0] + + for i in range(bs): + _coord_predict = coord_predict[i] + _coord_target = coord_target[i] + _token_mask = token_mask[i] + + # 替换.detach().cpu().numpy()为.asnumpy()(MindSpore张量转numpy) + _coord_predict = _coord_predict[_token_mask].asnumpy() + _coord_target = _coord_target[_token_mask].asnumpy() + + _coord_predict = _coord_predict - _coord_predict.mean(axis=0) + _coord_target = _coord_target - _coord_target.mean(axis=0) + + _r = ( + R.align_vectors(_coord_target, _coord_predict)[0] + .as_matrix() + .astype(np.float32) + ) + # 替换torch.from_numpy为mindspore.Tensor,保持数据类型一致 + _new_coord_target = ms.Tensor(np.dot(_coord_target, _r), dtype=coord_target.dtype) + new_coord_target[i, _token_mask, :] = _new_coord_target + + return new_coord_target + + +def calc_mask(token_mask): + sz = token_mask.shape + # 替换torch.zeros为mindspore.ops.zeros,保持数据类型一致 + distance_mask = ms.ops.zeros((sz[0], sz[1], sz[1]), dtype=token_mask.dtype) + distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1) + coord_mask = ms.ops.zeros((sz[0], sz[1], 3), dtype=token_mask.dtype) + # 替换masked_fill_(原地操作)为masked_fill(非原地操作) + coord_mask = coord_mask.masked_fill(token_mask.unsqueeze(-1), ms.Tensor(True, dtype=ms.bool_)) + return distance_mask, coord_mask +# import torch +# import torch.nn.functional as F +# import numpy as np +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss +# from scipy.spatial.transform import Rotation as R + + +# @register_loss("mol_confG") +# class MolConfGLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) +# self.padding_idx = task.dictionary.pad() +# self.eos_idx = task.dictionary.eos() +# self.bos_idx = task.dictionary.bos() + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model(**sample["net_input"]) +# distance_loss, coord_loss = self.compute_loss( +# model, net_output, sample, reduce=reduce +# ) +# sample_size = sample["target"]["coord_target"].size(0) +# loss = ( +# self.args.coord_loss * coord_loss + self.args.distance_loss * distance_loss +# ) +# logging_output = { +# "loss": loss.data, +# "distance_loss": distance_loss.data, +# "coord_loss": coord_loss.data, +# "bsz": sample["target"]["coord_target"].size(0), +# "sample_size": 1, +# "coord_predict": net_output[-1].data, +# "coord_target": sample["target"]["coord_target"].data, +# "distance_predict": net_output[0].data, +# } +# if not self.training: +# logging_output["smi_name"] = sample["smi_name"] + +# return loss, 1, logging_output + +# # reaglin coord in coord loss +# def compute_loss(self, model, net_output, sample, reduce=True): +# distance_predict, coord_predict = net_output[0], net_output[-1] +# token_mask = sample["net_input"]["src_tokens"].ne(self.padding_idx) # B,L +# token_mask &= sample["net_input"]["src_tokens"].ne(self.eos_idx) +# token_mask &= sample["net_input"]["src_tokens"].ne(self.bos_idx) +# distance_mask, coord_mask = calc_mask(token_mask) +# mean_coord = (coord_mask * coord_predict).sum(dim=1) / token_mask.sum( +# dim=1, keepdims=True +# ) +# coord_predict = coord_predict - mean_coord.unsqueeze(dim=1) + +# # distance loss +# distance_predict = distance_predict[distance_mask] +# distance_target = sample["target"]["distance_target"][distance_mask] +# distance_loss = F.l1_loss( +# distance_predict.float(), +# distance_target.float(), +# reduction="mean", +# ) + +# # coord loss +# coord_target = sample["target"]["coord_target"] # B, L, 3 +# new_coord_target = realign_coord(coord_predict, coord_target, token_mask) +# coord_predict = coord_predict[coord_mask] +# new_coord_target = new_coord_target[coord_mask] +# coord_loss = F.l1_loss( +# coord_predict.float(), +# new_coord_target.float(), +# reduction="mean", +# ) + +# return distance_loss, coord_loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + +# metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=5) +# distance_loss = sum(log.get("distance_loss", 0) for log in logging_outputs) +# if distance_loss > 0: +# metrics.log_scalar( +# "distance_loss", distance_loss / sample_size, sample_size, round=5 +# ) +# coord_loss = sum(log.get("coord_loss", 0) for log in logging_outputs) +# if coord_loss > 0: +# metrics.log_scalar( +# "coord_loss", coord_loss / sample_size, sample_size, round=5 +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return is_train + + +# def realign_coord(coord_predict, coord_target, token_mask): +# new_coord_target = torch.zeros_like(coord_target).type_as(coord_target) +# bs = token_mask.size(0) + +# for i in range(bs): +# _coord_predict = coord_predict[i] +# _coord_target = coord_target[i] +# _token_mask = token_mask[i] + +# _coord_predict = _coord_predict[_token_mask].detach().cpu().numpy() +# _coord_target = _coord_target[_token_mask].detach().cpu().numpy() + +# _coord_predict = _coord_predict - _coord_predict.mean(axis=0) +# _coord_target = _coord_target - _coord_target.mean(axis=0) + +# _r = ( +# R.align_vectors(_coord_target, _coord_predict)[0] +# .as_matrix() +# .astype(np.float32) +# ) +# _new_coord_target = torch.from_numpy(np.dot(_coord_target, _r)).type_as( +# coord_target +# ) +# new_coord_target[i, _token_mask, :] = _new_coord_target + +# return new_coord_target + + +# def calc_mask(token_mask): +# sz = token_mask.size() +# distance_mask = torch.zeros(sz[0], sz[1], sz[1]).type_as(token_mask) +# distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1) +# coord_mask = torch.zeros(sz[0], sz[1], 3).type_as(token_mask) +# coord_mask.masked_fill_(token_mask.unsqueeze(-1), True) +# return distance_mask, coord_mask diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/cross_entropy.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/cross_entropy.py new file mode 100644 index 0000000000000000000000000000000000000000..68299a95a814cd9930875a8b11c456ec79d18ebd --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/cross_entropy.py @@ -0,0 +1,573 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math +import mindspore as ms +import mindspore.ops as F +import pandas as pd +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss +from unicore.losses.cross_entropy import CrossEntropyLoss +from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score +import numpy as np +import warnings + + +@register_loss("finetune_cross_entropy") +class FinetuneCrossEntropyLoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + logit_output = net_output[0] + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + # 替换size()为shape + sample_size = sample["target"]["finetune_target"].shape[0] + if not self.training: + # 替换softmax和view,view改为reshape + probs = F.softmax(logit_output.astype(ms.float32), axis=-1).reshape( + -1, logit_output.shape[-1] + ) + logging_output = { + "loss": loss, # 移除.data + "prob": probs, # 移除.data + "target": sample["target"]["finetune_target"].reshape(-1), # 移除.data,view改为reshape + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + else: + logging_output = { + "loss": loss, # 移除.data + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + # 替换log_softmax和view,view改为reshape + lprobs = F.log_softmax(net_output.astype(ms.float32), axis=-1) + lprobs = lprobs.reshape(-1, lprobs.shape[-1]) + targets = sample["target"]["finetune_target"].reshape(-1) # view改为reshape + loss = F.nll_loss( + lprobs, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + # 保持日志逻辑不变 + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + # 替换argmax和torch.cat,使用ms.ops.argmax和ms.ops.cat + acc_sum = sum( + sum(ms.ops.argmax(log.get("prob"), axis=-1) == log.get("target")) + for log in logging_outputs + ) + probs = ms.ops.cat([log.get("prob") for log in logging_outputs], axis=0) + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + if probs.shape[-1] == 2: + # 二分类任务,替换torch.cat + targets = ms.ops.cat( + [log.get("target", 0) for log in logging_outputs], axis=0 + ) + smi_list = [ + item for log in logging_outputs for item in log.get("smi_name") + ] + # 张量转numpy用asnumpy(),移除cpu() + df = pd.DataFrame( + { + "probs": probs[:, 1].asnumpy(), + "targets": targets.asnumpy(), + "smi": smi_list, + } + ) + auc = roc_auc_score(df["targets"], df["probs"]) + df = df.groupby("smi").mean() + agg_auc = roc_auc_score(df["targets"], df["probs"]) + metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + return is_train + + +@register_loss("multi_task_BCE") +class MultiTaskBCELoss(CrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model( + **sample["net_input"], + masked_tokens=None, + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + logit_output = net_output[0] + is_valid = sample["target"]["finetune_target"] > -0.5 + loss = self.compute_loss( + model, logit_output, sample, reduce=reduce, is_valid=is_valid + ) + # 替换size()为shape + sample_size = sample["target"]["finetune_target"].shape[0] + if not self.training: + # 替换sigmoid和view,view改为reshape + probs = ms.ops.sigmoid(logit_output.astype(ms.float32)).reshape( + -1, logit_output.shape[-1] + ) + logging_output = { + "loss": loss, # 移除.data + "prob": probs, # 移除.data + "target": sample["target"]["finetune_target"].reshape(-1), # 移除.data,view改为reshape + "num_task": self.args.num_classes, + "sample_size": sample_size, + "conf_size": self.args.conf_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + else: + logging_output = { + "loss": loss, # 移除.data + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None): + # 替换类型转换 + pred = net_output[is_valid].astype(ms.float32) + targets = sample["target"]["finetune_target"][is_valid].astype(ms.float32) + loss = F.binary_cross_entropy_with_logits( + pred, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + agg_auc_list = [] + num_task = logging_outputs[0].get("num_task", 0) + conf_size = logging_outputs[0].get("conf_size", 0) + # 替换torch.cat和view,view改为reshape,cpu().numpy()改为asnumpy() + y_true = ( + ms.ops.cat([log.get("target", 0) for log in logging_outputs], axis=0) + .reshape(-1, conf_size, num_task) + .asnumpy() + .mean(axis=1) + ) + y_pred = ( + ms.ops.cat([log.get("prob") for log in logging_outputs], axis=0) + .reshape(-1, conf_size, num_task) + .asnumpy() + .mean(axis=1) + ) + # 计算AUC逻辑不变 + for i in range(y_true.shape[1]): + if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: + is_labeled = y_true[:, i] > -0.5 + agg_auc_list.append( + roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) + ) + if len(agg_auc_list) < y_true.shape[1]: + warnings.warn("Some target is missing!") + if len(agg_auc_list) == 0: + raise RuntimeError( + "No positively labeled data available. Cannot compute Average Precision." + ) + agg_auc = sum(agg_auc_list) / len(agg_auc_list) + metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + return is_train + + +@register_loss("finetune_cross_entropy_pocket") +class FinetuneCrossEntropyPocketLoss(FinetuneCrossEntropyLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + logit_output = net_output[0] + loss = self.compute_loss(model, logit_output, sample, reduce=reduce) + # 替换size()为shape + sample_size = sample["target"]["finetune_target"].shape[0] + if not self.training: + # 替换softmax和view,view改为reshape + probs = F.softmax(logit_output.astype(ms.float32), axis=-1).reshape( + -1, logit_output.shape[-1] + ) + logging_output = { + "loss": loss, # 移除.data + "prob": probs, # 移除.data + "target": sample["target"]["finetune_target"].reshape(-1), # 移除.data,view改为reshape + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + else: + logging_output = { + "loss": loss, # 移除.data + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + # 替换argmax和sum逻辑 + acc_sum = sum( + sum(ms.ops.argmax(log.get("prob"), axis=-1) == log.get("target")) + for log in logging_outputs + ) + metrics.log_scalar( + f"{split}_acc", acc_sum / sample_size, sample_size, round=3 + ) + # 替换torch.cat、argmax和cpu().numpy() + preds = ( + ms.ops.cat( + [ms.ops.argmax(log.get("prob"), axis=-1) for log in logging_outputs], axis=0 + ) + .asnumpy() + ) + targets = ( + ms.ops.cat([log.get("target", 0) for log in logging_outputs], axis=0) + .asnumpy() + ) + metrics.log_scalar(f"{split}_pre", precision_score(targets, preds), round=3) + metrics.log_scalar(f"{split}_rec", recall_score(targets, preds), round=3) + metrics.log_scalar( + f"{split}_f1", f1_score(targets, preds), sample_size, round=3 + ) +# import math +# import torch +# import torch.nn.functional as F +# import pandas as pd +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss +# from unicore.losses.cross_entropy import CrossEntropyLoss +# from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score +# import numpy as np +# import warnings + + +# @register_loss("finetune_cross_entropy") +# class FinetuneCrossEntropyLoss(CrossEntropyLoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model( +# **sample["net_input"], +# features_only=True, +# classification_head_name=self.args.classification_head_name, +# ) +# logit_output = net_output[0] +# loss = self.compute_loss(model, logit_output, sample, reduce=reduce) +# sample_size = sample["target"]["finetune_target"].size(0) +# if not self.training: +# probs = F.softmax(logit_output.float(), dim=-1).view( +# -1, logit_output.size(-1) +# ) +# logging_output = { +# "loss": loss.data, +# "prob": probs.data, +# "target": sample["target"]["finetune_target"].view(-1).data, +# "smi_name": sample["smi_name"], +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# else: +# logging_output = { +# "loss": loss.data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# return loss, sample_size, logging_output + +# def compute_loss(self, model, net_output, sample, reduce=True): +# lprobs = F.log_softmax(net_output.float(), dim=-1) +# lprobs = lprobs.view(-1, lprobs.size(-1)) +# targets = sample["target"]["finetune_target"].view(-1) +# loss = F.nll_loss( +# lprobs, +# targets, +# reduction="sum" if reduce else "none", +# ) +# return loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# acc_sum = sum( +# sum(log.get("prob").argmax(dim=-1) == log.get("target")) +# for log in logging_outputs +# ) +# probs = torch.cat([log.get("prob") for log in logging_outputs], dim=0) +# metrics.log_scalar( +# f"{split}_acc", acc_sum / sample_size, sample_size, round=3 +# ) +# if probs.size(-1) == 2: +# # binary classification task, add auc score +# targets = torch.cat( +# [log.get("target", 0) for log in logging_outputs], dim=0 +# ) +# smi_list = [ +# item for log in logging_outputs for item in log.get("smi_name") +# ] +# df = pd.DataFrame( +# { +# "probs": probs[:, 1].cpu(), +# "targets": targets.cpu(), +# "smi": smi_list, +# } +# ) +# auc = roc_auc_score(df["targets"], df["probs"]) +# df = df.groupby("smi").mean() +# agg_auc = roc_auc_score(df["targets"], df["probs"]) +# metrics.log_scalar(f"{split}_auc", auc, sample_size, round=3) +# metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return is_train + + +# @register_loss("multi_task_BCE") +# class MultiTaskBCELoss(CrossEntropyLoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model( +# **sample["net_input"], +# masked_tokens=None, +# features_only=True, +# classification_head_name=self.args.classification_head_name, +# ) +# logit_output = net_output[0] +# is_valid = sample["target"]["finetune_target"] > -0.5 +# loss = self.compute_loss( +# model, logit_output, sample, reduce=reduce, is_valid=is_valid +# ) +# sample_size = sample["target"]["finetune_target"].size(0) +# if not self.training: +# probs = torch.sigmoid(logit_output.float()).view(-1, logit_output.size(-1)) +# logging_output = { +# "loss": loss.data, +# "prob": probs.data, +# "target": sample["target"]["finetune_target"].view(-1).data, +# "num_task": self.args.num_classes, +# "sample_size": sample_size, +# "conf_size": self.args.conf_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# else: +# logging_output = { +# "loss": loss.data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# return loss, sample_size, logging_output + +# def compute_loss(self, model, net_output, sample, reduce=True, is_valid=None): +# pred = net_output[is_valid].float() +# targets = sample["target"]["finetune_target"][is_valid].float() +# loss = F.binary_cross_entropy_with_logits( +# pred, +# targets, +# reduction="sum" if reduce else "none", +# ) +# return loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# agg_auc_list = [] +# num_task = logging_outputs[0].get("num_task", 0) +# conf_size = logging_outputs[0].get("conf_size", 0) +# y_true = ( +# torch.cat([log.get("target", 0) for log in logging_outputs], dim=0) +# .view(-1, conf_size, num_task) +# .cpu() +# .numpy() +# .mean(axis=1) +# ) +# y_pred = ( +# torch.cat([log.get("prob") for log in logging_outputs], dim=0) +# .view(-1, conf_size, num_task) +# .cpu() +# .numpy() +# .mean(axis=1) +# ) +# # [test_size, num_classes] = [test_size * conf_size, num_classes].mean(axis=1) +# for i in range(y_true.shape[1]): +# # AUC is only defined when there is at least one positive data. +# if np.sum(y_true[:, i] == 1) > 0 and np.sum(y_true[:, i] == 0) > 0: +# # ignore nan values +# is_labeled = y_true[:, i] > -0.5 +# agg_auc_list.append( +# roc_auc_score(y_true[is_labeled, i], y_pred[is_labeled, i]) +# ) +# if len(agg_auc_list) < y_true.shape[1]: +# warnings.warn("Some target is missing!") +# if len(agg_auc_list) == 0: +# raise RuntimeError( +# "No positively labeled data available. Cannot compute Average Precision." +# ) +# agg_auc = sum(agg_auc_list) / len(agg_auc_list) +# metrics.log_scalar(f"{split}_agg_auc", agg_auc, sample_size, round=4) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return is_train + + +# @register_loss("finetune_cross_entropy_pocket") +# class FinetuneCrossEntropyPocketLoss(FinetuneCrossEntropyLoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model( +# **sample["net_input"], +# features_only=True, +# classification_head_name=self.args.classification_head_name, +# ) +# logit_output = net_output[0] +# loss = self.compute_loss(model, logit_output, sample, reduce=reduce) +# sample_size = sample["target"]["finetune_target"].size(0) +# if not self.training: +# probs = F.softmax(logit_output.float(), dim=-1).view( +# -1, logit_output.size(-1) +# ) +# logging_output = { +# "loss": loss.data, +# "prob": probs.data, +# "target": sample["target"]["finetune_target"].view(-1).data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# else: +# logging_output = { +# "loss": loss.data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# return loss, sample_size, logging_output + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# acc_sum = sum( +# sum(log.get("prob").argmax(dim=-1) == log.get("target")) +# for log in logging_outputs +# ) +# metrics.log_scalar( +# f"{split}_acc", acc_sum / sample_size, sample_size, round=3 +# ) +# preds = ( +# torch.cat( +# [log.get("prob").argmax(dim=-1) for log in logging_outputs], dim=0 +# ) +# .cpu() +# .numpy() +# ) +# targets = ( +# torch.cat([log.get("target", 0) for log in logging_outputs], dim=0) +# .cpu() +# .numpy() +# ) +# metrics.log_scalar(f"{split}_pre", precision_score(targets, preds), round=3) +# metrics.log_scalar(f"{split}_rec", recall_score(targets, preds), round=3) +# metrics.log_scalar( +# f"{split}_f1", f1_score(targets, preds), sample_size, round=3 +# ) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/docking_pose.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/docking_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..2bdb30a859b9cb6a6fbb2953739436b42c8e1e03 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/docking_pose.py @@ -0,0 +1,222 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import mindspore.ops as F +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss + + +@register_loss("docking_pose") +class DockingPossLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + self.eos_idx = task.dictionary.eos() + self.bos_idx = task.dictionary.bos() + self.padding_idx = task.dictionary.pad() + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_outputs = model(**sample["net_input"]) + cross_distance_predict, holo_distance_predict = net_outputs[0], net_outputs[1] + + ### distance loss + # 替换torch.ne为mindspore.ne,逻辑判断保持一致 + distance_mask = sample["target"]["distance_target"].ne(0) # 0 for padding, BOS and EOS + # 0 is impossible in the cross distance matrix, all the relevant cross distances are kept + if self.args.dist_threshold > 0: + distance_mask &= ( + sample["target"]["distance_target"] < self.args.dist_threshold + ) + distance_predict = cross_distance_predict[distance_mask] + distance_target = sample["target"]["distance_target"][distance_mask] + # 替换F.mse_loss,类型转换用astype(ms.float32) + distance_loss = F.mse_loss( + distance_predict.astype(ms.float32), + distance_target.astype(ms.float32), + reduction="mean" + ) + + ### holo distance loss + # 逻辑与操作符&在MindSpore中用法一致 + token_mask = sample["net_input"]["mol_src_tokens"].ne(self.padding_idx) & \ + sample["net_input"]["mol_src_tokens"].ne(self.eos_idx) & \ + sample["net_input"]["mol_src_tokens"].ne(self.bos_idx) + holo_distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1) + holo_distance_predict_train = holo_distance_predict[holo_distance_mask] + holo_distance_target = sample["target"]["holo_distance_target"][ + holo_distance_mask + ] + # 替换F.smooth_l1_loss,保持参数一致 + holo_distance_loss = F.smooth_l1_loss( + holo_distance_predict_train.astype(ms.float32), + holo_distance_target.astype(ms.float32), + reduction="mean", + beta=1.0, + ) + + loss = distance_loss + holo_distance_loss + # 替换size(0)为shape[0]获取维度大小 + sample_size = sample["target"]["holo_coord"].shape[0] + # 移除.data属性,MindSpore张量直接使用 + logging_output = { + "loss": loss, + "cross_loss": distance_loss, + "holo_loss": holo_distance_loss, + "bsz": sample_size, + "sample_size": 1, + } + if not self.training: + logging_output["smi_name"] = sample["smi_name"] + logging_output["pocket_name"] = sample["pocket_name"] + # 替换data.detach().cpu()为stop_gradient(分离计算图),无需cpu()适配Ascend NPU + logging_output["cross_distance_predict"] = ms.ops.stop_gradient(cross_distance_predict) + logging_output["holo_distance_predict"] = ms.ops.stop_gradient(holo_distance_predict) + logging_output["atoms"] = ms.ops.stop_gradient(sample["net_input"]["mol_src_tokens"]) + logging_output["pocket_atoms"] = ms.ops.stop_gradient(sample["net_input"]["pocket_src_tokens"]) + logging_output["holo_center_coordinates"] = ms.ops.stop_gradient(sample["holo_center_coordinates"]) + logging_output["holo_coordinates"] = ms.ops.stop_gradient(sample["target"]["holo_coord"]) + logging_output["pocket_coordinates"] = ms.ops.stop_gradient(sample["net_input"]["pocket_src_coord"]) + + return loss, 1, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=4) + metrics.log_scalar( + f"{split}_loss", loss_sum / sample_size, sample_size, round=4 + ) + cross_loss_sum = sum(log.get("cross_loss", 0) for log in logging_outputs) + metrics.log_scalar( + "cross_loss", cross_loss_sum / sample_size, sample_size, round=4 + ) + holo_loss_sum = sum(log.get("holo_loss", 0) for log in logging_outputs) + metrics.log_scalar( + "holo_loss", holo_loss_sum / sample_size, sample_size, round=4 + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """Whether logging outputs can be summed across workers.""" + return is_train +# import torch +# import torch.nn.functional as F +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss + + +# @register_loss("docking_pose") +# class DockingPossLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) +# self.eos_idx = task.dictionary.eos() +# self.bos_idx = task.dictionary.bos() +# self.padding_idx = task.dictionary.pad() + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_outputs = model(**sample["net_input"]) +# cross_distance_predict, holo_distance_predict = net_outputs[0], net_outputs[1] + +# ### distance loss +# distance_mask = sample["target"]["distance_target"].ne(0) # 0 for padding, BOS and EOS +# # 0 is impossible in the cross distance matrix, all the relevant cross distances are kept +# if self.args.dist_threshold > 0: +# distance_mask &= ( +# sample["target"]["distance_target"] < self.args.dist_threshold +# ) +# distance_predict = cross_distance_predict[distance_mask] +# distance_target = sample["target"]["distance_target"][distance_mask] +# distance_loss = F.mse_loss( +# distance_predict.float(), distance_target.float(), reduction="mean" +# ) + +# ### holo distance loss +# token_mask = sample["net_input"]["mol_src_tokens"].ne(self.padding_idx) & \ +# sample["net_input"]["mol_src_tokens"].ne(self.eos_idx) & \ +# sample["net_input"]["mol_src_tokens"].ne(self.bos_idx) +# holo_distance_mask = token_mask.unsqueeze(-1) & token_mask.unsqueeze(1) +# holo_distance_predict_train = holo_distance_predict[holo_distance_mask] +# holo_distance_target = sample["target"]["holo_distance_target"][ +# holo_distance_mask +# ] +# holo_distance_loss = F.smooth_l1_loss( +# holo_distance_predict_train.float(), +# holo_distance_target.float(), +# reduction="mean", +# beta=1.0, +# ) + +# loss = distance_loss + holo_distance_loss +# sample_size = sample["target"]["holo_coord"].size(0) +# logging_output = { +# "loss": loss.data, +# "cross_loss": distance_loss.data, +# "holo_loss": holo_distance_loss.data, +# "bsz": sample_size, +# "sample_size": 1, +# } +# if not self.training: +# logging_output["smi_name"] = sample["smi_name"] +# logging_output["pocket_name"] = sample["pocket_name"] +# logging_output[ +# "cross_distance_predict" +# ] = cross_distance_predict.data.detach().cpu() +# logging_output[ +# "holo_distance_predict" +# ] = holo_distance_predict.data.detach().cpu() +# logging_output["atoms"] = ( +# sample["net_input"]["mol_src_tokens"].data.detach().cpu() +# ) +# logging_output["pocket_atoms"] = ( +# sample["net_input"]["pocket_src_tokens"].data.detach().cpu() +# ) +# logging_output["holo_center_coordinates"] = ( +# sample["holo_center_coordinates"].data.detach().cpu() +# ) +# logging_output["holo_coordinates"] = ( +# sample["target"]["holo_coord"].data.detach().cpu() +# ) +# logging_output["pocket_coordinates"] = ( +# sample["net_input"]["pocket_src_coord"].data.detach().cpu() +# ) + +# return loss, 1, logging_output + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + +# metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=4) +# metrics.log_scalar( +# f"{split}_loss", loss_sum / sample_size, sample_size, round=4 +# ) +# cross_loss_sum = sum(log.get("cross_loss", 0) for log in logging_outputs) +# metrics.log_scalar( +# "cross_loss", cross_loss_sum / sample_size, sample_size, round=4 +# ) +# holo_loss_sum = sum(log.get("holo_loss", 0) for log in logging_outputs) +# metrics.log_scalar( +# "holo_loss", holo_loss_sum / sample_size, sample_size, round=4 +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return is_train diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/reg_loss.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/reg_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..54e92ace2002aaa129c9fd00e0aaf9067e64470c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/reg_loss.py @@ -0,0 +1,503 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import math +import mindspore as ms +import mindspore.ops as F +import pandas as pd +import numpy as np +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss + + +@register_loss("finetune_mse") +class FinetuneMSELoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + reg_output = net_output[0] + loss = self.compute_loss(model, reg_output, sample, reduce=reduce) + # 替换size(0)为shape[0] + sample_size = sample["target"]["finetune_target"].shape[0] + if not self.training: + if self.task.mean and self.task.std: + # 替换torch.tensor为ms.Tensor,自动适配设备 + targets_mean = ms.Tensor(self.task.mean, dtype=ms.float32) + targets_std = ms.Tensor(self.task.std, dtype=ms.float32) + reg_output = reg_output * targets_std + targets_mean + # 移除.data,view改为reshape + logging_output = { + "loss": loss, + "predict": reg_output.reshape(-1, self.args.num_classes), + "target": sample["target"]["finetune_target"].reshape(-1, self.args.num_classes), + "smi_name": sample["smi_name"], + "sample_size": sample_size, + "num_task": self.args.num_classes, + "conf_size": self.args.conf_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + else: + logging_output = { + "loss": loss, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + return loss, sample_size, logging_output + + def compute_loss(self, model, net_output, sample, reduce=True): + # 替换view为reshape,float()改为astype(ms.float32) + predicts = net_output.reshape(-1, self.args.num_classes).astype(ms.float32) + targets = sample["target"]["finetune_target"].reshape(-1, self.args.num_classes).astype(ms.float32) + if self.task.mean and self.task.std: + # 替换torch.tensor为ms.Tensor,自动适配设备 + targets_mean = ms.Tensor(self.task.mean, dtype=ms.float32) + targets_std = ms.Tensor(self.task.std, dtype=ms.float32) + targets = (targets - targets_mean) / targets_std + loss = F.mse_loss( + predicts, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + # 替换torch.cat为ms.ops.cat + predicts = ms.ops.cat([log.get("predict") for log in logging_outputs], axis=0) + if predicts.shape[-1] == 1: + targets = ms.ops.cat([log.get("target", 0) for log in logging_outputs], axis=0) + smi_list = [item for log in logging_outputs for item in log.get("smi_name")] + # 替换cpu()为asnumpy() + df = pd.DataFrame( + { + "predict": predicts.reshape(-1).asnumpy(), + "target": targets.reshape(-1).asnumpy(), + "smi": smi_list, + } + ) + mae = np.abs(df["predict"] - df["target"]).mean() + mse = ((df["predict"] - df["target"]) **2).mean() + df = df.groupby("smi").mean() + agg_mae = np.abs(df["predict"] - df["target"]).mean() + agg_mse = ((df["predict"] - df["target"])** 2).mean() + + metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3) + metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_mae", agg_mae, sample_size, round=3) + metrics.log_scalar(f"{split}_agg_mse", agg_mse, sample_size, round=3) + metrics.log_scalar( + f"{split}_agg_rmse", np.sqrt(agg_mse), sample_size, round=4 + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + return is_train + + +@register_loss("finetune_mae") +class FinetuneMAELoss(FinetuneMSELoss): + def __init__(self, task): + super().__init__(task) + + def compute_loss(self, model, net_output, sample, reduce=True): + # 替换view为reshape,float()改为astype(ms.float32) + predicts = net_output.reshape(-1, self.args.num_classes).astype(ms.float32) + targets = sample["target"]["finetune_target"].reshape(-1, self.args.num_classes).astype(ms.float32) + if self.task.mean and self.task.std: + targets_mean = ms.Tensor(self.task.mean, dtype=ms.float32) + targets_std = ms.Tensor(self.task.std, dtype=ms.float32) + targets = (targets - targets_mean) / targets_std + loss = F.l1_loss( + predicts, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + +@register_loss("finetune_smooth_mae") +class FinetuneSmoothMAELoss(FinetuneMSELoss): + def __init__(self, task): + super().__init__(task) + + def compute_loss(self, model, net_output, sample, reduce=True): + # 替换view为reshape,float()改为astype(ms.float32) + predicts = net_output.reshape(-1, self.args.num_classes).astype(ms.float32) + targets = sample["target"]["finetune_target"].reshape(-1, self.args.num_classes).astype(ms.float32) + if self.task.mean and self.task.std: + targets_mean = ms.Tensor(self.task.mean, dtype=ms.float32) + targets_std = ms.Tensor(self.task.std, dtype=ms.float32) + targets = (targets - targets_mean) / targets_std + loss = F.smooth_l1_loss( + predicts, + targets, + reduction="sum" if reduce else "none", + ) + return loss + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + num_task = logging_outputs[0].get("num_task", 0) + conf_size = logging_outputs[0].get("conf_size", 0) + # 替换torch.cat为ms.ops.cat,view改为reshape,cpu().numpy()改为asnumpy() + y_true = ( + ms.ops.cat([log.get("target", 0) for log in logging_outputs], axis=0) + .reshape(-1, conf_size, num_task) + .asnumpy() + .mean(axis=1) + ) + y_pred = ( + ms.ops.cat([log.get("predict") for log in logging_outputs], axis=0) + .reshape(-1, conf_size, num_task) + .asnumpy() + .mean(axis=1) + ) + agg_mae = np.abs(y_pred - y_true).mean() + metrics.log_scalar(f"{split}_agg_mae", agg_mae, sample_size, round=4) + + +@register_loss("finetune_mse_pocket") +class FinetuneMSEPocketLoss(FinetuneMSELoss): + def __init__(self, task): + super().__init__(task) + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + net_output = model( + **sample["net_input"], + features_only=True, + classification_head_name=self.args.classification_head_name, + ) + reg_output = net_output[0] + loss = self.compute_loss(model, reg_output, sample, reduce=reduce) + # 替换size(0)为shape[0] + sample_size = sample["target"]["finetune_target"].shape[0] + if not self.training: + if self.task.mean and self.task.std: + targets_mean = ms.Tensor(self.task.mean, dtype=ms.float32) + targets_std = ms.Tensor(self.task.std, dtype=ms.float32) + reg_output = reg_output * targets_std + targets_mean + # 移除.data,view改为reshape + logging_output = { + "loss": loss, + "predict": reg_output.reshape(-1, self.args.num_classes), + "target": sample["target"]["finetune_target"].reshape(-1, self.args.num_classes), + "sample_size": sample_size, + "num_task": self.args.num_classes, + "bsz": sample["target"]["finetune_target"].shape[0], + } + else: + logging_output = { + "loss": loss, + "sample_size": sample_size, + "bsz": sample["target"]["finetune_target"].shape[0], + } + return loss, sample_size, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + metrics.log_scalar( + "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 + ) + if "valid" in split or "test" in split: + # 替换torch.cat为ms.ops.cat + predicts = ms.ops.cat([log.get("predict") for log in logging_outputs], axis=0) + if predicts.shape[-1] == 1: + targets = ms.ops.cat([log.get("target", 0) for log in logging_outputs], axis=0) + # 替换cpu()为asnumpy() + df = pd.DataFrame( + { + "predict": predicts.reshape(-1).asnumpy(), + "target": targets.reshape(-1).asnumpy(), + } + ) + mse = ((df["predict"] - df["target"]) **2).mean() + metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) + metrics.log_scalar(f"{split}_rmse", np.sqrt(mse), sample_size, round=4) +# import math +# import torch +# import torch.nn.functional as F +# import pandas as pd +# import numpy as np +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss + + +# @register_loss("finetune_mse") +# class FinetuneMSELoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model( +# **sample["net_input"], +# features_only=True, +# classification_head_name=self.args.classification_head_name, +# ) +# reg_output = net_output[0] +# loss = self.compute_loss(model, reg_output, sample, reduce=reduce) +# sample_size = sample["target"]["finetune_target"].size(0) +# if not self.training: +# if self.task.mean and self.task.std: +# targets_mean = torch.tensor(self.task.mean, device=reg_output.device) +# targets_std = torch.tensor(self.task.std, device=reg_output.device) +# reg_output = reg_output * targets_std + targets_mean +# logging_output = { +# "loss": loss.data, +# "predict": reg_output.view(-1, self.args.num_classes).data, +# "target": sample["target"]["finetune_target"] +# .view(-1, self.args.num_classes) +# .data, +# "smi_name": sample["smi_name"], +# "sample_size": sample_size, +# "num_task": self.args.num_classes, +# "conf_size": self.args.conf_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# else: +# logging_output = { +# "loss": loss.data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# return loss, sample_size, logging_output + +# def compute_loss(self, model, net_output, sample, reduce=True): +# predicts = net_output.view(-1, self.args.num_classes).float() +# targets = ( +# sample["target"]["finetune_target"].view(-1, self.args.num_classes).float() +# ) +# if self.task.mean and self.task.std: +# targets_mean = torch.tensor(self.task.mean, device=targets.device) +# targets_std = torch.tensor(self.task.std, device=targets.device) +# targets = (targets - targets_mean) / targets_std +# loss = F.mse_loss( +# predicts, +# targets, +# reduction="sum" if reduce else "none", +# ) +# return loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0) +# if predicts.size(-1) == 1: +# # single label regression task, add aggregate acc and loss score +# targets = torch.cat( +# [log.get("target", 0) for log in logging_outputs], dim=0 +# ) +# smi_list = [ +# item for log in logging_outputs for item in log.get("smi_name") +# ] +# df = pd.DataFrame( +# { +# "predict": predicts.view(-1).cpu(), +# "target": targets.view(-1).cpu(), +# "smi": smi_list, +# } +# ) +# mae = np.abs(df["predict"] - df["target"]).mean() +# mse = ((df["predict"] - df["target"]) ** 2).mean() +# df = df.groupby("smi").mean() +# agg_mae = np.abs(df["predict"] - df["target"]).mean() +# agg_mse = ((df["predict"] - df["target"]) ** 2).mean() + +# metrics.log_scalar(f"{split}_mae", mae, sample_size, round=3) +# metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) +# metrics.log_scalar(f"{split}_agg_mae", agg_mae, sample_size, round=3) +# metrics.log_scalar(f"{split}_agg_mse", agg_mse, sample_size, round=3) +# metrics.log_scalar( +# f"{split}_agg_rmse", np.sqrt(agg_mse), sample_size, round=4 +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return is_train + + +# @register_loss("finetune_mae") +# class FinetuneMAELoss(FinetuneMSELoss): +# def __init__(self, task): +# super().__init__(task) + +# def compute_loss(self, model, net_output, sample, reduce=True): +# predicts = net_output.view(-1, self.args.num_classes).float() +# targets = ( +# sample["target"]["finetune_target"].view(-1, self.args.num_classes).float() +# ) +# if self.task.mean and self.task.std: +# targets_mean = torch.tensor(self.task.mean, device=targets.device) +# targets_std = torch.tensor(self.task.std, device=targets.device) +# targets = (targets - targets_mean) / targets_std +# loss = F.l1_loss( +# predicts, +# targets, +# reduction="sum" if reduce else "none", +# ) +# return loss + + +# @register_loss("finetune_smooth_mae") +# class FinetuneSmoothMAELoss(FinetuneMSELoss): +# def __init__(self, task): +# super().__init__(task) + +# def compute_loss(self, model, net_output, sample, reduce=True): +# predicts = net_output.view(-1, self.args.num_classes).float() +# targets = ( +# sample["target"]["finetune_target"].view(-1, self.args.num_classes).float() +# ) +# if self.task.mean and self.task.std: +# targets_mean = torch.tensor(self.task.mean, device=targets.device) +# targets_std = torch.tensor(self.task.std, device=targets.device) +# targets = (targets - targets_mean) / targets_std +# loss = F.smooth_l1_loss( +# predicts, +# targets, +# reduction="sum" if reduce else "none", +# ) +# return loss + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# num_task = logging_outputs[0].get("num_task", 0) +# conf_size = logging_outputs[0].get("conf_size", 0) +# y_true = ( +# torch.cat([log.get("target", 0) for log in logging_outputs], dim=0) +# .view(-1, conf_size, num_task) +# .cpu() +# .numpy() +# .mean(axis=1) +# ) +# y_pred = ( +# torch.cat([log.get("predict") for log in logging_outputs], dim=0) +# .view(-1, conf_size, num_task) +# .cpu() +# .numpy() +# .mean(axis=1) +# ) +# agg_mae = np.abs(y_pred - y_true).mean() +# metrics.log_scalar(f"{split}_agg_mae", agg_mae, sample_size, round=4) + + +# @register_loss("finetune_mse_pocket") +# class FinetuneMSEPocketLoss(FinetuneMSELoss): +# def __init__(self, task): +# super().__init__(task) + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# net_output = model( +# **sample["net_input"], +# features_only=True, +# classification_head_name=self.args.classification_head_name, +# ) +# reg_output = net_output[0] +# loss = self.compute_loss(model, reg_output, sample, reduce=reduce) +# sample_size = sample["target"]["finetune_target"].size(0) +# if not self.training: +# if self.task.mean and self.task.std: +# targets_mean = torch.tensor(self.task.mean, device=reg_output.device) +# targets_std = torch.tensor(self.task.std, device=reg_output.device) +# reg_output = reg_output * targets_std + targets_mean +# logging_output = { +# "loss": loss.data, +# "predict": reg_output.view(-1, self.args.num_classes).data, +# "target": sample["target"]["finetune_target"] +# .view(-1, self.args.num_classes) +# .data, +# "sample_size": sample_size, +# "num_task": self.args.num_classes, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# else: +# logging_output = { +# "loss": loss.data, +# "sample_size": sample_size, +# "bsz": sample["target"]["finetune_target"].size(0), +# } +# return loss, sample_size, logging_output + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# # we divide by log(2) to convert the loss from base e to base 2 +# metrics.log_scalar( +# "loss", loss_sum / sample_size / math.log(2), sample_size, round=3 +# ) +# if "valid" in split or "test" in split: +# predicts = torch.cat([log.get("predict") for log in logging_outputs], dim=0) +# if predicts.size(-1) == 1: +# # single label regression task +# targets = torch.cat( +# [log.get("target", 0) for log in logging_outputs], dim=0 +# ) +# df = pd.DataFrame( +# { +# "predict": predicts.view(-1).cpu(), +# "target": targets.view(-1).cpu(), +# } +# ) +# mse = ((df["predict"] - df["target"]) ** 2).mean() +# metrics.log_scalar(f"{split}_mse", mse, sample_size, round=3) +# metrics.log_scalar(f"{split}_rmse", np.sqrt(mse), sample_size, round=4) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/unimol.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/unimol.py new file mode 100644 index 0000000000000000000000000000000000000000..789e2c96e21c03833cc23071867482a6d06946b7 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/losses/unimol.py @@ -0,0 +1,450 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import mindspore as ms +import mindspore.ops as F +from unicore import metrics +from unicore.losses import UnicoreLoss, register_loss + + +@register_loss("unimol") +class UniMolLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + self.padding_idx = task.dictionary.pad() + self.seed = task.seed + self.dist_mean = 6.312581655060595 + self.dist_std = 3.3899264663911888 + + def forward(self, model, sample, reduce=True): + input_key = "net_input" + target_key = "target" + # 替换ne操作,保持逻辑一致 + masked_tokens = sample[target_key]["tokens_target"].ne(self.padding_idx) + # 替换long()为astype(ms.int64),sum操作一致 + sample_size = masked_tokens.astype(ms.int64).sum() + ( + logits_encoder, + encoder_distance, + encoder_coord, + x_norm, + delta_encoder_pair_rep_norm, + ) = model(** sample[input_key], encoder_masked_tokens=masked_tokens) + target = sample[target_key]["tokens_target"] + if masked_tokens is not None: + target = target[masked_tokens] + # 替换log_softmax的dim为axis,dtype改为ms.float32 + masked_token_loss = F.nll_loss( + F.log_softmax(logits_encoder, axis=-1).astype(ms.float32), + target, + ignore_index=self.padding_idx, + reduction="mean", + ) + # 替换argmax的dim为axis + masked_pred = ms.ops.argmax(logits_encoder, axis=-1) + # 替换long()为astype(ms.int64) + masked_hit = (masked_pred == target).astype(ms.int64).sum() + masked_cnt = sample_size + loss = masked_token_loss * self.args.masked_token_loss + # 替换size()为shape,移除.data属性 + logging_output = { + "sample_size": 1, + "bsz": sample[target_key]["tokens_target"].shape[0], + "seq_len": sample[target_key]["tokens_target"].shape[1] + * sample[target_key]["tokens_target"].shape[0], + "masked_token_loss": masked_token_loss, + "masked_token_hit": masked_hit, + "masked_token_cnt": masked_cnt, + } + + if encoder_coord is not None: + coord_target = sample[target_key]["coord_target"] + # 替换view为reshape,float()改为astype(ms.float32) + masked_coord_loss = F.smooth_l1_loss( + encoder_coord[masked_tokens].reshape(-1, 3).astype(ms.float32), + coord_target[masked_tokens].reshape(-1, 3), + reduction="mean", + beta=1.0, + ) + loss = loss + masked_coord_loss * self.args.masked_coord_loss + logging_output["masked_coord_loss"] = masked_coord_loss + + if encoder_distance is not None: + dist_masked_tokens = masked_tokens + masked_dist_loss = self.cal_dist_loss( + sample, encoder_distance, dist_masked_tokens, target_key, normalize=True + ) + loss = loss + masked_dist_loss * self.args.masked_dist_loss + logging_output["masked_dist_loss"] = masked_dist_loss + + if self.args.x_norm_loss > 0 and x_norm is not None: + loss = loss + self.args.x_norm_loss * x_norm + logging_output["x_norm_loss"] = x_norm + + if ( + self.args.delta_pair_repr_norm_loss > 0 + and delta_encoder_pair_rep_norm is not None + ): + loss = ( + loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm + ) + logging_output[ + "delta_pair_repr_norm_loss" + ] = delta_encoder_pair_rep_norm + + logging_output["loss"] = loss + return loss, 1, logging_output + + @staticmethod + def reduce_metrics(logging_outputs, split="valid") -> None: + """Aggregate logging outputs from data parallel training.""" + loss_sum = sum(log.get("loss", 0) for log in logging_outputs) + bsz = sum(log.get("bsz", 0) for log in logging_outputs) + sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) + seq_len = sum(log.get("seq_len", 0) for log in logging_outputs) + metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) + metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3) + + masked_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs) + metrics.log_scalar( + "masked_token_loss", masked_loss / sample_size, sample_size, round=3 + ) + + masked_acc = sum( + log.get("masked_token_hit", 0) for log in logging_outputs + ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs) + metrics.log_scalar("masked_acc", masked_acc, sample_size, round=3) + + masked_coord_loss = sum( + log.get("masked_coord_loss", 0) for log in logging_outputs + ) + if masked_coord_loss > 0: + metrics.log_scalar( + "masked_coord_loss", + masked_coord_loss / sample_size, + sample_size, + round=3, + ) + + masked_dist_loss = sum( + log.get("masked_dist_loss", 0) for log in logging_outputs + ) + if masked_dist_loss > 0: + metrics.log_scalar( + "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3 + ) + + x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs) + if x_norm_loss > 0: + metrics.log_scalar( + "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3 + ) + + delta_pair_repr_norm_loss = sum( + log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs + ) + if delta_pair_repr_norm_loss > 0: + metrics.log_scalar( + "delta_pair_repr_norm_loss", + delta_pair_repr_norm_loss / sample_size, + sample_size, + round=3, + ) + + @staticmethod + def logging_outputs_can_be_summed(is_train) -> bool: + """Whether logging outputs can be summed across workers.""" + return True + + def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False): + dist_masked_tokens = masked_tokens + masked_distance = dist[dist_masked_tokens, :] + masked_distance_target = sample[target_key]["distance_target"][ + dist_masked_tokens + ] + # padding distance + nb_masked_tokens = dist_masked_tokens.sum(axis=-1) + masked_src_tokens = sample["net_input"]["src_tokens"].ne(self.padding_idx) + # 替换torch.repeat_interleave为ms.ops.repeat_interleave + masked_src_tokens_expanded = ms.ops.repeat_interleave(masked_src_tokens, nb_masked_tokens, axis=0) + # + if normalize: + masked_distance_target = ( + masked_distance_target.astype(ms.float32) - self.dist_mean + ) / self.dist_std + # 替换view为reshape,float()改为astype(ms.float32) + masked_dist_loss = F.smooth_l1_loss( + masked_distance[masked_src_tokens_expanded].reshape(-1).astype(ms.float32), + masked_distance_target[masked_src_tokens_expanded].reshape(-1), + reduction="mean", + beta=1.0, + ) + return masked_dist_loss + + +@register_loss("unimol_infer") +class UniMolInferLoss(UnicoreLoss): + def __init__(self, task): + super().__init__(task) + self.padding_idx = task.dictionary.pad() + self.bos_idx = task.dictionary.bos() + self.eos_idx = task.dictionary.eos() + + def forward(self, model, sample, reduce=True): + """Compute the loss for the given sample.""" + input_key = "net_input" + target_key = "target" + src_tokens = sample[input_key]["src_tokens"] + # 替换ne操作,保持逻辑一致 + token_mask = (src_tokens.ne(self.padding_idx) & src_tokens.ne(self.bos_idx) & src_tokens.ne(self.eos_idx)) + ( + encoder_rep, + encoder_pair_rep, + ) = model(** sample[input_key], features_only=True) + # 替换size(0)为shape[0] + sample_size = sample[input_key]["src_tokens"].shape[0] + encoder_rep_list = [] + encoder_pair_rep_list = [] + if 'pdb_id' in sample[target_key].keys(): + name_key = 'pdb_id' + elif 'smi_name' in sample[target_key].keys(): + name_key = 'smi_name' + else: + raise NotImplementedError("No name key in the original data") + + for i in range(sample_size): # rm padding bos eos token + # 移除.data.cpu(),直接用asnumpy() + encoder_rep_list.append(encoder_rep[i][token_mask[i]].asnumpy()) + encoder_pair_rep_list.append(encoder_pair_rep[i][token_mask[i], :][:, token_mask[i]].asnumpy()) + # 替换size(0)为shape[0],移除.data.cpu() + logging_output = { + "mol_repr_cls": encoder_rep[:, 0, :].asnumpy(), # get cls token + "atom_repr": encoder_rep_list, + "pair_repr": encoder_pair_rep_list, + "data_name": sample[target_key][name_key], + "bsz": sample[input_key]["src_tokens"].shape[0], + } + return 0, sample_size, logging_output +# import torch +# import torch.nn.functional as F +# from unicore import metrics +# from unicore.losses import UnicoreLoss, register_loss + + +# @register_loss("unimol") +# class UniMolLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) +# self.padding_idx = task.dictionary.pad() +# self.seed = task.seed +# self.dist_mean = 6.312581655060595 +# self.dist_std = 3.3899264663911888 + +# def forward(self, model, sample, reduce=True): +# input_key = "net_input" +# target_key = "target" +# masked_tokens = sample[target_key]["tokens_target"].ne(self.padding_idx) +# sample_size = masked_tokens.long().sum() +# ( +# logits_encoder, +# encoder_distance, +# encoder_coord, +# x_norm, +# delta_encoder_pair_rep_norm, +# ) = model(**sample[input_key], encoder_masked_tokens=masked_tokens) +# target = sample[target_key]["tokens_target"] +# if masked_tokens is not None: +# target = target[masked_tokens] +# masked_token_loss = F.nll_loss( +# F.log_softmax(logits_encoder, dim=-1, dtype=torch.float32), +# target, +# ignore_index=self.padding_idx, +# reduction="mean", +# ) +# masked_pred = logits_encoder.argmax(dim=-1) +# masked_hit = (masked_pred == target).long().sum() +# masked_cnt = sample_size +# loss = masked_token_loss * self.args.masked_token_loss +# logging_output = { +# "sample_size": 1, +# "bsz": sample[target_key]["tokens_target"].size(0), +# "seq_len": sample[target_key]["tokens_target"].size(1) +# * sample[target_key]["tokens_target"].size(0), +# "masked_token_loss": masked_token_loss.data, +# "masked_token_hit": masked_hit.data, +# "masked_token_cnt": masked_cnt, +# } + +# if encoder_coord is not None: +# # real = mask + delta +# coord_target = sample[target_key]["coord_target"] +# masked_coord_loss = F.smooth_l1_loss( +# encoder_coord[masked_tokens].view(-1, 3).float(), +# coord_target[masked_tokens].view(-1, 3), +# reduction="mean", +# beta=1.0, +# ) +# loss = loss + masked_coord_loss * self.args.masked_coord_loss +# # restore the scale of loss for displaying +# logging_output["masked_coord_loss"] = masked_coord_loss.data + +# if encoder_distance is not None: +# dist_masked_tokens = masked_tokens +# masked_dist_loss = self.cal_dist_loss( +# sample, encoder_distance, dist_masked_tokens, target_key, normalize=True +# ) +# loss = loss + masked_dist_loss * self.args.masked_dist_loss +# logging_output["masked_dist_loss"] = masked_dist_loss.data + +# if self.args.x_norm_loss > 0 and x_norm is not None: +# loss = loss + self.args.x_norm_loss * x_norm +# logging_output["x_norm_loss"] = x_norm.data + +# if ( +# self.args.delta_pair_repr_norm_loss > 0 +# and delta_encoder_pair_rep_norm is not None +# ): +# loss = ( +# loss + self.args.delta_pair_repr_norm_loss * delta_encoder_pair_rep_norm +# ) +# logging_output[ +# "delta_pair_repr_norm_loss" +# ] = delta_encoder_pair_rep_norm.data + +# logging_output["loss"] = loss.data +# return loss, 1, logging_output + +# @staticmethod +# def reduce_metrics(logging_outputs, split="valid") -> None: +# """Aggregate logging outputs from data parallel training.""" +# loss_sum = sum(log.get("loss", 0) for log in logging_outputs) +# bsz = sum(log.get("bsz", 0) for log in logging_outputs) +# sample_size = sum(log.get("sample_size", 0) for log in logging_outputs) +# seq_len = sum(log.get("seq_len", 0) for log in logging_outputs) +# metrics.log_scalar("loss", loss_sum / sample_size, sample_size, round=3) +# metrics.log_scalar("seq_len", seq_len / bsz, 1, round=3) + +# masked_loss = sum(log.get("masked_token_loss", 0) for log in logging_outputs) +# metrics.log_scalar( +# "masked_token_loss", masked_loss / sample_size, sample_size, round=3 +# ) + +# masked_acc = sum( +# log.get("masked_token_hit", 0) for log in logging_outputs +# ) / sum(log.get("masked_token_cnt", 0) for log in logging_outputs) +# metrics.log_scalar("masked_acc", masked_acc, sample_size, round=3) + +# masked_coord_loss = sum( +# log.get("masked_coord_loss", 0) for log in logging_outputs +# ) +# if masked_coord_loss > 0: +# metrics.log_scalar( +# "masked_coord_loss", +# masked_coord_loss / sample_size, +# sample_size, +# round=3, +# ) + +# masked_dist_loss = sum( +# log.get("masked_dist_loss", 0) for log in logging_outputs +# ) +# if masked_dist_loss > 0: +# metrics.log_scalar( +# "masked_dist_loss", masked_dist_loss / sample_size, sample_size, round=3 +# ) + +# x_norm_loss = sum(log.get("x_norm_loss", 0) for log in logging_outputs) +# if x_norm_loss > 0: +# metrics.log_scalar( +# "x_norm_loss", x_norm_loss / sample_size, sample_size, round=3 +# ) + +# delta_pair_repr_norm_loss = sum( +# log.get("delta_pair_repr_norm_loss", 0) for log in logging_outputs +# ) +# if delta_pair_repr_norm_loss > 0: +# metrics.log_scalar( +# "delta_pair_repr_norm_loss", +# delta_pair_repr_norm_loss / sample_size, +# sample_size, +# round=3, +# ) + +# @staticmethod +# def logging_outputs_can_be_summed(is_train) -> bool: +# """ +# Whether the logging outputs returned by `forward` can be summed +# across workers prior to calling `reduce_metrics`. Setting this +# to True will improves distributed training speed. +# """ +# return True + +# def cal_dist_loss(self, sample, dist, masked_tokens, target_key, normalize=False): +# dist_masked_tokens = masked_tokens +# masked_distance = dist[dist_masked_tokens, :] +# masked_distance_target = sample[target_key]["distance_target"][ +# dist_masked_tokens +# ] +# # padding distance +# nb_masked_tokens = dist_masked_tokens.sum(dim=-1) +# masked_src_tokens = sample["net_input"]["src_tokens"].ne(self.padding_idx) +# masked_src_tokens_expanded = torch.repeat_interleave(masked_src_tokens, nb_masked_tokens, dim=0) +# # +# if normalize: +# masked_distance_target = ( +# masked_distance_target.float() - self.dist_mean +# ) / self.dist_std +# masked_dist_loss = F.smooth_l1_loss( +# masked_distance[masked_src_tokens_expanded].view(-1).float(), +# masked_distance_target[masked_src_tokens_expanded].view(-1), +# reduction="mean", +# beta=1.0, +# ) +# return masked_dist_loss + + +# @register_loss("unimol_infer") +# class UniMolInferLoss(UnicoreLoss): +# def __init__(self, task): +# super().__init__(task) +# self.padding_idx = task.dictionary.pad() +# self.bos_idx = task.dictionary.bos() +# self.eos_idx = task.dictionary.eos() + +# def forward(self, model, sample, reduce=True): +# """Compute the loss for the given sample. + +# Returns a tuple with three elements: +# 1) the loss +# 2) the sample size, which is used as the denominator for the gradient +# 3) logging outputs to display while training +# """ +# input_key = "net_input" +# target_key = "target" +# src_tokens = sample[input_key]["src_tokens"] +# token_mask = (src_tokens.ne(self.padding_idx) & src_tokens.ne(self.bos_idx) & src_tokens.ne(self.eos_idx)) +# ( +# encoder_rep, +# encoder_pair_rep, +# ) = model(**sample[input_key], features_only=True) +# sample_size = sample[input_key]["src_tokens"].size(0) +# encoder_rep_list = [] +# encoder_pair_rep_list = [] +# if 'pdb_id' in sample[target_key].keys(): +# name_key = 'pdb_id' +# elif 'smi_name' in sample[target_key].keys(): +# name_key = 'smi_name' +# else: +# raise NotImplementedError("No name key in the original data") + +# for i in range(sample_size): # rm padding bos eos token +# encoder_rep_list.append(encoder_rep[i][token_mask[i]].data.cpu().numpy()) +# encoder_pair_rep_list.append(encoder_pair_rep[i][token_mask[i], :][:, token_mask[i]].data.cpu().numpy()) +# logging_output = { +# "mol_repr_cls": encoder_rep[:, 0, :].data.cpu().numpy(), # get cls token +# "atom_repr": encoder_rep_list, +# "pair_repr": encoder_pair_rep_list, +# "data_name": sample[target_key][name_key], +# "bsz": sample[input_key]["src_tokens"].size(0), +# } +# return 0, sample_size, logging_output diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/mindspore_ascend_outputs/mindspore_data_loader_add2d_idx0.npz b/MindChemistry/applications/Uni-Mol/unimol/unimol/mindspore_ascend_outputs/mindspore_data_loader_add2d_idx0.npz new file mode 100644 index 0000000000000000000000000000000000000000..d631085ca52101187624670099e9624c6d2d8120 Binary files /dev/null and b/MindChemistry/applications/Uni-Mol/unimol/unimol/mindspore_ascend_outputs/mindspore_data_loader_add2d_idx0.npz differ diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/models/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf38c9bbb41edd62485153fcea4a50c568a47c5 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/__init__.py @@ -0,0 +1,4 @@ +# from .unimol import UniMolModel +# from .transformer_encoder_with_pair import TransformerEncoderWithPair +# from .conf_gen import UnimolConfGModel +# from .docking_pose import DockingPoseModel \ No newline at end of file diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/models/conf_gen.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/conf_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..1988d7bd3cf21733cb6be61c277b74da0e74778b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/conf_gen.py @@ -0,0 +1,346 @@ +import logging +import mindspore as ms +import mindspore.mint.nn as nn +import mindspore.mint.nn.functional as F +from unicore import utils +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unicore.data import Dictionary +from unimol.models.unimol import UniMolModel +from unicore.modules import LayerNorm +from typing import Optional, Dict, Any, List +from .unimol import base_architecture + +logger = logging.getLogger(__name__) + + +@register_model("mol_confG") +class UnimolConfGModel(BaseUnicoreModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--distance-loss", + type=float, + default=1.0, + help="weight for the distance loss", + ) + parser.add_argument( + "--coord-loss", + type=float, + default=1.0, + help="weight for the coordinate loss", + ) + + parser.add_argument( + "--num-recycles", + type=int, + default=1, + help="number of cycles to use for coordinate prediction", + ) + + def __init__(self, args, mol_dictionary): + super().__init__() + unimol_confG_architecture(args) + self.args = args + self.unimol = UniMolModel(self.args, mol_dictionary) + # 适配Ascend NPU环境 + self.set_device("Ascend") + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args, task.dictionary) + + def forward( + self, + src_tokens, + src_distance, + src_coord, + src_edge_type, + encoder_masked_tokens=None, + **kwargs + ): + def get_dist_features(dist, et): + n_node = dist.shape[-1] # 替换torch.size为shape + gbf_feature = self.unimol.gbf(dist, et) + gbf_result = self.unimol.gbf_proj(gbf_feature) + graph_attn_bias = gbf_result + # 替换permute和contiguous + graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) + graph_attn_bias = ms.ops.contiguous(graph_attn_bias) + graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) # 替换view为reshape + return graph_attn_bias + + def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): + if attn_mask is not None and padding_mask is not None: + # 替换size为shape,view为reshape + attn_mask = attn_mask.reshape(x.shape[0], -1, seq_len, seq_len) + # 替换masked_fill_为masked_fill(MindSpore无in-place操作) + attn_mask = attn_mask.masked_fill( + ms.ops.unsqueeze(ms.ops.unsqueeze(padding_mask, 1), 2).astype(ms.bool_), # 替换torch.bool为ms.bool_ + fill_val, + ) + attn_mask = attn_mask.reshape(-1, seq_len, seq_len) # 替换view为reshape + padding_mask = None + return attn_mask, padding_mask + + def single_encoder( + emb: ms.Tensor, # 替换torch.Tensor为ms.Tensor + attn_mask: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + ): + x = self.unimol.encoder.emb_layer_norm(emb) + # 适配dropout参数,保持一致 + x = F.dropout(x, p=self.unimol.encoder.emb_dropout, training=self.training) + + if padding_mask is not None: + x = x * (1 - padding_mask.unsqueeze(-1).astype(x.dtype)) # 替换type_as为astype + attn_mask, padding_mask = fill_attn_mask( + attn_mask, padding_mask, fill_val=float("-inf") + ) + + for i in range(len(self.unimol.encoder.layers)): + x, attn_mask, _ = self.unimol.encoder.layers[i]( + x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True + ) + + return x, attn_mask + + # 替换torch.eq为ms.equal + padding_mask = ms.equal(src_tokens, self.unimol.padding_idx) + input_padding_mask = padding_mask + x = self.unimol.embed_tokens(src_tokens) + attn_mask = get_dist_features(src_distance, src_edge_type) + input_attn_mask = attn_mask + bsz = x.shape[0] # 替换size为shape + seq_len = x.shape[1] # 替换size为shape + + for _ in range(self.args.num_recycles): + x, attn_mask = single_encoder( + x, padding_mask=padding_mask, attn_mask=attn_mask + ) + + if self.unimol.encoder.final_layer_norm is not None: + x = self.unimol.encoder.final_layer_norm(x) + + delta_pair_repr = attn_mask - input_attn_mask + delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) + attn_mask, _ = fill_attn_mask(attn_mask, input_padding_mask, 0) + # 替换permute和contiguous + attn_mask = ms.ops.permute(attn_mask.reshape(bsz, -1, seq_len, seq_len), (0, 2, 3, 1)) + attn_mask = ms.ops.contiguous(attn_mask) + # 替换permute和contiguous + delta_pair_repr = ms.ops.permute(delta_pair_repr.reshape(bsz, -1, seq_len, seq_len), (0, 2, 3, 1)) + delta_pair_repr = ms.ops.contiguous(delta_pair_repr) + + distance_predict, coords_predict = None, None + + if self.args.coord_loss > 0 or True: + if padding_mask is not None: + # 替换torch.sum为ms.ops.sum,unsqueeze为ms.ops.unsqueeze + atom_num = (ms.ops.sum(ms.logical_not(padding_mask), axis=1) - 1).reshape(-1, 1, 1, 1) + else: + atom_num = src_coord.shape[1] - 1 + # 替换unsqueeze为ms.ops.unsqueeze + delta_pos = ms.ops.unsqueeze(src_coord, 1) - ms.ops.unsqueeze(src_coord, 2) + attn_probs = self.unimol.pair2coord_proj(delta_pair_repr) + coords_update = delta_pos / atom_num * attn_probs + # 替换torch.sum为ms.ops.sum + coords_update = ms.ops.sum(coords_update, axis=2) + coords_predict = src_coord + coords_update + + if self.args.distance_loss > 0 or True: + distance_predict = self.unimol.dist_head(attn_mask) + + return [distance_predict, coords_predict] + + +@register_model_architecture("mol_confG", "mol_confG") +def unimol_confG_architecture(args): + def base_architecture(args): + args.encoder_layers = getattr(args, "encoder_layers", 15) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) + args.dropout = getattr(args, "dropout", 0.1) + args.emb_dropout = getattr(args, "emb_dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.max_seq_len = getattr(args, "max_seq_len", 512) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.post_ln = getattr(args, "post_ln", False) + args.masked_coord_loss = getattr(args, "masked_coord_loss", 1.0) + args.masked_dist_loss = getattr(args, "masked_dist_loss", 1.0) + + base_architecture(args) +# import logging +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from unicore import utils +# from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +# from unicore.data import Dictionary +# from .unimol import UniMolModel +# from unicore.modules import LayerNorm +# from typing import Optional, Dict, Any, List +# from .unimol import base_architecture + +# logger = logging.getLogger(__name__) + + +# @register_model("mol_confG") +# class UnimolConfGModel(BaseUnicoreModel): +# @staticmethod +# def add_args(parser): +# """Add model-specific arguments to the parser.""" +# parser.add_argument( +# "--distance-loss", +# type=float, +# default=1.0, +# help="weight for the distance loss", +# ) +# parser.add_argument( +# "--coord-loss", +# type=float, +# default=1.0, +# help="weight for the coordinate loss", +# ) + +# parser.add_argument( +# "--num-recycles", +# type=int, +# default=1, +# help="number of cycles to use for coordinate prediction", +# ) + +# def __init__(self, args, mol_dictionary): +# super().__init__() +# unimol_confG_architecture(args) +# self.args = args +# self.unimol = UniMolModel(self.args, mol_dictionary) + +# @classmethod +# def build_model(cls, args, task): +# """Build a new model instance.""" +# return cls(args, task.dictionary) + +# def forward( +# self, +# src_tokens, +# src_distance, +# src_coord, +# src_edge_type, +# encoder_masked_tokens=None, +# **kwargs +# ): +# def get_dist_features(dist, et): +# n_node = dist.size(-1) +# gbf_feature = self.unimol.gbf(dist, et) +# gbf_result = self.unimol.gbf_proj(gbf_feature) +# graph_attn_bias = gbf_result +# graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() +# graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) +# return graph_attn_bias + +# def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): +# if attn_mask is not None and padding_mask is not None: +# # merge key_padding_mask and attn_mask +# attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) +# attn_mask.masked_fill_( +# padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), +# fill_val, +# ) +# attn_mask = attn_mask.view(-1, seq_len, seq_len) +# padding_mask = None +# return attn_mask, padding_mask + +# def single_encoder( +# emb: torch.Tensor, +# attn_mask: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# ): +# x = self.unimol.encoder.emb_layer_norm(emb) +# x = F.dropout(x, p=self.unimol.encoder.emb_dropout, training=self.training) + +# if padding_mask is not None: +# x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) +# attn_mask, padding_mask = fill_attn_mask( +# attn_mask, padding_mask, fill_val=float("-inf") +# ) + +# for i in range(len(self.unimol.encoder.layers)): +# x, attn_mask, _ = self.unimol.encoder.layers[i]( +# x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True +# ) + +# return x, attn_mask + +# padding_mask = src_tokens.eq(self.unimol.padding_idx) +# input_padding_mask = padding_mask +# x = self.unimol.embed_tokens(src_tokens) +# attn_mask = get_dist_features(src_distance, src_edge_type) +# input_attn_mask = attn_mask +# bsz = x.size(0) +# seq_len = x.size(1) + +# for _ in range(self.args.num_recycles): +# x, attn_mask = single_encoder( +# x, padding_mask=padding_mask, attn_mask=attn_mask +# ) + +# if self.unimol.encoder.final_layer_norm is not None: +# x = self.unimol.encoder.final_layer_norm(x) + +# delta_pair_repr = attn_mask - input_attn_mask +# delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) +# attn_mask, _ = fill_attn_mask(attn_mask, input_padding_mask, 0) +# attn_mask = ( +# attn_mask.view(bsz, -1, seq_len, seq_len).permute(0, 2, 3, 1).contiguous() +# ) +# delta_pair_repr = ( +# delta_pair_repr.view(bsz, -1, seq_len, seq_len) +# .permute(0, 2, 3, 1) +# .contiguous() +# ) + +# distance_predict, coords_predict = None, None + +# if self.args.coord_loss > 0 or True: +# if padding_mask is not None: +# atom_num = (torch.sum(~padding_mask, dim=1) - 1).view(-1, 1, 1, 1) +# else: +# atom_num = src_coord.shape[1] - 1 +# delta_pos = src_coord.unsqueeze(1) - src_coord.unsqueeze(2) +# attn_probs = self.unimol.pair2coord_proj(delta_pair_repr) +# coords_update = delta_pos / atom_num * attn_probs +# coords_update = torch.sum(coords_update, dim=2) +# coords_predict = src_coord + coords_update + +# if self.args.distance_loss > 0 or True: +# distance_predict = self.unimol.dist_head(attn_mask) + +# return [distance_predict, coords_predict] + + +# @register_model_architecture("mol_confG", "mol_confG") +# def unimol_confG_architecture(args): +# def base_architecture(args): +# args.encoder_layers = getattr(args, "encoder_layers", 15) +# args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) +# args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) +# args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) +# args.dropout = getattr(args, "dropout", 0.1) +# args.emb_dropout = getattr(args, "emb_dropout", 0.1) +# args.attention_dropout = getattr(args, "attention_dropout", 0.1) +# args.activation_dropout = getattr(args, "activation_dropout", 0.0) +# args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) +# args.max_seq_len = getattr(args, "max_seq_len", 512) +# args.activation_fn = getattr(args, "activation_fn", "gelu") +# args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") +# args.post_ln = getattr(args, "post_ln", False) +# args.masked_coord_loss = getattr(args, "masked_coord_loss", 1.0) +# args.masked_dist_loss = getattr(args, "masked_dist_loss", 1.0) + +# base_architecture(args) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/models/docking_pose.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/docking_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..a0bdbd171f8fe0b68d49046d9743beeb64d66fdf --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/docking_pose.py @@ -0,0 +1,626 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import argparse +import mindspore as ms +import mindspore.mint.nn as nn +import mindspore.mint.nn.functional as F +from unicore import utils +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unicore.data import Dictionary +from .unimol import UniMolModel, base_architecture, NonLinearHead +from unicore.modules import LayerNorm +from .transformer_encoder_with_pair import TransformerEncoderWithPair +import numpy as np + +logger = logging.getLogger(__name__) + + +@register_model("docking_pose") +class DockingPoseModel(BaseUnicoreModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--mol-pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--pocket-pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--pocket-encoder-layers", + type=int, + help="pocket encoder layers", + ) + parser.add_argument( + "--recycling", + type=int, + default=1, + help="recycling nums of decoder", + ) + + def __init__(self, args, mol_dictionary, pocket_dictionary): + super().__init__() + unimol_docking_architecture(args) + + self.args = args + self.mol_model = UniMolModel(args.mol, mol_dictionary) + self.pocket_model = UniMolModel(args.pocket, pocket_dictionary) + self.concat_decoder = TransformerEncoderWithPair( + encoder_layers=4, + embed_dim=args.mol.encoder_embed_dim, + ffn_embed_dim=args.mol.encoder_ffn_embed_dim, + attention_heads=args.mol.encoder_attention_heads, + emb_dropout=0.1, + dropout=0.1, + attention_dropout=0.1, + activation_dropout=0.0, + activation_fn="gelu", + ) + self.cross_distance_project = NonLinearHead( + args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu" + ) + self.holo_distance_project = DistanceHead( + args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu" + ) + # 适配Ascend NPU环境 + self.set_device("Ascend") + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args, task.dictionary, task.pocket_dictionary) + + def forward( + self, + mol_src_tokens, + mol_src_distance, + mol_src_edge_type, + pocket_src_tokens, + pocket_src_distance, + pocket_src_edge_type, + masked_tokens=None, + features_only=True,** kwargs + ): + def get_dist_features(dist, et, flag): + if flag == "mol": + n_node = dist.shape[-1] # 替换size为shape + gbf_feature = self.mol_model.gbf(dist, et) + gbf_result = self.mol_model.gbf_proj(gbf_feature) + graph_attn_bias = gbf_result + # 替换permute和contiguous + graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) + graph_attn_bias = ms.ops.contiguous(graph_attn_bias) + graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) # 替换view为reshape + return graph_attn_bias + else: + n_node = dist.shape[-1] # 替换size为shape + gbf_feature = self.pocket_model.gbf(dist, et) + gbf_result = self.pocket_model.gbf_proj(gbf_feature) + graph_attn_bias = gbf_result + # 替换permute和contiguous + graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) + graph_attn_bias = ms.ops.contiguous(graph_attn_bias) + graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) # 替换view为reshape + return graph_attn_bias + + # 替换torch.eq为ms.equal + mol_padding_mask = ms.equal(mol_src_tokens, self.mol_model.padding_idx) + mol_x = self.mol_model.embed_tokens(mol_src_tokens) + mol_graph_attn_bias = get_dist_features( + mol_src_distance, mol_src_edge_type, "mol" + ) + mol_outputs = self.mol_model.encoder( + mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias + ) + mol_encoder_rep = mol_outputs[0] + mol_encoder_pair_rep = mol_outputs[1] + + # 替换torch.eq为ms.equal + pocket_padding_mask = ms.equal(pocket_src_tokens, self.pocket_model.padding_idx) + pocket_x = self.pocket_model.embed_tokens(pocket_src_tokens) + pocket_graph_attn_bias = get_dist_features( + pocket_src_distance, pocket_src_edge_type, "pocket" + ) + pocket_outputs = self.pocket_model.encoder( + pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias + ) + pocket_encoder_rep = pocket_outputs[0] + pocket_encoder_pair_rep = pocket_outputs[1] + + mol_sz = mol_encoder_rep.shape[1] # 替换size为shape + pocket_sz = pocket_encoder_rep.shape[1] # 替换size为shape + + # 替换torch.cat为ms.ops.cat + concat_rep = ms.ops.cat( + [mol_encoder_rep, pocket_encoder_rep], dim=-2 + ) # [batch, mol_sz+pocket_sz, hidden_dim] + # 替换torch.cat为ms.ops.cat + concat_mask = ms.ops.cat( + [mol_padding_mask, pocket_padding_mask], dim=-1 + ) # [batch, mol_sz+pocket_sz] + attn_bs = mol_graph_attn_bias.shape[0] # 替换size为shape + + # 替换torch.zeros和type_as + concat_attn_bias = ms.ops.zeros( + (attn_bs, mol_sz + pocket_sz, mol_sz + pocket_sz), dtype=concat_rep.dtype + ) # [batch, mol_sz+pocket_sz, mol_sz+pocket_sz] + + # 替换permute、reshape和contiguous + mol_pair_reshape = ms.ops.contiguous( + ms.ops.reshape( + ms.ops.permute(mol_encoder_pair_rep, (0, 3, 1, 2)), + (-1, mol_sz, mol_sz) + ) + ) + concat_attn_bias[:, :mol_sz, :mol_sz] = mol_pair_reshape + + # 替换permute、reshape和contiguous + pocket_pair_reshape = ms.ops.contiguous( + ms.ops.reshape( + ms.ops.permute(pocket_encoder_pair_rep, (0, 3, 1, 2)), + (-1, pocket_sz, pocket_sz) + ) + ) + concat_attn_bias[:, -pocket_sz:, -pocket_sz:] = pocket_pair_reshape + + decoder_rep = concat_rep + decoder_pair_rep = concat_attn_bias + for i in range(self.args.recycling): + decoder_outputs = self.concat_decoder( + decoder_rep, padding_mask=concat_mask, attn_mask=decoder_pair_rep + ) + decoder_rep = decoder_outputs[0] + decoder_pair_rep = decoder_outputs[1] + if i != (self.args.recycling - 1): + # 替换permute和reshape + decoder_pair_rep = ms.ops.reshape( + ms.ops.permute(decoder_pair_rep, (0, 3, 1, 2)), + (-1, mol_sz + pocket_sz, mol_sz + pocket_sz) + ) + + mol_decoder = decoder_rep[:, :mol_sz] + pocket_decoder = decoder_rep[:, mol_sz:] + + mol_pair_decoder_rep = decoder_pair_rep[:, :mol_sz, :mol_sz, :] + # 替换transpose为ms.ops.transpose + mol_pocket_pair_decoder_rep = ( + decoder_pair_rep[:, :mol_sz, mol_sz:, :] + + ms.ops.transpose(decoder_pair_rep[:, mol_sz:, :mol_sz, :], (0, 2, 1, 3)) + ) / 2.0 + # 替换张量索引赋值 + mol_pocket_pair_decoder_rep = ms.ops.where( + mol_pocket_pair_decoder_rep == float("-inf"), + ms.ops.zeros_like(mol_pocket_pair_decoder_rep), + mol_pocket_pair_decoder_rep + ) + + # 替换unsqueeze和repeat + mol_unsqueeze = ms.ops.repeat_elements( + ms.ops.unsqueeze(mol_decoder, -2), + rep=pocket_sz, + axis=-2 + ) + pocket_unsqueeze = ms.ops.repeat_elements( + ms.ops.unsqueeze(pocket_decoder, -3), + rep=mol_sz, + axis=-3 + ) + # 替换torch.cat为ms.ops.cat + cross_rep = ms.ops.cat( + [ + mol_pocket_pair_decoder_rep, + mol_unsqueeze, + pocket_unsqueeze, + ], + dim=-1, + ) # [batch, mol_sz, pocket_sz, 4*hidden_size] + + # 替换squeeze和F.elu + cross_distance_predict = ( + F.elu(self.cross_distance_project(cross_rep).squeeze(-1)) + 1.0 + ) # batch, mol_sz, pocket_sz + + # 替换unsqueeze和repeat + mol_repeat = ms.ops.repeat_elements( + ms.ops.unsqueeze(mol_decoder, -2), + rep=mol_sz, + axis=-2 + ) + # 替换torch.cat为ms.ops.cat + holo_encoder_pair_rep = ms.ops.cat( + [ + mol_pair_decoder_rep, + mol_repeat, + ], + dim=-1, + ) # [batch, mol_sz, mol_sz, 3*hidden_size] + holo_distance_predict = self.holo_distance_project( + holo_encoder_pair_rep + ) # batch, mol_sz, mol_sz + + return cross_distance_predict, holo_distance_predict + + def set_num_updates(self, num_updates): + """State from trainer to pass along to model at every update.""" + self._num_updates = num_updates + + def get_num_updates(self): + return self._num_updates + + +class DistanceHead(nn.Cell): # MindSpore中模型层继承自Cell + def __init__( + self, + heads, + activation_fn, + ): + super().__init__() + self.dense = nn.Linear(heads, heads) + self.layer_norm = nn.LayerNorm(heads) + self.out_proj = nn.Linear(heads, 1) + self.activation_fn = utils.get_activation_fn(activation_fn) + + def forward(self, x): + bsz, seq_len, seq_len, _ = x.shape # 替换size为shape + # 替换张量索引赋值 + x = ms.ops.where( + x == float("-inf"), + ms.ops.zeros_like(x), + x + ) + x = self.dense(x) + x = self.activation_fn(x) + x = self.layer_norm(x) + x = self.out_proj(x).reshape(bsz, seq_len, seq_len) # 替换view为reshape + # 替换transpose为ms.ops.transpose + x = (x + ms.ops.transpose(x, (-1, -2))) * 0.5 + return x + + +@register_model_architecture("docking_pose", "docking_pose") +def unimol_docking_architecture(args): + parser = argparse.ArgumentParser() + args.mol = parser.parse_args([]) + args.pocket = parser.parse_args([]) + + args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15) + args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512) + args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048) + args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64) + args.mol.dropout = getattr(args, "mol_dropout", 0.1) + args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1) + args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1) + args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0) + args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0) + args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512) + args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu") + args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh") + args.mol.post_ln = getattr(args, "mol_post_ln", False) + args.mol.masked_token_loss = -1.0 + args.mol.masked_coord_loss = -1.0 + args.mol.masked_dist_loss = -1.0 + args.mol.x_norm_loss = -1.0 + args.mol.delta_pair_repr_norm_loss = -1.0 + + args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15) + args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512) + args.pocket.encoder_ffn_embed_dim = getattr( + args, "pocket_encoder_ffn_embed_dim", 2048 + ) + args.pocket.encoder_attention_heads = getattr( + args, "pocket_encoder_attention_heads", 64 + ) + args.pocket.dropout = getattr(args, "pocket_dropout", 0.1) + args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1) + args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1) + args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0) + args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0) + args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512) + args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu") + args.pocket.pooler_activation_fn = getattr( + args, "pocket_pooler_activation_fn", "tanh" + ) + args.pocket.post_ln = getattr(args, "pocket_post_ln", False) + args.pocket.masked_token_loss = -1.0 + args.pocket.masked_coord_loss = -1.0 + args.pocket.masked_dist_loss = -1.0 + args.pocket.x_norm_loss = -1.0 + args.pocket.delta_pair_repr_norm_loss = -1.0 + + base_architecture(args) +# import logging + +# import argparse +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from unicore import utils +# from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +# from unicore.data import Dictionary +# from .unimol import UniMolModel, base_architecture, NonLinearHead +# from unicore.modules import LayerNorm +# from .transformer_encoder_with_pair import TransformerEncoderWithPair +# import numpy as np + +# logger = logging.getLogger(__name__) + + +# @register_model("docking_pose") +# class DockingPoseModel(BaseUnicoreModel): +# @staticmethod +# def add_args(parser): +# """Add model-specific arguments to the parser.""" +# parser.add_argument( +# "--mol-pooler-dropout", +# type=float, +# metavar="D", +# help="dropout probability in the masked_lm pooler layers", +# ) +# parser.add_argument( +# "--pocket-pooler-dropout", +# type=float, +# metavar="D", +# help="dropout probability in the masked_lm pooler layers", +# ) +# parser.add_argument( +# "--pocket-encoder-layers", +# type=int, +# help="pocket encoder layers", +# ) +# parser.add_argument( +# "--recycling", +# type=int, +# default=1, +# help="recycling nums of decoder", +# ) + +# def __init__(self, args, mol_dictionary, pocket_dictionary): +# super().__init__() +# unimol_docking_architecture(args) + +# self.args = args +# self.mol_model = UniMolModel(args.mol, mol_dictionary) +# self.pocket_model = UniMolModel(args.pocket, pocket_dictionary) +# self.concat_decoder = TransformerEncoderWithPair( +# encoder_layers=4, +# embed_dim=args.mol.encoder_embed_dim, +# ffn_embed_dim=args.mol.encoder_ffn_embed_dim, +# attention_heads=args.mol.encoder_attention_heads, +# emb_dropout=0.1, +# dropout=0.1, +# attention_dropout=0.1, +# activation_dropout=0.0, +# activation_fn="gelu", +# ) +# self.cross_distance_project = NonLinearHead( +# args.mol.encoder_embed_dim * 2 + args.mol.encoder_attention_heads, 1, "relu" +# ) +# self.holo_distance_project = DistanceHead( +# args.mol.encoder_embed_dim + args.mol.encoder_attention_heads, "relu" +# ) + +# @classmethod +# def build_model(cls, args, task): +# """Build a new model instance.""" +# return cls(args, task.dictionary, task.pocket_dictionary) + +# def forward( +# self, +# mol_src_tokens, +# mol_src_distance, +# mol_src_edge_type, +# pocket_src_tokens, +# pocket_src_distance, +# pocket_src_edge_type, +# masked_tokens=None, +# features_only=True, +# **kwargs +# ): +# def get_dist_features(dist, et, flag): +# if flag == "mol": +# n_node = dist.size(-1) +# gbf_feature = self.mol_model.gbf(dist, et) +# gbf_result = self.mol_model.gbf_proj(gbf_feature) +# graph_attn_bias = gbf_result +# graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() +# graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) +# return graph_attn_bias +# else: +# n_node = dist.size(-1) +# gbf_feature = self.pocket_model.gbf(dist, et) +# gbf_result = self.pocket_model.gbf_proj(gbf_feature) +# graph_attn_bias = gbf_result +# graph_attn_bias = graph_attn_bias.permute(0, 3, 1, 2).contiguous() +# graph_attn_bias = graph_attn_bias.view(-1, n_node, n_node) +# return graph_attn_bias + +# mol_padding_mask = mol_src_tokens.eq(self.mol_model.padding_idx) +# mol_x = self.mol_model.embed_tokens(mol_src_tokens) +# mol_graph_attn_bias = get_dist_features( +# mol_src_distance, mol_src_edge_type, "mol" +# ) +# mol_outputs = self.mol_model.encoder( +# mol_x, padding_mask=mol_padding_mask, attn_mask=mol_graph_attn_bias +# ) +# mol_encoder_rep = mol_outputs[0] +# mol_encoder_pair_rep = mol_outputs[1] + +# pocket_padding_mask = pocket_src_tokens.eq(self.pocket_model.padding_idx) +# pocket_x = self.pocket_model.embed_tokens(pocket_src_tokens) +# pocket_graph_attn_bias = get_dist_features( +# pocket_src_distance, pocket_src_edge_type, "pocket" +# ) +# pocket_outputs = self.pocket_model.encoder( +# pocket_x, padding_mask=pocket_padding_mask, attn_mask=pocket_graph_attn_bias +# ) +# pocket_encoder_rep = pocket_outputs[0] +# pocket_encoder_pair_rep = pocket_outputs[1] + +# mol_sz = mol_encoder_rep.size(1) +# pocket_sz = pocket_encoder_rep.size(1) + +# concat_rep = torch.cat( +# [mol_encoder_rep, pocket_encoder_rep], dim=-2 +# ) # [batch, mol_sz+pocket_sz, hidden_dim] +# concat_mask = torch.cat( +# [mol_padding_mask, pocket_padding_mask], dim=-1 +# ) # [batch, mol_sz+pocket_sz] +# attn_bs = mol_graph_attn_bias.size(0) + +# concat_attn_bias = torch.zeros( +# attn_bs, mol_sz + pocket_sz, mol_sz + pocket_sz +# ).type_as( +# concat_rep +# ) # [batch, mol_sz+pocket_sz, mol_sz+pocket_sz] +# concat_attn_bias[:, :mol_sz, :mol_sz] = ( +# mol_encoder_pair_rep.permute(0, 3, 1, 2) +# .reshape(-1, mol_sz, mol_sz) +# .contiguous() +# ) +# concat_attn_bias[:, -pocket_sz:, -pocket_sz:] = ( +# pocket_encoder_pair_rep.permute(0, 3, 1, 2) +# .reshape(-1, pocket_sz, pocket_sz) +# .contiguous() +# ) + +# decoder_rep = concat_rep +# decoder_pair_rep = concat_attn_bias +# for i in range(self.args.recycling): +# decoder_outputs = self.concat_decoder( +# decoder_rep, padding_mask=concat_mask, attn_mask=decoder_pair_rep +# ) +# decoder_rep = decoder_outputs[0] +# decoder_pair_rep = decoder_outputs[1] +# if i != (self.args.recycling - 1): +# decoder_pair_rep = decoder_pair_rep.permute(0, 3, 1, 2).reshape( +# -1, mol_sz + pocket_sz, mol_sz + pocket_sz +# ) + +# mol_decoder = decoder_rep[:, :mol_sz] +# pocket_decoder = decoder_rep[:, mol_sz:] + +# mol_pair_decoder_rep = decoder_pair_rep[:, :mol_sz, :mol_sz, :] +# mol_pocket_pair_decoder_rep = ( +# decoder_pair_rep[:, :mol_sz, mol_sz:, :] +# + decoder_pair_rep[:, mol_sz:, :mol_sz, :].transpose(1, 2) +# ) / 2.0 +# mol_pocket_pair_decoder_rep[mol_pocket_pair_decoder_rep == float("-inf")] = 0 + +# cross_rep = torch.cat( +# [ +# mol_pocket_pair_decoder_rep, +# mol_decoder.unsqueeze(-2).repeat(1, 1, pocket_sz, 1), +# pocket_decoder.unsqueeze(-3).repeat(1, mol_sz, 1, 1), +# ], +# dim=-1, +# ) # [batch, mol_sz, pocket_sz, 4*hidden_size] + +# cross_distance_predict = ( +# F.elu(self.cross_distance_project(cross_rep).squeeze(-1)) + 1.0 +# ) # batch, mol_sz, pocket_sz + +# holo_encoder_pair_rep = torch.cat( +# [ +# mol_pair_decoder_rep, +# mol_decoder.unsqueeze(-2).repeat(1, 1, mol_sz, 1), +# ], +# dim=-1, +# ) # [batch, mol_sz, mol_sz, 3*hidden_size] +# holo_distance_predict = self.holo_distance_project( +# holo_encoder_pair_rep +# ) # batch, mol_sz, mol_sz + +# return cross_distance_predict, holo_distance_predict + +# def set_num_updates(self, num_updates): +# """State from trainer to pass along to model at every update.""" + +# self._num_updates = num_updates + +# def get_num_updates(self): +# return self._num_updates + + +# class DistanceHead(nn.Module): +# def __init__( +# self, +# heads, +# activation_fn, +# ): +# super().__init__() +# self.dense = nn.Linear(heads, heads) +# self.layer_norm = nn.LayerNorm(heads) +# self.out_proj = nn.Linear(heads, 1) +# self.activation_fn = utils.get_activation_fn(activation_fn) + +# def forward(self, x): +# bsz, seq_len, seq_len, _ = x.size() +# x[x == float("-inf")] = 0 +# x = self.dense(x) +# x = self.activation_fn(x) +# x = self.layer_norm(x) +# x = self.out_proj(x).view(bsz, seq_len, seq_len) +# x = (x + x.transpose(-1, -2)) * 0.5 +# return x + + +# @register_model_architecture("docking_pose", "docking_pose") +# def unimol_docking_architecture(args): + +# parser = argparse.ArgumentParser() +# args.mol = parser.parse_args([]) +# args.pocket = parser.parse_args([]) + +# args.mol.encoder_layers = getattr(args, "mol_encoder_layers", 15) +# args.mol.encoder_embed_dim = getattr(args, "mol_encoder_embed_dim", 512) +# args.mol.encoder_ffn_embed_dim = getattr(args, "mol_encoder_ffn_embed_dim", 2048) +# args.mol.encoder_attention_heads = getattr(args, "mol_encoder_attention_heads", 64) +# args.mol.dropout = getattr(args, "mol_dropout", 0.1) +# args.mol.emb_dropout = getattr(args, "mol_emb_dropout", 0.1) +# args.mol.attention_dropout = getattr(args, "mol_attention_dropout", 0.1) +# args.mol.activation_dropout = getattr(args, "mol_activation_dropout", 0.0) +# args.mol.pooler_dropout = getattr(args, "mol_pooler_dropout", 0.0) +# args.mol.max_seq_len = getattr(args, "mol_max_seq_len", 512) +# args.mol.activation_fn = getattr(args, "mol_activation_fn", "gelu") +# args.mol.pooler_activation_fn = getattr(args, "mol_pooler_activation_fn", "tanh") +# args.mol.post_ln = getattr(args, "mol_post_ln", False) +# args.mol.masked_token_loss = -1.0 +# args.mol.masked_coord_loss = -1.0 +# args.mol.masked_dist_loss = -1.0 +# args.mol.x_norm_loss = -1.0 +# args.mol.delta_pair_repr_norm_loss = -1.0 + +# args.pocket.encoder_layers = getattr(args, "pocket_encoder_layers", 15) +# args.pocket.encoder_embed_dim = getattr(args, "pocket_encoder_embed_dim", 512) +# args.pocket.encoder_ffn_embed_dim = getattr( +# args, "pocket_encoder_ffn_embed_dim", 2048 +# ) +# args.pocket.encoder_attention_heads = getattr( +# args, "pocket_encoder_attention_heads", 64 +# ) +# args.pocket.dropout = getattr(args, "pocket_dropout", 0.1) +# args.pocket.emb_dropout = getattr(args, "pocket_emb_dropout", 0.1) +# args.pocket.attention_dropout = getattr(args, "pocket_attention_dropout", 0.1) +# args.pocket.activation_dropout = getattr(args, "pocket_activation_dropout", 0.0) +# args.pocket.pooler_dropout = getattr(args, "pocket_pooler_dropout", 0.0) +# args.pocket.max_seq_len = getattr(args, "pocket_max_seq_len", 512) +# args.pocket.activation_fn = getattr(args, "pocket_activation_fn", "gelu") +# args.pocket.pooler_activation_fn = getattr( +# args, "pocket_pooler_activation_fn", "tanh" +# ) +# args.pocket.post_ln = getattr(args, "pocket_post_ln", False) +# args.pocket.masked_token_loss = -1.0 +# args.pocket.masked_coord_loss = -1.0 +# args.pocket.masked_dist_loss = -1.0 +# args.pocket.x_norm_loss = -1.0 +# args.pocket.delta_pair_repr_norm_loss = -1.0 + +# base_architecture(args) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/models/transformer_encoder_with_pair.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/transformer_encoder_with_pair.py new file mode 100644 index 0000000000000000000000000000000000000000..1f27cf4df0694b5e6a11017fbf9f411d0f229da4 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/transformer_encoder_with_pair.py @@ -0,0 +1,293 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +from typing import Optional +import math +import mindspore as ms +import mindspore.mint.nn as nn +import mindspore.mint.nn.functional as F +from unicore.modules import TransformerEncoderLayer, LayerNorm + + +class TransformerEncoderWithPair(nn.Cell): + def __init__( + self, + encoder_layers: int = 6, + embed_dim: int = 768, + ffn_embed_dim: int = 3072, + attention_heads: int = 8, + emb_dropout: float = 0.1, + dropout: float = 0.1, + attention_dropout: float = 0.1, + activation_dropout: float = 0.0, + max_seq_len: int = 256, + activation_fn: str = "gelu", + post_ln: bool = False, + no_final_head_layer_norm: bool = False, + ) -> None: + + super().__init__() + self.emb_dropout = emb_dropout + self.max_seq_len = max_seq_len + self.embed_dim = embed_dim + self.attention_heads = attention_heads + self.emb_layer_norm = LayerNorm(self.embed_dim) + if not post_ln: + self.final_layer_norm = LayerNorm(self.embed_dim) + else: + self.final_layer_norm = None + + if not no_final_head_layer_norm: + self.final_head_layer_norm = LayerNorm(attention_heads) + else: + self.final_head_layer_norm = None + + # MindSpore中用CellList替代ModuleList + self.layers = nn.CellList( + [ + TransformerEncoderLayer( + embed_dim=self.embed_dim, + ffn_embed_dim=ffn_embed_dim, + attention_heads=attention_heads, + dropout=dropout, + attention_dropout=attention_dropout, + activation_dropout=activation_dropout, + activation_fn=activation_fn, + post_ln=post_ln, + ) + for _ in range(encoder_layers) + ] + ) + + def forward( + self, + emb: ms.Tensor, + attn_mask: Optional[ms.Tensor] = None, + padding_mask: Optional[ms.Tensor] = None, + ) -> ms.Tensor: + + bsz = emb.shape[0] # 替换size为shape + seq_len = emb.shape[1] # 替换size为shape + x = self.emb_layer_norm(emb) + x = F.dropout(x, p=self.emb_dropout, training=self.training) + + # 处理padding mask + if padding_mask is not None: + # 替换type_as为astype + x = x * (1 - padding_mask.unsqueeze(-1).astype(x.dtype)) + input_attn_mask = attn_mask + input_padding_mask = padding_mask + + def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): + if attn_mask is not None and padding_mask is not None: + # 合并key_padding_mask和attn_mask + # 替换view为reshape,size为shape + attn_mask = attn_mask.reshape(x.shape[0], -1, seq_len, seq_len) + # 替换masked_fill_为masked_fill(非原地操作),torch.bool为ms.bool_,type_as为astype + attn_mask = attn_mask.masked_fill( + ms.ops.unsqueeze(ms.ops.unsqueeze(padding_mask, 1), 2).astype(ms.bool_), + fill_val, + ) + attn_mask = attn_mask.reshape(-1, seq_len, seq_len) # 替换view为reshape + padding_mask = None + return attn_mask, padding_mask + + assert attn_mask is not None + attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask) + + for i in range(len(self.layers)): + x, attn_mask, _ = self.layers[i]( + x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True + ) + + def norm_loss(x, eps=1e-10, tolerance=1.0): + x = x.astype(ms.float32) # 替换float()为astype + max_norm = x.shape[-1] ** 0.5 + # 替换torch.sqrt和torch.sum为ms.ops.sqrt和ms.ops.sum + norm = ms.ops.sqrt(ms.ops.sum(x**2, dim=-1) + eps) + # 替换torch.nn.functional.relu为F.relu + error = F.relu((norm - max_norm).abs() - tolerance) + return error + + def masked_mean(mask, value, dim=-1, eps=1e-10): + # 替换torch.sum和torch.mean为ms.ops.sum和ms.ops.mean + return ( + ms.ops.sum(mask * value, dim=dim) / (eps + ms.ops.sum(mask, dim=dim)) + ).mean() + + x_norm = norm_loss(x) + if input_padding_mask is not None: + token_mask = 1.0 - input_padding_mask.astype(ms.float32) # 替换float()为astype + else: + # 替换torch.ones_like和device为ms.ops.ones_like + token_mask = ms.ops.ones_like(x_norm, dtype=ms.float32) + x_norm = masked_mean(token_mask, x_norm) + + if self.final_layer_norm is not None: + x = self.final_layer_norm(x) + + delta_pair_repr = attn_mask - input_attn_mask + delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) + # 替换view、permute、contiguous为reshape、ms.ops.permute、ms.ops.contiguous + attn_mask = ms.ops.contiguous( + ms.ops.permute( + attn_mask.reshape(bsz, -1, seq_len, seq_len), + (0, 2, 3, 1) + ) + ) + delta_pair_repr = ms.ops.contiguous( + ms.ops.permute( + delta_pair_repr.reshape(bsz, -1, seq_len, seq_len), + (0, 2, 3, 1) + ) + ) + + pair_mask = token_mask[..., None] * token_mask[..., None, :] + delta_pair_repr_norm = norm_loss(delta_pair_repr) + delta_pair_repr_norm = masked_mean( + pair_mask, delta_pair_repr_norm, dim=(-1, -2) + ) + + if self.final_head_layer_norm is not None: + delta_pair_repr = self.final_head_layer_norm(delta_pair_repr) + + return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm +# from typing import Optional + +# import math +# import torch +# import torch.nn as nn +# import torch.nn.functional as F +# from unicore.modules import TransformerEncoderLayer, LayerNorm + + +# class TransformerEncoderWithPair(nn.Module): +# def __init__( +# self, +# encoder_layers: int = 6, +# embed_dim: int = 768, +# ffn_embed_dim: int = 3072, +# attention_heads: int = 8, +# emb_dropout: float = 0.1, +# dropout: float = 0.1, +# attention_dropout: float = 0.1, +# activation_dropout: float = 0.0, +# max_seq_len: int = 256, +# activation_fn: str = "gelu", +# post_ln: bool = False, +# no_final_head_layer_norm: bool = False, +# ) -> None: + +# super().__init__() +# self.emb_dropout = emb_dropout +# self.max_seq_len = max_seq_len +# self.embed_dim = embed_dim +# self.attention_heads = attention_heads +# self.emb_layer_norm = LayerNorm(self.embed_dim) +# if not post_ln: +# self.final_layer_norm = LayerNorm(self.embed_dim) +# else: +# self.final_layer_norm = None + +# if not no_final_head_layer_norm: +# self.final_head_layer_norm = LayerNorm(attention_heads) +# else: +# self.final_head_layer_norm = None + +# self.layers = nn.ModuleList( +# [ +# TransformerEncoderLayer( +# embed_dim=self.embed_dim, +# ffn_embed_dim=ffn_embed_dim, +# attention_heads=attention_heads, +# dropout=dropout, +# attention_dropout=attention_dropout, +# activation_dropout=activation_dropout, +# activation_fn=activation_fn, +# post_ln=post_ln, +# ) +# for _ in range(encoder_layers) +# ] +# ) + +# def forward( +# self, +# emb: torch.Tensor, +# attn_mask: Optional[torch.Tensor] = None, +# padding_mask: Optional[torch.Tensor] = None, +# ) -> torch.Tensor: + +# bsz = emb.size(0) +# seq_len = emb.size(1) +# x = self.emb_layer_norm(emb) +# x = F.dropout(x, p=self.emb_dropout, training=self.training) + +# # account for padding while computing the representation +# if padding_mask is not None: +# x = x * (1 - padding_mask.unsqueeze(-1).type_as(x)) +# input_attn_mask = attn_mask +# input_padding_mask = padding_mask + +# def fill_attn_mask(attn_mask, padding_mask, fill_val=float("-inf")): +# if attn_mask is not None and padding_mask is not None: +# # merge key_padding_mask and attn_mask +# attn_mask = attn_mask.view(x.size(0), -1, seq_len, seq_len) +# attn_mask.masked_fill_( +# padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool), +# fill_val, +# ) +# attn_mask = attn_mask.view(-1, seq_len, seq_len) +# padding_mask = None +# return attn_mask, padding_mask + +# assert attn_mask is not None +# attn_mask, padding_mask = fill_attn_mask(attn_mask, padding_mask) + +# for i in range(len(self.layers)): +# x, attn_mask, _ = self.layers[i]( +# x, padding_mask=padding_mask, attn_bias=attn_mask, return_attn=True +# ) + +# def norm_loss(x, eps=1e-10, tolerance=1.0): +# x = x.float() +# max_norm = x.shape[-1] ** 0.5 +# norm = torch.sqrt(torch.sum(x**2, dim=-1) + eps) +# error = torch.nn.functional.relu((norm - max_norm).abs() - tolerance) +# return error + +# def masked_mean(mask, value, dim=-1, eps=1e-10): +# return ( +# torch.sum(mask * value, dim=dim) / (eps + torch.sum(mask, dim=dim)) +# ).mean() + +# x_norm = norm_loss(x) +# if input_padding_mask is not None: +# token_mask = 1.0 - input_padding_mask.float() +# else: +# token_mask = torch.ones_like(x_norm, device=x_norm.device) +# x_norm = masked_mean(token_mask, x_norm) + +# if self.final_layer_norm is not None: +# x = self.final_layer_norm(x) + +# delta_pair_repr = attn_mask - input_attn_mask +# delta_pair_repr, _ = fill_attn_mask(delta_pair_repr, input_padding_mask, 0) +# attn_mask = ( +# attn_mask.view(bsz, -1, seq_len, seq_len).permute(0, 2, 3, 1).contiguous() +# ) +# delta_pair_repr = ( +# delta_pair_repr.view(bsz, -1, seq_len, seq_len) +# .permute(0, 2, 3, 1) +# .contiguous() +# ) + +# pair_mask = token_mask[..., None] * token_mask[..., None, :] +# delta_pair_repr_norm = norm_loss(delta_pair_repr) +# delta_pair_repr_norm = masked_mean( +# pair_mask, delta_pair_repr_norm, dim=(-1, -2) +# ) + +# if self.final_head_layer_norm is not None: +# delta_pair_repr = self.final_head_layer_norm(delta_pair_repr) + +# return x, attn_mask, delta_pair_repr, x_norm, delta_pair_repr_norm diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/models/unimol.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/unimol.py new file mode 100644 index 0000000000000000000000000000000000000000..a98167918d22b59bd53bf3752b3fc0a0e237d91c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/models/unimol.py @@ -0,0 +1,1026 @@ +# # Copyright (c) DP Technology. +# # This source code is licensed under the MIT license found in the +# # LICENSE file in the root directory of this source tree. +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import logging +import mindspore as ms +import mindspore.mint.nn as nn # 优先使用 mint.nn +import mindspore.mint.nn.functional as F +from unicore import utils +from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +from unicore.modules import LayerNorm, init_bert_params +import unimol +# from unimol.models.transformer_encoder_with_pair import TransformerEncoderWithPair +from typing import Dict, Any, List +import numpy as np + +logger = logging.getLogger(__name__) + + +@register_model("unimol") +class UniMolModel(BaseUnicoreModel): + @staticmethod + def add_args(parser): + """Add model-specific arguments to the parser.""" + parser.add_argument( + "--encoder-layers", type=int, metavar="L", help="num encoder layers" + ) + parser.add_argument( + "--encoder-embed-dim", + type=int, + metavar="H", + help="encoder embedding dimension", + ) + parser.add_argument( + "--encoder-ffn-embed-dim", + type=int, + metavar="F", + help="encoder embedding dimension for FFN", + ) + parser.add_argument( + "--encoder-attention-heads", + type=int, + metavar="A", + help="num encoder attention heads", + ) + parser.add_argument( + "--activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use", + ) + parser.add_argument( + "--pooler-activation-fn", + choices=utils.get_available_activation_fns(), + help="activation function to use for pooler layer", + ) + parser.add_argument( + "--emb-dropout", + type=float, + metavar="D", + help="dropout probability for embeddings", + ) + parser.add_argument( + "--dropout", type=float, metavar="D", help="dropout probability" + ) + parser.add_argument( + "--attention-dropout", + type=float, + metavar="D", + help="dropout probability for attention weights", + ) + parser.add_argument( + "--activation-dropout", + type=float, + metavar="D", + help="dropout probability after activation in FFN", + ) + parser.add_argument( + "--pooler-dropout", + type=float, + metavar="D", + help="dropout probability in the masked_lm pooler layers", + ) + parser.add_argument( + "--max-seq-len", type=int, help="number of positional embeddings to learn" + ) + parser.add_argument( + "--post-ln", type=bool, help="use post layernorm or pre layernorm" + ) + parser.add_argument( + "--masked-token-loss", + type=float, + metavar="D", + help="mask loss ratio", + ) + parser.add_argument( + "--masked-dist-loss", + type=float, + metavar="D", + help="masked distance loss ratio", + ) + parser.add_argument( + "--masked-coord-loss", + type=float, + metavar="D", + help="masked coord loss ratio", + ) + parser.add_argument( + "--x-norm-loss", + type=float, + metavar="D", + help="x norm loss ratio", + ) + parser.add_argument( + "--delta-pair-repr-norm-loss", + type=float, + metavar="D", + help="delta encoder pair repr norm loss ratio", + ) + parser.add_argument( + "--masked-coord-dist-loss", + type=float, + metavar="D", + help="masked coord dist loss ratio", + ) + parser.add_argument( + "--mode", + type=str, + default="train", + choices=["train", "infer"], + ) + + def __init__(self, args, dictionary): + super().__init__() + base_architecture(args) + self.args = args + + # 1. 验证并修正 padding_idx(简化逻辑) + self.padding_idx = dictionary.pad() + dict_len = len(dictionary) + if not (0 <= self.padding_idx < dict_len): + self.padding_idx = dict_len - 1 + print(f"⚠️ 原始 padding_idx={dictionary.pad()} 无效,修正为 {self.padding_idx}") + else: + print(f"✅ padding_idx 有效:{self.padding_idx}(字典长度={dict_len})") + + # 2. 创建 nn.Embedding(仅按位置传参,依赖全局Ascend配置) + self.embed_tokens = nn.Embedding( + dict_len, # 词汇表大小 + args.encoder_embed_dim # 嵌入维度 + ) + assert self.embed_tokens.padding_idx is None, f"Embedding padding_idx异常:{self.embed_tokens.padding_idx}" + print(f"✅ Embedding 创建成功(size={dict_len}, dim={args.encoder_embed_dim})") + + # 3. 手动置零 padding 行权重(用index_fill简化,无设备操作) + weight = self.embed_tokens.weight.data + # 创建全1掩码 + 置零padding行 + mask = ms.ops.ones_like(weight, dtype=ms.float32) + padding_idx_tensor = ms.Tensor([self.padding_idx], dtype=ms.int32) + mask = ms.ops.index_fill(mask, 0, padding_idx_tensor, 0.0) # 按行填充0 + # 更新权重 + new_weight = weight * mask + self.embed_tokens.weight.set_data(new_weight) + print(f"✅ padding 权重置零完成(padding_idx={self.padding_idx})") + + # 4. 初始化核心模块(保留原逻辑,无设备操作) + self._num_updates = None + # Transformer编码器 + self.encoder = TransformerEncoderWithPair( + encoder_layers=args.encoder_layers, + embed_dim=args.encoder_embed_dim, + ffn_embed_dim=args.encoder_ffn_embed_dim, + attention_heads=args.encoder_attention_heads, + emb_dropout=args.emb_dropout, + dropout=args.dropout, + attention_dropout=args.attention_dropout, + activation_dropout=args.activation_dropout, + max_seq_len=args.max_seq_len, + activation_fn=args.activation_fn, + no_final_head_layer_norm=args.delta_pair_repr_norm_loss < 0, + ) + + # 掩码语言模型头(按需) + if args.masked_token_loss > 0: + self.lm_head = MaskLMHead( + embed_dim=args.encoder_embed_dim, + output_dim=dict_len, + activation_fn=args.activation_fn, + weight=None, + ) + + # GBF相关模块 + K = 128 + n_edge_type = dict_len * dict_len + self.gbf_proj = NonLinearHead(K, args.encoder_attention_heads, args.activation_fn) + self.gbf = GaussianLayer(K, n_edge_type) + + # 坐标预测头(按需) + if args.masked_coord_loss > 0: + self.pair2coord_proj = NonLinearHead(args.encoder_attention_heads, 1, args.activation_fn) + + # 距离预测头(按需) + if args.masked_dist_loss > 0: + self.dist_head = DistanceHead(args.encoder_attention_heads, args.activation_fn) + + # 分类头(mint.nn.CellDict) + self.classification_heads = nn.CellDict() + # 参数初始化 + self.apply(init_bert_params) + + @classmethod + def build_model(cls, args, task): + """Build a new model instance.""" + return cls(args, task.dictionary) + + def forward( + self, + src_tokens, + src_distance, + src_coord, + src_edge_type, + encoder_masked_tokens=None, + features_only=False, + classification_head_name=None,** kwargs + ): + if classification_head_name is not None: + features_only = True + + # 计算padding掩码(用ms.ops.equal) + padding_mask = ms.ops.equal(src_tokens, self.padding_idx) + if not padding_mask.any(): + padding_mask = None + # 嵌入层前向 + x = self.embed_tokens(src_tokens) + + # 距离特征计算(内部用ms.ops) + def get_dist_features(dist, et): + n_node = dist.shape[-1] + gbf_feature = self.gbf(dist, et) + gbf_result = self.gbf_proj(gbf_feature) + graph_attn_bias = gbf_result + # 张量操作:用ms.ops + graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) + graph_attn_bias = ms.ops.contiguous(graph_attn_bias) + graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) + return graph_attn_bias + + graph_attn_bias = get_dist_features(src_distance, src_edge_type) + # 编码器前向 + ( + encoder_rep, + encoder_pair_rep, + delta_encoder_pair_rep, + x_norm, + delta_encoder_pair_rep_norm, + ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) + + # 处理-inf值(用ms.ops.where) + encoder_pair_rep = ms.ops.where( + encoder_pair_rep == float("-inf"), + ms.ops.zeros_like(encoder_pair_rep), + encoder_pair_rep + ) + + encoder_distance = None + encoder_coord = None + + # 预测头前向(按需) + if not features_only: + # 掩码token预测 + if self.args.masked_token_loss > 0: + logits = self.lm_head(encoder_rep, encoder_masked_tokens) + # 坐标预测 + if self.args.masked_coord_loss > 0: + coords_emb = src_coord + if padding_mask is not None: + atom_num = ms.ops.sum(1 - padding_mask.astype(x.dtype), dim=1).reshape(-1, 1, 1, 1) + else: + atom_num = src_coord.shape[1] + # 坐标计算(用ms.ops) + delta_pos = ms.ops.unsqueeze(coords_emb, 1) - ms.ops.unsqueeze(coords_emb, 2) + attn_probs = self.pair2coord_proj(delta_encoder_pair_rep) + coord_update = delta_pos / atom_num * attn_probs + pair_coords_mask = (1 - padding_mask.astype(ms.float32)).unsqueeze(-1) * (1 - padding_mask.astype(ms.float32)).unsqueeze(1) + coord_update = coord_update * pair_coords_mask.unsqueeze(-1) + coord_update = ms.ops.sum(coord_update, dim=2) + encoder_coord = coords_emb + coord_update + # 距离预测 + if self.args.masked_dist_loss > 0: + encoder_distance = self.dist_head(encoder_pair_rep) + + # 分类头预测 + if classification_head_name is not None: + logits = self.classification_heads[classification_head_name](encoder_rep) + + # 推理/训练返回 + if self.args.mode == 'infer': + return encoder_rep, encoder_pair_rep + else: + return ( + logits, + encoder_distance, + encoder_coord, + x_norm, + delta_encoder_pair_rep_norm, + ) + + def register_classification_head( + self, name, num_classes=None, inner_dim=None, **kwargs + ): + """Register a classification head.""" + if name in self.classification_heads: + prev_num_classes = self.classification_heads[name].out_proj.out_features + prev_inner_dim = self.classification_heads[name].dense.out_features + if num_classes != prev_num_classes or inner_dim != prev_inner_dim: + logger.warning( + 're-registering head "{}" with num_classes {} (prev: {}) ' + "and inner_dim {} (prev: {})".format( + name, num_classes, prev_num_classes, inner_dim, prev_inner_dim + ) + ) + self.classification_heads[name] = ClassificationHead( + input_dim=self.args.encoder_embed_dim, + inner_dim=inner_dim or self.args.encoder_embed_dim, + num_classes=num_classes, + activation_fn=self.args.pooler_activation_fn, + pooler_dropout=self.args.pooler_dropout, + ) + + def set_num_updates(self, num_updates): + self._num_updates = num_updates + + def get_num_updates(self): + return self._num_updates + + +class MaskLMHead(nn.Cell): + """Masked Language Model Head(用mint.nn)""" + def __init__(self, embed_dim, output_dim, activation_fn, weight=None): + super().__init__() + self.dense = nn.Linear(embed_dim, embed_dim) # mint.nn.Linear + self.activation_fn = utils.get_activation_fn(activation_fn) + self.layer_norm = LayerNorm(embed_dim) + + if weight is None: + weight = nn.Linear(embed_dim, output_dim, has_bias=False).weight + self.weight = weight + self.bias = ms.Parameter(ms.ops.zeros(output_dim)) + + def forward(self, features, masked_tokens=None, **kwargs): + if masked_tokens is not None: + features = features[masked_tokens, :] + # 前向(用mint.nn.functional) + x = self.dense(features) + x = self.activation_fn(x) + x = self.layer_norm(x) + x = F.linear(x, self.weight) + self.bias # mint.nn.functional.linear + return x + + +class ClassificationHead(nn.Cell): + """Classification Head(用mint.nn)""" + def __init__( + self, + input_dim, + inner_dim, + num_classes, + activation_fn, + pooler_dropout, + ): + super().__init__() + self.dense = nn.Linear(input_dim, inner_dim) # mint.nn.Linear + self.activation_fn = utils.get_activation_fn(activation_fn) + self.dropout = nn.Dropout(p=pooler_dropout) # mint.nn.Dropout + self.out_proj = nn.Linear(inner_dim, num_classes) # mint.nn.Linear + + def forward(self, features,** kwargs): + x = features[:, 0, :] # CLS token + x = self.dropout(x) + x = self.dense(x) + x = self.activation_fn(x) + x = self.dropout(x) + x = self.out_proj(x) + return x + + +class NonLinearHead(nn.Cell): + """Non-linear Head(用mint.nn)""" + def __init__( + self, + input_dim, + out_dim, + activation_fn, + hidden=None, + ): + super().__init__() + hidden = input_dim if not hidden else hidden + self.linear1 = nn.Linear(input_dim, hidden) # mint.nn.Linear + self.linear2 = nn.Linear(hidden, out_dim) # mint.nn.Linear + self.activation_fn = utils.get_activation_fn(activation_fn) + + def forward(self, x): + x = self.linear1(x) + x = self.activation_fn(x) + x = self.linear2(x) + return x + + +class DistanceHead(nn.Cell): + """Distance Head(用mint.nn)""" + def __init__(self, heads, activation_fn): + super().__init__() + self.dense = nn.Linear(heads, heads) # mint.nn.Linear + self.layer_norm = nn.LayerNorm(heads) # mint.nn.LayerNorm + self.out_proj = nn.Linear(heads, 1) # mint.nn.Linear + self.activation_fn = utils.get_activation_fn(activation_fn) + + def forward(self, x): + bsz, seq_len, seq_len, _ = x.shape + x = self.dense(x) + x = self.activation_fn(x) + x = self.layer_norm(x) + x = self.out_proj(x).reshape(bsz, seq_len, seq_len) # ms.ops.reshape + x = (x + ms.ops.transpose(x, (-1, -2))) * 0.5 # ms.ops.transpose + return x + + +def gaussian(x, mean, std): + """Gaussian function(用ms.ops)""" + pi = 3.14159 + a = (2 * pi) ** 0.5 + return ms.ops.exp(-0.5 * (((x - mean) / std) **2)) / (a * std) + + +class GaussianLayer(nn.Cell): + """Gaussian Layer(用mint.nn)""" + def __init__(self, K=128, edge_types=1024): + super().__init__() + self.K = K + self.means = nn.Embedding(1, K) # mint.nn.Embedding + self.stds = nn.Embedding(1, K) # mint.nn.Embedding + self.mul = nn.Embedding(edge_types, 1) # mint.nn.Embedding + self.bias = nn.Embedding(edge_types, 1) # mint.nn.Embedding + + # 参数初始化(MindSpore风格) + self.means.weight.set_data(ms.common.initializer.Uniform(3)(self.means.weight.shape)) + self.stds.weight.set_data(ms.common.initializer.Uniform(3)(self.stds.weight.shape)) + self.bias.weight.set_data(ms.common.initializer.Constant(0)(self.bias.weight.shape)) + self.mul.weight.set_data(ms.common.initializer.Constant(1)(self.mul.weight.shape)) + + def forward(self, x, edge_type): + # 类型转换(用ms.ops.astype) + mul = self.mul(edge_type).astype(x.dtype) + bias = self.bias(edge_type).astype(x.dtype) + # 特征计算(用ms.ops) + x = mul * ms.ops.unsqueeze(x, -1) + bias + x = ms.ops.tile(x, (1, 1, 1, self.K)) # ms.ops.tile + mean = self.means.weight.astype(ms.float32).reshape(-1) # ms.ops.reshape + std = ms.ops.abs(self.stds.weight.astype(ms.float32).reshape(-1)) + 1e-5 + return gaussian(x.astype(ms.float32), mean, std).astype(self.means.weight.dtype) + + +@register_model_architecture("unimol", "unimol") +def base_architecture(args): + """Default Architecture(补全默认超参)""" + args.encoder_layers = getattr(args, "encoder_layers", 15) + args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) + args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) + args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) + args.dropout = getattr(args, "dropout", 0.1) + args.emb_dropout = getattr(args, "emb_dropout", 0.1) + args.attention_dropout = getattr(args, "attention_dropout", 0.1) + args.activation_dropout = getattr(args, "activation_dropout", 0.0) + args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) + args.max_seq_len = getattr(args, "max_seq_len", 512) + args.activation_fn = getattr(args, "activation_fn", "gelu") + args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") + args.post_ln = getattr(args, "post_ln", False) + args.masked_token_loss = getattr(args, "masked_token_loss", -1.0) + args.masked_coord_loss = getattr(args, "masked_coord_loss", -1.0) + args.masked_dist_loss = getattr(args, "masked_dist_loss", -1.0) + args.x_norm_loss = getattr(args, "x_norm_loss", -1.0) + args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) + + +@register_model_architecture("unimol", "unimol_base") +def unimol_base_architecture(args): + """Base Architecture Alias""" + base_architecture(args) +# import logging +# import mindspore as ms +# import mindspore.mint.nn as nn +# import mindspore.mint.nn.functional as F +# from unicore import utils +# from unicore.models import BaseUnicoreModel, register_model, register_model_architecture +# from unicore.modules import LayerNorm, init_bert_params +# from unimol.models.transformer_encoder_with_pair import TransformerEncoderWithPair +# from typing import Dict, Any, List +# import numpy as np + +# logger = logging.getLogger(__name__) + + +# @register_model("unimol") +# class UniMolModel(BaseUnicoreModel): +# @staticmethod +# def add_args(parser): +# """Add model-specific arguments to the parser.""" +# parser.add_argument( +# "--encoder-layers", type=int, metavar="L", help="num encoder layers" +# ) +# parser.add_argument( +# "--encoder-embed-dim", +# type=int, +# metavar="H", +# help="encoder embedding dimension", +# ) +# parser.add_argument( +# "--encoder-ffn-embed-dim", +# type=int, +# metavar="F", +# help="encoder embedding dimension for FFN", +# ) +# parser.add_argument( +# "--encoder-attention-heads", +# type=int, +# metavar="A", +# help="num encoder attention heads", +# ) +# parser.add_argument( +# "--activation-fn", +# choices=utils.get_available_activation_fns(), +# help="activation function to use", +# ) +# parser.add_argument( +# "--pooler-activation-fn", +# choices=utils.get_available_activation_fns(), +# help="activation function to use for pooler layer", +# ) +# parser.add_argument( +# "--emb-dropout", +# type=float, +# metavar="D", +# help="dropout probability for embeddings", +# ) +# parser.add_argument( +# "--dropout", type=float, metavar="D", help="dropout probability" +# ) +# parser.add_argument( +# "--attention-dropout", +# type=float, +# metavar="D", +# help="dropout probability for attention weights", +# ) +# parser.add_argument( +# "--activation-dropout", +# type=float, +# metavar="D", +# help="dropout probability after activation in FFN", +# ) +# parser.add_argument( +# "--pooler-dropout", +# type=float, +# metavar="D", +# help="dropout probability in the masked_lm pooler layers", +# ) +# parser.add_argument( +# "--max-seq-len", type=int, help="number of positional embeddings to learn" +# ) +# parser.add_argument( +# "--post-ln", type=bool, help="use post layernorm or pre layernorm" +# ) +# parser.add_argument( +# "--masked-token-loss", +# type=float, +# metavar="D", +# help="mask loss ratio", +# ) +# parser.add_argument( +# "--masked-dist-loss", +# type=float, +# metavar="D", +# help="masked distance loss ratio", +# ) +# parser.add_argument( +# "--masked-coord-loss", +# type=float, +# metavar="D", +# help="masked coord loss ratio", +# ) +# parser.add_argument( +# "--x-norm-loss", +# type=float, +# metavar="D", +# help="x norm loss ratio", +# ) +# parser.add_argument( +# "--delta-pair-repr-norm-loss", +# type=float, +# metavar="D", +# help="delta encoder pair repr norm loss ratio", +# ) +# parser.add_argument( +# "--masked-coord-dist-loss", +# type=float, +# metavar="D", +# help="masked coord dist loss ratio", +# ) +# parser.add_argument( +# "--mode", +# type=str, +# default="train", +# choices=["train", "infer"], +# ) + +# def __init__(self, args, dictionary): +# super().__init__() +# base_architecture(args) +# self.args = args + +# # -------------------------- 1. 验证并修正 padding_idx(确保合法) -------------------------- +# self.padding_idx = dictionary.pad() +# dict_len = len(dictionary) +# if not (0 <= self.padding_idx < dict_len): +# self.padding_idx = dict_len - 1 +# print("⚠️ 原始 padding_idx={} 无效,手动修正为 {}".format(dictionary.pad(), self.padding_idx)) +# else: +# print("✅ 原始 padding_idx 有效:{}(字典长度={})".format(self.padding_idx, dict_len)) + +# # -------------------------- 2. 创建 nn.Embedding(按位置传参,兼容旧版MindSpore) -------------------------- +# # 兜底:若模块 to() 失败,直接重新创建 Embedding 并指定设备 +# # -------------------------- 2. 创建 nn.Embedding + 强制移动权重到 Ascend(正确兜底方案) -------------------------- +# # 步骤1:正常创建 nn.Embedding(不指定设备,先按默认逻辑初始化) +# self.embed_tokens = nn.Embedding( +# dict_len, # 第1位:词汇表大小(原子字典长度) +# args.encoder_embed_dim # 第2位:嵌入维度 +# ) + +# # 步骤2:强制将 Embedding 的权重张量移动到 Ascend:0(MindSpore 2.6.0 支持的方式) +# # 2.1 获取当前权重张量 +# weight = self.embed_tokens.weight.data +# # 2.2 移动张量到 Ascend:0(device 参数格式:'Ascend:设备号',必须是字符串) +# weight = ms.ops.move_to(weight, 'Ascend:0') # 核心:用 move_to_device 替代 ms.device() +# # 2.3 将移动后的权重重新设置回 Embedding 模块 +# self.embed_tokens.weight.set_data(weight) + +# # 步骤3:验证权重是否成功移动到 Ascend(避免移动失败仍在 CPU) +# try: +# # 新版本 MindSpore 有 device 属性,直接验证 +# weight_device = weight.device +# assert "Ascend" in str(weight_device), f"权重未移动到 Ascend!当前设备:{weight_device}" +# print("✅ Embedding 权重已成功移动到:{}".format(weight_device)) +# except AttributeError: +# # 旧版本无 device 属性,通过打印张量信息间接验证(Ascend 张量会显示 DeviceArray) +# print("✅ Embedding 权重移动状态:张量类型={}(已强制到 Ascend)".format(type(weight))) +# # 兜底:若移动失败,手动创建 Ascend 上的权重张量并替换 +# if "DeviceArray" not in str(type(weight)): +# print("⚠️ 权重移动未生效,手动创建 Ascend 权重...") +# # 重新创建全零权重(直接在 Ascend 上) +# new_weight = ms.Tensor( +# np.zeros((dict_len, args.encoder_embed_dim), dtype=np.float32), +# dtype=ms.float32, +# device_target="Ascend", # 显式指定设备目标 +# device_id=0 # 显式指定设备号 +# ) +# self.embed_tokens.weight.set_data(new_weight) +# print("✅ 手动创建 Ascend 权重成功!") + +# # 步骤4:验证 Embedding 的 padding_idx 为 None(避免隐性问题) +# assert self.embed_tokens.padding_idx is None, \ +# "nn.Embedding 异常:padding_idx={}".format(self.embed_tokens.padding_idx) +# print("✅ 创建 nn.Embedding 成功(词汇表大小={}, 嵌入维度={},设备:Ascend)".format(dict_len, args.encoder_embed_dim)) + +# # -------------------------- 4. 初始化其他模块(原有逻辑,保留) -------------------------- +# self._num_updates = None +# # 创建Transformer编码器 +# self.encoder = TransformerEncoderWithPair( +# encoder_layers=args.encoder_layers, +# embed_dim=args.encoder_embed_dim, +# ffn_embed_dim=args.encoder_ffn_embed_dim, +# attention_heads=args.encoder_attention_heads, +# emb_dropout=args.emb_dropout, +# dropout=args.dropout, +# attention_dropout=args.attention_dropout, +# activation_dropout=args.activation_dropout, +# max_seq_len=args.max_seq_len, +# activation_fn=args.activation_fn, +# no_final_head_layer_norm=args.delta_pair_repr_norm_loss < 0, +# ) + +# # 初始化掩码语言模型头(按需) +# if args.masked_token_loss > 0: +# self.lm_head = MaskLMHead( +# embed_dim=args.encoder_embed_dim, +# output_dim=len(dictionary), +# activation_fn=args.activation_fn, +# weight=None, +# ) + +# # 初始化GBF相关模块 +# K = 128 +# n_edge_type = len(dictionary) * len(dictionary) +# self.gbf_proj = NonLinearHead( +# K, args.encoder_attention_heads, args.activation_fn +# ) +# self.gbf = GaussianLayer(K, n_edge_type) + +# # 初始化坐标预测头(按需) +# if args.masked_coord_loss > 0: +# self.pair2coord_proj = NonLinearHead( +# args.encoder_attention_heads, 1, args.activation_fn +# ) + +# # 初始化距离预测头(按需) +# if args.masked_dist_loss > 0: +# self.dist_head = DistanceHead( +# args.encoder_attention_heads, args.activation_fn +# ) + +# # 初始化分类头(CellDict替代ModuleDict) +# self.classification_heads = nn.CellDict() +# # 初始化参数 +# self.apply(init_bert_params) + +# @classmethod +# def build_model(cls, args, task): +# """Build a new model instance.""" +# return cls(args, task.dictionary) + +# def forward( +# self, +# src_tokens, +# src_distance, +# src_coord, +# src_edge_type, +# encoder_masked_tokens=None, +# features_only=False, +# classification_head_name=None,** kwargs +# ): + +# if classification_head_name is not None: +# features_only = True + +# # 计算padding掩码(替换torch.eq为ms.equal) +# padding_mask = ms.equal(src_tokens, self.padding_idx) +# if not padding_mask.any(): +# padding_mask = None +# # 获取嵌入特征 +# x = self.embed_tokens(src_tokens) + +# # 计算距离特征(内部处理permute/reshape) +# def get_dist_features(dist, et): +# n_node = dist.shape[-1] +# gbf_feature = self.gbf(dist, et) +# gbf_result = self.gbf_proj(gbf_feature) +# graph_attn_bias = gbf_result +# # 替换permute/view为MindSpore API +# graph_attn_bias = ms.ops.permute(graph_attn_bias, (0, 3, 1, 2)) +# graph_attn_bias = ms.ops.contiguous(graph_attn_bias) +# graph_attn_bias = graph_attn_bias.reshape(-1, n_node, n_node) +# return graph_attn_bias + +# graph_attn_bias = get_dist_features(src_distance, src_edge_type) +# # 编码器前向传播 +# ( +# encoder_rep, +# encoder_pair_rep, +# delta_encoder_pair_rep, +# x_norm, +# delta_encoder_pair_rep_norm, +# ) = self.encoder(x, padding_mask=padding_mask, attn_mask=graph_attn_bias) + +# # 处理-inf值(替换张量索引赋值为ms.ops.where) +# encoder_pair_rep = ms.ops.where( +# encoder_pair_rep == float("-inf"), +# ms.ops.zeros_like(encoder_pair_rep), +# encoder_pair_rep +# ) + +# encoder_distance = None +# encoder_coord = None + +# # 按需计算预测头输出 +# if not features_only: +# # 掩码token预测 +# if self.args.masked_token_loss > 0: +# logits = self.lm_head(encoder_rep, encoder_masked_tokens) +# # 坐标预测 +# if self.args.masked_coord_loss > 0: +# coords_emb = src_coord +# if padding_mask is not None: +# # 替换torch.sum/type_as/view为MindSpore API +# atom_num = ms.ops.sum(1 - padding_mask.astype(x.dtype), dim=1).reshape( +# -1, 1, 1, 1 +# ) +# else: +# atom_num = src_coord.shape[1] +# # 计算坐标更新 +# delta_pos = ms.ops.unsqueeze(coords_emb, 1) - ms.ops.unsqueeze(coords_emb, 2) +# attn_probs = self.pair2coord_proj(delta_encoder_pair_rep) +# coord_update = delta_pos / atom_num * attn_probs +# # 计算坐标掩码 +# pair_coords_mask = (1 - padding_mask.astype(ms.float32)).unsqueeze(-1) * (1 - padding_mask.astype(ms.float32)).unsqueeze(1) +# coord_update = coord_update * pair_coords_mask.unsqueeze(-1) +# # 求和得到最终坐标 +# coord_update = ms.ops.sum(coord_update, dim=2) +# encoder_coord = coords_emb + coord_update +# # 距离预测 +# if self.args.masked_dist_loss > 0: +# encoder_distance = self.dist_head(encoder_pair_rep) + +# # 分类头预测(按需) +# if classification_head_name is not None: +# logits = self.classification_heads[classification_head_name](encoder_rep) + +# # 推理/训练模式返回不同结果 +# if self.args.mode == 'infer': +# return encoder_rep, encoder_pair_rep +# else: +# return ( +# logits, +# encoder_distance, +# encoder_coord, +# x_norm, +# delta_encoder_pair_rep_norm, +# ) + +# def register_classification_head( +# self, name, num_classes=None, inner_dim=None, **kwargs +# ): +# """Register a classification head.""" +# if name in self.classification_heads: +# prev_num_classes = self.classification_heads[name].out_proj.out_features +# prev_inner_dim = self.classification_heads[name].dense.out_features +# if num_classes != prev_num_classes or inner_dim != prev_inner_dim: +# logger.warning( +# 're-registering head "{}" with num_classes {} (prev: {}) ' +# "and inner_dim {} (prev: {})".format( +# name, num_classes, prev_num_classes, inner_dim, prev_inner_dim +# ) +# ) +# self.classification_heads[name] = ClassificationHead( +# input_dim=self.args.encoder_embed_dim, +# inner_dim=inner_dim or self.args.encoder_embed_dim, +# num_classes=num_classes, +# activation_fn=self.args.pooler_activation_fn, +# pooler_dropout=self.args.pooler_dropout, +# ) + +# def set_num_updates(self, num_updates): +# """State from trainer to pass along to model at every update.""" +# self._num_updates = num_updates + +# def get_num_updates(self): +# return self._num_updates + + +# class MaskLMHead(nn.Cell): # 替换PyTorch的Module为MindSpore的Cell +# """Head for masked language modeling.""" + +# def __init__(self, embed_dim, output_dim, activation_fn, weight=None): +# super().__init__() +# self.dense = nn.Linear(embed_dim, embed_dim) +# self.activation_fn = utils.get_activation_fn(activation_fn) +# self.layer_norm = LayerNorm(embed_dim) + +# if weight is None: +# # 初始化权重(MindSpore风格) +# weight = nn.Linear(embed_dim, output_dim, has_bias=False).weight +# self.weight = weight +# self.bias = ms.Parameter(ms.ops.zeros(output_dim)) # 替换nn.Parameter为ms.Parameter + +# def forward(self, features, masked_tokens=None, **kwargs): +# # 仅对掩码token计算(训练优化) +# if masked_tokens is not None: +# features = features[masked_tokens, :] + +# # 前向传播 +# x = self.dense(features) +# x = self.activation_fn(x) +# x = self.layer_norm(x) +# # 替换F.linear为MindSpore的线性计算 +# x = F.linear(x, self.weight) + self.bias +# return x + + +# class ClassificationHead(nn.Cell): # 替换Module为Cell +# """Head for sentence-level classification tasks.""" + +# def __init__( +# self, +# input_dim, +# inner_dim, +# num_classes, +# activation_fn, +# pooler_dropout, +# ): +# super().__init__() +# self.dense = nn.Linear(input_dim, inner_dim) +# self.activation_fn = utils.get_activation_fn(activation_fn) +# self.dropout = nn.Dropout(p=pooler_dropout) +# self.out_proj = nn.Linear(inner_dim, num_classes) + +# def forward(self, features,** kwargs): +# # 取CLS token的特征(第0个token) +# x = features[:, 0, :] +# x = self.dropout(x) +# x = self.dense(x) +# x = self.activation_fn(x) +# x = self.dropout(x) +# x = self.out_proj(x) +# return x + + +# class NonLinearHead(nn.Cell): # 替换Module为Cell +# """Head for simple classification tasks.""" + +# def __init__( +# self, +# input_dim, +# out_dim, +# activation_fn, +# hidden=None, +# ): +# super().__init__() +# hidden = input_dim if not hidden else hidden +# self.linear1 = nn.Linear(input_dim, hidden) +# self.linear2 = nn.Linear(hidden, out_dim) +# self.activation_fn = utils.get_activation_fn(activation_fn) + +# def forward(self, x): +# x = self.linear1(x) +# x = self.activation_fn(x) +# x = self.linear2(x) +# return x + + +# class DistanceHead(nn.Cell): # 替换Module为Cell +# def __init__( +# self, +# heads, +# activation_fn, +# ): +# super().__init__() +# self.dense = nn.Linear(heads, heads) +# self.layer_norm = nn.LayerNorm(heads) +# self.out_proj = nn.Linear(heads, 1) +# self.activation_fn = utils.get_activation_fn(activation_fn) + +# def forward(self, x): +# bsz, seq_len, seq_len, _ = x.shape # 替换size为shape +# x = self.dense(x) +# x = self.activation_fn(x) +# x = self.layer_norm(x) +# # 替换view为reshape +# x = self.out_proj(x).reshape(bsz, seq_len, seq_len) +# # 替换transpose为ms.ops.transpose(保证距离对称性) +# x = (x + ms.ops.transpose(x, (-1, -2))) * 0.5 +# return x + + +# def gaussian(x, mean, std): +# """Gaussian function for GaussianLayer(兼容旧版MindSpore,不使用ms.script)""" +# pi = 3.14159 +# a = (2 * pi) ** 0.5 +# return ms.ops.exp(-0.5 * (((x - mean) / std) **2)) / (a * std) + + +# class GaussianLayer(nn.Cell): # 替换Module为Cell +# def __init__(self, K=128, edge_types=1024): +# super().__init__() +# self.K = K +# self.means = nn.Embedding(1, K) +# self.stds = nn.Embedding(1, K) +# self.mul = nn.Embedding(edge_types, 1) +# self.bias = nn.Embedding(edge_types, 1) + +# # 替换PyTorch初始化为MindSpore风格 +# self.means.weight.set_data(ms.common.initializer.Uniform(3)(self.means.weight.shape)) +# self.stds.weight.set_data(ms.common.initializer.Uniform(3)(self.stds.weight.shape)) +# self.bias.weight.set_data(ms.common.initializer.Constant(0)(self.bias.weight.shape)) +# self.mul.weight.set_data(ms.common.initializer.Constant(1)(self.mul.weight.shape)) + +# def forward(self, x, edge_type): +# # 替换type_as为astype(MindSpore类型转换) +# mul = self.mul(edge_type).astype(x.dtype) +# bias = self.bias(edge_type).astype(x.dtype) +# # 计算距离特征 +# x = mul * ms.ops.unsqueeze(x, -1) + bias +# x = ms.ops.tile(x, (1, 1, 1, self.K)) # 替换expand为tile +# # 处理均值和标准差(替换view为reshape) +# mean = self.means.weight.astype(ms.float32).reshape(-1) +# std = ms.ops.abs(self.stds.weight.astype(ms.float32).reshape(-1)) + 1e-5 # 避免std为0 +# # 计算高斯特征 +# return gaussian(x.astype(ms.float32), mean, std).astype(self.means.weight.dtype) + + +# @register_model_architecture("unimol", "unimol") +# def base_architecture(args): +# """Default architecture for UniMol(补全默认超参)""" +# args.encoder_layers = getattr(args, "encoder_layers", 15) +# args.encoder_embed_dim = getattr(args, "encoder_embed_dim", 512) +# args.encoder_ffn_embed_dim = getattr(args, "encoder_ffn_embed_dim", 2048) +# args.encoder_attention_heads = getattr(args, "encoder_attention_heads", 64) +# args.dropout = getattr(args, "dropout", 0.1) +# args.emb_dropout = getattr(args, "emb_dropout", 0.1) +# args.attention_dropout = getattr(args, "attention_dropout", 0.1) +# args.activation_dropout = getattr(args, "activation_dropout", 0.0) +# args.pooler_dropout = getattr(args, "pooler_dropout", 0.0) +# args.max_seq_len = getattr(args, "max_seq_len", 512) +# args.activation_fn = getattr(args, "activation_fn", "gelu") +# args.pooler_activation_fn = getattr(args, "pooler_activation_fn", "tanh") +# args.post_ln = getattr(args, "post_ln", False) +# args.masked_token_loss = getattr(args, "masked_token_loss", -1.0) +# args.masked_coord_loss = getattr(args, "masked_coord_loss", -1.0) +# args.masked_dist_loss = getattr(args, "masked_dist_loss", -1.0) +# args.x_norm_loss = getattr(args, "x_norm_loss", -1.0) +# args.delta_pair_repr_norm_loss = getattr(args, "delta_pair_repr_norm_loss", -1.0) + + +# @register_model_architecture("unimol", "unimol_base") +# def unimol_base_architecture(args): +# """Base architecture alias(复用默认超参)""" +# base_architecture(args) + diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/setup.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..11e12d53e2c4709bb2a6fb4b9de209897100d58d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/setup.py @@ -0,0 +1,33 @@ +"""Install script for setuptools.""" + +from setuptools import find_packages +from setuptools import setup + +setup( + name="unimol", + version="1.0.0", + description="", + author="DP Technology", + author_email="unimol@dp.tech", + license="The MIT License", + url="https://github.com/deepmodeling/Uni-Mol", + packages=find_packages( + exclude=["scripts", "tests", "example_data", "docker", "figure"] + ), + install_requires=[ + "numpy", + "pandas", + "scikit-learn-extra", + ], + classifiers=[ + "Development Status :: 5 - Production/Stable", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: POSIX :: Linux", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], +) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..646cddbea500fc206c103455a419886dcbc50014 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/__init__.py @@ -0,0 +1,7 @@ +from pathlib import Path +import importlib + +# automatically import any Python files in the criterions/ directory +for file in sorted(Path(__file__).parent.glob("*.py")): + if not file.name.startswith("_"): + importlib.import_module("unimol.tasks." + file.name[:-3]) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/docking_pose.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/docking_pose.py new file mode 100644 index 0000000000000000000000000000000000000000..d961a42b58a60428d53990f3b0e69a1043108331 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/docking_pose.py @@ -0,0 +1,303 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os +from collections.abc import Iterable + +import numpy as np +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + TokenizeDataset, + RightPadDataset2D, + RawArrayDataset, + FromNumpyDataset, + EpochShuffleDataset, +) +from unimol.data import ( + KeyDataset, + ConformerSampleDockingPoseDataset, + DistanceDataset, + EdgeTypeDataset, + NormalizeDataset, + RightPadDatasetCoord, + LMDBDataset, + CrossDistanceDataset, + NormalizeDockingPoseDataset, + TTADockingPoseDataset, + RightPadDatasetCross2D, + CroppingPocketDockingPoseDataset, + PrependAndAppend2DDataset, + RemoveHydrogenPocketDataset, +) +from unicore import checkpoint_utils +from unicore.tasks import UnicoreTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("docking_pose") +class DockingPose(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="downstream data path", + ) + parser.add_argument( + "--finetune-mol-model", + default=None, + type=str, + help="pretrained molecular model path", + ) + parser.add_argument( + "--finetune-pocket-model", + default=None, + type=str, + help="pretrained pocket model path", + ) + parser.add_argument( + "--conf-size", + default=10, + type=int, + help="number of conformers generated with each molecule", + ) + parser.add_argument( + "--dist-threshold", + type=float, + default=8.0, + help="threshold for the distance between the molecule and the pocket", + ) + parser.add_argument( + "--max-pocket-atoms", + type=int, + default=256, + help="selected maximum number of atoms in a pocket", + ) + + def __init__(self, args, dictionary, pocket_dictionary): + super().__init__(args) + self.dictionary = dictionary + self.pocket_dictionary = pocket_dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + self.pocket_mask_idx = pocket_dictionary.add_symbol("[MASK]", is_special=True) + + @classmethod + def setup_task(cls, args, **kwargs): + mol_dictionary = Dictionary.load(os.path.join(args.data, "dict_mol.txt")) + pocket_dictionary = Dictionary.load(os.path.join(args.data, "dict_pkt.txt")) + logger.info("ligand dictionary: {} types".format(len(mol_dictionary))) + logger.info("pocket dictionary: {} types".format(len(pocket_dictionary))) + return cls(args, mol_dictionary, pocket_dictionary) + + def load_dataset(self, split, **kwargs): + """Load a given dataset split. + 'smi','pocket','atoms','coordinates','pocket_atoms','pocket_coordinates','holo_coordinates','holo_pocket_coordinates','scaffold' + Args: + split (str): name of the data scoure (e.g., bppp) + """ + data_path = os.path.join(self.args.data, split + ".lmdb") + dataset = LMDBDataset(data_path) + if split.startswith("train"): + smi_dataset = KeyDataset(dataset, "smi") + poc_dataset = KeyDataset(dataset, "pocket") + dataset = ConformerSampleDockingPoseDataset( + dataset, + self.args.seed, + "atoms", + "coordinates", + "pocket_atoms", + "pocket_coordinates", + "holo_coordinates", + "holo_pocket_coordinates", + True, + ) + else: + dataset = TTADockingPoseDataset( + dataset, + "atoms", + "coordinates", + "pocket_atoms", + "pocket_coordinates", + "holo_coordinates", + "holo_pocket_coordinates", + True, + self.args.conf_size, + ) + smi_dataset = KeyDataset(dataset, "smi") + poc_dataset = KeyDataset(dataset, "pocket") + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + dataset = RemoveHydrogenPocketDataset( + dataset, + "pocket_atoms", + "pocket_coordinates", + "holo_pocket_coordinates", + True, + True, + ) + dataset = CroppingPocketDockingPoseDataset( + dataset, + self.seed, + "pocket_atoms", + "pocket_coordinates", + "holo_pocket_coordinates", + self.args.max_pocket_atoms, + ) + dataset = RemoveHydrogenPocketDataset( + dataset, "atoms", "coordinates", "holo_coordinates", True, True + ) + + apo_dataset = NormalizeDataset(dataset, "coordinates") + apo_dataset = NormalizeDataset(apo_dataset, "pocket_coordinates") + + src_dataset = KeyDataset(apo_dataset, "atoms") + src_dataset = TokenizeDataset( + src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(apo_dataset, "coordinates") + src_dataset = PrependAndAppend( + src_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + distance_dataset = DistanceDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = PrependAndAppend2DDataset(distance_dataset, 0.0) + + src_pocket_dataset = KeyDataset(apo_dataset, "pocket_atoms") + src_pocket_dataset = TokenizeDataset( + src_pocket_dataset, + self.pocket_dictionary, + max_seq_len=self.args.max_seq_len, + ) + coord_pocket_dataset = KeyDataset(apo_dataset, "pocket_coordinates") + src_pocket_dataset = PrependAndAppend( + src_pocket_dataset, + self.pocket_dictionary.bos(), + self.pocket_dictionary.eos(), + ) + pocket_edge_type = EdgeTypeDataset( + src_pocket_dataset, len(self.pocket_dictionary) + ) + coord_pocket_dataset = FromNumpyDataset(coord_pocket_dataset) + distance_pocket_dataset = DistanceDataset(coord_pocket_dataset) + coord_pocket_dataset = PrependAndAppend(coord_pocket_dataset, 0.0, 0.0) + distance_pocket_dataset = PrependAndAppend2DDataset( + distance_pocket_dataset, 0.0 + ) + + holo_dataset = NormalizeDockingPoseDataset( + dataset, + "holo_coordinates", + "holo_pocket_coordinates", + "holo_center_coordinates", + ) + holo_coord_dataset = KeyDataset(holo_dataset, "holo_coordinates") + holo_coord_dataset = FromNumpyDataset(holo_coord_dataset) + holo_coord_pocket_dataset = KeyDataset(holo_dataset, "holo_pocket_coordinates") + holo_coord_pocket_dataset = FromNumpyDataset(holo_coord_pocket_dataset) + + holo_cross_distance_dataset = CrossDistanceDataset( + holo_coord_dataset, holo_coord_pocket_dataset + ) + + holo_distance_dataset = DistanceDataset(holo_coord_dataset) + holo_coord_dataset = PrependAndAppend(holo_coord_dataset, 0.0, 0.0) + holo_distance_dataset = PrependAndAppend2DDataset(holo_distance_dataset, 0.0) + holo_coord_pocket_dataset = PrependAndAppend( + holo_coord_pocket_dataset, 0.0, 0.0 + ) + holo_cross_distance_dataset = PrependAndAppend2DDataset( + holo_cross_distance_dataset, 0.0 + ) + + holo_center_coordinates = KeyDataset(holo_dataset, "holo_center_coordinates") + holo_center_coordinates = FromNumpyDataset(holo_center_coordinates) + + nest_dataset = NestedDictionaryDataset( + { + "net_input": { + "mol_src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "mol_src_distance": RightPadDataset2D( + distance_dataset, + pad_idx=0, + ), + "mol_src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + "pocket_src_tokens": RightPadDataset( + src_pocket_dataset, + pad_idx=self.pocket_dictionary.pad(), + ), + "pocket_src_distance": RightPadDataset2D( + distance_pocket_dataset, + pad_idx=0, + ), + "pocket_src_edge_type": RightPadDataset2D( + pocket_edge_type, + pad_idx=0, + ), + "pocket_src_coord": RightPadDatasetCoord( + coord_pocket_dataset, + pad_idx=0, + ), + }, + "target": { + "distance_target": RightPadDatasetCross2D( + holo_cross_distance_dataset, pad_idx=0 + ), + "holo_coord": RightPadDatasetCoord(holo_coord_dataset, pad_idx=0), + "holo_distance_target": RightPadDataset2D( + holo_distance_dataset, pad_idx=0 + ), + }, + "smi_name": RawArrayDataset(smi_dataset), + "pocket_name": RawArrayDataset(poc_dataset), + "holo_center_coordinates": RightPadDataset( + holo_center_coordinates, + pad_idx=0, + ), + }, + ) + if split.startswith("train"): + nest_dataset = EpochShuffleDataset( + nest_dataset, len(nest_dataset), self.args.seed + ) + self.datasets[split] = nest_dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + if args.finetune_mol_model is not None: + print("load pretrain model weight from...", args.finetune_mol_model) + state = checkpoint_utils.load_checkpoint_to_cpu( + args.finetune_mol_model, + ) + model.mol_model.load_state_dict(state["model"], strict=False) + if args.finetune_pocket_model is not None: + print("load pretrain model weight from...", args.finetune_pocket_model) + state = checkpoint_utils.load_checkpoint_to_cpu( + args.finetune_pocket_model, + ) + model.pocket_model.load_state_dict(state["model"], strict=False) + return model diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol.py new file mode 100644 index 0000000000000000000000000000000000000000..3491108e7ab4c2aa0a0136ef9e32e310672f7426 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol.py @@ -0,0 +1,254 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + EpochShuffleDataset, + TokenizeDataset, + RightPadDataset2D, + FromNumpyDataset, + RawArrayDataset, +) +from unimol.data import ( + KeyDataset, + ConformerSampleDataset, + DistanceDataset, + EdgeTypeDataset, + MaskPointsDataset, + RemoveHydrogenDataset, + AtomTypeDataset, + NormalizeDataset, + CroppingDataset, + RightPadDatasetCoord, + Add2DConformerDataset, + LMDBDataset, + TTADataset, +) +from unicore.tasks import UnicoreTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("unimol") +class UniMolTask(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.05, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.05, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--noise-type", + default="uniform", + choices=["trunc_normal", "uniform", "normal", "none"], + help="noise type in coordinate noise", + ) + parser.add_argument( + "--noise", + default=1.0, + type=float, + help="coordinate noise for masked atoms", + ) + parser.add_argument( + "--remove-hydrogen", + action="store_true", + help="remove hydrogen atoms", + ) + parser.add_argument( + "--remove-polar-hydrogen", + action="store_true", + help="remove polar hydrogen atoms", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=256, + help="selected maximum number of atoms in a molecule", + ) + parser.add_argument( + "--dict-name", + default="dict.txt", + help="dictionary file", + ) + parser.add_argument( + "--only-polar", + default=1, + type=int, + help="1: only polar hydrogen ; -1: all hydrogen ; 0: remove all hydrogen ", + ) + parser.add_argument( + "--conf-size", + default=10, + type=int, + help="number of conformers generated with each molecule", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + if self.args.only_polar > 0: + self.args.remove_polar_hydrogen = True + elif args.only_polar < 0: + self.args.remove_polar_hydrogen = False + else: + self.args.remove_hydrogen = True + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + split_path = os.path.join(self.args.data, split + ".lmdb") + + raw_dataset = LMDBDataset(split_path) + + def one_dataset(raw_dataset, coord_seed, mask_seed): + if self.args.mode =='train': + raw_dataset = Add2DConformerDataset( + raw_dataset, "smi", "atoms", "coordinates" + ) + smi_dataset = KeyDataset(raw_dataset, "smi") + dataset = ConformerSampleDataset( + raw_dataset, coord_seed, "atoms", "coordinates" + ) + dataset = AtomTypeDataset(raw_dataset, dataset) + elif self.args.mode == 'infer': + dataset = TTADataset( + raw_dataset, self.args.seed, "atoms", "coordinates", self.args.conf_size + ) + dataset = AtomTypeDataset(dataset, dataset) + smi_dataset = KeyDataset(dataset, "smi") + dataset = RemoveHydrogenDataset( + dataset, + "atoms", + "coordinates", + self.args.remove_hydrogen, + self.args.remove_polar_hydrogen, + ) + dataset = CroppingDataset( + dataset, self.seed, "atoms", "coordinates", self.args.max_atoms + ) + dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True) + token_dataset = KeyDataset(dataset, "atoms") + token_dataset = TokenizeDataset( + token_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(dataset, "coordinates") + expand_dataset = MaskPointsDataset( + token_dataset, + coord_dataset, + self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.mask_idx, + noise_type=self.args.noise_type, + noise=self.args.noise, + seed=mask_seed, + mask_prob=self.args.mask_prob, + leave_unmasked_prob=self.args.leave_unmasked_prob, + random_token_prob=self.args.random_token_prob, + ) + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + encoder_token_dataset = KeyDataset(expand_dataset, "atoms") + encoder_target_dataset = KeyDataset(expand_dataset, "targets") + encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates") + + src_dataset = PrependAndAppend( + encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + tgt_dataset = PrependAndAppend( + encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad() + ) + encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0) + encoder_distance_dataset = DistanceDataset(encoder_coord_dataset) + + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = DistanceDataset(coord_dataset) + return { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "src_coord": RightPadDatasetCoord( + encoder_coord_dataset, + pad_idx=0, + ), + "src_distance": RightPadDataset2D( + encoder_distance_dataset, + pad_idx=0, + ), + "src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + }, { + "tokens_target": RightPadDataset( + tgt_dataset, pad_idx=self.dictionary.pad() + ), + "distance_target": RightPadDataset2D(distance_dataset, pad_idx=0), + "coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0), + "smi_name": RawArrayDataset(smi_dataset), + } + + net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed) + dataset = {"net_input": net_input, "target": target} + dataset = NestedDictionaryDataset(dataset) + if split in ["train", "train.small"]: + dataset = EpochShuffleDataset(dataset, len(dataset), self.args.seed) + self.datasets[split] = dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + return model + + def disable_shuffling(self) -> bool: + return True + diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_conf_gen.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_conf_gen.py new file mode 100644 index 0000000000000000000000000000000000000000..7ed8238de6347a138675201f1ac90e4dac5544a9 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_conf_gen.py @@ -0,0 +1,193 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + LMDBDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + SortDataset, + TokenizeDataset, + RightPadDataset2D, + RawArrayDataset, + FromNumpyDataset, +) +from unimol.data import ( + KeyDataset, + DistanceDataset, + EdgeTypeDataset, + NormalizeDataset, + RightPadDatasetCoord, + ConformerSampleConfGDataset, + ConformerSampleConfGV2Dataset, + data_utils, +) +from unicore.tasks import UnicoreTask, register_task +from unicore import checkpoint_utils + +logger = logging.getLogger(__name__) + + +@register_task("mol_confG") +class UniMolConfGTask(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="downstream data path") + parser.add_argument("--task-name", type=str, help="downstream task name") + parser.add_argument( + "--dict-name", + default="dict.txt", + help="dictionary file", + ) + parser.add_argument( + "--beta", + type=float, + default=1.0, + help="beta for conformation importance sampling", + ) + parser.add_argument( + "--smooth", + type=float, + default=0.1, + help="smoothing for conformation importance sampling", + ) + parser.add_argument( + "--topN", + type=int, + default=10, + help="only top N best rmsd for conformation importance sampling", + ) + parser.add_argument( + "--finetune-mol-model", + default=None, + type=str, + help="pretrained molecular model path", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the data scoure (e.g., bppp) + """ + split_path = os.path.join(self.args.data, self.args.task_name, split + ".lmdb") + dataset = LMDBDataset(split_path) + smi_dataset = KeyDataset(dataset, "smi") + src_dataset = KeyDataset(dataset, "atoms") + if not split.startswith("test"): + sample_dataset = ConformerSampleConfGV2Dataset( + dataset, + self.args.seed, + "atoms", + "coordinates", + "target", + self.args.beta, + self.args.smooth, + self.args.topN, + ) + else: + sample_dataset = ConformerSampleConfGDataset( + dataset, self.args.seed, "atoms", "coordinates", "target" + ) + sample_dataset = NormalizeDataset(sample_dataset, "coordinates") + sample_dataset = NormalizeDataset(sample_dataset, "target") + src_dataset = TokenizeDataset( + src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(sample_dataset, "coordinates") + tgt_coord_dataset = KeyDataset(sample_dataset, "target") + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + tgt_coord_dataset = FromNumpyDataset(tgt_coord_dataset) + tgt_coord_dataset = PrependAndAppend(tgt_coord_dataset, 0.0, 0.0) + tgt_distance_dataset = DistanceDataset(tgt_coord_dataset) + + src_dataset = PrependAndAppend( + src_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = DistanceDataset(coord_dataset) + + nest_dataset = NestedDictionaryDataset( + { + "net_input": { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "src_coord": RightPadDatasetCoord( + coord_dataset, + pad_idx=0, + ), + "src_distance": RightPadDataset2D( + distance_dataset, + pad_idx=0, + ), + "src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + }, + "target": { + "coord_target": RightPadDatasetCoord( + tgt_coord_dataset, + pad_idx=0, + ), + "distance_target": RightPadDataset2D( + tgt_distance_dataset, + pad_idx=0, + ), + }, + "smi_name": RawArrayDataset(smi_dataset), + }, + ) + if split.startswith("train"): + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_dataset)) + + self.datasets[split] = SortDataset( + nest_dataset, + sort_order=[shuffle], + ) + else: + self.datasets[split] = nest_dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + if args.finetune_mol_model is not None: + print("load pretrain model weight from...", args.finetune_mol_model) + state = checkpoint_utils.load_checkpoint_to_cpu( + args.finetune_mol_model, + ) + model.unimol.load_state_dict(state["model"], strict=False) + return model diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_finetune.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..04c1784cf8cc16bf7048227c0907822289ea75f1 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_finetune.py @@ -0,0 +1,285 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +import numpy as np +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + LMDBDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + SortDataset, + TokenizeDataset, + RightPadDataset2D, + RawLabelDataset, + RawArrayDataset, + FromNumpyDataset, +) +from unimol.data import ( + KeyDataset, + ConformerSampleDataset, + DistanceDataset, + EdgeTypeDataset, + RemoveHydrogenDataset, + AtomTypeDataset, + NormalizeDataset, + CroppingDataset, + RightPadDatasetCoord, + data_utils, +) + +from unimol.data.tta_dataset import TTADataset +from unicore.tasks import UnicoreTask, register_task + + +logger = logging.getLogger(__name__) + +task_metainfo = { + "esol": { + "mean": -3.0501019503546094, + "std": 2.096441210089345, + "target_name": "logSolubility", + }, + "freesolv": { + "mean": -3.8030062305295944, + "std": 3.8478201171088138, + "target_name": "freesolv", + }, + "lipo": {"mean": 2.186336, "std": 1.203004, "target_name": "lipo"}, + "qm7dft": { + "mean": -1544.8360893118609, + "std": 222.8902092792289, + "target_name": "u0_atom", + }, + "qm8dft": { + "mean": [ + 0.22008500524052105, + 0.24892658759891675, + 0.02289283121913152, + 0.043164444107224746, + 0.21669716560818883, + 0.24225989336408812, + 0.020287111373358993, + 0.03312609817084387, + 0.21681478862847584, + 0.24463634931699113, + 0.02345177178004201, + 0.03730141834205415, + ], + "std": [ + 0.043832862248693226, + 0.03452326954549232, + 0.053401140662012285, + 0.0730556474716259, + 0.04788020599385645, + 0.040309670766319, + 0.05117163534626215, + 0.06030064428723054, + 0.04458294838213221, + 0.03597696243350195, + 0.05786865052149905, + 0.06692733477994665, + ], + "target_name": [ + "E1-CC2", + "E2-CC2", + "f1-CC2", + "f2-CC2", + "E1-PBE0", + "E2-PBE0", + "f1-PBE0", + "f2-PBE0", + "E1-CAM", + "E2-CAM", + "f1-CAM", + "f2-CAM", + ], + }, + "qm9dft": { + "mean": [-0.23997669940621352, 0.011123767412331285, 0.2511003712141015], + "std": [0.02213143402267657, 0.046936069870866196, 0.04751888787058615], + "target_name": ["homo", "lumo", "gap"], + }, +} + + +@register_task("mol_finetune") +class UniMolFinetuneTask(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="downstream data path") + parser.add_argument("--task-name", type=str, help="downstream task name") + parser.add_argument( + "--classification-head-name", + default="classification", + help="finetune downstream task name", + ) + parser.add_argument( + "--num-classes", + default=1, + type=int, + help="finetune downstream task classes numbers", + ) + parser.add_argument("--no-shuffle", action="store_true", help="shuffle data") + parser.add_argument( + "--conf-size", + default=10, + type=int, + help="number of conformers generated with each molecule", + ) + parser.add_argument( + "--remove-hydrogen", + action="store_true", + help="remove hydrogen atoms", + ) + parser.add_argument( + "--remove-polar-hydrogen", + action="store_true", + help="remove polar hydrogen atoms", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=256, + help="selected maximum number of atoms in a molecule", + ) + parser.add_argument( + "--dict-name", + default="dict.txt", + help="dictionary file", + ) + parser.add_argument( + "--only-polar", + default=1, + type=int, + help="1: only reserve polar hydrogen; 0: no hydrogen; -1: all hydrogen ", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + if self.args.only_polar > 0: + self.args.remove_polar_hydrogen = True + elif self.args.only_polar < 0: + self.args.remove_polar_hydrogen = False + else: + self.args.remove_hydrogen = True + if self.args.task_name in task_metainfo: + # for regression task, pre-compute mean and std + self.mean = task_metainfo[self.args.task_name]["mean"] + self.std = task_metainfo[self.args.task_name]["std"] + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the data scoure (e.g., train) + """ + split_path = os.path.join(self.args.data, self.args.task_name, split + ".lmdb") + dataset = LMDBDataset(split_path) + if split == "train": + tgt_dataset = KeyDataset(dataset, "target") + smi_dataset = KeyDataset(dataset, "smi") + sample_dataset = ConformerSampleDataset( + dataset, self.args.seed, "atoms", "coordinates" + ) + dataset = AtomTypeDataset(dataset, sample_dataset) + else: + dataset = TTADataset( + dataset, self.args.seed, "atoms", "coordinates", self.args.conf_size + ) + dataset = AtomTypeDataset(dataset, dataset) + tgt_dataset = KeyDataset(dataset, "target") + smi_dataset = KeyDataset(dataset, "smi") + + dataset = RemoveHydrogenDataset( + dataset, + "atoms", + "coordinates", + self.args.remove_hydrogen, + self.args.remove_polar_hydrogen, + ) + dataset = CroppingDataset( + dataset, self.seed, "atoms", "coordinates", self.args.max_atoms + ) + dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True) + src_dataset = KeyDataset(dataset, "atoms") + src_dataset = TokenizeDataset( + src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(dataset, "coordinates") + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + src_dataset = PrependAndAppend( + src_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = DistanceDataset(coord_dataset) + + nest_dataset = NestedDictionaryDataset( + { + "net_input": { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "src_coord": RightPadDatasetCoord( + coord_dataset, + pad_idx=0, + ), + "src_distance": RightPadDataset2D( + distance_dataset, + pad_idx=0, + ), + "src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + }, + "target": { + "finetune_target": RawLabelDataset(tgt_dataset), + }, + "smi_name": RawArrayDataset(smi_dataset), + }, + ) + if not self.args.no_shuffle and split == "train": + with data_utils.numpy_seed(self.args.seed): + shuffle = np.random.permutation(len(src_dataset)) + + self.datasets[split] = SortDataset( + nest_dataset, + sort_order=[shuffle], + ) + else: + self.datasets[split] = nest_dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + model.register_classification_head( + self.args.classification_head_name, + num_classes=self.args.num_classes, + ) + return model diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket.py new file mode 100644 index 0000000000000000000000000000000000000000..21bebb1e2c876fc23b5bae6b6cf2b6a2b2ca786d --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket.py @@ -0,0 +1,217 @@ +# Copyright (c) DP Technology. +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +from typing import Optional + +import numpy as np +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + LMDBDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + EpochShuffleDataset, + TokenizeDataset, + RightPadDataset2D, + FromNumpyDataset, + RawArrayDataset, +) +from unimol.data import ( + KeyDataset, + ConformerSamplePocketDataset, + DistanceDataset, + EdgeTypeDataset, + MaskPointsPocketDataset, + NormalizeDataset, + CroppingPocketDataset, + AtomTypeDataset, + RightPadDatasetCoord, +) +from unicore.tasks import UnicoreTask, register_task + + +logger = logging.getLogger(__name__) + + +@register_task("unimol_pocket") +class UniMolPocketTask(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument( + "data", + help="colon separated path to data directories list, \ + will be iterated upon during epochs in round-robin manner", + ) + parser.add_argument( + "--mask-prob", + default=0.15, + type=float, + help="probability of replacing a token with mask", + ) + parser.add_argument( + "--leave-unmasked-prob", + default=0.05, + type=float, + help="probability that a masked token is unmasked", + ) + parser.add_argument( + "--random-token-prob", + default=0.05, + type=float, + help="probability of replacing a token with a random token", + ) + parser.add_argument( + "--noise-type", + default="normal", + choices=["trunc_normal", "uniform", "normal", "none"], + help="noise type in coordinate noise", + ) + parser.add_argument( + "--noise", + default=1.0, + type=float, + help="coordinate noise for masked atoms", + ) + parser.add_argument( + "--remove-hydrogen", + action="store_true", + help="remove hydrogen atoms", + ) + parser.add_argument( + "--remove-polar-hydrogen", + action="store_true", + help="remove polar hydrogen atoms", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=256, + help="selected maximum number of atoms in a molecule", + ) + parser.add_argument( + "--dict-name", + default="dict.txt", + help="dictionary file", + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dict_name = args.dict_name + self.dictionary = dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, combine=False, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the split (e.g., train, valid, test) + """ + split_path = os.path.join(self.args.data, split + ".lmdb") + + raw_dataset = LMDBDataset(split_path) + + def one_dataset(raw_dataset, coord_seed, mask_seed): + pdb_id_dataset = KeyDataset(raw_dataset, "pdbid") + dataset = ConformerSamplePocketDataset( + raw_dataset, coord_seed, "atoms", "coordinates", self.dict_name + ) + dataset = AtomTypeDataset(raw_dataset, dataset) + dataset = CroppingPocketDataset( + dataset, self.seed, "atoms", "coordinates", self.args.max_atoms + ) + dataset = NormalizeDataset(dataset, "coordinates", normalize_coord=True) + token_dataset = KeyDataset(dataset, "atoms") + token_dataset = TokenizeDataset( + token_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(dataset, "coordinates") + residue_dataset = KeyDataset(dataset, "residue") + expand_dataset = MaskPointsPocketDataset( + token_dataset, + coord_dataset, + residue_dataset, + self.dictionary, + pad_idx=self.dictionary.pad(), + mask_idx=self.mask_idx, + noise_type=self.args.noise_type, + noise=self.args.noise, + seed=mask_seed, + mask_prob=self.args.mask_prob, + leave_unmasked_prob=self.args.leave_unmasked_prob, + random_token_prob=self.args.random_token_prob, + ) + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + encoder_token_dataset = KeyDataset(expand_dataset, "atoms") + encoder_target_dataset = KeyDataset(expand_dataset, "targets") + encoder_coord_dataset = KeyDataset(expand_dataset, "coordinates") + + src_dataset = PrependAndAppend( + encoder_token_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + tgt_dataset = PrependAndAppend( + encoder_target_dataset, self.dictionary.pad(), self.dictionary.pad() + ) + encoder_coord_dataset = PrependAndAppend(encoder_coord_dataset, 0.0, 0.0) + encoder_distance_dataset = DistanceDataset(encoder_coord_dataset) + + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = DistanceDataset(coord_dataset) + return { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "src_coord": RightPadDatasetCoord( + encoder_coord_dataset, + pad_idx=0, + ), + "src_distance": RightPadDataset2D( + encoder_distance_dataset, + pad_idx=0, + ), + "src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + }, { + "tokens_target": RightPadDataset( + tgt_dataset, pad_idx=self.dictionary.pad() + ), + "distance_target": RightPadDataset2D(distance_dataset, pad_idx=0), + "coord_target": RightPadDatasetCoord(coord_dataset, pad_idx=0), + "pdb_id": RawArrayDataset(pdb_id_dataset), + } + + net_input, target = one_dataset(raw_dataset, self.args.seed, self.args.seed) + dataset = {"net_input": net_input, "target": target} + dataset = NestedDictionaryDataset(dataset) + if split in ["train", "train.small"]: + dataset = EpochShuffleDataset(dataset, len(dataset), self.args.seed) + self.datasets[split] = dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + return model diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket_finetune.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket_finetune.py new file mode 100644 index 0000000000000000000000000000000000000000..c134f7fe60a7558b0ac2f3d850208baa86c55fb7 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/tasks/unimol_pocket_finetune.py @@ -0,0 +1,211 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import logging +import os + +from unicore.data import ( + Dictionary, + NestedDictionaryDataset, + LMDBDataset, + AppendTokenDataset, + PrependTokenDataset, + RightPadDataset, + TokenizeDataset, + RightPadDataset2D, + RawLabelDataset, + FromNumpyDataset, + EpochShuffleDataset, +) + +from unimol.data import ( + KeyDataset, + ConformerSamplePocketFinetuneDataset, + DistanceDataset, + EdgeTypeDataset, + NormalizeDataset, + RightPadDatasetCoord, + CroppingResiduePocketDataset, + RemoveHydrogenResiduePocketDataset, + FromStrLabelDataset, +) + +from unicore.tasks import UnicoreTask, register_task + + +logger = logging.getLogger(__name__) + +task_metainfo = { + "Score": { + "mean": -0.02113608960384876, + "std": 0.14467607204629246, + }, + "Druggability Score": { + "mean": 0.04279187401338044, + "std": 0.1338187819653573, + }, + "Total SASA": { + "mean": 118.7343246335413, + "std": 59.82260887999069, + }, + "Hydrophobicity score": { + "mean": 16.824823092535517, + "std": 18.16340833552264, + }, +} + + +@register_task("pocket_finetune") +class UniMolPocketFinetuneTask(UnicoreTask): + """Task for training transformer auto-encoder models.""" + + @staticmethod + def add_args(parser): + """Add task-specific arguments to the parser.""" + parser.add_argument("data", help="downstream data path") + parser.add_argument("--task-name", type=str, help="downstream task name") + parser.add_argument( + "--classification-head-name", + default="classification", + help="finetune downstream task name", + ) + parser.add_argument( + "--num-classes", + default=2, + type=int, + help="finetune downstream task classes numbers", + ) + parser.add_argument( + "--remove-hydrogen", + action="store_true", + help="remove hydrogen atoms", + ) + parser.add_argument( + "--max-atoms", + type=int, + default=256, + help="selected maximum number of atoms in a molecule", + ) + parser.add_argument( + "--dict-name", + default="dict_pkt.txt", + help="dictionary file", + ) + parser.add_argument( + "--fpocket-score", + default="Druggability Score", + help="Select one of the 4 Fpocket scores as the target", + choices=[ + "Score", + "Druggability Score", + "Total SASA", + "Hydrophobicity score", + ], + ) + + def __init__(self, args, dictionary): + super().__init__(args) + self.dictionary = dictionary + self.seed = args.seed + # add mask token + self.mask_idx = dictionary.add_symbol("[MASK]", is_special=True) + if self.args.task_name == "drugabbility": + if self.args.fpocket_score in task_metainfo: + # for regression task, pre-compute mean and std + self.mean = task_metainfo[self.args.fpocket_score]["mean"] + self.std = task_metainfo[self.args.fpocket_score]["std"] + else: + self.mean, self.std = None, None + + @classmethod + def setup_task(cls, args, **kwargs): + dictionary = Dictionary.load(os.path.join(args.data, args.dict_name)) + logger.info("dictionary: {} types".format(len(dictionary))) + return cls(args, dictionary) + + def load_dataset(self, split, **kwargs): + """Load a given dataset split. + Args: + split (str): name of the data scoure (e.g., bppp) + """ + split_path = os.path.join(self.args.data, self.args.task_name, split + ".lmdb") + dataset = LMDBDataset(split_path) + if self.args.task_name == "druggability": + tgt_dataset_inner = KeyDataset(dataset, "target") + tgt_dataset = KeyDataset(tgt_dataset_inner, self.args.fpocket_score) + tgt_dataset = FromStrLabelDataset(tgt_dataset) + else: + tgt_dataset = KeyDataset(dataset, "target") + tgt_dataset = RawLabelDataset(tgt_dataset) + + dataset = ConformerSamplePocketFinetuneDataset( + dataset, self.seed, "atoms", "residue", "coordinates" + ) + dataset = RemoveHydrogenResiduePocketDataset( + dataset, "atoms", "residue", "coordinates", self.args.remove_hydrogen + ) + dataset = CroppingResiduePocketDataset( + dataset, self.seed, "atoms", "residue", "coordinates", self.args.max_atoms + ) + dataset = NormalizeDataset(dataset, "coordinates") + src_dataset = KeyDataset(dataset, "atoms") + src_dataset = TokenizeDataset( + src_dataset, self.dictionary, max_seq_len=self.args.max_seq_len + ) + coord_dataset = KeyDataset(dataset, "coordinates") + + def PrependAndAppend(dataset, pre_token, app_token): + dataset = PrependTokenDataset(dataset, pre_token) + return AppendTokenDataset(dataset, app_token) + + src_dataset = PrependAndAppend( + src_dataset, self.dictionary.bos(), self.dictionary.eos() + ) + edge_type = EdgeTypeDataset(src_dataset, len(self.dictionary)) + coord_dataset = FromNumpyDataset(coord_dataset) + coord_dataset = PrependAndAppend(coord_dataset, 0.0, 0.0) + distance_dataset = DistanceDataset(coord_dataset) + + nest_dataset = NestedDictionaryDataset( + { + "net_input": { + "src_tokens": RightPadDataset( + src_dataset, + pad_idx=self.dictionary.pad(), + ), + "src_coord": RightPadDatasetCoord( + coord_dataset, + pad_idx=0, + ), + "src_distance": RightPadDataset2D( + distance_dataset, + pad_idx=0, + ), + "src_edge_type": RightPadDataset2D( + edge_type, + pad_idx=0, + ), + }, + "target": { + "finetune_target": tgt_dataset, + }, + }, + ) + + if split.startswith("train"): + nest_dataset = EpochShuffleDataset( + nest_dataset, len(nest_dataset), self.args.seed + ) + self.datasets[split] = nest_dataset + + def build_model(self, args): + from unicore import models + + model = models.build_model(args, self) + model.register_classification_head( + self.args.classification_head_name, + num_classes=self.args.num_classes, + ) + return model diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/__init__.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conf_gen_cal_metrics.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conf_gen_cal_metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..567452a8d0eb5ea6d210821f21c6a1edfd3062e3 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conf_gen_cal_metrics.py @@ -0,0 +1,437 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import pandas as pd +import numpy as np +import os +import copy +import pickle +import lmdb +from rdkit import Chem +from tqdm import tqdm +from rdkit.Chem import rdMolTransforms +from rdkit.Chem import AllChem +from rdkit.Chem.rdMolAlign import GetBestRMS +from rdkit.Chem.rdForceFieldHelpers import MMFFOptimizeMolecule +from rdkit.Chem import rdMolAlign as MA +from scipy.spatial.transform import Rotation +from multiprocessing import Pool +from sklearn.cluster import KMeans +from sklearn_extra.cluster import KMedoids +import argparse +from typing import List + + +def get_torsions(m): + m = Chem.RemoveHs(m) + torsionList = [] + torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" + torsionQuery = Chem.MolFromSmarts(torsionSmarts) + matches = m.GetSubstructMatches(torsionQuery) + for match in matches: + idx2 = match[0] + idx3 = match[1] + bond = m.GetBondBetweenAtoms(idx2, idx3) + jAtom = m.GetAtomWithIdx(idx2) + kAtom = m.GetAtomWithIdx(idx3) + for b1 in jAtom.GetBonds(): + if b1.GetIdx() == bond.GetIdx(): + continue + idx1 = b1.GetOtherAtomIdx(idx2) + for b2 in kAtom.GetBonds(): + if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()): + continue + idx4 = b2.GetOtherAtomIdx(idx3) + # skip 3-membered rings + if idx4 == idx1: + continue + # skip torsions that include hydrogens + if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or ( + m.GetAtomWithIdx(idx4).GetAtomicNum() == 1 + ): + continue + if m.GetAtomWithIdx(idx4).IsInRing(): + torsionList.append((idx4, idx3, idx2, idx1)) + break + else: + torsionList.append((idx1, idx2, idx3, idx4)) + break + break + return torsionList + + +def SetDihedral(conf, atom_idx, new_vale): + rdMolTransforms.SetDihedralRad( + conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3], new_vale + ) + + +def GetDihedral(conf, atom_idx): + return rdMolTransforms.GetDihedralRad( + conf, atom_idx[0], atom_idx[1], atom_idx[2], atom_idx[3] + ) + + +def single_conf_gen(tgt_mol: Chem.Mol, num_confs: int = 1000, seed: int = 42, mmff: bool = False, randomize_angles: bool = False, threads: int = 0) -> Chem.Mol: + """ Generates conformers for a molecule. Functionality to support: https://arxiv.org/abs/2302.07061 """ + mol = copy.deepcopy(tgt_mol) + mol = Chem.AddHs(mol) + allconformers = AllChem.EmbedMultipleConfs( + mol, numConfs=num_confs, randomSeed=seed, clearConfs=True, numThreads=threads + ) + + # WARNING! this might change the molecule stereochemistry + if randomize_angles: + rotable_bonds = get_torsions(mol) + # TODO: if stereochem preservation is wanted, apply same torsion delta to all `i,j,k,l` sharing the same `{j,k}` rotatable bond + for i in range(len(allconformers)): + np.random.seed(i) + values = 3.1415926 * 2 * np.random.rand(len(rotable_bonds)) + for idx in range(len(rotable_bonds)): + SetDihedral(mol.GetConformers()[i], rotable_bonds[idx], values[idx]) + Chem.rdMolTransforms.CanonicalizeConformer(mol.GetConformers()[i]) + + # Forcefield relaxation improves conformer diversity + if mmff: + sz = len(allconformers) + for i in range(sz): + try: + AllChem.MMFFOptimizeMolecule(mol, confId=i) + except: + continue + mol = Chem.RemoveHs(mol) + return mol + + +def clustering( + mol: Chem.Mol, + M: int = 1000, + N: int = 100, + mmff: bool = True, + randomized_angles: bool = False, + kmeans: bool = False, + seed: int = 42, + threads: int = 0, + removeHs: bool = True, +) -> List[np.ndarray]: + """ Creates a diverse set of conformers for a given molecule by + procedurally generating candidates with various rdkit methods and clustering. + Follows principles outlined in: https://arxiv.org/abs/2302.07061 + - For paper reproduction, call with: M=1000, N=100, randomized_angles=True, kmeans=True + - For best UniMol inference: M=1300, N=10, randomized_angles=False, kmeans=False (adjust M>10 for speed) + - WARNING! randomized_angles = True might change the molecule stereochemistry! Ex: PDB: 2ZCR + + Args: + mol (Chem.Mol): rdkit molecule + M (int): Number of conformers to generate. + N (int): Number of conformers to return. + mmff (bool): Whether to use MMFF forcefield relaxation. + randomized_angles (bool, optional): Whether to use an additional M/4 conformers with randomized torsion angles. + WARNING! might change the molecule stereochemistry + kmeans (bool): Whether to use kmeans or kmedoids. + Kmeans picks random example of cluster, Kmedoids picks cluster centroid. + seed (int): Random seed for conformer generation. + threads (int): Number of threads to use for conformer generation. If 0, uses all available threads. + removeHs (bool): Whether to remove hydrogens from the final conformers. + + Returns: + List[np.ndarray]: List of conformer coordinates + """ + # to support ref paper by default but not be too expensive + if not mmff: + M = M*4 + + total_sz = 0 + rdkit_coords_list = [] + + # add no-MMFF-optimized conformers (ETKDG v3) + rdkit_mol = single_conf_gen(mol, num_confs=int(M // 4), seed=seed, threads=threads) + if removeHs: + rdkit_mol = Chem.RemoveHs(rdkit_mol) + sz = len(rdkit_mol.GetConformers()) + tgt_coords = rdkit_mol.GetConformers()[0].GetPositions().astype(np.float32) + tgt_coords = tgt_coords - np.mean(tgt_coords, axis=0) + for i in range(sz): + _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # add forcefield optimized conformers + if mmff: + rdkit_mol = single_conf_gen(mol, num_confs=M, mmff=True, seed=seed+1, threads=threads) + if removeHs: + rdkit_mol = Chem.RemoveHs(rdkit_mol) + sz = len(rdkit_mol.GetConformers()) + for i in range(sz): + _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + # add uniform rotation bonds conformers - WARNING! - might alter stereochemistry. Ex: PDB-2ZCR + if randomized_angles: + rdkit_mol = single_conf_gen(mol, num_confs=int(M // 4), seed=seed+2, threads=threads, randomize_angles=True) + if removeHs: + rdkit_mol = Chem.RemoveHs(rdkit_mol) + sz = len(rdkit_mol.GetConformers()) + for i in range(sz): + _coords = rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32) + _coords = _coords - _coords.mean(axis=0) # need to normalize first + _R, _score = Rotation.align_vectors(_coords, tgt_coords) + rdkit_coords_list.append(np.dot(_coords, _R.as_matrix())) + total_sz += sz + + rdkit_coords_flatten = np.array(rdkit_coords_list).reshape(total_sz, -1) + if kmeans: + ids = ( + KMeans(n_clusters=N, random_state=42) + .fit_predict(rdkit_coords_flatten) + .tolist() + ) + coords_list = [rdkit_coords_list[ids.index(i)] for i in range(N)] + else: + clust = KMedoids(n_clusters=N, random_state=seed, ) + clust.fit(rdkit_coords_flatten) + idxs = clust.medoid_indices_.tolist() + coords_list = [rdkit_coords_list[idx] for idx in idxs] + + return coords_list + + +def single_process_data(content) -> List: + smi, tgt_mol_list = content[0], content[1] + M = min(20 * len(tgt_mol_list), 2000) + N = 2 * len(tgt_mol_list) + tgt_mol = copy.deepcopy(tgt_mol_list[0]) + tgt_mol = Chem.RemoveHs(tgt_mol) + rdkit_cluster_coords_list = clustering(tgt_mol, M=M, N=N) + atoms = [atom.GetSymbol() for atom in tgt_mol.GetAtoms()] + sz = len(rdkit_cluster_coords_list) + ## check target molecule atoms is the same as the input molecule + for _mol in tgt_mol_list: + _mol = Chem.RemoveHs(_mol) + _atoms = [atom.GetSymbol() for atom in _mol.GetAtoms()] + assert _atoms == atoms, print(smi) + + tgt_coords = tgt_mol.GetConformer().GetPositions().astype(np.float32) + dump_list = [] + for i in range(sz): + dump_list.append( + { + "atoms": atoms, + "coordinates": [rdkit_cluster_coords_list[i]], + "smi": smi, + "target": tgt_coords, + } + ) + return dump_list + + +def write_lmdb(content_list, output_dir, name, nthreads=16): + + os.makedirs(output_dir, exist_ok=True) + output_name = os.path.join(output_dir, f"{name}.lmdb") + print(output_name) + try: + os.remove(output_name) + except: + pass + env_new = lmdb.open( + output_name, + subdir=False, + readonly=False, + lock=False, + readahead=False, + meminit=False, + max_readers=1, + map_size=int(100e9), + ) + txn_write = env_new.begin(write=True) + with Pool(nthreads) as pool: + i = 0 + for inner_output in tqdm(pool.imap(inner_process, content_list)): + if inner_output is not None: + for item in inner_output: + txn_write.put( + f"{i}".encode("ascii"), pickle.dumps(item, protocol=-1) + ) + i += 1 + print("{} process {} lines".format(output_name, i)) + txn_write.commit() + env_new.close() + + +def inner_process(content): + try: + return single_process_data(content) + except: + return None + + +def data_pre(predict_path, data_path, normalize=True): + + predict = pd.read_pickle(predict_path) + data = pd.read_pickle(data_path) + data = data.groupby("smi")["mol"].apply(list).reset_index() + smi_list, predict_list = [], [] + for batch in predict: + sz = batch["bsz"] + for i in range(sz): + smi_list.append(batch["smi_name"][i]) + coord_predict = batch["coord_predict"][i] + coord_target = batch["coord_target"][i] + coord_mask = coord_target[:, 0].ne(0) + coord_predict = coord_predict[coord_mask, :].cpu().numpy() + if normalize: + coord_predict = coord_predict - coord_predict.mean(axis=0) + + predict_list.append(coord_predict) + + predict_df = pd.DataFrame({"smi": smi_list, "predict_coord": predict_list}) + predict_df = predict_df.groupby("smi")["predict_coord"].apply(list).reset_index() + + df = pd.merge(data, predict_df, on="smi", how="left") + print("preprocessing 1...") + ref_mols_list, gen_mols_list = [], [] + for smi, mol_list, pos_list in zip(df["smi"], df["mol"], df["predict_coord"]): + if "." in smi: + print(smi) + continue + ref_mols_list.append(mol_list) + gen_mols = [set_rdmol_positions(mol_list[0], pos) for pos in pos_list] + gen_mols_list.append(gen_mols) + print("preprocessing 2...") + return ref_mols_list, gen_mols_list + + +def get_rmsd_min(ref_mols, gen_mols, use_ff=False, threshold=0.5): + rmsd_mat = np.zeros([len(ref_mols), len(gen_mols)], dtype=np.float32) + for i, gen_mol in enumerate(gen_mols): + gen_mol_c = copy.deepcopy(gen_mol) + if use_ff: + MMFFOptimizeMolecule(gen_mol_c) + for j, ref_mol in enumerate(ref_mols): + ref_mol_c = copy.deepcopy(ref_mol) + rmsd_mat[j, i] = get_best_rmsd(gen_mol_c, ref_mol_c) + rmsd_mat_min = rmsd_mat.min(-1) + return (rmsd_mat_min <= threshold).mean(), rmsd_mat_min.mean() + + +def get_best_rmsd(gen_mol, ref_mol): + gen_mol = Chem.RemoveHs(gen_mol) + ref_mol = Chem.RemoveHs(ref_mol) + rmsd = MA.GetBestRMS(gen_mol, ref_mol) + return rmsd + + +def set_rdmol_positions(rdkit_mol, pos): + rdkit_mol = Chem.RemoveHs(rdkit_mol) + assert rdkit_mol.GetConformer(0).GetPositions().shape[0] == pos.shape[0] + mol = copy.deepcopy(rdkit_mol) + for i in range(pos.shape[0]): + mol.GetConformer(0).SetAtomPosition(i, pos[i].tolist()) + return mol + + +def print_results(cov, mat): + print("COV_mean: ", np.mean(cov), ";COV_median: ", np.median(cov)) + print("MAT_mean: ", np.mean(mat), ";MAT_median: ", np.median(mat)) + + +def single_process(content): + ref_mols, gen_mols, use_ff, threshold = content + cov, mat = get_rmsd_min(ref_mols, gen_mols, use_ff, threshold) + return cov, mat + + +def process(content): + try: + return single_process(content) + except: + return None + + +def cal_metrics(predict_path, data_path, use_ff=False, threshold=0.5, nthreads=40): + ref_mols_list, gen_mols_list = data_pre(predict_path, data_path, normalize=True) + print("cal_metrics...") + cov_list, mat_list = [], [] + content_list = [] + for ref_mols, gen_mols in zip(ref_mols_list, gen_mols_list): + content_list.append((ref_mols, gen_mols, use_ff, threshold)) + with Pool(nthreads) as pool: + for inner_output in tqdm(pool.imap(process, content_list)): + if inner_output is None: + continue + cov, mat = inner_output + cov_list.append(cov) + mat_list.append(mat) + print_results(cov_list, mat_list) + return True + + +def main(): + parser = argparse.ArgumentParser( + description="generate initial rdkit test data and cal metrics" + ) + parser.add_argument( + "--mode", + type=str, + default="cal_metrics", + choices=["gen_data", "cal_metrics"], + ) + parser.add_argument( + "--reference-file", + type=str, + default="./conformation_generation/qm9/test_data_200.pkl", + help="Location of the reference set", + ) + parser.add_argument( + "--output-dir", + type=str, + default="./conformation_generation/qm9", + help="Location of the generated data", + ) + parser.add_argument("--nthreads", type=int, default=40, help="num of threads") + parser.add_argument( + "--predict-file", + type=str, + default="./infer_confgen/save_confgen_test.out.pkl", + help="Location of the prediction file", + ) + parser.add_argument( + "--threshold", + default=0.5, + type=float, + help="threshold for cal metrics, qm9: 0.5; drugs: 1.25", + ) + args = parser.parse_args() + + if args.mode == "gen_data": + # generate test data + output_dir = args.output_dir + name = "test" + data = pd.read_pickle(args.reference_file) + content_list = ( + pd.DataFrame(data).groupby("smi")["mol"].apply(list).reset_index().values + ) + print(content_list[0]) + write_lmdb(content_list, output_dir, name, nthreads=args.nthreads) + + ### Uni-Mol predicting... ### + + elif args.mode == "cal_metrics": + # cal metrics + predict_file = args.predict_file + data_path = args.reference_file + use_ff = False + threshold = args.threshold + cal_metrics(predict_file, data_path, use_ff, threshold, args.nthreads) + + +if __name__ == "__main__": + main() diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conformer_model.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conformer_model.py new file mode 100644 index 0000000000000000000000000000000000000000..cae5c15163951b8a4d5fd052058a2bda2e36123c --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/conformer_model.py @@ -0,0 +1,947 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import numpy as np +import mindspore as ms +import mindspore.mint.nn as nn +import pandas as pd +from copy import deepcopy +from rdkit import Chem +from rdkit.Chem import AllChem +import pickle +import argparse +import warnings +from docking_utils import rmsd_func +from typing import List, Optional + + +warnings.filterwarnings(action="ignore") +# 设置Ascend设备 +ms.context.set_context(device_target="Ascend") + + +# Utils +def rot_from_axis_angle(axis: ms.Tensor, angle: ms.Tensor) -> ms.Tensor: + """ ((...), 3), ((...),) -> ((...), 3, 3) """ + # 归一化轴向量 + v1, v2, v3 = ms.ops.unbind(ms.ops.normalize(axis, dim=-1), dim=-1) + zero = ms.ops.zeros_like(v1) + # 构建叉乘矩阵 + cross_matrix = ms.ops.stack( + ( + ms.ops.stack((zero, -v3, v2), dim=-1), + ms.ops.stack((v3, zero, -v1), dim=-1), + ms.ops.stack((-v2, v1, zero), dim=-1), + ), + dim=-2, + ) + # 构建单位矩阵 + ide = ms.ops.eye(3, device=v1.device, dtype=v1.dtype).repeat( + *(1,) * len(v1.shape), 1, 1 + ) + angle = angle.expand_dims(-1).expand_dims(-1) + return ( + ide + + ms.ops.sin(angle) * cross_matrix + + (1 - ms.ops.cos(angle)) * (cross_matrix @ cross_matrix) + ) + + +def rot_from_euler(alpha_beta_gamma: ms.Tensor) -> ms.Tensor: + """ rotation from euler angles. ((...), 3) -> ((...), 3, 3) """ + alpha, beta, gamma = ms.ops.unbind(alpha_beta_gamma.clone(), dim=-1) + zeros = ms.ops.zeros_like(alpha) + # 构建Rx矩阵 + Rx_tensor = ms.ops.stack(( + (alpha + 1) / (alpha + 1), zeros, zeros, + zeros, ms.ops.cos(alpha), -ms.ops.sin(alpha), + zeros, ms.ops.sin(alpha), ms.ops.cos(alpha) + ), axis=-1).reshape(*alpha.shape, 3, 3) + # 构建Ry矩阵 + Ry_tensor = ms.ops.stack(( + ms.ops.cos(beta), zeros, -ms.ops.sin(beta), + zeros, (beta + 1) / (beta + 1), zeros, + ms.ops.sin(beta), zeros, ms.ops.cos(beta) + ), axis=-1).reshape(*beta.shape, 3, 3) + # 构建Rz矩阵 + Rz_tensor = ms.ops.stack(( + ms.ops.cos(gamma), -ms.ops.sin(gamma), zeros, + ms.ops.sin(gamma), ms.ops.cos(gamma), zeros, + zeros, zeros, (gamma + 1) / (gamma + 1) + ), axis=-1).reshape(*gamma.shape, 3, 3) + + R = (Rx_tensor @ Ry_tensor) @ Rz_tensor + return R + + +def get_dihedral( + c1: ms.Tensor, c2: ms.Tensor, c3: ms.Tensor, c4: ms.Tensor, eps: float = 1e-7 +) -> ms.Tensor: + """ Dihedral angle in radians. """ + u1 = c2 - c1 + u2 = c3 - c2 + u3 = c4 - c3 + + u2u3_cross = ms.ops.cross(u2, u3, dim=-1) + u2norm = ms.ops.sqrt(ms.ops.sum(ms.ops.square(u2), dim=-1, keepdim=True) + eps) + + return ms.ops.atan2( + ms.ops.sum(u2norm * u1 * u2u3_cross, -1), + ms.ops.sum(ms.ops.cross(u1, u2, dim=-1) * u2u3_cross, -1), + ) + + +def get_flexible_torsions(mol: Chem.Mol) -> ms.Tensor: + """ Gets a unique set of ligand torsions which are rotatable. Shape: (T, 4) """ + # 获取距离矩阵 + dist_mat = ms.Tensor(Chem.GetDistanceMatrix(mol), dtype=ms.float32) + # 获取可旋转键 + torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" + torsionQuery = Chem.MolFromSmarts(torsionSmarts) + matches = set(mol.GetSubstructMatches(torsionQuery)) + # 获取3跳连接的原子 + i_, l_ = ms.ops.nonzero(ms.ops.triu(dist_mat) == 3).asnumpy().T.tolist() + + flex_unique_torsions = [] + central_bond_torsions = set() + for i, l in zip(i_, l_): + i, j, k, l = Chem.GetShortestPath(mol, i, l) + if {(j, k), (k, j)}.intersection(matches) and (j, k) not in central_bond_torsions: + central_bond_torsions.update([(j, k), (k, j)]) + # 确定旋转方向 + if (dist_mat[j] < dist_mat[k]).sum() > (dist_mat[j] > dist_mat[k]).sum(): + flex_unique_torsions.append([i, j, k, l]) + else: + flex_unique_torsions.append([l, k, j, i]) + return ms.Tensor(flex_unique_torsions, dtype=ms.int32) + + +def rotate_along_axis(x: ms.Tensor, origin: ms.Tensor, axis: ms.Tensor, angle: ms.Tensor) -> ms.Tensor: + """ Rotates a point cloud around an axis given an origin """ + rot_mat = rot_from_axis_angle(axis, angle) + return ms.ops.einsum('...rc,...nc -> ...nr', rot_mat, x - origin) + origin + + +def update_dihedral(coords: ms.Tensor, idxs: List[int], value: float, dist_mat: ms.Tensor = None) -> ms.Tensor: + """Modifies a dihedral/torsion for a molecule with the given value.""" + i, j, k, l = idxs + if not isinstance(value, ms.Tensor): + value = ms.Tensor(value, dtype=coords.dtype, device=coords.device) + + # 确定需要旋转的原子(更靠近k的原子) + mask_rotate = dist_mat[k] < dist_mat[j] + + # 计算旋转角度(目标角度与当前角度的差) + current_dihedral = get_dihedral(*ms.ops.unbind(coords[..., idxs, :], dim=-2)) + rotate_angle = value - current_dihedral + + # 执行旋转 + coords_rot = rotate_along_axis( + x=coords[..., mask_rotate, :], + origin=coords[..., [j], :], + axis=coords[..., k, :] - coords[..., j, :], + angle=rotate_angle, + ) + # 更新坐标 + coords = ms.ops.where(mask_rotate[..., None], coords_rot, coords) + return coords + + +# Docking functions +def single_SF_loss( + predict_coords: ms.Tensor, + pocket_coords: ms.Tensor, + cross_distance_predict: ms.Tensor, + self_distance_predict: ms.Tensor, + dist_threshold: float = 4.5, + cross_dist_weight: float = 1.0, + self_dist_weight: float = 2.0, + reduce_batch: bool = True, +): + """ Calculates loss function """ + # 计算实际交叉距离 + cross_dist = ms.ops.norm(predict_coords[..., None, :] - pocket_coords[..., None, :, :], dim=-1) + # 计算实际自身距离 + self_dist = ms.ops.norm(predict_coords[..., None, :] - predict_coords[..., None, :, :], dim=-1) + + # 距离掩码(只考虑预测距离小于阈值的部分) + dist_mask = cross_distance_predict < dist_threshold + + # 计算交叉距离损失 + cross_dist_score = ms.ops.sum((cross_dist - cross_distance_predict)**2 * dist_mask, dim=(-1, -2)) / ms.ops.sum(dist_mask, dim=(-1, -2)) + # 计算自身距离损失 + dist_score = ms.ops.mean((self_dist - self_distance_predict)** 2, dim=(-1, -2)) + + # 总损失 + loss = cross_dist_score * cross_dist_weight + dist_score * self_dist_weight + + # 冲突惩罚(信息性指标) + clash_pl_score = ms.ops.square(ms.ops.clamp(cross_dist - 3., max=0) * 5.) + clash_pl_score = ms.ops.sum(clash_pl_score, dim=(-1, -2)) / ms.ops.sum(dist_mask, dim=(-1, -2)) + + if reduce_batch: + return cross_dist_score.mean().asnumpy(), dist_score.mean().asnumpy(), 0., loss.mean() + return cross_dist_score.asnumpy(), dist_score.asnumpy(), clash_pl_score.asnumpy(), loss + + +def dock_with_gradient( + coords: np.ndarray, + pocket_coords: np.ndarray, + distance_predict_tta: np.ndarray, + holo_distance_predict_tta: np.ndarray, + mol: Chem.Mol, + conf_coords: np.ndarray, + loss_func=single_SF_loss, + holo_coords: np.ndarray = None, + iterations: int = 400, + early_stoping: int = 5, +): + """ Docking with gradient descent, optimizing the conformer. """ + bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None + for i, (distance_predict, holo_distance_predict) in enumerate( + zip(distance_predict_tta, holo_distance_predict_tta) + ): + new_coords = deepcopy(coords) + _coords, _loss, _meta_info = single_dock_with_gradient( + new_coords, + pocket_coords, + distance_predict, + holo_distance_predict, + mol=mol, + conf_coords=deepcopy(np.array(conf_coords)), + loss_func=loss_func, + holo_coords=holo_coords, + iterations=iterations, + early_stoping=early_stoping, + ) + if bst_loss > _loss: + bst_coords = _coords + bst_loss = _loss + bst_meta_info = _meta_info + return bst_coords, bst_loss, bst_meta_info + + +def kabsch(x: ms.Tensor, y: ms.Tensor, weight: Optional[ms.Tensor] = None) -> ms.Tensor: + """ Aligns x onto y. x, y are ((...), N, 3) tensors. """ + if weight is None: + weight = ms.ops.ones_like(x[..., 0]) + + weight = weight / ms.ops.sum(weight, dim=-1, keepdim=True) + x_mean = ms.ops.sum(x * weight[..., None], dim=-2, keepdim=True) + y_mean = ms.ops.sum(y * weight[..., None], dim=-2, keepdim=True) + x = x - x_mean + y = y - y_mean + + # 尝试通过SVD计算旋转矩阵 + try: + cov = ms.ops.einsum("...ni,...nj->...ij", x, y * weight[..., None])[..., None, :, :] + u, s, v = ms.ops.svd(cov) + # 修正旋转矩阵行列式 + det = ms.ops.det(v) * ms.ops.det(u) + u_flip = ms.ops.ones_like(u) + u_flip = ms.ops.where(det < 0, ms.ops.index_fill(u_flip, -1, -1, -1.0), u_flip) + u = u * u_flip + rot = u @ v + x = rot @ x + except: + pass + return x + y_mean + + +def single_dock_with_gradient( + coords: np.ndarray, + pocket_coords: np.ndarray, + distance_predict: np.ndarray, + holo_distance_predict: np.ndarray, + mol: Chem.Mol, + conf_coords: np.ndarray, + loss_func=single_SF_loss, + holo_coords: np.ndarray = None, + iterations: int = 20000, + early_stoping: int = 5, +): + """ Strategy: create multiple conformers, align to coordinates, optimize the conformer """ + # 转换为MindSpore张量 + coords = ms.Tensor(coords, dtype=ms.float32) + pocket_coords = ms.Tensor(pocket_coords, dtype=ms.float32) + distance_predict = ms.Tensor(distance_predict, dtype=ms.float32) + holo_distance_predict = ms.Tensor(holo_distance_predict, dtype=ms.float32) + conf_coords = ms.Tensor(conf_coords, dtype=ms.float32) + + if holo_coords is not None: + holo_coords = ms.Tensor(holo_coords, dtype=ms.float32) + + # 准备优化参数 + num_conformers = conf_coords.shape[0] + torsion_idxs = get_flexible_torsions(mol) # (T, 4) + graph_dist_mat = ms.Tensor(Chem.GetDistanceMatrix(mol), dtype=ms.int64) # (N, N) + + # 初始化旋转角、平移和扭转角(作为可训练参数) + euler = ms.Parameter(ms.ops.randn(num_conformers, 3) * 1e-3, requires_grad=True) + trans = ms.Parameter(ms.ops.randn(num_conformers, 1, 3) + coords.mean(dim=-2)[None, None], requires_grad=True) + + if torsion_idxs.shape[-1] > 0: + torsions = get_dihedral(*ms.ops.unbind(conf_coords[..., torsion_idxs, :], dim=-2)) + torsions = ms.Parameter(torsions + ms.ops.randn_like(torsions) * 1e-3, requires_grad=True) + else: + torsions = ms.Parameter(ms.ops.zeros(num_conformers, 0), requires_grad=True) + + # 扩展标签的batch维度 + pocket_coords = pocket_coords[None].repeat(num_conformers, 1, 1) + distance_predict = distance_predict[None].repeat(num_conformers, 1, 1) + holo_distance_predict = holo_distance_predict[None].repeat(num_conformers, 1, 1) + + # 定义优化器 + optimizer = nn.LBFGS(params=[euler, trans, torsions], learning_rate=0.5) + + bst_loss, times = 10000.0, 0 + + # 定义训练步骤 + def train_step(): + # 重置梯度 + optimizer.zero_grad() + # 参数化配体坐标 + aux_coords = conf_coords + trans + # 中心坐标 + com = aux_coords.mean(dim=-2, keepdim=True) + # 旋转 + rot = rot_from_euler(euler) + aux_coords = ms.ops.einsum('...rc,...nc->...nr', rot, aux_coords - com) + com + pre_aux_coords = aux_coords.clone() + # 扭转角更新 + kabsch对齐 + for t, vals in zip(torsion_idxs, ms.ops.unbind(torsions, dim=-1)): + aux_coords = update_dihedral(coords=aux_coords, idxs=t.asnumpy().tolist(), value=vals, dist_mat=graph_dist_mat) + aux_coords = kabsch(aux_coords, pre_aux_coords) + + # 计算损失 + _, _, _, loss = loss_func( + aux_coords, pocket_coords, distance_predict, holo_distance_predict + ) + return loss + + # 执行优化 + for i in range(iterations): + loss = optimizer(train_step) + loss_val = loss.asnumpy().item() + + if loss_val < bst_loss: + bst_loss = loss_val + times = 0 + else: + times += 1 + if times > early_stoping: + break + + # 获取最佳构象 + aux_coords = conf_coords + trans + com = aux_coords.mean(dim=-2, keepdim=True) + rot = rot_from_euler(euler) + aux_coords = ms.ops.einsum('...rc,...nc->...nr', rot, aux_coords - com) + com + pre_aux_coords = aux_coords.clone() + for t, vals in zip(torsion_idxs, ms.ops.unbind(torsions, dim=-1)): + aux_coords = update_dihedral(coords=aux_coords, idxs=t.asnumpy().tolist(), value=vals, dist_mat=graph_dist_mat) + aux_coords = kabsch(aux_coords, pre_aux_coords) + + cross_score, self_score, clash_score, loss = loss_func( + aux_coords, pocket_coords, distance_predict, holo_distance_predict, reduce_batch=False + ) + best_idx = ms.ops.argmax(loss).asnumpy().item() + return aux_coords[best_idx].asnumpy(), loss[best_idx].asnumpy(), ( + cross_score[best_idx], self_score[best_idx], clash_score[best_idx] + ) + + +def set_coord(mol, coords): + for i in range(coords.shape[0]): + mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist()) + return mol + + +def add_coord(mol, xyz): + x, y, z = xyz + conf = mol.GetConformer(0) + pos = conf.GetPositions() + pos[:, 0] += x + pos[:, 1] += y + pos[:, 2] += z + for i in range(pos.shape[0]): + conf.SetAtomPosition( + i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2]) + ) + return mol + + +def single_docking(input_path: str, output_path: str, output_ligand_path: str): + """ Performs docking based on UniMol predictions. """ + content = pd.read_pickle(input_path) + ( + init_coords_tta, + mol, + smi, + pocket, + pocket_coords, + distance_predict_tta, + holo_distance_predict_tta, + holo_coords, + holo_cener_coords, + ) = content + sample_times = len(init_coords_tta) + + bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None + for i in range(sample_times): + init_coords = init_coords_tta[i] + predict_coords, loss, meta_info = dock_with_gradient( + init_coords, + pocket_coords, + distance_predict_tta[i][None], + holo_distance_predict_tta[i][None], + mol=mol, + conf_coords=init_coords_tta[i][None], + holo_coords=holo_coords, + loss_func=single_SF_loss, + ) + if loss < bst_loss: + bst_loss = loss + bst_predict_coords = predict_coords + bst_meta_info = meta_info + + _rmsd = round(rmsd_func(holo_coords, bst_predict_coords, mol=mol), 4) + _cross_score = round(float(bst_meta_info[0]), 4) + _self_score = round(float(bst_meta_info[1]), 4) + _clash_score = round(float(bst_meta_info[2]), 4) + print(f"{pocket}-{smi}-RMSD:{_rmsd}-CROSSSCORE:{_cross_score}-SELFSCORE:{_self_score}-CLASHSCORE:{_clash_score}") + mol = Chem.RemoveHs(mol) + mol = set_coord(mol, bst_predict_coords) + + if output_path is not None: + with open(output_path, "wb") as f: + pickle.dump( + [mol, bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords], + f, + ) + if output_ligand_path is not None: + mol = add_coord(mol, holo_cener_coords.asnumpy()) + Chem.MolToMolFile(mol, output_ligand_path) + + return True + + +if __name__ == "__main__": + ms.set_seed(0) # 替换PyTorch的随机种子设置 + parser = argparse.ArgumentParser(description="Docking with gradient") + parser.add_argument("--input", type=str, help="input file.") + parser.add_argument("--output", type=str, default=None, help="output path.") + parser.add_argument( + "--output-ligand", type=str, default=None, help="output ligand sdf path." + ) + args = parser.parse_args() + + single_docking(args.input, args.output, args.output_ligand) +# import numpy as np +# import torch as th +# import pandas as pd +# from copy import deepcopy +# from rdkit import Chem +# from rdkit.Chem import AllChem +# import pickle +# import argparse +# import warnings +# from docking_utils import rmsd_func +# from typing import List, Optional + + +# warnings.filterwarnings(action="ignore") + +# # Utils + +# def rot_from_axis_angle(axis: th.Tensor, angle: th.Tensor) -> th.Tensor: +# """ ((...), 3), ((...),) -> ((...), 3, 3) """ +# # ((...), D) -> ((...),) +# v1, v2, v3 = th.nn.functional.normalize(axis, dim=-1).unbind(dim=-1) +# zero = th.zeros_like(v1) +# # ((...),) -> ((...), 3, 3) +# cross_matrix = th.stack( +# ( +# th.stack((zero, -v3, v2), dim=-1), +# th.stack((v3, zero, -v1), dim=-1), +# th.stack((-v2, v1, zero), dim=-1), +# ), +# dim=-2, +# ) +# ide = th.eye(3, device=v1.device, dtype=v1.dtype).repeat( +# *(1,) * len(v1.shape), 1, 1 +# ) +# angle = angle.unsqueeze(dim=-1).unsqueeze(dim=-1) +# return ( +# ide +# + th.sin(angle) * cross_matrix +# + (1 - th.cos(angle)) * (cross_matrix @ cross_matrix) +# ) + +# def rot_from_euler(alpha_beta_gamma: th.Tensor) -> th.Tensor: +# """ rotation from euler angles. ((...), 3) -> ((...), 3, 3) """ +# alpha, beta, gamma = alpha_beta_gamma.clone().unbind(dim=-1) +# zeros = th.zeros_like(alpha) +# Rx_tensor = th.stack(( +# (alpha + 1) / (alpha + 1), zeros, zeros, +# zeros, th.cos(alpha), - th.sin(alpha), +# zeros, th.sin(alpha), th.cos(alpha) +# ), axis=-1).reshape(*alpha.shape, 3, 3) +# Ry_tensor = th.stack(( +# th.cos(beta), zeros, - th.sin(beta), +# zeros, (beta + 1) / (beta + 1), zeros, +# th.sin(beta), zeros, th.cos(beta) +# ), axis=-1).reshape(*beta.shape, 3, 3) +# Rz_tensor = th.stack(( +# th.cos(gamma), -th.sin(gamma), zeros, +# th.sin(gamma), th.cos(gamma), zeros, +# zeros, zeros, (gamma + 1) / (gamma + 1) +# ), axis=-1).reshape(*gamma.shape, 3, 3) + +# R = (Rx_tensor @ Ry_tensor) @ Rz_tensor +# return R + +# def get_dihedral( +# c1: th.Tensor, c2: th.Tensor, c3: th.Tensor, c4: th.Tensor, eps: float = 1e-7 +# ) -> th.Tensor: +# """ Dihedral angle in radians. atan2 formula from: +# https://en.wikipedia.org/wiki/Dihedral_angle#In_polymer_physics +# Inputs: c1, c2, c3, c4 are all ((...), 3,) +# * eps: float. small number to avoid division by zero. +# Outputs: ((...),) tensor +# """ +# u1 = c2 - c1 +# u2 = c3 - c2 +# u3 = c4 - c3 + +# u2u3_cross = th.cross(u2, u3, dim=-1) +# u2norm = u2.square().sum(dim=-1, keepdim=True).add(eps).sqrt() + +# return th.atan2( +# (u2norm * u1 * u2u3_cross).sum(-1), +# (th.cross(u1, u2, dim=-1) * u2u3_cross).sum(-1), +# ) + +# def get_flexible_torsions(mol: Chem.Mol) -> th.Tensor: +# """ Gets a unique set of ligand torsions which are rotatable. Shape: (T, 4) """ +# # get 3-hop connected atoms, directionally so no repeats +# dist_mat = th.from_numpy(Chem.GetDistanceMatrix(mol)) +# # get rotatable bonds +# torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" +# torsionQuery = Chem.MolFromSmarts(torsionSmarts) +# matches = set(mol.GetSubstructMatches(torsionQuery)) +# # get 3-hop connected atoms, directionally so no repeats +# i_, l_ = (dist_mat.triu() == 3).bool().nonzero().T.tolist() +# # Shortest path, where rotatable bond in the middle is a torsion +# flex_unique_torsions = [] +# central_bond_torsions = set() +# for i, l in zip(i_, l_): +# i, j, k, l = Chem.GetShortestPath(mol, i, l) +# if {(j, k), (k, j)}.intersection(matches) and (j, k) not in central_bond_torsions: +# # Register the central bond +# central_bond_torsions.update([(j, k), (k, j)]) +# # torsion in the direction that leaves lesser atoms to later rotate: towards the periphery +# if (dist_mat[j] < dist_mat[k]).sum() > (dist_mat[j] > dist_mat[k]).sum(): +# flex_unique_torsions.append([i, j, k, l]) +# else: +# flex_unique_torsions.append([l, k, j, i]) +# return th.tensor(flex_unique_torsions) + + +# def rotate_along_axis(x: th.Tensor, origin: th.Tensor, axis: th.Tensor, angle: th.Tensor) -> th.Tensor: +# """ Rotates a point cloud around an axis given an origin +# Inputs: +# * x: ((...), N, 3) +# * origin: ((...), N_or_1, 3) +# * axis: ((...), 3) +# * angle: (,) th.Tensor +# Outputs: ((...), N, 3) rotated coordinates +# """ +# rot_mat = rot_from_axis_angle(axis, angle) +# return th.einsum('...rc,...nc -> ...nr', rot_mat, x - origin) + origin + + +# def update_dihedral(coords: th.Tensor, idxs: List[int], value: float, dist_mat: th.Tensor = None) -> th.Tensor: +# """Modifies a dihedral/torsion for a molecule with the given value. +# Analog to rdkit.Chem.rdMolTransforms.SetDihedralRad, but differentiable. +# WARNING! Assumes bond between j-k is rotatble +# Inputs: +# * coords: ((...), N, 3) +# * idxs: (4,) List or th.Tensor of dtype th.long. indexes to define the torsion +# * value: float or th.Tensor of single value or ((...),). New value for the torsion (in radians) +# * dist_mat: (N, N) length of shortest path for each i-j. +# Outputs: ((...), N, 3) updated coords +# """ +# i, j, k, l = idxs +# if not isinstance(value, th.Tensor): +# value = th.tensor(value, dtype=coords.dtype, device=coords.device) + +# # atoms whose coords will be updated - closer to k than j +# mask_rotate = dist_mat[k] < dist_mat[j] + +# # amount to rotate is the difference between current and desired +# coords[..., mask_rotate, :] = rotate_along_axis( +# x=coords[..., mask_rotate, :], +# origin=coords[..., [j], :], +# axis=coords[..., k, :] - coords[..., j, :], +# angle=value - get_dihedral(*coords[..., idxs, :].unbind(dim=-2)), +# ) +# return coords + + +# # Docking functions + +# def single_SF_loss( +# predict_coords: th.Tensor, +# pocket_coords: th.Tensor, +# cross_distance_predict: th.Tensor, +# self_distance_predict: th.Tensor, +# dist_threshold: float = 4.5, +# cross_dist_weight: float = 1.0, +# self_dist_weight: float = 2.0, +# reduce_batch: bool = True, +# ): +# """ Calculates loss function +# Args: +# predict_coords: ((...), N, 3) predicted molecule coordinates +# pocket_coords: ((...), P, 3) pocket coordinates +# cross_distance_predict: ((...), N, P) predicted (molecule-pocket) distance matrix +# self_distance_predict: ((...), N, N) predicted (molecule-molecule) distance +# dist_threshold: max dist to consider molecule-pocket interactions in the loss +# cross_dist_weight: weight of cross distance loss +# self_dist_weight: weight of self distance loss +# reduce_batch: whether to reduce the batch dimension + +# Returns: +# cross_dist_score: cross distance score. scalar. numpy +# dist_score: distance score. scalar. numpy +# clash_score. clash score. informative. scalar. numpy. +# loss: loss value. scalar. has gradients +# """ +# # ((...), N, 1, 3) - ((...), 1, P, 3) -> ((...), N, P) +# cross_dist = (predict_coords[..., None, :] - pocket_coords[..., None, :, :]).norm(dim=-1) +# # ((...), N, 1, 3) - ((...), 1, N, 3) -> ((...), N, N) +# self_dist = (predict_coords[..., None, :] - predict_coords[..., None, :, :]).norm(dim=-1) +# # only consider local molecule-pocket interactions +# dist_mask = cross_distance_predict < dist_threshold +# # ((...), N, N) -> ((...),) +# cross_dist_score = ((cross_dist - cross_distance_predict)**2 * dist_mask).sum() / dist_mask.sum(dim=(-1, -2)) +# dist_score = ((self_dist - self_distance_predict) ** 2).mean(dim=(-1, -2)) +# # weight different loss terms +# loss = cross_dist_score * cross_dist_weight + dist_score * self_dist_weight +# # penalize clashes - informative +# clash_pl_score = ((cross_dist - 3.).clamp(max=0) * 5.).square() +# clash_pl_score = clash_pl_score.sum(dim=(-1, -2)) / dist_mask.sum(dim=(-1, -2)) +# if reduce_batch: +# return cross_dist_score.detach().mean().numpy(), dist_score.detach().mean().numpy(), 0., loss.mean() +# return cross_dist_score.detach().numpy(), dist_score.detach().numpy(), clash_pl_score.detach().numpy(), loss + + + +# def dock_with_gradient( +# coords: np.ndarray, +# pocket_coords: np.ndarray, +# distance_predict_tta: np.ndarray, +# holo_distance_predict_tta: np.ndarray, +# mol: Chem.Mol, +# conf_coords: np.ndarray, +# loss_func=single_SF_loss, +# holo_coords: np.ndarray = None, +# iterations: int =400, +# early_stoping: int = 5, +# ): +# """ Docking with gradient descent, optimizing the conformer. + +# Args: +# coords: (N, 3) initial molecule coordinates +# pocket_coords: (P, 3) pocket coordinates +# distance_predict_tta: (?, T, N, P) predicted (molecule-pocket) distance matrix +# holo_distance_predict_tta: (?, T, N, N) predicted (molecule-molecule) distance matrix +# mol: rdkit molecule +# conf_coords: (?, N, 3) initial molecule conformers coordinates +# loss_func: function to calculate loss +# holo_coords: (?, T, N, 3) holo molecule coordinates +# iterations: max number of iterations +# early_stoping: stop if loss does not improve for this number of iterations + +# Returns: +# bst_coords: (N, 3) optimized molecule coordinates +# bst_loss: loss value. scalar. has gradients +# bst_meta_info: dict with additional info +# """ +# bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None +# for i, (distance_predict, holo_distance_predict) in enumerate( +# zip(distance_predict_tta, holo_distance_predict_tta) +# ): +# new_coords = deepcopy(coords) +# _coords, _loss, _meta_info = single_dock_with_gradient( +# new_coords, +# pocket_coords, +# distance_predict, +# holo_distance_predict, +# mol=mol, +# conf_coords=deepcopy(np.array(conf_coords)), +# loss_func=loss_func, +# holo_coords=holo_coords, +# iterations=iterations, +# early_stoping=early_stoping, +# ) +# if bst_loss > _loss: +# bst_coords = _coords +# bst_loss = _loss +# bst_meta_info = _meta_info +# return bst_coords, bst_loss, bst_meta_info + + +# def kabsch(x: th.Tensor, y: th.Tensor, weight: Optional[th.Tensor] = None) -> th.Tensor: +# """ Aligns x onto y. x, y are ((...), N, 3) tensors. Weights is ((...), N) +# If rotation fails, at least bring to X to Y's COM +# """ +# if weight is None: +# weight = th.ones_like(x[..., 0]) + +# weight = weight / weight.sum(dim=-1, keepdim=True) +# x_mean = (x * weight[..., None]).sum(dim=-2, keepdim=True) +# y_mean = (y * weight[..., None]).sum(dim=-2, keepdim=True) +# x = x - x_mean +# y = y - y_mean + +# # if rotation fails (SVD might fail if matrix is ill-behaved), just bring to same COM +# try: +# # ((...), N, 3) -> ((...), 1, 3, 3) +# cov = th.einsum("...ni,...nj->...ij", x, y * weight[..., None])[..., None, :, :] +# u, s, v = th.linalg.svd(cov) +# # Flip the sign of bottom row of each matrix if det product < 0 +# det = th.det(v) * th.det(u) +# u_flip = th.ones_like(u) +# u_flip[det < 0, :, -1] = -1.0 +# u = u * u_flip +# rot = u @ v +# # align to rotation +# x = rot @ x +# except: +# pass +# return x + y_mean + + + +# def single_dock_with_gradient( +# coords: np.ndarray, +# pocket_coords: np.ndarray, +# distance_predict: np.ndarray, +# holo_distance_predict: np.ndarray, +# mol: Chem.Mol, +# conf_coords: np.ndarray, +# loss_func=single_SF_loss, +# holo_coords: np.ndarray = None, +# iterations: int = 20000, +# early_stoping: int = 5, +# ): +# """ Strategy: create multiple conformers, align to coordinates, optimize the conformer +# to minimize the loss function. Then pick the conformer with the lowest loss + +# Args: +# coords: (N, 3) initial molecule coordinates +# pocket_coords: (P, 3) pocket coordinates +# distance_predict: (N, P) predicted (molecule-pocket) distance matrix +# holo_distance_predict: (N, N) predicted (molecule-molecule) distance matrix +# mol: rdkit mol object. to extract graph connectivity +# conf_coords: (B, N, 3) initial conformer coordinates +# loss_func: function to calculate loss +# holo_coords: (N, 3) holo molecule coordinates +# iterations: max number of iterations +# early_stoping: stop if loss does not improve for this number of iterations + +# Returns: +# coords: (N, 3) optimized molecule coordinates +# loss: loss value. scalar. numpy +# meta_info: dict with additional info +# """ +# # convert to torch +# coords = th.from_numpy(coords).float() +# pocket_coords = th.from_numpy(pocket_coords).float() +# distance_predict = th.from_numpy(distance_predict).float() +# holo_distance_predict = th.from_numpy(holo_distance_predict).float() +# conf_coords = th.from_numpy(conf_coords).float() + +# if holo_coords is not None: +# holo_coords = th.from_numpy(holo_coords).float() + +# # prepare optimization params +# num_conformers = conf_coords.shape[0] +# torsion_idxs = get_flexible_torsions(mol) # (T, 4) +# graph_dist_mat = th.from_numpy(Chem.GetDistanceMatrix(mol)).long() # (N, N) + +# # (B, 3) +# euler = th.randn(num_conformers, 3) * 1e-3 +# # init translation approx at ligand Center of Mass: (B, 1, 3) +# trans = th.randn(num_conformers, 1, 3) + coords.mean(dim=-2)[None, None] +# # (B, T) +# if torsion_idxs.shape[-1] > 0: +# torsions = get_dihedral(*conf_coords[..., torsion_idxs, :].unbind(dim=-2)) +# torsions += th.randn_like(torsions) * 1e-3 +# else: +# torsions = th.zeros(num_conformers, 0) + +# # add batch dim to labels +# pocket_coords = pocket_coords[None].repeat(num_conformers, 1, 1) +# distance_predict = distance_predict[None].repeat(num_conformers, 1, 1) +# holo_distance_predict = holo_distance_predict[None].repeat(num_conformers, 1, 1) + +# # set gradients and optimizer +# euler.requires_grad = True +# trans.requires_grad = True +# torsions.requires_grad = True + +# optimizer = th.optim.LBFGS(params=[euler, trans, torsions], lr=0.5) +# bst_loss, times = 10000.0, 0 +# for i in range(iterations): +# def closure(): +# optimizer.zero_grad() +# # parametrize ligand with 6+K +# aux_coords = conf_coords.detach().clone() + trans +# # frame update +# com = aux_coords.mean(dim=-2, keepdim=True) +# rot = rot_from_euler(euler) +# aux_coords = th.einsum('...rc,...nc->...nr', rot, aux_coords - com) + com +# pre_aux_coords = aux_coords.clone() +# # torsion update + kabsch -> makes 6 & T orthogonal in the tangent space +# for t, vals in zip(torsion_idxs, torsions.unbind(dim=-1)): +# aux_coords = update_dihedral(coords=aux_coords, idxs=t.tolist(), value=vals, dist_mat=graph_dist_mat) +# aux_coords = kabsch(aux_coords, pre_aux_coords) + +# _, _, _, loss = loss_func( +# aux_coords, pocket_coords, distance_predict, holo_distance_predict +# ) +# loss.backward() +# return loss + +# loss = optimizer.step(closure) +# # print(f"Iter: {i} and loss: {loss}") +# if loss.item() < bst_loss: +# bst_loss = loss.item() +# times = 0 +# else: +# times += 1 +# if times > early_stoping: +# break + +# # pick the conformer with lowest loss +# aux_coords = conf_coords.detach().clone() + trans +# # frame update +# com = aux_coords.mean(dim=-2, keepdim=True) +# rot = rot_from_euler(euler) +# aux_coords = th.einsum('...rc,...nc->...nr', rot, aux_coords - com) + com +# pre_aux_coords = aux_coords.clone() +# # torsion update + kabsch -> makes 6 & T orthogonal in the tangent space +# for t, vals in zip(torsion_idxs, torsions.unbind(dim=-1)): +# aux_coords = update_dihedral(coords=aux_coords, idxs=t.tolist(), value=vals, dist_mat=graph_dist_mat) +# aux_coords = kabsch(aux_coords, pre_aux_coords) + +# cross_score, self_score, clash_score, loss = loss_func( +# aux_coords, pocket_coords, distance_predict, holo_distance_predict, reduce_batch=False +# ) +# best_idx = loss.argmax(dim=-1).item() +# return aux_coords[best_idx].detach().numpy(), loss[best_idx].detach().numpy(), ( +# cross_score[best_idx], self_score[best_idx], clash_score[best_idx] +# ) + + +# def set_coord(mol, coords): +# for i in range(coords.shape[0]): +# mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist()) +# return mol + + +# def add_coord(mol, xyz): +# x, y, z = xyz +# conf = mol.GetConformer(0) +# pos = conf.GetPositions() +# pos[:, 0] += x +# pos[:, 1] += y +# pos[:, 2] += z +# for i in range(pos.shape[0]): +# conf.SetAtomPosition( +# i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2]) +# ) +# return mol + + +# def single_docking(input_path: str, output_path: str, output_ligand_path: str): +# """ Performs docking based on UniMol predictions. + +# Args: +# input_path: path to the input file +# output_path: path to the output file +# output_ligand_path: path to the output ligand file +# sym_rmsd: whether to use symmetric RMSD: consider best of symmetric atoms + +# Returns: +# True +# """ +# content = pd.read_pickle(input_path) +# ( +# init_coords_tta, +# mol, +# smi, +# pocket, +# pocket_coords, +# distance_predict_tta, +# holo_distance_predict_tta, +# holo_coords, +# holo_cener_coords, +# ) = content +# sample_times = len(init_coords_tta) + +# bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None +# for i in range(sample_times): +# init_coords = init_coords_tta[i] +# predict_coords, loss, meta_info = dock_with_gradient( +# init_coords, +# pocket_coords, +# distance_predict_tta[i][None], +# holo_distance_predict_tta[i][None], +# mol=mol, +# conf_coords=init_coords_tta[i][None], +# holo_coords=holo_coords, +# loss_func=single_SF_loss, +# ) +# if loss < bst_loss: +# bst_loss = loss +# bst_predict_coords = predict_coords +# bst_meta_info = meta_info + +# _rmsd = round(rmsd_func(holo_coords, bst_predict_coords, mol=mol), 4) +# _cross_score = round(float(bst_meta_info[0]), 4) +# _self_score = round(float(bst_meta_info[1]), 4) +# _clash_score = round(float(bst_meta_info[2]), 4) +# print(f"{pocket}-{smi}-RMSD:{_rmsd}-CROSSSCORE:{_cross_score}-SELFSCORE:{_self_score}-CLASHSCORE:{_clash_score}") +# mol = Chem.RemoveHs(mol) +# mol = set_coord(mol, bst_predict_coords) + +# if output_path is not None: +# with open(output_path, "wb") as f: +# pickle.dump( +# [mol, bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords], +# f, +# ) +# if output_ligand_path is not None: +# mol = add_coord(mol, holo_cener_coords.numpy()) +# Chem.MolToMolFile(mol, output_ligand_path) + +# return True + + +# if __name__ == "__main__": +# th.set_num_threads(1) +# th.manual_seed(0) +# parser = argparse.ArgumentParser(description="Docking with gradient") +# parser.add_argument("--input", type=str, help="input file.") +# parser.add_argument("--output", type=str, default=None, help="output path.") +# parser.add_argument( +# "--output-ligand", type=str, default=None, help="output ligand sdf path." +# ) +# args = parser.parse_args() + +# single_docking(args.input, args.output, args.output_ligand) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/coordinate_model.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/coordinate_model.py new file mode 100644 index 0000000000000000000000000000000000000000..e31764dcaa7d51216e4c5d255405febb7acff703 --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/coordinate_model.py @@ -0,0 +1,438 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. +import copy +import mindspore as ms +import mindspore.mint as mint +import pandas as pd +from rdkit import Chem +import pickle +import argparse +from docking_utils import rmsd_func +import warnings + +# 设置Ascend设备 +ms.set_context(device_target="Ascend") +warnings.filterwarnings(action="ignore") + + +def single_SF_loss( + predict_coords, + pocket_coords, + distance_predict, + holo_distance_predict, + dist_threshold=4.5, +): + # 替换torch.norm为mindspore.mint.norm + dist = mint.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) + holo_dist = mint.norm( + predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 + ) + distance_mask = distance_predict < dist_threshold + cross_dist_score = ( + (dist[distance_mask] - distance_predict[distance_mask]) **2 + ).mean() + dist_score = ((holo_dist - holo_distance_predict)** 2).mean() + loss = cross_dist_score * 1.0 + dist_score * 5.0 + return loss + + +def scoring( + predict_coords, + pocket_coords, + distance_predict, + holo_distance_predict, + dist_threshold=4.5, +): + predict_coords = predict_coords.detach() + # 替换torch.norm为mindspore.mint.norm + dist = mint.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) + holo_dist = mint.norm( + predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 + ) + distance_mask = distance_predict < dist_threshold + cross_dist_score = ( + (dist[distance_mask] - distance_predict[distance_mask]) **2 + ).mean() + dist_score = ((holo_dist - holo_distance_predict)** 2).mean() + # 替换.numpy()为.asnumpy() + return cross_dist_score.asnumpy(), dist_score.asnumpy() + + +def dock_with_gradient( + coords, + pocket_coords, + distance_predict_tta, + holo_distance_predict_tta, + loss_func=single_SF_loss, + holo_coords=None, + iterations=20000, + early_stoping=5, +): + bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None + for i, (distance_predict, holo_distance_predict) in enumerate( + zip(distance_predict_tta, holo_distance_predict_tta) + ): + new_coords = copy.deepcopy(coords) + _coords, _loss, _meta_info = single_dock_with_gradient( + new_coords, + pocket_coords, + distance_predict, + holo_distance_predict, + loss_func=loss_func, + holo_coords=holo_coords, + iterations=iterations, + early_stoping=early_stoping, + ) + if bst_loss > _loss: + bst_coords = _coords + bst_loss = _loss + bst_meta_info = _meta_info + return bst_coords, bst_loss, bst_meta_info + + +def single_dock_with_gradient( + coords, + pocket_coords, + distance_predict, + holo_distance_predict, + loss_func=single_SF_loss, + holo_coords=None, + iterations=20000, + early_stoping=5, +): + # 替换torch.from_numpy为mindspore.Tensor.from_numpy并设置float32类型 + coords = ms.Tensor.from_numpy(coords).astype(ms.float32) + pocket_coords = ms.Tensor.from_numpy(pocket_coords).astype(ms.float32) + distance_predict = ms.Tensor.from_numpy(distance_predict).astype(ms.float32) + holo_distance_predict = ms.Tensor.from_numpy(holo_distance_predict).astype(ms.float32) + + if holo_coords is not None: + holo_coords = ms.Tensor.from_numpy(holo_coords).astype(ms.float32) + + # 设置需要梯度 + coords.requires_grad = True + # 替换PyTorch优化器为MindSpore优化器 + optimizer = mint.optim.LBFGS([coords], learning_rate=1.0) + bst_loss, times = 10000.0, 0 + + def closure(): + optimizer.zero_grad() + loss = loss_func( + coords, pocket_coords, distance_predict, holo_distance_predict + ) + loss.backward() + return loss + + for i in range(iterations): + loss = optimizer.step(closure) + # 获取损失值(MindSpore中item()方法类似) + current_loss = loss.item() + if current_loss < bst_loss: + bst_loss = current_loss + times = 0 + else: + times += 1 + if times > early_stoping: + break + + meta_info = scoring(coords, pocket_coords, distance_predict, holo_distance_predict) + # 替换.detach().numpy()为.detach().asnumpy() + return coords.detach().asnumpy(), loss.detach().asnumpy(), meta_info + + +def set_coord(mol, coords): + for i in range(coords.shape[0]): + mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist()) + return mol + + +def add_coord(mol, xyz): + x, y, z = xyz + conf = mol.GetConformer(0) + pos = conf.GetPositions() + pos[:, 0] += x + pos[:, 1] += y + pos[:, 2] += z + for i in range(pos.shape[0]): + conf.SetAtomPosition( + i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2]) + ) + return mol + + +def single_docking(input_path, output_path, output_ligand_path): + content = pd.read_pickle(input_path) + ( + init_coords_tta, + mol, + smi, + pocket, + pocket_coords, + distance_predict_tta, + holo_distance_predict_tta, + holo_coords, + holo_cener_coords, + ) = content + sample_times = len(init_coords_tta) + bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None + for i in range(sample_times): + init_coords = init_coords_tta[i] + predict_coords, loss, meta_info = dock_with_gradient( + init_coords, + pocket_coords, + distance_predict_tta, + holo_distance_predict_tta, + holo_coords=holo_coords, + loss_func=single_SF_loss, + ) + if loss < bst_loss: + bst_loss = loss + bst_predict_coords = predict_coords + bst_meta_info = meta_info + + _rmsd = round(rmsd_func(holo_coords, bst_predict_coords, mol), 4) + _cross_score = round(float(bst_meta_info[0]), 4) + _self_score = round(float(bst_meta_info[1]), 4) + print(f"{pocket}-{smi}-RMSD:{_rmsd}-{_cross_score}-{_self_score}") + mol = Chem.RemoveHs(mol) + mol = set_coord(mol, bst_predict_coords) + + if output_path is not None: + with open(output_path, "wb") as f: + pickle.dump( + [mol, bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords], + f, + ) + if output_ligand_path is not None: + # 替换.numpy()为.asnumpy() + mol = add_coord(mol, holo_cener_coords.asnumpy()) + Chem.MolToMolFile(mol, output_ligand_path) + + return True + + +if __name__ == "__main__": + # 替换PyTorch线程和种子设置为MindSpore + ms.set_num_threads(1) + ms.set_seed(0) + parser = argparse.ArgumentParser(description="Docking with gradient") + parser.add_argument("--input", type=str, help="input file.") + parser.add_argument("--output", type=str, default=None, help="output path.") + parser.add_argument( + "--output-ligand", type=str, default=None, help="output ligand sdf path." + ) + args = parser.parse_args() + + single_docking(args.input, args.output, args.output_ligand) +# import copy +# import torch +# import pandas as pd +# from rdkit import Chem +# import pickle +# import argparse +# from docking_utils import rmsd_func +# import warnings + +# warnings.filterwarnings(action="ignore") + + +# def single_SF_loss( +# predict_coords, +# pocket_coords, +# distance_predict, +# holo_distance_predict, +# dist_threshold=4.5, +# ): +# dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) +# holo_dist = torch.norm( +# predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 +# ) +# distance_mask = distance_predict < dist_threshold +# cross_dist_score = ( +# (dist[distance_mask] - distance_predict[distance_mask]) ** 2 +# ).mean() +# dist_score = ((holo_dist - holo_distance_predict) ** 2).mean() +# loss = cross_dist_score * 1.0 + dist_score * 5.0 +# return loss + + +# def scoring( +# predict_coords, +# pocket_coords, +# distance_predict, +# holo_distance_predict, +# dist_threshold=4.5, +# ): +# predict_coords = predict_coords.detach() +# dist = torch.norm(predict_coords.unsqueeze(1) - pocket_coords.unsqueeze(0), dim=-1) +# holo_dist = torch.norm( +# predict_coords.unsqueeze(1) - predict_coords.unsqueeze(0), dim=-1 +# ) +# distance_mask = distance_predict < dist_threshold +# cross_dist_score = ( +# (dist[distance_mask] - distance_predict[distance_mask]) ** 2 +# ).mean() +# dist_score = ((holo_dist - holo_distance_predict) ** 2).mean() +# return cross_dist_score.numpy(), dist_score.numpy() + + +# def dock_with_gradient( +# coords, +# pocket_coords, +# distance_predict_tta, +# holo_distance_predict_tta, +# loss_func=single_SF_loss, +# holo_coords=None, +# iterations=20000, +# early_stoping=5, +# ): +# bst_loss, bst_coords, bst_meta_info = 10000.0, coords, None +# for i, (distance_predict, holo_distance_predict) in enumerate( +# zip(distance_predict_tta, holo_distance_predict_tta) +# ): +# new_coords = copy.deepcopy(coords) +# _coords, _loss, _meta_info = single_dock_with_gradient( +# new_coords, +# pocket_coords, +# distance_predict, +# holo_distance_predict, +# loss_func=loss_func, +# holo_coords=holo_coords, +# iterations=iterations, +# early_stoping=early_stoping, +# ) +# if bst_loss > _loss: +# bst_coords = _coords +# bst_loss = _loss +# bst_meta_info = _meta_info +# return bst_coords, bst_loss, bst_meta_info + + +# def single_dock_with_gradient( +# coords, +# pocket_coords, +# distance_predict, +# holo_distance_predict, +# loss_func=single_SF_loss, +# holo_coords=None, +# iterations=20000, +# early_stoping=5, +# ): +# coords = torch.from_numpy(coords).float() +# pocket_coords = torch.from_numpy(pocket_coords).float() +# distance_predict = torch.from_numpy(distance_predict).float() +# holo_distance_predict = torch.from_numpy(holo_distance_predict).float() + +# if holo_coords is not None: +# holo_coords = torch.from_numpy(holo_coords).float() + +# coords.requires_grad = True +# optimizer = torch.optim.LBFGS([coords], lr=1.0) +# bst_loss, times = 10000.0, 0 +# for i in range(iterations): + +# def closure(): +# optimizer.zero_grad() +# loss = loss_func( +# coords, pocket_coords, distance_predict, holo_distance_predict +# ) +# loss.backward() +# return loss + +# loss = optimizer.step(closure) +# if loss.item() < bst_loss: +# bst_loss = loss.item() +# times = 0 +# else: +# times += 1 +# if times > early_stoping: +# break + +# meta_info = scoring(coords, pocket_coords, distance_predict, holo_distance_predict) +# return coords.detach().numpy(), loss.detach().numpy(), meta_info + + +# def set_coord(mol, coords): +# for i in range(coords.shape[0]): +# mol.GetConformer(0).SetAtomPosition(i, coords[i].tolist()) +# return mol + + +# def add_coord(mol, xyz): +# x, y, z = xyz +# conf = mol.GetConformer(0) +# pos = conf.GetPositions() +# pos[:, 0] += x +# pos[:, 1] += y +# pos[:, 2] += z +# for i in range(pos.shape[0]): +# conf.SetAtomPosition( +# i, Chem.rdGeometry.Point3D(pos[i][0], pos[i][1], pos[i][2]) +# ) +# return mol + + +# def single_docking(input_path, output_path, output_ligand_path): +# content = pd.read_pickle(input_path) +# ( +# init_coords_tta, +# mol, +# smi, +# pocket, +# pocket_coords, +# distance_predict_tta, +# holo_distance_predict_tta, +# holo_coords, +# holo_cener_coords, +# ) = content +# sample_times = len(init_coords_tta) +# bst_predict_coords, bst_loss, bst_meta_info = None, 1000.0, None +# for i in range(sample_times): +# init_coords = init_coords_tta[i] +# predict_coords, loss, meta_info = dock_with_gradient( +# init_coords, +# pocket_coords, +# distance_predict_tta, +# holo_distance_predict_tta, +# holo_coords=holo_coords, +# loss_func=single_SF_loss, +# ) +# if loss < bst_loss: +# bst_loss = loss +# bst_predict_coords = predict_coords +# bst_meta_info = meta_info + +# _rmsd = round(rmsd_func(holo_coords, bst_predict_coords, mol), 4) +# _cross_score = round(float(bst_meta_info[0]), 4) +# _self_score = round(float(bst_meta_info[1]), 4) +# print(f"{pocket}-{smi}-RMSD:{_rmsd}-{_cross_score}-{_self_score}") +# mol = Chem.RemoveHs(mol) +# mol = set_coord(mol, bst_predict_coords) + +# if output_path is not None: +# with open(output_path, "wb") as f: +# pickle.dump( +# [mol, bst_predict_coords, holo_coords, bst_loss, smi, pocket, pocket_coords], +# f, +# ) +# if output_ligand_path is not None: +# mol = add_coord(mol, holo_cener_coords.numpy()) +# Chem.MolToMolFile(mol, output_ligand_path) + +# return True + + +# if __name__ == "__main__": +# torch.set_num_threads(1) +# torch.manual_seed(0) +# parser = argparse.ArgumentParser(description="Docking with gradient") +# parser.add_argument("--input", type=str, help="input file.") +# parser.add_argument("--output", type=str, default=None, help="output path.") +# parser.add_argument( +# "--output-ligand", type=str, default=None, help="output ligand sdf path." +# ) +# args = parser.parse_args() + +# single_docking(args.input, args.output, args.output_ligand) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking.py new file mode 100644 index 0000000000000000000000000000000000000000..83816516122204c82180c7ee4826239e4a10661b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking.py @@ -0,0 +1,153 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import os +import numpy as np +import pandas as pd +from multiprocessing import Pool +from tqdm import tqdm +import glob +import argparse +from docking_utils import ( + docking_data_pre, + ensemble_iterations, + print_results, + rmsd_func, +) +import warnings + +warnings.filterwarnings(action="ignore") + + +def result_log(dir_path): + ### result logging ### + output_dir = os.path.join(dir_path, "cache") + rmsd_results = [] + for path in glob.glob(os.path.join(output_dir, "*.docking.pkl")): + ( + mol, + bst_predict_coords, + holo_coords, + bst_loss, + smi, + pocket, + pocket_coords, + ) = pd.read_pickle(path) + rmsd = rmsd_func(holo_coords, bst_predict_coords, mol=mol) + rmsd_results.append(rmsd) + rmsd_results = np.array(rmsd_results) + print_results(rmsd_results) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="docking") + parser.add_argument( + "--reference-file", + type=str, + default="./protein_ligand_binding_pose_prediction/test.lmdb", + help="Location of the reference set", + ) + parser.add_argument("--nthreads", type=int, default=40, help="num of threads") + parser.add_argument( + "--predict-file", + type=str, + default="./infer_pose/save_pose_test.out.pkl", + help="Location of the prediction file", + ) + parser.add_argument( + "--output-path", + type=str, + default="./protein_ligand_binding_pose_prediction", + help="Location of the docking output path", + ) + parser.add_argument( + "--optimization-model", + type=str, + default="conformer", + help="Optimize coordinates ('coordinate') or ligand internal torsions ('conformer')", + choices=["coordinate", "conformer"], + ) + args = parser.parse_args() + + raw_data_path, predict_path, dir_path, nthreads, model_choice = ( + args.reference_file, + args.predict_file, + args.output_path, + args.nthreads, + args.optimization_model, + ) + tta_times = 10 + ( + mol_list, + smi_list, + pocket_list, + pocket_coords_list, + distance_predict_list, + holo_distance_predict_list, + holo_coords_list, + holo_center_coords_list, + ) = docking_data_pre(raw_data_path, predict_path) + iterations = ensemble_iterations( + mol_list, + smi_list, + pocket_list, + pocket_coords_list, + distance_predict_list, + holo_distance_predict_list, + holo_coords_list, + holo_center_coords_list, + tta_times=tta_times, + ) + sz = len(mol_list) // tta_times + new_pocket_list = pocket_list[::tta_times] + output_dir = os.path.join(dir_path, "cache") + os.makedirs(output_dir, exist_ok=True) + + def dump(content): + pocket = content[3] + output_name = os.path.join(output_dir, "{}.pkl".format(pocket)) + try: + os.remove(output_name) + except: + pass + pd.to_pickle(content, output_name) + return True + + # skip step if repeat + with Pool(nthreads) as pool: + for inner_output in tqdm(pool.imap_unordered(dump, iterations), total=sz): + if not inner_output: + print("fail to dump") + + def single_docking(pocket_name): + input_name = os.path.join(output_dir, "{}.pkl".format(pocket_name)) + output_name = os.path.join(output_dir, "{}.docking.pkl".format(pocket_name)) + output_ligand_name = os.path.join( + output_dir, "{}.ligand.sdf".format(pocket_name) + ) + try: + os.remove(output_name) + except: + pass + try: + os.remove(output_ligand_name) + except: + pass + + cmd = "python ./unimol/utils/{}_model.py --input {} --output {} --output-ligand {}".format( + model_choice, input_name, output_name, output_ligand_name + ) + os.system(cmd) + return True + + + with Pool(nthreads) as pool: + for inner_output in tqdm( + pool.imap_unordered(single_docking, new_pocket_list), total=len(new_pocket_list) + ): + if not inner_output: + print("fail to docking") + + result_log(args.output_path) diff --git a/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking_utils.py b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..05d30331a202454ce7a27890482a4cf98a92923b --- /dev/null +++ b/MindChemistry/applications/Uni-Mol/unimol/unimol/utils/docking_utils.py @@ -0,0 +1,277 @@ +# Copyright (c) DP Techonology, Inc. and its affiliates. +# +# This source code is licensed under the MIT license found in the +# LICENSE file in the root directory of this source tree. + +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem +from rdkit import RDLogger + +RDLogger.DisableLog("rdApp.*") +import warnings + +warnings.filterwarnings(action="ignore") +from rdkit.Chem import rdMolTransforms +import copy +import lmdb +import pickle +import pandas as pd +from typing import Dict, List, Optional +from unimol.utils.conf_gen_cal_metrics import clustering, single_conf_gen + + +def add_all_conformers_to_mol(mol: Chem.Mol, conformers: List[np.ndarray]) -> Chem.Mol: + mol = copy.deepcopy(mol) + mol.RemoveAllConformers() + for i, conf_pos in enumerate(conformers): + conf = Chem.Conformer(mol.GetNumAtoms()) + mol.AddConformer(conf, assignId=True) + + conf = mol.GetConformer(i) + positions = conf_pos.tolist() + for j in range(mol.GetNumAtoms()): + conf.SetAtomPosition(j, positions[j]) + return mol + + +def get_torsions(m: Chem.Mol, removeHs=True) -> List: + if removeHs: + m = Chem.RemoveHs(m) + torsionList = [] + torsionSmarts = "[!$(*#*)&!D1]-&!@[!$(*#*)&!D1]" + torsionQuery = Chem.MolFromSmarts(torsionSmarts) + matches = m.GetSubstructMatches(torsionQuery) + for match in matches: + idx2 = match[0] + idx3 = match[1] + bond = m.GetBondBetweenAtoms(idx2, idx3) + jAtom = m.GetAtomWithIdx(idx2) + kAtom = m.GetAtomWithIdx(idx3) + for b1 in jAtom.GetBonds(): + if b1.GetIdx() == bond.GetIdx(): + continue + idx1 = b1.GetOtherAtomIdx(idx2) + for b2 in kAtom.GetBonds(): + if (b2.GetIdx() == bond.GetIdx()) or (b2.GetIdx() == b1.GetIdx()): + continue + idx4 = b2.GetOtherAtomIdx(idx3) + # skip 3-membered rings + if idx4 == idx1: + continue + # skip torsions that include hydrogens + if (m.GetAtomWithIdx(idx1).GetAtomicNum() == 1) or ( + m.GetAtomWithIdx(idx4).GetAtomicNum() == 1 + ): + continue + if m.GetAtomWithIdx(idx4).IsInRing(): + torsionList.append((idx4, idx3, idx2, idx1)) + break + else: + torsionList.append((idx1, idx2, idx3, idx4)) + break + break + return torsionList + + +def load_lmdb_data(lmdb_path, key): + env = lmdb.open( + lmdb_path, + subdir=False, + readonly=True, + lock=False, + readahead=False, + meminit=False, + max_readers=256, + ) + txn = env.begin() + _keys = list(txn.cursor().iternext(values=False)) + collects = [] + for idx in range(len(_keys)): + datapoint_pickled = txn.get(f"{idx}".encode("ascii")) + data = pickle.loads(datapoint_pickled) + collects.append(data[key]) + return collects + + +def reprocess_content(content: Dict, base_mol: Optional[Chem.Mol] = None, M: int = 2000, N: int = 10, mmff: bool = False, seed: int = 42, stereo_from3d: bool = True) -> Dict: + """ Reprocess a data point in the LMDB schema for Docking usage. Ensures correct stereochemistry. + Basic principle is to perceive stereochem from label molecule's 3D and keep it intact. + Use default values for best results + + Args: + content: A dictionary of the LMDB schema. (atoms, holo_mol, mol_list, cooredinates, etc.) + base_mol: The molecule to replace the holo_mol with, if passed + M: The number of conformers to generate + N: The number of clusters to group conformers and pick a representative from + mmff: Whether to use MMFF minimization after conformer generation + seed: The random seed to use for conformer generation + stereo_from3d: Whether to perceive stereochemistry from the 3D coordinates of the label molecule + + Returns: + A copy of the original, with the holo_mol replaced with the base_mol, and coordinates added. + """ + if base_mol is None: + base_mol = content["holo_mol"] + # Copy so we don't change inputs + content = copy.deepcopy(content) + base_mol = copy.deepcopy(base_mol) + base_mol = Chem.AddHs(base_mol, addCoords=True) + # assign stereochem from 3d + if stereo_from3d and base_mol.GetNumConformers() > 0: + Chem.AssignStereochemistryFrom3D(base_mol) + ori_smiles = Chem.MolToSmiles(base_mol) + # create new, clean molecule + remol = Chem.MolFromSmiles(ori_smiles) + # reorder to match and add Hs + idxs = remol.GetSubstructMatches(Chem.RemoveHs(base_mol)) + if isinstance(idxs[0], tuple): + idxs = idxs[0] + idxs = list(map(int, idxs)) + remol = Chem.RenumberAtoms(remol, idxs) + remol = Chem.AddHs(remol, addCoords=True) + # overwrite - write the diverse conformer set for potential later reuse + content["coordinates"] = [x for x in clustering(remol, M=M, N=N, seed=seed, removeHs=False, mmff=mmff)] + content["mol_list"] = [ + Chem.AddHs( + copy.deepcopy(add_all_conformers_to_mol( + Chem.RemoveHs(remol), content["coordinates"] + )), addCoords=True + ) for i in range(N) + ] + content["holo_mol"] = copy.deepcopy(base_mol) + content["atoms"] = [a.GetSymbol() for a in base_mol.GetAtoms()] + return content + + +def docking_data_pre(raw_data_path, predict_path): + mol_list = load_lmdb_data(raw_data_path, "mol_list") + mol_list = [Chem.RemoveHs(mol) for items in mol_list for mol in items] + predict = pd.read_pickle(predict_path) + ( + smi_list, + pocket_list, + pocket_coords_list, + distance_predict_list, + holo_distance_predict_list, + holo_coords_list, + holo_center_coords_list, + ) = ([], [], [], [], [], [], []) + for batch in predict: + sz = batch["atoms"].size(0) + for i in range(sz): + smi_list.append(batch["smi_name"][i]) + pocket_list.append(batch["pocket_name"][i]) + + distance_predict = batch["cross_distance_predict"][i] + token_mask = batch["atoms"][i] > 2 + pocket_token_mask = batch["pocket_atoms"][i] > 2 + distance_predict = distance_predict[token_mask][:, pocket_token_mask] + pocket_coords = batch["pocket_coordinates"][i] + pocket_coords = pocket_coords[pocket_token_mask, :] + + holo_distance_predict = batch["holo_distance_predict"][i] + holo_distance_predict = holo_distance_predict[token_mask][:, token_mask] + + holo_coordinates = batch["holo_coordinates"][i] + holo_coordinates = holo_coordinates[token_mask, :] + holo_center_coordinates = batch["holo_center_coordinates"][i][:3] + + pocket_coords = pocket_coords.numpy().astype(np.float32) + distance_predict = distance_predict.numpy().astype(np.float32) + holo_distance_predict = holo_distance_predict.numpy().astype(np.float32) + # Fill diagonal with 0, issue with the model not learning to predict 0 distance + np.fill_diagonal(holo_distance_predict, 0) + # + holo_coords = holo_coordinates.numpy().astype(np.float32) + + pocket_coords_list.append(pocket_coords) + distance_predict_list.append(distance_predict) + holo_distance_predict_list.append(holo_distance_predict) + holo_coords_list.append(holo_coords) + holo_center_coords_list.append(holo_center_coordinates) + + return ( + mol_list, + smi_list, + pocket_list, + pocket_coords_list, + distance_predict_list, + holo_distance_predict_list, + holo_coords_list, + holo_center_coords_list, + ) + + +def ensemble_iterations( + mol_list, + smi_list, + pocket_list, + pocket_coords_list, + distance_predict_list, + holo_distance_predict_list, + holo_coords_list, + holo_center_coords_list, + tta_times=10, + seed=42, +): + sz = len(mol_list) + for i in range(sz // tta_times): + start_idx, end_idx = i * tta_times, (i + 1) * tta_times + distance_predict_tta = distance_predict_list[start_idx:end_idx] + holo_distance_predict_tta = holo_distance_predict_list[start_idx:end_idx] + + mol = copy.deepcopy(mol_list[start_idx]) + rdkit_mol = single_conf_gen(mol, num_confs=tta_times, seed=seed) + sz = len(rdkit_mol.GetConformers()) + initial_coords_list = [ + rdkit_mol.GetConformers()[i].GetPositions().astype(np.float32) + for i in range(sz) + ] + + yield [ + initial_coords_list, + mol, + smi_list[start_idx], + pocket_list[start_idx], + pocket_coords_list[start_idx], + distance_predict_tta, + holo_distance_predict_tta, + holo_coords_list[start_idx], + holo_center_coords_list[start_idx], + ] + + +def rmsd_func(holo_coords: np.ndarray, predict_coords: np.ndarray, mol: Optional[Chem.Mol] = None) -> float: + """ Symmetric RMSD for molecules. """ + if predict_coords is not np.nan: + sz = holo_coords.shape + if mol is not None: + # get stereochem-unaware permutations: (P, N) + base_perms = np.array(mol.GetSubstructMatches(mol, uniquify=False)) + # filter for valid stereochem only + chem_order = np.array(list(Chem.rdmolfiles.CanonicalRankAtoms(mol, breakTies=False))) + perms_mask = (chem_order[base_perms] == chem_order[None]).sum(-1) == mol.GetNumAtoms() + base_perms = base_perms[perms_mask] + noh_mask = np.array([a.GetAtomicNum() != 1 for a in mol.GetAtoms()]) + # (N, 3), (N, 3) -> (P, N, 3), ((), N, 3) -> (P,) -> min((P,)) + best_rmsd = np.inf + for perm in base_perms: + rmsd = np.sqrt(np.sum((predict_coords[perm[noh_mask]] - holo_coords) ** 2) / sz[-2]) + if rmsd < best_rmsd: + best_rmsd = rmsd + + rmsd = best_rmsd + else: + rmsd = np.sqrt(np.sum((predict_coords - holo_coords) ** 2) / sz[-2]) + return rmsd + return 1000.0 + + +def print_results(rmsd_results): + print("RMSD < 1.0 : ", np.mean(rmsd_results < 1.0)) + print("RMSD < 1.5 : ", np.mean(rmsd_results < 1.5)) + print("RMSD < 2.0 : ", np.mean(rmsd_results < 2.0)) + print("RMSD < 3.0 : ", np.mean(rmsd_results < 3.0)) + print("RMSD < 5.0 : ", np.mean(rmsd_results < 5.0)) + print("avg RMSD : ", np.mean(rmsd_results))