diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF.ipynb b/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..027c2ea2e7eb22081f10ddcc6b0829b50f6c26bc
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF.ipynb
@@ -0,0 +1,419 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# PhyMPGN: Physics-encoded Message Passing Graph Network for spatiotemporal PDE systems"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Complex dynamical systems governed by partial differential equations (PDEs) exist in a wide variety of disciplines. Recent progresses have demonstrated grand benefits of data-driven neural-based models for predicting spatiotemporal dynamics.\n",
+ "\n",
+ "Physics-encoded Message Passing Graph Network (PhyMPGN) is capable to model spatiotemporal PDE systems on irregular meshes given small training datasets. Specifically:\n",
+ "\n",
+ "- A physics-encoded grapph learning model with the message-passing mechanism is proposed, where the temporal marching is realized via a second-order numerical integrator (e.g. Runge-Kutta scheme)\n",
+ "- Considering the universality of diffusion processes in physical phenomena, a learnable Laplace Block is designed, which encodes the discrete Laplace-Beltrami operator\n",
+ "- A novel padding strategy to encode different types of BCs into the learning model is proposed.\n",
+ "\n",
+ "Paper link: [https://arxiv.org/abs/2410.01337](https://gitee.com/link?target=https%3A%2F%2Farxiv.org%2Fabs%2F2410.01337)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Problem Setup"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "Let's consider complex physical systems, governed by spatiotemporal PDEs in the general form:\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\dot {\\boldsymbol {u}}(\\boldsymbol x, t) = \\boldsymbol F (t, \\boldsymbol x, \\boldsymbol u, \\nabla \\boldsymbol u, \\Delta \\boldsymbol u, \\dots)\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "where $\\boldsymbol u(\\boldsymbol x, y) \\in \\mathbb{R}^m$ is the vector of state variable with $m$ components,such as velocity, temperature or pressure, defined over the spatiotemporal domain $\\{ \\boldsymbol x, t \\} \\in \\Omega \\times [0, \\mathcal{T}]$. Here, $\\dot{\\boldsymbol u}$ denotes the derivative with respect to time and $\\boldsymbol F$ is a nonlinear operator that depends on the current state $\\boldsymbol u$ and its spatial derivatives.\n",
+ "\n",
+ "We focus on a spatial domain $\\Omega$ with non-uniformly and sparsely observed nodes $\\{ \\boldsymbol x_0, \\dots, \\boldsymbol x_{N-1} \\}$ (e.g., on an unstructured mesh). Observations $\\{ \\boldsymbol U(t_0), \\dots, \\boldsymbol U(t_{T-1}) \\}$ are collected at time points $t_0, ... \\dots, t_{T- 1}$, where $\\boldsymbol U(t_i) = \\{ \\boldsymbol u(\\boldsymbol x_0, t_i), \\dots, \\boldsymbol u (\\boldsymbol x_{N-1}, t_i) \\}$ denote the physical quantities. Considering that many physical phenomena involve diffusion processes, we assume the diffusion term in the PDE is known as a priori knowledge. Our goal is to develop a graph learning model with small training datasets capable of accurately predicting various spatiotemporal dynamics on coarse unstructured meshes, handling different types of BCs, and producing the trajectory of dynamics for an arbitrarily given IC."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "This case demonstrates how PhyMPGN solves the cylinder flow problem."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The dynamical system of two-dimensional cylinder flow is governed by Navier-Stokes equation\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\dot{\\boldsymbol{u}} = - \\boldsymbol{u} \\cdot \\nabla \\boldsymbol{u} -\\frac{1}{\\rho} \\nabla p + \\frac{\\mu}{\\rho} \\Delta \\boldsymbol{u} + \\boldsymbol{f}\n",
+ "\\tag{2}\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "Where the fluid density $\\rho$ is 1,the fluid viscosity $\\mu$ is $5\\times10^{-3}$,and the external force $f$ is 0。The cylinder flow system has an inlet on the left boundary, an outlet on the right boundary, a no-slip boundary condition on the cylinder surface, and symmetric boundary conditions on the top and bottom boundaries. This case study focuses on generalizing the inflow velocity $U_m$ while keeping the fluid density $\\rho$, cylinder diameter $D=2$, and fluid viscosity $\\mu$ constant. Since the Reynolds number is defined as $Re=\\rho U_m D/ \\mu$, generalizing the inflow velocity $U_m$ inherently means generalizing different Reynolds numbers."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Model Architecture"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "For Equation (1), a second-order Runge-Kutta (RK2) scheme can be used for discretization:\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\boldsymbol u^{k+1} = \\boldsymbol u^k + \\frac{1}{2}(\\boldsymbol g_1 + \\boldsymbol g_2); \\quad \\boldsymbol g_1 = \\boldsymbol F(t^k, \\boldsymbol x, \\boldsymbol u^k, \\dots); \\quad \\boldsymbol g_2 = \\boldsymbol F(t^{k+1}, \\boldsymbol x, \\boldsymbol u^k + \\delta t \\boldsymbol g_1, \\dots)\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "where $\\boldsymbol u^k$ is the state variable at time $t^k$,and $\\delta t$ denotes the time interval between $t^k$ and $t^{k+1}$. According to the Equation (2), we develop a GNN to learn the nonlinear operator $\\boldsymbol F$.\n",
+ "\n",
+ "As shown in Figure, the NN block aims to learn the nonlinear operator $\\boldsymbol F$ and consists of two parts: a GNN block followed the Encode-Process-Decode module and a learnable Laplace block. Due to the universality of diffusion processes in physical phenomena, we design the learnable Laplace block, which encodes the discrete Laplace-Beltrami operator, to learn the increment caused by the diffusion term in the PDE, while the GNN block is responsible to learn the increment induced by other unknown mechanisms or sources."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Preparation"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- Make sure the required dependency libraries (such as MindSpore) have been installed\n",
+ "- Ensure the [cylinder flow dataset](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/data_mechanism_fusion/PhyMPGN/) has been downloaded\n",
+ "- Verify that the data and model weight storage paths have been properly configured in the `yamls/train.yaml` configuration file"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## Code Execution Steps"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "The code execution flow consists of the following steps:\n",
+ "\n",
+ "1. Read configuration file\n",
+ "2. Build dataset\n",
+ "3. Construct model\n",
+ "4. Model training\n",
+ "5. Model inference\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Read configuration file"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import os.path as osp\n",
+ "from pathlib import Path\n",
+ "from easydict import EasyDict\n",
+ "from mindflow.utils import log_config, load_yaml_config, print_log\n",
+ "\n",
+ "\n",
+ "def load_config(config_file_path, train):\n",
+ " config = load_yaml_config(config_file_path)\n",
+ " config['train'] = train\n",
+ " config = EasyDict(config)\n",
+ " log_dir = './logs'\n",
+ " if train:\n",
+ " log_file = f'phympgn-{config.experiment_name}'\n",
+ " else:\n",
+ " log_file = f'phympgn-{config.experiment_name}-te'\n",
+ " if not osp.exists(osp.join(log_dir, f'{log_file}.log')):\n",
+ " Path(osp.join(log_dir, f'{log_file}.log')).touch()\n",
+ " log_config(log_dir, log_file)\n",
+ " print_log(config)\n",
+ " return config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_file_path = 'yamls/train.yaml'\n",
+ "config = load_config(config_file_path=config_file_path, train=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mindspore as ms\n",
+ "\n",
+ "ms.set_device(device_target='Ascend', device_id=7)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Build dataset"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import PDECFDataset, get_data_loader\n",
+ "\n",
+ "\n",
+ "print_log('Train...')\n",
+ "print_log('Loading training data...')\n",
+ "tr_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.tr_raw_data,\n",
+ " dataset_start=config.data.dataset_start,\n",
+ " dataset_used=config.data.dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.tr_window_size,\n",
+ " training=True\n",
+ ")\n",
+ "tr_loader = get_data_loader(\n",
+ " dataset=tr_dataset,\n",
+ " batch_size=config.optim.batch_size\n",
+ ")\n",
+ "\n",
+ "print_log('Loading validation data...')\n",
+ "val_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.val_raw_data,\n",
+ " dataset_start=config.data.dataset_start,\n",
+ " dataset_used=config.data.dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.val_window_size\n",
+ ")\n",
+ "val_loader = get_data_loader(\n",
+ " dataset=val_dataset,\n",
+ " batch_size=config.optim.batch_size\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Construct model"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import PhyMPGN\n",
+ "\n",
+ "print_log('Building model...')\n",
+ "model = PhyMPGN(\n",
+ " encoder_config=config.network.encoder_config,\n",
+ " mpnn_block_config=config.network.mpnn_block_config,\n",
+ " decoder_config=config.network.decoder_config,\n",
+ " laplace_block_config=config.network.laplace_block_config,\n",
+ " integral=config.network.integral\n",
+ ")\n",
+ "print_log(f'Number of parameters: {model.num_params}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Model training"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import numpy as np\n",
+ "from mindflow import get_multi_step_lr\n",
+ "from mindspore import nn\n",
+ "\n",
+ "from src import Trainer, TwoStepLoss\n",
+ "\n",
+ "\n",
+ "lr_scheduler = get_multi_step_lr(\n",
+ " lr_init=config.optim.lr,\n",
+ " milestones=list(np.arange(0, config.optim.start_epoch+config.optim.epochs,\n",
+ " step=config.optim.steplr_size)[1:]),\n",
+ " gamma=config.optim.steplr_gamma,\n",
+ " steps_per_epoch=len(tr_loader),\n",
+ " last_epoch=config.optim.start_epoch+config.optim.epochs-1\n",
+ ")\n",
+ "optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler,\n",
+ " eps=1.0e-8, weight_decay=1.0e-2)\n",
+ "trainer = Trainer(\n",
+ " model=model, optimizer=optimizer, scheduler=lr_scheduler, config=config,\n",
+ " loss_func=TwoStepLoss()\n",
+ ")\n",
+ "trainer.train(tr_loader, val_loader)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[Epoch 1/1600] Batch Time: 2.907 (3.011) Data Time: 0.021 (0.035) Graph Time: 0.004 (0.004) Grad Time: 2.863 (2.873) Optim Time: 0.006 (0.022)\n",
+ "\n",
+ "[Epoch 1/1600] Batch Time: 1.766 (1.564) Data Time: 0.022 (0.044) Graph Time: 0.003 (0.004)\n",
+ "\n",
+ "[Epoch 1/1600] tr_loss: 1.36e-02 val_loss: 1.29e-02 [MIN]\n",
+ "\n",
+ "[Epoch 2/1600] Batch Time: 3.578 (3.181) Data Time: 0.024 (0.038) Graph Time: 0.004 (0.004) Grad Time: 3.531 (3.081) Optim Time: 0.004 (0.013)\n",
+ "\n",
+ "[Epoch 2/1600] Batch Time: 1.727 (1.664) Data Time: 0.023 (0.042) Graph Time: 0.003 (0.004)\n",
+ "\n",
+ "[Epoch 2/1600] tr_loss: 1.15e-02 val_loss: 9.55e-03 [MIN]\n",
+ "\n",
+ "...\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### Model inference"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_file_path = 'yamls/train.yaml'\n",
+ "config = load_config(config_file_path=config_file_path, train=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mindspore as ms\n",
+ "\n",
+ "ms.set_device(device_target='Ascend', device_id=7)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mindspore import nn\n",
+ "from src import PDECFDataset, get_data_loader, Trainer, PhyMPGN\n",
+ "\n",
+ "\n",
+ "# test datasets\n",
+ "te_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.te_raw_data,\n",
+ " dataset_start=config.data.te_dataset_start,\n",
+ " dataset_used=config.data.te_dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.te_window_size,\n",
+ " training=False\n",
+ ")\n",
+ "te_loader = get_data_loader(\n",
+ " dataset=te_dataset,\n",
+ " batch_size=1,\n",
+ " shuffle=False,\n",
+ ")\n",
+ "print_log('Building model...')\n",
+ "model = PhyMPGN(\n",
+ " encoder_config=config.network.encoder_config,\n",
+ " mpnn_block_config=config.network.mpnn_block_config,\n",
+ " decoder_config=config.network.decoder_config,\n",
+ " laplace_block_config=config.network.laplace_block_config,\n",
+ " integral=config.network.integral\n",
+ ")\n",
+ "print_log(f'Number of parameters: {model.num_params}')\n",
+ "trainer = Trainer(\n",
+ " model=model, optimizer=None, scheduler=None, config=config,\n",
+ " loss_func=nn.MSELoss()\n",
+ ")\n",
+ "print_log('Test...')\n",
+ "trainer.test(te_loader)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "zbc_ms2.5.0",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF_CN.ipynb b/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF_CN.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..ade33a5dea1b353ffd99d7aa25c9eed1600af500
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/PhyMPGN-CF_CN.ipynb
@@ -0,0 +1,421 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# 用于时空PDE系统的物理编码消息传递图神经网络 (PhyMPGN: Physics-encoded Message Passing Graph Network for spatiotemporal PDE systems)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "偏微分方程(PDEs)控制的复杂动力系统广泛存在于各个学科当中。近年来,数据驱动的神经网络模型在预测时空动态上取得了极好的效果。\n",
+ "\n",
+ "物理编码的消息传递图网络(PhyMPGN),可以使用少量训练数据在不规则计算域上建模时空PDE系统。具体来说,\n",
+ "\n",
+ "- 提出了一个使用消息传递机制的物理编码图学习模型,使用二阶龙格库塔(Runge-Kutta)数值方案进行时间步进\n",
+ "- 考虑到物理现象中普遍存在扩散过程,设计了一个可学习的Laplace Block,编码了离散拉普拉斯-贝尔特拉米算子(Laplace-Beltrami Operator)\n",
+ "- 提出了一个新颖的填充策略在模型中编码不同类型的边界条件\n",
+ "\n",
+ "论文链接: [https://arxiv.org/abs/2410.01337](https://gitee.com/link?target=https%3A%2F%2Farxiv.org%2Fabs%2F2410.01337)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 问题描述"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "考虑由如下形式控制的时空PDE系统:\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\dot {\\boldsymbol {u}}(\\boldsymbol x, t) = \\boldsymbol F (t, \\boldsymbol x, \\boldsymbol u, \\nabla \\boldsymbol u, \\Delta \\boldsymbol u, \\dots)\n",
+ "\\tag{1}\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "其中$\\boldsymbol u(\\boldsymbol x, y) \\in \\mathbb{R}^m$是具有$m$个分量的状态变量向量,例如速度、温度或者压力等,它的定义在时空域$\\{ \\boldsymbol x, t \\} \\in \\Omega \\times [0, \\mathcal{T}]$上;$\\dot{\\boldsymbol u}$代表$\\boldsymbol u$对时间的导数,$\\boldsymbol F$是依赖于当前状态$\\boldsymbol u$和其空间导数的非线性算子。\n",
+ "\n",
+ "假设在空间域$\\Omega$上有着非均匀且稀疏的观测结点$\\{ \\boldsymbol x_0, \\dots, \\boldsymbol x_{N-1} \\}$(即,非结构化网格),在时刻$t_0, \\dots, t_{T-1}$,这些结点上的观测为$\\{ \\boldsymbol U(t_0), \\dots, \\boldsymbol U(t_{T-1}) \\}$,其中的$\\boldsymbol U(t_i) = \\{ \\boldsymbol u(\\boldsymbol x_0, t_i), \\dots, \\boldsymbol u (\\boldsymbol x_{N-1}, t_i) \\}$代表某些物理量。考虑到很多物理现象包含扩散过程,我们假设PDE中的扩散项是已知的先验知识。我们的目标是使用少量训练数据学习一个图神经网络模型,在稀疏非结构网格上预测不同的时空动态系统,处理不同的边界条件,为任意的初始条件产生后续动态轨迹。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "本案例展示PhyMPGN如何求解圆柱绕流(Cylinder Flow)问题。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "圆柱绕流 (Cylinder Flow) 动态系统由如下的Navier-Stokes方程控制\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\dot{\\boldsymbol{u}} = - \\boldsymbol{u} \\cdot \\nabla \\boldsymbol{u} -\\frac{1}{\\rho} \\nabla p + \\frac{\\mu}{\\rho} \\Delta \\boldsymbol{u} + \\boldsymbol{f}\n",
+ "\\tag{2}\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "其中流体密度$\\rho=1$,流体粘度系数$\\mu=5.0\\times10^{-3}$,外力$f=0$。该圆柱绕流系统左边界为入口,右边界为出口,圆柱表面为无滑移边界条件,上下边界为对称边界条件。本案例关注于,在保持流体密度$\\rho$,圆柱大小$D=2$和流体粘度系数$\\mu$不变的情况下,泛化入射流速度$U_m$。因为雷诺数定义为$Re=\\rho U_m D / \\mu$,所以泛化入射流速度$U_m$也意味着泛化不同的雷诺数。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 模型方法"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "
"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "对于式(1),可以使用二阶龙格库塔(Runge-Kutta, RK2)方案进行离散:\n",
+ "\n",
+ "$$\n",
+ "\\begin{equation}\n",
+ "\\boldsymbol u^{k+1} = \\boldsymbol u^k + \\frac{1}{2}(\\boldsymbol g_1 + \\boldsymbol g_2); \\quad \\boldsymbol g_1 = \\boldsymbol F(t^k, \\boldsymbol x, \\boldsymbol u^k, \\dots); \\quad \\boldsymbol g_2 = \\boldsymbol F(t^{k+1}, \\boldsymbol x, \\boldsymbol u^k + \\delta t \\boldsymbol g_1, \\dots)\n",
+ "\\tag{3}\n",
+ "\\end{equation}\n",
+ "$$\n",
+ "\n",
+ "其中$\\boldsymbol u^k$为$t^k$时刻的状态变量,$\\delta t$为时刻$t^k$和$t^{k+1}$之间的时间间隔。根据式(3),我们构建一个GNN来学习非线性算子$\\boldsymbol F$.\n",
+ "\n",
+ "如图所示,我们使用NN block来学习非线性算子$\\boldsymbol F$。NN block又可以分为两部分:采用编码器-处理器-解码器架构的GNN block和可学习的Laplace block。因为物理现象中扩散过程的普遍存在性,我们设计了可学习的Laplace block,编码离散拉普拉斯贝尔特拉米算子(Laplace-Beltrami operator),来学习由PDE中扩散项导致的增量;而GNN block来学习PDE中其他项导致的增量。"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 准备环节"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "- 确保已安装相关依赖库,如Mindspore等\n",
+ "- 确保已下载好[圆柱绕流数据](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/data_mechanism_fusion/PhyMPGN/)\n",
+ "- 确保在`yamls/train.yaml`配置文件中已配置好数据和模型权重等相关保存路径"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 代码执行步骤"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "代码执行流程如下步骤:\n",
+ "\n",
+ "1. 读取配置文件\n",
+ "2. 构建数据集\n",
+ "3. 构建模型\n",
+ "4. 模型训练\n",
+ "5. 模型推理\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "### 读取配置文件"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mindflow.utils import log_config, load_yaml_config, print_log\n",
+ "from easydict import EasyDict\n",
+ "import os.path as osp\n",
+ "from pathlib import Path\n",
+ "\n",
+ "\n",
+ "def load_config(config_file_path, train):\n",
+ " config = load_yaml_config(config_file_path)\n",
+ " config['train'] = train\n",
+ " config = EasyDict(config)\n",
+ " log_dir = './logs'\n",
+ " if train:\n",
+ " log_file = f'phympgn-{config.experiment_name}'\n",
+ " else:\n",
+ " log_file = f'phympgn-{config.experiment_name}-te'\n",
+ " if not osp.exists(osp.join(log_dir, f'{log_file}.log')):\n",
+ " Path(osp.join(log_dir, f'{log_file}.log')).touch()\n",
+ " log_config(log_dir, log_file)\n",
+ " print_log(config)\n",
+ " return config"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_file_path = 'yamls/train.yaml'\n",
+ "config = load_config(config_file_path=config_file_path, train=True)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mindspore as ms\n",
+ "\n",
+ "ms.set_device(device_target='Ascend', device_id=7)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 构建数据集"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import PDECFDataset, get_data_loader\n",
+ "\n",
+ "\n",
+ "print_log('Train...')\n",
+ "print_log('Loading training data...')\n",
+ "tr_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.tr_raw_data,\n",
+ " dataset_start=config.data.dataset_start,\n",
+ " dataset_used=config.data.dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.tr_window_size,\n",
+ " training=True\n",
+ ")\n",
+ "tr_loader = get_data_loader(\n",
+ " dataset=tr_dataset,\n",
+ " batch_size=config.optim.batch_size\n",
+ ")\n",
+ "\n",
+ "print_log('Loading validation data...')\n",
+ "val_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.val_raw_data,\n",
+ " dataset_start=config.data.dataset_start,\n",
+ " dataset_used=config.data.dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.val_window_size\n",
+ ")\n",
+ "val_loader = get_data_loader(\n",
+ " dataset=val_dataset,\n",
+ " batch_size=config.optim.batch_size\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 构建模型"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import PhyMPGN\n",
+ "\n",
+ "print_log('Building model...')\n",
+ "model = PhyMPGN(\n",
+ " encoder_config=config.network.encoder_config,\n",
+ " mpnn_block_config=config.network.mpnn_block_config,\n",
+ " decoder_config=config.network.decoder_config,\n",
+ " laplace_block_config=config.network.laplace_block_config,\n",
+ " integral=config.network.integral\n",
+ ")\n",
+ "print_log(f'Number of parameters: {model.num_params}')"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 模型训练"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from mindflow import get_multi_step_lr\n",
+ "from mindspore import nn\n",
+ "import numpy as np\n",
+ "\n",
+ "from src import Trainer, TwoStepLoss\n",
+ "\n",
+ "\n",
+ "lr_scheduler = get_multi_step_lr(\n",
+ " lr_init=config.optim.lr,\n",
+ " milestones=list(np.arange(0, config.optim.start_epoch+config.optim.epochs,\n",
+ " step=config.optim.steplr_size)[1:]),\n",
+ " gamma=config.optim.steplr_gamma,\n",
+ " steps_per_epoch=len(tr_loader),\n",
+ " last_epoch=config.optim.start_epoch+config.optim.epochs-1\n",
+ ")\n",
+ "optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler,\n",
+ " eps=1.0e-8, weight_decay=1.0e-2)\n",
+ "trainer = Trainer(\n",
+ " model=model, optimizer=optimizer, scheduler=lr_scheduler, config=config,\n",
+ " loss_func=TwoStepLoss()\n",
+ ")\n",
+ "trainer.train(tr_loader, val_loader)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "[Epoch 1/1600] Batch Time: 2.907 (3.011) Data Time: 0.021 (0.035) Graph Time: 0.004 (0.004) Grad Time: 2.863 (2.873) Optim Time: 0.006 (0.022)\n",
+ "\n",
+ "[Epoch 1/1600] Batch Time: 1.766 (1.564) Data Time: 0.022 (0.044) Graph Time: 0.003 (0.004)\n",
+ "\n",
+ "[Epoch 1/1600] tr_loss: 1.36e-02 val_loss: 1.29e-02 [MIN]\n",
+ "\n",
+ "[Epoch 2/1600] Batch Time: 3.578 (3.181) Data Time: 0.024 (0.038) Graph Time: 0.004 (0.004) Grad Time: 3.531 (3.081) Optim Time: 0.004 (0.013)\n",
+ "\n",
+ "[Epoch 2/1600] Batch Time: 1.727 (1.664) Data Time: 0.023 (0.042) Graph Time: 0.003 (0.004)\n",
+ "\n",
+ "[Epoch 2/1600] tr_loss: 1.15e-02 val_loss: 9.55e-03 [MIN]\n",
+ "\n",
+ "...\n"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "## 模型推理"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "config_file_path = 'yamls/train.yaml'\n",
+ "config = load_config(config_file_path=config_file_path, train=False)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import mindspore as ms\n",
+ "\n",
+ "ms.set_device(device_target='Ascend', device_id=7)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from src import PDECFDataset, get_data_loader, Trainer, PhyMPGN\n",
+ "from mindspore import nn\n",
+ "\n",
+ "\n",
+ "# test datasets\n",
+ "te_dataset = PDECFDataset(\n",
+ " root=config.path.data_root_dir,\n",
+ " raw_files=config.path.te_raw_data,\n",
+ " dataset_start=config.data.te_dataset_start,\n",
+ " dataset_used=config.data.te_dataset_used,\n",
+ " time_start=config.data.time_start,\n",
+ " time_used=config.data.time_used,\n",
+ " window_size=config.data.te_window_size,\n",
+ " training=False\n",
+ ")\n",
+ "te_loader = get_data_loader(\n",
+ " dataset=te_dataset,\n",
+ " batch_size=1,\n",
+ " shuffle=False,\n",
+ ")\n",
+ "print_log('Building model...')\n",
+ "model = PhyMPGN(\n",
+ " encoder_config=config.network.encoder_config,\n",
+ " mpnn_block_config=config.network.mpnn_block_config,\n",
+ " decoder_config=config.network.decoder_config,\n",
+ " laplace_block_config=config.network.laplace_block_config,\n",
+ " integral=config.network.integral\n",
+ ")\n",
+ "print_log(f'Number of parameters: {model.num_params}')\n",
+ "trainer = Trainer(\n",
+ " model=model, optimizer=None, scheduler=None, config=config,\n",
+ " loss_func=nn.MSELoss()\n",
+ ")\n",
+ "print_log('Test...')\n",
+ "trainer.test(te_loader)"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "zbc_ms2.5.0",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.11.11"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/README.md b/MindFlow/applications/data_mechanism_fusion/phympgn/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5dfb7543f73029fddad74f03ea28e280c7a555e4
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/README.md
@@ -0,0 +1,113 @@
+non-uniformly and sparsely observed nodes
+
+# PhyMPGN: Physics-encoded Message Passing Graph Network for spatiotemporal PDE systems
+
+Complex dynamical systems governed by partial differential equations (PDEs) exist in a wide variety of disciplines. Recent progresses have demonstrated grand benefits of data-driven neural-based models for predicting spatiotemporal dynamics.
+
+Physics-encoded Message Passing Graph Network (PhyMPGN) is capable to model spatiotemporal PDE systems on irregular meshes given small training datasets. Specifically:
+
+- A physics-encoded grapph learning model with the message-passing mechanism is proposed, where the temporal marching is realized via a second-order numerical integrator (e.g. Runge-Kutta scheme)
+- Considering the universality of diffusion processes in physical phenomena, a learnable Laplace Block is designed, which encodes the discrete Laplace-Beltrami operator
+- A novel padding strategy to encode different types of BCs into the learning model is proposed.
+
+Paper link: [https://arxiv.org/abs/2410.01337](https://gitee.com/link?target=https%3A%2F%2Farxiv.org%2Fabs%2F2410.01337)
+
+## Problem Setup
+
+Let's consider complex physical systems, governed by spatiotemporal PDEs in the general form:
+
+$$
+\begin{equation}
+\dot {\boldsymbol {u}}(\boldsymbol x, t) = \boldsymbol F (t, \boldsymbol x, \boldsymbol u, \nabla \boldsymbol u, \Delta \boldsymbol u, \dots)
+\end{equation}
+$$
+
+where $\boldsymbol u(\boldsymbol x, y) \in \mathbb{R}^m$ is the vector of state variable with $m$ components,such as velocity, temperature or pressure, defined over the spatiotemporal domain $\{ \boldsymbol x, t \} \in \Omega \times [0, \mathcal{T}]$. Here, $\dot{\boldsymbol u}$ denotes the derivative with respect to time and $\boldsymbol F$ is a nonlinear operator that depends on the current state $\boldsymbol u$ and its spatial derivatives.
+
+We focus on a spatial domain $\Omega$ with non-uniformly and sparsely observed nodes $\{ \boldsymbol x_0, \dots, \boldsymbol x_{N-1} \}$ (e.g., on an unstructured mesh). Observations $\{ \boldsymbol U(t_0), \dots, \boldsymbol U(t_{T-1}) \}$ are collected at time points $t_0, ... \dots, t_{T- 1}$, where $\boldsymbol U(t_i) = \{ \boldsymbol u(\boldsymbol x_0, t_i), \dots, \boldsymbol u (\boldsymbol x_{N-1}, t_i) \}$ denote the physical quantities. Considering that many physical phenomena involve diffusion processes, we assume the diffusion term in the PDE is known as a priori knowledge. Our goal is to develop a graph learning model with small training datasets capable of accurately predicting various spatiotemporal dynamics on coarse unstructured meshes, handling different types of BCs, and producing the trajectory of dynamics for an arbitrarily given IC.
+
+## Model Architecture
+
+
+
+For Equation (1), a second-order Runge-Kutta (RK2) scheme can be used for discretization:
+
+$$
+\begin{equation}
+\boldsymbol u^{k+1} = \boldsymbol u^k + \frac{1}{2}(\boldsymbol g_1 + \boldsymbol g_2); \quad \boldsymbol g_1 = \boldsymbol F(t^k, \boldsymbol x, \boldsymbol u^k, \dots); \quad \boldsymbol g_2 = \boldsymbol F(t^{k+1}, \boldsymbol x, \boldsymbol u^k + \delta t \boldsymbol g_1, \dots)
+\end{equation}
+$$
+
+where $\boldsymbol u^k$ is the state variable at time $t^k$,and $\delta t$ denotes the time interval between $t^k$ and $t^{k+1}$. According to the Equation (2), we develop a GNN to learn the nonlinear operator $\boldsymbol F$.
+
+As shown in Figure, the NN block aims to learn the nonlinear operator $\boldsymbol F$ and consists of two parts: a GNN block followed the Encode-Process-Decode module and a learnable Laplace block. Due to the universality of diffusion processes in physical phenomena, we design the learnable Laplace block, which encodes the discrete Laplace-Beltrami operator, to learn the increment caused by the diffusion term in the PDE, while the GNN block is responsible to learn the increment induced by other unknown mechanisms or sources.
+
+## Requirements
+
+- python 3.11
+- mindspore 2.5.0
+- numpy 1.26
+
+## Dataset
+
+This [dataset](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/data_mechanism_fusion/PhyMPGN/) contains simulation data for cylinder flow, stored in HDF5 format, including geometric structures, fluid properties, and flow dynamics information. After downloading, please save it in the `data/2d_cf` directory. The dataset is divided into training and test sets:
+
+- **Training set** : `train_cf_4x2000x1598x2.h5` contains 4 trajectories.
+- **Test set** : `test_cf_9x2000x1598x2.h5` contains 9 trajectories.
+
+### Data Format
+
+Each HDF5 file contains the following attributes and datasets:
+
+- `f.attrs['x_c']`, `f.attrs['y_c']`: **Float**, coordinates of the cylinder center.
+- `f.attrs['r']`: **Float**, radius of the cylinder.
+- `f.attrs['x_l']`, `f.attrs['x_r']`, `f.attrs['y_b']`, `f.attrs['y_t']`: **Float**, boundaries of the computational domain.
+- `f.attrs['mu']`: **Float**, fluid viscosity.
+- `f.attrs['rho']`: **Float**, fluid density.
+- `f['pos']`: **(n, 2)**, positions of the observation nodes.
+- `f['mesh']`: **(n_tri, 3)**, triangular mesh of the observation nodes.
+- `g = f['node_type']`: Node type information.
+ - `g['inlet']`: **(n_inlet,)**, indices of inlet boundary nodes.
+ - `g['cylinder']`: **(n_cylinder,)**, indices of cylinder boundary nodes.
+ - `g['outlet']`: **(n_outlet,)**, indices of outlet boundary nodes.
+ - `g['inner']`: **(n_inner,)**, indices of inner domain nodes.
+- `g = f[i]`: The i-th trajectory.
+ - `g['U']`: **(t, n, 2)**, velocity states.
+ - `g['dt']`: **Float**, time interval between steps.
+ - `g['u_m']`: **Float**, inlet velocity.
+
+## Usage
+
+The `yamls/train.yaml` file serves as the configuration file for the project, including settings for dataset size, model parameters, and paths for logging and weight saving.
+
+**Traing**
+
+```python
+python main.py --config_file_path yamls/train.yaml --train
+```
+
+**Testing**
+
+```python
+python main.py --config_file_path yamls/train.yaml
+```
+
+## Visualization of results
+
+$Re=480$
+
+
+
+## Performance
+
+| Parameter | Ascend |
+| ------------------------------------- | ------------------------------ |
+| Hardware Resources | NPU 32G |
+| Framework Version | Minspore 2.5.0 |
+| Dataset | Cylinder flow |
+| Model Parameters | 950k |
+| Training Configuration | batch_size=4,
epochs=1600 |
+| Training Loss
(MSE) | |
+| Inference Loss
(MSE) | |
+| Training Speed
(s / epoch) | 420 s |
+| Inference Speed
(s / trajectory) | 174 s |
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/README_CN.md b/MindFlow/applications/data_mechanism_fusion/phympgn/README_CN.md
new file mode 100644
index 0000000000000000000000000000000000000000..55ae876f7264ef8b1710886c3453d0b57c5b4276
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/README_CN.md
@@ -0,0 +1,111 @@
+# 用于时空PDE系统的物理编码消息传递图神经网络 (PhyMPGN: Physics-encoded Message Passing Graph Network for spatiotemporal PDE systems)
+
+偏微分方程(PDEs)控制的复杂动力系统广泛存在于各个学科当中。近年来,数据驱动的神经网络模型在预测时空动态上取得了极好的效果。
+
+物理编码的消息传递图网络(PhyMPGN),可以使用少量训练数据在不规则计算域上建模时空PDE系统。具体来说,
+
+- 提出了一个使用消息传递机制的物理编码图学习模型,使用二阶龙格库塔(Runge-Kutta)数值方案进行时间步进
+- 考虑到物理现象中普遍存在扩散过程,设计了一个可学习的Laplace Block,编码了离散拉普拉斯-贝尔特拉米算子(Laplace-Beltrami Operator)
+- 提出了一个新颖的填充策略在模型中编码不同类型的边界条件
+
+论文链接: [https://arxiv.org/abs/2410.01337](https://gitee.com/link?target=https%3A%2F%2Farxiv.org%2Fabs%2F2410.01337)
+
+## 问题描述
+
+考虑由如下形式控制的时空PDE系统:
+
+$$
+\begin{equation}
+\dot {\boldsymbol {u}}(\boldsymbol x, t) = \boldsymbol F (t, \boldsymbol x, \boldsymbol u, \nabla \boldsymbol u, \Delta \boldsymbol u, \dots)
+\end{equation}
+$$
+
+其中$\boldsymbol u(\boldsymbol x, y) \in \mathbb{R}^m$是具有$m$个分量的状态变量向量,例如速度、温度或者压力等,它的定义在时空域$\{ \boldsymbol x, t \} \in \Omega \times [0, \mathcal{T}]$上;$\dot{\boldsymbol u}$代表$\boldsymbol u$对时间的导数,$\boldsymbol F$是依赖于当前状态$\boldsymbol u$和其空间导数的非线性算子。
+
+假设在空间域$\Omega$上有着非均匀且稀疏的观测结点$\{ \boldsymbol x_0, \dots, \boldsymbol x_{N-1} \}$(即,非结构化网格),在时刻$t_0, \dots, t_{T-1}$,这些结点上的观测为$\{ \boldsymbol U(t_0), \dots, \boldsymbol U(t_{T-1}) \}$,其中的$\boldsymbol U(t_i) = \{ \boldsymbol u(\boldsymbol x_0, t_i), \dots, \boldsymbol u (\boldsymbol x_{N-1}, t_i) \}$代表某些物理量。考虑到很多物理现象包含扩散过程,我们假设PDE中的扩散项是已知的先验知识。我们的目标是使用少量训练数据学习一个图神经网络模型,在稀疏非结构网格上预测不同的时空动态系统,处理不同的边界条件,为任意的初始条件产生后续动态轨迹。
+
+## 模型
+
+
+
+对于式(1),可以使用二阶龙格库塔(Runge-Kutta, RK2)方案进行离散:
+
+$$
+\begin{equation}
+\boldsymbol u^{k+1} = \boldsymbol u^k + \frac{1}{2}(\boldsymbol g_1 + \boldsymbol g_2); \quad \boldsymbol g_1 = \boldsymbol F(t^k, \boldsymbol x, \boldsymbol u^k, \dots); \quad \boldsymbol g_2 = \boldsymbol F(t^{k+1}, \boldsymbol x, \boldsymbol u^k + \delta t \boldsymbol g_1, \dots)
+\end{equation}
+$$
+
+其中$\boldsymbol u^k$为$t^k$时刻的状态变量,$\delta t$为时刻$t^k$和$t^{k+1}$之间的时间间隔。根据式(2),我们构建一个GNN来学习非线性算子$\boldsymbol F$.
+
+如图所示,我们使用NN block来学习非线性算子$\boldsymbol F$。NN block又可以分为两部分:采用编码器-处理器-解码器架构的GNN block和可学习的Laplace block。因为物理现象中扩散过程的普遍存在性,我们设计了可学习的Laplace block,编码离散拉普拉斯贝尔特拉米算子(Laplace-Beltrami operator),来学习由PDE中扩散项导致的增量;而GNN block来学习PDE中其他项导致的增量。
+
+## 相关依赖库
+
+- python 3.11
+- mindspore 2.5.0
+- numpy 1.26
+
+## 数据集
+
+该[数据集](https://download-mindspore.osinfra.cn/mindscience/mindflow/dataset/applications/data_mechanism_fusion/PhyMPGN/)包含圆柱绕流的模拟数据,以HDF5格式存储,包括几何结构、流体属性和流动动力学信息,下载后请保存在 `data/2d_cf`目录下。数据集分为训练集和测试集:
+
+- **训练集**:`train_cf_4x2000x1598x2.h5` 包含4条轨迹。
+- **测试集**:`test_cf_9x2000x1598x2.h5` 包含9条轨迹。
+
+### 数据格式
+
+每个HDF5文件包含以下属性和数据集:
+
+- `f.attrs['x_c'], f.attrs['y_c']`:**浮点数**,圆柱中心的坐标。
+- `f.attrs['r']`:**浮点数**,圆柱的半径。
+- `f.attrs['x_l'], f.attrs['x_r'], f.attrs['y_b'], f.attrs['y_t']`:**浮点数**,计算域的边界。
+- `f.attrs['mu']`:**浮点数**,流体粘度。
+- `f.attrs['rho']`:**浮点数**,流体密度。
+- `f['pos']`:**(n, 2)**,观测节点的位置。
+- `f['mesh']`:**(n_tri, 3)**,观测节点的三角网格。
+- `g = f['node_type']`:节点类型信息。
+ - `g['inlet']`:**(n_inlet,)**,入口边界节点的索引。
+ - `g['cylinder']`:**(n_cylinder,)**,圆柱边界节点的索引。
+ - `g['outlet']`:**(n_outlet,)**,出口边界节点的索引。
+ - `g['inner']`:**(n_inner,)**,域内节点的索引。
+- `g = f[i]`:第i条轨迹。
+ - `g['U']`:**(t, n, 2)**,速度状态。
+ - `g['dt']`:**浮点数**,时间步长之间的间隔。
+ - `g['u_m']`:**浮点数**,入口速度。
+
+## 使用方法
+
+`yamls/train.yaml`是项目的配置文件,包括数据集的大小、模型参数和日志、权重保存路径等设置。
+
+**训练**
+
+```python
+python main.py --config_file_path yamls/train.yaml --train
+```
+
+**测试**
+
+```python
+python main.py --config_file_path yamls/train.yaml
+```
+
+## 结果展示
+
+$Re=480$
+
+
+
+## 性能
+
+| 参数 | Ascend |
+| ------------------------------ | ------------------------------ |
+| 硬件资源 | NPU 显存32G |
+| 版本 | Minspore 2.5.0 |
+| 数据集 | Cylinder flow |
+| 参数量 | 950k |
+| 训练参数 | batch_size=4,
epochs=1600 |
+| 训练损失
(MSE) | |
+| 推理损失
(MSE) | |
+| 训练速度
(s / epoch) | 420 s |
+| 推理速度
(s / trajectory) | 174 s |
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/images/cf.png b/MindFlow/applications/data_mechanism_fusion/phympgn/images/cf.png
new file mode 100644
index 0000000000000000000000000000000000000000..d074786ef28d30d13f6c70ad46b7a47cc179772b
Binary files /dev/null and b/MindFlow/applications/data_mechanism_fusion/phympgn/images/cf.png differ
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/images/phympgn.png b/MindFlow/applications/data_mechanism_fusion/phympgn/images/phympgn.png
new file mode 100644
index 0000000000000000000000000000000000000000..c14bc1891ba6652070fad1ac7d58d9aa28e39e98
Binary files /dev/null and b/MindFlow/applications/data_mechanism_fusion/phympgn/images/phympgn.png differ
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/main.py b/MindFlow/applications/data_mechanism_fusion/phympgn/main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ba7306c318f3fe4147150e395a4724498c6f6d0
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/main.py
@@ -0,0 +1,159 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""main"""
+import argparse
+
+import numpy as np
+from easydict import EasyDict
+from mindspore import nn
+import mindspore as ms
+from mindflow import get_multi_step_lr
+from mindflow.utils import log_config, load_yaml_config, print_log
+from src import PDECFDataset, Trainer, get_data_loader, PhyMPGN, TwoStepLoss
+
+
+ms.set_device(device_target='Ascend', device_id=7)
+
+
+def parse_args():
+ """parse input args"""
+ parser = argparse.ArgumentParser(description="cylinder flow train")
+ parser.add_argument("--config_file_path", type=str,
+ default="./yamls/train.yaml")
+ parser.add_argument('--train', action='store_true',
+ default=False)
+ input_args = parser.parse_args()
+ return input_args
+
+
+def load_config():
+ """load config"""
+ args = parse_args()
+ config = load_yaml_config(args.config_file_path)
+ config['train'] = args.train
+ config = EasyDict(config)
+ if args.train:
+ log_config('./logs', f'phympgn-{config.experiment_name}')
+ else:
+ log_config('./logs', f'phympgn-{config.experiment_name}-te')
+ print_log(config)
+ return config
+
+
+def train(config):
+ """train"""
+ print_log('Train...')
+ print_log('Loading training data...')
+ tr_dataset = PDECFDataset(
+ root=config.path.data_root_dir,
+ raw_files=config.path.tr_raw_data,
+ dataset_start=config.data.dataset_start,
+ dataset_used=config.data.dataset_used,
+ time_start=config.data.time_start,
+ time_used=config.data.time_used,
+ window_size=config.data.tr_window_size,
+ training=True
+ )
+ tr_loader = get_data_loader(
+ dataset=tr_dataset,
+ batch_size=config.optim.batch_size
+ )
+
+ print_log('Loading validation data...')
+ val_dataset = PDECFDataset(
+ root=config.path.data_root_dir,
+ raw_files=config.path.val_raw_data,
+ dataset_start=config.data.dataset_start,
+ dataset_used=config.data.dataset_used,
+ time_start=config.data.time_start,
+ time_used=config.data.time_used,
+ window_size=config.data.val_window_size
+ )
+ val_loader = get_data_loader(
+ dataset=val_dataset,
+ batch_size=config.optim.batch_size
+ )
+
+ print_log('Building model...')
+ model = PhyMPGN(
+ encoder_config=config.network.encoder_config,
+ mpnn_block_config=config.network.mpnn_block_config,
+ decoder_config=config.network.decoder_config,
+ laplace_block_config=config.network.laplace_block_config,
+ integral=config.network.integral
+ )
+ print_log(f'Number of parameters: {model.num_params}')
+ lr_scheduler = get_multi_step_lr(
+ lr_init=config.optim.lr,
+ milestones=list(np.arange(0, config.optim.start_epoch+config.optim.epochs,
+ step=config.optim.steplr_size)[1:]),
+ gamma=config.optim.steplr_gamma,
+ steps_per_epoch=len(tr_loader),
+ last_epoch=config.optim.start_epoch+config.optim.epochs-1
+ )
+ optimizer = nn.AdamWeightDecay(model.trainable_params(), learning_rate=lr_scheduler,
+ eps=1.0e-8, weight_decay=1.0e-2)
+ trainer = Trainer(
+ model=model, optimizer=optimizer, scheduler=lr_scheduler, config=config,
+ loss_func=TwoStepLoss()
+ )
+ trainer.train(tr_loader, val_loader)
+
+
+def test(config):
+ """test"""
+ te_dataset = PDECFDataset(
+ root=config.path.data_root_dir,
+ raw_files=config.path.te_raw_data,
+ dataset_start=config.data.te_dataset_start,
+ dataset_used=config.data.te_dataset_used,
+ time_start=config.data.time_start,
+ time_used=config.data.time_used,
+ window_size=config.data.te_window_size,
+ training=False
+ )
+ te_loader = get_data_loader(
+ dataset=te_dataset,
+ batch_size=1,
+ shuffle=False,
+ )
+
+ print_log('Building model...')
+ model = PhyMPGN(
+ encoder_config=config.network.encoder_config,
+ mpnn_block_config=config.network.mpnn_block_config,
+ decoder_config=config.network.decoder_config,
+ laplace_block_config=config.network.laplace_block_config,
+ integral=config.network.integral
+ )
+ print_log(f'Number of parameters: {model.num_params}')
+ trainer = Trainer(
+ model=model, optimizer=None, scheduler=None, config=config,
+ loss_func=nn.MSELoss()
+ )
+ print_log('Test...')
+ trainer.test(te_loader)
+
+
+def main():
+ config = load_config()
+ if config.train:
+ train(config)
+ else:
+ test(config)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..22401410f6ed88ba902e1f92d2f492f30d0e63f8
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/__init__.py
@@ -0,0 +1,20 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""init"""
+from .datasets.dataset import PDECFDataset
+from .trainers.trainer import Trainer
+from .loaders.data_loader import get_data_loader
+from .models.phympgn import PhyMPGN
+from .models.loss import TwoStepLoss
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ad829162a4f515c7bc37784d5b44d01449ff83
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""init"""
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/data.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c76b1f866813f572d8ba167a01f23eaf9f23ab18
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/data.py
@@ -0,0 +1,45 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""data"""
+from mindspore import ops
+
+
+class Graph:
+ """Graph"""
+ def __init__(self, **kwargs):
+ self.pos = None
+ self.edge_index = None
+ self.edge_attr = None
+ for key, value in kwargs.items():
+ setattr(self, key, value)
+
+ def detach(self):
+ new_graph = Graph()
+ for attr, value in self.__dict__.items():
+ new_graph.__setattr__(str(attr), ops.stop_gradient(value))
+ return new_graph
+
+ def __repr__(self) -> str:
+ cls = self.__class__.__name__
+
+ infos = []
+ for attr, value in self.__dict__.items():
+ if value is None:
+ continue
+ out = str(list(value.shape))
+ key = str(attr)
+ infos.append(f'{key}={out}')
+ infos = ', '.join(infos)
+ return f'{cls}({infos})'
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/dataset.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..e47df9c9439f870e042d4989f708303edd3e0210
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/dataset.py
@@ -0,0 +1,186 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""dataset"""
+import os.path as osp
+
+import h5py
+from tqdm import tqdm
+import numpy as np
+import mindspore as ms
+from mindspore import Tensor
+import mindspore.ops as ops
+
+from .transform import Compose, Distance, Cartesian, Dirichlet, \
+ DirichletInlet, NodeTypeInfo, MaskFace, Delaunay, FaceToEdge
+from .data import Graph
+from .utils import add_noise
+from ..utils.voronoi_laplace import compute_discrete_laplace
+from ..utils.padding import graph_padding
+
+
+class PDECFDataset:
+ """PDECFDataset"""
+ def __init__(self, root, raw_files, dataset_start,
+ dataset_used, time_start, time_used, window_size,
+ training=False):
+ self.raw_files = raw_files
+ self.laplace_file = 'laplace.npy'
+ self.d_file = 'd_vector.npy'
+ self.root = root
+ self.training = training
+
+ self.dataset_start = dataset_start
+ self.dataset_used = dataset_used
+ self.time_start = time_start
+ self.time_used = time_used
+ self.window_size = window_size
+
+ self.set_transform()
+ self.data_list = self.process()
+
+ def __getitem__(self, index):
+ graph = self.data_list[index]
+ return graph.pos, graph.y, graph.edge_index, graph.edge_attr, graph.dt, \
+ graph.mu, graph.r, graph.rho, graph.L, graph.d, graph.u_m, \
+ graph.dirichlet_index, graph.inlet_index, graph.dirichlet_value, \
+ graph.inlet_value, graph.node_type, graph.truth_index
+
+ def __len__(self):
+ return len(self.data_list)
+
+ def set_transform(self):
+ """set transform"""
+ self.periodic_trans = None
+ self.dirichlet_trans = Dirichlet()
+ self.inlet_trans = DirichletInlet()
+ self.neumann_trans = None
+ self.node_type_trans = NodeTypeInfo()
+ self.mask_face_trans = MaskFace()
+ self.transform = [
+ Delaunay(),
+ self.mask_face_trans,
+ FaceToEdge(remove_faces=False),
+ Distance(norm=True),
+ Cartesian(norm=True),
+ ]
+ if self.dirichlet_trans is not None:
+ self.transform.append(self.dirichlet_trans)
+ if self.dirichlet_trans is not None:
+ self.transform.append(self.inlet_trans)
+ if self.periodic_trans is not None:
+ self.transform.append(self.periodic_trans)
+ if self.neumann_trans is not None:
+ self.transform.append(self.neumann_trans)
+ self.transform.append(self.node_type_trans)
+ self.transform = Compose(transforms=self.transform)
+
+ def process(self):
+ """process data"""
+ data_list = []
+ file_handler = h5py.File(osp.join(self.root, self.raw_files))
+ coarse_pos = file_handler['pos'][:] # (n, 2)
+ r = file_handler.attrs['r']
+ mu = file_handler.attrs['mu']
+ rho = file_handler.attrs['rho']
+ node_type = file_handler['node_type']
+ inlet_index, cylinder_index = node_type['inlet'][:], node_type['cylinder'][:]
+ self.dirichlet_trans.set_index(cylinder_index)
+ self.inlet_trans.set_index(inlet_index)
+ self.node_type_trans.set_type_dict(node_type)
+ self.mask_face_trans.set_cylinder_index(cylinder_index)
+ for i in tqdm(range(self.dataset_start, self.dataset_used)):
+ # (t, n_f, d)
+ g = file_handler[str(i)]
+ u = g['U'][:]
+ dt = g.attrs['dt']
+ u_m = g.attrs['u_m']
+
+ # dimensionless
+ u = u / u_m
+ pos = coarse_pos / (2 * r)
+ dt = dt / (2 * r / u_m)
+
+ # to tensor
+ u_t = Tensor(u, dtype=ms.float32) # (t, n, d)
+ pos_t = Tensor(pos, dtype=ms.float32)
+ # (n,)
+ truth_index = Tensor(ms.numpy.arange(pos.shape[0]), dtype=ms.int64)
+ # (n, 1)
+ u_m_t = ops.ones((pos.shape[0], 1), dtype=ms.float32) * u_m
+ dt_t = ops.ones((pos.shape[0], 1), dtype=ms.float32) * dt
+ r_t = ops.ones((pos.shape[0], 1), dtype=ms.float32) * r
+ mu_t = ops.ones((pos.shape[0], 1), dtype=ms.float32) * mu
+ rho_t = ops.ones((pos.shape[0], 1), dtype=ms.float32) * rho
+
+ for idx in ms.numpy.arange(self.time_start,
+ self.time_start + self.time_used,
+ step=self.window_size):
+ # [t, n, c] -> [n, t, c]
+ if idx + self.window_size > self.time_start + self.time_used:
+ break
+ y = u_t[idx:idx + self.window_size].permute(1, 0, 2)
+ if self.training:
+ y[:, 0, :] = add_noise(y[:, 0, :], percentage=0.03)
+
+ graph = Graph(pos=pos_t, y=y,
+ truth_index=truth_index,
+ dt=dt_t, u_m=u_m_t,
+ r=r_t, mu=mu_t,
+ rho=rho_t)
+ graph = self.transform(graph)
+ data_list.append(graph)
+
+ if osp.exists(osp.join(self.root, self.laplace_file)):
+ laplace_matrix_np = np.load(osp.join(self.root, self.laplace_file))
+ d_vector_np = np.load(osp.join(self.root, self.d_file))
+ laplace_matrix = ms.Tensor(laplace_matrix_np, dtype=ms.float32)
+ d_vector = ms.Tensor(d_vector_np, dtype=ms.float32)
+ else:
+ laplace_matrix_np, d_vector_np = compute_discrete_laplace(
+ pos=data_list[0].pos.numpy(),
+ edge_index=data_list[0].edge_index.numpy(),
+ face=data_list[0].face.numpy()
+ )
+ d_vector_np = d_vector_np[:, None]
+ np.save(osp.join(self.root, self.laplace_file), laplace_matrix_np)
+ np.save(osp.join(self.root, self.d_file), d_vector_np)
+ laplace_matrix = ms.Tensor(laplace_matrix_np, dtype=ms.float32)
+ d_vector = ms.Tensor(d_vector_np, dtype=ms.float32)
+
+ for data in data_list:
+ data.L = laplace_matrix
+ data.d = d_vector
+ data.dirichlet_value = ops.zeros((data.dirichlet_index.shape[0],
+ data.y.shape[2]))
+ data.inlet_value = self.inlet_velocity(
+ data.inlet_index, 1.)
+ graph_padding(data, clone=True)
+
+ return data_list
+
+ @staticmethod
+ def inlet_velocity(inlet_index, u_m):
+ u = u_m * ops.ones(inlet_index.shape[0])
+ v = ops.zeros_like(u)
+
+ return ops.stack((u, v), axis=-1) # (m, 2)
+
+ @staticmethod
+ def dimensional(u_pred, u_gt, pos, u_m, d):
+ u_pred = u_pred * u_m
+ u_gt = u_gt * u_m
+ pos = pos * d
+
+ return u_pred, u_gt, pos
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/transform.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/transform.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd30b64f757c8d97681fef2c24ec542dba9db720
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/transform.py
@@ -0,0 +1,229 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""transform"""
+import scipy.spatial
+from mindspore import ops, Tensor
+import mindspore as ms
+
+from .utils import NodeType
+from .utils import to_undirected
+
+
+class BaseTransform:
+ def __call__(self, data):
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}()'
+
+
+class Compose(BaseTransform):
+ """Composes several transforms together."""
+ def __init__(self, transforms):
+ self.transforms = transforms
+
+ def __call__(self, data):
+ for transform in self.transforms:
+ data = transform(data)
+ return data
+
+ def __repr__(self) -> str:
+ args = [f' {transform}' for transform in self.transforms]
+ return '{}([\n{}\n])'.format(self.__class__.__name__, ',\n'.join(args))
+
+
+class Cartesian(BaseTransform):
+ """Saves the relative Cartesian coordinates of linked nodes in its edge
+ attributes."""
+ def __init__(self, norm: bool = False):
+ self.norm = norm
+
+ def __call__(self, data):
+ (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
+
+ cart = pos[col] - pos[row]
+ cart = cart.view(-1, 1) if cart.dim() == 1 else cart
+ data.rel_pos = cart
+
+ if self.norm and cart.numel() > 0:
+ max_value = cart.abs().max()
+ cart = cart / (2 * max_value) + 0.5
+
+ if pseudo is not None:
+ pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
+ data.edge_attr = ops.cat([pseudo, cart.type_as(pseudo)], axis=-1)
+ else:
+ data.edge_attr = cart
+
+ return data
+
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}(norm={self.norm}'
+
+
+class Distance(BaseTransform):
+ """Saves the Euclidean distance of linked nodes in its edge attributes."""
+ def __init__(self, norm: bool = False):
+ self.norm = norm
+
+ def __call__(self, data):
+ (row, col), pos, pseudo = data.edge_index, data.pos, data.edge_attr
+
+ dist = ms.numpy.norm(pos[col] - pos[row], axis=-1).view(-1, 1)
+ data.distance = dist
+
+ if self.norm and dist.numel() > 0:
+ dist = dist / dist.max()
+
+ if pseudo is not None:
+ pseudo = pseudo.view(-1, 1) if pseudo.dim() == 1 else pseudo
+ data.edge_attr = ops.cat([pseudo, dist.type_as(pseudo)], axis=-1)
+ else:
+ data.edge_attr = dist
+
+ return data
+
+ def __repr__(self) -> str:
+ return f'{self.__class__.__name__}(norm={self.norm}'
+
+
+class Delaunay(BaseTransform):
+ """Delaunay transform"""
+ def __call__(self, data):
+ if data.pos.shape[0] < 2:
+ data.edge_index = ops.Tensor([], dtype=ms.int64,
+ device=data.pos.device).view(2, 0)
+ if data.pos.shape[0] == 2:
+ data.edge_index = ops.Tensor([[0, 1], [1, 0]], dtype=ms.int64,
+ device=data.pos.device)
+ elif data.pos.shape[0] == 3:
+ data.face = ops.Tensor([[0], [1], [2]], dtype=ms.int64,
+ device=data.pos.device)
+ if data.pos.shape[0] > 3:
+ pos = data.pos.numpy()
+ tri = scipy.spatial.Delaunay(pos, qhull_options='QJ')
+ face = ms.from_numpy(tri.simplices)
+
+ data.face = face.t().contiguous().to(ms.int64)
+
+ return data
+
+
+class FaceToEdge(BaseTransform):
+ """FaceToEdge transform"""
+ def __init__(self, remove_faces: bool = True):
+ self.remove_faces = remove_faces
+
+ def __call__(self, data):
+ if hasattr(data, 'face'):
+ face = data.face
+ edge_index = ops.cat([face[:2], face[1:], face[::2]], axis=1)
+ edge_index = to_undirected(edge_index, num_nodes=data.pos.shape[0])
+
+ data.edge_index = edge_index
+ if self.remove_faces:
+ data.face = None
+
+ return data
+
+
+class Dirichlet(BaseTransform):
+ """Dirichlet transform"""
+ def __init__(self):
+ self.index = None
+
+ def set_index(self, index):
+ self.index = Tensor(index, dtype=ms.int64)
+
+ def __call__(self, data):
+ data.dirichlet_index = self.index
+ return data
+
+
+class DirichletInlet(BaseTransform):
+ """DirichletInlet transform"""
+ def __init__(self):
+ self.index = None
+
+ def set_index(self, index):
+ self.index = Tensor(index, dtype=ms.int64)
+
+ def __call__(self, data):
+ data.inlet_index = self.index
+ return data
+
+
+class MaskFace(BaseTransform):
+ """MaskFace transform"""
+ def __init__(self):
+ self.cylinder_index = None
+ self.new_face_index = None
+
+ def is_none(self):
+ return self.new_face_index is None
+
+ def set_cylinder_index(self, cylinder_index):
+ self.cylinder_index = Tensor(cylinder_index, dtype=ms.int64)
+
+ def __call__(self, data):
+ if self.is_none():
+ self.new_face_index = self.cal_mask_face(data)
+
+ data.face = data.face[:, self.new_face_index]
+ return data
+
+ def cal_mask_face(self, graph):
+ on_circle_index = self.cylinder_index
+ new_face_index = []
+ for i in range(graph.face.shape[1]):
+ if ms.numpy.isin(graph.face[:, i], on_circle_index).all():
+ continue
+ else:
+ new_face_index.append(i)
+ return Tensor(new_face_index)
+
+
+class NodeTypeInfo(BaseTransform):
+ """NodeTypeInfo transform"""
+ def __init__(self):
+ self.type_dict = None
+ self.node_type = None
+
+ def is_none(self):
+ return self.node_type is None
+
+ def set_type_dict(self, type_dict):
+ self.type_dict = type_dict
+
+ def __call__(self, data):
+ if self.is_none():
+ self.node_type = self.cal_node_type(data)
+
+ data.node_type = self.node_type
+ return data
+
+ def cal_node_type(self, data):
+ """compute node type"""
+ node_num = data.pos.shape[0]
+ node_type = ops.ones(node_num, dtype=ms.int64) * NodeType.NORMAL
+ if hasattr(data, 'dirichlet_index'):
+ node_type[data.dirichlet_index] = NodeType.OBSTACLE
+ if hasattr(data, 'inlet_index'):
+ node_type[data.inlet_index] = NodeType.INLET
+
+ outlet_index = self.type_dict['outlet'][:]
+ outlet_index = Tensor(outlet_index, dtype=ms.int64)
+ node_type[outlet_index] = NodeType.OUTLET
+ return node_type
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/utils.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f51d250ad33c5e6fff102358f348dea4ee3f4e32
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/datasets/utils.py
@@ -0,0 +1,74 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""datasets utils"""
+import enum
+
+import mindspore.numpy as mnp
+from mindspore import ops
+
+
+class NodeType(enum.IntEnum):
+ NORMAL = 0
+ OBSTACLE = 1
+ INLET = 2
+ OUTLET = 3
+
+
+def add_noise(truth, percentage=0.05):
+ """add noise"""
+ # shape of truth must be (n, 2)
+ assert truth.shape[1] == 2
+ uv = [truth[:, 0:1], truth[:, 1:2]]
+ uv_noi = []
+ for component in uv:
+ r = ops.normal(mean=0.0, stddev=1.0, shape=component.shape)
+ std_r = ops.std(r) # std of samples
+ std_t = ops.std(component)
+ noise = r * std_t / std_r * percentage
+ uv_noi.append(component + noise)
+ return ops.cat(uv_noi, axis=1)
+
+
+def to_undirected(edge_index, num_nodes):
+ """to undirected"""
+ row, col = edge_index[0], edge_index[1]
+ row, col = ops.cat([row, col], axis=0), ops.cat([col, row], axis=0)
+ edge_index = ops.stack([row, col], axis=0)
+
+ return coalesce(edge_index, num_nodes)
+
+
+def coalesce(edge_index, num_nodes, is_sorted=False, sort_by_row=True):
+ """coalesce"""
+ nnz = edge_index.shape[1]
+ idx = mnp.empty(nnz + 1, dtype=edge_index.dtype)
+ idx[0] = -1
+ idx[1:] = edge_index[1 - int(sort_by_row)]
+ idx[1:] = idx[1:].mul(num_nodes).add(edge_index[int(sort_by_row)])
+
+ if not is_sorted:
+ # idx[1:], perm = index_sort(idx[1:], max_value=num_nodes * num_nodes)
+ idx[1:], perm = idx[1:].sort()
+ edge_index = edge_index[:, perm]
+
+ mask = idx[1:] > idx[:-1]
+
+ # Only perform expensive merging in case there exists duplicates:
+ if mask.all():
+ return edge_index
+
+ edge_index = edge_index[:, mask]
+
+ return edge_index
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ad829162a4f515c7bc37784d5b44d01449ff83
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""init"""
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/data_loader.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/data_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..b55fc9c99184188bbe1d9780bc0a74a1710b97c3
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/loaders/data_loader.py
@@ -0,0 +1,155 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""data loader"""
+import numpy as np
+import mindspore.dataset as ds
+import mindspore.ops as ops
+from mindspore import Tensor
+import mindspore as ms
+
+from ..datasets.data import Graph
+from ..datasets.dataset import PDECFDataset
+
+
+def get_data_loader(dataset: PDECFDataset, batch_size: int, shuffle: bool = True) \
+ -> ds.GeneratorDataset:
+ """get data loader"""
+ column_names = ['pos', 'y', 'edge_index', 'edge_attr', 'dt', 'mu', 'r',
+ 'rho', 'L', 'd', 'u_m', 'dirichlet_index', 'inlet_index',
+ 'dirichlet_value', 'inlet_value', 'node_type',
+ 'truth_index']
+
+ if shuffle:
+ loader = ds.GeneratorDataset(
+ source=dataset,
+ column_names=column_names
+ )
+ loader = loader.shuffle(buffer_size=5000).batch(batch_size)
+ else:
+ loader = ds.GeneratorDataset(
+ source=dataset,
+ column_names=column_names,
+ shuffle=False
+ )
+ loader = loader.batch(batch_size)
+ return loader
+
+
+def batch_graph(data: dict):
+ """batch graph"""
+ pos, y, edge_index = data['pos'], data['y'], data['edge_index']
+ edge_attr = data['edge_attr']
+ dt, mu, r, rho, node_type, u_m = data['dt'], data['mu'], data['r'], data['rho'], \
+ data['node_type'], data['u_m']
+ l, d = data['L'], data['d']
+ dirichlet_index, inlet_index = data['dirichlet_index'], data['inlet_index']
+ dirichlet_value, inlet_value = data['dirichlet_value'], data['inlet_value']
+ truth_index = data['truth_index']
+
+ batch_size = pos.shape[0]
+ node_num = pos.shape[1]
+ edge_num = edge_index.shape[2]
+ m = y.shape[2]
+ # (b, n, p_d) -> (b*n, p_d)
+ pos_batch = pos.reshape(batch_size * node_num, pos.shape[2])
+ # (b, n, m, y_d) -> (b*n, m, y_d)
+ y_batch = y.reshape(batch_size * node_num, m, y.shape[3])
+ # (b, 2, e) -> (2, b_e)
+ edge_index_batch = batched_edge_index(edge_index, node_num)
+ # (b, e, e_d) -> (b*e, e_d)
+ edge_attr_batch = edge_attr.reshape(batch_size * edge_num,
+ edge_attr.shape[2])
+ # (bn,)
+ batch = np.concatenate([i * np.ones(node_num) for i in range(batch_size)])
+ batch = Tensor(batch, dtype=ms.int64)
+
+ # (b, n, 1) -> (b*n, 1)
+ dt_batch = dt.reshape(batch_size * node_num, dt.shape[2])
+ mu_batch = mu.reshape(batch_size * node_num, mu.shape[2])
+ r_batch = r.reshape(batch_size * node_num, r.shape[2])
+ rho_batch = rho.reshape(batch_size * node_num, rho.shape[2])
+ u_m_batch = u_m.reshape(batch_size * node_num, u_m.shape[2])
+
+ # (b, n, n) -> (b*n, b*n)
+ l_batch = ops.block_diag(*ops.unbind(l))
+ # (b, n, 1) -> (b*n, 1)
+ d_batch = d.reshape(batch_size * node_num, d.shape[2])
+
+ # (b, m) -> (b*m,)
+ node_type_batch = node_type.reshape(-1)
+ dirichlet_index_batch = batched_node_index(dirichlet_index, node_num)
+ inlet_index_batch = batched_node_index(inlet_index, node_num)
+ truth_index_batch = batched_node_index(truth_index, node_num)
+
+ # (b, m, 2) -> (b*m, 2)
+ dirichlet_value_batch = dirichlet_value.reshape(
+ batch_size * dirichlet_value.shape[1], dirichlet_value.shape[2])
+ inlet_value_batch = inlet_value.reshape(
+ batch_size * inlet_value.shape[1], inlet_value.shape[2])
+
+ return Graph(pos=pos_batch, y=y_batch, edge_index=edge_index_batch,
+ edge_attr=edge_attr_batch,
+ dt=dt_batch, mu=mu_batch, r=r_batch, rho=rho_batch,
+ u_m=u_m_batch,
+ node_type=node_type_batch, L=l_batch, d=d_batch,
+ dirichlet_index=dirichlet_index_batch,
+ inlet_index=inlet_index_batch,
+ dirichlet_value=dirichlet_value_batch,
+ inlet_value=inlet_value_batch,
+ truth_index=truth_index_batch,
+ batch=batch)
+
+
+def batched_edge_index(edge_index, node_num):
+ """batch edge index
+
+ Args:
+ edge_index (Tensor): Shape (b, 2, e)
+ node_num (int): Number of nodes in each graph
+
+ Returns:
+ edge_index_batch: batched edge index
+ """
+ add_index = np.concatenate(
+ [node_num*i*np.ones([1, edge_index.shape[1],
+ edge_index.shape[2]], np.int64)
+ for i in range(edge_index.shape[0])], axis=0)
+ if isinstance(edge_index, np.ndarray):
+ return add_index + edge_index
+ edge_index_batch = Tensor(add_index) + edge_index
+ batch_size, edge_num = edge_index_batch.shape[0], edge_index_batch.shape[2]
+ edge_index_batch = ops.permute(edge_index_batch, (1, 0, 2))\
+ .reshape(2, edge_num * batch_size)
+ return edge_index_batch
+
+
+def batched_node_index(node_index, node_num):
+ """batched index of nodes
+
+ Args:
+ node_index (Tensor): Shape (b, m)
+ node_num (int): Number of nodes in each graph
+
+ Returns:
+ node_index_batch: batched index of nodes.
+ """
+ add_index = np.concatenate(
+ [node_num*i*np.ones([1, node_index.shape[1]], np.int64)
+ for i in range(node_index.shape[0])], axis=0)
+ if isinstance(node_index, np.ndarray):
+ return add_index + node_index
+ node_index_batch = Tensor(add_index) + node_index
+ node_index_batch = node_index_batch.reshape(-1)
+ return node_index_batch
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ad829162a4f515c7bc37784d5b44d01449ff83
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""init"""
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/encoder_decoder.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/encoder_decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..8833ff52bed9c3cca9d35079a0a78b00405d2e0e
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/encoder_decoder.py
@@ -0,0 +1,40 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""encoder and decoder module"""
+from mindspore import nn
+
+from .utils import build_net
+
+
+class Encoder(nn.Cell):
+ "Encoder"
+ def __init__(self, layers):
+ super().__init__()
+ self.layers = layers
+ self.net = build_net(layers)
+
+ def construct(self, inputs):
+ return self.net(inputs)
+
+
+class Decoder(nn.Cell):
+ "Decoder"
+ def __init__(self, layers):
+ super().__init__()
+ self.layers = layers
+ self.net = build_net(layers)
+
+ def construct(self, h):
+ return self.net(h)
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/laplace_block.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/laplace_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1af4b069abf372097a8a360affca05b0d1e2847
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/laplace_block.py
@@ -0,0 +1,84 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""Laplace block"""
+from mindspore import nn, ops
+
+from .utils import activation_func
+from .mpnn_layer import MPNNLayer
+
+
+class LaplaceBlock(nn.Cell):
+ """Laplace block"""
+ def __init__(self, enc_dim, h_dim, out_dim):
+ super().__init__()
+ self.encoder = nn.SequentialCell(
+ nn.Dense(enc_dim, h_dim),
+ activation_func(),
+ nn.Dense(h_dim, h_dim)
+ )
+ self.processor = LaplaceProcessor(
+ mpnn_layers=[
+ [h_dim * 2 + 3, h_dim, h_dim],
+ [h_dim * 2, h_dim, h_dim]
+ ],
+ mpnn_num=3
+ )
+ self.decoder = nn.SequentialCell(
+ nn.Dense(h_dim, h_dim),
+ activation_func(),
+ nn.Dense(h_dim, out_dim)
+ )
+
+ def cal_mesh_laplace(self, graph):
+ laplace = graph.L @ graph.y
+ return laplace
+
+ def construct(self, graph):
+ h = self.encoder(ops.cat((graph.y, graph.pos), axis=-1))
+ edge_attr = graph.edge_attr[:, :3]
+ h = self.processor(h, edge_attr, graph.edge_index)
+ out = self.decoder(h)
+ out = graph.d * out
+
+ out = out + self.cal_mesh_laplace(graph)
+ return out
+
+ @property
+ def num_params(self):
+ params = self.trainable_params()
+ return sum(param.size for param in params)
+
+
+class LaplaceProcessor(nn.Cell):
+ """Laplace Processor"""
+ def __init__(self, mpnn_layers, mpnn_num):
+ super().__init__()
+ self.phi_layers = mpnn_layers[0]
+ self.gamma_layers = mpnn_layers[1]
+ self.mpnn_num = mpnn_num
+ self.nets = self.build_block()
+
+ def build_block(self):
+ nets = nn.CellList()
+ for _ in range(self.mpnn_num):
+ nets.append(MPNNLayer(self.phi_layers, self.gamma_layers))
+ return nets
+
+ def construct(self, h, edge_attr, edge_index):
+ for mpnn in self.nets:
+ h = h + mpnn(edge_index=edge_index,
+ node_features=h,
+ edge_features=edge_attr)
+ return h
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/loss.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..153f47a08cb0ba680156cde692430b9c90e63a6c
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/loss.py
@@ -0,0 +1,39 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""loss"""
+from mindspore import nn, ops
+
+
+class TwoStepLoss(nn.Cell):
+ """TwoStepLoss"""
+ def __init__(self):
+ super().__init__()
+ self.loss_func = nn.MSELoss()
+
+ def construct(self, u_pred, truth, mask=None):
+ """construct"""
+ pred1 = u_pred[1] # [bxn, 2]
+ predn = u_pred[-1]
+ new_pred = ops.stack((pred1, predn), axis=0)
+
+ truth1 = truth[1]
+ truthn = truth[-1]
+ new_truth = ops.stack((truth1, truthn), axis=0)
+
+ if mask is None:
+ loss = self.loss_func(new_pred, new_truth)
+ else:
+ loss = self.loss_func(new_pred[:, mask], new_truth[:, mask])
+ return loss
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_block.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..1edfc67f42a0e045d7ecbdb1c779ac1991d3df02
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_block.py
@@ -0,0 +1,60 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""mpnn block"""
+from mindspore import nn
+
+from ..utils.padding import h_padding
+from .mpnn_layer import MPNNLayer
+
+
+class MPNNBlock(nn.Cell):
+ """MPNNBlock"""
+ def __init__(self, mpnn_layers, mpnn_num):
+ super().__init__()
+ self.phi_layers = mpnn_layers[0]
+ self.gamma_layers = mpnn_layers[1]
+ self.mpnn_num = mpnn_num
+ self.nets = self.build_block()
+
+ def build_block(self):
+ nets = nn.CellList()
+ for _ in range(self.mpnn_num):
+ nets.append(MPNNLayer(self.phi_layers, self.gamma_layers))
+ return nets
+
+ def construct(self, graph):
+ """construct"""
+ h = graph.state_node
+ for i in range(len(self.nets) - 1):
+ mpnn = self.nets[i]
+ h = h + mpnn(
+ edge_index=graph.edge_index,
+ node_features=h,
+ edge_features=graph.state_edge
+ )
+ # padding
+ h_padding(h, graph)
+
+ h = self.nets[-1](
+ edge_index=graph.edge_index,
+ node_features=h,
+ edge_features=graph.state_edge
+ )
+ return h # (b_n, node_h_dim)
+
+ @property
+ def num_params(self):
+ params = self.trainable_params()
+ return sum(param.size for param in params)
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_layer.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d7e3c024150556334693ff81c80b120a31b069d
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/mpnn_layer.py
@@ -0,0 +1,79 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""mpnn layer"""
+from mindspore import nn, ops
+
+from .utils import build_net
+
+
+class MPNNLayer(nn.Cell):
+ """MPNNLayer"""
+ def __init__(self, phi_layers, gamma_layers):
+ super().__init__()
+ self.phi = build_net(phi_layers)
+ self.gamma = build_net(gamma_layers)
+
+ def construct(self, edge_index, node_features, edge_features):
+ # (n_edge, h_dim)
+ m_ij = self.message(edge_index, node_features, edge_features)
+ # (n_node, h_dim)
+ aggr = self.aggregate(edge_index, m_ij, node_features.shape[0])
+ # (n_node, h_dim)
+ node_features_new = self.update(aggr, node_features)
+ return node_features_new
+
+ def message(self, edge_index, node_features, edge_features):
+ sender, receiver = edge_index[0], edge_index[1]
+ # (n_edge, node_h_dim * 2 + edge_h_dim)
+ phi_input = ops.cat([node_features[sender],
+ node_features[receiver] - node_features[sender],
+ edge_features], axis=1)
+ return self.phi(phi_input)
+
+ def aggregate(self, edge_index, messages, node_num):
+ aggr = scatter_mean(messages, edge_index[0, :].reshape(-1, 1),
+ dim_size=node_num)
+ return aggr
+
+ def update(self, aggr, node_features):
+ # (n_node, node_h_dim + h_dim)
+ gamma_input = ops.cat([node_features, aggr], axis=1)
+ return self.gamma(gamma_input) # (n_node, gamma_out_dim)
+
+
+def scatter_sum(src, index, dim_size):
+ """scatter sum"""
+ assert len(index.shape) == 2
+ assert index.shape[-1] == 1
+ assert src.shape[0] == index.shape[0]
+ assert len(src.shape) == 2
+
+ tmp_node = ops.zeros((dim_size, src.shape[1]), dtype=src.dtype)
+ out = ops.tensor_scatter_add(tmp_node, index, src)
+ return out
+
+
+def scatter_mean(src, index, dim_size):
+ """scatter mean"""
+ assert len(index.shape) == 2
+ assert index.shape[-1] == 1
+ assert src.shape[0] == index.shape[0]
+ assert len(src.shape) == 2
+
+ ones = ops.ones((index.shape[0], 1), dtype=src.dtype)
+ cnt = scatter_sum(ones, index, dim_size)
+ total = scatter_sum(src, index, dim_size)
+ out = total / cnt
+ return out
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/phympgn.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/phympgn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce06c7b8ad7263675281f39c17d947adf01e3430
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/phympgn.py
@@ -0,0 +1,169 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""phympgn"""
+from mindspore import nn, ops
+import mindspore as ms
+
+from .encoder_decoder import Encoder, Decoder
+from .mpnn_block import MPNNBlock
+from .laplace_block import LaplaceBlock
+from ..utils.padding import graph_padding
+
+class PhyMPGN(nn.Cell):
+ """PhyMPGN"""
+ def __init__(self, encoder_config, mpnn_block_config, decoder_config,
+ laplace_block_config, integral):
+ super().__init__()
+ self.node_encoder = Encoder(encoder_config['node_encoder_layers'])
+ self.edge_encoder = Encoder(encoder_config['edge_encoder_layers'])
+ self.mpnn_block = MPNNBlock(
+ mpnn_layers=mpnn_block_config['mpnn_layers'],
+ mpnn_num=mpnn_block_config['mpnn_num']
+ )
+ self.decoder = Decoder(decoder_config['node_decoder_layers'])
+ self.laplace_block = LaplaceBlock(
+ enc_dim=laplace_block_config['in_dim'],
+ h_dim=laplace_block_config['h_dim'],
+ out_dim=laplace_block_config['out_dim']
+ )
+
+ update_fn = {
+ 1: self.update_euler,
+ 2: self.update_rk2,
+ 4: self.update_rk4
+ }
+ self.update = update_fn[integral]
+
+ def construct(self, graph, steps):
+ """
+ Args:
+ graph (Graph): instance of Graph, involving edge_index, pos, y, and etc.
+ steps (int): steps of roll-out
+
+ Returns:
+ loss_states (Tensor): predicted states
+ """
+ loss_states = [graph.y] # [bn, 2]
+ # unroll for 1 step
+ graph_next = self.update(graph)
+ loss_states.append(graph_next.y)
+
+ graph = graph_next.detach()
+ # unroll for steps-1
+ for _ in range(steps - 1):
+ graph_next = self.update(graph)
+ loss_states.append(graph_next.y)
+ graph = graph_next
+
+ # [t, bn, 2]
+ loss_states = ops.stack(loss_states, axis=0)
+ return ops.index_select(loss_states, 1, graph.truth_index)
+
+ def get_temporal_diff(self, graph):
+ """compute results of F nonlinear operator"""
+ node_type = ops.one_hot(graph.node_type, graph.node_type.max() + 1)
+ node_type = node_type.astype(ms.float32)
+ graph.state_node = self.node_encoder(
+ ops.cat((graph.y, graph.pos, node_type), axis=-1))
+ # store dirichlet value
+ if hasattr(graph, 'dirichlet_index'):
+ graph.dirichlet_h_value = ops.index_select(
+ graph.state_node, 0, graph.dirichlet_index)
+ graph.inlet_h_value = ops.index_select(
+ graph.state_node, 0, graph.inlet_index)
+
+ rel_state = graph.y[graph.edge_index[1, :]] - \
+ graph.y[graph.edge_index[0, :]] # (b_e, 2)
+ # (b_e, 5) -> (b_e, h)
+ graph.state_edge = self.edge_encoder(
+ ops.cat((rel_state, graph.edge_attr), axis=-1))
+ mpnn_out = self.mpnn_block(graph) # (b_e, h)
+ decoder_out = self.decoder(mpnn_out) # (b_n, 2)
+
+ # laplace
+ laplace = self.laplace_block(graph) # (b_n, 2)
+
+ u_m, rho, d, mu = graph.u_m, graph.rho, graph.r * 2, graph.mu
+ re = rho * d * u_m / mu # (b_n, 1)
+ # (b_n, 1) * (b_n, 2) + (b_n, 2) -> (b_n, 2)
+ out = 1 / re * laplace + decoder_out
+
+ return out
+
+ def update_euler(self, graph):
+ """euler scheme"""
+ out = self.get_temporal_diff(graph)
+ graph.y = graph.y + out * graph.dt
+ # padding
+ graph_padding(graph)
+
+ return graph
+
+ def update_rk2(self, graph):
+ """rk2 scheme"""
+ u0 = graph.y
+ k1 = self.get_temporal_diff(graph) # (bn, 2)
+ u1 = u0 + k1 * graph.dt # (bn, 2) + (bn, 2) * (bn, 1) -> (bn, 2)
+ graph.y = u1
+ # padding
+ graph_padding(graph)
+
+ k2 = self.get_temporal_diff(graph)
+ graph.y = u0 + k1 * graph.dt / 2 + k2 * graph.dt / 2
+ # padding
+ graph_padding(graph)
+
+ return graph
+
+ def update_rk4(self, graph):
+ """rk4 scheme"""
+ # stage 1
+ u0 = graph.y
+ k1 = self.get_temporal_diff(graph)
+
+ # stage 2
+ u1 = u0 + k1 * graph.dt / 2.
+ graph.y = u1
+ # padding
+ graph_padding(graph)
+ k2 = self.get_temporal_diff(graph)
+
+ # stage 3
+ u2 = u0 + k2 * graph.dt / 2.
+ graph.y = u2
+ # padding
+ graph_padding(graph)
+ k3 = self.get_temporal_diff(graph)
+
+ # stage 4
+ u3 = u0 + k3 * graph.dt
+ graph.y = u3
+ # padding
+ graph_padding(graph)
+ k4 = self.get_temporal_diff(graph)
+
+ u4 = u0 + (k1 + 2 * k2 + 2 * k3 + k4) * graph.dt / 6.
+ graph.y = u4
+ # padding
+ graph_padding(graph)
+
+ return graph
+
+ @property
+ def num_params(self):
+ total = sum(param.size for param in self.trainable_params())
+ mpnn = self.mpnn_block.num_params
+ laplace = self.laplace_block.num_params
+ return total, mpnn, laplace
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/utils.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b8f70ce0be67761082f1d58e2e02c31fc425b0e
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/models/utils.py
@@ -0,0 +1,56 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""model utils"""
+from collections import OrderedDict
+
+from mindspore import nn
+
+
+def activation_func():
+ """activation function"""
+ return nn.ELU()
+
+
+def build_net(layers, activation_end=False):
+ """build net"""
+ net = nn.SequentialCell()
+ layer_n = len(layers)
+
+ assert layer_n >= 2
+
+ for i in range(layer_n - 2):
+ net.append(nn.Dense(layers[i], layers[i + 1]))
+ net.append(activation_func())
+ net.append(nn.Dense(layers[layer_n - 2], layers[layer_n - 1]))
+ if activation_end:
+ net.append(activation_func())
+ return net
+
+
+def build_dict_net(layers, activation_end=False):
+ """build dict net"""
+ layer_n = len(layers)
+
+ assert layer_n >= 2
+ d = OrderedDict()
+ for i in range(layer_n - 2):
+ d['dense' + str(i)] = nn.Dense(layers[i], layers[i + 1])
+ d['activation' + str(i)] = activation_func()
+ d['dense' + str(layer_n - 2)] = nn.Dense(layers[layer_n - 2], layers[layer_n - 1])
+ if activation_end:
+ d['activation' + str(layer_n - 2)] = activation_func()
+
+ net = nn.SequentialCell(d)
+ return net
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ad829162a4f515c7bc37784d5b44d01449ff83
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""init"""
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/trainer.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..23a96ebf42c2af4fe71efd56e0be81eb5b00ee63
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/trainer.py
@@ -0,0 +1,254 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""trainer"""
+import os
+import os.path as osp
+import time
+
+import numpy as np
+import mindspore as ms
+from mindspore import ops
+from mindflow.utils import print_log
+
+from .utils import AverageMeter, compute_average_correlation, compute_armse
+from ..loaders.data_loader import batch_graph
+from ..datasets.dataset import PDECFDataset
+
+
+class Trainer:
+ """
+ Args:
+ model: the instance of mindspore Cell's child class
+ optimizer: mindspore optimizer
+ scheduler: mindspore scheduler
+ config: the config of project
+ loss_func: loss function
+ """
+
+ def __init__(self, model, optimizer, scheduler, config, loss_func):
+ self.model = model
+ self.optimizer = optimizer
+ self.scheduler = scheduler
+ self.config = config
+ self.loss_func = loss_func
+
+ def train(self, tr_loader, val_loader):
+ """train"""
+ # continuous train
+ if self.config.continuous_train:
+ self.load_checkpoint()
+
+ min_val_loss = 1.0e+6
+ tr_batch_time = AverageMeter()
+ tr_data_time = AverageMeter()
+ tr_graph_time = AverageMeter()
+ tr_grad_time = AverageMeter()
+ tr_optim_time = AverageMeter()
+
+ val_batch_time = AverageMeter()
+ val_data_time = AverageMeter()
+ val_graph_time = AverageMeter()
+
+ def forward_fn(graph):
+ target = graph.y.transpose(1, 0, 2) # (t, n, 2)
+ graph.y = target[0] # (n, 2)
+ pred = self.model(graph, steps=target.shape[0]-1) # (t, n, 2)
+ loss = self.loss_func(pred, target)
+ return loss, pred
+
+ grad_fn = ops.value_and_grad(forward_fn, None,
+ self.model.trainable_params(),
+ has_aux=True)
+ for epoch in range(self.config.optim.start_epoch,
+ self.config.optim.epochs + self.config.optim.start_epoch):
+ tr_loss = self._train_loop(tr_loader, grad_fn, tr_batch_time,
+ tr_data_time, tr_graph_time, tr_grad_time, tr_optim_time)
+
+ if epoch == self.config.optim.start_epoch or \
+ epoch % self.config.optim.val_freq == 0:
+ self.save_checkpoint()
+ val_loss = self._evaluate_loop(val_loader, val_batch_time, val_data_time,
+ val_graph_time)
+
+ tr_time_str = '[Epoch {:>4}/{}] Batch Time: {:.3f} ({:.3f}) ' \
+ 'Data Time: {:.3f} ({:.3f}) Graph Time: {:.3f} ({:.3f}) ' \
+ 'Grad Time: {:.3f} ({:.3f}) Optim Time: {:.3f} ({:.3f})'.format(
+ epoch, self.config.optim.start_epoch + self.config.optim.epochs - 1,
+ tr_batch_time.val, tr_batch_time.avg, tr_data_time.val,
+ tr_data_time.avg,
+ tr_graph_time.val, tr_graph_time.avg, tr_grad_time.val,
+ tr_grad_time.avg,
+ tr_optim_time.val, tr_optim_time.avg)
+ val_time_str = '[Epoch {:>4}/{}] Batch Time: {:.3f} ({:.3f}) ' \
+ 'Data Time: {:.3f} ({:.3f}) Graph Time: {:.3f} ({:.3f})'.format(
+ epoch, self.config.optim.start_epoch + self.config.optim.epochs - 1,
+ val_batch_time.val, val_batch_time.avg, val_data_time.val,
+ val_data_time.avg,
+ val_graph_time.val, val_graph_time.avg)
+ info_str = '[Epoch {:>4}/{}] tr_loss: {:.2e} ' \
+ '\t\tval_loss: {:.2e} {}'.format(
+ epoch, self.config.optim.start_epoch + self.config.optim.epochs - 1,
+ tr_loss, val_loss, '{}')
+ if val_loss < min_val_loss:
+ min_val_loss = val_loss
+ info_str = info_str.format('[MIN]')
+ self.save_checkpoint(val=True)
+ else:
+ info_str = info_str.format(' ')
+ print_log(tr_time_str)
+ print_log(val_time_str)
+ print_log(info_str)
+
+ # @jit
+ def _train_loop(self, tr_loader, grad_fn, batch_time, data_time, graph_time, grad_time,
+ optim_time):
+ """train loop"""
+ self.model.set_train()
+ loss_list = []
+ end = time.time()
+ for _, data in enumerate(tr_loader.create_dict_iterator()):
+ # measure time
+ data_t_end = time.time()
+ data_time.update(data_t_end - end)
+
+ graph = batch_graph(data)
+ graph_t_end = time.time()
+ graph_time.update(graph_t_end - data_t_end)
+
+ (loss, _), grads = grad_fn(graph)
+ # clap grad norm
+ grads = ops.clip_by_norm(grads, max_norm=0.15)
+ grad_t_end = time.time()
+ grad_time.update(grad_t_end - graph_t_end)
+
+ self.optimizer(grads)
+ optim_t_end = time.time()
+ optim_time.update(optim_t_end - grad_t_end)
+
+ loss_list.append(loss.asnumpy())
+
+ # measure time
+ batch_time.update(time.time() - end)
+ end = time.time()
+ return np.mean(loss_list)
+
+ # @jit
+ def _evaluate_loop(self, val_loader, batch_time, data_time, graph_time):
+ """evaluate loop"""
+ self.model.set_train(False)
+ loss_list = []
+ end = time.time()
+ for _, data in enumerate(val_loader.create_dict_iterator()):
+ # measure time
+ data_time.update(time.time() - end)
+
+ start = time.time()
+ graph = batch_graph(data)
+ graph_time.update(time.time() - start)
+
+ target = graph.y.transpose(1, 0, 2) # (t, n, 2)
+ graph.y = target[0] # (n, 2)
+ pred = self.model(graph, steps=target.shape[0]-1) # (t, n, 2)
+ loss = self.loss_func(pred, target)
+ loss_list.append(loss.asnumpy())
+
+ # measure time
+ batch_time.update(time.time() - end)
+ end = time.time()
+ return np.mean(loss_list)
+
+ def test(self, te_loader):
+ """test"""
+ self.model.set_train(False)
+ self.load_checkpoint(val=True)
+
+ inference_time_list = []
+ mse_list = []
+ armse_list = []
+ corre_list = []
+ pred_list = []
+ target_list = []
+ te_num = len(te_loader)
+ for b_i, data in enumerate(te_loader.create_dict_iterator()):
+ graph = batch_graph(data)
+ target = graph.y.transpose(1, 0, 2)
+ t = target.shape[0]
+ graph.y = target[0]
+ start_time = time.time()
+ pred = self.model(graph, steps=target.shape[0]-1)
+ inference_time = time.time() - start_time
+ inference_time_list.append(inference_time)
+ target = ops.index_select(target, 1, graph.truth_index)
+
+ # dimensional
+ pos = ops.index_select(graph.pos, 0, graph.truth_index)
+ pred, target, pos = PDECFDataset.dimensional(
+ u_pred=pred,
+ u_gt=target,
+ pos=pos,
+ u_m=graph.u_m,
+ d=graph.r * 2
+ )
+
+ te_loss = ops.mse_loss(pred, target)
+ armse = compute_armse(pred, target)
+ mse_list.append(te_loss.asnumpy())
+ armse_list.append(armse[-1].asnumpy())
+ pred_list.append(pred.asnumpy())
+ target_list.append(target.asnumpy())
+
+ info_str = '[TEST {:>2}/{}] MSE at {}t: {:.2e}, armse: {:.3f}, time: {:.2f}s' \
+ .format(b_i, te_num, t, te_loss.asnumpy().mean(), armse[-1].item(), inference_time)
+ print_log(info_str)
+
+ corre_list = compute_average_correlation(pred_list, target_list)
+ corre = np.mean(corre_list)
+ info_str = '[Test {}] Mean Loss: {:.2e}, Mean armse: {:.3f}, corre: {:.3f}, time: {:.2f}' \
+ .format(len(te_loader), np.mean(mse_list), np.mean(armse_list), corre,
+ np.mean(inference_time_list))
+ print_log(info_str)
+
+ def save_checkpoint(self, val=False):
+ """save checkpoint"""
+ if val:
+ ckpt_path = osp.join(self.config.path.ckpt_path,
+ f'ckpt-{self.config.experiment_name}-val/')
+ else:
+ ckpt_path = osp.join(self.config.path.ckpt_path,
+ f'ckpt-{self.config.experiment_name}-tr/')
+
+ if not osp.exists(ckpt_path):
+ os.makedirs(ckpt_path)
+ ms.save_checkpoint(self.model.parameters_dict(),
+ osp.join(ckpt_path, 'model.ckpt'))
+ ms.save_checkpoint(self.optimizer.parameters_dict(),
+ osp.join(ckpt_path, 'optim.ckpt'))
+
+ def load_checkpoint(self, val=False):
+ """load checkpoint"""
+ if val:
+ ckpt_path = osp.join(self.config.path.ckpt_path,
+ f'ckpt-{self.config.experiment_name}-val/')
+ else:
+ ckpt_path = osp.join(self.config.path.ckpt_path,
+ f'ckpt-{self.config.experiment_name}-tr/')
+
+ ckpt_model = ms.load_checkpoint(
+ osp.join(ckpt_path, 'model.ckpt'))
+ ms.load_param_into_net(self.model, ckpt_model)
+ if self.optimizer is not None:
+ ckpt_optim = ms.load_checkpoint(
+ osp.join(ckpt_path, 'optim.ckpt'))
+ ms.load_param_into_net(self.optimizer, ckpt_optim)
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/utils.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..29954fa7eea23ac116b236644e8b9f85ced0d66f
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/trainers/utils.py
@@ -0,0 +1,82 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""trainer utils"""
+import numpy as np
+from mindspore import ops
+
+
+class AverageMeter:
+ """Computes and stores the average and current value"""
+
+ def __init__(self):
+ self.val = 0
+ self.avg = 0
+ self.sum = 0
+ self.count = 0
+
+ def update(self, val, n=1):
+ self.val = val
+ self.sum += val * n
+ self.count += n
+ self.avg = self.sum / self.count
+
+
+def compute_armse(pred, truth):
+ """compute armse"""
+ armses = []
+ for i in range(pred.shape[0]):
+ nume = ops.norm(pred[:i+1] - truth[:i+1])
+ deno = ops.norm(truth[:i+1])
+ res = nume / deno
+ armses.append(res)
+
+ return armses
+
+
+def correlation(u, truth):
+ """compute correlation"""
+ u = u.reshape(1, -1)
+ truth = truth.reshape(1, -1)
+ u_truth = np.concatenate((u, truth), axis=0)
+ coef = np.corrcoef(u_truth)[0][1]
+ return coef
+
+
+def cal_cur_time_corre(u, truth):
+ """
+ compute correlation per time
+ """
+ coef_list = []
+ for i in range(u.shape[0]):
+ cur_truth = truth[i]
+ cur_u = u[i]
+ cur_coef = correlation(cur_u, cur_truth)
+ coef_list.append(cur_coef)
+ return coef_list
+
+
+def compute_average_correlation(pred_list, truth_list):
+ """compute average correlation
+ """
+ corr_data = []
+ for i in range(len(pred_list)):
+ pred = pred_list[i]
+ truth = truth_list[i]
+ coef_list = cal_cur_time_corre(pred, truth)
+ corr_data.append(np.array(coef_list))
+
+ corr = np.mean(corr_data, axis=0) # [b, t] -> [t,]
+
+ return corr
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/__init__.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7696e5e2dee77755ab159bae56e8c07e1aaf449
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/__init__.py
@@ -0,0 +1,15 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""utils utils"""
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/padding.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..746a40df1d6b9787cb4e96a57ee9c7a8161dcc4e
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/padding.py
@@ -0,0 +1,83 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""padding"""
+
+
+def periodic_padding(features, source_index, target_index):
+ """
+
+ Args:
+ features (Tensor): shape [n, ...], the origin features
+ source_index (Tensor): shape [m,]
+ target_index (Tensor): shape [m,]
+
+ Returns:
+ features (Tensor): shape [n, ...], the padded features
+ """
+ features[target_index] = features[source_index]
+ return features
+
+
+def dirichlet_padding(features, padding_index, padding_value):
+ """dirichlet padding"""
+ if len(features.shape) == 3:
+ # (m, t, d)
+ features[padding_index] = padding_value.unsqueeze(1)\
+ .repeat(features.shape[1], axis=1)
+ else: # == 2
+ features[padding_index] = padding_value
+ return features
+
+
+def neumann_padding(features, source_index, target_index):
+ """neumann padding"""
+ features[target_index] = features[source_index]
+ return features
+
+
+def graph_padding(graph, clone=False):
+ """graph padding"""
+ if hasattr(graph, 'dirichlet_index'):
+ graph.y = dirichlet_padding(graph.y, graph.dirichlet_index,
+ graph.dirichlet_value)
+ if hasattr(graph, 'inlet_index'):
+ graph.y = dirichlet_padding(graph.y, graph.inlet_index,
+ graph.inlet_value)
+ if hasattr(graph, 'periodic_src_index'):
+ graph.y = periodic_padding(graph.y, graph.periodic_src_index,
+ graph.periodic_tgt_index)
+ if hasattr(graph, 'neumann_src_index'):
+ graph.y = neumann_padding(graph.y, graph.neumann_src_index,
+ graph.neumann_tgt_index)
+
+ if clone:
+ graph.y = graph.y.copy()
+
+
+def h_padding(h, graph):
+ """hidden state padding"""
+ if hasattr(graph, 'dirichlet_index'):
+ h = dirichlet_padding(h, graph.dirichlet_index,
+ graph.dirichlet_h_value)
+ if hasattr(graph, 'inlet_index'):
+ h = dirichlet_padding(h, graph.inlet_index,
+ graph.inlet_h_value)
+ if hasattr(graph, 'periodic_src_index'):
+ h = periodic_padding(h, graph.periodic_src_index,
+ graph.periodic_tgt_index)
+ if hasattr(graph, 'neumann_src_index'):
+ h = neumann_padding(h, graph.neumann_src_index,
+ graph.neumann_tgt_index)
+ return h
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/voronoi_laplace.py b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/voronoi_laplace.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad16310f3f97f5e2f5c26eddbcd8433d99aca9e3
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/src/utils/voronoi_laplace.py
@@ -0,0 +1,242 @@
+# Copyright 2025 Huawei Technologies Co., Ltd
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# ==============================================================================
+"""voronoi laplace"""
+from shapely.geometry import MultiPoint
+from tqdm import tqdm
+import numpy as np
+
+
+def compute_discrete_laplace(pos, edge_index, face):
+ """
+ Compute discrete Laplace-Beltrami operator.
+
+ Args:
+ pos: shape (n, 2), position of nodes
+ edge_index: shape (2, e)
+ face: shape (3, n_tri)
+
+ Returns:
+ l_matrix: shape (n, n), laplace matrix
+ d_inv: d vector
+ """
+ w_matrix = compute_weight_matrix(pos, edge_index, face) # (n, n)
+ v = np.sum(w_matrix, axis=1)
+ v_matrix = np.diag(v) # (n, n)
+ a_matrix = v_matrix - w_matrix
+ assert not np.isinf(a_matrix).any()
+ assert not np.isnan(a_matrix).any()
+
+ d = compute_d_vector(pos, face)
+ d_inv = 1 / d
+ d_inv[np.isinf(d_inv)] = 0
+ d_matrix_inv = np.diag(d_inv) # (n, n)
+ assert not np.isinf(d_matrix_inv).any()
+ assert not np.isnan(d_matrix_inv).any()
+
+ l_matrix_ = d_matrix_inv @ a_matrix
+ return -l_matrix_, d_inv
+
+
+# todo: optimize complexity
+def compute_weight_matrix(pos, edge_index, face):
+ """
+ Compute weight matrix of discrete Laplace-Beltrami operator proposed by
+ Pinkall and Polthier.
+
+ Args:
+ pos: shape (n, 2), position of nodes
+ edge_index: shape (2, e)
+ face: shape (3, n_tri)
+
+ Returns:
+ weights: shape (n, n), weight matrix
+ """
+ n = pos.shape[0]
+ e = edge_index.shape[1]
+ weights = np.zeros((n, n), dtype=np.float32)
+ eps = np.finfo(np.float32).eps
+ for e_i in tqdm(range(e)):
+ edge = edge_index[:, e_i]
+ i, j = edge
+ nodes = find_opposite_nodes(edge, face)
+ if nodes:
+ p, q = nodes
+ alpha = compute_opposite_angle([
+ pos[i], pos[j], pos[p]
+ ])
+ beta = compute_opposite_angle([
+ pos[i], pos[j], pos[q]
+ ])
+ if np.isnan(alpha) or np.isnan(beta):
+ w = 0.
+ elif alpha < eps or beta < eps:
+ w = 0.
+ else:
+ w = (cot(alpha) + cot(beta)) / 2
+ if np.isnan(w): # for debug
+ print('weights nan, e_{}, n_{}-n_{}'.format(e_i, i, j))
+ weights[i, j] = w
+ return weights
+
+
+# todo: optimize complexity
+def compute_d_vector(pos, face):
+ """
+ Compute d matrix of discrete Laplace-Beltrami operator proposed by Meyer.
+ """
+ d_vector = []
+ n = pos.shape[0]
+ for i in tqdm(range(n)):
+ tris = find_node_triangles(i, face)
+ area = compute_all_voronoi_area(pos, tris)
+ d_vector.append(area)
+ if np.isnan(area): # for debug
+ print('d nan, n_{}'.format(i))
+ return np.array(d_vector, dtype=np.float32)
+
+
+def find_opposite_nodes(edge, triangles):
+ """
+ Find the two opposite nodes of the edge in the triangles mesh.
+ Args:
+ edge: shape (2,), (v_i, v_j)
+ triangles: shape (3, n_tri), each column is (v_p, v_q, v_r)
+
+ Returns:
+ nodes: null List if the two opposite nodes aren't found or (v_a, v_b)
+ """
+ nodes = []
+ n_tri = triangles.shape[1]
+ for i in range(n_tri):
+ tri = triangles[:, i]
+ is_subset = np.all(np.isin(edge, tri))
+ if is_subset:
+ mask = ~np.isin(tri, edge)
+ diff = tri[mask]
+ nodes.append(diff.item())
+ if len(nodes) == 1:
+ nodes = []
+ assert len(nodes) in {0, 2}
+ return nodes
+
+
+def compute_opposite_angle(triangle):
+ """
+ Compute the opposite angle of edge ij in triangle.
+ Args:
+ triangle: len (3,), position three nodes (i, j, k)
+
+ Returns:
+ angle: shape (1,), the angle of vector ki and kj.
+ """
+ v_i, v_j, v_k = triangle[0], triangle[1], triangle[2]
+ e_ki = v_k - v_i
+ e_kj = v_k - v_j
+ cos = np.dot(e_ki, e_kj) / \
+ (np.linalg.norm(e_ki) * np.linalg.norm(e_kj))
+ # if cos > 1. or cos < -1., angle will be nan.
+ angle = np.arccos(cos)
+ return angle
+
+
+def cot(theta):
+ """
+ cot = 1 / tan
+ """
+ ret = 1 / np.tan(theta)
+ if ret < np.finfo(np.float32).eps:
+ ret = 0.
+ return ret
+
+
+def compute_tri_circumcenter(triangle):
+ """compute triangle circumcenter"""
+ a, b, c = triangle[0], triangle[1], triangle[2]
+ if a.shape[0] == 2:
+ d = 2 * (a[0] * (b[1] - c[1]) + b[0] * (c[1] - a[1]) + c[0] * (
+ a[1] - b[1]))
+ ux = ((a[0] ** 2 + a[1] ** 2) * (b[1] - c[1]) + (
+ b[0] ** 2 + b[1] ** 2) * (c[1] - a[1]) + (
+ c[0] ** 2 + c[1] ** 2) * (a[1] - b[1])) / d
+ uy = ((a[0] ** 2 + a[1] ** 2) * (c[0] - b[0]) + (
+ b[0] ** 2 + b[1] ** 2) * (a[0] - c[0]) + (
+ c[0] ** 2 + c[1] ** 2) * (b[0] - a[0])) / d
+ center = np.stack([ux, uy], dtype=np.float32)
+ else:
+ ab = b - a
+ ac = c - a
+ ab_magnitude = np.linalg.norm(ab)
+ ac_magnitude = np.linalg.norm(ac)
+
+ # calculate triangle normal
+ n = np.cross(ab, ac)
+ n_magnitude = np.linalg.norm(n)
+
+ # Calculate circumcenter
+ center = a + (ab_magnitude * np.cross(
+ ab_magnitude * ac - ac_magnitude * ab, n)) / (2 * n_magnitude ** 2)
+ return center
+
+
+def compute_voronoi_area(triangle):
+ """
+ Compute voronoi area.
+ """
+ a, b, c = triangle[0], triangle[1], triangle[2]
+
+ ab = b - a
+ ac = c - a
+ cos_a = np.dot(ab, ac) / \
+ (np.linalg.norm(ab) * np.linalg.norm(ac))
+
+ eps = np.finfo(np.float32).eps
+ if np.abs(cos_a - 1.0) < eps:
+ area = 0.
+ elif cos_a < 0: # A is obtuse
+ area = 0.5 * MultiPoint(triangle).convex_hull.area
+ else:
+ circumcenter = compute_tri_circumcenter(triangle)
+ mab = (a + b) / 2
+ mac = (a + c) / 2
+ area = MultiPoint([a, mab, circumcenter, mac]).convex_hull.area
+ return area
+
+
+def compute_all_voronoi_area(pos, tris):
+ """
+ Compute area.
+ """
+ areas_sum = 0
+ n_tri = tris.shape[0]
+ for tri_i in range(n_tri):
+ i, j, k = tris[tri_i]
+ area = compute_voronoi_area([pos[i], pos[j], pos[k]])
+ areas_sum += area
+ return areas_sum
+
+
+def find_node_triangles(node, triangles):
+ """
+ Find the triangles consists of the node
+ """
+ tris = []
+ n_tri = triangles.shape[1]
+ for i in range(n_tri):
+ triangle = triangles[:, i]
+ if node in triangles[:, i]:
+ # move node to the first loconcatenateion in triangle
+ tri = np.concatenate([[node], triangle[triangle != node]])
+ tris.append(tri)
+ return np.stack(tris)
diff --git a/MindFlow/applications/data_mechanism_fusion/phympgn/yamls/train.yaml b/MindFlow/applications/data_mechanism_fusion/phympgn/yamls/train.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f0ba1b887a02be8b1187554fd5824de1fb63cd0d
--- /dev/null
+++ b/MindFlow/applications/data_mechanism_fusion/phympgn/yamls/train.yaml
@@ -0,0 +1,56 @@
+# super parameters
+data:
+ dataset_start: 0
+ dataset_used: 4
+ te_dataset_start: 0
+ te_dataset_used: 9
+ time_start: 0
+ time_used: 2000
+ tr_window_size: 20
+ val_window_size: 20
+ te_window_size: 2000
+
+optim:
+ lr: 1.25e-4
+ steplr_size: 200
+ steplr_gamma: 0.96
+ start_epoch: 1
+ epochs: 1600
+ batch_size: 4
+ num_workers: 0
+ window_shuffle: true
+ val_freq: 2
+
+# network architecture
+network:
+ integral: 2
+ encoder_config:
+ node_encoder_layers: [8, 32, 64, 128]
+ edge_encoder_layers: [5, 32, 64, 128]
+
+ mpnn_block_config:
+ mpnn_layers:
+ - [384, 128, 128, 128, 128]
+ - [256, 128, 128, 128, 128]
+ mpnn_num: 5
+
+ decoder_config:
+ node_decoder_layers: [128, 64, 32, 2]
+
+ laplace_block_config:
+ in_dim: 4
+ h_dim: 24
+ out_dim: 2
+
+# experiment name
+experiment_name: "tmp"
+continuous_train: false
+
+path:
+ ckpt_path: ckpts/
+
+ # data path
+ data_root_dir: data/2d_cf
+ tr_raw_data: train_cf_4x2000x1598x2.h5
+ val_raw_data: train_cf_4x2000x1598x2.h5
+ te_raw_data: test_cf_9x2000x1598x2.h5