From 0776b28503ba4c3a6a5d14941e4453593bf0d21a Mon Sep 17 00:00:00 2001 From: ixsluo Date: Wed, 4 Dec 2024 15:26:24 +0800 Subject: [PATCH 1/2] feat: add flow model Signed-off-by: ixsluo Signed-off-by: wangqingchang2024 --- .../applications/crystalflow/.gitignore | 11 + .../applications/crystalflow/README.md | 133 ++++++ .../crystalflow/compute_metric.py | 327 +++++++++++++++ .../applications/crystalflow/config.yaml | 57 +++ .../crystalflow/data/crysloader.py | 215 ++++++++++ .../crystalflow/data/data_utils.py | 380 ++++++++++++++++++ .../applications/crystalflow/data/dataset.py | 139 +++++++ .../applications/crystalflow/evaluate.py | 175 ++++++++ .../crystalflow/models/conditioning.py | 34 ++ .../applications/crystalflow/models/cspnet.py | 278 +++++++++++++ .../crystalflow/models/cspnet_condition.py | 299 ++++++++++++++ .../crystalflow/models/diff_utils.py | 175 ++++++++ .../applications/crystalflow/models/flow.py | 284 +++++++++++++ .../crystalflow/models/flow_condition.py | 295 ++++++++++++++ .../crystalflow/models/infer_utils.py | 95 +++++ .../crystalflow/models/lattice.py | 61 +++ .../crystalflow/models/train_utils.py | 117 ++++++ .../applications/crystalflow/requirement.txt | 7 + .../crystalflow/test_crystalflow.py | 191 +++++++++ .../applications/crystalflow/train.py | 218 ++++++++++ .../crystalflow/train_pressure.py | 231 +++++++++++ 21 files changed, 3722 insertions(+) create mode 100644 MindChemistry/applications/crystalflow/.gitignore create mode 100644 MindChemistry/applications/crystalflow/README.md create mode 100644 MindChemistry/applications/crystalflow/compute_metric.py create mode 100644 MindChemistry/applications/crystalflow/config.yaml create mode 100644 MindChemistry/applications/crystalflow/data/crysloader.py create mode 100644 MindChemistry/applications/crystalflow/data/data_utils.py create mode 100644 MindChemistry/applications/crystalflow/data/dataset.py create mode 100644 MindChemistry/applications/crystalflow/evaluate.py create mode 100644 MindChemistry/applications/crystalflow/models/conditioning.py create mode 100644 MindChemistry/applications/crystalflow/models/cspnet.py create mode 100644 MindChemistry/applications/crystalflow/models/cspnet_condition.py create mode 100644 MindChemistry/applications/crystalflow/models/diff_utils.py create mode 100644 MindChemistry/applications/crystalflow/models/flow.py create mode 100644 MindChemistry/applications/crystalflow/models/flow_condition.py create mode 100644 MindChemistry/applications/crystalflow/models/infer_utils.py create mode 100644 MindChemistry/applications/crystalflow/models/lattice.py create mode 100644 MindChemistry/applications/crystalflow/models/train_utils.py create mode 100644 MindChemistry/applications/crystalflow/requirement.txt create mode 100644 MindChemistry/applications/crystalflow/test_crystalflow.py create mode 100644 MindChemistry/applications/crystalflow/train.py create mode 100644 MindChemistry/applications/crystalflow/train_pressure.py diff --git a/MindChemistry/applications/crystalflow/.gitignore b/MindChemistry/applications/crystalflow/.gitignore new file mode 100644 index 000000000..208f56287 --- /dev/null +++ b/MindChemistry/applications/crystalflow/.gitignore @@ -0,0 +1,11 @@ +dataset/ +*.log +*.npy +ckpt/ +dataset.zip +rank_0/ +test_mind_cspnet.py +torch2ms_ckpt/ +*.ipynb +*.ckpt +ignore/ diff --git a/MindChemistry/applications/crystalflow/README.md b/MindChemistry/applications/crystalflow/README.md new file mode 100644 index 000000000..bead72bb0 --- /dev/null +++ b/MindChemistry/applications/crystalflow/README.md @@ -0,0 +1,133 @@ + +# 模型名称 + +> CrystalFlow + +## 介绍 + +> 理论晶体结构预测是通过计算的手段寻找物质在给定的外界条件下最稳定结构的重要手段。传统结构预测方法依赖在势能面上广泛的随机采样来寻找最稳定结构,然而,这种方法需要对大量随机生成的结构进行局域优化,而局域优化通常需要消耗巨大的第一性原理计算成本,尤其在模拟多元素复杂体系时,这种计算开销会显著增加,从而带来巨大的挑战。近年来,基于深度学习生成模型的晶体结构生成方法因其能够在势能面上更高效地采样合理结构而逐渐受到关注。这种方法通过从已有的稳定或局域稳定结构数据中学习,进而生成合理的晶体结构,与随机采样相比,不仅能够减少局域优化的计算成本,还能通过较少的采样找到体系的最稳定结构。采用神经常微分方程和连续变化建模概率密度的归一化流流模型,相比采用扩散模型方法的生成模型具有更加简洁、灵活、高效的优点。本方法基于流模型架构,发展了以CrystalFlow命名的晶体结构生成模型,在MP20等基准数据集上达到优秀的水平。 + +## 环境要求 + +> 1. 安装`mindspore(2.5.0)` +> 2. 安装依赖包:`pip install -r requirement.txt` + +## 快速入门 + +> 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 +> 3. 安装依赖包:`pip install -r requirement.txt` +> 4. 训练命令: `python train.py` +> 5. 预测命令: `python evaluate.py` +> 6. 评估命令: `python compute_metric.py` +> 7. 评估结果放在`config.yaml`中指定的`metric_dir`路径的json文件中 + +### 代码目录结构 + +```text +代码主要模块在models文件夹下,其中cspnet.py是网络层,flow.py是流模型模块.data文件夹下是数据集处理模块。 + +applications + └── crystalflow # 模型名 + ├── readme.md # readme文件 + ├── config.yaml # 配置文件 + ├── train.py # 训练启动脚本 + ├── evaluate.py # 推理启动脚本 + ├── compute_metric.py # 评估启动脚本 + ├── requirement.txt # 环境依赖 + ├── data # 数据处理模块 + | ├── data_utils.py # 工具模块 + | ├── dataset.py # 构造数据集 + | └── crysloader.py # 构造数据加载器 + └── models + ├── conditioning.py # 条件生成工具模块 + ├── cspnet.py # 基于图神经网络的去噪器模块 + ├── cspnet_condition.py # 条件生成的网络层 + ├── diff_utils.py # 工具模块 + ├── flow.py # 流模型模块 + ├── flow_condition.py # 条件生成的流模型 + ├── infer_utils.py # 推理工具模块 + ├── lattice.py # 晶格矩阵处理工具 + └── train_utils.py # 训练工具模块 + +``` + +## 下载数据集 + +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: + +```txt +crystalflow + ... + └─dataset + perov_5 钙钛矿数据集 + carbon_24 碳晶体数据集 + mp_20 晶胞内原子数最多为20的MP数据集 + mpts_52 晶胞内原子数最多为52的MP数据集 + dataset_prop.txt 数据集属性文件 + ... +``` + +## 训练过程 + +### 训练 + +将Mindchemistry/mindchemistry文件包下载到当前目录; + +更改config文件,设置训练参数: +> 1. 设置训练的dataset,见dataset字段 +> 2. 设置去噪器模型的配置,见model字段 +> 3. 设置训练保存的权重文件,更改train.ckpt_dir文件夹名称和checkpoint.last_path权重文件名称 +> 4. 其它训练设置见train字段 + +```bash +pip install -r requirement.txt +python train.py +``` + +### 推理 + +将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 + +更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 + +```bash +python evaluate.py +``` + +推理得到的晶体将保存在test.eval_save_path指定的文件中 + +文件中存储的内容为python字典,格式为: + +```python +{ + 'pred': [ + [晶体A sample 1, 晶体A sample 2, 晶体A sample 3, ... 晶体A sample num_eval], + [晶体B sample 1, 晶体B sample 2, 晶体B sample 3, ... 晶体B sample num_eval] + ... + ] + 'gt': [ + 晶体A ground truth, + 晶体B ground truth, + ... + ] +} +``` + +### 评估 + +将推理得到的晶体文件的path写入config文件的test.eval_save_path中; + +确保num_evals与进行推理时设置的对于每个组分生成样本的数量一致或更小。比如进行推理时,num_evals设置为1,那么评估时,num_evals只能设置为1;推理时,num_evals设置为20,那么评估时,num_evals可以设置为1-20的数字来进行评估。 + +更改config文件中的test.metric_dir字段来设置评估结果的保存路径 + +```bash +python compute_metric.py +``` + +得到的评估结果文件示例: + +```json +{"match_rate": 0.6107671899181959, "rms_dist": 0.07492558322002925} +``` diff --git a/MindChemistry/applications/crystalflow/compute_metric.py b/MindChemistry/applications/crystalflow/compute_metric.py new file mode 100644 index 000000000..a8b92c65b --- /dev/null +++ b/MindChemistry/applications/crystalflow/compute_metric.py @@ -0,0 +1,327 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""compute metric file""" +import itertools +import json +import os +import pickle +from collections import Counter +from pathlib import Path +import argparse +import yaml + +import numpy as np +from matminer.featurizers.composition.composite import ElementProperty +from matminer.featurizers.site.fingerprint import CrystalNNFingerprint +from p_tqdm import p_map +from pymatgen.analysis.structure_matcher import StructureMatcher +from pymatgen.core.composition import Composition +from pymatgen.core.lattice import Lattice +from pymatgen.core.structure import Structure +import smact +from smact.screening import pauling_test +from tqdm import trange + +from models.infer_utils import chemical_symbols + +matcher = StructureMatcher(stol=0.5, angle_tol=10, ltol=0.3) +crystalnn_fp = CrystalNNFingerprint.from_preset("ops") +comp_fp = ElementProperty.from_preset('magpie') + + +def smact_validity(comp, count, use_pauling_test=True, include_alloys=True): + """Smact validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + elem_symbols = tuple([chemical_symbols[elem] for elem in comp]) + space = smact.element_dictionary(elem_symbols) + smact_elems = [e[1] for e in space.items()] + electronegs = [e.pauling_eneg for e in smact_elems] + ox_combos = [e.oxidation_states for e in smact_elems] + if len(set(elem_symbols)) == 1: + return True + if include_alloys: + is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols] + if all(is_metal_list): + return True + + threshold = np.max(count) + oxn = 1 + for oxc in ox_combos: + oxn *= len(oxc) + if oxn > 1e7: + return False + for ox_states in itertools.product(*ox_combos): + stoichs = [(c,) for c in count] + # Test for charge balance + cn_e, _ = smact.neutral_ratios(ox_states, + stoichs=stoichs, + threshold=threshold) + # Electronegativity test + if cn_e: + if use_pauling_test: + try: + electroneg_ok = pauling_test(ox_states, electronegs) + except TypeError: + # if no electronegativity data, assume it is okay + electroneg_ok = True + else: + electroneg_ok = True + if electroneg_ok: + return True + return False + + +def structure_validity(crystal, cutoff=0.5): + """Structure validity. See details in the paper Crystal Diffution Variational Autoencoder and + its codebase. + """ + dist_mat = crystal.distance_matrix + # Pad diagonal with a large number + dist_mat = dist_mat + np.diag(np.ones(dist_mat.shape[0]) * (cutoff + 10.)) + if dist_mat.min() < cutoff or crystal.volume < 0.1 or max( + crystal.lattice.abc) > 40: + return False + + return True + +class Crystal: + """Strict crystal validity. See details in the paper CDVAE `Crystal + Diffution Variational Autoencoder` and + its codebase. We adopt the same evaluation metric criteria as CDVAE. + """ + + def __init__(self, crys_array_dict): + self.frac_coords = crys_array_dict['frac_coords'] + self.atom_types = crys_array_dict['atom_types'] + self.lengths = crys_array_dict['lengths'] + self.angles = crys_array_dict['angles'] + self.dict = crys_array_dict + if len(self.atom_types.shape) > 1: + self.dict['atom_types'] = np.argmax(self.atom_types, axis=-1) + 1 + self.atom_types = np.argmax(self.atom_types, axis=-1) + 1 + + self.get_structure() + self.get_composition() + self.get_validity() + self.get_fingerprints() + + def get_structure(self): + """get_structure + """ + if min(self.lengths.tolist()) < 0: + self.constructed = False + self.invalid_reason = 'non_positive_lattice' + if np.isnan(self.lengths).any() or np.isnan( + self.angles).any() or np.isnan(self.frac_coords).any(): + self.constructed = False + self.invalid_reason = 'nan_value' + else: + try: + self.structure = Structure(lattice=Lattice.from_parameters( + *(self.lengths.tolist() + self.angles.tolist())), + species=self.atom_types, + coords=self.frac_coords, + coords_are_cartesian=False) + self.constructed = True + # pylint: disable=W0703 + except Exception: + self.constructed = False + self.invalid_reason = 'construction_raises_exception' + if self.structure.volume < 0.1: + self.constructed = False + self.invalid_reason = 'unrealistically_small_lattice' + + def get_composition(self): + """get_composition + """ + elem_counter = Counter(self.atom_types) + composition = [(elem, elem_counter[elem]) + for elem in sorted(elem_counter.keys())] + elems, counts = list(zip(*composition)) + counts = np.array(counts) + counts = counts / np.gcd.reduce(counts) + self.elems = elems + self.comps = tuple(counts.astype('int').tolist()) + + def get_validity(self): + """get_validity + """ + self.comp_valid = smact_validity(self.elems, self.comps) + if self.constructed: + self.struct_valid = structure_validity(self.structure) + else: + self.struct_valid = False + self.valid = self.comp_valid and self.struct_valid + + def get_fingerprints(self): + """get_fingerprints + """ + elem_counter = Counter(self.atom_types) + comp = Composition(elem_counter) + self.comp_fp = comp_fp.featurize(comp) + try: + site_fps = [ + crystalnn_fp.featurize(self.structure, i) + for i in range(len(self.structure)) + ] + # pylint: disable=W0703 + except Exception: + # counts crystal as invalid if fingerprint cannot be constructed. + self.valid = False + self.comp_fp = None + self.struct_fp = None + return + self.struct_fp = np.array(site_fps).mean(axis=0) + + +def get_rms(pred_struc_list, gt_struc: Structure, num_eval, np_list): + """Calculate the rms distance between the ground truth and predicted crystal structures. + + Args: + pred_struc_list (List[Structure]): The crystals generated by diffution model + in the form of Structure. + gt_struc (Structure): The ground truth crystal. + num_eval (int): Specify that the first N items in the predicted List of crystal structures + participate in the evaluationo. + np_list (List[Dict]): The crystals generated by diffution model in the form of Dict. + """ + + def process_one(pred_struc: Structure): + try: + if not pred_struc.is_valid(): + return None + rms_dist = matcher.get_rms_dist(pred_struc, gt_struc) + rms_dist = None if rms_dist is None else rms_dist[0] + tune_rms = rms_dist + # pylint: disable=W0703 + except Exception: + tune_rms = None + return tune_rms + + min_rms = None + min_struc = None + for i, struct in enumerate(pred_struc_list): + if i == num_eval: + break + rms = process_one(struct) + if rms is not None and (min_rms is None or min_rms > rms): + min_rms = rms + min_struc = np_list[i] + return min_rms, min_struc + + +def get_struc_from_np_list(np_list): + """convert the crystal in the form of Dict to pymatgen.Structure + """ + result = [] + for cry_array in np_list: + try: + struct = Structure(lattice=Lattice.from_parameters( + *(cry_array['lengths'].tolist() + + cry_array['angles'].tolist())), + species=cry_array['atom_types'], + coords=cry_array['frac_coords'], + coords_are_cartesian=False) + # pylint: disable=W0703 + except Exception: + print('Warning: One anomalous crystal structure has captured and removed. ') + struct = None + + result.append(struct) + return result + +def main(args): + """main + """ + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + eval_file = config['test']['eval_save_path'] + num_eval = config['test']['num_eval'] + output_path = config['test']['metric_dir'] + + with open(eval_file, 'rb') as f: + eval_dict = pickle.load(f) + + pred_list = eval_dict['pred'] + gt_list = eval_dict['gt'] + gt_list = get_struc_from_np_list(gt_list) + rms = [] + + # calculate rmsd + for i in trange(len(gt_list)): + pred_struc = get_struc_from_np_list(pred_list[i]) + gt_struc = gt_list[i] + rms_single, struc_single = get_rms(pred_struc, gt_struc, num_eval, + pred_list[i]) + rms.append((rms_single, struc_single)) + + rms, struc_list = zip(*rms) + + # Remove the ones with RMSD as None, and store the valid structures in the list valid_crys. + rms_np = [] + valid_crys = [] + for i, rms_per in enumerate(rms): + if rms_per is not None: + rms_np.append(rms_per) + valid_crys.append(struc_list[i]) + + # Conduct rigorous structural verification, specifically through verification using the Crystal class. + print('Using the Crystal class for validity checks') + valid_list = p_map(lambda x: Crystal(x).valid, valid_crys) + rms_np_strict = [] + for i, is_valid in enumerate(valid_list): + if is_valid: + rms_np_strict.append(rms_np[i]) + + rms_np = np.array(rms_np_strict) + rms_valid_index = np.array([x is not None for x in rms_np_strict]) + + match_rate = rms_valid_index.sum() / len(gt_list) + rms = rms_np[rms_valid_index].mean() + + print('match_rate: ', match_rate) + print('rms: ', rms) + + all_metrics = {'match_rate': match_rate, 'rms_dist': rms} + + if Path(output_path).exists(): + metrics_out_file = f'eval_metrics_{num_eval}.json' + metrics_out_file = os.path.join(output_path, metrics_out_file) + + # only overwrite metrics computed in the new run. + if Path(metrics_out_file).exists(): + with open(metrics_out_file, 'r') as f: + written_metrics = json.load(f) + if isinstance(written_metrics, dict): + written_metrics.update(all_metrics) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + if isinstance(written_metrics, dict): + with open(metrics_out_file, 'w') as f: + json.dump(written_metrics, f) + else: + with open(metrics_out_file, 'w') as f: + json.dump(all_metrics, f) + else: + print('Warning: The metric result file path is not specified') + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml') + main_args = parser.parse_args() + main(main_args) diff --git a/MindChemistry/applications/crystalflow/config.yaml b/MindChemistry/applications/crystalflow/config.yaml new file mode 100644 index 000000000..514db4601 --- /dev/null +++ b/MindChemistry/applications/crystalflow/config.yaml @@ -0,0 +1,57 @@ +dataset: + data_name: 'mp_20' + train: + path: './dataset/mp_20/train.csv' + save_path: './dataset/mp_20/train.npy' + val: + path: './dataset/mp_20/val.csv' + save_path: './dataset/mp_20/val.npy' + test: + path: './dataset/mp_20/test.csv' + save_path: './dataset/mp_20/test.npy' + +model: + # For dataset carbon, mp, mpts + hidden_dim: 512 + num_layers: 6 + num_freqs: 256 + # # For dataset perov + # hidden_dim: 256 + # num_layers: 4 + # num_freqs: 10 + conditions: + pressure: + _target_: crystalflow.train_pressure + start: -2 + stop: 2 + n_out: 128 + +train: + ckpt_dir: "./ckpt/mp_20" + # 3500, 4000, 1000, 1000 epochs for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + epoch_size: 3000 + # 512, 512, 128, 128 for Perov-5, Carbon-24, MP-20 and MPTS-52 respectively. + batch_size: 256 + seed: 1234 + cost_lattice: 1 + cost_coord: 10 + +checkpoint: + last_path: "./ckpt/mp_20/last_test.ckpt" + +test: + # 1024 for perov, 512 for carbon and mp, 256 for mpts + batch_size: 512 + num_eval: 1 + # 1e-5 for mp and mpts, 5e-7 for perov, 5e-6 for carbon num_eval=1 and 5e-7 for carbon num_eval=20 + step_lr: 1e-5 + eval_save_path: './ckpt/mp_20/predict_crys.pkl' + metric_dir: './ckpt/mp_20/' + + + + + + + + diff --git a/MindChemistry/applications/crystalflow/data/crysloader.py b/MindChemistry/applications/crystalflow/data/crysloader.py new file mode 100644 index 000000000..dccb4ddbc --- /dev/null +++ b/MindChemistry/applications/crystalflow/data/crysloader.py @@ -0,0 +1,215 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataloader file""" +import numpy as np +from mindspore import Tensor, ops +import mindspore as ms + +from mindchemistry.graph.dataloader import DataLoaderBase, CommonData + + +class Crysloader(DataLoaderBase): + """ + Crysloader is used to stacks a batch of graph data to fixed-size Tensors. + + Exactly the same code logic as DataLoaderBase, with additional node attribute 'frac_coords' + and graph attributes, 'lengths' and 'angles' + """ + + def __init__(self, + batch_size, + node_attr=None, + frac_coords=None, + edge_attr=None, + edge_index=None, + lengths=None, + angles=None, + lattice_polar=None, + num_atoms=None, + label=None, + padding_std_ratio=3.5, + dynamic_batch_size=True, + shuffle_dataset=True, + max_node=None, + max_edge=None): + self.batch_size = batch_size + self.edge_index = edge_index + self.index = 0 + self.step = 0 + self.padding_std_ratio = padding_std_ratio + self.batch_change_num = 0 + self.batch_exceeding_num = 0 + self.dynamic_batch_size = dynamic_batch_size + self.shuffle_dataset = shuffle_dataset + + ## can be customized to specific dataset + self.label = label + self.node_attr = node_attr + self.frac_coords = frac_coords + self.edge_attr = edge_attr + self.lengths = lengths + self.angles = angles + self.lattice_polar = lattice_polar + self.num_atoms = num_atoms + self.sample_num = len(self.node_attr) + batch_size_div = self.batch_size + if batch_size_div != 0: + self.step_num = int(self.sample_num / batch_size_div) + else: + print('The batch size cannot be set to 0') + raise ValueError + + if dynamic_batch_size: + self.max_start_sample = self.sample_num + else: + self.max_start_sample = self.sample_num - self.batch_size + 1 + + self.set_global_max_node_edge_num(self.node_attr, self.edge_attr, + max_node, max_edge, shuffle_dataset, + dynamic_batch_size) + + def __iter__(self): + if self.shuffle_dataset: + self.shuffle() + else: + self.restart() + + while self.index < self.max_start_sample: + edge_index_step, node_batch_step, node_mask, edge_mask, \ + batch_size_mask, node_num, _, batch_size \ + = self.gen_common_data(self.node_attr, self.edge_attr) + + node_attr_step = self.gen_node_attr(self.node_attr, batch_size, + node_num) + node_attr_step = ops.reshape(node_attr_step, (-1,)) + node_attr_step = ops.Cast()(node_attr_step, ms.int32) + frac_coords_step = self.gen_node_attr(self.frac_coords, batch_size, + node_num) + label_step = self.gen_global_attr(self.label, batch_size) + lengths_step = self.gen_global_attr(self.lengths, batch_size) + angles_step = self.gen_global_attr(self.angles, batch_size) + lattice_polar_step = self.gen_global_attr(self.lattice_polar, batch_size) + num_atoms_step = self.gen_global_attr(self.num_atoms, batch_size).to(ms.int32) + + self.add_step_index(batch_size) + + ### make number to Tensor, if it is used as a Tensor in the network + node_num = Tensor(node_num, ms.int32) + batch_size = Tensor(batch_size, ms.int32) + + yield node_attr_step, frac_coords_step, label_step, lengths_step, \ + angles_step, lattice_polar_step, num_atoms_step, edge_index_step, node_batch_step, \ + node_mask, edge_mask, batch_size_mask, node_num, batch_size + + def shuffle_action(self): + """shuffle_action""" + indices = self.shuffle_index() + self.edge_index = [self.edge_index[i] for i in indices] + self.label = [self.label[i] for i in indices] + self.node_attr = [self.node_attr[i] for i in indices] + self.frac_coords = [self.frac_coords[i] for i in indices] + self.edge_attr = [self.edge_attr[i] for i in indices] + self.lengths = [self.lengths[i] for i in indices] + self.angles = [self.angles[i] for i in indices] + self.lattice_polar = [self.lattice_polar[i] for i in indices] + self.num_atoms = [self.num_atoms[i] for i in indices] + + def gen_common_data(self, node_attr, edge_attr): + """gen_common_data + + Args: + node_attr: node_attr, i.e. atom types + edge_attr: edge_attr + + Returns: + common_data + """ + if self.dynamic_batch_size: + if self.step >= self.step_num: + batch_size = self.get_batch_size( + node_attr, edge_attr, + min((self.sample_num - self.index), self.batch_size)) + else: + batch_size = self.get_batch_size(node_attr, edge_attr, + self.batch_size) + else: + batch_size = self.batch_size + + ######################## node_batch + node_batch_step = [] + sample_num = 0 + for i in range(self.index, self.index + batch_size): + node_batch_step.extend([sample_num] * node_attr[i].shape[0]) + sample_num += 1 + node_batch_step = np.array(node_batch_step) + node_num = node_batch_step.shape[0] + + ######################## edge_index + edge_index_step = np.array([[], []], dtype=np.int64) + max_edge_index = 0 + for i in range(self.index, self.index + batch_size): + edge_index_step = np.concatenate( + (edge_index_step, self.edge_index[i] + max_edge_index), 1) + max_edge_index = np.max(edge_index_step) + 1 + edge_num = edge_index_step.shape[1] + + ######################### padding + edge_index_step = self.pad_zero_to_end( + edge_index_step, 1, self.max_edge_num_global - edge_num) + node_batch_step = self.pad_zero_to_end( + node_batch_step, 0, self.max_node_num_global - node_num) + + ######################### mask + node_mask = self.gen_mask(self.max_node_num_global, node_num) + edge_mask = self.gen_mask(self.max_edge_num_global, edge_num) + batch_size_mask = self.gen_mask(self.batch_size, batch_size) + + ######################### make Tensor + edge_index_step = Tensor(edge_index_step, ms.int32) + node_batch_step = Tensor(node_batch_step, ms.int32) + node_mask = Tensor(node_mask, ms.int32) + edge_mask = Tensor(edge_mask, ms.int32) + batch_size_mask = Tensor(batch_size_mask, ms.int32) + + return CommonData(edge_index_step, node_batch_step, node_mask, + edge_mask, batch_size_mask, node_num, edge_num, + batch_size).get_tuple_data() + + def gen_node_attr(self, node_attr, batch_size, node_num): + """gen_node_attr""" + node_attr_step = np.concatenate( + node_attr[self.index:self.index + batch_size], 0) + node_attr_step = self.pad_zero_to_end( + node_attr_step, 0, self.max_node_num_global - node_num) + node_attr_step = Tensor(node_attr_step, ms.float32) + return node_attr_step + + def gen_edge_attr(self, edge_attr, batch_size, edge_num): + """gen_edge_attr""" + edge_attr_step = np.concatenate( + edge_attr[self.index:self.index + batch_size], 0) + edge_attr_step = self.pad_zero_to_end( + edge_attr_step, 0, self.max_edge_num_global - edge_num) + edge_attr_step = Tensor(edge_attr_step, ms.float32) + return edge_attr_step + + def gen_global_attr(self, global_attr, batch_size): + """gen_global_attr""" + global_attr_step = np.stack( + global_attr[self.index:self.index + batch_size], 0) + global_attr_step = self.pad_zero_to_end(global_attr_step, 0, + self.batch_size - batch_size) + global_attr_step = Tensor(global_attr_step, ms.float32) + return global_attr_step diff --git a/MindChemistry/applications/crystalflow/data/data_utils.py b/MindChemistry/applications/crystalflow/data/data_utils.py new file mode 100644 index 000000000..700c032a1 --- /dev/null +++ b/MindChemistry/applications/crystalflow/data/data_utils.py @@ -0,0 +1,380 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""data utils file""" +import numpy as np +import pandas as pd +import scipy + +import mindspore as ms + +from pymatgen.core.structure import Structure +from pymatgen.core.lattice import Lattice +from pymatgen.analysis.graphs import StructureGraph +from pymatgen.analysis import local_env + +from p_tqdm import p_umap + +# Tensor of unit cells. Assumes 27 cells in -1, 0, 1 offsets in the x and y dimensions +# Note that differing from OCP, we have 27 offsets here because we are in 3D +OFFSET_LIST = [ + [-1, -1, -1], + [-1, -1, 0], + [-1, -1, 1], + [-1, 0, -1], + [-1, 0, 0], + [-1, 0, 1], + [-1, 1, -1], + [-1, 1, 0], + [-1, 1, 1], + [0, -1, -1], + [0, -1, 0], + [0, -1, 1], + [0, 0, -1], + [0, 0, 0], + [0, 0, 1], + [0, 1, -1], + [0, 1, 0], + [0, 1, 1], + [1, -1, -1], + [1, -1, 0], + [1, -1, 1], + [1, 0, -1], + [1, 0, 0], + [1, 0, 1], + [1, 1, -1], + [1, 1, 0], + [1, 1, 1], +] + +EPSILON = 1e-5 + +chemical_symbols = [ + # 0 + 'X', + # 1 + 'H', 'He', + # 2 + 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', + # 3 + 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', + # 4 + 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', + 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', + # 5 + 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', + 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', + # 6 + 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', + 'Ho', 'Er', 'Tm', 'Yb', 'Lu', + 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', + 'Po', 'At', 'Rn', + # 7 + 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', + 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', + 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', + 'Lv', 'Ts', 'Og'] + +def lattice_params_to_matrix(a, b, c, alpha, beta, gamma): + r"""Converts lattice from abc, angles to matrix. + https://github.com/materialsproject/pymatgen/blob/b789d74639aa851d7e5ee427a765d9fd5a8d1079/pymatgen/core/lattice.py#L311 + """ + angles_r = np.radians([alpha, beta, gamma]) + cos_alpha, cos_beta, cos_gamma = np.cos(angles_r) + sin_alpha, sin_beta, _ = np.sin(angles_r) + + val = (cos_alpha * cos_beta - cos_gamma) / (sin_alpha * sin_beta) + # Sometimes rounding errors result in values slightly > 1. + val = abs_cap(val) + gamma_star = np.arccos(val) + + vector_a = [a * sin_beta, 0.0, a * cos_beta] + vector_b = [ + -b * sin_alpha * np.cos(gamma_star), + b * sin_alpha * np.sin(gamma_star), + b * cos_alpha, + ] + vector_c = [0.0, 0.0, float(c)] + return np.array([vector_a, vector_b, vector_c]) + +def lattice_polar_decompose(lattice: np.ndarray): + """decompose the lattice to lattice_polar""" + assert lattice.ndim == 2 + a, u = np.linalg.eigh(lattice @ lattice.T) + a, u = np.real(a), np.real(u) + a = np.diag(np.log(a)) / 2 + s = u @ a @ u.T + + k = np.array( + [ + s[0, 1], + s[0, 2], + s[1, 2], + (s[0, 0] - s[1, 1]) / 2, + (s[0, 0] + s[1, 1] - 2 * s[2, 2]) / 6, + (s[0, 0] + s[1, 1] + s[2, 2]) / 3, + ] + ) + return k + +def lattice_polar_build(k: np.ndarray): + """build lattice using lattice_polar""" + assert k.ndim == 1 + s = np.array( + [ + [k[3] + k[4] + k[5], k[0], k[1]], + [k[0], -k[3] + k[4] + k[5], k[2]], + [k[1], k[2], -2 * k[4] + k[5]], + ] + ) # (3, 3) + exp_s = scipy.linalg.expm(s) # (3, 3) + return exp_s + +class StandardScalerMS: + """Normalizes the targets of a dataset.""" + + def __init__(self, means=None, stds=None): + self.means = ms.Tensor(means, dtype=ms.float32) + self.stds = ms.Tensor(stds, dtype=ms.float32) + + def transform(self, x): + x = ms.Tensor(x, dtype=ms.float32) + return (x - self.means) / self.stds + + def inverse_transform(self, x): + x = ms.Tensor(x, dtype=ms.float32) + return x * self.stds + self.means + + def copy(self): + return StandardScalerMS( + means=self.means.copy(), + stds=self.stds.copy()) + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"means: {self.means}, " + f"stds: {self.stds})" + ) + +crystalnn = local_env.CrystalNN( + distance_cutoffs=None, x_diff_weight=-1, porous_adjustment=False) + +def build_crystal(crystal_str, niggli=True, primitive=False): + """Build crystal from cif string.""" + crystal = Structure.from_str(crystal_str, fmt='cif') + + if primitive: + crystal = crystal.get_primitive_structure() + + if niggli: + crystal = crystal.get_reduced_structure() + + canonical_crystal = Structure( + lattice=Lattice.from_parameters(*crystal.lattice.parameters), + species=crystal.species, + coords=crystal.frac_coords, + coords_are_cartesian=False, + ) + return canonical_crystal + +def build_crystal_graph(crystal, graph_method='crystalnn'): + """build crystal graph especially for edge data from Structure of Pymatgen. + Convert them to numpy arrays.""" + if graph_method == 'crystalnn': + crystal_graph = StructureGraph.with_local_env_strategy( + crystal, crystalnn) + elif graph_method == 'none': + pass + else: + raise NotImplementedError + + frac_coords = crystal.frac_coords + atom_types = crystal.atomic_numbers + lattice_parameters = crystal.lattice.parameters + lengths = lattice_parameters[:3] + angles = lattice_parameters[3:] + lattice_polar = lattice_polar_decompose(crystal.lattice.matrix) + + edge_indices, to_jimages = [], [] + + if graph_method != 'none': + for i, j, to_jimage in crystal_graph.graph.edges(data='to_jimage'): + edge_indices.append([j, i]) + to_jimages.append(to_jimage) + edge_indices.append([i, j]) + to_jimages.append(tuple(-tj for tj in to_jimage)) + + atom_types = np.array(atom_types) + lengths, angles = np.array(lengths), np.array(angles) + edge_indices = np.array(edge_indices) + to_jimages = np.array(to_jimages) + num_atoms = atom_types.shape[0] + + return frac_coords, atom_types, lengths, angles, lattice_polar, edge_indices, to_jimages, num_atoms + +def abs_cap(val, max_abs_val=1): + """ + Returns the value with its absolute value capped at max_abs_val. + Particularly useful in passing values to trignometric functions where + numerical errors may result in an argument > 1 being passed in. + https://github.com/materialsproject/pymatgen/blob/b789d74639aa851d7e5ee427a765d9fd5a8d1079/pymatgen/util/num.py#L15 + Args: + val (float): Input value. + max_abs_val (float): The maximum absolute value for val. Defaults to 1. + Returns: + val if abs(val) < 1 else sign of val * max_abs_val. + """ + return max(min(val, max_abs_val), -max_abs_val) + +def lattice_matrix_to_params(matrix): + """Converts matrix to lattice from abc, angles. + https://github.com/materialsproject/pymatgen/blob/b789d74639aa851d7e5ee427a765d9fd5a8d1079/pymatgen/core/lattice.py#L311 + """ + lengths = np.sqrt(np.sum(matrix ** 2, axis=1)).tolist() + + angles = np.zeros(3) + for i in range(3): + j = (i + 1) % 3 + k = (i + 2) % 3 + angles[i] = abs_cap(np.dot(matrix[j], matrix[k]) / + (lengths[j] * lengths[k])) + angles = np.arccos(angles) * 180.0 / np.pi + a, b, c = lengths + alpha, beta, gamma = angles + return a, b, c, alpha, beta, gamma + +def preprocess(input_file, num_workers, niggli, primitive, graph_method, + prop_list, nrows=-1): + """ + Read crystal data from a dataset CSV file and preprocess it + + Args: + input_file (str): The path of dataset csv. + num_workers (int): The numbers of cpus used for preprocessing the crystals. + niggli (bool): Whether to use niggli algorithom to preprocess the choice of lattice. + primitive (bool): Whether to represent the crystal in primitive cell. + graph_method (str): If 'crystalnn', construct the graph by crystalnn algorithm, + mainly effect the construct of edges. If 'none', don't construct any edge. + prop_list (list[str]): Read the property of crystal as specified by the element of the list. + nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. If nrows = -1, read the whole csv file. + This arg is mainly for debugging to quickly load a few crystals. + + Returns: + List. Return the list of crystals. Each element is a Dict composed by: + { + 'mp_id': int, + 'cif': crystal string, + 'graph_arrays': numpy arrays of frac_coords, atom_types, + lengths, angles, edge_indices, to_jimages, num_atoms, + } + """ + if nrows == -1: + df = pd.read_csv(input_file) + # for debug + else: + df = pd.read_csv(input_file, nrows=nrows) + + def process_one(row, niggli, primitive, graph_method, prop_list): + crystal_str = row['cif'] + + crystal = build_crystal( + crystal_str, niggli=niggli, primitive=primitive) + # 晶体构建图,重点 + graph_arrays = build_crystal_graph(crystal, graph_method) + + properties = {k: row[k] for k in prop_list if k in row.keys()} + result_dict = { + 'mp_id': row['material_id'], + 'cif': crystal_str, + 'graph_arrays': graph_arrays, + } + result_dict.update(properties) + return result_dict + + unordered_results = p_umap( + process_one, + [df.iloc[idx] for idx in range(len(df))], + [niggli] * len(df), + [primitive] * len(df), + [graph_method] * len(df), + [prop_list] * len(df), + num_cpus=num_workers) + + mpid_to_results = {result['mp_id']: result for result in unordered_results} + ordered_results = [mpid_to_results[df.iloc[idx]['material_id']] + for idx in range(len(df))] + return ordered_results + +class StandardScaler: + """A :class:`StandardScaler` normalizes the features of a dataset. + When it is fit on a dataset, the :class:`StandardScaler` learns the + mean and standard deviation across the 0th axis. + When transforming a dataset, the :class:`StandardScaler` subtracts the + means and divides by the standard deviations. + """ + + def __init__(self, means=None, stds=None, replace_nan_token=None): + """ + :param means: An optional 1D numpy array of precomputed means. + :param stds: An optional 1D numpy array of precomputed standard deviations. + :param replace_nan_token: A token to use to replace NaN entries in the features. + """ + self.means = means + self.stds = stds + self.replace_nan_token = replace_nan_token + + def fit(self, x): + """ + Learns means and standard deviations across the 0th axis of the data :code:`x`. + :param x: A list of lists of floats (or None). + :return: The fitted :class:`StandardScaler` (self). + """ + x = np.array(x).astype(float) + self.means = np.nanmean(x, axis=0) + self.stds = np.nanstd(x, axis=0) + self.means = np.where(np.isnan(self.means), + np.zeros(self.means.shape), self.means).astype(float) + self.stds = np.where(np.isnan(self.stds), + np.ones(self.stds.shape), self.stds) + self.stds = np.where(self.stds == 0, np.ones( + self.stds.shape), self.stds).astype(float) + + return self + + def transform(self, x): + """ + Transforms the data by subtracting the means and dividing by the standard deviations. + :param x: A list of lists of floats (or None). + :return: The transformed data with NaNs replaced by :code:`self.replace_nan_token`. + """ + x = np.array(x).astype(float) + transformed_with_nan = (x - self.means) / self.stds + transformed_with_none = np.where( + np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan).astype(float) + + return transformed_with_none + + def inverse_transform(self, x): + """ + Performs the inverse transformation by multiplying by the standard deviations and adding the means. + :param x: A list of lists of floats. + :return: The inverse transformed data with NaNs replaced by :code:`self.replace_nan_token`. + """ + x = np.array(x).astype(float) + transformed_with_nan = x * self.stds + self.means + transformed_with_none = np.where( + np.isnan(transformed_with_nan), self.replace_nan_token, transformed_with_nan).astype(float) + + return transformed_with_none diff --git a/MindChemistry/applications/crystalflow/data/dataset.py b/MindChemistry/applications/crystalflow/data/dataset.py new file mode 100644 index 000000000..7ba9d7901 --- /dev/null +++ b/MindChemistry/applications/crystalflow/data/dataset.py @@ -0,0 +1,139 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""dataset file""" +import os +from pathlib import Path + +import numpy as np + +from data.data_utils import StandardScaler, preprocess + +if Path('./dataset/dataset_prop.txt').exists(): + with open('./dataset/dataset_prop.txt', 'r') as file: + data = file.read() + # pylint: disable=W0123 + scalar_dict = eval(data) +else: + scalar_dict = {} + +def fullconnect_dataset(name, + path, + niggli=True, + primitive=False, + graph_method='none', + preprocess_workers=30, + save_path='', + nrows=-1): + """ + Read crystal data from a CSV file and convert each into a fully connected graph, + where the nodes represent atoms within the unit cell and the edges connect every pair of nodes. + + Args: + name (str): The name of dataset, mainly used to read the dataset + property in './dataset/dataset_prop.txt'. + It doesn't matter for crystal structure prediction task. + Choices: [perov_5, carbon_24, mp_20, mpts_52]. + Users can also create custom datasets, by modify the + './dataset/dataset_prop.txt'. + path (str): The path of csv file of dataset. + niggli (bool): Whether to use niggli algorithom to + preprocess the choice of lattice. Default: + ``True``. + primitive (bool): Whether to represent the crystal in primitive cell. Default: + ``False``. + graph_method (str): If 'crystalnn', construct the graph by crystalnn + algorithm, mainly effect the construct of edges. + If 'none', don't construct any edge. Default: ``none``. + preprocess_workers (int): The numbers of cpus used for + preprocessing the crystals. Default: ``None``. + save_path (str): The path for saving the preprocessed data, + aiming to load the dataset more quickly next time. + nrows (int): If nrows > 0, read the first 'nrows' lines of csv file. + If nrows = -1, read the whole csv file. + This arg is mainly for debugging to quickly load a few crystals. + Returns: + x (list): List of Atom types. Shape of each element i.e. numpy array: (num_atoms, 1) + frac_coord_list (list): List of Fractional Coordinates of atoms. + Shape of each element i.e. numpy array: (num_atoms, 3) + edge_attr (list): List of numpy arrays filled with ones, + just used to better construct the dataloader, + without numerical significance. Shape of each element + i.e. numpy array: (num_edges, 3) + edge_index (list): List of index of the beginning and end + of edges. Each element is composed as [src, dst], where + src and dst is numpy arrays with Shape (num_edges,). + lengths_list (list): List of lengths of lattice. Shape of + each element i.e. numpy array: (3,) + angles_list (list): List of angles of lattice. Shape of + each element i.e. numpy array: (3,) + labels (list): List of property of crystal. Shape of + each element i.e. numpy array: (1,) + """ + x = [] + frac_coord_list = [] + edge_index = [] + edge_attr = [] + labels = [] + lengths_list = [] + angles_list = [] + lattice_polar_list = [] + num_atoms_list = [] + + if name in scalar_dict.keys(): + prop = scalar_dict[name]['prop'] + scaler = StandardScaler(scalar_dict[name]['scaler.means'], + scalar_dict[name]['scaler.stds']) + else: + print('No dataset property is specified, so no property reading is performed') + prop = "None" + scaler = None + + if os.path.exists(save_path): + cached_data = np.load(save_path, allow_pickle=True) + else: + cached_data = preprocess(path, + preprocess_workers, + niggli=niggli, + primitive=primitive, + graph_method=graph_method, + prop_list=[prop], + nrows=nrows) + + np.save(save_path, cached_data) + + for idx in range(len(cached_data)): + data_dict = cached_data[idx] + (frac_coords, atom_types, lengths, angles, lattice_polar, _, _, num_atoms) = data_dict['graph_arrays'] + + indices = np.arange(num_atoms) + dst, src = np.meshgrid(indices, indices) + src = src.reshape(-1) + dst = dst.reshape(-1) + + x.append(atom_types.reshape(-1, 1)) + frac_coord_list.append(frac_coords) + edge_index.append(np.array([src, dst])) + edge_attr.append(np.ones((num_atoms * num_atoms, 3))) + lengths_list.append(lengths) + angles_list.append(angles) + lattice_polar_list.append(lattice_polar) + num_atoms_list.append(num_atoms) + if scaler is not None: + labels.append(scaler.transform(data_dict[prop])) + else: + labels.append(0.0) + + return x, frac_coord_list, edge_attr, edge_index, lengths_list, \ + angles_list, lattice_polar_list, num_atoms_list, labels diff --git a/MindChemistry/applications/crystalflow/evaluate.py b/MindChemistry/applications/crystalflow/evaluate.py new file mode 100644 index 000000000..1bd2d1c54 --- /dev/null +++ b/MindChemistry/applications/crystalflow/evaluate.py @@ -0,0 +1,175 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""evaluate file""" +import os +import pickle +import time +import argparse +import yaml + +import mindspore as ms + +from data.crysloader import Crysloader +from data.dataset import fullconnect_dataset +from models.cspnet import CSPNet +from models.flow import CSPFlow +from models.infer_utils import (count_consecutive_occurrences, + lattices_to_params_ms) + + +def flow(loader, model, num_evals, n, anneal_slope, anneal_offset): + """Generating "num_evals" crystals for each composition in the dataset. + + Args: + loader (Crysloader): The dataset loader. + model (nn.cell): The diffution model. + num_evals (int): The number of generated crystals for each composition + step_lr (float): Langevin dynamics. Defaults to 1e-5. + + Returns: + Tuple(List[Dict], List[List[Dict]]): The ground truth and predicted crystals.The form is + as follows: + ... + ( + [ + [Crystal A sample 1, Crystal A sample 2, Crystal A sample 3, ... Crystal A sample num_eval], + [Crystal B sample 1, Crystal B sample 2, Crystal B sample 3, ... Crystal B sample num_eval] + ... + ], + + [ + Crystal A ground truth, + Crystal B ground truth, + ... + ] + ) + ... + """ + gt_struc = [] + pred_struc = [] + for atom_types_step, frac_coords_step, _, lengths_step, \ + angles_step, _, num_atoms_step, edge_index_step, node_batch_step, \ + node_mask_step, edge_mask_step, batch_mask, _, batch_size_step in loader: + num_node_list = count_consecutive_occurrences( + node_batch_step.asnumpy().tolist()) + + pred_struc_batch = [[] for _ in range(batch_size_step)] + epoch_starttime = time.time() + for eval_idx in range(num_evals): + print( + f'Batch {loader.step} / {loader.step_num+1}, sample {eval_idx} / {num_evals}' + ) + starttime = time.time() + _, frac_coords_t, lattices_t = model.sample(node_batch_step, + node_mask_step, + edge_mask_step, + batch_mask, + atom_types_step, + edge_index_step, + num_atoms_step, + n, + anneal_slope, + anneal_offset) + lengths_pred, angles_pred = lattices_to_params_ms( + lattices_t[:batch_size_step]) + + start_index = 0 + for i in range(batch_size_step): + num_node_i = num_node_list[i] + atom_types_i = atom_types_step[start_index:start_index + + num_node_i].asnumpy() + frac_coords_i = frac_coords_t[start_index:start_index + + num_node_i].asnumpy() + lengths_i = lengths_pred[i].asnumpy() + angles_i = angles_pred[i].asnumpy() + + pred_struc_batch[i].append({ + 'atom_types': atom_types_i, + 'frac_coords': frac_coords_i, + 'lengths': lengths_i, + 'angles': angles_i + }) + + if eval_idx == 0: + frac_coords_gt = frac_coords_step[start_index:start_index + + num_node_i].asnumpy() + lengths_gt = lengths_step[i].asnumpy() + angles_gt = angles_step[i].asnumpy() + gt_struc.append({ + 'atom_types': atom_types_i, + 'frac_coords': frac_coords_gt, + 'lengths': lengths_gt, + 'angles': angles_gt + }) + + start_index += num_node_i + + starttime0 = starttime + starttime = time.time() + print(f"Evaluation time: {starttime - starttime0} s") + pred_struc.extend(pred_struc_batch) + + print( + f"##########Evaluation time for one Batch : \ + {time.time() - epoch_starttime} s ################" + ) + return gt_struc, pred_struc + +def main(args): + """main + """ + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + test_datatset = fullconnect_dataset( + name=config['dataset']['data_name'], + path=config['dataset']['test']['path'], + save_path=config['dataset']['test']['save_path']) + test_loader = Crysloader(config['test']['batch_size'], + *test_datatset, + shuffle_dataset=False) + + decoder = CSPNet(num_layers=config['model']['num_layers'], + hidden_dim=config['model']['hidden_dim'], + num_freqs=config['model']['num_freqs']) + mindspore_ckpt = ms.load_checkpoint(config['checkpoint']['last_path']) + ms.load_param_into_net(decoder, mindspore_ckpt) + + model = CSPFlow(decoder) + + model.set_train(False) + + gt_struc, pred_struc = flow(test_loader, + model, + config['test']['num_eval'], + n=args.N, + anneal_slope=args.anneal_slope, + anneal_offset=args.anneal_offset) + + eval_save_path = config['test']['eval_save_path'] + os.makedirs(os.path.dirname(eval_save_path), exist_ok=True) + + with open(eval_save_path, 'wb') as f: + pickle.dump({'pred': pred_struc, 'gt': gt_struc}, f) + +if __name__ == '__main__': + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml') + parser.add_argument('-N', type=int, default=1000) + parser.add_argument('--anneal_slope', type=float, default=0.0) + parser.add_argument('--anneal_offset', type=float, default=0.0) + main_args = parser.parse_args() + main(main_args) + \ No newline at end of file diff --git a/MindChemistry/applications/crystalflow/models/conditioning.py b/MindChemistry/applications/crystalflow/models/conditioning.py new file mode 100644 index 000000000..799b8b6ad --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/conditioning.py @@ -0,0 +1,34 @@ +"""condition utils""" +import mindspore as ms +from mindspore import ops, nn + +class GaussianExpansion(nn.Cell): + r"""Expansion layer using a set of Gaussian functions. + https://github.com/atomistic-machine-learning/cG-SchNet/blob/53d73830f9fb1158296f060c2f82be375e2bb7f9/nn_classes.py#L687) + """ + def __init__(self, start, stop, n_gaussians=50, trainable=False, width=None): + super(GaussianExpansion, self).__init__() + offset = ops.linspace(start, stop, n_gaussians) + self.n_out = n_gaussians + if width is None: + widths = (offset[1] - offset[0]) * ops.ones_like(offset) + else: + widths = width * ops.ones_like(offset) + if trainable: + self.widths = ms.Parameter(widths) + self.offsets = ms.Parameter(offset) + else: + self.widths = ms.Parameter(widths, requires_grad=False) + self.offsets = ms.Parameter(offset, requires_grad=False) + + def construct(self, prop): + """Compute expanded gaussian property values. + Args: + prop (Tensor): property values of (N_b x 1) shape. + Returns: + Tensor: layer output of (N_b x N_g) shape. + """ + prop = prop.reshape(prop.shape[0], -1) + coeff = -0.5 / ops.pow(self.widths, 2)[None, :] + diff = prop - self.offsets[None, :] + return ops.exp(coeff * ops.pow(diff, 2)) diff --git a/MindChemistry/applications/crystalflow/models/cspnet.py b/MindChemistry/applications/crystalflow/models/cspnet.py new file mode 100644 index 000000000..c5ae68956 --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/cspnet.py @@ -0,0 +1,278 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GNN denoiser file""" +import math + +import numpy as np +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindchemistry.graph.graph import (AggregateEdgeToNode, + AggregateNodeToGlobal, LiftGlobalToNode) + +MAX_ATOMIC_NUM = 100 + + +class SinusoidsEmbedding(nn.Cell): + """ + + Fourier embedding for edge features to address periodic translation invariance + as described in the paper of CrystalFlow. + + Args: + n_frequencies (int): The number of frequencies for embedding. + n_space (int): The dimension of edge feature. + """ + + def __init__(self, n_frequencies=10, n_space=3): + super(SinusoidsEmbedding, self).__init__() + self.n_frequencies = n_frequencies + self.n_space = n_space + self.frequencies = 2 * math.pi * np.arange(self.n_frequencies) + self.dim = self.n_frequencies * 2 * self.n_space + + def construct(self, x): + """construct + + Args: + x (Tensor): Distance + + Returns: + Tensor: Fourier embedding + """ + emb = ops.ExpandDims()(x, -1) * Tensor( + self.frequencies, + dtype=mindspore.float32).expand_dims(0).expand_dims(0) + emb = ops.Reshape()(emb, (-1, self.n_frequencies * self.n_space)) + emb = ops.Concat(axis=-1)((ops.Sin()(emb), ops.Cos()(emb))) + return emb + + +def mul_mask(features, mask): + """Make the padded dim of features to be zeros + + Args: + features (Tensor): Input tensor + mask (Tensor): Value 1 specifies the corresponding dimension of input tensor to be valid, + and value 0 specifies the corresponding dimension of input tensor to be zero. + + Returns: + Tensor: Output tensor + """ + return ops.mul(features, ops.reshape(mask, (-1, 1))) + + +class CSPLayer(nn.Cell): + r"""One layer of the GNN denoiser. For the input node feature + :math:`h_i^{(s-1)}` from last layer, the lattice matrix + :math:`L` and the fractional coordinates :math:`f_i`, the formula is defined as: + + .. math:: + + m_{ij}^{(s)} = mlp_m(h_i^{(s-1)}, h_j^{(s-1)}, L^\topL, \text{edge_emb}(f_j - f_i)), + + m_{i}^{(s)} = \sum_{j=1}^N m_{ij}^{(s)}, + + h_{i}^{(s)} = h_i^{(s-1)} + mlp_h(h_i^{(s-1)}, m_{i}^{(s)}). + + ... + + Then we can get the new node features `h_i^{(s)}`. + + """ + + def __init__(self, hidden_dim=512, act_fn=nn.SiLU(), dis_emb=None): + """Initialization + + Args: + hidden_dim (int): The dimension of hidden node features. Defaults to 512. + act_fn (nn): The activation function used in the layer. Defaults to nn.SiLU(). + dis_emb (object): The embbing method used for edge features. Defaults to None. + """ + super(CSPLayer, self).__init__() + self.dis_dim = 3 + self.dis_emb = dis_emb + if dis_emb is not None: + self.dis_dim = dis_emb.dim + self.edge_mlp = nn.SequentialCell([ + nn.Dense(hidden_dim * 2 + 6 + self.dis_dim, hidden_dim), act_fn, + nn.Dense(hidden_dim, hidden_dim), act_fn + ]) + self.node_mlp = nn.SequentialCell([ + nn.Dense(hidden_dim * 2, hidden_dim), act_fn, + nn.Dense(hidden_dim, hidden_dim), act_fn + ]) + + self.layer_norm = nn.LayerNorm([hidden_dim], epsilon=1e-5) + + self.edge_scatter = AggregateEdgeToNode(mode='mean', dim=0) + + def edge_model(self, node_features, lattice_polar, edge_index, + edge2graph, frac_diff, edge_mask): + """Edge embbding for edge feature. + """ + hi, hj = node_features[edge_index[0]], node_features[edge_index[1]] + + frac_diff = self.dis_emb(frac_diff) + + #lattice_ips = ops.BatchMatMul()(lattices, + # lattices.transpose(0, -1, -2)) + + #lattice_ips_flatten = ops.Reshape()(lattice_ips, (-1, 9)) + lattice_polar = ops.Gather()(lattice_polar, edge2graph, 0) + + edges_input = ops.Concat(axis=1)( + (hi, hj, frac_diff, lattice_polar)) + edges_input = mul_mask(edges_input, edge_mask) + + edge_features = self.edge_mlp(edges_input) + return edge_features + + def node_model(self, node_features, edge_features, edge_index, edge_mask): + """Aggregate the edge features to be the node features. + """ + agg = self.edge_scatter(edge_features, + edge_index, + dim_size=node_features.shape[0], + mask=edge_mask) + agg = ops.Concat(axis=1)((node_features, agg)) + out = self.node_mlp(agg) + return out + + def construct(self, node_features, lattice_polar, edge_index, + edge2graph, frac_diff, node_mask, edge_mask): + """Apply GNN layer over node features from last layer. + + Args: + node_features (Tensor): Node features from last layer. Shape: (num_atoms, hidden_dim) + frac_coords (Tensor): Fractional coordinates for calculating edge features. + Shape: (num_atoms, 3) + lattices (Tensor): Lattice mattrix for calculating edge features. + Shape: (batchsize, 3, 3) + edge_index (Tensor): Edge index for aggregating the edge features. + Shape: (2, num_edges) + edge2graph (Tensor): Graph index to lift the lattice to edge features. + Shape: (num_edges,) + frac_diff (Tensor): Distance of fractional coordinates for calculating + edge features. Shape: (num_edges, 3) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + + Returns: + Tensor: The output tensor. Shape: (num_atoms, hidden_dim) + """ + node_input = node_features + node_features = mul_mask(node_features, node_mask) + node_features = self.layer_norm(node_features) + edge_features = self.edge_model(node_features, lattice_polar, + edge_index, edge2graph, frac_diff, + edge_mask) + node_output = self.node_model(node_features, edge_features, edge_index, + edge_mask) + resiual_output = mul_mask(node_input + node_output, node_mask) + return resiual_output + + +class CSPNet(nn.Cell): + """GNN denoiser for CrystalFlow. + """ + + def __init__(self, + hidden_dim=512, + latent_dim=256, + num_layers=6, + max_atoms=100, + num_freqs=128): + """Initialization + + Args: + hidden_dim (int): The dimension of hidden node features. Defaults to 512. + latent_dim (int): The dimension of time embedding. Defaults to 256. + num_layers (int): The number of layers used in GNN. Defaults to 6. + max_atoms (int): The number of embedding table lines for atom types. Defaults to 100. + num_freqs (int): The number of frequencies for Fourier embedding for + edge features. Defaults to 128. + """ + super(CSPNet, self).__init__() + self.node_embedding = nn.Embedding(max_atoms, hidden_dim) + self.atom_latent_emb = nn.Dense(hidden_dim + latent_dim, hidden_dim) + self.act_fn = nn.SiLU() + self.dis_emb = SinusoidsEmbedding(n_frequencies=num_freqs) + + self.csp_layers = nn.CellList([ + CSPLayer(hidden_dim, self.act_fn, self.dis_emb) + for _ in range(num_layers) + ]) + + self.num_layers = num_layers + self.coord_out = nn.Dense(hidden_dim, 3, has_bias=False) + self.lattice_out = nn.Dense(hidden_dim, 6, has_bias=False) + self.final_layer_norm = nn.LayerNorm([hidden_dim]) + + self.node_scatter = AggregateNodeToGlobal(mode='mean') + self.lift_node = LiftGlobalToNode() + + def construct(self, t, atom_types, frac_coords, lattice_polar, node2graph, + edge_index, node_mask, edge_mask): + """Apply GNN over noised fractional coordinates and lattice matrix. + + Args: + t (Tensor): Time embeddind features. Shape: (batchsize, latent_dim) + atom_types (Tensor): Atom types. Shape: (num_atoms,) + frac_coords (Tensor): Fractional coordinates. Shape: (num_atoms, 3) + lattices (Tensor): Lattice mattrix. Shape: (batchsize, 3, 3) + node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + edge_index (Tensor): Edge index for aggregating the edge features. Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + + Returns: + Tuple(Tensor,Tensor): Node features for fractional coordinates denoising terms and + graph features for lattice matrix denoising terms. + """ + edge_src = edge_index[0] + edge_dst = edge_index[1] + frac_diff = (frac_coords[edge_dst] - frac_coords[edge_src]) % 1 + edge2graph = ops.Gather()(node2graph, edge_index[0], 0) + + node_features = self.node_embedding(atom_types - 1) + node_features = mul_mask(node_features, node_mask) + + t_per_atom = self.lift_node(t, node2graph, mask=node_mask) + + node_features = ops.Concat(axis=1)((node_features, t_per_atom)) + + node_features = self.atom_latent_emb(node_features) + + for i in range(self.num_layers): + node_features = self.csp_layers[i](node_features, + lattice_polar, edge_index, + edge2graph, frac_diff, + node_mask, edge_mask) + + node_features = mul_mask(node_features, node_mask) + node_features = self.final_layer_norm(node_features) + node_features = mul_mask(node_features, node_mask) + + coord_out = self.coord_out(node_features) + + graph_features = self.node_scatter(node_features, + node2graph, + dim_size=lattice_polar.shape[0], + mask=node_mask) + lattice_out = self.lattice_out(graph_features) + + return lattice_out, coord_out diff --git a/MindChemistry/applications/crystalflow/models/cspnet_condition.py b/MindChemistry/applications/crystalflow/models/cspnet_condition.py new file mode 100644 index 000000000..9ae1c2cdd --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/cspnet_condition.py @@ -0,0 +1,299 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""GNN denoiser file""" +import math + +import numpy as np +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +from mindchemistry.graph.graph import (AggregateEdgeToNode, + AggregateNodeToGlobal, LiftGlobalToNode) + +MAX_ATOMIC_NUM = 100 + + +class SinusoidsEmbedding(nn.Cell): + """ + + Fourier embedding for edge features to address periodic translation invariance + as described in the paper of CrystalFlow. + + Args: + n_frequencies (int): The number of frequencies for embedding. + n_space (int): The dimension of edge feature. + """ + + def __init__(self, n_frequencies=10, n_space=3): + super(SinusoidsEmbedding, self).__init__() + self.n_frequencies = n_frequencies + self.n_space = n_space + self.frequencies = 2 * math.pi * np.arange(self.n_frequencies) + self.dim = self.n_frequencies * 2 * self.n_space + + def construct(self, x): + """construct + + Args: + x (Tensor): Distance + + Returns: + Tensor: Fourier embedding + """ + emb = ops.ExpandDims()(x, -1) * Tensor( + self.frequencies, + dtype=mindspore.float32).expand_dims(0).expand_dims(0) + emb = ops.Reshape()(emb, (-1, self.n_frequencies * self.n_space)) + emb = ops.Concat(axis=-1)((ops.Sin()(emb), ops.Cos()(emb))) + return emb + + +def mul_mask(features, mask): + """Make the padded dim of features to be zeros + + Args: + features (Tensor): Input tensor + mask (Tensor): Value 1 specifies the corresponding dimension of input tensor to be valid, + and value 0 specifies the corresponding dimension of input tensor to be zero. + + Returns: + Tensor: Output tensor + """ + return ops.mul(features, ops.reshape(mask, (-1, 1))) + + +class CSPLayer(nn.Cell): + r"""One layer of the GNN denoiser. For the input node feature + :math:`h_i^{(s-1)}` from last layer, the lattice matrix + :math:`L` in the polar-decomposition form, + and the fractional coordinates :math:`f_i`, the formula is defined as: + + .. math:: + + m_{ij}^{(s)} = mlp_m(h_i^{(s-1)}, h_j^{(s-1)}, L, \text{edge_emb}(f_j - f_i)), + + m_{i}^{(s)} = \sum_{j=1}^N m_{ij}^{(s)}, + + h_{i}^{(s)} = h_i^{(s-1)} + mlp_h(h_i^{(s-1)}, m_{i}^{(s)}). + + ... + + Then we can get the new node features `h_i^{(s)}`. + + """ + + def __init__(self, hidden_dim=512, act_fn=nn.SiLU(), dis_emb=None): + """Initialization + + Args: + hidden_dim (int): The dimension of hidden node features. Defaults to 512. + act_fn (nn): The activation function used in the layer. Defaults to nn.SiLU(). + dis_emb (object): The embbing method used for edge features. Defaults to None. + """ + super(CSPLayer, self).__init__() + self.dis_dim = 3 + self.dis_emb = dis_emb + if dis_emb is not None: + self.dis_dim = dis_emb.dim + self.edge_mlp = nn.SequentialCell([ + nn.Dense(hidden_dim * 2 + 6 + self.dis_dim, hidden_dim), act_fn, + nn.Dense(hidden_dim, hidden_dim), act_fn + ]) + self.node_mlp = nn.SequentialCell([ + nn.Dense(hidden_dim * 2, hidden_dim), act_fn, + nn.Dense(hidden_dim, hidden_dim), act_fn + ]) + + self.layer_norm = nn.LayerNorm([hidden_dim], epsilon=1e-5) + + self.edge_scatter = AggregateEdgeToNode(mode='mean', dim=0) + + def edge_model(self, node_features, lattice_polar, edge_index, + edge2graph, frac_diff, edge_mask): + """Edge embbding for edge feature. + """ + hi, hj = node_features[edge_index[0]], node_features[edge_index[1]] + + frac_diff = self.dis_emb(frac_diff) + + #lattice_ips = ops.BatchMatMul()(lattices, + # lattices.transpose(0, -1, -2)) + + #lattice_ips_flatten = ops.Reshape()(lattice_ips, (-1, 9)) + lattice_polar = ops.Gather()(lattice_polar, edge2graph, 0) + + edges_input = ops.Concat(axis=1)( + (hi, hj, frac_diff, lattice_polar)) + edges_input = mul_mask(edges_input, edge_mask) + + edge_features = self.edge_mlp(edges_input) + return edge_features + + def node_model(self, node_features, edge_features, edge_index, edge_mask): + """Aggregate the edge features to be the node features. + """ + agg = self.edge_scatter(edge_features, + edge_index, + dim_size=node_features.shape[0], + mask=edge_mask) + agg = ops.Concat(axis=1)((node_features, agg)) + out = self.node_mlp(agg) + return out + + def construct(self, node_features, lattice_polar, edge_index, + edge2graph, frac_diff, node_mask, edge_mask): + """Apply GNN layer over node features from last layer. + + Args: + node_features (Tensor): Node features from last layer. Shape: (num_atoms, hidden_dim) + frac_coords (Tensor): Fractional coordinates for calculating edge features. + Shape: (num_atoms, 3) + lattices (Tensor): Lattice mattrix for calculating edge features. + Shape: (batchsize, 3, 3) + edge_index (Tensor): Edge index for aggregating the edge features. + Shape: (2, num_edges) + edge2graph (Tensor): Graph index to lift the lattice to edge features. + Shape: (num_edges,) + frac_diff (Tensor): Distance of fractional coordinates for calculating + edge features. Shape: (num_edges, 3) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + + Returns: + Tensor: The output tensor. Shape: (num_atoms, hidden_dim) + """ + node_input = node_features + node_features = mul_mask(node_features, node_mask) + node_features = self.layer_norm(node_features) + edge_features = self.edge_model(node_features, lattice_polar, + edge_index, edge2graph, frac_diff, + edge_mask) + node_output = self.node_model(node_features, edge_features, edge_index, + edge_mask) + resiual_output = mul_mask(node_input + node_output, node_mask) + return resiual_output + + +class CSPNet(nn.Cell): + """GNN denoiser for CrystalFlow. + """ + + def __init__(self, + hidden_dim=512, + latent_dim=256, + num_layers=6, + max_atoms=100, + num_freqs=128, + cemb_dim=1, + ): + """Initialization + + Args: + hidden_dim (int): The dimension of hidden node features. Defaults to 512. + latent_dim (int): The dimension of time embedding. Defaults to 256. + num_layers (int): The number of layers used in GNN. Defaults to 6. + max_atoms (int): The number of embedding table lines for atom types. Defaults to 100. + num_freqs (int): The number of frequencies for Fourier embedding for + edge features. Defaults to 128. + cemb_dim (int): The dimension of condition embedding vector. Defaults to 1. + """ + super(CSPNet, self).__init__() + self.node_embedding = nn.Embedding(max_atoms, hidden_dim) + self.atom_latent_emb = nn.Dense(hidden_dim + latent_dim, hidden_dim) + self.act_fn = nn.SiLU() + self.dis_emb = SinusoidsEmbedding(n_frequencies=num_freqs) + + self.csp_layers = nn.CellList([ + CSPLayer(hidden_dim, self.act_fn, self.dis_emb) + for _ in range(num_layers) + ]) + self.cemb_mixin = nn.CellList([ + nn.Dense(hidden_dim, hidden_dim, has_bias=False) + for _ in range(num_layers)]) + self.cemb_adapter = nn.CellList([ + nn.SequentialCell([ + nn.Dense(cemb_dim, hidden_dim), + self.act_fn, + nn.Dense(hidden_dim, hidden_dim), + self.act_fn,]) + for _ in range(num_layers) + ]) + + self.num_layers = num_layers + self.coord_out = nn.Dense(hidden_dim, 3, has_bias=False) + self.lattice_out = nn.Dense(hidden_dim, 6, has_bias=False) + self.final_layer_norm = nn.LayerNorm([hidden_dim]) + + self.node_scatter = AggregateNodeToGlobal(mode='mean') + self.lift_node = LiftGlobalToNode() + + def construct(self, t, atom_types, frac_coords, lattice_polar, node2graph, + edge_index, node_mask, edge_mask, cemb): + """Apply GNN over noised fractional coordinates and lattice matrix. + + Args: + t (Tensor): Time embeddind features. Shape: (batchsize, latent_dim) + atom_types (Tensor): Atom types. Shape: (num_atoms,) + frac_coords (Tensor): Fractional coordinates. Shape: (num_atoms, 3) + lattices_polar (Tensor): Lattice mattrix. Shape: (batchsize, 6) + node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + edge_index (Tensor): Edge index for aggregating the edge features. Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + cemb (Tensor): Embedded condition vector. Shape: (batchsize, cemb_dim) + + Returns: + Tuple(Tensor,Tensor): Node features for fractional coordinates denoising terms and + graph features for lattice matrix denoising terms. + """ + edge_src = edge_index[0] + edge_dst = edge_index[1] + frac_diff = (frac_coords[edge_dst] - frac_coords[edge_src]) % 1 + edge2graph = ops.Gather()(node2graph, edge_index[0], 0) + + node_features = self.node_embedding(atom_types - 1) + node_features = mul_mask(node_features, node_mask) + + t_per_atom = self.lift_node(t, node2graph, mask=node_mask) + + node_features = ops.Concat(axis=1)((node_features, t_per_atom)) + + node_features = self.atom_latent_emb(node_features) + + cemb = self.lift_node(cemb, node2graph, mask=node_mask) + + for i in range(self.num_layers): + cemb_bias = self.cemb_mixin[i](self.cemb_adapter[i](cemb)) + node_features = node_features + cemb_bias + #csp layers + node_features = self.csp_layers[i](node_features, + lattice_polar, edge_index, + edge2graph, frac_diff, + node_mask, edge_mask) + + node_features = mul_mask(node_features, node_mask) + node_features = self.final_layer_norm(node_features) + node_features = mul_mask(node_features, node_mask) + + coord_out = self.coord_out(node_features) + + graph_features = self.node_scatter(node_features, + node2graph, + dim_size=lattice_polar.shape[0], + mask=node_mask) + lattice_out = self.lattice_out(graph_features) + + return lattice_out, coord_out diff --git a/MindChemistry/applications/crystalflow/models/diff_utils.py b/MindChemistry/applications/crystalflow/models/diff_utils.py new file mode 100644 index 000000000..e74e20a00 --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/diff_utils.py @@ -0,0 +1,175 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""diffution utils file""" +import math + +import mindspore +import mindspore.nn as nn +import mindspore.ops as ops +from mindspore import Tensor +import numpy as np + +def mindspore_random_choice(low, high, size): + """ Mimic np.random.choice for integers in MindSpore """ + indices = ops.UniformInt()((size,), low, high) + return indices + +def cosine_beta_schedule(timesteps, s=0.008): + """ + The beta scheduled by cosine in DDPM used for lattice diffution. + See details in the paper of DiffCSP. + """ + steps = timesteps + 1 + x = np.linspace(0, timesteps, steps) + alphas_cumprod = np.cos(((x / timesteps) + s) / (1 + s) * math.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return np.clip(betas, 0.0001, 0.9999) + +def linear_beta_schedule(timesteps, beta_start, beta_end): + """ + The beta scheduled by linear in DDPM. + """ + return np.linspace(beta_start, beta_end, timesteps) + +def quadratic_beta_schedule(timesteps, beta_start, beta_end): + """ + The beta scheduled by quadratic in DDPM. + """ + return np.linspace(beta_start**0.5, beta_end**0.5, timesteps) ** 2 + +def sigmoid_beta_schedule(timesteps, beta_start, beta_end): + """ + The beta scheduled by sigmoid in DDPM. + """ + betas = np.linspace(-6, 6, timesteps) + return 1 / (1 + np.exp(-betas)) * (beta_end - beta_start) + beta_start + +def p_wrapped_normal(x, sigma, n=10, t=1.0): + """Utils for calcatating the score of wrapped normal distribution. + """ + p_ = 0 + for i in range(-n, n + 1): + p_ += np.exp(-(x + t * i) ** 2 / 2 / sigma ** 2) + return p_ + +def d_log_p_wrapped_normal(x, sigma, n=10, t=1.0): + """The score of wrapped normal distribution, which is parameterized by sigma, + for the input value x. See details in Appendix B.1 in the paper of DiffCSP. + + Args: + x (numpy.ndarray): Input noise. + sigma (numpy.ndarray): The variance of wrapped normal distribution. + n (int): The approximate parameter of the score of wrapped normal distribution. Defaults to 10. + t (int): The period of wrapped normal distribution. Defaults to 1.0. + + Returns: + numpy.ndarray: The score for the input value x. + """ + p_ = 0 + for i in range(-n, n + 1): + p_ += (x + t * i) / sigma ** 2 * np.exp(-(x + t * i) ** 2 / 2 / sigma ** 2) + return p_ / p_wrapped_normal(x, sigma, n, t) + +def sigma_norm(sigma, t=1.0, sn=10000): + r"""Monte-Carlo sampling for :math`\lambda_t`. + See details in Appendix B.1 in the paper of DiffCSP. + """ + sigmas = np.tile(sigma[None, :], (sn, 1)) + x_sample = sigma * np.random.standard_normal(sigmas.shape) + x_sample = x_sample % t + normal_ = d_log_p_wrapped_normal(x_sample, sigmas, t=t) + return (normal_ ** 2).mean(axis=0) + +def p_wrapped_normal_ms(x, sigma, n=10, t=1.0): + """Utils for calcatating the score of wrapped normal distribution. + """ + p_ = 0 + for i in range(-n, n + 1): + p_ += ops.Exp()(-(x + t * i) ** 2 / 2 / sigma ** 2) + return p_ + +def d_log_p_wrapped_normal_ms(x, sigma, n=10, t=1.0): + """The score of wrapped normal distribution, which is parameterized by sigma, + for the input value x. See details in Appendix B.1 in the paper of DiffCSP. + + Args: + x (Tensor): Input noise. + sigma (Tensor): The variance of wrapped normal distribution. + n (int): The approximate parameter of the score of wrapped normal distribution. Defaults to 10. + t (int): The period of wrapped normal distribution. Defaults to 1.0. + + Returns: + Tensor: The score for the input value x. + """ + p_ = 0 + for i in range(-n, n + 1): + p_ += (x + t * i) / sigma ** 2 * ops.Exp()(-(x + t * i) ** 2 / 2 / sigma ** 2) + return p_ / p_wrapped_normal_ms(x, sigma, n, t) + +class BetaScheduler(nn.Cell): + """ + The alpha, alphas_cumprod and beta in DDPM used for lattice diffution. + """ + def __init__(self, timesteps, scheduler_mode, beta_start=0.0001, beta_end=0.02): + super(BetaScheduler, self).__init__() + self.timesteps = Tensor(timesteps, mindspore.int32) + self.timesteps_begin = Tensor(1, mindspore.int32) + if scheduler_mode == 'cosine': + betas = cosine_beta_schedule(timesteps) + elif scheduler_mode == 'linear': + betas = linear_beta_schedule(timesteps, beta_start, beta_end) + elif scheduler_mode == 'quadratic': + betas = quadratic_beta_schedule(timesteps, beta_start, beta_end) + elif scheduler_mode == 'sigmoid': + betas = sigmoid_beta_schedule(timesteps, beta_start, beta_end) + + betas = np.concatenate([np.zeros([1]), betas], axis=0) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + + sigmas = np.zeros_like(betas) + sigmas[1:] = betas[1:] * (1. - alphas_cumprod[:-1]) / (1. - alphas_cumprod[1:]) + sigmas = np.sqrt(sigmas) + + self.betas = Tensor(betas, mindspore.float32) + self.alphas = Tensor(alphas, mindspore.float32) + self.alphas_cumprod = Tensor(alphas_cumprod, mindspore.float32) + self.sigmas = Tensor(sigmas, mindspore.float32) + + def uniform_sample_t(self, batch_size): + return mindspore_random_choice(self.timesteps_begin, self.timesteps + 1, batch_size) + +class SigmaScheduler(nn.Cell): + r""" + The sigmas and :math`\lambda_t` in SDEs used for fractional coordinates diffution. + """ + def __init__(self, timesteps, sigma_begin=0.01, sigma_end=1.0): + super(SigmaScheduler, self).__init__() + self.timesteps = Tensor(timesteps, mindspore.int32) + self.timesteps_begin = Tensor(1, mindspore.int32) + self.sigma_begin = sigma_begin + self.sigma_end = sigma_end + sigmas = np.exp(np.linspace(np.log(sigma_begin), np.log(sigma_end), timesteps)) + sigmas_norm_ = sigma_norm(sigmas) + + sigmas = np.concatenate([np.zeros([1]), sigmas], axis=0) + sigmas_norm = np.concatenate([np.ones([1]), sigmas_norm_], axis=0) + + self.sigmas = Tensor(sigmas, mindspore.float32) + self.sigmas_norm = Tensor(sigmas_norm, mindspore.float32) + + def uniform_sample_t(self, batch_size): + return mindspore_random_choice(self.timesteps_begin, self.timesteps + 1, batch_size) diff --git a/MindChemistry/applications/crystalflow/models/flow.py b/MindChemistry/applications/crystalflow/models/flow.py new file mode 100644 index 000000000..821f3520b --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/flow.py @@ -0,0 +1,284 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""flow file""" +import math + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import nn, ops +from mindchemistry.graph.graph import (AggregateNodeToGlobal, LiftGlobalToNode) + +from models.lattice import LatticePolarDecomp + + +def replace_nan_with_zero(tensor): + """Replace nan in tensor with zero to avoid numerical errors. + """ + is_nan = ops.IsNan()(tensor) + zeros = ops.Fill()(ms.float32, tensor.shape, 0.0) + result = ops.Select()(is_nan, zeros, tensor) + return result + + +class SinusoidalTimeEmbeddings(nn.Cell): + """ Embedding for the time step in flow. + Referring the implementation details in the paper Attention is all you need. """ + + def __init__(self, dim): + super(SinusoidalTimeEmbeddings, self).__init__() + self.dim = dim + + def construct(self, time): + """construct + + Args: + time (Tensor): flow time step + + Returns: + Tensor: Time embedding + """ + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = ops.Exp()(mnp.arange(half_dim) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = ops.Concat(axis=-1)( + (ops.Sin()(embeddings), ops.Cos()(embeddings))) + return embeddings + + +def lattice_params_to_matrix_mindspore(lengths, angles): + """Batched MindSpore version to compute lattice matrix from params. + + Args: + lengths (Tensor): Tensor of shape (N, 3), unit A + angles (Tensor):: Tensor of shape (N, 3), unit degree + Returns: + Tensor: Tensor of shape (N, 3, 3) + """ + angles_r = ops.deg2rad(angles) + coses = ops.cos(angles_r) + sins = ops.sin(angles_r) + + val = (coses[:, 0] * coses[:, 1] - coses[:, 2]) / (sins[:, 0] * sins[:, 1]) + # Sometimes rounding errors result in values slightly > 1. + val = ops.clip_by_value(val, -1., 1.) + gamma_star = ops.acos(val) + + zero_tensor = ops.zeros((lengths.shape[0],)) + + vector_a = ops.stack( + [lengths[:, 0] * sins[:, 1], zero_tensor, lengths[:, 0] * coses[:, 1]], + axis=1) + + vector_b = ops.stack([ + -lengths[:, 1] * sins[:, 0] * ops.cos(gamma_star), lengths[:, 1] * + sins[:, 0] * ops.sin(gamma_star), lengths[:, 1] * coses[:, 0] + ], + axis=1) + + vector_c = ops.stack([zero_tensor, zero_tensor, lengths[:, 2]], axis=1) + + return ops.stack([vector_a, vector_b, vector_c], axis=1) + + +class CSPFlow(nn.Cell): + """Flow model used in CrystalFlow + """ + + def __init__(self, + decoder, + time_dim=256, + sigma=0.1): + """Initialization + + Args: + decoder (nn.cell): Nerual network as denoiser for flow. + time_dim (int): The dimension of time embedding. Defaults to 256. + sigma (float): the standard deviation of Gaussian prior where lattice_polar_0 is sampled + """ + super(CSPFlow, self).__init__() + self.time_dim = time_dim + self.time_embedding = SinusoidalTimeEmbeddings(self.time_dim) + self.lattice_model = LatticePolarDecomp() + self.lift_node = LiftGlobalToNode() + self.aggre_graph = AggregateNodeToGlobal('mean') + self.decoder = decoder + self.sigma = sigma + self.relu = nn.ReLU() + + def construct(self, batch_num_graphs, batch_atom_types, batch_lengths, + batch_angles_step, batch_lattice_polar, batch_num_atoms_step, + batch_frac_coords, batch_node2graph, + batch_edge_index, node_mask, edge_mask, batch_mask): + """Training process for diffution. + + Args: + batch_num_graphs (Tensor): Batch size with shape (1,) + batch_atom_types (Tensor): Atom types of nodes in a batch of graph. Shape: (num_atoms,) + batch_lengths (Tensor): Lattices lengths in a batch of graph. Shape: (batchsize, 3) + batch_angles (Tensor): Lattice angles in a batch of graph. Shape: (batchsize, 3) + batch_frac_coords (Tensor): Fractional coordinates of nodes in + a batch of graph. (num_atoms, 3) + batch_node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + batch_edge_index (Tensor): Beginning and ending node index for each edge. + Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + batch_mask (Tensor): Graph mask for padded tensor. Shape: (batchsize,) + + Returns: + Tuple(Tensor, Tensor, Tensor, Tensor): Return the ground truth + and predicted flow terms of lattice polar and fractional + coordinates respectively. + """ + _, _, _, _ = batch_num_graphs, batch_angles_step, batch_num_atoms_step, batch_mask + times = ops.rand(batch_lengths.shape[0]) + time_emb = self.time_embedding(times) + + lattice_polar = batch_lattice_polar + frac_coords = batch_frac_coords + + lattice_polar_0 = self.lattice_model.sample(batch_lengths.shape[0], self.sigma) + frac_coords_0 = ops.rand_like(frac_coords) + + tar_l = lattice_polar - lattice_polar_0 + tar_f = (frac_coords - frac_coords_0 - 0.5) % 1 - 0.5 + + tar_f = ops.mul(tar_f, ops.reshape(node_mask, (-1, 1))) + + l_expand_dim = (slice(None),) + (None,) * (tar_l.dim() - 1) # in this case is (:, None, None) + input_lattice = lattice_polar_0 + times[l_expand_dim] * tar_l + input_frac_coords = frac_coords_0 + self.lift_node(times[:, None], batch_node2graph) * tar_f + + + + #flow + pred_l, pred_f = self.decoder(time_emb, batch_atom_types, + input_frac_coords, input_lattice, + batch_node2graph, batch_edge_index, + node_mask, edge_mask) + + + + return pred_l, tar_l, pred_f, tar_f + + #sample and evaluate + + def get_anneal_factor(self, t, slope: float = 0.0, offset: float = 0.0): + if not isinstance(t, ms.Tensor): + t = ms.tensor(t) + return 1 + slope * self.relu(t - offset) + + def post_decoder_on_sample( + self, pred, t, + anneal_slope=0.0, anneal_offset=0.0, + ): + """apply anneal to pred_f""" + + pred_l, pred_f = pred + anneal_factor = self.get_anneal_factor(t, anneal_slope, anneal_offset) + + pred_f *= anneal_factor + + return pred_l, pred_f + + def sample(self, + batch_node2graph, + node_mask, + edge_mask, + batch_mask, + batch_atom_types, + batch_edge_index, + batch_num_atoms, + n=1000, + anneal_slope=0.0, + anneal_offset=0.0): + """Generation process of flow. Note: For simplicity, we use x instead of frac_coords and + l instead of lattice. + + Args: + batch_atom_types (Tensor): Atom types of nodes in a batch of graph. Shape: (num_atoms,) + batch_node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + batch_edge_index (Tensor): Beginning and ending node index for each edge. + Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + batch_mask (Tensor): Graph mask for padded tensor. Shape: (batchsize,) + N (int): the steps of flow + anneal_slope(float): + anneal_offset(float): + Returns: + Tuple(dict, Tensor, Tensor): Return the traj of flow process, the fractional coordinates and + generated lattice matrix for the input atom types of each crystal. + """ + batch_size_pad = batch_mask.shape[0] + num_node_pad = node_mask.shape[0] #shape: (2819,) where 2819 is the largest numbers of atoms in evry batches + + l_0 = self.lattice_model.sample(batch_size_pad, self.sigma) + x_0 = ops.UniformReal()((num_node_pad, 3)) % 1.0 + + + l_t = l_0 + x_t = x_0 + l_mat_t = LatticePolarDecomp().build(l_t) + traj = { + 0: { + 'num_atoms': batch_num_atoms, + 'atom_types': batch_atom_types, + 'frac_coords': x_t, + 'lattices': l_mat_t, + } + } + for t in range(1, n+1): + t_stamp = t / n + times = ops.Fill()(ms.float32, (batch_size_pad,), t_stamp) + time_emb = self.time_embedding(times) + # ========= pred each step start ========= + + pred = self.decoder(time_emb, batch_atom_types, x_t, l_t, + batch_node2graph, batch_edge_index, + node_mask, edge_mask) + + pred = self.post_decoder_on_sample( + pred, + t=t_stamp, + anneal_slope=anneal_slope, anneal_offset=anneal_offset, + ) + + pred_l, pred_f = pred + + # ========= pred each step end ========= + + + # ========= update each step start ========= + + l_t = l_t + pred_l / n + x_t = x_t + pred_f / n + x_t = x_t % 1.0 + + # ========= update each step end ========= + + # ========= build trajectory start ========= + l_mat_t = LatticePolarDecomp().build(l_t) + traj[t] = { + t: { + 'num_atoms': batch_num_atoms, + 'atom_types': batch_atom_types, + 'frac_coords': x_t, + 'lattices': l_mat_t, + } + } + + return traj, x_t, l_mat_t diff --git a/MindChemistry/applications/crystalflow/models/flow_condition.py b/MindChemistry/applications/crystalflow/models/flow_condition.py new file mode 100644 index 000000000..938012b7c --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/flow_condition.py @@ -0,0 +1,295 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""flow file""" +import math + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import nn, ops +from mindchemistry.graph.graph import (AggregateNodeToGlobal, LiftGlobalToNode) + +from models.lattice import LatticePolarDecomp +import numpy as np + + +def replace_nan_with_zero(tensor): + """Replace nan in tensor with zero to avoid numerical errors. + """ + is_nan = ops.IsNan()(tensor) + zeros = ops.Fill()(ms.float32, tensor.shape, 0.0) + result = ops.Select()(is_nan, zeros, tensor) + return result + + +class SinusoidalTimeEmbeddings(nn.Cell): + """ Embedding for the time step in flow. + Referring the implementation details in the paper Attention is all you need. """ + + def __init__(self, dim): + super(SinusoidalTimeEmbeddings, self).__init__() + self.dim = dim + + def construct(self, time): + """construct + + Args: + time (Tensor): flow time step + + Returns: + Tensor: Time embedding + """ + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = ops.Exp()(mnp.arange(half_dim) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = ops.Concat(axis=-1)( + (ops.Sin()(embeddings), ops.Cos()(embeddings))) + return embeddings + + +def lattice_params_to_matrix_mindspore(lengths, angles): + """Batched MindSpore version to compute lattice matrix from params. + + Args: + lengths (Tensor): Tensor of shape (N, 3), unit A + angles (Tensor):: Tensor of shape (N, 3), unit degree + Returns: + Tensor: Tensor of shape (N, 3, 3) + """ + angles_r = ops.deg2rad(angles) + coses = ops.cos(angles_r) + sins = ops.sin(angles_r) + + val = (coses[:, 0] * coses[:, 1] - coses[:, 2]) / (sins[:, 0] * sins[:, 1]) + # Sometimes rounding errors result in values slightly > 1. + val = ops.clip_by_value(val, -1., 1.) + gamma_star = ops.acos(val) + + zero_tensor = ops.zeros((lengths.shape[0],)) + + vector_a = ops.stack( + [lengths[:, 0] * sins[:, 1], zero_tensor, lengths[:, 0] * coses[:, 1]], + axis=1) + + vector_b = ops.stack([ + -lengths[:, 1] * sins[:, 0] * ops.cos(gamma_star), lengths[:, 1] * + sins[:, 0] * ops.sin(gamma_star), lengths[:, 1] * coses[:, 0] + ], + axis=1) + + vector_c = ops.stack([zero_tensor, zero_tensor, lengths[:, 2]], axis=1) + + return ops.stack([vector_a, vector_b, vector_c], axis=1) + + +class CSPFlow(nn.Cell): + """Flow model used in CrystalFlow + """ + + def __init__(self, + decoder, + cond_emb_model, + time_dim=256, + sigma=0.1): + """Initialization + + Args: + decoder (nn.cell): Nerual network as vector field for flow. + cond_emb_model (nn.cell): Neural network for creating condition embedding vector. + time_dim (int): The dimension of time embedding. Defaults to 256. + sigma (float): the standard deviation of Gaussian prior where lattice_polar_0 is sampled + """ + super(CSPFlow, self).__init__() + + self.time_dim = time_dim + self.time_embedding = SinusoidalTimeEmbeddings(self.time_dim) + self.lattice_model = LatticePolarDecomp() + self.lift_node = LiftGlobalToNode() + self.aggre_graph = AggregateNodeToGlobal('mean') + self.decoder = decoder + self.sigma = sigma + self.relu = nn.ReLU() + self.cond_emb = cond_emb_model + + def construct(self, batch_num_graphs, batch_atom_types, batch_lengths, + batch_lattice_polar, batch_frac_coords, batch_node2graph, + batch_edge_index, node_mask, edge_mask, condition): + """Training process for diffution. + + Args: + batch_num_graphs (Tensor): Batch size with shape (1,) + batch_atom_types (Tensor): Atom types of nodes in a batch of graph. Shape: (num_atoms,) + batch_lengths (Tensor): Lattices lengths in a batch of graph. Shape: (batchsize, 3) + batch_angles (Tensor): Lattice angles in a batch of graph. Shape: (batchsize, 3) + batch_lattice_polar (Tensor): lattice of polar-decomposition representation. Shape: (batchsize, 6) + batch_frac_coords (Tensor): Fractional coordinates of nodes in + a batch of graph. (num_atoms, 3) + batch_node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + batch_edge_index (Tensor): Beginning and ending node index for each edge. + Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + batch_mask (Tensor): Graph mask for padded tensor. Shape: (batchsize,) + condition (Tensor): Condition variable. Must match with cond_emb_model. Shape: (batchsize, 1) + + Returns: + Tuple(Tensor, Tensor, Tensor, Tensor): Return the ground truth + and predicted flow terms of lattice polar and fractional + coordinates respectively. + """ + _ = batch_num_graphs + #times = ops.rand(batch_lengths.shape[0]) + times = np.random.rand(batch_lengths.shape[0]) + times = ms.tensor(times, dtype=ms.float32) + time_emb = self.time_embedding(times) + cemb = self.cond_emb(condition) + + lattice_polar = batch_lattice_polar + frac_coords = batch_frac_coords + + #lattice_polar_0 = self.lattice_model.sample(batch_lengths.shape[0], self.sigma) + #frac_coords_0 = ops.rand_like(frac_coords) + lattice_polar_0 = self.lattice_model.sample_numpy(batch_lengths.shape[0], self.sigma) + frac_coords_0 = self.lattice_model.rand_like_numpy(frac_coords) + + tar_l = lattice_polar - lattice_polar_0 + tar_f = (frac_coords - frac_coords_0 - 0.5) % 1 - 0.5 + + tar_f = ops.mul(tar_f, ops.reshape(node_mask, (-1, 1))) + + l_expand_dim = (slice(None),) + (None,) * (tar_l.dim() - 1) # in this case is (:, None, None) + input_lattice = lattice_polar_0 + times[l_expand_dim] * tar_l + input_frac_coords = frac_coords_0 + self.lift_node(times[:, None], batch_node2graph) * tar_f + + + + #flow + pred_l, pred_f = self.decoder(time_emb, batch_atom_types, + input_frac_coords, input_lattice, + batch_node2graph, batch_edge_index, + node_mask, edge_mask, cemb) + + + + return pred_l, tar_l, pred_f, tar_f + + #sample and evaluate + + def get_anneal_factor(self, t, slope: float = 0.0, offset: float = 0.0): + if not isinstance(t, ms.Tensor): + t = ms.tensor(t) + return 1 + slope * self.relu(t - offset) + + def post_decoder_on_sample( + self, pred, t, + anneal_slope=0.0, anneal_offset=0.0, + ): + """apply anneal to pred_f""" + + pred_l, pred_f = pred + anneal_factor = self.get_anneal_factor(t, anneal_slope, anneal_offset) + + pred_f *= anneal_factor + + return pred_l, pred_f + + def sample(self, + batch_node2graph, + node_mask, + edge_mask, + batch_mask, + batch_atom_types, + batch_edge_index, + batch_num_atoms, + n=1000, + anneal_slope=0.0, + anneal_offset=0.0): + """Generation process of diffution. Note: For simplicity, we use x instead of frac_coords and + l instead of lattice. + + Args: + batch_atom_types (Tensor): Atom types of nodes in a batch of graph. Shape: (num_atoms,) + batch_node2graph (Tensor): Graph index for each node. Shape: (num_atoms,) + batch_edge_index (Tensor): Beginning and ending node index for each edge. + Shape: (2, num_edges) + node_mask (Tensor): Node mask for padded tensor. Shape: (num_atoms,) + edge_mask (Tensor): Edge mask for padded tensor. Shape: (num_edges,) + batch_mask (Tensor): Graph mask for padded tensor. Shape: (batchsize,) + N (int): the steps of flow + anneal_slope(float): + anneal_offset(float): + Returns: + Tuple(dict, Tensor, Tensor): Return the traj of flow process, the fractional coordinates and + generated lattice matrix for the input atom types of each crystal. + """ + batch_size_pad = batch_mask.shape[0] + num_node_pad = node_mask.shape[0] #shape: (2819,) where 2819 is the largest numbers of atoms in evry batches + + l_0 = self.lattice_model.sample(batch_size_pad, self.sigma) + x_0 = ops.UniformReal()((num_node_pad, 3)) % 1.0 + + + l_t = l_0 + x_t = x_0 + l_mat_t = LatticePolarDecomp().build(l_t) + traj = { + 0: { + 'num_atoms': batch_num_atoms, + 'atom_types': batch_atom_types, + 'frac_coords': x_t, + 'lattices': l_mat_t, + } + } + for t in range(1, n+1): + t_stamp = t / n + times = ops.Fill()(ms.float32, (batch_size_pad,), t_stamp) + time_emb = self.time_embedding(times) + # ========= pred each step start ========= + + pred = self.decoder(time_emb, batch_atom_types, x_t, l_t, + batch_node2graph, batch_edge_index, + node_mask, edge_mask) + + pred = self.post_decoder_on_sample( + pred, + t=t_stamp, + anneal_slope=anneal_slope, anneal_offset=anneal_offset, + ) + + pred_l, pred_f = pred + + # ========= pred each step end ========= + + + # ========= update each step start ========= + + l_t = l_t + pred_l / n + x_t = x_t + pred_f / n + x_t = x_t % 1.0 + + # ========= update each step end ========= + + # ========= build trajectory start ========= + l_mat_t = LatticePolarDecomp().build(l_t) + traj[t] = { + t: { + 'num_atoms': batch_num_atoms, + 'atom_types': batch_atom_types, + 'frac_coords': x_t, + 'lattices': l_mat_t, + } + } + + return traj, x_t, l_mat_t diff --git a/MindChemistry/applications/crystalflow/models/infer_utils.py b/MindChemistry/applications/crystalflow/models/infer_utils.py new file mode 100644 index 000000000..7ff160955 --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/infer_utils.py @@ -0,0 +1,95 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""inference utils file""" +import mindspore.ops as ops +import mindspore.numpy as np + +chemical_symbols = [ + # 0 + 'X', + # 1 + 'H', 'He', + # 2 + 'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne', + # 3 + 'Na', 'Mg', 'Al', 'Si', 'P', 'S', 'Cl', 'Ar', + # 4 + 'K', 'Ca', 'Sc', 'Ti', 'V', 'Cr', 'Mn', 'Fe', 'Co', 'Ni', 'Cu', 'Zn', + 'Ga', 'Ge', 'As', 'Se', 'Br', 'Kr', + # 5 + 'Rb', 'Sr', 'Y', 'Zr', 'Nb', 'Mo', 'Tc', 'Ru', 'Rh', 'Pd', 'Ag', 'Cd', + 'In', 'Sn', 'Sb', 'Te', 'I', 'Xe', + # 6 + 'Cs', 'Ba', 'La', 'Ce', 'Pr', 'Nd', 'Pm', 'Sm', 'Eu', 'Gd', 'Tb', 'Dy', + 'Ho', 'Er', 'Tm', 'Yb', 'Lu', + 'Hf', 'Ta', 'W', 'Re', 'Os', 'Ir', 'Pt', 'Au', 'Hg', 'Tl', 'Pb', 'Bi', + 'Po', 'At', 'Rn', + # 7 + 'Fr', 'Ra', 'Ac', 'Th', 'Pa', 'U', 'Np', 'Pu', 'Am', 'Cm', 'Bk', + 'Cf', 'Es', 'Fm', 'Md', 'No', 'Lr', + 'Rf', 'Db', 'Sg', 'Bh', 'Hs', 'Mt', 'Ds', 'Rg', 'Cn', 'Nh', 'Fl', 'Mc', + 'Lv', 'Ts', 'Og'] + +def count_consecutive_occurrences(lst): + """ + Return the number of consecutive occurrences of each digit in the list. + + Args: + lst (list): The input list + + Returns: + list: List of numbers of consecutive occurrences of each digit in the list. + """ + if not lst: + return [] + + counts = [] + current_count = 1 + + for i in range(1, len(lst)): + if lst[i] == lst[i - 1]: + current_count += 1 + else: + counts.append(current_count) + current_count = 1 + + counts.append(current_count) + + return counts + +def lattices_to_params_ms(lattices): + """Batched MindSpore version to compute lattice params from matrix. + + Args: + lattices (Tensor): Tensor of shape (N, 3, 3) + Returns: + lengths (Tensor): Tensor of shape (N, 3), unit A + angles (Tensor):: Tensor of shape (N, 3), unit degree + """ + + lengths = ops.sqrt(ops.reduce_sum(lattices ** 2, -1)) + + angles = ops.zeros_like(lengths) + + for i in range(3): + j = (i + 1) % 3 + k = (i + 2) % 3 + + cos_angle = ops.clamp(ops.reduce_sum(lattices[..., j, :] * lattices[..., k, :], -1) / + (lengths[..., j] * lengths[..., k]), -1.0, 1.0) + + angles[..., i] = ops.acos(cos_angle) * 180.0 / np.pi + + return lengths, angles diff --git a/MindChemistry/applications/crystalflow/models/lattice.py b/MindChemistry/applications/crystalflow/models/lattice.py new file mode 100644 index 000000000..8d237947b --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/lattice.py @@ -0,0 +1,61 @@ +"""lattice transform utils""" +import numpy as np +import mindspore as ms +from mindspore import nn, ops + + +def matrix_exp(m, n=20): # m: (B,3,3) + s = current_m = ops.eye(3, dtype=m.dtype)[None, :, :] + for i in range(1, n + 1): + current_m = (current_m @ m) / i # m^n/n! + s += current_m + return s + + +class LatticePolarDecomp(nn.Cell): + """class for transformation between lattice and lattice_polar""" + def decompose(self, matrix): # matrix as row vectors + """transform lattice to lattice_polar""" + a, u = ops.eig(matrix @ matrix.swapaxes(-1, -2)) + a, u = a.real(), u.real() + s = u @ (ops.diag_embed(a.log()) / 2) @ u.swapaxes(-1, -2) + + k0 = s[:, 0, 1] + k1 = s[:, 0, 2] + k2 = s[:, 1, 2] + k3 = (s[:, 0, 0] - s[:, 1, 1]) / 2 + k4 = (s[:, 0, 0] + s[:, 1, 1] - 2 * s[:, 2, 2]) / 6 + k5 = (s[:, 0, 0] + s[:, 1, 1] + s[:, 2, 2]) / 3 + k = ops.vstack([k0, k1, k2, k3, k4, k5]).swapaxes(-1, -2) + return k + + def build(self, vector): + k = vector + s0 = ops.stack([k[:, 3] + k[:, 4] + k[:, 5], k[:, 0], k[:, 1]], 1) # (B, 3) + s1 = ops.stack([k[:, 0], -k[:, 3] + k[:, 4] + k[:, 5], k[:, 2]], 1) # (B, 3) + s2 = ops.stack([k[:, 1], k[:, 2], -2 * k[:, 4] + k[:, 5]], 1) # (B, 3) + s = ops.stack([s0, s1, s2], 1) # (B, 3, 3) + exp_s = matrix_exp(s) # (B, 3, 3) + return exp_s + + def sample(self, batch_size, sigma, dtype=None): + v = ops.randn([batch_size, 6], dtype=dtype) * sigma + v[:, -1] = v[:, -1] + 1 + return v + + def sample_like(self, vector, sigma): + v = ops.randn_like(vector) * sigma + v[:, -1] = v[:, -1] + 1 + return v + + def sample_numpy(self, batch_size, sigma, dtype=ms.float32): + v = np.random.randn(batch_size, 6) * sigma + v[:, -1] = v[:, -1] + 1 + v = ms.Tensor(v, dtype=dtype) + return v + + def rand_like_numpy(self, vector, dtype=ms.float32): + #the numpy version of ops.rand_like + v = np.random.rand(*vector.shape) + v = ms.Tensor(v, dtype=dtype) + return v diff --git a/MindChemistry/applications/crystalflow/models/train_utils.py b/MindChemistry/applications/crystalflow/models/train_utils.py new file mode 100644 index 000000000..d0e0f3658 --- /dev/null +++ b/MindChemistry/applications/crystalflow/models/train_utils.py @@ -0,0 +1,117 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""training utils file""" + +import math +from mindspore import ops, nn, Parameter + +class LossRecord: + """LossRecord""" + + def __init__(self): + self.last_val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + self.reset() + + def reset(self): + """reset""" + self.last_val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, num=1): + """update""" + self.last_val = val + self.sum += val * num + self.count += num + self.avg = self.sum / self.count + +class RBFExpansion(nn.Cell): + """Expand interatomic distances with radial basis functions.""" + + def __init__( + self, + vmin=0, + vmax=8, + bins=40, + lengthscale=None, + ): + """Register torch parameters for RBF expansion.""" + super(RBFExpansion, self).__init__() + self.vmin = vmin + self.vmax = vmax + self.bins = bins + self.centers = Parameter(ops.linspace(self.vmin, self.vmax, self.bins), name="centers", requires_grad=False) + if lengthscale is None: + # SchNet-style + # set lengthscales relative to granularity of RBF expansion + self.lengthscale = ops.mean(ops.diff(self.centers)) + self.gamma = -1 / self.lengthscale + else: + self.lengthscale = lengthscale + self.gamma = -1 / (lengthscale ** 2) + + def construct(self, distance): + """Apply RBF expansion to interatomic distance tensor.""" + tmp1 = ops.unsqueeze(distance, dim=1) + tmp2 = tmp1 - self.centers + tmp3 = tmp2 ** 2 + tmp4 = self.gamma * tmp3 + res = ops.exp(tmp4) + return res + +class OneCycleLr(): + """one cycle learning rate scheduler""" + + def __init__(self, max_lr, steps_per_epoch, epochs, optimizer, pct_start=0.3, anneal_strategy="cos", + div_factor=25.0, final_div_factor=10000.0): + """init""" + self.max_lr = max_lr + self.steps_per_epoch = steps_per_epoch + self.epochs = epochs + self.optimizer = optimizer + self.pct_start = pct_start + self.anneal_strategy = anneal_strategy + self.div_factor = div_factor + self.final_div_factor = final_div_factor + self.current_step = 0 + + self.initial_lr = self.max_lr / self.div_factor + self.min_lr = self.initial_lr / self.final_div_factor + self.steps = self.steps_per_epoch * self.epochs + self.step_size_up = float(self.pct_start * self.steps) - 1 + self.step_size_down = float(2 * self.pct_start * self.steps) - 2 + self.step_size_end = float(self.steps) - 1 + + self.step() + + def _annealing_cos(self, start, end, pct): + """annealing cosin""" + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + def step(self): + """step""" + if self.current_step <= self.step_size_up: + lr = self._annealing_cos(self.initial_lr, self.max_lr, self.current_step / self.step_size_up) + else: + lr = self._annealing_cos(self.max_lr, self.min_lr, + (self.current_step - self.step_size_up) / (self.step_size_end - self.step_size_up)) + self.current_step = self.current_step + 1 + ### for AdamWeightDecay + self.optimizer.learning_rate.set_data(lr) diff --git a/MindChemistry/applications/crystalflow/requirement.txt b/MindChemistry/applications/crystalflow/requirement.txt new file mode 100644 index 000000000..ff917ee59 --- /dev/null +++ b/MindChemistry/applications/crystalflow/requirement.txt @@ -0,0 +1,7 @@ +pymatgen==2023.8.10 +pandas +scikit-learn +p_tqdm +matminer==0.7.3 +smact==2.2.1 +ruamel.yaml==0.17.19 diff --git a/MindChemistry/applications/crystalflow/test_crystalflow.py b/MindChemistry/applications/crystalflow/test_crystalflow.py new file mode 100644 index 000000000..ffce4881f --- /dev/null +++ b/MindChemistry/applications/crystalflow/test_crystalflow.py @@ -0,0 +1,191 @@ +"""model test""" +import math +import os + +import mindspore as ms +import mindspore.numpy as mnp +from mindspore import nn, ops, Tensor, mint, load_checkpoint, load_param_into_net +from mindchemistry.graph.loss import L2LossMask +import numpy as np + + +from models.cspnet import CSPNet +from models.flow import CSPFlow +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + + +ms.set_seed(1234) +np.random.seed(1234) + +class SinusoidalTimeEmbeddings(nn.Cell): + """time embedding""" + def __init__(self, dim): + super(SinusoidalTimeEmbeddings, self).__init__() + self.dim = dim + + def construct(self, time): + half_dim = self.dim // 2 + embeddings = math.log(10000) / (half_dim - 1) + embeddings = ops.Exp()(mnp.arange(half_dim) * -embeddings) + embeddings = time[:, None] * embeddings[None, :] + embeddings = ops.Concat(axis=-1)( + (ops.Sin()(embeddings), ops.Cos()(embeddings))) + return embeddings + +def test_cspnet(): + """test cspnet.py""" + ms.set_seed(1234) + time_embedding = SinusoidalTimeEmbeddings(256) + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=128) + atom_types = Tensor([61, 12, 52, 52, 46, 46], dtype=ms.int32) + frac_coords = Tensor( + [[5.00000000e-01, 5.00000000e-01, 5.00000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [6.66666687e-01, 3.33333343e-01, 7.50000000e-01], + [3.33333343e-01, 6.66666687e-01, 2.50000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 5.00000000e-01]], dtype=ms.float32) + lengths = Tensor( + [[3.86215806e+00, 3.86215806e+00, 3.86215806e+00], + [4.21191406e+00, 4.21191454e+00, 5.75016499e+00]], dtype=ms.float32) + lattice_polar = Tensor( + [[0.00000000e+00, 0.00000000e+00, 3.97458431e-1, 5.55111512e-16, 0.00000000e+00, 1.35122609e+00], + [-2.74653047e-01, 1.58676151e-16, 6.82046943e-17, -5.38849108e-08, -1.27743945e-01, 1.49374068e+00]], + dtype=ms.float32) + edge_index = Tensor( + [[0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + [0, 1, 0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]], dtype=ms.int32) + node2graph = Tensor([0, 0, 1, 1, 1, 1], dtype=ms.int32) + node_mask = Tensor([1, 1, 1, 1, 1, 1], dtype=ms.int32) + edge_mask = Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=ms.int32,) + tar_lat_polar = Tensor( + [[-0.5366, 0.5920, 0.2546, 0.4013, -0.0032, 0.6611], + [-0.5696, 0.6870, 0.2512, 0.4647, 0.0228, 0.5979]] + ) + tar_coord = Tensor([[-0.7573, 0.2272, -0.4823], + [-0.7647, 0.2261, -0.4763], + [-0.7841, 0.2948, -0.3861], + [-0.7872, 0.2915, -0.3810], + [-0.7789, 0.2759, -0.4070], + [-0.7785, 0.2757, -0.4070]]) + + np.random.seed(1234) + times = np.random.rand(lengths.shape[0]) + times = ms.tensor(times, dtype=ms.float32) + t = time_embedding(times) + lattices_out, coords_out = cspnet(t, atom_types, frac_coords, lattice_polar, node2graph,\ + edge_index, node_mask, edge_mask) + assert mint.isclose(lattices_out, tar_lat_polar, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_lat_polar}, but got {lattices_out}." + assert mint.isclose(coords_out, tar_coord, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_coord}, but got {coords_out}." + +def test_flow(): + """test flow.py""" + ms.set_seed(1234) + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=128) + cspflow = CSPFlow(cspnet) + atom_types = Tensor([61, 12, 52, 52, 46, 46], dtype=ms.int32) + frac_coords = Tensor( + [[5.00000000e-01, 5.00000000e-01, 5.00000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [6.66666687e-01, 3.33333343e-01, 7.50000000e-01], + [3.33333343e-01, 6.66666687e-01, 2.50000000e-01], + [0.00000000e+00, 0.00000000e+00, 0.00000000e+00], + [0.00000000e+00, 0.00000000e+00, 5.00000000e-01]], dtype=ms.float32) + lengths = Tensor( + [[3.86215806e+00, 3.86215806e+00, 3.86215806e+00], + [4.21191406e+00, 4.21191454e+00, 5.75016499e+00]], dtype=ms.float32) + angles = Tensor( + [[9.00000000e+01, 9.00000000e+01, 9.00000000e+01], + [9.00000000e+01, 9.00000000e+01, 1.20000000e+02]], dtype=ms.float32) + lattice_polar = Tensor( + [[0.00000000e+00, 0.00000000e+00, 3.97458431e-1, 5.55111512e-16, 0.00000000e+00, 1.35122609e+00], + [-2.74653047e-01, 1.58676151e-16, 6.82046943e-17, -5.38849108e-08, -1.27743945e-01, 1.49374068e+00]], \ + dtype=ms.float32) + num_atoms = Tensor([2, 4], dtype=ms.int32) + edge_index = Tensor( + [[0, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5], + [0, 1, 0, 1, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5, 2, 3, 4, 5]], dtype=ms.int32) + node2graph = Tensor([0, 0, 1, 1, 1, 1], dtype=ms.int32) + node_mask = Tensor([1, 1, 1, 1, 1, 1], dtype=ms.int32) + edge_mask = Tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], dtype=ms.int32,) + batch_size_mask = Tensor([1, 1], dtype=ms.int32) + + pred_l, tar_l, pred_f, tar_f = cspflow(atom_types, atom_types, lengths, + angles, lattice_polar, num_atoms, frac_coords, node2graph, + edge_index, node_mask, edge_mask, batch_size_mask) + out_pred_l = Tensor([[-0.54417396, 0.6183988, 0.25345746, 0.41497535, -0.00219233, 0.6622897], + [-0.5647707, 0.68243337, 0.25912297, 0.45234668, 0.01847154, 0.6095263]]) + out_tar_l = Tensor([[-0.02254689, 0.04679973, 0.3856261, -0.08269336, -0.08592724, 0.45530552], + [-0.53301036, -0.21067567, -0.05119152, 0.04148455, -0.0907657, 0.3682214]]) + out_pred_f = Tensor([[-0.7662705, 0.24618103, -0.4741043], + [-0.77218896, 0.2367004, -0.4617761], + [-0.7825796, 0.28697833, -0.38660413], + [-0.7888657, 0.2943602, -0.39356205], + [-0.7792929, 0.26879176, -0.42642403], + [-0.77509487, 0.2633396, -0.41789246]]) + out_tar_f = Tensor([[0.20181239, -0.07186192, -0.40746307], + [-0.4028666, 0.18524933, 0.14020872], + [-0.31370556, 0.08878523, -0.18229586], + [-0.15636778, -0.44619012, 0.13355094], + [-0.03352255, -0.15093482, -0.13720155], + [-0.2018686, 0.07621789, -0.4946221]]) + assert mint.isclose(pred_l, out_pred_l, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {pred_l}, but got {out_pred_l}." + assert mint.isclose(pred_f, out_pred_f, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {pred_f}, but got {out_pred_f}." + assert mint.isclose(tar_l, out_tar_l, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_l}, but got {out_tar_l}." + assert mint.isclose(tar_f, out_tar_f, rtol=1e-4, atol=1e-4).all(), \ + f"For `cspnet`, the output should be {tar_f}, but got {out_tar_f}." + +def test_loss(): + """test loss""" + ms.set_context(device_target="CPU") + ckpt_dir = "./ckpt/mp_20" + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + ms.set_seed(1234) + batch_size_max = 256 + + cspnet = CSPNet(num_layers=6, hidden_dim=512, num_freqs=256) + cspflow = CSPFlow(cspnet) + mindspore_ckpt = load_checkpoint("./torch2ms_ckpt/ms_flow.ckpt") + load_param_into_net(cspflow, mindspore_ckpt) + + loss_func_mse = L2LossMask(reduction='mean') + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, lattice_polar_step, \ + num_atoms_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + pred_l, tar_l, pred_x, tar_x = cspflow(batch_size_valid, atom_types_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, + frac_coords_step, batch_node2graph, edge_index_step, + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, tar_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = mseloss_l + 10 * mseloss_x + + return mseloss, mseloss_l, mseloss_x + + train_datatset = fullconnect_dataset(name="mp_20", path='./dataset/mp_20/train.csv', + save_path='./dataset/mp_20/train.npy') + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=False) + + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = forward(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + _, mseloss_l, mseloss_x = result + break + assert mseloss_l <= 0.7, "The denoising of lattice accuracy is not successful." + assert mseloss_x <= 0.7, "The denoising of fractional coordinates accuracy is not successful." diff --git a/MindChemistry/applications/crystalflow/train.py b/MindChemistry/applications/crystalflow/train.py new file mode 100644 index 000000000..8458c4b20 --- /dev/null +++ b/MindChemistry/applications/crystalflow/train.py @@ -0,0 +1,218 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""train file""" +import os +import time +import logging +import argparse +import yaml +import numpy as np +import mindspore as ms +from mindspore import nn, set_seed +from mindspore.amp import all_finite +from mindchemistry.graph.loss import L2LossMask + +from models.cspnet import CSPNet +from models.flow import CSPFlow +from models.train_utils import LossRecord +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + +logging.basicConfig(level=logging.INFO) + +def parse_args(): + '''Parse input args''' + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml', help="The config file path") + parser.add_argument('--device_id', type=int, default=0, + help="ID of the target device") + parser.add_argument('--device_target', type=str, default='CPU', choices=["GPU", "Ascend", "CPU"], + help="The target device to run, support 'Ascend', 'GPU'") + input_args = parser.parse_args() + return input_args + +if __name__ == '__main__': + args = parse_args() + ms.set_context(device_target=args.device_target, device_id=args.device_id) + + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + ckpt_dir = config['train']["ckpt_dir"] + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + set_seed(config['train']["seed"]) + + batch_size_max = config['train']['batch_size'] + + cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], + num_freqs=config['model']['num_freqs']) + cost_coord = config['train']['cost_coord'] + cost_lattice = config['train']['cost_lattice'] + + if os.path.exists(config['checkpoint']['last_path']): + logging.info("load from existing check point................") + param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) + ms.load_param_into_net(cspnet, param_dict) + logging.info("finish load from existing checkpoint") + else: + logging.info("Starting new training process") + + cspflow = CSPFlow(cspnet) + + model_parameters = filter(lambda p: p.requires_grad, cspflow.get_parameters()) + params = sum(np.prod(p.shape) for p in model_parameters) + logging.info("The model you built has %s parameters.", params) + + optimizer = nn.Adam(params=cspflow.trainable_params()) + loss_func_mse = L2LossMask(reduction='mean') + + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, \ + lattice_polar_step, num_atoms_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """forward""" + pred_l, tar_l, pred_x, tar_x = cspflow(batch_size_valid, atom_types_step, lengths_step, \ + angles_step, lattice_polar_step, num_atoms_step, \ + frac_coords_step, batch_node2graph, edge_index_step, \ + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, tar_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = cost_lattice * mseloss_l + cost_coord * mseloss_x + + return mseloss, mseloss_l, mseloss_x + + backward = ms.value_and_grad(forward, None, weights=cspflow.trainable_params(), has_aux=True) + + @ms.jit + def train_step(atom_types_step, frac_coords_step, property_step, lengths_step, angles_step, + lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, + node_num_valid, batch_size_valid): + """train step""" + (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + + is_finite = all_finite(grads) + if is_finite: + optimizer(grads) + + return mseloss, is_finite, mseloss_l, mseloss_x + + @ms.jit + def eval_step(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, + batch_mask, node_num_valid, batch_size_valid): + """eval step""" + mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + return mseloss, mseloss_l, mseloss_x + + epoch = 0 + epoch_size = config['train']["epoch_size"] + + logging.info("Start to initialise train_loader") + train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], + save_path=config['dataset']["train"]["save_path"]) + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) + logging.info("Start to initialise eval_loader") + val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], + save_path=config['dataset']["val"]["save_path"]) + eval_loader = DataLoader(batch_size_max, *val_datatset, + dynamic_batch_size=False, shuffle_dataset=True) + + while epoch < epoch_size: + epoch_starttime = time.time() + + train_mseloss_record = LossRecord() + eval_mseloss_record = LossRecord() + + #################################################### train ################################################### + logging.info("+++++++++++++++ start traning +++++++++++++++++++++") + cspflow.set_train(True) + + starttime = time.time() + record_iter = 0 + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = train_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, _, mseloss_l_, mseloss_x_ = result + + if record_iter % 50 == 0: + logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) + logging.info("learning rate: %s", optimizer.learning_rate.value()) + logging.info("train mse loss: %s", mseloss_step) + logging.info("train mse_lattice loss: %s", mseloss_l_) + logging.info("train mse_coords loss: %s", mseloss_x_) + starttime0 = starttime + starttime = time.time() + logging.info("traning time: %s", starttime - starttime0) + + record_iter += 1 + + train_mseloss_record.update(mseloss_step) + + #################################################### finish train ######################################## + epoch_endtime = time.time() + logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) + logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) + + ms.save_checkpoint(cspflow.decoder, config['checkpoint']['last_path']) + + if epoch % 5 == 0: + #################################################### validation ########################################## + logging.info("+++++++++++++++ start validation +++++++++++++++++++++") + cspflow.set_train(False) + + starttime = time.time() + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in eval_loader: + + result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, mseloss_l_, mseloss_x_ = result_e + + eval_mseloss_record.update(mseloss_step) + + #################################################### finish validation ################################# + + starttime0 = starttime + starttime = time.time() + logging.info("validation time: %s", starttime - starttime0) + logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) + + epoch = epoch + 1 diff --git a/MindChemistry/applications/crystalflow/train_pressure.py b/MindChemistry/applications/crystalflow/train_pressure.py new file mode 100644 index 000000000..e56253f2a --- /dev/null +++ b/MindChemistry/applications/crystalflow/train_pressure.py @@ -0,0 +1,231 @@ +# Copyright 2024 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +# -------------- +# only for reading pressure as condition +# -------------- +"""train file""" +import os +import time +import logging +import argparse +import yaml +import numpy as np +import mindspore as ms +from mindspore import nn, set_seed +from mindspore.amp import all_finite +from mindchemistry.graph.loss import L2LossMask + +# ========================================= +# use condition model +from models.cspnet_condition import CSPNet +from models.flow_condition import CSPFlow +# ========================================= +from models.conditioning import GaussianExpansion +from models.train_utils import LossRecord + +from data.dataset import fullconnect_dataset +from data.crysloader import Crysloader as DataLoader + +logging.basicConfig(level=logging.INFO) + +def parse_args(): + '''Parse input args''' + parser = argparse.ArgumentParser() + parser.add_argument('--config', default='config.yaml', help="The config file path") + parser.add_argument('--device_id', type=int, default=0, + help="ID of the target device") + parser.add_argument('--device_target', type=str, default='CPU', choices=["GPU", "Ascend", "CPU"], + help="The target device to run, support 'Ascend', 'GPU'") + input_args = parser.parse_args() + return input_args + +if __name__ == '__main__': + args = parse_args() + ms.set_context(device_target=args.device_target, device_id=args.device_id) + + with open(args.config, 'r') as stream: + config = yaml.safe_load(stream) + + ckpt_dir = config['train']["ckpt_dir"] + + if not os.path.exists(ckpt_dir): + os.makedirs(ckpt_dir) + + set_seed(config['train']["seed"]) + + batch_size_max = config['train']['batch_size'] + + cond_emb_model = GaussianExpansion(start=config['model']['conditions']['pressure']['start'], + stop=config['model']['conditions']['pressure']['stop'], + n_gaussians=['model']['conditions']['pressure']['n_out'], + ) + cspnet = CSPNet(num_layers=config['model']['num_layers'], hidden_dim=config['model']['hidden_dim'], + num_freqs=config['model']['num_freqs'], + cemb_dim=['model']['conditions']['pressure']['n_out']) + cost_coord = config['train']['cost_coord'] + cost_lattice = config['train']['cost_lattice'] + + if os.path.exists(config['checkpoint']['last_path']): + logging.info("load from existing check point................") + param_dict = ms.load_checkpoint(config['checkpoint']['last_path']) + ms.load_param_into_net(cspnet, param_dict) + logging.info("finish load from existing checkpoint") + else: + logging.info("Starting new training process") + + diffcsp = CSPFlow(cspnet, cond_emb_model) + + model_parameters = filter(lambda p: p.requires_grad, diffcsp.get_parameters()) + params = sum(np.prod(p.shape) for p in model_parameters) + logging.info("The model you built has %s parameters.", params) + + optimizer = nn.Adam(params=diffcsp.trainable_params()) + loss_func_mse = L2LossMask(reduction='mean') + + def forward(atom_types_step, frac_coords_step, _, lengths_step, angles_step, \ + lattice_polar_step, num_atoms_step, edge_index_step, batch_node2graph, \ + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """forward""" + pred_l, tar_l, pred_x, tar_x = diffcsp(batch_size_valid, atom_types_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, frac_coords_step, + batch_node2graph, edge_index_step, + node_mask_step, edge_mask_step, batch_mask) + mseloss_l = loss_func_mse(pred_l, tar_l, mask=batch_mask, num=batch_size_valid) + mseloss_x = loss_func_mse(pred_x, tar_x, mask=node_mask_step, num=node_num_valid) + mseloss = cost_lattice * mseloss_l + cost_coord * mseloss_x + + return mseloss, mseloss_l, mseloss_x + + backward = ms.value_and_grad(forward, None, weights=diffcsp.trainable_params(), has_aux=True) + + @ms.jit + def train_step(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, batch_mask, + node_num_valid, batch_size_valid): + """train step""" + (mseloss, mseloss_l, mseloss_x), grads = backward(atom_types_step, frac_coords_step, property_step, + lengths_step, angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, + batch_size_valid) + + is_finite = all_finite(grads) + if is_finite: + optimizer(grads) + + return mseloss, is_finite, mseloss_l, mseloss_x + + @ms.jit + def eval_step(atom_types_step, frac_coords_step, property_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, + node_mask_step, edge_mask_step, batch_mask, node_num_valid, batch_size_valid): + """eval step""" + mseloss, mseloss_l, mseloss_x = forward(atom_types_step, frac_coords_step, property_step, lengths_step, + angles_step, lattice_polar_step, num_atoms_step, + edge_index_step, batch_node2graph, node_mask_step, edge_mask_step, + batch_mask, node_num_valid, batch_size_valid) + return mseloss, mseloss_l, mseloss_x + + epoch = 0 + epoch_size = config['train']["epoch_size"] + + logging.info("Start to initialise train_loader") + train_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["train"]["path"], + save_path=config['dataset']["train"]["save_path"]) + train_loader = DataLoader(batch_size_max, *train_datatset, shuffle_dataset=True) + logging.info("Start to initialise eval_loader") + val_datatset = fullconnect_dataset(name=config['dataset']["data_name"], path=config['dataset']["val"]["path"], + save_path=config['dataset']["val"]["save_path"]) + eval_loader = DataLoader(batch_size_max, *val_datatset, + dynamic_batch_size=False, shuffle_dataset=True) + + while epoch < epoch_size: + epoch_starttime = time.time() + + train_mseloss_record = LossRecord() + eval_mseloss_record = LossRecord() + + #################################################### train ################################################### + logging.info("+++++++++++++++ start traning +++++++++++++++++++++") + diffcsp.set_train(True) + + starttime = time.time() + record_iter = 0 + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in train_loader: + + result = train_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, _, mseloss_l_, mseloss_x_ = result + + if record_iter % 50 == 0: + logging.info("==============================step: %s ,epoch: %s", train_loader.step - 1, epoch) + logging.info("learning rate: %s", optimizer.learning_rate.value()) + logging.info("train mse loss: %s", mseloss_step) + logging.info("train mse_lattice loss: %s", mseloss_l_) + logging.info("train mse_coords loss: %s", mseloss_x_) + starttime0 = starttime + starttime = time.time() + logging.info("traning time: %s", starttime - starttime0) + + record_iter += 1 + + train_mseloss_record.update(mseloss_step) + + #################################################### finish train ######################################## + epoch_endtime = time.time() + logging.info("epoch %s running time: %s", epoch, epoch_endtime - epoch_starttime) + logging.info("epoch %s average train mse loss: %s", epoch, train_mseloss_record.avg) + + ms.save_checkpoint(diffcsp.decoder, config['checkpoint']['last_path']) + + if epoch % 5 == 0: + #################################################### validation ########################################## + logging.info("+++++++++++++++ start validation +++++++++++++++++++++") + diffcsp.set_train(False) + + starttime = time.time() + for atom_types_batch, frac_coords_batch, property_batch, lengths_batch, \ + angles_batch, lattice_polar_batch, num_atoms_batch,\ + edge_index_batch, batch_node2graph_, node_mask_batch, edge_mask_batch, batch_mask_batch,\ + node_num_valid_, batch_size_valid_ in eval_loader: + + result_e = eval_step(atom_types_batch, frac_coords_batch, property_batch, + lengths_batch, angles_batch, lattice_polar_batch, + num_atoms_batch, edge_index_batch, batch_node2graph_, + node_mask_batch, edge_mask_batch, batch_mask_batch, node_num_valid_, + batch_size_valid_) + + mseloss_step, mseloss_l_, mseloss_x_ = result_e + + eval_mseloss_record.update(mseloss_step) + + #################################################### finish validation ################################# + + starttime0 = starttime + starttime = time.time() + logging.info("validation time: %s", starttime - starttime0) + logging.info("epoch %s average validation mse loss: %s", epoch, eval_mseloss_record.avg) + + epoch = epoch + 1 -- Gitee From 473cb4dde5bcf85f14ce3159f58ac1877015f76e Mon Sep 17 00:00:00 2001 From: wangqc <1160619743@qq.com> Date: Tue, 3 Jun 2025 13:36:00 +0800 Subject: [PATCH 2/2] fix: fix readme --- MindChemistry/applications/crystalflow/README.md | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/MindChemistry/applications/crystalflow/README.md b/MindChemistry/applications/crystalflow/README.md index bead72bb0..556a98a12 100644 --- a/MindChemistry/applications/crystalflow/README.md +++ b/MindChemistry/applications/crystalflow/README.md @@ -15,7 +15,7 @@ ## 快速入门 > 1. 将Mindchemistry/mindchemistry文件包下载到当前目录 -> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)下载相应的数据集 +> 2. 在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)下载相应的数据集 > 3. 安装依赖包:`pip install -r requirement.txt` > 4. 训练命令: `python train.py` > 5. 预测命令: `python evaluate.py` @@ -54,7 +54,7 @@ applications ## 下载数据集 -在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: +在[数据集链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/dataset/)中下载相应的数据集文件夹和dataset_prop.txt数据集属性文件放置于当前路径的dataset文件夹下(如果没有需要自己手动创建),文件路径参考: ```txt crystalflow @@ -87,7 +87,6 @@ python train.py ### 推理 -将权重的path写入config文件的checkpoint.last_path中。预训练模型可以从[预训练模型链接](https://download-mindspore.osinfra.cn/mindscience/mindchemistry/diffcsp/pre-train)中获取。 更改config文件中的test字段来更改推理参数,特别是test.num_eval,它**决定了对于每个组分生成多少个样本**,对于后续的评估阶段很重要。 -- Gitee