diff --git a/MindChemistry/applications/bete-net/README.md b/MindChemistry/applications/bete-net/README.md new file mode 100644 index 0000000000000000000000000000000000000000..352a1881bd336890c38b53a1a18750f67d228252 --- /dev/null +++ b/MindChemistry/applications/bete-net/README.md @@ -0,0 +1,106 @@ +# BETE-NET Training Guide +## Overview +We present a deep-learning strategy tailored for electron–phonon-coupled superconductors. +The core obstacle is the prohibitive cost of computing the Eliashberg spectral function α²F(ω). +We therefore adopt a two-step workflow: +1. First-principles calculation of α²F(ω) for 818 dynamically stable materials. +2. Training a dedicated deep-learning model—BETE-NET—to predict α²F(ω) directly from crystal structure. + +BETE-NET employs a dual-branch graph neural network to encode electron–phonon spectral interactions. By integrating +spectral-function attention and a temperature-annealing schedule, it efficiently predicts the superconducting critical +temperature Tc. Its principal innovation lies in embedding the physical spectral function directly into the graph- +convolution process rather than treating it as a post hoc feature, enabling superior accuracy even with limited data . + +![model_structure](./images/model_structure.png) + +## Quick Start + +```bash +export PYTHONPATH=$PYTHONPATH:../../ +python train.py +``` + +## Training Progress Display + +### Real-time Progress Bar +- **Per-epoch display**: live progress bars for both training and validation +- **Loss updates**: current sample loss and running average shown in real time +- **Time estimation**: runtime per epoch + +## Configuration Options + +### Training Parameters +| Parameter | Default | Description | +| ----------------- | ------- | ----------------------------- | +| epochs | 100 | Number of training epochs | +| display_interval | 5 | Interval for detailed reports | +| plot_interval | 10 | Interval for plotting figures | +| learning_rate | 0.0005 | Optimizer learning rate (FPD) | +| patience | 20 | Early-stopping patience | + + +## Output Files + +### Model Files +- `best_cpd_model_ms.ckpt` - [best model weights](https://download.mindspore.cn/mindscience/mindchemistry/bete-net) +- `fpd_training_state.json` - training state snapshot + +### 可视化文件 +- `fpd_training_progress_epoch_10.png` - intermediate snapshot +- `fpd_training_progress_epoch_20.png` - intermediate snapshot +- `fpd_final_training_results.png` - final training curves + +### Log Files +- Terminal output contains complete training log +- JSON state file stores loss history and configuration + +## Interpreting Results + +### Training Metrics +- **Training Loss**: training-set loss (should decrease steadily) +- **Validation Loss**: validation-set loss (used for early-stopping and model selection) +- **Loss Ratio**: validation / training ratio (monitors overfitting) + +### Trend Indicators +- **Decreasing**: loss is dropping, training is healthy +- **Increasing**: loss is rising, possible overfitting +- **➡️ Stable**: loss plateau, likely convergence + +### Final Evaluation +- **MAE**: Mean Absolute Error +- **RMSE**: Root-Mean-Square Error +- **R²**: Coefficient of determination (closer to 1 is better) + +# BETE-NET Inference Guide + +## Usage + +```bash +export PYTHONPATH=$PYTHONPATH:../../ +python eval.py --output_dir +``` + +### Argument Reference + +| Argument | Type | Default | Description | +| --------------- | ---- | ------------------ | ----------------------------------- | +| `--output_dir` | str | inference\_results | Directory to save results | + +## Output Files + +After running, the specified directory will contain: + +### 1. Visualization +- **Filename**: `{MODEL_TYPE}_inference_results.png` +- **Content**: three scatter plots (λ, ω_log, ω_2) +- **Format**: high-resolution PNG (300 DPI) + +### 2. Detailed Results +- **Filename**: `{MODEL_TYPE}_detailed_results.csv` +- **Content**: predicted vs. true values for every sample +- **Purpose**: further analysis and post-processing + +### 3. Summary Report +- **Filename**: `{MODEL_TYPE}_summary_metrics.txt` +- **Content**: global and per-metric evaluation results +- **Format**: plain text, human-readable \ No newline at end of file diff --git a/MindChemistry/applications/bete-net/README_CN.md b/MindChemistry/applications/bete-net/README_CN.md new file mode 100644 index 0000000000000000000000000000000000000000..ffe608834c51599d2ed95adf152e5dd70bdbcd0c --- /dev/null +++ b/MindChemistry/applications/bete-net/README_CN.md @@ -0,0 +1,98 @@ +# BETE-NET 训练指南 +## 概述 +提出了一种将深度学习用于发现电子-声子耦合超导体的策略,核心挑战是 α²F(ω) 计算成本极高。 +这里采用两步法:先对 818 个动态稳定材料计算 α²F(ω),再训练名为 BETE-NET的深度学习模型预测 α²F(ω)。 +BETE-NET 通过 双分支 GNN 分别编码电子-声子谱相互作用,利用 谱函数注意力 和 温度退火 策略,高效预测超导体的 Tc。 +其创新点在于将物理谱函数直接嵌入图卷积过程,而非作为后处理特征。 +![model_structure](./images/model_structure.png) + +## 快速启动 + +```bash +export PYTHONPATH=$PYTHONPATH:../../ +python train.py +``` + +## 训练进度显示功能 + +### 实时进度条 +- **每个epoch显示**: 训练和验证的实时进度条 +- **损失更新**: 实时显示当前样本损失和平均损失 +- **时间估算**: 每个epoch的运行时间 + +## 配置选项 + +### 训练参数 +| 参数 | 默认值 | 描述 | +| ----------------- | ------ | ----------- | +| epochs | 100 | 训练轮数 | +| display_interval | 5 | 详细报告间隔 | +| plot_interval | 10 | 绘图间隔 | +| learning_rate | 0.0005 | 优化器学习率(FPD) | +| patience | 20 | 早停耐心值 | + +## 输出文件 + +### 模型文件 +- `best_cpd_model_ms.ckpt` - [最佳模型权重](https://download.mindspore.cn/mindscience/mindchemistry/bete-net) +- `fpd_training_state.json` - 训练状态 + +### 可视化文件 +- `fpd_training_progress_epoch_10.png` - 中间进度图 +- `fpd_training_progress_epoch_20.png` - 中间进度图 +- `fpd_final_training_results.png` - 最终结果图 + +### 日志文件 +- 终端输出包含完整训练日志 +- JSON状态文件包含损失历史和配置 + +## 结果解读 + +### 训练指标 +- **Training Loss**: 训练集损失,应该持续下降 +- **Validation Loss**: 验证集损失,用于早停和模型选择 +- **Loss Ratio**: 验证/训练损失比,监控过拟合 + +### 趋势分析 +- **Decreasing**: 损失在下降,训练正常 +- **Increasing**: 损失在上升,可能过拟合 +- **➡️ Stable**: 损失稳定,可能收敛 + +### 最终评估 +- **MAE**: 平均绝对误差 +- **RMSE**: 均方根误差 +- **R²**: 决定系数(越接近1越好) + +# BETE-NET 推理指南 + +## 使用方法 + +```bash +export PYTHONPATH=$PYTHONPATH:../../ +python eval.py --output_dir +``` + +### 参数说明 + +| 参数 | 类型 | 默认值 | 说明 | +|------|------|--------|------| +| `--output_dir` | str | inference_results | 结果输出目录 | + +## 输出文件说明 + +运行后会在指定目录生成以下文件: + +### 1. 可视化图片 +- **文件名**: `{MODEL_TYPE}_inference_results.png` +- **内容**: 三个散点图 (λ, ω_log, ω_2) +- **格式**: 高分辨率PNG (300 DPI) + +### 2. 详细结果数据 +- **文件名**: `{MODEL_TYPE}_detailed_results.csv` +- **内容**: 每个样本的预测值和真实值 +- **用途**: 进一步分析和后处理 + +### 3. 汇总指标报告 +- **文件名**: `{MODEL_TYPE}_summary_metrics.txt` +- **内容**: 整体和分项评估指标 +- **格式**: 文本格式,易于阅读 \ No newline at end of file diff --git a/MindChemistry/applications/bete-net/eval.py b/MindChemistry/applications/bete-net/eval.py new file mode 100644 index 0000000000000000000000000000000000000000..8c12dee274d81ab66bc4c3f96aa4fb2edd96b04a --- /dev/null +++ b/MindChemistry/applications/bete-net/eval.py @@ -0,0 +1,375 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import os +import sys +import argparse +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +from tqdm import tqdm +from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score +import ase.io +import mindspore as ms +from gnn.data import Batch + +ms.set_context(mode=ms.PYNATIVE_MODE) +sys.path.append('./src') + +from src import utils_data_ms as data_utils +from src import utils_model_ms as model_utils + +def cal_mae_rmse_r2(dft_val, pred): + """Calculate MAE, RMSE, and R²""" + + valid_mask = ~(np.isnan(dft_val) | np.isnan(pred)) + dft_clean = dft_val[valid_mask] + pred_clean = pred[valid_mask] + + if len(dft_clean) == 0: + return np.nan, np.nan, np.nan + + mae = mean_absolute_error(dft_clean, pred_clean) + rmse = mean_squared_error(dft_clean, pred_clean) ** 0.5 + r2 = r2_score(dft_clean, pred_clean) + return mae, rmse, r2 + + +def add_metrics(title, mae, r2, ax, rmse, unit='', test=True, fontsize=10.5): + """Add metrics to plot""" + + mae_str = f'{mae:.3f}' if not np.isnan(mae) else 'NaN' + rmse_str = f'{rmse:.3f}' if not np.isnan(rmse) else 'NaN' + r2_str = f'{r2:.3f}' if not np.isnan(r2) else 'NaN' + + if test: + text = f'{title}\nMAE = {mae_str} {unit}\nRMSE = {rmse_str} {unit}\nR² = {r2_str}' + else: + text = f'{title}\nMAE = {mae_str} {unit}\nR² = {r2_str}' + + ax.text(0.05, 0.95, text, transform=ax.transAxes, fontsize=fontsize, + verticalalignment='top', bbox=dict(boxstyle='round', facecolor='white', alpha=0.8)) + + +def main(): + """Main function""" + parser = argparse.ArgumentParser( + description='BETE-NET FPD Model Inference') + parser.add_argument('--output_dir', type=str, default='fpd_inference_results', + help='Output directory for results') + + args = parser.parse_args() + + logging.info("BETE-NET FPD Model Inference") + logging.info("=" * 50) + logging.info(f"Output directory: {args.output_dir}") + + os.makedirs(args.output_dir, exist_ok=True) + + logging.info("Loading database...") + df = pd.read_json('database.json') + df.dropna(inplace=True) + + logging.info("Loading structures...") + structures = [] + for index, row in tqdm(df.iterrows(), total=len(df), desc="Loading structures"): + try: + structure = ase.io.read(f'structures/{index}.cif') + structures.append(structure) + except Exception as e: + structures.append(None) + + df['structure'] = structures + df = df.dropna(subset=['structure']) + logging.info(f"Successfully loaded {len(df)} structures") + + logging.info("Processing FPD model data...") + df['target'] = df.apply(data_utils.get_target, axis=1) + df['formula'] = df['structure'].map(lambda x: x.get_chemical_formula()) + df['species'] = df['structure'].map( + lambda x: list(set(x.get_chemical_symbols()))) + + r_max = 4 + embed_ph_dos = True + embed_e_dos = False + fine = True + + logging.info( + f"Building graph data (r_max={r_max}, embed_ph_dos={embed_ph_dos}, fine={fine})...") + tqdm.pandas() + df['data'] = df.progress_apply( + data_utils.build_data, + embed_ph_dos=embed_ph_dos, + embed_e_dos=embed_e_dos, + fine=fine, + r_max=r_max, + axis=1 + ) + + sample_data = df.iloc[0]['data'] + out_dim = len(df.iloc[0]['target']) + in_dim = sample_data.x.shape[1] + em_dim = 64 + logging.info(f"Data dimensions:") + logging.info(f" - Input features: {in_dim}") + logging.info(f" - Output targets: {out_dim}") + logging.info(f" - Total samples: {len(df)}") + logging.info("Using original paper's data split...") + train_df, test_df, _ = data_utils.get_original_data_split(df) + + def create_batches(df, batch_size, shuffle=True): + """Create batches from dataframe using gnn's Batch.from_data_list""" + indices = df.index.tolist() + if shuffle: + np.random.shuffle(indices) + + batches = [] + for i in range(0, len(indices), batch_size): + batch_indices = indices[i:i + batch_size] + + if len(batch_indices) == 1: + # Single sample - no batching needed + data = df.loc[batch_indices[0], 'data'] + target = ms.Tensor( + [df.loc[batch_indices[0], 'target']], dtype=ms.float32) + batches.append((data, target)) + else: + # Multiple samples - use gnn batching + data_list = [df.loc[idx, 'data'] for idx in batch_indices] + targets = [df.loc[idx, 'target'] for idx in batch_indices] + + batch_data = Batch.from_data_list(data_list) + batch_targets = ms.Tensor(targets, dtype=ms.float32) + batches.append((batch_data, batch_targets)) + + return batches + + test_batch = create_batches(test_df, batch_size=128, shuffle=False) + + val_split_idx = int(len(train_df) * 0.8) + val_df = train_df.iloc[val_split_idx:].copy() + train_subset_df = train_df.iloc[:val_split_idx].copy() + + logging.info("Creating FPD model...") + model_params = { + 'in_dim': 118+51, + 'em_dim': em_dim, + 'irreps_in': f'{em_dim}x0e', + 'irreps_out': f'{out_dim}x0e', + 'irreps_node_attr': f'{em_dim}x0e', + 'layers': 2, + 'mul': 32, + 'lmax': 1, + 'max_radius': r_max, + 'number_of_basis': 10, + 'radial_layers': 1, + 'radial_neurons': 128, + 'num_neighbors': data_utils.get_neighbors(train_df, train_df.index).mean(), + 'num_nodes': 8.0, + 'reduce_output': True, + 'dropout': False + } + + model = model_utils.PeriodicNetwork( + **{k: v for k, v in model_params.items() if k not in ['input_dim', 'output_dim']}) + param_count = sum(p.size for p in model.get_parameters()) + + model.set_train(False) + + predictions = [] + targets = [] + + num_fold = 10 + + folds = range(num_fold) + for k in tqdm(folds): + test_batch = create_batches(test_df, batch_size=128, shuffle=False) + name = f"best_fpd_model_ms_{k}.ckpt" + run_name = f'./fpd/{name}' + ms.load_checkpoint(run_name, model) + logging.info("Weights loaded successfully") + + for data, target in tqdm(test_batch, desc="Inferring"): + pred = model(data) + pred_np = pred.asnumpy() + + predictions.append(pred_np) + targets.append(target) + + predictions = np.concatenate(predictions) + targets = ms.ops.concat(targets) + targets = np.array(targets) + + logging.info(f"\nCalculating physical quantities from a2F spectra...") + + target_properties = [] + for i, target_spectrum in enumerate(targets): + lamb, w_log, w_2 = data_utils.compute_physical_properties( + target_spectrum) + target_properties.append([lamb, w_log, w_2]) + if i < 3: + logging.info( + f"Target {i+1}: λ={lamb:.4f}, ω_log={w_log:.1f}K, ω_2={w_2:.1f}K") + + pred_properties = [] + for i, pred_spectrum in enumerate(predictions): + lamb, w_log, w_2 = data_utils.compute_physical_properties( + pred_spectrum) + pred_properties.append([lamb, w_log, w_2]) + if i < 3: + logging.info( + f"Pred {i+1}: λ={lamb:.4f}, ω_log={w_log:.1f}K, ω_2={w_2:.1f}K") + + target_properties = np.array(target_properties) + pred_properties = np.array(pred_properties) + + target_names = ['lamb', 'wlog', 'w2'] + + test_df = test_df.copy() + for i, prop in enumerate(target_names): + test_df[f'{prop}_target'] = target_properties[:, + i].reshape((-1, num_fold)).mean(axis=1) + test_df[f'{prop}_pred'] = pred_properties[:, + i].reshape((-1, num_fold)).mean(axis=1) + + logging.info(f"\nFPD Model Test Results:") + + all_targets = target_properties.flatten() + all_preds = pred_properties.flatten() + + valid_mask = ~(np.isnan(all_targets) | np.isnan(all_preds)) + all_targets_clean = all_targets[valid_mask] + all_preds_clean = all_preds[valid_mask] + + if len(all_targets_clean) > 0: + overall_mae = mean_absolute_error(all_targets_clean, all_preds_clean) + overall_rmse = mean_squared_error( + all_targets_clean, all_preds_clean) ** 0.5 + overall_r2 = r2_score(all_targets_clean, all_preds_clean) + else: + overall_mae = overall_rmse = overall_r2 = np.nan + + logging.info(f" - Overall MAE: {overall_mae:.6f}") + logging.info(f" - Overall RMSE: {overall_rmse:.6f}") + logging.info(f" - Overall R²: {overall_r2:.6f}") + logging.info(f" - Valid samples: {len(all_targets_clean)}/{len(all_targets)}") + + logging.info("Generating visualization...") + + plt.rcParams.update({'font.size': 12}) + plt.rcParams["font.family"] = 'DejaVu Sans' + + fig, axs = plt.subplots(1, 3, figsize=(15, 5)) + plt.subplots_adjust(left=0.05, bottom=0.15, right=0.99, top=0.90) + + prop = ['lamb', 'wlog', 'w2'] + color = ['C0', 'C5', 'C8'] + units = ['', 'K', 'K'] + titles = [r'$\lambda$', r'$\omega_{\text{log}}$', r'$\omega_{2}$'] + + for i in range(3): + if i == 0: + lim = (0, 2) + ticks = [0, 0.5, 1.0, 1.5, 2.0] + elif i == 1: + lim = (0, 550) + ticks = np.arange(0, 600, 200) + else: + lim = (0, 700) + ticks = np.arange(0, 800, 200) + + mae, rmse, r2 = cal_mae_rmse_r2( + dft_val=test_df[f'{prop[i]}_target'], + pred=test_df[f'{prop[i]}_pred'] + ) + + valid_mask = ~(np.isnan(test_df[f'{prop[i]}_target']) | np.isnan( + test_df[f'{prop[i]}_pred'])) + n_valid = valid_mask.sum() + n_total = len(test_df) + + valid_targets = test_df[f'{prop[i]}_target'][valid_mask] + valid_preds = test_df[f'{prop[i]}_pred'][valid_mask] + + if len(valid_targets) > 0: + axs[i].scatter(valid_targets, valid_preds, + marker='.', color=color[i], alpha=0.5) + + add_metrics(title=f'Test (n={n_valid}/{n_total})', mae=mae, r2=r2, ax=axs[i], + rmse=rmse, unit=units[i], test=True, fontsize=10.5) + + axs[i].plot(lim, lim, 'k--', zorder=0) + + axs[i].set_xlim(lim) + axs[i].set_ylim(lim) + axs[i].set_xticks(ticks) + axs[i].set_yticks(ticks) + axs[i].set_aspect('equal') + + if i == 0: + axs[i].set_ylabel(f'Predicted {titles[i]}') + else: + axs[i].set_ylabel(f'Predicted {titles[i]} (K)') + + if i == 2: + axs[i].set_xlabel('Target') + + logging.info( + f" - {prop[i].upper()} - MAE: {mae:.6f}, RMSE: {rmse:.6f}, R²: {r2:.6f}") + + plot_path = os.path.join(args.output_dir, 'FPD_prediction_results.png') + plt.tight_layout() + plt.savefig(plot_path, dpi=300, bbox_inches='tight') + logging.info(f"Plot saved to: {plot_path}") + + results_path = os.path.join(args.output_dir, 'FPD_detailed_results.csv') + test_df.to_csv(results_path, index=False) + logging.info(f"Detailed results saved to: {results_path}") + + summary_path = os.path.join(args.output_dir, 'FPD_summary_metrics.txt') + with open(summary_path, 'w', encoding='utf-8') as f: + f.write("BETE-NET FPD Model - Prediction Results\n") + f.write("=" * 50 + "\n\n") + f.write(f"Model Configuration: FPD (Fine PhDOS)\n") + f.write(f"Test Samples: {len(test_df)}\n") + f.write(f"Model Parameters: {param_count:,}\n\n") + + f.write("Overall Metrics:\n") + f.write(f" MAE: {overall_mae:.6f}\n") + f.write(f" RMSE: {overall_rmse:.6f}\n") + f.write(f" R2: {overall_r2:.6f}\n\n") + + f.write("Individual Property Metrics:\n") + for i, prop in enumerate(prop): + mae, rmse, r2 = cal_mae_rmse_r2( + dft_val=test_df[f'{prop}_target'], + pred=test_df[f'{prop}_pred'] + ) + f.write(f" {prop.upper()}:\n") + f.write(f" MAE: {mae:.6f} {units[i]}\n") + f.write(f" RMSE: {rmse:.6f} {units[i]}\n") + f.write(f" R2: {r2:.6f}\n\n") + + logging.info(f"Summary metrics saved to: {summary_path}") + + plt.show() + + logging.info(f"\nFPD model inference completed!") + logging.info(f"All results saved in: {args.output_dir}") + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindChemistry/applications/bete-net/images/model_structure.png b/MindChemistry/applications/bete-net/images/model_structure.png new file mode 100644 index 0000000000000000000000000000000000000000..b7c32ba449c21c0b6b7bc16998d31801585f7d81 Binary files /dev/null and b/MindChemistry/applications/bete-net/images/model_structure.png differ diff --git a/MindChemistry/applications/bete-net/requirements.txt b/MindChemistry/applications/bete-net/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..da6f7ae851f44842d458692e1b3f902be3a7d09b --- /dev/null +++ b/MindChemistry/applications/bete-net/requirements.txt @@ -0,0 +1,6 @@ +ase==3.22.0 +notebook==6.4.5 +pip==21.3.1 +plotly==5.3.1 +mindspore==2.5.0 +mindchemsitry==0.2.0 diff --git a/MindChemistry/applications/bete-net/src/conv_e3nn.py b/MindChemistry/applications/bete-net/src/conv_e3nn.py new file mode 100644 index 0000000000000000000000000000000000000000..6bfaf544af0a5a5515a05672bc45f7d019ff1f69 --- /dev/null +++ b/MindChemistry/applications/bete-net/src/conv_e3nn.py @@ -0,0 +1,122 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import math +from mindspore import nn, ops, float32 +from mindchemistry.graph.graph import AggregateEdgeToNode +from mindchemistry.e3.o3 import TensorProduct, Irreps, FullyConnectedTensorProduct +from mindchemistry.e3.nn import FullyConnectedNet + +softplus = ops.Softplus() + + +def shift_softplus(x): + return softplus(x) - 0.6931471805599453 + + +def silu(x): + return x * ops.sigmoid(x) + + +class Convolution(nn.Cell): + """convolution""" + + def __init__(self, + irreps_node_input, + irreps_node_attr, + irreps_node_output, + irreps_edge_attr, + irreps_edge_scalars, + invariant_layers=1, + invariant_neurons=8, + avg_num_neighbors=None, + use_sc=True, + nonlin_scalars=None, + dtype=float32, + ncon_dtype=float32): + super().__init__() + self.avg_num_neighbors = avg_num_neighbors + self.use_sc = use_sc + + self.irreps_node_input = Irreps(irreps_node_input) + self.irreps_node_attr = Irreps(irreps_node_attr) + self.irreps_node_output = Irreps(irreps_node_output) + self.irreps_edge_attr = Irreps(irreps_edge_attr) + self.irreps_edge_scalars = Irreps( + [(irreps_edge_scalars.num_irreps, (0, 1))]) + self.lin1 = FullyConnectedTensorProduct( + self.irreps_node_input, self.irreps_node_attr, self.irreps_node_input) + + irreps_mid = [] + instructions = [] + for i, (mul, ir_in) in enumerate(self.irreps_node_input): + for j, (_, ir_edge) in enumerate(self.irreps_edge_attr): + for ir_out in ir_in * ir_edge: + if ir_out in self.irreps_node_output: + k = len(irreps_mid) + irreps_mid.append((mul, ir_out)) + instructions.append((i, j, k, "uvu", True)) + irreps_mid = Irreps(irreps_mid) + irreps_mid, p, _ = irreps_mid.sort() + instructions = [(i_1, i_2, p[i_out], mode, train) + for i_1, i_2, i_out, mode, train in instructions] + + tp = TensorProduct(self.irreps_node_input, + self.irreps_edge_attr, + irreps_mid, + instructions, + weight_mode='custom', + dtype=dtype, + ncon_dtype=ncon_dtype) + + self.fc = FullyConnectedNet([self.irreps_edge_scalars.num_irreps] + invariant_layers * [invariant_neurons] + + [tp.weight_numel], { + "ssp": shift_softplus, + "silu": ops.silu, + }.get(nonlin_scalars.get("e", None), None), dtype=dtype) + + self.tp = tp + self.scatter = AggregateEdgeToNode(dim=1) + + self.lin2 = FullyConnectedTensorProduct(tp.irreps_out.simplify(), + self.irreps_node_attr, self.irreps_node_output) + + self.sc = None + if self.use_sc: + self.sc = FullyConnectedTensorProduct( + self.irreps_node_input, self.irreps_node_attr, self.irreps_node_output) + + def construct(self, node_input, node_attr, edge_src, edge_dst, edge_attr, edge_scalars): + """Evaluate interaction Block with resnet""" + weight = self.fc(edge_scalars) + node_features = self.lin1(node_input, node_attr) + edge_features = self.tp(node_features[edge_src], edge_attr, weight) + node_features = self.scatter(edge_attr=edge_features, edge_index=[edge_src, edge_dst], + dim_size=node_input.shape[0]) + + if self.avg_num_neighbors is not None: + node_features = node_features.div(self.avg_num_neighbors**0.5) + + node_features = self.lin2(node_features, node_attr) + + if self.sc is not None: + sc = self.sc(node_input, node_attr) + + c_s, c_x = math.sin(math.pi / 8), math.cos(math.pi / 8) + m = self.sc.output_mask + c_x = (1 - m) + c_x * m + node_features = c_s * sc + c_x * node_features + + return node_features diff --git a/MindChemistry/applications/bete-net/src/utils_data_ms.py b/MindChemistry/applications/bete-net/src/utils_data_ms.py new file mode 100644 index 0000000000000000000000000000000000000000..4fd1567ec3ad59bf59eaaee4e44d134fbbe5ade5 --- /dev/null +++ b/MindChemistry/applications/bete-net/src/utils_data_ms.py @@ -0,0 +1,278 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import numpy as np +from scipy.signal import savgol_filter +from ase.neighborlist import neighbor_list +from ase import Atom +import mindspore as ms +from gnn.data import Graph as Data + +# Define frequency ranges +Freq_final = np.arange(0.25, 101, 2) +Freq_final_E = np.arange(-50, 50, 1) + +# Create atomic-type encodings +type_encoding = {} +specie_am = [] +for Z in range(1, 119): + specie = Atom(Z) + type_encoding[specie.symbol] = Z - 1 + specie_am.append(specie.mass) + +type_onehot = ms.ops.eye(len(type_encoding), len(type_encoding), ms.float32) +am_onehot = ms.ops.diag(ms.Tensor(specie_am, dtype=ms.float32)) + +def build_data(entry, r_max=4.0, embed_ph_dos=True, embed_e_dos=True, fine=False): + """Builds a Data object from a given entry.""" + symbols = list(entry.structure.symbols).copy() + positions = ms.Tensor(entry.structure.positions.copy(), dtype=ms.float32) + lattice = ms.Tensor(entry.structure.cell.array.copy(), dtype=ms.float32).expand_dims(0) + + # Compute edge source and target indices + edge_src, edge_dst, edge_shift = neighbor_list( + "ijS", a=entry.structure, cutoff=r_max, self_interaction=True + ) + + # Compute relative distances and periodic-boundary displacements + edge_batch = ms.ops.zeros(positions.shape[0], dtype=ms.int64)[ + ms.Tensor(edge_src) + ] + + # Replace einsum with matrix multiplication + edge_shift_tensor = ms.Tensor(edge_shift, dtype=ms.float32) + lattice_batch = lattice[edge_batch] # shape: [n_edges, 3, 3] + + # Compute edge_shift @ lattice + lattice_contribution = ms.mint.einsum('nj,nij->nj', edge_shift_tensor, lattice_batch) + edge_vec = ( + positions[ms.Tensor(edge_dst)] + - positions[ms.Tensor(edge_src)] + + lattice_contribution + ) + + # Compute edge lengths + edge_len = np.around(edge_vec.norm(dim=1).asnumpy(), decimals=2) + + # Build node features + x = am_onehot[[type_encoding[specie] for specie in symbols]].astype(ms.float32) + z = type_onehot[[type_encoding[specie] for specie in symbols]].astype(ms.float32) + + if embed_ph_dos and embed_e_dos: + p_ph_dos = process_phdos(entry, fine=fine) + p_e_dos = process_edos(entry, fine=fine) + + x = ms.ops.concat((x, ms.ops.ones_like(p_ph_dos), ms.ops.ones_like(p_e_dos)), 1) + z = ms.ops.concat((z, p_ph_dos, p_e_dos), 1) + + elif embed_ph_dos: + p_ph_dos = process_phdos(entry, fine=fine) + x = ms.ops.concat((x, ms.ops.ones_like(p_ph_dos)), 1) + z = ms.ops.concat((z, p_ph_dos), 1) + + elif embed_e_dos: + p_e_dos = process_edos(entry, fine=fine) + x = ms.ops.concat((x, ms.ops.ones_like(p_e_dos)), 1) + z = ms.ops.concat((z, p_e_dos), 1) + + data = Data( + x=x, + edge_index=ms.ops.stack( + [ms.Tensor(edge_src, dtype=ms.int64), ms.Tensor(edge_dst, dtype=ms.int64)], axis=0 + ), + edge_attr=None, + y=ms.Tensor(np.asarray(entry.target), dtype=ms.float32).expand_dims(0), + crd=positions, + pos=positions, + lattice=lattice, + symbol=symbols, + z=z, + edge_shift=ms.Tensor(edge_shift, dtype=ms.float32), + edge_vec=edge_vec, + edge_len=edge_len, + target=ms.Tensor(np.asarray(entry.target), dtype=ms.float32).expand_dims(0), + ) + return data + +def get_target(df): + """Interpolates and smooths the Eliashberg function (a²F) from DataFrame.""" + x = df.Freq_meV + y = df.a2F + xl = np.arange(0.25, 101, 0.1) + y = np.interp(xl, x, y) + Y = savgol_filter(y, 101, 3, mode="interp") + Y = np.interp(Freq_final, xl, Y) + Y = np.asarray([y if y > 0.0 else 0.0 for y in Y]) + return Y + +def get_neighbors(df, idx): + """Returns the number of neighbors for each atom in the dataset.""" + n = [] + for entry in df.itertuples(): + N = entry.data.pos.shape[0] + for i in range(N): + n.append(len((entry.data.edge_index[0] == i).nonzero())) + return np.array(n) + +def get_phdos(df): + """Interpolates and smooths the total phonon DOS from DataFrame.""" + x = df.PhFreq_meV + y = df.Tot_PhDOS + xl = np.arange(0.25, 101, 0.1) + y = np.interp(xl, x, y) + Y = savgol_filter(y, 101, 3, mode="interp") + Y = np.interp(Freq_final, xl, Y) + return np.asarray([y if y > 0.0 else 0.0 for y in Y]) + +def process_phdos(entry, fine=False): + """Processes site-projected phonon DOS into fixed-length vectors.""" + Y_proc = [] + if fine: + x = entry.PhFreq_meV_dense + ys = entry.Site_Proj_PhDOS_dense + else: + x = entry.Ph_2x2x2_interpolated_Freq_meV + ys = entry.Ph_2x2x2_interpolated_Site_Proj_DOS + + for y in ys: + xl = np.arange(0.25, 101, 0.1) + y = np.interp(xl, x, y) + window_length = min(101, len(y)) + if window_length % 2 == 0: + window_length -= 1 + window_length = max(3, window_length) + Y = savgol_filter(y, 101, 3, mode="interp") + Y = np.interp(Freq_final, xl, Y) + Y = [y if y > 0.0 else 0.0 for y in Y] + Y_proc.append(Y.copy()) + return ms.Tensor(Y_proc, dtype=ms.float32) + +def process_edos(entry, fine=False): + """Processes site-projected electronic DOS into fixed-length vectors.""" + ys = entry.Site_proj_eDOS + x = entry.Site_proj_eDOS_eng_meV + Y_proc = [] + + for y in ys: + window_length = min(101, len(y)) + if window_length % 2 == 0: + window_length -= 1 + window_length = max(3, window_length) + Y = savgol_filter(y, 101, 3, mode="interp") + Y = np.interp(Freq_final_E, x, Y) + Y = [y if y > 0.0 else 0.0 for y in Y] + Y_proc.append(Y.copy()) + return ms.Tensor(Y_proc, dtype=ms.float32) + +def load_data_splits(idx=None): + """Loads train/validation/test indices from disk.""" + if isinstance(idx, int): + idx_train = np.loadtxt(f'indices/idx_train_V2_{idx}.txt').astype(int) + idx_valid = np.loadtxt(f'indices/idx_valid_V2_{idx}.txt').astype(int) + idx_test = np.loadtxt(f'indices/idx_test_full.txt').astype(int) + + else: + idx_test = np.loadtxt('indices/idx_test_full.txt').astype(int) + idx_train = np.loadtxt('indices/idx_train_full.txt').astype(int) + idx_valid = None + + return idx_train, idx_test, idx_valid + +def get_original_data_split(df, idx=None): + """Retrieves original train/validation/test splits, accounting for missing samples.""" + idx_train, idx_test, idx_valid = load_data_splits(idx) + available_train_indices = [idx for idx in idx_train if idx in df.index] + available_test_indices = [idx for idx in idx_test if idx in df.index] + + missing_train = len(idx_train) - len(available_train_indices) + missing_test = len(idx_test) - len(available_test_indices) + + train_df = df.loc[available_train_indices].copy() + test_df = df.loc[available_test_indices].copy() + + if idx_valid is not None: + available_val_indices = [idx for idx in idx_valid if idx in df.index] + missing_val = len(idx_valid) - len(available_val_indices) + val_df = df.loc[available_val_indices].copy() + logging.info(f" - Validation: {len(val_df)} samples (original: {len(idx_valid)}, missing: {missing_val})") + else: + val_df = None + + + logging.info(f"Original data split (after handling missing samples):") + logging.info(f" - Train: {len(train_df)} samples (original: {len(idx_train)}, missing: {missing_train})") + logging.info(f" - Test: {len(test_df)} samples (original: {len(idx_test)}, missing: {missing_test})") + if idx_valid is None: + logging.info(f" - No overlap (train vs test): {len(set(train_df.index) & set(test_df.index)) == 0}") + else: + logging.info(f" - No overlap (train, val, test): \ + {len(set(train_df.index) & set(test_df.index)) + (len(set(train_df.index) & set(val_df.index))) == 0}") + + if missing_train > 0 or missing_test > 0: + logging.info(f" Database version discrepancy: {missing_train + missing_test} samples missing in total") + logging.info(f" Using available samples while maintaining original split strategy") + + return train_df, test_df, val_df + +def cal_lamb(freq_w, alpha_F): + """Computes the electron-phonon coupling constant λ from the a²F spectrum.""" + lambdaF = 0 + try: + for i in range(1, len(freq_w)): + dw = freq_w[i] - freq_w[i-1] + w = freq_w[i] + alpha_F_w = alpha_F[i] + lambdaF = lambdaF + ((alpha_F_w/w)*dw) + return 2*lambdaF + except: + return np.nan + +def cal_w_log(freq_w, alpha_F, lamb): + """Computes the logarithmic average frequency ω_log (in K) from the a²F spectrum.""" + w_logF = 0 + try: + for i in range(1, len(freq_w)): + dw = freq_w[i] - freq_w[i-1] + w_logF = w_logF + (alpha_F[i]*np.log(freq_w[i])*dw/freq_w[i]) + return np.exp(2*w_logF/lamb) + except: + return np.nan + +def cal_w_sq(freq_w, alpha_F, lamb): + """Computes the second-moment frequency ω₂ (in K) from the a²F spectrum.""" + w_sqF = 0 + try: + for i in range(1, len(freq_w)): + dw = freq_w[i] - freq_w[i-1] + w_sqF = w_sqF + (alpha_F[i]*freq_w[i]*dw) + return (2*w_sqF/lamb)**.5 + except: + return np.nan + +def compute_physical_properties(a2F_spectrum): + """Computes physical quantities λ, ω_log, and ω₂ from the a²F spectrum.""" + frequency = Freq_final + + # Compute λ + lamb = cal_lamb(frequency, a2F_spectrum) + + # Compute ω_log (convert to Kelvin) + w_log = cal_w_log(frequency, a2F_spectrum, lamb) / 0.08617 if lamb > 0 else np.nan + + # Compute ω₂ (convert to Kelvin) + w_2 = cal_w_sq(frequency, a2F_spectrum, lamb) / 0.08617 if lamb > 0 else np.nan + + return lamb, w_log, w_2 \ No newline at end of file diff --git a/MindChemistry/applications/bete-net/src/utils_model_ms.py b/MindChemistry/applications/bete-net/src/utils_model_ms.py new file mode 100644 index 0000000000000000000000000000000000000000..1a846e3392a9ed8b321209c05d492844db4308d6 --- /dev/null +++ b/MindChemistry/applications/bete-net/src/utils_model_ms.py @@ -0,0 +1,365 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import math +import mindspore as ms +import mindspore.nn as nn +from mindchemistry.e3 import o3 +from mindchemistry.e3.nn.one_hot import soft_one_hot_linspace +from mindchemistry.e3.nn import Gate +from gnn.utils import scatter +from conv_e3nn import Convolution + +default_dtype = ms.float32 + + +def smooth_cutoff(x): + """smooth_cutoff""" + u = 2 * (x - 1) + y = (math.pi * u).cos().neg().add(1).div(2) + y[u > 0] = 0 + y[u < -1] = 1 + return y + + +def tp_path_exists(irreps1, irreps2, irreps_out): + """tp_path_exists""" + irreps1 = o3.Irreps(irreps1).simplify() + irreps2 = o3.Irreps(irreps2).simplify() + irreps_out = o3.Irrep(irreps_out) + + for _, ir1 in irreps1: + for _, ir2 in irreps2: + if irreps_out in ir1 * ir2: + return True + return False + + +def radius_graph(pos, r, batch=None): + """radius_graph""" + n = pos.shape[0] + + pos_expanded = pos.unsqueeze(1) + pos_expanded_t = pos.unsqueeze(0) + distances = ms.ops.norm(pos_expanded - pos_expanded_t, dim=2) + + mask = (distances <= r).astype(ms.int64) & ( + distances > 0).astype(ms.int64) + + edge_indices = ms.ops.nonzero(mask) + if edge_indices.shape[0] == 0: + return ms.ops.zeros((2, 0), dtype=ms.int64) + + return edge_indices.T + + +class CustomCompose(nn.Cell): + """CustomCompose""" + + def __init__(self, first, second): + super().__init__() + self.first = first + self.second = second + self.irreps_in = first.irreps_node_input + self.irreps_out = second.irreps_out + + def construct(self, *input): + x = self.first(*input) + self.first_out = x + x = self.second(x) + self.second_out = x + return x + + +class Dropout(nn.Cell): + """Dropout""" + + def __init__(self, irreps, p): + super().__init__() + self.irreps = irreps + if p <= 0 or p >= 1: + self.dropout = None + else: + self.dropout = nn.Dropout(p=p) + + def construct(self, x): + if self.dropout is not None: + return self.dropout(x) + else: + return x + + +class Network(nn.Cell): + """Network""" + + def __init__( + self, + irreps_in, + irreps_out, + irreps_node_attr, + layers, + mul, + lmax, + max_radius, + number_of_basis=10, + radial_layers=1, + radial_neurons=100, + num_neighbors=1.0, + num_nodes=1.0, + reduce_output=True, + p=0.2, + dropout=False + ): + super().__init__() + self.dropout = dropout + self.mul = mul + self.lmax = lmax + self.max_radius = max_radius + self.number_of_basis = number_of_basis + self.num_neighbors = num_neighbors + self.num_nodes = num_nodes + self.reduce_output = reduce_output + + self.irreps_in = o3.Irreps( + irreps_in) if irreps_in is not None else None + self.irreps_hidden = o3.Irreps( + [(self.mul, (l, p)) for l in range(lmax + 1) for p in [-1, 1]] + ) + self.irreps_out = o3.Irreps(irreps_out) + self.irreps_node_attr = ( + o3.Irreps(irreps_node_attr) + if irreps_node_attr is not None + else o3.Irreps("0e") + ) + self.irreps_edge_attr = o3.Irreps.spherical_harmonics(lmax) + + self.input_has_node_in = irreps_in is not None + self.input_has_node_attr = irreps_node_attr is not None + + irreps = self.irreps_in if self.irreps_in is not None else o3.Irreps( + "0e") + + act = {1: ms.ops.silu, -1: ms.ops.tanh} + act_gates = {1: ms.ops.sigmoid, -1: ms.ops.tanh} + + self.layers = nn.CellList() + self.drop_outs = nn.CellList() + + for _ in range(layers): + irreps_scalars = o3.Irreps([ + (mul, ir) + for mul, ir in self.irreps_hidden + if ir.l == 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) + ]) + + irreps_gated = o3.Irreps([ + (mul, ir) + for mul, ir in self.irreps_hidden + if ir.l > 0 and tp_path_exists(irreps, self.irreps_edge_attr, ir) + ]) + + ir = "0e" if tp_path_exists( + irreps, self.irreps_edge_attr, "0e") else "0o" + irreps_gates = o3.Irreps([(mul, ir) for mul, _ in irreps_gated]) + + gate = Gate( + irreps_scalars, + [act[ir.p] for _, ir in irreps_scalars], + irreps_gates, + [act_gates[ir.p] for _, ir in irreps_gates], + irreps_gated, + ) + + conv = Convolution( + irreps_node_input=irreps, + irreps_node_attr=self.irreps_node_attr, + irreps_node_output=gate.irreps_in, + irreps_edge_attr=self.irreps_edge_attr, + irreps_edge_scalars=o3.Irreps(f"{number_of_basis}x0e"), + invariant_layers=radial_layers, + invariant_neurons=radial_neurons, + avg_num_neighbors=num_neighbors, + nonlin_scalars={"e": "silu", "o": "tanh"}, + ) + + irreps = gate.irreps_out + self.layers.append(CustomCompose(conv, gate)) + + self.layers.append( + Convolution( + irreps_node_input=irreps, + irreps_node_attr=self.irreps_node_attr, + irreps_node_output=self.irreps_out, + irreps_edge_attr=self.irreps_edge_attr, + irreps_edge_scalars=o3.Irreps(f"{number_of_basis}x0e"), + invariant_layers=radial_layers, + invariant_neurons=radial_neurons, + avg_num_neighbors=num_neighbors, + nonlin_scalars={"e": "silu", "o": "tanh"}, + ) + ) + self.drop_outs.append(Dropout(self.irreps_out, p)) + + def preprocess(self, data): + """preprocess""" + if hasattr(data, 'batch') and data.batch is not None: + batch = data.batch + else: + batch = ms.ops.zeros(data.pos.shape[0], dtype=ms.int64) + + if hasattr(data, 'edge_index') and data.edge_index is not None: + edge_src = data.edge_index[0] + edge_dst = data.edge_index[1] + edge_vec = data.edge_vec + else: + edge_index = radius_graph(data.pos, self.max_radius, batch) + edge_src = edge_index[0] + edge_dst = edge_index[1] + edge_vec = data.pos[edge_src] - data.pos[edge_dst] + return batch, edge_src, edge_dst, edge_vec + + def construct(self, data): + """construct""" + batch, edge_src, edge_dst, edge_vec = self.preprocess(data) + + edge_sh = o3.spherical_harmonics( + self.irreps_edge_attr, edge_vec, True, normalization="component" + ) + edge_length = edge_vec.norm(dim=1) + edge_length_embedded = soft_one_hot_linspace( + x=edge_length, + start=0.0, + end=self.max_radius, + number=self.number_of_basis, + basis="gaussian", + cutoff=True, + ) * (self.number_of_basis ** 0.5) + + edge_attr = smooth_cutoff( + edge_length / self.max_radius).unsqueeze(-1) * edge_sh + + if self.input_has_node_in and hasattr(data, 'x') and data.x is not None: + assert self.irreps_in is not None + x = data.x + else: + assert self.irreps_in is None + x = ms.ops.ones((data.pos.shape[0], 1), dtype=ms.float32) + + if self.input_has_node_attr and hasattr(data, 'z') and data.z is not None: + z = data.z + else: + assert self.irreps_node_attr == o3.Irreps("0e") + z = ms.ops.ones((data.pos.shape[0], 1), dtype=ms.float32) + + for i, lay in enumerate(self.layers): + x = lay(x, z, edge_src, edge_dst, edge_attr, edge_length_embedded) + + if self.dropout and i < len(self.drop_outs): + do = self.drop_outs[i] + x = do(x) + + if self.reduce_output: + return scatter(x, batch, dim=0, reduce="sum") / (self.num_nodes ** 0.5) + else: + return x + + +class PeriodicNetwork(Network): + """PeriodicNetwork""" + + def __init__(self, in_dim, em_dim, **kwargs): + self.pool = False + if kwargs.get("reduce_output", False) == True: + kwargs["reduce_output"] = False + self.pool = True + + super().__init__(**kwargs) + + if self.irreps_in is not None: + expected_dim = self.irreps_in.dim + else: + expected_dim = em_dim + + self.em = nn.Dense(in_dim, expected_dim) + + def construct(self, data): + """construct""" + embedded_x = nn.ReLU()(self.em(data.x)) + embedded_z = nn.ReLU()(self.em(data.z)) + + data.x = embedded_x + data.z = embedded_z + + output = super().construct(data) + output = nn.ReLU()(output) + + if self.pool == True: + if hasattr(data, 'batch') and data.batch is not None: + output = scatter(output, data.batch, dim=0, reduce="mean") + else: + output = output.mean(axis=0, keep_dims=True) + return output + + +class PeriodicNetworkPhdos(Network): + """PeriodicNetworkPhdos""" + + def __init__(self, in_dim, em_dim, out_dim, **kwargs): + self.pool = False + if kwargs.get("reduce_output", False) == True: + kwargs["reduce_output"] = False + self.pool = True + + super().__init__(**kwargs) + + self.em = nn.Dense(in_dim, em_dim) + self.output = nn.Dense(out_dim * 2, out_dim) + + def construct(self, data): + """construct""" + data.x = nn.ReLU()(self.em(data.x)) + data.z = nn.ReLU()(self.em(data.z)) + output = super().construct(data) + output = nn.ReLU()(output) + output = scatter(output, data.batch, dim=0, reduce="mean") + output = ms.ops.concat((output, data.phdos), axis=1) + output = self.output(output) + output = nn.ReLU()(output) + return output + + +class EMDLoss(nn.Cell): + """EMDLoss""" + + def __init__(self): + super().__init__() + + def construct(self, p, q): + logging.info(f' p = {p.shape}') + cdf_p = ms.ops.cumsum(p, axis=1) + cdf_q = ms.ops.cumsum(q, axis=1) + emd = ms.ops.abs(cdf_p - cdf_q).sum(axis=1).mean() + return emd + + +class WeightedMSELoss(nn.Cell): + """WeightedMSELoss""" + + def __init__(self): + super().__init__() + + def construct(self, inputs, targets, weights): + return (((inputs - targets) ** 2) * weights).mean() diff --git a/MindChemistry/applications/bete-net/src/utils_training_ms.py b/MindChemistry/applications/bete-net/src/utils_training_ms.py new file mode 100644 index 0000000000000000000000000000000000000000..2edea2d4ff1d0ac0bd11c5de084ad42dd642e499 --- /dev/null +++ b/MindChemistry/applications/bete-net/src/utils_training_ms.py @@ -0,0 +1,90 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import numpy as np +from tqdm import tqdm +import utils_model_ms + + +def get_model(init_dict): + """get_model""" + + model = utils_model_ms.Network( + irreps_in=init_dict['irreps_in'], + irreps_out=init_dict['irreps_out'], + irreps_node_attr=init_dict['irreps_node_attr'], + layers=init_dict['layers'], + mul=init_dict['mul'], + lmax=init_dict['lmax'], + max_radius=init_dict['max_radius'], + num_neighbors=init_dict['num_neighbors'], + reduce_output=init_dict['reduce_output'], + p=init_dict['p'] + ) + + return model + + +def train(model, dataloader, loss_fn, optimizer, max_iter, device=None): + """train""" + model.set_train() + + for epoch in range(max_iter): + total_loss = 0 + for batch in tqdm(dataloader, desc=f'Epoch {epoch+1}/{max_iter}'): + pred = model(batch) + target = batch.target + + loss = loss_fn(pred, target) + + optimizer.zero_grad() + loss.backward() + optimizer.step() + + total_loss += loss.asnumpy() + + avg_loss = total_loss / len(dataloader) + logging.info(f'Epoch {epoch+1}/{max_iter}, Average Loss: {avg_loss:.4f}') + + return model + + +def evaluate(model, dataloader, loss_fn, device=None): + """evaluate""" + model.set_eval() + total_loss = 0 + + for batch in dataloader: + pred = model(batch) + target = batch.target + loss = loss_fn(pred, target) + total_loss += loss.asnumpy() + + avg_loss = total_loss / len(dataloader) + + return avg_loss + + +def predict(model, dataloader, device=None): + """predict""" + model.set_eval() + predictions = [] + + for batch in dataloader: + pred = model(batch) + predictions.append(pred.asnumpy()) + + return np.concatenate(predictions, axis=0) diff --git a/MindChemistry/applications/bete-net/train.ipynb b/MindChemistry/applications/bete-net/train.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..5b33d249b1b61029f8b3f587b51ff0aab2a74456 --- /dev/null +++ b/MindChemistry/applications/bete-net/train.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "BETE-NET: Accelerating superconductor discovery through tempered deep learning of the electron-phonon spectral function\n", + "\n", + "Related Paper:https://www.nature.com/articles/s41524-024-01475-4 " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Background\n", + "Superconducting materials possess the transformative potential to revolutionize a wide range of technologies. However, \n", + "the enormous chemical and structural search space poses a severe bottleneck for both experimental and theoretical \n", + "studies. Conventional discovery routes rely either on serendipitous laboratory findings or on computationally expensive \n", + "high-throughput screening campaigns. To accelerate this process, integrating deep-learning techniques with the search \n", + "for superconductors has emerged as a promising new research frontier.\n", + "\n", + "We present a deep-learning strategy tailored for electron–phonon-coupled superconductors. \n", + "The core obstacle is the prohibitive cost of computing the Eliashberg spectral function α²F(ω). \n", + "We therefore adopt a two-step workflow:\n", + "1. First-principles calculation of α²F(ω) for 818 dynamically stable materials.\n", + "2. Training a dedicated deep-learning model—BETE-NET—to predict α²F(ω) directly from crystal structure.\n", + "\n", + "BETE-NET employs a dual-branch graph neural network to encode electron–phonon spectral interactions. By integrating \n", + "spectral-function attention and a temperature-annealing schedule, it efficiently predicts the superconducting critical \n", + "temperature Tc. Its principal innovation lies in embedding the physical spectral function directly into the graph-\n", + "convolution process rather than treating it as a post hoc feature, enabling superior accuracy even with limited data ." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Model Architecture\n", + "\n", + "The model architecture is shown below:\n", + "[model_structure](./images/model_structure.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Dataset Download\n", + "\n", + "The dataset can be downloaded from the following link: https://github.com/henniggroup/BETE-NET. \n", + "You need to download the indices and structures folders, as well as the database.json file." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Model Training\n", + "Import the required packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import argparse\n", + "import sys\n", + "\n", + "sys.path.append('./src')\n", + "\n", + "from train_utils import full_training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Configuration of Model-Related Parameters and Definition of the Training Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "def main():\n", + " \"\"\"main\"\"\"\n", + " parser = argparse.ArgumentParser(description='BETE-NET Training with Configurable Batch Size')\n", + " \n", + " parser.add_argument('--config', type=str, default='FPD', \n", + " choices=['CSO', 'CPD', 'FPD'],\n", + " help='Model configuration (CSO/CPD/FPD)')\n", + " \n", + " parser.add_argument('--batch_size', type=int, default=256,\n", + " help='Batch size for training (default: 32)')\n", + " parser.add_argument('--max_epochs', type=int, default=100,\n", + " help='Maximum number of epochs (default: 100)')\n", + " parser.add_argument('--display_interval', type=int, default=5,\n", + " help='Display progress every N epochs (default: 5)')\n", + " parser.add_argument('--plot_interval', type=int, default=5,\n", + " help='Generate plots every N epochs (default: 10)')\n", + " \n", + " args = parser.parse_args()\n", + " \n", + " logging.info(f\"BETE-NET Training Configuration\")\n", + " logging.info(f\"=\" * 50)\n", + " logging.info(f\"Model Configuration: {args.config}\")\n", + " logging.info(f\"Batch Size: {args.batch_size}\")\n", + " logging.info(f\"Max Epochs: {args.max_epochs}\")\n", + " logging.info(f\"Display Interval: {args.display_interval}\")\n", + " logging.info(f\"Plot Interval: {args.plot_interval}\")\n", + " logging.info(f\"=\" * 50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Code Training and Output Results Section" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + " logging.info(\"\\nStarting training...\")\n", + " \n", + " for i in range(0,10):\n", + " logging.info(f\"============ {i} ==============\")\n", + " results = full_training(\n", + " config_name=args.config,\n", + " max_epochs=args.max_epochs,\n", + " batch_size=args.batch_size,\n", + " display_interval=args.display_interval,\n", + " plot_interval=args.plot_interval,\n", + " idx = i\n", + " )\n", + " \n", + " logging.info(f\"\\nTraining completed!\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindChemistry/applications/bete-net/train.py b/MindChemistry/applications/bete-net/train.py new file mode 100644 index 0000000000000000000000000000000000000000..7a5973c81d9908f500efcff8457ee6d724fd1407 --- /dev/null +++ b/MindChemistry/applications/bete-net/train.py @@ -0,0 +1,76 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import argparse +import sys + +sys.path.append('./src') + +from train_utils import full_training + +def main(): + """main""" + parser = argparse.ArgumentParser(description='BETE-NET Training with Configurable Batch Size') + + parser.add_argument('--config', type=str, default='FPD', + choices=['CSO', 'CPD', 'FPD'], + help='Model configuration (CSO/CPD/FPD)') + + parser.add_argument('--batch_size', type=int, default=256, + help='Batch size for training (default: 32)') + parser.add_argument('--max_epochs', type=int, default=100, + help='Maximum number of epochs (default: 100)') + parser.add_argument('--display_interval', type=int, default=5, + help='Display progress every N epochs (default: 5)') + parser.add_argument('--plot_interval', type=int, default=5, + help='Generate plots every N epochs (default: 10)') + + args = parser.parse_args() + + logging.info(f"BETE-NET Training Configuration") + logging.info(f"=" * 50) + logging.info(f"Model Configuration: {args.config}") + logging.info(f"Batch Size: {args.batch_size}") + logging.info(f"Max Epochs: {args.max_epochs}") + logging.info(f"Display Interval: {args.display_interval}") + logging.info(f"Plot Interval: {args.plot_interval}") + logging.info(f"=" * 50) + + logging.info("\nStarting training...") + + try: + for i in range(0,10): + logging.info(f"============ {i} ==============") + results = full_training( + config_name=args.config, + max_epochs=args.max_epochs, + batch_size=args.batch_size, + display_interval=args.display_interval, + plot_interval=args.plot_interval, + idx = i + ) + logging.info(f"\nTraining completed!") + + except KeyboardInterrupt: + logging.info(f"\nTraining interrupted by user") + + except Exception as e: + logging.info(f"\nTraining failed: {e}") + import traceback + traceback.print_exc() + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/MindChemistry/applications/bete-net/train_CN.ipynb b/MindChemistry/applications/bete-net/train_CN.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..9bc798e0a88bc4d352c085f88c817743800e8add --- /dev/null +++ b/MindChemistry/applications/bete-net/train_CN.ipynb @@ -0,0 +1,162 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "BETE-NET: 通过电子-声子谱函数的深度学习加速超导体发现\n", + "相关论文:https://www.nature.com/articles/s41524-024-01475-4 " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "背景介绍\n", + "超导材料的发现对技术应用具有变革性潜力。然而,庞大的化学和结构搜索空间对实验和理论研究都构成了巨大的瓶颈。\n", + "传统的材料发现方法往往依赖于实验的偶然发现或计算量大的高通量筛选。\n", + "为了加速这一进程,将深度学习与超导材料的搜索相结合成为一个新兴的研究领域。\n", + "\n", + "提出了一种将深度学习用于发现电子-声子耦合超导体的策略,核心挑战是 α²F(ω) 计算成本极高。\n", + "这里采用两步法:先对 818 个动态稳定材料计算 α²F(ω),再训练名为 BETE-NET的深度学习模型预测 α²F(ω)。\n", + "\n", + "BETE-NET 通过 双分支 GNN 分别编码电子-声子谱相互作用,利用 谱函数注意力 和 温度退火 策略,高效预测超导体的 Tc。\n", + "其创新点在于将物理谱函数直接嵌入图卷积过程,而非作为后处理特征。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "模型架构\n", + "\n", + "模型架构图如下所示:\n", + "[模型架构](./images/model_structure.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "数据集下载\n", + "\n", + "下载链接:https://github.com/henniggroup/BETE-NET \n", + "需下载 indices、structures文件夹和database.json文件" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "模型训练\n", + "\n", + "引入代码包" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "import logging\n", + "import argparse\n", + "import sys\n", + "\n", + "sys.path.append('./src')\n", + "\n", + "from train_utils import full_training" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "模型相关参数的设置以及训练模型的定义" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "def main():\n", + " \"\"\"main\"\"\"\n", + " parser = argparse.ArgumentParser(description='BETE-NET Training with Configurable Batch Size')\n", + " \n", + " parser.add_argument('--config', type=str, default='FPD', \n", + " choices=['CSO', 'CPD', 'FPD'],\n", + " help='Model configuration (CSO/CPD/FPD)')\n", + " \n", + " parser.add_argument('--batch_size', type=int, default=256,\n", + " help='Batch size for training (default: 32)')\n", + " parser.add_argument('--max_epochs', type=int, default=100,\n", + " help='Maximum number of epochs (default: 100)')\n", + " parser.add_argument('--display_interval', type=int, default=5,\n", + " help='Display progress every N epochs (default: 5)')\n", + " parser.add_argument('--plot_interval', type=int, default=5,\n", + " help='Generate plots every N epochs (default: 10)')\n", + " \n", + " args = parser.parse_args()\n", + " \n", + " logging.info(f\"BETE-NET Training Configuration\")\n", + " logging.info(f\"=\" * 50)\n", + " logging.info(f\"Model Configuration: {args.config}\")\n", + " logging.info(f\"Batch Size: {args.batch_size}\")\n", + " logging.info(f\"Max Epochs: {args.max_epochs}\")\n", + " logging.info(f\"Display Interval: {args.display_interval}\")\n", + " logging.info(f\"Plot Interval: {args.plot_interval}\")\n", + " logging.info(f\"=\" * 50)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "代码训练与输出结果部分" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + " logging.info(f\"\\n 开始训练...\")\n", + " \n", + " for i in range(0,10):\n", + " logging.info(f\"============ {i} ==============\")\n", + " results = full_training(\n", + " config_name=args.config,\n", + " max_epochs=args.max_epochs,\n", + " batch_size=args.batch_size,\n", + " display_interval=args.display_interval,\n", + " plot_interval=args.plot_interval,\n", + " idx = i\n", + " )\n", + " \n", + " logging.info(f\"\\n训练完成!\")" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/MindChemistry/applications/bete-net/train_utils.py b/MindChemistry/applications/bete-net/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..cede1de973cb9b2775dcef7ef0cda3d3ffe044f7 --- /dev/null +++ b/MindChemistry/applications/bete-net/train_utils.py @@ -0,0 +1,538 @@ +# Copyright 2025 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import logging +import sys +import os +import json +import time +import ase.io +from tqdm import tqdm +from datetime import datetime, timedelta +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from gnn.data import Batch +from mindspore.experimental import optim as optimex +from mindspore import ops +import mindspore.nn as nn +import mindspore as ms + +sys.path.append('./src') + +import utils_model_ms as model_utils +import utils_data_ms as data_utils + + +# Set random seeds +np.random.seed(42) +ms.set_seed(42) + + +def save_training_state(epoch, train_losses, val_losses, config, start_time, save_path="training_state.json"): + """Save current training state""" + state = { + 'epoch': epoch, + 'train_losses': train_losses, + 'val_losses': val_losses, + 'config': config, + 'start_time': start_time.isoformat(), + 'current_time': datetime.now().isoformat(), + 'elapsed_hours': (datetime.now() - start_time).total_seconds() / 3600 + } + with open(save_path, 'w') as f: + json.dump(state, f, indent=2) + + +def load_training_state(save_path="training_state.json"): + """Load training state if exists""" + if os.path.exists(save_path): + with open(save_path, 'r') as f: + state = json.load(f) + state['start_time'] = datetime.fromisoformat(state['start_time']) + return state + return None + + +def plot_training_progress(train_losses, val_losses, config_name, save_path=None): + """Plot and save training progress""" + plt.figure(figsize=(12, 8)) + + # Main plot + plt.subplot(2, 2, 1) + epochs = range(1, len(train_losses) + 1) + plt.plot(epochs, train_losses, 'b-', + label='Training Loss', alpha=0.7, linewidth=2) + plt.plot(epochs, val_losses, 'r-', + label='Validation Loss', alpha=0.7, linewidth=2) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title(f'{config_name} Training Progress') + plt.legend() + plt.grid(True, alpha=0.3) + + # Recent progress (last 20 epochs) + plt.subplot(2, 2, 2) + recent_epochs = max(1, len(train_losses) - 20) + recent_range = range(recent_epochs, len(train_losses) + 1) + plt.plot(recent_range, train_losses[recent_epochs-1:], + 'b-', label='Training Loss', linewidth=2) + plt.plot(recent_range, val_losses[recent_epochs-1:], + 'r-', label='Validation Loss', linewidth=2) + plt.xlabel('Epoch') + plt.ylabel('Loss') + plt.title('Recent Progress (Last 20 Epochs)') + plt.legend() + plt.grid(True, alpha=0.3) + + # Loss distribution + plt.subplot(2, 2, 3) + plt.hist(train_losses, bins=20, alpha=0.7, + label='Training Loss', color='blue') + plt.hist(val_losses, bins=20, alpha=0.7, + label='Validation Loss', color='red') + plt.xlabel('Loss Value') + plt.ylabel('Frequency') + plt.title('Loss Distribution') + plt.legend() + + # Training statistics + plt.subplot(2, 2, 4) + stats_text = f"""Training Statistics: + +Current Epoch: {len(train_losses)} +Best Train Loss: {min(train_losses):.6f} +Best Val Loss: {min(val_losses):.6f} +Current Train Loss: {train_losses[-1]:.6f} +Current Val Loss: {val_losses[-1]:.6f} + +Improvement Rate: +Train: {((train_losses[0] - train_losses[-1]) / train_losses[0] * 100):.2f}% +Val: {((val_losses[0] - val_losses[-1]) / val_losses[0] * 100):.2f}% + """ + plt.text(0.1, 0.5, stats_text, fontsize=10, verticalalignment='center', + transform=plt.gca().transAxes, fontfamily='monospace') + plt.axis('off') + + plt.tight_layout() + + if save_path: + plt.savefig(save_path, dpi=300, bbox_inches='tight') + logging.info(f"Progress plot saved to: {save_path}") + + plt.show() + + +def display_progress_summary(epoch, total_epochs, train_loss, val_loss, train_losses, val_losses, + start_time, config_name, model_params): + """Display comprehensive progress summary""" + current_time = datetime.now() + elapsed = current_time - start_time + + # Calculate ETA + if epoch > 0: + avg_time_per_epoch = elapsed.total_seconds() / epoch + remaining_epochs = total_epochs - epoch + eta = current_time + \ + timedelta(seconds=avg_time_per_epoch * remaining_epochs) + else: + eta = None + + # Best losses + best_train = min(train_losses) + best_val = min(val_losses) + best_train_epoch = train_losses.index(best_train) + 1 + best_val_epoch = val_losses.index(best_val) + 1 + + logging.info(f"\n{'='*80}") + logging.info(f"TRAINING PROGRESS SUMMARY - {config_name}") + logging.info(f"{'='*80}") + logging.info( + f"Progress: Epoch {epoch}/{total_epochs} ({epoch/total_epochs*100:.1f}%)") + logging.info(f" Elapsed Time: {str(elapsed).split('.')[0]}") + if eta: + logging.info( + f" ETA: {eta.strftime('%Y-%m-%d %H:%M:%S')} \ + (≈{str(timedelta(seconds=avg_time_per_epoch * remaining_epochs)).split('.')[0]} remaining)") + + logging.info(f"\n Current Performance:") + logging.info(f" Training Loss: {train_loss:.6f}") + logging.info(f" Validation Loss: {val_loss:.6f}") + logging.info(f" Loss Ratio: {val_loss/train_loss:.3f}") + + logging.info(f"\n Best Performance:") + logging.info(f" Best Train Loss: {best_train:.6f} (Epoch {best_train_epoch})") + logging.info(f" Best Val Loss: {best_val:.6f} (Epoch {best_val_epoch})") + + if len(train_losses) >= 5: + recent_train_trend = np.mean( + train_losses[-5:]) - np.mean(train_losses[-10:-5]) if len(train_losses) >= 10 else 0 + recent_val_trend = np.mean( + val_losses[-5:]) - np.mean(val_losses[-10:-5]) if len(val_losses) >= 10 else 0 + + logging.info(f"\nRecent Trends (Last 5 epochs):") + trend_train = " Increasing" if recent_train_trend > 0 else "📉 Decreasing" \ + if recent_train_trend < 0 else "➡️ Stable" + trend_val = " Increasing" if recent_val_trend > 0 else "📉 Decreasing" \ + if recent_val_trend < 0 else "➡️ Stable" + logging.info(f" Training: {trend_train} ({recent_train_trend:+.6f})") + logging.info(f" Validation: {trend_val} ({recent_val_trend:+.6f})") + + logging.info(f"\n Model Configuration:") + # logging.info(f" Parameters: {sum(p.size for p in model_params):,}") + logging.info(f" Input Dim: {model_params.get('input_dim', 'N/A')}") + logging.info(f" Output Dim: {model_params.get('output_dim', 'N/A')}") + + logging.info(f"{'='*80}\n") + + +def full_training(config_name="CPD", max_epochs=100, display_interval=5, plot_interval=10, batch_size=32, idx=0): + """Complete training with progress tracking""" + + logging.info(f"BETE-NET Full Training - {config_name} Configuration") + logging.info(f"{'='*60}") + logging.info(f" Start Time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") + logging.info(f" Configuration: {config_name}") + logging.info(f" Max Epochs: {max_epochs}") + logging.info(f" Batch Size: {batch_size}") + logging.info(f" Display Interval: Every {display_interval} epochs") + logging.info(f" Plot Interval: Every {plot_interval} epochs") + logging.info(f"{'='*60}\n") + + start_time = datetime.now() + + # Configuration settings + configs = { + 'CSO': {'in_dim': 118, 'embed_ph_dos': False, 'embed_e_dos': False, \ + 'fine': False, 'layers': 2, 'mul': 32, 'lr': 0.005}, + 'CPD': {'in_dim': 118 + 51, 'embed_ph_dos': True, 'embed_e_dos': False, \ + 'fine': False, 'layers': 2, 'mul': 32, 'lr': 0.005}, + 'FPD': {'in_dim': 118 + 51, 'embed_ph_dos': True, 'embed_e_dos': False, \ + 'fine': True, 'layers': 2, 'mul': 32, 'lr': 0.005} + } + + if config_name not in configs: + raise ValueError( + f"Unknown configuration: {config_name}. Available: {list(configs.keys())}") + + config = configs[config_name] + + # Load data + logging.info("Loading database...") + df = pd.read_json('database.json') + df.dropna(inplace=True) + logging.info(f"Loaded {len(df)} samples from database") + + # Load structures + logging.info(" Loading crystal structures...") + structures = [] + for index, row in tqdm(df.iterrows(), desc="Loading structures", ncols=100): + try: + structures.append(ase.io.read(f'structures/{index}.cif')) + except Exception as e: + structures.append(None) + + df['structure'] = structures + df = df[df['structure'].notna()] + logging.info(f"Successfully loaded {len(df)} crystal structures") + + # Process data + logging.info(f" Processing data for {config_name} configuration...") + r_max = 4 + df['target'] = df.apply(data_utils.get_target, axis=1) + df['formula'] = df['structure'].map(lambda x: x.get_chemical_formula()) + + tqdm.pandas(desc="Building graph data", ncols=100) + df['data'] = df.progress_apply( + data_utils.build_data, + embed_ph_dos=config['embed_ph_dos'], + embed_e_dos=config['embed_e_dos'], + fine=config['fine'], + r_max=r_max, + axis=1 + ) + + # Get dimensions + sample_data = df.iloc[0]['data'] + out_dim = len(df.iloc[0]['target']) + in_dim = sample_data.x[0] + em_dim = 64 + + logging.info(f"Data Information:") + logging.info(f" - Total samples: {len(df)}") + logging.info(f" - Input features: {in_dim}") + logging.info(f" - Output targets: {out_dim}") + logging.info(f" - Configuration: {config_name}") + + train_df, test_df, val_df = data_utils.get_original_data_split(df, idx) + + if val_df is None: + val_split_idx = int(len(train_df) * 0.8) + val_df = train_df.iloc[val_split_idx:].copy() + train_df = train_df.iloc[:val_split_idx].copy() + + logging.info(f"Data Split:") + logging.info( + f" - Training: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)") + logging.info( + f" - Validation: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)") + logging.info( + f" - Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)") + + # Create model + model_params = { + 'in_dim': config['in_dim'], + 'em_dim': 64, + 'irreps_in': f'{em_dim}x0e', + 'irreps_out': f'{out_dim}x0e', + 'irreps_node_attr': f'{em_dim}x0e', + 'layers': config['layers'], + 'mul': config['mul'], + 'lmax': 1, + 'max_radius': r_max, + 'number_of_basis': 10, + 'radial_layers': 1, + 'radial_neurons': 128, + 'num_neighbors': data_utils.get_neighbors(train_df, train_df.index).mean(), + 'num_nodes': 8.0, + 'reduce_output': True, + 'dropout': False, + 'input_dim': in_dim, + 'output_dim': out_dim + } + + logging.info("run full training") + if config_name == 'CSO': + model = model_utils.PeriodicNetwork( + **{k: v for k, v in model_params.items() if k not in ['input_dim', 'output_dim']}) + model.pool = True + elif config_name == 'CPD': + model = model_utils.PeriodicNetwork( + **{k: v for k, v in model_params.items() if k not in ['input_dim', 'output_dim']}) + elif config_name == 'FPD': + model = model_utils.PeriodicNetwork( + **{k: v for k, v in model_params.items() if k not in ['input_dim', 'output_dim']}) + param_count = sum(p.size for p in model.get_parameters()) + logging.info(f" Model created with {param_count:,} parameters") + + loss_fn = nn.MSELoss() + optimizer = optimex.AdamW(model.trainable_params(), lr=config['lr']) + scheduler = optimex.lr_scheduler.MultiStepLR( + optimizer, milestones=[60, 120, 180], gamma=0.3) + + def forward_fn(data, targets): + pred = model(data) + loss = loss_fn(pred, targets) + return loss, pred + + grad_fn = ops.value_and_grad( + forward_fn, None, optimizer.parameters, has_aux=True) + + # Training variables + train_losses = [] + val_losses = [] + best_val_loss = float('inf') + patience = 20 + patience_counter = 0 + + logging.info(f"\n Starting Training...") + logging.info(f" Learning Rate: {config['lr']}") + logging.info(f" Early Stopping Patience: {patience}") + logging.info(f" Model Save Path: best_{config_name.lower()}_model_ms.ckpt") + + def create_batches(df, batch_size, shuffle=True): + """Create batches from dataframe using gnn's Batch.from_data_list""" + indices = np.arange(0, len(df), 1) + if shuffle: + np.random.shuffle(indices) + + batches = [] + for i in range(0, len(indices), batch_size): + batch_indices = indices[i:i + batch_size] + + if len(batch_indices) == 1: + # Single sample - no batching needed + data = df.loc[batch_indices[0], 'data'] + target = ms.Tensor( + [df.loc[batch_indices[0], 'target']], dtype=ms.float32) + batches.append((data, target)) + else: + # Multiple samples - use gnn batching + data_list = [df.iloc[idx]['data'] for idx in batch_indices] + targets = [df.iloc[idx]['target'] for idx in batch_indices] + + batch_data = Batch.from_data_list(data_list) + batch_targets = ms.Tensor(targets, dtype=ms.float32) + batches.append((batch_data, batch_targets)) + + return batches + + # Training loop + for epoch in range(max_epochs): + epoch_start_time = time.time() + + # Create batches for this epoch + train_batches = create_batches(train_df, batch_size, shuffle=True) + val_batches = create_batches(val_df, batch_size, shuffle=False) + + # Training phase + model.set_train() + train_loss = 0.0 + train_count = 0 + + train_pbar = tqdm(train_batches, desc=f"Epoch {epoch+1:3d}/{max_epochs} [Train]", + ncols=100, leave=False) + + for batch_data, batch_targets in train_pbar: + if len(batch_targets.shape) == 1: + # Single sample case + batch_targets = batch_targets.expand_dims(0) + + scheduler.step() + (loss, pred), grads = grad_fn(batch_data, batch_targets) + optimizer(grads) + + current_lr = scheduler.get_last_lr() + logging.info(current_lr) + + current_loss = float(loss.asnumpy()) + train_loss += current_loss + train_count += 1 + + # Update progress bar + train_pbar.set_postfix({ + 'Loss': f'{current_loss:.6f}', + 'Avg': f'{train_loss/train_count:.6f}', + 'Batch': f'{batch_targets.shape[0]}' + }) + + avg_train_loss = train_loss / train_count + train_losses.append(avg_train_loss) + + # Validation phase + model.set_train(False) + val_loss = 0.0 + val_count = 0 + + val_pbar = tqdm(val_batches, desc=f"Epoch {epoch+1:3d}/{max_epochs} [Val]", + ncols=100, leave=False) + + for batch_data, batch_targets in val_pbar: + if len(batch_targets.shape) == 1: + # Single sample case + batch_targets = batch_targets.expand_dims(0) + + pred = model(batch_data) + loss = loss_fn(pred, batch_targets) + + current_loss = float(loss.asnumpy()) + val_loss += current_loss + val_count += 1 + + val_pbar.set_postfix({ + 'Loss': f'{current_loss:.6f}', + 'Avg': f'{val_loss/val_count:.6f}', + 'Batch': f'{batch_targets.shape[0]}' + }) + + avg_val_loss = val_loss / val_count + val_losses.append(avg_val_loss) + + epoch_time = time.time() - epoch_start_time + + # Basic epoch summary + improvement = "🟢" if avg_val_loss < best_val_loss else "🔴" + logging.info( + f"Epoch {epoch+1:3d}: Train={avg_train_loss:.6f}, Val={avg_val_loss:.6f} {improvement} ({epoch_time:.1f}s)") + + # Save best model + if avg_val_loss < best_val_loss: + best_val_loss = avg_val_loss + ms.save_checkpoint( + model, f"./{config_name.lower()}/best_{config_name.lower()}_model_ms_{idx}.ckpt") + patience_counter = 0 + logging.info( + f" New best model saved! (Val Loss: {best_val_loss:.6f})") + else: + patience_counter += 1 + + # Detailed progress display + if (epoch + 1) % display_interval == 0: + display_progress_summary(epoch + 1, max_epochs, avg_train_loss, avg_val_loss, + train_losses, val_losses, start_time, config_name, model_params) + + # Plot progress + if (epoch + 1) % plot_interval == 0: + plot_path = f"{config_name.lower()}_training_progress_epoch_{epoch+1}.png" + plot_training_progress( + train_losses, val_losses, config_name, plot_path) + + # Save training state + save_training_state(epoch + 1, train_losses, val_losses, config, start_time, + f"{config_name.lower()}_training_state.json") + + # Early stopping + if patience_counter >= patience: + logging.info( + f"\n🛑 Early stopping triggered after {patience} epochs without improvement") + break + + # Final plot + final_plot_path = f"{config_name.lower()}_final_training_results.png" + plot_training_progress(train_losses, val_losses, + config_name, final_plot_path) + + return { + 'model': model, + 'train_losses': train_losses, + 'val_losses': val_losses, + 'config': config_name + } + + +if __name__ == "__main__": + + CONFIG = "FPD" + MAX_EPOCHS = 100 + BATCH_SIZE = 256 + DISPLAY_INTERVAL = 5 + PLOT_INTERVAL = 10 + + logging.info(f"Starting BETE-NET Training") + logging.info(f"Configuration: {CONFIG}") + logging.info(f"Max Epochs: {MAX_EPOCHS}") + logging.info(f"Batch Size: {BATCH_SIZE}") + logging.info(f"Progress Display: Every {DISPLAY_INTERVAL} epochs") + logging.info(f"Plot Generation: Every {PLOT_INTERVAL} epochs") + logging.info(f"\nPress Ctrl+C to stop training gracefully...") + + try: + results = full_training( + config_name=CONFIG, + max_epochs=MAX_EPOCHS, + batch_size=BATCH_SIZE, + display_interval=DISPLAY_INTERVAL, + plot_interval=PLOT_INTERVAL + ) + logging.info(f"\nTraining completed successfully!") + + except KeyboardInterrupt: + logging.info(f"\n Training interrupted by user") + logging.info(f"Training state has been saved and can be resumed later") + + except Exception as e: + logging.info(f"\n Training failed with error: {e}") + import traceback + traceback.print_exc()