diff --git a/MindChemistry/applications/cdvae/README.md b/MindChemistry/applications/cdvae/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1672422847cc861893eaa0b280cdf1113a939b6b
--- /dev/null
+++ b/MindChemistry/applications/cdvae/README.md
@@ -0,0 +1,131 @@
+# 模型名称
+
+> CDVAE
+
+## 介绍
+
+> Crystal Diffusion Variational AutoEncoder (CDVAE)是用来生成材料的周期性结构的SOTA模型,相关论文已发表在ICLR上。模型主要有两个部分组成,首先是encoder部分,将输入得信息转化成隐变量z,部分简单得特性,如原子数量和晶格常数等,直接使用MLP进行decode得到输出,其他部分如原子种类和原子在晶格中得位置等,则通过扩散模型得到。具体模型结构如下图所示:
+
+
+

+
+
+## 数据集
+
+> 提供了三个数据集:
+
+1. Perov_5 (Castelli et al., 2012): 包含接近19000个钙钛矿晶体结构,结构相似,但是组成不同,下载地址:[Perov_5](https://figshare.com/articles/dataset/Perov5/22705189)。
+2. Carbon_24 (Pickard, 2020): 包含10000个仅包含碳原子的晶体结构,因此其具有相同的组成,但是结构不同,下载地址:[Carbon_24](https://figshare.com/articles/dataset/Carbon24/22705192)。
+3. MP_20(Jain et al., 2013): 包含有45000个无机材料结构,包含绝大多数小于单胞小于20个原子的实验已知材料,下载地址:[mp_20](https://figshare.com/articles/dataset/mp_20/25563693)。
+
+前两个数据集下载后直接放在./data目录下即可。MP_20数据集下载后运行`python ./cdvae/dataloader/mp_20_process.py --init_path ./data/mp_20.json --data_path ./data/mp_20`, 其中 init_path是下载得到的json格式数据集的位置,而data_path是dataset存放的位置。
+
+## 环境要求
+
+> 1. 安装`pip install -r requirements.txt`
+
+## 脚本说明
+
+### 代码目录结构
+
+```txt
+└─cdvae
+ │ README.md README文件
+ │ train.py 训练启动脚本
+ │ evaluation.py 推理启动脚本
+ │ compute_metrics.py 评估结果脚本
+ │ create_dataset.py 生成数据集
+ │
+ └─src
+ │ evaluate_utils.py 推理结果生成
+ │ metrics_utils.py 评估结果计算
+ │ dataloader.py 将数据集加载到网络
+ | mp_20_process.py 对mp_20数据集预处理
+ │
+ └─conf 参数配置
+ │ config.yaml 网络参数
+ └─data 数据集参数
+```
+
+## 训练
+
+## 快速开始
+
+> 训练命令: `python train.py --dataset 'perov_5'`
+
+### 命令行参数
+
+```txt
+dataset: 使用得数据集,perov_5, carbon_24, mp_20
+create_dataset: 是否重新对数据集进行处理
+num_sample_train: 如重新处理数据集,训练集得大小,-1为使用全部原始数据
+num_samples_val:如重新处理数据集,验证集得大小,-1为使用全部原始数据
+num_samples_test:如重新处理数据集,测试集得大小,-1为使用全部原始数据
+name_ckpt:保存权重的路径和名称
+load_ckpt:是否读取权重
+device_target:MindSpore使用的后端
+device_id:如MindSpore使用昇腾后端,使用的NPU卡号
+epoch_num:训练的epoch数
+```
+
+## 推理评估过程
+
+### 推理过程
+
+```txt
+1.将权重checkpoint文件保存至 `/loss/`目录下(默认读取目录)
+2.执行推理脚本:reconstruction任务:
+ python evaluation.py --dataset perov_5 --tasks 'recon' (指定dataset为perov_5)
+ generation任务:
+ python evaluation.py --dataset perov_5 --tasks 'gen'
+ optimization任务(如需使用optimization,在训练时请在configs.yaml中将predict_property设置为True):
+ python evaluation.py --dataset perov_5 --tasks 'opt'
+```
+
+### 命令行参数
+
+```txt
+device_target:MindSpore使用的后端
+device_id:如MindSpore使用昇腾后端,使用的NPU卡号
+model_path: 权重保存路径
+dataset: 使用得数据集,perov_5, carbon_24, mp_20
+tasks:推理执行的任务,可选:recon,gen,opt
+n_step_each:执行的denoising的步数
+step_lr:opt任务中设置的lr
+min_sigma:生成随机噪声的最小值
+save_traj:是否保存traj
+disable_bar:是否展示进度条
+num_evals:gen任务中产生的结果数量
+start_from:随机或从头开始读取数据集,可选:randon, data
+batch_size: batch_size大小
+force_num_atoms:是否限制原子数不变
+force_atom_types:是否限制原子种类不变
+label:推理结果保存时的名称
+```
+
+推理结果
+
+```txt
+可以在`/eval_result/`路径下找到推理的输出文件。
+reconstruction的输出文件为eval_recon.npy和gt_recon.npy,分别包含了reconstruction后的晶体结构信息以及作为ground truth的晶体结构信息;
+generation的输出文件为eval_gen.npy,包含了随机生成结果的晶体结构信息;
+optimization的输出文件为eval_opt.npy,包含了基于特定性质优化的晶体结构信息。
+```
+
+### 结果评估
+
+```txt
+运行 python comput_metrics.py --eval_path './eval_result' --dataset 'perov_5' --task recon, 结果会保存在./eval_path文件夹下的eval_metrics.json文件中(目前支持recon和generation两种模式)
+```
+
+## 引用
+
+[1] Xie T, Fu X, Ganea O E, et al. Crystal diffusion variational autoencoder for periodic material generation[J]. arXiv preprint arXiv:2110.06197, 2021.
+
+[2] Castelli I E, Landis D D, Thygesen K S, et al. New cubic perovskites for one-and two-photon water splitting using the computational materials repository[J]. Energy & Environmental Science, 2012, 5(10): 9034-9043.
+
+[3] Castelli I E, Olsen T, Datta S, et al. Computational screening of perovskite metal oxides for optimal solar light capture[J]. Energy & Environmental Science, 2012, 5(2): 5814-5819.
+
+[4] Pickard C J. AIRSS data for carbon at 10GPa and the C+ N+ H+ O system at 1GPa[J]. (No Title), 2020.
+
+[5] Jain A, Ong S P, Hautier G, et al. Commentary: The Materials Project: A materials genome approach to accelerating materials innovation[J]. APL materials, 2013, 1(1).
\ No newline at end of file
diff --git a/MindChemistry/applications/cdvae/compute_metrics.py b/MindChemistry/applications/cdvae/compute_metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..e51be922e3bbfd45212fd6f6f1ef5852f738c072
--- /dev/null
+++ b/MindChemistry/applications/cdvae/compute_metrics.py
@@ -0,0 +1,321 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Compute metrics
+"""
+from collections import Counter
+import logging
+import argparse
+import os
+import json
+
+import numpy as np
+from tqdm import tqdm
+from p_tqdm import p_map
+from scipy.stats import wasserstein_distance
+from pymatgen.core.structure import Structure
+from pymatgen.core.composition import Composition
+from pymatgen.core.lattice import Lattice
+from pymatgen.analysis.structure_matcher import StructureMatcher
+from matminer.featurizers.site.fingerprint import CrystalNNFingerprint
+from matminer.featurizers.composition.composite import ElementProperty
+from mindchemistry.cell.gemnet.data_utils import StandardScaler
+from src.metrics_utils import (
+ smact_validity, structure_validity, get_fp_pdist,
+ get_crystals_list, compute_cov)
+
+CRYSTALNNFP = CrystalNNFingerprint.from_preset("ops")
+COMPFP = ElementProperty.from_preset("magpie")
+
+COV_CUTOFFS = {
+ "mp_20": {"struct": 0.4, "comp": 10.},
+ "carbon_24": {"struct": 0.2, "comp": 4.},
+ "perov_5": {"struct": 0.2, "comp": 4},
+}
+# threshold for coverage metrics, olny struct distance and comp distance
+# smaller than the threshold will be counted as covered.
+
+
+class Crystal():
+ """get crystal structures"""
+
+ def __init__(self, crys_array_dict):
+ self.frac_coords = crys_array_dict["frac_coords"]
+ self.atom_types = crys_array_dict["atom_types"]
+ self.lengths = crys_array_dict["lengths"]
+ self.angles = crys_array_dict["angles"]
+ self.dict = crys_array_dict
+
+ self.get_structure()
+ self.get_composition()
+ self.get_validity()
+ self.get_fingerprints()
+
+ def get_structure(self):
+ """get structure"""
+ if min(self.lengths.tolist()) < 0:
+ self.constructed = False
+ self.invalid_reason = "non_positive_lattice"
+ else:
+ try:
+ self.structure = Structure(
+ lattice=Lattice.from_parameters(
+ *(self.lengths.tolist() + self.angles.tolist())),
+ species=self.atom_types, coords=self.frac_coords, coords_are_cartesian=False)
+ self.constructed = True
+ except (ValueError, AttributeError, TypeError):
+ self.constructed = False
+ self.invalid_reason = "construction_raises_exception"
+ if self.structure.volume < 0.1:
+ self.constructed = False
+ self.invalid_reason = "unrealistically_small_lattice"
+
+ def get_composition(self):
+ elem_counter = Counter(self.atom_types)
+ composition = [(elem, elem_counter[elem])
+ for elem in sorted(elem_counter.keys())]
+ elems, counts = list(zip(*composition))
+ counts = np.array(counts)
+ counts = counts / np.gcd.reduce(counts)
+ self.elems = elems
+ self.comps = tuple(counts.astype("int").tolist())
+
+ def get_validity(self):
+ self.comp_valid = smact_validity(self.elems, self.comps)
+ if self.constructed:
+ self.struct_valid = structure_validity(self.structure)
+ else:
+ self.struct_valid = False
+ self.valid = self.comp_valid and self.struct_valid
+
+ def get_fingerprints(self):
+ """get fingerprints"""
+ elem_counter = Counter(self.atom_types)
+ comp = Composition(elem_counter)
+ self.comp_fp = COMPFP.featurize(comp)
+ try:
+ site_fps = [CRYSTALNNFP.featurize(
+ self.structure, i) for i in range(len(self.structure))]
+ except (ValueError, AttributeError, TypeError):
+ # counts crystal as invalid if fingerprint cannot be constructed.
+ self.valid = False
+ self.comp_fp = None
+ self.struct_fp = None
+ return
+ self.struct_fp = np.array(site_fps).mean(axis=0)
+
+
+class RecEval():
+ """reconstruction evaluation result"""
+
+ def __init__(self, pred_crys, gt_crys, stol=0.5, angle_tol=10, ltol=0.3):
+ assert len(pred_crys) == len(gt_crys)
+ self.matcher = StructureMatcher(
+ stol=stol, angle_tol=angle_tol, ltol=ltol)
+ self.preds = pred_crys
+ self.gts = gt_crys
+
+ def get_match_rate_and_rms(self):
+ """get match rate and rms, match rate shows how much rate of the prediction has
+ the same structure as the ground truth."""
+ def process_one(pred, gt, is_valid):
+ if not is_valid:
+ return None
+ try:
+ rms_dist = self.matcher.get_rms_dist(
+ pred.structure, gt.structure)
+ rms_dist = None if rms_dist is None else rms_dist[0]
+ return rms_dist
+ except (ValueError, AttributeError, TypeError):
+ return None
+ validity = [c.valid for c in self.preds]
+
+ rms_dists = []
+ for i in tqdm(range(len(self.preds))):
+ rms_dists.append(process_one(
+ self.preds[i], self.gts[i], validity[i]))
+ rms_dists = np.array(rms_dists)
+ match_rate = sum(x is not None for x in rms_dists) / len(self.preds)
+ mean_rms_dist = np.array(
+ [x for x in rms_dists if x is not None]).mean()
+ return {"match_rate": match_rate,
+ "rms_dist": mean_rms_dist}
+
+ def get_metrics(self):
+ return self.get_match_rate_and_rms()
+
+
+class GenEval():
+ """Generation Evaluation result"""
+
+ def __init__(self, pred_crys, gt_crys, comp_scaler, n_samples=10, eval_model_name=None):
+ self.crys = pred_crys
+ self.gt_crys = gt_crys
+ self.n_samples = n_samples
+ self.eval_model_name = eval_model_name
+ self.comp_scaler = comp_scaler
+
+ valid_crys = [c for c in pred_crys if c.valid]
+ if len(valid_crys) >= n_samples:
+ sampled_indices = np.random.choice(
+ len(valid_crys), n_samples, replace=False)
+ self.valid_samples = [valid_crys[i] for i in sampled_indices]
+ else:
+ raise Exception(
+ f"not enough valid crystals in the predicted set: {len(valid_crys)}/{n_samples}")
+
+ def get_validity(self):
+ """
+ Compute Validity, which means whether the structure is reasonable and phyically stable
+ in both composition and structure.
+ """
+ comp_valid = np.array([c.comp_valid for c in self.crys]).mean()
+ struct_valid = np.array([c.struct_valid for c in self.crys]).mean()
+ valid = np.array([c.valid for c in self.crys]).mean()
+ return {"comp_valid": comp_valid,
+ "struct_valid": struct_valid,
+ "valid": valid}
+
+ def get_comp_diversity(self):
+ """the earth mover’s distance (EMD) between the property distribution of
+ generated materials and test materials.
+ """
+ comp_fps = [c.comp_fp for c in self.valid_samples]
+ comp_fps = self.comp_scaler.transform(comp_fps)
+ comp_div = get_fp_pdist(comp_fps)
+ return {"comp_div": comp_div}
+
+ def get_struct_diversity(self):
+ return {"struct_div": get_fp_pdist([c.struct_fp for c in self.valid_samples])}
+
+ def get_density_wdist(self):
+ pred_densities = [c.structure.density for c in self.valid_samples]
+ gt_densities = [c.structure.density for c in self.gt_crys]
+ wdist_density = wasserstein_distance(pred_densities, gt_densities)
+ return {"wdist_density": wdist_density}
+
+ def get_num_elem_wdist(self):
+ pred_nelems = [len(set(c.structure.species))
+ for c in self.valid_samples]
+ gt_nelems = [len(set(c.structure.species)) for c in self.gt_crys]
+ wdist_num_elems = wasserstein_distance(pred_nelems, gt_nelems)
+ return {"wdist_num_elems": wdist_num_elems}
+
+ def get_coverage(self):
+ """measure the similarity between ensembles of generated materials
+ and ground truth materials. COV-R measures the percentage of
+ ground truth materials being correctly predicted.
+ """
+ cutoff_dict = COV_CUTOFFS[self.eval_model_name]
+ (cov_metrics_dict, _) = compute_cov(
+ self.crys, self.gt_crys, self.comp_scaler,
+ struc_cutoff=cutoff_dict["struct"],
+ comp_cutoff=cutoff_dict["comp"])
+ return cov_metrics_dict
+
+ def get_metrics(self):
+ metrics = {}
+ metrics.update(self.get_validity())
+ metrics.update(self.get_comp_diversity())
+ metrics.update(self.get_struct_diversity())
+ metrics.update(self.get_density_wdist())
+ metrics.update(self.get_num_elem_wdist())
+ print(f'evaluation metrics:{metrics}')
+ metrics.update(self.get_coverage())
+ return metrics
+
+
+def get_crystal_array_list(data, gt_data=None, ground_truth=False):
+ """get crystal array list"""
+ crys_array_list = get_crystals_list(
+ np.concatenate(data["frac_coords"], axis=1).squeeze(0),
+ np.concatenate(data["atom_types"], axis=1).squeeze(0),
+ np.concatenate(data["lengths"], axis=1).squeeze(0),
+ np.concatenate(data["angles"], axis=1).squeeze(0),
+ np.concatenate(data["num_atoms"], axis=1).squeeze(0))
+
+ # if "input_data_batch" in data:
+ if ground_truth:
+ true_crystal_array_list = get_crystals_list(
+ np.concatenate(gt_data["frac_coords"], axis=0).squeeze(),
+ np.concatenate(gt_data["atom_types"], axis=0).squeeze(),
+ np.concatenate(gt_data["lengths"],
+ axis=0).squeeze().reshape(-1, 3),
+ np.concatenate(gt_data["angles"], axis=0).squeeze().reshape(-1, 3),
+ np.concatenate(gt_data["num_atoms"], axis=0).squeeze())
+ else:
+ true_crystal_array_list = None
+
+ return crys_array_list, true_crystal_array_list
+
+
+def main(args):
+ all_metrics = {}
+ eval_model_name = args.dataset
+
+ if "recon" in args.tasks:
+ out_data = np.load(args.eval_path+"/eval_recon.npy",
+ allow_pickle=True).item()
+ gt_data = np.load(args.eval_path+"/gt_recon.npy",
+ allow_pickle=True).item()
+ crys_array_list, true_crystal_array_list = get_crystal_array_list(
+ out_data, gt_data, ground_truth=True)
+ pred_crys = p_map(Crystal, crys_array_list)
+ gt_crys = p_map(Crystal, true_crystal_array_list)
+
+ rec_evaluator = RecEval(pred_crys, gt_crys)
+ recon_metrics = rec_evaluator.get_metrics()
+ all_metrics.update(recon_metrics)
+
+ if "gen" in args.tasks:
+ out_data = np.load(args.eval_path+"/eval_gen.npy",
+ allow_pickle=True).item()
+ gt_data = np.load(args.eval_path+"/gt_recon.npy",
+ allow_pickle=True).item()
+ crys_array_list, true_crystal_array_list = get_crystal_array_list(
+ out_data, gt_data, ground_truth=True)
+
+ gen_crys = p_map(Crystal, crys_array_list)
+ gt_crys = p_map(Crystal, true_crystal_array_list)
+ gt_comp_fps = [c.comp_fp for c in gt_crys]
+ gt_fp_np = np.array(gt_comp_fps)
+ comp_scaler = StandardScaler(replace_nan_token=0.)
+ comp_scaler.fit(gt_fp_np)
+
+ gen_evaluator = GenEval(
+ gen_crys, gt_crys, comp_scaler, eval_model_name=eval_model_name)
+ gen_metrics = gen_evaluator.get_metrics()
+ all_metrics.update(gen_metrics)
+
+ logging.info(all_metrics)
+
+ if args.label == "":
+ metrics_out_file = "eval_metrics.json"
+ else:
+ metrics_out_file = f"eval_metrics_{args.label}.json"
+ metrics_out_file = os.path.join(args.eval_path, metrics_out_file)
+
+ with open(metrics_out_file, "w") as f:
+ json.dump(all_metrics, f)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset", default="perov_5")
+ parser.add_argument("--eval_path", default="./eval_result")
+ parser.add_argument("--label", default="")
+ parser.add_argument("--tasks", nargs="+", default=["recon"])
+ main_args = parser.parse_args()
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+ main(main_args)
diff --git a/MindChemistry/applications/cdvae/conf/configs.yaml b/MindChemistry/applications/cdvae/conf/configs.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a75a20a86e390b7428eac0cccd961ccb450656ea
--- /dev/null
+++ b/MindChemistry/applications/cdvae/conf/configs.yaml
@@ -0,0 +1,65 @@
+hidden_dim: 256
+latent_dim: 256
+fc_num_layers: 0
+max_atoms: 20
+cost_natom: 1.
+cost_coord: 10.
+cost_type: 1.
+cost_lattice: 10.
+cost_composition: 1.
+cost_edge: 10.
+cost_property: 1.
+beta: 0.01
+max_neighbors: 20
+radius: 7.
+sigma_begin: 10.
+sigma_end: 0.01
+type_sigma_begin: 5.
+type_sigma_end: 0.01
+num_noise_level: 50
+teacher_forcing_lattice: True
+predict_property: True
+
+Encoder:
+ hidden_channels: 128
+ num_blocks: 4
+ int_emb_size: 64
+ basis_emb_size: 8
+ out_emb_channels: 256
+ num_spherical: 7
+ num_radial: 6
+ cutoff: 7.0
+ max_num_neighbors: 20
+ envelope_exponent: 5
+ num_before_skip: 1
+ num_after_skip: 2
+ num_output_layers: 3
+
+Decoder:
+ hidden_dim: 128
+
+Optimizer:
+ learning_rate: 0.001
+ factor: 0.6
+ patience: 30
+ cooldown: 10
+ min_lr: 0.0001
+
+Scaler:
+ TripInteraction_1_had_rbf: 18.873615264892578
+ TripInteraction_1_sum_cbf: 7.996850490570068
+ AtomUpdate_1_sum: 1.220463752746582
+ TripInteraction_2_had_rbf: 16.10817527770996
+ TripInteraction_2_sum_cbf: 7.614634037017822
+ AtomUpdate_2_sum: 0.9690994620323181
+ TripInteraction_3_had_rbf: 15.01930046081543
+ TripInteraction_3_sum_cbf: 7.025179862976074
+ AtomUpdate_3_sum: 0.8903237581253052
+ OutBlock_0_sum: 1.6437848806381226
+ OutBlock_0_had: 16.161039352416992
+ OutBlock_1_sum: 1.1077653169631958
+ OutBlock_1_had: 13.54678726196289
+ OutBlock_2_sum: 0.9477927684783936
+ OutBlock_2_had: 12.754337310791016
+ OutBlock_3_sum: 0.9059251546859741
+ OutBlock_3_had: 13.484951972961426
diff --git a/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml b/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..8a7c7093b586db54bbee2afc09bb05da2b5e31fa
--- /dev/null
+++ b/MindChemistry/applications/cdvae/conf/data/carbon_24.yaml
@@ -0,0 +1,12 @@
+prop: energy_per_atom
+num_targets: 1
+niggli: true
+primitive: false
+graph_method: crystalnn
+lattice_scale_method: scale_length
+preprocess_workers: 30
+readout: mean
+max_atoms: 24
+otf_graph: false
+eval_model_name: carbon
+batch_size: 50
diff --git a/MindChemistry/applications/cdvae/conf/data/mp_20.yaml b/MindChemistry/applications/cdvae/conf/data/mp_20.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f43fb448daecec40d3ce5b1f1237add9ecd5c743
--- /dev/null
+++ b/MindChemistry/applications/cdvae/conf/data/mp_20.yaml
@@ -0,0 +1,12 @@
+prop: formation_energy_per_atom
+num_targets: 1
+niggli: true
+primitive: False
+graph_method: crystalnn
+lattice_scale_method: scale_length
+preprocess_workers: 30
+readout: mean
+max_atoms: 20
+otf_graph: false
+eval_model_name: mp20
+batch_size: 50
diff --git a/MindChemistry/applications/cdvae/conf/data/perov_5.yaml b/MindChemistry/applications/cdvae/conf/data/perov_5.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f25a93abd529484492d46d02156591f150b5d656
--- /dev/null
+++ b/MindChemistry/applications/cdvae/conf/data/perov_5.yaml
@@ -0,0 +1,12 @@
+prop: heat_ref
+num_targets: 1
+niggli: true
+primitive: false
+graph_method: crystalnn
+lattice_scale_method: scale_length
+preprocess_workers: 24
+readout: mean
+max_atoms: 20
+otf_graph: false
+eval_model_name: perovskite
+batch_size: 128
diff --git a/MindChemistry/applications/cdvae/create_dataset.py b/MindChemistry/applications/cdvae/create_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..827a1f2bab210863bfa31e0aba3ee7e599c78e95
--- /dev/null
+++ b/MindChemistry/applications/cdvae/create_dataset.py
@@ -0,0 +1,341 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""create_dataset"""
+
+import os
+import logging
+import argparse
+import numpy as np
+import pandas as pd
+from p_tqdm import p_umap
+from pymatgen.core.structure import Structure
+from pymatgen.core.lattice import Lattice
+from pymatgen.analysis.graphs import StructureGraph
+from pymatgen.analysis import local_env
+
+from mindchemistry.utils.load_config import load_yaml_config_from_path
+from mindchemistry.cell.gemnet.data_utils import get_scaler_from_data_list
+from mindchemistry.cell.gemnet.data_utils import lattice_params_to_matrix
+from mindchemistry.cell.dimenet.preprocess import PreProcess
+logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+
+
+class CreateDataset:
+ """Create Dataset for crystal structures
+
+ Args:
+ name (str): Name of the dataset
+ path (str): Path to the dataset
+ prop (str): Property to predict
+ niggli (bool): Whether to convert to Niggli reduced cell
+ primitive (bool): Whether to convert to primitive cell
+ graph_method (str): Method to create graph
+ preprocess_workers (int): Number of workers for preprocessing
+ lattice_scale_method (str): Method to scale lattice
+ num_samples (int): Number of samples to use, if None use all
+ """
+
+ def __init__(self, name, path,
+ prop, niggli, primitive,
+ graph_method, preprocess_workers,
+ lattice_scale_method, config_path,
+ num_samples=None):
+ super().__init__()
+ self.path = path
+ self.name = name
+ self.num_samples = num_samples
+ self.prop = prop
+ self.niggli = niggli
+ self.primitive = primitive
+ self.graph_method = graph_method
+ self.lattice_scale_method = lattice_scale_method
+ self.config = load_yaml_config_from_path(config_path).get("Encoder")
+ self.preprocess = PreProcess(
+ num_spherical=self.config.get("num_spherical"),
+ num_radial=self.config.get("num_radial"),
+ envelope_exponent=self.config.get("envelope_exponent"),
+ otf_graph=False,
+ cutoff=self.config.get("cutoff"),
+ max_num_neighbors=self.config.get("max_num_neighbors"),)
+
+ self.cached_data = data_preprocess(
+ self.path,
+ preprocess_workers,
+ niggli=self.niggli,
+ primitive=self.primitive,
+ graph_method=self.graph_method,
+ prop_list=[prop],
+ num_samples=self.num_samples
+ )[:self.num_samples]
+ add_scaled_lattice_prop(self.cached_data, lattice_scale_method)
+ self.lattice_scaler = None
+ self.scaler = None
+
+ def __len__(self):
+ return len(self.cached_data)
+
+ def __getitem__(self, index):
+ data = self.cached_data[index]
+
+ # scaler is set in DataModule set stage
+ prop = self.scaler.transform(data[self.prop])
+ (frac_coords, atom_types, lengths, angles, edge_indices,
+ to_jimages, num_atoms) = data["graph_arrays"]
+ data_res = self.preprocess.data_process(angles.reshape(1, -1), lengths.reshape(1, -1),
+ np.array([num_atoms]), edge_indices.T, frac_coords,
+ edge_indices.shape[0], to_jimages, atom_types, prop)
+ return data_res
+
+ def __repr__(self):
+ return f"CrystDataset({self.name}, {self.path})"
+
+ def get_dataset_size(self):
+ return len(self.cached_data)
+
+
+# match element with its chemical symbols
+chemical_symbols = [
+ # 0
+ "X",
+ # 1
+ "H", "He",
+ # 2
+ "Li", "Be", "B", "C", "N", "O", "F", "Ne",
+ # 3
+ "Na", "Mg", "Al", "Si", "P", "S", "Cl", "Ar",
+ # 4
+ "K", "Ca", "Sc", "Ti", "V", "Cr", "Mn", "Fe", "Co", "Ni", "Cu", "Zn",
+ "Ga", "Ge", "As", "Se", "Br", "Kr",
+ # 5
+ "Rb", "Sr", "Y", "Zr", "Nb", "Mo", "Tc", "Ru", "Rh", "Pd", "Ag", "Cd",
+ "In", "Sn", "Sb", "Te", "I", "Xe",
+ # 6
+ "Cs", "Ba", "La", "Ce", "Pr", "Nd", "Pm", "Sm", "Eu", "Gd", "Tb", "Dy",
+ "Ho", "Er", "Tm", "Yb", "Lu",
+ "Hf", "Ta", "W", "Re", "Os", "Ir", "Pt", "Au", "Hg", "Tl", "Pb", "Bi",
+ "Po", "At", "Rn",
+ # 7
+ "Fr", "Ra", "Ac", "Th", "Pa", "U", "Np", "Pu", "Am", "Cm", "Bk",
+ "Cf", "Es", "Fm", "Md", "No", "Lr",
+ "Rf", "Db", "Sg", "Bh", "Hs", "Mt", "Ds", "Rg", "Cn", "Nh", "Fl", "Mc",
+ "Lv", "Ts", "Og"
+]
+
+# used for crystal matching
+CRYSTALNN = local_env.CrystalNN(
+ distance_cutoffs=None, x_diff_weight=-1, porous_adjustment=False)
+
+
+def build_crystal(crystal_str, niggli=True, primitive=False):
+ """Build crystal from cif string."""
+ crystal = Structure.from_str(crystal_str, fmt="cif")
+
+ if primitive:
+ crystal = crystal.get_primitive_structure()
+
+ if niggli:
+ crystal = crystal.get_reduced_structure()
+
+ canonical_crystal = Structure(
+ lattice=Lattice.from_parameters(*crystal.lattice.parameters),
+ species=crystal.species,
+ coords=crystal.frac_coords,
+ coords_are_cartesian=False,
+ )
+ # match is gaurantteed because cif only uses lattice params & frac_coords
+ assert canonical_crystal.matches(crystal)
+ return canonical_crystal
+
+
+def build_crystal_graph(crystal, graph_method="crystalnn"):
+ """build crystal graph"""
+
+ if graph_method == "crystalnn":
+ crystal_graph = StructureGraph.with_local_env_strategy(
+ crystal, CRYSTALNN)
+ elif graph_method == "none":
+ pass
+ else:
+ raise NotImplementedError
+
+ frac_coords = crystal.frac_coords
+ atom_types = crystal.atomic_numbers
+ lattice_parameters = crystal.lattice.parameters
+ lengths = lattice_parameters[:3]
+ angles = lattice_parameters[3:]
+
+ assert np.allclose(crystal.lattice.matrix,
+ lattice_params_to_matrix(*lengths, *angles))
+
+ edge_indices, to_jimages = [], []
+ if graph_method != "none":
+ for i, j, to_jimage in crystal_graph.graph.edges(data="to_jimage"):
+ edge_indices.append([j, i])
+ to_jimages.append(to_jimage)
+ edge_indices.append([i, j])
+ to_jimages.append(tuple(-tj for tj in to_jimage))
+
+ atom_types = np.array(atom_types)
+ lengths, angles = np.array(lengths), np.array(angles)
+ edge_indices = np.array(edge_indices)
+ to_jimages = np.array(to_jimages)
+ num_atoms = atom_types.shape[0]
+
+ return frac_coords, atom_types, lengths, angles, edge_indices, to_jimages, num_atoms
+
+
+def save_data(dataset, is_train, dataset_name):
+ """save created dataset to npy"""
+ processed_data = dict()
+ data_parameters = ["atom_types", "dist", "angle", "idx_kj", "idx_ji",
+ "edge_j", "edge_i", "pos", "batch", "lengths",
+ "num_atoms", "angles", "frac_coords",
+ "num_bonds", "num_triplets", "sbf", "y"]
+ for j, name in enumerate(data_parameters):
+ if j == 16:
+ # Here, y is mindspore.Tensor, while others are all numpy.array, so need to change the type first.
+ processed_data[name] = [i[j].astype(np.float32) for i in dataset]
+ elif j == 14:
+ # Here, we need the sum of num_triplets, so get the summary before we save it.
+ processed_data[name] = [i[j].sum() for i in dataset]
+ else:
+ processed_data[name] = [i[j] for i in dataset]
+
+ if not os.path.exists(f"./data/{dataset_name}/{is_train}"):
+ os.makedirs(f"./data/{dataset_name}/{is_train}")
+ logging.info("%s has been created",
+ f"./data/{dataset_name}/{is_train}")
+ if is_train == "train":
+ np.savetxt(f"./data/{dataset_name}/{is_train}/scaler_mean.csv",
+ dataset.scaler.means.reshape(-1))
+ np.savetxt(f"./data/{dataset_name}/{is_train}/scaler_std.csv",
+ dataset.scaler.stds.reshape(-1))
+ np.savetxt(
+ f"./data/{dataset_name}/{is_train}/lattice_scaler_mean.csv", dataset.lattice_scaler.means)
+ np.savetxt(
+ f"./data/{dataset_name}/{is_train}/lattice_scaler_std.csv", dataset.lattice_scaler.stds)
+ np.save(
+ f"./data/{dataset_name}/{is_train}/processed_data.npy", processed_data)
+
+
+def process_one(row, niggli, primitive, graph_method, prop_list):
+ """process one one sample"""
+ crystal_str = row["cif"]
+ crystal = build_crystal(
+ crystal_str, niggli=niggli, primitive=primitive)
+ graph_arrays = build_crystal_graph(crystal, graph_method)
+ properties = {k: row[k] for k in prop_list if k in row.keys()}
+ result_dict = {
+ "mp_id": row["material_id"],
+ "cif": crystal_str,
+ "graph_arrays": graph_arrays,
+ }
+ result_dict.update(properties)
+ return result_dict
+
+
+def data_preprocess(input_file, num_workers, niggli, primitive, graph_method, prop_list, num_samples):
+ """process data"""
+ df = pd.read_csv(input_file)[:num_samples]
+
+ unordered_results = p_umap(
+ process_one,
+ [df.iloc[idx] for idx in range(len(df))],
+ [niggli] * len(df),
+ [primitive] * len(df),
+ [graph_method] * len(df),
+ [prop_list] * len(df),
+ num_cpus=num_workers)
+
+ mpid_to_results = {result["mp_id"]: result for result in unordered_results}
+ ordered_results = [mpid_to_results[df.iloc[idx]["material_id"]]
+ for idx in range(len(df))]
+
+ return ordered_results
+
+
+def add_scaled_lattice_prop(data_list, lattice_scale_method):
+ """add scaled lattice prop to dataset"""
+ for data in data_list:
+ graph_arrays = data["graph_arrays"]
+ # the indexes are brittle if more objects are returned
+ lengths = graph_arrays[2]
+ angles = graph_arrays[3]
+ num_atoms = graph_arrays[-1]
+ assert lengths.shape[0] == angles.shape[0] == 3
+ assert isinstance(num_atoms, int)
+
+ if lattice_scale_method == "scale_length":
+ lengths = lengths / float(num_atoms)**(1 / 3)
+
+ data["scaled_lattice"] = np.concatenate([lengths, angles])
+
+
+def create_dataset(args):
+ """create dataset"""
+ config_data_path = f"./conf/data/{args.dataset}.yaml"
+ config_path = f"./conf/configs.yaml"
+ config_data = load_yaml_config_from_path(config_data_path)
+ prop = config_data.get("prop")
+ niggli = config_data.get("niggli")
+ primitive = config_data.get("primitive")
+ graph_method = config_data.get("graph_method")
+ lattice_scale_method = config_data.get("lattice_scale_method")
+ preprocess_workers = config_data.get("preprocess_workers")
+ path_train = f"./data/{args.dataset}/train.csv"
+ train_dataset = CreateDataset("Formation energy train", path_train, prop,
+ niggli, primitive, graph_method,
+ preprocess_workers, lattice_scale_method,
+ config_path, args.num_samples_train)
+ lattice_scaler = get_scaler_from_data_list(
+ train_dataset.cached_data,
+ key="scaled_lattice")
+ scaler = get_scaler_from_data_list(
+ train_dataset.cached_data,
+ key=train_dataset.prop)
+ train_dataset.lattice_scaler = lattice_scaler
+ train_dataset.scaler = scaler
+ save_data(train_dataset, "train", args.dataset)
+
+ path_val = f"./data/{args.dataset}/val.csv"
+ val_dataset = CreateDataset("Formation energy val", path_val, prop,
+ niggli, primitive, graph_method,
+ preprocess_workers, lattice_scale_method, args.num_samples_val)
+ val_dataset.lattice_scaler = lattice_scaler
+ val_dataset.scaler = scaler
+ save_data(val_dataset, "val", args.dataset)
+
+ path_test = f"./data/{args.dataset}/test.csv"
+ test_dataset = CreateDataset("Formation energy test", path_test, prop,
+ niggli, primitive, graph_method,
+ preprocess_workers, lattice_scale_method,
+ args.num_samples_test)
+ test_dataset.lattice_scaler = lattice_scaler
+ test_dataset.scaler = scaler
+ save_data(test_dataset, "test", args.dataset)
+
+
+def main(args):
+ create_dataset(args)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset", default="perov_5")
+ parser.add_argument("--num_samples_train", default=300, type=int)
+ parser.add_argument("--num_samples_val", default=300, type=int)
+ parser.add_argument("--num_samples_test", default=300, type=int)
+ main_args = parser.parse_args()
+ main(main_args)
diff --git a/MindChemistry/applications/cdvae/evaluation.py b/MindChemistry/applications/cdvae/evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..4361fb5e4f593f3f170c6af33e5f2d2022a15910
--- /dev/null
+++ b/MindChemistry/applications/cdvae/evaluation.py
@@ -0,0 +1,192 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Evaluation
+"""
+
+import os
+import time
+import logging
+from types import SimpleNamespace
+import argparse
+import mindspore as ms
+import numpy as np
+
+from mindchemistry.cell.cdvae import CDVAE
+from src.dataloader import DataLoaderBaseCDVAE
+from src.evaluate_utils import (get_reconstructon_res, get_generation_res,
+ get_optimization_res)
+from train import get_scaler
+
+
+def task_reconstruction(model, ld_kwargs, graph_dataset, recon_args):
+ """Evaluate model on the reconstruction task."""
+ logging.info("Evaluate model on the reconstruction task.")
+ (frac_coords, num_atoms, atom_types, lengths, angles,
+ gt_frac_coords, gt_num_atoms, gt_atom_types,
+ gt_lengths, gt_angles) = get_reconstructon_res(
+ graph_dataset, model, ld_kwargs, recon_args.num_evals,
+ recon_args.force_num_atoms, recon_args.force_atom_types)
+
+ if recon_args.label == "":
+ recon_out_name = "eval_recon.npy"
+ else:
+ recon_out_name = f"eval_recon_{recon_args.label}.npy"
+
+ result = {
+ "eval_setting": recon_args,
+ "frac_coords": frac_coords,
+ "num_atoms": num_atoms,
+ "atom_types": atom_types,
+ "lengths": lengths,
+ "angles": angles,
+ }
+ # save result as numpy
+ np.save("./eval_result/" + recon_out_name, result)
+ groundtruth = {
+ "frac_coords": gt_frac_coords,
+ "num_atoms": gt_num_atoms,
+ "atom_types": gt_atom_types,
+ "lengths": gt_lengths,
+ "angles": gt_angles,
+ }
+ # save ground truth as numpy
+ np.save("./eval_result/gt_recon.npy", groundtruth)
+
+
+def task_generation(model, ld_kwargs, gen_args):
+ """Evaluate model on the generation task."""
+ logging.info("Evaluate model on the generation task.")
+
+ (frac_coords, num_atoms, atom_types, lengths, angles,
+ all_frac_coords_stack, all_atom_types_stack) = get_generation_res(
+ model, ld_kwargs, gen_args.num_batches_to_samples, gen_args.num_evals,
+ gen_args.batch_size, gen_args.down_sample_traj_step)
+
+ if gen_args.label == "":
+ gen_out_name = "eval_gen.npy"
+ else:
+ gen_out_name = f"eval_gen_{gen_args.label}.npy"
+
+ result = {
+ "eval_setting": gen_args,
+ "frac_coords": frac_coords,
+ "num_atoms": num_atoms,
+ "atom_types": atom_types,
+ "lengths": lengths,
+ "angles": angles,
+ "all_frac_coords_stack": all_frac_coords_stack,
+ "all_atom_types_stack": all_atom_types_stack,
+ }
+ # save result as numpy
+ np.save("./eval_result/" + gen_out_name, result)
+
+
+def task_optimization(model, ld_kwargs, graph_dataset, opt_args):
+ """Evaluate model on the property optimization task."""
+ logging.info("Evaluate model on the property optimization task.")
+ if opt_args.start_from == "data":
+ loader = graph_dataset
+ else:
+ loader = None
+ optimized_crystals = get_optimization_res(model, ld_kwargs, loader)
+ if opt_args.label == "":
+ gen_out_name = "eval_opt.npy"
+ else:
+ gen_out_name = f"eval_opt_{opt_args.label}.npy"
+ # save result as numpy
+ np.save("./eval_result/" + gen_out_name, optimized_crystals)
+
+
+def main(args):
+ # check whether path exists, if not exists create the direction
+ folder_path = os.path.dirname(args.model_path)
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ logging.info("%s has been created", folder_path)
+ result_path = "./eval_result/"
+ if not os.path.exists(result_path):
+ os.makedirs(result_path)
+ logging.info("%s has been created", result_path)
+ config_path = "./conf/configs.yaml"
+ data_config_path = f"./conf/data/{args.dataset}.yaml"
+ # load model
+ model = CDVAE(config_path, data_config_path)
+ # load mindspore check point
+ param_dict = ms.load_checkpoint(args.model_path)
+ param_not_load, _ = ms.load_param_into_net(model, param_dict)
+ logging.info("parameter not load: %s.", param_not_load)
+ model.set_train(False)
+
+ ld_kwargs = SimpleNamespace(n_step_each=args.n_step_each,
+ step_lr=args.step_lr,
+ min_sigma=args.min_sigma,
+ save_traj=args.save_traj,
+ disable_bar=args.disable_bar)
+ # load dataset
+ graph_dataset = DataLoaderBaseCDVAE(
+ args.batch_size, args.dataset, shuffle_dataset=False, mode="test")
+ # load scaler
+ lattice_scaler, scaler = get_scaler(args)
+ model.lattice_scaler = lattice_scaler
+ model.scaler = scaler
+
+ start_time_eval = time.time()
+ if "recon" in args.tasks:
+ task_reconstruction(model, ld_kwargs, graph_dataset, args)
+ if "gen" in args.tasks:
+ task_generation(model, ld_kwargs, args)
+ if "opt" in args.tasks:
+ task_optimization(model, ld_kwargs, graph_dataset, args)
+ logging.info("end evaluation, time: %f s.", time.time() - start_time_eval)
+
+def get_args():
+ """args used for evaluation"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--device_target", default="Ascend", help="device target")
+ parser.add_argument("--device_id", default=7, type=int, help="device id")
+ parser.add_argument("--model_path", default="./loss/loss.ckpt",
+ help="path to checkpoint")
+ parser.add_argument("--dataset", default="perov_5", help="name of dataset")
+ parser.add_argument("--tasks", nargs="+", default=["gen"],
+ help="tasks to evaluate, choose from 'recon, gen, opt'")
+ parser.add_argument("--n_step_each", default=1, type=int,
+ help="number of steps in diffusion")
+ parser.add_argument("--step_lr", default=1e-3, type=float, help="learning rate")
+ parser.add_argument("--min_sigma", default=0, type=float, help="minimum sigma")
+ parser.add_argument("--save_traj", default=False, type=bool,
+ help="whether to save trajectory")
+ parser.add_argument("--disable_bar", default=False, type=bool,
+ help="disable progress bar")
+ parser.add_argument("--num_evals", default=1, type=int,
+ help="number of evaluations returned for each task")
+ parser.add_argument("--num_batches_to_samples", default=1, type=int,
+ help="number of batches to sample")
+ parser.add_argument("--start_from", default="data", type=str,
+ help="start from data or random")
+ parser.add_argument("--batch_size", default=128, type=int, help="batch size")
+ parser.add_argument("--force_num_atoms", action="store_true",
+ help="fixed num atoms or not")
+ parser.add_argument("--force_atom_types", action="store_true",
+ help="fixed atom types or not")
+ parser.add_argument("--down_sample_traj_step", default=10, type=int, help="down sample")
+ parser.add_argument("--label", default="", help="label for output file")
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ main_args = get_args()
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+ ms.context.set_context(device_target=main_args.device_target,
+ device_id=main_args.device_id, mode=1)
+ main(main_args)
diff --git a/MindChemistry/applications/cdvae/images/illustrative.png b/MindChemistry/applications/cdvae/images/illustrative.png
new file mode 100644
index 0000000000000000000000000000000000000000..a70858f7f67a881ba63606c03ec3eba13fb7ef1a
Binary files /dev/null and b/MindChemistry/applications/cdvae/images/illustrative.png differ
diff --git a/MindChemistry/applications/cdvae/mp_20_process.py b/MindChemistry/applications/cdvae/mp_20_process.py
new file mode 100644
index 0000000000000000000000000000000000000000..99dbef741c0458037ce62868ccc263b9f6c9e03d
--- /dev/null
+++ b/MindChemistry/applications/cdvae/mp_20_process.py
@@ -0,0 +1,72 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+""" script used for generate mp_20 dataset from raw data"""
+import os
+import logging
+import argparse
+import pandas as pd
+from pymatgen.core.structure import Structure
+from pymatgen.core.lattice import Lattice
+from pymatgen.io.cif import CifWriter
+
+
+def mp_20_process():
+ """process the mp_20 dataset"""
+ if not os.path.exists(args.data_path):
+ os.makedirs(args.data_path)
+ logging.info("%s has been created", args.data_path)
+
+ # read json file and transfer to pandasframe
+ df = pd.read_json(args.init_path)
+ df = df[["id", "formation_energy_per_atom", "band_gap", "pretty_formula",
+ "e_above_hull", "elements", "atoms", "spacegroup_number"]]
+ struct_list = []
+ element_list = []
+ # generate Structure from its df["atoms"] for each samples
+ for struct in df["atoms"]:
+ lattice = Lattice(struct["lattice_mat"], (False, False, False))
+ pos = struct["coords"]
+ species = struct["elements"]
+ structure = Structure(lattice, species, pos)
+ # save cif from Structure
+ cif = CifWriter(structure)
+ struct_list.append(cif.__str__())
+ element_list.append(struct["elements"])
+
+ # add cif to df
+ df.insert(7, "cif", struct_list)
+ df = df.drop("atoms", axis=1)
+ df["elements"] = element_list
+
+ # save to csv file
+ # solit the dataset to train:val:test = 6:2:2
+ train_df = df.iloc[:int(0.6 * len(df))]
+ val_df = df.iloc[int(0.6 * len(df)):int(0.8 * len(df))]
+ test_df = df.iloc[int(0.8 * len(df)):]
+ train_df.to_csv(args.data_path+"/train.csv", index=False)
+ val_df.to_csv(args.data_path+"/val.csv", index=False)
+ test_df.to_csv(args.data_path+"/test.csv", index=False)
+ logging.info("Finished!")
+
+
+if __name__ == "__main__":
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--init_path", default="./data/mp_20.json",
+ help="path to the initial dataset file")
+ parser.add_argument("--data_path", default="./data/mp_20",
+ help="path to save the processed dataset")
+ args = parser.parse_args()
+ mp_20_process()
diff --git a/MindChemistry/applications/cdvae/requirements.txt b/MindChemistry/applications/cdvae/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..8cac7af853e1281c606c9a5aa325e6453394feee
--- /dev/null
+++ b/MindChemistry/applications/cdvae/requirements.txt
@@ -0,0 +1,12 @@
+matminer==0.7.3
+mindchemistry_ascend==0.1.0
+mindspore==2.3.0.20240411
+numpy==1.26.4
+p_tqdm==1.4.0
+pandas==2.2.2
+pymatgen==2023.8.10
+sciai==0.1.0
+scipy==1.13.1
+SMACT==2.2.1
+sympy==1.12
+tqdm==4.66.2
diff --git a/MindChemistry/applications/cdvae/src/__init__.py b/MindChemistry/applications/cdvae/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0be6d3fcd8c0be618bfa801a8585477fd5326443
--- /dev/null
+++ b/MindChemistry/applications/cdvae/src/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""src"""
diff --git a/MindChemistry/applications/cdvae/src/dataloader.py b/MindChemistry/applications/cdvae/src/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..f02b257767810b8088405f449d7996ca92f5ba0a
--- /dev/null
+++ b/MindChemistry/applications/cdvae/src/dataloader.py
@@ -0,0 +1,233 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""dataloader
+"""
+import random
+import numpy as np
+from mindspore import Tensor
+import mindspore as ms
+
+
+class DataLoaderBaseCDVAE:
+ r"""
+ DataLoader for CDVAE
+ """
+
+ def __init__(self,
+ batch_size,
+ dataset,
+ shuffle_dataset=True,
+ mode="train"):
+ dataset = np.load(
+ f"./data/{dataset}/{mode}/processed_data.npy", allow_pickle=True).item()
+ self.atom_types = dataset["atom_types"]
+ self.dist = dataset["dist"]
+ self.angle = dataset["angle"]
+ self.idx_kj = dataset["idx_kj"]
+ self.idx_ji = dataset["idx_ji"]
+ self.edge_j = dataset["edge_j"]
+ self.edge_i = dataset["edge_i"]
+ self.pos = dataset["pos"]
+ self.batch = dataset["batch"]
+ self.lengths = dataset["lengths"]
+ self.num_atoms = dataset["num_atoms"]
+ self.angles = dataset["angles"]
+ self.frac_coords = dataset["frac_coords"]
+ self.y = dataset["y"]
+ self.num_bonds = dataset["num_bonds"]
+ self.num_triplets = dataset["num_triplets"]
+ self.sbf = dataset["sbf"]
+ self.edge_attr = self.edge_j
+ self.batch_size = batch_size
+ self.index = 0
+ self.step = 0
+ self.shuffle_dataset = shuffle_dataset
+ self.feature = [self.atom_types, self.dist, self.angle, self.idx_kj, self.idx_ji,
+ self.edge_j, self.edge_i, self.pos, self.batch, self.lengths,
+ self.num_atoms, self.angles, self.frac_coords, self.y,
+ self.num_bonds, self.num_triplets, self.sbf]
+
+ # can be customized to specific dataset
+ self.label = self.num_atoms
+ self.node_attr = self.atom_types
+ self.sample_num = len(self.node_attr)
+
+ self.max_start_sample = self.sample_num - self.batch_size + 1
+
+ def get_dataset_size(self):
+ return self.sample_num
+
+ def __iter__(self):
+ if self.shuffle_dataset:
+ self.shuffle()
+ else:
+ self.restart()
+ while self.index < self.max_start_sample:
+ # can be customized to generate different attributes or labels according to specific dataset
+ num_bonds_step = self.gen_global_attr(
+ self.num_bonds, self.batch_size).astype(np.int32)
+ num_atoms_step = self.gen_global_attr(
+ self.num_atoms, self.batch_size).squeeze().astype(np.int32)
+ num_triplets_step = self.gen_global_attr(
+ self.num_triplets, self.batch_size).astype(np.int32)
+ atom_types_step = self.gen_node_attr(
+ self.atom_types, self.batch_size).astype(np.int32)
+ dist_step = self.gen_edge_attr(
+ self.dist, self.batch_size).astype(np.float32)
+ angle_step = self.gen_triplet_attr(
+ self.angle, self.batch_size).astype(np.float32)
+ idx_kj_step = self.gen_triplet_attr(self.idx_kj, self.batch_size)
+ idx_kj_step = self.add_index_offset(
+ idx_kj_step, num_bonds_step, num_triplets_step).astype(np.int32)
+ idx_ji_step = self.gen_triplet_attr(self.idx_ji, self.batch_size)
+ idx_ji_step = self.add_index_offset(
+ idx_ji_step, num_bonds_step, num_triplets_step).astype(np.int32)
+ edge_j_step = self.gen_edge_attr(self.edge_j, self.batch_size)
+ edge_j_step = self.add_index_offset(
+ edge_j_step, num_atoms_step, num_bonds_step).astype(np.int32)
+ edge_i_step = self.gen_edge_attr(self.edge_j, self.batch_size)
+ edge_i_step = self.add_index_offset(
+ edge_i_step, num_atoms_step, num_bonds_step).astype(np.int32)
+ batch_step = np.repeat(
+ np.arange(num_atoms_step.shape[0],), num_atoms_step, axis=0).astype(np.int32)
+ lengths_step = self.gen_crystal_attr(
+ self.lengths, self.batch_size).astype(np.float32)
+ angles_step = self.gen_crystal_attr(
+ self.angles, self.batch_size).astype(np.float32)
+ frac_coords_step = self.gen_node_attr(
+ self.frac_coords, self.batch_size).astype(np.float32)
+ y_step = self.gen_global_attr(
+ self.y, self.batch_size).astype(np.float32)
+ sbf_step = self.gen_triplet_attr(
+ self.sbf, self.batch_size).astype(np.float32)
+ total_atoms = num_atoms_step.sum().item()
+ self.add_step_index(self.batch_size)
+
+ ############## change to mindspore Tensor #############
+ yield self.np2tensor(atom_types_step, dist_step, angle_step, idx_kj_step,
+ idx_ji_step, edge_j_step, edge_i_step, batch_step,
+ lengths_step, num_atoms_step, angles_step, frac_coords_step,
+ y_step, self.batch_size, sbf_step, total_atoms)
+
+ def np2tensor(self, atom_types_step, dist_step, angle_step, idx_kj_step,
+ idx_ji_step, edge_j_step, edge_i_step, batch_step,
+ lengths_step, num_atoms_step, angles_step, frac_coords_step,
+ y_step, batch_size, sbf_step, total_atoms):
+ """np2tensor"""
+ atom_types_step = Tensor(atom_types_step, ms.int32)
+ dist_step = Tensor(dist_step, ms.float32)
+ angle_step = Tensor(angle_step, ms.float32)
+ idx_kj_step = Tensor(idx_kj_step, ms.int32)
+ idx_ji_step = Tensor(idx_ji_step, ms.int32)
+ edge_j_step = Tensor(edge_j_step, ms.int32)
+ edge_i_step = Tensor(edge_i_step, ms.int32)
+ batch_step = Tensor(batch_step, ms.int32)
+ lengths_step = Tensor(lengths_step, ms.float32)
+ num_atoms_step = Tensor(num_atoms_step, ms.int32)
+ angles_step = Tensor(angles_step, ms.float32)
+ frac_coords_step = Tensor(frac_coords_step, ms.float32)
+ y_step = Tensor(y_step, ms.float32)
+ sbf_step = Tensor(sbf_step, ms.float32)
+ return (atom_types_step, dist_step, angle_step, idx_kj_step,
+ idx_ji_step, edge_j_step, edge_i_step, batch_step,
+ lengths_step, num_atoms_step, angles_step, frac_coords_step,
+ y_step, batch_size, sbf_step, total_atoms)
+
+ def add_index_offset(self, edge_index, num_atoms, num_bonds):
+ index_offset = (
+ np.cumsum(num_atoms, axis=0) - num_atoms
+ )
+
+ index_offset_expand = np.repeat(
+ index_offset, num_bonds
+ )
+ edge_index += index_offset_expand
+ return edge_index
+
+ def shuffle_index(self):
+ """shuffle_index"""
+ indices = list(range(self.sample_num))
+ random.shuffle(indices)
+ return indices
+
+ def shuffle(self):
+ """shuffle"""
+ self.shuffle_action()
+ self.step = 0
+ self.index = 0
+
+ def shuffle_action(self):
+ """shuffle_action"""
+ indices = self.shuffle_index()
+ self.atom_types = [self.atom_types[i] for i in indices]
+ self.dist = [self.dist[i] for i in indices]
+ self.angle = [self.angle[i] for i in indices]
+ self.idx_kj = [self.idx_kj[i] for i in indices]
+ self.idx_ji = [self.idx_ji[i] for i in indices]
+ self.edge_j = [self.edge_j[i] for i in indices]
+ self.edge_i = [self.edge_i[i] for i in indices]
+ self.pos = [self.pos[i] for i in indices]
+ self.batch = [self.batch[i] for i in indices]
+ self.lengths = [self.lengths[i] for i in indices]
+ self.num_atoms = [self.num_atoms[i] for i in indices]
+ self.angles = [self.angles[i] for i in indices]
+ self.frac_coords = [self.frac_coords[i] for i in indices]
+ self.y = [self.y[i] for i in indices]
+ self.num_bonds = [self.num_bonds[i] for i in indices]
+ self.num_triplets = [self.num_triplets[i] for i in indices]
+ self.sbf = [self.sbf[i] for i in indices]
+
+ def restart(self):
+ """restart"""
+ self.step = 0
+ self.index = 0
+
+ def gen_node_attr(self, node_attr, batch_size):
+ """gen_node_attr"""
+ node_attr_step = np.concatenate(
+ node_attr[self.index:self.index + batch_size], 0)
+ return node_attr_step
+
+ def gen_edge_attr(self, edge_attr, batch_size):
+ """gen_edge_attr"""
+ edge_attr_step = np.concatenate(
+ edge_attr[self.index:self.index + batch_size], 0)
+
+ return edge_attr_step
+
+ def gen_global_attr(self, global_attr, batch_size):
+ """gen_global_attr"""
+ global_attr_step = np.stack(
+ global_attr[self.index:self.index + batch_size], 0)
+
+ return global_attr_step
+
+ def gen_crystal_attr(self, global_attr, batch_size):
+ """gen_global_attr"""
+ global_attr_step = np.stack(
+ global_attr[self.index:self.index + batch_size], 0).squeeze()
+ return global_attr_step
+
+ def gen_triplet_attr(self, triplet_attr, batch_size):
+ """gen_triplet_attr"""
+ global_attr_step = np.concatenate(
+ triplet_attr[self.index:self.index + batch_size], 0)
+
+ return global_attr_step
+
+ def add_step_index(self, batch_size):
+ """add_step_index"""
+ self.index = self.index + batch_size
+ self.step += 1
diff --git a/MindChemistry/applications/cdvae/src/evaluate_utils.py b/MindChemistry/applications/cdvae/src/evaluate_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..236ececa4c6eb8f3f5c0d28bfef7de6b5d9b3328
--- /dev/null
+++ b/MindChemistry/applications/cdvae/src/evaluate_utils.py
@@ -0,0 +1,191 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""evaluate_utils"""
+import logging
+import mindspore as ms
+import mindspore.mint as mint
+from mindspore.nn import Adam
+from tqdm import tqdm
+
+logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+
+
+def get_reconstructon_res(loader, model, ld_kwargs, num_evals,
+ force_num_atoms=False, force_atom_types=False):
+ """
+ reconstruct the crystals in .
+ """
+ result_frac_coords = []
+ result_num_atoms = []
+ result_atom_types = []
+ result_lengths = []
+ result_angles = []
+ gt_frac_coords = []
+ groundtruth_num_atoms = []
+ groundtruth_atom_types = []
+ gt_lengths = []
+ gt_angles = []
+ for idx, data in enumerate(loader):
+ logging.info("Reconstructing %d", int(idx * data[-3]))
+ batch_frac_coords, batch_num_atoms, batch_atom_types = [], [], []
+ batch_lengths, batch_angles = [], []
+
+ # only sample one z, multiple evals for stoichaticity in langevin dynamics
+ (atom_types, dist, _, idx_kj, idx_ji,
+ edge_j, edge_i, batch, lengths, num_atoms,
+ angles, frac_coords, _, batch_size, sbf,
+ total_atoms) = data
+ gt_frac_coords.append(frac_coords.asnumpy())
+ gt_angles.append(angles.asnumpy())
+ gt_lengths.append(lengths.asnumpy())
+ groundtruth_atom_types.append(atom_types.asnumpy())
+ groundtruth_num_atoms.append(num_atoms.asnumpy())
+ _, _, z = model.encode(atom_types, dist,
+ idx_kj, idx_ji, edge_j, edge_i,
+ batch, total_atoms, batch_size, sbf)
+ for _ in range(num_evals):
+ gt_num_atoms = num_atoms if force_num_atoms else None
+ gt_atom_types = atom_types if force_atom_types else None
+ outputs = model.langevin_dynamics(
+ z, ld_kwargs, batch_size, total_atoms, gt_num_atoms, gt_atom_types)
+ # collect sampled crystals in this batch.
+ batch_frac_coords.append(outputs["frac_coords"].asnumpy())
+ batch_num_atoms.append(outputs["num_atoms"].asnumpy())
+ batch_atom_types.append(outputs["atom_types"].asnumpy())
+ batch_lengths.append(outputs["lengths"].asnumpy())
+ batch_angles.append(outputs["angles"].asnumpy())
+ # collect sampled crystals for this z.
+ result_frac_coords.append(batch_frac_coords)
+ result_num_atoms.append(batch_num_atoms)
+ result_atom_types.append(batch_atom_types)
+ result_lengths.append(batch_lengths)
+ result_angles.append(batch_angles)
+
+ return (
+ result_frac_coords, result_num_atoms, result_atom_types,
+ result_lengths, result_angles,
+ gt_frac_coords, groundtruth_num_atoms, groundtruth_atom_types,
+ gt_lengths, gt_angles)
+
+
+def get_generation_res(model, ld_kwargs, num_batches_to_sample, num_samples_per_z,
+ batch_size=512, down_sample_traj_step=1):
+ """
+ generate new crystals based on randomly sampled z.
+ """
+ all_frac_coords_stack = []
+ all_atom_types_stack = []
+ result_frac_coords = []
+ result_num_atoms = []
+ result_atom_types = []
+ result_lengths = []
+ result_angles = []
+
+ for _ in range(num_batches_to_sample):
+ batch_all_frac_coords = []
+ batch_all_atom_types = []
+ batch_frac_coords, batch_num_atoms, batch_atom_types = [], [], []
+ batch_lengths, batch_angles = [], []
+
+ z = ms.ops.randn(batch_size, model.hidden_dim)
+
+ for _ in range(num_samples_per_z):
+ samples = model.langevin_dynamics(z, ld_kwargs, batch_size)
+
+ # collect sampled crystals in this batch.
+ batch_frac_coords.append(samples["frac_coords"].asnumpy())
+ batch_num_atoms.append(samples["num_atoms"].asnumpy())
+ batch_atom_types.append(samples["atom_types"].asnumpy())
+ batch_lengths.append(samples["lengths"].asnumpy())
+ batch_angles.append(samples["angles"].asnumpy())
+ if ld_kwargs.save_traj:
+ batch_all_frac_coords.append(
+ samples["all_frac_coords"][::down_sample_traj_step].asnumpy())
+ batch_all_atom_types.append(
+ samples["all_atom_types"][::down_sample_traj_step].asnumpy())
+
+ # collect sampled crystals for this z.
+ result_frac_coords.append(batch_frac_coords)
+ result_num_atoms.append(batch_num_atoms)
+ result_atom_types.append(batch_atom_types)
+ result_lengths.append(batch_lengths)
+ result_angles.append(batch_angles)
+ if ld_kwargs.save_traj:
+ all_frac_coords_stack.append(
+ batch_all_frac_coords)
+ all_atom_types_stack.append(
+ batch_all_atom_types)
+
+ return (result_frac_coords, result_num_atoms, result_atom_types,
+ result_lengths, result_angles,
+ all_frac_coords_stack, all_atom_types_stack)
+
+
+def get_optimization_res(model, ld_kwargs, data_loader,
+ num_starting_points=128, num_gradient_steps=5000,
+ lr=1e-3, num_saved_crys=10):
+ """
+ optimize the structure based on specific proprety.
+ """
+ model.set_train(True)
+ if data_loader is not None:
+ data = next(iter(data_loader))
+ (atom_types, dist, _, idx_kj, idx_ji,
+ edge_j, edge_i, batch, _, num_atoms,
+ _, _, _, batch_size, sbf,
+ total_atoms) = data
+ _, _, z = model.encode(atom_types, dist,
+ idx_kj, idx_ji, edge_j, edge_i,
+ batch, total_atoms, batch_size, sbf)
+ z = mint.narrow(z, 0, 0, num_starting_points)
+ z = ms.Parameter(z, requires_grad=True)
+ else:
+ z = mint.randn(num_starting_points, model.hparams.hidden_dim)
+ z = ms.Parameter(z, requires_grad=True)
+
+ opt = Adam([z], learning_rate=lr)
+ freeze_model(model)
+
+ loss_fn = model.fc_property
+
+ def forward_fn(data):
+ loss = loss_fn(data)
+ return loss
+ grad_fn = ms.value_and_grad(forward_fn, None, opt.parameters)
+
+ def train_step(data):
+ loss, grads = grad_fn(data)
+ opt(grads)
+ return loss
+
+ all_crystals = []
+ total_atoms = mint.sum(mint.narrow(
+ num_atoms, 0, 0, num_starting_points)).item()
+ interval = num_gradient_steps // (num_saved_crys - 1)
+ for i in tqdm(range(num_gradient_steps)):
+ loss = mint.mean(train_step(z))
+ logging.info("Task opt step: %d, loss: %f", i, loss)
+ if i % interval == 0 or i == (num_gradient_steps - 1):
+ crystals = model.langevin_dynamics(
+ z, ld_kwargs, batch_size, total_atoms)
+ all_crystals.append(crystals)
+ return {k: mint.cat([d[k] for d in all_crystals]).unsqueeze(0).asnumpy() for k in
+ ["frac_coords", "atom_types", "num_atoms", "lengths", "angles"]}
+
+
+def freeze_model(model):
+ """ The model is fixed, only optimize z"""
+ for param in model.get_parameters():
+ param.requires_grad = False
diff --git a/MindChemistry/applications/cdvae/src/metrics_utils.py b/MindChemistry/applications/cdvae/src/metrics_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf179ec1053b7e2776e2ff55b7d9247f4ed30cf9
--- /dev/null
+++ b/MindChemistry/applications/cdvae/src/metrics_utils.py
@@ -0,0 +1,191 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""utils for compute metrics"""
+import itertools
+import numpy as np
+
+from scipy.spatial.distance import pdist
+from scipy.spatial.distance import cdist
+
+import smact
+from smact.screening import pauling_test
+
+from create_dataset import chemical_symbols
+
+
+def get_crystals_list(
+ frac_coords, atom_types, lengths, angles, num_atoms):
+ """
+ args:
+ frac_coords: (num_atoms, 3)
+ atom_types: (num_atoms)
+ lengths: (num_crystals)
+ angles: (num_crystals)
+ num_atoms: (num_crystals)
+ """
+ assert frac_coords.shape[0] == atom_types.shape[0] == num_atoms.sum()
+ assert lengths.shape[0] == angles.shape[0] == num_atoms.shape[0]
+
+ start_idx = 0
+ crystal_array_list = []
+ for batch_idx, num_atom in enumerate(num_atoms.tolist()):
+ cur_frac_coords = frac_coords[start_idx:start_idx+num_atom]
+ cur_atom_types = atom_types[start_idx:start_idx+num_atom]
+ cur_lengths = lengths[batch_idx]
+ cur_angles = angles[batch_idx]
+
+ crystal_array_list.append({
+ "frac_coords": cur_frac_coords,
+ "atom_types": cur_atom_types,
+ "lengths": cur_lengths,
+ "angles": cur_angles,
+ })
+ start_idx = start_idx + num_atom
+ return crystal_array_list
+
+
+def smact_validity(comp, count,
+ use_pauling_test=True,
+ include_alloys=True):
+ """compute smact validity"""
+ elem_symbols = tuple([chemical_symbols[elem] for elem in comp])
+ space = smact.element_dictionary(elem_symbols)
+ smact_elems = [e[1] for e in space.items()]
+ electronegs = [e.pauling_eneg for e in smact_elems]
+ ox_combos = [e.oxidation_states for e in smact_elems]
+ if len(set(elem_symbols)) == 1:
+ return True
+ if include_alloys:
+ is_metal_list = [elem_s in smact.metals for elem_s in elem_symbols]
+ if all(is_metal_list):
+ return True
+
+ threshold = np.max(count)
+ compositions = []
+ for ox_states in itertools.product(*ox_combos):
+ stoichs = [(c,) for c in count]
+ # Test for charge balance
+ cn_e, cn_r = smact.neutral_ratios(
+ ox_states, stoichs=stoichs, threshold=threshold)
+ # Electronegativity test
+ if cn_e:
+ if use_pauling_test:
+ try:
+ electroneg_pass = pauling_test(ox_states, electronegs)
+ except TypeError:
+ # if no electronegativity data, assume it is okay
+ electroneg_pass = True
+ else:
+ electroneg_pass = True
+ if electroneg_pass:
+ for ratio in cn_r:
+ compositions.append(
+ tuple([elem_symbols, ox_states, ratio]))
+ compositions = [(i[0], i[2]) for i in compositions]
+ compositions = list(set(compositions))
+ res = bool(compositions)
+ return res
+
+
+def structure_validity(crystal, cutoff=0.5):
+ """compute structure validity"""
+ dist_mat = crystal.distance_matrix
+ # Pad diagonal with a large number
+ dist_mat = dist_mat + np.diag(
+ np.ones(dist_mat.shape[0]) * (cutoff + 10.))
+ res = None
+ if dist_mat.min() < cutoff or crystal.volume < 0.1:
+ res = False
+ else:
+ res = True
+ return res
+
+
+def get_fp_pdist(fp_array):
+ if isinstance(fp_array, list):
+ fp_array = np.array(fp_array)
+ fp_pdists = pdist(fp_array)
+ return fp_pdists.mean()
+
+
+def filter_fps(struc_fps, comp_fps):
+ assert len(struc_fps) == len(comp_fps)
+
+ filtered_struc_fps, filtered_comp_fps = [], []
+
+ for struc_fp, comp_fp in zip(struc_fps, comp_fps):
+ if struc_fp is not None and comp_fp is not None:
+ filtered_struc_fps.append(struc_fp)
+ filtered_comp_fps.append(comp_fp)
+ return filtered_struc_fps, filtered_comp_fps
+
+
+def compute_cov(crys, gt_crys, comp_scaler,
+ struc_cutoff, comp_cutoff, num_gen_crystals=None):
+ """compute COV"""
+ struc_fps = [c.struct_fp for c in crys]
+ comp_fps = [c.comp_fp for c in crys]
+ gt_struc_fps = [c.struct_fp for c in gt_crys]
+ gt_comp_fps = [c.comp_fp for c in gt_crys]
+
+ assert len(struc_fps) == len(comp_fps)
+ assert len(gt_struc_fps) == len(gt_comp_fps)
+
+ # Use number of crystal before filtering to compute COV
+ if num_gen_crystals is None:
+ num_gen_crystals = len(struc_fps)
+
+ struc_fps, comp_fps = filter_fps(struc_fps, comp_fps)
+
+ comp_fps = comp_scaler.transform(comp_fps)
+ gt_comp_fps = comp_scaler.transform(gt_comp_fps)
+
+ struc_fps = np.array(struc_fps)
+ gt_struc_fps = np.array(gt_struc_fps)
+ comp_fps = np.array(comp_fps)
+ gt_comp_fps = np.array(gt_comp_fps)
+
+ struc_pdist = cdist(struc_fps, gt_struc_fps)
+ comp_pdist = cdist(comp_fps, gt_comp_fps)
+
+ struc_recall_dist = struc_pdist.min(axis=0)
+ struc_precision_dist = struc_pdist.min(axis=1)
+ comp_recall_dist = comp_pdist.min(axis=0)
+ comp_precision_dist = comp_pdist.min(axis=1)
+
+ cov_recall = np.mean(np.logical_and(
+ struc_recall_dist <= struc_cutoff,
+ comp_recall_dist <= comp_cutoff))
+ cov_precision = np.sum(np.logical_and(
+ struc_precision_dist <= struc_cutoff,
+ comp_precision_dist <= comp_cutoff)) / num_gen_crystals
+
+ metrics_dict = {
+ "cov_recall": cov_recall,
+ "cov_precision": cov_precision,
+ "amsd_recall": np.mean(struc_recall_dist),
+ "amsd_precision": np.mean(struc_precision_dist),
+ "amcd_recall": np.mean(comp_recall_dist),
+ "amcd_precision": np.mean(comp_precision_dist),
+ }
+
+ combined_dist_dict = {
+ "struc_recall_dist": struc_recall_dist.tolist(),
+ "struc_precision_dist": struc_precision_dist.tolist(),
+ "comp_recall_dist": comp_recall_dist.tolist(),
+ "comp_precision_dist": comp_precision_dist.tolist(),
+ }
+
+ return metrics_dict, combined_dist_dict
diff --git a/MindChemistry/applications/cdvae/train.py b/MindChemistry/applications/cdvae/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b73927f594f602817db32fa1986c68afebce833
--- /dev/null
+++ b/MindChemistry/applications/cdvae/train.py
@@ -0,0 +1,185 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ============================================================================
+"""Train
+"""
+
+import os
+import logging
+import argparse
+import time
+import numpy as np
+import mindspore as ms
+from mindspore.experimental import optim
+from mindchemistry.utils.load_config import load_yaml_config_from_path
+from mindchemistry.cell.cdvae import CDVAE
+from mindchemistry.cell.gemnet.data_utils import StandardScalerMindspore
+from create_dataset import create_dataset
+from src.dataloader import DataLoaderBaseCDVAE
+
+
+def train_epoch(epoch, model, optimizer, scheduler, train_dataset):
+ """Train the model for one epoch"""
+ model.set_train()
+ # Define forward function
+
+ def forward_fn(data):
+ (atom_types, dist, _, idx_kj, idx_ji,
+ edge_j, edge_i, batch, lengths, num_atoms,
+ angles, frac_coords, y, batch_size, sbf, total_atoms) = data
+ loss = model(atom_types, dist, idx_kj, idx_ji, edge_j, edge_i,
+ batch, lengths, num_atoms, angles, frac_coords,
+ y, batch_size, sbf, total_atoms, True, True)
+ return loss
+ # Get gradient function
+ grad_fn = ms.value_and_grad(
+ forward_fn, None, optimizer.parameters, has_aux=False)
+
+ # Define function of one-step training
+ def train_step(data):
+ loss, grads = grad_fn(data)
+ scheduler.step(loss)
+ optimizer(grads)
+ return loss
+
+ start_time_step = time.time()
+ for batch, data in enumerate(train_dataset):
+ loss = train_step(data)
+ time_step = time.time() - start_time_step
+ start_time_step = time.time()
+ if batch % 10 == 0:
+ logging.info("Train Epoch: %d [%d]\tLoss: %4f,\t time_step: %4f",
+ epoch, batch, loss, time_step)
+
+
+def test_epoch(model, val_dataset):
+ """test for one epoch"""
+ model.set_train(False)
+ test_loss = 0
+ i = 1
+ for i, data in enumerate(val_dataset):
+ (atom_types, dist, _, idx_kj, idx_ji,
+ edge_j, edge_i, batch, lengths, num_atoms,
+ angles, frac_coords, y, batch_size, sbf, total_atoms) = data
+ output = model(atom_types, dist,
+ idx_kj, idx_ji, edge_j, edge_i,
+ batch, lengths, num_atoms,
+ angles, frac_coords, y, batch_size,
+ sbf, total_atoms, False, True)
+ test_loss += float(output)
+ test_loss /= (i+1)
+ logging.info("Val Loss: %4f", test_loss)
+ return test_loss
+
+def get_scaler(args):
+ """get scaler"""
+ lattice_scaler_mean = ms.Tensor(np.loadtxt(
+ f"./data/{args.dataset}/train/lattice_scaler_mean.csv"), ms.float32)
+ lattice_scaler_std = ms.Tensor(np.loadtxt(
+ f"./data/{args.dataset}/train/lattice_scaler_std.csv"), ms.float32)
+ scaler_std = ms.Tensor(np.loadtxt(
+ f"./data/{args.dataset}/train/scaler_std.csv"), ms.float32)
+ scaler_mean = ms.Tensor(np.loadtxt(
+ f"./data/{args.dataset}/train/scaler_mean.csv"), ms.float32)
+ lattice_scaler = StandardScalerMindspore(
+ lattice_scaler_mean, lattice_scaler_std)
+ scaler = StandardScalerMindspore(scaler_mean, scaler_std)
+ return lattice_scaler, scaler
+
+def train_net(args):
+ """training process"""
+ folder_path = os.path.dirname(args.name_ckpt)
+ if not os.path.exists(folder_path):
+ os.makedirs(folder_path)
+ logging.info("%s has been created", folder_path)
+ config_path = "./conf/configs.yaml"
+ data_config_path = f"./conf/data/{args.dataset}.yaml"
+
+ model = CDVAE(config_path, data_config_path)
+
+ # load checkpoint
+ if args.load_ckpt:
+ model_path = args.name_ckpt
+ param_dict = ms.load_checkpoint(model_path)
+ param_not_load, _ = ms.load_param_into_net(model, param_dict)
+ logging.info("%s have not been loaded", param_not_load)
+
+ # create dataset when running the model first-time or when dataset is not exist
+ if args.create_dataset or not os.path.exists(f"./data/{args.dataset}/train/processed_data.npy"):
+ logging.info("Creating dataset......")
+ create_dataset(args) # dataset created will be save to the dir based on args.dataset as npy
+
+ # read dataset from processed_data
+ batch_size = load_yaml_config_from_path(data_config_path).get("batch_size")
+ train_dataset = DataLoaderBaseCDVAE(
+ batch_size, args.dataset, shuffle_dataset=True, mode="train")
+ val_dataset = DataLoaderBaseCDVAE(
+ batch_size, args.dataset, shuffle_dataset=False, mode="val")
+ lattice_scaler, scaler = get_scaler(args)
+ model.lattice_scaler = lattice_scaler
+ model.scaler = scaler
+
+ config_opt = load_yaml_config_from_path(config_path).get("Optimizer")
+ learning_rate = config_opt.get("learning_rate")
+ min_lr = config_opt.get("min_lr")
+ factor = config_opt.get("factor")
+ patience = config_opt.get("patience")
+
+ optimizer = optim.Adam(model.trainable_params(), learning_rate)
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
+ optimizer, 'min', factor=factor, patience=patience, min_lr=min_lr)
+
+ min_test_loss = float("inf")
+ for epoch in range(args.epoch_num):
+ train_epoch(epoch, model, optimizer, scheduler, train_dataset)
+ if epoch % 10 == 0:
+ test_loss = test_epoch(model, val_dataset)
+ if test_loss < min_test_loss:
+ min_test_loss = test_loss
+ ms.save_checkpoint(model, args.name_ckpt)
+ logging.info("Updata best acc: %f", test_loss)
+
+ logging.info('Finished Training')
+
+def get_args():
+ """get args"""
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--dataset", default="perov_5", help="dataset name")
+ parser.add_argument("--create_dataset", default=False,
+ type=bool, help="whether create dataset again or not")
+ parser.add_argument("--num_samples_train", default=500, type=int,
+ help="number of samples for training,\
+ only valid when create_dataset is True")
+ parser.add_argument("--num_samples_val", default=300, type=int,
+ help="number of samples for validation,\
+ only valid when create_dataset is True")
+ parser.add_argument("--num_samples_test", default=300, type=int,
+ help="number of samples for test,\
+ only valid when create_dataset is True")
+ parser.add_argument("--name_ckpt", default="./loss/loss.ckpt",
+ help="the path to save checkpoint")
+ parser.add_argument("--load_ckpt", default=False, type=bool,
+ help="whether load checkpoint or not")
+ parser.add_argument("--device_target", default="Ascend", help="device target")
+ parser.add_argument("--device_id", default=3, type=int, help="device id")
+ parser.add_argument("--epoch_num", default=100, type=int, help="number of epoch")
+ return parser.parse_args()
+
+if __name__ == "__main__":
+ main_args = get_args()
+ logging.basicConfig(format="%(levelname)s:%(message)s", level=logging.INFO)
+ ms.context.set_context(device_target=main_args.device_target,
+ device_id=main_args.device_id,
+ mode=1)
+ train_net(main_args)
diff --git a/MindChemistry/mindchemistry/cell/gemnet/layers/base_layers.py b/MindChemistry/mindchemistry/cell/gemnet/layers/base_layers.py
index fe6193abf02b1c0a298aadfc6bc5d5211481d9cc..deae89f36339599281acb927ee7a3fd31ac7fdbf 100644
--- a/MindChemistry/mindchemistry/cell/gemnet/layers/base_layers.py
+++ b/MindChemistry/mindchemistry/cell/gemnet/layers/base_layers.py
@@ -132,7 +132,7 @@ class MLP(ms.nn.Cell):
super().__init__()
self.activation = activation
self.last_activation = last_activation
- self.in_layer = mint.nn.Linear(in_dim, hidden_dim, bias=False)
+ self.in_layer = mint.nn.Linear(in_dim, hidden_dim, bias=True)
self.dense_mlp = ms.nn.SequentialCell(
*[
mint.nn.Linear(
@@ -143,7 +143,7 @@ class MLP(ms.nn.Cell):
for _ in range(fc_num_layers)
]
)
- self.out_layer = mint.nn.Linear(hidden_dim, out_dim, bias=False)
+ self.out_layer = mint.nn.Linear(hidden_dim, out_dim, bias=True)
def construct(self, x):
"""MLP construct"""